Browse Source

refactor postgres support

pull/43/head
yedf2 5 years ago
parent
commit
be36d4b30e
  1. 2
      app/main.go
  2. 4
      common/types.go
  3. 5
      dtmcli/barrier.go
  4. 11
      dtmcli/consts.go
  5. 41
      dtmcli/db_special.go
  6. 11
      dtmcli/db_special_test.go
  7. 4
      examples/base_http.go
  8. 54
      helper/sync-dtmcli.sh
  9. 2
      test/main_test.go

2
app/main.go

@ -27,7 +27,7 @@ Available commands:
`
func main() {
dtmcli.DBDriver = common.DtmConfig.DB["driver"]
dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"])
if len(os.Args) == 1 {
fmt.Println(usage)
for name := range examples.Samples {

4
common/types.go

@ -28,10 +28,10 @@ type ModelBase struct {
}
func getGormDialetor(driver string, dsn string) gorm.Dialector {
if driver == dtmcli.DriverPostgres {
if driver == dtmcli.DBTypePostgres {
return postgres.Open(dsn)
}
dtmcli.PanicIf(driver != dtmcli.DriverMysql, fmt.Errorf("unkown driver: %s", driver))
dtmcli.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unkown driver: %s", driver))
return mysql.Open(dsn)
}

5
dtmcli/barrier.go

@ -44,10 +44,7 @@ func insertBarrier(tx Tx, transType string, gid string, branchID string, branchT
if branchType == "" {
return 0, nil
}
sql := map[string]string{
"mysql": "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)",
"postgres": "insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?) on conflict ON CONSTRAINT uniq_barrier do nothing",
}[DBDriver]
sql := GetDBSpecial().GetInsertIgnoreTemplate("dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier")
return DBExec(tx, sql, transType, gid, branchID, branchType, barrierID, reason)
}

11
dtmcli/consts.go

@ -30,11 +30,8 @@ const (
// ResultFailure for result of a trans/trans branch
ResultFailure = "FAILURE"
// DriverMysql const for driver mysql
DriverMysql = "mysql"
// DriverPostgres const for driver postgres
DriverPostgres = "postgres"
// DBTypeMysql const for driver mysql
DBTypeMysql = "mysql"
// DBTypePostgres const for driver postgres
DBTypePostgres = "postgres"
)
// DBDriver dtm和dtmcli可以支持mysql和postgres,但不支持混合,通过全局变量指定当前要支持的驱动
var DBDriver = DriverMysql

41
dtmcli/db_special.go

@ -1,7 +1,6 @@
package dtmcli
import (
"errors"
"fmt"
"strings"
)
@ -10,9 +9,13 @@ import (
type DBSpecial interface {
TimestampAdd(second int) string
GetPlaceHoldSQL(sql string) string
GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string
GetXaSQL(command string, xid string) string
}
var dbSpecials = map[string]DBSpecial{}
var currentDBType = DBTypeMysql
type mysqlDBSpecial struct{}
func (*mysqlDBSpecial) TimestampAdd(second int) string {
@ -27,6 +30,14 @@ func (*mysqlDBSpecial) GetXaSQL(command string, xid string) string {
return fmt.Sprintf("xa %s '%s'", command, xid)
}
func (*mysqlDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string {
return fmt.Sprintf("insert ignore into %s", tableAndValues)
}
func init() {
dbSpecials[DBTypeMysql] = &mysqlDBSpecial{}
}
type postgresDBSpecial struct{}
func (*postgresDBSpecial) TimestampAdd(second int) string {
@ -59,12 +70,26 @@ func (*postgresDBSpecial) GetPlaceHoldSQL(sql string) string {
return strings.Join(parts, "")
}
// GetDBSpecial get DBSpecial for DBDriver
func (*postgresDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string {
return fmt.Sprintf("insert into %s on conflict ON CONSTRAINT %s do nothing", tableAndValues, pgConstraint)
}
func init() {
dbSpecials[DBTypePostgres] = &postgresDBSpecial{}
}
// GetDBSpecial get DBSpecial for currentDBType
func GetDBSpecial() DBSpecial {
if DBDriver == DriverMysql {
return &mysqlDBSpecial{}
} else if DBDriver == DriverPostgres {
return &postgresDBSpecial{}
}
panic(errors.New("unknown DBDriver, please set it to a valid driver: " + DBDriver))
return dbSpecials[currentDBType]
}
// SetCurrentDBType set currentDBType
func SetCurrentDBType(dbType string) {
spec := dbSpecials[dbType]
PanicIf(spec == nil, fmt.Errorf("unknown db type %s", dbType))
currentDBType = dbType
}
// GetCurrentDBType get currentDBType
func GetCurrentDBType() string {
return currentDBType
}

11
dtmcli/db_special_test.go

@ -7,21 +7,20 @@ import (
)
func TestDBSpecial(t *testing.T) {
old := DBDriver
DBDriver = "no-driver"
old := currentDBType
assert.Error(t, CatchP(func() {
GetDBSpecial()
SetCurrentDBType("no-driver")
}))
DBDriver = DriverMysql
SetCurrentDBType(DBTypeMysql)
sp := GetDBSpecial()
assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?"))
assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1"))
assert.Equal(t, "date_add(now(), interval 1000 second)", sp.TimestampAdd(1000))
DBDriver = DriverPostgres
SetCurrentDBType(DBTypePostgres)
sp = GetDBSpecial()
assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?"))
assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1"))
assert.Equal(t, "current_timestamp + interval '1000 second'", sp.TimestampAdd(1000))
DBDriver = old
SetCurrentDBType(old)
}

4
examples/base_http.go

@ -131,9 +131,9 @@ func BaseAddRoute(app *gin.Engine) {
return dtmcli.MapFailure, nil
}
var dia gorm.Dialector = nil
if dtmcli.DBDriver == dtmcli.DriverMysql {
if dtmcli.GetCurrentDBType() == dtmcli.DBTypeMysql {
dia = mysql.New(mysql.Config{Conn: db})
} else if dtmcli.DBDriver == dtmcli.DriverPostgres {
} else if dtmcli.GetCurrentDBType() == dtmcli.DBTypePostgres {
dia = postgres.New(postgres.Config{Conn: db})
}
gdb, err := gorm.Open(dia, &gorm.Config{})

54
helper/sync-dtmcli.sh

@ -0,0 +1,54 @@
#! /bin/bash
set -x
if [ x$1 == x ]; then
echo please specify you version like vx.x.x;
exit 1;
fi
if [ ${1:1:1} != v ]; then
echo please specify you version like vx.x.x;
exit 1;
fi
cd ../dtmcli
cp ../dtm/dtmcli/*.go ./
rm -f *_test.go
go mod tidy
go build || exit 1
git add .
git commit -m'update from dtm'
git push
# git tag $1
# git push --tags
cd ../dtmcli-go-sample
go get -u github.com/yedf/dtmcli
go mod tidy
go build || exit 1
git add .
git commit -m'update from dtm'
git push
cd ../dtmgrpc
cp ../dtm/dtmgrpc/*.go ./
cp ../dtm/dtmgrpc/*.proto ./
sed -i '' -e 's/yedf\/dtm\//yedf\//g' *.go *.proto
rm -rf *_test.go
go get -u github.com/yedf/dtmcli
go mod tidy
go build || exit 1
git add .
git commit -m'update from dtm'
git push
# git tag $1
# git push --tags
cd ../dtmgrpc-go-sample
go get -u github.com/yedf/dtmcli
go get -u github.com/yedf/dtmgrpc
go build || exit 1
git add .
git commit -m'update from dtm'
git push

2
test/main_test.go

@ -12,7 +12,7 @@ import (
)
func TestMain(m *testing.M) {
dtmcli.DBDriver = common.DtmConfig.DB["driver"]
dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"])
dtmsvr.TransProcessedTestChan = make(chan string, 1)
dtmsvr.CronForwardDuration = 60 * time.Second
common.DtmConfig.UpdateBranchSync = 1

Loading…
Cancel
Save