diff --git a/dtmgrpc/workflow/imp.go b/dtmgrpc/workflow/imp.go index 9a4bf3e..81a33d5 100644 --- a/dtmgrpc/workflow/imp.go +++ b/dtmgrpc/workflow/imp.go @@ -162,8 +162,7 @@ func (wf *Workflow) callPhase2(branchID string, fn WfPhase2Func) error { func (wf *Workflow) recordedDo(fn func(bb *dtmcli.BranchBarrier) *stepResult) *stepResult { sr := wf.recordedDoInner(fn) - // if options not enabled, only successful branch need to be compensated - if !wf.Options.CompensateErrorBranch && wf.currentRollbackItem != nil && sr.Status == dtmcli.StatusSucceed { + if wf.currentRollbackItem != nil && (sr.Status == dtmcli.StatusSucceed || sr.Status == dtmcli.StatusFailed && wf.Options.CompensateErrorBranch) { wf.failedOps = append(wf.failedOps, *wf.currentRollbackItem) } wf.currentRollbackItem = nil diff --git a/dtmgrpc/workflow/workflow.go b/dtmgrpc/workflow/workflow.go index 29116aa..b0b78cf 100644 --- a/dtmgrpc/workflow/workflow.go +++ b/dtmgrpc/workflow/workflow.go @@ -115,30 +115,27 @@ func (wf *Workflow) NewBranchCtx() context.Context { return wf.NewBranch().Context } -// OnBranchRollback will define a saga branch transaction -// param compensate specify a function for the compensation of next workflow action -func (wf *Workflow) OnBranchRollback(compensate WfPhase2Func) *Workflow { +// OnRollback will set the callback for current branch when rollback happen. +// If you are writing a saga transaction, then you should write the compensation here +// If you are writing a tcc transaction, then you should write the cancel operation here +func (wf *Workflow) OnRollback(compensate WfPhase2Func) *Workflow { branchID := wf.currentBranch - dtmimp.PanicIf(wf.currentRollbackAdded, fmt.Errorf("on branch can only add one rollback callback")) + dtmimp.PanicIf(wf.currentRollbackAdded, fmt.Errorf("one branch can only add one rollback callback")) wf.currentRollbackAdded = true item := workflowPhase2Item{ branchID: branchID, op: dtmimp.OpRollback, fn: compensate, } - if wf.Options.CompensateErrorBranch { - wf.failedOps = append(wf.failedOps, item) - } else { - wf.currentRollbackItem = &item - } + wf.currentRollbackItem = &item return wf } -// OnBranchCommit will define a saga branch transaction -// param compensate specify a function for the compensation of next workflow action -func (wf *Workflow) OnBranchCommit(fn WfPhase2Func) *Workflow { +// OnCommit will will set the callback for current branch when commit happen. +// If you are writing a tcc transaction, then you should write the confirm operation here +func (wf *Workflow) OnCommit(fn WfPhase2Func) *Workflow { branchID := wf.currentBranch - dtmimp.PanicIf(wf.currentCommitAdded, fmt.Errorf("on branch can only add one commit callback")) + dtmimp.PanicIf(wf.currentCommitAdded, fmt.Errorf("one branch can only add one commit callback")) wf.currentCommitAdded = true wf.failedOps = append(wf.succeededOps, workflowPhase2Item{ branchID: branchID, diff --git a/test/workflow_grpc_test.go b/test/workflow_grpc_test.go index a3db107..856e76b 100644 --- a/test/workflow_grpc_test.go +++ b/test/workflow_grpc_test.go @@ -45,7 +45,7 @@ func TestWorkflowGrpcNormal(t *testing.T) { workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { var req busi.BusiReq dtmgimp.MustProtoUnmarshal(data, &req) - wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransOutRevertBSaga(wf.Context, &req) return err }) @@ -53,7 +53,7 @@ func TestWorkflowGrpcNormal(t *testing.T) { if err != nil { return err } - wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransInRevertBSaga(wf.Context, &req) return err }) @@ -74,7 +74,7 @@ func TestWorkflowMixed(t *testing.T) { var req busi.BusiReq dtmgimp.MustProtoUnmarshal(data, &req) - wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransOutRevertBSaga(wf.Context, &req) return err }) @@ -83,10 +83,10 @@ func TestWorkflowMixed(t *testing.T) { return err } - _, err = wf.NewBranch().OnBranchCommit(func(bb *dtmcli.BranchBarrier) error { + _, err = wf.NewBranch().OnCommit(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransInConfirm(wf.Context, &req) return err - }).OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + }).OnRollback(func(bb *dtmcli.BranchBarrier) error { req2 := &busi.ReqHTTP{Amount: 30} _, err := wf.NewRequest().SetBody(req2).Post(Busi + "/TransInRevert") return err diff --git a/test/workflow_http_test.go b/test/workflow_http_test.go index ff85c5e..18c18a9 100644 --- a/test/workflow_http_test.go +++ b/test/workflow_http_test.go @@ -51,7 +51,7 @@ func TestWorkflowRollback(t *testing.T) { workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { var req busi.ReqHTTP dtmimp.MustUnmarshal(data, &req) - _, err := wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + _, err := wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { _, err := wf.NewRequest().SetBody(req).Post(Busi + "/SagaBTransOutCom") return err }).Do(func(bb *dtmcli.BranchBarrier) ([]byte, error) { @@ -62,7 +62,7 @@ func TestWorkflowRollback(t *testing.T) { if err != nil { return err } - _, err = wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + _, err = wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { return bb.CallWithDB(dbGet().ToSQLDB(), func(tx *sql.Tx) error { return busi.SagaAdjustBalance(tx, busi.TransInUID, -req.Amount, "") }) diff --git a/test/workflow_ongoing_test.go b/test/workflow_ongoing_test.go index eafa5ba..a42cdc8 100644 --- a/test/workflow_ongoing_test.go +++ b/test/workflow_ongoing_test.go @@ -64,7 +64,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { if fetchOngoingStep(0) { return dtmcli.ErrOngoing } - wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { if fetchOngoingStep(4) { return dtmcli.ErrOngoing } @@ -78,7 +78,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { if err != nil { return err } - wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnRollback(func(bb *dtmcli.BranchBarrier) error { if fetchOngoingStep(3) { return dtmcli.ErrOngoing }