diff --git a/app/main.go b/app/main.go index c63dcd8..e9f4094 100644 --- a/app/main.go +++ b/app/main.go @@ -27,7 +27,7 @@ Available commands: ` func main() { - dtmcli.DBDriver = common.DtmConfig.DB["driver"] + dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) if len(os.Args) == 1 { fmt.Println(usage) for name := range examples.Samples { diff --git a/common/types.go b/common/types.go index 90573d4..5ada90b 100644 --- a/common/types.go +++ b/common/types.go @@ -28,10 +28,10 @@ type ModelBase struct { } func getGormDialetor(driver string, dsn string) gorm.Dialector { - if driver == dtmcli.DriverPostgres { + if driver == dtmcli.DBTypePostgres { return postgres.Open(dsn) } - dtmcli.PanicIf(driver != dtmcli.DriverMysql, fmt.Errorf("unkown driver: %s", driver)) + dtmcli.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unkown driver: %s", driver)) return mysql.Open(dsn) } diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 59b6582..dbc47f5 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -44,10 +44,7 @@ func insertBarrier(tx Tx, transType string, gid string, branchID string, branchT if branchType == "" { return 0, nil } - sql := map[string]string{ - "mysql": "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", - "postgres": "insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?) on conflict ON CONSTRAINT uniq_barrier do nothing", - }[DBDriver] + sql := GetDBSpecial().GetInsertIgnoreTemplate("dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier") return DBExec(tx, sql, transType, gid, branchID, branchType, barrierID, reason) } diff --git a/dtmcli/consts.go b/dtmcli/consts.go index 78aead8..8d2b1d7 100644 --- a/dtmcli/consts.go +++ b/dtmcli/consts.go @@ -30,11 +30,8 @@ const ( // ResultFailure for result of a trans/trans branch ResultFailure = "FAILURE" - // DriverMysql const for driver mysql - DriverMysql = "mysql" - // DriverPostgres const for driver postgres - DriverPostgres = "postgres" + // DBTypeMysql const for driver mysql + DBTypeMysql = "mysql" + // DBTypePostgres const for driver postgres + DBTypePostgres = "postgres" ) - -// DBDriver dtm和dtmcli可以支持mysql和postgres,但不支持混合,通过全局变量指定当前要支持的驱动 -var DBDriver = DriverMysql diff --git a/dtmcli/db_special.go b/dtmcli/db_special.go index 5807d7a..12769eb 100644 --- a/dtmcli/db_special.go +++ b/dtmcli/db_special.go @@ -1,7 +1,6 @@ package dtmcli import ( - "errors" "fmt" "strings" ) @@ -10,9 +9,13 @@ import ( type DBSpecial interface { TimestampAdd(second int) string GetPlaceHoldSQL(sql string) string + GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string GetXaSQL(command string, xid string) string } +var dbSpecials = map[string]DBSpecial{} +var currentDBType = DBTypeMysql + type mysqlDBSpecial struct{} func (*mysqlDBSpecial) TimestampAdd(second int) string { @@ -27,6 +30,14 @@ func (*mysqlDBSpecial) GetXaSQL(command string, xid string) string { return fmt.Sprintf("xa %s '%s'", command, xid) } +func (*mysqlDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string { + return fmt.Sprintf("insert ignore into %s", tableAndValues) +} + +func init() { + dbSpecials[DBTypeMysql] = &mysqlDBSpecial{} +} + type postgresDBSpecial struct{} func (*postgresDBSpecial) TimestampAdd(second int) string { @@ -59,12 +70,26 @@ func (*postgresDBSpecial) GetPlaceHoldSQL(sql string) string { return strings.Join(parts, "") } -// GetDBSpecial get DBSpecial for DBDriver +func (*postgresDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string { + return fmt.Sprintf("insert into %s on conflict ON CONSTRAINT %s do nothing", tableAndValues, pgConstraint) +} +func init() { + dbSpecials[DBTypePostgres] = &postgresDBSpecial{} +} + +// GetDBSpecial get DBSpecial for currentDBType func GetDBSpecial() DBSpecial { - if DBDriver == DriverMysql { - return &mysqlDBSpecial{} - } else if DBDriver == DriverPostgres { - return &postgresDBSpecial{} - } - panic(errors.New("unknown DBDriver, please set it to a valid driver: " + DBDriver)) + return dbSpecials[currentDBType] +} + +// SetCurrentDBType set currentDBType +func SetCurrentDBType(dbType string) { + spec := dbSpecials[dbType] + PanicIf(spec == nil, fmt.Errorf("unknown db type %s", dbType)) + currentDBType = dbType +} + +// GetCurrentDBType get currentDBType +func GetCurrentDBType() string { + return currentDBType } diff --git a/dtmcli/db_special_test.go b/dtmcli/db_special_test.go index 4ed5607..586f64c 100644 --- a/dtmcli/db_special_test.go +++ b/dtmcli/db_special_test.go @@ -7,21 +7,20 @@ import ( ) func TestDBSpecial(t *testing.T) { - old := DBDriver - DBDriver = "no-driver" + old := currentDBType assert.Error(t, CatchP(func() { - GetDBSpecial() + SetCurrentDBType("no-driver") })) - DBDriver = DriverMysql + SetCurrentDBType(DBTypeMysql) sp := GetDBSpecial() assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?")) assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1")) assert.Equal(t, "date_add(now(), interval 1000 second)", sp.TimestampAdd(1000)) - DBDriver = DriverPostgres + SetCurrentDBType(DBTypePostgres) sp = GetDBSpecial() assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?")) assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1")) assert.Equal(t, "current_timestamp + interval '1000 second'", sp.TimestampAdd(1000)) - DBDriver = old + SetCurrentDBType(old) } diff --git a/examples/base_http.go b/examples/base_http.go index 68e14cc..049119a 100644 --- a/examples/base_http.go +++ b/examples/base_http.go @@ -131,9 +131,9 @@ func BaseAddRoute(app *gin.Engine) { return dtmcli.MapFailure, nil } var dia gorm.Dialector = nil - if dtmcli.DBDriver == dtmcli.DriverMysql { + if dtmcli.GetCurrentDBType() == dtmcli.DBTypeMysql { dia = mysql.New(mysql.Config{Conn: db}) - } else if dtmcli.DBDriver == dtmcli.DriverPostgres { + } else if dtmcli.GetCurrentDBType() == dtmcli.DBTypePostgres { dia = postgres.New(postgres.Config{Conn: db}) } gdb, err := gorm.Open(dia, &gorm.Config{}) diff --git a/helper/sync-dtmcli.sh b/helper/sync-dtmcli.sh new file mode 100644 index 0000000..2fabffd --- /dev/null +++ b/helper/sync-dtmcli.sh @@ -0,0 +1,54 @@ +#! /bin/bash +set -x +if [ x$1 == x ]; then + echo please specify you version like vx.x.x; + exit 1; +fi + +if [ ${1:1:1} != v ]; then + echo please specify you version like vx.x.x; + exit 1; +fi + +cd ../dtmcli +cp ../dtm/dtmcli/*.go ./ +rm -f *_test.go +go mod tidy +go build || exit 1 + +git add . +git commit -m'update from dtm' +git push +# git tag $1 +# git push --tags + +cd ../dtmcli-go-sample +go get -u github.com/yedf/dtmcli +go mod tidy +go build || exit 1 +git add . +git commit -m'update from dtm' +git push + +cd ../dtmgrpc +cp ../dtm/dtmgrpc/*.go ./ +cp ../dtm/dtmgrpc/*.proto ./ + +sed -i '' -e 's/yedf\/dtm\//yedf\//g' *.go *.proto +rm -rf *_test.go +go get -u github.com/yedf/dtmcli +go mod tidy +go build || exit 1 +git add . +git commit -m'update from dtm' +git push +# git tag $1 +# git push --tags + +cd ../dtmgrpc-go-sample +go get -u github.com/yedf/dtmcli +go get -u github.com/yedf/dtmgrpc +go build || exit 1 +git add . +git commit -m'update from dtm' +git push \ No newline at end of file diff --git a/test/main_test.go b/test/main_test.go index 0eebb53..5b472fe 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -12,7 +12,7 @@ import ( ) func TestMain(m *testing.M) { - dtmcli.DBDriver = common.DtmConfig.DB["driver"] + dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) dtmsvr.TransProcessedTestChan = make(chan string, 1) dtmsvr.CronForwardDuration = 60 * time.Second common.DtmConfig.UpdateBranchSync = 1