diff --git a/dtmcli/dtmimp/trans_base.go b/dtmcli/dtmimp/trans_base.go index 3024ead..fba19cc 100644 --- a/dtmcli/dtmimp/trans_base.go +++ b/dtmcli/dtmimp/trans_base.go @@ -48,6 +48,7 @@ type TransOptions struct { RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc PassthroughHeaders []string `json:"passthrough_headers,omitempty" gorm:"-"` BranchHeaders map[string]string `json:"branch_headers,omitempty" gorm:"-"` + Concurrent bool `json:"concurrent" gorm:"-"` // for trans type: saga msg } // TransBase base for all trans diff --git a/dtmcli/saga.go b/dtmcli/saga.go index 5fcd331..a705fbe 100644 --- a/dtmcli/saga.go +++ b/dtmcli/saga.go @@ -13,8 +13,7 @@ import ( // Saga struct of saga type Saga struct { dtmimp.TransBase - orders map[int][]int - concurrent bool + orders map[int][]int } // NewSaga create a saga @@ -35,9 +34,9 @@ func (s *Saga) AddBranchOrder(branch int, preBranches []int) *Saga { return s } -// EnableConcurrent enable the concurrent exec of sub trans -func (s *Saga) EnableConcurrent() *Saga { - s.concurrent = true +// SetConcurrent enable the concurrent exec of sub trans +func (s *Saga) SetConcurrent() *Saga { + s.Concurrent = true return s } @@ -49,7 +48,7 @@ func (s *Saga) Submit() error { // BuildCustomOptions add custom options to the request context func (s *Saga) BuildCustomOptions() { - if s.concurrent { - s.CustomData = dtmimp.MustMarshalString(map[string]interface{}{"orders": s.orders, "concurrent": s.concurrent}) + if s.Concurrent { + s.CustomData = dtmimp.MustMarshalString(map[string]interface{}{"orders": s.orders, "concurrent": s.Concurrent}) } } diff --git a/dtmgrpc/saga.go b/dtmgrpc/saga.go index 4d206fe..a0cc9d4 100644 --- a/dtmgrpc/saga.go +++ b/dtmgrpc/saga.go @@ -37,7 +37,7 @@ func (s *SagaGrpc) AddBranchOrder(branch int, preBranches []int) *SagaGrpc { // EnableConcurrent enable the concurrent exec of sub trans func (s *SagaGrpc) EnableConcurrent() *SagaGrpc { - s.Saga.EnableConcurrent() + s.Saga.SetConcurrent() return s } diff --git a/dtmsvr/trans_type_msg.go b/dtmsvr/trans_type_msg.go index 6cb4679..8343753 100644 --- a/dtmsvr/trans_type_msg.go +++ b/dtmsvr/trans_type_msg.go @@ -65,7 +65,6 @@ func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { if !t.needProcess() || t.Status == dtmcli.StatusPrepared { return nil } - cmc := cMsgCustom{Delay: 0} if t.CustomData != "" { dtmimp.MustUnmarshalString(t.CustomData, &cmc) @@ -75,22 +74,66 @@ func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { t.touchCronTime(cronKeep, cmc.Delay) return nil } - - current := 0 // 当前正在处理的步骤 - for ; current < len(branches); current++ { + execBranch := func(current int) (bool, error) { branch := &branches[current] if branch.Op != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared { - continue + return true, nil } err := t.execBranch(branch, current) if err != nil { - return err + if !errors.Is(err, dtmcli.ErrOngoing) { + logger.Errorf("exec branch error: %v", err) + } + return false, err } if branch.Status != dtmcli.StatusSucceed { - break + return false, nil + } + return true, nil + } + type branchResult struct { + success bool + err error + } + waitChan := make(chan branchResult, len(branches)) + consumeWork := func(i int) error { + success, err := execBranch(i) + waitChan <- branchResult{ + success: success, + err: err, + } + return err + } + produceWork := func() { + for i := 0; i < len(branches); i++ { + if t.Concurrent { + go func(i int) { + _ = consumeWork(i) + }(i) + continue + } + err := consumeWork(i) + if err != nil { + return + } + } + } + go produceWork() + successCnt := 0 + var err error + for i := 0; i < len(branches); i++ { + result := <-waitChan + if result.err != nil { + err = result.err + if !t.Concurrent { + return err + } + } + if result.success { + successCnt++ } } - if current == len(branches) { // msg 事务完成 + if successCnt == len(branches) { // msg 事务完成 t.changeStatus(dtmcli.StatusSucceed) return nil } diff --git a/test/msg_options_test.go b/test/msg_options_test.go index 3d1ded8..7597ec5 100644 --- a/test/msg_options_test.go +++ b/test/msg_options_test.go @@ -51,3 +51,13 @@ func TestMsgOptionsTimeoutFailed(t *testing.T) { cronTransOnceForwardNow(t, gid, 180) assert.Equal(t, StatusFailed, getTransStatus(msg.Gid)) } + +func TestMsgConcurrent(t *testing.T) { + msg := genMsg(dtmimp.GetFuncName()) + msg.Concurrent = true + msg.Submit() + assert.Equal(t, StatusSubmitted, getTransStatus(msg.Gid)) + waitTransProcessed(msg.Gid) + assert.Equal(t, []string{StatusSucceed, StatusSucceed}, getBranchesStatus(msg.Gid)) + assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid)) +} diff --git a/test/saga_concurrent_test.go b/test/saga_concurrent_test.go index f130a8b..be7a57a 100644 --- a/test/saga_concurrent_test.go +++ b/test/saga_concurrent_test.go @@ -16,7 +16,7 @@ import ( ) func genSagaCon(gid string, outFailed bool, inFailed bool) *dtmcli.Saga { - return genSaga(gid, outFailed, inFailed).EnableConcurrent() + return genSaga(gid, outFailed, inFailed).SetConcurrent() } func TestSagaConNormal(t *testing.T) {