diff --git a/dtmgrpc/dtmgimp/types.go b/dtmgrpc/dtmgimp/types.go index 2df228a..4c0436d 100644 --- a/dtmgrpc/dtmgimp/types.go +++ b/dtmgrpc/dtmgimp/types.go @@ -51,12 +51,12 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c } // InvokeURL invoke a url for trans -func InvokeBranch(t *dtmimp.TransBase, msg proto.Message, url string, reply interface{}, branchID string, op string) error { +func InvokeBranch(t *dtmimp.TransBase, isRaw bool, msg proto.Message, url string, reply interface{}, branchID string, op string) error { server, method, err := dtmdriver.GetDriver().ParseServerMethod(url) if err != nil { return err } ctx := TransInfo2Ctx(t.Gid, t.TransType, branchID, op, t.Dtm) ctx = metadata.AppendToOutgoingContext(ctx, Map2Kvs(t.BranchHeaders)...) - return MustGetGrpcConn(server, false).Invoke(ctx, method, msg, reply) + return MustGetGrpcConn(server, isRaw).Invoke(ctx, method, msg, reply) } diff --git a/dtmgrpc/msg.go b/dtmgrpc/msg.go index 33ab7a3..575ec35 100644 --- a/dtmgrpc/msg.go +++ b/dtmgrpc/msg.go @@ -60,8 +60,7 @@ func (s *MsgGrpc) Do(queryPrepared string, busiCall func(bb *dtmcli.BranchBarrie if err == nil { err = busiCall(bb) if err != nil && !errors.Is(err, dtmcli.ErrFailure) { - var reply interface{} - err = dtmgimp.InvokeBranch(&s.TransBase, nil, queryPrepared, &reply, bb.BranchID, bb.Op) + err = dtmgimp.InvokeBranch(&s.TransBase, true, nil, queryPrepared, &[]byte{}, bb.BranchID, bb.Op) err = GrpcError2DtmError(err) } if errors.Is(err, dtmcli.ErrFailure) { diff --git a/dtmgrpc/tcc.go b/dtmgrpc/tcc.go index 11fb1b4..54a9e05 100644 --- a/dtmgrpc/tcc.go +++ b/dtmgrpc/tcc.go @@ -85,5 +85,5 @@ func (t *TccGrpc) CallBranch(busiMsg proto.Message, tryURL string, confirmURL st if err != nil { return err } - return dtmgimp.InvokeBranch(&t.TransBase, busiMsg, tryURL, reply, branchID, "try") + return dtmgimp.InvokeBranch(&t.TransBase, false, busiMsg, tryURL, reply, branchID, "try") } diff --git a/dtmgrpc/xa.go b/dtmgrpc/xa.go index dcdb9b1..a7bbebb 100644 --- a/dtmgrpc/xa.go +++ b/dtmgrpc/xa.go @@ -118,5 +118,5 @@ func (xc *XaGrpcClient) XaGlobalTransaction2(gid string, custom func(*XaGrpc), x // CallBranch call a xa branch func (x *XaGrpc) CallBranch(msg proto.Message, url string, reply interface{}) error { - return dtmgimp.InvokeBranch(&x.TransBase, msg, url, reply, x.NewSubBranchID(), "action") + return dtmgimp.InvokeBranch(&x.TransBase, false, msg, url, reply, x.NewSubBranchID(), "action") } diff --git a/test/busi/barrier.go b/test/busi/barrier.go index a0734e9..63c24c2 100644 --- a/test/busi/barrier.go +++ b/test/busi/barrier.go @@ -11,6 +11,7 @@ import ( "database/sql" "github.com/dtm-labs/dtm/dtmcli" + "github.com/dtm-labs/dtm/dtmgrpc" "github.com/dtm-labs/dtm/dtmsvr/config" "github.com/dtm-labs/dtm/dtmutil" "github.com/gin-gonic/gin" @@ -149,5 +150,6 @@ func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *BusiReq) (*emp func (s *busiServer) QueryPreparedB(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { barrier := MustBarrierFromGrpc(ctx) - return &emptypb.Empty{}, barrier.QueryPrepared(dbGet().ToSQLDB()) + err := barrier.QueryPrepared(dbGet().ToSQLDB()) + return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(err) } diff --git a/test/msg_grpc_barrier_test.go b/test/msg_grpc_barrier_test.go index 9b57633..aba9a65 100644 --- a/test/msg_grpc_barrier_test.go +++ b/test/msg_grpc_barrier_test.go @@ -8,6 +8,7 @@ import ( "bou.ke/monkey" "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/dtm-labs/dtm/dtmcli/logger" "github.com/dtm-labs/dtm/dtmgrpc" "github.com/dtm-labs/dtm/test/busi" "github.com/stretchr/testify/assert" @@ -48,7 +49,30 @@ func TestMsgGrpcPrepareAndSubmitCommitAfterFailed(t *testing.T) { }) return err }) - assert.Error(t, err) - cronTransOnceForwardNow(t, gid, 180) + assert.Nil(t, err) + waitTransProcessed(gid) assertNotSameBalance(t, before, "mysql") } + +func TestMsgGrpcPrepareAndSubmitCommitFailed(t *testing.T) { + if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit + return + } + before := getBeforeBalances("mysql") + gid := dtmimp.GetFuncName() + req := busi.GenBusiReq(30, false, false) + msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid). + Add(busi.Busi+"/SagaBTransIn", req) + var g *monkey.PatchGuard + err := msg.PrepareAndSubmit(busi.BusiGrpc+"/busi.Busi/QueryPreparedB", dbGet().ToSQLDB(), func(tx *sql.Tx) error { + g = monkey.PatchInstanceMethod(reflect.TypeOf(tx), "Commit", func(tx *sql.Tx) error { + logger.Debugf("tx.Commit rollback and return error in test") + _ = tx.Rollback() + return errors.New("test error for patch") + }) + return busi.SagaAdjustBalance(tx, busi.TransOutUID, -int(req.Amount), "SUCCESS") + }) + g.Unpatch() + assert.Error(t, err) + assertSameBalance(t, before, "mysql") +}