From ea15146423d5d942b56e86d2db8fb05ae866bc44 Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Wed, 29 Jun 2022 21:16:09 +0800 Subject: [PATCH] branch conflict detect ok --- dtmgrpc/workflow/workflow.go | 6 ++--- dtmsvr/storage/boltdb/boltdb.go | 20 +++++++++----- dtmsvr/storage/redis/redis.go | 11 ++++++++ dtmsvr/storage/sql/sql.go | 6 ++++- test/busi/base_http.go | 2 +- test/busi/base_types.go | 16 ++++++------ test/busi/utils.go | 2 ++ test/tcc_barrier_test.go | 2 +- test/workflow_test.go | 46 ++++++++++++++++++++++++++------- 9 files changed, 81 insertions(+), 30 deletions(-) diff --git a/dtmgrpc/workflow/workflow.go b/dtmgrpc/workflow/workflow.go index 98b1081..ed6f714 100644 --- a/dtmgrpc/workflow/workflow.go +++ b/dtmgrpc/workflow/workflow.go @@ -85,9 +85,9 @@ func (wf *Workflow) NewRequest() *resty.Request { return wf.restyClient.R().SetContext(wf.Context) } -// DefineSagaPhase2 will define a saga branch transaction +// AddSagaPhase2 will define a saga branch transaction // param compensate specify a function for the compensation of next workflow action -func (wf *Workflow) DefineSagaPhase2(compensate WfPhase2Func) { +func (wf *Workflow) AddSagaPhase2(compensate WfPhase2Func) { branchID := wf.currentBranch wf.failedOps = append(wf.failedOps, workflowPhase2Item{ branchID: branchID, @@ -98,7 +98,7 @@ func (wf *Workflow) DefineSagaPhase2(compensate WfPhase2Func) { // DefineSagaPhase2 will define a tcc branch transaction // param confirm, concel specify the confirm and cancel operation of next workflow action -func (wf *Workflow) DefineTccPhase2(confirm, cancel WfPhase2Func) { +func (wf *Workflow) AddTccPhase2(confirm, cancel WfPhase2Func) { branchID := wf.currentBranch wf.failedOps = append(wf.failedOps, workflowPhase2Item{ branchID: branchID, diff --git a/dtmsvr/storage/boltdb/boltdb.go b/dtmsvr/storage/boltdb/boltdb.go index a628ca0..c1f646d 100644 --- a/dtmsvr/storage/boltdb/boltdb.go +++ b/dtmsvr/storage/boltdb/boltdb.go @@ -69,12 +69,12 @@ func initializeBuckets(db *bolt.DB) error { // cleanupExpiredData will clean the expired data in boltdb, the // expired time is configurable. -func cleanupExpiredData(expiredSeconds time.Duration, db *bolt.DB) error { - if expiredSeconds <= 0 { +func cleanupExpiredData(expire time.Duration, db *bolt.DB) error { + if expire <= 0 { return nil } - lastKeepTime := time.Now().Add(-expiredSeconds) + lastKeepTime := time.Now().Add(-expire) return db.Update(func(t *bolt.Tx) error { globalBucket := t.Bucket(bucketGlobal) if globalBucket == nil { @@ -209,9 +209,15 @@ func tPutGlobal(t *bolt.Tx, global *storage.TransGlobalStore) { dtmimp.E2P(err) } -func tPutBranches(t *bolt.Tx, branches []storage.TransBranchStore, start int64) { +func tPutBranches(t *bolt.Tx, branches []storage.TransBranchStore, start int64) error { if start == -1 { - bs := tGetBranches(t, branches[0].Gid) + b0 := &branches[0] + bs := tGetBranches(t, b0.Gid) + for _, b := range bs { + if b.BranchID == b0.BranchID && b.Op == b0.Op { + return storage.ErrUniqueConflict + } + } start = int64(len(bs)) } for i, b := range branches { @@ -220,6 +226,7 @@ func tPutBranches(t *bolt.Tx, branches []storage.TransBranchStore, start int64) err := t.Bucket(bucketBranches).Put([]byte(k), []byte(v)) dtmimp.E2P(err) } + return nil } func tDelIndex(t *bolt.Tx, unix int64, gid string) { @@ -323,8 +330,7 @@ func (s *Store) LockGlobalSaveBranches(gid string, status string, branches []sto if g.Status != status { return storage.ErrNotFound } - tPutBranches(t, branches, int64(branchStart)) - return nil + return tPutBranches(t, branches, int64(branchStart)) }) dtmimp.E2P(err) } diff --git a/dtmsvr/storage/redis/redis.go b/dtmsvr/storage/redis/redis.go index d24d658..6ff5869 100644 --- a/dtmsvr/storage/redis/redis.go +++ b/dtmsvr/storage/redis/redis.go @@ -198,6 +198,17 @@ if old ~= ARGV[3] then return 'NOT_FOUND' end local start = ARGV[4] +-- check duplicates for workflow +if start == "-1" then + local t = cjson.decode(ARGV[5]) + local bs = redis.call('LRANGE', KEYS[2], 0, -1) + for i = 1, table.getn(bs) do + local c = cjson.decode(bs[i]) + if t['branch_id'] == c['branch_id'] and t['op'] == c['op'] then + return 'UNIQUE_CONFLICT' + end + end +end for k = 5, table.getn(ARGV) do if start == "-1" then redis.call('RPUSH', KEYS[2], ARGV[k]) diff --git a/dtmsvr/storage/sql/sql.go b/dtmsvr/storage/sql/sql.go index 0ad3565..7628623 100644 --- a/dtmsvr/storage/sql/sql.go +++ b/dtmsvr/storage/sql/sql.go @@ -89,7 +89,11 @@ func (s *Store) LockGlobalSaveBranches(gid string, status string, branches []sto g := &storage.TransGlobalStore{} dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g) if dbr.Error == nil { - dbr = tx.Save(branches) + if branchStart == -1 { + dbr = tx.Create(branches) + } else { + dbr = tx.Save(branches) + } } return wrapError(dbr.Error) }) diff --git a/test/busi/base_http.go b/test/busi/base_http.go index efa2633..d0d6a4b 100644 --- a/test/busi/base_http.go +++ b/test/busi/base_http.go @@ -158,7 +158,7 @@ func BaseAddRoute(app *gin.Engine) { tcc, err := dtmcli.TccFromQuery(c.Request.URL.Query()) logger.FatalIfError(err) logger.Debugf("TransInTccNested ") - resp, err := tcc.CallBranch(&TransReq{Amount: reqFrom(c).Amount}, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") + resp, err := tcc.CallBranch(&ReqHttp{Amount: reqFrom(c).Amount}, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") if err != nil { return err } diff --git a/test/busi/base_types.go b/test/busi/base_types.go index 0d4171d..9cb1c28 100644 --- a/test/busi/base_types.go +++ b/test/busi/base_types.go @@ -60,21 +60,21 @@ func GetBalanceByUID(uid int, store string) int { return dtmimp.MustAtoi(ua.Balance[:len(ua.Balance)-3]) } -// TransReq transaction request payload -type TransReq struct { +// ReqHttp transaction request payload +type ReqHttp struct { Amount int `json:"amount"` TransInResult string `json:"trans_in_result"` TransOutResult string `json:"trans_out_Result"` Store string `json:"store"` // default mysql, value can be mysql|redis } -func (t *TransReq) String() string { +func (t *ReqHttp) String() string { return fmt.Sprintf("amount: %d transIn: %s transOut: %s", t.Amount, t.TransInResult, t.TransOutResult) } // GenTransReq 1 -func GenTransReq(amount int, outFailed bool, inFailed bool) *TransReq { - return &TransReq{ +func GenTransReq(amount int, outFailed bool, inFailed bool) *ReqHttp { + return &ReqHttp{ Amount: amount, TransOutResult: dtmimp.If(outFailed, dtmcli.ResultFailure, "").(string), TransInResult: dtmimp.If(inFailed, dtmcli.ResultFailure, "").(string), @@ -90,16 +90,16 @@ func GenBusiReq(amount int, outFailed bool, inFailed bool) *BusiReq { } } -func reqFrom(c *gin.Context) *TransReq { +func reqFrom(c *gin.Context) *ReqHttp { v, ok := c.Get("trans_req") if !ok { - req := TransReq{} + req := ReqHttp{} err := c.BindJSON(&req) logger.FatalIfError(err) c.Set("trans_req", &req) v = &req } - return v.(*TransReq) + return v.(*ReqHttp) } func infoFromContext(c *gin.Context) *dtmcli.BranchBarrier { diff --git a/test/busi/utils.go b/test/busi/utils.go index db07113..49b02e2 100644 --- a/test/busi/utils.go +++ b/test/busi/utils.go @@ -25,6 +25,8 @@ import ( "google.golang.org/grpc/metadata" ) +type ReqGrpc = BusiReq + func dbGet() *dtmutil.DB { return dtmutil.DbGet(BusiConf) } diff --git a/test/tcc_barrier_test.go b/test/tcc_barrier_test.go index fb2beef..bfa5e90 100644 --- a/test/tcc_barrier_test.go +++ b/test/tcc_barrier_test.go @@ -69,7 +69,7 @@ func runTestTccBarrierDisorder(t *testing.T, store string) { gid := dtmimp.GetFuncName() + store cronFinished := make(chan string, 2) err := dtmcli.TccGlobalTransaction(DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { - body := &busi.TransReq{Amount: 30, Store: store} + body := &busi.ReqHttp{Amount: 30, Store: store} tryURL := Busi + "/TccBTransOutTry" confirmURL := Busi + "/TccBTransOutConfirm" cancelURL := Busi + "/SleepCancel" diff --git a/test/workflow_test.go b/test/workflow_test.go index 6a1d9ef..2b1f7bd 100644 --- a/test/workflow_test.go +++ b/test/workflow_test.go @@ -9,12 +9,15 @@ package test import ( "database/sql" "testing" + "time" "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" "github.com/dtm-labs/dtm/dtmgrpc/workflow" + "github.com/dtm-labs/dtm/dtmsvr" + "github.com/dtm-labs/dtm/dtmsvr/storage" "github.com/dtm-labs/dtm/test/busi" "github.com/stretchr/testify/assert" ) @@ -25,7 +28,7 @@ func TestWorkflowNormal(t *testing.T) { gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { - var req busi.TransReq + var req busi.ReqHttp dtmimp.MustUnmarshal(data, &req) _, err := wf.NewRequest().SetBody(req).Post(Busi + "/TransOut") if err != nil { @@ -47,13 +50,13 @@ func TestWorkflowNormal(t *testing.T) { func TestWorkflowRollback(t *testing.T) { workflow.SetProtocolForTest(dtmimp.ProtocolHTTP) - req := busi.GenTransReq(30, false, true) + req := &busi.ReqHttp{Amount: 30, TransInResult: dtmimp.ResultFailure} gid := dtmimp.GetFuncName() workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { - var req busi.TransReq + var req busi.ReqHttp dtmimp.MustUnmarshal(data, &req) - wf.DefineSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { _, err := wf.NewRequest().SetBody(req).Post(Busi + "/SagaBTransOutCom") return err }) @@ -65,7 +68,7 @@ func TestWorkflowRollback(t *testing.T) { if err != nil { return err } - wf.DefineSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { return bb.CallWithDB(dbGet().ToSQLDB(), func(tx *sql.Tx) error { return busi.SagaAdjustBalance(tx, busi.TransInUID, -req.Amount, "") }) @@ -90,7 +93,7 @@ func TestWorkflowGrpcNormal(t *testing.T) { workflow.Register(gid, func(wf *workflow.Workflow, data []byte) error { var req busi.BusiReq dtmgimp.MustProtoUnmarshal(data, &req) - wf.DefineSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransOutRevertBSaga(wf.Context, &req) return err }) @@ -98,7 +101,7 @@ func TestWorkflowGrpcNormal(t *testing.T) { if err != nil { return err } - wf.DefineSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { _, err := busi.BusiCli.TransInRevertBSaga(wf.Context, &req) return err }) @@ -133,7 +136,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { if fetchOngoingStep(0) { return dtmcli.ErrOngoing } - wf.DefineSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { if fetchOngoingStep(4) { return dtmcli.ErrOngoing } @@ -147,7 +150,7 @@ func TestWorkflowGrpcRollbackResume(t *testing.T) { if err != nil { return err } - wf.DefineSagaPhase2(func(bb *dtmcli.BranchBarrier) error { + wf.AddSagaPhase2(func(bb *dtmcli.BranchBarrier) error { if fetchOngoingStep(3) { return dtmcli.ErrOngoing } @@ -263,3 +266,28 @@ func TestWorkflowXaResume(t *testing.T) { cronTransOnceForwardNow(t, gid, 1000) assert.Equal(t, StatusSucceed, getTransStatus(gid)) } + +func TestWorkflowBranchConflict(t *testing.T) { + gid := dtmimp.GetFuncName() + store := dtmsvr.GetStore() + now := time.Now() + g := &storage.TransGlobalStore{ + Gid: gid, + Status: dtmcli.StatusPrepared, + NextCronTime: &now, + } + err := store.MaySaveNewTrans(g, []storage.TransBranchStore{ + { + BranchID: "00", + Op: dtmimp.OpAction, + }, + }) + assert.Nil(t, err) + err = dtmimp.CatchP(func() { + store.LockGlobalSaveBranches(gid, dtmcli.StatusPrepared, []storage.TransBranchStore{ + {BranchID: "00", Op: dtmimp.OpAction}, + }, -1) + }) + assert.Equal(t, storage.ErrUniqueConflict, err) + store.ChangeGlobalStatus(g, StatusSucceed, []string{}, true) +}