diff --git a/.gitignore b/.gitignore index f01bdba..7de6399 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .idea/** .vscode/*.json default.etcd +*/**/*.bolt diff --git a/.travis.yml b/.travis.yml index a57c97b..9a08ffa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,8 +10,9 @@ branches: - alpha services: - mysql + - redis-server before_install: - go get -t -v ./... - - go get github.com/mattn/goveralls + - go get github.com/yedf2/goveralls script: - - $GOPATH/bin/goveralls -service=travis-ci -ignore="examples/*,dtmgrpc/dtmgimp/*.pb.go,bench/*,test/*" + - $GOPATH/bin/goveralls -envs=TEST_STORE=redis,TEST_STORE=mysql,TEST_STORE=boltdb -service=travis-ci -ignore="examples/*,dtmgrpc/dtmgimp/*.pb.go,bench/*,test/*" diff --git a/README-cn.md b/README-cn.md index 241d125..8f866fc 100644 --- a/README-cn.md +++ b/README-cn.md @@ -98,7 +98,7 @@ DTM是一款golang开发的分布式事务管理器,解决了跨数据库、 const qsBusi = "http://localhost:8081/api/busi_saga" req := &gin.H{"amount": 30} // 微服务的载荷 // DtmServer为DTM服务的地址,是一个url - DtmServer := "http://localhost:8080/api/dtmsvr" + DtmServer := "http://localhost:36789/api/dtmsvr" saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). // 添加一个TransOut的子事务,正向操作为url: qsBusi+"/TransOut", 补偿操作为url: qsBusi+"/TransOutCompensate" Add(qsBusi+"/TransOut", qsBusi+"/TransOutCompensate", req). diff --git a/README-en.md b/README-en.md index bd13a01..e106d9e 100644 --- a/README-en.md +++ b/README-en.md @@ -87,7 +87,7 @@ If your language stack is Java, you can also choose to access dtm and use sub-tr // business micro-service address const qsBusi = "http://localhost:8081/api/busi_saga" // The address where DtmServer serves DTM, which is a url - DtmServer := "http://localhost:8080/api/dtmsvr" + DtmServer := "http://localhost:36789/api/dtmsvr" req := &gin.H{"amount": 30} // micro-service payload // DtmServer is the address of DTM micro-service saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). diff --git a/README.md b/README.md index 241d125..8f866fc 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ DTM是一款golang开发的分布式事务管理器,解决了跨数据库、 const qsBusi = "http://localhost:8081/api/busi_saga" req := &gin.H{"amount": 30} // 微服务的载荷 // DtmServer为DTM服务的地址,是一个url - DtmServer := "http://localhost:8080/api/dtmsvr" + DtmServer := "http://localhost:36789/api/dtmsvr" saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). // 添加一个TransOut的子事务,正向操作为url: qsBusi+"/TransOut", 补偿操作为url: qsBusi+"/TransOutCompensate" Add(qsBusi+"/TransOut", qsBusi+"/TransOutCompensate", req). diff --git a/app/main.go b/app/main.go index 0e344b0..fb97b4b 100644 --- a/app/main.go +++ b/app/main.go @@ -15,6 +15,7 @@ import ( "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/dtmsvr" + "github.com/yedf/dtm/dtmsvr/storage" "github.com/yedf/dtm/examples" _ "go.uber.org/automaxprocs" @@ -51,9 +52,9 @@ func main() { } dtmimp.Logf("starting dtm....") common.MustLoadConfig() - dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) + dtmcli.SetCurrentDBType(common.Config.ExamplesDB.Driver) if os.Args[1] != "dtmsvr" { // 实际线上运行,只启动dtmsvr,不准备table相关的数据 - common.WaitDBUp() + storage.WaitStoreUp() dtmsvr.PopulateDB(true) examples.PopulateDB(true) } diff --git a/bench/http.go b/bench/http.go index 4308658..f9a03c7 100644 --- a/bench/http.go +++ b/bench/http.go @@ -31,7 +31,7 @@ const total = 200000 var benchBusi = fmt.Sprintf("http://localhost:%d%s", benchPort, benchAPI) func sdbGet() *sql.DB { - db, err := dtmimp.PooledDB(common.DtmConfig.DB) + db, err := dtmimp.PooledDB(common.Config.Store.GetDBConf()) dtmimp.FatalIfError(err) return db } diff --git a/bench/main.go b/bench/main.go index f5bfefd..02b2931 100644 --- a/bench/main.go +++ b/bench/main.go @@ -2,12 +2,14 @@ package main import ( "fmt" + "os" + "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/dtmsvr" + "github.com/yedf/dtm/dtmsvr/storage" "github.com/yedf/dtm/examples" - "os" ) var hint = `To start the bench server, you need to specify the parameters: @@ -25,8 +27,8 @@ func main() { if os.Args[1] == "http" { fmt.Println("start bench server") common.MustLoadConfig() - dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) - common.WaitDBUp() + dtmcli.SetCurrentDBType(common.Config.ExamplesDB.Driver) + storage.WaitStoreUp() dtmsvr.PopulateDB(true) examples.PopulateDB(true) dtmsvr.StartSvr() // 启动dtmsvr的api服务 diff --git a/common/config.go b/common/config.go new file mode 100644 index 0000000..5be36d3 --- /dev/null +++ b/common/config.go @@ -0,0 +1,103 @@ +package common + +import ( + "errors" + "io/ioutil" + "path/filepath" + + "github.com/yedf/dtm/dtmcli" + "github.com/yedf/dtm/dtmcli/dtmimp" + "gopkg.in/yaml.v2" +) + +const ( + DtmMetricsPort = 8889 +) + +// MicroService config type for micro service +type MicroService struct { + Driver string `yaml:"Driver" default:"default"` + Target string `yaml:"Target"` + EndPoint string `yaml:"EndPoint"` +} + +type Store struct { + Driver string `yaml:"Driver" default:"boltdb"` + Host string `yaml:"Host"` + Port int64 `yaml:"Port"` + User string `yaml:"User"` + Password string `yaml:"Password"` + MaxOpenConns int64 `yaml:"MaxOpenConns" default:"500"` + MaxIdleConns int64 `yaml:"MaxIdleConns" default:"500"` + ConnMaxLifeTime int64 `yaml:"ConnMaxLifeTime" default:"5"` + DataExpire int64 `yaml:"DataExpire" default:"604800"` // Trans data will expire in 7 days. only for redis/boltdb. + RedisPrefix string `yaml:"RedisPrefix" default:"{}"` // Redis storage prefix. store data to only one slot in cluster +} + +func (s *Store) IsDB() bool { + return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres +} + +func (s *Store) GetDBConf() dtmcli.DBConf { + return dtmcli.DBConf{ + Driver: s.Driver, + Host: s.Host, + Port: s.Port, + User: s.User, + Passwrod: s.Password, + } +} + +type configType struct { + Store Store `yaml:"Store"` + TransCronInterval int64 `yaml:"TransCronInterval" default:"3"` + TimeoutToFail int64 `yaml:"TimeoutToFail" default:"35"` + RetryInterval int64 `yaml:"RetryInterval" default:"10"` + HttpPort int64 `yaml:"HttpPort" default:"36789"` + GrpcPort int64 `yaml:"GrpcPort" default:"36790"` + MicroService MicroService `yaml:"MicroService"` + UpdateBranchSync int64 `yaml:"UpdateBranchSync"` + ExamplesDB dtmcli.DBConf `yaml:"ExamplesDB"` +} + +// Config 配置 +var Config = configType{} + +func MustLoadConfig() { + loadFromEnv("", &Config) + cont := []byte{} + for d := MustGetwd(); d != "" && d != "/"; d = filepath.Dir(d) { + cont1, err := ioutil.ReadFile(d + "/conf.yml") + if err != nil { + cont1, err = ioutil.ReadFile(d + "/conf.sample.yml") + } + if cont1 != nil { + cont = cont1 + break + } + } + if len(cont) != 0 { + dtmimp.Logf("config is: \n%s", string(cont)) + err := yaml.UnmarshalStrict(cont, &Config) + dtmimp.FatalIfError(err) + } + err := checkConfig() + dtmimp.LogIfFatalf(err != nil, `config error: '%v'. + check you env, and conf.yml/conf.sample.yml in current and parent path: %s. + please visit http://d.dtm.pub to see the config document. + loaded config is: + %v`, err, MustGetwd(), Config) +} + +func checkConfig() error { + if Config.RetryInterval < 10 { + return errors.New("RetryInterval should not be less than 10") + } else if Config.TimeoutToFail < Config.RetryInterval { + return errors.New("TimeoutToFail should not be less than RetryInterval") + } else if Config.Store.Driver == "boltdb" { + return nil + } else if Config.Store.Driver != "redis" && (Config.Store.User == "" || Config.Store.Host == "" || Config.Store.Port == 0) { + return errors.New("db config not valid") + } + return nil +} diff --git a/common/config_test.go b/common/config_test.go new file mode 100644 index 0000000..dc2a660 --- /dev/null +++ b/common/config_test.go @@ -0,0 +1,17 @@ +package common + +import ( + "os" + "testing" + + "github.com/go-playground/assert/v2" +) + +func TestLoadFromEnv(t *testing.T) { + assert.Equal(t, "MICRO_SERVICE_DRIVER", toUnderscoreUpper("MicroService_Driver")) + + ms := MicroService{} + os.Setenv("T_DRIVER", "d1") + loadFromEnv("T", &ms) + assert.Equal(t, "d1", ms.Driver) +} diff --git a/common/config_utils.go b/common/config_utils.go new file mode 100644 index 0000000..44c93f6 --- /dev/null +++ b/common/config_utils.go @@ -0,0 +1,55 @@ +package common + +import ( + "fmt" + "os" + "reflect" + "regexp" + "strings" + + "github.com/yedf/dtm/dtmcli/dtmimp" +) + +func loadFromEnv(prefix string, conf interface{}) { + rv := reflect.ValueOf(conf) + dtmimp.PanicIf(rv.Kind() != reflect.Ptr || rv.IsNil(), + fmt.Errorf("should be a valid pointer, but %s found", reflect.TypeOf(conf).Name())) + loadFromEnvInner(prefix, rv.Elem(), "") +} + +func loadFromEnvInner(prefix string, conf reflect.Value, defaultValue string) { + kind := conf.Kind() + switch kind { + case reflect.Struct: + t := conf.Type() + for i := 0; i < t.NumField(); i++ { + tag := t.Field(i).Tag + loadFromEnvInner(prefix+"_"+tag.Get("yaml"), conf.Field(i), tag.Get("default")) + } + case reflect.String: + str := os.Getenv(toUnderscoreUpper(prefix)) + if str == "" { + str = defaultValue + } + conf.Set(reflect.ValueOf(str)) + case reflect.Int64: + str := os.Getenv(toUnderscoreUpper(prefix)) + if str == "" { + str = defaultValue + } + if str == "" { + str = "0" + } + conf.Set(reflect.ValueOf(int64(dtmimp.MustAtoi(str)))) + default: + panic(fmt.Errorf("unsupported type: %s", conf.Type().Name())) + } +} + +func toUnderscoreUpper(key string) string { + key = strings.Trim(key, "_") + matchFirstCap := regexp.MustCompile("([a-z])([A-Z]+)") + s2 := matchFirstCap.ReplaceAllString(key, "${1}_${2}") + // dtmimp.Logf("loading from env: %s", strings.ToUpper(s2)) + return strings.ToUpper(s2) +} diff --git a/common/db.go b/common/db.go index bf6c02e..a9c315d 100644 --- a/common/db.go +++ b/common/db.go @@ -3,7 +3,6 @@ package common import ( "database/sql" "fmt" - "strconv" "strings" "sync" "time" @@ -20,8 +19,8 @@ import ( // ModelBase model base for gorm to provide base fields type ModelBase struct { ID uint64 - CreateTime *time.Time `gorm:"autoCreateTime"` - UpdateTime *time.Time `gorm:"autoUpdateTime"` + CreateTime *time.Time `json:"create_time" gorm:"autoCreateTime"` + UpdateTime *time.Time `json:"update_time" gorm:"autoUpdateTime"` } func getGormDialetor(driver string, dsn string) gorm.Dialector { @@ -45,12 +44,6 @@ func (m *DB) Must() *DB { return &DB{DB: db} } -// NoMust unset must flag, don't panic when error occur -func (m *DB) NoMust() *DB { - db := m.InstanceSet("ivy.must", false) - return &DB{DB: db} -} - // ToSQLDB get the sql.DB func (m *DB) ToSQLDB() *sql.DB { d, err := m.DB.DB() @@ -105,27 +98,18 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { // SetDBConn set db connection conf func SetDBConn(db *DB) { sqldb, _ := db.DB.DB() - maxOpenCons, err := strconv.Atoi(DtmConfig.DB["max_open_conns"]) - if err == nil { - sqldb.SetMaxOpenConns(maxOpenCons) - } - maxIdleCons, err := strconv.Atoi(DtmConfig.DB["max_idle_conns"]) - if err == nil { - sqldb.SetMaxIdleConns(maxIdleCons) - } - connMaxLifeTime, err := strconv.ParseInt(DtmConfig.DB["conn_max_life_time"], 10, 64) - if err == nil { - sqldb.SetConnMaxLifetime(time.Duration(connMaxLifeTime) * time.Minute) - } + sqldb.SetMaxOpenConns(int(Config.Store.MaxOpenConns)) + sqldb.SetMaxIdleConns(int(Config.Store.MaxIdleConns)) + sqldb.SetConnMaxLifetime(time.Duration(Config.Store.ConnMaxLifeTime) * time.Minute) } // DbGet get db connection for specified conf -func DbGet(conf map[string]string) *DB { +func DbGet(conf dtmcli.DBConf) *DB { dsn := dtmimp.GetDsn(conf) db, ok := dbs.Load(dsn) if !ok { - dtmimp.Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) - db1, err := gorm.Open(getGormDialetor(conf["driver"], dsn), &gorm.Config{ + dtmimp.Logf("connecting %s", strings.Replace(dsn, conf.Passwrod, "****", 1)) + db1, err := gorm.Open(getGormDialetor(conf.Driver, dsn), &gorm.Config{ SkipDefaultTransaction: true, }) dtmimp.E2P(err) @@ -136,16 +120,3 @@ func DbGet(conf map[string]string) *DB { } return db.(*DB) } - -// WaitDBUp wait for db to go up -func WaitDBUp() { - sdb, err := dtmimp.StandaloneDB(DtmConfig.DB) - dtmimp.FatalIfError(err) - defer func() { - sdb.Close() - }() - for _, err = dtmimp.DBExec(sdb, "select 1"); err != nil; { // wait for mysql to start - time.Sleep(3 * time.Second) - _, err = dtmimp.DBExec(sdb, "select 1") - } -} diff --git a/common/types.go b/common/types.go index 478bea5..7577db0 100644 --- a/common/types.go +++ b/common/types.go @@ -7,97 +7,21 @@ package common import ( - "errors" - "io/ioutil" - "os" - "path/filepath" + "fmt" + "sync" - "gopkg.in/yaml.v2" - - "github.com/yedf/dtm/dtmcli/dtmimp" -) - -const ( - DtmHttpPort = 36789 - DtmGrpcPort = 36790 + "github.com/go-redis/redis/v8" ) -// MicroService config type for micro service -type MicroService struct { - Driver string `yaml:"Driver"` - Target string `yaml:"Target"` - EndPoint string `yaml:"EndPoint"` -} - -type dtmConfigType struct { - TransCronInterval int64 `yaml:"TransCronInterval"` - TimeoutToFail int64 `yaml:"TimeoutToFail"` - RetryInterval int64 `yaml:"RetryInterval"` - DB map[string]string `yaml:"DB"` - MicroService MicroService `yaml:"MicroService"` - DisableLocalhost int64 `yaml:"DisableLocalhost"` - UpdateBranchSync int64 `yaml:"UpdateBranchSync"` -} - -// DtmConfig 配置 -var DtmConfig = dtmConfigType{} - -func getIntEnv(key string, defaultV string) int64 { - return int64(dtmimp.MustAtoi(dtmimp.OrString(os.Getenv(key), defaultV))) -} - -func MustLoadConfig() { - DtmConfig.TransCronInterval = getIntEnv("TRANS_CRON_INTERVAL", "3") - DtmConfig.TimeoutToFail = getIntEnv("TIMEOUT_TO_FAIL", "35") - DtmConfig.RetryInterval = getIntEnv("RETRY_INTERVAL", "10") - DtmConfig.DB = map[string]string{ - "driver": dtmimp.OrString(os.Getenv("DB_DRIVER"), "mysql"), - "host": os.Getenv("DB_HOST"), - "port": dtmimp.OrString(os.Getenv("DB_PORT"), "3306"), - "user": os.Getenv("DB_USER"), - "password": os.Getenv("DB_PASSWORD"), - "max_open_conns": dtmimp.OrString(os.Getenv("DB_MAX_OPEN_CONNS"), "500"), - "max_idle_conns": dtmimp.OrString(os.Getenv("DB_MAX_IDLE_CONNS"), "500"), - "conn_max_life_time": dtmimp.OrString(os.Getenv("DB_CONN_MAX_LIFE_TIME"), "5"), - } - DtmConfig.MicroService.Driver = dtmimp.OrString(os.Getenv("MICRO_SERVICE_DRIVER"), "default") - DtmConfig.MicroService.Target = os.Getenv("MICRO_SERVICE_TARGET") - DtmConfig.MicroService.EndPoint = os.Getenv("MICRO_SERVICE_ENDPOINT") - DtmConfig.DisableLocalhost = getIntEnv("DISABLE_LOCALHOST", "0") - DtmConfig.UpdateBranchSync = getIntEnv("UPDATE_BRANCH_SYNC", "0") - cont := []byte{} - for d := MustGetwd(); d != "" && d != "/"; d = filepath.Dir(d) { - cont1, err := ioutil.ReadFile(d + "/conf.yml") - if err != nil { - cont1, err = ioutil.ReadFile(d + "/conf.sample.yml") - } - if cont1 != nil { - cont = cont1 - break - } - } - if len(cont) != 0 { - dtmimp.Logf("config is: \n%s", string(cont)) - err := yaml.Unmarshal(cont, &DtmConfig) - dtmimp.FatalIfError(err) - } - err := checkConfig() - dtmimp.LogIfFatalf(err != nil, `config error: '%v'. - check you env, and conf.yml/conf.sample.yml in current and parent path: %s. - please visit http://d.dtm.pub to see the config document. - loaded config is: - %v`, err, MustGetwd(), DtmConfig) -} - -func checkConfig() error { - if DtmConfig.DB["driver"] == "" { - return errors.New("db driver empty") - } else if DtmConfig.DB["user"] == "" || DtmConfig.DB["host"] == "" { - return errors.New("db config not valid") - } else if DtmConfig.RetryInterval < 10 { - return errors.New("RetryInterval should not be less than 10") - } else if DtmConfig.TimeoutToFail < DtmConfig.RetryInterval { - return errors.New("TimeoutToFail should not be less than RetryInterval") - } - return nil +var rdb *redis.Client +var once sync.Once + +func RedisGet() *redis.Client { + once.Do(func() { + rdb = redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", Config.Store.Host, Config.Store.Port), + Password: Config.Store.Password, + }) + }) + return rdb } diff --git a/common/types_test.go b/common/types_test.go index 353b4ed..8aea076 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -13,25 +13,25 @@ import ( "github.com/yedf/dtm/dtmcli/dtmimp" ) -func TestDb(t *testing.T) { +func TestGeneralDB(t *testing.T) { MustLoadConfig() - db := DbGet(DtmConfig.DB) + if Config.Store.IsDB() { + testSql(t) + testDbAlone(t) + } +} +func testSql(t *testing.T) { + db := DbGet(Config.Store.GetDBConf()) err := func() (rerr error) { defer dtmimp.P2E(&rerr) - dbr := db.NoMust().Exec("select a") - assert.NotEqual(t, nil, dbr.Error) db.Must().Exec("select a") return nil }() assert.NotEqual(t, nil, err) } -func TestWaitDBUp(t *testing.T) { - WaitDBUp() -} - -func TestDbAlone(t *testing.T) { - db, err := dtmimp.StandaloneDB(DtmConfig.DB) +func testDbAlone(t *testing.T) { + db, err := dtmimp.StandaloneDB(Config.Store.GetDBConf()) assert.Nil(t, err) _, err = dtmimp.DBExec(db, "select 1") assert.Equal(t, nil, err) @@ -43,18 +43,18 @@ func TestDbAlone(t *testing.T) { } func TestConfig(t *testing.T) { - testConfigStringField(DtmConfig.DB, "driver", "", t) - testConfigStringField(DtmConfig.DB, "user", "", t) - testConfigIntField(&DtmConfig.RetryInterval, 9, t) - testConfigIntField(&DtmConfig.TimeoutToFail, 9, t) + testConfigStringField(&Config.Store.Driver, "", t) + testConfigStringField(&Config.Store.User, "", t) + testConfigIntField(&Config.RetryInterval, 9, t) + testConfigIntField(&Config.TimeoutToFail, 9, t) } -func testConfigStringField(m map[string]string, key string, val string, t *testing.T) { - old := m[key] - m[key] = val +func testConfigStringField(fd *string, val string, t *testing.T) { + old := *fd + *fd = val str := checkConfig() assert.NotEqual(t, "", str) - m[key] = old + *fd = old } func testConfigIntField(fd *int64, val int64, t *testing.T) { diff --git a/common/utils.go b/common/utils.go index dda68eb..24e25d7 100644 --- a/common/utils.go +++ b/common/utils.go @@ -19,6 +19,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" + "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" ) @@ -82,9 +83,37 @@ func MustGetwd() string { // GetCallerCodeDir 获取调用该函数的caller源代码的目录,主要用于测试时,查找相关文件 func GetCallerCodeDir() string { _, file, _, _ := runtime.Caller(1) - wd := MustGetwd() - if strings.HasSuffix(wd, "/test") { - wd = filepath.Dir(wd) + return filepath.Dir(file) +} + +func RecoverPanic(err *error) { + if x := recover(); x != nil { + e := dtmimp.AsError(x) + if err != nil { + *err = e + } + } +} + +func GetNextTime(second int64) *time.Time { + next := time.Now().Add(time.Duration(second) * time.Second) + return &next +} + +// RunSQLScript 1 +func RunSQLScript(conf dtmcli.DBConf, script string, skipDrop bool) { + con, err := dtmimp.StandaloneDB(conf) + dtmimp.FatalIfError(err) + defer func() { con.Close() }() + content, err := ioutil.ReadFile(script) + dtmimp.FatalIfError(err) + sqls := strings.Split(string(content), ";") + for _, sql := range sqls { + s := strings.TrimSpace(sql) + if s == "" || (skipDrop && strings.Contains(s, "drop")) { + continue + } + _, err = dtmimp.DBExec(con, s) + dtmimp.FatalIfError(err) } - return wd + "/" + filepath.Base(filepath.Dir(file)) } diff --git a/common/utils_test.go b/common/utils_test.go index 851c575..a3240a2 100644 --- a/common/utils_test.go +++ b/common/utils_test.go @@ -8,6 +8,7 @@ package common import ( "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -45,3 +46,11 @@ func TestFuncs(t *testing.T) { assert.Equal(t, true, strings.HasSuffix(dir1, "common")) } + +func TestRecoverPanic(t *testing.T) { + err := func() (rerr error) { + defer RecoverPanic(&rerr) + panic(fmt.Errorf("an error")) + }() + assert.Equal(t, "an error", err.Error()) +} diff --git a/conf.sample.yml b/conf.sample.yml index 4701009..3d698c0 100644 --- a/conf.sample.yml +++ b/conf.sample.yml @@ -1,29 +1,48 @@ -DB: - driver: 'mysql' - host: 'localhost' - user: 'root' - password: '' - port: '3306' - - # driver: 'postgres' - # host: 'localhost' - # user: 'postgres' - # password: 'mysecretpassword' - # port: '5432' - - # max_open_conns: 'dbmaxopenconns' - # max_idle_conns: 'dbmaxidleconns' - # conn_max_life_time: 'dbconnmaxlifetime' +Store: # specify which engine to store trans status +# Driver: 'boltdb' # default store engine + +# Driver: 'redis' +# Host: 'localhost' +# User: '' +# Password: '' +# Port: 6379 + + Driver: 'mysql' + Host: 'localhost' + User: 'root' + Password: '' + Port: 3306 + +# Driver: 'postgres' +# Host: 'localhost' +# User: 'postgres' +# Password: 'mysecretpassword' +# Port: '5432' + +### following connection config is for only Driver postgres/mysql +# MaxOpenConns: 500 +# MaxIdleConns: 500 +# ConnMaxLifeTime 5 # default value is 5 (minutes) + +### flollowing config is only for some Driver +# DataExpire: 604800 # Trans data will expire in 7 days. only for redis/boltdb. +# RedisPrefix: '{}' # default value is '{}'. Redis storage prefix. store data to only one slot in cluster + # MicroService: # Driver: 'dtm-driver-gozero' # name of the driver to handle register/discover # Target: 'etcd://localhost:2379/dtmservice' # register dtm server to this url # EndPoint: 'localhost:36790' -# MicroService: -# Driver: 'dtm-driver-protocol1' - # the unit of following configurations is second # TransCronInterval: 3 # the interval to poll unfinished global transaction for every dtm process # TimeoutToFail: 35 # timeout for XA, TCC to fail. saga's timeout default to infinite, which can be overwritten in saga options # RetryInterval: 10 # the subtrans branch will be retried after this interval + +### dtm can run examples, and examples will use following config to connect db +ExamplesDB: + Driver: 'mysql' + Host: 'localhost' + User: 'root' + Password: '' + Port: 3306 diff --git a/dtmcli/dtmimp/consts.go b/dtmcli/dtmimp/consts.go index 8227a98..050d086 100644 --- a/dtmcli/dtmimp/consts.go +++ b/dtmcli/dtmimp/consts.go @@ -17,4 +17,6 @@ const ( DBTypeMysql = "mysql" // DBTypePostgres const for driver postgres DBTypePostgres = "postgres" + // DBTypeRedis const for driver redis + DBTypeRedis = "redis" ) diff --git a/dtmcli/dtmimp/db_special.go b/dtmcli/dtmimp/db_special.go index 3eb3071..d45f460 100644 --- a/dtmcli/dtmimp/db_special.go +++ b/dtmcli/dtmimp/db_special.go @@ -13,7 +13,6 @@ import ( // DBSpecial db specific operations type DBSpecial interface { - TimestampAdd(second int) string GetPlaceHoldSQL(sql string) string GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string GetXaSQL(command string, xid string) string @@ -24,10 +23,6 @@ var currentDBType = DBTypeMysql type mysqlDBSpecial struct{} -func (*mysqlDBSpecial) TimestampAdd(second int) string { - return fmt.Sprintf("date_add(now(), interval %d second)", second) -} - func (*mysqlDBSpecial) GetPlaceHoldSQL(sql string) string { return sql } diff --git a/dtmcli/dtmimp/db_special_test.go b/dtmcli/dtmimp/db_special_test.go index 3bf7012..3966cd2 100644 --- a/dtmcli/dtmimp/db_special_test.go +++ b/dtmcli/dtmimp/db_special_test.go @@ -22,13 +22,11 @@ func TestDBSpecial(t *testing.T) { assert.Equal(t, "? ?", sp.GetPlaceHoldSQL("? ?")) assert.Equal(t, "xa start 'xa1'", sp.GetXaSQL("start", "xa1")) - assert.Equal(t, "date_add(now(), interval 1000 second)", sp.TimestampAdd(1000)) assert.Equal(t, "insert ignore into a(f) values(?)", sp.GetInsertIgnoreTemplate("a(f) values(?)", "c")) SetCurrentDBType(DBTypePostgres) sp = GetDBSpecial() assert.Equal(t, "$1 $2", sp.GetPlaceHoldSQL("? ?")) assert.Equal(t, "begin", sp.GetXaSQL("start", "xa1")) - assert.Equal(t, "current_timestamp + interval '1000 second'", sp.TimestampAdd(1000)) assert.Equal(t, "insert into a(f) values(?) on conflict ON CONSTRAINT c do nothing", sp.GetInsertIgnoreTemplate("a(f) values(?)", "c")) SetCurrentDBType(old) } diff --git a/dtmcli/dtmimp/trans_xa_base.go b/dtmcli/dtmimp/trans_xa_base.go index a899c67..737f2a4 100644 --- a/dtmcli/dtmimp/trans_xa_base.go +++ b/dtmcli/dtmimp/trans_xa_base.go @@ -14,7 +14,7 @@ import ( // XaClientBase XaClient/XaGrpcClient base type XaClientBase struct { Server string - Conf map[string]string + Conf DBConf NotifyURL string } diff --git a/dtmcli/dtmimp/types.go b/dtmcli/dtmimp/types.go index 3848fb0..a092295 100644 --- a/dtmcli/dtmimp/types.go +++ b/dtmcli/dtmimp/types.go @@ -13,3 +13,11 @@ type DB interface { Exec(query string, args ...interface{}) (sql.Result, error) QueryRow(query string, args ...interface{}) *sql.Row } + +type DBConf struct { + Driver string `yaml:"Driver"` + Host string `yaml:"Host"` + Port int64 `yaml:"Port"` + User string `yaml:"User"` + Passwrod string `yaml:"Password"` +} diff --git a/dtmcli/dtmimp/utils.go b/dtmcli/dtmimp/utils.go index 26686ab..c484884 100644 --- a/dtmcli/dtmimp/utils.go +++ b/dtmcli/dtmimp/utils.go @@ -190,7 +190,7 @@ func MayReplaceLocalhost(host string) string { var sqlDbs sync.Map // PooledDB get pooled sql.DB -func PooledDB(conf map[string]string) (*sql.DB, error) { +func PooledDB(conf DBConf) (*sql.DB, error) { dsn := GetDsn(conf) db, ok := sqlDbs.Load(dsn) if !ok { @@ -205,10 +205,10 @@ func PooledDB(conf map[string]string) (*sql.DB, error) { } // StandaloneDB get a standalone db instance -func StandaloneDB(conf map[string]string) (*sql.DB, error) { +func StandaloneDB(conf DBConf) (*sql.DB, error) { dsn := GetDsn(conf) - Logf("opening standalone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1)) - return sql.Open(conf["driver"], dsn) + Logf("opening standalone %s: %s", conf.Driver, strings.Replace(dsn, conf.Passwrod, "****", 1)) + return sql.Open(conf.Driver, dsn) } // DBExec use raw db to exec @@ -230,14 +230,14 @@ func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr erro } // GetDsn get dsn from map config -func GetDsn(conf map[string]string) string { - host := MayReplaceLocalhost(conf["host"]) - driver := conf["driver"] +func GetDsn(conf DBConf) string { + host := MayReplaceLocalhost(conf.Host) + driver := conf.Driver dsn := map[string]string{ - "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", - conf["user"], conf["password"], host, conf["port"], conf["database"]), - "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable", - host, conf["user"], conf["password"], conf["database"], conf["port"]), + "mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", + conf.User, conf.Passwrod, host, conf.Port, ""), + "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%d sslmode=disable", + host, conf.User, conf.Passwrod, "", conf.Port), }[driver] PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) return dsn diff --git a/dtmcli/trans_test.go b/dtmcli/trans_test.go index cbc2a02..05e6ab4 100644 --- a/dtmcli/trans_test.go +++ b/dtmcli/trans_test.go @@ -25,6 +25,6 @@ func TestQuery(t *testing.T) { } func TestXa(t *testing.T) { - _, err := NewXaClient("http://localhost:8080", map[string]string{}, ":::::", nil) + _, err := NewXaClient("http://localhost:36789", DBConf{}, ":::::", nil) assert.Error(t, err) } diff --git a/dtmcli/types.go b/dtmcli/types.go index c478334..3726e56 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -28,6 +28,8 @@ type DB = dtmimp.DB // TransOptions transaction option type TransOptions = dtmimp.TransOptions +type DBConf = dtmimp.DBConf + // SetCurrentDBType set currentDBType func SetCurrentDBType(dbType string) { dtmimp.SetCurrentDBType(dbType) diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go index 849fe57..8128e7d 100644 --- a/dtmcli/types_test.go +++ b/dtmcli/types_test.go @@ -16,7 +16,7 @@ import ( func TestTypes(t *testing.T) { err := dtmimp.CatchP(func() { - MustGenGid("http://localhost:8080/api/no") + MustGenGid("http://localhost:36789/api/no") }) assert.Error(t, err) assert.Error(t, err) diff --git a/dtmcli/xa.go b/dtmcli/xa.go index ac26120..cd9ca48 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -44,7 +44,7 @@ func XaFromQuery(qs url.Values) (*Xa, error) { } // NewXaClient construct a xa client -func NewXaClient(server string, mysqlConf map[string]string, notifyURL string, register XaRegisterCallback) (*XaClient, error) { +func NewXaClient(server string, mysqlConf DBConf, notifyURL string, register XaRegisterCallback) (*XaClient, error) { xa := &XaClient{XaClientBase: dtmimp.XaClientBase{ Server: server, Conf: mysqlConf, diff --git a/dtmgrpc/xa.go b/dtmgrpc/xa.go index 405b3d7..c12a7e6 100644 --- a/dtmgrpc/xa.go +++ b/dtmgrpc/xa.go @@ -11,6 +11,7 @@ import ( "database/sql" "fmt" + "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/dtmgrpc/dtmgimp" "github.com/yedf/dtmdriver" @@ -47,7 +48,7 @@ func XaGrpcFromRequest(ctx context.Context) (*XaGrpc, error) { } // NewXaGrpcClient construct a xa client -func NewXaGrpcClient(server string, mysqlConf map[string]string, notifyURL string) *XaGrpcClient { +func NewXaGrpcClient(server string, mysqlConf dtmcli.DBConf, notifyURL string) *XaGrpcClient { xa := &XaGrpcClient{XaClientBase: dtmimp.XaClientBase{ Server: server, Conf: mysqlConf, diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 03b8300..eb26ffb 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -10,33 +10,30 @@ import ( "fmt" "github.com/yedf/dtm/dtmcli" - "gorm.io/gorm" - "gorm.io/gorm/clause" + "github.com/yedf/dtm/dtmcli/dtmimp" + "github.com/yedf/dtm/dtmsvr/storage" ) func svcSubmit(t *TransGlobal) (interface{}, error) { - db := dbGet() t.Status = dtmcli.StatusSubmitted - err := t.saveNew(db) + err := t.saveNew() - if err == errUniqueConflict { - dbt := transFromDb(db.DB, t.Gid, false) + if err == storage.ErrUniqueConflict { + dbt := GetTransGlobal(t.Gid) if dbt.Status == dtmcli.StatusPrepared { - updates := t.setNextCron(cronReset) - dbr := db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) - checkAffected(dbr) + dbt.changeStatus(t.Status) } else if dbt.Status != dtmcli.StatusSubmitted { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil } } - return t.Process(db), nil + return t.Process(), nil } func svcPrepare(t *TransGlobal) (interface{}, error) { t.Status = dtmcli.StatusPrepared - err := t.saveNew(dbGet()) - if err == errUniqueConflict { - dbt := transFromDb(dbGet().DB, t.Gid, false) + err := t.saveNew() + if err == storage.ErrUniqueConflict { + dbt := GetTransGlobal(t.Gid) if dbt.Status != dtmcli.StatusPrepared { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil } @@ -45,46 +42,36 @@ func svcPrepare(t *TransGlobal) (interface{}, error) { } func svcAbort(t *TransGlobal) (interface{}, error) { - db := dbGet() - dbt := transFromDb(db.DB, t.Gid, false) + dbt := GetTransGlobal(t.Gid) if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil } - dbt.changeStatus(db, dtmcli.StatusAborting) - return dbt.Process(db), nil + dbt.changeStatus(dtmcli.StatusAborting) + return dbt.Process(), nil } -func svcRegisterBranch(branch *TransBranch, data map[string]string) (ret interface{}, rerr error) { - err := dbGet().Transaction(func(db *gorm.DB) error { - dbt := transFromDb(db, branch.Gid, true) - if dbt.Status != dtmcli.StatusPrepared { - ret = map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)} - return nil - } - - branches := []TransBranch{*branch, *branch} - if dbt.TransType == "tcc" { - for i, b := range []string{dtmcli.BranchCancel, dtmcli.BranchConfirm} { - branches[i].Op = b - branches[i].URL = data[b] - } - } else if dbt.TransType == "xa" { - branches[0].Op = dtmcli.BranchRollback - branches[0].URL = data["url"] - branches[1].Op = dtmcli.BranchCommit - branches[1].URL = data["url"] - } else { - rerr = fmt.Errorf("unknow trans type: %s", dbt.TransType) - return nil +func svcRegisterBranch(transType string, branch *TransBranch, data map[string]string) (ret interface{}, rerr error) { + branches := []TransBranch{*branch, *branch} + if transType == "tcc" { + for i, b := range []string{dtmcli.BranchCancel, dtmcli.BranchConfirm} { + branches[i].Op = b + branches[i].URL = data[b] } + } else if transType == "xa" { + branches[0].Op = dtmcli.BranchRollback + branches[0].URL = data["url"] + branches[1].Op = dtmcli.BranchCommit + branches[1].URL = data["url"] + } else { + return nil, fmt.Errorf("unknow trans type: %s", transType) + } - dbr := db.Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(branches) - checkAffected(dbr) - ret = dtmcli.MapSuccess - return nil + err := dtmimp.CatchP(func() { + GetStore().LockGlobalSaveBranches(branch.Gid, dtmcli.StatusPrepared, branches, -1) }) - e2p(err) - return + if err == storage.ErrNotFound { + msg := fmt.Sprintf("no trans with gid: %s status: %s found", branch.Gid, dtmcli.StatusPrepared) + return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": msg}, nil + } + return dtmimp.If(err != nil, nil, dtmcli.MapSuccess), err } diff --git a/dtmsvr/api_grpc.go b/dtmsvr/api_grpc.go index d2783ca..07e4d0e 100644 --- a/dtmsvr/api_grpc.go +++ b/dtmsvr/api_grpc.go @@ -40,7 +40,7 @@ func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empt } func (s *dtmServer) RegisterBranch(ctx context.Context, in *pb.DtmBranchRequest) (*emptypb.Empty, error) { - r, err := svcRegisterBranch(&TransBranch{ + r, err := svcRegisterBranch(in.TransType, &TransBranch{ Gid: in.Gid, BranchID: in.BranchID, Status: dtmcli.StatusPrepared, diff --git a/dtmsvr/api_http.go b/dtmsvr/api_http.go index 5208102..5102506 100644 --- a/dtmsvr/api_http.go +++ b/dtmsvr/api_http.go @@ -8,7 +8,6 @@ package dtmsvr import ( "errors" - "math" "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -61,7 +60,7 @@ func registerBranch(c *gin.Context) (interface{}, error) { Status: dtmcli.StatusPrepared, BinData: []byte(data["data"]), } - return svcRegisterBranch(&branch, data) + return svcRegisterBranch(data["trans_type"], &branch, data) } func query(c *gin.Context) (interface{}, error) { @@ -69,20 +68,14 @@ func query(c *gin.Context) (interface{}, error) { if gid == "" { return nil, errors.New("no gid specified") } - db := dbGet() - trans := transFromDb(db.DB, gid, false) - branches := []TransBranch{} - db.Must().Where("gid", gid).Find(&branches) + trans := GetStore().FindTransGlobalStore(gid) + branches := GetStore().FindBranches(gid) return map[string]interface{}{"transaction": trans, "branches": branches}, nil } func all(c *gin.Context) (interface{}, error) { - lastID := c.Query("last_id") - lid := math.MaxInt64 - if lastID != "" { - lid = dtmimp.MustAtoi(lastID) - } - trans := []TransGlobal{} - dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(100).Find(&trans) - return map[string]interface{}{"transactions": trans}, nil + position := c.Query("position") + slimit := dtmimp.OrString(c.Query("limit"), "100") + globals := GetStore().ScanTransGlobalStores(&position, int64(dtmimp.MustAtoi(slimit))) + return map[string]interface{}{"transactions": globals, "next_position": position}, nil } diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index dcf0f48..c4ef282 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -30,7 +30,7 @@ func CronTransOnce() (gid string) { } gid = trans.Gid trans.WaitResult = true - trans.Process(dbGet()) + trans.Process() return } @@ -45,23 +45,11 @@ func CronExpiredTrans(num int) { } func lockOneTrans(expireIn time.Duration) *TransGlobal { - trans := TransGlobal{} - owner := GenGid() - db := dbGet() - getTime := dtmimp.GetDBSpecial().TimestampAdd - expire := int(expireIn / time.Second) - whereTime := fmt.Sprintf("next_cron_time < %s and update_time < %s", getTime(expire), getTime(expire-3)) - // 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢 - // 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出 - dbr := db.Must().Model(&trans). - Where(whereTime+"and status in ('prepared', 'aborting', 'submitted')").Limit(1).Update("owner", owner) - if dbr.RowsAffected == 0 { + global := GetStore().LockOneGlobalTrans(expireIn) + if global == nil { return nil } - dbr = db.Must().Where("owner=?", owner).Find(&trans) - updates := trans.setNextCron(cronKeep) - db.Must().Model(&trans).Select(updates).Updates(&trans) - return &trans + return &TransGlobal{TransGlobalStore: *global} } func handlePanic(perr *error) { diff --git a/dtmsvr/storage/boltdb.go b/dtmsvr/storage/boltdb.go new file mode 100644 index 0000000..e1e6cc0 --- /dev/null +++ b/dtmsvr/storage/boltdb.go @@ -0,0 +1,244 @@ +package storage + +import ( + "fmt" + "sync" + "time" + + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli/dtmimp" + bolt "go.etcd.io/bbolt" + "gorm.io/gorm" +) + +type BoltdbStore struct { +} + +var boltDb *bolt.DB = nil +var boltOnce sync.Once + +func boltGet() *bolt.DB { + boltOnce.Do(func() { + db, err := bolt.Open("./dtm.bolt", 0666, &bolt.Options{Timeout: 1 * time.Second}) + dtmimp.E2P(err) + boltDb = db + }) + return boltDb +} + +var bucketGlobal = []byte("global") +var bucketBranches = []byte("branches") +var bucketIndex = []byte("index") + +func tGetGlobal(t *bolt.Tx, gid string) *TransGlobalStore { + trans := TransGlobalStore{} + bs := t.Bucket(bucketGlobal).Get([]byte(gid)) + if bs == nil { + return nil + } + dtmimp.MustUnmarshal(bs, &trans) + return &trans +} + +func tGetBranches(t *bolt.Tx, gid string) []TransBranchStore { + branches := []TransBranchStore{} + cursor := t.Bucket(bucketBranches).Cursor() + for k, v := cursor.Seek([]byte(gid)); k != nil; k, v = cursor.Next() { + b := TransBranchStore{} + dtmimp.MustUnmarshal(v, &b) + if b.Gid != gid { + break + } + branches = append(branches, b) + } + return branches +} +func tPutGlobal(t *bolt.Tx, global *TransGlobalStore) { + bs := dtmimp.MustMarshal(global) + err := t.Bucket(bucketGlobal).Put([]byte(global.Gid), bs) + dtmimp.E2P(err) +} + +func tPutBranches(t *bolt.Tx, branches []TransBranchStore, start int64) { + if start == -1 { + bs := tGetBranches(t, branches[0].Gid) + start = int64(len(bs)) + } + for i, b := range branches { + k := b.Gid + fmt.Sprintf("%03d", i+int(start)) + v := dtmimp.MustMarshalString(b) + err := t.Bucket(bucketBranches).Put([]byte(k), []byte(v)) + dtmimp.E2P(err) + } +} + +func tDelIndex(t *bolt.Tx, unix int64, gid string) { + k := fmt.Sprintf("%d-%s", unix, gid) + err := t.Bucket(bucketIndex).Delete([]byte(k)) + dtmimp.E2P(err) +} + +func tPutIndex(t *bolt.Tx, unix int64, gid string) { + k := fmt.Sprintf("%d-%s", unix, gid) + err := t.Bucket(bucketIndex).Put([]byte(k), []byte(gid)) + dtmimp.E2P(err) +} + +func (s *BoltdbStore) Ping() error { + return nil +} + +func (s *BoltdbStore) PopulateData(skipDrop bool) { + if !skipDrop { + err := boltGet().Update(func(t *bolt.Tx) error { + t.DeleteBucket(bucketIndex) + t.DeleteBucket(bucketBranches) + t.DeleteBucket(bucketGlobal) + t.CreateBucket(bucketIndex) + t.CreateBucket(bucketBranches) + t.CreateBucket(bucketGlobal) + return nil + }) + dtmimp.E2P(err) + } +} + +func (s *BoltdbStore) FindTransGlobalStore(gid string) (trans *TransGlobalStore) { + err := boltGet().View(func(t *bolt.Tx) error { + trans = tGetGlobal(t, gid) + return nil + }) + dtmimp.E2P(err) + return +} + +func (s *BoltdbStore) ScanTransGlobalStores(position *string, limit int64) []TransGlobalStore { + globals := []TransGlobalStore{} + err := boltGet().View(func(t *bolt.Tx) error { + cursor := t.Bucket(bucketGlobal).Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + if string(k) == *position { + continue + } + g := TransGlobalStore{} + dtmimp.MustUnmarshal(v, &g) + globals = append(globals, g) + if len(globals) == int(limit) { + break + } + } + return nil + }) + dtmimp.E2P(err) + if len(globals) < int(limit) { + *position = "" + } else { + *position = globals[len(globals)-1].Gid + } + return globals +} + +func (s *BoltdbStore) FindBranches(gid string) []TransBranchStore { + var branches []TransBranchStore = nil + err := boltGet().View(func(t *bolt.Tx) error { + branches = tGetBranches(t, gid) + return nil + }) + dtmimp.E2P(err) + return branches +} + +func (s *BoltdbStore) UpdateBranchesSql(branches []TransBranchStore, updates []string) *gorm.DB { + return nil // not implemented +} + +func (s *BoltdbStore) LockGlobalSaveBranches(gid string, status string, branches []TransBranchStore, branchStart int) { + err := boltGet().Update(func(t *bolt.Tx) error { + g := tGetGlobal(t, gid) + if g == nil { + return ErrNotFound + } + if g.Status != status { + return ErrNotFound + } + tPutBranches(t, branches, int64(branchStart)) + return nil + }) + dtmimp.E2P(err) +} + +func (s *BoltdbStore) MaySaveNewTrans(global *TransGlobalStore, branches []TransBranchStore) error { + return boltGet().Update(func(t *bolt.Tx) error { + g := tGetGlobal(t, global.Gid) + if g != nil { + return ErrUniqueConflict + } + tPutGlobal(t, global) + tPutIndex(t, global.NextCronTime.Unix(), global.Gid) + tPutBranches(t, branches, 0) + return nil + }) +} + +func (s *BoltdbStore) ChangeGlobalStatus(global *TransGlobalStore, newStatus string, updates []string, finished bool) { + old := global.Status + global.Status = newStatus + err := boltGet().Update(func(t *bolt.Tx) error { + g := tGetGlobal(t, global.Gid) + if g == nil || g.Status != old { + return ErrNotFound + } + if finished { + tDelIndex(t, g.NextCronTime.Unix(), g.Gid) + } + tPutGlobal(t, global) + return nil + }) + dtmimp.E2P(err) +} + +func (s *BoltdbStore) TouchCronTime(global *TransGlobalStore, nextCronInterval int64) { + oldUnix := global.NextCronTime.Unix() + global.NextCronTime = common.GetNextTime(nextCronInterval) + global.UpdateTime = common.GetNextTime(0) + global.NextCronInterval = nextCronInterval + err := boltGet().Update(func(t *bolt.Tx) error { + g := tGetGlobal(t, global.Gid) + if g == nil || g.Gid != global.Gid { + return ErrNotFound + } + tDelIndex(t, oldUnix, global.Gid) + tPutGlobal(t, global) + tPutIndex(t, global.NextCronTime.Unix(), global.Gid) + return nil + }) + dtmimp.E2P(err) +} + +func (s *BoltdbStore) LockOneGlobalTrans(expireIn time.Duration) *TransGlobalStore { + var trans *TransGlobalStore = nil + min := fmt.Sprintf("%d", time.Now().Add(expireIn).Unix()) + next := time.Now().Add(time.Duration(config.RetryInterval) * time.Second) + err := boltGet().Update(func(t *bolt.Tx) error { + cursor := t.Bucket(bucketIndex).Cursor() + k, v := cursor.First() + if k == nil || string(k) > min { + return ErrNotFound + } + trans = tGetGlobal(t, string(v)) + err := t.Bucket(bucketIndex).Delete(k) + dtmimp.E2P(err) + if trans == nil { // index exists, but global trans not exists, so retry to get next + return ErrShouldRetry + } + trans.NextCronTime = &next + tPutGlobal(t, trans) + tPutIndex(t, next.Unix(), trans.Gid) + return nil + }) + if err == ErrNotFound { + return nil + } + dtmimp.E2P(err) + return trans +} diff --git a/dtmsvr/storage/redis.go b/dtmsvr/storage/redis.go new file mode 100644 index 0000000..5bc751e --- /dev/null +++ b/dtmsvr/storage/redis.go @@ -0,0 +1,259 @@ +package storage + +import ( + "context" + "fmt" + "time" + + "github.com/go-redis/redis/v8" + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli/dtmimp" + "gorm.io/gorm" +) + +var ctx context.Context = context.Background() + +type RedisStore struct { +} + +func (s *RedisStore) Ping() error { + _, err := redisGet().Ping(ctx).Result() + return err +} + +func (s *RedisStore) PopulateData(skipDrop bool) { + _, err := redisGet().FlushAll(ctx).Result() + dtmimp.PanicIf(err != nil, err) +} + +func (s *RedisStore) FindTransGlobalStore(gid string) *TransGlobalStore { + r, err := redisGet().Get(ctx, config.Store.RedisPrefix+"_g_"+gid).Result() + if err == redis.Nil { + return nil + } + dtmimp.E2P(err) + trans := &TransGlobalStore{} + dtmimp.MustUnmarshalString(r, trans) + return trans +} + +func (s *RedisStore) ScanTransGlobalStores(position *string, limit int64) []TransGlobalStore { + lid := uint64(0) + if *position != "" { + lid = uint64(dtmimp.MustAtoi(*position)) + } + keys, cursor, err := redisGet().Scan(ctx, lid, config.Store.RedisPrefix+"_g_*", limit).Result() + dtmimp.E2P(err) + globals := []TransGlobalStore{} + if len(keys) > 0 { + values, err := redisGet().MGet(ctx, keys...).Result() + dtmimp.E2P(err) + for _, v := range values { + global := TransGlobalStore{} + dtmimp.MustUnmarshalString(v.(string), &global) + globals = append(globals, global) + } + } + if cursor > 0 { + *position = fmt.Sprintf("%d", cursor) + } else { + *position = "" + } + return globals +} + +func (s *RedisStore) FindBranches(gid string) []TransBranchStore { + sa, err := redisGet().LRange(ctx, config.Store.RedisPrefix+"_b_"+gid, 0, -1).Result() + dtmimp.E2P(err) + branches := make([]TransBranchStore, len(sa)) + for k, v := range sa { + dtmimp.MustUnmarshalString(v, &branches[k]) + } + return branches +} + +func (s *RedisStore) UpdateBranchesSql(branches []TransBranchStore, updates []string) *gorm.DB { + return nil // not implemented +} + +type argList struct { + List []interface{} +} + +func newArgList() *argList { + a := &argList{} + return a.AppendRaw(config.Store.RedisPrefix).AppendObject(config.Store.DataExpire) +} + +func (a *argList) AppendRaw(v interface{}) *argList { + a.List = append(a.List, v) + return a +} + +func (a *argList) AppendObject(v interface{}) *argList { + return a.AppendRaw(dtmimp.MustMarshalString(v)) +} + +func (a *argList) AppendBranches(branches []TransBranchStore) *argList { + for _, b := range branches { + a.AppendRaw(dtmimp.MustMarshalString(b)) + } + return a +} + +func handleRedisResult(ret interface{}, err error) (string, error) { + dtmimp.Logf("result is: '%v', err: '%v'", ret, err) + if err != nil && err != redis.Nil { + return "", err + } + s, _ := ret.(string) + err = map[string]error{ + "NOT_FOUND": ErrNotFound, + "UNIQUE_CONFLICT": ErrUniqueConflict, + }[s] + return s, err +} + +func callLua(args []interface{}, lua string) (string, error) { + dtmimp.Logf("calling lua. args: %v\nlua:%s", args, lua) + ret, err := redisGet().Eval(ctx, lua, []string{config.Store.RedisPrefix}, args...).Result() + return handleRedisResult(ret, err) +} + +func (s *RedisStore) MaySaveNewTrans(global *TransGlobalStore, branches []TransBranchStore) error { + args := newArgList(). + AppendObject(global). + AppendRaw(global.NextCronTime.Unix()). + AppendBranches(branches). + List + global.Steps = nil + global.Payloads = nil + _, err := callLua(args, `-- MaySaveNewTrans +local gs = cjson.decode(ARGV[3]) +local g = redis.call('GET', ARGV[1] .. '_g_' .. gs.gid) +if g ~= false then + return 'UNIQUE_CONFLICT' +end + +redis.call('SET', ARGV[1] .. '_g_' .. gs.gid, ARGV[3], 'EX', ARGV[2]) +redis.call('ZADD', ARGV[1] .. '_u', ARGV[4], gs.gid) +for k = 5, table.getn(ARGV) do + redis.call('RPUSH', ARGV[1] .. '_b_' .. gs.gid, ARGV[k]) +end +redis.call('EXPIRE', ARGV[1] .. '_b_' .. gs.gid, ARGV[2]) +`) + return err +} + +func (s *RedisStore) LockGlobalSaveBranches(gid string, status string, branches []TransBranchStore, branchStart int) { + args := newArgList(). + AppendObject(&TransGlobalStore{Gid: gid, Status: status}). + AppendRaw(branchStart). + AppendBranches(branches). + List + _, err := callLua(args, ` +local pre = ARGV[1] +local gs = cjson.decode(ARGV[3]) +local g = redis.call('GET', pre .. '_g_' .. gs.gid) +if (g == false) then + return 'NOT_FOUND' +end +local js = cjson.decode(g) +if js.status ~= gs.status then + return 'NOT_FOUND' +end +local start = ARGV[4] +for k = 5, table.getn(ARGV) do + if start == "-1" then + redis.call('RPUSH', pre .. '_b_' .. gs.gid, ARGV[k]) + else + redis.call('LSET', pre .. '_b_' .. gs.gid, start+k-5, ARGV[k]) + end +end +redis.call('EXPIRE', pre .. '_b_' .. gs.gid, ARGV[2]) + `) + dtmimp.E2P(err) +} + +func (s *RedisStore) ChangeGlobalStatus(global *TransGlobalStore, newStatus string, updates []string, finished bool) { + old := global.Status + global.Status = newStatus + args := newArgList().AppendObject(global).AppendRaw(old).AppendRaw(finished).List + _, err := callLua(args, `-- ChangeGlobalStatus +local p = ARGV[1] +local gs = cjson.decode(ARGV[3]) +local old = redis.call('GET', p .. '_g_' .. gs.gid) +if old == false then + return 'NOT_FOUND' +end +local os = cjson.decode(old) +if os.status ~= ARGV[4] then + return 'NOT_FOUND' +end +redis.call('SET', p .. '_g_' .. gs.gid, ARGV[3], 'EX', ARGV[2]) +redis.log(redis.LOG_WARNING, 'finished: ', ARGV[5]) +if ARGV[5] == '1' then + redis.call('ZREM', p .. '_u', gs.gid) +end +`) + dtmimp.E2P(err) +} + +func (s *RedisStore) LockOneGlobalTrans(expireIn time.Duration) *TransGlobalStore { + expired := time.Now().Add(expireIn).Unix() + next := time.Now().Add(time.Duration(config.RetryInterval) * time.Second).Unix() + args := newArgList().AppendRaw(expired).AppendRaw(next).List + lua := `-- LocakOneGlobalTrans +local k = ARGV[1] .. '_u' +local r = redis.call('ZRANGE', k, 0, 0, 'WITHSCORES') +local gid = r[1] +if gid == nil then + return 'NOT_FOUND' +end +local g = redis.call('GET', ARGV[1] .. '_g_' .. gid) +redis.log(redis.LOG_WARNING, 'g is: ', g, 'gid is: ', gid) +if g == false then + redis.call('ZREM', k, gid) + return 'NOT_FOUND' +end + +if tonumber(r[2]) > tonumber(ARGV[3]) then + return 'NOT_FOUND' +end +redis.call('ZADD', k, ARGV[4], gid) +return g +` + r, err := callLua(args, lua) + for err == ErrShouldRetry { + r, err = callLua(args, lua) + } + if err == ErrNotFound { + return nil + } + dtmimp.E2P(err) + global := &TransGlobalStore{} + dtmimp.MustUnmarshalString(r, global) + return global +} + +func (s *RedisStore) TouchCronTime(global *TransGlobalStore, nextCronInterval int64) { + global.NextCronTime = common.GetNextTime(nextCronInterval) + global.UpdateTime = common.GetNextTime(0) + global.NextCronInterval = nextCronInterval + args := newArgList().AppendObject(global).AppendRaw(global.NextCronTime.Unix()).List + _, err := callLua(args, `-- TouchCronTime +local p = ARGV[1] +local g = cjson.decode(ARGV[3]) +local old = redis.call('GET', p .. '_g_' .. g.gid) +if old == false then + return 'NOT_FOUND' +end +local os = cjson.decode(old) +if os.status ~= g.status then + return 'NOT_FOUND' +end +redis.call('ZADD', p .. '_u', ARGV[4], g.gid) +redis.call('SET', p .. '_g_' .. g.gid, ARGV[3], 'EX', ARGV[2]) + `) + dtmimp.E2P(err) +} diff --git a/dtmsvr/storage/sql.go b/dtmsvr/storage/sql.go new file mode 100644 index 0000000..61e8862 --- /dev/null +++ b/dtmsvr/storage/sql.go @@ -0,0 +1,138 @@ +package storage + +import ( + "fmt" + "math" + "time" + + "github.com/google/uuid" + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli/dtmimp" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type SqlStore struct { +} + +func (s *SqlStore) Ping() error { + dbr := dbGet().Exec("select 1") + return dbr.Error +} + +func (s *SqlStore) PopulateData(skipDrop bool) { + file := fmt.Sprintf("%s/storage.%s.sql", common.GetCallerCodeDir(), config.Store.Driver) + common.RunSQLScript(config.Store.GetDBConf(), file, skipDrop) +} + +func (s *SqlStore) FindTransGlobalStore(gid string) *TransGlobalStore { + trans := &TransGlobalStore{} + dbr := dbGet().Model(trans).Where("gid=?", gid).First(trans) + if dbr.Error == gorm.ErrRecordNotFound { + return nil + } + dtmimp.E2P(dbr.Error) + return trans +} + +func (s *SqlStore) ScanTransGlobalStores(position *string, limit int64) []TransGlobalStore { + globals := []TransGlobalStore{} + lid := math.MaxInt64 + if *position != "" { + lid = dtmimp.MustAtoi(*position) + } + dbr := dbGet().Must().Where("id < ?", lid).Order("id desc").Limit(int(limit)).Find(&globals) + if dbr.RowsAffected < limit { + *position = "" + } else { + *position = fmt.Sprintf("%d", globals[len(globals)-1].ID) + } + return globals +} + +func (s *SqlStore) FindBranches(gid string) []TransBranchStore { + branches := []TransBranchStore{} + dbGet().Must().Where("gid=?", gid).Order("id asc").Find(&branches) + return branches +} + +func (s *SqlStore) UpdateBranchesSql(branches []TransBranchStore, updates []string) *gorm.DB { + return dbGet().Clauses(clause.OnConflict{ + OnConstraint: "trans_branch_op_pkey", + DoUpdates: clause.AssignmentColumns(updates), + }).Create(branches) +} + +func (s *SqlStore) LockGlobalSaveBranches(gid string, status string, branches []TransBranchStore, branchStart int) { + err := dbGet().Transaction(func(tx *gorm.DB) error { + g := &TransGlobalStore{} + dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g) + if dbr.Error == nil { + dbr = tx.Save(branches) + } + return wrapError(dbr.Error) + }) + dtmimp.E2P(err) +} + +func (s *SqlStore) MaySaveNewTrans(global *TransGlobalStore, branches []TransBranchStore) error { + return dbGet().Transaction(func(db1 *gorm.DB) error { + db := &common.DB{DB: db1} + dbr := db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(global) + if dbr.RowsAffected <= 0 { // 如果这个不是新事务,返回错误 + return ErrUniqueConflict + } + if len(branches) > 0 { + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(&branches) + } + return nil + }) +} + +func (s *SqlStore) ChangeGlobalStatus(global *TransGlobalStore, newStatus string, updates []string, finished bool) { + old := global.Status + global.Status = newStatus + dbr := dbGet().Must().Model(global).Where("status=? and gid=?", old, global.Gid).Select(updates).Updates(global) + if dbr.RowsAffected == 0 { + dtmimp.E2P(ErrNotFound) + } +} + +func (s *SqlStore) TouchCronTime(global *TransGlobalStore, nextCronInterval int64) { + global.NextCronTime = common.GetNextTime(nextCronInterval) + global.UpdateTime = common.GetNextTime(0) + global.NextCronInterval = nextCronInterval + dbGet().Must().Model(global).Where("status=? and gid=?", global.Status, global.Gid). + Select([]string{"next_cron_time", "update_time", "next_cron_interval"}).Updates(global) +} + +func (s *SqlStore) LockOneGlobalTrans(expireIn time.Duration) *TransGlobalStore { + db := dbGet() + getTime := func(second int) string { + return map[string]string{ + "mysql": fmt.Sprintf("date_add(now(), interval %d second)", second), + "postgres": fmt.Sprintf("current_timestamp + interval '%d second'", second), + }[config.Store.Driver] + } + expire := int(expireIn / time.Second) + whereTime := fmt.Sprintf("next_cron_time < %s", getTime(expire)) + owner := uuid.NewString() + global := &TransGlobalStore{} + dbr := db.Must().Model(global). + Where(whereTime + "and status in ('prepared', 'aborting', 'submitted')"). + Limit(1). + Select([]string{"owner", "next_cron_time"}). + Updates(&TransGlobalStore{ + Owner: owner, + NextCronTime: common.GetNextTime(common.Config.RetryInterval), + }) + if dbr.RowsAffected == 0 { + return nil + } + dbr = db.Must().Where("owner=?", owner).First(global) + return global +} diff --git a/dtmsvr/dtmsvr.mysql.sql b/dtmsvr/storage/storage.mysql.sql similarity index 100% rename from dtmsvr/dtmsvr.mysql.sql rename to dtmsvr/storage/storage.mysql.sql diff --git a/dtmsvr/dtmsvr.postgres.sql b/dtmsvr/storage/storage.postgres.sql similarity index 100% rename from dtmsvr/dtmsvr.postgres.sql rename to dtmsvr/storage/storage.postgres.sql diff --git a/dtmsvr/storage/store.go b/dtmsvr/storage/store.go new file mode 100644 index 0000000..223818b --- /dev/null +++ b/dtmsvr/storage/store.go @@ -0,0 +1,54 @@ +package storage + +import ( + "errors" + "time" + + "github.com/go-redis/redis/v8" + "github.com/yedf/dtm/dtmcli/dtmimp" + "gorm.io/gorm" +) + +var ErrNotFound = errors.New("storage: NotFound") +var ErrShouldRetry = errors.New("storage: ShoudRetry") +var ErrUniqueConflict = errors.New("storage: UniqueKeyConflict") + +type Store interface { + Ping() error + PopulateData(skipDrop bool) + FindTransGlobalStore(gid string) *TransGlobalStore + ScanTransGlobalStores(position *string, limit int64) []TransGlobalStore + FindBranches(gid string) []TransBranchStore + UpdateBranchesSql(branches []TransBranchStore, updates []string) *gorm.DB + LockGlobalSaveBranches(gid string, status string, branches []TransBranchStore, branchStart int) + MaySaveNewTrans(global *TransGlobalStore, branches []TransBranchStore) error + ChangeGlobalStatus(global *TransGlobalStore, newStatus string, updates []string, finished bool) + TouchCronTime(global *TransGlobalStore, nextCronInterval int64) + LockOneGlobalTrans(expireIn time.Duration) *TransGlobalStore +} + +var stores map[string]Store = map[string]Store{ + "redis": &RedisStore{}, + "mysql": &SqlStore{}, + "postgres": &SqlStore{}, + "boltdb": &BoltdbStore{}, +} + +func GetStore() Store { + return stores[config.Store.Driver] +} + +// WaitStoreUp wait for db to go up +func WaitStoreUp() { + for err := GetStore().Ping(); err != nil; err = GetStore().Ping() { + time.Sleep(3 * time.Second) + } +} + +func wrapError(err error) error { + if err == gorm.ErrRecordNotFound || err == redis.Nil { + return ErrNotFound + } + dtmimp.E2P(err) + return err +} diff --git a/dtmsvr/storage/trans.go b/dtmsvr/storage/trans.go new file mode 100644 index 0000000..dbabc01 --- /dev/null +++ b/dtmsvr/storage/trans.go @@ -0,0 +1,52 @@ +package storage + +import ( + "time" + + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" +) + +type TransGlobalStore struct { + common.ModelBase + Gid string `json:"gid,omitempty"` + TransType string `json:"trans_type,omitempty"` + Steps []map[string]string `json:"steps,omitempty" gorm:"-"` + Payloads []string `json:"payloads,omitempty" gorm:"-"` + BinPayloads [][]byte `json:"-" gorm:"-"` + Status string `json:"status,omitempty"` + QueryPrepared string `json:"query_prepared,omitempty"` + Protocol string `json:"protocol,omitempty"` + CommitTime *time.Time `json:"commit_time,omitempty"` + FinishTime *time.Time `json:"finish_time,omitempty"` + RollbackTime *time.Time `json:"rollback_time,omitempty"` + Options string `json:"options,omitempty"` + CustomData string `json:"custom_data,omitempty"` + NextCronInterval int64 `json:"next_cron_interval,omitempty"` + NextCronTime *time.Time `json:"next_cron_time,omitempty"` + Owner string `json:"owner,omitempty"` + dtmcli.TransOptions +} + +// TableName TableName +func (*TransGlobalStore) TableName() string { + return "dtm.trans_global" +} + +// TransBranchStore branch transaction +type TransBranchStore struct { + common.ModelBase + Gid string `json:"gid,omitempty"` + URL string `json:"url,omitempty"` + BinData []byte + BranchID string `json:"branch_id,omitempty"` + Op string `json:"op,omitempty"` + Status string `json:"status,omitempty"` + FinishTime *time.Time `json:"finish_time,omitempty"` + RollbackTime *time.Time `json:"rollback_time,omitempty"` +} + +// TableName TableName +func (*TransBranchStore) TableName() string { + return "dtm.trans_branch_op" +} diff --git a/dtmsvr/storage/utils.go b/dtmsvr/storage/utils.go new file mode 100644 index 0000000..d2d73d1 --- /dev/null +++ b/dtmsvr/storage/utils.go @@ -0,0 +1,16 @@ +package storage + +import ( + "github.com/go-redis/redis/v8" + "github.com/yedf/dtm/common" +) + +var config = &common.Config + +func dbGet() *common.DB { + return common.DbGet(config.Store.GetDBConf()) +} + +func redisGet() *redis.Client { + return common.RedisGet() +} diff --git a/dtmsvr/dtmsvr.go b/dtmsvr/svr.go similarity index 76% rename from dtmsvr/dtmsvr.go rename to dtmsvr/svr.go index d5a4c0b..af92105 100644 --- a/dtmsvr/dtmsvr.go +++ b/dtmsvr/svr.go @@ -15,14 +15,8 @@ import ( "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/dtmgrpc/dtmgimp" - "github.com/yedf/dtm/examples" "github.com/yedf/dtmdriver" "google.golang.org/grpc" - "gorm.io/gorm/clause" - - _ "github.com/ychensha/dtmdriver-polaris" - _ "github.com/yedf/dtmdriver-gozero" - _ "github.com/yedf/dtmdriver-protocol1" ) // StartSvr StartSvr @@ -31,10 +25,10 @@ func StartSvr() { app := common.GetGinApp() app = httpMetrics(app) addRoute(app) - dtmimp.Logf("dtmsvr listen at: %d", common.DtmHttpPort) - go app.Run(fmt.Sprintf(":%d", common.DtmHttpPort)) + dtmimp.Logf("dtmsvr listen at: %d", config.HttpPort) + go app.Run(fmt.Sprintf(":%d", config.HttpPort)) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", common.DtmGrpcPort)) + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", config.GrpcPort)) dtmimp.FatalIfError(err) s := grpc.NewServer( grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( @@ -57,8 +51,7 @@ func StartSvr() { // PopulateDB setup mysql data func PopulateDB(skipDrop bool) { - file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCallerCodeDir(), config.DB["driver"]) - examples.RunSQLScript(config.DB, file, skipDrop) + GetStore().PopulateData(skipDrop) } // UpdateBranchAsyncInterval interval to flush branch @@ -67,6 +60,7 @@ var updateBranchAsyncChan chan branchStatus = make(chan branchStatus, 1000) func updateBranchAsync() { for { // flush branches every second + defer common.RecoverPanic(nil) updates := []TransBranch{} started := time.Now() checkInterval := 20 * time.Millisecond @@ -82,10 +76,8 @@ func updateBranchAsync() { } } for len(updates) > 0 { - dbr := dbGet().Clauses(clause.OnConflict{ - OnConstraint: "trans_branch_op_pkey", - DoUpdates: clause.AssignmentColumns([]string{"status", "finish_time", "update_time"}), - }).Create(updates) + dbr := GetStore().UpdateBranchesSql(updates, []string{"status", "finish_time", "update_time"}) + dtmimp.Logf("flushed %d branch status to db. affected: %d", len(updates), dbr.RowsAffected) if dbr.Error != nil { dtmimp.LogRedf("async update branch status error: %v", dbr.Error) diff --git a/dtmsvr/svr_imports.go b/dtmsvr/svr_imports.go new file mode 100644 index 0000000..f708b1f --- /dev/null +++ b/dtmsvr/svr_imports.go @@ -0,0 +1,7 @@ +package dtmsvr + +import ( + _ "github.com/ychensha/dtmdriver-polaris" + _ "github.com/yedf/dtmdriver-gozero" + _ "github.com/yedf/dtmdriver-protocol1" +) diff --git a/dtmsvr/trans_class.go b/dtmsvr/trans_class.go index bda1a37..4ce6b9e 100644 --- a/dtmsvr/trans_class.go +++ b/dtmsvr/trans_class.go @@ -7,69 +7,28 @@ package dtmsvr import ( - "errors" - "fmt" "time" "github.com/gin-gonic/gin" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/dtmgrpc/dtmgimp" - "gorm.io/gorm" + "github.com/yedf/dtm/dtmsvr/storage" ) -var errUniqueConflict = errors.New("unique key conflict error") - // TransGlobal global transaction type TransGlobal struct { - common.ModelBase - Gid string `json:"gid"` - TransType string `json:"trans_type"` - Steps []map[string]string `json:"steps" gorm:"-"` - Payloads []string `json:"payloads" gorm:"-"` - BinPayloads [][]byte `json:"-" gorm:"-"` - Status string `json:"status"` - QueryPrepared string `json:"query_prepared"` - Protocol string `json:"protocol"` - CommitTime *time.Time - FinishTime *time.Time - RollbackTime *time.Time - Options string - CustomData string `json:"custom_data"` - NextCronInterval int64 - NextCronTime *time.Time - dtmcli.TransOptions + storage.TransGlobalStore lastTouched time.Time // record the start time of process updateBranchSync bool } -// TableName TableName -func (*TransGlobal) TableName() string { - return "dtm.trans_global" -} - // TransBranch branch transaction -type TransBranch struct { - common.ModelBase - Gid string - URL string `json:"url"` - BinData []byte - BranchID string `json:"branch_id"` - Op string - Status string - FinishTime *time.Time - RollbackTime *time.Time -} - -// TableName TableName -func (*TransBranch) TableName() string { - return "dtm.trans_branch_op" -} +type TransBranch = storage.TransBranchStore type transProcessor interface { GenBranches() []TransBranch - ProcessOnce(db *common.DB, branches []TransBranch) error + ProcessOnce(branches []TransBranch) error } type processorCreator func(*TransGlobal) transProcessor @@ -118,7 +77,7 @@ func TransFromDtmRequest(c *dtmgimp.DtmRequest) *TransGlobal { if c.TransOptions != nil { o = c.TransOptions } - r := TransGlobal{ + r := TransGlobal{TransGlobalStore: storage.TransGlobalStore{ Gid: c.Gid, TransType: c.TransType, QueryPrepared: c.QueryPrepared, @@ -129,15 +88,9 @@ func TransFromDtmRequest(c *dtmgimp.DtmRequest) *TransGlobal { TimeoutToFail: o.TimeoutToFail, RetryInterval: o.RetryInterval, }, - } + }} if c.Steps != "" { dtmimp.MustUnmarshalString(c.Steps, &r.Steps) } return &r } - -func checkAffected(db1 *gorm.DB) { - if db1.RowsAffected == 0 { - panic(fmt.Errorf("rows affected 0, please check for abnormal trans")) - } -} diff --git a/dtmsvr/trans_process.go b/dtmsvr/trans_process.go index 7aad95a..38a68b9 100644 --- a/dtmsvr/trans_process.go +++ b/dtmsvr/trans_process.go @@ -12,28 +12,26 @@ import ( "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" - "gorm.io/gorm" - "gorm.io/gorm/clause" ) // Process process global transaction once -func (t *TransGlobal) Process(db *common.DB) map[string]interface{} { - r := t.process(db) +func (t *TransGlobal) Process() map[string]interface{} { + r := t.process() transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess) return r } -func (t *TransGlobal) process(db *common.DB) map[string]interface{} { +func (t *TransGlobal) process() map[string]interface{} { if t.Options != "" { dtmimp.MustUnmarshalString(t.Options, &t.TransOptions) } if !t.WaitResult { - go t.processInner(db) + go t.processInner() return dtmcli.MapSuccess } submitting := t.Status == dtmcli.StatusSubmitted - err := t.processInner(db) + err := t.processInner() if err != nil { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": err.Error()} } @@ -43,7 +41,7 @@ func (t *TransGlobal) process(db *common.DB) map[string]interface{} { return dtmcli.MapSuccess } -func (t *TransGlobal) processInner(db *common.DB) (rerr error) { +func (t *TransGlobal) processInner() (rerr error) { defer handlePanic(&rerr) defer func() { if rerr != nil { @@ -56,34 +54,22 @@ func (t *TransGlobal) processInner(db *common.DB) (rerr error) { } }() dtmimp.Logf("processing: %s status: %s", t.Gid, t.Status) - branches := []TransBranch{} - db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches) + branches := GetStore().FindBranches(t.Gid) t.lastTouched = time.Now() - rerr = t.getProcessor().ProcessOnce(db, branches) + rerr = t.getProcessor().ProcessOnce(branches) return } -func (t *TransGlobal) saveNew(db *common.DB) error { - return db.Transaction(func(db1 *gorm.DB) error { - db := &common.DB{DB: db1} - t.setNextCron(cronReset) - t.Options = dtmimp.MustMarshalString(t.TransOptions) - if t.Options == "{}" { - t.Options = "" - } - dbr := db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(t) - if dbr.RowsAffected <= 0 { // 如果这个不是新事务,返回错误 - return errUniqueConflict - } - branches := t.getProcessor().GenBranches() - if len(branches) > 0 { - checkLocalhost(branches) - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(&branches) - } - return nil - }) +func (t *TransGlobal) saveNew() error { + branches := t.getProcessor().GenBranches() + t.NextCronInterval = t.getNextCronInterval(cronReset) + t.NextCronTime = common.GetNextTime(t.NextCronInterval) + t.Options = dtmimp.MustMarshalString(t.TransOptions) + if t.Options == "{}" { + t.Options = "" + } + now := time.Now() + t.CreateTime = &now + t.UpdateTime = &now + return GetStore().MaySaveNewTrans(&t.TransGlobalStore, branches) } diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index e646f5d..b7fda72 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -11,28 +11,21 @@ import ( "strings" "time" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/dtmgrpc/dtmgimp" "github.com/yedf/dtmdriver" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "gorm.io/gorm" - "gorm.io/gorm/clause" ) -func (t *TransGlobal) touch(db *common.DB, ctype cronType) *gorm.DB { +func (t *TransGlobal) touchCronTime(ctype cronType) { t.lastTouched = time.Now() - updates := t.setNextCron(ctype) - return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Select(updates).Updates(t) + GetStore().TouchCronTime(&t.TransGlobalStore, t.getNextCronInterval(ctype)) } -func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { - old := t.Status - t.Status = status - updates := t.setNextCron(cronReset) - updates = append(updates, "status") +func (t *TransGlobal) changeStatus(status string) { + updates := []string{"status", "update_time"} now := time.Now() if status == dtmcli.StatusSucceed { t.FinishTime = &now @@ -41,30 +34,21 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { t.RollbackTime = &now updates = append(updates, "rollback_time") } - dbr := db.Must().Model(&TransGlobal{}).Where("status=? and gid=?", old, t.Gid).Select(updates).Updates(t) - checkAffected(dbr) - return dbr + t.UpdateTime = &now + GetStore().ChangeGlobalStatus(&t.TransGlobalStore, status, updates, status == dtmcli.StatusSucceed || status == dtmcli.StatusFailed) + t.Status = status } -func (t *TransGlobal) changeBranchStatus(db *common.DB, b *TransBranch, status string) { +func (t *TransGlobal) changeBranchStatus(b *TransBranch, status string, branchPos int) { now := time.Now() - if common.DtmConfig.UpdateBranchSync > 0 || t.updateBranchSync { - err := db.Transaction(func(tx *gorm.DB) error { - dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, t.Status).Find(&[]TransGlobal{}) - checkAffected(dbr) // check TransGlobal is not modified - dbr = tx.Model(b).Updates(map[string]interface{}{ - "status": status, - "finish_time": now, - "update_time": now, - }) - checkAffected(dbr) - return dbr.Error - }) - e2p(err) + b.Status = status + b.FinishTime = &now + b.UpdateTime = &now + if config.Store.Driver != dtmimp.DBTypeMysql && config.Store.Driver != dtmimp.DBTypePostgres || config.UpdateBranchSync > 0 || t.updateBranchSync { + GetStore().LockGlobalSaveBranches(t.Gid, t.Status, []TransBranch{*b}, branchPos) } else { // 为了性能优化,把branch的status更新异步化 updateBranchAsyncChan <- branchStatus{id: b.ID, status: status, finishTime: &now} } - b.Status = status } func (t *TransGlobal) isTimeout() bool { @@ -136,36 +120,32 @@ func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { return "", fmt.Errorf("http result should contains SUCCESS|FAILURE|ONGOING. grpc error should return nil|Aborted with message(FAILURE|ONGOING). \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body) } -func (t *TransGlobal) execBranch(db *common.DB, branch *TransBranch) error { +func (t *TransGlobal) execBranch(branch *TransBranch, branchPos int) error { status, err := t.getBranchResult(branch) if status != "" { - t.changeBranchStatus(db, branch, status) + t.changeBranchStatus(branch, status, branchPos) } branchMetrics(t, branch, status == dtmcli.StatusSucceed) // if time pass 1500ms and NextCronInterval is not default, then reset NextCronInterval if err == nil && time.Since(t.lastTouched)+NowForwardDuration >= 1500*time.Millisecond || t.NextCronInterval > config.RetryInterval && t.NextCronInterval > t.RetryInterval { - t.touch(db, cronReset) + t.touchCronTime(cronReset) } else if err == dtmimp.ErrOngoing { - t.touch(db, cronKeep) + t.touchCronTime(cronKeep) } else if err != nil { - t.touch(db, cronBackoff) + t.touchCronTime(cronBackoff) } return err } -func (t *TransGlobal) setNextCron(ctype cronType) []string { +func (t *TransGlobal) getNextCronInterval(ctype cronType) int64 { if ctype == cronBackoff { - t.NextCronInterval = t.NextCronInterval * 2 + return t.NextCronInterval * 2 } else if ctype == cronKeep { - // do nothing + return t.NextCronInterval } else if t.RetryInterval != 0 { - t.NextCronInterval = t.RetryInterval + return t.RetryInterval } else { - t.NextCronInterval = config.RetryInterval + return config.RetryInterval } - - next := time.Now().Add(time.Duration(t.NextCronInterval) * time.Second) - t.NextCronTime = &next - return []string{"next_cron_interval", "next_cron_time"} } diff --git a/dtmsvr/trans_type_msg.go b/dtmsvr/trans_type_msg.go index 4327bf5..f6bc9b1 100644 --- a/dtmsvr/trans_type_msg.go +++ b/dtmsvr/trans_type_msg.go @@ -10,7 +10,6 @@ import ( "fmt" "strings" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" ) @@ -39,25 +38,25 @@ func (t *transMsgProcessor) GenBranches() []TransBranch { return branches } -func (t *TransGlobal) mayQueryPrepared(db *common.DB) { +func (t *TransGlobal) mayQueryPrepared() { if !t.needProcess() || t.Status == dtmcli.StatusSubmitted { return } body, err := t.getURLResult(t.QueryPrepared, "", "", nil) if strings.Contains(body, dtmcli.ResultSuccess) { - t.changeStatus(db, dtmcli.StatusSubmitted) + t.changeStatus(dtmcli.StatusSubmitted) } else if strings.Contains(body, dtmcli.ResultFailure) { - t.changeStatus(db, dtmcli.StatusFailed) + t.changeStatus(dtmcli.StatusFailed) } else if strings.Contains(body, dtmcli.ResultOngoing) { - t.touch(db, cronReset) + t.touchCronTime(cronReset) } else { dtmimp.LogRedf("getting result failed for %s. error: %s", t.QueryPrepared, err.Error()) - t.touch(db, cronBackoff) + t.touchCronTime(cronBackoff) } } -func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { - t.mayQueryPrepared(db) +func (t *transMsgProcessor) ProcessOnce(branches []TransBranch) error { + t.mayQueryPrepared() if !t.needProcess() || t.Status == dtmcli.StatusPrepared { return nil } @@ -67,7 +66,7 @@ func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) e if branch.Op != dtmcli.BranchAction || branch.Status != dtmcli.StatusPrepared { continue } - err := t.execBranch(db, branch) + err := t.execBranch(branch, current) if err != nil { return err } @@ -76,7 +75,7 @@ func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) e } } if current == len(branches) { // msg 事务完成 - t.changeStatus(db, dtmcli.StatusSucceed) + t.changeStatus(dtmcli.StatusSucceed) return nil } panic("msg go pass all branch") diff --git a/dtmsvr/trans_type_saga.go b/dtmsvr/trans_type_saga.go index bf804a0..dad3d6f 100644 --- a/dtmsvr/trans_type_saga.go +++ b/dtmsvr/trans_type_saga.go @@ -10,7 +10,6 @@ import ( "fmt" "time" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" ) @@ -53,11 +52,11 @@ type branchResult struct { op string } -func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { +func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { // when saga tasks is fetched, it always need to process dtmimp.Logf("status: %s timeout: %t", t.Status, t.isTimeout()) if t.Status == dtmcli.StatusSubmitted && t.isTimeout() { - t.changeStatus(db, dtmcli.StatusAborting) + t.changeStatus(dtmcli.StatusAborting) } n := len(branches) @@ -108,7 +107,7 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) dtmimp.LogRedf("exec branch error: %v", err) } }() - err = t.execBranch(db, &branches[i]) + err = t.execBranch(&branches[i], i) } pickToRunActions := func() []int { toRun := []int{} @@ -175,11 +174,11 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) waitDoneOnce() } if t.Status == dtmcli.StatusSubmitted && rsAFailed == 0 && rsAToStart == rsASucceed { - t.changeStatus(db, dtmcli.StatusSucceed) + t.changeStatus(dtmcli.StatusSucceed) return nil } if t.Status == dtmcli.StatusSubmitted && (rsAFailed > 0 || t.isTimeout()) { - t.changeStatus(db, dtmcli.StatusAborting) + t.changeStatus(dtmcli.StatusAborting) } if t.Status == dtmcli.StatusAborting { toRun := pickToRunActions() @@ -189,7 +188,7 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) } } if t.Status == dtmcli.StatusAborting && rsCToStart == rsCSucceed { - t.changeStatus(db, dtmcli.StatusFailed) + t.changeStatus(dtmcli.StatusFailed) } return nil } diff --git a/dtmsvr/trans_type_tcc.go b/dtmsvr/trans_type_tcc.go index e22188d..88145d1 100644 --- a/dtmsvr/trans_type_tcc.go +++ b/dtmsvr/trans_type_tcc.go @@ -7,7 +7,6 @@ package dtmsvr import ( - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" ) @@ -24,23 +23,23 @@ func (t *transTccProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { +func (t *transTccProcessor) ProcessOnce(branches []TransBranch) error { if !t.needProcess() { return nil } if t.Status == dtmcli.StatusPrepared && t.isTimeout() { - t.changeStatus(db, dtmcli.StatusAborting) + t.changeStatus(dtmcli.StatusAborting) } op := dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchConfirm, dtmcli.BranchCancel).(string) for current := len(branches) - 1; current >= 0; current-- { if branches[current].Op == op && branches[current].Status == dtmcli.StatusPrepared { dtmimp.Logf("branch info: current: %d ID: %d", current, branches[current].ID) - err := t.execBranch(db, &branches[current]) + err := t.execBranch(&branches[current], current) if err != nil { return err } } } - t.changeStatus(db, dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) + t.changeStatus(dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) return nil } diff --git a/dtmsvr/trans_type_xa.go b/dtmsvr/trans_type_xa.go index cbdaf7c..ad44dd0 100644 --- a/dtmsvr/trans_type_xa.go +++ b/dtmsvr/trans_type_xa.go @@ -7,7 +7,6 @@ package dtmsvr import ( - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" ) @@ -24,22 +23,22 @@ func (t *transXaProcessor) GenBranches() []TransBranch { return []TransBranch{} } -func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) error { +func (t *transXaProcessor) ProcessOnce(branches []TransBranch) error { if !t.needProcess() { return nil } if t.Status == dtmcli.StatusPrepared && t.isTimeout() { - t.changeStatus(db, dtmcli.StatusAborting) + t.changeStatus(dtmcli.StatusAborting) } currentType := dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmcli.BranchCommit, dtmcli.BranchRollback).(string) - for _, branch := range branches { + for i, branch := range branches { if branch.Op == currentType && branch.Status != dtmcli.StatusSucceed { - err := t.execBranch(db, &branch) + err := t.execBranch(&branch, i) if err != nil { return err } } } - t.changeStatus(db, dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) + t.changeStatus(dtmimp.If(t.Status == dtmcli.StatusSubmitted, dtmcli.StatusSucceed, dtmcli.StatusFailed).(string)) return nil } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 79f8263..fe5c7cb 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -7,18 +7,13 @@ package dtmsvr import ( - "encoding/hex" - "errors" "fmt" - "net" - "strings" "time" - "github.com/bwmarrin/snowflake" + "github.com/google/uuid" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli/dtmimp" - "gorm.io/gorm" - "gorm.io/gorm/clause" + "github.com/yedf/dtm/dtmsvr/storage" ) type branchStatus struct { @@ -30,68 +25,23 @@ type branchStatus struct { var p2e = dtmimp.P2E var e2p = dtmimp.E2P -var config = &common.DtmConfig +var config = &common.Config -func dbGet() *common.DB { - return common.DbGet(config.DB) +func GetStore() storage.Store { + return storage.GetStore() } // TransProcessedTestChan only for test usage. when transaction processed once, write gid to this chan var TransProcessedTestChan chan string = nil -var gNode *snowflake.Node = nil - -func init() { - node, err := snowflake.NewNode(1) - e2p(err) - gNode = node -} - -// GenGid generate gid, use ip + snowflake +// GenGid generate gid, use uuid func GenGid() string { - return getOneHexIP() + "_" + gNode.Generate().Base58() -} - -func getOneHexIP() string { - addrs, err := net.InterfaceAddrs() - if err == nil { - for _, address := range addrs { - if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { - ip := ipnet.IP.To4().String() - ns := strings.Split(ip, ".") - r := []byte{} - for _, n := range ns { - r = append(r, byte(dtmimp.MustAtoi(n))) - } - return hex.EncodeToString(r) - } - } - } - fmt.Printf("err is: %s", err.Error()) - return "" // 获取不到IP,则直接返回空 -} - -// transFromDb construct trans from db -func transFromDb(db *gorm.DB, gid string, lock bool) *TransGlobal { - m := TransGlobal{} - if lock { - db = db.Clauses(clause.Locking{Strength: "UPDATE"}) - } - dbr := db.Model(&m).Where("gid=?", gid).First(&m) - if dbr.Error == gorm.ErrRecordNotFound { - return nil - } - e2p(dbr.Error) - return &m + return uuid.NewString() } -func checkLocalhost(branches []TransBranch) { - if config.DisableLocalhost == 0 { - return - } - for _, branch := range branches { - if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") { - panic(errors.New("url for localhost is disabled. check for your config")) - } - } +// GetTransGlobal construct trans from db +func GetTransGlobal(gid string) *TransGlobal { + trans := GetStore().FindTransGlobalStore(gid) + dtmimp.PanicIf(trans == nil, fmt.Errorf("no TransGlobal with gid: %s found", gid)) + return &TransGlobal{TransGlobalStore: *trans} } diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index 08b1acd..6c6f371 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -10,44 +10,18 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/yedf/dtm/common" - "github.com/yedf/dtm/dtmcli/dtmimp" ) func TestUtils(t *testing.T) { - common.MustLoadConfig() - db := dbGet() - db.NoMust() - err := dtmimp.CatchP(func() { - checkAffected(db.DB) - }) - assert.Error(t, err) - CronExpiredTrans(1) sleepCronTime() } -func TestCheckLocalHost(t *testing.T) { - config.DisableLocalhost = 1 - err := dtmimp.CatchP(func() { - checkLocalhost([]TransBranch{{URL: "http://localhost"}}) - }) - assert.Error(t, err) - config.DisableLocalhost = 0 - err = dtmimp.CatchP(func() { - checkLocalhost([]TransBranch{{URL: "http://localhost"}}) - }) - assert.Nil(t, err) -} - func TestSetNextCron(t *testing.T) { tg := TransGlobal{} tg.RetryInterval = 15 - tg.setNextCron(cronReset) - assert.Equal(t, int64(15), tg.NextCronInterval) + assert.Equal(t, int64(15), tg.getNextCronInterval(cronReset)) tg.RetryInterval = 0 - tg.setNextCron(cronReset) - assert.Equal(t, config.RetryInterval, tg.NextCronInterval) - tg.setNextCron(cronBackoff) - assert.Equal(t, config.RetryInterval*2, tg.NextCronInterval) + assert.Equal(t, config.RetryInterval, tg.getNextCronInterval(cronReset)) + assert.Equal(t, config.RetryInterval*2, tg.getNextCronInterval(cronBackoff)) } diff --git a/examples/base.go b/examples/base.go new file mode 100644 index 0000000..d9b8fcf --- /dev/null +++ b/examples/base.go @@ -0,0 +1,14 @@ +package examples + +import "fmt" + +func Startup() { + InitConfig() + GrpcStartup() + BaseAppStartup() +} + +func InitConfig() { + DtmHttpServer = fmt.Sprintf("http://localhost:%d/api/dtmsvr", config.HttpPort) + DtmGrpcServer = fmt.Sprintf("localhost:%d", config.GrpcPort) +} diff --git a/examples/base_grpc.go b/examples/base_grpc.go index 91546a5..d0002b4 100644 --- a/examples/base_grpc.go +++ b/examples/base_grpc.go @@ -36,7 +36,7 @@ var XaGrpcClient *dtmgrpc.XaGrpcClient = nil func init() { setupFuncs["XaGrpcSetup"] = func(app *gin.Engine) { - XaGrpcClient = dtmgrpc.NewXaGrpcClient(DtmGrpcServer, config.DB, BusiGrpc+"/examples.Busi/XaNotify") + XaGrpcClient = dtmgrpc.NewXaGrpcClient(DtmGrpcServer, config.ExamplesDB, BusiGrpc+"/examples.Busi/XaNotify") } } diff --git a/examples/base_types.go b/examples/base_types.go index 8ebfb5f..a40f491 100644 --- a/examples/base_types.go +++ b/examples/base_types.go @@ -19,10 +19,10 @@ import ( ) // DtmHttpServer dtm service address -var DtmHttpServer = fmt.Sprintf("http://localhost:%d/api/dtmsvr", common.DtmHttpPort) +var DtmHttpServer = fmt.Sprintf("http://localhost:%d/api/dtmsvr", 36789) // DtmGrpcServer dtm grpc service address -var DtmGrpcServer = fmt.Sprintf("localhost:%d", common.DtmGrpcPort) +var DtmGrpcServer = fmt.Sprintf("localhost:%d", 36790) // TransReq transaction request payload type TransReq struct { @@ -76,11 +76,11 @@ func infoFromContext(c *gin.Context) *dtmcli.BranchBarrier { } func dbGet() *common.DB { - return common.DbGet(config.DB) + return common.DbGet(config.ExamplesDB) } func sdbGet() *sql.DB { - db, err := dtmimp.PooledDB(config.DB) + db, err := dtmimp.PooledDB(config.ExamplesDB) dtmimp.FatalIfError(err) return db } diff --git a/examples/data.go b/examples/data.go index 7d22d3d..f7ad10b 100644 --- a/examples/data.go +++ b/examples/data.go @@ -8,35 +8,15 @@ package examples import ( "fmt" - "io/ioutil" - "strings" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli/dtmimp" ) -var config = &common.DtmConfig - -// RunSQLScript 1 -func RunSQLScript(conf map[string]string, script string, skipDrop bool) { - con, err := dtmimp.StandaloneDB(conf) - dtmimp.FatalIfError(err) - defer func() { con.Close() }() - content, err := ioutil.ReadFile(script) - dtmimp.FatalIfError(err) - sqls := strings.Split(string(content), ";") - for _, sql := range sqls { - s := strings.TrimSpace(sql) - if s == "" || (skipDrop && strings.Contains(s, "drop")) { - continue - } - _, err = dtmimp.DBExec(con, s) - dtmimp.FatalIfError(err) - } -} +var config = &common.Config func resetXaData() { - if config.DB["driver"] != "mysql" { + if config.ExamplesDB.Driver != "mysql" { return } @@ -54,10 +34,10 @@ func resetXaData() { // PopulateDB populate example mysql data func PopulateDB(skipDrop bool) { resetXaData() - file := fmt.Sprintf("%s/examples.%s.sql", common.GetCallerCodeDir(), config.DB["driver"]) - RunSQLScript(config.DB, file, skipDrop) - file = fmt.Sprintf("%s/../dtmcli/barrier.%s.sql", common.GetCallerCodeDir(), config.DB["driver"]) - RunSQLScript(config.DB, file, skipDrop) + file := fmt.Sprintf("%s/examples.%s.sql", common.GetCallerCodeDir(), config.ExamplesDB.Driver) + common.RunSQLScript(config.ExamplesDB, file, skipDrop) + file = fmt.Sprintf("%s/../dtmcli/barrier.%s.sql", common.GetCallerCodeDir(), config.ExamplesDB.Driver) + common.RunSQLScript(config.ExamplesDB, file, skipDrop) } type sampleInfo struct { diff --git a/examples/http_xa.go b/examples/http_xa.go index b4d0ca3..5ed0159 100644 --- a/examples/http_xa.go +++ b/examples/http_xa.go @@ -20,7 +20,7 @@ var XaClient *dtmcli.XaClient = nil func init() { setupFuncs["XaSetup"] = func(app *gin.Engine) { var err error - XaClient, err = dtmcli.NewXaClient(DtmHttpServer, config.DB, Busi+"/xa", func(path string, xa *dtmcli.XaClient) { + XaClient, err = dtmcli.NewXaClient(DtmHttpServer, config.ExamplesDB, Busi+"/xa", func(path string, xa *dtmcli.XaClient) { app.POST(path, common.WrapHandler(func(c *gin.Context) (interface{}, error) { return xa.HandleCallback(c.Query("gid"), c.Query("branch_id"), c.Query("op")) })) diff --git a/go.mod b/go.mod index eea5aaf..342daae 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,10 @@ require ( github.com/bwmarrin/snowflake v0.3.0 github.com/gin-gonic/gin v1.6.3 github.com/go-playground/assert/v2 v2.0.1 + github.com/go-redis/redis/v8 v8.11.4 github.com/go-resty/resty/v2 v2.7.0 github.com/go-sql-driver/mysql v1.6.0 + github.com/google/uuid v1.3.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/lib/pq v1.10.3 github.com/polarismesh/grpc-go-polaris v0.0.0-20211128162137-1a59cd7b5733 // indirect @@ -17,6 +19,7 @@ require ( github.com/yedf/dtmdriver v0.0.0-20211203060147-29426c663b6e github.com/yedf/dtmdriver-gozero v0.0.0-20211204083751-a14485949435 github.com/yedf/dtmdriver-protocol1 v0.0.0-20211205112411-d7a7052dc90e + go.etcd.io/bbolt v1.3.6 go.uber.org/atomic v1.9.0 // indirect go.uber.org/automaxprocs v1.4.1-0.20210525221652-0180b04c18a7 go.uber.org/multierr v1.7.0 // indirect diff --git a/go.sum b/go.sum index 0c6c588..665a229 100644 --- a/go.sum +++ b/go.sum @@ -43,7 +43,6 @@ github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWX github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0= -github.com/agiledragon/gomonkey v0.0.0-20190517145658-8fa491f7b918 h1:a88Ln+jbIokfi6xoKtq10dbgp4VMg1CmHF1J42p8EyE= github.com/agiledragon/gomonkey v0.0.0-20190517145658-8fa491f7b918/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -62,11 +61,12 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4= -github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0= github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -89,6 +89,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -108,6 +110,7 @@ github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoD github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= @@ -144,7 +147,10 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg= +github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= @@ -163,7 +169,6 @@ github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -226,7 +231,6 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gnostic v0.4.1 h1:DLJCy1n/vrD4HPjOvYcT8aYQXpPIzoRZONaYwyycI+I= github.com/googleapis/gnostic v0.4.1/go.mod h1:LRhVm6pbyptWbWbuZ38d1eyptfvIytN3ir6b65WBswg= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -318,7 +322,6 @@ github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMW github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= @@ -381,6 +384,7 @@ github.com/natefinch/lumberjack v2.0.0+incompatible h1:4QJd3OLAMgj7ph+yZTuX13Ld4 github.com/natefinch/lumberjack v2.0.0+incompatible/go.mod h1:Wi9p2TTF5DG5oU+6YfsmYQpsTIOm0B1VNzQg9Mw6nPk= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -389,12 +393,14 @@ github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/openzipkin/zipkin-go v0.2.5/go.mod h1:KpXfKdgRDnnhsxw4pNIH9Md5lyFqKUa4YDFlwRYAMyE= @@ -453,15 +459,12 @@ github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMB github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337 h1:WN9BUFbdyOsSH/XohnWpXOlq9NBD5sGAB2FciQMUEe8= github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -500,6 +503,8 @@ github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da/go.mod h1:E1AXubJB github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M= github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8= +go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= +go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.etcd.io/etcd/api/v3 v3.5.1 h1:v28cktvBq+7vGyJXF8G+rWJmj+1XUmMtqcLnH8hDocM= go.etcd.io/etcd/api/v3 v3.5.1/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= go.etcd.io/etcd/client/pkg/v3 v3.5.1 h1:XIQcHCFSG53bJETYeRJtIxdLv2EWRGxcfzR8lSnTH4E= @@ -514,7 +519,6 @@ go.opentelemetry.io/otel v1.1.0 h1:8p0uMLcyyIx0KHNTgO8o3CW8A1aA+dJZJW6PvnMz0Wc= go.opentelemetry.io/otel v1.1.0/go.mod h1:7cww0OW51jQ8IaZChIEdqLwgh+44+7uiTdWsAL0wQpA= go.opentelemetry.io/otel/exporters/jaeger v1.1.0/go.mod h1:D/GIBwAdrFTTqCy1iITpC9nh5rgJpIbFVgkhlz2vCXk= go.opentelemetry.io/otel/exporters/zipkin v1.1.0/go.mod h1:LZwDnf1mVGTPMq9hdRUHfFBH30SuQvZ1BJaVywpg0VI= -go.opentelemetry.io/otel/sdk v1.1.0 h1:j/1PngUJIDOddkCILQYTevrTIbWd494djgGkSsMit+U= go.opentelemetry.io/otel/sdk v1.1.0/go.mod h1:3aQvM6uLm6C4wJpHtT8Od3vNzeZ34Pqc6bps8MywWzo= go.opentelemetry.io/otel/trace v1.1.0 h1:N25T9qCL0+7IpOT8RrRy0WYlL7y6U0WiUJzXcVdXY/o= go.opentelemetry.io/otel/trace v1.1.0/go.mod h1:i47XtdcBQiktu5IsrPqOHe8w+sBmnLwwHt8wiUsWGTI= @@ -680,6 +684,7 @@ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201112073958-5cba982894dd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -776,7 +781,6 @@ google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -829,8 +833,8 @@ gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaD gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/helper/compose.cloud.yml b/helper/compose.cloud.yml index 1336052..641adf9 100644 --- a/helper/compose.cloud.yml +++ b/helper/compose.cloud.yml @@ -5,16 +5,13 @@ services: volumes: - /etc/localtime:/etc/localtime:ro - /etc/timezone:/etc/timezone:ro + - ..:/app/dtm extra_hosts: - 'host.docker.internal:host-gateway' environment: IS_DOCKER: 1 - DISABLE_LOCALHOST: 1 - RETRY_LIMIT: 6 ports: - '9080:8080' - volumes: - - ..:/app/dtm mysql: image: 'mysql:5.7' volumes: diff --git a/helper/compose.dev.yml b/helper/compose.dev.yml deleted file mode 100644 index d904d35..0000000 --- a/helper/compose.dev.yml +++ /dev/null @@ -1,34 +0,0 @@ -version: '3.3' -services: - api: - image: golang:1.16.6-alpine3.14 - extra_hosts: - - 'host.docker.internal:host-gateway' - volumes: - - /etc/localtime:/etc/localtime:ro - - /etc/timezone:/etc/timezone:ro - environment: - IS_DOCKER: '1' - GOPROXY: 'https://mirrors.aliyun.com/goproxy/,direct' - ports: - - '8080:8080' - - '8082:8082' - - '58080:58080' - volumes: - - ..:/app/work - command: ['go', 'run', '/app/work/app/main.go', 'dev'] - working_dir: /app/work - mysql: - image: 'mysql:5.7' - volumes: - - /etc/localtime:/etc/localtime:ro - - /etc/timezone:/etc/timezone:ro - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: 1 - command: - [ - '--character-set-server=utf8mb4', - '--collation-server=utf8mb4_unicode_ci', - ] - ports: - - '3306:3306' diff --git a/helper/compose.qs.yml b/helper/compose.store.yml similarity index 61% rename from helper/compose.qs.yml rename to helper/compose.store.yml index 2b50a8b..75a8868 100644 --- a/helper/compose.qs.yml +++ b/helper/compose.store.yml @@ -1,28 +1,12 @@ version: '3.3' services: - api: - image: 'yedf/dtm' - environment: - IS_DOCKER: '1' - ports: - - '8080:8080' - - '8082:8082' - - '58080:58080' + mysql: + image: 'mysql:5.7' volumes: - - ..:/app/work - /etc/localtime:/etc/localtime:ro - /etc/timezone:/etc/timezone:ro - command: ['/app/dtm/main', 'qs'] - working_dir: /app/work - extra_hosts: - - 'host.docker.internal:host-gateway' - db: - image: 'mysql:5.7' environment: MYSQL_ALLOW_EMPTY_PASSWORD: 1 - volumes: - - /etc/localtime:/etc/localtime:ro - - /etc/timezone:/etc/timezone:ro command: [ '--character-set-server=utf8mb4', @@ -30,3 +14,10 @@ services: ] ports: - '3306:3306' + redis: + image: 'redis' + volumes: + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro + ports: + - '6379:6379' diff --git a/test/api_test.go b/test/api_test.go index eefe35a..e44ed14 100644 --- a/test/api_test.go +++ b/test/api_test.go @@ -7,6 +7,7 @@ package test import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -14,14 +15,13 @@ import ( "github.com/yedf/dtm/examples" ) -const gidTestAPI = "TestAPI" - func TestAPIQuery(t *testing.T) { - err := genMsg(gidTestAPI).Submit() + gid := dtmimp.GetFuncName() + err := genMsg(gid).Submit() + assert.Nil(t, err) + waitTransProcessed(gid) + resp, err := dtmimp.RestyClient.R().SetQueryParam("gid", gid).Get(examples.DtmHttpServer + "/query") assert.Nil(t, err) - waitTransProcessed(gidTestAPI) - resp, err := dtmimp.RestyClient.R().SetQueryParam("gid", gidTestAPI).Get(examples.DtmHttpServer + "/query") - e2p(err) m := map[string]interface{}{} assert.Equal(t, resp.StatusCode(), 200) dtmimp.MustUnmarshalString(resp.String(), &m) @@ -41,10 +41,35 @@ func TestAPIQuery(t *testing.T) { } func TestAPIAll(t *testing.T) { - _, err := dtmimp.RestyClient.R().Get(examples.DtmHttpServer + "/all") + for i := 0; i < 3; i++ { // add three + gid := dtmimp.GetFuncName() + fmt.Sprintf("%d", i) + err := genMsg(gid).Submit() + assert.Nil(t, err) + waitTransProcessed(gid) + } + resp, err := dtmimp.RestyClient.R().SetQueryParam("limit", "1").Get(examples.DtmHttpServer + "/all") assert.Nil(t, err) - _, err = dtmimp.RestyClient.R().SetQueryParam("last_id", "10").Get(examples.DtmHttpServer + "/all") + m := map[string]interface{}{} + dtmimp.MustUnmarshalString(resp.String(), &m) + nextPos := m["next_position"].(string) + assert.NotEqual(t, "", nextPos) + + resp, err = dtmimp.RestyClient.R().SetQueryParams(map[string]string{ + "limit": "1", + "position": nextPos, + }).Get(examples.DtmHttpServer + "/all") assert.Nil(t, err) - resp, err := dtmimp.RestyClient.R().SetQueryParam("last_id", "abc").Get(examples.DtmHttpServer + "/all") - assert.Equal(t, resp.StatusCode(), 500) + dtmimp.MustUnmarshalString(resp.String(), &m) + nextPos2 := m["next_position"].(string) + assert.NotEqual(t, "", nextPos2) + assert.NotEqual(t, nextPos, nextPos2) + + resp, err = dtmimp.RestyClient.R().SetQueryParams(map[string]string{ + "limit": "1000", + "position": nextPos, + }).Get(examples.DtmHttpServer + "/all") + assert.Nil(t, err) + dtmimp.MustUnmarshalString(resp.String(), &m) + nextPos3 := m["next_position"].(string) + assert.Equal(t, "", nextPos3) } diff --git a/test/base_test.go b/test/base_test.go index d4024b9..7dbc192 100644 --- a/test/base_test.go +++ b/test/base_test.go @@ -29,7 +29,7 @@ func (BarrierModel) TableName() string { return "dtm_barrier.barrier" } func TestBaseSqlDB(t *testing.T) { asserts := assert.New(t) - db := common.DbGet(config.DB) + db := common.DbGet(config.ExamplesDB) barrier := &dtmcli.BranchBarrier{ TransType: "saga", Gid: "gid2", diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index 05930c1..52ac6ae 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -24,16 +24,11 @@ var Busi = examples.Busi var app *gin.Engine func getTransStatus(gid string) string { - sm := TransGlobal{} - dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) - e2p(dbr.Error) - return sm.Status + return dtmsvr.GetTransGlobal(gid).Status } func getBranchesStatus(gid string) []string { - branches := []TransBranch{} - dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Order("id").Find(&branches) - e2p(dbr.Error) + branches := dtmsvr.GetStore().FindBranches(gid) status := []string{} for _, branch := range branches { status = append(status, branch.Status) @@ -47,7 +42,10 @@ func assertSucceed(t *testing.T, gid string) { } func TestUpdateBranchAsync(t *testing.T) { - common.DtmConfig.UpdateBranchSync = 0 + if config.Store.Driver != "mysql" { + return + } + common.Config.UpdateBranchSync = 0 saga := genSaga1(dtmimp.GetFuncName(), false, false) saga.SetOptions(&dtmcli.TransOptions{WaitResult: true}) err := saga.Submit() @@ -56,5 +54,5 @@ func TestUpdateBranchAsync(t *testing.T) { time.Sleep(dtmsvr.UpdateBranchAsyncInterval) assert.Equal(t, []string{StatusPrepared, StatusSucceed}, getBranchesStatus(saga.Gid)) assert.Equal(t, StatusSucceed, getTransStatus(saga.Gid)) - common.DtmConfig.UpdateBranchSync = 1 + common.Config.UpdateBranchSync = 1 } diff --git a/test/main_test.go b/test/main_test.go index 251d4e5..9207a06 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -7,28 +7,54 @@ package test import ( + "os" "testing" "time" + "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmsvr" "github.com/yedf/dtm/examples" ) +func exitIf(code int) { + if code != 0 { + os.Exit(code) + } +} + func TestMain(m *testing.M) { common.MustLoadConfig() - dtmcli.SetCurrentDBType(common.DtmConfig.DB["driver"]) + dtmcli.SetCurrentDBType(common.Config.ExamplesDB.Driver) dtmsvr.TransProcessedTestChan = make(chan string, 1) dtmsvr.NowForwardDuration = 0 * time.Second dtmsvr.CronForwardDuration = 180 * time.Second - common.DtmConfig.UpdateBranchSync = 1 - dtmsvr.PopulateDB(false) - examples.PopulateDB(false) + common.Config.UpdateBranchSync = 1 + // 启动组件 go dtmsvr.StartSvr() examples.GrpcStartup() app = examples.BaseAppStartup() + app.POST(examples.BusiAPI+"/TccBSleepCancel", common.WrapHandler(func(c *gin.Context) (interface{}, error) { + return disorderHandler(c) + })) + tenv := os.Getenv("TEST_STORE") + if tenv == "boltdb" { + config.Store.Driver = "boltdb" + } else if tenv == "mysql" { + config.Store.Driver = "mysql" + config.Store.Host = "localhost" + config.Store.Port = 3306 + config.Store.User = "root" + config.Store.Password = "" + } else { + config.Store.Driver = "redis" + config.Store.Host = "localhost" + config.Store.Port = 6379 + } + dtmsvr.PopulateDB(false) + examples.PopulateDB(false) + exitIf(m.Run()) - m.Run() } diff --git a/test/store_test.go b/test/store_test.go new file mode 100644 index 0000000..24a2f22 --- /dev/null +++ b/test/store_test.go @@ -0,0 +1,98 @@ +package test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/yedf/dtm/dtmcli/dtmimp" + "github.com/yedf/dtm/dtmsvr/storage" +) + +func initTransGlobal(gid string) (*storage.TransGlobalStore, storage.Store) { + next := time.Now().Add(10 * time.Second) + g := &storage.TransGlobalStore{Gid: gid, Status: "prepared", NextCronTime: &next} + bs := []storage.TransBranchStore{ + {Gid: gid, BranchID: "01"}, + } + s := storage.GetStore() + err := s.MaySaveNewTrans(g, bs) + dtmimp.E2P(err) + return g, s +} + +func TestStoreSave(t *testing.T) { + gid := dtmimp.GetFuncName() + bs := []storage.TransBranchStore{ + {Gid: gid, BranchID: "01"}, + {Gid: gid, BranchID: "02"}, + } + g, s := initTransGlobal(gid) + g2 := s.FindTransGlobalStore(gid) + assert.NotNil(t, g2) + assert.Equal(t, gid, g2.Gid) + + bs2 := s.FindBranches(gid) + assert.Equal(t, len(bs2), int(1)) + assert.Equal(t, "01", bs2[0].BranchID) + + s.LockGlobalSaveBranches(gid, g.Status, []storage.TransBranchStore{bs[1]}, -1) + bs3 := s.FindBranches(gid) + assert.Equal(t, 2, len(bs3)) + assert.Equal(t, "02", bs3[1].BranchID) + assert.Equal(t, "01", bs3[0].BranchID) + + err := dtmimp.CatchP(func() { + s.LockGlobalSaveBranches(g.Gid, "submitted", []storage.TransBranchStore{bs[1]}, 1) + }) + assert.Equal(t, storage.ErrNotFound, err) + + s.ChangeGlobalStatus(g, "succeed", []string{}, true) +} + +func TestStoreChangeStatus(t *testing.T) { + gid := dtmimp.GetFuncName() + g, s := initTransGlobal(gid) + g.Status = "no" + err := dtmimp.CatchP(func() { + s.ChangeGlobalStatus(g, "submitted", []string{}, false) + }) + assert.Equal(t, storage.ErrNotFound, err) + g.Status = "prepared" + s.ChangeGlobalStatus(g, "submitted", []string{}, false) + s.ChangeGlobalStatus(g, "succeed", []string{}, true) +} + +func TestStoreLockTrans(t *testing.T) { + // lock trans will only lock unfinished trans. ensure all other trans are finished + gid := dtmimp.GetFuncName() + g, s := initTransGlobal(gid) + + g2 := s.LockOneGlobalTrans(2 * time.Duration(config.RetryInterval) * time.Second) + assert.NotNil(t, g2) + assert.Equal(t, gid, g2.Gid) + + s.TouchCronTime(g, 3*config.RetryInterval) + g2 = s.LockOneGlobalTrans(2 * time.Duration(config.RetryInterval) * time.Second) + assert.Nil(t, g2) + + s.TouchCronTime(g, 1*config.RetryInterval) + g2 = s.LockOneGlobalTrans(2 * time.Duration(config.RetryInterval) * time.Second) + assert.NotNil(t, g2) + assert.Equal(t, gid, g2.Gid) + + s.ChangeGlobalStatus(g, "succeed", []string{}, true) + g2 = s.LockOneGlobalTrans(2 * time.Duration(config.RetryInterval) * time.Second) + assert.Nil(t, g2) +} + +func TestStoreWait(t *testing.T) { + storage.WaitStoreUp() +} + +func TestUpdateBranchSql(t *testing.T) { + if !config.Store.IsDB() { + r := storage.GetStore().UpdateBranchesSql(nil, nil) + assert.Nil(t, r) + } +} diff --git a/test/tcc_barrier_test.go b/test/tcc_barrier_test.go index d2139c9..a1b3c88 100644 --- a/test/tcc_barrier_test.go +++ b/test/tcc_barrier_test.go @@ -17,7 +17,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" "github.com/stretchr/testify/assert" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/examples" @@ -51,6 +50,8 @@ func TestTccBarrierRollback(t *testing.T) { assert.Equal(t, []string{StatusSucceed, StatusPrepared, StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) } +var disorderHandler func(c *gin.Context) (interface{}, error) = nil + func TestTccBarrierDisorder(t *testing.T) { timeoutChan := make(chan string, 2) finishedChan := make(chan string, 2) @@ -63,7 +64,7 @@ func TestTccBarrierDisorder(t *testing.T) { // 请参见子事务屏障里的时序图,这里为了模拟该时序图,手动拆解了callbranch branchID := tcc.NewSubBranchID() sleeped := false - app.POST(examples.BusiAPI+"/TccBSleepCancel", common.WrapHandler(func(c *gin.Context) (interface{}, error) { + disorderHandler = func(c *gin.Context) (interface{}, error) { res, err := examples.TccBarrierTransOutCancel(c) if !sleeped { sleeped = true @@ -72,7 +73,7 @@ func TestTccBarrierDisorder(t *testing.T) { finishedChan <- "1" } return res, err - })) + } // 注册子事务 resp, err := dtmimp.RestyClient.R(). SetBody(map[string]interface{}{ diff --git a/test/types.go b/test/types.go index 399cb6b..7c54b86 100644 --- a/test/types.go +++ b/test/types.go @@ -15,10 +15,10 @@ import ( "github.com/yedf/dtm/dtmsvr" ) -var config = &common.DtmConfig +var config = &common.Config func dbGet() *common.DB { - return common.DbGet(config.DB) + return common.DbGet(config.ExamplesDB) } // waitTransProcessed only for test usage. wait for transaction processed once diff --git a/test/xa_cover_test.go b/test/xa_cover_test.go index a2c515e..bd6d8be 100644 --- a/test/xa_cover_test.go +++ b/test/xa_cover_test.go @@ -11,22 +11,22 @@ import ( ) func TestXaCoverDBError(t *testing.T) { - oldDriver := getXc().Conf["driver"] + oldDriver := getXc().Conf.Driver gid := dtmimp.GetFuncName() err := getXc().XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { req := examples.GenTransReq(30, false, false) _, err := xa.CallBranch(req, examples.Busi+"/TransOutXa") assert.Nil(t, err) - getXc().Conf["driver"] = "no-driver" + getXc().Conf.Driver = "no-driver" _, err = xa.CallBranch(req, examples.Busi+"/TransInXa") assert.Error(t, err) - getXc().Conf["driver"] = oldDriver // make abort succeed + getXc().Conf.Driver = oldDriver // make abort succeed return nil, err }) assert.Error(t, err) - getXc().Conf["driver"] = "no-driver" // make xa rollback failed + getXc().Conf.Driver = "no-driver" // make xa rollback failed waitTransProcessed(gid) - getXc().Conf["driver"] = oldDriver + getXc().Conf.Driver = oldDriver cronTransOnceForwardNow(500) // rollback succeeded here assert.Equal(t, StatusFailed, getTransStatus(gid)) } diff --git a/test/xa_test.go b/test/xa_test.go index 0cd96ba..7122910 100644 --- a/test/xa_test.go +++ b/test/xa_test.go @@ -12,7 +12,6 @@ import ( "github.com/go-resty/resty/v2" "github.com/stretchr/testify/assert" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli/dtmimp" "github.com/yedf/dtm/examples" @@ -44,7 +43,7 @@ func TestXaDuplicate(t *testing.T) { req := examples.GenTransReq(30, false, false) _, err := xa.CallBranch(req, examples.Busi+"/TransOutXa") assert.Nil(t, err) - sdb, err := dtmimp.StandaloneDB(common.DtmConfig.DB) + sdb, err := dtmimp.StandaloneDB(config.ExamplesDB) assert.Nil(t, err) if dtmcli.GetCurrentDBType() == dtmcli.DBTypeMysql { _, err = dtmimp.DBExec(sdb, "xa recover")