From 2ff1b4bb7504fdef3a02c5719a110717cf65952b Mon Sep 17 00:00:00 2001 From: Alan YU Date: Fri, 12 Apr 2024 16:07:45 +0800 Subject: [PATCH] azure mysql ca support --- client/dtmcli/dtmimp/types.go | 2 ++ client/dtmcli/dtmimp/utils.go | 4 ++-- dtmsvr/config/config.go | 4 ++++ dtmutil/db.go | 22 ++++++++++++++++++---- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/client/dtmcli/dtmimp/types.go b/client/dtmcli/dtmimp/types.go index a889d5f..d8fed7c 100644 --- a/client/dtmcli/dtmimp/types.go +++ b/client/dtmcli/dtmimp/types.go @@ -23,4 +23,6 @@ type DBConf struct { Password string `yaml:"Password"` Db string `yaml:"Db"` Schema string `yaml:"Schema"` + Ca string `yaml:"Ca"` + Tls string `yaml:"Tls"` } diff --git a/client/dtmcli/dtmimp/utils.go b/client/dtmcli/dtmimp/utils.go index cc7b912..3ebc914 100644 --- a/client/dtmcli/dtmimp/utils.go +++ b/client/dtmcli/dtmimp/utils.go @@ -226,8 +226,8 @@ func GetDsn(conf DBConf) string { host := MayReplaceLocalhost(conf.Host) driver := conf.Driver dsn := map[string]string{ - "mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true", - conf.User, conf.Password, host, conf.Port, conf.Db), + "mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true&tls=%s", + conf.User, conf.Password, host, conf.Port, conf.Db, conf.Tls), "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' search_path=%s port=%d sslmode=disable", host, conf.User, conf.Password, conf.Db, conf.Schema, conf.Port), // sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30 diff --git a/dtmsvr/config/config.go b/dtmsvr/config/config.go index 30f0115..600980f 100644 --- a/dtmsvr/config/config.go +++ b/dtmsvr/config/config.go @@ -56,6 +56,8 @@ type Store struct { User string `yaml:"User"` Password string `yaml:"Password"` Db string `yaml:"Db" default:"dtm"` + Ca string `yaml:"Ca" default:""` + Tls string `yaml:"Tls" default:"false"` Schema string `yaml:"Schema" default:"public"` MaxOpenConns int64 `yaml:"MaxOpenConns" default:"500"` MaxIdleConns int64 `yaml:"MaxIdleConns" default:"500"` @@ -80,6 +82,8 @@ func (s *Store) GetDBConf() dtmcli.DBConf { Password: s.Password, Db: s.Db, Schema: s.Schema, + Ca: s.Ca, + Tls: s.Tls, } } diff --git a/dtmutil/db.go b/dtmutil/db.go index db09ab9..223f705 100644 --- a/dtmutil/db.go +++ b/dtmutil/db.go @@ -1,6 +1,8 @@ package dtmutil import ( + "crypto/tls" + "crypto/x509" "database/sql" "fmt" "sync" @@ -9,7 +11,7 @@ import ( "github.com/dtm-labs/dtm/client/dtmcli" "github.com/dtm-labs/dtm/client/dtmcli/dtmimp" "github.com/dtm-labs/logger" - _ "github.com/go-sql-driver/mysql" // register mysql driver + mysql_driver "github.com/go-sql-driver/mysql" // register mysql driver _ "github.com/lib/pq" // register postgres driver // _ "github.com/microsoft/go-mssqldb" // Microsoft's package conflicts with gorm's package: panic: sql: Register called twice for driver mssql @@ -26,7 +28,16 @@ type ModelBase struct { UpdateTime *time.Time `json:"update_time" gorm:"autoUpdateTime"` } -func getGormDialetor(driver string, dsn string) gorm.Dialector { +func registerMysqlCA(caPath string) { + rootCertPool := x509.NewCertPool() + pem, _ := os.ReadFile(caPath) + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + logger.Errorf("Failed to append PEM.") + } + mysql_driver.RegisterTLSConfig("custom", &tls.Config{RootCAs: rootCertPool}) +} + +func getGormDialetor(driver string, dsn string, ca string) gorm.Dialector { if driver == dtmcli.DBTypePostgres { return postgres.Open(dsn) } @@ -34,6 +45,9 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector { return sqlserver.Open(dsn) } dtmimp.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unknown driver: %s", driver)) + if ca != "" { + registerMysqlCA(ca) + } return mysql.Open(dsn) } @@ -106,8 +120,8 @@ func DbGet(conf dtmcli.DBConf, ops ...func(*gorm.DB)) *DB { dsn := dtmimp.GetDsn(conf) db, ok := dbs.Load(dsn) if !ok { - logger.Infof("connecting '%s' '%s' '%s' '%d' '%s'", conf.Driver, conf.Host, conf.User, conf.Port, conf.Db) - db1, err := gorm.Open(getGormDialetor(conf.Driver, dsn), &gorm.Config{ + logger.Infof("connecting '%s' '%s' '%s' '%d' '%s', '%s', '%s'", conf.Driver, conf.Host, conf.User, conf.Port, conf.Db, conf.Ca, conf.Tls) + db1, err := gorm.Open(getGormDialetor(conf.Driver, dsn, conf.Ca), &gorm.Config{ SkipDefaultTransaction: true, }) dtmimp.E2P(err)