From 34e5b3d50bcd2738e626cd3ac8659b7204bea1d2 Mon Sep 17 00:00:00 2001 From: Makonike Date: Sun, 27 Aug 2023 18:27:37 +0800 Subject: [PATCH] feat: copy context --- dtmsvr/trans_class.go | 1 + dtmsvr/utils.go | 47 +++++++++++++++++++++++++++++++++++++++++++ dtmsvr/utils_test.go | 40 ++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/dtmsvr/trans_class.go b/dtmsvr/trans_class.go index a311879..2038bdf 100644 --- a/dtmsvr/trans_class.go +++ b/dtmsvr/trans_class.go @@ -103,6 +103,7 @@ func TransFromDtmRequest(ctx context.Context, c *dtmgpb.DtmRequest) *TransGlobal }, }} r.ReqExtra = c.ReqExtra + r.Context = CopyContext(ctx) if c.Steps != "" { dtmimp.MustUnmarshalString(c.Steps, &r.Steps) } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 6ab5ffe..6fac299 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -7,14 +7,18 @@ package dtmsvr import ( + "context" "fmt" + "reflect" "time" + "unsafe" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" "github.com/dtm-labs/dtm/dtmsvr/storage" "github.com/dtm-labs/dtm/dtmsvr/storage/registry" "github.com/lithammer/shortuuid/v3" + "google.golang.org/grpc/metadata" ) type branchStatus struct { @@ -50,3 +54,46 @@ func GetTransGlobal(gid string) *TransGlobal { //nolint:staticcheck return &TransGlobal{TransGlobalStore: *trans} } + +type iface struct { + itab, data uintptr +} + +type valueCtx struct { + context.Context + key, value interface{} +} + +// CopyContext copy context with value and grpc metadata +// if raw context is nil, return nil +func CopyContext(ctx context.Context) context.Context { + if ctx == nil { + return ctx + } + newCtx := context.Background() + kv := make(map[interface{}]interface{}) + getKeyValues(ctx, kv) + for k, v := range kv { + newCtx = context.WithValue(newCtx, k, v) + } + if md, ok := metadata.FromIncomingContext(ctx); ok { + newCtx = metadata.NewIncomingContext(newCtx, md) + } + return newCtx +} + +func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { + rtType := reflect.TypeOf(ctx).String() + if rtType == "*context.emptyCtx" { + return + } + ictx := *(*iface)(unsafe.Pointer(&ctx)) + if ictx.data == 0 { + return + } + valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) + if valCtx != nil && valCtx.key != nil && valCtx.value != nil { + kv[valCtx.key] = valCtx.value + } + getKeyValues(valCtx.Context, kv) +} diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index 15f8635..5a78931 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -7,9 +7,11 @@ package dtmsvr import ( + "context" "testing" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" ) func TestUtils(t *testing.T) { @@ -29,3 +31,41 @@ func TestSetNextCron(t *testing.T) { tg.TimeoutToFail = 3 assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) } + +func TestCopyContext(t *testing.T) { + ctxWithValue := context.WithValue(context.Background(), "key", "value") + newCtx := CopyContext(ctxWithValue) + assert.Equal(t, ctxWithValue.Value("key"), newCtx.Value("key")) + + var ctx context.Context + newCtx = CopyContext(ctx) + assert.Nil(t, newCtx) +} + +func TestCopyContextRecursive(t *testing.T) { + ctxWithValue := context.WithValue(context.Background(), "key", "value") + nestedCtx := context.WithValue(ctxWithValue, "nested_key", "nested_value") + newCtx := CopyContext(nestedCtx) + + assert.Equal(t, nestedCtx.Value("nested_key"), newCtx.Value("nested_key")) + assert.Equal(t, nestedCtx.Value("key"), newCtx.Value("key")) +} + +func TestCopyContextWithMetadata(t *testing.T) { + md := metadata.New(map[string]string{"key": "value"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + newCtx := CopyContext(ctx) + + copiedMD, ok := metadata.FromIncomingContext(newCtx) + assert.True(t, ok) + assert.Equal(t, 1, len(copiedMD["key"])) + assert.Equal(t, "value", copiedMD["key"][0]) +} + +func BenchmarkCopyContext(b *testing.B) { + ctx := context.WithValue(context.Background(), "key", "value") + b.ResetTimer() + for i := 0; i < b.N; i++ { + CopyContext(ctx) + } +}