From 8db7e03c4d32911c0bdc7e1358eb78c1e7a196ef Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Thu, 11 Nov 2021 10:00:33 +0800 Subject: [PATCH] run change branch status within a transaction which lock transglobal --- dtmsvr/api.go | 65 +++++++++++++++++++++++------------------- dtmsvr/api_http.go | 2 +- dtmsvr/trans_status.go | 20 ++++++++----- dtmsvr/utils.go | 8 ++++-- examples/base_http.go | 1 - 5 files changed, 56 insertions(+), 40 deletions(-) diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 458f752..9e93c74 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/yedf/dtm/dtmcli" + "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -13,10 +14,11 @@ func svcSubmit(t *TransGlobal) (interface{}, error) { err := t.saveNew(db) if err == errUniqueConflict { - dbt := transFromDb(db, t.Gid) + dbt := transFromDb(db.DB, t.Gid, false) if dbt.Status == dtmcli.StatusPrepared { updates := t.setNextCron(cronReset) - db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) + dbr := db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) + checkAffected(dbr) } else if dbt.Status != dtmcli.StatusSubmitted { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil } @@ -28,7 +30,7 @@ func svcPrepare(t *TransGlobal) (interface{}, error) { t.Status = dtmcli.StatusPrepared err := t.saveNew(dbGet()) if err == errUniqueConflict { - dbt := transFromDb(dbGet(), t.Gid) + dbt := transFromDb(dbGet().DB, t.Gid, false) if dbt.Status != dtmcli.StatusPrepared { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil } @@ -38,7 +40,7 @@ func svcPrepare(t *TransGlobal) (interface{}, error) { func svcAbort(t *TransGlobal) (interface{}, error) { db := dbGet() - dbt := transFromDb(db, t.Gid) + dbt := transFromDb(db.DB, t.Gid, false) if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting { return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil } @@ -46,32 +48,37 @@ func svcAbort(t *TransGlobal) (interface{}, error) { return dbt.Process(db), nil } -func svcRegisterBranch(branch *TransBranch, data map[string]string) (interface{}, error) { - db := dbGet() - dbt := transFromDb(db, branch.Gid) - if dbt.Status != dtmcli.StatusPrepared { - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil - } +func svcRegisterBranch(branch *TransBranch, data map[string]string) (ret interface{}, rerr error) { + err := dbGet().Transaction(func(db *gorm.DB) error { + dbt := transFromDb(db, branch.Gid, true) + if dbt.Status != dtmcli.StatusPrepared { + ret = map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)} + return nil + } - branches := []TransBranch{*branch, *branch} - if dbt.TransType == "tcc" { - for i, b := range []string{dtmcli.BranchCancel, dtmcli.BranchConfirm} { - branches[i].Op = b - branches[i].URL = data[b] + branches := []TransBranch{*branch, *branch} + if dbt.TransType == "tcc" { + for i, b := range []string{dtmcli.BranchCancel, dtmcli.BranchConfirm} { + branches[i].Op = b + branches[i].URL = data[b] + } + } else if dbt.TransType == "xa" { + branches[0].Op = dtmcli.BranchRollback + branches[0].URL = data["url"] + branches[1].Op = dtmcli.BranchCommit + branches[1].URL = data["url"] + } else { + rerr = fmt.Errorf("unknow trans type: %s", dbt.TransType) + return nil } - } else if dbt.TransType == "xa" { - branches[0].Op = dtmcli.BranchRollback - branches[0].URL = data["url"] - branches[1].Op = dtmcli.BranchCommit - branches[1].URL = data["url"] - } else { - return nil, fmt.Errorf("unknow trans type: %s", dbt.TransType) - } - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(branches) - global := TransGlobal{Gid: branch.Gid} - global.touch(dbGet(), cronKeep) - return dtmcli.MapSuccess, nil + dbr := db.Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(branches) + checkAffected(dbr) + ret = dtmcli.MapSuccess + return nil + }) + e2p(err) + return } diff --git a/dtmsvr/api_http.go b/dtmsvr/api_http.go index b1f2e90..6cbdb9e 100644 --- a/dtmsvr/api_http.go +++ b/dtmsvr/api_http.go @@ -57,7 +57,7 @@ func query(c *gin.Context) (interface{}, error) { return nil, errors.New("no gid specified") } db := dbGet() - trans := transFromDb(db, gid) + trans := transFromDb(db.DB, gid, false) branches := []TransBranch{} db.Must().Where("gid", gid).Find(&branches) return map[string]interface{}{"transaction": trans, "branches": branches}, nil diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index 625ebff..dcee6b1 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func (t *TransGlobal) touch(db *common.DB, ctype cronType) *gorm.DB { @@ -33,23 +34,28 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB { t.RollbackTime = &now updates = append(updates, "rollback_time") } - dbr := db.Must().Model(t).Where("status=?", old).Select(updates).Updates(t) + dbr := db.Must().Model(&TransGlobal{}).Where("status=? and gid=?", old, t.Gid).Select(updates).Updates(t) checkAffected(dbr) return dbr } -func (t *TransGlobal) changeBranchStatus(db *common.DB, b *TransBranch, status string) *gorm.DB { +func (t *TransGlobal) changeBranchStatus(db *common.DB, b *TransBranch, status string) { if common.DtmConfig.UpdateBranchSync > 0 || t.TransType == "saga" && t.TimeoutToFail > 0 { - dbr := db.Must().Model(b).Updates(map[string]interface{}{ - "status": status, - "finish_time": time.Now(), + err := db.Transaction(func(tx *gorm.DB) error { + dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, t.Status).Find(&[]TransGlobal{}) + checkAffected(dbr) // check TransGlobal is not modified + dbr = tx.Model(b).Updates(map[string]interface{}{ + "status": status, + "finish_time": time.Now(), + }) + checkAffected(dbr) + return dbr.Error }) - checkAffected(dbr) + e2p(err) } else { // 为了性能优化,把branch的status更新异步化 updateBranchAsyncChan <- branchStatus{id: b.ID, status: status} } b.Status = status - return db.DB } func (t *TransGlobal) isTimeout() bool { diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 1aeb587..4ae90b5 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -12,6 +12,7 @@ import ( "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli/dtmimp" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type branchStatus struct { @@ -65,9 +66,12 @@ func getOneHexIP() string { } // transFromDb construct trans from db -func transFromDb(db *common.DB, gid string) *TransGlobal { +func transFromDb(db *gorm.DB, gid string, lock bool) *TransGlobal { m := TransGlobal{} - dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) + if lock { + db = db.Clauses(clause.Locking{Strength: "UPDATE"}) + } + dbr := db.Model(&m).Where("gid=?", gid).First(&m) if dbr.Error == gorm.ErrRecordNotFound { return nil } diff --git a/examples/base_http.go b/examples/base_http.go index 6a6dd4d..447fc47 100644 --- a/examples/base_http.go +++ b/examples/base_http.go @@ -70,7 +70,6 @@ func (s *AutoEmptyString) SetOnce(v string) { // Fetch fetch the stored value, then reset the value to empty func (s *AutoEmptyString) Fetch() string { - dtmimp.Logf("fetch result is: %s", s.value) v := s.value s.value = "" return v