/* * Copyright (c) 2021 yedf. All rights reserved. * Use of this source code is governed by a BSD-style * license that can be found in the LICENSE file. */ package dtmutil import ( "bytes" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "os" "path/filepath" "strings" "time" "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" ) // GetGinApp init and return gin func GetGinApp() *gin.Engine { gin.SetMode(gin.ReleaseMode) app := gin.New() app.Use(gin.Recovery()) app.Use(func(c *gin.Context) { body := "" if c.Request.Body != nil { rb, err := c.GetRawData() dtmimp.E2P(err) if len(rb) > 0 { body = string(rb) c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(rb)) } } logger.Debugf("begin %s %s body: %s", c.Request.Method, c.Request.URL, body) c.Next() }) app.Any("/api/ping", func(c *gin.Context) { c.JSON(200, map[string]interface{}{"msg": "pong"}) }) return app } // WrapHandler2 wrap a function te bo the handler of gin request func WrapHandler2(fn func(*gin.Context) interface{}) gin.HandlerFunc { return func(c *gin.Context) { began := time.Now() var err error r := func() interface{} { defer dtmimp.P2E(&err) return fn(c) }() status := http.StatusOK // in dtm test/busi, there are some functions, which will return a resty response // pass resty response as gin's response if resp, ok := r.(*resty.Response); ok { b := resp.Body() status = resp.StatusCode() r = nil err = json.Unmarshal(b, &r) } // error maybe returned in r, assign it to err if ne, ok := r.(error); ok && err == nil { err = ne } // if err != nil || r == nil. then set the status and dtm_result // dtm_result is for compatible with version lower than v1.10 // when >= v1.10, result test should base on status, not dtm_result. result := map[string]interface{}{} if err != nil { if errors.Is(err, dtmcli.ErrFailure) { status = http.StatusConflict result["dtm_result"] = dtmcli.ResultFailure } else if errors.Is(err, dtmcli.ErrOngoing) { status = http.StatusTooEarly result["dtm_result"] = dtmcli.ResultOngoing } else if err != nil { status = http.StatusInternalServerError } result["message"] = err.Error() r = result } else if r == nil { result["dtm_result"] = dtmcli.ResultSuccess r = result } b, _ := json.Marshal(r) cont := string(b) if status == http.StatusOK || status == http.StatusTooEarly { logger.Infof("%2dms %d %s %s %s", time.Since(began).Milliseconds(), status, c.Request.Method, c.Request.RequestURI, cont) } else { logger.Errorf("%2dms %d %s %s %s", time.Since(began).Milliseconds(), status, c.Request.Method, c.Request.RequestURI, cont) } c.JSON(status, r) } } const jrpcCodeFailure = -32901 const jrpcCodeOngoing = -32902 // JrpcReq json-rpc request type JrpcReq struct { Method string `json:"method"` Jsonrpc string `json:"jsonrpc"` Params interface{} `json:"params"` ID string `json:"id"` } // WrapJrpcHandler wrap a gin func to be a gin handler func func WrapJrpcHandler(fn func(*JrpcReq) interface{}) gin.HandlerFunc { return func(c *gin.Context) { began := time.Now() var err error var req JrpcReq var jerr map[string]interface{} r := func() interface{} { defer dtmimp.P2E(&err) err2 := c.BindJSON(&req) if err2 != nil { jerr = map[string]interface{}{ "code": -32700, "message": fmt.Sprintf("Parse json error: %s", err2.Error()), } } else if req.ID == "" || req.Jsonrpc != "2.0" { jerr = map[string]interface{}{ "code": -32600, "message": fmt.Sprintf("Bad json request: %s", dtmimp.MustMarshalString(req)), } } else { return fn(&req) } return nil }() // error maybe returned in r, assign it to err if ne, ok := r.(error); ok && err == nil { err = ne } if err != nil { if errors.Is(err, dtmcli.ErrFailure) { jerr = map[string]interface{}{ "code": jrpcCodeFailure, "message": err.Error(), } } else if errors.Is(err, dtmcli.ErrOngoing) { jerr = map[string]interface{}{ "code": jrpcCodeOngoing, "message": err.Error(), } } else if jerr == nil { jerr = map[string]interface{}{ "code": -32603, "message": err.Error(), } } } result := map[string]interface{}{ "jsonrpc": "2.0", "id": req.ID, "error": jerr, "result": r, } b, _ := json.Marshal(result) cont := string(b) if jerr == nil || jerr["code"] == jrpcCodeOngoing { logger.Infof("%2dms %d %s %s %s", time.Since(began).Milliseconds(), 200, c.Request.Method, c.Request.RequestURI, cont) } else { logger.Errorf("%2dms %d %s %s %s", time.Since(began).Milliseconds(), 200, c.Request.Method, c.Request.RequestURI, cont) } c.JSON(200, result) } } // MustGetwd must version of os.Getwd func MustGetwd() string { wd, err := os.Getwd() dtmimp.E2P(err) return wd } // GetSQLDir 获取调用该函数的caller源代码的目录,主要用于测试时,查找相关文件 func GetSQLDir() string { wd := MustGetwd() if filepath.Base(wd) == "test" { wd = filepath.Dir(wd) } return wd + "/sqls" } // RecoverPanic execs recovery operation func RecoverPanic(err *error) { if x := recover(); x != nil { e := dtmimp.AsError(x) if err != nil { *err = e } } } // GetNextTime gets next time from second func GetNextTime(second int64) *time.Time { next := time.Now().Add(time.Duration(second) * time.Second) return &next } // RunSQLScript 1 func RunSQLScript(conf dtmcli.DBConf, script string, skipDrop bool) { con, err := dtmimp.StandaloneDB(conf) logger.FatalIfError(err) defer func() { _ = con.Close() }() content, err := ioutil.ReadFile(script) logger.FatalIfError(err) sqls := strings.Split(string(content), ";") for _, sql := range sqls { s := strings.TrimSpace(sql) if s == "" || (skipDrop && strings.Contains(s, "drop")) { continue } _, err = dtmimp.DBExec(con, s) logger.FatalIfError(err) logger.Infof("sql scripts finished: %s", s) } }