diff --git a/client/workflow/factory.go b/client/workflow/factory.go index 51e9f12..69b26ae 100644 --- a/client/workflow/factory.go +++ b/client/workflow/factory.go @@ -1,6 +1,7 @@ package workflow import ( + "context" "fmt" "github.com/dtm-labs/logger" @@ -19,12 +20,12 @@ var defaultFac = workflowFactory{ handlers: map[string]*wfItem{}, } -func (w *workflowFactory) execute(name string, gid string, data []byte) ([]byte, error) { +func (w *workflowFactory) execute(ctx context.Context, name string, gid string, data []byte) ([]byte, error) { handler := w.handlers[name] if handler == nil { return nil, fmt.Errorf("workflow '%s' not registered. please register at startup", name) } - wf := w.newWorkflow(name, gid, data) + wf := w.newWorkflow(ctx, name, gid, data) for _, fn := range handler.custom { fn(wf) } diff --git a/client/workflow/imp.go b/client/workflow/imp.go index 7b9d003..f277d65 100644 --- a/client/workflow/imp.go +++ b/client/workflow/imp.go @@ -47,7 +47,7 @@ func (wf *Workflow) initProgress(progresses []*dtmgpb.DtmProgress) { type wfMeta struct{} -func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Workflow { +func (w *workflowFactory) newWorkflow(ctx context.Context, name string, gid string, data []byte) *Workflow { wf := &Workflow{ TransBase: dtmimp.NewTransBase(gid, "workflow", "not inited", ""), Name: name, @@ -58,6 +58,7 @@ func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Wor currentOp: dtmimp.OpAction, }, } + wf.Context = ctx wf.Protocol = w.protocol if w.protocol == dtmimp.ProtocolGRPC { wf.Dtm = w.grpcDtm diff --git a/client/workflow/server.go b/client/workflow/server.go index 441d2d0..376fdf2 100644 --- a/client/workflow/server.go +++ b/client/workflow/server.go @@ -21,6 +21,6 @@ func (s *workflowServer) Execute(ctx context.Context, wd *wfpb.WorkflowData) (*e return nil, status.Errorf(codes.Internal, "workflow server not inited. please call workflow.InitGrpc first") } tb := dtmgimp.TransBaseFromGrpc(ctx) - _, err := defaultFac.execute(tb.Op, tb.Gid, wd.Data) + _, err := defaultFac.execute(ctx, tb.Op, tb.Gid, wd.Data) return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(err) } diff --git a/client/workflow/workflow.go b/client/workflow/workflow.go index a72bee2..700ba40 100644 --- a/client/workflow/workflow.go +++ b/client/workflow/workflow.go @@ -56,21 +56,21 @@ func Register2(name string, handler WfFunc2, custom ...func(wf *Workflow)) error // Execute will execute a workflow with the gid and specified params // if the workflow with the gid does not exist, then create a new workflow and execute it // if the workflow with the gid exists, resume to execute it -func Execute(name string, gid string, data []byte) error { - _, err := defaultFac.execute(name, gid, data) +func Execute(ctx context.Context, name string, gid string, data []byte) error { + _, err := defaultFac.execute(ctx, name, gid, data) return err } // Execute2 is the same as Execute, but workflow func can return result -func Execute2(name string, gid string, data []byte) ([]byte, error) { - return defaultFac.execute(name, gid, data) +func Execute2(ctx context.Context, name string, gid string, data []byte) ([]byte, error) { + return defaultFac.execute(ctx, name, gid, data) } // ExecuteByQS is like Execute, but name and gid will be obtained from qs -func ExecuteByQS(qs url.Values, body []byte) error { +func ExecuteByQS(ctx context.Context, qs url.Values, body []byte) error { name := qs.Get("op") gid := qs.Get("gid") - _, err := defaultFac.execute(name, gid, body) + _, err := defaultFac.execute(ctx, name, gid, body) return err } diff --git a/client/workflow/workflow_test.go b/client/workflow/workflow_test.go index 93cc7b0..530d8b7 100644 --- a/client/workflow/workflow_test.go +++ b/client/workflow/workflow_test.go @@ -10,7 +10,7 @@ import ( func TestAbnormal(t *testing.T) { fname := dtmimp.GetFuncName() - _, err := defaultFac.execute(fname, fname, nil) + _, err := defaultFac.execute(context.Background(), fname, fname, nil) assert.Error(t, err) err = defaultFac.register(fname, func(wf *Workflow, data []byte) ([]byte, error) { return nil, nil }) diff --git a/test/busi/base_http.go b/test/busi/base_http.go index 44bbabc..2da780e 100644 --- a/test/busi/base_http.go +++ b/test/busi/base_http.go @@ -88,7 +88,7 @@ func BaseAddRoute(app *gin.Engine) { app.POST(BusiAPI+"/workflow/resume", dtmutil.WrapHandler(func(ctx *gin.Context) interface{} { data, err := ioutil.ReadAll(ctx.Request.Body) logger.FatalIfError(err) - return workflow.ExecuteByQS(ctx.Request.URL.Query(), data) + return workflow.ExecuteByQS(ctx, ctx.Request.URL.Query(), data) })) app.POST(BusiAPI+"/TransIn", dtmutil.WrapHandler(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransInResult.Fetch(), reqFrom(c).TransInResult, "transIn") diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index 497a449..a482234 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -70,7 +70,7 @@ func TestUpdateBranchAsync(t *testing.T) { return err }) assert.Nil(t, err) - err = workflow.Execute(gid, gid, nil) + err = workflow.Execute(context.Background(), gid, gid, nil) assert.Nil(t, err) time.Sleep(dtmsvr.UpdateBranchAsyncInterval) diff --git a/test/workflow_grpc_test.go b/test/workflow_grpc_test.go index 1385cd2..cf90eb6 100644 --- a/test/workflow_grpc_test.go +++ b/test/workflow_grpc_test.go @@ -7,6 +7,7 @@ package test import ( + "context" "database/sql" "testing" @@ -33,7 +34,7 @@ func TestWorkflowGrpcSimple(t *testing.T) { _, err = busi.BusiCli.TransInBSaga(wf.NewBranchCtx(), &req) return err }) - err := workflow.Execute(gid, gid, dtmgimp.MustProtoMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmgimp.MustProtoMarshal(req)) assert.Error(t, err) assert.Equal(t, StatusFailed, getTransStatus(gid)) } @@ -61,7 +62,7 @@ func TestWorkflowGrpcRollback(t *testing.T) { return err }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, dtmgimp.MustProtoMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmgimp.MustProtoMarshal(req)) assert.Error(t, err, dtmcli.ErrFailure) assert.Equal(t, StatusFailed, getTransStatus(gid)) assertSameBalance(t, before, "mysql") @@ -106,7 +107,7 @@ func TestWorkflowMixed(t *testing.T) { assert.Nil(t, err) before := getBeforeBalances("mysql") req := &busi.ReqGrpc{Amount: 30} - err = workflow.Execute(gid, gid, dtmgimp.MustProtoMarshal(req)) + err = workflow.Execute(context.Background(), gid, gid, dtmgimp.MustProtoMarshal(req)) assert.Nil(t, err) assert.Equal(t, StatusSucceed, getTransStatus(gid)) assertNotSameBalance(t, before, "mysql") @@ -127,7 +128,7 @@ func TestWorkflowGrpcError(t *testing.T) { _, err = busi.BusiCli.TransIn(wf.NewBranchCtx(), &req) return err }) - err := workflow.Execute(gid, gid, dtmgimp.MustProtoMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmgimp.MustProtoMarshal(req)) assert.Error(t, err) cronTransOnceForwardCron(t, gid, 1000) assert.Equal(t, StatusSucceed, getTransStatus(gid)) diff --git a/test/workflow_http_ret_test.go b/test/workflow_http_ret_test.go index 434653a..a0747b8 100644 --- a/test/workflow_http_ret_test.go +++ b/test/workflow_http_ret_test.go @@ -1,6 +1,7 @@ package test import ( + "context" "testing" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" @@ -21,13 +22,13 @@ func TestWorkflowRet(t *testing.T) { return []byte("result of workflow"), err }) - ret, err := workflow.Execute2(gid, gid, dtmimp.MustMarshal(req)) + ret, err := workflow.Execute2(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Nil(t, err) assert.Equal(t, "result of workflow", string(ret)) assert.Equal(t, StatusSucceed, getTransStatus(gid)) // the second execute will return result directly - ret, err = workflow.Execute2(gid, gid, dtmimp.MustMarshal(req)) + ret, err = workflow.Execute2(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Nil(t, err) assert.Equal(t, "result of workflow", string(ret)) assert.Equal(t, StatusSucceed, getTransStatus(gid)) diff --git a/test/workflow_http_test.go b/test/workflow_http_test.go index 991a315..8b4cb26 100644 --- a/test/workflow_http_test.go +++ b/test/workflow_http_test.go @@ -7,6 +7,7 @@ package test import ( + "context" "database/sql" "testing" @@ -41,7 +42,7 @@ func TestWorkflowNormal(t *testing.T) { return nil }) - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Nil(t, err) assert.Equal(t, StatusSucceed, getTransStatus(gid)) } @@ -82,7 +83,7 @@ func TestWorkflowRollback(t *testing.T) { }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Error(t, err, dtmcli.ErrFailure) assert.Equal(t, StatusFailed, getTransStatus(gid)) assertSameBalance(t, before, "mysql") @@ -120,7 +121,7 @@ func TestWorkflowTcc(t *testing.T) { }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Nil(t, err) assert.Equal(t, StatusSucceed, getTransStatus(gid)) assertNotSameBalance(t, before, "mysql") @@ -158,7 +159,7 @@ func TestWorkflowTccRollback(t *testing.T) { }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Error(t, err) assert.Equal(t, StatusFailed, getTransStatus(gid)) assertSameBalance(t, before, "mysql") @@ -177,7 +178,7 @@ func TestWorkflowError(t *testing.T) { return err }) - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Error(t, err) cronTransOnceForwardCron(t, gid, 1000) assert.Equal(t, StatusSucceed, getTransStatus(gid)) @@ -196,7 +197,7 @@ func TestWorkflowOngoing(t *testing.T) { return err }) - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Error(t, err) cronTransOnceForwardCron(t, gid, 1000) assert.Equal(t, StatusSucceed, getTransStatus(gid)) @@ -224,7 +225,7 @@ func TestWorkflowResumeSkip(t *testing.T) { return err }) - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Error(t, err) cronTransOnceForwardCron(t, gid, 1000) assert.Equal(t, StatusSucceed, getTransStatus(gid)) diff --git a/test/workflow_ongoing_test.go b/test/workflow_ongoing_test.go index 4d6dba4..e4454bd 100644 --- a/test/workflow_ongoing_test.go +++ b/test/workflow_ongoing_test.go @@ -7,6 +7,7 @@ package test import ( + "context" "database/sql" "testing" @@ -47,7 +48,7 @@ func TestWorkflowSimpleResume(t *testing.T) { return err }) - err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmimp.MustMarshal(req)) assert.Error(t, err) cronTransOnceForwardNow(t, gid, 1000) assert.Equal(t, StatusSucceed, getTransStatus(gid)) @@ -94,7 +95,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { }) before := getBeforeBalances("mysql") req := &busi.ReqGrpc{Amount: 30, TransInResult: "FAILURE"} - err := workflow.Execute(gid, gid, dtmgimp.MustProtoMarshal(req)) + err := workflow.Execute(context.Background(), gid, gid, dtmgimp.MustProtoMarshal(req)) assert.Error(t, err, dtmcli.ErrOngoing) assert.Equal(t, StatusPrepared, getTransStatus(gid)) cronTransOnceForwardNow(t, gid, 1000) @@ -140,7 +141,7 @@ func TestWorkflowXaResume(t *testing.T) { return err }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, nil) + err := workflow.Execute(context.Background(), gid, gid, nil) assert.Equal(t, dtmcli.ErrOngoing, err) cronTransOnceForwardNow(t, gid, 1000) diff --git a/test/workflow_xa_test.go b/test/workflow_xa_test.go index dcf55b1..0881062 100644 --- a/test/workflow_xa_test.go +++ b/test/workflow_xa_test.go @@ -7,6 +7,7 @@ package test import ( + "context" "database/sql" "testing" @@ -34,7 +35,7 @@ func TestWorkflowXaAction(t *testing.T) { return err }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, nil) + err := workflow.Execute(context.Background(), gid, gid, nil) assert.Nil(t, err) assert.Equal(t, StatusSucceed, getTransStatus(gid)) assertNotSameBalance(t, before, "mysql") @@ -58,7 +59,7 @@ func TestWorkflowXaRollback(t *testing.T) { return err }) before := getBeforeBalances("mysql") - err := workflow.Execute(gid, gid, nil) + err := workflow.Execute(context.Background(), gid, gid, nil) assert.Equal(t, dtmcli.ErrFailure, err) assert.Equal(t, StatusFailed, getTransStatus(gid)) assertSameBalance(t, before, "mysql")