Browse Source

Merge pull request #303 from dtm-labs/alpha

Support db in server config
pull/305/head
yedf2 4 years ago
committed by GitHub
parent
commit
4f1ce7480d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      conf.sample.yml
  2. 18
      dtmcli/barrier.go
  3. 4
      dtmcli/dtmimp/db_special.go
  4. 4
      dtmcli/dtmimp/db_special_test.go
  5. 12
      dtmcli/dtmimp/trans_xa_base.go
  6. 16
      dtmcli/dtmimp/utils.go
  7. 4
      dtmsvr/config/config.go
  8. 5
      dtmsvr/storage/trans.go
  9. 2
      dtmutil/utils.go
  10. 14
      helper/bench/svr/http.go
  11. 3
      test/busi/base_http.go
  12. 9
      test/busi/busi.go
  13. 6
      test/common_test.go
  14. 2
      test/main_test.go
  15. 4
      test/xa_test.go

3
conf.sample.yml

@ -11,6 +11,7 @@
# User: 'root' # User: 'root'
# Password: '' # Password: ''
# Port: 3306 # Port: 3306
# Db: 'dtm'
# Driver: 'boltdb' # default store engine # Driver: 'boltdb' # default store engine
@ -30,8 +31,6 @@
# MaxOpenConns: 500 # MaxOpenConns: 500
# MaxIdleConns: 500 # MaxIdleConns: 500
# ConnMaxLifeTime: 5 # default value is 5 (minutes) # ConnMaxLifeTime: 5 # default value is 5 (minutes)
# TransGlobalTable: 'dtm.trans_global'
# TransBranchOpTable: 'dtm.trans_branch_op'
### flollowing config is only for some Driver ### flollowing config is only for some Driver
# DataExpire: 604800 # Trans data will expire in 7 days. only for redis/boltdb. # DataExpire: 604800 # Trans data will expire in 7 days. only for redis/boltdb.

18
dtmcli/barrier.go

@ -20,11 +20,13 @@ type BarrierBusiFunc func(tx *sql.Tx) error
// BranchBarrier every branch info // BranchBarrier every branch info
type BranchBarrier struct { type BranchBarrier struct {
TransType string TransType string
Gid string Gid string
BranchID string BranchID string
Op string Op string
BarrierID int BarrierID int
DBType string // DBTypeMysql | DBTypePostgres
BarrierTableName string
} }
func (bb *BranchBarrier) String() string { func (bb *BranchBarrier) String() string {
@ -70,8 +72,8 @@ func (bb *BranchBarrier) Call(tx *sql.Tx, busiCall BarrierBusiFunc) (rerr error)
dtmimp.OpCompensate: dtmimp.OpAction, dtmimp.OpCompensate: dtmimp.OpAction,
}[bb.Op] }[bb.Op]
originAffected, oerr := dtmimp.InsertBarrier(tx, bb.TransType, bb.Gid, bb.BranchID, originOp, bid, bb.Op) originAffected, oerr := dtmimp.InsertBarrier(tx, bb.TransType, bb.Gid, bb.BranchID, originOp, bid, bb.Op, bb.DBType, bb.BarrierTableName)
currentAffected, rerr := dtmimp.InsertBarrier(tx, bb.TransType, bb.Gid, bb.BranchID, bb.Op, bid, bb.Op) currentAffected, rerr := dtmimp.InsertBarrier(tx, bb.TransType, bb.Gid, bb.BranchID, bb.Op, bid, bb.Op, bb.DBType, bb.BarrierTableName)
logger.Debugf("originAffected: %d currentAffected: %d", originAffected, currentAffected) logger.Debugf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
if rerr == nil && bb.Op == dtmimp.MsgDoOp && currentAffected == 0 { // for msg's DoAndSubmit, repeated insert should be rejected. if rerr == nil && bb.Op == dtmimp.MsgDoOp && currentAffected == 0 { // for msg's DoAndSubmit, repeated insert should be rejected.
@ -103,7 +105,7 @@ func (bb *BranchBarrier) CallWithDB(db *sql.DB, busiCall BarrierBusiFunc) error
// QueryPrepared queries prepared data // QueryPrepared queries prepared data
func (bb *BranchBarrier) QueryPrepared(db *sql.DB) error { func (bb *BranchBarrier) QueryPrepared(db *sql.DB) error {
_, err := dtmimp.InsertBarrier(db, bb.TransType, bb.Gid, dtmimp.MsgDoBranch0, dtmimp.MsgDoOp, dtmimp.MsgDoBarrier1, dtmimp.OpRollback) _, err := dtmimp.InsertBarrier(db, bb.TransType, bb.Gid, dtmimp.MsgDoBranch0, dtmimp.MsgDoOp, dtmimp.MsgDoBarrier1, dtmimp.OpRollback, bb.DBType, bb.BarrierTableName)
var reason string var reason string
if err == nil { if err == nil {
sql := fmt.Sprintf("select reason from %s where gid=? and branch_id=? and op=? and barrier_id=?", dtmimp.BarrierTableName) sql := fmt.Sprintf("select reason from %s where gid=? and branch_id=? and op=? and barrier_id=?", dtmimp.BarrierTableName)

4
dtmcli/dtmimp/db_special.go

@ -75,8 +75,8 @@ func init() {
} }
// GetDBSpecial get DBSpecial for currentDBType // GetDBSpecial get DBSpecial for currentDBType
func GetDBSpecial() DBSpecial { func GetDBSpecial(dbType string) DBSpecial {
return dbSpecials[currentDBType] return dbSpecials[dbType]
} }
// SetCurrentDBType set currentDBType // SetCurrentDBType set currentDBType

4
dtmcli/dtmimp/db_special_test.go

@ -18,13 +18,13 @@ func TestDBSpecial(t *testing.T) {
SetCurrentDBType("no-driver") SetCurrentDBType("no-driver")
})) }))
SetCurrentDBType(DBTypeMysql) SetCurrentDBType(DBTypeMysql)
sp := GetDBSpecial() sp := GetDBSpecial(DBTypeMysql)
assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?")) assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?"))
assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1")) assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1"))
assert.Equal(t, "insert ignore into a(f) values(?)", sp.GetInsertIgnoreTemplate("a(f) values(?)", "c")) assert.Equal(t, "insert ignore into a(f) values(?)", sp.GetInsertIgnoreTemplate("a(f) values(?)", "c"))
SetCurrentDBType(DBTypePostgres) SetCurrentDBType(DBTypePostgres)
sp = GetDBSpecial() sp = GetDBSpecial(DBTypePostgres)
assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?")) assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?"))
assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1")) assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1"))
assert.Equal(t, "insert into a(f) values(?) on conflict ON CONSTRAINT c do nothing", sp.GetInsertIgnoreTemplate("a(f) values(?)", "c")) assert.Equal(t, "insert into a(f) values(?) on conflict ON CONSTRAINT c do nothing", sp.GetInsertIgnoreTemplate("a(f) values(?)", "c"))

12
dtmcli/dtmimp/trans_xa_base.go

@ -18,14 +18,14 @@ func XaHandlePhase2(gid string, dbConf DBConf, branchID string, op string) error
return err return err
} }
xaID := gid + "-" + branchID xaID := gid + "-" + branchID
_, err = DBExec(db, GetDBSpecial().GetXaSQL(op, xaID)) _, err = DBExec(dbConf.Driver, db, GetDBSpecial(dbConf.Driver).GetXaSQL(op, xaID))
if err != nil && if err != nil &&
(strings.Contains(err.Error(), "XAER_NOTA") || strings.Contains(err.Error(), "does not exist")) { // Repeat commit/rollback with the same id, report this error, ignore (strings.Contains(err.Error(), "XAER_NOTA") || strings.Contains(err.Error(), "does not exist")) { // Repeat commit/rollback with the same id, report this error, ignore
err = nil err = nil
} }
if op == OpRollback && err == nil { if op == OpRollback && err == nil {
// rollback insert a row after prepare. no-error means prepare has finished. // rollback insert a row after prepare. no-error means prepare has finished.
_, err = InsertBarrier(db, "xa", gid, branchID, OpAction, XaBarrier1, op) _, err = InsertBarrier(db, "xa", gid, branchID, OpAction, XaBarrier1, op, dbConf.Driver, "")
} }
return err return err
} }
@ -39,20 +39,20 @@ func XaHandleLocalTrans(xa *TransBase, dbConf DBConf, cb func(*sql.DB) error) (r
} }
defer func() { _ = db.Close() }() defer func() { _ = db.Close() }()
defer DeferDo(&rerr, func() error { defer DeferDo(&rerr, func() error {
_, err := DBExec(db, GetDBSpecial().GetXaSQL("prepare", xaBranch)) _, err := DBExec(dbConf.Driver, db, GetDBSpecial(dbConf.Driver).GetXaSQL("prepare", xaBranch))
return err return err
}, func() error { }, func() error {
return nil return nil
}) })
_, rerr = DBExec(db, GetDBSpecial().GetXaSQL("start", xaBranch)) _, rerr = DBExec(dbConf.Driver, db, GetDBSpecial(dbConf.Driver).GetXaSQL("start", xaBranch))
if rerr != nil { if rerr != nil {
return return
} }
defer func() { defer func() {
_, _ = DBExec(db, GetDBSpecial().GetXaSQL("end", xaBranch)) _, _ = DBExec(dbConf.Driver, db, GetDBSpecial(dbConf.Driver).GetXaSQL("end", xaBranch))
}() }()
// prepare and rollback both insert a row // prepare and rollback both insert a row
_, rerr = InsertBarrier(db, xa.TransType, xa.Gid, xa.BranchID, OpAction, XaBarrier1, OpAction) _, rerr = InsertBarrier(db, xa.TransType, xa.Gid, xa.BranchID, OpAction, XaBarrier1, OpAction, dbConf.Driver, "")
if rerr == nil { if rerr == nil {
rerr = cb(db) rerr = cb(db)
} }

16
dtmcli/dtmimp/utils.go

@ -187,12 +187,12 @@ func XaDB(conf DBConf) (*sql.DB, error) {
} }
// DBExec use raw db to exec // DBExec use raw db to exec
func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) { func DBExec(dbType string, db DB, sql string, values ...interface{}) (affected int64, rerr error) {
if sql == "" { if sql == "" {
return 0, nil return 0, nil
} }
began := time.Now() began := time.Now()
sql = GetDBSpecial().GetPlaceHoldSQL(sql) sql = GetDBSpecial(dbType).GetPlaceHoldSQL(sql)
r, rerr := db.Exec(sql, values...) r, rerr := db.Exec(sql, values...)
used := time.Since(began) / time.Millisecond used := time.Since(began) / time.Millisecond
if rerr == nil { if rerr == nil {
@ -262,10 +262,16 @@ func EscapeGet(qs url.Values, key string) string {
} }
// InsertBarrier insert a record to barrier // InsertBarrier insert a record to barrier
func InsertBarrier(tx DB, transType string, gid string, branchID string, op string, barrierID string, reason string) (int64, error) { func InsertBarrier(tx DB, transType string, gid string, branchID string, op string, barrierID string, reason string, dbType string, barrierTableName string) (int64, error) {
if op == "" { if op == "" {
return 0, nil return 0, nil
} }
sql := GetDBSpecial().GetInsertIgnoreTemplate(BarrierTableName+"(trans_type, gid, branch_id, op, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier") if dbType == "" {
return DBExec(tx, sql, transType, gid, branchID, op, barrierID, reason) dbType = currentDBType
}
if barrierTableName == "" {
barrierTableName = BarrierTableName
}
sql := GetDBSpecial(dbType).GetInsertIgnoreTemplate(barrierTableName+"(trans_type, gid, branch_id, op, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier")
return DBExec(dbType, tx, sql, transType, gid, branchID, op, barrierID, reason)
} }

4
dtmsvr/config/config.go

@ -53,14 +53,13 @@ type Store struct {
Port int64 `yaml:"Port"` Port int64 `yaml:"Port"`
User string `yaml:"User"` User string `yaml:"User"`
Password string `yaml:"Password"` Password string `yaml:"Password"`
Db string `yaml:"Db" default:"dtm"`
MaxOpenConns int64 `yaml:"MaxOpenConns" default:"500"` MaxOpenConns int64 `yaml:"MaxOpenConns" default:"500"`
MaxIdleConns int64 `yaml:"MaxIdleConns" default:"500"` MaxIdleConns int64 `yaml:"MaxIdleConns" default:"500"`
ConnMaxLifeTime int64 `yaml:"ConnMaxLifeTime" default:"5"` ConnMaxLifeTime int64 `yaml:"ConnMaxLifeTime" default:"5"`
DataExpire int64 `yaml:"DataExpire" default:"604800"` // Trans data will expire in 7 days. only for redis/boltdb. DataExpire int64 `yaml:"DataExpire" default:"604800"` // Trans data will expire in 7 days. only for redis/boltdb.
FinishedDataExpire int64 `yaml:"FinishedDataExpire" default:"86400"` // finished Trans data will expire in 1 days. only for redis. FinishedDataExpire int64 `yaml:"FinishedDataExpire" default:"86400"` // finished Trans data will expire in 1 days. only for redis.
RedisPrefix string `yaml:"RedisPrefix" default:"{a}"` // Redis storage prefix. store data to only one slot in cluster RedisPrefix string `yaml:"RedisPrefix" default:"{a}"` // Redis storage prefix. store data to only one slot in cluster
TransGlobalTable string `yaml:"TransGlobalTable" default:"dtm.trans_global"`
TransBranchOpTable string `yaml:"TransBranchOpTable" default:"dtm.trans_branch_op"`
} }
// IsDB checks config driver is mysql or postgres // IsDB checks config driver is mysql or postgres
@ -76,6 +75,7 @@ func (s *Store) GetDBConf() dtmcli.DBConf {
Port: s.Port, Port: s.Port,
User: s.User, User: s.User,
Password: s.Password, Password: s.Password,
Db: s.Db,
} }
} }

5
dtmsvr/storage/trans.go

@ -11,7 +11,6 @@ import (
"github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli"
"github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/dtmsvr/config"
"github.com/dtm-labs/dtm/dtmutil" "github.com/dtm-labs/dtm/dtmutil"
) )
@ -45,7 +44,7 @@ type TransGlobalStore struct {
// TableName TableName // TableName TableName
func (g *TransGlobalStore) TableName() string { func (g *TransGlobalStore) TableName() string {
return config.Config.Store.TransGlobalTable return "trans_global"
} }
func (g *TransGlobalStore) String() string { func (g *TransGlobalStore) String() string {
@ -67,7 +66,7 @@ type TransBranchStore struct {
// TableName TableName // TableName TableName
func (b *TransBranchStore) TableName() string { func (b *TransBranchStore) TableName() string {
return config.Config.Store.TransBranchOpTable return "trans_branch_op"
} }
func (b *TransBranchStore) String() string { func (b *TransBranchStore) String() string {

2
dtmutil/utils.go

@ -168,7 +168,7 @@ func RunSQLScript(conf dtmcli.DBConf, script string, skipDrop bool) {
if s == "" || (skipDrop && strings.Contains(s, "drop")) { if s == "" || (skipDrop && strings.Contains(s, "drop")) {
continue continue
} }
_, err = dtmimp.DBExec(con, s) _, err = dtmimp.DBExec(conf.Driver, con, s)
logger.FatalIfError(err) logger.FatalIfError(err)
logger.Infof("sql scripts finished: %s", s) logger.Infof("sql scripts finished: %s", s)
} }

14
helper/bench/svr/http.go

@ -53,7 +53,7 @@ func reloadData() {
db := pdbGet() db := pdbGet()
tables := []string{"dtm_busi.user_account", "dtm_busi.user_account_log", "dtm.trans_global", "dtm.trans_branch_op", "dtm_barrier.barrier"} tables := []string{"dtm_busi.user_account", "dtm_busi.user_account_log", "dtm.trans_global", "dtm.trans_branch_op", "dtm_barrier.barrier"}
for _, t := range tables { for _, t := range tables {
_, err := dtmimp.DBExec(db, fmt.Sprintf("truncate %s", t)) _, err := dtmimp.DBExec(busi.BusiConf.Driver, db, fmt.Sprintf("truncate %s", t))
logger.FatalIfError(err) logger.FatalIfError(err)
} }
s := "insert ignore into dtm_busi.user_account(user_id, balance) values " s := "insert ignore into dtm_busi.user_account(user_id, balance) values "
@ -61,7 +61,7 @@ func reloadData() {
for i := 1; i <= total; i++ { for i := 1; i <= total; i++ {
ss = append(ss, fmt.Sprintf("(%d, 1000000)", i)) ss = append(ss, fmt.Sprintf("(%d, 1000000)", i))
} }
_, err := dtmimp.DBExec(db, s+strings.Join(ss, ",")) _, err := dtmimp.DBExec(busi.BusiConf.Driver, db, s+strings.Join(ss, ","))
logger.FatalIfError(err) logger.FatalIfError(err)
logger.Debugf("%d users inserted. used: %dms", total, time.Since(began).Milliseconds()) logger.Debugf("%d users inserted. used: %dms", total, time.Since(began).Milliseconds())
} }
@ -73,11 +73,11 @@ var sqls = 1
// PrepareBenchDB prepares db data for bench // PrepareBenchDB prepares db data for bench
func PrepareBenchDB() { func PrepareBenchDB() {
db := pdbGet() db := pdbGet()
_, err := dtmimp.DBExec(db, "CREATE DATABASE if not exists dtm_busi") _, err := dtmimp.DBExec(busi.BusiConf.Driver, db, "CREATE DATABASE if not exists dtm_busi")
logger.FatalIfError(err) logger.FatalIfError(err)
_, err = dtmimp.DBExec(db, "drop table if exists dtm_busi.user_account_log") _, err = dtmimp.DBExec(busi.BusiConf.Driver, db, "drop table if exists dtm_busi.user_account_log")
logger.FatalIfError(err) logger.FatalIfError(err)
_, err = dtmimp.DBExec(db, `create table if not exists dtm_busi.user_account_log ( _, err = dtmimp.DBExec(busi.BusiConf.Driver, db, `create table if not exists dtm_busi.user_account_log (
id INT(11) AUTO_INCREMENT PRIMARY KEY, id INT(11) AUTO_INCREMENT PRIMARY KEY,
user_id INT(11) NOT NULL, user_id INT(11) NOT NULL,
delta DECIMAL(11, 2) not null, delta DECIMAL(11, 2) not null,
@ -111,10 +111,10 @@ func qsAdjustBalance(uid int, amount int, c *gin.Context) error { // nolint: unp
tb := dtmimp.TransBaseFromQuery(c.Request.URL.Query()) tb := dtmimp.TransBaseFromQuery(c.Request.URL.Query())
f := func(tx *sql.Tx) error { f := func(tx *sql.Tx) error {
for i := 0; i < sqls; i++ { for i := 0; i < sqls; i++ {
_, err := dtmimp.DBExec(tx, "insert into dtm_busi.user_account_log(user_id, delta, gid, branch_id, op, reason) values(?,?,?,?,?,?)", _, err := dtmimp.DBExec(busi.BusiConf.Driver, tx, "insert into dtm_busi.user_account_log(user_id, delta, gid, branch_id, op, reason) values(?,?,?,?,?,?)",
uid, amount, tb.Gid, c.Query("branch_id"), tb.TransType, fmt.Sprintf("inserted by dtm transaction %s %s", tb.Gid, c.Query("branch_id"))) uid, amount, tb.Gid, c.Query("branch_id"), tb.TransType, fmt.Sprintf("inserted by dtm transaction %s %s", tb.Gid, c.Query("branch_id")))
logger.FatalIfError(err) logger.FatalIfError(err)
_, err = dtmimp.DBExec(tx, "update dtm_busi.user_account set balance = balance + ?, update_time = now() where user_id = ?", amount, uid) _, err = dtmimp.DBExec(busi.BusiConf.Driver, tx, "update dtm_busi.user_account set balance = balance + ?, update_time = now() where user_id = ?", amount, uid)
logger.FatalIfError(err) logger.FatalIfError(err)
} }
return nil return nil

3
test/busi/base_http.go

@ -69,7 +69,8 @@ func BaseAppStartup() *gin.Engine {
} }
logger.Debugf("Starting busi at: %d", BusiPort) logger.Debugf("Starting busi at: %d", BusiPort)
go func() { go func() {
_ = app.Run(fmt.Sprintf(":%d", BusiPort)) err := app.Run(fmt.Sprintf(":%d", BusiPort))
dtmimp.FatalIfError(err)
}() }()
return app return app
} }

9
test/busi/busi.go

@ -66,7 +66,7 @@ func sagaGrpcAdjustBalance(db dtmcli.DB, uid int, amount int64, result string) e
if result == dtmcli.ResultFailure { if result == dtmcli.ResultFailure {
return status.New(codes.Aborted, dtmcli.ResultFailure).Err() return status.New(codes.Aborted, dtmcli.ResultFailure).Err()
} }
_, err := dtmimp.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) _, err := dtmimp.DBExec(BusiConf.Driver, db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return err return err
} }
@ -75,7 +75,7 @@ func SagaAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error {
if strings.Contains(result, dtmcli.ResultFailure) { if strings.Contains(result, dtmcli.ResultFailure) {
return dtmcli.ErrFailure return dtmcli.ErrFailure
} }
_, err := dtmimp.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) _, err := dtmimp.DBExec(BusiConf.Driver, db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return err return err
} }
@ -102,11 +102,10 @@ func SagaMongoAdjustBalance(ctx context.Context, mc *mongo.Client, uid int, amou
return fmt.Errorf("balance not enough %w", dtmcli.ErrFailure) return fmt.Errorf("balance not enough %w", dtmcli.ErrFailure)
} }
return nil return nil
} }
func tccAdjustTrading(db dtmcli.DB, uid int, amount int) error { func tccAdjustTrading(db dtmcli.DB, uid int, amount int) error {
affected, err := dtmimp.DBExec(db, `update dtm_busi.user_account affected, err := dtmimp.DBExec(BusiConf.Driver, db, `update dtm_busi.user_account
set trading_balance=trading_balance+? set trading_balance=trading_balance+?
where user_id=? and trading_balance + ? + balance >= 0`, amount, uid, amount) where user_id=? and trading_balance + ? + balance >= 0`, amount, uid, amount)
if err == nil && affected == 0 { if err == nil && affected == 0 {
@ -116,7 +115,7 @@ func tccAdjustTrading(db dtmcli.DB, uid int, amount int) error {
} }
func tccAdjustBalance(db dtmcli.DB, uid int, amount int) error { func tccAdjustBalance(db dtmcli.DB, uid int, amount int) error {
affected, err := dtmimp.DBExec(db, `update dtm_busi.user_account affected, err := dtmimp.DBExec(BusiConf.Driver, db, `update dtm_busi.user_account
set trading_balance=trading_balance-?, set trading_balance=trading_balance-?,
balance=balance+? where user_id=?`, amount, amount, uid) balance=balance+? where user_id=?`, amount, amount, uid)
if err == nil && affected == 0 { if err == nil && affected == 0 {

6
test/common_test.go

@ -33,12 +33,12 @@ func testSql(t *testing.T) {
func testDbAlone(t *testing.T) { func testDbAlone(t *testing.T) {
db, err := dtmimp.StandaloneDB(conf.Store.GetDBConf()) db, err := dtmimp.StandaloneDB(conf.Store.GetDBConf())
assert.Nil(t, err) assert.Nil(t, err)
_, err = dtmimp.DBExec(db, "select 1") _, err = dtmimp.DBExec(conf.Store.Driver, db, "select 1")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
_, err = dtmimp.DBExec(db, "") _, err = dtmimp.DBExec(conf.Store.Driver, db, "")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
db.Close() db.Close()
_, err = dtmimp.DBExec(db, "select 1") _, err = dtmimp.DBExec(conf.Store.Driver, db, "select 1")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }

2
test/main_test.go

@ -30,7 +30,6 @@ func exitIf(code int) {
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
config.MustLoadConfig("") config.MustLoadConfig("")
logger.InitLog("debug") logger.InitLog("debug")
dtmcli.SetCurrentDBType(busi.BusiConf.Driver)
dtmsvr.TransProcessedTestChan = make(chan string, 1) dtmsvr.TransProcessedTestChan = make(chan string, 1)
dtmsvr.NowForwardDuration = 0 * time.Second dtmsvr.NowForwardDuration = 0 * time.Second
dtmsvr.CronForwardDuration = 180 * time.Second dtmsvr.CronForwardDuration = 180 * time.Second
@ -59,6 +58,7 @@ func TestMain(m *testing.M) {
registry.WaitStoreUp() registry.WaitStoreUp()
dtmsvr.PopulateDB(false) dtmsvr.PopulateDB(false)
conf.Store.Db = "dtm" // after populateDB, set current db to dtm
go dtmsvr.StartSvr() go dtmsvr.StartSvr()
busi.PopulateDB(false) busi.PopulateDB(false)

4
test/xa_test.go

@ -43,10 +43,10 @@ func TestXaDuplicate(t *testing.T) {
sdb, err := dtmimp.StandaloneDB(busi.BusiConf) sdb, err := dtmimp.StandaloneDB(busi.BusiConf)
assert.Nil(t, err) assert.Nil(t, err)
if dtmcli.GetCurrentDBType() == dtmcli.DBTypeMysql { if dtmcli.GetCurrentDBType() == dtmcli.DBTypeMysql {
_, err = dtmimp.DBExec(sdb, "xa recover") _, err = dtmimp.DBExec(busi.BusiConf.Driver, sdb, "xa recover")
assert.Nil(t, err) assert.Nil(t, err)
} }
_, err = dtmimp.DBExec(sdb, dtmimp.GetDBSpecial().GetXaSQL("commit", gid+"-01")) // simulate repeated request _, err = dtmimp.DBExec(busi.BusiConf.Driver, sdb, dtmimp.GetDBSpecial(busi.BusiConf.Driver).GetXaSQL("commit", gid+"-01")) // simulate repeated request
assert.Nil(t, err) assert.Nil(t, err)
return xa.CallBranch(req, busi.Busi+"/TransInXa") return xa.CallBranch(req, busi.Busi+"/TransInXa")
}) })

Loading…
Cancel
Save