diff --git a/client/workflow/factory.go b/client/workflow/factory.go index 51e9f12..69b26ae 100644 --- a/client/workflow/factory.go +++ b/client/workflow/factory.go @@ -1,6 +1,7 @@ package workflow import ( + "context" "fmt" "github.com/dtm-labs/logger" @@ -19,12 +20,12 @@ var defaultFac = workflowFactory{ handlers: map[string]*wfItem{}, } -func (w *workflowFactory) execute(name string, gid string, data []byte) ([]byte, error) { +func (w *workflowFactory) execute(ctx context.Context, name string, gid string, data []byte) ([]byte, error) { handler := w.handlers[name] if handler == nil { return nil, fmt.Errorf("workflow '%s' not registered. please register at startup", name) } - wf := w.newWorkflow(name, gid, data) + wf := w.newWorkflow(ctx, name, gid, data) for _, fn := range handler.custom { fn(wf) } diff --git a/client/workflow/imp.go b/client/workflow/imp.go index 7b9d003..f277d65 100644 --- a/client/workflow/imp.go +++ b/client/workflow/imp.go @@ -47,7 +47,7 @@ func (wf *Workflow) initProgress(progresses []*dtmgpb.DtmProgress) { type wfMeta struct{} -func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Workflow { +func (w *workflowFactory) newWorkflow(ctx context.Context, name string, gid string, data []byte) *Workflow { wf := &Workflow{ TransBase: dtmimp.NewTransBase(gid, "workflow", "not inited", ""), Name: name, @@ -58,6 +58,7 @@ func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Wor currentOp: dtmimp.OpAction, }, } + wf.Context = ctx wf.Protocol = w.protocol if w.protocol == dtmimp.ProtocolGRPC { wf.Dtm = w.grpcDtm diff --git a/client/workflow/server.go b/client/workflow/server.go index 441d2d0..376fdf2 100644 --- a/client/workflow/server.go +++ b/client/workflow/server.go @@ -21,6 +21,6 @@ func (s *workflowServer) Execute(ctx context.Context, wd *wfpb.WorkflowData) (*e return nil, status.Errorf(codes.Internal, "workflow server not inited. please call workflow.InitGrpc first") } tb := dtmgimp.TransBaseFromGrpc(ctx) - _, err := defaultFac.execute(tb.Op, tb.Gid, wd.Data) + _, err := defaultFac.execute(ctx, tb.Op, tb.Gid, wd.Data) return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(err) } diff --git a/client/workflow/workflow.go b/client/workflow/workflow.go index a72bee2..02b38a2 100644 --- a/client/workflow/workflow.go +++ b/client/workflow/workflow.go @@ -53,24 +53,39 @@ func Register2(name string, handler WfFunc2, custom ...func(wf *Workflow)) error return defaultFac.register(name, handler, custom...) } -// Execute will execute a workflow with the gid and specified params +// Execute is the same as ExecuteCtx, but with context.Background +func Execute(name string, gid string, data []byte) error { + return ExecuteCtx(context.Background(), name, gid, data) +} + +// ExecuteCtx will execute a workflow with the gid and specified params // if the workflow with the gid does not exist, then create a new workflow and execute it // if the workflow with the gid exists, resume to execute it -func Execute(name string, gid string, data []byte) error { - _, err := defaultFac.execute(name, gid, data) +func ExecuteCtx(ctx context.Context, name string, gid string, data []byte) error { + _, err := defaultFac.execute(ctx, name, gid, data) return err } // Execute2 is the same as Execute, but workflow func can return result func Execute2(name string, gid string, data []byte) ([]byte, error) { - return defaultFac.execute(name, gid, data) + return Execute2Ctx(context.Background(), name, gid, data) +} + +// Execute2Ctx is the same as Execute2, but with context.Background +func Execute2Ctx(ctx context.Context, name string, gid string, data []byte) ([]byte, error) { + return defaultFac.execute(ctx, name, gid, data) } // ExecuteByQS is like Execute, but name and gid will be obtained from qs func ExecuteByQS(qs url.Values, body []byte) error { + return ExecuteByQSCtx(context.Background(), qs, body) +} + +// ExecuteByQSCtx is the same as ExecuteByQS, but with context.Background +func ExecuteByQSCtx(ctx context.Context, qs url.Values, body []byte) error { name := qs.Get("op") gid := qs.Get("gid") - _, err := defaultFac.execute(name, gid, body) + _, err := defaultFac.execute(ctx, name, gid, body) return err } diff --git a/client/workflow/workflow_test.go b/client/workflow/workflow_test.go index 93cc7b0..530d8b7 100644 --- a/client/workflow/workflow_test.go +++ b/client/workflow/workflow_test.go @@ -10,7 +10,7 @@ import ( func TestAbnormal(t *testing.T) { fname := dtmimp.GetFuncName() - _, err := defaultFac.execute(fname, fname, nil) + _, err := defaultFac.execute(context.Background(), fname, fname, nil) assert.Error(t, err) err = defaultFac.register(fname, func(wf *Workflow, data []byte) ([]byte, error) { return nil, nil })