diff --git a/dtmsvr/storage/boltdb/boltdb.go b/dtmsvr/storage/boltdb/boltdb.go index 439cfd0..3441f90 100644 --- a/dtmsvr/storage/boltdb/boltdb.go +++ b/dtmsvr/storage/boltdb/boltdb.go @@ -13,7 +13,6 @@ import ( "time" bolt "go.etcd.io/bbolt" - "gorm.io/gorm" "github.com/dtm-labs/dtm/common" "github.com/dtm-labs/dtm/dtmcli/dtmimp" @@ -297,8 +296,8 @@ func (s *BoltdbStore) FindBranches(gid string) []storage.TransBranchStore { return branches } -func (s *BoltdbStore) UpdateBranchesSql(branches []storage.TransBranchStore, updates []string) *gorm.DB { - return nil // not implemented +func (s *BoltdbStore) UpdateBranches(branches []storage.TransBranchStore, updates []string) (int, error) { + return 0, nil // not implemented } func (s *BoltdbStore) LockGlobalSaveBranches(gid string, status string, branches []storage.TransBranchStore, branchStart int) { diff --git a/dtmsvr/storage/redis/redis.go b/dtmsvr/storage/redis/redis.go index b32872b..7893f59 100644 --- a/dtmsvr/storage/redis/redis.go +++ b/dtmsvr/storage/redis/redis.go @@ -7,7 +7,6 @@ import ( "time" "github.com/go-redis/redis/v8" - "gorm.io/gorm" "github.com/dtm-labs/dtm/common" "github.com/dtm-labs/dtm/dtmcli/dtmimp" @@ -87,8 +86,8 @@ func (s *RedisStore) FindBranches(gid string) []storage.TransBranchStore { return branches } -func (s *RedisStore) UpdateBranchesSql(branches []storage.TransBranchStore, updates []string) *gorm.DB { - return nil // not implemented +func (s *RedisStore) UpdateBranches(branches []storage.TransBranchStore, updates []string) (int, error) { + return 0, nil // not implemented } type argList struct { diff --git a/dtmsvr/storage/sql/sql.go b/dtmsvr/storage/sql/sql.go index fd3fe56..9e86b82 100644 --- a/dtmsvr/storage/sql/sql.go +++ b/dtmsvr/storage/sql/sql.go @@ -62,11 +62,12 @@ func (s *SqlStore) FindBranches(gid string) []storage.TransBranchStore { return branches } -func (s *SqlStore) UpdateBranchesSql(branches []storage.TransBranchStore, updates []string) *gorm.DB { - return dbGet().Clauses(clause.OnConflict{ +func (s *SqlStore) UpdateBranches(branches []storage.TransBranchStore, updates []string) (int, error) { + db := dbGet().Clauses(clause.OnConflict{ OnConstraint: "trans_branch_op_pkey", DoUpdates: clause.AssignmentColumns(updates), }).Create(branches) + return int(db.RowsAffected), db.Error } func (s *SqlStore) LockGlobalSaveBranches(gid string, status string, branches []storage.TransBranchStore, branchStart int) { diff --git a/dtmsvr/storage/store.go b/dtmsvr/storage/store.go index 54a40f9..1a9da9c 100644 --- a/dtmsvr/storage/store.go +++ b/dtmsvr/storage/store.go @@ -3,11 +3,12 @@ package storage import ( "errors" "time" - - "gorm.io/gorm" ) +// ErrNotFound defines the query item is not found in storage implement. var ErrNotFound = errors.New("storage: NotFound") + +// ErrUniqueConflict defines the item is conflict with unique key in storage implement. var ErrUniqueConflict = errors.New("storage: UniqueKeyConflict") type Store interface { @@ -16,7 +17,7 @@ type Store interface { FindTransGlobalStore(gid string) *TransGlobalStore ScanTransGlobalStores(position *string, limit int64) []TransGlobalStore FindBranches(gid string) []TransBranchStore - UpdateBranchesSql(branches []TransBranchStore, updates []string) *gorm.DB + UpdateBranches(branches []TransBranchStore, updates []string) (int, error) LockGlobalSaveBranches(gid string, status string, branches []TransBranchStore, branchStart int) MaySaveNewTrans(global *TransGlobalStore, branches []TransBranchStore) error ChangeGlobalStatus(global *TransGlobalStore, newStatus string, updates []string, finished bool) diff --git a/dtmsvr/svr.go b/dtmsvr/svr.go index e2c739d..b782d05 100644 --- a/dtmsvr/svr.go +++ b/dtmsvr/svr.go @@ -77,13 +77,13 @@ func updateBranchAsync() { } } for len(updates) > 0 { - dbr := GetStore().UpdateBranchesSql(updates, []string{"status", "finish_time", "update_time"}) + rowAffected, err := GetStore().UpdateBranches(updates, []string{"status", "finish_time", "update_time"}) - if dbr.Error != nil { - logger.Errorf("async update branch status error: %v", dbr.Error) + if err != nil { + logger.Errorf("async update branch status error: %v", err) time.Sleep(1 * time.Second) } else { - logger.Infof("flushed %d branch status to db. affected: %d", len(updates), dbr.RowsAffected) + logger.Infof("flushed %d branch status to db. affected: %d", len(updates), rowAffected) updates = []TransBranch{} } } diff --git a/test/store_test.go b/test/store_test.go index 08c9966..28aafae 100644 --- a/test/store_test.go +++ b/test/store_test.go @@ -94,7 +94,7 @@ func TestStoreWait(t *testing.T) { func TestUpdateBranchSql(t *testing.T) { if !config.Store.IsDB() { - r := registry.GetStore().UpdateBranchesSql(nil, nil) - assert.Nil(t, r) + _, err := registry.GetStore().UpdateBranches(nil, nil) + assert.Nil(t, err) } }