From b210ad7b8ce052a03aeadb77c4b3df0f92729bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E4=BA=91=E9=87=91YunjinXu?= Date: Mon, 16 Oct 2023 10:51:54 +0800 Subject: [PATCH] implement sqlserver storage using gorm.io/driver/sqlserver --- client/dtmcli/consts.go | 2 ++ client/dtmcli/dtmimp/consts.go | 2 ++ client/dtmcli/dtmimp/db_special.go | 21 +++++++++++++++++++ client/dtmcli/dtmimp/utils.go | 15 ++++++++++++++ dtmsvr/config/config.go | 4 +++- dtmsvr/storage/registry/registry.go | 5 +++-- dtmsvr/storage/sql/sql.go | 31 +++++++++++++++++++++-------- dtmutil/db.go | 6 ++++++ test/main_test.go | 4 ++++ 9 files changed, 79 insertions(+), 11 deletions(-) diff --git a/client/dtmcli/consts.go b/client/dtmcli/consts.go index 7ae4fc6..bb87a15 100644 --- a/client/dtmcli/consts.go +++ b/client/dtmcli/consts.go @@ -35,6 +35,8 @@ const ( DBTypeMysql = dtmimp.DBTypeMysql // DBTypePostgres const for driver postgres DBTypePostgres = dtmimp.DBTypePostgres + // DBTypeSqlServer const for driver SqlServer + DBTypeSqlServer = dtmimp.DBTypeSqlServer ) // MapSuccess HTTP result of SUCCESS diff --git a/client/dtmcli/dtmimp/consts.go b/client/dtmcli/dtmimp/consts.go index 6f4e6cd..036e6dc 100644 --- a/client/dtmcli/dtmimp/consts.go +++ b/client/dtmcli/dtmimp/consts.go @@ -36,6 +36,8 @@ const ( DBTypeMysql = "mysql" // DBTypePostgres const for driver postgres DBTypePostgres = "postgres" + // DBTypeSqlServer const for driver SqlServer + DBTypeSqlServer = "sqlserver" // DBTypeRedis const for driver redis DBTypeRedis = "redis" // Jrpc const for json-rpc diff --git a/client/dtmcli/dtmimp/db_special.go b/client/dtmcli/dtmimp/db_special.go index d9128b1..2f0c6d4 100644 --- a/client/dtmcli/dtmimp/db_special.go +++ b/client/dtmcli/dtmimp/db_special.go @@ -78,6 +78,27 @@ func init() { dbSpecials[DBTypePostgres] = &postgresDBSpecial{} } +// TODO sqlserver implement (for go client only, not for dtm server) +type sqlserverDBSpecial struct{} + +func (*sqlserverDBSpecial) GetPlaceHoldSQL(sql string) string { + // TODO sqlserver implement + return sql +} + +func (*sqlserverDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string { + // TODO sqlserver implement + return "" +} + +func (*sqlserverDBSpecial) GetXaSQL(command string, xid string) string { + // TODO sqlserver implement + return "" +} +func init() { + dbSpecials[DBTypeSqlServer] = &sqlserverDBSpecial{} +} + // GetDBSpecial get DBSpecial for currentDBType func GetDBSpecial(dbType string) DBSpecial { if dbType == "" { diff --git a/client/dtmcli/dtmimp/utils.go b/client/dtmcli/dtmimp/utils.go index b0a6a8a..223ed39 100644 --- a/client/dtmcli/dtmimp/utils.go +++ b/client/dtmcli/dtmimp/utils.go @@ -228,11 +228,26 @@ func GetDsn(conf DBConf) string { conf.User, conf.Password, host, conf.Port, conf.Db), "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' search_path=%s port=%d sslmode=disable", host, conf.User, conf.Password, conf.Db, conf.Schema, conf.Port), + // sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30 + "sqlserver": getSqlServerConnectionString(&conf, &host), }[driver] PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) return dsn } +func getSqlServerConnectionString(conf *DBConf, host *string) string { + query := url.Values{} + query.Add("database", conf.Db) + u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword(conf.User, conf.Password), + Host: fmt.Sprintf("%s:%d", *host, conf.Port), + // Path: instance, // if connecting to an instance instead of a port + RawQuery: query.Encode(), + } + return u.String() +} + // RespAsErrorByJSONRPC translate json rpc resty response to error func RespAsErrorByJSONRPC(resp *resty.Response) error { str := resp.String() diff --git a/dtmsvr/config/config.go b/dtmsvr/config/config.go index 2021db4..e9ba1e6 100644 --- a/dtmsvr/config/config.go +++ b/dtmsvr/config/config.go @@ -20,6 +20,8 @@ const ( BoltDb = "boltdb" // Postgres is postgres driver Postgres = "postgres" + // SqlServer is SQL Server driver + SqlServer = "sqlserver" ) // MicroService config type for microservice based grpc @@ -65,7 +67,7 @@ type Store struct { // IsDB checks config driver is mysql or postgres func (s *Store) IsDB() bool { - return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres + return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres || s.Driver == dtmcli.DBTypeSqlServer } // GetDBConf returns db conf info diff --git a/dtmsvr/storage/registry/registry.go b/dtmsvr/storage/registry/registry.go index 297c875..469d20d 100644 --- a/dtmsvr/storage/registry/registry.go +++ b/dtmsvr/storage/registry/registry.go @@ -37,8 +37,9 @@ var storeFactorys = map[string]StorageFactory{ return &redis.Store{} }, }, - "mysql": sqlFac, - "postgres": sqlFac, + "mysql": sqlFac, + "postgres": sqlFac, + "sqlserver": sqlFac, } // GetStore returns storage.Store diff --git a/dtmsvr/storage/sql/sql.go b/dtmsvr/storage/sql/sql.go index 8e444bc..8ae53a4 100644 --- a/dtmsvr/storage/sql/sql.go +++ b/dtmsvr/storage/sql/sql.go @@ -69,10 +69,10 @@ func (s *Store) ScanTransGlobalStores(position *string, limit int64, condition s query = query.Where("trans_type = ?", condition.TransType) } if !condition.CreateTimeStart.IsZero() { - query = query.Where("create_time >= ?", condition.CreateTimeStart.Format("2006-01-02 15:04:05")) + query = query.Where("create_time >= ?", condition.CreateTimeStart) } if !condition.CreateTimeEnd.IsZero() { - query = query.Where("create_time <= ?", condition.CreateTimeEnd.Format("2006-01-02 15:04:05")) + query = query.Where("create_time <= ?", condition.CreateTimeEnd) } dbr := query.Order("id desc").Limit(int(limit)).Find(&globals) @@ -105,7 +105,13 @@ func (s *Store) UpdateBranches(branches []storage.TransBranchStore, updates []st func (s *Store) LockGlobalSaveBranches(gid string, status string, branches []storage.TransBranchStore, branchStart int) { err := dbGet().Transaction(func(tx *gorm.DB) error { g := &storage.TransGlobalStore{} - dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g) + var dbr *gorm.DB + // sqlserver sql should be: SELECT * FROM "trans_global" with(RowLock,UpdLock) ,but gorm generates "FOR UPDATE" at the back, raw sql instead. + if conf.Store.Driver == config.SqlServer { + dbr = tx.Raw("SELECT * FROM trans_global with(RowLock,UpdLock) WHERE gid=? and status=? ORDER BY id OFFSET 0 ROW FETCH NEXT 1 ROWS ONLY ", gid, status).First(g) + } else { + dbr = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g) + } if dbr.Error == nil { if branchStart == -1 { dbr = tx.Create(branches) @@ -164,11 +170,16 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS where := fmt.Sprintf(`next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted')`, nextCronTime) order := map[string]string{ - dtmimp.DBTypeMysql: `order by rand()`, - dtmimp.DBTypePostgres: `order by random()`, + dtmimp.DBTypeMysql: `order by rand()`, + dtmimp.DBTypePostgres: `order by random()`, + dtmimp.DBTypeSqlServer: `order by rand()`, }[conf.Store.Driver] - ssql := fmt.Sprintf(`select id from trans_global where %s %s limit 1`, where, order) + ssql := map[string]string{ + dtmimp.DBTypeMysql: fmt.Sprintf(`select id from trans_global where %s %s limit 1`, where, order), + dtmimp.DBTypePostgres: fmt.Sprintf(`select id from trans_global where %s %s limit 1`, where, order), + dtmimp.DBTypeSqlServer: fmt.Sprintf(`select top 1 id from trans_global where %s %s`, where, order), + }[conf.Store.Driver] var id int64 err := db.ToSQLDB().QueryRow(ssql).Scan(&id) if errors.Is(err, sql.ErrNoRows) { @@ -198,8 +209,9 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS func (s *Store) ResetCronTime(after time.Duration, limit int64) (succeedCount int64, hasRemaining bool, err error) { nextCronTime := getTimeStr(int64(after / time.Second)) where := map[string]string{ - dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit), - dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit), + dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit), + dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit), + dtmimp.DBTypeSqlServer: fmt.Sprintf(`id in (select top %d id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') )`, limit, nextCronTime), }[conf.Store.Driver] sql := fmt.Sprintf(`UPDATE trans_global SET update_time='%s',next_cron_time='%s' WHERE %s`, @@ -317,5 +329,8 @@ func wrapError(err error) error { } func getTimeStr(afterSecond int64) string { + if conf.Store.Driver == config.SqlServer { + return dtmutil.GetNextTime(afterSecond).Format(time.RFC3339) + } return dtmutil.GetNextTime(afterSecond).Format("2006-01-02 15:04:05") } diff --git a/dtmutil/db.go b/dtmutil/db.go index 7e8423e..4237ccf 100644 --- a/dtmutil/db.go +++ b/dtmutil/db.go @@ -11,8 +11,11 @@ import ( "github.com/dtm-labs/logger" _ "github.com/go-sql-driver/mysql" // register mysql driver _ "github.com/lib/pq" // register postgres driver + + // _ "github.com/microsoft/go-mssqldb" // Microsoft's package conflicts with gorm's package: panic: sql: Register called twice for driver mssql "gorm.io/driver/mysql" "gorm.io/driver/postgres" + "gorm.io/driver/sqlserver" // register sqlserver driver, "gorm.io/gorm" ) @@ -27,6 +30,9 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector { if driver == dtmcli.DBTypePostgres { return postgres.Open(dsn) } + if driver == dtmcli.DBTypeSqlServer { + return sqlserver.Open(dsn) + } dtmimp.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unknown driver: %s", driver)) return mysql.Open(dsn) } diff --git a/test/main_test.go b/test/main_test.go index b91e9dd..5b2c1ce 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -53,6 +53,10 @@ func TestMain(m *testing.M) { conf.Store.User = "" conf.Store.Password = "" conf.Store.Port = 6379 + } else if tenv == config.SqlServer { + conf.Store.User = "sa" + conf.Store.Password = "p@ssw0rd" + conf.Store.Port = 1433 } conf.Store.Db = "" registry.WaitStoreUp()