Browse Source

first test case ok

pull/159/head
yedf2 4 years ago
parent
commit
aa9358028f
  1. 33
      dtmcli/barrier.go
  2. 23
      dtmcli/msg.go
  3. 2
      dtmutil/utils.go
  4. 32
      test/busi/barrier.go
  5. 4
      test/busi/base_grpc.go
  6. 12
      test/busi/base_http.go
  7. 11
      test/busi/base_types.go
  8. 6
      test/busi/busi.go
  9. 26
      test/msg_barrier_test.go
  10. 6
      test/tcc_barrier_test.go
  11. 22
      test/types.go

33
dtmcli/barrier.go

@ -102,35 +102,14 @@ func (bb *BranchBarrier) CallWithDB(db *sql.DB, busiCall BarrierBusiFunc) error
}
func (bb *BranchBarrier) QueryPrepared(db *sql.DB) error {
affected, err := insertBarrier(db, bb.TransType, bb.Gid, bb.BranchID, BranchAction, bb.BranchID, bb.Op)
if err != nil {
return err
}
if affected > 0 {
return ErrFailure
}
return nil
}
func (bb *BranchBarrier) PrepareAndSubmit(msg *Msg, queryPrepared string, db *sql.DB, busiCall BarrierBusiFunc) (err error) {
var tx *sql.Tx
tx, err = db.Begin()
_, err := insertBarrier(db, bb.TransType, bb.Gid, bb.BranchID, BranchAction, bb.BranchID, "rollback")
var reason string
if err == nil {
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
err = busiCall(tx)
sql := fmt.Sprintf("select reason from %s where gid=? and branch_id=? and op=? and barrier_id=?", dtmimp.BarrierTableName)
err = db.QueryRow(sql, bb.Gid, bb.BranchID, bb.Op, bb.BarrierID).Scan(&reason)
}
if err == nil {
err = msg.Prepare(queryPrepared)
}
if err == nil {
err = tx.Commit()
}
if err == nil {
return msg.Submit() // should not assign err. or else defer may try to rollback a committed tx
if reason == "rollback" {
return ErrFailure
}
return err
}

23
dtmcli/msg.go

@ -6,7 +6,11 @@
package dtmcli
import "github.com/dtm-labs/dtm/dtmcli/dtmimp"
import (
"database/sql"
"github.com/dtm-labs/dtm/dtmcli/dtmimp"
)
// Msg reliable msg type
type Msg struct {
@ -35,3 +39,20 @@ func (s *Msg) Prepare(queryPrepared string) error {
func (s *Msg) Submit() error {
return dtmimp.TransCallDtm(&s.TransBase, s, "submit")
}
func (s *Msg) PrepareAndSubmit(queryPrepared string, db *sql.DB, busiCall BarrierBusiFunc) error {
bb, err := BarrierFrom(s.TransType, s.Gid, "00", "msg") // a special barrier for msg QueryPrepared
if err == nil {
err = bb.CallWithDB(db, func(tx *sql.Tx) error {
err := busiCall(tx)
if err == nil {
err = s.Prepare(queryPrepared)
}
return err
})
}
if err == nil {
err = s.Submit()
}
return err
}

2
dtmutil/utils.go

@ -38,7 +38,7 @@ func GetGinApp() *gin.Engine {
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(rb))
}
}
logger.Debugf("begin %s %s query: %s body: %s", c.Request.Method, c.FullPath(), c.Request.URL.RawQuery, body)
logger.Debugf("begin %s %s body: %s", c.Request.Method, c.Request.URL, body)
c.Next()
})
app.Any("/api/ping", func(c *gin.Context) { c.JSON(200, map[string]interface{}{"msg": "pong"}) })

32
test/busi/barrier.go

@ -17,29 +17,29 @@ import (
)
func init() {
setupFuncs["TccBarrierSetup"] = func(app *gin.Engine) {
setupFuncs["BarrierSetup"] = func(app *gin.Engine) {
app.POST(BusiAPI+"/SagaBTransIn", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaAdjustBalance(tx, transInUID, reqFrom(c).Amount, reqFrom(c).TransInResult)
return SagaAdjustBalance(tx, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult)
})
}))
app.POST(BusiAPI+"/SagaBTransInCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaAdjustBalance(tx, transInUID, -reqFrom(c).Amount, "")
return SagaAdjustBalance(tx, TransInUID, -reqFrom(c).Amount, "")
})
}))
app.POST(BusiAPI+"/SagaBTransOut", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaAdjustBalance(tx, transOutUID, -reqFrom(c).Amount, reqFrom(c).TransOutResult)
return SagaAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount, reqFrom(c).TransOutResult)
})
}))
app.POST(BusiAPI+"/SagaBTransOutCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaAdjustBalance(tx, transOutUID, reqFrom(c).Amount, "")
return SagaAdjustBalance(tx, TransOutUID, reqFrom(c).Amount, "")
})
}))
app.POST(BusiAPI+"/SagaBTransOutGorm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
@ -47,7 +47,7 @@ func init() {
barrier := MustBarrierFromGin(c)
tx := dbGet().DB.Begin()
return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(tx1 *sql.Tx) error {
return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, transOutUID).Error
return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, TransOutUID).Error
})
}))
@ -57,17 +57,17 @@ func init() {
return req.TransInResult, nil
}
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustTrading(tx, transInUID, req.Amount)
return tccAdjustTrading(tx, TransInUID, req.Amount)
})
}))
app.POST(BusiAPI+"/TccBTransInConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustBalance(tx, transInUID, reqFrom(c).Amount)
return tccAdjustBalance(tx, TransInUID, reqFrom(c).Amount)
})
}))
app.POST(BusiAPI+"/TccBTransInCancel", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustTrading(tx, transInUID, -reqFrom(c).Amount)
return tccAdjustTrading(tx, TransInUID, -reqFrom(c).Amount)
})
}))
app.POST(BusiAPI+"/TccBTransOutTry", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
@ -76,12 +76,12 @@ func init() {
return req.TransOutResult, nil
}
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustTrading(tx, transOutUID, -req.Amount)
return tccAdjustTrading(tx, TransOutUID, -req.Amount)
})
}))
app.POST(BusiAPI+"/TccBTransOutConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustBalance(tx, transOutUID, -reqFrom(c).Amount)
return tccAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount)
})
}))
app.POST(BusiAPI+"/TccBTransOutCancel", dtmutil.WrapHandler(TccBarrierTransOutCancel))
@ -91,34 +91,34 @@ func init() {
// TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return tccAdjustTrading(tx, transOutUID, reqFrom(c).Amount)
return tccAdjustTrading(tx, TransOutUID, reqFrom(c).Amount)
})
}
func (s *busiServer) TransInBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcAdjustBalance(tx, transInUID, in.Amount, in.TransInResult)
return sagaGrpcAdjustBalance(tx, TransInUID, in.Amount, in.TransInResult)
})
}
func (s *busiServer) TransOutBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcAdjustBalance(tx, transOutUID, -in.Amount, in.TransOutResult)
return sagaGrpcAdjustBalance(tx, TransOutUID, -in.Amount, in.TransOutResult)
})
}
func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcAdjustBalance(tx, transInUID, -in.Amount, "")
return sagaGrpcAdjustBalance(tx, TransInUID, -in.Amount, "")
})
}
func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcAdjustBalance(tx, transOutUID, in.Amount, "")
return sagaGrpcAdjustBalance(tx, TransOutUID, in.Amount, "")
})
}

4
test/busi/base_grpc.go

@ -102,13 +102,13 @@ func (s *busiServer) TransOutTcc(ctx context.Context, in *BusiReq) (*emptypb.Emp
func (s *busiServer) TransInXa(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
return &emptypb.Empty{}, XaGrpcClient.XaLocalTransaction(ctx, in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error {
return sagaGrpcAdjustBalance(db, transInUID, in.Amount, in.TransInResult)
return sagaGrpcAdjustBalance(db, TransInUID, in.Amount, in.TransInResult)
})
}
func (s *busiServer) TransOutXa(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
return &emptypb.Empty{}, XaGrpcClient.XaLocalTransaction(ctx, in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error {
return sagaGrpcAdjustBalance(db, transOutUID, in.Amount, in.TransOutResult)
return sagaGrpcAdjustBalance(db, TransOutUID, in.Amount, in.TransOutResult)
})
}

12
test/busi/base_http.go

@ -103,15 +103,21 @@ func BaseAddRoute(app *gin.Engine) {
logger.Debugf("%s QueryPrepared", c.Query("gid"))
return dtmimp.OrString(MainSwitch.QueryPreparedResult.Fetch(), dtmcli.ResultSuccess), nil
}))
app.GET(BusiAPI+"/QueryPreparedB", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
logger.Debugf("%s QueryPreparedB", c.Query("gid"))
bb := MustBarrierFromGin(c)
db := dbGet().ToSQLDB()
return error2Resp(bb.QueryPrepared(db))
}))
app.POST(BusiAPI+"/TransInXa", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error {
return sagaAdjustBalance(db, transInUID, reqFrom(c).Amount, reqFrom(c).TransInResult)
return SagaAdjustBalance(db, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult)
})
return error2Resp(err)
}))
app.POST(BusiAPI+"/TransOutXa", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error {
return sagaAdjustBalance(db, transOutUID, reqFrom(c).Amount, reqFrom(c).TransOutResult)
return SagaAdjustBalance(db, TransOutUID, reqFrom(c).Amount, reqFrom(c).TransOutResult)
})
return error2Resp(err)
}))
@ -137,7 +143,7 @@ func BaseAddRoute(app *gin.Engine) {
if err != nil {
return err
}
dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, transOutUID)
dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, TransOutUID)
return dbr.Error
})
return error2Resp(err)

11
test/busi/base_types.go

@ -32,15 +32,10 @@ func (*UserAccount) TableName() string {
return "dtm_busi.user_account"
}
func GetUserAccountByUid(uid int) *UserAccount {
func GetBalanceByUid(uid int) int {
ua := UserAccount{}
dbr := dbGet().Must().Model(&ua).Where("user_id=?", uid).First(&ua)
dtmimp.E2P(dbr.Error)
return &ua
}
func IsEqual(ua1, ua2 *UserAccount) bool {
return ua1.UserId == ua2.UserId && ua1.Balance == ua2.Balance && ua1.TradingBalance == ua2.TradingBalance
_ = dbGet().Must().Model(&ua).Where("user_id=?", uid).First(&ua)
return dtmimp.MustAtoi(ua.Balance[:len(ua.Balance)-3])
}
// TransReq transaction request payload

6
test/busi/busi.go

@ -13,8 +13,8 @@ import (
status "google.golang.org/grpc/status"
)
const transOutUID = 1
const transInUID = 2
const TransOutUID = 1
const TransInUID = 2
func handleGrpcBusiness(in *BusiReq, result1 string, result2 string, busi string) error {
res := dtmimp.OrString(result1, result2, dtmcli.ResultSuccess)
@ -59,7 +59,7 @@ func sagaGrpcAdjustBalance(db dtmcli.DB, uid int, amount int64, result string) e
}
func sagaAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error {
func SagaAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error {
if strings.Contains(result, dtmcli.ResultFailure) {
return dtmcli.ErrFailure
}

26
test/msg_barrier_test.go

@ -0,0 +1,26 @@
package test
import (
"database/sql"
"testing"
"github.com/dtm-labs/dtm/dtmcli"
"github.com/dtm-labs/dtm/test/busi"
"github.com/stretchr/testify/assert"
)
func TestMsgPrepareAndSubmit(t *testing.T) {
before := getBeforeBalances()
gid := dtmcli.MustGenGid(DtmServer)
req := busi.GenTransReq(30, false, false)
msg := dtmcli.NewMsg(DtmServer, gid).
Add(busi.Busi+"/SagaBTransIn1", req)
err := msg.PrepareAndSubmit(Busi+"/QueryPreparedB", dbGet().ToSQLDB(), func(tx *sql.Tx) error {
return busi.SagaAdjustBalance(tx, busi.TransOutUID, -req.Amount, "SUCCESS")
})
assert.Nil(t, err)
waitTransProcessed(msg.Gid)
assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid))
assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid))
assertNotSameBalance(t, before)
}

6
test/tcc_barrier_test.go

@ -51,8 +51,7 @@ func TestTccBarrierRollback(t *testing.T) {
}
func TestTccBarrierDisorder(t *testing.T) {
ua1 := busi.GetUserAccountByUid(1)
ua2 := busi.GetUserAccountByUid(2)
before := getBeforeBalances()
cancelFinishedChan := make(chan string, 2)
cancelCanReturnChan := make(chan string, 2)
gid := dtmimp.GetFuncName()
@ -123,8 +122,7 @@ func TestTccBarrierDisorder(t *testing.T) {
assert.Error(t, err, fmt.Errorf("a cancelled tcc"))
assert.Equal(t, []string{StatusSucceed, StatusPrepared}, getBranchesStatus(gid))
assert.Equal(t, StatusFailed, getTransStatus(gid))
assert.True(t, busi.IsEqual(ua1, busi.GetUserAccountByUid(1)))
assert.True(t, busi.IsEqual(ua2, busi.GetUserAccountByUid(2)))
assertSameBalance(t, before)
}
func TestTccBarrierPanic(t *testing.T) {

22
test/types.go

@ -7,6 +7,7 @@
package test
import (
"testing"
"time"
"github.com/dtm-labs/dtm/dtmcli"
@ -16,6 +17,7 @@ import (
"github.com/dtm-labs/dtm/dtmsvr/config"
"github.com/dtm-labs/dtm/dtmutil"
"github.com/dtm-labs/dtm/test/busi"
"github.com/stretchr/testify/assert"
)
var conf = &config.Config
@ -80,3 +82,23 @@ const (
// StatusAborting status for global trans status.
StatusAborting = dtmcli.StatusAborting
)
func getBeforeBalances() []int {
b1 := busi.GetBalanceByUid(busi.TransOutUID)
b2 := busi.GetBalanceByUid(busi.TransInUID)
return []int{b1, b2}
}
func assertSameBalance(t *testing.T, before []int) {
b1 := busi.GetBalanceByUid(busi.TransOutUID)
b2 := busi.GetBalanceByUid(busi.TransInUID)
assert.Equal(t, before[0], b1)
assert.Equal(t, before[1], b2)
}
func assertNotSameBalance(t *testing.T, before []int) {
b1 := busi.GetBalanceByUid(busi.TransOutUID)
b2 := busi.GetBalanceByUid(busi.TransInUID)
assert.NotEqual(t, before[0], b1)
assert.Equal(t, before[0]+before[1], b1+b2)
}

Loading…
Cancel
Save