Browse Source

add concurrent saga

pull/44/head
yedf2 4 years ago
parent
commit
a3d8bf4be4
  1. 32
      common/types.go
  2. 23
      common/types_test.go
  3. 5
      common/utils.go
  4. 29
      dtmcli/concurrent_saga.go
  5. 28
      dtmcli/types.go
  6. 18
      dtmcli/utils.go
  7. 7
      dtmcli/utils_test.go
  8. 17
      dtmgrpc/barrier.go
  9. 9
      dtmgrpc/type.go
  10. 22
      dtmsvr/api.go
  11. 4
      dtmsvr/api_grpc.go
  12. 17
      dtmsvr/api_http.go
  13. 12
      dtmsvr/cron.go
  14. 2
      dtmsvr/dtmsvr.go
  15. 143
      dtmsvr/trans.go
  16. 175
      dtmsvr/trans_concurrent_saga.go
  17. 22
      dtmsvr/trans_msg.go
  18. 29
      dtmsvr/trans_saga.go
  19. 12
      dtmsvr/trans_tcc.go
  20. 12
      dtmsvr/trans_xa.go
  21. 6
      dtmsvr/utils.go
  22. 12
      dtmsvr/utils_test.go
  23. 6
      examples/base_grpc.go
  24. 9
      examples/base_http.go
  25. 2
      examples/grpc_saga_barrier.go
  26. 2
      examples/http_msg.go
  27. 2
      examples/http_saga.go
  28. 53
      test/base_test.go
  29. 36
      test/dtmsvr_test.go
  30. 12
      test/grpc_msg_test.go
  31. 12
      test/grpc_saga_test.go
  32. 4
      test/grpc_tcc_test.go
  33. 3
      test/main_test.go
  34. 22
      test/msg_test.go
  35. 72
      test/saga_concurrent_test.go
  36. 12
      test/saga_test.go
  37. 4
      test/tcc_test.go
  38. 9
      test/types.go
  39. 14
      test/wait_saga_test.go

32
common/types.go

@ -22,7 +22,7 @@ import (
// ModelBase model base for gorm to provide base fields // ModelBase model base for gorm to provide base fields
type ModelBase struct { type ModelBase struct {
ID uint ID uint64
CreateTime *time.Time `gorm:"autoCreateTime"` CreateTime *time.Time `gorm:"autoCreateTime"`
UpdateTime *time.Time `gorm:"autoUpdateTime"` UpdateTime *time.Time `gorm:"autoUpdateTime"`
} }
@ -123,7 +123,9 @@ func DbGet(conf map[string]string) *DB {
} }
type dtmConfigType struct { type dtmConfigType struct {
TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务 TransCronInterval int64 `yaml:"TransCronInterval"`
TimeoutToFail int64 `yaml:"TimeoutToFail"`
RetryInterval int64 `yaml:"RetryInterval"`
DB map[string]string `yaml:"DB"` DB map[string]string `yaml:"DB"`
DisableLocalhost int64 `yaml:"DisableLocalhost"` DisableLocalhost int64 `yaml:"DisableLocalhost"`
UpdateBranchSync int64 `yaml:"UpdateBranchSync"` UpdateBranchSync int64 `yaml:"UpdateBranchSync"`
@ -140,7 +142,9 @@ func init() {
if len(os.Args) == 1 { if len(os.Args) == 1 {
return return
} }
DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "10") DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "3")
DtmConfig.TimeoutToFail = getIntEnv("TIMEOUT_TO_FAIL", "10")
DtmConfig.RetryInterval = getIntEnv("RETRY_INTERVAL", "10")
DtmConfig.DB = map[string]string{ DtmConfig.DB = map[string]string{
"driver": dtmcli.OrString(os.Getenv("DB_DRIVER"), "mysql"), "driver": dtmcli.OrString(os.Getenv("DB_DRIVER"), "mysql"),
"host": os.Getenv("DB_HOST"), "host": os.Getenv("DB_HOST"),
@ -166,6 +170,24 @@ func init() {
err := yaml.Unmarshal(cont, &DtmConfig) err := yaml.Unmarshal(cont, &DtmConfig)
dtmcli.FatalIfError(err) dtmcli.FatalIfError(err)
} }
dtmcli.LogIfFatalf(DtmConfig.DB["driver"] == "" || DtmConfig.DB["user"] == "", errStr := checkConfig()
"dtm配置错误. 请访问 http://dtm.pub 查看部署运维环节. check you env, and conf.yml/conf.sample.yml in current and parent path: %s. config is: \n%v", MustGetwd(), DtmConfig) dtmcli.LogIfFatalf(errStr != "",
`config error: '%s'.
check you env, and conf.yml/conf.sample.yml in current and parent path: %s.
please visit http://d.dtm.pub to see the config document.
loaded config is:
%v`, MustGetwd(), DtmConfig)
}
func checkConfig() string {
if DtmConfig.DB["driver"] == "" {
return "db driver empty"
} else if DtmConfig.DB["user"] == "" || DtmConfig.DB["host"] == "" {
return "db config not valid"
} else if DtmConfig.RetryInterval < 10 {
return "RetryInterval should not be less than 10"
} else if DtmConfig.TimeoutToFail < DtmConfig.RetryInterval {
return "TimeoutToFail should not be less than RetryInterval"
}
return ""
} }

23
common/types_test.go

@ -30,3 +30,26 @@ func TestDbAlone(t *testing.T) {
_, err = dtmcli.DBExec(db, "select 1") _, err = dtmcli.DBExec(db, "select 1")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
func TestConfig(t *testing.T) {
testConfigStringField(DtmConfig.DB, "driver", "", t)
testConfigStringField(DtmConfig.DB, "user", "", t)
testConfigIntField(&DtmConfig.RetryInterval, 9, t)
testConfigIntField(&DtmConfig.TimeoutToFail, 9, t)
}
func testConfigStringField(m map[string]string, key string, val string, t *testing.T) {
old := m[key]
m[key] = val
str := checkConfig()
assert.NotEqual(t, "", str)
m[key] = old
}
func testConfigIntField(fd *int64, val int64, t *testing.T) {
old := *fd
*fd = val
str := checkConfig()
assert.NotEqual(t, "", str)
*fd = old
}

5
common/utils.go

@ -43,7 +43,10 @@ func GetGinApp() *gin.Engine {
// WrapHandler name is clear // WrapHandler name is clear
func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
r, err := fn(c) r, err := func() (r interface{}, rerr error) {
defer dtmcli.P2E(&rerr)
return fn(c)
}()
var b = []byte{} var b = []byte{}
if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理 if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理
b = resp.Body() b = resp.Body()

29
dtmcli/concurrent_saga.go

@ -0,0 +1,29 @@
package dtmcli
import "fmt"
// ConcurrentSaga struct of concurrent saga
type ConcurrentSaga struct {
Saga
orders map[int][]int
}
// NewConcurrentSaga create a concurrent saga
func NewConcurrentSaga(server string, gid string) *ConcurrentSaga {
return &ConcurrentSaga{Saga: Saga{TransBase: *NewTransBase(gid, "csaga", server, "")}, orders: map[int][]int{}}
}
// AddStepOrder specify that step should be after preSteps. Step is larger than all the element in preSteps
func (s *ConcurrentSaga) AddStepOrder(step int, preSteps []int) *ConcurrentSaga {
PanicIf(step > len(s.Steps), fmt.Errorf("step value: %d is invalid. which cannot be larger than total steps: %d", step, len(s.Steps)))
s.orders[step] = preSteps
return s
}
// Submit submit the saga trans
func (s *ConcurrentSaga) Submit() error {
if len(s.orders) > 0 {
s.CustomData = MustMarshalString(M{"orders": s.orders})
}
return s.callDtm(s, "submit")
}

28
dtmcli/types.go

@ -51,14 +51,26 @@ type TransResult struct {
Message string Message string
} }
// TransOptions transaction options
type TransOptions struct {
WaitResult bool `json:"wait_result,omitempty" gorm:"-"`
TimeoutToFail int64 `json:"timeout_to_fail,omitempty" gorm:"-"` // for trans type: xa, tcc
RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc
}
// TransBase 事务的基础类 // TransBase 事务的基础类
type TransBase struct { type TransBase struct {
Gid string `json:"gid"` Gid string `json:"gid"`
TransType string `json:"trans_type"` TransType string `json:"trans_type"`
Dtm string `json:"-"`
CustomData string `json:"custom_data,omitempty"`
IDGenerator IDGenerator
Dtm string TransOptions
// WaitResult 是否等待全局事务的最终结果 }
WaitResult bool
// SetOptions set options
func (tb *TransBase) SetOptions(options *TransOptions) {
tb.TransOptions = *options
} }
// NewTransBase 1 // NewTransBase 1
@ -95,10 +107,10 @@ func (tb *TransBase) callDtm(body interface{}, operation string) error {
} }
// ErrFailure 表示返回失败,要求回滚 // ErrFailure 表示返回失败,要求回滚
var ErrFailure = errors.New("transaction FAILURE") var ErrFailure = errors.New("FAILURE")
// ErrPending 表示暂时失败,要求重试 // ErrOngoing 表示暂时失败,要求重试
var ErrPending = errors.New("transaction PENDING") var ErrOngoing = errors.New("ONGOING")
// MapSuccess 表示返回成功,可以进行下一步 // MapSuccess 表示返回成功,可以进行下一步
var MapSuccess = M{"dtm_result": ResultSuccess} var MapSuccess = M{"dtm_result": ResultSuccess}

18
dtmcli/utils.go

@ -17,14 +17,18 @@ import (
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
) )
// AsError wrap a panic value as an error
func AsError(x interface{}) error {
if e, ok := x.(error); ok {
return e
}
return fmt.Errorf("%v", x)
}
// P2E panic to error // P2E panic to error
func P2E(perr *error) { func P2E(perr *error) {
if x := recover(); x != nil { if x := recover(); x != nil {
if e, ok := x.(error); ok { *perr = AsError(x)
*perr = e
} else {
panic(x)
}
} }
} }
@ -262,8 +266,8 @@ func CheckResult(res interface{}, err error) error {
str := MustMarshalString(res) str := MustMarshalString(res)
if strings.Contains(str, ResultFailure) { if strings.Contains(str, ResultFailure) {
return ErrFailure return ErrFailure
} else if strings.Contains(str, "PENDING") { } else if strings.Contains(str, ResultOngoing) {
return ErrPending return ErrOngoing
} }
} }
return err return err

7
dtmcli/utils_test.go

@ -25,13 +25,10 @@ func TestEP(t *testing.T) {
}) })
assert.Equal(t, "err2", err.Error()) assert.Equal(t, "err2", err.Error())
err = func() (rerr error) { err = func() (rerr error) {
defer func() {
x := recover()
assert.Equal(t, 1, x)
}()
defer P2E(&rerr) defer P2E(&rerr)
panic(1) panic("raw_string")
}() }()
assert.Equal(t, "raw_string", err.Error())
} }
func TestTernary(t *testing.T) { func TestTernary(t *testing.T) {

17
dtmgrpc/barrier.go

@ -2,8 +2,6 @@ package dtmgrpc
import ( import (
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
"google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
) )
// BranchBarrier 子事务屏障 // BranchBarrier 子事务屏障
@ -11,21 +9,6 @@ type BranchBarrier struct {
*dtmcli.BranchBarrier *dtmcli.BranchBarrier
} }
// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
// db: 本地数据库
// transInfo: 事务信息
// bisiCall: 业务函数,仅在必要时被调用
// 返回值:
// 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚
// 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行
func (bb *BranchBarrier) Call(tx dtmcli.Tx, busiCall dtmcli.BusiFunc) (rerr error) {
err := bb.BranchBarrier.Call(tx, busiCall)
if err == dtmcli.ErrFailure {
return status.New(codes.Aborted, "user rollback").Err()
}
return err
}
// BarrierFromGrpc 从BusiRequest生成一个Barrier // BarrierFromGrpc 从BusiRequest生成一个Barrier
func BarrierFromGrpc(in *BusiRequest) (*BranchBarrier, error) { func BarrierFromGrpc(in *BusiRequest) (*BranchBarrier, error) {
b, err := dtmcli.BarrierFrom(in.Info.TransType, in.Info.Gid, in.Info.BranchID, in.Info.BranchType) b, err := dtmcli.BarrierFrom(in.Info.TransType, in.Info.Gid, in.Info.BranchID, in.Info.BranchType)

9
dtmgrpc/type.go

@ -9,7 +9,7 @@ import (
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes" codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
) )
@ -89,9 +89,10 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c
func Result2Error(res interface{}, err error) error { func Result2Error(res interface{}, err error) error {
e := dtmcli.CheckResult(res, err) e := dtmcli.CheckResult(res, err)
if e == dtmcli.ErrFailure { if e == dtmcli.ErrFailure {
return status.New(codes.Aborted, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err() dtmcli.LogRedf("failure: res: %v, err: %v", res, e)
} else if e == dtmcli.ErrPending { return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
return status.New(codes.Unavailable, fmt.Sprintf("failure: res: %v, err: %s", res, e.Error())).Err() } else if e == dtmcli.ErrOngoing {
return status.New(codes.Aborted, dtmcli.ResultOngoing).Err()
} }
return e return e
} }

22
dtmsvr/api.go

@ -7,7 +7,7 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
func svcSubmit(t *TransGlobal, waitResult bool) (interface{}, error) { func svcSubmit(t *TransGlobal) (interface{}, error) {
db := dbGet() db := dbGet()
t.Status = dtmcli.StatusSubmitted t.Status = dtmcli.StatusSubmitted
err := t.saveNew(db) err := t.saveNew(db)
@ -15,13 +15,13 @@ func svcSubmit(t *TransGlobal, waitResult bool) (interface{}, error) {
if err == errUniqueConflict { if err == errUniqueConflict {
dbt := TransFromDb(db, t.Gid) dbt := TransFromDb(db, t.Gid)
if dbt.Status == dtmcli.StatusPrepared { if dbt.Status == dtmcli.StatusPrepared {
updates := t.setNextCron(config.TransCronInterval) updates := t.setNextCron(cronReset)
db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t)
} else if dbt.Status != dtmcli.StatusSubmitted { } else if dbt.Status != dtmcli.StatusSubmitted {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot sumbmit", dbt.Status)}, nil return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil
} }
} }
return t.Process(db, waitResult), nil return t.Process(db), nil
} }
func svcPrepare(t *TransGlobal) (interface{}, error) { func svcPrepare(t *TransGlobal) (interface{}, error) {
@ -30,19 +30,19 @@ func svcPrepare(t *TransGlobal) (interface{}, error) {
if err == errUniqueConflict { if err == errUniqueConflict {
dbt := TransFromDb(dbGet(), t.Gid) dbt := TransFromDb(dbGet(), t.Gid)
if dbt.Status != dtmcli.StatusPrepared { if dbt.Status != dtmcli.StatusPrepared {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status %s, cannot prepare", dbt.Status)}, nil return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil
} }
} }
return dtmcli.MapSuccess, nil return dtmcli.MapSuccess, nil
} }
func svcAbort(t *TransGlobal, waitResult bool) (interface{}, error) { func svcAbort(t *TransGlobal) (interface{}, error) {
db := dbGet() db := dbGet()
dbt := TransFromDb(db, t.Gid) dbt := TransFromDb(db, t.Gid)
if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != "aborting" { if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting {
return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: %s current status %s, cannot abort", dbt.TransType, dbt.Status)}, nil return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil
} }
return dbt.Process(db, waitResult), nil return dbt.Process(db), nil
} }
func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, error) { func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, error) {
@ -62,7 +62,7 @@ func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, err
DoNothing: true, DoNothing: true,
}).Create(branches) }).Create(branches)
global := TransGlobal{Gid: branch.Gid} global := TransGlobal{Gid: branch.Gid}
global.touch(dbGet(), config.TransCronInterval) global.touch(dbGet(), cronKeep)
return dtmcli.MapSuccess, nil return dtmcli.MapSuccess, nil
} }
@ -80,6 +80,6 @@ func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) {
DoNothing: true, DoNothing: true,
}).Create(branches) }).Create(branches)
global := TransGlobal{Gid: branch.Gid} global := TransGlobal{Gid: branch.Gid}
global.touch(db, config.TransCronInterval) global.touch(db, cronKeep)
return dtmcli.MapSuccess, nil return dtmcli.MapSuccess, nil
} }

4
dtmsvr/api_grpc.go

@ -19,7 +19,7 @@ func (s *dtmServer) NewGid(ctx context.Context, in *emptypb.Empty) (*dtmgrpc.Dtm
} }
func (s *dtmServer) Submit(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { func (s *dtmServer) Submit(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) {
r, err := svcSubmit(TransFromDtmRequest(in), in.WaitResult) r, err := svcSubmit(TransFromDtmRequest(in))
return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err) return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err)
} }
@ -29,7 +29,7 @@ func (s *dtmServer) Prepare(ctx context.Context, in *pb.DtmRequest) (*emptypb.Em
} }
func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) {
r, err := svcAbort(TransFromDtmRequest(in), in.WaitResult) r, err := svcAbort(TransFromDtmRequest(in))
return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err) return &emptypb.Empty{}, dtmgrpc.Result2Error(r, err)
} }

17
dtmsvr/api_http.go

@ -2,6 +2,7 @@ package dtmsvr
import ( import (
"errors" "errors"
"math"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
@ -10,14 +11,14 @@ import (
) )
func addRoute(engine *gin.Engine) { func addRoute(engine *gin.Engine) {
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid))
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare)) engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare))
engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit)) engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort))
engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch)) engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch))
engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch)) engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort))
engine.GET("/api/dtmsvr/query", common.WrapHandler(query)) engine.GET("/api/dtmsvr/query", common.WrapHandler(query))
engine.GET("/api/dtmsvr/all", common.WrapHandler(all)) engine.GET("/api/dtmsvr/all", common.WrapHandler(all))
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid))
} }
func newGid(c *gin.Context) (interface{}, error) { func newGid(c *gin.Context) (interface{}, error) {
@ -29,11 +30,11 @@ func prepare(c *gin.Context) (interface{}, error) {
} }
func submit(c *gin.Context) (interface{}, error) { func submit(c *gin.Context) (interface{}, error) {
return svcSubmit(TransFromContext(c), c.Query("wait_result") == "1") return svcSubmit(TransFromContext(c))
} }
func abort(c *gin.Context) (interface{}, error) { func abort(c *gin.Context) (interface{}, error) {
return svcAbort(TransFromContext(c), c.Query("wait_result") == "1") return svcAbort(TransFromContext(c))
} }
func registerXaBranch(c *gin.Context) (interface{}, error) { func registerXaBranch(c *gin.Context) (interface{}, error) {
@ -74,11 +75,11 @@ func query(c *gin.Context) (interface{}, error) {
} }
func all(c *gin.Context) (interface{}, error) { func all(c *gin.Context) (interface{}, error) {
lastId := c.Query("last_id") lastID := c.Query("last_id")
if lastId == "" { lid := math.MaxInt64
lastId = "2000000000" if lastID != "" {
lid = dtmcli.MustAtoi(lastID)
} }
lid := dtmcli.MustAtoi(lastId)
trans := []TransGlobal{} trans := []TransGlobal{}
dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(100).Find(&trans) dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(100).Find(&trans)
return M{"transactions": trans}, nil return M{"transactions": trans}, nil

12
dtmsvr/cron.go

@ -10,7 +10,10 @@ import (
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
) )
// CronForwardDuration will be set in test, cron will fetch trans which expire in CronForwardDuration // NowForwardDuration will be set in test, trans may be timeout
var NowForwardDuration time.Duration = time.Duration(0)
// CronForwardDuration will be set in test. cron will fetch trans which expire in CronForwardDuration
var CronForwardDuration time.Duration = time.Duration(0) var CronForwardDuration time.Duration = time.Duration(0)
// CronTransOnce cron expired trans. use expireIn as expire time // CronTransOnce cron expired trans. use expireIn as expire time
@ -24,7 +27,8 @@ func CronTransOnce() (hasTrans bool) {
if TransProcessedTestChan != nil { if TransProcessedTestChan != nil {
defer WaitTransProcessed(trans.Gid) defer WaitTransProcessed(trans.Gid)
} }
trans.Process(dbGet(), true) trans.WaitResult = true
trans.Process(dbGet())
return return
} }
@ -44,7 +48,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal {
db := dbGet() db := dbGet()
getTime := dtmcli.GetDBSpecial().TimestampAdd getTime := dtmcli.GetDBSpecial().TimestampAdd
expire := int(expireIn / time.Second) expire := int(expireIn / time.Second)
whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3)) whereTime := fmt.Sprintf("next_cron_time < %s and update_time < %s", getTime(expire), getTime(expire-3))
// 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢 // 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢
// 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出 // 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出
dbr := db.Must().Model(&trans). dbr := db.Must().Model(&trans).
@ -53,7 +57,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal {
return nil return nil
} }
dbr = db.Must().Where("owner=?", owner).Find(&trans) dbr = db.Must().Where("owner=?", owner).Find(&trans)
updates := trans.setNextCron(trans.NextCronInterval * 2) // 下次被cron的间隔加倍 updates := trans.setNextCron(cronKeep)
db.Must().Model(&trans).Select(updates).Updates(&trans) db.Must().Model(&trans).Select(updates).Updates(&trans)
return &trans return &trans
} }

2
dtmsvr/dtmsvr.go

@ -70,7 +70,7 @@ func updateBranchAsync() {
updates = append(updates, TransBranch{ updates = append(updates, TransBranch{
ModelBase: common.ModelBase{ID: updateBranch.id}, ModelBase: common.ModelBase{ID: updateBranch.id},
Status: updateBranch.status, Status: updateBranch.status,
FinishTime: updateBranch.finish_time, FinishTime: updateBranch.finishTime,
}) })
case <-time.After(checkInterval): case <-time.After(checkInterval):
} }

143
dtmsvr/trans.go

@ -32,9 +32,12 @@ type TransGlobal struct {
CommitTime *time.Time CommitTime *time.Time
FinishTime *time.Time FinishTime *time.Time
RollbackTime *time.Time RollbackTime *time.Time
Options string
CustomData string `json:"custom_data"`
NextCronInterval int64 NextCronInterval int64
NextCronTime *time.Time NextCronTime *time.Time
processStarted time.Time // record the start time of process dtmcli.TransOptions
processStarted time.Time // record the start time of process
} }
// TableName TableName // TableName TableName
@ -44,12 +47,12 @@ func (*TransGlobal) TableName() string {
type transProcessor interface { type transProcessor interface {
GenBranches() []TransBranch GenBranches() []TransBranch
ProcessOnce(db *common.DB, branches []TransBranch) ProcessOnce(db *common.DB, branches []TransBranch) error
} }
func (t *TransGlobal) touch(db *common.DB, interval int64) *gorm.DB { func (t *TransGlobal) touch(db *common.DB, ctype cronType) *gorm.DB {
writeTransLog(t.Gid, "touch trans", "", "", "") writeTransLog(t.Gid, "touch trans", "", "", "")
updates := t.setNextCron(interval) updates := t.setNextCron(ctype)
return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Select(updates).Updates(t) return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Select(updates).Updates(t)
} }
@ -57,7 +60,7 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB {
writeTransLog(t.Gid, "change status", status, "", "") writeTransLog(t.Gid, "change status", status, "", "")
old := t.Status old := t.Status
t.Status = status t.Status = status
updates := t.setNextCron(config.TransCronInterval) updates := t.setNextCron(cronReset)
updates = append(updates, "status") updates = append(updates, "status")
now := time.Now() now := time.Now()
if status == dtmcli.StatusSucceed { if status == dtmcli.StatusSucceed {
@ -72,6 +75,28 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB {
return dbr return dbr
} }
func (t *TransGlobal) isTimeout() bool {
timeout := t.TimeoutToFail
if t.TimeoutToFail == 0 && t.TransType != "saga" && t.TransType != "csaga" {
timeout = config.TimeoutToFail
}
if timeout == 0 {
return false
}
return time.Since(*t.CreateTime)+NowForwardDuration >= time.Duration(timeout)*time.Second
}
func (t *TransGlobal) getRetryInterval() int64 {
if t.RetryInterval > 0 {
return t.RetryInterval
}
return config.RetryInterval
}
func (t *TransGlobal) needProcess() bool {
return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout()
}
// TransBranch branch transaction // TransBranch branch transaction
type TransBranch struct { type TransBranch struct {
common.ModelBase common.ModelBase
@ -124,14 +149,14 @@ func (t *TransGlobal) getProcessor() transProcessor {
} }
// Process process global transaction once // Process process global transaction once
func (t *TransGlobal) Process(db *common.DB, waitResult bool) dtmcli.M { func (t *TransGlobal) Process(db *common.DB) dtmcli.M {
r := t.process(db, waitResult) r := t.process(db)
transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess) transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess)
return r return r
} }
func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M { func (t *TransGlobal) process(db *common.DB) dtmcli.M {
if !waitResult { if !t.WaitResult {
go t.processInner(db) go t.processInner(db)
return dtmcli.MapSuccess return dtmcli.MapSuccess
} }
@ -149,6 +174,9 @@ func (t *TransGlobal) process(db *common.DB, waitResult bool) dtmcli.M {
func (t *TransGlobal) processInner(db *common.DB) (rerr error) { func (t *TransGlobal) processInner(db *common.DB) (rerr error) {
defer handlePanic(&rerr) defer handlePanic(&rerr)
defer func() { defer func() {
if rerr != nil {
dtmcli.LogRedf("processInner got error: %s", rerr.Error())
}
if TransProcessedTestChan != nil { if TransProcessedTestChan != nil {
dtmcli.Logf("processed: %s", t.Gid) dtmcli.Logf("processed: %s", t.Gid)
TransProcessedTestChan <- t.Gid TransProcessedTestChan <- t.Gid
@ -157,23 +185,40 @@ func (t *TransGlobal) processInner(db *common.DB) (rerr error) {
}() }()
dtmcli.Logf("processing: %s status: %s", t.Gid, t.Status) dtmcli.Logf("processing: %s status: %s", t.Gid, t.Status)
if t.Status == dtmcli.StatusPrepared && t.TransType != "msg" { if t.Status == dtmcli.StatusPrepared && t.TransType != "msg" {
t.changeStatus(db, "aborting") t.changeStatus(db, dtmcli.StatusAborting)
} }
branches := []TransBranch{} branches := []TransBranch{}
db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches) db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches)
t.processStarted = time.Now() t.processStarted = time.Now()
t.getProcessor().ProcessOnce(db, branches) rerr = t.getProcessor().ProcessOnce(db, branches)
return return
} }
func (t *TransGlobal) setNextCron(expireIn int64) []string { type cronType int
t.NextCronInterval = expireIn
const (
cronBackoff cronType = iota
cronReset
cronKeep
)
func (t *TransGlobal) setNextCron(ctype cronType) []string {
if ctype == cronBackoff {
t.NextCronInterval = t.NextCronInterval * 2
} else if ctype == cronKeep {
// do nothing
} else if t.RetryInterval != 0 {
t.NextCronInterval = t.RetryInterval
} else {
t.NextCronInterval = config.RetryInterval
}
next := time.Now().Add(time.Duration(t.NextCronInterval) * time.Second) next := time.Now().Add(time.Duration(t.NextCronInterval) * time.Second)
t.NextCronTime = &next t.NextCronTime = &next
return []string{"next_cron_interval", "next_cron_time"} return []string{"next_cron_interval", "next_cron_time"}
} }
func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) string { func (t *TransGlobal) getURLResult(url string, branchID, branchType string, branchData []byte) (string, error) {
if t.Protocol == "grpc" { if t.Protocol == "grpc" {
dtmcli.PanicIf(strings.HasPrefix(url, "http"), fmt.Errorf("bad url for grpc: %s", url)) dtmcli.PanicIf(strings.HasPrefix(url, "http"), fmt.Errorf("bad url for grpc: %s", url))
server, method := dtmgrpc.GetServerAndMethod(url) server, method := dtmgrpc.GetServerAndMethod(url)
@ -188,11 +233,17 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran
BusiData: branchData, BusiData: branchData,
}, &emptypb.Empty{}) }, &emptypb.Empty{})
if err == nil { if err == nil {
return dtmcli.ResultSuccess return dtmcli.ResultSuccess, nil
} else if status.Code(err) == codes.Aborted { }
return dtmcli.ResultFailure st, ok := status.FromError(err)
if ok && st.Code() == codes.Aborted {
if st.Message() == dtmcli.ResultOngoing {
return dtmcli.ResultOngoing, nil
} else if st.Message() == dtmcli.ResultFailure {
return dtmcli.ResultFailure, nil
}
} }
return err.Error() return "", err
} }
dtmcli.PanicIf(!strings.HasPrefix(url, "http"), fmt.Errorf("bad url for http: %s", url)) dtmcli.PanicIf(!strings.HasPrefix(url, "http"), fmt.Errorf("bad url for http: %s", url))
resp, err := dtmcli.RestyClient.R().SetBody(string(branchData)). resp, err := dtmcli.RestyClient.R().SetBody(string(branchData)).
@ -204,36 +255,49 @@ func (t *TransGlobal) getURLResult(url string, branchID, branchType string, bran
}). }).
SetHeader("Content-type", "application/json"). SetHeader("Content-type", "application/json").
Execute(dtmcli.If(branchData == nil, "GET", "POST").(string), url) Execute(dtmcli.If(branchData == nil, "GET", "POST").(string), url)
e2p(err) if err != nil {
return resp.String() return "", err
}
return resp.String(), nil
} }
func (t *TransGlobal) getBranchResult(branch *TransBranch) string { func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) {
return t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data)) body, err := t.getURLResult(branch.URL, branch.BranchID, branch.BranchType, []byte(branch.Data))
if err != nil {
return "", err
}
if strings.Contains(body, dtmcli.ResultSuccess) {
return dtmcli.StatusSucceed, nil
} else if strings.HasSuffix(t.TransType, "saga") && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) {
return dtmcli.StatusFailed, nil
} else if strings.Contains(body, dtmcli.ResultOngoing) {
return "", dtmcli.ErrOngoing
}
return "", fmt.Errorf("http result should contains SUCCESS|FAILURE|ONGOING. grpc error should return nil|Aborted with message(FAILURE|ONGOING). \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body)
} }
func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) { func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) error {
body := t.getBranchResult(branch) status, err := t.getBranchResult(branch)
status := "" if status != "" {
if strings.Contains(body, dtmcli.ResultSuccess) { branch.changeStatus(db, status)
status = dtmcli.StatusSucceed
} else if t.TransType == "saga" && branch.BranchType == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) {
status = dtmcli.StatusFailed
} else {
panic(fmt.Errorf("http result should contains SUCCESS|FAILURE. grpc error should return nil|Aborted. \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body))
} }
branchMetrics(t, branch, status == dtmcli.StatusSucceed) branchMetrics(t, branch, status == dtmcli.StatusSucceed)
// 如果一次处理超过1500ms,那么touch一下TransGlobal,避免被Cron取出 // if time pass 1500ms and NextCronInterval is not default, then reset NextCronInterval
if time.Since(t.processStarted)+CronForwardDuration >= 1500*time.Millisecond || t.NextCronInterval > config.TransCronInterval { if err == nil && time.Since(t.processStarted)+NowForwardDuration >= 1500*time.Millisecond ||
t.touch(db, config.TransCronInterval) t.NextCronInterval > config.RetryInterval && t.NextCronInterval > t.RetryInterval {
t.touch(db, cronReset)
} else if err == dtmcli.ErrOngoing {
t.touch(db, cronKeep)
} else {
t.touch(db, cronBackoff)
} }
branch.changeStatus(db, status) return err
} }
func (t *TransGlobal) saveNew(db *common.DB) error { func (t *TransGlobal) saveNew(db *common.DB) error {
return db.Transaction(func(db1 *gorm.DB) error { return db.Transaction(func(db1 *gorm.DB) error {
db := &common.DB{DB: db1} db := &common.DB{DB: db1}
t.setNextCron(config.TransCronInterval) t.setNextCron(cronReset)
writeTransLog(t.Gid, "create trans", t.Status, "", t.Data) writeTransLog(t.Gid, "create trans", t.Status, "", t.Data)
dbr := db.Must().Clauses(clause.OnConflict{ dbr := db.Must().Clauses(clause.OnConflict{
DoNothing: true, DoNothing: true,
@ -265,6 +329,10 @@ func TransFromContext(c *gin.Context) *TransGlobal {
} }
m := TransGlobal{} m := TransGlobal{}
dtmcli.MustRemarshal(data, &m) dtmcli.MustRemarshal(data, &m)
m.Options = dtmcli.MustMarshalString(m.TransOptions)
if m.Options == "{}" {
m.Options = ""
}
m.Protocol = "http" m.Protocol = "http"
return &m return &m
} }
@ -288,6 +356,9 @@ func TransFromDb(db *common.DB, gid string) *TransGlobal {
return nil return nil
} }
e2p(dbr.Error) e2p(dbr.Error)
if m.Options != "" {
dtmcli.MustUnmarshalString(m.Options, &m.TransOptions)
}
return &m return &m
} }

175
dtmsvr/trans_concurrent_saga.go

@ -0,0 +1,175 @@
package dtmsvr
import (
"time"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"gorm.io/gorm/clause"
)
type transCSagaProcessor struct {
*TransGlobal
}
func init() {
registorProcessorCreator("csaga", func(trans *TransGlobal) transProcessor { return &transCSagaProcessor{TransGlobal: trans} })
}
func (t *transCSagaProcessor) GenBranches() []TransBranch {
return genSagaBranches(t.TransGlobal)
}
type cSagaCustom struct {
Orders map[int][]int `json:"orders"`
}
func isPreconditionsSucceed(branches []TransBranch, pres []int) bool {
for _, pre := range pres {
if branches[pre].Status != dtmcli.StatusSucceed {
return false
}
}
return true
}
type branchResult struct {
index int
status string
started bool
branchType string
}
func (t *transCSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed {
return nil
}
n := len(branches)
orders := map[int][]int{}
if t.CustomData != "" {
csc := cSagaCustom{Orders: map[int][]int{}}
dtmcli.MustUnmarshalString(t.CustomData, &csc)
for k, v := range csc.Orders { // new branches is doubled, so change the order value
orders[2*k+1] = []int{}
for j := 0; j < len(v); j++ {
orders[2*k+1] = append(orders[2*k+1], csc.Orders[k][j]*2+1)
}
}
}
// resultStats
var rsActionToStart, rsActionDone, rsActionFailed, rsActionSucceed, rsCompensateToStart, rsCompensateDone, rsCompensateSucceed int
branchResults := make([]branchResult, n) // save the branch result
for i := 0; i < n; i++ {
b := branches[i]
if b.BranchType == dtmcli.BranchAction {
if b.Status == dtmcli.StatusPrepared || b.Status == dtmcli.StatusDoing {
rsActionToStart++
} else if b.Status == dtmcli.StatusFailed {
rsActionFailed++
}
}
branchResults[i] = branchResult{status: branches[i].Status, branchType: branches[i].BranchType}
}
stopChan := make(chan branchResult, n)
asyncExecBranch := func(i int) {
var err error
defer func() {
if x := recover(); x != nil {
err = dtmcli.AsError(x)
}
stopChan <- branchResult{index: i, status: branches[i].Status, branchType: branches[i].BranchType}
if err != nil {
dtmcli.LogRedf("exec branch error: %v", err)
}
}()
err = t.execBranch(db, &branches[i])
}
needRollback := func(i int) bool {
br := &branchResults[i]
return !br.started && br.branchType == dtmcli.BranchCompensate && br.status != dtmcli.StatusSucceed && branchResults[i+1].branchType == dtmcli.BranchAction && branchResults[i+1].status != dtmcli.StatusPrepared
}
pickAndRun := func(branchType string) {
toRun := []int{}
for current := 0; current < n; current++ {
br := &branchResults[current]
if br.branchType == branchType && branchType == dtmcli.BranchAction {
if (br.status == dtmcli.StatusPrepared || br.status == dtmcli.StatusDoing) &&
!br.started && isPreconditionsSucceed(branches, orders[current]) {
br.status = dtmcli.StatusDoing
toRun = append(toRun, current)
}
} else if br.branchType == branchType && branchType == dtmcli.BranchCompensate {
if needRollback(current) {
toRun = append(toRun, current)
}
}
}
if branchType == dtmcli.BranchAction && len(toRun) > 0 {
updates := make([]TransBranch, len(toRun))
for i, b := range toRun {
updates[i].ID = branches[b].ID
branches[b].Status = dtmcli.StatusDoing
updates[i].Status = dtmcli.StatusDoing
}
dbGet().Must().Clauses(clause.OnConflict{
OnConstraint: "trans_branch_pkey",
DoUpdates: clause.AssignmentColumns([]string{"status"}),
}).Create(updates)
} else if branchType == dtmcli.BranchCompensate {
rsCompensateToStart = len(toRun)
}
for _, b := range toRun {
branchResults[b].started = true
go asyncExecBranch(b)
}
}
processorTimeout := func() bool {
return time.Since(t.processStarted)+NowForwardDuration > time.Duration(t.getRetryInterval()-3)*time.Second
}
waitOnceForDone := func() {
select {
case r := <-stopChan:
br := &branchResults[r.index]
br.status = r.status
if r.branchType == dtmcli.BranchAction {
rsActionDone++
if r.status == dtmcli.StatusFailed {
rsActionFailed++
} else if r.status == dtmcli.StatusSucceed {
rsActionSucceed++
}
} else {
rsCompensateDone++
if r.status == dtmcli.StatusSucceed {
rsCompensateSucceed++
}
}
dtmcli.Logf("branch done: %v", r)
case <-time.After(time.Duration(time.Second * 3)):
dtmcli.Logf("wait once for done")
}
}
for t.Status == dtmcli.StatusSubmitted && !t.isTimeout() && rsActionFailed == 0 && rsActionDone != rsActionToStart && !processorTimeout() {
pickAndRun(dtmcli.BranchAction)
waitOnceForDone()
}
if t.Status == dtmcli.StatusSubmitted && rsActionFailed == 0 && rsActionToStart == rsActionSucceed {
t.changeStatus(db, dtmcli.StatusSucceed)
return nil
}
if t.Status == dtmcli.StatusSubmitted && (rsActionFailed > 0 || t.isTimeout()) {
t.changeStatus(db, dtmcli.StatusAborting)
}
if t.Status == dtmcli.StatusAborting {
pickAndRun(dtmcli.BranchCompensate)
for rsCompensateDone != rsCompensateToStart && !processorTimeout() {
waitOnceForDone()
}
}
if (t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting) && rsActionFailed > 0 && rsCompensateToStart == rsCompensateSucceed {
t.changeStatus(db, dtmcli.StatusFailed)
}
return nil
}

22
dtmsvr/trans_msg.go

@ -33,23 +33,26 @@ func (t *transMsgProcessor) GenBranches() []TransBranch {
} }
func (t *TransGlobal) mayQueryPrepared(db *common.DB) { func (t *TransGlobal) mayQueryPrepared(db *common.DB) {
if t.Status != dtmcli.StatusPrepared { if !t.needProcess() || t.Status == dtmcli.StatusSubmitted {
return return
} }
body := t.getURLResult(t.QueryPrepared, "", "", nil) body, err := t.getURLResult(t.QueryPrepared, "", "", nil)
if strings.Contains(body, dtmcli.ResultSuccess) { if strings.Contains(body, dtmcli.ResultSuccess) {
t.changeStatus(db, dtmcli.StatusSubmitted) t.changeStatus(db, dtmcli.StatusSubmitted)
} else if strings.Contains(body, dtmcli.ResultFailure) { } else if strings.Contains(body, dtmcli.ResultFailure) {
t.changeStatus(db, dtmcli.StatusFailed) t.changeStatus(db, dtmcli.StatusFailed)
} else if strings.Contains(body, dtmcli.ResultOngoing) {
t.touch(db, cronReset)
} else { } else {
t.touch(db, t.NextCronInterval*2) dtmcli.LogRedf("getting result failed for %s. error: %s", t.QueryPrepared, err.Error())
t.touch(db, cronBackoff)
} }
} }
func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
t.mayQueryPrepared(db) t.mayQueryPrepared(db)
if t.Status != dtmcli.StatusSubmitted { if !t.needProcess() || t.Status == dtmcli.StatusPrepared {
return return nil
} }
current := 0 // 当前正在处理的步骤 current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ { for ; current < len(branches); current++ {
@ -57,14 +60,17 @@ func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if branch.BranchType != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared { if branch.BranchType != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared {
continue continue
} }
t.execBranch(db, branch) err := t.execBranch(db, branch)
if err != nil {
return err
}
if branch.Status != dtmcli.StatusSucceed { if branch.Status != dtmcli.StatusSucceed {
break break
} }
} }
if current == len(branches) { // msg 事务完成 if current == len(branches) { // msg 事务完成
t.changeStatus(db, dtmcli.StatusSucceed) t.changeStatus(db, dtmcli.StatusSucceed)
return return nil
} }
panic("msg go pass all branch") panic("msg go pass all branch")
} }

29
dtmsvr/trans_saga.go

@ -15,7 +15,7 @@ func init() {
registorProcessorCreator("saga", func(trans *TransGlobal) transProcessor { return &transSagaProcessor{TransGlobal: trans} }) registorProcessorCreator("saga", func(trans *TransGlobal) transProcessor { return &transSagaProcessor{TransGlobal: trans} })
} }
func (t *transSagaProcessor) GenBranches() []TransBranch { func genSagaBranches(t *TransGlobal) []TransBranch {
branches := []TransBranch{} branches := []TransBranch{}
steps := []M{} steps := []M{}
dtmcli.MustUnmarshalString(t.Data, &steps) dtmcli.MustUnmarshalString(t.Data, &steps)
@ -35,9 +35,13 @@ func (t *transSagaProcessor) GenBranches() []TransBranch {
return branches return branches
} }
func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transSagaProcessor) GenBranches() []TransBranch {
if t.Status == dtmcli.StatusFailed || t.Status == dtmcli.StatusSucceed { return genSagaBranches(t.TransGlobal)
return }
func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
if !t.needProcess() {
return nil
} }
current := 0 // 当前正在处理的步骤 current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ { for ; current < len(branches); current++ {
@ -47,7 +51,10 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch)
} }
// 找到了一个非succeed的action // 找到了一个非succeed的action
if branch.Status == dtmcli.StatusPrepared { if branch.Status == dtmcli.StatusPrepared {
t.execBranch(db, branch) err := t.execBranch(db, branch)
if err != nil {
return err
}
} }
if branch.Status != dtmcli.StatusSucceed { if branch.Status != dtmcli.StatusSucceed {
break break
@ -55,17 +62,21 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch)
} }
if current == len(branches) { // saga 事务完成 if current == len(branches) { // saga 事务完成
t.changeStatus(db, dtmcli.StatusSucceed) t.changeStatus(db, dtmcli.StatusSucceed)
return return nil
} }
if t.Status != "aborting" && t.Status != dtmcli.StatusFailed { if t.Status != dtmcli.StatusAborting && t.Status != dtmcli.StatusFailed {
t.changeStatus(db, "aborting") t.changeStatus(db, dtmcli.StatusAborting)
} }
for current = current - 1; current >= 0; current-- { for current = current - 1; current >= 0; current-- {
branch := &branches[current] branch := &branches[current]
if branch.BranchType != dtmcli.BranchCompensate || branch.Status != dtmcli.StatusPrepared { if branch.BranchType != dtmcli.BranchCompensate || branch.Status != dtmcli.StatusPrepared {
continue continue
} }
t.execBranch(db, branch) err := t.execBranch(db, branch)
if err != nil {
return err
}
} }
t.changeStatus(db, dtmcli.StatusFailed) t.changeStatus(db, dtmcli.StatusFailed)
return nil
} }

12
dtmsvr/trans_tcc.go

@ -17,15 +17,19 @@ func (t *transTccProcessor) GenBranches() []TransBranch {
return []TransBranch{} return []TransBranch{}
} }
func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
if t.Status == dtmcli.StatusSucceed || t.Status == dtmcli.StatusFailed { if !t.needProcess() {
return return nil
} }
branchType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchConfirm, dtmcli.BranchCancel).(string) branchType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchConfirm, dtmcli.BranchCancel).(string)
for current := len(branches) - 1; current >= 0; current-- { for current := len(branches) - 1; current >= 0; current-- {
if branches[current].BranchType == branchType && branches[current].Status == dtmcli.StatusPrepared { if branches[current].BranchType == branchType && branches[current].Status == dtmcli.StatusPrepared {
t.execBranch(db, &branches[current]) err := t.execBranch(db, &branches[current])
if err != nil {
return err
}
} }
} }
t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string))
return nil
} }

12
dtmsvr/trans_xa.go

@ -17,15 +17,19 @@ func (t *transXaProcessor) GenBranches() []TransBranch {
return []TransBranch{} return []TransBranch{}
} }
func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error {
if t.Status == dtmcli.StatusSucceed { if !t.needProcess() {
return return nil
} }
currentType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchCommit, dtmcli.BranchRollback).(string) currentType := dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchCommit, dtmcli.BranchRollback).(string)
for _, branch := range branches { for _, branch := range branches {
if branch.BranchType == currentType && branch.Status != dtmcli.StatusSucceed { if branch.BranchType == currentType && branch.Status != dtmcli.StatusSucceed {
t.execBranch(db, &branch) err := t.execBranch(db, &branch)
if err != nil {
return err
}
} }
} }
t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) t.changeStatus(db, dtmcli.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string))
return nil
} }

6
dtmsvr/utils.go

@ -16,9 +16,9 @@ import (
type M = map[string]interface{} type M = map[string]interface{}
type branchStatus struct { type branchStatus struct {
id uint id uint64
status string status string
finish_time *time.Time finishTime *time.Time
} }
var p2e = dtmcli.P2E var p2e = dtmcli.P2E

12
dtmsvr/utils_test.go

@ -31,3 +31,15 @@ func TestCheckLocalHost(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestSetNextCron(t *testing.T) {
tg := TransGlobal{}
tg.RetryInterval = 15
tg.setNextCron(cronReset)
assert.Equal(t, int64(15), tg.NextCronInterval)
tg.RetryInterval = 0
tg.setNextCron(cronReset)
assert.Equal(t, config.RetryInterval, tg.NextCronInterval)
tg.setNextCron(cronBackoff)
assert.Equal(t, config.RetryInterval*2, tg.NextCronInterval)
}

6
examples/base_grpc.go

@ -45,7 +45,7 @@ func handleGrpcBusiness(in *dtmgrpc.BusiRequest, result1 string, result2 string,
if res == dtmcli.ResultSuccess { if res == dtmcli.ResultSuccess {
return nil return nil
} else if res == dtmcli.ResultFailure { } else if res == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user want to rollback").Err() return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
} }
return status.New(codes.Internal, fmt.Sprintf("unknow result %s", res)).Err() return status.New(codes.Internal, fmt.Sprintf("unknow result %s", res)).Err()
} }
@ -113,7 +113,7 @@ func (s *busiServer) TransInXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*d
dtmcli.MustUnmarshal(in.BusiData, &req) dtmcli.MustUnmarshal(in.BusiData, &req)
return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error {
if req.TransInResult == dtmcli.ResultFailure { if req.TransInResult == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user return failure").Err() return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
} }
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2) _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
return err return err
@ -125,7 +125,7 @@ func (s *busiServer) TransOutXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*
dtmcli.MustUnmarshal(in.BusiData, &req) dtmcli.MustUnmarshal(in.BusiData, &req)
return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error { return &dtmgrpc.BusiReply{BusiData: []byte("reply")}, XaGrpcClient.XaLocalTransaction(in, func(db *sql.DB, xa *dtmgrpc.XaGrpc) error {
if req.TransOutResult == dtmcli.ResultFailure { if req.TransOutResult == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user return failure").Err() return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
} }
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1) _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
return err return err

9
examples/base_http.go

@ -2,6 +2,7 @@ package examples
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"time" "time"
@ -145,4 +146,12 @@ func BaseAddRoute(app *gin.Engine) {
}) })
})) }))
app.POST(BusiAPI+"/TestPanic", common.WrapHandler(func(c *gin.Context) (interface{}, error) {
if c.Query("panic_error") != "" {
panic(errors.New("panic_error"))
} else if c.Query("panic_string") != "" {
panic("panic_string")
}
return "SUCCESS", nil
}))
} }

2
examples/grpc_saga_barrier.go

@ -24,7 +24,7 @@ func init() {
func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error {
if result == dtmcli.ResultFailure { if result == dtmcli.ResultFailure {
return status.New(codes.Aborted, "user rollback").Err() return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
} }
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return err return err

2
examples/http_msg.go

@ -11,7 +11,7 @@ func init() {
msg := dtmcli.NewMsg(DtmServer, dtmcli.MustGenGid(DtmServer)). msg := dtmcli.NewMsg(DtmServer, dtmcli.MustGenGid(DtmServer)).
Add(Busi+"/TransOut", req). Add(Busi+"/TransOut", req).
Add(Busi+"/TransIn", req) Add(Busi+"/TransIn", req)
err := msg.Prepare(Busi + "/TransQuery") err := msg.Prepare(Busi + "/query")
dtmcli.FatalIfError(err) dtmcli.FatalIfError(err)
dtmcli.Logf("busi trans submit") dtmcli.Logf("busi trans submit")
err = msg.Submit() err = msg.Submit()

2
examples/http_saga.go

@ -23,7 +23,7 @@ func init() {
saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)).
Add(Busi+"/TransOut", Busi+"/TransOutRevert", req). Add(Busi+"/TransOut", Busi+"/TransOutRevert", req).
Add(Busi+"/TransIn", Busi+"/TransInRevert", req) Add(Busi+"/TransIn", Busi+"/TransInRevert", req)
saga.WaitResult = true // 设置为等待结果模式,后面的submit调用,会等待服务器处理这个事务。如果Submit正常返回,那么整个全局事务已成功完成 saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit() err := saga.Submit()
dtmcli.Logf("result gid is: %s", saga.Gid) dtmcli.Logf("result gid is: %s", saga.Gid)
dtmcli.FatalIfError(err) dtmcli.FatalIfError(err)

53
test/base_test.go

@ -0,0 +1,53 @@
package test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/examples"
)
func TestSqlDB(t *testing.T) {
asserts := assert.New(t)
db := common.DbGet(config.DB)
barrier := &dtmcli.BranchBarrier{
TransType: "saga",
Gid: "gid2",
BranchID: "branch_id2",
BranchType: dtmcli.BranchAction,
}
db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
tx, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx, func(db dtmcli.DB) error {
dtmcli.Logf("rollback gid2")
return fmt.Errorf("gid2 error")
})
asserts.Error(err, fmt.Errorf("gid2 error"))
dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
barrier.BarrierID = 0
tx2, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx2, func(db dtmcli.DB) error {
dtmcli.Logf("submit gid2")
return nil
})
asserts.Nil(err)
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
}
func TestHttp(t *testing.T) {
resp, err := dtmcli.RestyClient.R().SetQueryParam("panic_string", "1").Post(examples.Busi + "/TestPanic")
assert.Nil(t, err)
assert.Contains(t, resp.String(), "panic_string")
resp, err = dtmcli.RestyClient.R().SetQueryParam("panic_error", "1").Post(examples.Busi + "/TestPanic")
assert.Nil(t, err)
assert.Contains(t, resp.String(), "panic_error")
}

36
test/dtmsvr_test.go

@ -1,7 +1,6 @@
package test package test
import ( import (
"fmt"
"testing" "testing"
"time" "time"
@ -83,43 +82,10 @@ func transQuery(t *testing.T, gid string) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestSqlDB(t *testing.T) {
asserts := assert.New(t)
db := common.DbGet(config.DB)
barrier := &dtmcli.BranchBarrier{
TransType: "saga",
Gid: "gid2",
BranchID: "branch_id2",
BranchType: dtmcli.BranchAction,
}
db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
tx, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx, func(db dtmcli.DB) error {
dtmcli.Logf("rollback gid2")
return fmt.Errorf("gid2 error")
})
asserts.Error(err, fmt.Errorf("gid2 error"))
dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
barrier.BarrierID = 0
tx2, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx2, func(db dtmcli.DB) error {
dtmcli.Logf("submit gid2")
return nil
})
asserts.Nil(err)
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
}
func TestUpdateBranchAsync(t *testing.T) { func TestUpdateBranchAsync(t *testing.T) {
common.DtmConfig.UpdateBranchSync = 0 common.DtmConfig.UpdateBranchSync = 0
saga := genSaga("gid-update-branch-async", false, false) saga := genSaga("gid-update-branch-async", false, false)
saga.WaitResult = true saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit() err := saga.Submit()
assert.Nil(t, err) assert.Nil(t, err)
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)

12
test/grpc_msg_test.go

@ -12,7 +12,7 @@ import (
func TestGrpcMsg(t *testing.T) { func TestGrpcMsg(t *testing.T) {
grpcMsgNormal(t) grpcMsgNormal(t)
grpcMsgPending(t) grpcMsgOngoing(t)
} }
func grpcMsgNormal(t *testing.T) { func grpcMsgNormal(t *testing.T) {
@ -23,15 +23,15 @@ func grpcMsgNormal(t *testing.T) {
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))
} }
func grpcMsgPending(t *testing.T) { func grpcMsgOngoing(t *testing.T) {
msg := genGrpcMsg("grpc-msg-pending") msg := genGrpcMsg("grpc-msg-pending")
err := msg.Prepare(fmt.Sprintf("%s/examples.Busi/CanSubmit", examples.BusiGrpc)) err := msg.Prepare(fmt.Sprintf("%s/examples.Busi/CanSubmit", examples.BusiGrpc))
assert.Nil(t, err) assert.Nil(t, err)
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing)
CronTransOnce() cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.TransInResult.SetOnce("PENDING") examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing)
CronTransOnce() cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid))
CronTransOnce() CronTransOnce()
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))

12
test/grpc_saga_test.go

@ -11,7 +11,7 @@ import (
func TestGrpcSaga(t *testing.T) { func TestGrpcSaga(t *testing.T) {
sagaGrpcNormal(t) sagaGrpcNormal(t)
sagaGrpcCommittedPending(t) sagaGrpcCommittedOngoing(t)
sagaGrpcRollback(t) sagaGrpcRollback(t)
} }
@ -24,9 +24,9 @@ func sagaGrpcNormal(t *testing.T) {
transQuery(t, saga.Gid) transQuery(t, saga.Gid)
} }
func sagaGrpcCommittedPending(t *testing.T) { func sagaGrpcCommittedOngoing(t *testing.T) {
saga := genSagaGrpc("gid-committedPendingGrpc", false, false) saga := genSagaGrpc("gid-committedOngoingGrpc", false, false)
examples.MainSwitch.TransOutResult.SetOnce("PENDING") examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
saga.Submit() saga.Submit()
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid))
@ -37,10 +37,10 @@ func sagaGrpcCommittedPending(t *testing.T) {
func sagaGrpcRollback(t *testing.T) { func sagaGrpcRollback(t *testing.T) {
saga := genSagaGrpc("gid-rollbackSaga2Grpc", false, true) saga := genSagaGrpc("gid-rollbackSaga2Grpc", false, true)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
saga.Submit() saga.Submit()
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, "aborting", getTransStatus(saga.Gid)) assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid))
CronTransOnce() CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid))

4
test/grpc_tcc_test.go

@ -53,13 +53,13 @@ func tccGrpcRollback(t *testing.T) {
err := dtmgrpc.TccGlobalTransaction(examples.DtmGrpcServer, gid, func(tcc *dtmgrpc.TccGrpc) error { err := dtmgrpc.TccGlobalTransaction(examples.DtmGrpcServer, gid, func(tcc *dtmgrpc.TccGrpc) error {
_, err := tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransOutTcc", examples.BusiGrpc+"/examples.Busi/TransOutConfirm", examples.BusiGrpc+"/examples.Busi/TransOutRevert") _, err := tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransOutTcc", examples.BusiGrpc+"/examples.Busi/TransOutConfirm", examples.BusiGrpc+"/examples.Busi/TransOutRevert")
assert.Nil(t, err) assert.Nil(t, err)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
_, err = tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransInTcc", examples.BusiGrpc+"/examples.Busi/TransInConfirm", examples.BusiGrpc+"/examples.Busi/TransInRevert") _, err = tcc.CallBranch(data, examples.BusiGrpc+"/examples.Busi/TransInTcc", examples.BusiGrpc+"/examples.Busi/TransInConfirm", examples.BusiGrpc+"/examples.Busi/TransInRevert")
return err return err
}) })
assert.Error(t, err) assert.Error(t, err)
WaitTransProcessed(gid) WaitTransProcessed(gid)
assert.Equal(t, "aborting", getTransStatus(gid)) assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid))
CronTransOnce() CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid))
} }

3
test/main_test.go

@ -14,7 +14,8 @@ import (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"])
dtmsvr.TransProcessedTestChan = make(chan string, 1) dtmsvr.TransProcessedTestChan = make(chan string, 1)
dtmsvr.CronForwardDuration = 60 * time.Second dtmsvr.NowForwardDuration = 0 * time.Second
dtmsvr.CronForwardDuration = 180 * time.Second
common.DtmConfig.UpdateBranchSync = 1 common.DtmConfig.UpdateBranchSync = 1
dtmsvr.PopulateDB(false) dtmsvr.PopulateDB(false)
examples.PopulateDB(false) examples.PopulateDB(false)

22
test/msg_test.go

@ -11,8 +11,8 @@ import (
func TestMsg(t *testing.T) { func TestMsg(t *testing.T) {
msgNormal(t) msgNormal(t)
msgPending(t) msgOngoing(t)
msgPendingFailed(t) msgOngoingFailed(t)
} }
func msgNormal(t *testing.T) { func msgNormal(t *testing.T) {
@ -24,30 +24,30 @@ func msgNormal(t *testing.T) {
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))
} }
func msgPending(t *testing.T) { func msgOngoing(t *testing.T) {
msg := genMsg("gid-msg-normal-pending") msg := genMsg("gid-msg-normal-pending")
msg.Prepare("") msg.Prepare("")
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing)
CronTransOnce() cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.TransInResult.SetOnce("PENDING") examples.MainSwitch.TransInResult.SetOnce(dtmcli.ResultOngoing)
CronTransOnce() cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(msg.Gid))
CronTransOnce() CronTransOnce()
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid))
} }
func msgPendingFailed(t *testing.T) { func msgOngoingFailed(t *testing.T) {
msg := genMsg("gid-msg-pending-failed") msg := genMsg("gid-msg-pending-failed")
msg.Prepare("") msg.Prepare("")
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing)
CronTransOnce() cronTransOnceForwardNow(180)
assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultFailure) examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultFailure)
CronTransOnce() cronTransOnceForwardNow(180)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(msg.Gid)) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(msg.Gid))
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(msg.Gid))
} }

72
test/saga_concurrent_test.go

@ -0,0 +1,72 @@
package test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/examples"
)
func TestCSaga(t *testing.T) {
csagaNormal(t)
csagaRollback(t)
csagaRollback2(t)
csagaCommittedOngoing(t)
}
func genCSaga(gid string, outFailed bool, inFailed bool) *dtmcli.ConcurrentSaga {
dtmcli.Logf("beginning a concurrent saga test ---------------- %s", gid)
csaga := dtmcli.NewConcurrentSaga(examples.DtmServer, gid)
req := examples.GenTransReq(30, outFailed, inFailed)
csaga.Add(examples.Busi+"/TransOut", examples.Busi+"/TransOutRevert", &req)
csaga.Add(examples.Busi+"/TransIn", examples.Busi+"/TransInRevert", &req)
return csaga
}
func csagaNormal(t *testing.T) {
csaga := genCSaga("gid-noraml-csaga", false, false)
csaga.Submit()
WaitTransProcessed(csaga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid))
}
func csagaRollback(t *testing.T) {
csaga := genCSaga("gid-rollback-csaga", true, false)
examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
err := csaga.Submit()
assert.Nil(t, err)
WaitTransProcessed(csaga.Gid)
assert.Equal(t, dtmcli.StatusAborting, getTransStatus(csaga.Gid))
CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
err = csaga.Submit()
assert.Error(t, err)
}
func csagaRollback2(t *testing.T) {
csaga := genCSaga("gid-rollback-csaga2", true, false)
csaga.AddStepOrder(1, []int{0})
err := csaga.Submit()
assert.Nil(t, err)
WaitTransProcessed(csaga.Gid)
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(csaga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusFailed, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(csaga.Gid))
err = csaga.Submit()
assert.Error(t, err)
}
func csagaCommittedOngoing(t *testing.T) {
csaga := genCSaga("gid-committed-ongoing-csaga", false, false)
examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
csaga.Submit()
WaitTransProcessed(csaga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusDoing, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(csaga.Gid))
CronTransOnce()
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusSucceed, dtmcli.StatusPrepared, dtmcli.StatusSucceed}, getBranchesStatus(csaga.Gid))
assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(csaga.Gid))
}

12
test/saga_test.go

@ -10,7 +10,7 @@ import (
func TestSaga(t *testing.T) { func TestSaga(t *testing.T) {
sagaNormal(t) sagaNormal(t)
sagaCommittedPending(t) sagaCommittedOngoing(t)
sagaRollback(t) sagaRollback(t)
} }
@ -25,9 +25,9 @@ func sagaNormal(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
func sagaCommittedPending(t *testing.T) { func sagaCommittedOngoing(t *testing.T) {
saga := genSaga("gid-committedPending", false, false) saga := genSaga("gid-committedOngoing", false, false)
examples.MainSwitch.TransOutResult.SetOnce("PENDING") examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
saga.Submit() saga.Submit()
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared, dtmcli.StatusPrepared}, getBranchesStatus(saga.Gid))
@ -38,11 +38,11 @@ func sagaCommittedPending(t *testing.T) {
func sagaRollback(t *testing.T) { func sagaRollback(t *testing.T) {
saga := genSaga("gid-rollbackSaga2", false, true) saga := genSaga("gid-rollbackSaga2", false, true)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
err := saga.Submit() err := saga.Submit()
assert.Nil(t, err) assert.Nil(t, err)
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, "aborting", getTransStatus(saga.Gid)) assert.Equal(t, dtmcli.StatusAborting, getTransStatus(saga.Gid))
CronTransOnce() CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(saga.Gid))
assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusSucceed, dtmcli.StatusFailed}, getBranchesStatus(saga.Gid))

4
test/tcc_test.go

@ -32,12 +32,12 @@ func tccRollback(t *testing.T) {
err := dtmcli.TccGlobalTransaction(examples.DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { err := dtmcli.TccGlobalTransaction(examples.DtmServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) {
_, rerr := tcc.CallBranch(data, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert") _, rerr := tcc.CallBranch(data, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert")
assert.Nil(t, rerr) assert.Nil(t, rerr)
examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing)
return tcc.CallBranch(data, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") return tcc.CallBranch(data, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert")
}) })
assert.Error(t, err) assert.Error(t, err)
WaitTransProcessed(gid) WaitTransProcessed(gid)
assert.Equal(t, "aborting", getTransStatus(gid)) assert.Equal(t, dtmcli.StatusAborting, getTransStatus(gid))
CronTransOnce() CronTransOnce()
assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid)) assert.Equal(t, dtmcli.StatusFailed, getTransStatus(gid))
} }

9
test/types.go

@ -1,6 +1,8 @@
package test package test
import ( import (
"time"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmsvr" "github.com/yedf/dtm/dtmsvr"
@ -27,3 +29,10 @@ type TransBranch = dtmsvr.TransBranch
// M alias // M alias
type M = dtmcli.M type M = dtmcli.M
func cronTransOnceForwardNow(seconds int) {
old := dtmsvr.NowForwardDuration
dtmsvr.NowForwardDuration = time.Duration(seconds) * time.Second
CronTransOnce()
dtmsvr.NowForwardDuration = old
}

14
test/wait_saga_test.go

@ -11,13 +11,13 @@ import (
func TestWaitSaga(t *testing.T) { func TestWaitSaga(t *testing.T) {
sagaNormalWait(t) sagaNormalWait(t)
sagaCommittedPendingWait(t) sagaCommittedOngoingWait(t)
sagaRollbackWait(t) sagaRollbackWait(t)
} }
func sagaNormalWait(t *testing.T) { func sagaNormalWait(t *testing.T) {
saga := genSaga("gid-noramlSagaWait", false, false) saga := genSaga("gid-noramlSagaWait", false, false)
saga.WaitResult = true saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit() err := saga.Submit()
assert.Nil(t, err) assert.Nil(t, err)
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
@ -26,10 +26,10 @@ func sagaNormalWait(t *testing.T) {
transQuery(t, saga.Gid) transQuery(t, saga.Gid)
} }
func sagaCommittedPendingWait(t *testing.T) { func sagaCommittedOngoingWait(t *testing.T) {
saga := genSaga("gid-committedPendingWait", false, false) saga := genSaga("gid-committedOngoingWait", false, false)
examples.MainSwitch.TransOutResult.SetOnce("PENDING") examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing)
saga.WaitResult = true saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit() err := saga.Submit()
assert.Error(t, err) assert.Error(t, err)
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
@ -41,7 +41,7 @@ func sagaCommittedPendingWait(t *testing.T) {
func sagaRollbackWait(t *testing.T) { func sagaRollbackWait(t *testing.T) {
saga := genSaga("gid-rollbackSaga2Wait", false, true) saga := genSaga("gid-rollbackSaga2Wait", false, true)
saga.WaitResult = true saga.SetOptions(&dtmcli.TransOptions{WaitResult: true})
err := saga.Submit() err := saga.Submit()
assert.Error(t, err) assert.Error(t, err)
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)

Loading…
Cancel
Save