From dfcfaae8688fe83284419b840285dc2d51217111 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 12:52:06 +0800 Subject: [PATCH] fix --- dtmsvr/utils.go | 39 ++++++++++++++++++++++++++++++++++++--- dtmsvr/utils_test.go | 24 +++++++++++++++++------- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index b1ae5e0..ac60f3c 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -61,7 +61,35 @@ type iface struct { type valueCtx struct { context.Context - key, value interface{} + key, value any +} + +type cancelCtx struct { + context.Context +} + +type timerCtx struct { + cancelCtx *cancelCtx +} + +func (*timerCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*timerCtx) Done() <-chan struct{} { + return nil +} + +func (*timerCtx) Err() error { + return nil +} + +func (*timerCtx) Value(key any) any { + return nil +} + +func (e *timerCtx) String() string { + return "" } // CopyContext copy context with value and grpc metadata @@ -87,7 +115,7 @@ func CopyContext(ctx context.Context) context.Context { func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { rtType := reflect.TypeOf(ctx).String() - if rtType == "*context.emptyCtx" || rtType == "*context.timerCtx" { + if rtType == "*context.emptyCtx" { return } ictx := *(*iface)(unsafe.Pointer(&ctx)) @@ -95,8 +123,13 @@ func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { return } valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) - if valCtx != nil && valCtx.key != nil && valCtx.value != nil { + if valCtx.key != nil && valCtx.value != nil && rtType == "*context.valueCtx" { kv[valCtx.key] = valCtx.value } + if rtType == "*context.timerCtx" { + tCtx := (*timerCtx)(unsafe.Pointer(ictx.data)) + getKeyValues(tCtx.cancelCtx, kv) + return + } getKeyValues(valCtx.Context, kv) } diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index d91f7a9..cd3007f 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -32,10 +32,14 @@ func TestSetNextCron(t *testing.T) { assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) } +type testContextType string + func TestCopyContext(t *testing.T) { - ctxWithValue := context.WithValue(context.Background(), "key", "value") + var key testContextType = "key" + var value testContextType = "value" + ctxWithValue := context.WithValue(context.Background(), key, value) newCtx := CopyContext(ctxWithValue) - assert.Equal(t, ctxWithValue.Value("key"), newCtx.Value("key")) + assert.Equal(t, ctxWithValue.Value(key), newCtx.Value(key)) var ctx context.Context newCtx = CopyContext(ctx) @@ -43,12 +47,16 @@ func TestCopyContext(t *testing.T) { } func TestCopyContextRecursive(t *testing.T) { - ctxWithValue := context.WithValue(context.Background(), "key", "value") - nestedCtx := context.WithValue(ctxWithValue, "nested_key", "nested_value") + var key testContextType = "key" + var value testContextType = "value" + var nestedKey testContextType = "nested_key" + var nestedValue testContextType = "nested_value" + ctxWithValue := context.WithValue(context.Background(), key, value) + nestedCtx := context.WithValue(ctxWithValue, nestedKey, nestedValue) newCtx := CopyContext(nestedCtx) - assert.Equal(t, nestedCtx.Value("nested_key"), newCtx.Value("nested_key")) - assert.Equal(t, nestedCtx.Value("key"), newCtx.Value("key")) + assert.Equal(t, nestedCtx.Value(nestedKey), newCtx.Value(nestedKey)) + assert.Equal(t, nestedCtx.Value(key), newCtx.Value(key)) } func TestCopyContextWithMetadata(t *testing.T) { @@ -68,7 +76,9 @@ func TestCopyContextWithMetadata(t *testing.T) { } func BenchmarkCopyContext(b *testing.B) { - ctx := context.WithValue(context.Background(), "key", "value") + var key testContextType = "key" + var value testContextType = "value" + ctx := context.WithValue(context.Background(), key, value) b.ResetTimer() for i := 0; i < b.N; i++ { CopyContext(ctx)