Browse Source

fix concurrent map access

pull/39/head
yedf2 5 years ago
parent
commit
a411208880
  1. 11
      common/types.go
  2. 19
      dtmcli/utils.go

11
common/types.go

@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
@ -31,7 +32,7 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector {
panic(fmt.Errorf("unkown driver: %s", driver))
}
var dbs = map[string]*DB{}
var dbs sync.Map
// DB provide more func over gorm.DB
type DB struct {
@ -104,16 +105,18 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
// DbGet get db connection for specified conf
func DbGet(conf map[string]string) *DB {
dsn := dtmcli.GetDsn(conf)
if dbs[dsn] == nil {
db, ok := dbs.Load(dsn)
if !ok {
dtmcli.Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
db1, err := gorm.Open(getGormDialetor(conf["driver"], dsn), &gorm.Config{
SkipDefaultTransaction: true,
})
dtmcli.E2P(err)
db1.Use(&tracePlugin{})
dbs[dsn] = &DB{DB: db1}
db = &DB{DB: db1}
dbs.Store(dsn, db)
}
return dbs[dsn]
return db.(*DB)
}
type dtmConfigType struct {

19
dtmcli/utils.go

@ -10,6 +10,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/go-resty/resty/v2"
@ -183,19 +184,21 @@ func MayReplaceLocalhost(host string) string {
return host
}
var sqlDbs = map[string]*sql.DB{}
var sqlDbs sync.Map
// PooledDB get pooled sql.DB
func PooledDB(conf map[string]string) (*sql.DB, error) {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
db, err := StandaloneDB(conf)
db, ok := sqlDbs.Load(dsn)
if !ok {
db2, err := StandaloneDB(conf)
if err != nil {
return nil, err
}
sqlDbs[dsn] = db
db = db2
sqlDbs.Store(dsn, db)
}
return sqlDbs[dsn], nil
return db.(*sql.DB), nil
}
// StandaloneDB get a standalone db instance
@ -225,13 +228,13 @@ func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row {
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
host := MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
conf["user"], conf["password"], host, conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
host, conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn

Loading…
Cancel
Save