diff --git a/dtmgrpc/workflow/dummyReadCloser.go b/dtmgrpc/workflow/dummyReadCloser.go index 5a38a51..8f07042 100644 --- a/dtmgrpc/workflow/dummyReadCloser.go +++ b/dtmgrpc/workflow/dummyReadCloser.go @@ -3,51 +3,23 @@ package workflow import ( "bytes" "io" - "strings" ) -// NewRespBodyFromString creates an io.ReadCloser from a string that -// is suitable for use as an http response body. -// -// To pass the content of an existing file as body use httpmock.File as in: -// httpmock.NewRespBodyFromString(httpmock.File("body.txt").String()) -func NewRespBodyFromString(body string) io.ReadCloser { - return &dummyReadCloser{orig: body} -} - // NewRespBodyFromBytes creates an io.ReadCloser from a byte slice // that is suitable for use as an http response body. -// -// To pass the content of an existing file as body use httpmock.File as in: -// httpmock.NewRespBodyFromBytes(httpmock.File("body.txt").Bytes()) func NewRespBodyFromBytes(body []byte) io.ReadCloser { - return &dummyReadCloser{orig: body} + return &dummyReadCloser{body: bytes.NewReader(body)} } type dummyReadCloser struct { - orig interface{} // string or []byte - body io.ReadSeeker // instanciated on demand from orig -} - -// setup ensures d.body is correctly initialized. -func (d *dummyReadCloser) setup() { - if d.body == nil { - switch body := d.orig.(type) { - case string: - d.body = strings.NewReader(body) - case []byte: - d.body = bytes.NewReader(body) - } - } + body io.ReadSeeker } func (d *dummyReadCloser) Read(p []byte) (n int, err error) { - d.setup() return d.body.Read(p) } func (d *dummyReadCloser) Close() error { - d.setup() - d.body.Seek(0, io.SeekEnd) // nolint: errcheck + _, _ = d.body.Seek(0, io.SeekEnd) return nil } diff --git a/dtmgrpc/workflow/imp.go b/dtmgrpc/workflow/imp.go index 5fd8c7a..1391eba 100644 --- a/dtmgrpc/workflow/imp.go +++ b/dtmgrpc/workflow/imp.go @@ -3,6 +3,7 @@ package workflow import ( "context" "errors" + "fmt" "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" @@ -12,13 +13,16 @@ import ( ) type workflowImp struct { - restyClient *resty.Client //nolint - idGen dtmimp.BranchIDGen - currentBranch string //nolint - progresses map[string]*stepResult //nolint - currentOp string - succeededOps []workflowPhase2Item - failedOps []workflowPhase2Item + restyClient *resty.Client //nolint + idGen dtmimp.BranchIDGen + currentBranch string //nolint + currentActionAdded bool //nolint + currentCommitAdded bool //nolint + currentRollbackAdded bool //nolint + progresses map[string]*stepResult //nolint + currentOp string + succeededOps []workflowPhase2Item + failedOps []workflowPhase2Item } type workflowPhase2Item struct { @@ -61,7 +65,6 @@ func (w *workflowFactory) newWorkflow(name string, gid string, data []byte) *Wor wf.Dtm = w.httpDtm wf.QueryPrepared = w.httpCallback } - wf.newBranch() wf.CustomData = dtmimp.MustMarshalString(map[string]interface{}{ "name": wf.Name, "data": data, @@ -155,10 +158,11 @@ func (wf *Workflow) callPhase2(branchID string, fn WfPhase2Func) error { func (wf *Workflow) recordedDo(fn func(bb *dtmcli.BranchBarrier) *stepResult) *stepResult { branchID := wf.currentBranch - r := wf.getStepResult() - if wf.currentOp == dtmimp.OpAction { // for action steps, an action will start a new branch - wf.newBranch() + if wf.currentOp == dtmimp.OpAction { + dtmimp.PanicIf(wf.currentActionAdded, fmt.Errorf("one branch can have only on action")) + wf.currentActionAdded = true } + r := wf.getStepResult() if r != nil { logger.Debugf("progress restored: %s %s %v %s %s", branchID, wf.currentOp, r.Error, r.Status, r.Data) return r @@ -177,11 +181,6 @@ func (wf *Workflow) recordedDo(fn func(bb *dtmcli.BranchBarrier) *stepResult) *s return r } -func (wf *Workflow) newBranch() { - wf.idGen.NewSubBranchID() - wf.currentBranch = wf.idGen.CurrentSubBranchID() -} - func (wf *Workflow) getStepResult() *stepResult { logger.Debugf("getStepResult: %s %v", wf.currentBranch+"-"+wf.currentOp, wf.progresses[wf.currentBranch+"-"+wf.currentOp]) return wf.progresses[wf.currentBranch+"-"+wf.currentOp] diff --git a/dtmgrpc/workflow/rpc.go b/dtmgrpc/workflow/rpc.go index a2204d5..4e07fd6 100644 --- a/dtmgrpc/workflow/rpc.go +++ b/dtmgrpc/workflow/rpc.go @@ -2,7 +2,6 @@ package workflow import ( "context" - "strings" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" @@ -68,11 +67,3 @@ func (wf *Workflow) registerBranch(res []byte, branchID string, op string, statu }) return err } - -func (wf *Workflow) prepare() error { - operation := "prepare" - if wf.Protocol == dtmimp.ProtocolGRPC { - return dtmgimp.DtmGrpcCall(wf.TransBase, strings.Title(operation)) - } - return dtmimp.TransCallDtm(wf.TransBase, operation) -} diff --git a/dtmgrpc/workflow/workflow.go b/dtmgrpc/workflow/workflow.go index 5c1fc20..2418cba 100644 --- a/dtmgrpc/workflow/workflow.go +++ b/dtmgrpc/workflow/workflow.go @@ -85,35 +85,52 @@ func (wf *Workflow) NewRequest() *resty.Request { return wf.restyClient.R().SetContext(wf.Context) } -// AddSagaPhase2 will define a saga branch transaction +// NewBranch will start a new branch transaction +func (wf *Workflow) NewBranch() *Workflow { + dtmimp.PanicIf(wf.currentOp != dtmimp.OpAction, fmt.Errorf("should not call NewBranch() in Branch callbacks")) + wf.idGen.NewSubBranchID() + wf.currentBranch = wf.idGen.CurrentSubBranchID() + wf.currentActionAdded = false + wf.currentCommitAdded = false + wf.currentRollbackAdded = false + return wf +} + +// NewBranchCtx will call NewBranch and return a workflow context +func (wf *Workflow) NewBranchCtx() context.Context { + return wf.NewBranch().Context +} + +// OnBranchRollback will define a saga branch transaction // param compensate specify a function for the compensation of next workflow action -func (wf *Workflow) AddSagaPhase2(compensate WfPhase2Func) { +func (wf *Workflow) OnBranchRollback(compensate WfPhase2Func) *Workflow { branchID := wf.currentBranch + dtmimp.PanicIf(wf.currentRollbackAdded, fmt.Errorf("on branch can only add one rollback callback")) + wf.currentRollbackAdded = true wf.failedOps = append(wf.failedOps, workflowPhase2Item{ branchID: branchID, op: dtmimp.OpRollback, fn: compensate, }) + return wf } -// AddTccPhase2 will define a tcc branch transaction -// param confirm, concel specify the confirm and cancel operation of next workflow action -func (wf *Workflow) AddTccPhase2(confirm, cancel WfPhase2Func) { +// OnBranchCommit will define a saga branch transaction +// param compensate specify a function for the compensation of next workflow action +func (wf *Workflow) OnBranchCommit(fn WfPhase2Func) *Workflow { branchID := wf.currentBranch - wf.failedOps = append(wf.failedOps, workflowPhase2Item{ - branchID: branchID, - op: dtmimp.OpRollback, - fn: cancel, - }) - wf.succeededOps = append(wf.succeededOps, workflowPhase2Item{ + dtmimp.PanicIf(wf.currentCommitAdded, fmt.Errorf("on branch can only add one commit callback")) + wf.currentCommitAdded = true + wf.failedOps = append(wf.succeededOps, workflowPhase2Item{ branchID: branchID, op: dtmimp.OpCommit, - fn: confirm, + fn: fn, }) + return wf } -// DoAction will do an action which will be recored -func (wf *Workflow) DoAction(fn func(bb *dtmcli.BranchBarrier) ([]byte, error)) ([]byte, error) { +// Do will do an action which will be recored +func (wf *Workflow) Do(fn func(bb *dtmcli.BranchBarrier) ([]byte, error)) ([]byte, error) { res := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult { r, e := fn(bb) return stepResultFromLocal(r, e) @@ -121,9 +138,9 @@ func (wf *Workflow) DoAction(fn func(bb *dtmcli.BranchBarrier) ([]byte, error)) return stepResultToLocal(res) } -// DoXaAction will begin a local xa transaction +// DoXa will begin a local xa transaction // after the return of workflow function, xa commit/rollback will be called -func (wf *Workflow) DoXaAction(dbConf dtmcli.DBConf, fn func(db *sql.DB) ([]byte, error)) ([]byte, error) { +func (wf *Workflow) DoXa(dbConf dtmcli.DBConf, fn func(db *sql.DB) ([]byte, error)) ([]byte, error) { branchID := wf.currentBranch res := wf.recordedDo(func(bb *dtmcli.BranchBarrier) *stepResult { sBusi := "business" diff --git a/dtmgrpc/workflow/workflow_test.go b/dtmgrpc/workflow/workflow_test.go new file mode 100644 index 0000000..f4c81e7 --- /dev/null +++ b/dtmgrpc/workflow/workflow_test.go @@ -0,0 +1,24 @@ +package workflow + +import ( + "context" + "testing" + + "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/stretchr/testify/assert" +) + +func TestAbnormal(t *testing.T) { + fname := dtmimp.GetFuncName() + err := defaultFac.execute(fname, fname, nil) + assert.Error(t, err) + + err = defaultFac.register(fname, func(wf *Workflow, data []byte) error { return nil }) + assert.Nil(t, err) + err = defaultFac.register(fname, nil) + assert.Error(t, err) + + ws := &workflowServer{} + _, err = ws.Execute(context.Background(), nil) + assert.Contains(t, err.Error(), "call workflow.InitGrpc first") +} diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 614cdac..ae8c32d 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -56,7 +56,7 @@ func svcPrepareWorkflow(t *TransGlobal) ([]TransBranch, error) { if err == storage.ErrUniqueConflict { // transaction exists, query the branches return GetStore().FindBranches(t.Gid), nil } - return []TransBranch{}, nil + return []TransBranch{}, err } func svcAbort(t *TransGlobal) interface{} { diff --git a/test/workflow_test.go b/test/workflow_test.go index e02fb77..1ef3b1b 100644 --- a/test/workflow_test.go +++ b/test/workflow_test.go @@ -30,11 +30,11 @@ func TestWorkflowNormal(t *testing.T) { workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { var req busi.ReqHTTP dtmimp.MustUnmarshal(data, &req) - _, err := wf.NewRequest().SetBody(req).Post(Busi + "/TransOut") + _, err := wf.NewBranch().NewRequest().SetBody(req).Post(Busi + "/TransOut") if err != nil { return err } - _, err = wf.NewRequest().SetBody(req).Post(Busi + "/TransIn") + _, err = wf.NewBranch().NewRequest().SetBody(req).Post(Busi + "/TransIn") if err != nil { return err } @@ -47,6 +47,29 @@ func TestWorkflowNormal(t *testing.T) { assert.Equal(t, StatusSucceed, getTransStatus(gid)) } +func TestWorkflowSimpleResume(t *testing.T) { + workflow.SetProtocolForTest(dtmimp.ProtocolHTTP) + req := busi.GenTransReq(30, false, false) + gid := dtmimp.GetFuncName() + ongoingStep = 0 + + workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { + if fetchOngoingStep(0) { + return dtmcli.ErrOngoing + } + var req busi.ReqHTTP + dtmimp.MustUnmarshal(data, &req) + _, err := wf.NewBranch().NewRequest().SetBody(req).Post(Busi + "/TransOut") + return err + }) + + err := workflow.Execute(gid, gid, dtmimp.MustMarshal(req)) + assert.Error(t, err) + go waitTransProcessed(gid) + cronTransOnceForwardNow(t, gid, 1000) + assert.Equal(t, StatusSucceed, getTransStatus(gid)) +} + func TestWorkflowRollback(t *testing.T) { workflow.SetProtocolForTest(dtmimp.ProtocolHTTP) @@ -56,11 +79,10 @@ func TestWorkflowRollback(t *testing.T) { workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { var req busi.ReqHTTP dtmimp.MustUnmarshal(data, &req) - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + _, err := wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { _, err := wf.NewRequest().SetBody(req).Post(Busi + "/SagaBTransOutCom") return err - }) - _, err := wf.DoAction(func(bb *dtmcli.BranchBarrier) ([]byte, error) { + }).Do(func(bb *dtmcli.BranchBarrier) ([]byte, error) { return nil, bb.CallWithDB(dbGet().ToSQLDB(), func(tx *sql.Tx) error { return busi.SagaAdjustBalance(tx, busi.TransOutUID, -req.Amount, "") }) @@ -68,12 +90,11 @@ func TestWorkflowRollback(t *testing.T) { if err != nil { return err } - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + _, err = wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { return bb.CallWithDB(dbGet().ToSQLDB(), func(tx *sql.Tx) error { return busi.SagaAdjustBalance(tx, busi.TransInUID, -req.Amount, "") }) - }) - _, err = wf.NewRequest().SetBody(req).Post(Busi + "/SagaBTransIn") + }).NewRequest().SetBody(req).Post(Busi + "/SagaBTransIn") if err != nil { return err } @@ -93,7 +114,7 @@ func TestWorkflowGrpcNormal(t *testing.T) { workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { var req busi.BusiReq dtmgimp.MustProtoUnmarshal(data, &req) - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransOutRevertBSaga(wf.Context, &req) return err }) @@ -101,7 +122,7 @@ func TestWorkflowGrpcNormal(t *testing.T) { if err != nil { return err } - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransInRevertBSaga(wf.Context, &req) return err }) @@ -136,7 +157,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { if fetchOngoingStep(0) { return dtmcli.ErrOngoing } - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { if fetchOngoingStep(4) { return dtmcli.ErrOngoing } @@ -150,7 +171,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { if err != nil { return err } - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { if fetchOngoingStep(3) { return dtmcli.ErrOngoing } @@ -185,13 +206,13 @@ func TestWorkflowXaAction(t *testing.T) { workflow.SetProtocolForTest(dtmimp.ProtocolGRPC) gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { - _, err := wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err := wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { return nil, busi.SagaAdjustBalance(db, busi.TransOutUID, -30, dtmcli.ResultSuccess) }) if err != nil { return err } - _, err = wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err = wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { return nil, busi.SagaAdjustBalance(db, busi.TransInUID, 30, dtmcli.ResultSuccess) }) return err @@ -206,13 +227,13 @@ func TestWorkflowXaRollback(t *testing.T) { workflow.SetProtocolForTest(dtmimp.ProtocolGRPC) gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { - _, err := wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err := wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { return nil, busi.SagaAdjustBalance(db, busi.TransOutUID, -30, dtmcli.ResultSuccess) }) if err != nil { return err } - _, err = wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err = wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { e := busi.SagaAdjustBalance(db, busi.TransInUID, 30, dtmcli.ResultSuccess) logger.FatalIfError(e) return nil, dtmcli.ErrFailure @@ -230,7 +251,7 @@ func TestWorkflowXaResume(t *testing.T) { ongoingStep = 0 gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { - _, err := wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err := wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { if fetchOngoingStep(0) { return nil, dtmcli.ErrOngoing } @@ -239,7 +260,7 @@ func TestWorkflowXaResume(t *testing.T) { if err != nil { return err } - _, err = wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err = wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { if fetchOngoingStep(1) { return nil, dtmcli.ErrOngoing } @@ -300,7 +321,7 @@ func TestWorkflowMixed(t *testing.T) { var req busi.BusiReq dtmgimp.MustProtoUnmarshal(data, &req) - wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.NewBranch().OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransOutRevertBSaga(wf.Context, &req) return err }) @@ -309,22 +330,21 @@ func TestWorkflowMixed(t *testing.T) { return err } - wf.AddTccPhase2(func(bb *dtmcli.BranchBarrier) error { + _, err = wf.NewBranch().OnBranchCommit(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransInConfirm(wf.Context, &req) return err - }, func(bb *dtmcli.BranchBarrier) error { + }).OnBranchRollback(func(bb *dtmcli.BranchBarrier) error { req2 := &busi.ReqHTTP{Amount: 30} _, err := wf.NewRequest().SetBody(req2).Post(Busi + "/TransInRevert") return err - }) - _, err = wf.DoAction(func(bb *dtmcli.BranchBarrier) ([]byte, error) { + }).Do(func(bb *dtmcli.BranchBarrier) ([]byte, error) { err := busi.SagaAdjustBalance(dbGet().ToSQLDB(), busi.TransInUID, int(req.Amount), "") return nil, err }) if err != nil { return err } - _, err = wf.DoXaAction(busi.BusiConf, func(db *sql.DB) ([]byte, error) { + _, err = wf.NewBranch().DoXa(busi.BusiConf, func(db *sql.DB) ([]byte, error) { return nil, busi.SagaAdjustBalance(db, busi.TransInUID, 0, dtmcli.ResultSuccess) }) return err @@ -333,5 +353,4 @@ func TestWorkflowMixed(t *testing.T) { assert.Nil(t, err) assert.Equal(t, StatusSucceed, getTransStatus(gid)) waitTransProcessed(gid) - }