diff --git a/common/types.go b/common/types.go index a035a9f..60e404c 100644 --- a/common/types.go +++ b/common/types.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" _ "github.com/go-sql-driver/mysql" @@ -31,7 +32,7 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector { panic(fmt.Errorf("unkown driver: %s", driver)) } -var dbs = map[string]*DB{} +var dbs sync.Map // DB provide more func over gorm.DB type DB struct { @@ -104,16 +105,18 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { // DbGet get db connection for specified conf func DbGet(conf map[string]string) *DB { dsn := dtmcli.GetDsn(conf) - if dbs[dsn] == nil { + db, ok := dbs.Load(dsn) + if !ok { dtmcli.Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) db1, err := gorm.Open(getGormDialetor(conf["driver"], dsn), &gorm.Config{ SkipDefaultTransaction: true, }) dtmcli.E2P(err) db1.Use(&tracePlugin{}) - dbs[dsn] = &DB{DB: db1} + db = &DB{DB: db1} + dbs.Store(dsn, db) } - return dbs[dsn] + return db.(*DB) } type dtmConfigType struct { diff --git a/dtmcli/utils.go b/dtmcli/utils.go index 2ef73c0..f0df2cf 100644 --- a/dtmcli/utils.go +++ b/dtmcli/utils.go @@ -10,6 +10,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "github.com/go-resty/resty/v2" @@ -183,19 +184,21 @@ func MayReplaceLocalhost(host string) string { return host } -var sqlDbs = map[string]*sql.DB{} +var sqlDbs sync.Map // PooledDB get pooled sql.DB func PooledDB(conf map[string]string) (*sql.DB, error) { dsn := GetDsn(conf) - if sqlDbs[dsn] == nil { - db, err := StandaloneDB(conf) + db, ok := sqlDbs.Load(dsn) + if !ok { + db2, err := StandaloneDB(conf) if err != nil { return nil, err } - sqlDbs[dsn] = db + db = db2 + sqlDbs.Store(dsn, db) } - return sqlDbs[dsn], nil + return db.(*sql.DB), nil } // StandaloneDB get a standalone db instance @@ -225,13 +228,13 @@ func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row { // GetDsn get dsn from map config func GetDsn(conf map[string]string) string { - conf["host"] = MayReplaceLocalhost(conf["host"]) + host := MayReplaceLocalhost(conf["host"]) driver := conf["driver"] dsn := MS{ "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", - conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]), + conf["user"], conf["password"], host, conf["port"], conf["database"]), "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai", - conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]), + host, conf["user"], conf["password"], conf["database"], conf["port"]), }[driver] PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) return dsn