Browse Source

Merge pull request #459 from Makonike/feature/context

feat: new saga grpc with context
pull/465/head
yedf2 3 years ago
committed by GitHub
parent
commit
387946fe9b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      client/dtmcli/trans_saga.go
  2. 50
      client/dtmgrpc/options_test.go
  3. 13
      client/dtmgrpc/saga.go
  4. 3
      dtmsvr/trans_class.go
  5. 15
      dtmsvr/trans_process.go
  6. 17
      dtmsvr/trans_status.go
  7. 18
      dtmsvr/trans_type_msg.go
  8. 13
      dtmsvr/trans_type_saga.go
  9. 5
      dtmsvr/trans_type_tcc.go
  10. 6
      dtmsvr/trans_type_workflow.go
  11. 5
      dtmsvr/trans_type_xa.go
  12. 22
      dtmsvr/utils.go
  13. 87
      dtmsvr/utils_test.go
  14. 3
      test/saga_grpc_test.go

9
client/dtmcli/trans_saga.go

@ -7,6 +7,8 @@
package dtmcli
import (
"context"
"github.com/dtm-labs/dtm/client/dtmcli/dtmimp"
)
@ -21,6 +23,13 @@ func NewSaga(server string, gid string) *Saga {
return &Saga{TransBase: *dtmimp.NewTransBase(gid, "saga", server, ""), orders: map[int][]int{}}
}
// NewSagaWithContext create a saga with context
func NewSagaWithContext(ctx context.Context, server string, gid string) *Saga {
saga := NewSaga(server, gid)
saga.TransBase.Context = ctx
return saga
}
// Add add a saga step
func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga {
s.Steps = append(s.Steps, map[string]string{"action": action, "compensate": compensate})

50
client/dtmgrpc/options_test.go

@ -1,6 +1,7 @@
package dtmgrpc
import (
"context"
"reflect"
"testing"
@ -102,3 +103,52 @@ func TestNewSagaGrpc(t *testing.T) {
})
}
}
// TestNewSagaGrpcWithContext ut for NewSagaGrpcWithContext
func TestNewSagaGrpcWithContext(t *testing.T) {
var (
ctx = context.Background()
server = "dmt_server_address"
gidNoOptions = "msg_no_setup_options"
gidTraceIDXXX = "msg_setup_options_trace_id_xxx"
sagaWithTraceIDXXX = &SagaGrpc{Saga: *dtmcli.NewSagaWithContext(ctx, server, gidTraceIDXXX)}
traceIDHeaders = map[string]string{
"x-trace-id": "xxx",
}
)
sagaWithTraceIDXXX.BranchHeaders = traceIDHeaders
type args struct {
gid string
opts []TransBaseOption
}
tests := []struct {
name string
args args
want *SagaGrpc
}{
{
name: "no setup options",
args: args{gid: gidNoOptions},
want: &SagaGrpc{Saga: *dtmcli.NewSaga(server, gidNoOptions)},
},
{
name: "msg with trace_id",
args: args{
gid: gidTraceIDXXX,
opts: []TransBaseOption{
WithBranchHeaders(traceIDHeaders),
},
},
want: sagaWithTraceIDXXX,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewSagaGrpcWithContext(ctx, server, tt.args.gid, tt.args.opts...)
t.Logf("TestNewSagaGrpc %s got %+v\n", tt.name, got)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewSagaGrpc() = %v, want %v", got, tt.want)
}
})
}
}

13
client/dtmgrpc/saga.go

@ -7,6 +7,8 @@
package dtmgrpc
import (
"context"
"github.com/dtm-labs/dtm/client/dtmcli"
"github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp"
"google.golang.org/protobuf/proto"
@ -28,6 +30,17 @@ func NewSagaGrpc(server string, gid string, opts ...TransBaseOption) *SagaGrpc {
return sg
}
// NewSagaGrpcWithContext create a saga with context
func NewSagaGrpcWithContext(ctx context.Context, server string, gid string, opts ...TransBaseOption) *SagaGrpc {
sg := &SagaGrpc{Saga: *dtmcli.NewSagaWithContext(ctx, server, gid)}
for _, opt := range opts {
opt(&sg.TransBase)
}
return sg
}
// Add add a saga step
func (s *SagaGrpc) Add(action string, compensate string, payload proto.Message) *SagaGrpc {
s.Steps = append(s.Steps, map[string]string{"action": action, "compensate": compensate})

3
dtmsvr/trans_class.go

@ -42,7 +42,7 @@ type TransBranch = storage.TransBranchStore
type transProcessor interface {
GenBranches() []TransBranch
ProcessOnce(branches []TransBranch) error
ProcessOnce(ctx context.Context, branches []TransBranch) error
}
type processorCreator func(*TransGlobal) transProcessor
@ -103,6 +103,7 @@ func TransFromDtmRequest(ctx context.Context, c *dtmgpb.DtmRequest) *TransGlobal
},
}}
r.ReqExtra = c.ReqExtra
r.Context = ctx
if c.Steps != "" {
dtmimp.MustUnmarshalString(c.Steps, &r.Steps)
}

15
dtmsvr/trans_process.go

@ -7,6 +7,7 @@
package dtmsvr
import (
"context"
"errors"
"fmt"
"time"
@ -31,18 +32,18 @@ func (t *TransGlobal) process(branches []TransBranch) error {
if t.ExtData != "" {
dtmimp.MustUnmarshalString(t.ExtData, &t.Ext)
}
if !t.WaitResult {
go func() {
err := t.processInner(branches)
ctx := NewAsyncContext(t.Context)
go func(ctx context.Context) {
err := t.processInner(ctx, branches)
if err != nil && !errors.Is(err, dtmimp.ErrOngoing) {
logger.Errorf("processInner err: %v", err)
}
}()
}(ctx)
return nil
}
submitting := t.Status == dtmcli.StatusSubmitted
err := t.processInner(branches)
err := t.processInner(t.Context, branches)
if err != nil {
return err
}
@ -56,7 +57,7 @@ func (t *TransGlobal) process(branches []TransBranch) error {
return nil
}
func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) {
func (t *TransGlobal) processInner(ctx context.Context, branches []TransBranch) (rerr error) {
defer handlePanic(&rerr)
defer func() {
if rerr != nil && !errors.Is(rerr, dtmcli.ErrOngoing) {
@ -70,7 +71,7 @@ func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) {
}()
logger.Debugf("processing: %s status: %s", t.Gid, t.Status)
t.lastTouched = time.Now()
rerr = t.getProcessor().ProcessOnce(branches)
rerr = t.getProcessor().ProcessOnce(ctx, branches)
return
}

17
dtmsvr/trans_status.go

@ -1,6 +1,7 @@
package dtmsvr
import (
"context"
"errors"
"fmt"
"math"
@ -127,7 +128,7 @@ func (t *TransGlobal) needProcess() bool {
return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout()
}
func (t *TransGlobal) getURLResult(uri string, branchID, op string, branchPayload []byte) error {
func (t *TransGlobal) getURLResult(ctx context.Context, uri string, branchID, op string, branchPayload []byte) error {
if uri == "" { // empty url is success
return nil
}
@ -137,7 +138,7 @@ func (t *TransGlobal) getURLResult(uri string, branchID, op string, branchPayloa
}
return t.getHTTPResult(uri, branchID, op, branchPayload)
}
return t.getGrpcResult(uri, branchID, op, branchPayload)
return t.getGrpcResult(ctx, uri, branchID, op, branchPayload)
}
func (t *TransGlobal) getHTTPResult(uri string, branchID, op string, branchPayload []byte) error {
@ -192,7 +193,7 @@ func (t *TransGlobal) getJSONRPCResult(uri string, branchID, op string, branchPa
return err
}
func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPayload []byte) error {
func (t *TransGlobal) getGrpcResult(ctx context.Context, uri string, branchID, op string, branchPayload []byte) error {
// grpc handler
server, method, err := dtmdriver.GetDriver().ParseServerMethod(uri)
if err != nil {
@ -200,7 +201,7 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo
}
conn := dtmgimp.MustGetGrpcConn(server, true)
ctx := dtmgimp.TransInfo2Ctx(t.Context, t.Gid, t.TransType, branchID, op, "")
ctx = dtmgimp.TransInfo2Ctx(ctx, t.Gid, t.TransType, branchID, op, "")
kvs := dtmgimp.Map2Kvs(t.Ext.Headers)
kvs = append(kvs, dtmgimp.Map2Kvs(t.BranchHeaders)...)
ctx = metadata.AppendToOutgoingContext(ctx, kvs...)
@ -212,8 +213,8 @@ func (t *TransGlobal) getGrpcResult(uri string, branchID, op string, branchPaylo
return dtmgrpc.GrpcError2DtmError(err)
}
func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) {
err := t.getURLResult(branch.URL, branch.BranchID, branch.Op, branch.BinData)
func (t *TransGlobal) getBranchResult(ctx context.Context, branch *TransBranch) (string, error) {
err := t.getURLResult(ctx, branch.URL, branch.BranchID, branch.Op, branch.BinData)
if err == nil {
return dtmcli.StatusSucceed, nil
} else if t.TransType == "saga" && branch.Op == dtmimp.OpAction && errors.Is(err, dtmcli.ErrFailure) {
@ -225,8 +226,8 @@ func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) {
return "", fmt.Errorf("your http/grpc result should be specified as in:\nhttp://d.dtm.pub/practice/arch.html#proto\nunkown result will be retried: %w", err)
}
func (t *TransGlobal) execBranch(branch *TransBranch, branchPos int) error {
status, err := t.getBranchResult(branch)
func (t *TransGlobal) execBranch(ctx context.Context, branch *TransBranch, branchPos int) error {
status, err := t.getBranchResult(ctx, branch)
if status != "" {
t.changeBranchStatus(branch, status, branchPos)
}

18
dtmsvr/trans_type_msg.go

@ -7,6 +7,7 @@
package dtmsvr
import (
"context"
"errors"
"fmt"
"strings"
@ -51,11 +52,11 @@ type cMsgCustom struct {
Delay uint64 //delay call branch, unit second
}
func (t *TransGlobal) mayQueryPrepared() {
func (t *TransGlobal) mayQueryPrepared(ctx context.Context) {
if !t.needProcess() || t.Status == dtmcli.StatusSubmitted {
return
}
err := t.getURLResult(t.QueryPrepared, "00", "msg", nil)
err := t.getURLResult(ctx, t.QueryPrepared, "00", "msg", nil)
if err == nil {
t.changeStatus(dtmcli.StatusSubmitted)
} else if errors.Is(err, dtmcli.ErrFailure) {
@ -68,8 +69,8 @@ func (t *TransGlobal) mayQueryPrepared() {
}
}
func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error {
t.mayQueryPrepared()
func (t *transMsgProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error {
t.mayQueryPrepared(ctx)
if !t.needProcess() || t.Status == dtmcli.StatusPrepared {
return nil
}
@ -91,12 +92,13 @@ func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error {
continue
}
if t.Concurrent {
copyCtx := NewAsyncContext(ctx)
started++
go func(pos int) {
resultsChan <- t.execBranch(b, pos)
}(i)
go func(ctx context.Context, pos int) {
resultsChan <- t.execBranch(ctx, b, pos)
}(copyCtx, i)
} else {
err = t.execBranch(b, i)
err = t.execBranch(ctx, b, i)
if err != nil {
break
}

13
dtmsvr/trans_type_saga.go

@ -1,6 +1,7 @@
package dtmsvr
import (
"context"
"errors"
"fmt"
"time"
@ -52,7 +53,7 @@ type branchResult struct {
err error
}
func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error {
func (t *transSagaProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error {
// when saga tasks is fetched, it always need to process
logger.Debugf("status: %s timeout: %t", t.Status, t.isTimeout())
if t.Status == dtmcli.StatusSubmitted && t.isTimeout() {
@ -121,7 +122,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error {
return true
}
resultChan := make(chan branchResult, n)
asyncExecBranch := func(i int) {
asyncExecBranch := func(ctx context.Context, i int) {
var err error
defer func() {
if x := recover(); x != nil {
@ -132,7 +133,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error {
logger.Errorf("exec branch %s %s %s error: %v", branches[i].BranchID, branches[i].Op, branches[i].URL, err)
}
}()
err = t.execBranch(&branches[i], i)
err = t.execBranch(ctx, &branches[i], i)
}
pickToRunActions := func() []int {
toRun := []int{}
@ -162,7 +163,8 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error {
if branchResults[b].op == dtmimp.OpAction {
rsAStarted++
}
go asyncExecBranch(b)
copyCtx := NewAsyncContext(ctx)
go asyncExecBranch(copyCtx, b)
}
}
waitDoneOnce := func() {
@ -178,7 +180,8 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error {
t.RetryCount++
logger.Infof("Retrying branch %s %s %s, t.RetryLimit: %d, t.RetryCount: %d",
branches[r.index].BranchID, branches[r.index].Op, branches[r.index].URL, t.RetryLimit, t.RetryCount)
go asyncExecBranch(r.index)
copyCtx := NewAsyncContext(ctx)
go asyncExecBranch(copyCtx, r.index)
break
}
// if t.RetryCount = t.RetryLimit, trans will be aborted

5
dtmsvr/trans_type_tcc.go

@ -1,6 +1,7 @@
package dtmsvr
import (
"context"
"fmt"
"github.com/dtm-labs/dtm/client/dtmcli"
@ -20,7 +21,7 @@ func (t *transTccProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error {
func (t *transTccProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error {
if !t.needProcess() {
return nil
}
@ -31,7 +32,7 @@ func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error {
for current := len(branches) - 1; current >= 0; current-- {
if branches[current].Op == op && branches[current].Status == dtmcli.StatusPrepared {
logger.Debugf("branch info: current: %d ID: %d", current, branches[current].ID)
err := t.execBranch(&branches[current], current)
err := t.execBranch(ctx, &branches[current], current)
if err != nil {
return err
}

6
dtmsvr/trans_type_workflow.go

@ -1,6 +1,8 @@
package dtmsvr
import (
"context"
"github.com/dtm-labs/dtm/client/dtmcli"
"github.com/dtm-labs/dtm/client/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/client/dtmgrpc/dtmgimp"
@ -24,7 +26,7 @@ type cWorkflowCustom struct {
Data []byte `json:"data"`
}
func (t *transWorkflowProcessor) ProcessOnce(branches []TransBranch) error {
func (t *transWorkflowProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error {
if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed {
return nil
}
@ -36,5 +38,5 @@ func (t *transWorkflowProcessor) ProcessOnce(branches []TransBranch) error {
wd := wfpb.WorkflowData{Data: cmc.Data}
data = dtmgimp.MustProtoMarshal(&wd)
}
return t.getURLResult(t.QueryPrepared, "00", cmc.Name, data)
return t.getURLResult(ctx, t.QueryPrepared, "00", cmc.Name, data)
}

5
dtmsvr/trans_type_xa.go

@ -1,6 +1,7 @@
package dtmsvr
import (
"context"
"fmt"
"github.com/dtm-labs/dtm/client/dtmcli"
@ -19,7 +20,7 @@ func (t *transXaProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error {
func (t *transXaProcessor) ProcessOnce(ctx context.Context, branches []TransBranch) error {
if !t.needProcess() {
return nil
}
@ -29,7 +30,7 @@ func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error {
currentType := dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmimp.OpCommit, dtmimp.OpRollback).(string)
for i, branch := range branches {
if branch.Op == currentType && branch.Status != dtmcli.StatusSucceed {
err := t.execBranch(&branch, i)
err := t.execBranch(ctx, &branch, i)
if err != nil {
return err
}

22
dtmsvr/utils.go

@ -7,6 +7,7 @@
package dtmsvr
import (
"context"
"fmt"
"time"
@ -50,3 +51,24 @@ func GetTransGlobal(gid string) *TransGlobal {
//nolint:staticcheck
return &TransGlobal{TransGlobalStore: *trans}
}
type asyncCtx struct {
context.Context
}
func (a *asyncCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (a *asyncCtx) Done() <-chan struct{} {
return nil
}
// NewAsyncContext create a new async context
// the context will not be canceled when the parent context is canceled
func NewAsyncContext(ctx context.Context) context.Context {
if ctx == nil {
return nil
}
return &asyncCtx{Context: ctx}
}

87
dtmsvr/utils_test.go

@ -7,9 +7,12 @@
package dtmsvr
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)
func TestUtils(t *testing.T) {
@ -29,3 +32,87 @@ func TestSetNextCron(t *testing.T) {
tg.TimeoutToFail = 3
assert.Equal(t, int64(3), tg.getNextCronInterval(cronReset))
}
func TestNewAsyncContext(t *testing.T) {
var key testContextType = "key"
var value testContextType = "value"
ctxWithValue := context.WithValue(context.Background(), key, value)
newCtx := NewAsyncContext(ctxWithValue)
assert.Equal(t, ctxWithValue.Value(key), newCtx.Value(key))
var ctx context.Context
newCtx = NewAsyncContext(ctx)
assert.Nil(t, newCtx)
}
func TestAsyncContext(t *testing.T) {
ctx := context.Background()
cancelCtx2, cancel := context.WithCancel(ctx)
async := NewAsyncContext(cancelCtx2)
cancelCtx3, cancel2 := context.WithCancel(async)
defer cancel2()
cancel()
select {
case <-cancelCtx2.Done():
default:
assert.Failf(t, "context should be canceled", "context should be canceled")
}
select {
case <-cancelCtx3.Done():
assert.Failf(t, "context should not be canceled", "context should not be canceled")
default:
}
}
type testContextType string
func TestAsyncContextRecursive(t *testing.T) {
var key testContextType = "key"
var key2 testContextType = "key2"
var key3 testContextType = "key3"
var value testContextType = "value"
var value2 testContextType = "value2"
var value3 testContextType = "value3"
var nestedKey testContextType = "nested_key"
var nestedValue testContextType = "nested_value"
ctxWithValue := context.WithValue(context.Background(), key, value)
nestedCtx := context.WithValue(ctxWithValue, nestedKey, nestedValue)
cancelCtxx, cancel := context.WithCancel(nestedCtx)
defer cancel()
timerCtxx, cancel2 := context.WithTimeout(cancelCtxx, time.Duration(10)*time.Second)
defer cancel2()
timer2 := context.WithValue(timerCtxx, key2, value2)
timer3 := context.WithValue(timer2, key3, value3)
newCtx := NewAsyncContext(timer3)
assert.Equal(t, timer3.Value(nestedKey), newCtx.Value(nestedKey))
assert.Equal(t, timer3.Value(key), newCtx.Value(key))
assert.Equal(t, timer3.Value(key2), newCtx.Value(key2))
assert.Equal(t, timer3.Value(key3), newCtx.Value(key3))
}
func TestCopyContextWithMetadata(t *testing.T) {
md := metadata.New(map[string]string{"key": "value"})
ctx := metadata.NewIncomingContext(context.Background(), md)
ctx = metadata.NewOutgoingContext(ctx, md)
newCtx := NewAsyncContext(ctx)
copiedMD, ok := metadata.FromIncomingContext(newCtx)
assert.True(t, ok)
assert.Equal(t, 1, len(copiedMD["key"]))
assert.Equal(t, "value", copiedMD["key"][0])
copiedMD, ok = metadata.FromOutgoingContext(newCtx)
assert.True(t, ok)
assert.Equal(t, 1, len(copiedMD["key"]))
assert.Equal(t, "value", copiedMD["key"][0])
}
func BenchmarkCopyContext(b *testing.B) {
var key testContextType = "key"
var value testContextType = "value"
ctx := context.WithValue(context.Background(), key, value)
b.ResetTimer()
for i := 0; i < b.N; i++ {
NewAsyncContext(ctx)
}
}

3
test/saga_grpc_test.go

@ -7,6 +7,7 @@
package test
import (
"context"
"testing"
"github.com/dtm-labs/dtm/client/dtmcli"
@ -94,7 +95,7 @@ func TestSagaGrpcEmptyUrl(t *testing.T) {
// nolint: unparam
func genSagaGrpc(gid string, outFailed bool, inFailed bool) *dtmgrpc.SagaGrpc {
saga := dtmgrpc.NewSagaGrpc(dtmutil.DefaultGrpcServer, gid)
saga := dtmgrpc.NewSagaGrpcWithContext(context.Background(), dtmutil.DefaultGrpcServer, gid)
req := busi.GenReqGrpc(30, outFailed, inFailed)
saga.Add(busi.BusiGrpc+"/busi.Busi/TransOut", busi.BusiGrpc+"/busi.Busi/TransOutRevert", req)
saga.Add(busi.BusiGrpc+"/busi.Busi/TransIn", busi.BusiGrpc+"/busi.Busi/TransInRevert", req)

Loading…
Cancel
Save