Browse Source

Merge pull request #44 from yedf/alpha

concurrent saga
pull/46/head
yedf2 4 years ago
committed by GitHub
parent
commit
50becd29fc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      .gitignore
  2. 5
      .vscode/launch.sample.json
  3. 0
      .vscode/settings.sample.json
  4. 32
      common/types.go
  5. 23
      common/types_test.go
  6. 5
      common/utils.go
  7. 6
      conf.sample.yml
  8. 13
      dtmcli/barrier.mysql.sql
  9. 12
      dtmcli/barrier.postgres.sql
  10. 12
      dtmcli/consts.go
  11. 24
      dtmcli/saga.go
  12. 34
      dtmcli/types.go
  13. 18
      dtmcli/utils.go
  14. 7
      dtmcli/utils_test.go
  15. 17
      dtmgrpc/barrier.go
  16. 9
      dtmgrpc/type.go
  17. 33
      dtmsvr/api.go
  18. 4
      dtmsvr/api_grpc.go
  19. 25
      dtmsvr/api_http.go
  20. 25
      dtmsvr/cron.go
  21. 2
      dtmsvr/dtmsvr.go
  22. 8
      dtmsvr/dtmsvr.mysql.sql
  23. 8
      dtmsvr/dtmsvr.postgres.sql
  24. 160
      dtmsvr/trans.go
  25. 22
      dtmsvr/trans_msg.go
  26. 160
      dtmsvr/trans_saga.go
  27. 15
      dtmsvr/trans_tcc.go
  28. 15
      dtmsvr/trans_xa.go
  29. 30
      dtmsvr/utils.go
  30. 14
      dtmsvr/utils_test.go
  31. 6
      examples/base_grpc.go
  32. 12
      examples/base_http.go
  33. 2
      examples/grpc_saga_barrier.go
  34. 2
      examples/http_msg.go
  35. 20
      examples/http_saga.go
  36. 4
      test/barrier_tcc_test.go
  37. 53
      test/base_test.go
  38. 36
      test/dtmsvr_test.go
  39. 12
      test/grpc_msg_test.go
  40. 12
      test/grpc_saga_test.go
  41. 4
      test/grpc_tcc_test.go
  42. 3
      test/main_test.go
  43. 27
      test/msg_test.go
  44. 73
      test/saga_concurrent_test.go
  45. 39
      test/saga_test.go
  46. 4
      test/tcc_test.go
  47. 9
      test/types.go
  48. 14
      test/wait_saga_test.go
  49. 17
      test/xa_test.go

1
.gitignore

@ -5,3 +5,4 @@ conf.yml
main
dist
.idea/**
.vscode/**

5
.vscode/launch.json → .vscode/launch.sample.json

@ -1,5 +1,5 @@
{
// 使 IntelliSense
// 使 IntelliSense
//
// 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
@ -10,10 +10,11 @@
"request": "launch",
"mode": "debug",
"program": "${workspaceFolder}/app/main.go",
"cwd": "${workspaceFolder}",
"env": {
// "GIN_MODE": "release"
},
"args": []
"args": ["grpc_saga"]
},
{
"name": "Test",

0
.vscode/settings.json → .vscode/settings.sample.json

32
common/types.go

@ -22,7 +22,7 @@ import (
// ModelBase model base for gorm to provide base fields
type ModelBase struct {
ID uint
ID uint64
CreateTime *time.Time `gorm:"autoCreateTime"`
UpdateTime *time.Time `gorm:"autoUpdateTime"`
}
@ -123,7 +123,9 @@ func DbGet(conf map[string]string) *DB {
}
type dtmConfigType struct {
TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务
TransCronInterval int64 `yaml:"TransCronInterval"`
TimeoutToFail int64 `yaml:"TimeoutToFail"`
RetryInterval int64 `yaml:"RetryInterval"`
DB map[string]string `yaml:"DB"`
DisableLocalhost int64 `yaml:"DisableLocalhost"`
UpdateBranchSync int64 `yaml:"UpdateBranchSync"`
@ -140,7 +142,9 @@ func init() {
if len(os.Args) == 1 {
return
}
DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "10")
DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "3")
DtmConfig.TimeoutToFail = getIntEnv("TIMEOUT_TO_FAIL", "35")
DtmConfig.RetryInterval = getIntEnv("RETRY_INTERVAL", "10")
DtmConfig.DB = map[string]string{
"driver": dtmcli.OrString(os.Getenv("DB_DRIVER"), "mysql"),
"host": os.Getenv("DB_HOST"),
@ -166,6 +170,24 @@ func init() {
err := yaml.Unmarshal(cont, &DtmConfig)
dtmcli.FatalIfError(err)
}
dtmcli.LogIfFatalf(DtmConfig.DB["driver"] == "" || DtmConfig.DB["user"] == "",
"dtm配置错误. 请访问 http://dtm.pub 查看部署运维环节. check you env, and conf.yml/conf.sample.yml in current and parent path: %s. config is: \n%v", MustGetwd(), DtmConfig)
errStr := checkConfig()
dtmcli.LogIfFatalf(errStr != "",
`config error: '%s'.
check you env, and conf.yml/conf.sample.yml in current and parent path: %s.
please visit http://d.dtm.pub to see the config document.
loaded config is:
%v`, MustGetwd(), DtmConfig)
}
func checkConfig() string {
if DtmConfig.DB["driver"] == "" {
return "db driver empty"
} else if DtmConfig.DB["user"] == "" || DtmConfig.DB["host"] == "" {
return "db config not valid"
} else if DtmConfig.RetryInterval < 10 {
return "RetryInterval should not be less than 10"
} else if DtmConfig.TimeoutToFail < DtmConfig.RetryInterval {
return "TimeoutToFail should not be less than RetryInterval"
}
return ""
}

23
common/types_test.go

@ -30,3 +30,26 @@ func TestDbAlone(t *testing.T) {
_, err = dtmcli.DBExec(db, "select 1")
assert.NotEqual(t, nil, err)
}
func TestConfig(t *testing.T) {
testConfigStringField(DtmConfig.DB, "driver", "", t)
testConfigStringField(DtmConfig.DB, "user", "", t)
testConfigIntField(&DtmConfig.RetryInterval, 9, t)
testConfigIntField(&DtmConfig.TimeoutToFail, 9, t)
}
func testConfigStringField(m map[string]string, key string, val string, t *testing.T) {
old := m[key]
m[key] = val
str := checkConfig()
assert.NotEqual(t, "", str)
m[key] = old
}
func testConfigIntField(fd *int64, val int64, t *testing.T) {
old := *fd
*fd = val
str := checkConfig()
assert.NotEqual(t, "", str)
*fd = old
}

5
common/utils.go

@ -43,7 +43,10 @@ func GetGinApp() *gin.Engine {
// WrapHandler name is clear
func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
return func(c *gin.Context) {
r, err := fn(c)
r, err := func() (r interface{}, rerr error) {
defer dtmcli.P2E(&rerr)
return fn(c)
}()
var b = []byte{}
if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理
b = resp.Body()

6
conf.sample.yml

@ -9,4 +9,8 @@ DB:
# user: 'postgres'
# password: 'mysecretpassword'
# port: '5432'
TransCronInterval: 10 # 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮重试处理,包括prepared中的任务和commited的任务
# the unit of following configurations is second
# TransCronInterval: 3 # the interval to poll unfinished global transaction for every dtm process
# TimeoutToFail: 35 # timeout for XA, TCC to fail. saga's timeout default to infinite, which can be overwritten in saga options
# RetryInterval: 10 # the subtrans branch will be retried after this interval

13
dtmcli/barrier.mysql.sql

@ -1,10 +1,11 @@
create database if not exists dtm_barrier /*!40100 DEFAULT CHARACTER SET utf8mb4 */;
create database if not exists dtm_barrier
/*!40100 DEFAULT CHARACTER SET utf8mb4 */
;
drop table if exists dtm_barrier.barrier;
create table if not exists dtm_barrier.barrier(
id int(11) PRIMARY KEY AUTO_INCREMENT,
trans_type varchar(45) default '' ,
gid varchar(128) default'',
id bigint(22) PRIMARY KEY AUTO_INCREMENT,
trans_type varchar(45) default '',
gid varchar(128) default '',
branch_id varchar(128) default '',
branch_type varchar(45) default '',
barrier_id varchar(45) default '',
@ -14,4 +15,4 @@ create table if not exists dtm_barrier.barrier(
key(create_time),
key(update_time),
UNIQUE key(gid, branch_id, branch_type, barrier_id)
);
);

12
dtmcli/barrier.postgres.sql

@ -1,13 +1,10 @@
create schema if not exists dtm_barrier;
drop table if exists dtm_barrier.barrier;
CREATE SEQUENCE if not EXISTS dtm_barrier.barrier_seq;
create table if not exists dtm_barrier.barrier(
id int NOT NULL DEFAULT NEXTVAL ('dtm_barrier.barrier_seq'),
trans_type varchar(45) default '' ,
gid varchar(128) default'',
id bigint NOT NULL DEFAULT NEXTVAL ('dtm_barrier.barrier_seq'),
trans_type varchar(45) default '',
gid varchar(128) default '',
branch_id varchar(128) default '',
branch_type varchar(45) default '',
barrier_id varchar(45) default '',
@ -16,5 +13,4 @@ create table if not exists dtm_barrier.barrier(
update_time timestamp(0) DEFAULT NULL,
PRIMARY KEY(id),
CONSTRAINT uniq_barrier unique(gid, branch_id, branch_type, barrier_id)
);
);

12
dtmcli/consts.go

@ -1,14 +1,16 @@
package dtmcli
const (
// StatusPrepared status for global trans status. exists only in tran message
// StatusPrepared status for global/branch trans status.
StatusPrepared = "prepared"
// StatusSubmitted StatusSubmitted status for global trans status.
// StatusSubmitted status for global trans status.
StatusSubmitted = "submitted"
// StatusSucceed status for global trans status.
// StatusSucceed status for global/branch trans status.
StatusSucceed = "succeed"
// StatusFailed status for global trans status.
// StatusFailed status for global/branch trans status.
StatusFailed = "failed"
// StatusAborting status for global trans status.
StatusAborting = "aborting"
// BranchTry branch type for TCC
BranchTry = "try"
@ -29,6 +31,8 @@ const (
ResultSuccess = "SUCCESS"
// ResultFailure for result of a trans/trans branch
ResultFailure = "FAILURE"
// ResultOngoing for result of a trans/trans branch
ResultOngoing = "ONGOING"
// DBTypeMysql const for driver mysql
DBTypeMysql = "mysql"

24
dtmcli/saga.go

@ -1,9 +1,13 @@
package dtmcli
import "fmt"
// Saga struct of saga
type Saga struct {
TransBase
Steps []SagaStep `json:"steps"`
Steps []SagaStep `json:"steps"`
orders map[int][]int
concurrent bool
}
// SagaStep one step of saga
@ -15,7 +19,7 @@ type SagaStep struct {
// NewSaga create a saga
func NewSaga(server string, gid string) *Saga {
return &Saga{TransBase: *NewTransBase(gid, "saga", server, "")}
return &Saga{TransBase: *NewTransBase(gid, "saga", server, ""), orders: map[int][]int{}}
}
// Add add a saga step
@ -28,7 +32,23 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga
return s
}
// AddStepOrder specify that step should be after preSteps. Step is larger than all the element in preSteps
func (s *Saga) AddStepOrder(step int, preSteps []int) *Saga {
PanicIf(step > len(s.Steps), fmt.Errorf("step value: %d is invalid. which cannot be larger than total steps: %d", step, len(s.Steps)))
s.orders[step] = preSteps
return s
}
// EnableConcurrent enable the concurrent exec of sub trans
func (s *Saga) EnableConcurrent() *Saga {
s.concurrent = true
return s
}
// Submit submit the saga trans
func (s *Saga) Submit() error {
if s.concurrent {
s.CustomData = MustMarshalString(M{"orders": s.orders, "concurrent": s.concurrent})
}
return s.callDtm(s, "submit")
}

34
dtmcli/types.go

@ -51,14 +51,26 @@ type TransResult struct {
Message string
}
// TransOptions transaction options
type TransOptions struct {
WaitResult bool `json:"wait_result,omitempty" gorm:"-"`
TimeoutToFail int64 `json:"timeout_to_fail,omitempty" gorm:"-"` // for trans type: xa, tcc
RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc
}
// TransBase 事务的基础类
type TransBase struct {
Gid string `json:"gid"`
TransType string `json:"trans_type"`
Gid string `json:"gid"`
TransType string `json:"trans_type"`
Dtm string `json:"-"`
CustomData string `json:"custom_data,omitempty"`
IDGenerator
Dtm string
// WaitResult 是否等待全局事务的最终结果
WaitResult bool
TransOptions
}
// SetOptions set options
func (tb *TransBase) SetOptions(options *TransOptions) {
tb.TransOptions = *options
}
// NewTransBase 1
@ -78,11 +90,7 @@ func TransBaseFromQuery(qs url.Values) *TransBase {
// callDtm 调用dtm服务器,返回事务的状态
func (tb *TransBase) callDtm(body interface{}, operation string) error {
params := MS{}
if tb.WaitResult {
params["wait_result"] = "1"
}
resp, err := RestyClient.R().SetQueryParams(params).
resp, err := RestyClient.R().
SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation))
if err != nil {
return err
@ -95,10 +103,10 @@ func (tb *TransBase) callDtm(body interface{}, operation string) error {
}
// ErrFailure 表示返回失败,要求回滚
var ErrFailure = errors.New("transaction FAILURE")
var ErrFailure = errors.New("FAILURE")
// ErrPending 表示暂时失败,要求重试
var ErrPending = errors.New("transaction PENDING")
// ErrOngoing 表示暂时失败,要求重试
var ErrOngoing = errors.New("ONGOING")
// MapSuccess 表示返回成功,可以进行下一步
var MapSuccess = M{"dtm_result": ResultSuccess}

18
dtmcli/utils.go

@ -17,14 +17,18 @@ import (
"github.com/go-resty/resty/v2"
)
// AsError wrap a panic value as an error
func AsError(x interface{}) error {
if e, ok := x.(error); ok {
return e
}
return fmt.Errorf("%v", x)
}
// P2E panic to error
func P2E(perr *error) {
if x := recover(); x != nil {
if e, ok := x.(error); ok {
*perr = e
} else {
panic(x)
}
*perr = AsError(x)
}
}
@ -262,8 +266,8 @@ func CheckResult(res interface{}, err error) error {
str := MustMarshalString(res)
if strings.Contains(str, ResultFailure) {
return ErrFailure
} else if strings.Contains(str, "PENDING") {
return ErrPending
} else if strings.Contains(str, ResultOngoing) {
return ErrOngoing
}
}
return err

7
dtmcli/utils_test.go

@ -25,13 +25,10 @@ func TestEP(t *testing.T) {
})
assert.Equal(t, "err2", err.Error())
err = func() (rerr error) {
defer func() {
x := recover()
assert.Equal(t, 1, x)
}()
defer P2E(&rerr)
panic(1)
panic("raw_string")
}()
assert.Equal(t, "raw_string", err.Error())
}
func TestTernary(t *testing.T) {

17
dtmgrpc/barrier.go

@ -2,8 +2,6 @@ package dtmgrpc
import (
"github.com/yedf/dtm/dtmcli"
"google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// BranchBarrier 子事务屏障
@ -11,21 +9,6 @@ type BranchBarrier struct {
*dtmcli.BranchBarrier
}
// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
// db: 本地数据库
// transInfo: 事务信息
// bisiCall: 业务函数,仅在必要时被调用
// 返回值:
// 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚
// 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行
func (bb *BranchBarrier) Call(tx dtmcli.Tx, busiCall dtmcli.BusiFunc) (rerr error) {
err := bb.BranchBarrier.Call(tx, busiCall)
if err == dtmcli.ErrFailure {
return status.New(codes.Aborted, "user rollback").Err()
}
return err
}
// BarrierFromGrpc 从BusiRequest生成一个Barrier
func BarrierFromGrpc(in *BusiRequest) (*BranchBarrier, error) {
b, err := dtmcli.BarrierFrom(in.Info.TransType, in.Info.Gid, in.Info.BranchID, in.Info.BranchType)

9
dtmgrpc/type.go

@ -9,7 +9,7 @@ import (
"github.com/yedf/dtm/dtmcli"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
)
@ -89,9 +89,10 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c
func Result2Error(res interface{}, err error) error {
e := dtmcli.CheckResult(res, err)
if e == dtmcli.ErrFailure {
return status.New(codes.Aborted, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err()
} else if e == dtmcli.ErrPending {
return status.New(codes.Unavailable, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err()
dtmcli.LogRedf("failure: res: %v, err: %v", res, e)
return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
} else if e == dtmcli.ErrOngoing {
return status.New(codes.Aborted, dtmcli.ResultOngoing).Err()
}
return e
}

33
dtmsvr/api.go

@ -7,47 +7,48 @@ import (
"gorm.io/gorm/clause"
)
func svcSubmit(t *TransGlobal, waitResult bool) (interface{}, error) {
func svcSubmit(t *TransGlobal) (interface{}, error) {
db := dbGet()
t.Status = dtmcli.StatusSubmitted
err := t.saveNew(db)
if err == errUniqueConflict {
dbt := TransFromDb(db, t.Gid)
dbt := transFromDb(db, t.Gid)
if dbt.Status == dtmcli.StatusPrepared {
updates := t.setNextCron(config.TransCronInterval)
updates := t.setNextCron(cronReset)
db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t)
} else if dbt.Status != dtmcli.StatusSubmitted {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot sumbmit", dbt.Status)}, nil
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil
}
}
return t.Process(db, waitResult), nil
return t.Process(db), nil
}
func svcPrepare(t *TransGlobal) (interface{}, error) {
t.Status = dtmcli.StatusPrepared
err := t.saveNew(dbGet())
if err == errUniqueConflict {
dbt := TransFromDb(dbGet(), t.Gid)
dbt := transFromDb(dbGet(), t.Gid)
if dbt.Status != dtmcli.StatusPrepared {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot prepare", dbt.Status)}, nil
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil
}
}
return dtmcli.MapSuccess, nil
}
func svcAbort(t *TransGlobal, waitResult bool) (interface{}, error) {
func svcAbort(t *TransGlobal) (interface{}, error) {
db := dbGet()
dbt := TransFromDb(db, t.Gid)
if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != "aborting" {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: %s current status %s, cannot abort", dbt.TransType, dbt.Status)}, nil
dbt := transFromDb(db, t.Gid)
if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil
}
return dbt.Process(db, waitResult), nil
dbt.changeStatus(db, dtmcli.StatusAborting)
return dbt.Process(db), nil
}
func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, error) {
db := dbGet()
dbt := TransFromDb(db, branch.Gid)
dbt := transFromDb(db, branch.Gid)
if dbt.Status != dtmcli.StatusPrepared {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil
}
@ -62,14 +63,14 @@ func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, err
DoNothing: true,
}).Create(branches)
global := TransGlobal{Gid: branch.Gid}
global.touch(dbGet(), config.TransCronInterval)
global.touch(dbGet(), cronKeep)
return dtmcli.MapSuccess, nil
}
func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) {
branch.Status = dtmcli.StatusPrepared
db := dbGet()
dbt := TransFromDb(db, branch.Gid)
dbt := transFromDb(db, branch.Gid)
if dbt.Status != dtmcli.StatusPrepared {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil
}
@ -80,6 +81,6 @@ func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) {
DoNothing: true,
}).Create(branches)
global := TransGlobal{Gid: branch.Gid}
global.touch(db, config.TransCronInterval)
global.touch(db, cronKeep)
return dtmcli.MapSuccess, nil
}

4
dtmsvr/api_grpc.go

@ -19,7 +19,7 @@ func (s *dtmServer) NewGid(ctx context.Context, in *emptypb.Empty) (*dtmgrpc.Dtm
}
func (s *dtmServer) Submit(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) {
r, err := svcSubmit(TransFromDtmRequest(in), in.WaitResult)
r, err := svcSubmit(TransFromDtmRequest(in))
return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err)
}
@ -29,7 +29,7 @@ func (s *dtmServer) Prepare(ctx context.Context, in *pb.DtmRequest) (*emptypb.Em
}
func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) {
r, err := svcAbort(TransFromDtmRequest(in), in.WaitResult)
r, err := svcAbort(TransFromDtmRequest(in))
return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err)
}

25
dtmsvr/api_http.go

@ -2,22 +2,22 @@ package dtmsvr
import (
"errors"
"math"
"github.com/gin-gonic/gin"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"gorm.io/gorm"
)
func addRoute(engine *gin.Engine) {
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid))
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare))
engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort))
engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch))
engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort))
engine.GET("/api/dtmsvr/query", common.WrapHandler(query))
engine.GET("/api/dtmsvr/all", common.WrapHandler(all))
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid))
}
func newGid(c *gin.Context) (interface{}, error) {
@ -29,11 +29,11 @@ func prepare(c *gin.Context) (interface{}, error) {
}
func submit(c *gin.Context) (interface{}, error) {
return svcSubmit(TransFromContext(c), c.Query("wait_result") == "1")
return svcSubmit(TransFromContext(c))
}
func abort(c *gin.Context) (interface{}, error) {
return svcAbort(TransFromContext(c), c.Query("wait_result") == "1")
return svcAbort(TransFromContext(c))
}
func registerXaBranch(c *gin.Context) (interface{}, error) {
@ -61,24 +61,19 @@ func query(c *gin.Context) (interface{}, error) {
if gid == "" {
return nil, errors.New("no gid specified")
}
trans := TransGlobal{}
db := dbGet()
db.Begin()
dbr := db.Must().Where("gid", gid).First(&trans)
if dbr.Error == gorm.ErrRecordNotFound {
return M{"transaction": nil, "branches": [0]int{}}, nil
}
trans := transFromDb(db, gid)
branches := []TransBranch{}
db.Must().Where("gid", gid).Find(&branches)
return M{"transaction": trans, "branches": branches}, nil
}
func all(c *gin.Context) (interface{}, error) {
lastId := c.Query("last_id")
if lastId == "" {
lastId = "2000000000"
lastID := c.Query("last_id")
lid := math.MaxInt64
if lastID != "" {
lid = dtmcli.MustAtoi(lastID)
}
lid := dtmcli.MustAtoi(lastId)
trans := []TransGlobal{}
dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(100).Find(&trans)
return M{"transactions": trans}, nil

25
dtmsvr/cron.go

@ -2,7 +2,6 @@ package dtmsvr
import (
"fmt"
"math"
"math/rand"
"runtime/debug"
"time"
@ -10,7 +9,10 @@ import (
"github.com/yedf/dtm/dtmcli"
)
// CronForwardDuration will be set in test, cron will fetch trans which expire in CronForwardDuration
// NowForwardDuration will be set in test, trans may be timeout
var NowForwardDuration time.Duration = time.Duration(0)
// CronForwardDuration will be set in test. cron will fetch trans which expire in CronForwardDuration
var CronForwardDuration time.Duration = time.Duration(0)
// CronTransOnce cron expired trans. use expireIn as expire time
@ -24,7 +26,8 @@ func CronTransOnce() (hasTrans bool) {
if TransProcessedTestChan != nil {
defer WaitTransProcessed(trans.Gid)
}
trans.Process(dbGet(), true)
trans.WaitResult = true
trans.Process(dbGet())
return
}
@ -33,7 +36,7 @@ func CronExpiredTrans(num int) {
for i := 0; i < num || num == -1; i++ {
hasTrans := CronTransOnce()
if !hasTrans && num != 1 {
sleepCronTime(0)
sleepCronTime()
}
}
}
@ -44,7 +47,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal {
db := dbGet()
getTime := dtmcli.GetDBSpecial().TimestampAdd
expire := int(expireIn / time.Second)
whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3))
whereTime := fmt.Sprintf("next_cron_time < %s and update_time < %s", getTime(expire), getTime(expire-3))
// 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢
// 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出
dbr := db.Must().Model(&trans).
@ -53,7 +56,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal {
return nil
}
dbr = db.Must().Where("owner=?", owner).Find(&trans)
updates := trans.setNextCron(trans.NextCronInterval * 2) // 下次被cron的间隔加倍
updates := trans.setNextCron(cronKeep)
db.Must().Model(&trans).Select(updates).Updates(&trans)
return &trans
}
@ -67,9 +70,9 @@ func handlePanic(perr *error) {
}
}
func sleepCronTime(milli int) {
delta := math.Min(3, float64(config.TransCronInterval))
interval := time.Duration((float64(config.TransCronInterval) - rand.Float64()*delta) * float64(time.Second))
dtmcli.Logf("sleeping for %v pass in %d milli", interval, milli)
time.Sleep(dtmcli.If(milli == 0, interval, time.Duration(milli*int(time.Millisecond))).(time.Duration))
func sleepCronTime() {
normal := time.Duration((float64(config.TransCronInterval) - rand.Float64()) * float64(time.Second))
interval := dtmcli.If(CronForwardDuration > 0, 1*time.Millisecond, normal).(time.Duration)
dtmcli.Logf("sleeping for %v milli", interval/time.Microsecond)
time.Sleep(interval)
}

2
dtmsvr/dtmsvr.go

@ -70,7 +70,7 @@ func updateBranchAsync() {
updates = append(updates, TransBranch{
ModelBase: common.ModelBase{ID: updateBranch.id},
Status: updateBranch.status,
FinishTime: updateBranch.finish_time,
FinishTime: updateBranch.finishTime,
})
case <-time.After(checkInterval):
}

8
dtmsvr/dtmsvr.mysql.sql

@ -3,7 +3,7 @@ CREATE DATABASE IF NOT EXISTS dtm
;
drop table IF EXISTS dtm.trans_global;
CREATE TABLE if not EXISTS dtm.trans_global (
`id` int(11) NOT NULL AUTO_INCREMENT,
`id` bigint(22) NOT NULL AUTO_INCREMENT,
`gid` varchar(128) NOT NULL COMMENT '事务全局id',
`trans_type` varchar(45) not null COMMENT '事务类型: saga | xa | tcc | msg',
-- `data` TEXT COMMENT '事务携带的数据', -- 影响性能,不必要存储
@ -15,6 +15,8 @@ CREATE TABLE if not EXISTS dtm.trans_global (
`commit_time` datetime DEFAULT NULL,
`finish_time` datetime DEFAULT NULL,
`rollback_time` datetime DEFAULT NULL,
`options` varchar(256) DEFAULT '',
`custom_data` varchar(256) DEFAULT '',
`next_cron_interval` int(11) default null comment '下次定时处理的间隔',
`next_cron_time` datetime default null comment '下次定时处理的时间',
`owner` varchar(128) not null default '' comment '正在处理全局事务的锁定者',
@ -27,7 +29,7 @@ CREATE TABLE if not EXISTS dtm.trans_global (
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;
drop table IF EXISTS dtm.trans_branch;
CREATE TABLE IF NOT EXISTS dtm.trans_branch (
`id` int(11) NOT NULL AUTO_INCREMENT,
`id` bigint(22) NOT NULL AUTO_INCREMENT,
`gid` varchar(128) NOT NULL COMMENT '事务全局id',
`url` varchar(128) NOT NULL COMMENT '动作关联的url',
`data` TEXT COMMENT '请求所携带的数据',
@ -45,7 +47,7 @@ CREATE TABLE IF NOT EXISTS dtm.trans_branch (
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;
drop table IF EXISTS dtm.trans_log;
CREATE TABLE IF NOT EXISTS dtm.trans_log (
`id` int(11) NOT NULL AUTO_INCREMENT,
`id` bigint(22) NOT NULL AUTO_INCREMENT,
`gid` varchar(128) NOT NULL COMMENT '事务全局id',
`branch_id` varchar(128) DEFAULT NULL COMMENT '事务分支',
`action` varchar(45) DEFAULT NULL COMMENT '行为',

8
dtmsvr/dtmsvr.postgres.sql

@ -5,7 +5,7 @@ drop table IF EXISTS dtm.trans_global;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
CREATE SEQUENCE if not EXISTS dtm.trans_global_seq;
CREATE TABLE if not EXISTS dtm.trans_global (
id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_global_seq'),
id bigint NOT NULL DEFAULT NEXTVAL ('dtm.trans_global_seq'),
gid varchar(128) NOT NULL,
trans_type varchar(45) not null,
status varchar(45) NOT NULL,
@ -16,6 +16,8 @@ CREATE TABLE if not EXISTS dtm.trans_global (
commit_time timestamp(0) DEFAULT NULL,
finish_time timestamp(0) DEFAULT NULL,
rollback_time timestamp(0) DEFAULT NULL,
options varchar(256) DEFAULT '',
custom_data varchar(256) DEFAULT '',
next_cron_interval int default null,
next_cron_time timestamp(0) default null,
owner varchar(128) not null default '',
@ -30,7 +32,7 @@ drop table IF EXISTS dtm.trans_branch;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
CREATE SEQUENCE if not EXISTS dtm.trans_branch_seq;
CREATE TABLE IF NOT EXISTS dtm.trans_branch (
id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_branch_seq'),
id bigint NOT NULL DEFAULT NEXTVAL ('dtm.trans_branch_seq'),
gid varchar(128) NOT NULL,
url varchar(128) NOT NULL,
data TEXT,
@ -50,7 +52,7 @@ drop table IF EXISTS dtm.trans_log;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
CREATE SEQUENCE if not EXISTS dtm.trans_log_seq;
CREATE TABLE IF NOT EXISTS dtm.trans_log (
id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_log_seq'),
id bigint NOT NULL DEFAULT NEXTVAL ('dtm.trans_log_seq'),
gid varchar(128) NOT NULL,
branch_id varchar(128) DEFAULT NULL,
action varchar(45) DEFAULT NULL,

160
dtmsvr/trans.go

@ -32,9 +32,12 @@ type TransGlobal struct {
CommitTime *time.Time
FinishTime *time.Time
RollbackTime *time.Time
Options string
CustomData string `json:"custom_data"`
NextCronInterval int64
NextCronTime *time.Time
processStarted time.Time // record the start time of process
dtmcli.TransOptions
processStarted time.Time // record the start time of process
}
// TableName TableName
@ -44,12 +47,12 @@ func (*TransGlobal) TableName() string {
type transProcessor interface {
GenBranches() []TransBranch
ProcessOnce(db *common.DB, branches []TransBranch)
ProcessOnce(db *common.DB, branches []TransBranch) error
}
func (t *TransGlobal) touch(db *common.DB, interval int64) *gorm.DB {
func (t *TransGlobal) touch(db *common.DB, ctype cronType) *gorm.DB {
writeTransLog(t.Gid, "touch trans", "", "", "")
updates := t.setNextCron(interval)
updates := t.setNextCron(ctype)
return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Select(updates).Updates(t)
}
@ -57,7 +60,7 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB {
writeTransLog(t.Gid, "change status", status, "", "")
old := t.Status
t.Status = status
updates := t.setNextCron(config.TransCronInterval)
updates := t.setNextCron(cronReset)
updates = append(updates, "status")
now := time.Now()
if status == dtmcli.StatusSucceed {
@ -72,6 +75,21 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB {
return dbr
}
func (t *TransGlobal) isTimeout() bool {
timeout := t.TimeoutToFail
if t.TimeoutToFail == 0 && t.TransType != "saga" {
timeout = config.TimeoutToFail
}
if timeout == 0 {
return false
}
return time.Since(*t.CreateTime)+NowForwardDuration >= time.Duration(timeout)*time.Second
}
func (t *TransGlobal) needProcess() bool {
return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout()
}
// TransBranch branch transaction
type TransBranch struct {
common.ModelBase
@ -124,14 +142,18 @@ func (t *TransGlobal) getProcessor() transProcessor {
}
// Process process global transaction once
func (t *TransGlobal) Process(db *common.DB, waitResult bool) dtmcli.M {
r := t.process(db, waitResult)
func (t *TransGlobal) Process(db *common.DB) dtmcli.M {
r := t.process(db)
transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess)
return r
}
func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M {
if !waitResult {
func (t *TransGlobal) process(db *common.DB) dtmcli.M {
if t.Options != "" {
dtmcli.MustUnmarshalString(t.Options, &t.TransOptions)
}
if !t.WaitResult {
go t.processInner(db)
return dtmcli.MapSuccess
}
@ -149,6 +171,9 @@ func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M {
func (t *TransGlobal) processInner(db *common.DB) (rerr error) {
defer handlePanic(&rerr)
defer func() {
if rerr != nil {
dtmcli.LogRedf("processInner got error: %s", rerr.Error())
}
if TransProcessedTestChan != nil {
dtmcli.Logf("processed: %s", t.Gid)
TransProcessedTestChan <- t.Gid
@ -156,24 +181,38 @@ func (t *TransGlobal) processInner(db *common.DB) (rerr error) {
}
}()
dtmcli.Logf("processing: %s status: %s", t.Gid, t.Status)
if t.Status == dtmcli.StatusPrepared && t.TransType != "msg" {
t.changeStatus(db, "aborting")
}
branches := []TransBranch{}
db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches)
t.processStarted = time.Now()
t.getProcessor().ProcessOnce(db, branches)
rerr = t.getProcessor().ProcessOnce(db, branches)
return
}
func (t *TransGlobal) setNextCron(expireIn int64) []string {
t.NextCronInterval = expireIn
type cronType int
const (
cronBackoff cronType = iota
cronReset
cronKeep
)
func (t *TransGlobal) setNextCron(ctype cronType) []string {
if ctype == cronBackoff {
t.NextCronInterval = t.NextCronInterval * 2
} else if ctype == cronKeep {
// do nothing
} else if t.RetryInterval != 0 {
t.NextCronInterval = t.RetryInterval
} else {
t.NextCronInterval = config.RetryInterval
}
next := time.Now().Add(time.Duration(t.NextCronInterval) * time.Second)
t.NextCronTime = &next
return []string{"next_cron_interval", "next_cron_time"}
}
func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) string {
func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) (string, error) {
if t.Protocol == "grpc" {
dtmcli.PanicIf(strings.HasPrefix(url, "http"), fmt.Errorf("bad url for grpc: %s", url))
server, method := dtmgrpc.GetServerAndMethod(url)
@ -188,11 +227,17 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran
BusiData: branchData,
}, &emptypb.Empty{})
if err == nil {
return dtmcli.ResultSuccess
} else if status.Code(err) == codes.Aborted {
return dtmcli.ResultFailure
return dtmcli.ResultSuccess, nil
}
st, ok := status.FromError(err)
if ok && st.Code() == codes.Aborted {
if st.Message() == dtmcli.ResultOngoing {
return dtmcli.ResultOngoing, nil
} else if st.Message() == dtmcli.ResultFailure {
return dtmcli.ResultFailure, nil
}
}
return err.Error()
return "", err
}
dtmcli.PanicIf(!strings.HasPrefix(url, "http"), fmt.Errorf("bad url for http: %s", url))
resp, err := dtmcli.RestyClient.R().SetBody(string(branchData)).
@ -204,36 +249,53 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran
}).
SetHeader("Content-type", "application/json").
Execute(dtmcli.If(branchData == nil, "GET", "POST").(string), url)
e2p(err)
return resp.String()
if err != nil {
return "", err
}
return resp.String(), nil
}
func (t *TransGlobal) getBranchResult(branch *TransBranch) string {
return t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data))
func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) {
body, err := t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data))
if err != nil {
return "", err
}
if strings.Contains(body, dtmcli.ResultSuccess) {
return dtmcli.StatusSucceed, nil
} else if strings.HasSuffix(t.TransType, "saga") && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) {
return dtmcli.StatusFailed, nil
} else if strings.Contains(body, dtmcli.ResultOngoing) {
return "", dtmcli.ErrOngoing
}
return "", fmt.Errorf("http result should contains SUCCESS|FAILURE|ONGOING. grpc error should return nil|Aborted with message(FAILURE|ONGOING). \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body)
}
func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) {
body := t.getBranchResult(branch)
status := ""
if strings.Contains(body, dtmcli.ResultSuccess) {
status = dtmcli.StatusSucceed
} else if t.TransType == "saga" && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) {
status = dtmcli.StatusFailed
} else {
panic(fmt.Errorf("http result should contains SUCCESS|FAILURE. grpc error should return nil|Aborted. \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body))
func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) error {
status, err := t.getBranchResult(branch)
if status != "" {
branch.changeStatus(db, status)
}
branchMetrics(t, branch, status == dtmcli.StatusSucceed)
// 如果一次处理超过1500ms,那么touch一下TransGlobal,避免被Cron取出
if time.Since(t.processStarted)+CronForwardDuration >= 1500*time.Millisecond || t.NextCronInterval > config.TransCronInterval {
t.touch(db, config.TransCronInterval)
// if time pass 1500ms and NextCronInterval is not default, then reset NextCronInterval
if err == nil && time.Since(t.processStarted)+NowForwardDuration >= 1500*time.Millisecond ||
t.NextCronInterval > config.RetryInterval && t.NextCronInterval > t.RetryInterval {
t.touch(db, cronReset)
} else if err == dtmcli.ErrOngoing {
t.touch(db, cronKeep)
} else {
t.touch(db, cronBackoff)
}
branch.changeStatus(db, status)
return err
}
func (t *TransGlobal) saveNew(db *common.DB) error {
return db.Transaction(func(db1 *gorm.DB) error {
db := &common.DB{DB: db1}
t.setNextCron(config.TransCronInterval)
t.setNextCron(cronReset)
t.Options = dtmcli.MustMarshalString(t.TransOptions)
if t.Options == "{}" {
t.Options = ""
}
writeTransLog(t.Gid, "create trans", t.Status, "", t.Data)
dbr := db.Must().Clauses(clause.OnConflict{
DoNothing: true,
@ -279,25 +341,3 @@ func TransFromDtmRequest(c *dtmgrpc.DtmRequest) *TransGlobal {
Protocol: "grpc",
}
}
// TransFromDb construct trans from db
func TransFromDb(db *common.DB, gid string) *TransGlobal {
m := TransGlobal{}
dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)
if dbr.Error == gorm.ErrRecordNotFound {
return nil
}
e2p(dbr.Error)
return &m
}
func checkLocalhost(branches []TransBranch) {
if config.DisableLocalhost == 0 {
return
}
for _, branch := range branches {
if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") {
panic(errors.New("url for localhost is disabled. check for your config"))
}
}
}

22
dtmsvr/trans_msg.go

@ -33,23 +33,26 @@ func (t *transMsgProcessor) GenBranches() []TransBranch {
}
func (t *TransGlobal) mayQueryPrepared(db *common.DB) {
if t.Status != dtmcli.StatusPrepared {
if !t.needProcess() || t.Status == dtmcli.StatusSubmitted {
return
}
body := t.getURLResult(t.QueryPrepared, "", "", nil)
body, err := t.getURLResult(t.QueryPrepared, "", "", nil)
if strings.Contains(body, dtmcli.ResultSuccess) {
t.changeStatus(db, dtmcli.StatusSubmitted)
} else if strings.Contains(body, dtmcli.ResultFailure) {
t.changeStatus(db, dtmcli.StatusFailed)
} else if strings.Contains(body, dtmcli.ResultOngoing) {
t.touch(db, cronReset)
} else {
t.touch(db, t.NextCronInterval*2)
dtmcli.LogRedf("getting result failed for %s. error: %s", t.QueryPrepared, err.Error())
t.touch(db, cronBackoff)
}
}
func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
t.mayQueryPrepared(db)
if t.Status != dtmcli.StatusSubmitted {
return
if !t.needProcess() || t.Status == dtmcli.StatusPrepared {
return nil
}
current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ {
@ -57,14 +60,17 @@ func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if branch.BranchType != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared {
continue
}
t.execBranch(db, branch)
err := t.execBranch(db, branch)
if err != nil {
return err
}
if branch.Status != dtmcli.StatusSucceed {
break
}
}
if current == len(branches) { // msg 事务完成
t.changeStatus(db, dtmcli.StatusSucceed)
return
return nil
}
panic("msg go pass all branch")
}

160
dtmsvr/trans_saga.go

@ -2,6 +2,7 @@ package dtmsvr
import (
"fmt"
"time"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
@ -35,37 +36,152 @@ func (t *transSagaProcessor) GenBranches() []TransBranch {
return branches
}
func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed {
return
type cSagaCustom struct {
Orders map[int][]int `json:"orders"`
Concurrent bool `json:"concurrent"`
}
type branchResult struct {
index int
status string
started bool
branchType string
}
func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
// when saga tasks is fetched, it always need to process
dtmcli.Logf("status: %s timeout: %t", t.Status, t.isTimeout())
if t.Status == dtmcli.StatusSubmitted && t.isTimeout() {
t.changeStatus(db, dtmcli.StatusAborting)
}
n := len(branches)
csc := cSagaCustom{Orders: map[int][]int{}}
if t.CustomData != "" {
dtmcli.MustUnmarshalString(t.CustomData, &csc)
}
// resultStats
var rsAToStart, rsAStarted, rsADone, rsAFailed, rsASucceed, rsCToStart, rsCDone, rsCSucceed int
branchResults := make([]branchResult, n) // save the branch result
for i := 0; i < n; i++ {
b := branches[i]
if b.BranchType == dtmcli.BranchAction {
if b.Status == dtmcli.StatusPrepared {
rsAToStart++
} else if b.Status == dtmcli.StatusFailed {
rsAFailed++
}
}
branchResults[i] = branchResult{status: branches[i].Status, branchType: branches[i].BranchType}
}
isPreconditionsSucceed := func(current int) bool {
// if !csc.Concurrent,then check the branch in previous step is succeed
if !csc.Concurrent && current >= 2 && branches[current-2].Status != dtmcli.StatusSucceed {
return false
}
// if csc.concurrent, then check the Orders. origin one step correspond to 2 step in dtmsvr
for _, pre := range csc.Orders[current/2] {
if branches[pre*2+1].Status != dtmcli.StatusSucceed {
return false
}
}
return true
}
current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ {
branch := &branches[current]
if branch.BranchType != dtmcli.BranchAction || branch.Status == dtmcli.StatusSucceed {
continue
resultChan := make(chan branchResult, n)
asyncExecBranch := func(i int) {
var err error
defer func() {
if x := recover(); x != nil {
err = dtmcli.AsError(x)
}
resultChan <- branchResult{index: i, status: branches[i].Status, branchType: branches[i].BranchType}
if err != nil {
dtmcli.LogRedf("exec branch error: %v", err)
}
}()
err = t.execBranch(db, &branches[i])
}
pickToRunActions := func() []int {
toRun := []int{}
for current := 0; current < n; current++ {
br := &branchResults[current]
if br.branchType == dtmcli.BranchAction && !br.started && isPreconditionsSucceed(current) && br.status == dtmcli.StatusPrepared {
toRun = append(toRun, current)
}
}
dtmcli.Logf("toRun picked for action is: %v", toRun)
return toRun
}
runBranches := func(toRun []int) {
for _, b := range toRun {
branchResults[b].started = true
if branchResults[b].branchType == dtmcli.BranchAction {
rsAStarted++
}
go asyncExecBranch(b)
}
}
pickAndRunCompensates := func(toRunActions []int) {
for _, b := range toRunActions {
// these branches may have run. so flag them to status succeed, then run the corresponding compensate
branchResults[b].status = dtmcli.StatusSucceed
}
// 找到了一个非succeed的action
if branch.Status == dtmcli.StatusPrepared {
t.execBranch(db, branch)
for i, b := range branchResults {
if b.branchType == dtmcli.BranchCompensate && b.status != dtmcli.StatusSucceed && branchResults[i+1].status != dtmcli.StatusPrepared {
rsCToStart++
go asyncExecBranch(i)
}
}
if branch.Status != dtmcli.StatusSucceed {
}
waitDoneOnce := func() {
select {
case r := <-resultChan:
br := &branchResults[r.index]
br.status = r.status
if r.branchType == dtmcli.BranchAction {
rsADone++
if r.status == dtmcli.StatusFailed {
rsAFailed++
} else if r.status == dtmcli.StatusSucceed {
rsASucceed++
}
} else {
rsCDone++
if r.status == dtmcli.StatusSucceed {
rsCSucceed++
}
}
dtmcli.Logf("branch done: %v", r)
case <-time.After(time.Duration(time.Second * 3)):
dtmcli.Logf("wait once for done")
}
}
for t.Status == dtmcli.StatusSubmitted && !t.isTimeout() && rsAFailed == 0 && rsADone != rsAToStart {
toRun := pickToRunActions()
runBranches(toRun)
if rsADone == rsAStarted { // no branch is running, so break
break
}
waitDoneOnce()
}
if current == len(branches) { // saga 事务完成
if t.Status == dtmcli.StatusSubmitted && rsAFailed == 0 && rsAToStart == rsASucceed {
t.changeStatus(db, dtmcli.StatusSucceed)
return
return nil
}
if t.Status != "aborting" && t.Status != dtmcli.StatusFailed {
t.changeStatus(db, "aborting")
if t.Status == dtmcli.StatusSubmitted && (rsAFailed > 0 || t.isTimeout()) {
t.changeStatus(db, dtmcli.StatusAborting)
}
for current = current - 1; current >= 0; current-- {
branch := &branches[current]
if branch.BranchType != dtmcli.BranchCompensate || branch.Status != dtmcli.StatusPrepared {
continue
if t.Status == dtmcli.StatusAborting {
toRun := pickToRunActions()
pickAndRunCompensates(toRun)
for rsCDone != rsCToStart {
waitDoneOnce()
}
t.execBranch(db, branch)
}
t.changeStatus(db, dtmcli.StatusFailed)
if t.Status == dtmcli.StatusAborting && rsCToStart == rsCSucceed {
t.changeStatus(db, dtmcli.StatusFailed)
}
return nil
}

15
dtmsvr/trans_tcc.go

@ -17,15 +17,22 @@ func (t *transTccProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == dtmcli.StatusSucceed || t.Status == dtmcli.StatusFailed {
return
func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
if !t.needProcess() {
return nil
}
if t.Status == dtmcli.StatusPrepared && t.isTimeout() {
t.changeStatus(db, dtmcli.StatusAborting)
}
branchType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchConfirm, dtmcli.BranchCancel).(string)
for current := len(branches) - 1; current >= 0; current-- {
if branches[current].BranchType == branchType && branches[current].Status == dtmcli.StatusPrepared {
t.execBranch(db, &branches[current])
err := t.execBranch(db, &branches[current])
if err != nil {
return err
}
}
}
t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string))
return nil
}

15
dtmsvr/trans_xa.go

@ -17,15 +17,22 @@ func (t *transXaProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == dtmcli.StatusSucceed {
return
func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
if !t.needProcess() {
return nil
}
if t.Status == dtmcli.StatusPrepared && t.isTimeout() {
t.changeStatus(db, dtmcli.StatusAborting)
}
currentType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchCommit, dtmcli.BranchRollback).(string)
for _, branch := range branches {
if branch.BranchType == currentType && branch.Status != dtmcli.StatusSucceed {
t.execBranch(db, &branch)
err := t.execBranch(db, &branch)
if err != nil {
return err
}
}
}
t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string))
return nil
}

30
dtmsvr/utils.go

@ -2,6 +2,7 @@ package dtmsvr
import (
"encoding/hex"
"errors"
"fmt"
"net"
"strings"
@ -10,15 +11,16 @@ import (
"github.com/bwmarrin/snowflake"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"gorm.io/gorm"
)
// M a short name
type M = map[string]interface{}
type branchStatus struct {
id uint
status string
finish_time *time.Time
id uint64
status string
finishTime *time.Time
}
var p2e = dtmcli.P2E
@ -91,3 +93,25 @@ func getOneHexIP() string {
fmt.Printf("err is: %s", err.Error())
return "" // 获取不到IP,则直接返回空
}
// transFromDb construct trans from db
func transFromDb(db *common.DB, gid string) *TransGlobal {
m := TransGlobal{}
dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)
if dbr.Error == gorm.ErrRecordNotFound {
return nil
}
e2p(dbr.Error)
return &m
}
func checkLocalhost(branches []TransBranch) {
if config.DisableLocalhost == 0 {
return
}
for _, branch := range branches {
if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") {
panic(errors.New("url for localhost is disabled. check for your config"))
}
}
}

14
dtmsvr/utils_test.go

@ -16,7 +16,7 @@ func TestUtils(t *testing.T) {
assert.Error(t, err)
CronExpiredTrans(1)
sleepCronTime(10)
sleepCronTime()
}
func TestCheckLocalHost(t *testing.T) {
@ -31,3 +31,15 @@ func TestCheckLocalHost(t *testing.T) {
})
assert.Nil(t, err)
}
func TestSetNextCron(t *testing.T) {
tg := TransGlobal{}
tg.RetryInterval = 15
tg.setNextCron(cronReset)
assert.Equal(t, int64(15), tg.NextCronInterval)
tg.RetryInterval = 0
tg.setNextCron(cronReset)
assert.Equal(t, config.RetryInterval, tg.NextCronInterval)
tg.setNextCron(cronBackoff)
assert.Equal(t, config.RetryInterval*2, tg.NextCronInterval)
}

6
examples/base_grpc.go

@ -45,7 +45,7 @@ func handleGrpcBusiness(in *dtmgrpc.BusiRequest, result1 string, result2 string,
if res == dtmcli.ResultSuccess {
return nil
} else if res == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user want to rollback").Err()
return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
}
return status.New(codes.Internal, fmt.Sprintf("unknow result %s", res)).Err()
}
@ -113,7 +113,7 @@ func (s *busiServer) TransInXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*d
dtmcli.MustUnmarshal(in.BusiData, &req)
return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error {
if req.TransInResult == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user return failure").Err()
return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
}
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
return err
@ -125,7 +125,7 @@ func (s *busiServer) TransOutXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*
dtmcli.MustUnmarshal(in.BusiData, &req)
return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error {
if req.TransOutResult == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user return failure").Err()
return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
}
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
return err

12
examples/base_http.go

@ -2,6 +2,7 @@ package examples
import (
"database/sql"
"errors"
"fmt"
"time"
@ -79,6 +80,9 @@ func handleGeneralBusiness(c *gin.Context, result1 string, result2 string, busi
info := infoFromContext(c)
res := dtmcli.OrString(result1, result2, dtmcli.ResultSuccess)
dtmcli.Logf("%s %s result: %s", busi, info.String(), res)
if res == "ERROR" {
return nil, errors.New("ERROR from user")
}
return M{"dtm_result": res}, nil
}
@ -145,4 +149,12 @@ func BaseAddRoute(app *gin.Engine) {
})
}))
app.POST(BusiAPI+"/TestPanic", common.WrapHandler(func(c *gin.Context) (interface{}, error) {
if c.Query("panic_error") != "" {
panic(errors.New("panic_error"))
} else if c.Query("panic_string") != "" {
panic("panic_string")
}
return "SUCCESS", nil
}))
}

2
examples/grpc_saga_barrier.go

@ -24,7 +24,7 @@ func init() {
func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error {
if result == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user rollback").Err()
return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
}
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return err

2
examples/http_msg.go

@ -11,7 +11,7 @@ func init() {
msg := dtmcli.NewMsg(DtmServer, dtmcli.MustGenGid(DtmServer)).
Add(Busi+"/TransOut", req).
Add(Busi+"/TransIn", req)
err := msg.Prepare(Busi + "/TransQuery")
err := msg.Prepare(Busi + "/query")
dtmcli.FatalIfError(err)
dtmcli.Logf("busi trans submit")
err = msg.Submit()

20
examples/http_saga.go

@ -23,11 +23,27 @@ func init() {
saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)).
Add(Busi+"/TransOut", Busi+"/TransOutRevert", req).
Add(Busi+"/TransIn", Busi+"/TransInRevert", req)
saga.WaitResult = true // 设置为等待结果模式,后面的submit调用,会等待服务器处理这个事务。如果Submit正常返回,那么整个全局事务已成功完成
saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit()
dtmcli.Logf("result gid is: %s", saga.Gid)
dtmcli.FatalIfError(err)
return saga.Gid
})
addSample("concurrent_saga", func() string {
dtmcli.Logf("a concurrent saga busi transaction begin")
req := &TransReq{Amount: 30}
csaga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)).
Add(Busi+"/TransOut", Busi+"/TransOutRevert", req).
Add(Busi+"/TransOut", Busi+"/TransOutRevert", req).
Add(Busi+"/TransIn", Busi+"/TransInRevert", req).
Add(Busi+"/TransIn", Busi+"/TransInRevert", req).
EnableConcurrent().
AddStepOrder(2, []int{0, 1}).
AddStepOrder(3, []int{0, 1})
dtmcli.Logf("concurrent saga busi trans submit")
err := csaga.Submit()
dtmcli.Logf("result gid is: %s", csaga.Gid)
dtmcli.FatalIfError(err)
return csaga.Gid
})
}

4
test/barrier_tcc_test.go

@ -102,10 +102,10 @@ func tccBarrierDisorder(t *testing.T) {
finishedChan <- "1"
}()
dtmcli.Logf("cron to timeout and then call cancel")
go CronTransOnce()
go cronTransOnceForwardNow(300)
time.Sleep(100 * time.Millisecond)
dtmcli.Logf("cron to timeout and then call cancelled twice")
CronTransOnce()
cronTransOnceForwardNow(300)
timeoutChan <- "wake"
timeoutChan <- "wake"
<-finishedChan

53
test/base_test.go

@ -0,0 +1,53 @@
package test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/examples"
)
func TestSqlDB(t *testing.T) {
asserts := assert.New(t)
db := common.DbGet(config.DB)
barrier := &dtmcli.BranchBarrier{
TransType: "saga",
Gid: "gid2",
BranchID: "branch_id2",
BranchType: dtmcli.BranchAction,
}
db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
tx, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx, func(db dtmcli.DB) error {
dtmcli.Logf("rollback gid2")
return fmt.Errorf("gid2 error")
})
asserts.Error(err, fmt.Errorf("gid2 error"))
dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
barrier.BarrierID = 0
tx2, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx2, func(db dtmcli.DB) error {
dtmcli.Logf("submit gid2")
return nil
})
asserts.Nil(err)
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
}
func TestHttp(t *testing.T) {
resp, err := dtmcli.RestyClient.R().SetQueryParam("panic_string", "1").Post(examples.Busi + "/TestPanic")
assert.Nil(t, err)
assert.Contains(t, resp.String(), "panic_string")
resp, err = dtmcli.RestyClient.R().SetQueryParam("panic_error", "1").Post(examples.Busi + "/TestPanic")
assert.Nil(t, err)
assert.Contains(t, resp.String(), "panic_error")
}

36
test/dtmsvr_test.go

@ -1,7 +1,6 @@
package test
import (
"fmt"
"testing"
"time"
@ -83,43 +82,10 @@ func transQuery(t *testing.T, gid string) {
assert.Nil(t, err)
}
func TestSqlDB(t *testing.T) {
asserts := assert.New(t)
db := common.DbGet(config.DB)
barrier := &dtmcli.BranchBarrier{
TransType: "saga",
Gid: "gid2",
BranchID: "branch_id2",
BranchType: dtmcli.BranchAction,
}
db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
tx, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx, func(db dtmcli.DB) error {
dtmcli.Logf("rollback gid2")
return fmt.Errorf("gid2 error")
})
asserts.Error(err, fmt.Errorf("gid2 error"))
dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
barrier.BarrierID = 0
tx2, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx2, func(db dtmcli.DB) error {
dtmcli.Logf("submit gid2")
return nil
})
asserts.Nil(err)
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
}
func TestUpdateBranchAsync(t *testing.T) {
common.DtmConfig.UpdateBranchSync = 0
saga := genSaga("gid-update-branch-async", false, false)
saga.WaitResult = true
saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit()
assert.Nil(t, err)
WaitTransProcessed(saga.Gid)

12
test/grpc_msg_test.go

@ -12,7 +12,7 @@ import (
func TestGrpcMsg(t *testing.T) {
grpcMsgNormal(t)
grpcMsgPending(t)
grpcMsgOngoing(t)
}
func grpcMsgNormal(t *testing.T) {
@ -23,15 +23,15 @@ func grpcMsgNormal(t *testing.T) {
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))
}
func grpcMsgPending(t *testing.T) {
func grpcMsgOngoing(t *testing.T) {
msg := genGrpcMsg("grpc-msg-pending")
err := msg.Prepare(fmt.Sprintf("%s/examples.Busi/CanSubmit", examples.BusiGrpc))
assert.Nil(t, err)
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING")
CronTransOnce()
examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing)
cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.TransInResult.SetOnce("PENDING")
CronTransOnce()
examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing)
cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))

12
test/grpc_saga_test.go

@ -11,7 +11,7 @@ import (
func TestGrpcSaga(t *testing.T) {
sagaGrpcNormal(t)
sagaGrpcCommittedPending(t)
sagaGrpcCommittedOngoing(t)
sagaGrpcRollback(t)
}
@ -24,9 +24,9 @@ func sagaGrpcNormal(t *testing.T) {
transQuery(t, saga.Gid)
}
func sagaGrpcCommittedPending(t *testing.T) {
saga := genSagaGrpc("gid-committedPendingGrpc", false, false)
examples.MainSwitch.TransOutResult.SetOnce("PENDING")
func sagaGrpcCommittedOngoing(t *testing.T) {
saga := genSagaGrpc("gid-committedOngoingGrpc", false, false)
examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
saga.Submit()
WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid))
@ -37,10 +37,10 @@ func sagaGrpcCommittedPending(t *testing.T) {
func sagaGrpcRollback(t *testing.T) {
saga := genSagaGrpc("gid-rollbackSaga2Grpc", false, true)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING")
examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
saga.Submit()
WaitTransProcessed(saga.Gid)
assert.Equal(t, "aborting", getTransStatus(saga.Gid))
assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid))

4
test/grpc_tcc_test.go

@ -53,13 +53,13 @@ func tccGrpcRollback(t *testing.T) {
err := dtmgrpc.TccGlobalTransaction(examples.DtmGrpcServer, gid, func(tcc *dtmgrpc.TccGrpc) error {
_, err := tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransOutTcc", examples.BusiGrpc+"/examples.Busi/TransOutConfirm", examples.BusiGrpc+"/examples.Busi/TransOutRevert")
assert.Nil(t, err)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING")
examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
_, err = tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransInTcc", examples.BusiGrpc+"/examples.Busi/TransInConfirm", examples.BusiGrpc+"/examples.Busi/TransInRevert")
return err
})
assert.Error(t, err)
WaitTransProcessed(gid)
assert.Equal(t, "aborting", getTransStatus(gid))
assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid))
}

3
test/main_test.go

@ -14,7 +14,8 @@ import (
func TestMain(m *testing.M) {
dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"])
dtmsvr.TransProcessedTestChan = make(chan string, 1)
dtmsvr.CronForwardDuration = 60 * time.Second
dtmsvr.NowForwardDuration = 0 * time.Second
dtmsvr.CronForwardDuration = 180 * time.Second
common.DtmConfig.UpdateBranchSync = 1
dtmsvr.PopulateDB(false)
examples.PopulateDB(false)

27
test/msg_test.go

@ -11,8 +11,8 @@ import (
func TestMsg(t *testing.T) {
msgNormal(t)
msgPending(t)
msgPendingFailed(t)
msgOngoing(t)
msgOngoingFailed(t)
}
func msgNormal(t *testing.T) {
@ -22,32 +22,37 @@ func msgNormal(t *testing.T) {
WaitTransProcessed(msg.Gid)
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))
CronTransOnce()
}
func msgPending(t *testing.T) {
func msgOngoing(t *testing.T) {
msg := genMsg("gid-msg-normal-pending")
msg.Prepare("")
err := msg.Prepare("") // additional prepare to go conflict key path
assert.Nil(t, err)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING")
CronTransOnce()
examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing)
cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.TransInResult.SetOnce("PENDING")
CronTransOnce()
examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing)
cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid))
CronTransOnce()
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))
err = msg.Prepare("")
assert.Error(t, err)
}
func msgPendingFailed(t *testing.T) {
func msgOngoingFailed(t *testing.T) {
msg := genMsg("gid-msg-pending-failed")
msg.Prepare("")
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING")
CronTransOnce()
examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing)
cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultFailure)
CronTransOnce()
cronTransOnceForwardNow(180)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(msg.Gid))
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(msg.Gid))
}

73
test/saga_concurrent_test.go

@ -0,0 +1,73 @@
package test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/examples"
)
func TestCSaga(t *testing.T) {
csagaNormal(t)
csagaRollback(t)
csagaRollback2(t)
csagaCommittedOngoing(t)
}
func genCSaga(gid string, outFailed bool, inFailed bool) *dtmcli.Saga {
dtmcli.Logf("beginning a concurrent saga test ---------------- %s", gid)
req := examples.GenTransReq(30, outFailed, inFailed)
csaga := dtmcli.NewSaga(examples.DtmServer, gid).
Add(examples.Busi+"/TransOut", examples.Busi+"/TransOutRevert", &req).
Add(examples.Busi+"/TransIn", examples.Busi+"/TransInRevert", &req).
EnableConcurrent()
return csaga
}
func csagaNormal(t *testing.T) {
csaga := genCSaga("gid-noraml-csaga", false, false)
csaga.Submit()
WaitTransProcessed(csaga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid))
}
func csagaRollback(t *testing.T) {
csaga := genCSaga("gid-rollback-csaga", true, false)
examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
err := csaga.Submit()
assert.Nil(t, err)
WaitTransProcessed(csaga.Gid)
assert.Equal(t, dtmcli.StatusAborting, getTransStatus(csaga.Gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
err = csaga.Submit()
assert.Error(t, err)
}
func csagaRollback2(t *testing.T) {
csaga := genCSaga("gid-rollback-csaga2", true, false)
csaga.AddStepOrder(1, []int{0})
err := csaga.Submit()
assert.Nil(t, err)
WaitTransProcessed(csaga.Gid)
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(csaga.Gid))
err = csaga.Submit()
assert.Error(t, err)
}
func csagaCommittedOngoing(t *testing.T) {
csaga := genCSaga("gid-committed-ongoing-csaga", false, false)
examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
csaga.Submit()
WaitTransProcessed(csaga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(csaga.Gid))
CronTransOnce()
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid))
}

39
test/saga_test.go

@ -10,8 +10,10 @@ import (
func TestSaga(t *testing.T) {
sagaNormal(t)
sagaCommittedPending(t)
sagaCommittedOngoing(t)
sagaRollback(t)
sagaRollback2(t)
sagaTimeout(t)
}
func sagaNormal(t *testing.T) {
@ -25,9 +27,9 @@ func sagaNormal(t *testing.T) {
assert.Error(t, err)
}
func sagaCommittedPending(t *testing.T) {
saga := genSaga("gid-committedPending", false, false)
examples.MainSwitch.TransOutResult.SetOnce("PENDING")
func sagaCommittedOngoing(t *testing.T) {
saga := genSaga("gid-committedOngoing", false, false)
examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
saga.Submit()
WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid))
@ -37,12 +39,12 @@ func sagaCommittedPending(t *testing.T) {
}
func sagaRollback(t *testing.T) {
saga := genSaga("gid-rollbackSaga2", false, true)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING")
saga := genSaga("gid-rollback-saga", false, true)
examples.MainSwitch.TransOutRevertResult.SetOnce("ERROR")
err := saga.Submit()
assert.Nil(t, err)
WaitTransProcessed(saga.Gid)
assert.Equal(t, "aborting", getTransStatus(saga.Gid))
assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid))
@ -50,6 +52,29 @@ func sagaRollback(t *testing.T) {
assert.Error(t, err)
}
func sagaRollback2(t *testing.T) {
saga := genSaga("gid-rollback-saga2", false, false)
saga.TimeoutToFail = 1800
examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing)
err := saga.Submit()
assert.Nil(t, err)
WaitTransProcessed(saga.Gid)
cronTransOnceForwardNow(3600)
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid))
}
func sagaTimeout(t *testing.T) {
saga := genSaga("gid-timeout-saga", false, false)
saga.TimeoutToFail = 1800
examples.MainSwitch.TransOutResult.SetOnce("UNKOWN")
saga.Submit()
WaitTransProcessed(saga.Gid)
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(saga.Gid))
cronTransOnceForwardNow(3600)
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid))
}
func genSaga(gid string, outFailed bool, inFailed bool) *dtmcli.Saga {
dtmcli.Logf("beginning a saga test ---------------- %s", gid)
saga := dtmcli.NewSaga(examples.DtmServer, gid)

4
test/tcc_test.go

@ -32,12 +32,12 @@ func tccRollback(t *testing.T) {
err := dtmcli.TccGlobalTransaction(examples.DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) {
_, rerr := tcc.CallBranch(data, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert")
assert.Nil(t, rerr)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING")
examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
return tcc.CallBranch(data, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert")
})
assert.Error(t, err)
WaitTransProcessed(gid)
assert.Equal(t, "aborting", getTransStatus(gid))
assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid))
}

9
test/types.go

@ -1,6 +1,8 @@
package test
import (
"time"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmsvr"
@ -27,3 +29,10 @@ type TransBranch = dtmsvr.TransBranch
// M alias
type M = dtmcli.M
func cronTransOnceForwardNow(seconds int) {
old := dtmsvr.NowForwardDuration
dtmsvr.NowForwardDuration = time.Duration(seconds) * time.Second
CronTransOnce()
dtmsvr.NowForwardDuration = old
}

14
test/wait_saga_test.go

@ -11,13 +11,13 @@ import (
func TestWaitSaga(t *testing.T) {
sagaNormalWait(t)
sagaCommittedPendingWait(t)
sagaCommittedOngoingWait(t)
sagaRollbackWait(t)
}
func sagaNormalWait(t *testing.T) {
saga := genSaga("gid-noramlSagaWait", false, false)
saga.WaitResult = true
saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit()
assert.Nil(t, err)
WaitTransProcessed(saga.Gid)
@ -26,10 +26,10 @@ func sagaNormalWait(t *testing.T) {
transQuery(t, saga.Gid)
}
func sagaCommittedPendingWait(t *testing.T) {
saga := genSaga("gid-committedPendingWait", false, false)
examples.MainSwitch.TransOutResult.SetOnce("PENDING")
saga.WaitResult = true
func sagaCommittedOngoingWait(t *testing.T) {
saga := genSaga("gid-committedOngoingWait", false, false)
examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit()
assert.Error(t, err)
WaitTransProcessed(saga.Gid)
@ -41,7 +41,7 @@ func sagaCommittedPendingWait(t *testing.T) {
func sagaRollbackWait(t *testing.T) {
saga := genSaga("gid-rollbackSaga2Wait", false, true)
saga.WaitResult = true
saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit()
assert.Error(t, err)
WaitTransProcessed(saga.Gid)

17
test/xa_test.go

@ -16,6 +16,7 @@ func TestXa(t *testing.T) {
xaNormal(t)
xaDuplicate(t)
xaRollback(t)
xaTimeout(t)
}
func xaLocalError(t *testing.T) {
@ -75,3 +76,19 @@ func xaRollback(t *testing.T) {
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusPrepared}, getBranchesStatus(gid))
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid))
}
func xaTimeout(t *testing.T) {
xc := examples.XaClient
gid := "xaTimeout"
timeoutChan := make(chan int, 1)
err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) {
go func() {
cronTransOnceForwardNow(1)
cronTransOnceForwardNow(300)
timeoutChan <- 0
}()
_ = <-timeoutChan
return nil, nil
})
assert.Error(t, err)
}

Loading…
Cancel
Save