diff --git a/dtmcli/dtmimp/trans_base.go b/dtmcli/dtmimp/trans_base.go index a1a136f..ec30d6b 100644 --- a/dtmcli/dtmimp/trans_base.go +++ b/dtmcli/dtmimp/trans_base.go @@ -142,15 +142,19 @@ func TransRequestBranch(t *TransBase, method string, body interface{}, branchID if url == "" { return nil, nil } + query := map[string]string{ + "dtm": t.Dtm, + "gid": t.Gid, + "branch_id": branchID, + "trans_type": t.TransType, + "op": op, + } + if t.TransType == "xa" { // xa trans will add notify_url + query["phase2_url"] = url + } resp, err := RestyClient.R(). SetBody(body). - SetQueryParams(map[string]string{ - "dtm": t.Dtm, - "gid": t.Gid, - "branch_id": branchID, - "trans_type": t.TransType, - "op": op, - }). + SetQueryParams(query). SetHeaders(t.BranchHeaders). Execute(method, url) if err == nil { diff --git a/dtmcli/dtmimp/trans_xa_base.go b/dtmcli/dtmimp/trans_xa_base.go index 6d633c7..6404c31 100644 --- a/dtmcli/dtmimp/trans_xa_base.go +++ b/dtmcli/dtmimp/trans_xa_base.go @@ -11,36 +11,29 @@ import ( "strings" ) -// XaClientBase XaClient/XaGrpcClient base. shared by http and grpc -type XaClientBase struct { - Server string - Conf DBConf - NotifyURL string -} - -// HandleCallback Handle the callback of commit/rollback -func (xc *XaClientBase) HandleCallback(gid string, branchID string, action string) error { - db, err := PooledDB(xc.Conf) +// XaHandlePhase2 Handle the callback of commit/rollback +func XaHandlePhase2(gid string, dbConf DBConf, branchID string, op string) error { + db, err := PooledDB(dbConf) if err != nil { return err } xaID := gid + "-" + branchID - _, err = DBExec(db, GetDBSpecial().GetXaSQL(action, xaID)) + _, err = DBExec(db, GetDBSpecial().GetXaSQL(op, xaID)) if err != nil && (strings.Contains(err.Error(), "XAER_NOTA") || strings.Contains(err.Error(), "does not exist")) { // Repeat commit/rollback with the same id, report this error, ignore err = nil } - if action == OpRollback && err == nil { + if op == OpRollback && err == nil { // rollback insert a row after prepare. no-error means prepare has finished. - _, err = InsertBarrier(db, "xa", gid, branchID, OpAction, XaBarrier1, action) + _, err = InsertBarrier(db, "xa", gid, branchID, OpAction, XaBarrier1, op) } return err } -// HandleLocalTrans public handler of LocalTransaction via http/grpc -func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) error) (rerr error) { +// XaHandleLocalTrans public handler of LocalTransaction via http/grpc +func XaHandleLocalTrans(xa *TransBase, dbConf DBConf, cb func(*sql.DB) error) (rerr error) { xaBranch := xa.Gid + "-" + xa.BranchID - db, rerr := StandaloneDB(xc.Conf) + db, rerr := StandaloneDB(dbConf) if rerr != nil { return } @@ -66,8 +59,8 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) error) return } -// HandleGlobalTrans http/grpc GlobalTransaction shared func -func (xc *XaClientBase) HandleGlobalTrans(xa *TransBase, callDtm func(string) error, callBusi func() error) (rerr error) { +// XaHandleGlobalTrans http/grpc GlobalTransaction shared func +func XaHandleGlobalTrans(xa *TransBase, callDtm func(string) error, callBusi func() error) (rerr error) { rerr = callDtm("prepare") if rerr != nil { return diff --git a/dtmcli/trans_test.go b/dtmcli/trans_test.go index 05e6ab4..66a0459 100644 --- a/dtmcli/trans_test.go +++ b/dtmcli/trans_test.go @@ -23,8 +23,3 @@ func TestQuery(t *testing.T) { _, err = BarrierFromQuery(qs) assert.Error(t, err) } - -func TestXa(t *testing.T) { - _, err := NewXaClient("http://localhost:36789", DBConf{}, ":::::", nil) - assert.Error(t, err) -} diff --git a/dtmcli/xa.go b/dtmcli/xa.go index e2b8d01..662f8d6 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -21,76 +21,54 @@ type XaGlobalFunc func(xa *Xa) (*resty.Response, error) // XaLocalFunc type of xa local function type XaLocalFunc func(db *sql.DB, xa *Xa) error -// XaRegisterCallback type of xa register callback handler -type XaRegisterCallback func(path string, xa *XaClient) - -// XaClient xa client -type XaClient struct { - dtmimp.XaClientBase -} - // Xa xa transaction type Xa struct { dtmimp.TransBase + Phase2URL string } // XaFromQuery construct xa info from request func XaFromQuery(qs url.Values) (*Xa, error) { xa := &Xa{TransBase: *dtmimp.TransBaseFromQuery(qs)} - if xa.Gid == "" || xa.BranchID == "" { - return nil, fmt.Errorf("bad xa info: gid: %s branchid: %s", xa.Gid, xa.BranchID) + xa.Op = dtmimp.EscapeGet(qs, "op") + xa.Phase2URL = dtmimp.EscapeGet(qs, "phase2_url") + if xa.Gid == "" || xa.BranchID == "" || xa.Op == "" { + return nil, fmt.Errorf("bad xa info: gid: %s branchid: %s op: %s phase2_url: %s", xa.Gid, xa.BranchID, xa.Op, xa.Phase2URL) } return xa, nil } -// NewXaClient construct a xa client -func NewXaClient(server string, mysqlConf DBConf, notifyURL string, register XaRegisterCallback) (*XaClient, error) { - xa := &XaClient{XaClientBase: dtmimp.XaClientBase{ - Server: server, - Conf: mysqlConf, - NotifyURL: notifyURL, - }} - u, err := url.Parse(notifyURL) - if err != nil { - return nil, err - } - register(u.Path, xa) - return xa, nil -} - -// HandleCallback handle commit/rollback callback -func (xc *XaClient) HandleCallback(gid string, branchID string, action string) interface{} { - return xc.XaClientBase.HandleCallback(gid, branchID, action) -} - // XaLocalTransaction start a xa local transaction -func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) error { +func XaLocalTransaction(qs url.Values, dbConf DBConf, xaFunc XaLocalFunc) error { xa, err := XaFromQuery(qs) if err != nil { return err } - return xc.HandleLocalTrans(&xa.TransBase, func(db *sql.DB) error { + if xa.Op == dtmimp.OpCommit || xa.Op == dtmimp.OpRollback { + return dtmimp.XaHandlePhase2(xa.Gid, dbConf, xa.BranchID, xa.Op) + } + return dtmimp.XaHandleLocalTrans(&xa.TransBase, dbConf, func(db *sql.DB) error { err := xaFunc(db, xa) if err != nil { return err } return dtmimp.TransRegisterBranch(&xa.TransBase, map[string]string{ - "url": xc.XaClientBase.NotifyURL, + "url": xa.Phase2URL, "branch_id": xa.BranchID, }, "registerBranch") }) } // XaGlobalTransaction start a xa global transaction -func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr error) { - return xc.XaGlobalTransaction2(gid, func(x *Xa) {}, xaFunc) +func XaGlobalTransaction(server string, gid string, xaFunc XaGlobalFunc) error { + return XaGlobalTransaction2(server, gid, func(x *Xa) {}, xaFunc) } -// XaGlobalTransaction2 start a xa global transaction -func (xc *XaClient) XaGlobalTransaction2(gid string, custom func(*Xa), xaFunc XaGlobalFunc) (rerr error) { - xa := &Xa{TransBase: *dtmimp.NewTransBase(gid, "xa", xc.XaClientBase.Server, "")} +// XaGlobalTransaction2 start a xa global transaction with xa custom function +func XaGlobalTransaction2(server string, gid string, custom func(*Xa), xaFunc XaGlobalFunc) (rerr error) { + xa := &Xa{TransBase: *dtmimp.NewTransBase(gid, "xa", server, "")} custom(xa) - return xc.HandleGlobalTrans(&xa.TransBase, func(action string) error { + return dtmimp.XaHandleGlobalTrans(&xa.TransBase, func(action string) error { return dtmimp.TransCallDtm(&xa.TransBase, xa, action) }, func() error { _, rerr := xaFunc(xa) diff --git a/dtmgrpc/dtmgimp/types.go b/dtmgrpc/dtmgimp/types.go index 2b6a653..bb7fcaf 100644 --- a/dtmgrpc/dtmgimp/types.go +++ b/dtmgrpc/dtmgimp/types.go @@ -58,5 +58,8 @@ func InvokeBranch(t *dtmimp.TransBase, isRaw bool, msg proto.Message, url string } ctx := TransInfo2Ctx(t.Gid, t.TransType, branchID, op, t.Dtm) ctx = metadata.AppendToOutgoingContext(ctx, Map2Kvs(t.BranchHeaders)...) + if t.TransType == "xa" { // xa branch need addtional phase2_url + ctx = metadata.AppendToOutgoingContext(ctx, Map2Kvs(map[string]string{dtmpre + "phase2_url": url})...) + } return MustGetGrpcConn(server, isRaw).Invoke(ctx, method, msg, reply) } diff --git a/dtmgrpc/dtmgimp/utils.go b/dtmgrpc/dtmgimp/utils.go index 17aaa11..0e46c70 100644 --- a/dtmgrpc/dtmgimp/utils.go +++ b/dtmgrpc/dtmgimp/utils.go @@ -102,6 +102,11 @@ func GetMetaFromContext(ctx context.Context, name string) string { return mdGet(md, name) } +func GetDtmMetaFromContext(ctx context.Context, name string) string { + md, _ := metadata.FromIncomingContext(ctx) + return dtmGet(md, name) +} + type requestTimeoutKey struct{} // RequestTimeoutFromContext returns requestTime of transOption option diff --git a/dtmgrpc/xa.go b/dtmgrpc/xa.go index 37f7cfc..7ab34d7 100644 --- a/dtmgrpc/xa.go +++ b/dtmgrpc/xa.go @@ -26,14 +26,10 @@ type XaGrpcGlobalFunc func(xa *XaGrpc) error // XaGrpcLocalFunc type of xa local function type XaGrpcLocalFunc func(db *sql.DB, xa *XaGrpc) error -// XaGrpcClient xa client -type XaGrpcClient struct { - dtmimp.XaClientBase -} - // XaGrpc xa transaction type XaGrpc struct { dtmimp.TransBase + Phase2URL string } // XaGrpcFromRequest construct xa info from request @@ -41,39 +37,23 @@ func XaGrpcFromRequest(ctx context.Context) (*XaGrpc, error) { xa := &XaGrpc{ TransBase: *dtmgimp.TransBaseFromGrpc(ctx), } - if xa.Gid == "" || xa.BranchID == "" { - return nil, fmt.Errorf("bad xa info: gid: %s branchid: %s", xa.Gid, xa.BranchID) + xa.Phase2URL = dtmgimp.GetDtmMetaFromContext(ctx, "phase2_url") + if xa.Gid == "" || xa.BranchID == "" || xa.Op == "" { + return nil, fmt.Errorf("bad xa info: gid: %s branchid: %s op: %s phase2_url: %s", xa.Gid, xa.BranchID, xa.Op, xa.Phase2URL) } return xa, nil } -// NewXaGrpcClient construct a xa client -func NewXaGrpcClient(server string, mysqlConf dtmcli.DBConf, notifyURL string) *XaGrpcClient { - xa := &XaGrpcClient{XaClientBase: dtmimp.XaClientBase{ - Server: server, - Conf: mysqlConf, - NotifyURL: notifyURL, - }} - return xa -} - -// HandleCallback handle commit/rollback callback -func (xc *XaGrpcClient) HandleCallback(ctx context.Context) (*emptypb.Empty, error) { - tb := dtmgimp.TransBaseFromGrpc(ctx) - return &emptypb.Empty{}, xc.XaClientBase.HandleCallback(tb.Gid, tb.BranchID, tb.Op) -} - // XaLocalTransaction start a xa local transaction -func (xc *XaGrpcClient) XaLocalTransaction(ctx context.Context, msg proto.Message, xaFunc XaGrpcLocalFunc) error { +func XaLocalTransaction(ctx context.Context, dbConf dtmcli.DBConf, xaFunc XaGrpcLocalFunc) error { xa, err := XaGrpcFromRequest(ctx) if err != nil { return err } - data, err := proto.Marshal(msg) - if err != nil { - return err + if xa.Op == dtmimp.OpCommit || xa.Op == dtmimp.OpRollback { + return dtmimp.XaHandlePhase2(xa.Gid, dbConf, xa.BranchID, xa.Op) } - return xc.HandleLocalTrans(&xa.TransBase, func(db *sql.DB) error { + return dtmimp.XaHandleLocalTrans(&xa.TransBase, dbConf, func(db *sql.DB) error { err := xaFunc(db, xa) if err != nil { return err @@ -82,28 +62,28 @@ func (xc *XaGrpcClient) XaLocalTransaction(ctx context.Context, msg proto.Messag Gid: xa.Gid, BranchID: xa.BranchID, TransType: xa.TransType, - BusiPayload: data, - Data: map[string]string{"url": xc.NotifyURL}, + BusiPayload: nil, + Data: map[string]string{"url": xa.Phase2URL}, }) return err }) } // XaGlobalTransaction start a xa global transaction -func (xc *XaGrpcClient) XaGlobalTransaction(gid string, xaFunc XaGrpcGlobalFunc) error { - return xc.XaGlobalTransaction2(gid, func(xg *XaGrpc) {}, xaFunc) +func XaGlobalTransaction(server string, gid string, xaFunc XaGrpcGlobalFunc) error { + return XaGlobalTransaction2(server, gid, func(xg *XaGrpc) {}, xaFunc) } // XaGlobalTransaction2 new version of XaGlobalTransaction. support custom -func (xc *XaGrpcClient) XaGlobalTransaction2(gid string, custom func(*XaGrpc), xaFunc XaGrpcGlobalFunc) error { - xa := &XaGrpc{TransBase: *dtmimp.NewTransBase(gid, "xa", xc.Server, "")} +func XaGlobalTransaction2(server string, gid string, custom func(*XaGrpc), xaFunc XaGrpcGlobalFunc) error { + xa := &XaGrpc{TransBase: *dtmimp.NewTransBase(gid, "xa", server, "")} custom(xa) dc := dtmgimp.MustGetDtmClient(xa.Dtm) req := &dtmgpb.DtmRequest{ Gid: gid, TransType: xa.TransType, } - return xc.HandleGlobalTrans(&xa.TransBase, func(action string) error { + return dtmimp.XaHandleGlobalTrans(&xa.TransBase, func(action string) error { f := map[string]func(context.Context, *dtmgpb.DtmRequest, ...grpc.CallOption) (*emptypb.Empty, error){ "prepare": dc.Prepare, "submit": dc.Submit, diff --git a/go.mod b/go.mod index 4df247e..5d9206a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/BurntSushi/toml v0.4.1 // indirect github.com/dtm-labs/dtmdriver v0.0.1 github.com/dtm-labs/dtmdriver-gozero v0.0.2 - github.com/dtm-labs/dtmdriver-kratos v0.0.4 // indirect + github.com/dtm-labs/dtmdriver-kratos v0.0.4 github.com/dtm-labs/dtmdriver-polaris v0.0.4 github.com/dtm-labs/dtmdriver-protocol1 v0.0.1 github.com/gin-gonic/gin v1.7.7 diff --git a/go.sum b/go.sum index 725f08d..c67d3e1 100644 --- a/go.sum +++ b/go.sum @@ -107,14 +107,6 @@ github.com/dtm-labs/dtmdriver v0.0.1 h1:dHUZQ6g2ZN6eRUqds9kKq/3K7u9bcUGatUlbthD9 github.com/dtm-labs/dtmdriver v0.0.1/go.mod h1:fLiEeD2BPwM9Yq96TfcP9KpbTwFsn5nTxa/PP0jmFuk= github.com/dtm-labs/dtmdriver-gozero v0.0.2 h1:T+JH9kwVNMmISPU1BNviiTrvPdMA7UMFD+nfTqGPSyA= github.com/dtm-labs/dtmdriver-gozero v0.0.2/go.mod h1:5AAKwYok5f56e0kATOXvc+DAsfu4elISDuCV+G3+fYE= -github.com/dtm-labs/dtmdriver-kratos v0.0.0-20220318113458-787275b94ed2 h1:oTh5EWgcZ0eW2qjBscPc0SLK+IMbrEbrwqHowmSeP4c= -github.com/dtm-labs/dtmdriver-kratos v0.0.0-20220318113458-787275b94ed2/go.mod h1:MjrFIa2A191ATVb/xy2vnA2ZKqMK9zC/1m3pjxXwkac= -github.com/dtm-labs/dtmdriver-kratos v0.0.1 h1:JP3qnY9b+jE0RJ1ax20tKBJHwZrhrqYg0M8eNxcpuIw= -github.com/dtm-labs/dtmdriver-kratos v0.0.1/go.mod h1:MjrFIa2A191ATVb/xy2vnA2ZKqMK9zC/1m3pjxXwkac= -github.com/dtm-labs/dtmdriver-kratos v0.0.2 h1:/Tw1X9lvGOVXjc+cY6omMoODr16b4V5cim+w19ZeGAA= -github.com/dtm-labs/dtmdriver-kratos v0.0.2/go.mod h1:MjrFIa2A191ATVb/xy2vnA2ZKqMK9zC/1m3pjxXwkac= -github.com/dtm-labs/dtmdriver-kratos v0.0.3 h1:a09mvcGEqXf0DzjHOVR/UJnOGAMAwjfJ3LMG6z9092Q= -github.com/dtm-labs/dtmdriver-kratos v0.0.3/go.mod h1:MjrFIa2A191ATVb/xy2vnA2ZKqMK9zC/1m3pjxXwkac= github.com/dtm-labs/dtmdriver-kratos v0.0.4 h1:jDVvrwiw8GwVrampIxhoXZ9TewwQKHFpcDcQXyU2Qyc= github.com/dtm-labs/dtmdriver-kratos v0.0.4/go.mod h1:MjrFIa2A191ATVb/xy2vnA2ZKqMK9zC/1m3pjxXwkac= github.com/dtm-labs/dtmdriver-polaris v0.0.4 h1:yli0YmAsEgl47ymJHTxIzULeNe5dnmfN2ixLJRWm2Ok= diff --git a/test/busi/base_grpc.go b/test/busi/base_grpc.go index cb1032f..ff7daf1 100644 --- a/test/busi/base_grpc.go +++ b/test/busi/base_grpc.go @@ -18,7 +18,6 @@ import ( "github.com/dtm-labs/dtm/dtmcli/logger" "github.com/dtm-labs/dtm/dtmgrpc" "github.com/dtm-labs/dtm/dtmutil" - "github.com/gin-gonic/gin" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" "github.com/dtm-labs/dtm/dtmgrpc/dtmgpb" @@ -33,15 +32,6 @@ var BusiGrpc = fmt.Sprintf("localhost:%d", BusiGrpcPort) // DtmClient grpc client for dtm var DtmClient dtmgpb.DtmClient -// XaGrpcClient XA client connection -var XaGrpcClient *dtmgrpc.XaGrpcClient - -func init() { - setupFuncs["XaGrpcSetup"] = func(app *gin.Engine) { - XaGrpcClient = dtmgrpc.NewXaGrpcClient(dtmutil.DefaultGrpcServer, BusiConf, BusiGrpc+"/busi.Busi/XaNotify") - } -} - // GrpcStartup for grpc func GrpcStartup() { conn, err := grpc.Dial(dtmutil.DefaultGrpcServer, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(dtmgimp.GrpcClientLog)) @@ -105,13 +95,13 @@ func (s *busiServer) TransOutTcc(ctx context.Context, in *BusiReq) (*emptypb.Emp } func (s *busiServer) TransInXa(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { - return &emptypb.Empty{}, XaGrpcClient.XaLocalTransaction(ctx, in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { + return &emptypb.Empty{}, dtmgrpc.XaLocalTransaction(ctx, BusiConf, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { return sagaGrpcAdjustBalance(db, TransInUID, in.Amount, in.TransInResult) }) } func (s *busiServer) TransOutXa(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { - return &emptypb.Empty{}, XaGrpcClient.XaLocalTransaction(ctx, in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { + return &emptypb.Empty{}, dtmgrpc.XaLocalTransaction(ctx, BusiConf, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { return sagaGrpcAdjustBalance(db, TransOutUID, in.Amount, in.TransOutResult) }) } @@ -125,10 +115,6 @@ func (s *busiServer) TransInTccNested(ctx context.Context, in *BusiReq) (*emptyp return r, handleGrpcBusiness(in, MainSwitch.TransInResult.Fetch(), in.TransInResult, dtmimp.GetFuncName()) } -func (s *busiServer) XaNotify(ctx context.Context, in *emptypb.Empty) (*emptypb.Empty, error) { - return XaGrpcClient.HandleCallback(ctx) -} - func (s *busiServer) TransOutHeaderYes(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { meta := dtmgimp.GetMetaFromContext(ctx, "test_header") if meta == "" { diff --git a/test/busi/base_http.go b/test/busi/base_http.go index 6608603..6018e69 100644 --- a/test/busi/base_http.go +++ b/test/busi/base_http.go @@ -37,9 +37,6 @@ var setupFuncs = map[string]setupFunc{} // Busi busi service url prefix var Busi = fmt.Sprintf("http://localhost:%d%s", BusiPort, BusiAPI) -// XaClient 1 -var XaClient *dtmcli.XaClient - // SleepCancelHandler 1 type SleepCancelHandler func(c *gin.Context) interface{} @@ -63,13 +60,6 @@ func BaseAppStartup() *gin.Engine { } c.Next() }) - var err error - XaClient, err = dtmcli.NewXaClient(dtmutil.DefaultHTTPServer, BusiConf, Busi+"/xa", func(path string, xa *dtmcli.XaClient) { - app.POST(path, dtmutil.WrapHandler2(func(c *gin.Context) interface{} { - return xa.HandleCallback(c.Query("gid"), c.Query("branch_id"), c.Query("op")) - })) - }) - logger.FatalIfError(err) BaseAddRoute(app) addJrpcRoute(app) @@ -144,12 +134,12 @@ func BaseAddRoute(app *gin.Engine) { return bb.MongoQueryPrepared(MongoGet()) })) app.POST(BusiAPI+"/TransInXa", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { - return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { + return dtmcli.XaLocalTransaction(c.Request.URL.Query(), BusiConf, func(db *sql.DB, xa *dtmcli.Xa) error { return SagaAdjustBalance(db, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) }) })) app.POST(BusiAPI+"/TransOutXa", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { - return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { + return dtmcli.XaLocalTransaction(c.Request.URL.Query(), BusiConf, func(db *sql.DB, xa *dtmcli.Xa) error { return SagaAdjustBalance(db, TransOutUID, reqFrom(c).Amount, reqFrom(c).TransOutResult) }) })) @@ -167,7 +157,7 @@ func BaseAddRoute(app *gin.Engine) { return resp })) app.POST(BusiAPI+"/TransOutXaGorm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { - return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { + return dtmcli.XaLocalTransaction(c.Request.URL.Query(), BusiConf, func(db *sql.DB, xa *dtmcli.Xa) error { if reqFrom(c).TransOutResult == dtmcli.ResultFailure { return dtmcli.ErrFailure } diff --git a/test/xa_cover_test.go b/test/xa_cover_test.go index 673a58b..6836582 100644 --- a/test/xa_cover_test.go +++ b/test/xa_cover_test.go @@ -11,39 +11,36 @@ import ( ) func TestXaCoverDBError(t *testing.T) { - oldDriver := getXc().Conf.Driver + oldDriver := busi.BusiConf.Driver gid := dtmimp.GetFuncName() - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { req := busi.GenTransReq(30, false, false) _, err := xa.CallBranch(req, busi.Busi+"/TransOutXa") assert.Nil(t, err) - getXc().Conf.Driver = "no-driver" + busi.BusiConf.Driver = "no-driver" _, err = xa.CallBranch(req, busi.Busi+"/TransInXa") assert.Error(t, err) return nil, err }) assert.Error(t, err) waitTransProcessed(gid) - getXc().Conf.Driver = oldDriver + busi.BusiConf.Driver = oldDriver cronTransOnceForwardNow(t, gid, 500) // rollback succeeded here assert.Equal(t, StatusFailed, getTransStatus(gid)) assert.Equal(t, []string{StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) } func TestXaCoverDTMError(t *testing.T) { - oldServer := getXc().Server - getXc().Server = "localhost:01" gid := dtmimp.GetFuncName() - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction("localhost:01", gid, func(xa *dtmcli.Xa) (*resty.Response, error) { return nil, nil }) assert.Error(t, err) - getXc().Server = oldServer } func TestXaCoverGidError(t *testing.T) { gid := dtmimp.GetFuncName() + "-' '" - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { req := busi.GenTransReq(30, false, false) _, err := xa.CallBranch(req, busi.Busi+"/TransOutXa") assert.Error(t, err) diff --git a/test/xa_grpc_test.go b/test/xa_grpc_test.go index 890d129..678dac6 100644 --- a/test/xa_grpc_test.go +++ b/test/xa_grpc_test.go @@ -18,12 +18,9 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -func getXcg() *dtmgrpc.XaGrpcClient { - return busi.XaGrpcClient -} func TestXaGrpcNormal(t *testing.T) { gid := dtmimp.GetFuncName() - err := getXcg().XaGlobalTransaction(gid, func(xa *dtmgrpc.XaGrpc) error { + err := dtmgrpc.XaGlobalTransaction(DtmGrpcServer, gid, func(xa *dtmgrpc.XaGrpc) error { req := busi.GenBusiReq(30, false, false) r := &emptypb.Empty{} err := xa.CallBranch(req, busi.BusiGrpc+"/busi.Busi/TransOutXa", r) @@ -40,7 +37,7 @@ func TestXaGrpcNormal(t *testing.T) { func TestXaGrpcRollback(t *testing.T) { gid := dtmimp.GetFuncName() - err := getXcg().XaGlobalTransaction(gid, func(xa *dtmgrpc.XaGrpc) error { + err := dtmgrpc.XaGlobalTransaction(DtmGrpcServer, gid, func(xa *dtmgrpc.XaGrpc) error { req := busi.GenBusiReq(30, false, true) r := &emptypb.Empty{} err := xa.CallBranch(req, busi.BusiGrpc+"/busi.Busi/TransOutXa", r) @@ -60,11 +57,11 @@ func TestXaGrpcType(t *testing.T) { _, err := dtmgrpc.XaGrpcFromRequest(context.Background()) assert.Error(t, err) - err = busi.XaGrpcClient.XaLocalTransaction(context.Background(), nil, nil) + err = dtmgrpc.XaLocalTransaction(context.Background(), busi.BusiConf, nil) assert.Error(t, err) err = dtmimp.CatchP(func() { - busi.XaGrpcClient.XaGlobalTransaction(gid, func(xa *dtmgrpc.XaGrpc) error { panic(fmt.Errorf("hello")) }) + dtmgrpc.XaGlobalTransaction(DtmGrpcServer, gid, func(xa *dtmgrpc.XaGrpc) error { panic(fmt.Errorf("hello")) }) }) assert.Error(t, err) waitTransProcessed(gid) @@ -72,8 +69,7 @@ func TestXaGrpcType(t *testing.T) { func TestXaGrpcLocalError(t *testing.T) { gid := dtmimp.GetFuncName() - xc := busi.XaGrpcClient - err := xc.XaGlobalTransaction(gid, func(xa *dtmgrpc.XaGrpc) error { + err := dtmgrpc.XaGlobalTransaction(DtmGrpcServer, gid, func(xa *dtmgrpc.XaGrpc) error { return fmt.Errorf("an error") }) assert.Error(t, err, fmt.Errorf("an error")) diff --git a/test/xa_test.go b/test/xa_test.go index 6da797a..4980a0a 100644 --- a/test/xa_test.go +++ b/test/xa_test.go @@ -12,18 +12,15 @@ import ( "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/dtm-labs/dtm/dtmutil" "github.com/dtm-labs/dtm/test/busi" "github.com/go-resty/resty/v2" "github.com/stretchr/testify/assert" ) -func getXc() *dtmcli.XaClient { - return busi.XaClient -} - func TestXaNormal(t *testing.T) { gid := dtmimp.GetFuncName() - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(dtmutil.DefaultHTTPServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { req := busi.GenTransReq(30, false, false) resp, err := xa.CallBranch(req, busi.Busi+"/TransOutXa") if err != nil { @@ -39,7 +36,7 @@ func TestXaNormal(t *testing.T) { func TestXaDuplicate(t *testing.T) { gid := dtmimp.GetFuncName() - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { req := busi.GenTransReq(30, false, false) _, err := xa.CallBranch(req, busi.Busi+"/TransOutXa") assert.Nil(t, err) @@ -61,7 +58,7 @@ func TestXaDuplicate(t *testing.T) { func TestXaRollback(t *testing.T) { gid := dtmimp.GetFuncName() - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { req := busi.GenTransReq(30, false, true) resp, err := xa.CallBranch(req, busi.Busi+"/TransOutXa") if err != nil { @@ -77,7 +74,7 @@ func TestXaRollback(t *testing.T) { func TestXaLocalError(t *testing.T) { gid := dtmimp.GetFuncName() - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { return nil, fmt.Errorf("an error") }) assert.Error(t, err, fmt.Errorf("an error")) @@ -87,7 +84,7 @@ func TestXaLocalError(t *testing.T) { func TestXaTimeout(t *testing.T) { gid := dtmimp.GetFuncName() timeoutChan := make(chan int, 1) - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { go func() { cronTransOnceForwardNow(t, gid, 300) timeoutChan <- 0 @@ -103,7 +100,7 @@ func TestXaTimeout(t *testing.T) { func TestXaNotTimeout(t *testing.T) { gid := dtmimp.GetFuncName() timeoutChan := make(chan int, 1) - err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + err := dtmcli.XaGlobalTransaction(DtmServer, gid, func(xa *dtmcli.Xa) (*resty.Response, error) { go func() { cronTransOnceForwardNow(t, gid, 0) // not timeout, timeoutChan <- 0