Browse Source

fix

pull/459/head
Makonike 3 years ago
parent
commit
dfcfaae868
  1. 39
      dtmsvr/utils.go
  2. 24
      dtmsvr/utils_test.go

39
dtmsvr/utils.go

@ -61,7 +61,35 @@ type iface struct {
type valueCtx struct { type valueCtx struct {
context.Context context.Context
key, value interface{} key, value any
}
type cancelCtx struct {
context.Context
}
type timerCtx struct {
cancelCtx *cancelCtx
}
func (*timerCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*timerCtx) Done() <-chan struct{} {
return nil
}
func (*timerCtx) Err() error {
return nil
}
func (*timerCtx) Value(key any) any {
return nil
}
func (e *timerCtx) String() string {
return ""
} }
// CopyContext copy context with value and grpc metadata // CopyContext copy context with value and grpc metadata
@ -87,7 +115,7 @@ func CopyContext(ctx context.Context) context.Context {
func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) {
rtType := reflect.TypeOf(ctx).String() rtType := reflect.TypeOf(ctx).String()
if rtType == "*context.emptyCtx" || rtType == "*context.timerCtx" { if rtType == "*context.emptyCtx" {
return return
} }
ictx := *(*iface)(unsafe.Pointer(&ctx)) ictx := *(*iface)(unsafe.Pointer(&ctx))
@ -95,8 +123,13 @@ func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) {
return return
} }
valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) valCtx := (*valueCtx)(unsafe.Pointer(ictx.data))
if valCtx != nil && valCtx.key != nil && valCtx.value != nil { if valCtx.key != nil && valCtx.value != nil && rtType == "*context.valueCtx" {
kv[valCtx.key] = valCtx.value kv[valCtx.key] = valCtx.value
} }
if rtType == "*context.timerCtx" {
tCtx := (*timerCtx)(unsafe.Pointer(ictx.data))
getKeyValues(tCtx.cancelCtx, kv)
return
}
getKeyValues(valCtx.Context, kv) getKeyValues(valCtx.Context, kv)
} }

24
dtmsvr/utils_test.go

@ -32,10 +32,14 @@ func TestSetNextCron(t *testing.T) {
assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset))
} }
type testContextType string
func TestCopyContext(t *testing.T) { func TestCopyContext(t *testing.T) {
ctxWithValue := context.WithValue(context.Background(), "key", "value") var key testContextType = "key"
var value testContextType = "value"
ctxWithValue := context.WithValue(context.Background(), key, value)
newCtx := CopyContext(ctxWithValue) newCtx := CopyContext(ctxWithValue)
assert.Equal(t, ctxWithValue.Value("key"), newCtx.Value("key")) assert.Equal(t, ctxWithValue.Value(key), newCtx.Value(key))
var ctx context.Context var ctx context.Context
newCtx = CopyContext(ctx) newCtx = CopyContext(ctx)
@ -43,12 +47,16 @@ func TestCopyContext(t *testing.T) {
} }
func TestCopyContextRecursive(t *testing.T) { func TestCopyContextRecursive(t *testing.T) {
ctxWithValue := context.WithValue(context.Background(), "key", "value") var key testContextType = "key"
nestedCtx := context.WithValue(ctxWithValue, "nested_key", "nested_value") var value testContextType = "value"
var nestedKey testContextType = "nested_key"
var nestedValue testContextType = "nested_value"
ctxWithValue := context.WithValue(context.Background(), key, value)
nestedCtx := context.WithValue(ctxWithValue, nestedKey, nestedValue)
newCtx := CopyContext(nestedCtx) newCtx := CopyContext(nestedCtx)
assert.Equal(t, nestedCtx.Value("nested_key"), newCtx.Value("nested_key")) assert.Equal(t, nestedCtx.Value(nestedKey), newCtx.Value(nestedKey))
assert.Equal(t, nestedCtx.Value("key"), newCtx.Value("key")) assert.Equal(t, nestedCtx.Value(key), newCtx.Value(key))
} }
func TestCopyContextWithMetadata(t *testing.T) { func TestCopyContextWithMetadata(t *testing.T) {
@ -68,7 +76,9 @@ func TestCopyContextWithMetadata(t *testing.T) {
} }
func BenchmarkCopyContext(b *testing.B) { func BenchmarkCopyContext(b *testing.B) {
ctx := context.WithValue(context.Background(), "key", "value") var key testContextType = "key"
var value testContextType = "value"
ctx := context.WithValue(context.Background(), key, value)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
CopyContext(ctx) CopyContext(ctx)

Loading…
Cancel
Save