diff --git a/dtmgrpc/dtmgimp/utils.go b/dtmgrpc/dtmgimp/utils.go index 12c7e8d..17aaa11 100644 --- a/dtmgrpc/dtmgimp/utils.go +++ b/dtmgrpc/dtmgimp/utils.go @@ -8,9 +8,6 @@ package dtmgimp import ( context "context" - "time" - - "google.golang.org/grpc" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" @@ -29,13 +26,6 @@ func MustProtoMarshal(msg proto.Message) []byte { // DtmGrpcCall make a convenient call to dtm func DtmGrpcCall(s *dtmimp.TransBase, operation string) error { - if s.RequestTimeout != 0 { - ClientInterceptors = append(ClientInterceptors, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx2, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second) - defer cancel() - return invoker(ctx2, method, req, reply, cc, opts...) - }) - } reply := emptypb.Empty{} return MustGetGrpcConn(s.Dtm, false).Invoke(context.Background(), "/dtmgimp.Dtm/"+operation, &dtmgpb.DtmRequest{ Gid: s.Gid, @@ -111,3 +101,19 @@ func GetMetaFromContext(ctx context.Context, name string) string { md, _ := metadata.FromIncomingContext(ctx) return mdGet(md, name) } + +type requestTimeoutKey struct{} + +// RequestTimeoutFromContext returns requestTime of transOption option +func RequestTimeoutFromContext(ctx context.Context) int64 { + if v, ok := ctx.Value(requestTimeoutKey{}).(int64); ok { + return v + } + + return 0 +} + +// RequestTimeoutNewContext sets requestTimeout of transOption option to context +func RequestTimeoutNewContext(ctx context.Context, requestTimeout int64) context.Context { + return context.WithValue(ctx, requestTimeoutKey{}, requestTimeout) +} diff --git a/dtmsvr/svr.go b/dtmsvr/svr.go index b897ad9..5d97fc2 100644 --- a/dtmsvr/svr.go +++ b/dtmsvr/svr.go @@ -7,10 +7,13 @@ package dtmsvr import ( + "context" "fmt" "net" "time" + "github.com/dtm-labs/dtm/dtmgrpc" + "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/logger" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" @@ -26,11 +29,15 @@ func StartSvr() { setServerInfoMetrics() dtmcli.GetRestyClient().SetTimeout(time.Duration(conf.RequestTimeout) * time.Second) - //dtmgrpc.AddUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - // ctx2, cancel := context.WithTimeout(ctx, time.Duration(conf.RequestTimeout)*time.Second) - // defer cancel() - // return invoker(ctx2, method, req, reply, cc, opts...) - //}) + dtmgrpc.AddUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + timeout := conf.RequestTimeout + if v := dtmgimp.RequestTimeoutFromContext(ctx); v != 0 { + timeout = v + } + ctx2, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + return invoker(ctx2, method, req, reply, cc, opts...) + }) // start gin server app := dtmutil.GetGinApp() diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index bda6330..00c0c58 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -7,14 +7,11 @@ package dtmsvr import ( - "context" "errors" "fmt" "strings" "time" - "google.golang.org/grpc" - "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" @@ -122,20 +119,13 @@ func (t *TransGlobal) getURLResult(url string, branchID, op string, branchPayloa if err != nil { return err } - dtmgimp.ClientInterceptors = append(dtmgimp.ClientInterceptors, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - timeout := conf.RequestTimeout - if t.RequestTimeout != 0 { - timeout = conf.RequestTimeout - } - ctx2, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - return invoker(ctx2, method, req, reply, cc, opts...) - }) + conn := dtmgimp.MustGetGrpcConn(server, true) ctx := dtmgimp.TransInfo2Ctx(t.Gid, t.TransType, branchID, op, "") kvs := dtmgimp.Map2Kvs(t.Ext.Headers) kvs = append(kvs, dtmgimp.Map2Kvs(t.BranchHeaders)...) ctx = metadata.AppendToOutgoingContext(ctx, kvs...) + ctx = dtmgimp.RequestTimeoutNewContext(ctx, t.RequestTimeout) err = conn.Invoke(ctx, method, branchPayload, &[]byte{}) if err == nil { return nil