diff --git a/dtmgrpc/workflow/factory.go b/dtmgrpc/workflow/factory.go index 348dc19..fa65a9a 100644 --- a/dtmgrpc/workflow/factory.go +++ b/dtmgrpc/workflow/factory.go @@ -13,11 +13,11 @@ type workflowFactory struct { httpCallback string grpcDtm string grpcCallback string - handlers map[string]WfFunc + handlers map[string]*wfItem } var defaultFac = workflowFactory{ - handlers: map[string]WfFunc{}, + handlers: map[string]*wfItem{}, } func (w *workflowFactory) execute(name string, gid string, data []byte) error { @@ -26,7 +26,10 @@ func (w *workflowFactory) execute(name string, gid string, data []byte) error { return fmt.Errorf("workflow '%s' not registered. please register at startup", name) } wf := w.newWorkflow(name, gid, data) - return wf.process(handler, data) + for _, fn := range handler.custom { + fn(wf) + } + return wf.process(handler.fn, data) } func (w *workflowFactory) executeByQS(qs url.Values, body []byte) error { @@ -35,12 +38,15 @@ func (w *workflowFactory) executeByQS(qs url.Values, body []byte) error { return w.execute(name, gid, body) } -func (w *workflowFactory) register(name string, handler WfFunc) error { +func (w *workflowFactory) register(name string, handler WfFunc, custom ...func(wf *Workflow)) error { e := w.handlers[name] if e != nil { return fmt.Errorf("a handler already exists for %s", name) } logger.Debugf("workflow '%s' registered.", name) - w.handlers[name] = handler + w.handlers[name] = &wfItem{ + fn: handler, + custom: custom, + } return nil } diff --git a/dtmgrpc/workflow/imp.go b/dtmgrpc/workflow/imp.go index 1391eba..3477151 100644 --- a/dtmgrpc/workflow/imp.go +++ b/dtmgrpc/workflow/imp.go @@ -19,6 +19,7 @@ type workflowImp struct { currentActionAdded bool //nolint currentCommitAdded bool //nolint currentRollbackAdded bool //nolint + currentRollbackItem *workflowPhase2Item // nolint progresses map[string]*stepResult //nolint currentOp string succeededOps []workflowPhase2Item @@ -157,6 +158,15 @@ func (wf *Workflow) callPhase2(branchID string, fn WfPhase2Func) error { } func (wf *Workflow) recordedDo(fn func(bb *dtmcli.BranchBarrier) *stepResult) *stepResult { + sr := wf.recordedDoInner(fn) + // if options not enabled, only successful branch need to be compensated + if !wf.Options.CompensateErrorBranch && wf.currentRollbackItem != nil && sr.Status == dtmcli.ResultSuccess { + wf.failedOps = append(wf.failedOps, *wf.currentRollbackItem) + } + return sr +} + +func (wf *Workflow) recordedDoInner(fn func(bb *dtmcli.BranchBarrier) *stepResult) *stepResult { branchID := wf.currentBranch if wf.currentOp == dtmimp.OpAction { dtmimp.PanicIf(wf.currentActionAdded, fmt.Errorf("one branch can have only on action")) diff --git a/dtmgrpc/workflow/workflow.go b/dtmgrpc/workflow/workflow.go index 2418cba..6fdd131 100644 --- a/dtmgrpc/workflow/workflow.go +++ b/dtmgrpc/workflow/workflow.go @@ -41,8 +41,8 @@ func SetProtocolForTest(protocol string) { } // Register will register a workflow with the specified name -func Register(name string, handler WfFunc) error { - return defaultFac.register(name, handler) +func Register(name string, handler WfFunc, custom ...func(wf *Workflow)) error { + return defaultFac.register(name, handler, custom...) } // Execute will execute a workflow with the gid and specified params @@ -59,9 +59,13 @@ func ExecuteByQS(qs url.Values, body []byte) error { // Options is for specifying workflow options type Options struct { + // default false: Workflow's restyClient will convert http response to error if status code is not 200 // if this flag is set true, then Workflow's restyClient will keep the origin http response - // or else, Workflow's restyClient will convert http response to error if status code is not 200 DisalbeAutoError bool + + // default false: fn registered by OnBranchRollback will not be called for FAILURE branch + // if this flag is set true, then fn will be called. the user should handle null rollback and dangling + CompensateErrorBranch bool } // Workflow is the type for a workflow @@ -73,6 +77,11 @@ type Workflow struct { workflowImp } +type wfItem struct { + fn WfFunc + custom []func(*Workflow) +} + // WfFunc is the type for workflow function type WfFunc func(wf *Workflow, data []byte) error @@ -107,11 +116,16 @@ func (wf *Workflow) OnBranchRollback(compensate WfPhase2Func) *Workflow { branchID := wf.currentBranch dtmimp.PanicIf(wf.currentRollbackAdded, fmt.Errorf("on branch can only add one rollback callback")) wf.currentRollbackAdded = true - wf.failedOps = append(wf.failedOps, workflowPhase2Item{ + item := workflowPhase2Item{ branchID: branchID, op: dtmimp.OpRollback, fn: compensate, - }) + } + if wf.Options.CompensateErrorBranch { + wf.failedOps = append(wf.failedOps, item) + } else { + wf.currentRollbackItem = &item + } return wf } diff --git a/helper/test-cover.sh b/helper/test-cover.sh index 9223328..8d6723e 100755 --- a/helper/test-cover.sh +++ b/helper/test-cover.sh @@ -1,5 +1,5 @@ set -x -echo "mode: count" coverage.txt +echo "mode: count" > coverage.txt for store in redis boltdb mysql postgres; do for d in $(go list ./... | grep -v vendor); do TEST_STORE=$store go test -failfast -covermode count -coverprofile=profile.out -coverpkg=github.com/dtm-labs/dtm/dtmcli,github.com/dtm-labs/dtm/dtmcli/dtmimp,github.com/dtm-labs/dtm/dtmcli/logger,github.com/dtm-labs/dtm/dtmgrpc,github.com/dtm-labs/dtm/dtmgrpc/workflow,github.com/dtm-labs/dtm/dtmgrpc/dtmgimp,github.com/dtm-labs/dtm/dtmsvr,github.com/dtm-labs/dtm/dtmsvr/config,github.com/dtm-labs/dtm/dtmsvr/storage,github.com/dtm-labs/dtm/dtmsvr/storage/boltdb,github.com/dtm-labs/dtm/dtmsvr/storage/redis,github.com/dtm-labs/dtm/dtmsvr/storage/registry,github.com/dtm-labs/dtm/dtmsvr/storage/sql,github.com/dtm-labs/dtm/dtmutil -gcflags=-l $d || exit 1 diff --git a/test/workflow_ongoing_test.go b/test/workflow_ongoing_test.go index 8a6fcec..f956554 100644 --- a/test/workflow_ongoing_test.go +++ b/test/workflow_ongoing_test.go @@ -90,6 +90,8 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { return dtmcli.ErrOngoing } return err + }, func(wf *workflow.Workflow) { + wf.Options.CompensateErrorBranch = true }) req := &busi.BusiReq{Amount: 30, TransInResult: "FAILURE"} err := workflow.Execute(gid, gid, dtmgimp.MustProtoMarshal(req))