Browse Source

feat: copy context

pull/459/head
Makonike 3 years ago
parent
commit
34e5b3d50b
  1. 1
      dtmsvr/trans_class.go
  2. 47
      dtmsvr/utils.go
  3. 40
      dtmsvr/utils_test.go

1
dtmsvr/trans_class.go

@ -103,6 +103,7 @@ func TransFromDtmRequest(ctx context.Context, c *dtmgpb.DtmRequest) *TransGlobal
},
}}
r.ReqExtra = c.ReqExtra
r.Context = CopyContext(ctx)
if c.Steps != "" {
dtmimp.MustUnmarshalString(c.Steps, &r.Steps)
}

47
dtmsvr/utils.go

@ -7,14 +7,18 @@
package dtmsvr
import (
"context"
"fmt"
"reflect"
"time"
"unsafe"
"github.com/dtm-labs/dtm/client/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/dtmsvr/config"
"github.com/dtm-labs/dtm/dtmsvr/storage"
"github.com/dtm-labs/dtm/dtmsvr/storage/registry"
"github.com/lithammer/shortuuid/v3"
"google.golang.org/grpc/metadata"
)
type branchStatus struct {
@ -50,3 +54,46 @@ func GetTransGlobal(gid string) *TransGlobal {
//nolint:staticcheck
return &TransGlobal{TransGlobalStore: *trans}
}
type iface struct {
itab, data uintptr
}
type valueCtx struct {
context.Context
key, value interface{}
}
// CopyContext copy context with value and grpc metadata
// if raw context is nil, return nil
func CopyContext(ctx context.Context) context.Context {
if ctx == nil {
return ctx
}
newCtx := context.Background()
kv := make(map[interface{}]interface{})
getKeyValues(ctx, kv)
for k, v := range kv {
newCtx = context.WithValue(newCtx, k, v)
}
if md, ok := metadata.FromIncomingContext(ctx); ok {
newCtx = metadata.NewIncomingContext(newCtx, md)
}
return newCtx
}
func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) {
rtType := reflect.TypeOf(ctx).String()
if rtType == "*context.emptyCtx" {
return
}
ictx := *(*iface)(unsafe.Pointer(&ctx))
if ictx.data == 0 {
return
}
valCtx := (*valueCtx)(unsafe.Pointer(ictx.data))
if valCtx != nil && valCtx.key != nil && valCtx.value != nil {
kv[valCtx.key] = valCtx.value
}
getKeyValues(valCtx.Context, kv)
}

40
dtmsvr/utils_test.go

@ -7,9 +7,11 @@
package dtmsvr
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)
func TestUtils(t *testing.T) {
@ -29,3 +31,41 @@ func TestSetNextCron(t *testing.T) {
tg.TimeoutToFail = 3
assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset))
}
func TestCopyContext(t *testing.T) {
ctxWithValue := context.WithValue(context.Background(), "key", "value")
newCtx := CopyContext(ctxWithValue)
assert.Equal(t, ctxWithValue.Value("key"), newCtx.Value("key"))
var ctx context.Context
newCtx = CopyContext(ctx)
assert.Nil(t, newCtx)
}
func TestCopyContextRecursive(t *testing.T) {
ctxWithValue := context.WithValue(context.Background(), "key", "value")
nestedCtx := context.WithValue(ctxWithValue, "nested_key", "nested_value")
newCtx := CopyContext(nestedCtx)
assert.Equal(t, nestedCtx.Value("nested_key"), newCtx.Value("nested_key"))
assert.Equal(t, nestedCtx.Value("key"), newCtx.Value("key"))
}
func TestCopyContextWithMetadata(t *testing.T) {
md := metadata.New(map[string]string{"key": "value"})
ctx := metadata.NewIncomingContext(context.Background(), md)
newCtx := CopyContext(ctx)
copiedMD, ok := metadata.FromIncomingContext(newCtx)
assert.True(t, ok)
assert.Equal(t, 1, len(copiedMD["key"]))
assert.Equal(t, "value", copiedMD["key"][0])
}
func BenchmarkCopyContext(b *testing.B) {
ctx := context.WithValue(context.Background(), "key", "value")
b.ResetTimer()
for i := 0; i < b.N; i++ {
CopyContext(ctx)
}
}

Loading…
Cancel
Save