diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index c264bd6..c2b66d1 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -13,6 +13,7 @@ import ( "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" + "github.com/go-redis/redis/v8" ) // BarrierBusiFunc type for busi func @@ -76,12 +77,12 @@ func (bb *BranchBarrier) Call(tx *sql.Tx, busiCall BarrierBusiFunc) (rerr error) } }() ti := bb - originType := map[string]string{ + originOp := map[string]string{ BranchCancel: BranchTry, BranchCompensate: BranchAction, }[ti.Op] - originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, bid, ti.Op) + originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originOp, bid, ti.Op) currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.Op, bid, ti.Op) logger.Debugf("originAffected: %d currentAffected: %d", originAffected, currentAffected) if (ti.Op == BranchCancel || ti.Op == BranchCompensate) && originAffected > 0 || // 这个是空补偿 @@ -114,3 +115,44 @@ func (bb *BranchBarrier) QueryPrepared(db *sql.DB) error { } return err } + +// RedisCheckAdjustAmount check the value of key is valid and >= amount. then adjust the amount +func (bb *BranchBarrier) RedisCheckAdjustAmount(rd *redis.Client, key string, amount int, barrierExpire int) error { + bkey1 := fmt.Sprintf("%s-%s-%s-%s-%02d", key, bb.Gid, bb.BranchID, bb.Op, bb.BarrierID) + originOp := map[string]string{ + BranchCancel: BranchTry, + BranchCompensate: BranchAction, + }[bb.Op] + bkey2 := fmt.Sprintf("%s-%s-%s-%s-%02d", key, bb.Gid, bb.BranchID, originOp, bb.BarrierID) + v, err := rd.Eval(rd.Context(), ` -- RedisCheckAdjustAmount +local v = redis.call('GET', KEYS[1]) +local e1 = redis.call('GET', KEYS[2]) + +if v == false or v + ARGV[1] < 0 then + return 'FAILURE' +end + +if e1 ~= false then + return +end + +redis.call('SET', KEYS[2], 'op', 'EX', ARGV[3]) + +if ARGV[2] ~= '' then + local e2 = redis.call('GET', KEYS[3]) + if e2 == false then + redis.call('SET', KEYS[3], 'rollback', 'EX', ARGV[3]) + return + end +end +redis.call('INCRBY', KEYS[1], ARGV[1]) +`, []string{key, bkey1, bkey2}, amount, originOp, barrierExpire).Result() + logger.Debugf("lua return v: %v err: %v", v, err) + if err == redis.Nil { + err = nil + } + if err == nil && v == ResultFailure { + err = ErrFailure + } + return err +} diff --git a/dtmsvr/trans_type_saga.go b/dtmsvr/trans_type_saga.go index 15f8c6b..513190c 100644 --- a/dtmsvr/trans_type_saga.go +++ b/dtmsvr/trans_type_saga.go @@ -193,11 +193,13 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { } } prepareToCompensate := func() { - toRunActions := pickToRunActions() - for _, b := range toRunActions { + _ = pickToRunActions() // flag started + for i := 1; i < len(branchResults); i += 2 { // these branches may have run. so flag them to status succeed, then run the corresponding // compensate - branchResults[b].status = dtmcli.StatusSucceed + if branchResults[i].started && branchResults[i].status == dtmcli.StatusPrepared { + branchResults[i].status = dtmcli.StatusSucceed + } } for i, b := range branchResults { if b.op == dtmcli.BranchCompensate && b.status != dtmcli.StatusSucceed && diff --git a/test/busi/barrier.go b/test/busi/barrier.go index 46463f7..9d8c30f 100644 --- a/test/busi/barrier.go +++ b/test/busi/barrier.go @@ -70,16 +70,35 @@ func init() { return tccAdjustTrading(tx, TransInUID, -reqFrom(c).Amount) }) })) + app.POST(BusiAPI+"/SagaRedisTransIn", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransInUID), reqFrom(c).Amount, 7*86400) + })) + app.POST(BusiAPI+"/SagaRedisTransInCom", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransInUID), -reqFrom(c).Amount, 7*86400) + })) + app.POST(BusiAPI+"/SagaRedisTransOut", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), -reqFrom(c).Amount, 7*86400) + })) + app.POST(BusiAPI+"/SagaRedisTransOutCom", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), reqFrom(c).Amount, 7*86400) + })) app.POST(BusiAPI+"/TccBTransOutTry", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { req := reqFrom(c) if req.TransOutResult != "" { return dtmcli.String2DtmError(req.TransOutResult) } + if req.Store == "redis" { + return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), req.Amount, 7*86400) + } + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustTrading(tx, TransOutUID, -req.Amount) }) })) app.POST(BusiAPI+"/TccBTransOutConfirm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + if reqFrom(c).Store == "redis" { + return nil + } return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount) }) @@ -90,6 +109,10 @@ func init() { // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) interface{} { + req := reqFrom(c) + if req.Store == "redis" { + return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), -req.Amount, 7*86400) + } return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustTrading(tx, TransOutUID, reqFrom(c).Amount) }) diff --git a/test/busi/base_types.go b/test/busi/base_types.go index 5aaaed9..38b332d 100644 --- a/test/busi/base_types.go +++ b/test/busi/base_types.go @@ -32,7 +32,13 @@ func (*UserAccount) TableName() string { return "dtm_busi.user_account" } -func GetBalanceByUid(uid int) int { +func GetBalanceByUid(uid int, store string) int { + if store == "redis" { + rd := RedisGet() + accA, err := rd.Get(rd.Context(), getRedisAccountKey(uid)).Result() + dtmimp.E2P(err) + return dtmimp.MustAtoi(accA) + } ua := UserAccount{} _ = dbGet().Must().Model(&ua).Where("user_id=?", uid).First(&ua) return dtmimp.MustAtoi(ua.Balance[:len(ua.Balance)-3]) @@ -43,6 +49,7 @@ type TransReq struct { Amount int `json:"amount"` TransInResult string `json:"trans_in_result"` TransOutResult string `json:"trans_out_Result"` + Store string `json:"store"` // default mysql, value can be mysql|redis } func (t *TransReq) String() string { @@ -119,3 +126,7 @@ type mainSwitchType struct { // MainSwitch controls busi success or fail var MainSwitch mainSwitchType + +func getRedisAccountKey(uid int) string { + return fmt.Sprintf("{a}-redis-account-key-%d", uid) +} diff --git a/test/busi/startup.go b/test/busi/startup.go index 0c6e329..2b40a5c 100644 --- a/test/busi/startup.go +++ b/test/busi/startup.go @@ -1,8 +1,10 @@ package busi import ( + "context" "fmt" + "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmutil" "github.com/gin-gonic/gin" ) @@ -20,4 +22,6 @@ func PopulateDB(skipDrop bool) { dtmutil.RunSQLScript(BusiConf, file, skipDrop) file = fmt.Sprintf("%s/dtmcli.barrier.%s.sql", dtmutil.GetSQLDir(), BusiConf.Driver) dtmutil.RunSQLScript(BusiConf, file, skipDrop) + _, err := RedisGet().FlushAll(context.Background()).Result() // redis barrier need clear + dtmimp.E2P(err) } diff --git a/test/busi/utils.go b/test/busi/utils.go index bc5a77f..9441d07 100644 --- a/test/busi/utils.go +++ b/test/busi/utils.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "strings" + sync "sync" "time" "github.com/dtm-labs/dtm/dtmcli" @@ -15,6 +16,7 @@ import ( "github.com/dtm-labs/dtm/dtmgrpc/dtmgpb" "github.com/dtm-labs/dtm/dtmutil" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "github.com/go-resty/resty/v2" grpc "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -113,3 +115,28 @@ func oldWrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc } } } + +var ( + rdb *redis.Client + once sync.Once +) + +func RedisGet() *redis.Client { + once.Do(func() { + logger.Debugf("connecting to client redis") + rdb = redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Username: "root", + Password: "", + }) + }) + return rdb +} + +func SetRedisBothAccount(accountA int, accountB int) { + rd := RedisGet() + _, err := rd.Set(rd.Context(), getRedisAccountKey(TransOutUID), accountA, 0).Result() + dtmimp.E2P(err) + _, err = rd.Set(rd.Context(), getRedisAccountKey(TransInUID), accountB, 0).Result() + dtmimp.E2P(err) +} diff --git a/test/msg_barrier_test.go b/test/msg_barrier_test.go index fe586e0..cd274fb 100644 --- a/test/msg_barrier_test.go +++ b/test/msg_barrier_test.go @@ -15,7 +15,7 @@ import ( ) func TestMsgPrepareAndSubmit(t *testing.T) { - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenTransReq(30, false, false) msg := dtmcli.NewMsg(DtmServer, gid). @@ -27,11 +27,11 @@ func TestMsgPrepareAndSubmit(t *testing.T) { waitTransProcessed(msg.Gid) assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid)) - assertNotSameBalance(t, before) + assertNotSameBalance(t, before, "mysql") } func TestMsgPrepareAndSubmitBusiFailed(t *testing.T) { - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenTransReq(30, false, false) msg := dtmcli.NewMsg(DtmServer, gid). @@ -40,11 +40,11 @@ func TestMsgPrepareAndSubmitBusiFailed(t *testing.T) { return errors.New("an error") }) assert.Error(t, err) - assertSameBalance(t, before) + assertSameBalance(t, before, "mysql") } func TestMsgPrepareAndSubmitPrepareFailed(t *testing.T) { - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenTransReq(30, false, false) msg := dtmcli.NewMsg(DtmServer+"not-exists", gid). @@ -53,14 +53,14 @@ func TestMsgPrepareAndSubmitPrepareFailed(t *testing.T) { return busi.SagaAdjustBalance(tx, busi.TransOutUID, -req.Amount, "SUCCESS") }) assert.Error(t, err) - assertSameBalance(t, before) + assertSameBalance(t, before, "mysql") } func TestMsgPrepareAndSubmitCommitFailed(t *testing.T) { if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit return } - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenTransReq(30, false, false) msg := dtmcli.NewMsg(DtmServer, gid). @@ -77,14 +77,14 @@ func TestMsgPrepareAndSubmitCommitFailed(t *testing.T) { g.Unpatch() assert.Error(t, err) cronTransOnceForwardNow(180) - assertSameBalance(t, before) + assertSameBalance(t, before, "mysql") } func TestMsgPrepareAndSubmitCommitAfterFailed(t *testing.T) { if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit return } - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenTransReq(30, false, false) msg := dtmcli.NewMsg(DtmServer, gid). @@ -101,5 +101,5 @@ func TestMsgPrepareAndSubmitCommitAfterFailed(t *testing.T) { }) assert.Error(t, err) cronTransOnceForwardNow(180) - assertNotSameBalance(t, before) + assertNotSameBalance(t, before, "mysql") } diff --git a/test/msg_grpc_barrier_test.go b/test/msg_grpc_barrier_test.go index ca1b93d..50edd6c 100644 --- a/test/msg_grpc_barrier_test.go +++ b/test/msg_grpc_barrier_test.go @@ -14,7 +14,7 @@ import ( ) func TestMsgGrpcPrepareAndSubmit(t *testing.T) { - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenBusiReq(30, false, false) msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid). @@ -26,14 +26,14 @@ func TestMsgGrpcPrepareAndSubmit(t *testing.T) { waitTransProcessed(msg.Gid) assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid)) - assertNotSameBalance(t, before) + assertNotSameBalance(t, before, "mysql") } func TestMsgGrpcPrepareAndSubmitCommitAfterFailed(t *testing.T) { if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit return } - before := getBeforeBalances() + before := getBeforeBalances("mysql") gid := dtmimp.GetFuncName() req := busi.GenBusiReq(30, false, false) msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid). @@ -50,5 +50,5 @@ func TestMsgGrpcPrepareAndSubmitCommitAfterFailed(t *testing.T) { }) assert.Error(t, err) cronTransOnceForwardNow(180) - assertNotSameBalance(t, before) + assertNotSameBalance(t, before, "mysql") } diff --git a/test/saga_barrier_redis_test.go b/test/saga_barrier_redis_test.go new file mode 100644 index 0000000..c37dd29 --- /dev/null +++ b/test/saga_barrier_redis_test.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 yedf. All rights reserved. + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +package test + +import ( + "testing" + + "github.com/dtm-labs/dtm/dtmcli" + "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/dtm-labs/dtm/test/busi" + "github.com/stretchr/testify/assert" +) + +func TestSagaBarrierRedisNormal(t *testing.T) { + busi.SetRedisBothAccount(100, 100) + before := getBeforeBalances("redis") + saga := genSagaBarrierRedis(dtmimp.GetFuncName()) + err := saga.Submit() + assert.Nil(t, err) + waitTransProcessed(saga.Gid) + assert.Equal(t, []string{StatusPrepared, StatusSucceed, StatusPrepared, StatusSucceed}, getBranchesStatus(saga.Gid)) + assert.Equal(t, StatusSucceed, getTransStatus(saga.Gid)) + assertNotSameBalance(t, before, "redis") +} + +func TestSagaBarrierRedisRollback(t *testing.T) { + busi.SetRedisBothAccount(20, 20) + before := getBeforeBalances("redis") + saga := genSagaBarrierRedis(dtmimp.GetFuncName()) + err := saga.Submit() + assert.Nil(t, err) + waitTransProcessed(saga.Gid) + assert.Equal(t, StatusFailed, getTransStatus(saga.Gid)) + assert.Equal(t, []string{StatusSucceed, StatusSucceed, StatusSucceed, StatusFailed}, getBranchesStatus(saga.Gid)) + assertSameBalance(t, before, "redis") +} + +func genSagaBarrierRedis(gid string) *dtmcli.Saga { + req := busi.GenTransReq(30, false, false) + req.Store = "redis" + return dtmcli.NewSaga(DtmServer, gid). + Add(Busi+"/SagaRedisTransIn", Busi+"/SagaRedisTransInCom", req). + Add(Busi+"/SagaRedisTransOut", Busi+"/SagaRedisTransOutCom", req) +} diff --git a/test/tcc_barrier_test.go b/test/tcc_barrier_test.go index 2a4ec94..8a1ae57 100644 --- a/test/tcc_barrier_test.go +++ b/test/tcc_barrier_test.go @@ -50,14 +50,23 @@ func TestTccBarrierRollback(t *testing.T) { assert.Equal(t, []string{StatusSucceed, StatusPrepared, StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) } -func TestTccBarrierDisorder(t *testing.T) { - before := getBeforeBalances() +func TestTccBarrierDisorderMysql(t *testing.T) { + runTestTccBarrierDisorder(t, "mysql") +} + +func TestTccBarrierDisorderRedis(t *testing.T) { + busi.SetRedisBothAccount(200, 200) + runTestTccBarrierDisorder(t, "redis") +} + +func runTestTccBarrierDisorder(t *testing.T, store string) { + before := getBeforeBalances(store) cancelFinishedChan := make(chan string, 2) cancelCanReturnChan := make(chan string, 2) - gid := dtmimp.GetFuncName() + gid := dtmimp.GetFuncName() + store cronFinished := make(chan string, 2) err := dtmcli.TccGlobalTransaction(DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { - body := &busi.TransReq{Amount: 30} + body := &busi.TransReq{Amount: 30, Store: store} tryURL := Busi + "/TccBTransOutTry" confirmURL := Busi + "/TccBTransOutConfirm" cancelURL := Busi + "/TccBSleepCancel" @@ -122,7 +131,7 @@ func TestTccBarrierDisorder(t *testing.T) { assert.Error(t, err, fmt.Errorf("a cancelled tcc")) assert.Equal(t, []string{StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) assert.Equal(t, StatusFailed, getTransStatus(gid)) - assertSameBalance(t, before) + assertSameBalance(t, before, store) } func TestTccBarrierPanic(t *testing.T) { diff --git a/test/types.go b/test/types.go index 0db969d..adff1a6 100644 --- a/test/types.go +++ b/test/types.go @@ -83,22 +83,22 @@ const ( StatusAborting = dtmcli.StatusAborting ) -func getBeforeBalances() []int { - b1 := busi.GetBalanceByUid(busi.TransOutUID) - b2 := busi.GetBalanceByUid(busi.TransInUID) +func getBeforeBalances(store string) []int { + b1 := busi.GetBalanceByUid(busi.TransOutUID, store) + b2 := busi.GetBalanceByUid(busi.TransInUID, store) return []int{b1, b2} } -func assertSameBalance(t *testing.T, before []int) { - b1 := busi.GetBalanceByUid(busi.TransOutUID) - b2 := busi.GetBalanceByUid(busi.TransInUID) +func assertSameBalance(t *testing.T, before []int, store string) { + b1 := busi.GetBalanceByUid(busi.TransOutUID, store) + b2 := busi.GetBalanceByUid(busi.TransInUID, store) assert.Equal(t, before[0], b1) assert.Equal(t, before[1], b2) } -func assertNotSameBalance(t *testing.T, before []int) { - b1 := busi.GetBalanceByUid(busi.TransOutUID) - b2 := busi.GetBalanceByUid(busi.TransInUID) +func assertNotSameBalance(t *testing.T, before []int, store string) { + b1 := busi.GetBalanceByUid(busi.TransOutUID, store) + b2 := busi.GetBalanceByUid(busi.TransInUID, store) assert.NotEqual(t, before[0], b1) assert.Equal(t, before[0]+before[1], b1+b2) }