Browse Source

barrier add RedisCheckAdjustAmount

pull/179/head
yedf2 4 years ago
parent
commit
4e11fffd86
  1. 45
      dtmcli/barrier.go
  2. 23
      test/busi/barrier.go
  3. 13
      test/busi/base_types.go
  4. 27
      test/busi/utils.go
  5. 20
      test/msg_barrier_test.go
  6. 8
      test/msg_grpc_barrier_test.go
  7. 48
      test/saga_barrier_redis_test.go
  8. 19
      test/tcc_barrier_test.go
  9. 18
      test/types.go

45
dtmcli/barrier.go

@ -13,6 +13,7 @@ import (
"github.com/dtm-labs/dtm/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/dtmcli/logger"
"github.com/go-redis/redis/v8"
)
// BarrierBusiFunc type for busi func
@ -76,12 +77,12 @@ func (bb *BranchBarrier) Call(tx *sql.Tx, busiCall BarrierBusiFunc) (rerr error)
}
}()
ti := bb
originType := map[string]string{
originOp := map[string]string{
BranchCancel: BranchTry,
BranchCompensate: BranchAction,
}[ti.Op]
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, bid, ti.Op)
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originOp, bid, ti.Op)
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.Op, bid, ti.Op)
logger.Debugf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
if (ti.Op == BranchCancel || ti.Op == BranchCompensate) && originAffected > 0 || // 这个是空补偿
@ -114,3 +115,43 @@ func (bb *BranchBarrier) QueryPrepared(db *sql.DB) error {
}
return err
}
// RedisCheckAdjustAmount check the value of key is valid and >= amount. then adjust the amount
func (bb *BranchBarrier) RedisCheckAdjustAmount(rd *redis.Client, key string, amount int, barrierExpire int) error {
bkey1 := fmt.Sprintf("%s-%s-%s-%s-%02d", key, bb.Gid, bb.BranchID, bb.Op, bb.BarrierID)
originOp := map[string]string{
BranchCancel: BranchTry,
BranchCompensate: BranchAction,
}[bb.Op]
bkey2 := fmt.Sprintf("%s-%s-%s-%s-%02d", key, bb.Gid, bb.BranchID, originOp, bb.BarrierID)
v, err := rd.Eval(rd.Context(), ` -- RedisCheckAdjustAmount
local v = redis.call('GET', KEYS[1])
local e1 = redis.call('GET', KEYS[2])
if v == false or v + ARGV[1] < 0 then
return 'FAILURE'
end
if e1 ~= false then
return
end
if ARGV[2] ~= '' then
local e2 = redis.call('GET', KEYS[3])
if e2 ~= false then
return
end
redis.call('SET', KEYS[3], 'origin', 'EX', ARGV[3])
end
redis.call('INCRBY', KEYS[1], ARGV[1])
redis.call('SET', KEYS[2], 'op', 'EX', ARGV[3])
`, []string{key, bkey1, bkey2}, amount, originOp, barrierExpire).Result()
logger.Debugf("lua return v: %v err: %v", v, err)
if err == redis.Nil {
err = nil
}
if err == nil && v == ResultFailure {
err = ErrFailure
}
return err
}

23
test/busi/barrier.go

@ -70,16 +70,35 @@ func init() {
return tccAdjustTrading(tx, TransInUID, -reqFrom(c).Amount)
})
}))
app.POST(BusiAPI+"/SagaRedisTransIn", dtmutil.WrapHandler2(func(c *gin.Context) interface{} {
return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransInUID), reqFrom(c).Amount, 7*86400)
}))
app.POST(BusiAPI+"/SagaRedisTransInCom", dtmutil.WrapHandler2(func(c *gin.Context) interface{} {
return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransInUID), -reqFrom(c).Amount, 7*86400)
}))
app.POST(BusiAPI+"/SagaRedisTransOut", dtmutil.WrapHandler2(func(c *gin.Context) interface{} {
return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), -reqFrom(c).Amount, 7*86400)
}))
app.POST(BusiAPI+"/SagaRedisTransOutCom", dtmutil.WrapHandler2(func(c *gin.Context) interface{} {
return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), reqFrom(c).Amount, 7*86400)
}))
app.POST(BusiAPI+"/TccBTransOutTry", dtmutil.WrapHandler2(func(c *gin.Context) interface{} {
req := reqFrom(c)
if req.TransOutResult != "" {
return dtmcli.String2DtmError(req.TransOutResult)
}
if req.Store == "redis" {
return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), req.Amount, 7*86400)
}
return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustTrading(tx, TransOutUID, -req.Amount)
})
}))
app.POST(BusiAPI+"/TccBTransOutConfirm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} {
if reqFrom(c).Store == "redis" {
return nil
}
return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount)
})
@ -90,6 +109,10 @@ func init() {
// TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) interface{} {
req := reqFrom(c)
if req.Store == "redis" {
return MustBarrierFromGin(c).RedisCheckAdjustAmount(RedisGet(), getRedisAccountKey(TransOutUID), -req.Amount, 7*86400)
}
return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustTrading(tx, TransOutUID, reqFrom(c).Amount)
})

13
test/busi/base_types.go

@ -32,7 +32,13 @@ func (*UserAccount) TableName() string {
return "dtm_busi.user_account"
}
func GetBalanceByUid(uid int) int {
func GetBalanceByUid(uid int, store string) int {
if store == "redis" {
rd := RedisGet()
accA, err := rd.Get(rd.Context(), getRedisAccountKey(uid)).Result()
dtmimp.E2P(err)
return dtmimp.MustAtoi(accA)
}
ua := UserAccount{}
_ = dbGet().Must().Model(&ua).Where("user_id=?", uid).First(&ua)
return dtmimp.MustAtoi(ua.Balance[:len(ua.Balance)-3])
@ -43,6 +49,7 @@ type TransReq struct {
Amount int `json:"amount"`
TransInResult string `json:"trans_in_result"`
TransOutResult string `json:"trans_out_Result"`
Store string `json:"store"` // default mysql, value can be mysql|redis
}
func (t *TransReq) String() string {
@ -119,3 +126,7 @@ type mainSwitchType struct {
// MainSwitch controls busi success or fail
var MainSwitch mainSwitchType
func getRedisAccountKey(uid int) string {
return fmt.Sprintf("{a}-redis-account-key-%d", uid)
}

27
test/busi/utils.go

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"strings"
sync "sync"
"time"
"github.com/dtm-labs/dtm/dtmcli"
@ -15,6 +16,7 @@ import (
"github.com/dtm-labs/dtm/dtmgrpc/dtmgpb"
"github.com/dtm-labs/dtm/dtmutil"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/go-resty/resty/v2"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@ -113,3 +115,28 @@ func oldWrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc
}
}
}
var (
rdb *redis.Client
once sync.Once
)
func RedisGet() *redis.Client {
once.Do(func() {
logger.Debugf("connecting to client redis")
rdb = redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Username: "root",
Password: "",
})
})
return rdb
}
func SetRedisBothAccount(accountA int, accountB int) {
rd := RedisGet()
_, err := rd.Set(rd.Context(), getRedisAccountKey(TransOutUID), accountA, 0).Result()
dtmimp.E2P(err)
_, err = rd.Set(rd.Context(), getRedisAccountKey(TransInUID), accountB, 0).Result()
dtmimp.E2P(err)
}

20
test/msg_barrier_test.go

@ -15,7 +15,7 @@ import (
)
func TestMsgPrepareAndSubmit(t *testing.T) {
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenTransReq(30, false, false)
msg := dtmcli.NewMsg(DtmServer, gid).
@ -27,11 +27,11 @@ func TestMsgPrepareAndSubmit(t *testing.T) {
waitTransProcessed(msg.Gid)
assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid))
assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid))
assertNotSameBalance(t, before)
assertNotSameBalance(t, before, "mysql")
}
func TestMsgPrepareAndSubmitBusiFailed(t *testing.T) {
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenTransReq(30, false, false)
msg := dtmcli.NewMsg(DtmServer, gid).
@ -40,11 +40,11 @@ func TestMsgPrepareAndSubmitBusiFailed(t *testing.T) {
return errors.New("an error")
})
assert.Error(t, err)
assertSameBalance(t, before)
assertSameBalance(t, before, "mysql")
}
func TestMsgPrepareAndSubmitPrepareFailed(t *testing.T) {
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenTransReq(30, false, false)
msg := dtmcli.NewMsg(DtmServer+"not-exists", gid).
@ -53,14 +53,14 @@ func TestMsgPrepareAndSubmitPrepareFailed(t *testing.T) {
return busi.SagaAdjustBalance(tx, busi.TransOutUID, -req.Amount, "SUCCESS")
})
assert.Error(t, err)
assertSameBalance(t, before)
assertSameBalance(t, before, "mysql")
}
func TestMsgPrepareAndSubmitCommitFailed(t *testing.T) {
if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit
return
}
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenTransReq(30, false, false)
msg := dtmcli.NewMsg(DtmServer, gid).
@ -77,14 +77,14 @@ func TestMsgPrepareAndSubmitCommitFailed(t *testing.T) {
g.Unpatch()
assert.Error(t, err)
cronTransOnceForwardNow(180)
assertSameBalance(t, before)
assertSameBalance(t, before, "mysql")
}
func TestMsgPrepareAndSubmitCommitAfterFailed(t *testing.T) {
if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit
return
}
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenTransReq(30, false, false)
msg := dtmcli.NewMsg(DtmServer, gid).
@ -101,5 +101,5 @@ func TestMsgPrepareAndSubmitCommitAfterFailed(t *testing.T) {
})
assert.Error(t, err)
cronTransOnceForwardNow(180)
assertNotSameBalance(t, before)
assertNotSameBalance(t, before, "mysql")
}

8
test/msg_grpc_barrier_test.go

@ -14,7 +14,7 @@ import (
)
func TestMsgGrpcPrepareAndSubmit(t *testing.T) {
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenBusiReq(30, false, false)
msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid).
@ -26,14 +26,14 @@ func TestMsgGrpcPrepareAndSubmit(t *testing.T) {
waitTransProcessed(msg.Gid)
assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid))
assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid))
assertNotSameBalance(t, before)
assertNotSameBalance(t, before, "mysql")
}
func TestMsgGrpcPrepareAndSubmitCommitAfterFailed(t *testing.T) {
if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit
return
}
before := getBeforeBalances()
before := getBeforeBalances("mysql")
gid := dtmimp.GetFuncName()
req := busi.GenBusiReq(30, false, false)
msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid).
@ -50,5 +50,5 @@ func TestMsgGrpcPrepareAndSubmitCommitAfterFailed(t *testing.T) {
})
assert.Error(t, err)
cronTransOnceForwardNow(180)
assertNotSameBalance(t, before)
assertNotSameBalance(t, before, "mysql")
}

48
test/saga_barrier_redis_test.go

@ -0,0 +1,48 @@
/*
* Copyright (c) 2021 yedf. All rights reserved.
* Use of this source code is governed by a BSD-style
* license that can be found in the LICENSE file.
*/
package test
import (
"testing"
"github.com/dtm-labs/dtm/dtmcli"
"github.com/dtm-labs/dtm/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/test/busi"
"github.com/stretchr/testify/assert"
)
func TestSagaBarrierRedisNormal(t *testing.T) {
busi.SetRedisBothAccount(100, 100)
before := getBeforeBalances("redis")
saga := genSagaBarrierRedis(dtmimp.GetFuncName())
err := saga.Submit()
assert.Nil(t, err)
waitTransProcessed(saga.Gid)
assert.Equal(t, []string{StatusPrepared, StatusSucceed, StatusPrepared, StatusSucceed}, getBranchesStatus(saga.Gid))
assert.Equal(t, StatusSucceed, getTransStatus(saga.Gid))
assertNotSameBalance(t, before, "redis")
}
func TestSagaBarrierRedisRollback(t *testing.T) {
busi.SetRedisBothAccount(20, 20)
before := getBeforeBalances("redis")
saga := genSagaBarrierRedis(dtmimp.GetFuncName())
err := saga.Submit()
assert.Nil(t, err)
waitTransProcessed(saga.Gid)
assert.Equal(t, StatusFailed, getTransStatus(saga.Gid))
assert.Equal(t, []string{StatusSucceed, StatusSucceed, StatusSucceed, StatusFailed}, getBranchesStatus(saga.Gid))
assertSameBalance(t, before, "redis")
}
func genSagaBarrierRedis(gid string) *dtmcli.Saga {
req := busi.GenTransReq(30, false, false)
req.Store = "redis"
return dtmcli.NewSaga(DtmServer, gid).
Add(Busi+"/SagaRedisTransIn", Busi+"/SagaRedisTransInCom", req).
Add(Busi+"/SagaRedisTransOut", Busi+"/SagaRedisTransOutCom", req)
}

19
test/tcc_barrier_test.go

@ -50,14 +50,23 @@ func TestTccBarrierRollback(t *testing.T) {
assert.Equal(t, []string{StatusSucceed, StatusPrepared, StatusSucceed, StatusPrepared}, getBranchesStatus(gid))
}
func TestTccBarrierDisorder(t *testing.T) {
before := getBeforeBalances()
func TestTccBarrierDisorderMysql(t *testing.T) {
runTestTccBarrierDisorder(t, "mysql")
}
func TestTccBarrierDisorderRedis(t *testing.T) {
busi.SetRedisBothAccount(200, 200)
runTestTccBarrierDisorder(t, "redis")
}
func runTestTccBarrierDisorder(t *testing.T, store string) {
before := getBeforeBalances(store)
cancelFinishedChan := make(chan string, 2)
cancelCanReturnChan := make(chan string, 2)
gid := dtmimp.GetFuncName()
gid := dtmimp.GetFuncName() + store
cronFinished := make(chan string, 2)
err := dtmcli.TccGlobalTransaction(DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) {
body := &busi.TransReq{Amount: 30}
body := &busi.TransReq{Amount: 30, Store: store}
tryURL := Busi + "/TccBTransOutTry"
confirmURL := Busi + "/TccBTransOutConfirm"
cancelURL := Busi + "/TccBSleepCancel"
@ -122,7 +131,7 @@ func TestTccBarrierDisorder(t *testing.T) {
assert.Error(t, err, fmt.Errorf("a cancelled tcc"))
assert.Equal(t, []string{StatusSucceed, StatusPrepared}, getBranchesStatus(gid))
assert.Equal(t, StatusFailed, getTransStatus(gid))
assertSameBalance(t, before)
assertSameBalance(t, before, store)
}
func TestTccBarrierPanic(t *testing.T) {

18
test/types.go

@ -83,22 +83,22 @@ const (
StatusAborting = dtmcli.StatusAborting
)
func getBeforeBalances() []int {
b1 := busi.GetBalanceByUid(busi.TransOutUID)
b2 := busi.GetBalanceByUid(busi.TransInUID)
func getBeforeBalances(store string) []int {
b1 := busi.GetBalanceByUid(busi.TransOutUID, store)
b2 := busi.GetBalanceByUid(busi.TransInUID, store)
return []int{b1, b2}
}
func assertSameBalance(t *testing.T, before []int) {
b1 := busi.GetBalanceByUid(busi.TransOutUID)
b2 := busi.GetBalanceByUid(busi.TransInUID)
func assertSameBalance(t *testing.T, before []int, store string) {
b1 := busi.GetBalanceByUid(busi.TransOutUID, store)
b2 := busi.GetBalanceByUid(busi.TransInUID, store)
assert.Equal(t, before[0], b1)
assert.Equal(t, before[1], b2)
}
func assertNotSameBalance(t *testing.T, before []int) {
b1 := busi.GetBalanceByUid(busi.TransOutUID)
b2 := busi.GetBalanceByUid(busi.TransInUID)
func assertNotSameBalance(t *testing.T, before []int, store string) {
b1 := busi.GetBalanceByUid(busi.TransOutUID, store)
b2 := busi.GetBalanceByUid(busi.TransInUID, store)
assert.NotEqual(t, before[0], b1)
assert.Equal(t, before[0]+before[1], b1+b2)
}

Loading…
Cancel
Save