Browse Source

split db specific op to a stand alone file

pull/43/head
yedf2 5 years ago
parent
commit
b8f369c324
  1. 70
      dtmcli/db_special.go
  2. 27
      dtmcli/db_special_test.go
  3. 36
      dtmcli/utils.go
  4. 11
      dtmcli/utils_test.go
  5. 8
      dtmcli/xa_base.go
  6. 7
      dtmsvr/cron.go

70
dtmcli/db_special.go

@ -0,0 +1,70 @@
package dtmcli
import (
"errors"
"fmt"
"strings"
)
// DBSpecial db specific operations
type DBSpecial interface {
TimestampAdd(second int) string
GetPlaceHoldSQL(sql string) string
GetXaSQL(command string, xid string) string
}
type mysqlDBSpecial struct{}
func (*mysqlDBSpecial) TimestampAdd(second int) string {
return fmt.Sprintf("date_add(now(), interval %d second)", second)
}
func (*mysqlDBSpecial) GetPlaceHoldSQL(sql string) string {
return sql
}
func (*mysqlDBSpecial) GetXaSQL(command string, xid string) string {
return fmt.Sprintf("xa %s '%s'", command, xid)
}
type postgresDBSpecial struct{}
func (*postgresDBSpecial) TimestampAdd(second int) string {
return fmt.Sprintf("current_timestamp + interval '%d second'", second)
}
func (*postgresDBSpecial) GetXaSQL(command string, xid string) string {
return map[string]string{
"end": "",
"start": "begin",
"prepare": fmt.Sprintf("prepare transaction '%s'", xid),
"commit": fmt.Sprintf("commit prepared '%s'", xid),
"rollback": fmt.Sprintf("rollback prepared '%s'", xid),
}[command]
}
func (*postgresDBSpecial) GetPlaceHoldSQL(sql string) string {
pos := 1
parts := []string{}
b := 0
for i := 0; i < len(sql); i++ {
if sql[i] == '?' {
parts = append(parts, sql[b:i])
b = i + 1
parts = append(parts, fmt.Sprintf("$%d", pos))
pos++
}
}
parts = append(parts, sql[b:])
return strings.Join(parts, "")
}
// GetDBSpecial get DBSpecial for DBDriver
func GetDBSpecial() DBSpecial {
if DBDriver == DriverMysql {
return &mysqlDBSpecial{}
} else if DBDriver == DriverPostgres {
return &postgresDBSpecial{}
}
panic(errors.New("unknown DBDriver, please set it to a valid driver: " + DBDriver))
}

27
dtmcli/db_special_test.go

@ -0,0 +1,27 @@
package dtmcli
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDBSpecial(t *testing.T) {
old := DBDriver
DBDriver = "no-driver"
assert.Error(t, CatchP(func() {
GetDBSpecial()
}))
DBDriver = DriverMysql
sp := GetDBSpecial()
assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?"))
assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1"))
assert.Equal(t, "date_add(now(), interval 1000 second)", sp.TimestampAdd(1000))
DBDriver = DriverPostgres
sp = GetDBSpecial()
assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?"))
assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1"))
assert.Equal(t, "current_timestamp + interval '1000 second'", sp.TimestampAdd(1000))
DBDriver = old
}

36
dtmcli/utils.go

@ -215,7 +215,7 @@ func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr erro
if sql == "" {
return 0, nil
}
sql = makeSQLCompatible(sql)
sql = GetDBSpecial().GetPlaceHoldSQL(sql)
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
@ -268,37 +268,3 @@ func CheckResult(res interface{}, err error) error {
}
return err
}
func makeSQLCompatible(sql string) string {
if DBDriver == DriverPostgres {
pos := 1
parts := []string{}
b := 0
for i := 0; i < len(sql); i++ {
if sql[i] == '?' {
parts = append(parts, sql[b:i])
b = i + 1
parts = append(parts, fmt.Sprintf("$%d", pos))
pos++
}
}
parts = append(parts, sql[b:])
return strings.Join(parts, "")
}
PanicIf(DBDriver != DriverMysql, fmt.Errorf("unkown db driver: %s", DBDriver))
return sql
}
func getXaSQL(action string, xid string) string {
if DBDriver == DriverPostgres {
return map[string]string{
"end": "",
"start": "begin",
"prepare": fmt.Sprintf("prepare transaction '%s'", xid),
"commit": fmt.Sprintf("commit prepared '%s'", xid),
"rollback": fmt.Sprintf("rollback prepared '%s'", xid),
}[action]
}
PanicIf(DBDriver != DriverMysql, fmt.Errorf("unkown db driver: %s", DBDriver))
return fmt.Sprintf("xa %s '%s'", action, xid)
}

11
dtmcli/utils_test.go

@ -89,14 +89,3 @@ func TestFatal(t *testing.T) {
})
assert.Error(t, err, fmt.Errorf("fatal"))
}
func TestCompatible(t *testing.T) {
old := DBDriver
DBDriver = DriverMysql
assert.Equal(t, "? ?", makeSQLCompatible("? ?"))
assert.Equal(t, "xa start 'xa1'", getXaSQL("start", "xa1"))
DBDriver = DriverPostgres
assert.Equal(t, "$1 $2", makeSQLCompatible("? ?"))
assert.Equal(t, "begin", getXaSQL("start", "xa1"))
DBDriver = old
}

8
dtmcli/xa_base.go

@ -20,7 +20,7 @@ func (xc *XaClientBase) HandleCallback(gid string, branchID string, action strin
}
defer db.Close()
xaID := gid + "-" + branchID
_, err = DBExec(db, getXaSQL(action, xaID))
_, err = DBExec(db, GetDBSpecial().GetXaSQL(action, xaID))
if err != nil &&
(strings.Contains(err.Error(), "Error 1397: XAER_NOTA") || strings.Contains(err.Error(), "does not exist")) { // 重复commit/rollback同一个id,报这个错误,忽略
err = nil
@ -39,9 +39,9 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) (interf
defer func() { db.Close() }()
defer func() {
x := recover()
_, err := DBExec(db, getXaSQL("end", xaBranch))
_, err := DBExec(db, GetDBSpecial().GetXaSQL("end", xaBranch))
if x == nil && rerr == nil && err == nil {
_, err = DBExec(db, getXaSQL("prepare", xaBranch))
_, err = DBExec(db, GetDBSpecial().GetXaSQL("prepare", xaBranch))
}
if rerr == nil {
rerr = err
@ -50,7 +50,7 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) (interf
panic(x)
}
}()
_, rerr = DBExec(db, getXaSQL("start", xaBranch))
_, rerr = DBExec(db, GetDBSpecial().GetXaSQL("start", xaBranch))
if rerr != nil {
return
}

7
dtmsvr/cron.go

@ -41,12 +41,7 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal {
trans := TransGlobal{}
owner := GenGid()
db := dbGet()
getTime := func(second int) string {
return fmt.Sprintf(map[string]string{
"mysql": "date_add(now(), interval %d second)",
"postgres": "current_timestamp + interval '%d second'",
}[dtmcli.DBDriver], second)
}
getTime := dtmcli.GetDBSpecial().TimestampAdd
expire := int(expireIn / time.Second)
whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3))
// 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢

Loading…
Cancel
Save