From a3d8bf4be4d052b87e138486f2d71c902f9cc2ff Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Fri, 29 Oct 2021 10:20:49 +0800 Subject: [PATCH] add concurrent saga --- common/types.go | 32 +++++- common/types_test.go | 23 +++++ common/utils.go | 5 +- dtmcli/concurrent_saga.go | 29 ++++++ dtmcli/types.go | 28 +++-- dtmcli/utils.go | 18 ++-- dtmcli/utils_test.go | 7 +- dtmgrpc/barrier.go | 17 ---- dtmgrpc/type.go | 9 +- dtmsvr/api.go | 22 ++-- dtmsvr/api_grpc.go | 4 +- dtmsvr/api_http.go | 17 ++-- dtmsvr/cron.go | 12 ++- dtmsvr/dtmsvr.go | 2 +- dtmsvr/trans.go | 143 +++++++++++++++++++------- dtmsvr/trans_concurrent_saga.go | 175 ++++++++++++++++++++++++++++++++ dtmsvr/trans_msg.go | 22 ++-- dtmsvr/trans_saga.go | 29 ++++-- dtmsvr/trans_tcc.go | 12 ++- dtmsvr/trans_xa.go | 12 ++- dtmsvr/utils.go | 6 +- dtmsvr/utils_test.go | 12 +++ examples/base_grpc.go | 6 +- examples/base_http.go | 9 ++ examples/grpc_saga_barrier.go | 2 +- examples/http_msg.go | 2 +- examples/http_saga.go | 2 +- test/base_test.go | 53 ++++++++++ test/dtmsvr_test.go | 36 +------ test/grpc_msg_test.go | 12 +-- test/grpc_saga_test.go | 12 +-- test/grpc_tcc_test.go | 4 +- test/main_test.go | 3 +- test/msg_test.go | 22 ++-- test/saga_concurrent_test.go | 72 +++++++++++++ test/saga_test.go | 12 +-- test/tcc_test.go | 4 +- test/types.go | 9 ++ test/wait_saga_test.go | 14 +-- 39 files changed, 691 insertions(+), 219 deletions(-) create mode 100644 dtmcli/concurrent_saga.go create mode 100644 dtmsvr/trans_concurrent_saga.go create mode 100644 test/base_test.go create mode 100644 test/saga_concurrent_test.go diff --git a/common/types.go b/common/types.go index 5ada90b..fc4c429 100644 --- a/common/types.go +++ b/common/types.go @@ -22,7 +22,7 @@ import ( // ModelBase model base for gorm to provide base fields type ModelBase struct { - ID uint + ID uint64 CreateTime *time.Time `gorm:"autoCreateTime"` UpdateTime *time.Time `gorm:"autoUpdateTime"` } @@ -123,7 +123,9 @@ func DbGet(conf map[string]string) *DB { } type dtmConfigType struct { - TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务 + TransCronInterval int64 `yaml:"TransCronInterval"` + TimeoutToFail int64 `yaml:"TimeoutToFail"` + RetryInterval int64 `yaml:"RetryInterval"` DB map[string]string `yaml:"DB"` DisableLocalhost int64 `yaml:"DisableLocalhost"` UpdateBranchSync int64 `yaml:"UpdateBranchSync"` @@ -140,7 +142,9 @@ func init() { if len(os.Args) == 1 { return } - DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "10") + DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "3") + DtmConfig.TimeoutToFail = getIntEnv("TIMEOUT_TO_FAIL", "10") + DtmConfig.RetryInterval = getIntEnv("RETRY_INTERVAL", "10") DtmConfig.DB = map[string]string{ "driver": dtmcli.OrString(os.Getenv("DB_DRIVER"), "mysql"), "host": os.Getenv("DB_HOST"), @@ -166,6 +170,24 @@ func init() { err := yaml.Unmarshal(cont, &DtmConfig) dtmcli.FatalIfError(err) } - dtmcli.LogIfFatalf(DtmConfig.DB["driver"] == "" || DtmConfig.DB["user"] == "", - "dtm配置错误. 请访问 http://dtm.pub 查看部署运维环节. check you env, and conf.yml/conf.sample.yml in current and parent path: %s. config is: \n%v", MustGetwd(), DtmConfig) + errStr := checkConfig() + dtmcli.LogIfFatalf(errStr != "", + `config error: '%s'. +check you env, and conf.yml/conf.sample.yml in current and parent path: %s. +please visit http://d.dtm.pub to see the config document. +loaded config is: +%v`, MustGetwd(), DtmConfig) +} + +func checkConfig() string { + if DtmConfig.DB["driver"] == "" { + return "db driver empty" + } else if DtmConfig.DB["user"] == "" || DtmConfig.DB["host"] == "" { + return "db config not valid" + } else if DtmConfig.RetryInterval < 10 { + return "RetryInterval should not be less than 10" + } else if DtmConfig.TimeoutToFail < DtmConfig.RetryInterval { + return "TimeoutToFail should not be less than RetryInterval" + } + return "" } diff --git a/common/types_test.go b/common/types_test.go index 20ccc38..4dd9c0b 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -30,3 +30,26 @@ func TestDbAlone(t *testing.T) { _, err = dtmcli.DBExec(db, "select 1") assert.NotEqual(t, nil, err) } + +func TestConfig(t *testing.T) { + testConfigStringField(DtmConfig.DB, "driver", "", t) + testConfigStringField(DtmConfig.DB, "user", "", t) + testConfigIntField(&DtmConfig.RetryInterval, 9, t) + testConfigIntField(&DtmConfig.TimeoutToFail, 9, t) +} + +func testConfigStringField(m map[string]string, key string, val string, t *testing.T) { + old := m[key] + m[key] = val + str := checkConfig() + assert.NotEqual(t, "", str) + m[key] = old +} + +func testConfigIntField(fd *int64, val int64, t *testing.T) { + old := *fd + *fd = val + str := checkConfig() + assert.NotEqual(t, "", str) + *fd = old +} diff --git a/common/utils.go b/common/utils.go index 14a89c0..5c9b174 100644 --- a/common/utils.go +++ b/common/utils.go @@ -43,7 +43,10 @@ func GetGinApp() *gin.Engine { // WrapHandler name is clear func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { return func(c *gin.Context) { - r, err := fn(c) + r, err := func() (r interface{}, rerr error) { + defer dtmcli.P2E(&rerr) + return fn(c) + }() var b = []byte{} if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理 b = resp.Body() diff --git a/dtmcli/concurrent_saga.go b/dtmcli/concurrent_saga.go new file mode 100644 index 0000000..202fd24 --- /dev/null +++ b/dtmcli/concurrent_saga.go @@ -0,0 +1,29 @@ +package dtmcli + +import "fmt" + +// ConcurrentSaga struct of concurrent saga +type ConcurrentSaga struct { + Saga + orders map[int][]int +} + +// NewConcurrentSaga create a concurrent saga +func NewConcurrentSaga(server string, gid string) *ConcurrentSaga { + return &ConcurrentSaga{Saga: Saga{TransBase: *NewTransBase(gid, "csaga", server, "")}, orders: map[int][]int{}} +} + +// AddStepOrder specify that step should be after preSteps. Step is larger than all the element in preSteps +func (s *ConcurrentSaga) AddStepOrder(step int, preSteps []int) *ConcurrentSaga { + PanicIf(step > len(s.Steps), fmt.Errorf("step value: %d is invalid. which cannot be larger than total steps: %d", step, len(s.Steps))) + s.orders[step] = preSteps + return s +} + +// Submit submit the saga trans +func (s *ConcurrentSaga) Submit() error { + if len(s.orders) > 0 { + s.CustomData = MustMarshalString(M{"orders": s.orders}) + } + return s.callDtm(s, "submit") +} diff --git a/dtmcli/types.go b/dtmcli/types.go index e21994b..a8b30b9 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -51,14 +51,26 @@ type TransResult struct { Message string } +// TransOptions transaction options +type TransOptions struct { + WaitResult bool `json:"wait_result,omitempty" gorm:"-"` + TimeoutToFail int64 `json:"timeout_to_fail,omitempty" gorm:"-"` // for trans type: xa, tcc + RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc +} + // TransBase 事务的基础类 type TransBase struct { - Gid string `json:"gid"` - TransType string `json:"trans_type"` + Gid string `json:"gid"` + TransType string `json:"trans_type"` + Dtm string `json:"-"` + CustomData string `json:"custom_data,omitempty"` IDGenerator - Dtm string - // WaitResult 是否等待全局事务的最终结果 - WaitResult bool + TransOptions +} + +// SetOptions set options +func (tb *TransBase) SetOptions(options *TransOptions) { + tb.TransOptions = *options } // NewTransBase 1 @@ -95,10 +107,10 @@ func (tb *TransBase) callDtm(body interface{}, operation string) error { } // ErrFailure 表示返回失败,要求回滚 -var ErrFailure = errors.New("transaction FAILURE") +var ErrFailure = errors.New("FAILURE") -// ErrPending 表示暂时失败,要求重试 -var ErrPending = errors.New("transaction PENDING") +// ErrOngoing 表示暂时失败,要求重试 +var ErrOngoing = errors.New("ONGOING") // MapSuccess 表示返回成功,可以进行下一步 var MapSuccess = M{"dtm_result": ResultSuccess} diff --git a/dtmcli/utils.go b/dtmcli/utils.go index 8fec13e..e9d511a 100644 --- a/dtmcli/utils.go +++ b/dtmcli/utils.go @@ -17,14 +17,18 @@ import ( "github.com/go-resty/resty/v2" ) +// AsError wrap a panic value as an error +func AsError(x interface{}) error { + if e, ok := x.(error); ok { + return e + } + return fmt.Errorf("%v", x) +} + // P2E panic to error func P2E(perr *error) { if x := recover(); x != nil { - if e, ok := x.(error); ok { - *perr = e - } else { - panic(x) - } + *perr = AsError(x) } } @@ -262,8 +266,8 @@ func CheckResult(res interface{}, err error) error { str := MustMarshalString(res) if strings.Contains(str, ResultFailure) { return ErrFailure - } else if strings.Contains(str, "PENDING") { - return ErrPending + } else if strings.Contains(str, ResultOngoing) { + return ErrOngoing } } return err diff --git a/dtmcli/utils_test.go b/dtmcli/utils_test.go index a59d628..c0954c3 100644 --- a/dtmcli/utils_test.go +++ b/dtmcli/utils_test.go @@ -25,13 +25,10 @@ func TestEP(t *testing.T) { }) assert.Equal(t, "err2", err.Error()) err = func() (rerr error) { - defer func() { - x := recover() - assert.Equal(t, 1, x) - }() defer P2E(&rerr) - panic(1) + panic("raw_string") }() + assert.Equal(t, "raw_string", err.Error()) } func TestTernary(t *testing.T) { diff --git a/dtmgrpc/barrier.go b/dtmgrpc/barrier.go index 1389da6..f24fc46 100644 --- a/dtmgrpc/barrier.go +++ b/dtmgrpc/barrier.go @@ -2,8 +2,6 @@ package dtmgrpc import ( "github.com/yedf/dtm/dtmcli" - "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" ) // BranchBarrier 子事务屏障 @@ -11,21 +9,6 @@ type BranchBarrier struct { *dtmcli.BranchBarrier } -// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 -// db: 本地数据库 -// transInfo: 事务信息 -// bisiCall: 业务函数,仅在必要时被调用 -// 返回值: -// 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚 -// 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行 -func (bb *BranchBarrier) Call(tx dtmcli.Tx, busiCall dtmcli.BusiFunc) (rerr error) { - err := bb.BranchBarrier.Call(tx, busiCall) - if err == dtmcli.ErrFailure { - return status.New(codes.Aborted, "user rollback").Err() - } - return err -} - // BarrierFromGrpc 从BusiRequest生成一个Barrier func BarrierFromGrpc(in *BusiRequest) (*BranchBarrier, error) { b, err := dtmcli.BarrierFrom(in.Info.TransType, in.Info.Gid, in.Info.BranchID, in.Info.BranchType) diff --git a/dtmgrpc/type.go b/dtmgrpc/type.go index b0539fe..fdc21db 100644 --- a/dtmgrpc/type.go +++ b/dtmgrpc/type.go @@ -9,7 +9,7 @@ import ( "github.com/yedf/dtm/dtmcli" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" ) @@ -89,9 +89,10 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c func Result2Error(res interface{}, err error) error { e := dtmcli.CheckResult(res, err) if e == dtmcli.ErrFailure { - return status.New(codes.Aborted, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err() - } else if e == dtmcli.ErrPending { - return status.New(codes.Unavailable, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err() + dtmcli.LogRedf("failure: res: %v, err: %v", res, e) + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() + } else if e == dtmcli.ErrOngoing { + return status.New(codes.Aborted, dtmcli.ResultOngoing).Err() } return e } diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 3c87ed1..beb5d46 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm/clause" ) -func svcSubmit(t *TransGlobal, waitResult bool) (interface{}, error) { +func svcSubmit(t *TransGlobal) (interface{}, error) { db := dbGet() t.Status = dtmcli.StatusSubmitted err := t.saveNew(db) @@ -15,13 +15,13 @@ func svcSubmit(t *TransGlobal, waitResult bool) (interface{}, error) { if err == errUniqueConflict { dbt := TransFromDb(db, t.Gid) if dbt.Status == dtmcli.StatusPrepared { - updates := t.setNextCron(config.TransCronInterval) + updates := t.setNextCron(cronReset) db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) } else if dbt.Status != dtmcli.StatusSubmitted { - return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot sumbmit", dbt.Status)}, nil + return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil } } - return t.Process(db, waitResult), nil + return t.Process(db), nil } func svcPrepare(t *TransGlobal) (interface{}, error) { @@ -30,19 +30,19 @@ func svcPrepare(t *TransGlobal) (interface{}, error) { if err == errUniqueConflict { dbt := TransFromDb(dbGet(), t.Gid) if dbt.Status != dtmcli.StatusPrepared { - return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot prepare", dbt.Status)}, nil + return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil } } return dtmcli.MapSuccess, nil } -func svcAbort(t *TransGlobal, waitResult bool) (interface{}, error) { +func svcAbort(t *TransGlobal) (interface{}, error) { db := dbGet() dbt := TransFromDb(db, t.Gid) - if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != "aborting" { - return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: %s current status %s, cannot abort", dbt.TransType, dbt.Status)}, nil + if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting { + return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil } - return dbt.Process(db, waitResult), nil + return dbt.Process(db), nil } func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, error) { @@ -62,7 +62,7 @@ func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, err DoNothing: true, }).Create(branches) global := TransGlobal{Gid: branch.Gid} - global.touch(dbGet(), config.TransCronInterval) + global.touch(dbGet(), cronKeep) return dtmcli.MapSuccess, nil } @@ -80,6 +80,6 @@ func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) { DoNothing: true, }).Create(branches) global := TransGlobal{Gid: branch.Gid} - global.touch(db, config.TransCronInterval) + global.touch(db, cronKeep) return dtmcli.MapSuccess, nil } diff --git a/dtmsvr/api_grpc.go b/dtmsvr/api_grpc.go index 074faf6..4b44cdf 100644 --- a/dtmsvr/api_grpc.go +++ b/dtmsvr/api_grpc.go @@ -19,7 +19,7 @@ func (s *dtmServer) NewGid(ctx context.Context, in *emptypb.Empty) (*dtmgrpc.Dtm } func (s *dtmServer) Submit(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcSubmit(TransFromDtmRequest(in), in.WaitResult) + r, err := svcSubmit(TransFromDtmRequest(in)) return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err) } @@ -29,7 +29,7 @@ func (s *dtmServer) Prepare(ctx context.Context, in *pb.DtmRequest) (*emptypb.Em } func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcAbort(TransFromDtmRequest(in), in.WaitResult) + r, err := svcAbort(TransFromDtmRequest(in)) return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err) } diff --git a/dtmsvr/api_http.go b/dtmsvr/api_http.go index ec59e85..7d2b5af 100644 --- a/dtmsvr/api_http.go +++ b/dtmsvr/api_http.go @@ -2,6 +2,7 @@ package dtmsvr import ( "errors" + "math" "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" @@ -10,14 +11,14 @@ import ( ) func addRoute(engine *gin.Engine) { + engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid)) engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare)) engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit)) + engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort)) engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch)) engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch)) - engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort)) engine.GET("/api/dtmsvr/query", common.WrapHandler(query)) engine.GET("/api/dtmsvr/all", common.WrapHandler(all)) - engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid)) } func newGid(c *gin.Context) (interface{}, error) { @@ -29,11 +30,11 @@ func prepare(c *gin.Context) (interface{}, error) { } func submit(c *gin.Context) (interface{}, error) { - return svcSubmit(TransFromContext(c), c.Query("wait_result") == "1") + return svcSubmit(TransFromContext(c)) } func abort(c *gin.Context) (interface{}, error) { - return svcAbort(TransFromContext(c), c.Query("wait_result") == "1") + return svcAbort(TransFromContext(c)) } func registerXaBranch(c *gin.Context) (interface{}, error) { @@ -74,11 +75,11 @@ func query(c *gin.Context) (interface{}, error) { } func all(c *gin.Context) (interface{}, error) { - lastId := c.Query("last_id") - if lastId == "" { - lastId = "2000000000" + lastID := c.Query("last_id") + lid := math.MaxInt64 + if lastID != "" { + lid = dtmcli.MustAtoi(lastID) } - lid := dtmcli.MustAtoi(lastId) trans := []TransGlobal{} dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(100).Find(&trans) return M{"transactions": trans}, nil diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 8880511..7aeda41 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -10,7 +10,10 @@ import ( "github.com/yedf/dtm/dtmcli" ) -// CronForwardDuration will be set in test, cron will fetch trans which expire in CronForwardDuration +// NowForwardDuration will be set in test, trans may be timeout +var NowForwardDuration time.Duration = time.Duration(0) + +// CronForwardDuration will be set in test. cron will fetch trans which expire in CronForwardDuration var CronForwardDuration time.Duration = time.Duration(0) // CronTransOnce cron expired trans. use expireIn as expire time @@ -24,7 +27,8 @@ func CronTransOnce() (hasTrans bool) { if TransProcessedTestChan != nil { defer WaitTransProcessed(trans.Gid) } - trans.Process(dbGet(), true) + trans.WaitResult = true + trans.Process(dbGet()) return } @@ -44,7 +48,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal { db := dbGet() getTime := dtmcli.GetDBSpecial().TimestampAdd expire := int(expireIn / time.Second) - whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3)) + whereTime := fmt.Sprintf("next_cron_time < %s and update_time < %s", getTime(expire), getTime(expire-3)) // 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢 // 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出 dbr := db.Must().Model(&trans). @@ -53,7 +57,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal { return nil } dbr = db.Must().Where("owner=?", owner).Find(&trans) - updates := trans.setNextCron(trans.NextCronInterval * 2) // 下次被cron的间隔加倍 + updates := trans.setNextCron(cronKeep) db.Must().Model(&trans).Select(updates).Updates(&trans) return &trans } diff --git a/dtmsvr/dtmsvr.go b/dtmsvr/dtmsvr.go index 216bb3a..ee0b35d 100644 --- a/dtmsvr/dtmsvr.go +++ b/dtmsvr/dtmsvr.go @@ -70,7 +70,7 @@ func updateBranchAsync() { updates = append(updates, TransBranch{ ModelBase: common.ModelBase{ID: updateBranch.id}, Status: updateBranch.status, - FinishTime: updateBranch.finish_time, + FinishTime: updateBranch.finishTime, }) case <-time.After(checkInterval): } diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 1948beb..4a4b2d2 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -32,9 +32,12 @@ type TransGlobal struct { CommitTime *time.Time FinishTime *time.Time RollbackTime *time.Time + Options string + CustomData string `json:"custom_data"` NextCronInterval int64 NextCronTime *time.Time - processStarted time.Time // record the start time of process + dtmcli.TransOptions + processStarted time.Time // record the start time of process } // TableName TableName @@ -44,12 +47,12 @@ func (*TransGlobal) TableName() string { type transProcessor interface { GenBranches() []TransBranch - ProcessOnce(db *common.DB, branches []TransBranch) + ProcessOnce(db *common.DB, branches []TransBranch) error } -func (t *TransGlobal) touch(db *common.DB, interval int64) *gorm.DB { +func (t *TransGlobal) touch(db *common.DB, ctype cronType) *gorm.DB { writeTransLog(t.Gid, "touch trans", "", "", "") - updates := t.setNextCron(interval) + updates := t.setNextCron(ctype) return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Select(updates).Updates(t) } @@ -57,7 +60,7 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { writeTransLog(t.Gid, "change status", status, "", "") old := t.Status t.Status = status - updates := t.setNextCron(config.TransCronInterval) + updates := t.setNextCron(cronReset) updates = append(updates, "status") now := time.Now() if status == dtmcli.StatusSucceed { @@ -72,6 +75,28 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { return dbr } +func (t *TransGlobal) isTimeout() bool { + timeout := t.TimeoutToFail + if t.TimeoutToFail == 0 && t.TransType != "saga" && t.TransType != "csaga" { + timeout = config.TimeoutToFail + } + if timeout == 0 { + return false + } + return time.Since(*t.CreateTime)+NowForwardDuration >= time.Duration(timeout)*time.Second +} + +func (t *TransGlobal) getRetryInterval() int64 { + if t.RetryInterval > 0 { + return t.RetryInterval + } + return config.RetryInterval +} + +func (t *TransGlobal) needProcess() bool { + return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout() +} + // TransBranch branch transaction type TransBranch struct { common.ModelBase @@ -124,14 +149,14 @@ func (t *TransGlobal) getProcessor() transProcessor { } // Process process global transaction once -func (t *TransGlobal) Process(db *common.DB, waitResult bool) dtmcli.M { - r := t.process(db, waitResult) +func (t *TransGlobal) Process(db *common.DB) dtmcli.M { + r := t.process(db) transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess) return r } -func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M { - if !waitResult { +func (t *TransGlobal) process(db *common.DB) dtmcli.M { + if !t.WaitResult { go t.processInner(db) return dtmcli.MapSuccess } @@ -149,6 +174,9 @@ func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M { func (t *TransGlobal) processInner(db *common.DB) (rerr error) { defer handlePanic(&rerr) defer func() { + if rerr != nil { + dtmcli.LogRedf("processInner got error: %s", rerr.Error()) + } if TransProcessedTestChan != nil { dtmcli.Logf("processed: %s", t.Gid) TransProcessedTestChan <- t.Gid @@ -157,23 +185,40 @@ func (t *TransGlobal) processInner(db *common.DB) (rerr error) { }() dtmcli.Logf("processing: %s status: %s", t.Gid, t.Status) if t.Status == dtmcli.StatusPrepared && t.TransType != "msg" { - t.changeStatus(db, "aborting") + t.changeStatus(db, dtmcli.StatusAborting) } branches := []TransBranch{} db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches) t.processStarted = time.Now() - t.getProcessor().ProcessOnce(db, branches) + rerr = t.getProcessor().ProcessOnce(db, branches) return } -func (t *TransGlobal) setNextCron(expireIn int64) []string { - t.NextCronInterval = expireIn +type cronType int + +const ( + cronBackoff cronType = iota + cronReset + cronKeep +) + +func (t *TransGlobal) setNextCron(ctype cronType) []string { + if ctype == cronBackoff { + t.NextCronInterval = t.NextCronInterval * 2 + } else if ctype == cronKeep { + // do nothing + } else if t.RetryInterval != 0 { + t.NextCronInterval = t.RetryInterval + } else { + t.NextCronInterval = config.RetryInterval + } + next := time.Now().Add(time.Duration(t.NextCronInterval) * time.Second) t.NextCronTime = &next return []string{"next_cron_interval", "next_cron_time"} } -func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) string { +func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) (string, error) { if t.Protocol == "grpc" { dtmcli.PanicIf(strings.HasPrefix(url, "http"), fmt.Errorf("bad url for grpc: %s", url)) server, method := dtmgrpc.GetServerAndMethod(url) @@ -188,11 +233,17 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran BusiData: branchData, }, &emptypb.Empty{}) if err == nil { - return dtmcli.ResultSuccess - } else if status.Code(err) == codes.Aborted { - return dtmcli.ResultFailure + return dtmcli.ResultSuccess, nil + } + st, ok := status.FromError(err) + if ok && st.Code() == codes.Aborted { + if st.Message() == dtmcli.ResultOngoing { + return dtmcli.ResultOngoing, nil + } else if st.Message() == dtmcli.ResultFailure { + return dtmcli.ResultFailure, nil + } } - return err.Error() + return "", err } dtmcli.PanicIf(!strings.HasPrefix(url, "http"), fmt.Errorf("bad url for http: %s", url)) resp, err := dtmcli.RestyClient.R().SetBody(string(branchData)). @@ -204,36 +255,49 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran }). SetHeader("Content-type", "application/json"). Execute(dtmcli.If(branchData == nil, "GET", "POST").(string), url) - e2p(err) - return resp.String() + if err != nil { + return "", err + } + return resp.String(), nil } -func (t *TransGlobal) getBranchResult(branch *TransBranch) string { - return t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data)) +func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { + body, err := t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data)) + if err != nil { + return "", err + } + if strings.Contains(body, dtmcli.ResultSuccess) { + return dtmcli.StatusSucceed, nil + } else if strings.HasSuffix(t.TransType, "saga") && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) { + return dtmcli.StatusFailed, nil + } else if strings.Contains(body, dtmcli.ResultOngoing) { + return "", dtmcli.ErrOngoing + } + return "", fmt.Errorf("http result should contains SUCCESS|FAILURE|ONGOING. grpc error should return nil|Aborted with message(FAILURE|ONGOING). \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body) } -func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) { - body := t.getBranchResult(branch) - status := "" - if strings.Contains(body, dtmcli.ResultSuccess) { - status = dtmcli.StatusSucceed - } else if t.TransType == "saga" && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) { - status = dtmcli.StatusFailed - } else { - panic(fmt.Errorf("http result should contains SUCCESS|FAILURE. grpc error should return nil|Aborted. \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body)) +func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) error { + status, err := t.getBranchResult(branch) + if status != "" { + branch.changeStatus(db, status) } branchMetrics(t, branch, status == dtmcli.StatusSucceed) - // 如果一次处理超过1500ms,那么touch一下TransGlobal,避免被Cron取出 - if time.Since(t.processStarted)+CronForwardDuration >= 1500*time.Millisecond || t.NextCronInterval > config.TransCronInterval { - t.touch(db, config.TransCronInterval) + // if time pass 1500ms and NextCronInterval is not default, then reset NextCronInterval + if err == nil && time.Since(t.processStarted)+NowForwardDuration >= 1500*time.Millisecond || + t.NextCronInterval > config.RetryInterval && t.NextCronInterval > t.RetryInterval { + t.touch(db, cronReset) + } else if err == dtmcli.ErrOngoing { + t.touch(db, cronKeep) + } else { + t.touch(db, cronBackoff) } - branch.changeStatus(db, status) + return err } func (t *TransGlobal) saveNew(db *common.DB) error { return db.Transaction(func(db1 *gorm.DB) error { db := &common.DB{DB: db1} - t.setNextCron(config.TransCronInterval) + t.setNextCron(cronReset) writeTransLog(t.Gid, "create trans", t.Status, "", t.Data) dbr := db.Must().Clauses(clause.OnConflict{ DoNothing: true, @@ -265,6 +329,10 @@ func TransFromContext(c *gin.Context) *TransGlobal { } m := TransGlobal{} dtmcli.MustRemarshal(data, &m) + m.Options = dtmcli.MustMarshalString(m.TransOptions) + if m.Options == "{}" { + m.Options = "" + } m.Protocol = "http" return &m } @@ -288,6 +356,9 @@ func TransFromDb(db *common.DB, gid string) *TransGlobal { return nil } e2p(dbr.Error) + if m.Options != "" { + dtmcli.MustUnmarshalString(m.Options, &m.TransOptions) + } return &m } diff --git a/dtmsvr/trans_concurrent_saga.go b/dtmsvr/trans_concurrent_saga.go new file mode 100644 index 0000000..3d64000 --- /dev/null +++ b/dtmsvr/trans_concurrent_saga.go @@ -0,0 +1,175 @@ +package dtmsvr + +import ( + "time" + + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" + "gorm.io/gorm/clause" +) + +type transCSagaProcessor struct { + *TransGlobal +} + +func init() { + registorProcessorCreator("csaga", func(trans *TransGlobal) transProcessor { return &transCSagaProcessor{TransGlobal: trans} }) +} + +func (t *transCSagaProcessor) GenBranches() []TransBranch { + return genSagaBranches(t.TransGlobal) +} + +type cSagaCustom struct { + Orders map[int][]int `json:"orders"` +} + +func isPreconditionsSucceed(branches []TransBranch, pres []int) bool { + for _, pre := range pres { + if branches[pre].Status != dtmcli.StatusSucceed { + return false + } + } + return true +} + +type branchResult struct { + index int + status string + started bool + branchType string +} + +func (t *transCSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed { + return nil + } + n := len(branches) + + orders := map[int][]int{} + if t.CustomData != "" { + csc := cSagaCustom{Orders: map[int][]int{}} + dtmcli.MustUnmarshalString(t.CustomData, &csc) + for k, v := range csc.Orders { // new branches is doubled, so change the order value + orders[2*k+1] = []int{} + for j := 0; j < len(v); j++ { + orders[2*k+1] = append(orders[2*k+1], csc.Orders[k][j]*2+1) + } + } + } + // resultStats + var rsActionToStart, rsActionDone, rsActionFailed, rsActionSucceed, rsCompensateToStart, rsCompensateDone, rsCompensateSucceed int + branchResults := make([]branchResult, n) // save the branch result + for i := 0; i < n; i++ { + b := branches[i] + if b.BranchType == dtmcli.BranchAction { + if b.Status == dtmcli.StatusPrepared || b.Status == dtmcli.StatusDoing { + rsActionToStart++ + } else if b.Status == dtmcli.StatusFailed { + rsActionFailed++ + } + } + branchResults[i] = branchResult{status: branches[i].Status, branchType: branches[i].BranchType} + } + stopChan := make(chan branchResult, n) + asyncExecBranch := func(i int) { + var err error + defer func() { + if x := recover(); x != nil { + err = dtmcli.AsError(x) + } + stopChan <- branchResult{index: i, status: branches[i].Status, branchType: branches[i].BranchType} + if err != nil { + dtmcli.LogRedf("exec branch error: %v", err) + } + }() + err = t.execBranch(db, &branches[i]) + } + needRollback := func(i int) bool { + br := &branchResults[i] + return !br.started && br.branchType == dtmcli.BranchCompensate && br.status != dtmcli.StatusSucceed && branchResults[i+1].branchType == dtmcli.BranchAction && branchResults[i+1].status != dtmcli.StatusPrepared + } + pickAndRun := func(branchType string) { + toRun := []int{} + for current := 0; current < n; current++ { + br := &branchResults[current] + if br.branchType == branchType && branchType == dtmcli.BranchAction { + if (br.status == dtmcli.StatusPrepared || br.status == dtmcli.StatusDoing) && + !br.started && isPreconditionsSucceed(branches, orders[current]) { + br.status = dtmcli.StatusDoing + toRun = append(toRun, current) + } + } else if br.branchType == branchType && branchType == dtmcli.BranchCompensate { + if needRollback(current) { + toRun = append(toRun, current) + } + } + } + if branchType == dtmcli.BranchAction && len(toRun) > 0 { + updates := make([]TransBranch, len(toRun)) + for i, b := range toRun { + updates[i].ID = branches[b].ID + branches[b].Status = dtmcli.StatusDoing + updates[i].Status = dtmcli.StatusDoing + } + dbGet().Must().Clauses(clause.OnConflict{ + OnConstraint: "trans_branch_pkey", + DoUpdates: clause.AssignmentColumns([]string{"status"}), + }).Create(updates) + } else if branchType == dtmcli.BranchCompensate { + rsCompensateToStart = len(toRun) + } + for _, b := range toRun { + branchResults[b].started = true + go asyncExecBranch(b) + } + } + processorTimeout := func() bool { + return time.Since(t.processStarted)+NowForwardDuration > time.Duration(t.getRetryInterval()-3)*time.Second + } + waitOnceForDone := func() { + select { + case r := <-stopChan: + br := &branchResults[r.index] + br.status = r.status + if r.branchType == dtmcli.BranchAction { + rsActionDone++ + if r.status == dtmcli.StatusFailed { + rsActionFailed++ + } else if r.status == dtmcli.StatusSucceed { + rsActionSucceed++ + } + } else { + rsCompensateDone++ + if r.status == dtmcli.StatusSucceed { + rsCompensateSucceed++ + } + } + dtmcli.Logf("branch done: %v", r) + case <-time.After(time.Duration(time.Second * 3)): + dtmcli.Logf("wait once for done") + } + } + + for t.Status == dtmcli.StatusSubmitted && !t.isTimeout() && rsActionFailed == 0 && rsActionDone != rsActionToStart && !processorTimeout() { + pickAndRun(dtmcli.BranchAction) + waitOnceForDone() + } + if t.Status == dtmcli.StatusSubmitted && rsActionFailed == 0 && rsActionToStart == rsActionSucceed { + t.changeStatus(db, dtmcli.StatusSucceed) + return nil + } + if t.Status == dtmcli.StatusSubmitted && (rsActionFailed > 0 || t.isTimeout()) { + t.changeStatus(db, dtmcli.StatusAborting) + } + if t.Status == dtmcli.StatusAborting { + pickAndRun(dtmcli.BranchCompensate) + for rsCompensateDone != rsCompensateToStart && !processorTimeout() { + waitOnceForDone() + } + } + if (t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting) && rsActionFailed > 0 && rsCompensateToStart == rsCompensateSucceed { + t.changeStatus(db, dtmcli.StatusFailed) + } + return nil +} diff --git a/dtmsvr/trans_msg.go b/dtmsvr/trans_msg.go index 0ca83ee..1c4e0d0 100644 --- a/dtmsvr/trans_msg.go +++ b/dtmsvr/trans_msg.go @@ -33,23 +33,26 @@ func (t *transMsgProcessor) GenBranches() []TransBranch { } func (t *TransGlobal) mayQueryPrepared(db *common.DB) { - if t.Status != dtmcli.StatusPrepared { + if !t.needProcess() || t.Status == dtmcli.StatusSubmitted { return } - body := t.getURLResult(t.QueryPrepared, "", "", nil) + body, err := t.getURLResult(t.QueryPrepared, "", "", nil) if strings.Contains(body, dtmcli.ResultSuccess) { t.changeStatus(db, dtmcli.StatusSubmitted) } else if strings.Contains(body, dtmcli.ResultFailure) { t.changeStatus(db, dtmcli.StatusFailed) + } else if strings.Contains(body, dtmcli.ResultOngoing) { + t.touch(db, cronReset) } else { - t.touch(db, t.NextCronInterval*2) + dtmcli.LogRedf("getting result failed for %s. error: %s", t.QueryPrepared, err.Error()) + t.touch(db, cronBackoff) } } -func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { +func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { t.mayQueryPrepared(db) - if t.Status != dtmcli.StatusSubmitted { - return + if !t.needProcess() || t.Status == dtmcli.StatusPrepared { + return nil } current := 0 // 当前正在处理的步骤 for ; current < len(branches); current++ { @@ -57,14 +60,17 @@ func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { if branch.BranchType != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared { continue } - t.execBranch(db, branch) + err := t.execBranch(db, branch) + if err != nil { + return err + } if branch.Status != dtmcli.StatusSucceed { break } } if current == len(branches) { // msg 事务完成 t.changeStatus(db, dtmcli.StatusSucceed) - return + return nil } panic("msg go pass all branch") } diff --git a/dtmsvr/trans_saga.go b/dtmsvr/trans_saga.go index a8c27a5..2fbcc51 100644 --- a/dtmsvr/trans_saga.go +++ b/dtmsvr/trans_saga.go @@ -15,7 +15,7 @@ func init() { registorProcessorCreator("saga", func(trans *TransGlobal) transProcessor { return &transSagaProcessor{TransGlobal: trans} }) } -func (t *transSagaProcessor) GenBranches() []TransBranch { +func genSagaBranches(t *TransGlobal) []TransBranch { branches := []TransBranch{} steps := []M{} dtmcli.MustUnmarshalString(t.Data, &steps) @@ -35,9 +35,13 @@ func (t *transSagaProcessor) GenBranches() []TransBranch { return branches } -func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { - if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed { - return +func (t *transSagaProcessor) GenBranches() []TransBranch { + return genSagaBranches(t.TransGlobal) +} + +func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + if !t.needProcess() { + return nil } current := 0 // 当前正在处理的步骤 for ; current < len(branches); current++ { @@ -47,7 +51,10 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) } // 找到了一个非succeed的action if branch.Status == dtmcli.StatusPrepared { - t.execBranch(db, branch) + err := t.execBranch(db, branch) + if err != nil { + return err + } } if branch.Status != dtmcli.StatusSucceed { break @@ -55,17 +62,21 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) } if current == len(branches) { // saga 事务完成 t.changeStatus(db, dtmcli.StatusSucceed) - return + return nil } - if t.Status != "aborting" && t.Status != dtmcli.StatusFailed { - t.changeStatus(db, "aborting") + if t.Status != dtmcli.StatusAborting && t.Status != dtmcli.StatusFailed { + t.changeStatus(db, dtmcli.StatusAborting) } for current = current - 1; current >= 0; current-- { branch := &branches[current] if branch.BranchType != dtmcli.BranchCompensate || branch.Status != dtmcli.StatusPrepared { continue } - t.execBranch(db, branch) + err := t.execBranch(db, branch) + if err != nil { + return err + } } t.changeStatus(db, dtmcli.StatusFailed) + return nil } diff --git a/dtmsvr/trans_tcc.go b/dtmsvr/trans_tcc.go index 43f564a..8e4afc4 100644 --- a/dtmsvr/trans_tcc.go +++ b/dtmsvr/trans_tcc.go @@ -17,15 +17,19 @@ func (t *transTccProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { - if t.Status == dtmcli.StatusSucceed || t.Status == dtmcli.StatusFailed { - return +func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + if !t.needProcess() { + return nil } branchType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchConfirm, dtmcli.BranchCancel).(string) for current := len(branches) - 1; current >= 0; current-- { if branches[current].BranchType == branchType && branches[current].Status == dtmcli.StatusPrepared { - t.execBranch(db, &branches[current]) + err := t.execBranch(db, &branches[current]) + if err != nil { + return err + } } } t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) + return nil } diff --git a/dtmsvr/trans_xa.go b/dtmsvr/trans_xa.go index e4a79a4..8c6b578 100644 --- a/dtmsvr/trans_xa.go +++ b/dtmsvr/trans_xa.go @@ -17,15 +17,19 @@ func (t *transXaProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { - if t.Status == dtmcli.StatusSucceed { - return +func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + if !t.needProcess() { + return nil } currentType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchCommit, dtmcli.BranchRollback).(string) for _, branch := range branches { if branch.BranchType == currentType && branch.Status != dtmcli.StatusSucceed { - t.execBranch(db, &branch) + err := t.execBranch(db, &branch) + if err != nil { + return err + } } } t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) + return nil } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index ee07452..abd5fb8 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -16,9 +16,9 @@ import ( type M = map[string]interface{} type branchStatus struct { - id uint - status string - finish_time *time.Time + id uint64 + status string + finishTime *time.Time } var p2e = dtmcli.P2E diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index d6f7732..c678e7c 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -31,3 +31,15 @@ func TestCheckLocalHost(t *testing.T) { }) assert.Nil(t, err) } + +func TestSetNextCron(t *testing.T) { + tg := TransGlobal{} + tg.RetryInterval = 15 + tg.setNextCron(cronReset) + assert.Equal(t, int64(15), tg.NextCronInterval) + tg.RetryInterval = 0 + tg.setNextCron(cronReset) + assert.Equal(t, config.RetryInterval, tg.NextCronInterval) + tg.setNextCron(cronBackoff) + assert.Equal(t, config.RetryInterval*2, tg.NextCronInterval) +} diff --git a/examples/base_grpc.go b/examples/base_grpc.go index 8b9f0d2..d449015 100644 --- a/examples/base_grpc.go +++ b/examples/base_grpc.go @@ -45,7 +45,7 @@ func handleGrpcBusiness(in *dtmgrpc.BusiRequest, result1 string, result2 string, if res == dtmcli.ResultSuccess { return nil } else if res == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user want to rollback").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } return status.New(codes.Internal, fmt.Sprintf("unknow result %s", res)).Err() } @@ -113,7 +113,7 @@ func (s *busiServer) TransInXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*d dtmcli.MustUnmarshal(in.BusiData, &req) return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { if req.TransInResult == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user return failure").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2) return err @@ -125,7 +125,7 @@ func (s *busiServer) TransOutXa(ctx context.Context, in *dtmgrpc.BusiRequest) (* dtmcli.MustUnmarshal(in.BusiData, &req) return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { if req.TransOutResult == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user return failure").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1) return err diff --git a/examples/base_http.go b/examples/base_http.go index 049119a..a643bc2 100644 --- a/examples/base_http.go +++ b/examples/base_http.go @@ -2,6 +2,7 @@ package examples import ( "database/sql" + "errors" "fmt" "time" @@ -145,4 +146,12 @@ func BaseAddRoute(app *gin.Engine) { }) })) + app.POST(BusiAPI+"/TestPanic", common.WrapHandler(func(c *gin.Context) (interface{}, error) { + if c.Query("panic_error") != "" { + panic(errors.New("panic_error")) + } else if c.Query("panic_string") != "" { + panic("panic_string") + } + return "SUCCESS", nil + })) } diff --git a/examples/grpc_saga_barrier.go b/examples/grpc_saga_barrier.go index f48bfce..a1ee504 100644 --- a/examples/grpc_saga_barrier.go +++ b/examples/grpc_saga_barrier.go @@ -24,7 +24,7 @@ func init() { func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { if result == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user rollback").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) return err diff --git a/examples/http_msg.go b/examples/http_msg.go index d304754..044f541 100644 --- a/examples/http_msg.go +++ b/examples/http_msg.go @@ -11,7 +11,7 @@ func init() { msg := dtmcli.NewMsg(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/TransOut", req). Add(Busi+"/TransIn", req) - err := msg.Prepare(Busi + "/TransQuery") + err := msg.Prepare(Busi + "/query") dtmcli.FatalIfError(err) dtmcli.Logf("busi trans submit") err = msg.Submit() diff --git a/examples/http_saga.go b/examples/http_saga.go index d014eee..40df77a 100644 --- a/examples/http_saga.go +++ b/examples/http_saga.go @@ -23,7 +23,7 @@ func init() { saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/TransOut", Busi+"/TransOutRevert", req). Add(Busi+"/TransIn", Busi+"/TransInRevert", req) - saga.WaitResult = true // 设置为等待结果模式,后面的submit调用,会等待服务器处理这个事务。如果Submit正常返回,那么整个全局事务已成功完成 + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() dtmcli.Logf("result gid is: %s", saga.Gid) dtmcli.FatalIfError(err) diff --git a/test/base_test.go b/test/base_test.go new file mode 100644 index 0000000..ba50e94 --- /dev/null +++ b/test/base_test.go @@ -0,0 +1,53 @@ +package test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" + "github.com/yedf/dtm/examples" +) + +func TestSqlDB(t *testing.T) { + asserts := assert.New(t) + db := common.DbGet(config.DB) + barrier := &dtmcli.BranchBarrier{ + TransType: "saga", + Gid: "gid2", + BranchID: "branch_id2", + BranchType: dtmcli.BranchAction, + } + db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") + tx, err := db.ToSQLDB().Begin() + asserts.Nil(err) + err = barrier.Call(tx, func(db dtmcli.DB) error { + dtmcli.Logf("rollback gid2") + return fmt.Errorf("gid2 error") + }) + asserts.Error(err, fmt.Errorf("gid2 error")) + dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{}) + asserts.Equal(dbr.RowsAffected, int64(1)) + dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) + asserts.Equal(dbr.RowsAffected, int64(0)) + barrier.BarrierID = 0 + tx2, err := db.ToSQLDB().Begin() + asserts.Nil(err) + err = barrier.Call(tx2, func(db dtmcli.DB) error { + dtmcli.Logf("submit gid2") + return nil + }) + asserts.Nil(err) + dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) + asserts.Equal(dbr.RowsAffected, int64(1)) +} + +func TestHttp(t *testing.T) { + resp, err := dtmcli.RestyClient.R().SetQueryParam("panic_string", "1").Post(examples.Busi + "/TestPanic") + assert.Nil(t, err) + assert.Contains(t, resp.String(), "panic_string") + resp, err = dtmcli.RestyClient.R().SetQueryParam("panic_error", "1").Post(examples.Busi + "/TestPanic") + assert.Nil(t, err) + assert.Contains(t, resp.String(), "panic_error") +} diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index f5d5524..e185a60 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -1,7 +1,6 @@ package test import ( - "fmt" "testing" "time" @@ -83,43 +82,10 @@ func transQuery(t *testing.T, gid string) { assert.Nil(t, err) } -func TestSqlDB(t *testing.T) { - asserts := assert.New(t) - db := common.DbGet(config.DB) - barrier := &dtmcli.BranchBarrier{ - TransType: "saga", - Gid: "gid2", - BranchID: "branch_id2", - BranchType: dtmcli.BranchAction, - } - db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") - tx, err := db.ToSQLDB().Begin() - asserts.Nil(err) - err = barrier.Call(tx, func(db dtmcli.DB) error { - dtmcli.Logf("rollback gid2") - return fmt.Errorf("gid2 error") - }) - asserts.Error(err, fmt.Errorf("gid2 error")) - dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{}) - asserts.Equal(dbr.RowsAffected, int64(1)) - dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) - asserts.Equal(dbr.RowsAffected, int64(0)) - barrier.BarrierID = 0 - tx2, err := db.ToSQLDB().Begin() - asserts.Nil(err) - err = barrier.Call(tx2, func(db dtmcli.DB) error { - dtmcli.Logf("submit gid2") - return nil - }) - asserts.Nil(err) - dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) - asserts.Equal(dbr.RowsAffected, int64(1)) -} - func TestUpdateBranchAsync(t *testing.T) { common.DtmConfig.UpdateBranchSync = 0 saga := genSaga("gid-update-branch-async", false, false) - saga.WaitResult = true + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) diff --git a/test/grpc_msg_test.go b/test/grpc_msg_test.go index 233aaca..c6a127e 100644 --- a/test/grpc_msg_test.go +++ b/test/grpc_msg_test.go @@ -12,7 +12,7 @@ import ( func TestGrpcMsg(t *testing.T) { grpcMsgNormal(t) - grpcMsgPending(t) + grpcMsgOngoing(t) } func grpcMsgNormal(t *testing.T) { @@ -23,15 +23,15 @@ func grpcMsgNormal(t *testing.T) { assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) } -func grpcMsgPending(t *testing.T) { +func grpcMsgOngoing(t *testing.T) { msg := genGrpcMsg("grpc-msg-pending") err := msg.Prepare(fmt.Sprintf("%s/examples.Busi/CanSubmit", examples.BusiGrpc)) assert.Nil(t, err) - examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.TransInResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) diff --git a/test/grpc_saga_test.go b/test/grpc_saga_test.go index 3bfd57b..393a1bc 100644 --- a/test/grpc_saga_test.go +++ b/test/grpc_saga_test.go @@ -11,7 +11,7 @@ import ( func TestGrpcSaga(t *testing.T) { sagaGrpcNormal(t) - sagaGrpcCommittedPending(t) + sagaGrpcCommittedOngoing(t) sagaGrpcRollback(t) } @@ -24,9 +24,9 @@ func sagaGrpcNormal(t *testing.T) { transQuery(t, saga.Gid) } -func sagaGrpcCommittedPending(t *testing.T) { - saga := genSagaGrpc("gid-committedPendingGrpc", false, false) - examples.MainSwitch.TransOutResult.SetOnce("PENDING") +func sagaGrpcCommittedOngoing(t *testing.T) { + saga := genSagaGrpc("gid-committedOngoingGrpc", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) saga.Submit() WaitTransProcessed(saga.Gid) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) @@ -37,10 +37,10 @@ func sagaGrpcCommittedPending(t *testing.T) { func sagaGrpcRollback(t *testing.T) { saga := genSagaGrpc("gid-rollbackSaga2Grpc", false, true) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) saga.Submit() WaitTransProcessed(saga.Gid) - assert.Equal(t, "aborting", getTransStatus(saga.Gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid)) diff --git a/test/grpc_tcc_test.go b/test/grpc_tcc_test.go index a894cd3..574d0f0 100644 --- a/test/grpc_tcc_test.go +++ b/test/grpc_tcc_test.go @@ -53,13 +53,13 @@ func tccGrpcRollback(t *testing.T) { err := dtmgrpc.TccGlobalTransaction(examples.DtmGrpcServer, gid, func(tcc *dtmgrpc.TccGrpc) error { _, err := tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransOutTcc", examples.BusiGrpc+"/examples.Busi/TransOutConfirm", examples.BusiGrpc+"/examples.Busi/TransOutRevert") assert.Nil(t, err) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) _, err = tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransInTcc", examples.BusiGrpc+"/examples.Busi/TransInConfirm", examples.BusiGrpc+"/examples.Busi/TransInRevert") return err }) assert.Error(t, err) WaitTransProcessed(gid) - assert.Equal(t, "aborting", getTransStatus(gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) } diff --git a/test/main_test.go b/test/main_test.go index 5b472fe..328cb4c 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -14,7 +14,8 @@ import ( func TestMain(m *testing.M) { dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) dtmsvr.TransProcessedTestChan = make(chan string, 1) - dtmsvr.CronForwardDuration = 60 * time.Second + dtmsvr.NowForwardDuration = 0 * time.Second + dtmsvr.CronForwardDuration = 180 * time.Second common.DtmConfig.UpdateBranchSync = 1 dtmsvr.PopulateDB(false) examples.PopulateDB(false) diff --git a/test/msg_test.go b/test/msg_test.go index 9e03679..85f6cb6 100644 --- a/test/msg_test.go +++ b/test/msg_test.go @@ -11,8 +11,8 @@ import ( func TestMsg(t *testing.T) { msgNormal(t) - msgPending(t) - msgPendingFailed(t) + msgOngoing(t) + msgOngoingFailed(t) } func msgNormal(t *testing.T) { @@ -24,30 +24,30 @@ func msgNormal(t *testing.T) { assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) } -func msgPending(t *testing.T) { +func msgOngoing(t *testing.T) { msg := genMsg("gid-msg-normal-pending") msg.Prepare("") assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.TransInResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid)) CronTransOnce() assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) } -func msgPendingFailed(t *testing.T) { +func msgOngoingFailed(t *testing.T) { msg := genMsg("gid-msg-pending-failed") msg.Prepare("") assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultFailure) - CronTransOnce() + cronTransOnceForwardNow(180) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(msg.Gid)) } diff --git a/test/saga_concurrent_test.go b/test/saga_concurrent_test.go new file mode 100644 index 0000000..5d2020d --- /dev/null +++ b/test/saga_concurrent_test.go @@ -0,0 +1,72 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yedf/dtm/dtmcli" + "github.com/yedf/dtm/examples" +) + +func TestCSaga(t *testing.T) { + csagaNormal(t) + csagaRollback(t) + csagaRollback2(t) + csagaCommittedOngoing(t) +} + +func genCSaga(gid string, outFailed bool, inFailed bool) *dtmcli.ConcurrentSaga { + dtmcli.Logf("beginning a concurrent saga test ---------------- %s", gid) + csaga := dtmcli.NewConcurrentSaga(examples.DtmServer, gid) + req := examples.GenTransReq(30, outFailed, inFailed) + csaga.Add(examples.Busi+"/TransOut", examples.Busi+"/TransOutRevert", &req) + csaga.Add(examples.Busi+"/TransIn", examples.Busi+"/TransInRevert", &req) + return csaga +} + +func csagaNormal(t *testing.T) { + csaga := genCSaga("gid-noraml-csaga", false, false) + csaga.Submit() + WaitTransProcessed(csaga.Gid) + assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid)) +} + +func csagaRollback(t *testing.T) { + csaga := genCSaga("gid-rollback-csaga", true, false) + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) + err := csaga.Submit() + assert.Nil(t, err) + WaitTransProcessed(csaga.Gid) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(csaga.Gid)) + CronTransOnce() + assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid)) + assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + err = csaga.Submit() + assert.Error(t, err) +} + +func csagaRollback2(t *testing.T) { + csaga := genCSaga("gid-rollback-csaga2", true, false) + csaga.AddStepOrder(1, []int{0}) + err := csaga.Submit() + assert.Nil(t, err) + WaitTransProcessed(csaga.Gid) + assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid)) + assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(csaga.Gid)) + err = csaga.Submit() + assert.Error(t, err) +} + +func csagaCommittedOngoing(t *testing.T) { + csaga := genCSaga("gid-committed-ongoing-csaga", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) + csaga.Submit() + WaitTransProcessed(csaga.Gid) + assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusDoing, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(csaga.Gid)) + + CronTransOnce() + assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid)) +} diff --git a/test/saga_test.go b/test/saga_test.go index 6627f2b..a7f05e8 100644 --- a/test/saga_test.go +++ b/test/saga_test.go @@ -10,7 +10,7 @@ import ( func TestSaga(t *testing.T) { sagaNormal(t) - sagaCommittedPending(t) + sagaCommittedOngoing(t) sagaRollback(t) } @@ -25,9 +25,9 @@ func sagaNormal(t *testing.T) { assert.Error(t, err) } -func sagaCommittedPending(t *testing.T) { - saga := genSaga("gid-committedPending", false, false) - examples.MainSwitch.TransOutResult.SetOnce("PENDING") +func sagaCommittedOngoing(t *testing.T) { + saga := genSaga("gid-committedOngoing", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) saga.Submit() WaitTransProcessed(saga.Gid) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) @@ -38,11 +38,11 @@ func sagaCommittedPending(t *testing.T) { func sagaRollback(t *testing.T) { saga := genSaga("gid-rollbackSaga2", false, true) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) - assert.Equal(t, "aborting", getTransStatus(saga.Gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid)) diff --git a/test/tcc_test.go b/test/tcc_test.go index cb98f3b..2b5588d 100644 --- a/test/tcc_test.go +++ b/test/tcc_test.go @@ -32,12 +32,12 @@ func tccRollback(t *testing.T) { err := dtmcli.TccGlobalTransaction(examples.DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { _, rerr := tcc.CallBranch(data, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert") assert.Nil(t, rerr) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) return tcc.CallBranch(data, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") }) assert.Error(t, err) WaitTransProcessed(gid) - assert.Equal(t, "aborting", getTransStatus(gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) } diff --git a/test/types.go b/test/types.go index f069a89..c9b84bb 100644 --- a/test/types.go +++ b/test/types.go @@ -1,6 +1,8 @@ package test import ( + "time" + "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmsvr" @@ -27,3 +29,10 @@ type TransBranch = dtmsvr.TransBranch // M alias type M = dtmcli.M + +func cronTransOnceForwardNow(seconds int) { + old := dtmsvr.NowForwardDuration + dtmsvr.NowForwardDuration = time.Duration(seconds) * time.Second + CronTransOnce() + dtmsvr.NowForwardDuration = old +} diff --git a/test/wait_saga_test.go b/test/wait_saga_test.go index 87c4213..9166c55 100644 --- a/test/wait_saga_test.go +++ b/test/wait_saga_test.go @@ -11,13 +11,13 @@ import ( func TestWaitSaga(t *testing.T) { sagaNormalWait(t) - sagaCommittedPendingWait(t) + sagaCommittedOngoingWait(t) sagaRollbackWait(t) } func sagaNormalWait(t *testing.T) { saga := genSaga("gid-noramlSagaWait", false, false) - saga.WaitResult = true + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) @@ -26,10 +26,10 @@ func sagaNormalWait(t *testing.T) { transQuery(t, saga.Gid) } -func sagaCommittedPendingWait(t *testing.T) { - saga := genSaga("gid-committedPendingWait", false, false) - examples.MainSwitch.TransOutResult.SetOnce("PENDING") - saga.WaitResult = true +func sagaCommittedOngoingWait(t *testing.T) { + saga := genSaga("gid-committedOngoingWait", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Error(t, err) WaitTransProcessed(saga.Gid) @@ -41,7 +41,7 @@ func sagaCommittedPendingWait(t *testing.T) { func sagaRollbackWait(t *testing.T) { saga := genSaga("gid-rollbackSaga2Wait", false, true) - saga.WaitResult = true + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Error(t, err) WaitTransProcessed(saga.Gid)