diff --git a/app/main.go b/app/main.go index a4e04bd..c63dcd8 100644 --- a/app/main.go +++ b/app/main.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/yedf/dtm/bench" + "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmsvr" "github.com/yedf/dtm/examples" @@ -26,6 +27,7 @@ Available commands: ` func main() { + dtmcli.DBDriver = common.DtmConfig.DB["driver"] if len(os.Args) == 1 { fmt.Println(usage) for name := range examples.Samples { diff --git a/dtmcli/consts.go b/dtmcli/consts.go index bbb3455..78aead8 100644 --- a/dtmcli/consts.go +++ b/dtmcli/consts.go @@ -29,4 +29,12 @@ const ( ResultSuccess = "SUCCESS" // ResultFailure for result of a trans/trans branch ResultFailure = "FAILURE" + + // DriverMysql const for driver mysql + DriverMysql = "mysql" + // DriverPostgres const for driver postgres + DriverPostgres = "postgres" ) + +// DBDriver dtm和dtmcli可以支持mysql和postgres,但不支持混合,通过全局变量指定当前要支持的驱动 +var DBDriver = DriverMysql diff --git a/dtmcli/utils.go b/dtmcli/utils.go index f0df2cf..f29d82e 100644 --- a/dtmcli/utils.go +++ b/dtmcli/utils.go @@ -210,6 +210,7 @@ func StandaloneDB(conf map[string]string) (*sql.DB, error) { // DBExec use raw db to exec func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) { + sql = makeSqlCompatible(sql) r, rerr := db.Exec(sql, values...) if rerr == nil { affected, rerr = r.RowsAffected() @@ -222,6 +223,7 @@ func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr erro // DBQueryRow use raw tx to query row func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row { + query = makeSqlCompatible(query) Logf("querying: "+query, args...) return db.QueryRow(query, args...) } @@ -268,3 +270,23 @@ func CheckResult(res interface{}, err error) error { } return err } + +func makeSqlCompatible(sql string) string { + if DBDriver == DriverMysql { + return sql + } else if DBDriver == DriverPostgres { + pos := 1 + parts := []string{} + b := 0 + for i := 0; i < len(sql); i++ { + if sql[i] == '?' { + parts = append(parts, sql[b:i]) + b = i + 1 + parts = append(parts, fmt.Sprintf("$%d", pos)) + pos++ + } + } + return strings.Join(parts, "") + } + panic(fmt.Sprintf("unknown driver %s", DBDriver)) +} diff --git a/dtmcli/utils_test.go b/dtmcli/utils_test.go index a59d628..b7c9554 100644 --- a/dtmcli/utils_test.go +++ b/dtmcli/utils_test.go @@ -89,3 +89,12 @@ func TestFatal(t *testing.T) { }) assert.Error(t, err, fmt.Errorf("fatal")) } + +func TestMakeSqlCompatible(t *testing.T) { + old := DBDriver + DBDriver = DriverMysql + assert.Equal(t, "? ?", makeSqlCompatible("? ?")) + DBDriver = DriverPostgres + assert.Equal(t, "$1 $2", makeSqlCompatible("? ?")) + DBDriver = old +} diff --git a/test/main_test.go b/test/main_test.go index cf3bf77..0eebb53 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -6,11 +6,13 @@ import ( "time" "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmsvr" "github.com/yedf/dtm/examples" ) func TestMain(m *testing.M) { + dtmcli.DBDriver = common.DtmConfig.DB["driver"] dtmsvr.TransProcessedTestChan = make(chan string, 1) dtmsvr.CronForwardDuration = 60 * time.Second common.DtmConfig.UpdateBranchSync = 1