diff --git a/dtmsvr/trans_class.go b/dtmsvr/trans_class.go index aad32e4..999c5b3 100644 --- a/dtmsvr/trans_class.go +++ b/dtmsvr/trans_class.go @@ -42,7 +42,7 @@ type TransBranch = storage.TransBranchStore type transProcessor interface { GenBranches() []TransBranch - ProcessOnce(branches []TransBranch) error + ProcessOnce(ctx context.Context, branches []TransBranch) error } type processorCreator func(*TransGlobal) transProcessor diff --git a/dtmsvr/trans_process.go b/dtmsvr/trans_process.go index 789f3d0..d39898b 100644 --- a/dtmsvr/trans_process.go +++ b/dtmsvr/trans_process.go @@ -33,17 +33,17 @@ func (t *TransGlobal) process(branches []TransBranch) error { dtmimp.MustUnmarshalString(t.ExtData, &t.Ext) } if !t.WaitResult { + ctx := CopyContext(t.Context) go func(ctx context.Context) { - t.Context = CopyContext(ctx) - err := t.processInner(branches) + err := t.processInner(ctx, branches) if err != nil && !errors.Is(err, dtmimp.ErrOngoing) { logger.Errorf("processInner err: %v", err) } - }(t.Context) + }(ctx) return nil } submitting := t.Status == dtmcli.StatusSubmitted - err := t.processInner(branches) + err := t.processInner(t.Context, branches) if err != nil { return err } @@ -57,7 +57,7 @@ func (t *TransGlobal) process(branches []TransBranch) error { return nil } -func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) { +func (t *TransGlobal) processInner(ctx context.Context, branches []TransBranch) (rerr error) { defer handlePanic(&rerr) defer func() { if rerr != nil && !errors.Is(rerr, dtmcli.ErrOngoing) { @@ -71,7 +71,7 @@ func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) { }() logger.Debugf("processing: %s status: %s", t.Gid, t.Status) t.lastTouched = time.Now() - rerr = t.getProcessor().ProcessOnce(branches) + rerr = t.getProcessor().ProcessOnce(ctx, branches) return } diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index baa51dd..9879b87 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "math" @@ -127,7 +128,7 @@ func (t *TransGlobal) needProcess() bool { return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout() } -func (t *TransGlobal) getURLResult(uri string, branchID, op string, branchPayload []byte) error { +func (t *TransGlobal) getURLResult(ctx context.Context, uri string, branchID, op string, branchPayload []byte) error { if uri == "" { // empty url is success return nil } @@ -137,7 +138,7 @@ func (t *TransGlobal) getURLResult(uri string, branchID, op string, branchPayloa } return t.getHTTPResult(uri, branchID, op, branchPayload) } - return t.getGrpcResult(uri, branchID, op, branchPayload) + return t.getGrpcResult(ctx, uri, branchID, op, branchPayload) } func (t *TransGlobal) getHTTPResult(uri string, branchID, op string, branchPayload []byte) error { @@ -192,7 +193,7 @@ func (t *TransGlobal) getJSONRPCResult(uri string, branchID, op string, branchPa return err } -func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPayload []byte) error { +func (t *TransGlobal) getGrpcResult(ctx context.Context, uri string, branchID, op string, branchPayload []byte) error { // grpc handler server, method, err := dtmdriver.GetDriver().ParseServerMethod(uri) if err != nil { @@ -200,7 +201,7 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo } conn := dtmgimp.MustGetGrpcConn(server, true) - ctx := dtmgimp.TransInfo2Ctx(t.Context, t.Gid, t.TransType, branchID, op, "") + ctx = dtmgimp.TransInfo2Ctx(ctx, t.Gid, t.TransType, branchID, op, "") kvs := dtmgimp.Map2Kvs(t.Ext.Headers) kvs = append(kvs, dtmgimp.Map2Kvs(t.BranchHeaders)...) ctx = metadata.AppendToOutgoingContext(ctx, kvs...) @@ -212,8 +213,8 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo return dtmgrpc.GrpcError2DtmError(err) } -func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { - err := t.getURLResult(branch.URL, branch.BranchID, branch.Op, branch.BinData) +func (t *TransGlobal) getBranchResult(ctx context.Context, branch *TransBranch) (string, error) { + err := t.getURLResult(ctx, branch.URL, branch.BranchID, branch.Op, branch.BinData) if err == nil { return dtmcli.StatusSucceed, nil } else if t.TransType == "saga" && branch.Op == dtmimp.OpAction && errors.Is(err, dtmcli.ErrFailure) { @@ -225,8 +226,8 @@ func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { return "", fmt.Errorf("your http/grpc result should be specified as in:\nhttp://d.dtm.pub/practice/arch.html#proto\nunkown result will be retried: %w", err) } -func (t *TransGlobal) execBranch(branch *TransBranch, branchPos int) error { - status, err := t.getBranchResult(branch) +func (t *TransGlobal) execBranch(ctx context.Context, branch *TransBranch, branchPos int) error { + status, err := t.getBranchResult(ctx, branch) if status != "" { t.changeBranchStatus(branch, status, branchPos) } diff --git a/dtmsvr/trans_type_msg.go b/dtmsvr/trans_type_msg.go index 8037150..0def41d 100644 --- a/dtmsvr/trans_type_msg.go +++ b/dtmsvr/trans_type_msg.go @@ -7,6 +7,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "strings" @@ -51,11 +52,11 @@ type cMsgCustom struct { Delay uint64 //delay call branch, unit second } -func (t *TransGlobal) mayQueryPrepared() { +func (t *TransGlobal) mayQueryPrepared(ctx context.Context) { if !t.needProcess() || t.Status == dtmcli.StatusSubmitted { return } - err := t.getURLResult(t.QueryPrepared, "00", "msg", nil) + err := t.getURLResult(ctx, t.QueryPrepared, "00", "msg", nil) if err == nil { t.changeStatus(dtmcli.StatusSubmitted) } else if errors.Is(err, dtmcli.ErrFailure) { @@ -68,8 +69,8 @@ func (t *TransGlobal) mayQueryPrepared() { } } -func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { - t.mayQueryPrepared() +func (t *transMsgProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { + t.mayQueryPrepared(ctx) if !t.needProcess() || t.Status == dtmcli.StatusPrepared { return nil } @@ -91,12 +92,13 @@ func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { continue } if t.Concurrent { + copyCtx := CopyContext(ctx) started++ - go func(pos int) { - resultsChan <- t.execBranch(b, pos) - }(i) + go func(ctx context.Context, pos int) { + resultsChan <- t.execBranch(ctx, b, pos) + }(copyCtx, i) } else { - err = t.execBranch(b, i) + err = t.execBranch(ctx, b, i) if err != nil { break } diff --git a/dtmsvr/trans_type_saga.go b/dtmsvr/trans_type_saga.go index 44c8667..cd60a56 100644 --- a/dtmsvr/trans_type_saga.go +++ b/dtmsvr/trans_type_saga.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "errors" "fmt" "time" @@ -52,7 +53,7 @@ type branchResult struct { err error } -func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transSagaProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { // when saga tasks is fetched, it always need to process logger.Debugf("status: %s timeout: %t", t.Status, t.isTimeout()) if t.Status == dtmcli.StatusSubmitted && t.isTimeout() { @@ -121,7 +122,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { return true } resultChan := make(chan branchResult, n) - asyncExecBranch := func(i int) { + asyncExecBranch := func(ctx context.Context, i int) { var err error defer func() { if x := recover(); x != nil { @@ -132,7 +133,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { logger.Errorf("exec branch %s %s %s error: %v", branches[i].BranchID, branches[i].Op, branches[i].URL, err) } }() - err = t.execBranch(&branches[i], i) + err = t.execBranch(ctx, &branches[i], i) } pickToRunActions := func() []int { toRun := []int{} @@ -162,7 +163,8 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { if branchResults[b].op == dtmimp.OpAction { rsAStarted++ } - go asyncExecBranch(b) + copyCtx := CopyContext(ctx) + go asyncExecBranch(copyCtx, b) } } waitDoneOnce := func() { @@ -178,7 +180,8 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { t.RetryCount++ logger.Infof("Retrying branch %s %s %s, t.RetryLimit: %d, t.RetryCount: %d", branches[r.index].BranchID, branches[r.index].Op, branches[r.index].URL, t.RetryLimit, t.RetryCount) - go asyncExecBranch(r.index) + copyCtx := CopyContext(ctx) + go asyncExecBranch(copyCtx, r.index) break } // if t.RetryCount = t.RetryLimit, trans will be aborted diff --git a/dtmsvr/trans_type_tcc.go b/dtmsvr/trans_type_tcc.go index 767f17b..578e5a9 100644 --- a/dtmsvr/trans_type_tcc.go +++ b/dtmsvr/trans_type_tcc.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "fmt" "github.com/dtm-labs/dtm/client/dtmcli" @@ -20,7 +21,7 @@ func (t *transTccProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transTccProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { if !t.needProcess() { return nil } @@ -31,7 +32,7 @@ func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error { for current := len(branches) - 1; current >= 0; current-- { if branches[current].Op == op && branches[current].Status == dtmcli.StatusPrepared { logger.Debugf("branch info: current: %d ID: %d", current, branches[current].ID) - err := t.execBranch(&branches[current], current) + err := t.execBranch(ctx, &branches[current], current) if err != nil { return err } diff --git a/dtmsvr/trans_type_workflow.go b/dtmsvr/trans_type_workflow.go index a350081..a57f28b 100644 --- a/dtmsvr/trans_type_workflow.go +++ b/dtmsvr/trans_type_workflow.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" @@ -24,7 +25,7 @@ type cWorkflowCustom struct { Data []byte `json:"data"` } -func (t *transWorkflowProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transWorkflowProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed { return nil } @@ -36,5 +37,5 @@ func (t *transWorkflowProcessor) ProcessOnce(branches []TransBranch) error { wd := wfpb.WorkflowData{Data: cmc.Data} data = dtmgimp.MustProtoMarshal(&wd) } - return t.getURLResult(t.QueryPrepared, "00", cmc.Name, data) + return t.getURLResult(ctx, t.QueryPrepared, "00", cmc.Name, data) } diff --git a/dtmsvr/trans_type_xa.go b/dtmsvr/trans_type_xa.go index fdb7a01..3e63949 100644 --- a/dtmsvr/trans_type_xa.go +++ b/dtmsvr/trans_type_xa.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "context" "fmt" "github.com/dtm-labs/dtm/client/dtmcli" @@ -19,7 +20,7 @@ func (t *transXaProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error { +func (t *transXaProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error { if !t.needProcess() { return nil } @@ -29,7 +30,7 @@ func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error { currentType := dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmimp.OpCommit, dtmimp.OpRollback).(string) for i, branch := range branches { if branch.Op == currentType && branch.Status != dtmcli.StatusSucceed { - err := t.execBranch(&branch, i) + err := t.execBranch(ctx, &branch, i) if err != nil { return err }