Browse Source

add driver and compatible

pull/43/head
yedf2 5 years ago
parent
commit
b2dda03339
  1. 2
      app/main.go
  2. 8
      dtmcli/consts.go
  3. 22
      dtmcli/utils.go
  4. 9
      dtmcli/utils_test.go
  5. 2
      test/main_test.go

2
app/main.go

@ -6,6 +6,7 @@ import (
"strings"
"github.com/yedf/dtm/bench"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmsvr"
"github.com/yedf/dtm/examples"
@ -26,6 +27,7 @@ Available commands:
`
func main() {
dtmcli.DBDriver = common.DtmConfig.DB["driver"]
if len(os.Args) == 1 {
fmt.Println(usage)
for name := range examples.Samples {

8
dtmcli/consts.go

@ -29,4 +29,12 @@ const (
ResultSuccess = "SUCCESS"
// ResultFailure for result of a trans/trans branch
ResultFailure = "FAILURE"
// DriverMysql const for driver mysql
DriverMysql = "mysql"
// DriverPostgres const for driver postgres
DriverPostgres = "postgres"
)
// DBDriver dtm和dtmcli可以支持mysql和postgres,但不支持混合,通过全局变量指定当前要支持的驱动
var DBDriver = DriverMysql

22
dtmcli/utils.go

@ -210,6 +210,7 @@ func StandaloneDB(conf map[string]string) (*sql.DB, error) {
// DBExec use raw db to exec
func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) {
sql = makeSqlCompatible(sql)
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
@ -222,6 +223,7 @@ func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr erro
// DBQueryRow use raw tx to query row
func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row {
query = makeSqlCompatible(query)
Logf("querying: "+query, args...)
return db.QueryRow(query, args...)
}
@ -268,3 +270,23 @@ func CheckResult(res interface{}, err error) error {
}
return err
}
func makeSqlCompatible(sql string) string {
if DBDriver == DriverMysql {
return sql
} else if DBDriver == DriverPostgres {
pos := 1
parts := []string{}
b := 0
for i := 0; i < len(sql); i++ {
if sql[i] == '?' {
parts = append(parts, sql[b:i])
b = i + 1
parts = append(parts, fmt.Sprintf("$%d", pos))
pos++
}
}
return strings.Join(parts, "")
}
panic(fmt.Sprintf("unknown driver %s", DBDriver))
}

9
dtmcli/utils_test.go

@ -89,3 +89,12 @@ func TestFatal(t *testing.T) {
})
assert.Error(t, err, fmt.Errorf("fatal"))
}
func TestMakeSqlCompatible(t *testing.T) {
old := DBDriver
DBDriver = DriverMysql
assert.Equal(t, "? ?", makeSqlCompatible("? ?"))
DBDriver = DriverPostgres
assert.Equal(t, "$1 $2", makeSqlCompatible("? ?"))
DBDriver = old
}

2
test/main_test.go

@ -6,11 +6,13 @@ import (
"time"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmsvr"
"github.com/yedf/dtm/examples"
)
func TestMain(m *testing.M) {
dtmcli.DBDriver = common.DtmConfig.DB["driver"]
dtmsvr.TransProcessedTestChan = make(chan string, 1)
dtmsvr.CronForwardDuration = 60 * time.Second
common.DtmConfig.UpdateBranchSync = 1

Loading…
Cancel
Save