From b8f369c3248bf77bd878a276bebbfd1a6c3542c5 Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Fri, 15 Oct 2021 13:41:23 +0800 Subject: [PATCH] split db specific op to a stand alone file --- dtmcli/db_special.go | 70 +++++++++++++++++++++++++++++++++++++++ dtmcli/db_special_test.go | 27 +++++++++++++++ dtmcli/utils.go | 36 +------------------- dtmcli/utils_test.go | 11 ------ dtmcli/xa_base.go | 8 ++--- dtmsvr/cron.go | 7 +--- 6 files changed, 103 insertions(+), 56 deletions(-) create mode 100644 dtmcli/db_special.go create mode 100644 dtmcli/db_special_test.go diff --git a/dtmcli/db_special.go b/dtmcli/db_special.go new file mode 100644 index 0000000..5807d7a --- /dev/null +++ b/dtmcli/db_special.go @@ -0,0 +1,70 @@ +package dtmcli + +import ( + "errors" + "fmt" + "strings" +) + +// DBSpecial db specific operations +type DBSpecial interface { + TimestampAdd(second int) string + GetPlaceHoldSQL(sql string) string + GetXaSQL(command string, xid string) string +} + +type mysqlDBSpecial struct{} + +func (*mysqlDBSpecial) TimestampAdd(second int) string { + return fmt.Sprintf("date_add(now(), interval %d second)", second) +} + +func (*mysqlDBSpecial) GetPlaceHoldSQL(sql string) string { + return sql +} + +func (*mysqlDBSpecial) GetXaSQL(command string, xid string) string { + return fmt.Sprintf("xa %s '%s'", command, xid) +} + +type postgresDBSpecial struct{} + +func (*postgresDBSpecial) TimestampAdd(second int) string { + return fmt.Sprintf("current_timestamp + interval '%d second'", second) +} + +func (*postgresDBSpecial) GetXaSQL(command string, xid string) string { + return map[string]string{ + "end": "", + "start": "begin", + "prepare": fmt.Sprintf("prepare transaction '%s'", xid), + "commit": fmt.Sprintf("commit prepared '%s'", xid), + "rollback": fmt.Sprintf("rollback prepared '%s'", xid), + }[command] +} + +func (*postgresDBSpecial) GetPlaceHoldSQL(sql string) string { + pos := 1 + parts := []string{} + b := 0 + for i := 0; i < len(sql); i++ { + if sql[i] == '?' { + parts = append(parts, sql[b:i]) + b = i + 1 + parts = append(parts, fmt.Sprintf("$%d", pos)) + pos++ + } + } + parts = append(parts, sql[b:]) + return strings.Join(parts, "") +} + +// GetDBSpecial get DBSpecial for DBDriver +func GetDBSpecial() DBSpecial { + if DBDriver == DriverMysql { + return &mysqlDBSpecial{} + } else if DBDriver == DriverPostgres { + return &postgresDBSpecial{} + } + panic(errors.New("unknown DBDriver, please set it to a valid driver: " + DBDriver)) +} diff --git a/dtmcli/db_special_test.go b/dtmcli/db_special_test.go new file mode 100644 index 0000000..4ed5607 --- /dev/null +++ b/dtmcli/db_special_test.go @@ -0,0 +1,27 @@ +package dtmcli + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDBSpecial(t *testing.T) { + old := DBDriver + DBDriver = "no-driver" + assert.Error(t, CatchP(func() { + GetDBSpecial() + })) + DBDriver = DriverMysql + sp := GetDBSpecial() + + assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?")) + assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1")) + assert.Equal(t, "date_add(now(), interval 1000 second)", sp.TimestampAdd(1000)) + DBDriver = DriverPostgres + sp = GetDBSpecial() + assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?")) + assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1")) + assert.Equal(t, "current_timestamp + interval '1000 second'", sp.TimestampAdd(1000)) + DBDriver = old +} diff --git a/dtmcli/utils.go b/dtmcli/utils.go index 3b800a1..8fec13e 100644 --- a/dtmcli/utils.go +++ b/dtmcli/utils.go @@ -215,7 +215,7 @@ func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr erro if sql == "" { return 0, nil } - sql = makeSQLCompatible(sql) + sql = GetDBSpecial().GetPlaceHoldSQL(sql) r, rerr := db.Exec(sql, values...) if rerr == nil { affected, rerr = r.RowsAffected() @@ -268,37 +268,3 @@ func CheckResult(res interface{}, err error) error { } return err } - -func makeSQLCompatible(sql string) string { - if DBDriver == DriverPostgres { - pos := 1 - parts := []string{} - b := 0 - for i := 0; i < len(sql); i++ { - if sql[i] == '?' { - parts = append(parts, sql[b:i]) - b = i + 1 - parts = append(parts, fmt.Sprintf("$%d", pos)) - pos++ - } - } - parts = append(parts, sql[b:]) - return strings.Join(parts, "") - } - PanicIf(DBDriver != DriverMysql, fmt.Errorf("unkown db driver: %s", DBDriver)) - return sql -} - -func getXaSQL(action string, xid string) string { - if DBDriver == DriverPostgres { - return map[string]string{ - "end": "", - "start": "begin", - "prepare": fmt.Sprintf("prepare transaction '%s'", xid), - "commit": fmt.Sprintf("commit prepared '%s'", xid), - "rollback": fmt.Sprintf("rollback prepared '%s'", xid), - }[action] - } - PanicIf(DBDriver != DriverMysql, fmt.Errorf("unkown db driver: %s", DBDriver)) - return fmt.Sprintf("xa %s '%s'", action, xid) -} diff --git a/dtmcli/utils_test.go b/dtmcli/utils_test.go index 364283e..a59d628 100644 --- a/dtmcli/utils_test.go +++ b/dtmcli/utils_test.go @@ -89,14 +89,3 @@ func TestFatal(t *testing.T) { }) assert.Error(t, err, fmt.Errorf("fatal")) } - -func TestCompatible(t *testing.T) { - old := DBDriver - DBDriver = DriverMysql - assert.Equal(t, "? ?", makeSQLCompatible("? ?")) - assert.Equal(t, "xa start 'xa1'", getXaSQL("start", "xa1")) - DBDriver = DriverPostgres - assert.Equal(t, "$1 $2", makeSQLCompatible("? ?")) - assert.Equal(t, "begin", getXaSQL("start", "xa1")) - DBDriver = old -} diff --git a/dtmcli/xa_base.go b/dtmcli/xa_base.go index e1e31f8..a8c3ceb 100644 --- a/dtmcli/xa_base.go +++ b/dtmcli/xa_base.go @@ -20,7 +20,7 @@ func (xc *XaClientBase) HandleCallback(gid string, branchID string, action strin } defer db.Close() xaID := gid + "-" + branchID - _, err = DBExec(db, getXaSQL(action, xaID)) + _, err = DBExec(db, GetDBSpecial().GetXaSQL(action, xaID)) if err != nil && (strings.Contains(err.Error(), "Error 1397: XAER_NOTA") || strings.Contains(err.Error(), "does not exist")) { // 重复commit/rollback同一个id,报这个错误,忽略 err = nil @@ -39,9 +39,9 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) (interf defer func() { db.Close() }() defer func() { x := recover() - _, err := DBExec(db, getXaSQL("end", xaBranch)) + _, err := DBExec(db, GetDBSpecial().GetXaSQL("end", xaBranch)) if x == nil && rerr == nil && err == nil { - _, err = DBExec(db, getXaSQL("prepare", xaBranch)) + _, err = DBExec(db, GetDBSpecial().GetXaSQL("prepare", xaBranch)) } if rerr == nil { rerr = err @@ -50,7 +50,7 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) (interf panic(x) } }() - _, rerr = DBExec(db, getXaSQL("start", xaBranch)) + _, rerr = DBExec(db, GetDBSpecial().GetXaSQL("start", xaBranch)) if rerr != nil { return } diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index f301754..86cd0aa 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -41,12 +41,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal { trans := TransGlobal{} owner := GenGid() db := dbGet() - getTime := func(second int) string { - return fmt.Sprintf(map[string]string{ - "mysql": "date_add(now(), interval %d second)", - "postgres": "current_timestamp + interval '%d second'", - }[dtmcli.DBDriver], second) - } + getTime := dtmcli.GetDBSpecial().TimestampAdd expire := int(expireIn / time.Second) whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3)) // 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢