From d940034084c535497187d4a50f77e5066c40dc4b Mon Sep 17 00:00:00 2001 From: Makonike Date: Thu, 24 Aug 2023 00:40:10 +0800 Subject: [PATCH 01/13] feat: new saga grpc with context --- client/dtmcli/trans_saga.go | 8 ++++++ client/dtmgrpc/options_test.go | 50 ++++++++++++++++++++++++++++++++++ client/dtmgrpc/saga.go | 12 ++++++++ 3 files changed, 70 insertions(+) diff --git a/client/dtmcli/trans_saga.go b/client/dtmcli/trans_saga.go index 04c8124..db79b4f 100644 --- a/client/dtmcli/trans_saga.go +++ b/client/dtmcli/trans_saga.go @@ -7,6 +7,7 @@ package dtmcli import ( + "context" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" ) @@ -21,6 +22,13 @@ func NewSaga(server string, gid string) *Saga { return &Saga{TransBase: *dtmimp.NewTransBase(gid, "saga", server, ""), orders: map[int][]int{}} } +// NewSagaWithContext create a saga with context +func NewSagaWithContext(ctx context.Context, server string, gid string) *Saga { + saga := NewSaga(server, gid) + saga.TransBase.Context = ctx + return saga +} + // Add add a saga step func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga { s.Steps = append(s.Steps, map[string]string{"action": action, "compensate": compensate}) diff --git a/client/dtmgrpc/options_test.go b/client/dtmgrpc/options_test.go index 011c36b..6a4ddf1 100644 --- a/client/dtmgrpc/options_test.go +++ b/client/dtmgrpc/options_test.go @@ -1,6 +1,7 @@ package dtmgrpc import ( + "context" "reflect" "testing" @@ -102,3 +103,52 @@ func TestNewSagaGrpc(t *testing.T) { }) } } + +// TestNewSagaGrpcWithContext ut for NewSagaGrpcWithContext +func TestNewSagaGrpcWithContext(t *testing.T) { + var ( + ctx = context.Background() + server = "dmt_server_address" + gidNoOptions = "msg_no_setup_options" + gidTraceIDXXX = "msg_setup_options_trace_id_xxx" + sagaWithTraceIDXXX = &SagaGrpc{Saga: *dtmcli.NewSagaWithContext(ctx, server, gidTraceIDXXX)} + traceIDHeaders = map[string]string{ + "x-trace-id": "xxx", + } + ) + sagaWithTraceIDXXX.BranchHeaders = traceIDHeaders + type args struct { + gid string + opts []TransBaseOption + } + tests := []struct { + name string + args args + want *SagaGrpc + }{ + { + name: "no setup options", + args: args{gid: gidNoOptions}, + want: &SagaGrpc{Saga: *dtmcli.NewSaga(server, gidNoOptions)}, + }, + { + name: "msg with trace_id", + args: args{ + gid: gidTraceIDXXX, + opts: []TransBaseOption{ + WithBranchHeaders(traceIDHeaders), + }, + }, + want: sagaWithTraceIDXXX, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewSagaGrpcWithContext(ctx, server, tt.args.gid, tt.args.opts...) + t.Logf("TestNewSagaGrpc %s got %+v\n", tt.name, got) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSagaGrpc() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/client/dtmgrpc/saga.go b/client/dtmgrpc/saga.go index cc13df8..59766ef 100644 --- a/client/dtmgrpc/saga.go +++ b/client/dtmgrpc/saga.go @@ -7,6 +7,7 @@ package dtmgrpc import ( + "context" "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" "google.golang.org/protobuf/proto" @@ -28,6 +29,17 @@ func NewSagaGrpc(server string, gid string, opts ...TransBaseOption) *SagaGrpc { return sg } +// NewSagaGrpcWithContext create a saga with context +func NewSagaGrpcWithContext(ctx context.Context, server string, gid string, opts ...TransBaseOption) *SagaGrpc { + sg := &SagaGrpc{Saga: *dtmcli.NewSagaWithContext(ctx, server, gid)} + + for _, opt := range opts { + opt(&sg.TransBase) + } + + return sg +} + // Add add a saga step func (s *SagaGrpc) Add(action string, compensate string, payload proto.Message) *SagaGrpc { s.Steps = append(s.Steps, map[string]string{"action": action, "compensate": compensate}) From 34e5b3d50bcd2738e626cd3ac8659b7204bea1d2 Mon Sep 17 00:00:00 2001 From: Makonike Date: Sun, 27 Aug 2023 18:27:37 +0800 Subject: [PATCH 02/13] feat: copy context --- dtmsvr/trans_class.go | 1 + dtmsvr/utils.go | 47 +++++++++++++++++++++++++++++++++++++++++++ dtmsvr/utils_test.go | 40 ++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/dtmsvr/trans_class.go b/dtmsvr/trans_class.go index a311879..2038bdf 100644 --- a/dtmsvr/trans_class.go +++ b/dtmsvr/trans_class.go @@ -103,6 +103,7 @@ func TransFromDtmRequest(ctx context.Context, c *dtmgpb.DtmRequest) *TransGlobal }, }} r.ReqExtra = c.ReqExtra + r.Context = CopyContext(ctx) if c.Steps != "" { dtmimp.MustUnmarshalString(c.Steps, &r.Steps) } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 6ab5ffe..6fac299 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -7,14 +7,18 @@ package dtmsvr import ( + "context" "fmt" + "reflect" "time" + "unsafe" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" "github.com/dtm-labs/dtm/dtmsvr/storage" "github.com/dtm-labs/dtm/dtmsvr/storage/registry" "github.com/lithammer/shortuuid/v3" + "google.golang.org/grpc/metadata" ) type branchStatus struct { @@ -50,3 +54,46 @@ func GetTransGlobal(gid string) *TransGlobal { //nolint:staticcheck return &TransGlobal{TransGlobalStore: *trans} } + +type iface struct { + itab, data uintptr +} + +type valueCtx struct { + context.Context + key, value interface{} +} + +// CopyContext copy context with value and grpc metadata +// if raw context is nil, return nil +func CopyContext(ctx context.Context) context.Context { + if ctx == nil { + return ctx + } + newCtx := context.Background() + kv := make(map[interface{}]interface{}) + getKeyValues(ctx, kv) + for k, v := range kv { + newCtx = context.WithValue(newCtx, k, v) + } + if md, ok := metadata.FromIncomingContext(ctx); ok { + newCtx = metadata.NewIncomingContext(newCtx, md) + } + return newCtx +} + +func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { + rtType := reflect.TypeOf(ctx).String() + if rtType == "*context.emptyCtx" { + return + } + ictx := *(*iface)(unsafe.Pointer(&ctx)) + if ictx.data == 0 { + return + } + valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) + if valCtx != nil && valCtx.key != nil && valCtx.value != nil { + kv[valCtx.key] = valCtx.value + } + getKeyValues(valCtx.Context, kv) +} diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index 15f8635..5a78931 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -7,9 +7,11 @@ package dtmsvr import ( + "context" "testing" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" ) func TestUtils(t *testing.T) { @@ -29,3 +31,41 @@ func TestSetNextCron(t *testing.T) { tg.TimeoutToFail = 3 assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) } + +func TestCopyContext(t *testing.T) { + ctxWithValue := context.WithValue(context.Background(), "key", "value") + newCtx := CopyContext(ctxWithValue) + assert.Equal(t, ctxWithValue.Value("key"), newCtx.Value("key")) + + var ctx context.Context + newCtx = CopyContext(ctx) + assert.Nil(t, newCtx) +} + +func TestCopyContextRecursive(t *testing.T) { + ctxWithValue := context.WithValue(context.Background(), "key", "value") + nestedCtx := context.WithValue(ctxWithValue, "nested_key", "nested_value") + newCtx := CopyContext(nestedCtx) + + assert.Equal(t, nestedCtx.Value("nested_key"), newCtx.Value("nested_key")) + assert.Equal(t, nestedCtx.Value("key"), newCtx.Value("key")) +} + +func TestCopyContextWithMetadata(t *testing.T) { + md := metadata.New(map[string]string{"key": "value"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + newCtx := CopyContext(ctx) + + copiedMD, ok := metadata.FromIncomingContext(newCtx) + assert.True(t, ok) + assert.Equal(t, 1, len(copiedMD["key"])) + assert.Equal(t, "value", copiedMD["key"][0]) +} + +func BenchmarkCopyContext(b *testing.B) { + ctx := context.WithValue(context.Background(), "key", "value") + b.ResetTimer() + for i := 0; i < b.N; i++ { + CopyContext(ctx) + } +} From 55087390163d9b0c8eb94d47bf3f476d9844733c Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 00:49:36 +0800 Subject: [PATCH 03/13] feat: copy context when client call --- dtmsvr/trans_class.go | 2 +- dtmsvr/trans_status.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dtmsvr/trans_class.go b/dtmsvr/trans_class.go index 2038bdf..aad32e4 100644 --- a/dtmsvr/trans_class.go +++ b/dtmsvr/trans_class.go @@ -103,7 +103,7 @@ func TransFromDtmRequest(ctx context.Context, c *dtmgpb.DtmRequest) *TransGlobal }, }} r.ReqExtra = c.ReqExtra - r.Context = CopyContext(ctx) + r.Context = ctx if c.Steps != "" { dtmimp.MustUnmarshalString(c.Steps, &r.Steps) } diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index baa51dd..4cb2a96 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -200,7 +200,7 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo } conn := dtmgimp.MustGetGrpcConn(server, true) - ctx := dtmgimp.TransInfo2Ctx(t.Context, t.Gid, t.TransType, branchID, op, "") + ctx := dtmgimp.TransInfo2Ctx(CopyContext(t.Context), t.Gid, t.TransType, branchID, op, "") kvs := dtmgimp.Map2Kvs(t.Ext.Headers) kvs = append(kvs, dtmgimp.Map2Kvs(t.BranchHeaders)...) ctx = metadata.AppendToOutgoingContext(ctx, kvs...) From 938122af6b0491c73521450b84284e5771f65182 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 12:15:54 +0800 Subject: [PATCH 04/13] fix --- dtmsvr/utils.go | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 6fac299..fb13f2e 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -9,16 +9,13 @@ package dtmsvr import ( "context" "fmt" - "reflect" - "time" - "unsafe" - "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" "github.com/dtm-labs/dtm/dtmsvr/storage" "github.com/dtm-labs/dtm/dtmsvr/storage/registry" "github.com/lithammer/shortuuid/v3" "google.golang.org/grpc/metadata" + "time" ) type branchStatus struct { @@ -71,29 +68,9 @@ func CopyContext(ctx context.Context) context.Context { return ctx } newCtx := context.Background() - kv := make(map[interface{}]interface{}) - getKeyValues(ctx, kv) - for k, v := range kv { - newCtx = context.WithValue(newCtx, k, v) - } + // TODO: copy value in context if md, ok := metadata.FromIncomingContext(ctx); ok { newCtx = metadata.NewIncomingContext(newCtx, md) } return newCtx } - -func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { - rtType := reflect.TypeOf(ctx).String() - if rtType == "*context.emptyCtx" { - return - } - ictx := *(*iface)(unsafe.Pointer(&ctx)) - if ictx.data == 0 { - return - } - valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) - if valCtx != nil && valCtx.key != nil && valCtx.value != nil { - kv[valCtx.key] = valCtx.value - } - getKeyValues(valCtx.Context, kv) -} From c63e778aa9a89699ef47b4624e5b607c9c94ca98 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 12:16:30 +0800 Subject: [PATCH 05/13] fix --- dtmsvr/utils.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index fb13f2e..8e6a8ba 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -9,13 +9,14 @@ package dtmsvr import ( "context" "fmt" + "time" + "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" "github.com/dtm-labs/dtm/dtmsvr/storage" "github.com/dtm-labs/dtm/dtmsvr/storage/registry" "github.com/lithammer/shortuuid/v3" "google.golang.org/grpc/metadata" - "time" ) type branchStatus struct { From 5c79adf760753c9dd9e9d41e976b23a8d4fd3fe0 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 12:21:12 +0800 Subject: [PATCH 06/13] fix --- dtmsvr/utils.go | 27 ++++++++++++++++++++++++++- dtmsvr/utils_test.go | 5 +++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 8e6a8ba..b1ae5e0 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -9,7 +9,9 @@ package dtmsvr import ( "context" "fmt" + "reflect" "time" + "unsafe" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" @@ -69,9 +71,32 @@ func CopyContext(ctx context.Context) context.Context { return ctx } newCtx := context.Background() - // TODO: copy value in context + kv := make(map[interface{}]interface{}) + getKeyValues(ctx, kv) + for k, v := range kv { + newCtx = context.WithValue(newCtx, k, v) + } if md, ok := metadata.FromIncomingContext(ctx); ok { newCtx = metadata.NewIncomingContext(newCtx, md) } + if md, ok := metadata.FromOutgoingContext(ctx); ok { + newCtx = metadata.NewOutgoingContext(newCtx, md) + } return newCtx } + +func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { + rtType := reflect.TypeOf(ctx).String() + if rtType == "*context.emptyCtx" || rtType == "*context.timerCtx" { + return + } + ictx := *(*iface)(unsafe.Pointer(&ctx)) + if ictx.data == 0 { + return + } + valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) + if valCtx != nil && valCtx.key != nil && valCtx.value != nil { + kv[valCtx.key] = valCtx.value + } + getKeyValues(valCtx.Context, kv) +} diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index 5a78931..d91f7a9 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -54,12 +54,17 @@ func TestCopyContextRecursive(t *testing.T) { func TestCopyContextWithMetadata(t *testing.T) { md := metadata.New(map[string]string{"key": "value"}) ctx := metadata.NewIncomingContext(context.Background(), md) + ctx = metadata.NewOutgoingContext(ctx, md) newCtx := CopyContext(ctx) copiedMD, ok := metadata.FromIncomingContext(newCtx) assert.True(t, ok) assert.Equal(t, 1, len(copiedMD["key"])) assert.Equal(t, "value", copiedMD["key"][0]) + copiedMD, ok = metadata.FromOutgoingContext(newCtx) + assert.True(t, ok) + assert.Equal(t, 1, len(copiedMD["key"])) + assert.Equal(t, "value", copiedMD["key"][0]) } func BenchmarkCopyContext(b *testing.B) { From dfcfaae8688fe83284419b840285dc2d51217111 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 28 Aug 2023 12:52:06 +0800 Subject: [PATCH 07/13] fix --- dtmsvr/utils.go | 39 ++++++++++++++++++++++++++++++++++++--- dtmsvr/utils_test.go | 24 +++++++++++++++++------- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index b1ae5e0..ac60f3c 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -61,7 +61,35 @@ type iface struct { type valueCtx struct { context.Context - key, value interface{} + key, value any +} + +type cancelCtx struct { + context.Context +} + +type timerCtx struct { + cancelCtx *cancelCtx +} + +func (*timerCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*timerCtx) Done() <-chan struct{} { + return nil +} + +func (*timerCtx) Err() error { + return nil +} + +func (*timerCtx) Value(key any) any { + return nil +} + +func (e *timerCtx) String() string { + return "" } // CopyContext copy context with value and grpc metadata @@ -87,7 +115,7 @@ func CopyContext(ctx context.Context) context.Context { func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { rtType := reflect.TypeOf(ctx).String() - if rtType == "*context.emptyCtx" || rtType == "*context.timerCtx" { + if rtType == "*context.emptyCtx" { return } ictx := *(*iface)(unsafe.Pointer(&ctx)) @@ -95,8 +123,13 @@ func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { return } valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) - if valCtx != nil && valCtx.key != nil && valCtx.value != nil { + if valCtx.key != nil && valCtx.value != nil && rtType == "*context.valueCtx" { kv[valCtx.key] = valCtx.value } + if rtType == "*context.timerCtx" { + tCtx := (*timerCtx)(unsafe.Pointer(ictx.data)) + getKeyValues(tCtx.cancelCtx, kv) + return + } getKeyValues(valCtx.Context, kv) } diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index d91f7a9..cd3007f 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -32,10 +32,14 @@ func TestSetNextCron(t *testing.T) { assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) } +type testContextType string + func TestCopyContext(t *testing.T) { - ctxWithValue := context.WithValue(context.Background(), "key", "value") + var key testContextType = "key" + var value testContextType = "value" + ctxWithValue := context.WithValue(context.Background(), key, value) newCtx := CopyContext(ctxWithValue) - assert.Equal(t, ctxWithValue.Value("key"), newCtx.Value("key")) + assert.Equal(t, ctxWithValue.Value(key), newCtx.Value(key)) var ctx context.Context newCtx = CopyContext(ctx) @@ -43,12 +47,16 @@ func TestCopyContext(t *testing.T) { } func TestCopyContextRecursive(t *testing.T) { - ctxWithValue := context.WithValue(context.Background(), "key", "value") - nestedCtx := context.WithValue(ctxWithValue, "nested_key", "nested_value") + var key testContextType = "key" + var value testContextType = "value" + var nestedKey testContextType = "nested_key" + var nestedValue testContextType = "nested_value" + ctxWithValue := context.WithValue(context.Background(), key, value) + nestedCtx := context.WithValue(ctxWithValue, nestedKey, nestedValue) newCtx := CopyContext(nestedCtx) - assert.Equal(t, nestedCtx.Value("nested_key"), newCtx.Value("nested_key")) - assert.Equal(t, nestedCtx.Value("key"), newCtx.Value("key")) + assert.Equal(t, nestedCtx.Value(nestedKey), newCtx.Value(nestedKey)) + assert.Equal(t, nestedCtx.Value(key), newCtx.Value(key)) } func TestCopyContextWithMetadata(t *testing.T) { @@ -68,7 +76,9 @@ func TestCopyContextWithMetadata(t *testing.T) { } func BenchmarkCopyContext(b *testing.B) { - ctx := context.WithValue(context.Background(), "key", "value") + var key testContextType = "key" + var value testContextType = "value" + ctx := context.WithValue(context.Background(), key, value) b.ResetTimer() for i := 0; i < b.N; i++ { CopyContext(ctx) From 2f4e60c970a4f9d82fa08d1353d92f88509e7996 Mon Sep 17 00:00:00 2001 From: Makonike Date: Tue, 29 Aug 2023 01:15:16 +0800 Subject: [PATCH 08/13] fix --- dtmsvr/trans_process.go | 7 ++++--- dtmsvr/trans_status.go | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dtmsvr/trans_process.go b/dtmsvr/trans_process.go index 887a0b1..789f3d0 100644 --- a/dtmsvr/trans_process.go +++ b/dtmsvr/trans_process.go @@ -7,6 +7,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "time" @@ -31,14 +32,14 @@ func (t *TransGlobal) process(branches []TransBranch) error { if t.ExtData != "" { dtmimp.MustUnmarshalString(t.ExtData, &t.Ext) } - if !t.WaitResult { - go func() { + go func(ctx context.Context) { + t.Context = CopyContext(ctx) err := t.processInner(branches) if err != nil && !errors.Is(err, dtmimp.ErrOngoing) { logger.Errorf("processInner err: %v", err) } - }() + }(t.Context) return nil } submitting := t.Status == dtmcli.StatusSubmitted diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index 4cb2a96..baa51dd 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -200,7 +200,7 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo } conn := dtmgimp.MustGetGrpcConn(server, true) - ctx := dtmgimp.TransInfo2Ctx(CopyContext(t.Context), t.Gid, t.TransType, branchID, op, "") + ctx := dtmgimp.TransInfo2Ctx(t.Context, t.Gid, t.TransType, branchID, op, "") kvs := dtmgimp.Map2Kvs(t.Ext.Headers) kvs = append(kvs, dtmgimp.Map2Kvs(t.BranchHeaders)...) ctx = metadata.AppendToOutgoingContext(ctx, kvs...) From 56796afdbbae65b3bd48b180ffa2591406fb04fd Mon Sep 17 00:00:00 2001 From: Makonike Date: Sun, 3 Sep 2023 18:31:38 +0800 Subject: [PATCH 09/13] fix --- dtmsvr/trans_class.go | 2 +- dtmsvr/trans_process.go | 12 ++++++------ dtmsvr/trans_status.go | 17 +++++++++-------- dtmsvr/trans_type_msg.go | 18 ++++++++++-------- dtmsvr/trans_type_saga.go | 13 ++++++++----- dtmsvr/trans_type_tcc.go | 5 +++-- dtmsvr/trans_type_workflow.go | 5 +++-- dtmsvr/trans_type_xa.go | 5 +++-- 8 files changed, 43 insertions(+), 34 deletions(-) diff --git a/dtmsvr/trans_class.go b/dtmsvr/trans_class.go index aad32e4..999c5b3 100644 --- a/dtmsvr/trans_class.go +++ b/dtmsvr/trans_class.go @@ -42,7 +42,7 @@ type TransBranch = storage.TransBranchStore type transProcessor interface { GenBranches() []TransBranch - ProcessOnce(branches []TransBranch) error + ProcessOnce(ctx context.Context, branches []TransBranch) error } type processorCreator func(*TransGlobal) transProcessor diff --git a/dtmsvr/trans_process.go b/dtmsvr/trans_process.go index 789f3d0..d39898b 100644 --- a/dtmsvr/trans_process.go +++ b/dtmsvr/trans_process.go @@ -33,17 +33,17 @@ func (t *TransGlobal) process(branches []TransBranch) error { dtmimp.MustUnmarshalString(t.ExtData, &t.Ext) } if !t.WaitResult { + ctx := CopyContext(t.Context) go func(ctx context.Context) { - t.Context = CopyContext(ctx) - err := t.processInner(branches) + err := t.processInner(ctx, branches) if err != nil && !errors.Is(err, dtmimp.ErrOngoing) { logger.Errorf("processInner err: %v", err) } - }(t.Context) + }(ctx) return nil } submitting := t.Status == dtmcli.StatusSubmitted - err := t.processInner(branches) + err := t.processInner(t.Context, branches) if err != nil { return err } @@ -57,7 +57,7 @@ func (t *TransGlobal) process(branches []TransBranch) error { return nil } -func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) { +func (t *TransGlobal) processInner(ctx context.Context, branches []TransBranch) (rerr error) { defer handlePanic(&rerr) defer func() { if rerr != nil && !errors.Is(rerr, dtmcli.ErrOngoing) { @@ -71,7 +71,7 @@ func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) { }() logger.Debugf("processing: %s status: %s", t.Gid, t.Status) t.lastTouched = time.Now() - rerr = t.getProcessor().ProcessOnce(branches) + rerr = t.getProcessor().ProcessOnce(ctx, branches) return } diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index baa51dd..9879b87 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "math" @@ -127,7 +128,7 @@ func (t *TransGlobal) needProcess() bool { return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout() } -func (t *TransGlobal) getURLResult(uri string, branchID, op string, branchPayload []byte) error { +func (t *TransGlobal) getURLResult(ctx context.Context, uri string, branchID, op string, branchPayload []byte) error { if uri == "" { // empty url is success return nil } @@ -137,7 +138,7 @@ func (t *TransGlobal) getURLResult(uri string, branchID, op string, branchPayloa } return t.getHTTPResult(uri, branchID, op, branchPayload) } - return t.getGrpcResult(uri, branchID, op, branchPayload) + return t.getGrpcResult(ctx, uri, branchID, op, branchPayload) } func (t *TransGlobal) getHTTPResult(uri string, branchID, op string, branchPayload []byte) error { @@ -192,7 +193,7 @@ func (t *TransGlobal) getJSONRPCResult(uri string, branchID, op string, branchPa return err } -func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPayload []byte) error { +func (t *TransGlobal) getGrpcResult(ctx context.Context, uri string, branchID, op string, branchPayload []byte) error { // grpc handler server, method, err := dtmdriver.GetDriver().ParseServerMethod(uri) if err != nil { @@ -200,7 +201,7 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo } conn := dtmgimp.MustGetGrpcConn(server, true) - ctx := dtmgimp.TransInfo2Ctx(t.Context, t.Gid, t.TransType, branchID, op, "") + ctx = dtmgimp.TransInfo2Ctx(ctx, t.Gid, t.TransType, branchID, op, "") kvs := dtmgimp.Map2Kvs(t.Ext.Headers) kvs = append(kvs, dtmgimp.Map2Kvs(t.BranchHeaders)...) ctx = metadata.AppendToOutgoingContext(ctx, kvs...) @@ -212,8 +213,8 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo return dtmgrpc.GrpcError2DtmError(err) } -func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { - err := t.getURLResult(branch.URL, branch.BranchID, branch.Op, branch.BinData) +func (t *TransGlobal) getBranchResult(ctx context.Context, branch *TransBranch) (string, error) { + err := t.getURLResult(ctx, branch.URL, branch.BranchID, branch.Op, branch.BinData) if err == nil { return dtmcli.StatusSucceed, nil } else if t.TransType == "saga" && branch.Op == dtmimp.OpAction && errors.Is(err, dtmcli.ErrFailure) { @@ -225,8 +226,8 @@ func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { return "", fmt.Errorf("your http/grpc result should be specified as in:\nhttp://d.dtm.pub/practice/arch.html#proto\nunkown result will be retried: %w", err) } -func (t *TransGlobal) execBranch(branch *TransBranch, branchPos int) error { - status, err := t.getBranchResult(branch) +func (t *TransGlobal) execBranch(ctx context.Context, branch *TransBranch, branchPos int) error { + status, err := t.getBranchResult(ctx, branch) if status != "" { t.changeBranchStatus(branch, status, branchPos) } diff --git a/dtmsvr/trans_type_msg.go b/dtmsvr/trans_type_msg.go index 8037150..0def41d 100644 --- a/dtmsvr/trans_type_msg.go +++ b/dtmsvr/trans_type_msg.go @@ -7,6 +7,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "strings" @@ -51,11 +52,11 @@ type cMsgCustom struct { Delay uint64 //delay call branch, unit second } -func (t *TransGlobal) mayQueryPrepared() { +func (t *TransGlobal) mayQueryPrepared(ctx context.Context) { if !t.needProcess() || t.Status == dtmcli.StatusSubmitted { return } - err := t.getURLResult(t.QueryPrepared, "00", "msg", nil) + err := t.getURLResult(ctx, t.QueryPrepared, "00", "msg", nil) if err == nil { t.changeStatus(dtmcli.StatusSubmitted) } else if errors.Is(err, dtmcli.ErrFailure) { @@ -68,8 +69,8 @@ func (t *TransGlobal) mayQueryPrepared() { } } -func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { - t.mayQueryPrepared() +func (t *transMsgProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { + t.mayQueryPrepared(ctx) if !t.needProcess() || t.Status == dtmcli.StatusPrepared { return nil } @@ -91,12 +92,13 @@ func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { continue } if t.Concurrent { + copyCtx := CopyContext(ctx) started++ - go func(pos int) { - resultsChan <- t.execBranch(b, pos) - }(i) + go func(ctx context.Context, pos int) { + resultsChan <- t.execBranch(ctx, b, pos) + }(copyCtx, i) } else { - err = t.execBranch(b, i) + err = t.execBranch(ctx, b, i) if err != nil { break } diff --git a/dtmsvr/trans_type_saga.go b/dtmsvr/trans_type_saga.go index 44c8667..cd60a56 100644 --- a/dtmsvr/trans_type_saga.go +++ b/dtmsvr/trans_type_saga.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "time" @@ -52,7 +53,7 @@ type branchResult struct { err error } -func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transSagaProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { // when saga tasks is fetched, it always need to process logger.Debugf("status: %s timeout: %t", t.Status, t.isTimeout()) if t.Status == dtmcli.StatusSubmitted && t.isTimeout() { @@ -121,7 +122,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { return true } resultChan := make(chan branchResult, n) - asyncExecBranch := func(i int) { + asyncExecBranch := func(ctx context.Context, i int) { var err error defer func() { if x := recover(); x != nil { @@ -132,7 +133,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { logger.Errorf("exec branch %s %s %s error: %v", branches[i].BranchID, branches[i].Op, branches[i].URL, err) } }() - err = t.execBranch(&branches[i], i) + err = t.execBranch(ctx, &branches[i], i) } pickToRunActions := func() []int { toRun := []int{} @@ -162,7 +163,8 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { if branchResults[b].op == dtmimp.OpAction { rsAStarted++ } - go asyncExecBranch(b) + copyCtx := CopyContext(ctx) + go asyncExecBranch(copyCtx, b) } } waitDoneOnce := func() { @@ -178,7 +180,8 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { t.RetryCount++ logger.Infof("Retrying branch %s %s %s, t.RetryLimit: %d, t.RetryCount: %d", branches[r.index].BranchID, branches[r.index].Op, branches[r.index].URL, t.RetryLimit, t.RetryCount) - go asyncExecBranch(r.index) + copyCtx := CopyContext(ctx) + go asyncExecBranch(copyCtx, r.index) break } // if t.RetryCount = t.RetryLimit, trans will be aborted diff --git a/dtmsvr/trans_type_tcc.go b/dtmsvr/trans_type_tcc.go index 767f17b..578e5a9 100644 --- a/dtmsvr/trans_type_tcc.go +++ b/dtmsvr/trans_type_tcc.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "fmt" "github.com/dtm-labs/dtm/client/dtmcli" @@ -20,7 +21,7 @@ func (t *transTccProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transTccProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { if !t.needProcess() { return nil } @@ -31,7 +32,7 @@ func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error { for current := len(branches) - 1; current >= 0; current-- { if branches[current].Op == op && branches[current].Status == dtmcli.StatusPrepared { logger.Debugf("branch info: current: %d ID: %d", current, branches[current].ID) - err := t.execBranch(&branches[current], current) + err := t.execBranch(ctx, &branches[current], current) if err != nil { return err } diff --git a/dtmsvr/trans_type_workflow.go b/dtmsvr/trans_type_workflow.go index a350081..a57f28b 100644 --- a/dtmsvr/trans_type_workflow.go +++ b/dtmsvr/trans_type_workflow.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" @@ -24,7 +25,7 @@ type cWorkflowCustom struct { Data []byte `json:"data"` } -func (t *transWorkflowProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transWorkflowProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed { return nil } @@ -36,5 +37,5 @@ func (t *transWorkflowProcessor) ProcessOnce(branches []TransBranch) error { wd := wfpb.WorkflowData{Data: cmc.Data} data = dtmgimp.MustProtoMarshal(&wd) } - return t.getURLResult(t.QueryPrepared, "00", cmc.Name, data) + return t.getURLResult(ctx, t.QueryPrepared, "00", cmc.Name, data) } diff --git a/dtmsvr/trans_type_xa.go b/dtmsvr/trans_type_xa.go index fdb7a01..3e63949 100644 --- a/dtmsvr/trans_type_xa.go +++ b/dtmsvr/trans_type_xa.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "fmt" "github.com/dtm-labs/dtm/client/dtmcli" @@ -19,7 +20,7 @@ func (t *transXaProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transXaProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { if !t.needProcess() { return nil } @@ -29,7 +30,7 @@ func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error { currentType := dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmimp.OpCommit, dtmimp.OpRollback).(string) for i, branch := range branches { if branch.Op == currentType && branch.Status != dtmcli.StatusSucceed { - err := t.execBranch(&branch, i) + err := t.execBranch(ctx, &branch, i) if err != nil { return err } From edf7a75045f83189112bf8d666e7116a7d756be7 Mon Sep 17 00:00:00 2001 From: Makonike Date: Mon, 4 Sep 2023 14:23:12 +0800 Subject: [PATCH 10/13] fix --- dtmsvr/utils.go | 29 +---------------------------- dtmsvr/utils_test.go | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 33 deletions(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index ac60f3c..16f28f2 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -18,7 +18,6 @@ import ( "github.com/dtm-labs/dtm/dtmsvr/storage" "github.com/dtm-labs/dtm/dtmsvr/storage/registry" "github.com/lithammer/shortuuid/v3" - "google.golang.org/grpc/metadata" ) type branchStatus struct { @@ -69,27 +68,7 @@ type cancelCtx struct { } type timerCtx struct { - cancelCtx *cancelCtx -} - -func (*timerCtx) Deadline() (deadline time.Time, ok bool) { - return -} - -func (*timerCtx) Done() <-chan struct{} { - return nil -} - -func (*timerCtx) Err() error { - return nil -} - -func (*timerCtx) Value(key any) any { - return nil -} - -func (e *timerCtx) String() string { - return "" + cancelCtx } // CopyContext copy context with value and grpc metadata @@ -104,12 +83,6 @@ func CopyContext(ctx context.Context) context.Context { for k, v := range kv { newCtx = context.WithValue(newCtx, k, v) } - if md, ok := metadata.FromIncomingContext(ctx); ok { - newCtx = metadata.NewIncomingContext(newCtx, md) - } - if md, ok := metadata.FromOutgoingContext(ctx); ok { - newCtx = metadata.NewOutgoingContext(newCtx, md) - } return newCtx } diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index cd3007f..3349ad0 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -9,6 +9,7 @@ package dtmsvr import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" "google.golang.org/grpc/metadata" @@ -32,8 +33,6 @@ func TestSetNextCron(t *testing.T) { assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) } -type testContextType string - func TestCopyContext(t *testing.T) { var key testContextType = "key" var value testContextType = "value" @@ -46,17 +45,31 @@ func TestCopyContext(t *testing.T) { assert.Nil(t, newCtx) } +type testContextType string + func TestCopyContextRecursive(t *testing.T) { var key testContextType = "key" + var key2 testContextType = "key2" + var key3 testContextType = "key3" var value testContextType = "value" + var value2 testContextType = "value2" + var value3 testContextType = "value3" var nestedKey testContextType = "nested_key" var nestedValue testContextType = "nested_value" ctxWithValue := context.WithValue(context.Background(), key, value) nestedCtx := context.WithValue(ctxWithValue, nestedKey, nestedValue) - newCtx := CopyContext(nestedCtx) + cancelCtxx, cancel := context.WithCancel(nestedCtx) + defer cancel() + timerCtxx, cancel2 := context.WithTimeout(cancelCtxx, time.Duration(10)*time.Second) + defer cancel2() + timer2 := context.WithValue(timerCtxx, key2, value2) + timer3 := context.WithValue(timer2, key3, value3) + newCtx := CopyContext(timer3) - assert.Equal(t, nestedCtx.Value(nestedKey), newCtx.Value(nestedKey)) - assert.Equal(t, nestedCtx.Value(key), newCtx.Value(key)) + assert.Equal(t, timer3.Value(nestedKey), newCtx.Value(nestedKey)) + assert.Equal(t, timer3.Value(key), newCtx.Value(key)) + assert.Equal(t, timer3.Value(key2), newCtx.Value(key2)) + assert.Equal(t, timer3.Value(key3), newCtx.Value(key3)) } func TestCopyContextWithMetadata(t *testing.T) { From 2726af38fd9abc3785795e76e3c6a27546529a1b Mon Sep 17 00:00:00 2001 From: Makonike Date: Tue, 5 Sep 2023 17:58:44 +0800 Subject: [PATCH 11/13] fix: using async context instead of CopyContext --- dtmsvr/trans_process.go | 2 +- dtmsvr/trans_type_msg.go | 2 +- dtmsvr/trans_type_saga.go | 4 +-- dtmsvr/utils.go | 58 +++++++++++---------------------------- dtmsvr/utils_test.go | 33 +++++++++++++++++----- test/saga_grpc_test.go | 3 +- 6 files changed, 48 insertions(+), 54 deletions(-) diff --git a/dtmsvr/trans_process.go b/dtmsvr/trans_process.go index d39898b..267b5d6 100644 --- a/dtmsvr/trans_process.go +++ b/dtmsvr/trans_process.go @@ -33,7 +33,7 @@ func (t *TransGlobal) process(branches []TransBranch) error { dtmimp.MustUnmarshalString(t.ExtData, &t.Ext) } if !t.WaitResult { - ctx := CopyContext(t.Context) + ctx := NewAsyncContext(t.Context) go func(ctx context.Context) { err := t.processInner(ctx, branches) if err != nil && !errors.Is(err, dtmimp.ErrOngoing) { diff --git a/dtmsvr/trans_type_msg.go b/dtmsvr/trans_type_msg.go index 0def41d..016d7f5 100644 --- a/dtmsvr/trans_type_msg.go +++ b/dtmsvr/trans_type_msg.go @@ -92,7 +92,7 @@ func (t *transMsgProcessor) ProcessOnce(ctx context.Context, branches []TransBra continue } if t.Concurrent { - copyCtx := CopyContext(ctx) + copyCtx := NewAsyncContext(ctx) started++ go func(ctx context.Context, pos int) { resultsChan <- t.execBranch(ctx, b, pos) diff --git a/dtmsvr/trans_type_saga.go b/dtmsvr/trans_type_saga.go index cd60a56..e984b27 100644 --- a/dtmsvr/trans_type_saga.go +++ b/dtmsvr/trans_type_saga.go @@ -163,7 +163,7 @@ func (t *transSagaProcessor) ProcessOnce(ctx context.Context, branches []TransBr if branchResults[b].op == dtmimp.OpAction { rsAStarted++ } - copyCtx := CopyContext(ctx) + copyCtx := NewAsyncContext(ctx) go asyncExecBranch(copyCtx, b) } } @@ -180,7 +180,7 @@ func (t *transSagaProcessor) ProcessOnce(ctx context.Context, branches []TransBr t.RetryCount++ logger.Infof("Retrying branch %s %s %s, t.RetryLimit: %d, t.RetryCount: %d", branches[r.index].BranchID, branches[r.index].Op, branches[r.index].URL, t.RetryLimit, t.RetryCount) - copyCtx := CopyContext(ctx) + copyCtx := NewAsyncContext(ctx) go asyncExecBranch(copyCtx, r.index) break } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 16f28f2..fdf7d81 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -9,9 +9,7 @@ package dtmsvr import ( "context" "fmt" - "reflect" "time" - "unsafe" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmsvr/config" @@ -54,55 +52,31 @@ func GetTransGlobal(gid string) *TransGlobal { return &TransGlobal{TransGlobalStore: *trans} } -type iface struct { - itab, data uintptr +type asyncCtx struct { + parent context.Context } -type valueCtx struct { - context.Context - key, value any +func (a *asyncCtx) Deadline() (deadline time.Time, ok bool) { + return } -type cancelCtx struct { - context.Context +func (a *asyncCtx) Done() <-chan struct{} { + return nil } -type timerCtx struct { - cancelCtx +func (a *asyncCtx) Err() error { + return a.parent.Err() } -// CopyContext copy context with value and grpc metadata -// if raw context is nil, return nil -func CopyContext(ctx context.Context) context.Context { - if ctx == nil { - return ctx - } - newCtx := context.Background() - kv := make(map[interface{}]interface{}) - getKeyValues(ctx, kv) - for k, v := range kv { - newCtx = context.WithValue(newCtx, k, v) - } - return newCtx +func (a *asyncCtx) Value(key any) any { + return a.parent.Value(key) } -func getKeyValues(ctx context.Context, kv map[interface{}]interface{}) { - rtType := reflect.TypeOf(ctx).String() - if rtType == "*context.emptyCtx" { - return - } - ictx := *(*iface)(unsafe.Pointer(&ctx)) - if ictx.data == 0 { - return - } - valCtx := (*valueCtx)(unsafe.Pointer(ictx.data)) - if valCtx.key != nil && valCtx.value != nil && rtType == "*context.valueCtx" { - kv[valCtx.key] = valCtx.value - } - if rtType == "*context.timerCtx" { - tCtx := (*timerCtx)(unsafe.Pointer(ictx.data)) - getKeyValues(tCtx.cancelCtx, kv) - return +// NewAsyncContext create a new async context +// the context will not be canceled when the parent context is canceled +func NewAsyncContext(ctx context.Context) context.Context { + if ctx == nil { + return nil } - getKeyValues(valCtx.Context, kv) + return &asyncCtx{parent: ctx} } diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index 3349ad0..060c1c2 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -33,21 +33,40 @@ func TestSetNextCron(t *testing.T) { assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset)) } -func TestCopyContext(t *testing.T) { +func TestNewAsyncContext(t *testing.T) { var key testContextType = "key" var value testContextType = "value" ctxWithValue := context.WithValue(context.Background(), key, value) - newCtx := CopyContext(ctxWithValue) + newCtx := NewAsyncContext(ctxWithValue) assert.Equal(t, ctxWithValue.Value(key), newCtx.Value(key)) var ctx context.Context - newCtx = CopyContext(ctx) + newCtx = NewAsyncContext(ctx) assert.Nil(t, newCtx) } +func TestAsyncContext(t *testing.T) { + ctx := context.Background() + cancelCtx2, cancel := context.WithCancel(ctx) + async := NewAsyncContext(cancelCtx2) + cancelCtx3, cancel2 := context.WithCancel(async) + defer cancel2() + cancel() + select { + case <-cancelCtx2.Done(): + default: + assert.Failf(t, "context should be canceled", "context should be canceled") + } + select { + case <-cancelCtx3.Done(): + assert.Failf(t, "context should not be canceled", "context should not be canceled") + default: + } +} + type testContextType string -func TestCopyContextRecursive(t *testing.T) { +func TestAsyncContextRecursive(t *testing.T) { var key testContextType = "key" var key2 testContextType = "key2" var key3 testContextType = "key3" @@ -64,7 +83,7 @@ func TestCopyContextRecursive(t *testing.T) { defer cancel2() timer2 := context.WithValue(timerCtxx, key2, value2) timer3 := context.WithValue(timer2, key3, value3) - newCtx := CopyContext(timer3) + newCtx := NewAsyncContext(timer3) assert.Equal(t, timer3.Value(nestedKey), newCtx.Value(nestedKey)) assert.Equal(t, timer3.Value(key), newCtx.Value(key)) @@ -76,7 +95,7 @@ func TestCopyContextWithMetadata(t *testing.T) { md := metadata.New(map[string]string{"key": "value"}) ctx := metadata.NewIncomingContext(context.Background(), md) ctx = metadata.NewOutgoingContext(ctx, md) - newCtx := CopyContext(ctx) + newCtx := NewAsyncContext(ctx) copiedMD, ok := metadata.FromIncomingContext(newCtx) assert.True(t, ok) @@ -94,6 +113,6 @@ func BenchmarkCopyContext(b *testing.B) { ctx := context.WithValue(context.Background(), key, value) b.ResetTimer() for i := 0; i < b.N; i++ { - CopyContext(ctx) + NewAsyncContext(ctx) } } diff --git a/test/saga_grpc_test.go b/test/saga_grpc_test.go index 17806a3..e97a8d7 100644 --- a/test/saga_grpc_test.go +++ b/test/saga_grpc_test.go @@ -7,6 +7,7 @@ package test import ( + "context" "testing" "github.com/dtm-labs/dtm/client/dtmcli" @@ -94,7 +95,7 @@ func TestSagaGrpcEmptyUrl(t *testing.T) { // nolint: unparam func genSagaGrpc(gid string, outFailed bool, inFailed bool) *dtmgrpc.SagaGrpc { - saga := dtmgrpc.NewSagaGrpc(dtmutil.DefaultGrpcServer, gid) + saga := dtmgrpc.NewSagaGrpcWithContext(context.Background(), dtmutil.DefaultGrpcServer, gid) req := busi.GenReqGrpc(30, outFailed, inFailed) saga.Add(busi.BusiGrpc+"/busi.Busi/TransOut", busi.BusiGrpc+"/busi.Busi/TransOutRevert", req) saga.Add(busi.BusiGrpc+"/busi.Busi/TransIn", busi.BusiGrpc+"/busi.Busi/TransInRevert", req) From 34370e192b20dd3b31d82b873c6217fd44ecb599 Mon Sep 17 00:00:00 2001 From: Makonike Date: Tue, 5 Sep 2023 18:04:46 +0800 Subject: [PATCH 12/13] style: lint --- client/dtmcli/trans_saga.go | 1 + client/dtmgrpc/saga.go | 1 + dtmsvr/trans_type_workflow.go | 1 + 3 files changed, 3 insertions(+) diff --git a/client/dtmcli/trans_saga.go b/client/dtmcli/trans_saga.go index db79b4f..927a660 100644 --- a/client/dtmcli/trans_saga.go +++ b/client/dtmcli/trans_saga.go @@ -8,6 +8,7 @@ package dtmcli import ( "context" + "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" ) diff --git a/client/dtmgrpc/saga.go b/client/dtmgrpc/saga.go index 59766ef..0d6aaf1 100644 --- a/client/dtmgrpc/saga.go +++ b/client/dtmgrpc/saga.go @@ -8,6 +8,7 @@ package dtmgrpc import ( "context" + "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" "google.golang.org/protobuf/proto" diff --git a/dtmsvr/trans_type_workflow.go b/dtmsvr/trans_type_workflow.go index a57f28b..925323c 100644 --- a/dtmsvr/trans_type_workflow.go +++ b/dtmsvr/trans_type_workflow.go @@ -2,6 +2,7 @@ package dtmsvr import ( "context" + "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" From 035f0db6981f9db2aa515268fafdae6b7b90138b Mon Sep 17 00:00:00 2001 From: Makonike Date: Tue, 5 Sep 2023 20:09:45 +0800 Subject: [PATCH 13/13] fix: use anonymous member --- dtmsvr/utils.go | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index fdf7d81..d83e092 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -53,7 +53,7 @@ func GetTransGlobal(gid string) *TransGlobal { } type asyncCtx struct { - parent context.Context + context.Context } func (a *asyncCtx) Deadline() (deadline time.Time, ok bool) { @@ -64,19 +64,11 @@ func (a *asyncCtx) Done() <-chan struct{} { return nil } -func (a *asyncCtx) Err() error { - return a.parent.Err() -} - -func (a *asyncCtx) Value(key any) any { - return a.parent.Value(key) -} - // NewAsyncContext create a new async context // the context will not be canceled when the parent context is canceled func NewAsyncContext(ctx context.Context) context.Context { if ctx == nil { return nil } - return &asyncCtx{parent: ctx} + return &asyncCtx{Context: ctx} }