diff --git a/dtmgrpc/workflow/imp.go b/dtmgrpc/workflow/imp.go index 81a33d5..26b62a3 100644 --- a/dtmgrpc/workflow/imp.go +++ b/dtmgrpc/workflow/imp.go @@ -8,7 +8,6 @@ import ( "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" - "github.com/dtm-labs/dtm/dtmgrpc" "github.com/go-resty/resty/v2" ) @@ -76,7 +75,7 @@ func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Wor }) wf.Context = context.WithValue(wf.Context, wfMeta{}, wf) wf.Options.HTTPResp2DtmError = HTTPResp2DtmError - wf.Options.GRPCError2DtmError = dtmgrpc.GrpcError2DtmError + wf.Options.GRPCError2DtmError = GrpcError2DtmError wf.initRestyClient() return wf } @@ -104,7 +103,7 @@ func (wf *Workflow) process(handler WfFunc, data []byte) (err error) { err = wf.loadProgresses() if err == nil { err = handler(wf, data) - err = dtmgrpc.GrpcError2DtmError(err) + err = wf.Options.GRPCError2DtmError(err) if err != nil && !errors.Is(err, dtmcli.ErrFailure) { return err } diff --git a/dtmgrpc/workflow/utils.go b/dtmgrpc/workflow/utils.go index a4d4d36..5c5720e 100644 --- a/dtmgrpc/workflow/utils.go +++ b/dtmgrpc/workflow/utils.go @@ -11,6 +11,8 @@ import ( "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -82,6 +84,17 @@ func HTTPResp2DtmError(resp *http.Response) ([]byte, error) { return data, err } +// GrpcError2DtmError translate grpc error to dtm error +func GrpcError2DtmError(err error) error { + st, _ := status.FromError(err) + if st != nil && st.Code() == codes.Aborted { + return fmt.Errorf("%s. %w", st.Message(), dtmcli.ErrFailure) + } else if st != nil && st.Code() == codes.FailedPrecondition { + return fmt.Errorf("%s. %w", st.Message(), dtmcli.ErrOngoing) + } + return err +} + func (wf *Workflow) stepResultFromLocal(data []byte, err error) *stepResult { return &stepResult{ Error: err, diff --git a/dtmgrpc/workflow/workflow.go b/dtmgrpc/workflow/workflow.go index b0b78cf..f00a30b 100644 --- a/dtmgrpc/workflow/workflow.go +++ b/dtmgrpc/workflow/workflow.go @@ -61,10 +61,10 @@ func ExecuteByQS(qs url.Values, body []byte) error { // Options is for specifying workflow options type Options struct { - // Default: Code 409 => ErrFailure; Code 425 => ErrOngoing + // Default == HTTPResp2DtmError : Code 409 => ErrFailure; Code 425 => ErrOngoing HTTPResp2DtmError func(*http.Response) ([]byte, error) - // Default: Code Aborted => ErrFailure; Code FailedPrecondition => ErrOngoing + // Default == GrpcError2DtmError: Code Aborted => ErrFailure; Code FailedPrecondition => ErrOngoing GRPCError2DtmError func(error) error // This Option specify whether a branch returning ErrFailure should be compensated on rollback. @@ -145,6 +145,15 @@ func (wf *Workflow) OnCommit(fn WfPhase2Func) *Workflow { return wf } +// OnFinish will both set the callback for OnCommit and OnRollback +func (wf *Workflow) OnFinish(fn func(bb *dtmcli.BranchBarrier, isRollback bool) error) *Workflow { + return wf.OnCommit(func(bb *dtmcli.BranchBarrier) error { + return fn(bb, false) + }).OnRollback(func(bb *dtmcli.BranchBarrier) error { + return fn(bb, true) + }) +} + // Do will do an action which will be recored func (wf *Workflow) Do(fn func(bb *dtmcli.BranchBarrier) ([]byte, error)) ([]byte, error) { res := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult { diff --git a/test/workflow_http_test.go b/test/workflow_http_test.go index 18c18a9..e8b8de2 100644 --- a/test/workflow_http_test.go +++ b/test/workflow_http_test.go @@ -12,6 +12,7 @@ import ( "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/dtm-labs/dtm/dtmcli/logger" "github.com/dtm-labs/dtm/dtmgrpc/workflow" "github.com/dtm-labs/dtm/test/busi" "github.com/stretchr/testify/assert" @@ -23,6 +24,10 @@ func TestWorkflowNormal(t *testing.T) { gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { + wf.NewBranch().OnFinish(func(bb *dtmcli.BranchBarrier, isRollback bool) error { + logger.Debugf("OnFinish isRollback: %v", isRollback) + return nil + }) var req busi.ReqHTTP dtmimp.MustUnmarshal(data, &req) _, err := wf.NewBranch().NewRequest().SetBody(req).Post(Busi + "/TransOut") @@ -49,6 +54,10 @@ func TestWorkflowRollback(t *testing.T) { gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { + wf.NewBranch().OnFinish(func(bb *dtmcli.BranchBarrier, isRollback bool) error { + logger.Debugf("OnFinish isRollback: %v", isRollback) + return nil + }) var req busi.ReqHTTP dtmimp.MustUnmarshal(data, &req) _, err := wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error {