From 262f68976b2ea08b888014e69fc58761f40eee3a Mon Sep 17 00:00:00 2001 From: qintang Date: Tue, 16 Aug 2022 14:28:45 +0800 Subject: [PATCH] =?UTF-8?q?feature:=20callBranch=20=E5=A2=9E=E5=8A=A0callO?= =?UTF-8?q?ption=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/dtmgrpc/dtmgimp/types.go | 4 ++-- client/dtmgrpc/msg.go | 5 +++-- client/dtmgrpc/tcc.go | 5 +++-- client/dtmgrpc/xa.go | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/client/dtmgrpc/dtmgimp/types.go b/client/dtmgrpc/dtmgimp/types.go index 431aa05..e6aebc2 100644 --- a/client/dtmgrpc/dtmgimp/types.go +++ b/client/dtmgrpc/dtmgimp/types.go @@ -55,7 +55,7 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c } // InvokeBranch invoke a url for trans -func InvokeBranch(t *dtmimp.TransBase, isRaw bool, msg proto.Message, url string, reply interface{}, branchID string, op string) error { +func InvokeBranch(t *dtmimp.TransBase, isRaw bool, msg proto.Message, url string, reply interface{}, branchID string, op string, opts ...grpc.CallOption) error { server, method, err := dtmdriver.GetDriver().ParseServerMethod(url) if err != nil { return err @@ -65,5 +65,5 @@ func InvokeBranch(t *dtmimp.TransBase, isRaw bool, msg proto.Message, url string if t.TransType == "xa" { // xa branch need additional phase2_url ctx = metadata.AppendToOutgoingContext(ctx, Map2Kvs(map[string]string{dtmpre + "phase2_url": url})...) } - return MustGetGrpcConn(server, isRaw).Invoke(ctx, method, msg, reply) + return MustGetGrpcConn(server, isRaw).Invoke(ctx, method, msg, reply, opts...) } diff --git a/client/dtmgrpc/msg.go b/client/dtmgrpc/msg.go index ec622ab..5c4943f 100644 --- a/client/dtmgrpc/msg.go +++ b/client/dtmgrpc/msg.go @@ -13,6 +13,7 @@ import ( "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" + grpc "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) @@ -62,7 +63,7 @@ func (s *MsgGrpc) DoAndSubmitDB(queryPrepared string, db *sql.DB, busiCall dtmcl // the error returned by busiCall will be returned // if busiCall return ErrFailure, then abort is called directly // if busiCall return not nil error other than ErrFailure, then DoAndSubmit will call queryPrepared to get the result -func (s *MsgGrpc) DoAndSubmit(queryPrepared string, busiCall func(bb *dtmcli.BranchBarrier) error) error { +func (s *MsgGrpc) DoAndSubmit(queryPrepared string, busiCall func(bb *dtmcli.BranchBarrier) error, opts ...grpc.CallOption) error { bb, err := dtmcli.BarrierFrom(s.TransType, s.Gid, dtmimp.MsgDoBranch0, dtmimp.MsgDoOp) // a special barrier for msg QueryPrepared if err == nil { err = s.Prepare(queryPrepared) @@ -70,7 +71,7 @@ func (s *MsgGrpc) DoAndSubmit(queryPrepared string, busiCall func(bb *dtmcli.Bra if err == nil { errb := busiCall(bb) if errb != nil && !errors.Is(errb, dtmcli.ErrFailure) { - err = dtmgimp.InvokeBranch(&s.TransBase, true, nil, queryPrepared, &[]byte{}, bb.BranchID, bb.Op) + err = dtmgimp.InvokeBranch(&s.TransBase, true, nil, queryPrepared, &[]byte{}, bb.BranchID, bb.Op, opts...) err = GrpcError2DtmError(err) } if errors.Is(errb, dtmcli.ErrFailure) || errors.Is(err, dtmcli.ErrFailure) { diff --git a/client/dtmgrpc/tcc.go b/client/dtmgrpc/tcc.go index 09f242f..029eab9 100644 --- a/client/dtmgrpc/tcc.go +++ b/client/dtmgrpc/tcc.go @@ -13,6 +13,7 @@ import ( "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp" "github.com/dtm-labs/dtm/client/dtmgrpc/dtmgpb" + grpc "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) @@ -61,7 +62,7 @@ func TccFromGrpc(ctx context.Context) (*TccGrpc, error) { } // CallBranch call a tcc branch -func (t *TccGrpc) CallBranch(busiMsg proto.Message, tryURL string, confirmURL string, cancelURL string, reply interface{}) error { +func (t *TccGrpc) CallBranch(busiMsg proto.Message, tryURL string, confirmURL string, cancelURL string, reply interface{}, opts ...grpc.CallOption) error { branchID := t.NewSubBranchID() bd, err := proto.Marshal(busiMsg) if err == nil { @@ -76,5 +77,5 @@ func (t *TccGrpc) CallBranch(busiMsg proto.Message, tryURL string, confirmURL st if err != nil { return err } - return dtmgimp.InvokeBranch(&t.TransBase, false, busiMsg, tryURL, reply, branchID, "try") + return dtmgimp.InvokeBranch(&t.TransBase, false, busiMsg, tryURL, reply, branchID, "try", opts...) } diff --git a/client/dtmgrpc/xa.go b/client/dtmgrpc/xa.go index 96add2b..f23e3bf 100644 --- a/client/dtmgrpc/xa.go +++ b/client/dtmgrpc/xa.go @@ -97,6 +97,6 @@ func XaGlobalTransaction2(server string, gid string, custom func(*XaGrpc), xaFun } // CallBranch call a xa branch -func (x *XaGrpc) CallBranch(msg proto.Message, url string, reply interface{}) error { - return dtmgimp.InvokeBranch(&x.TransBase, false, msg, url, reply, x.NewSubBranchID(), "action") +func (x *XaGrpc) CallBranch(msg proto.Message, url string, reply interface{}, opts ...grpc.CallOption) error { + return dtmgimp.InvokeBranch(&x.TransBase, false, msg, url, reply, x.NewSubBranchID(), "action", opts...) }