diff --git a/client/dtmcli/dtmimp/utils.go b/client/dtmcli/dtmimp/utils.go index 3811523..b0a6a8a 100644 --- a/client/dtmcli/dtmimp/utils.go +++ b/client/dtmcli/dtmimp/utils.go @@ -151,21 +151,31 @@ func MayReplaceLocalhost(host string) string { return host } -var sqlDbs sync.Map +var sqlDbs = &mapCache{cache: map[string]*sql.DB{}} -// PooledDB get pooled sql.DB -func PooledDB(conf DBConf) (*sql.DB, error) { +type mapCache struct { + mutex sync.Mutex + cache map[string]*sql.DB +} + +func (m *mapCache) LoadOrStore(conf DBConf, factory func(conf DBConf) (*sql.DB, error)) (*sql.DB, error) { + m.mutex.Lock() + defer m.mutex.Unlock() dsn := GetDsn(conf) - db, ok := sqlDbs.Load(dsn) - if !ok { - db2, err := StandaloneDB(conf) - if err != nil { - return nil, err - } - db = db2 - sqlDbs.Store(dsn, db) + if db, ok := m.cache[dsn]; ok { + return db, nil } - return db.(*sql.DB), nil + db, err := factory(conf) + if err != nil { + return nil, err + } + m.cache[dsn] = db + return db, nil +} + +// PooledDB get pooled sql.DB +func PooledDB(conf DBConf) (*sql.DB, error) { + return sqlDbs.LoadOrStore(conf, StandaloneDB) } // StandaloneDB get a standalone db instance