diff --git a/bench/http.go b/bench/http.go index a60cf02..1d1bd4a 100644 --- a/bench/http.go +++ b/bench/http.go @@ -97,7 +97,7 @@ func qsAdjustBalance(uid int, amount int, c *gin.Context) (interface{}, error) { return dtmcli.MapSuccess, nil } tb := dtmimp.TransBaseFromQuery(c.Request.URL.Query()) - f := func(tx dtmcli.DB) error { + f := func(tx *sql.Tx) error { for i := 0; i < sqls; i++ { _, err := dtmimp.DBExec(tx, "insert into dtm_busi.user_account_log(user_id, delta, gid, branch_id, op, reason) values(?,?,?,?,?,?)", uid, amount, tb.Gid, c.Query("branch_id"), tb.TransType, fmt.Sprintf("inserted by dtm transaction %s %s", tb.Gid, c.Query("branch_id"))) diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index fbd9215..480ae17 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -7,6 +7,7 @@ package dtmcli import ( + "database/sql" "fmt" "net/url" @@ -14,7 +15,7 @@ import ( ) // BarrierBusiFunc type for busi func -type BarrierBusiFunc func(db DB) error +type BarrierBusiFunc func(tx *sql.Tx) error // BranchBarrier every branch info type BranchBarrier struct { @@ -48,7 +49,7 @@ func BarrierFrom(transType, gid, branchID, op string) (*BranchBarrier, error) { return ti, nil } -func insertBarrier(tx Tx, transType string, gid string, branchID string, op string, barrierID string, reason string) (int64, error) { +func insertBarrier(tx DB, transType string, gid string, branchID string, op string, barrierID string, reason string) (int64, error) { if op == "" { return 0, nil } @@ -59,7 +60,7 @@ func insertBarrier(tx Tx, transType string, gid string, branchID string, op stri // Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 // tx: 本地数据库的事务对象,允许子事务屏障进行事务操作 // busiCall: 业务函数,仅在必要时被调用 -func (bb *BranchBarrier) Call(tx Tx, busiCall BarrierBusiFunc) (rerr error) { +func (bb *BranchBarrier) Call(tx *sql.Tx, busiCall BarrierBusiFunc) (rerr error) { bb.BarrierID = bb.BarrierID + 1 bid := fmt.Sprintf("%02d", bb.BarrierID) defer func() { @@ -89,3 +90,12 @@ func (bb *BranchBarrier) Call(tx Tx, busiCall BarrierBusiFunc) (rerr error) { rerr = busiCall(tx) return } + +// CallWithDB the same as Call, but with *sql.DB +func (bb *BranchBarrier) CallWithDB(db *sql.DB, busiCall BarrierBusiFunc) error { + tx, err := db.Begin() + if err != nil { + return err + } + return bb.Call(tx, busiCall) +} diff --git a/dtmcli/dtmimp/types.go b/dtmcli/dtmimp/types.go index 632a072..3848fb0 100644 --- a/dtmcli/dtmimp/types.go +++ b/dtmcli/dtmimp/types.go @@ -13,10 +13,3 @@ type DB interface { Exec(query string, args ...interface{}) (sql.Result, error) QueryRow(query string, args ...interface{}) *sql.Row } - -// Tx interface of dtmcli tx -type Tx interface { - Rollback() error - Commit() error - DB -} diff --git a/dtmcli/types.go b/dtmcli/types.go index 4b502c8..64b4ab6 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -25,9 +25,6 @@ func MustGenGid(server string) string { // DB interface type DB = dtmimp.DB -// Tx interface -type Tx = dtmimp.Tx - // TransOptions transaction option type TransOptions = dtmimp.TransOptions diff --git a/examples/grpc_saga_barrier.go b/examples/grpc_saga_barrier.go index c89a2d5..777b236 100644 --- a/examples/grpc_saga_barrier.go +++ b/examples/grpc_saga_barrier.go @@ -8,6 +8,7 @@ package examples import ( "context" + "database/sql" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" @@ -41,28 +42,28 @@ func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int64, result st func (s *busiServer) TransInBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) - return &emptypb.Empty{}, barrier.Call(txGet(), func(tx dtmcli.DB) error { + return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { return sagaGrpcBarrierAdjustBalance(tx, 2, in.Amount, in.TransInResult) }) } func (s *busiServer) TransOutBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) - return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaGrpcBarrierAdjustBalance(db, 1, -in.Amount, in.TransOutResult) + return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(tx, 1, -in.Amount, in.TransOutResult) }) } func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) - return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaGrpcBarrierAdjustBalance(db, 2, -in.Amount, "") + return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(tx, 2, -in.Amount, "") }) } func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) - return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaGrpcBarrierAdjustBalance(db, 1, in.Amount, "") + return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(tx, 1, in.Amount, "") }) } diff --git a/examples/http_saga_barrier.go b/examples/http_saga_barrier.go index 7c9e916..c68aeb8 100644 --- a/examples/http_saga_barrier.go +++ b/examples/http_saga_barrier.go @@ -7,6 +7,8 @@ package examples import ( + "database/sql" + "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" @@ -45,15 +47,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaBarrierAdjustBalance(db, 1, req.Amount) + return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaBarrierAdjustBalance(tx, 1, req.Amount) }) } func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaBarrierAdjustBalance(db, 1, -reqFrom(c).Amount) + return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaBarrierAdjustBalance(tx, 1, -reqFrom(c).Amount) }) } @@ -63,14 +65,14 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { return req.TransOutResult, nil } barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaBarrierAdjustBalance(db, 2, -req.Amount) + return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaBarrierAdjustBalance(tx, 2, -req.Amount) }) } func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { - return sagaBarrierAdjustBalance(db, 2, reqFrom(c).Amount) + return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return sagaBarrierAdjustBalance(tx, 2, reqFrom(c).Amount) }) } diff --git a/examples/http_saga_gorm_barrier.go b/examples/http_saga_gorm_barrier.go index 8652a74..1f81c44 100644 --- a/examples/http_saga_gorm_barrier.go +++ b/examples/http_saga_gorm_barrier.go @@ -37,7 +37,7 @@ func sagaGormBarrierTransOut(c *gin.Context) (interface{}, error) { req := reqFrom(c) barrier := MustBarrierFromGin(c) tx := dbGet().DB.Begin() - return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(db dtmcli.DB) error { + return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(tx1 *sql.Tx) error { return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, 2).Error }) } diff --git a/examples/http_tcc_barrier.go b/examples/http_tcc_barrier.go index 089a4f6..85e65da 100644 --- a/examples/http_tcc_barrier.go +++ b/examples/http_tcc_barrier.go @@ -7,6 +7,7 @@ package examples import ( + "database/sql" "fmt" "github.com/gin-gonic/gin" @@ -68,20 +69,20 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error { - return adjustTrading(db, transInUID, req.Amount) + return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return adjustTrading(tx, transInUID, req.Amount) }) } func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error { - return adjustBalance(db, transInUID, reqFrom(c).Amount) + return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return adjustBalance(tx, transInUID, reqFrom(c).Amount) }) } func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error { - return adjustTrading(db, transInUID, -reqFrom(c).Amount) + return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return adjustTrading(tx, transInUID, -reqFrom(c).Amount) }) } @@ -90,20 +91,20 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) { if req.TransOutResult != "" { return req.TransOutResult, nil } - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error { - return adjustTrading(db, transOutUID, -req.Amount) + return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return adjustTrading(tx, transOutUID, -req.Amount) }) } func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error { - return adjustBalance(db, transOutUID, -reqFrom(c).Amount) + return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return adjustBalance(tx, transOutUID, -reqFrom(c).Amount) }) } // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error { - return adjustTrading(db, transOutUID, reqFrom(c).Amount) + return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return adjustTrading(tx, transOutUID, reqFrom(c).Amount) }) } diff --git a/test/base_test.go b/test/base_test.go index a6dc3f4..c3922ef 100644 --- a/test/base_test.go +++ b/test/base_test.go @@ -7,6 +7,7 @@ package test import ( + "database/sql" "fmt" "testing" @@ -38,7 +39,7 @@ func TestBaseSqlDB(t *testing.T) { db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, op, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") tx, err := db.ToSQLDB().Begin() asserts.Nil(err) - err = barrier.Call(tx, func(db dtmcli.DB) error { + err = barrier.Call(tx, func(tx *sql.Tx) error { dtmimp.Logf("rollback gid2") return fmt.Errorf("gid2 error") }) @@ -50,7 +51,7 @@ func TestBaseSqlDB(t *testing.T) { barrier.BarrierID = 0 tx2, err := db.ToSQLDB().Begin() asserts.Nil(err) - err = barrier.Call(tx2, func(db dtmcli.DB) error { + err = barrier.Call(tx2, func(tx *sql.Tx) error { dtmimp.Logf("submit gid2") return nil }) diff --git a/test/tcc_barrier_test.go b/test/tcc_barrier_test.go index 66184ec..3fb5339 100644 --- a/test/tcc_barrier_test.go +++ b/test/tcc_barrier_test.go @@ -126,7 +126,7 @@ func TestTccBarrierPanic(t *testing.T) { func() { defer dtmimp.P2E(&err) tx, _ := dbGet().ToSQLDB().BeginTx(context.Background(), &sql.TxOptions{}) - bb.Call(tx, func(db dtmcli.DB) error { + bb.Call(tx, func(tx *sql.Tx) error { panic(fmt.Errorf("an error")) }) }()