mirror of https://github.com/dtm-labs/dtm.git
csharpjavadistributed-transactionsdtmgogolangmicroservicenodejsphpdatabasesagaseatatcctransactiontransactionsxapythondistributed
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
310 lines
7.6 KiB
310 lines
7.6 KiB
/*
|
|
* Copyright (c) 2021 yedf. All rights reserved.
|
|
* Use of this source code is governed by a BSD-style
|
|
* license that can be found in the LICENSE file.
|
|
*/
|
|
|
|
package dtmimp
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/dtm-labs/logger"
|
|
"github.com/go-resty/resty/v2"
|
|
)
|
|
|
|
// Logf an alias of Infof
|
|
// Deprecated: use logger.Errorf
|
|
var Logf = logger.Infof
|
|
|
|
// LogRedf an alias of Errorf
|
|
// Deprecated: use logger.Errorf
|
|
var LogRedf = logger.Errorf
|
|
|
|
// FatalIfError fatal if error is not nil
|
|
// Deprecated: use logger.FatalIfError
|
|
var FatalIfError = logger.FatalIfError
|
|
|
|
// LogIfFatalf fatal if cond is true
|
|
// Deprecated: use logger.FatalfIf
|
|
var LogIfFatalf = logger.FatalfIf
|
|
|
|
// AsError wrap a panic value as an error
|
|
func AsError(x interface{}) error {
|
|
logger.Errorf("panic wrapped to error: '%v'", x)
|
|
if e, ok := x.(error); ok {
|
|
return e
|
|
}
|
|
return fmt.Errorf("%v", x)
|
|
}
|
|
|
|
// P2E panic to error
|
|
func P2E(perr *error) {
|
|
if x := recover(); x != nil {
|
|
*perr = AsError(x)
|
|
}
|
|
}
|
|
|
|
// E2P error to panic
|
|
func E2P(err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
// CatchP catch panic to error
|
|
func CatchP(f func()) (rerr error) {
|
|
defer P2E(&rerr)
|
|
f()
|
|
return nil
|
|
}
|
|
|
|
// PanicIf name is clear
|
|
func PanicIf(cond bool, err error) {
|
|
if cond {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
// MustAtoi is string to int
|
|
func MustAtoi(s string) int {
|
|
r, err := strconv.Atoi(s)
|
|
if err != nil {
|
|
E2P(errors.New("convert to int error: " + s))
|
|
}
|
|
return r
|
|
}
|
|
|
|
// OrString return the first not empty string
|
|
func OrString(ss ...string) string {
|
|
for _, s := range ss {
|
|
if s != "" {
|
|
return s
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// If ternary operator
|
|
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
|
|
if condition {
|
|
return trueObj
|
|
}
|
|
return falseObj
|
|
}
|
|
|
|
// MustMarshal checked version for marshal
|
|
func MustMarshal(v interface{}) []byte {
|
|
b, err := json.Marshal(v)
|
|
E2P(err)
|
|
return b
|
|
}
|
|
|
|
// MustMarshalString string version of MustMarshal
|
|
func MustMarshalString(v interface{}) string {
|
|
return string(MustMarshal(v))
|
|
}
|
|
|
|
// MustUnmarshal checked version for unmarshal
|
|
func MustUnmarshal(b []byte, obj interface{}) {
|
|
err := json.Unmarshal(b, obj)
|
|
E2P(err)
|
|
}
|
|
|
|
// MustUnmarshalString string version of MustUnmarshal
|
|
func MustUnmarshalString(s string, obj interface{}) {
|
|
MustUnmarshal([]byte(s), obj)
|
|
}
|
|
|
|
// MustRemarshal marshal and unmarshal, and check error
|
|
func MustRemarshal(from interface{}, to interface{}) {
|
|
b, err := json.Marshal(from)
|
|
E2P(err)
|
|
err = json.Unmarshal(b, to)
|
|
E2P(err)
|
|
}
|
|
|
|
// GetFuncName get current call func name
|
|
func GetFuncName() string {
|
|
pc, _, _, _ := runtime.Caller(1)
|
|
nm := runtime.FuncForPC(pc).Name()
|
|
return nm[strings.LastIndex(nm, ".")+1:]
|
|
}
|
|
|
|
// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network
|
|
func MayReplaceLocalhost(host string) string {
|
|
if os.Getenv("IS_DOCKER") != "" {
|
|
return strings.Replace(strings.Replace(host,
|
|
"localhost", "host.docker.internal", 1),
|
|
"127.0.0.1", "host.docker.internal", 1)
|
|
}
|
|
return host
|
|
}
|
|
|
|
var sqlDbs = &mapCache{cache: map[string]*sql.DB{}}
|
|
|
|
type mapCache struct {
|
|
mutex sync.Mutex
|
|
cache map[string]*sql.DB
|
|
}
|
|
|
|
func (m *mapCache) LoadOrStore(conf DBConf, factory func(conf DBConf) (*sql.DB, error)) (*sql.DB, error) {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
dsn := GetDsn(conf)
|
|
if db, ok := m.cache[dsn]; ok {
|
|
return db, nil
|
|
}
|
|
db, err := factory(conf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m.cache[dsn] = db
|
|
return db, nil
|
|
}
|
|
|
|
// PooledDB get pooled sql.DB
|
|
func PooledDB(conf DBConf) (*sql.DB, error) {
|
|
return sqlDbs.LoadOrStore(conf, StandaloneDB)
|
|
}
|
|
|
|
// StandaloneDB get a standalone db instance
|
|
func StandaloneDB(conf DBConf) (*sql.DB, error) {
|
|
dsn := GetDsn(conf)
|
|
logger.Infof("opening standalone %s: %s", conf.Driver, strings.Replace(dsn, conf.Password, "****", 1))
|
|
return sql.Open(conf.Driver, dsn)
|
|
}
|
|
|
|
// XaDB return a standalone db instance for xa
|
|
func XaDB(conf DBConf) (*sql.DB, error) {
|
|
dsn := GetDsn(conf)
|
|
if conf.Driver == DBTypeMysql {
|
|
dsn += "&autocommit=0"
|
|
}
|
|
logger.Infof("opening xa standalone %s: %s", conf.Driver, strings.Replace(dsn, conf.Password, "****", 1))
|
|
return sql.Open(conf.Driver, dsn)
|
|
}
|
|
|
|
// XaClose will log and close the db
|
|
func XaClose(db *sql.DB) {
|
|
logger.Infof("closing xa db")
|
|
_ = db.Close()
|
|
}
|
|
|
|
// DBExec use raw db to exec
|
|
func DBExec(dbType string, db DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
|
if sql == "" {
|
|
return 0, nil
|
|
}
|
|
began := time.Now()
|
|
if len(values) > 0 {
|
|
sql = GetDBSpecial(dbType).GetPlaceHoldSQL(sql)
|
|
}
|
|
r, rerr := db.Exec(sql, values...)
|
|
used := time.Since(began) / time.Millisecond
|
|
if rerr == nil {
|
|
affected, rerr = r.RowsAffected()
|
|
logger.Debugf("used: %d ms affected: %d for %s %v", used, affected, sql, values)
|
|
} else {
|
|
logger.Errorf("used: %d ms exec error: %v for %s %v", used, rerr, sql, values)
|
|
}
|
|
return
|
|
}
|
|
|
|
// GetDsn get dsn from map config
|
|
func GetDsn(conf DBConf) string {
|
|
host := MayReplaceLocalhost(conf.Host)
|
|
driver := conf.Driver
|
|
dsn := map[string]string{
|
|
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true",
|
|
conf.User, conf.Password, host, conf.Port, conf.Db),
|
|
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' search_path=%s port=%d sslmode=disable",
|
|
host, conf.User, conf.Password, conf.Db, conf.Schema, conf.Port),
|
|
// sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30
|
|
"sqlserver": getSQLServerConnectionString(&conf, &host),
|
|
}[driver]
|
|
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
|
return dsn
|
|
}
|
|
|
|
func getSQLServerConnectionString(conf *DBConf, host *string) string {
|
|
query := url.Values{}
|
|
query.Add("database", conf.Db)
|
|
u := &url.URL{
|
|
Scheme: "sqlserver",
|
|
User: url.UserPassword(conf.User, conf.Password),
|
|
Host: fmt.Sprintf("%s:%d", *host, conf.Port),
|
|
// Path: instance, // if connecting to an instance instead of a port
|
|
RawQuery: query.Encode(),
|
|
}
|
|
return u.String()
|
|
}
|
|
|
|
// RespAsErrorByJSONRPC translate json rpc resty response to error
|
|
func RespAsErrorByJSONRPC(resp *resty.Response) error {
|
|
str := resp.String()
|
|
var result map[string]interface{}
|
|
MustUnmarshalString(str, &result)
|
|
if result["error"] != nil {
|
|
rerr := result["error"].(map[string]interface{})
|
|
if rerr["code"] == JrpcCodeFailure {
|
|
return fmt.Errorf("%s. %w", str, ErrFailure)
|
|
} else if rerr["code"] == JrpcCodeOngoing {
|
|
return ErrOngoing
|
|
}
|
|
return errors.New(resp.String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeferDo a common defer do used in dtmcli/dtmgrpc
|
|
func DeferDo(rerr *error, success func() error, fail func() error) {
|
|
if x := recover(); x != nil {
|
|
*rerr = AsError(x)
|
|
_ = fail()
|
|
panic(x)
|
|
} else if *rerr != nil {
|
|
_ = fail()
|
|
} else {
|
|
*rerr = success()
|
|
}
|
|
}
|
|
|
|
// Escape solve CodeQL reported problem
|
|
func Escape(input string) string {
|
|
v := strings.Replace(input, "\n", "", -1)
|
|
v = strings.Replace(v, "\r", "", -1)
|
|
v = strings.Replace(v, ";", "", -1)
|
|
// v = strings.Replace(v, "'", "", -1)
|
|
return v
|
|
}
|
|
|
|
// EscapeGet escape get
|
|
func EscapeGet(qs url.Values, key string) string {
|
|
return Escape(qs.Get(key))
|
|
}
|
|
|
|
// InsertBarrier insert a record to barrier
|
|
func InsertBarrier(tx DB, transType string, gid string, branchID string, op string, barrierID string, reason string, dbType string, barrierTableName string) (int64, error) {
|
|
if op == "" {
|
|
return 0, nil
|
|
}
|
|
if dbType == "" {
|
|
dbType = currentDBType
|
|
}
|
|
if barrierTableName == "" {
|
|
barrierTableName = BarrierTableName
|
|
}
|
|
sql := GetDBSpecial(dbType).GetInsertIgnoreTemplate(barrierTableName+"(trans_type, gid, branch_id, op, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier")
|
|
return DBExec(dbType, tx, sql, transType, gid, branchID, op, barrierID, reason)
|
|
}
|
|
|