From 37dd2b8a5241a36be3585b9a90a876f0a414682a Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Fri, 29 Oct 2021 21:47:57 +0800 Subject: [PATCH] test coverage opt --- dtmsvr/api.go | 10 +++++----- dtmsvr/cron.go | 13 ++++++------- dtmsvr/trans.go | 29 ----------------------------- dtmsvr/utils.go | 24 ++++++++++++++++++++++++ dtmsvr/utils_test.go | 7 ++++++- examples/base_http.go | 3 +++ test/msg_test.go | 5 +++++ test/saga_test.go | 4 ++-- 8 files changed, 51 insertions(+), 44 deletions(-) diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 8060fe1..f2923f0 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -13,7 +13,7 @@ func svcSubmit(t *TransGlobal) (interface{}, error) { err := t.saveNew(db) if err == errUniqueConflict { - dbt := TransFromDb(db, t.Gid) + dbt := transFromDb(db, t.Gid) if dbt.Status == dtmcli.StatusPrepared { updates := t.setNextCron(cronReset) db.Must().Model(t).Where("gid=? and status=?", t.Gid, dtmcli.StatusPrepared).Select(append(updates, "status")).Updates(t) @@ -28,7 +28,7 @@ func svcPrepare(t *TransGlobal) (interface{}, error) { t.Status = dtmcli.StatusPrepared err := t.saveNew(dbGet()) if err == errUniqueConflict { - dbt := TransFromDb(dbGet(), t.Gid) + dbt := transFromDb(dbGet(), t.Gid) if dbt.Status != dtmcli.StatusPrepared { return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil } @@ -38,7 +38,7 @@ func svcPrepare(t *TransGlobal) (interface{}, error) { func svcAbort(t *TransGlobal) (interface{}, error) { db := dbGet() - dbt := TransFromDb(db, t.Gid) + dbt := transFromDb(db, t.Gid) if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting { return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil } @@ -48,7 +48,7 @@ func svcAbort(t *TransGlobal) (interface{}, error) { func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, error) { db := dbGet() - dbt := TransFromDb(db, branch.Gid) + dbt := transFromDb(db, branch.Gid) if dbt.Status != dtmcli.StatusPrepared { return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil } @@ -70,7 +70,7 @@ func svcRegisterTccBranch(branch *TransBranch, data dtmcli.MS) (interface{}, err func svcRegisterXaBranch(branch *TransBranch) (interface{}, error) { branch.Status = dtmcli.StatusPrepared db := dbGet() - dbt := TransFromDb(db, branch.Gid) + dbt := transFromDb(db, branch.Gid) if dbt.Status != dtmcli.StatusPrepared { return M{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil } diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 7aeda41..78e0546 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -2,7 +2,6 @@ package dtmsvr import ( "fmt" - "math" "math/rand" "runtime/debug" "time" @@ -37,7 +36,7 @@ func CronExpiredTrans(num int) { for i := 0; i < num || num == -1; i++ { hasTrans := CronTransOnce() if !hasTrans && num != 1 { - sleepCronTime(0) + sleepCronTime() } } } @@ -71,9 +70,9 @@ func handlePanic(perr *error) { } } -func sleepCronTime(milli int) { - delta := math.Min(3, float64(config.TransCronInterval)) - interval := time.Duration((float64(config.TransCronInterval) - rand.Float64()*delta) * float64(time.Second)) - dtmcli.Logf("sleeping for %v pass in %d milli", interval, milli) - time.Sleep(dtmcli.If(milli == 0, interval, time.Duration(milli*int(time.Millisecond))).(time.Duration)) +func sleepCronTime() { + normal := time.Duration((float64(config.TransCronInterval) - rand.Float64()) * float64(time.Second)) + interval := dtmcli.If(CronForwardDuration > 0, 1*time.Millisecond, normal).(time.Duration) + dtmcli.Logf("sleeping for %v milli", interval/time.Microsecond) + time.Sleep(interval) } diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 7444b8b..1eee285 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -86,13 +86,6 @@ func (t *TransGlobal) isTimeout() bool { return time.Since(*t.CreateTime)+NowForwardDuration >= time.Duration(timeout)*time.Second } -func (t *TransGlobal) getRetryInterval() int64 { - if t.RetryInterval > 0 { - return t.RetryInterval - } - return config.RetryInterval -} - func (t *TransGlobal) needProcess() bool { return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout() } @@ -348,25 +341,3 @@ func TransFromDtmRequest(c *dtmgrpc.DtmRequest) *TransGlobal { Protocol: "grpc", } } - -// TransFromDb construct trans from db -func TransFromDb(db *common.DB, gid string) *TransGlobal { - m := TransGlobal{} - dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) - if dbr.Error == gorm.ErrRecordNotFound { - return nil - } - e2p(dbr.Error) - return &m -} - -func checkLocalhost(branches []TransBranch) { - if config.DisableLocalhost == 0 { - return - } - for _, branch := range branches { - if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") { - panic(errors.New("url for localhost is disabled. check for your config")) - } - } -} diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index abd5fb8..f498a5f 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -2,6 +2,7 @@ package dtmsvr import ( "encoding/hex" + "errors" "fmt" "net" "strings" @@ -10,6 +11,7 @@ import ( "github.com/bwmarrin/snowflake" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" + "gorm.io/gorm" ) // M a short name @@ -91,3 +93,25 @@ func getOneHexIP() string { fmt.Printf("err is: %s", err.Error()) return "" // 获取不到IP,则直接返回空 } + +// transFromDb construct trans from db +func transFromDb(db *common.DB, gid string) *TransGlobal { + m := TransGlobal{} + dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) + if dbr.Error == gorm.ErrRecordNotFound { + return nil + } + e2p(dbr.Error) + return &m +} + +func checkLocalhost(branches []TransBranch) { + if config.DisableLocalhost == 0 { + return + } + for _, branch := range branches { + if strings.HasPrefix(branch.URL, "http://localhost") || strings.HasPrefix(branch.URL, "localhost") { + panic(errors.New("url for localhost is disabled. check for your config")) + } + } +} diff --git a/dtmsvr/utils_test.go b/dtmsvr/utils_test.go index c678e7c..7bf44bd 100644 --- a/dtmsvr/utils_test.go +++ b/dtmsvr/utils_test.go @@ -16,7 +16,7 @@ func TestUtils(t *testing.T) { assert.Error(t, err) CronExpiredTrans(1) - sleepCronTime(10) + sleepCronTime() } func TestCheckLocalHost(t *testing.T) { @@ -43,3 +43,8 @@ func TestSetNextCron(t *testing.T) { tg.setNextCron(cronBackoff) assert.Equal(t, config.RetryInterval*2, tg.NextCronInterval) } +func TestTransFromDB(t *testing.T) { + db := dbGet() + trans := transFromDb(db, "-1") + assert.Nil(t, trans) +} diff --git a/examples/base_http.go b/examples/base_http.go index a643bc2..8dcebc7 100644 --- a/examples/base_http.go +++ b/examples/base_http.go @@ -80,6 +80,9 @@ func handleGeneralBusiness(c *gin.Context, result1 string, result2 string, busi info := infoFromContext(c) res := dtmcli.OrString(result1, result2, dtmcli.ResultSuccess) dtmcli.Logf("%s %s result: %s", busi, info.String(), res) + if res == "ERROR" { + return nil, errors.New("ERROR from user") + } return M{"dtm_result": res}, nil } diff --git a/test/msg_test.go b/test/msg_test.go index 85f6cb6..cac6ecb 100644 --- a/test/msg_test.go +++ b/test/msg_test.go @@ -22,11 +22,14 @@ func msgNormal(t *testing.T) { WaitTransProcessed(msg.Gid) assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) + CronTransOnce() } func msgOngoing(t *testing.T) { msg := genMsg("gid-msg-normal-pending") msg.Prepare("") + err := msg.Prepare("") // additional prepare to go conflict key path + assert.Nil(t, err) assert.Equal(t, dtmcli.StatusPrepared, getTransStatus(msg.Gid)) examples.MainSwitch.CanSubmitResult.SetOnce(dtmcli.ResultOngoing) cronTransOnceForwardNow(180) @@ -37,6 +40,8 @@ func msgOngoing(t *testing.T) { CronTransOnce() assert.Equal(t, []string{dtmcli.StatusSucceed, dtmcli.StatusSucceed}, getBranchesStatus(msg.Gid)) assert.Equal(t, dtmcli.StatusSucceed, getTransStatus(msg.Gid)) + err = msg.Prepare("") + assert.Error(t, err) } func msgOngoingFailed(t *testing.T) { diff --git a/test/saga_test.go b/test/saga_test.go index 8ceaada..d4b6eb8 100644 --- a/test/saga_test.go +++ b/test/saga_test.go @@ -40,7 +40,7 @@ func sagaCommittedOngoing(t *testing.T) { func sagaRollback(t *testing.T) { saga := genSaga("gid-rollback-saga", false, true) - examples.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) + examples.MainSwitch.TransOutRevertResult.SetOnce("ERROR") err := saga.Submit() assert.Nil(t, err) WaitTransProcessed(saga.Gid) @@ -67,7 +67,7 @@ func sagaRollback2(t *testing.T) { func sagaTimeout(t *testing.T) { saga := genSaga("gid-timeout-saga", false, false) saga.TimeoutToFail = 1800 - examples.MainSwitch.TransOutResult.SetOnce(dtmcli.ResultOngoing) + examples.MainSwitch.TransOutResult.SetOnce("UNKOWN") saga.Submit() WaitTransProcessed(saga.Gid) assert.Equal(t, dtmcli.StatusSubmitted, getTransStatus(saga.Gid))