From aa9358028fa1c367c2c5b73019e6bccb21fedc5f Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Wed, 5 Jan 2022 23:08:00 +0800 Subject: [PATCH] first test case ok --- dtmcli/barrier.go | 33 ++++++--------------------------- dtmcli/msg.go | 23 ++++++++++++++++++++++- dtmutil/utils.go | 2 +- test/busi/barrier.go | 32 ++++++++++++++++---------------- test/busi/base_grpc.go | 4 ++-- test/busi/base_http.go | 12 +++++++++--- test/busi/base_types.go | 11 +++-------- test/busi/busi.go | 6 +++--- test/msg_barrier_test.go | 26 ++++++++++++++++++++++++++ test/tcc_barrier_test.go | 6 ++---- test/types.go | 22 ++++++++++++++++++++++ 11 files changed, 112 insertions(+), 65 deletions(-) create mode 100644 test/msg_barrier_test.go diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index f515550..92aab5b 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -102,35 +102,14 @@ func (bb *BranchBarrier) CallWithDB(db *sql.DB, busiCall BarrierBusiFunc) error } func (bb *BranchBarrier) QueryPrepared(db *sql.DB) error { - affected, err := insertBarrier(db, bb.TransType, bb.Gid, bb.BranchID, BranchAction, bb.BranchID, bb.Op) - if err != nil { - return err - } - if affected > 0 { - return ErrFailure - } - return nil -} - -func (bb *BranchBarrier) PrepareAndSubmit(msg *Msg, queryPrepared string, db *sql.DB, busiCall BarrierBusiFunc) (err error) { - var tx *sql.Tx - tx, err = db.Begin() + _, err := insertBarrier(db, bb.TransType, bb.Gid, bb.BranchID, BranchAction, bb.BranchID, "rollback") + var reason string if err == nil { - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - err = busiCall(tx) + sql := fmt.Sprintf("select reason from %s where gid=? and branch_id=? and op=? and barrier_id=?", dtmimp.BarrierTableName) + err = db.QueryRow(sql, bb.Gid, bb.BranchID, bb.Op, bb.BarrierID).Scan(&reason) } - if err == nil { - err = msg.Prepare(queryPrepared) - } - if err == nil { - err = tx.Commit() - } - if err == nil { - return msg.Submit() // should not assign err. or else defer may try to rollback a committed tx + if reason == "rollback" { + return ErrFailure } return err } diff --git a/dtmcli/msg.go b/dtmcli/msg.go index a461ea2..815b97d 100644 --- a/dtmcli/msg.go +++ b/dtmcli/msg.go @@ -6,7 +6,11 @@ package dtmcli -import "github.com/dtm-labs/dtm/dtmcli/dtmimp" +import ( + "database/sql" + + "github.com/dtm-labs/dtm/dtmcli/dtmimp" +) // Msg reliable msg type type Msg struct { @@ -35,3 +39,20 @@ func (s *Msg) Prepare(queryPrepared string) error { func (s *Msg) Submit() error { return dtmimp.TransCallDtm(&s.TransBase, s, "submit") } + +func (s *Msg) PrepareAndSubmit(queryPrepared string, db *sql.DB, busiCall BarrierBusiFunc) error { + bb, err := BarrierFrom(s.TransType, s.Gid, "00", "msg") // a special barrier for msg QueryPrepared + if err == nil { + err = bb.CallWithDB(db, func(tx *sql.Tx) error { + err := busiCall(tx) + if err == nil { + err = s.Prepare(queryPrepared) + } + return err + }) + } + if err == nil { + err = s.Submit() + } + return err +} diff --git a/dtmutil/utils.go b/dtmutil/utils.go index 4fbbc7d..0b08915 100644 --- a/dtmutil/utils.go +++ b/dtmutil/utils.go @@ -38,7 +38,7 @@ func GetGinApp() *gin.Engine { c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(rb)) } } - logger.Debugf("begin %s %s query: %s body: %s", c.Request.Method, c.FullPath(), c.Request.URL.RawQuery, body) + logger.Debugf("begin %s %s body: %s", c.Request.Method, c.Request.URL, body) c.Next() }) app.Any("/api/ping", func(c *gin.Context) { c.JSON(200, map[string]interface{}{"msg": "pong"}) }) diff --git a/test/busi/barrier.go b/test/busi/barrier.go index df30145..8ca19de 100644 --- a/test/busi/barrier.go +++ b/test/busi/barrier.go @@ -17,29 +17,29 @@ import ( ) func init() { - setupFuncs["TccBarrierSetup"] = func(app *gin.Engine) { + setupFuncs["BarrierSetup"] = func(app *gin.Engine) { app.POST(BusiAPI+"/SagaBTransIn", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaAdjustBalance(tx, transInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) + return SagaAdjustBalance(tx, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) }) })) app.POST(BusiAPI+"/SagaBTransInCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaAdjustBalance(tx, transInUID, -reqFrom(c).Amount, "") + return SagaAdjustBalance(tx, TransInUID, -reqFrom(c).Amount, "") }) })) app.POST(BusiAPI+"/SagaBTransOut", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaAdjustBalance(tx, transOutUID, -reqFrom(c).Amount, reqFrom(c).TransOutResult) + return SagaAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount, reqFrom(c).TransOutResult) }) })) app.POST(BusiAPI+"/SagaBTransOutCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaAdjustBalance(tx, transOutUID, reqFrom(c).Amount, "") + return SagaAdjustBalance(tx, TransOutUID, reqFrom(c).Amount, "") }) })) app.POST(BusiAPI+"/SagaBTransOutGorm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { @@ -47,7 +47,7 @@ func init() { barrier := MustBarrierFromGin(c) tx := dbGet().DB.Begin() return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(tx1 *sql.Tx) error { - return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, transOutUID).Error + return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, TransOutUID).Error }) })) @@ -57,17 +57,17 @@ func init() { return req.TransInResult, nil } return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { - return tccAdjustTrading(tx, transInUID, req.Amount) + return tccAdjustTrading(tx, TransInUID, req.Amount) }) })) app.POST(BusiAPI+"/TccBTransInConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { - return tccAdjustBalance(tx, transInUID, reqFrom(c).Amount) + return tccAdjustBalance(tx, TransInUID, reqFrom(c).Amount) }) })) app.POST(BusiAPI+"/TccBTransInCancel", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { - return tccAdjustTrading(tx, transInUID, -reqFrom(c).Amount) + return tccAdjustTrading(tx, TransInUID, -reqFrom(c).Amount) }) })) app.POST(BusiAPI+"/TccBTransOutTry", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { @@ -76,12 +76,12 @@ func init() { return req.TransOutResult, nil } return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { - return tccAdjustTrading(tx, transOutUID, -req.Amount) + return tccAdjustTrading(tx, TransOutUID, -req.Amount) }) })) app.POST(BusiAPI+"/TccBTransOutConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { - return tccAdjustBalance(tx, transOutUID, -reqFrom(c).Amount) + return tccAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount) }) })) app.POST(BusiAPI+"/TccBTransOutCancel", dtmutil.WrapHandler(TccBarrierTransOutCancel)) @@ -91,34 +91,34 @@ func init() { // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { - return tccAdjustTrading(tx, transOutUID, reqFrom(c).Amount) + return tccAdjustTrading(tx, TransOutUID, reqFrom(c).Amount) }) } func (s *busiServer) TransInBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaGrpcAdjustBalance(tx, transInUID, in.Amount, in.TransInResult) + return sagaGrpcAdjustBalance(tx, TransInUID, in.Amount, in.TransInResult) }) } func (s *busiServer) TransOutBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaGrpcAdjustBalance(tx, transOutUID, -in.Amount, in.TransOutResult) + return sagaGrpcAdjustBalance(tx, TransOutUID, -in.Amount, in.TransOutResult) }) } func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaGrpcAdjustBalance(tx, transInUID, -in.Amount, "") + return sagaGrpcAdjustBalance(tx, TransInUID, -in.Amount, "") }) } func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { - return sagaGrpcAdjustBalance(tx, transOutUID, in.Amount, "") + return sagaGrpcAdjustBalance(tx, TransOutUID, in.Amount, "") }) } diff --git a/test/busi/base_grpc.go b/test/busi/base_grpc.go index b1da93e..c858a1f 100644 --- a/test/busi/base_grpc.go +++ b/test/busi/base_grpc.go @@ -102,13 +102,13 @@ func (s *busiServer) TransOutTcc(ctx context.Context, in *BusiReq) (*emptypb.Emp func (s *busiServer) TransInXa(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { return &emptypb.Empty{}, XaGrpcClient.XaLocalTransaction(ctx, in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { - return sagaGrpcAdjustBalance(db, transInUID, in.Amount, in.TransInResult) + return sagaGrpcAdjustBalance(db, TransInUID, in.Amount, in.TransInResult) }) } func (s *busiServer) TransOutXa(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { return &emptypb.Empty{}, XaGrpcClient.XaLocalTransaction(ctx, in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { - return sagaGrpcAdjustBalance(db, transOutUID, in.Amount, in.TransOutResult) + return sagaGrpcAdjustBalance(db, TransOutUID, in.Amount, in.TransOutResult) }) } diff --git a/test/busi/base_http.go b/test/busi/base_http.go index 7a9838b..83da433 100644 --- a/test/busi/base_http.go +++ b/test/busi/base_http.go @@ -103,15 +103,21 @@ func BaseAddRoute(app *gin.Engine) { logger.Debugf("%s QueryPrepared", c.Query("gid")) return dtmimp.OrString(MainSwitch.QueryPreparedResult.Fetch(), dtmcli.ResultSuccess), nil })) + app.GET(BusiAPI+"/QueryPreparedB", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + logger.Debugf("%s QueryPreparedB", c.Query("gid")) + bb := MustBarrierFromGin(c) + db := dbGet().ToSQLDB() + return error2Resp(bb.QueryPrepared(db)) + })) app.POST(BusiAPI+"/TransInXa", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { - return sagaAdjustBalance(db, transInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) + return SagaAdjustBalance(db, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) }) return error2Resp(err) })) app.POST(BusiAPI+"/TransOutXa", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { - return sagaAdjustBalance(db, transOutUID, reqFrom(c).Amount, reqFrom(c).TransOutResult) + return SagaAdjustBalance(db, TransOutUID, reqFrom(c).Amount, reqFrom(c).TransOutResult) }) return error2Resp(err) })) @@ -137,7 +143,7 @@ func BaseAddRoute(app *gin.Engine) { if err != nil { return err } - dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, transOutUID) + dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, TransOutUID) return dbr.Error }) return error2Resp(err) diff --git a/test/busi/base_types.go b/test/busi/base_types.go index f0d29fb..5aaaed9 100644 --- a/test/busi/base_types.go +++ b/test/busi/base_types.go @@ -32,15 +32,10 @@ func (*UserAccount) TableName() string { return "dtm_busi.user_account" } -func GetUserAccountByUid(uid int) *UserAccount { +func GetBalanceByUid(uid int) int { ua := UserAccount{} - dbr := dbGet().Must().Model(&ua).Where("user_id=?", uid).First(&ua) - dtmimp.E2P(dbr.Error) - return &ua -} - -func IsEqual(ua1, ua2 *UserAccount) bool { - return ua1.UserId == ua2.UserId && ua1.Balance == ua2.Balance && ua1.TradingBalance == ua2.TradingBalance + _ = dbGet().Must().Model(&ua).Where("user_id=?", uid).First(&ua) + return dtmimp.MustAtoi(ua.Balance[:len(ua.Balance)-3]) } // TransReq transaction request payload diff --git a/test/busi/busi.go b/test/busi/busi.go index 13eb7e8..7d6c064 100644 --- a/test/busi/busi.go +++ b/test/busi/busi.go @@ -13,8 +13,8 @@ import ( status "google.golang.org/grpc/status" ) -const transOutUID = 1 -const transInUID = 2 +const TransOutUID = 1 +const TransInUID = 2 func handleGrpcBusiness(in *BusiReq, result1 string, result2 string, busi string) error { res := dtmimp.OrString(result1, result2, dtmcli.ResultSuccess) @@ -59,7 +59,7 @@ func sagaGrpcAdjustBalance(db dtmcli.DB, uid int, amount int64, result string) e } -func sagaAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { +func SagaAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { if strings.Contains(result, dtmcli.ResultFailure) { return dtmcli.ErrFailure } diff --git a/test/msg_barrier_test.go b/test/msg_barrier_test.go new file mode 100644 index 0000000..7a77b7a --- /dev/null +++ b/test/msg_barrier_test.go @@ -0,0 +1,26 @@ +package test + +import ( + "database/sql" + "testing" + + "github.com/dtm-labs/dtm/dtmcli" + "github.com/dtm-labs/dtm/test/busi" + "github.com/stretchr/testify/assert" +) + +func TestMsgPrepareAndSubmit(t *testing.T) { + before := getBeforeBalances() + gid := dtmcli.MustGenGid(DtmServer) + req := busi.GenTransReq(30, false, false) + msg := dtmcli.NewMsg(DtmServer, gid). + Add(busi.Busi+"/SagaBTransIn1", req) + err := msg.PrepareAndSubmit(Busi+"/QueryPreparedB", dbGet().ToSQLDB(), func(tx *sql.Tx) error { + return busi.SagaAdjustBalance(tx, busi.TransOutUID, -req.Amount, "SUCCESS") + }) + assert.Nil(t, err) + waitTransProcessed(msg.Gid) + assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid)) + assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid)) + assertNotSameBalance(t, before) +} diff --git a/test/tcc_barrier_test.go b/test/tcc_barrier_test.go index 87c3aea..376b3ca 100644 --- a/test/tcc_barrier_test.go +++ b/test/tcc_barrier_test.go @@ -51,8 +51,7 @@ func TestTccBarrierRollback(t *testing.T) { } func TestTccBarrierDisorder(t *testing.T) { - ua1 := busi.GetUserAccountByUid(1) - ua2 := busi.GetUserAccountByUid(2) + before := getBeforeBalances() cancelFinishedChan := make(chan string, 2) cancelCanReturnChan := make(chan string, 2) gid := dtmimp.GetFuncName() @@ -123,8 +122,7 @@ func TestTccBarrierDisorder(t *testing.T) { assert.Error(t, err, fmt.Errorf("a cancelled tcc")) assert.Equal(t, []string{StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) assert.Equal(t, StatusFailed, getTransStatus(gid)) - assert.True(t, busi.IsEqual(ua1, busi.GetUserAccountByUid(1))) - assert.True(t, busi.IsEqual(ua2, busi.GetUserAccountByUid(2))) + assertSameBalance(t, before) } func TestTccBarrierPanic(t *testing.T) { diff --git a/test/types.go b/test/types.go index 4fcbe34..c583a28 100644 --- a/test/types.go +++ b/test/types.go @@ -7,6 +7,7 @@ package test import ( + "testing" "time" "github.com/dtm-labs/dtm/dtmcli" @@ -16,6 +17,7 @@ import ( "github.com/dtm-labs/dtm/dtmsvr/config" "github.com/dtm-labs/dtm/dtmutil" "github.com/dtm-labs/dtm/test/busi" + "github.com/stretchr/testify/assert" ) var conf = &config.Config @@ -80,3 +82,23 @@ const ( // StatusAborting status for global trans status. StatusAborting = dtmcli.StatusAborting ) + +func getBeforeBalances() []int { + b1 := busi.GetBalanceByUid(busi.TransOutUID) + b2 := busi.GetBalanceByUid(busi.TransInUID) + return []int{b1, b2} +} + +func assertSameBalance(t *testing.T, before []int) { + b1 := busi.GetBalanceByUid(busi.TransOutUID) + b2 := busi.GetBalanceByUid(busi.TransInUID) + assert.Equal(t, before[0], b1) + assert.Equal(t, before[1], b2) +} + +func assertNotSameBalance(t *testing.T, before []int) { + b1 := busi.GetBalanceByUid(busi.TransOutUID) + b2 := busi.GetBalanceByUid(busi.TransInUID) + assert.NotEqual(t, before[0], b1) + assert.Equal(t, before[0]+before[1], b1+b2) +}