diff --git a/.gitignore b/.gitignore index 9304d6c..7bec55c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ conf.yml main dist .idea/** +.vscode/** diff --git a/.vscode/launch.json b/.vscode/launch.sample.json similarity index 85% rename from .vscode/launch.json rename to .vscode/launch.sample.json index 7c8dc8b..7cf371c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.sample.json @@ -1,5 +1,5 @@ { - // 使用 IntelliSense 了解相关属性。 + // 使用 IntelliSense 了解相关属性。 // 悬停以查看现有属性的描述。 // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", @@ -10,10 +10,11 @@ "request": "launch", "mode": "debug", "program": "${workspaceFolder}/app/main.go", + "cwd": "${workspaceFolder}", "env": { // "GIN_MODE": "release" }, - "args": [] + "args": ["grpc_saga"] }, { "name": "Test", diff --git a/.vscode/settings.json b/.vscode/settings.sample.json similarity index 100% rename from .vscode/settings.json rename to .vscode/settings.sample.json diff --git a/common/types.go b/common/types.go index 5ada90b..e080313 100644 --- a/common/types.go +++ b/common/types.go @@ -22,7 +22,7 @@ import ( // ModelBase model base for gorm to provide base fields type ModelBase struct { - ID uint + ID uint64 CreateTime *time.Time `gorm:"autoCreateTime"` UpdateTime *time.Time `gorm:"autoUpdateTime"` } @@ -123,7 +123,9 @@ func DbGet(conf map[string]string) *DB { } type dtmConfigType struct { - TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务 + TransCronInterval int64 `yaml:"TransCronInterval"` + TimeoutToFail int64 `yaml:"TimeoutToFail"` + RetryInterval int64 `yaml:"RetryInterval"` DB map[string]string `yaml:"DB"` DisableLocalhost int64 `yaml:"DisableLocalhost"` UpdateBranchSync int64 `yaml:"UpdateBranchSync"` @@ -140,7 +142,9 @@ func init() { if len(os.Args) == 1 { return } - DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "10") + DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "3") + DtmConfig.TimeoutToFail = getIntEnv("TIMEOUT_TO_FAIL", "35") + DtmConfig.RetryInterval = getIntEnv("RETRY_INTERVAL", "10") DtmConfig.DB = map[string]string{ "driver": dtmcli.OrString(os.Getenv("DB_DRIVER"), "mysql"), "host": os.Getenv("DB_HOST"), @@ -166,6 +170,24 @@ func init() { err := yaml.Unmarshal(cont, &DtmConfig) dtmcli.FatalIfError(err) } - dtmcli.LogIfFatalf(DtmConfig.DB["driver"] == "" || DtmConfig.DB["user"] == "", - "dtm配置错误. 请访问 http://dtm.pub 查看部署运维环节. check you env, and conf.yml/conf.sample.yml in current and parent path: %s. config is: \n%v", MustGetwd(), DtmConfig) + errStr := checkConfig() + dtmcli.LogIfFatalf(errStr != "", + `config error: '%s'. +check you env, and conf.yml/conf.sample.yml in current and parent path: %s. +please visit http://d.dtm.pub to see the config document. +loaded config is: +%v`, MustGetwd(), DtmConfig) +} + +func checkConfig() string { + if DtmConfig.DB["driver"] == "" { + return "db driver empty" + } else if DtmConfig.DB["user"] == "" || DtmConfig.DB["host"] == "" { + return "db config not valid" + } else if DtmConfig.RetryInterval < 10 { + return "RetryInterval should not be less than 10" + } else if DtmConfig.TimeoutToFail < DtmConfig.RetryInterval { + return "TimeoutToFail should not be less than RetryInterval" + } + return "" } diff --git a/common/types_test.go b/common/types_test.go index 20ccc38..4dd9c0b 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -30,3 +30,26 @@ func TestDbAlone(t *testing.T) { _, err = dtmcli.DBExec(db, "select 1") assert.NotEqual(t, nil, err) } + +func TestConfig(t *testing.T) { + testConfigStringField(DtmConfig.DB, "driver", "", t) + testConfigStringField(DtmConfig.DB, "user", "", t) + testConfigIntField(&DtmConfig.RetryInterval, 9, t) + testConfigIntField(&DtmConfig.TimeoutToFail, 9, t) +} + +func testConfigStringField(m map[string]string, key string, val string, t *testing.T) { + old := m[key] + m[key] = val + str := checkConfig() + assert.NotEqual(t, "", str) + m[key] = old +} + +func testConfigIntField(fd *int64, val int64, t *testing.T) { + old := *fd + *fd = val + str := checkConfig() + assert.NotEqual(t, "", str) + *fd = old +} diff --git a/common/utils.go b/common/utils.go index 14a89c0..5c9b174 100644 --- a/common/utils.go +++ b/common/utils.go @@ -43,7 +43,10 @@ func GetGinApp() *gin.Engine { // WrapHandler name is clear func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { return func(c *gin.Context) { - r, err := fn(c) + r, err := func() (r interface{}, rerr error) { + defer dtmcli.P2E(&rerr) + return fn(c) + }() var b = []byte{} if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理 b = resp.Body() diff --git a/conf.sample.yml b/conf.sample.yml index 55d8e63..f79e598 100644 --- a/conf.sample.yml +++ b/conf.sample.yml @@ -9,4 +9,8 @@ DB: # user: 'postgres' # password: 'mysecretpassword' # port: '5432' -TransCronInterval: 10 # 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮重试处理,包括prepared中的任务和commited的任务 +# the unit of following configurations is second + +# TransCronInterval: 3 # the interval to poll unfinished global transaction for every dtm process +# TimeoutToFail: 35 # timeout for XA, TCC to fail. saga's timeout default to infinite, which can be overwritten in saga options +# RetryInterval: 10 # the subtrans branch will be retried after this interval diff --git a/dtmcli/barrier.mysql.sql b/dtmcli/barrier.mysql.sql index a676f34..60b828f 100644 --- a/dtmcli/barrier.mysql.sql +++ b/dtmcli/barrier.mysql.sql @@ -1,10 +1,11 @@ -create database if not exists dtm_barrier /*!40100 DEFAULT CHARACTER SET utf8mb4 */; - +create database if not exists dtm_barrier +/*!40100 DEFAULT CHARACTER SET utf8mb4 */ +; drop table if exists dtm_barrier.barrier; create table if not exists dtm_barrier.barrier( - id int(11) PRIMARY KEY AUTO_INCREMENT, - trans_type varchar(45) default '' , - gid varchar(128) default'', + id bigint(22) PRIMARY KEY AUTO_INCREMENT, + trans_type varchar(45) default '', + gid varchar(128) default '', branch_id varchar(128) default '', branch_type varchar(45) default '', barrier_id varchar(45) default '', @@ -14,4 +15,4 @@ create table if not exists dtm_barrier.barrier( key(create_time), key(update_time), UNIQUE key(gid, branch_id, branch_type, barrier_id) -); +); \ No newline at end of file diff --git a/dtmcli/barrier.postgres.sql b/dtmcli/barrier.postgres.sql index dc3d069..62b8f5a 100644 --- a/dtmcli/barrier.postgres.sql +++ b/dtmcli/barrier.postgres.sql @@ -1,13 +1,10 @@ create schema if not exists dtm_barrier; - drop table if exists dtm_barrier.barrier; - CREATE SEQUENCE if not EXISTS dtm_barrier.barrier_seq; - create table if not exists dtm_barrier.barrier( - id int NOT NULL DEFAULT NEXTVAL ('dtm_barrier.barrier_seq'), - trans_type varchar(45) default '' , - gid varchar(128) default'', + id bigint NOT NULL DEFAULT NEXTVAL ('dtm_barrier.barrier_seq'), + trans_type varchar(45) default '', + gid varchar(128) default '', branch_id varchar(128) default '', branch_type varchar(45) default '', barrier_id varchar(45) default '', @@ -16,5 +13,4 @@ create table if not exists dtm_barrier.barrier( update_time timestamp(0) DEFAULT NULL, PRIMARY KEY(id), CONSTRAINT uniq_barrier unique(gid, branch_id, branch_type, barrier_id) -); - +); \ No newline at end of file diff --git a/dtmcli/consts.go b/dtmcli/consts.go index 8d2b1d7..438deed 100644 --- a/dtmcli/consts.go +++ b/dtmcli/consts.go @@ -1,14 +1,16 @@ package dtmcli const ( - // StatusPrepared status for global trans status. exists only in tran message + // StatusPrepared status for global/branch trans status. StatusPrepared = "prepared" - // StatusSubmitted StatusSubmitted status for global trans status. + // StatusSubmitted status for global trans status. StatusSubmitted = "submitted" - // StatusSucceed status for global trans status. + // StatusSucceed status for global/branch trans status. StatusSucceed = "succeed" - // StatusFailed status for global trans status. + // StatusFailed status for global/branch trans status. StatusFailed = "failed" + // StatusAborting status for global trans status. + StatusAborting = "aborting" // BranchTry branch type for TCC BranchTry = "try" @@ -29,6 +31,8 @@ const ( ResultSuccess = "SUCCESS" // ResultFailure for result of a trans/trans branch ResultFailure = "FAILURE" + // ResultOngoing for result of a trans/trans branch + ResultOngoing = "ONGOING" // DBTypeMysql const for driver mysql DBTypeMysql = "mysql" diff --git a/dtmcli/saga.go b/dtmcli/saga.go index f6681d8..fa9de44 100644 --- a/dtmcli/saga.go +++ b/dtmcli/saga.go @@ -1,9 +1,13 @@ package dtmcli +import "fmt" + // Saga struct of saga type Saga struct { TransBase - Steps []SagaStep `json:"steps"` + Steps []SagaStep `json:"steps"` + orders map[int][]int + concurrent bool } // SagaStep one step of saga @@ -15,7 +19,7 @@ type SagaStep struct { // NewSaga create a saga func NewSaga(server string, gid string) *Saga { - return &Saga{TransBase: *NewTransBase(gid, "saga", server, "")} + return &Saga{TransBase: *NewTransBase(gid, "saga", server, ""), orders: map[int][]int{}} } // Add add a saga step @@ -28,7 +32,23 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga return s } +// AddStepOrder specify that step should be after preSteps. Step is larger than all the element in preSteps +func (s *Saga) AddStepOrder(step int, preSteps []int) *Saga { + PanicIf(step > len(s.Steps), fmt.Errorf("step value: %d is invalid. which cannot be larger than total steps: %d", step, len(s.Steps))) + s.orders[step] = preSteps + return s +} + +// EnableConcurrent enable the concurrent exec of sub trans +func (s *Saga) EnableConcurrent() *Saga { + s.concurrent = true + return s +} + // Submit submit the saga trans func (s *Saga) Submit() error { + if s.concurrent { + s.CustomData = MustMarshalString(M{"orders": s.orders, "concurrent": s.concurrent}) + } return s.callDtm(s, "submit") } diff --git a/dtmcli/types.go b/dtmcli/types.go index e21994b..06e8d75 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -51,14 +51,26 @@ type TransResult struct { Message string } +// TransOptions transaction options +type TransOptions struct { + WaitResult bool `json:"wait_result,omitempty" gorm:"-"` + TimeoutToFail int64 `json:"timeout_to_fail,omitempty" gorm:"-"` // for trans type: xa, tcc + RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc +} + // TransBase 事务的基础类 type TransBase struct { - Gid string `json:"gid"` - TransType string `json:"trans_type"` + Gid string `json:"gid"` + TransType string `json:"trans_type"` + Dtm string `json:"-"` + CustomData string `json:"custom_data,omitempty"` IDGenerator - Dtm string - // WaitResult 是否等待全局事务的最终结果 - WaitResult bool + TransOptions +} + +// SetOptions set options +func (tb *TransBase) SetOptions(options *TransOptions) { + tb.TransOptions = *options } // NewTransBase 1 @@ -78,11 +90,7 @@ func TransBaseFromQuery(qs url.Values) *TransBase { // callDtm 调用dtm服务器,返回事务的状态 func (tb *TransBase) callDtm(body interface{}, operation string) error { - params := MS{} - if tb.WaitResult { - params["wait_result"] = "1" - } - resp, err := RestyClient.R().SetQueryParams(params). + resp, err := RestyClient.R(). SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation)) if err != nil { return err @@ -95,10 +103,10 @@ func (tb *TransBase) callDtm(body interface{}, operation string) error { } // ErrFailure 表示返回失败,要求回滚 -var ErrFailure = errors.New("transaction FAILURE") +var ErrFailure = errors.New("FAILURE") -// ErrPending 表示暂时失败,要求重试 -var ErrPending = errors.New("transaction PENDING") +// ErrOngoing 表示暂时失败,要求重试 +var ErrOngoing = errors.New("ONGOING") // MapSuccess 表示返回成功,可以进行下一步 var MapSuccess = M{"dtm_result": ResultSuccess} diff --git a/dtmcli/utils.go b/dtmcli/utils.go index 8fec13e..e9d511a 100644 --- a/dtmcli/utils.go +++ b/dtmcli/utils.go @@ -17,14 +17,18 @@ import ( "github.com/go-resty/resty/v2" ) +// AsError wrap a panic value as an error +func AsError(x interface{}) error { + if e, ok := x.(error); ok { + return e + } + return fmt.Errorf("%v", x) +} + // P2E panic to error func P2E(perr *error) { if x := recover(); x != nil { - if e, ok := x.(error); ok { - *perr = e - } else { - panic(x) - } + *perr = AsError(x) } } @@ -262,8 +266,8 @@ func CheckResult(res interface{}, err error) error { str := MustMarshalString(res) if strings.Contains(str, ResultFailure) { return ErrFailure - } else if strings.Contains(str, "PENDING") { - return ErrPending + } else if strings.Contains(str, ResultOngoing) { + return ErrOngoing } } return err diff --git a/dtmcli/utils_test.go b/dtmcli/utils_test.go index a59d628..c0954c3 100644 --- a/dtmcli/utils_test.go +++ b/dtmcli/utils_test.go @@ -25,13 +25,10 @@ func TestEP(t *testing.T) { }) assert.Equal(t, "err2", err.Error()) err = func() (rerr error) { - defer func() { - x := recover() - assert.Equal(t, 1, x) - }() defer P2E(&rerr) - panic(1) + panic("raw_string") }() + assert.Equal(t, "raw_string", err.Error()) } func TestTernary(t *testing.T) { diff --git a/dtmgrpc/barrier.go b/dtmgrpc/barrier.go index 1389da6..f24fc46 100644 --- a/dtmgrpc/barrier.go +++ b/dtmgrpc/barrier.go @@ -2,8 +2,6 @@ package dtmgrpc import ( "github.com/yedf/dtm/dtmcli" - "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" ) // BranchBarrier 子事务屏障 @@ -11,21 +9,6 @@ type BranchBarrier struct { *dtmcli.BranchBarrier } -// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 -// db: 本地数据库 -// transInfo: 事务信息 -// bisiCall: 业务函数,仅在必要时被调用 -// 返回值: -// 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚 -// 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行 -func (bb *BranchBarrier) Call(tx dtmcli.Tx, busiCall dtmcli.BusiFunc) (rerr error) { - err := bb.BranchBarrier.Call(tx, busiCall) - if err == dtmcli.ErrFailure { - return status.New(codes.Aborted, "user rollback").Err() - } - return err -} - // BarrierFromGrpc 从BusiRequest生成一个Barrier func BarrierFromGrpc(in *BusiRequest) (*BranchBarrier, error) { b, err := dtmcli.BarrierFrom(in.Info.TransType, in.Info.Gid, in.Info.BranchID, in.Info.BranchType) diff --git a/dtmgrpc/type.go b/dtmgrpc/type.go index b0539fe..fdc21db 100644 --- a/dtmgrpc/type.go +++ b/dtmgrpc/type.go @@ -9,7 +9,7 @@ import ( "github.com/yedf/dtm/dtmcli" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" ) @@ -89,9 +89,10 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c func Result2Error(res interface{}, err error) error { e := dtmcli.CheckResult(res, err) if e == dtmcli.ErrFailure { - return status.New(codes.Aborted, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err() - } else if e == dtmcli.ErrPending { - return status.New(codes.Unavailable, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err() + dtmcli.LogRedf("failure: res: %v, err: %v", res, e) + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() + } else if e == dtmcli.ErrOngoing { + return status.New(codes.Aborted, dtmcli.ResultOngoing).Err() } return e } diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 3c87ed1..f2923f0 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -7,47 +7,48 @@ import ( "gorm.io/gorm/clause" ) -func svcSubmit(t *TransGlobal, waitResult bool) (interface{}, error) { +func svcSubmit(t *TransGlobal) (interface{}, error) { db := dbGet() t.Status = dtmcli.StatusSubmitted err := t.saveNew(db) if err == errUniqueConflict { - dbt := TransFromDb(db, t.Gid) + dbt := transFromDb(db, t.Gid) if dbt.Status == dtmcli.StatusPrepared { - updates := t.setNextCron(config.TransCronInterval) + updates := t.setNextCron(cronReset) db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) } else if dbt.Status != dtmcli.StatusSubmitted { - return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot sumbmit", dbt.Status)}, nil + return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil } } - return t.Process(db, waitResult), nil + return t.Process(db), nil } func svcPrepare(t *TransGlobal) (interface{}, error) { t.Status = dtmcli.StatusPrepared err := t.saveNew(dbGet()) if err == errUniqueConflict { - dbt := TransFromDb(dbGet(), t.Gid) + dbt := transFromDb(dbGet(), t.Gid) if dbt.Status != dtmcli.StatusPrepared { - return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot prepare", dbt.Status)}, nil + return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil } } return dtmcli.MapSuccess, nil } -func svcAbort(t *TransGlobal, waitResult bool) (interface{}, error) { +func svcAbort(t *TransGlobal) (interface{}, error) { db := dbGet() - dbt := TransFromDb(db, t.Gid) - if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != "aborting" { - return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: %s current status %s, cannot abort", dbt.TransType, dbt.Status)}, nil + dbt := transFromDb(db, t.Gid) + if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting { + return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil } - return dbt.Process(db, waitResult), nil + dbt.changeStatus(db, dtmcli.StatusAborting) + return dbt.Process(db), nil } func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, error) { db := dbGet() - dbt := TransFromDb(db, branch.Gid) + dbt := transFromDb(db, branch.Gid) if dbt.Status != dtmcli.StatusPrepared { return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil } @@ -62,14 +63,14 @@ func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, err DoNothing: true, }).Create(branches) global := TransGlobal{Gid: branch.Gid} - global.touch(dbGet(), config.TransCronInterval) + global.touch(dbGet(), cronKeep) return dtmcli.MapSuccess, nil } func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) { branch.Status = dtmcli.StatusPrepared db := dbGet() - dbt := TransFromDb(db, branch.Gid) + dbt := transFromDb(db, branch.Gid) if dbt.Status != dtmcli.StatusPrepared { return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil } @@ -80,6 +81,6 @@ func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) { DoNothing: true, }).Create(branches) global := TransGlobal{Gid: branch.Gid} - global.touch(db, config.TransCronInterval) + global.touch(db, cronKeep) return dtmcli.MapSuccess, nil } diff --git a/dtmsvr/api_grpc.go b/dtmsvr/api_grpc.go index 074faf6..4b44cdf 100644 --- a/dtmsvr/api_grpc.go +++ b/dtmsvr/api_grpc.go @@ -19,7 +19,7 @@ func (s *dtmServer) NewGid(ctx context.Context, in *emptypb.Empty) (*dtmgrpc.Dtm } func (s *dtmServer) Submit(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcSubmit(TransFromDtmRequest(in), in.WaitResult) + r, err := svcSubmit(TransFromDtmRequest(in)) return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err) } @@ -29,7 +29,7 @@ func (s *dtmServer) Prepare(ctx context.Context, in *pb.DtmRequest) (*emptypb.Em } func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcAbort(TransFromDtmRequest(in), in.WaitResult) + r, err := svcAbort(TransFromDtmRequest(in)) return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err) } diff --git a/dtmsvr/api_http.go b/dtmsvr/api_http.go index ec59e85..07aec3a 100644 --- a/dtmsvr/api_http.go +++ b/dtmsvr/api_http.go @@ -2,22 +2,22 @@ package dtmsvr import ( "errors" + "math" "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" - "gorm.io/gorm" ) func addRoute(engine *gin.Engine) { + engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid)) engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare)) engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit)) + engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort)) engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch)) engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch)) - engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort)) engine.GET("/api/dtmsvr/query", common.WrapHandler(query)) engine.GET("/api/dtmsvr/all", common.WrapHandler(all)) - engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid)) } func newGid(c *gin.Context) (interface{}, error) { @@ -29,11 +29,11 @@ func prepare(c *gin.Context) (interface{}, error) { } func submit(c *gin.Context) (interface{}, error) { - return svcSubmit(TransFromContext(c), c.Query("wait_result") == "1") + return svcSubmit(TransFromContext(c)) } func abort(c *gin.Context) (interface{}, error) { - return svcAbort(TransFromContext(c), c.Query("wait_result") == "1") + return svcAbort(TransFromContext(c)) } func registerXaBranch(c *gin.Context) (interface{}, error) { @@ -61,24 +61,19 @@ func query(c *gin.Context) (interface{}, error) { if gid == "" { return nil, errors.New("no gid specified") } - trans := TransGlobal{} db := dbGet() - db.Begin() - dbr := db.Must().Where("gid", gid).First(&trans) - if dbr.Error == gorm.ErrRecordNotFound { - return M{"transaction": nil, "branches": [0]int{}}, nil - } + trans := transFromDb(db, gid) branches := []TransBranch{} db.Must().Where("gid", gid).Find(&branches) return M{"transaction": trans, "branches": branches}, nil } func all(c *gin.Context) (interface{}, error) { - lastId := c.Query("last_id") - if lastId == "" { - lastId = "2000000000" + lastID := c.Query("last_id") + lid := math.MaxInt64 + if lastID != "" { + lid = dtmcli.MustAtoi(lastID) } - lid := dtmcli.MustAtoi(lastId) trans := []TransGlobal{} dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(100).Find(&trans) return M{"transactions": trans}, nil diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 8880511..78e0546 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -2,7 +2,6 @@ package dtmsvr import ( "fmt" - "math" "math/rand" "runtime/debug" "time" @@ -10,7 +9,10 @@ import ( "github.com/yedf/dtm/dtmcli" ) -// CronForwardDuration will be set in test, cron will fetch trans which expire in CronForwardDuration +// NowForwardDuration will be set in test, trans may be timeout +var NowForwardDuration time.Duration = time.Duration(0) + +// CronForwardDuration will be set in test. cron will fetch trans which expire in CronForwardDuration var CronForwardDuration time.Duration = time.Duration(0) // CronTransOnce cron expired trans. use expireIn as expire time @@ -24,7 +26,8 @@ func CronTransOnce() (hasTrans bool) { if TransProcessedTestChan != nil { defer WaitTransProcessed(trans.Gid) } - trans.Process(dbGet(), true) + trans.WaitResult = true + trans.Process(dbGet()) return } @@ -33,7 +36,7 @@ func CronExpiredTrans(num int) { for i := 0; i < num || num == -1; i++ { hasTrans := CronTransOnce() if !hasTrans && num != 1 { - sleepCronTime(0) + sleepCronTime() } } } @@ -44,7 +47,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal { db := dbGet() getTime := dtmcli.GetDBSpecial().TimestampAdd expire := int(expireIn / time.Second) - whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3)) + whereTime := fmt.Sprintf("next_cron_time < %s and update_time < %s", getTime(expire), getTime(expire-3)) // 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢 // 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出 dbr := db.Must().Model(&trans). @@ -53,7 +56,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal { return nil } dbr = db.Must().Where("owner=?", owner).Find(&trans) - updates := trans.setNextCron(trans.NextCronInterval * 2) // 下次被cron的间隔加倍 + updates := trans.setNextCron(cronKeep) db.Must().Model(&trans).Select(updates).Updates(&trans) return &trans } @@ -67,9 +70,9 @@ func handlePanic(perr *error) { } } -func sleepCronTime(milli int) { - delta := math.Min(3, float64(config.TransCronInterval)) - interval := time.Duration((float64(config.TransCronInterval) - rand.Float64()*delta) * float64(time.Second)) - dtmcli.Logf("sleeping for %v pass in %d milli", interval, milli) - time.Sleep(dtmcli.If(milli == 0, interval, time.Duration(milli*int(time.Millisecond))).(time.Duration)) +func sleepCronTime() { + normal := time.Duration((float64(config.TransCronInterval) - rand.Float64()) * float64(time.Second)) + interval := dtmcli.If(CronForwardDuration > 0, 1*time.Millisecond, normal).(time.Duration) + dtmcli.Logf("sleeping for %v milli", interval/time.Microsecond) + time.Sleep(interval) } diff --git a/dtmsvr/dtmsvr.go b/dtmsvr/dtmsvr.go index 216bb3a..ee0b35d 100644 --- a/dtmsvr/dtmsvr.go +++ b/dtmsvr/dtmsvr.go @@ -70,7 +70,7 @@ func updateBranchAsync() { updates = append(updates, TransBranch{ ModelBase: common.ModelBase{ID: updateBranch.id}, Status: updateBranch.status, - FinishTime: updateBranch.finish_time, + FinishTime: updateBranch.finishTime, }) case <-time.After(checkInterval): } diff --git a/dtmsvr/dtmsvr.mysql.sql b/dtmsvr/dtmsvr.mysql.sql index 0c7c550..20ce7b9 100644 --- a/dtmsvr/dtmsvr.mysql.sql +++ b/dtmsvr/dtmsvr.mysql.sql @@ -3,7 +3,7 @@ CREATE DATABASE IF NOT EXISTS dtm ; drop table IF EXISTS dtm.trans_global; CREATE TABLE if not EXISTS dtm.trans_global ( - `id` int(11) NOT NULL AUTO_INCREMENT, + `id` bigint(22) NOT NULL AUTO_INCREMENT, `gid` varchar(128) NOT NULL COMMENT '事务全局id', `trans_type` varchar(45) not null COMMENT '事务类型: saga | xa | tcc | msg', -- `data` TEXT COMMENT '事务携带的数据', -- 影响性能,不必要存储 @@ -15,6 +15,8 @@ CREATE TABLE if not EXISTS dtm.trans_global ( `commit_time` datetime DEFAULT NULL, `finish_time` datetime DEFAULT NULL, `rollback_time` datetime DEFAULT NULL, + `options` varchar(256) DEFAULT '', + `custom_data` varchar(256) DEFAULT '', `next_cron_interval` int(11) default null comment '下次定时处理的间隔', `next_cron_time` datetime default null comment '下次定时处理的时间', `owner` varchar(128) not null default '' comment '正在处理全局事务的锁定者', @@ -27,7 +29,7 @@ CREATE TABLE if not EXISTS dtm.trans_global ( ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; drop table IF EXISTS dtm.trans_branch; CREATE TABLE IF NOT EXISTS dtm.trans_branch ( - `id` int(11) NOT NULL AUTO_INCREMENT, + `id` bigint(22) NOT NULL AUTO_INCREMENT, `gid` varchar(128) NOT NULL COMMENT '事务全局id', `url` varchar(128) NOT NULL COMMENT '动作关联的url', `data` TEXT COMMENT '请求所携带的数据', @@ -45,7 +47,7 @@ CREATE TABLE IF NOT EXISTS dtm.trans_branch ( ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; drop table IF EXISTS dtm.trans_log; CREATE TABLE IF NOT EXISTS dtm.trans_log ( - `id` int(11) NOT NULL AUTO_INCREMENT, + `id` bigint(22) NOT NULL AUTO_INCREMENT, `gid` varchar(128) NOT NULL COMMENT '事务全局id', `branch_id` varchar(128) DEFAULT NULL COMMENT '事务分支', `action` varchar(45) DEFAULT NULL COMMENT '行为', diff --git a/dtmsvr/dtmsvr.postgres.sql b/dtmsvr/dtmsvr.postgres.sql index 6bc598a..ed8b273 100644 --- a/dtmsvr/dtmsvr.postgres.sql +++ b/dtmsvr/dtmsvr.postgres.sql @@ -5,7 +5,7 @@ drop table IF EXISTS dtm.trans_global; -- SQLINES LICENSE FOR EVALUATION USE ONLY CREATE SEQUENCE if not EXISTS dtm.trans_global_seq; CREATE TABLE if not EXISTS dtm.trans_global ( - id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_global_seq'), + id bigint NOT NULL DEFAULT NEXTVAL ('dtm.trans_global_seq'), gid varchar(128) NOT NULL, trans_type varchar(45) not null, status varchar(45) NOT NULL, @@ -16,6 +16,8 @@ CREATE TABLE if not EXISTS dtm.trans_global ( commit_time timestamp(0) DEFAULT NULL, finish_time timestamp(0) DEFAULT NULL, rollback_time timestamp(0) DEFAULT NULL, + options varchar(256) DEFAULT '', + custom_data varchar(256) DEFAULT '', next_cron_interval int default null, next_cron_time timestamp(0) default null, owner varchar(128) not null default '', @@ -30,7 +32,7 @@ drop table IF EXISTS dtm.trans_branch; -- SQLINES LICENSE FOR EVALUATION USE ONLY CREATE SEQUENCE if not EXISTS dtm.trans_branch_seq; CREATE TABLE IF NOT EXISTS dtm.trans_branch ( - id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_branch_seq'), + id bigint NOT NULL DEFAULT NEXTVAL ('dtm.trans_branch_seq'), gid varchar(128) NOT NULL, url varchar(128) NOT NULL, data TEXT, @@ -50,7 +52,7 @@ drop table IF EXISTS dtm.trans_log; -- SQLINES LICENSE FOR EVALUATION USE ONLY CREATE SEQUENCE if not EXISTS dtm.trans_log_seq; CREATE TABLE IF NOT EXISTS dtm.trans_log ( - id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_log_seq'), + id bigint NOT NULL DEFAULT NEXTVAL ('dtm.trans_log_seq'), gid varchar(128) NOT NULL, branch_id varchar(128) DEFAULT NULL, action varchar(45) DEFAULT NULL, diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 1948beb..1eee285 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -32,9 +32,12 @@ type TransGlobal struct { CommitTime *time.Time FinishTime *time.Time RollbackTime *time.Time + Options string + CustomData string `json:"custom_data"` NextCronInterval int64 NextCronTime *time.Time - processStarted time.Time // record the start time of process + dtmcli.TransOptions + processStarted time.Time // record the start time of process } // TableName TableName @@ -44,12 +47,12 @@ func (*TransGlobal) TableName() string { type transProcessor interface { GenBranches() []TransBranch - ProcessOnce(db *common.DB, branches []TransBranch) + ProcessOnce(db *common.DB, branches []TransBranch) error } -func (t *TransGlobal) touch(db *common.DB, interval int64) *gorm.DB { +func (t *TransGlobal) touch(db *common.DB, ctype cronType) *gorm.DB { writeTransLog(t.Gid, "touch trans", "", "", "") - updates := t.setNextCron(interval) + updates := t.setNextCron(ctype) return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Select(updates).Updates(t) } @@ -57,7 +60,7 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { writeTransLog(t.Gid, "change status", status, "", "") old := t.Status t.Status = status - updates := t.setNextCron(config.TransCronInterval) + updates := t.setNextCron(cronReset) updates = append(updates, "status") now := time.Now() if status == dtmcli.StatusSucceed { @@ -72,6 +75,21 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { return dbr } +func (t *TransGlobal) isTimeout() bool { + timeout := t.TimeoutToFail + if t.TimeoutToFail == 0 && t.TransType != "saga" { + timeout = config.TimeoutToFail + } + if timeout == 0 { + return false + } + return time.Since(*t.CreateTime)+NowForwardDuration >= time.Duration(timeout)*time.Second +} + +func (t *TransGlobal) needProcess() bool { + return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout() +} + // TransBranch branch transaction type TransBranch struct { common.ModelBase @@ -124,14 +142,18 @@ func (t *TransGlobal) getProcessor() transProcessor { } // Process process global transaction once -func (t *TransGlobal) Process(db *common.DB, waitResult bool) dtmcli.M { - r := t.process(db, waitResult) +func (t *TransGlobal) Process(db *common.DB) dtmcli.M { + r := t.process(db) transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess) return r } -func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M { - if !waitResult { +func (t *TransGlobal) process(db *common.DB) dtmcli.M { + if t.Options != "" { + dtmcli.MustUnmarshalString(t.Options, &t.TransOptions) + } + + if !t.WaitResult { go t.processInner(db) return dtmcli.MapSuccess } @@ -149,6 +171,9 @@ func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M { func (t *TransGlobal) processInner(db *common.DB) (rerr error) { defer handlePanic(&rerr) defer func() { + if rerr != nil { + dtmcli.LogRedf("processInner got error: %s", rerr.Error()) + } if TransProcessedTestChan != nil { dtmcli.Logf("processed: %s", t.Gid) TransProcessedTestChan <- t.Gid @@ -156,24 +181,38 @@ func (t *TransGlobal) processInner(db *common.DB) (rerr error) { } }() dtmcli.Logf("processing: %s status: %s", t.Gid, t.Status) - if t.Status == dtmcli.StatusPrepared && t.TransType != "msg" { - t.changeStatus(db, "aborting") - } branches := []TransBranch{} db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches) t.processStarted = time.Now() - t.getProcessor().ProcessOnce(db, branches) + rerr = t.getProcessor().ProcessOnce(db, branches) return } -func (t *TransGlobal) setNextCron(expireIn int64) []string { - t.NextCronInterval = expireIn +type cronType int + +const ( + cronBackoff cronType = iota + cronReset + cronKeep +) + +func (t *TransGlobal) setNextCron(ctype cronType) []string { + if ctype == cronBackoff { + t.NextCronInterval = t.NextCronInterval * 2 + } else if ctype == cronKeep { + // do nothing + } else if t.RetryInterval != 0 { + t.NextCronInterval = t.RetryInterval + } else { + t.NextCronInterval = config.RetryInterval + } + next := time.Now().Add(time.Duration(t.NextCronInterval) * time.Second) t.NextCronTime = &next return []string{"next_cron_interval", "next_cron_time"} } -func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) string { +func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) (string, error) { if t.Protocol == "grpc" { dtmcli.PanicIf(strings.HasPrefix(url, "http"), fmt.Errorf("bad url for grpc: %s", url)) server, method := dtmgrpc.GetServerAndMethod(url) @@ -188,11 +227,17 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran BusiData: branchData, }, &emptypb.Empty{}) if err == nil { - return dtmcli.ResultSuccess - } else if status.Code(err) == codes.Aborted { - return dtmcli.ResultFailure + return dtmcli.ResultSuccess, nil + } + st, ok := status.FromError(err) + if ok && st.Code() == codes.Aborted { + if st.Message() == dtmcli.ResultOngoing { + return dtmcli.ResultOngoing, nil + } else if st.Message() == dtmcli.ResultFailure { + return dtmcli.ResultFailure, nil + } } - return err.Error() + return "", err } dtmcli.PanicIf(!strings.HasPrefix(url, "http"), fmt.Errorf("bad url for http: %s", url)) resp, err := dtmcli.RestyClient.R().SetBody(string(branchData)). @@ -204,36 +249,53 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran }). SetHeader("Content-type", "application/json"). Execute(dtmcli.If(branchData == nil, "GET", "POST").(string), url) - e2p(err) - return resp.String() + if err != nil { + return "", err + } + return resp.String(), nil } -func (t *TransGlobal) getBranchResult(branch *TransBranch) string { - return t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data)) +func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { + body, err := t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data)) + if err != nil { + return "", err + } + if strings.Contains(body, dtmcli.ResultSuccess) { + return dtmcli.StatusSucceed, nil + } else if strings.HasSuffix(t.TransType, "saga") && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) { + return dtmcli.StatusFailed, nil + } else if strings.Contains(body, dtmcli.ResultOngoing) { + return "", dtmcli.ErrOngoing + } + return "", fmt.Errorf("http result should contains SUCCESS|FAILURE|ONGOING. grpc error should return nil|Aborted with message(FAILURE|ONGOING). \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body) } -func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) { - body := t.getBranchResult(branch) - status := "" - if strings.Contains(body, dtmcli.ResultSuccess) { - status = dtmcli.StatusSucceed - } else if t.TransType == "saga" && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) { - status = dtmcli.StatusFailed - } else { - panic(fmt.Errorf("http result should contains SUCCESS|FAILURE. grpc error should return nil|Aborted. \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body)) +func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) error { + status, err := t.getBranchResult(branch) + if status != "" { + branch.changeStatus(db, status) } branchMetrics(t, branch, status == dtmcli.StatusSucceed) - // 如果一次处理超过1500ms,那么touch一下TransGlobal,避免被Cron取出 - if time.Since(t.processStarted)+CronForwardDuration >= 1500*time.Millisecond || t.NextCronInterval > config.TransCronInterval { - t.touch(db, config.TransCronInterval) + // if time pass 1500ms and NextCronInterval is not default, then reset NextCronInterval + if err == nil && time.Since(t.processStarted)+NowForwardDuration >= 1500*time.Millisecond || + t.NextCronInterval > config.RetryInterval && t.NextCronInterval > t.RetryInterval { + t.touch(db, cronReset) + } else if err == dtmcli.ErrOngoing { + t.touch(db, cronKeep) + } else { + t.touch(db, cronBackoff) } - branch.changeStatus(db, status) + return err } func (t *TransGlobal) saveNew(db *common.DB) error { return db.Transaction(func(db1 *gorm.DB) error { db := &common.DB{DB: db1} - t.setNextCron(config.TransCronInterval) + t.setNextCron(cronReset) + t.Options = dtmcli.MustMarshalString(t.TransOptions) + if t.Options == "{}" { + t.Options = "" + } writeTransLog(t.Gid, "create trans", t.Status, "", t.Data) dbr := db.Must().Clauses(clause.OnConflict{ DoNothing: true, @@ -279,25 +341,3 @@ func TransFromDtmRequest(c *dtmgrpc.DtmRequest) *TransGlobal { Protocol: "grpc", } } - -// TransFromDb construct trans from db -func TransFromDb(db *common.DB, gid string) *TransGlobal { - m := TransGlobal{} - dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) - if dbr.Error == gorm.ErrRecordNotFound { - return nil - } - e2p(dbr.Error) - return &m -} - -func checkLocalhost(branches []TransBranch) { - if config.DisableLocalhost == 0 { - return - } - for _, branch := range branches { - if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") { - panic(errors.New("url for localhost is disabled. check for your config")) - } - } -} diff --git a/dtmsvr/trans_msg.go b/dtmsvr/trans_msg.go index 0ca83ee..1c4e0d0 100644 --- a/dtmsvr/trans_msg.go +++ b/dtmsvr/trans_msg.go @@ -33,23 +33,26 @@ func (t *transMsgProcessor) GenBranches() []TransBranch { } func (t *TransGlobal) mayQueryPrepared(db *common.DB) { - if t.Status != dtmcli.StatusPrepared { + if !t.needProcess() || t.Status == dtmcli.StatusSubmitted { return } - body := t.getURLResult(t.QueryPrepared, "", "", nil) + body, err := t.getURLResult(t.QueryPrepared, "", "", nil) if strings.Contains(body, dtmcli.ResultSuccess) { t.changeStatus(db, dtmcli.StatusSubmitted) } else if strings.Contains(body, dtmcli.ResultFailure) { t.changeStatus(db, dtmcli.StatusFailed) + } else if strings.Contains(body, dtmcli.ResultOngoing) { + t.touch(db, cronReset) } else { - t.touch(db, t.NextCronInterval*2) + dtmcli.LogRedf("getting result failed for %s. error: %s", t.QueryPrepared, err.Error()) + t.touch(db, cronBackoff) } } -func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { +func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { t.mayQueryPrepared(db) - if t.Status != dtmcli.StatusSubmitted { - return + if !t.needProcess() || t.Status == dtmcli.StatusPrepared { + return nil } current := 0 // 当前正在处理的步骤 for ; current < len(branches); current++ { @@ -57,14 +60,17 @@ func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { if branch.BranchType != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared { continue } - t.execBranch(db, branch) + err := t.execBranch(db, branch) + if err != nil { + return err + } if branch.Status != dtmcli.StatusSucceed { break } } if current == len(branches) { // msg 事务完成 t.changeStatus(db, dtmcli.StatusSucceed) - return + return nil } panic("msg go pass all branch") } diff --git a/dtmsvr/trans_saga.go b/dtmsvr/trans_saga.go index a8c27a5..dc74237 100644 --- a/dtmsvr/trans_saga.go +++ b/dtmsvr/trans_saga.go @@ -2,6 +2,7 @@ package dtmsvr import ( "fmt" + "time" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" @@ -35,37 +36,152 @@ func (t *transSagaProcessor) GenBranches() []TransBranch { return branches } -func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { - if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed { - return +type cSagaCustom struct { + Orders map[int][]int `json:"orders"` + Concurrent bool `json:"concurrent"` +} + +type branchResult struct { + index int + status string + started bool + branchType string +} + +func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + // when saga tasks is fetched, it always need to process + dtmcli.Logf("status: %s timeout: %t", t.Status, t.isTimeout()) + if t.Status == dtmcli.StatusSubmitted && t.isTimeout() { + t.changeStatus(db, dtmcli.StatusAborting) + } + n := len(branches) + + csc := cSagaCustom{Orders: map[int][]int{}} + if t.CustomData != "" { + dtmcli.MustUnmarshalString(t.CustomData, &csc) + } + // resultStats + var rsAToStart, rsAStarted, rsADone, rsAFailed, rsASucceed, rsCToStart, rsCDone, rsCSucceed int + branchResults := make([]branchResult, n) // save the branch result + for i := 0; i < n; i++ { + b := branches[i] + if b.BranchType == dtmcli.BranchAction { + if b.Status == dtmcli.StatusPrepared { + rsAToStart++ + } else if b.Status == dtmcli.StatusFailed { + rsAFailed++ + } + } + branchResults[i] = branchResult{status: branches[i].Status, branchType: branches[i].BranchType} + } + isPreconditionsSucceed := func(current int) bool { + // if !csc.Concurrent,then check the branch in previous step is succeed + if !csc.Concurrent && current >= 2 && branches[current-2].Status != dtmcli.StatusSucceed { + return false + } + // if csc.concurrent, then check the Orders. origin one step correspond to 2 step in dtmsvr + for _, pre := range csc.Orders[current/2] { + if branches[pre*2+1].Status != dtmcli.StatusSucceed { + return false + } + } + return true } - current := 0 // 当前正在处理的步骤 - for ; current < len(branches); current++ { - branch := &branches[current] - if branch.BranchType != dtmcli.BranchAction || branch.Status == dtmcli.StatusSucceed { - continue + + resultChan := make(chan branchResult, n) + asyncExecBranch := func(i int) { + var err error + defer func() { + if x := recover(); x != nil { + err = dtmcli.AsError(x) + } + resultChan <- branchResult{index: i, status: branches[i].Status, branchType: branches[i].BranchType} + if err != nil { + dtmcli.LogRedf("exec branch error: %v", err) + } + }() + err = t.execBranch(db, &branches[i]) + } + pickToRunActions := func() []int { + toRun := []int{} + for current := 0; current < n; current++ { + br := &branchResults[current] + if br.branchType == dtmcli.BranchAction && !br.started && isPreconditionsSucceed(current) && br.status == dtmcli.StatusPrepared { + toRun = append(toRun, current) + } + } + dtmcli.Logf("toRun picked for action is: %v", toRun) + return toRun + } + runBranches := func(toRun []int) { + for _, b := range toRun { + branchResults[b].started = true + if branchResults[b].branchType == dtmcli.BranchAction { + rsAStarted++ + } + go asyncExecBranch(b) + } + } + pickAndRunCompensates := func(toRunActions []int) { + for _, b := range toRunActions { + // these branches may have run. so flag them to status succeed, then run the corresponding compensate + branchResults[b].status = dtmcli.StatusSucceed } - // 找到了一个非succeed的action - if branch.Status == dtmcli.StatusPrepared { - t.execBranch(db, branch) + for i, b := range branchResults { + if b.branchType == dtmcli.BranchCompensate && b.status != dtmcli.StatusSucceed && branchResults[i+1].status != dtmcli.StatusPrepared { + rsCToStart++ + go asyncExecBranch(i) + } } - if branch.Status != dtmcli.StatusSucceed { + } + waitDoneOnce := func() { + select { + case r := <-resultChan: + br := &branchResults[r.index] + br.status = r.status + if r.branchType == dtmcli.BranchAction { + rsADone++ + if r.status == dtmcli.StatusFailed { + rsAFailed++ + } else if r.status == dtmcli.StatusSucceed { + rsASucceed++ + } + } else { + rsCDone++ + if r.status == dtmcli.StatusSucceed { + rsCSucceed++ + } + } + dtmcli.Logf("branch done: %v", r) + case <-time.After(time.Duration(time.Second * 3)): + dtmcli.Logf("wait once for done") + } + } + + for t.Status == dtmcli.StatusSubmitted && !t.isTimeout() && rsAFailed == 0 && rsADone != rsAToStart { + toRun := pickToRunActions() + runBranches(toRun) + if rsADone == rsAStarted { // no branch is running, so break break } + waitDoneOnce() } - if current == len(branches) { // saga 事务完成 + if t.Status == dtmcli.StatusSubmitted && rsAFailed == 0 && rsAToStart == rsASucceed { t.changeStatus(db, dtmcli.StatusSucceed) - return + return nil } - if t.Status != "aborting" && t.Status != dtmcli.StatusFailed { - t.changeStatus(db, "aborting") + if t.Status == dtmcli.StatusSubmitted && (rsAFailed > 0 || t.isTimeout()) { + t.changeStatus(db, dtmcli.StatusAborting) } - for current = current - 1; current >= 0; current-- { - branch := &branches[current] - if branch.BranchType != dtmcli.BranchCompensate || branch.Status != dtmcli.StatusPrepared { - continue + if t.Status == dtmcli.StatusAborting { + toRun := pickToRunActions() + pickAndRunCompensates(toRun) + for rsCDone != rsCToStart { + waitDoneOnce() } - t.execBranch(db, branch) } - t.changeStatus(db, dtmcli.StatusFailed) + if t.Status == dtmcli.StatusAborting && rsCToStart == rsCSucceed { + t.changeStatus(db, dtmcli.StatusFailed) + } + return nil } diff --git a/dtmsvr/trans_tcc.go b/dtmsvr/trans_tcc.go index 43f564a..80be0ee 100644 --- a/dtmsvr/trans_tcc.go +++ b/dtmsvr/trans_tcc.go @@ -17,15 +17,22 @@ func (t *transTccProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { - if t.Status == dtmcli.StatusSucceed || t.Status == dtmcli.StatusFailed { - return +func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + if !t.needProcess() { + return nil + } + if t.Status == dtmcli.StatusPrepared && t.isTimeout() { + t.changeStatus(db, dtmcli.StatusAborting) } branchType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchConfirm, dtmcli.BranchCancel).(string) for current := len(branches) - 1; current >= 0; current-- { if branches[current].BranchType == branchType && branches[current].Status == dtmcli.StatusPrepared { - t.execBranch(db, &branches[current]) + err := t.execBranch(db, &branches[current]) + if err != nil { + return err + } } } t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) + return nil } diff --git a/dtmsvr/trans_xa.go b/dtmsvr/trans_xa.go index e4a79a4..007aa76 100644 --- a/dtmsvr/trans_xa.go +++ b/dtmsvr/trans_xa.go @@ -17,15 +17,22 @@ func (t *transXaProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { - if t.Status == dtmcli.StatusSucceed { - return +func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { + if !t.needProcess() { + return nil + } + if t.Status == dtmcli.StatusPrepared && t.isTimeout() { + t.changeStatus(db, dtmcli.StatusAborting) } currentType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchCommit, dtmcli.BranchRollback).(string) for _, branch := range branches { if branch.BranchType == currentType && branch.Status != dtmcli.StatusSucceed { - t.execBranch(db, &branch) + err := t.execBranch(db, &branch) + if err != nil { + return err + } } } t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) + return nil } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index ee07452..f498a5f 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -2,6 +2,7 @@ package dtmsvr import ( "encoding/hex" + "errors" "fmt" "net" "strings" @@ -10,15 +11,16 @@ import ( "github.com/bwmarrin/snowflake" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" + "gorm.io/gorm" ) // M a short name type M = map[string]interface{} type branchStatus struct { - id uint - status string - finish_time *time.Time + id uint64 + status string + finishTime *time.Time } var p2e = dtmcli.P2E @@ -91,3 +93,25 @@ func getOneHexIP() string { fmt.Printf("err is: %s", err.Error()) return "" // 获取不到IP,则直接返回空 } + +// transFromDb construct trans from db +func transFromDb(db *common.DB, gid string) *TransGlobal { + m := TransGlobal{} + dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) + if dbr.Error == gorm.ErrRecordNotFound { + return nil + } + e2p(dbr.Error) + return &m +} + +func checkLocalhost(branches []TransBranch) { + if config.DisableLocalhost == 0 { + return + } + for _, branch := range branches { + if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") { + panic(errors.New("url for localhost is disabled. check for your config")) + } + } +} diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index d6f7732..ede2ec6 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -16,7 +16,7 @@ func TestUtils(t *testing.T) { assert.Error(t, err) CronExpiredTrans(1) - sleepCronTime(10) + sleepCronTime() } func TestCheckLocalHost(t *testing.T) { @@ -31,3 +31,15 @@ func TestCheckLocalHost(t *testing.T) { }) assert.Nil(t, err) } + +func TestSetNextCron(t *testing.T) { + tg := TransGlobal{} + tg.RetryInterval = 15 + tg.setNextCron(cronReset) + assert.Equal(t, int64(15), tg.NextCronInterval) + tg.RetryInterval = 0 + tg.setNextCron(cronReset) + assert.Equal(t, config.RetryInterval, tg.NextCronInterval) + tg.setNextCron(cronBackoff) + assert.Equal(t, config.RetryInterval*2, tg.NextCronInterval) +} diff --git a/examples/base_grpc.go b/examples/base_grpc.go index 8b9f0d2..d449015 100644 --- a/examples/base_grpc.go +++ b/examples/base_grpc.go @@ -45,7 +45,7 @@ func handleGrpcBusiness(in *dtmgrpc.BusiRequest, result1 string, result2 string, if res == dtmcli.ResultSuccess { return nil } else if res == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user want to rollback").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } return status.New(codes.Internal, fmt.Sprintf("unknow result %s", res)).Err() } @@ -113,7 +113,7 @@ func (s *busiServer) TransInXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*d dtmcli.MustUnmarshal(in.BusiData, &req) return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { if req.TransInResult == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user return failure").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2) return err @@ -125,7 +125,7 @@ func (s *busiServer) TransOutXa(ctx context.Context, in *dtmgrpc.BusiRequest) (* dtmcli.MustUnmarshal(in.BusiData, &req) return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { if req.TransOutResult == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user return failure").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1) return err diff --git a/examples/base_http.go b/examples/base_http.go index 049119a..8dcebc7 100644 --- a/examples/base_http.go +++ b/examples/base_http.go @@ -2,6 +2,7 @@ package examples import ( "database/sql" + "errors" "fmt" "time" @@ -79,6 +80,9 @@ func handleGeneralBusiness(c *gin.Context, result1 string, result2 string, busi info := infoFromContext(c) res := dtmcli.OrString(result1, result2, dtmcli.ResultSuccess) dtmcli.Logf("%s %s result: %s", busi, info.String(), res) + if res == "ERROR" { + return nil, errors.New("ERROR from user") + } return M{"dtm_result": res}, nil } @@ -145,4 +149,12 @@ func BaseAddRoute(app *gin.Engine) { }) })) + app.POST(BusiAPI+"/TestPanic", common.WrapHandler(func(c *gin.Context) (interface{}, error) { + if c.Query("panic_error") != "" { + panic(errors.New("panic_error")) + } else if c.Query("panic_string") != "" { + panic("panic_string") + } + return "SUCCESS", nil + })) } diff --git a/examples/grpc_saga_barrier.go b/examples/grpc_saga_barrier.go index f48bfce..a1ee504 100644 --- a/examples/grpc_saga_barrier.go +++ b/examples/grpc_saga_barrier.go @@ -24,7 +24,7 @@ func init() { func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { if result == dtmcli.ResultFailure { - return status.New(codes.Aborted, "user rollback").Err() + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() } _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) return err diff --git a/examples/http_msg.go b/examples/http_msg.go index d304754..044f541 100644 --- a/examples/http_msg.go +++ b/examples/http_msg.go @@ -11,7 +11,7 @@ func init() { msg := dtmcli.NewMsg(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/TransOut", req). Add(Busi+"/TransIn", req) - err := msg.Prepare(Busi + "/TransQuery") + err := msg.Prepare(Busi + "/query") dtmcli.FatalIfError(err) dtmcli.Logf("busi trans submit") err = msg.Submit() diff --git a/examples/http_saga.go b/examples/http_saga.go index d014eee..2558586 100644 --- a/examples/http_saga.go +++ b/examples/http_saga.go @@ -23,11 +23,27 @@ func init() { saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/TransOut", Busi+"/TransOutRevert", req). Add(Busi+"/TransIn", Busi+"/TransInRevert", req) - saga.WaitResult = true // 设置为等待结果模式,后面的submit调用,会等待服务器处理这个事务。如果Submit正常返回,那么整个全局事务已成功完成 + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() dtmcli.Logf("result gid is: %s", saga.Gid) dtmcli.FatalIfError(err) return saga.Gid }) - + addSample("concurrent_saga", func() string { + dtmcli.Logf("a concurrent saga busi transaction begin") + req := &TransReq{Amount: 30} + csaga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). + Add(Busi+"/TransOut", Busi+"/TransOutRevert", req). + Add(Busi+"/TransOut", Busi+"/TransOutRevert", req). + Add(Busi+"/TransIn", Busi+"/TransInRevert", req). + Add(Busi+"/TransIn", Busi+"/TransInRevert", req). + EnableConcurrent(). + AddStepOrder(2, []int{0, 1}). + AddStepOrder(3, []int{0, 1}) + dtmcli.Logf("concurrent saga busi trans submit") + err := csaga.Submit() + dtmcli.Logf("result gid is: %s", csaga.Gid) + dtmcli.FatalIfError(err) + return csaga.Gid + }) } diff --git a/test/barrier_tcc_test.go b/test/barrier_tcc_test.go index 8bb656b..8b358a4 100644 --- a/test/barrier_tcc_test.go +++ b/test/barrier_tcc_test.go @@ -102,10 +102,10 @@ func tccBarrierDisorder(t *testing.T) { finishedChan <- "1" }() dtmcli.Logf("cron to timeout and then call cancel") - go CronTransOnce() + go cronTransOnceForwardNow(300) time.Sleep(100 * time.Millisecond) dtmcli.Logf("cron to timeout and then call cancelled twice") - CronTransOnce() + cronTransOnceForwardNow(300) timeoutChan <- "wake" timeoutChan <- "wake" <-finishedChan diff --git a/test/base_test.go b/test/base_test.go new file mode 100644 index 0000000..ba50e94 --- /dev/null +++ b/test/base_test.go @@ -0,0 +1,53 @@ +package test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" + "github.com/yedf/dtm/examples" +) + +func TestSqlDB(t *testing.T) { + asserts := assert.New(t) + db := common.DbGet(config.DB) + barrier := &dtmcli.BranchBarrier{ + TransType: "saga", + Gid: "gid2", + BranchID: "branch_id2", + BranchType: dtmcli.BranchAction, + } + db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") + tx, err := db.ToSQLDB().Begin() + asserts.Nil(err) + err = barrier.Call(tx, func(db dtmcli.DB) error { + dtmcli.Logf("rollback gid2") + return fmt.Errorf("gid2 error") + }) + asserts.Error(err, fmt.Errorf("gid2 error")) + dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{}) + asserts.Equal(dbr.RowsAffected, int64(1)) + dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) + asserts.Equal(dbr.RowsAffected, int64(0)) + barrier.BarrierID = 0 + tx2, err := db.ToSQLDB().Begin() + asserts.Nil(err) + err = barrier.Call(tx2, func(db dtmcli.DB) error { + dtmcli.Logf("submit gid2") + return nil + }) + asserts.Nil(err) + dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) + asserts.Equal(dbr.RowsAffected, int64(1)) +} + +func TestHttp(t *testing.T) { + resp, err := dtmcli.RestyClient.R().SetQueryParam("panic_string", "1").Post(examples.Busi + "/TestPanic") + assert.Nil(t, err) + assert.Contains(t, resp.String(), "panic_string") + resp, err = dtmcli.RestyClient.R().SetQueryParam("panic_error", "1").Post(examples.Busi + "/TestPanic") + assert.Nil(t, err) + assert.Contains(t, resp.String(), "panic_error") +} diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index f5d5524..e185a60 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -1,7 +1,6 @@ package test import ( - "fmt" "testing" "time" @@ -83,43 +82,10 @@ func transQuery(t *testing.T, gid string) { assert.Nil(t, err) } -func TestSqlDB(t *testing.T) { - asserts := assert.New(t) - db := common.DbGet(config.DB) - barrier := &dtmcli.BranchBarrier{ - TransType: "saga", - Gid: "gid2", - BranchID: "branch_id2", - BranchType: dtmcli.BranchAction, - } - db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") - tx, err := db.ToSQLDB().Begin() - asserts.Nil(err) - err = barrier.Call(tx, func(db dtmcli.DB) error { - dtmcli.Logf("rollback gid2") - return fmt.Errorf("gid2 error") - }) - asserts.Error(err, fmt.Errorf("gid2 error")) - dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{}) - asserts.Equal(dbr.RowsAffected, int64(1)) - dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) - asserts.Equal(dbr.RowsAffected, int64(0)) - barrier.BarrierID = 0 - tx2, err := db.ToSQLDB().Begin() - asserts.Nil(err) - err = barrier.Call(tx2, func(db dtmcli.DB) error { - dtmcli.Logf("submit gid2") - return nil - }) - asserts.Nil(err) - dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) - asserts.Equal(dbr.RowsAffected, int64(1)) -} - func TestUpdateBranchAsync(t *testing.T) { common.DtmConfig.UpdateBranchSync = 0 saga := genSaga("gid-update-branch-async", false, false) - saga.WaitResult = true + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) diff --git a/test/grpc_msg_test.go b/test/grpc_msg_test.go index 233aaca..c6a127e 100644 --- a/test/grpc_msg_test.go +++ b/test/grpc_msg_test.go @@ -12,7 +12,7 @@ import ( func TestGrpcMsg(t *testing.T) { grpcMsgNormal(t) - grpcMsgPending(t) + grpcMsgOngoing(t) } func grpcMsgNormal(t *testing.T) { @@ -23,15 +23,15 @@ func grpcMsgNormal(t *testing.T) { assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) } -func grpcMsgPending(t *testing.T) { +func grpcMsgOngoing(t *testing.T) { msg := genGrpcMsg("grpc-msg-pending") err := msg.Prepare(fmt.Sprintf("%s/examples.Busi/CanSubmit", examples.BusiGrpc)) assert.Nil(t, err) - examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.TransInResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) diff --git a/test/grpc_saga_test.go b/test/grpc_saga_test.go index 3bfd57b..393a1bc 100644 --- a/test/grpc_saga_test.go +++ b/test/grpc_saga_test.go @@ -11,7 +11,7 @@ import ( func TestGrpcSaga(t *testing.T) { sagaGrpcNormal(t) - sagaGrpcCommittedPending(t) + sagaGrpcCommittedOngoing(t) sagaGrpcRollback(t) } @@ -24,9 +24,9 @@ func sagaGrpcNormal(t *testing.T) { transQuery(t, saga.Gid) } -func sagaGrpcCommittedPending(t *testing.T) { - saga := genSagaGrpc("gid-committedPendingGrpc", false, false) - examples.MainSwitch.TransOutResult.SetOnce("PENDING") +func sagaGrpcCommittedOngoing(t *testing.T) { + saga := genSagaGrpc("gid-committedOngoingGrpc", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) saga.Submit() WaitTransProcessed(saga.Gid) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) @@ -37,10 +37,10 @@ func sagaGrpcCommittedPending(t *testing.T) { func sagaGrpcRollback(t *testing.T) { saga := genSagaGrpc("gid-rollbackSaga2Grpc", false, true) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) saga.Submit() WaitTransProcessed(saga.Gid) - assert.Equal(t, "aborting", getTransStatus(saga.Gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid)) diff --git a/test/grpc_tcc_test.go b/test/grpc_tcc_test.go index a894cd3..574d0f0 100644 --- a/test/grpc_tcc_test.go +++ b/test/grpc_tcc_test.go @@ -53,13 +53,13 @@ func tccGrpcRollback(t *testing.T) { err := dtmgrpc.TccGlobalTransaction(examples.DtmGrpcServer, gid, func(tcc *dtmgrpc.TccGrpc) error { _, err := tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransOutTcc", examples.BusiGrpc+"/examples.Busi/TransOutConfirm", examples.BusiGrpc+"/examples.Busi/TransOutRevert") assert.Nil(t, err) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) _, err = tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransInTcc", examples.BusiGrpc+"/examples.Busi/TransInConfirm", examples.BusiGrpc+"/examples.Busi/TransInRevert") return err }) assert.Error(t, err) WaitTransProcessed(gid) - assert.Equal(t, "aborting", getTransStatus(gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) } diff --git a/test/main_test.go b/test/main_test.go index 5b472fe..328cb4c 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -14,7 +14,8 @@ import ( func TestMain(m *testing.M) { dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) dtmsvr.TransProcessedTestChan = make(chan string, 1) - dtmsvr.CronForwardDuration = 60 * time.Second + dtmsvr.NowForwardDuration = 0 * time.Second + dtmsvr.CronForwardDuration = 180 * time.Second common.DtmConfig.UpdateBranchSync = 1 dtmsvr.PopulateDB(false) examples.PopulateDB(false) diff --git a/test/msg_test.go b/test/msg_test.go index 9e03679..cac6ecb 100644 --- a/test/msg_test.go +++ b/test/msg_test.go @@ -11,8 +11,8 @@ import ( func TestMsg(t *testing.T) { msgNormal(t) - msgPending(t) - msgPendingFailed(t) + msgOngoing(t) + msgOngoingFailed(t) } func msgNormal(t *testing.T) { @@ -22,32 +22,37 @@ func msgNormal(t *testing.T) { WaitTransProcessed(msg.Gid) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) + CronTransOnce() } -func msgPending(t *testing.T) { +func msgOngoing(t *testing.T) { msg := genMsg("gid-msg-normal-pending") msg.Prepare("") + err := msg.Prepare("") // additional prepare to go conflict key path + assert.Nil(t, err) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.TransInResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid)) CronTransOnce() assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) + err = msg.Prepare("") + assert.Error(t, err) } -func msgPendingFailed(t *testing.T) { +func msgOngoingFailed(t *testing.T) { msg := genMsg("gid-msg-pending-failed") msg.Prepare("") assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) - examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce() + examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) + cronTransOnceForwardNow(180) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultFailure) - CronTransOnce() + cronTransOnceForwardNow(180) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(msg.Gid)) } diff --git a/test/saga_concurrent_test.go b/test/saga_concurrent_test.go new file mode 100644 index 0000000..e71da70 --- /dev/null +++ b/test/saga_concurrent_test.go @@ -0,0 +1,73 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yedf/dtm/dtmcli" + "github.com/yedf/dtm/examples" +) + +func TestCSaga(t *testing.T) { + csagaNormal(t) + csagaRollback(t) + csagaRollback2(t) + csagaCommittedOngoing(t) +} + +func genCSaga(gid string, outFailed bool, inFailed bool) *dtmcli.Saga { + dtmcli.Logf("beginning a concurrent saga test ---------------- %s", gid) + req := examples.GenTransReq(30, outFailed, inFailed) + csaga := dtmcli.NewSaga(examples.DtmServer, gid). + Add(examples.Busi+"/TransOut", examples.Busi+"/TransOutRevert", &req). + Add(examples.Busi+"/TransIn", examples.Busi+"/TransInRevert", &req). + EnableConcurrent() + return csaga +} + +func csagaNormal(t *testing.T) { + csaga := genCSaga("gid-noraml-csaga", false, false) + csaga.Submit() + WaitTransProcessed(csaga.Gid) + assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid)) +} + +func csagaRollback(t *testing.T) { + csaga := genCSaga("gid-rollback-csaga", true, false) + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) + err := csaga.Submit() + assert.Nil(t, err) + WaitTransProcessed(csaga.Gid) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(csaga.Gid)) + CronTransOnce() + assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid)) + assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + err = csaga.Submit() + assert.Error(t, err) +} + +func csagaRollback2(t *testing.T) { + csaga := genCSaga("gid-rollback-csaga2", true, false) + csaga.AddStepOrder(1, []int{0}) + err := csaga.Submit() + assert.Nil(t, err) + WaitTransProcessed(csaga.Gid) + assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid)) + assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(csaga.Gid)) + err = csaga.Submit() + assert.Error(t, err) +} + +func csagaCommittedOngoing(t *testing.T) { + csaga := genCSaga("gid-committed-ongoing-csaga", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) + csaga.Submit() + WaitTransProcessed(csaga.Gid) + assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(csaga.Gid)) + + CronTransOnce() + assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid)) + assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid)) +} diff --git a/test/saga_test.go b/test/saga_test.go index 6627f2b..d4b6eb8 100644 --- a/test/saga_test.go +++ b/test/saga_test.go @@ -10,8 +10,10 @@ import ( func TestSaga(t *testing.T) { sagaNormal(t) - sagaCommittedPending(t) + sagaCommittedOngoing(t) sagaRollback(t) + sagaRollback2(t) + sagaTimeout(t) } func sagaNormal(t *testing.T) { @@ -25,9 +27,9 @@ func sagaNormal(t *testing.T) { assert.Error(t, err) } -func sagaCommittedPending(t *testing.T) { - saga := genSaga("gid-committedPending", false, false) - examples.MainSwitch.TransOutResult.SetOnce("PENDING") +func sagaCommittedOngoing(t *testing.T) { + saga := genSaga("gid-committedOngoing", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) saga.Submit() WaitTransProcessed(saga.Gid) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) @@ -37,12 +39,12 @@ func sagaCommittedPending(t *testing.T) { } func sagaRollback(t *testing.T) { - saga := genSaga("gid-rollbackSaga2", false, true) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + saga := genSaga("gid-rollback-saga", false, true) + examples.MainSwitch.TransOutRevertResult.SetOnce("ERROR") err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) - assert.Equal(t, "aborting", getTransStatus(saga.Gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid)) @@ -50,6 +52,29 @@ func sagaRollback(t *testing.T) { assert.Error(t, err) } +func sagaRollback2(t *testing.T) { + saga := genSaga("gid-rollback-saga2", false, false) + saga.TimeoutToFail = 1800 + examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing) + err := saga.Submit() + assert.Nil(t, err) + WaitTransProcessed(saga.Gid) + cronTransOnceForwardNow(3600) + assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) + assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) +} + +func sagaTimeout(t *testing.T) { + saga := genSaga("gid-timeout-saga", false, false) + saga.TimeoutToFail = 1800 + examples.MainSwitch.TransOutResult.SetOnce("UNKOWN") + saga.Submit() + WaitTransProcessed(saga.Gid) + assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(saga.Gid)) + cronTransOnceForwardNow(3600) + assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) +} + func genSaga(gid string, outFailed bool, inFailed bool) *dtmcli.Saga { dtmcli.Logf("beginning a saga test ---------------- %s", gid) saga := dtmcli.NewSaga(examples.DtmServer, gid) diff --git a/test/tcc_test.go b/test/tcc_test.go index cb98f3b..2b5588d 100644 --- a/test/tcc_test.go +++ b/test/tcc_test.go @@ -32,12 +32,12 @@ func tccRollback(t *testing.T) { err := dtmcli.TccGlobalTransaction(examples.DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { _, rerr := tcc.CallBranch(data, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert") assert.Nil(t, rerr) - examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") + examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) return tcc.CallBranch(data, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") }) assert.Error(t, err) WaitTransProcessed(gid) - assert.Equal(t, "aborting", getTransStatus(gid)) + assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid)) CronTransOnce() assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) } diff --git a/test/types.go b/test/types.go index f069a89..c9b84bb 100644 --- a/test/types.go +++ b/test/types.go @@ -1,6 +1,8 @@ package test import ( + "time" + "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmsvr" @@ -27,3 +29,10 @@ type TransBranch = dtmsvr.TransBranch // M alias type M = dtmcli.M + +func cronTransOnceForwardNow(seconds int) { + old := dtmsvr.NowForwardDuration + dtmsvr.NowForwardDuration = time.Duration(seconds) * time.Second + CronTransOnce() + dtmsvr.NowForwardDuration = old +} diff --git a/test/wait_saga_test.go b/test/wait_saga_test.go index 87c4213..9166c55 100644 --- a/test/wait_saga_test.go +++ b/test/wait_saga_test.go @@ -11,13 +11,13 @@ import ( func TestWaitSaga(t *testing.T) { sagaNormalWait(t) - sagaCommittedPendingWait(t) + sagaCommittedOngoingWait(t) sagaRollbackWait(t) } func sagaNormalWait(t *testing.T) { saga := genSaga("gid-noramlSagaWait", false, false) - saga.WaitResult = true + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) @@ -26,10 +26,10 @@ func sagaNormalWait(t *testing.T) { transQuery(t, saga.Gid) } -func sagaCommittedPendingWait(t *testing.T) { - saga := genSaga("gid-committedPendingWait", false, false) - examples.MainSwitch.TransOutResult.SetOnce("PENDING") - saga.WaitResult = true +func sagaCommittedOngoingWait(t *testing.T) { + saga := genSaga("gid-committedOngoingWait", false, false) + examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Error(t, err) WaitTransProcessed(saga.Gid) @@ -41,7 +41,7 @@ func sagaCommittedPendingWait(t *testing.T) { func sagaRollbackWait(t *testing.T) { saga := genSaga("gid-rollbackSaga2Wait", false, true) - saga.WaitResult = true + saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() assert.Error(t, err) WaitTransProcessed(saga.Gid) diff --git a/test/xa_test.go b/test/xa_test.go index 793039b..8133085 100644 --- a/test/xa_test.go +++ b/test/xa_test.go @@ -16,6 +16,7 @@ func TestXa(t *testing.T) { xaNormal(t) xaDuplicate(t) xaRollback(t) + xaTimeout(t) } func xaLocalError(t *testing.T) { @@ -75,3 +76,19 @@ func xaRollback(t *testing.T) { assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusPrepared}, getBranchesStatus(gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) } + +func xaTimeout(t *testing.T) { + xc := examples.XaClient + gid := "xaTimeout" + timeoutChan := make(chan int, 1) + err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + go func() { + cronTransOnceForwardNow(1) + cronTransOnceForwardNow(300) + timeoutChan <- 0 + }() + _ = <-timeoutChan + return nil, nil + }) + assert.Error(t, err) +}