From 5c79adf760753c9dd9e9d41e976b23a8d4fd3fe0 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 12:21:12 +0800 Subject: [PATCH] fix --- dtmsvr/utils.go | 27 ++++++++++++++++++++++++++- dtmsvr/utils_test.go | 5 +++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 8e6a8ba..b1ae5e0 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -9,7 +9,9 @@ package dtmsvr import ( "context" "fmt" + "reflect" "time" + "unsafe" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" @@ -69,9 +71,32 @@ func CopyContext(ctx context.Context) context.Context { return ctx } newCtx := context.Background() - // TODO: copy value in context + kv := make(map[interface{}]interface{}) + getKeyValues(ctx, kv) + for k, v := range kv { + newCtx = context.WithValue(newCtx, k, v) + } if md, ok := metadata.FromIncomingContext(ctx); ok { newCtx = metadata.NewIncomingContext(newCtx, md) } + if md, ok := metadata.FromOutgoingContext(ctx); ok { + newCtx = metadata.NewOutgoingContext(newCtx, md) + } return newCtx } + +func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { + rtType := reflect.TypeOf(ctx).String() + if rtType == "*context.emptyCtx" || rtType == "*context.timerCtx" { + return + } + ictx := *(*iface)(unsafe.Pointer(&ctx)) + if ictx.data == 0 { + return + } + valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) + if valCtx != nil && valCtx.key != nil && valCtx.value != nil { + kv[valCtx.key] = valCtx.value + } + getKeyValues(valCtx.Context, kv) +} diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index 5a78931..d91f7a9 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -54,12 +54,17 @@ func TestCopyContextRecursive(t *testing.T) { func TestCopyContextWithMetadata(t *testing.T) { md := metadata.New(map[string]string{"key": "value"}) ctx := metadata.NewIncomingContext(context.Background(), md) + ctx = metadata.NewOutgoingContext(ctx, md) newCtx := CopyContext(ctx) copiedMD, ok := metadata.FromIncomingContext(newCtx) assert.True(t, ok) assert.Equal(t, 1, len(copiedMD["key"])) assert.Equal(t, "value", copiedMD["key"][0]) + copiedMD, ok = metadata.FromOutgoingContext(newCtx) + assert.True(t, ok) + assert.Equal(t, 1, len(copiedMD["key"])) + assert.Equal(t, "value", copiedMD["key"][0]) } func BenchmarkCopyContext(b *testing.B) {