diff --git a/main.go b/main.go index 1ade997..645031f 100644 --- a/main.go +++ b/main.go @@ -7,15 +7,18 @@ package main import ( + "bytes" + "compress/gzip" "embed" "fmt" + "io" "io/fs" - "io/ioutil" "net/http" "net/http/httputil" "net/url" "os" "os/signal" + "strconv" "strings" "syscall" @@ -64,17 +67,13 @@ func addAdmin(app *gin.Engine, conf *config.Type) { router = app.Group(conf.AdminBasePath) } if err == nil { - cont, err := ioutil.ReadAll(index) + cont, err := io.ReadAll(index) logger.FatalIfError(err) _ = index.Close() - sfile := string(cont) + cont = bytesTryReplaceIndex(cont, conf) - //replace base path - sfile = strings.Replace(sfile, "\"assets/", "\""+conf.AdminBasePath+"/assets/", -1) - sfile = strings.Replace(sfile, "PUBLIC-PATH-VARIABLE", conf.AdminBasePath, -1) renderIndex := func(c *gin.Context) { - c.Header("content-type", "text/html;charset=utf-8") - c.String(200, sfile) + c.Data(200, "text/html; charset=utf-8", cont) } router.StaticFS("/assets", http.FS(getSub(dist, "assets"))) router.GET("/admin/*name", renderIndex) @@ -84,9 +83,9 @@ func addAdmin(app *gin.Engine, conf *config.Type) { }) logger.Infof("admin is served from dir 'admin/dist/'") } else { - router.GET("/", proxyAdmin) - router.GET("/assets/*name", proxyAdmin) - router.GET("/admin/*name", proxyAdmin) + router.GET("/", proxyAdmin(conf)) + router.GET("/assets/*name", proxyAdmin(conf)) + router.GET("/admin/*name", proxyAdmin(conf)) lang := os.Getenv("LANG") if strings.HasPrefix(lang, "zh_CN") { target = "cn-admin.dtm.pub" @@ -98,20 +97,79 @@ func addAdmin(app *gin.Engine, conf *config.Type) { logger.Infof("admin is running at: http://localhost:%d%s", conf.HTTPPort, conf.AdminBasePath) } -func proxyAdmin(c *gin.Context) { +func proxyAdmin(conf *config.Type) func(c *gin.Context) { + return func(c *gin.Context) { + u := &url.URL{} + u.Scheme = "http" + u.Host = target + proxy := httputil.NewSingleHostReverseProxy(u) + originalDirector := proxy.Director + proxy.Director = func(r *http.Request) { + originalDirector(r) + p := strings.TrimPrefix(r.URL.Path, conf.AdminBasePath) + rp := strings.TrimPrefix(r.URL.RawPath, conf.AdminBasePath) + r.URL.Path = p + r.URL.RawPath = rp + } + proxy.Transport = &transport{RoundTripper: http.DefaultTransport, conf: conf} + proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { + logger.Warnf("http: proxy error: %v", err) + ret := fmt.Sprintf("http proxy error %v", err) + _, _ = rw.Write([]byte(ret)) + } + logger.Debugf("proxy admin to %s", target) + c.Request.Host = target + proxy.ServeHTTP(c.Writer, c.Request) + } - u := &url.URL{} - u.Scheme = "http" - u.Host = target - proxy := httputil.NewSingleHostReverseProxy(u) +} - proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - logger.Warnf("http: proxy error: %v", err) - ret := fmt.Sprintf("http proxy error %v", err) - _, _ = rw.Write([]byte(ret)) - } - logger.Debugf("proxy admin to %s", target) - c.Request.Host = target - proxy.ServeHTTP(c.Writer, c.Request) +// bytesTryReplaceIndex replace index.html base path +func bytesTryReplaceIndex(source []byte, conf *config.Type) []byte { + source = bytes.Replace(source, []byte("\"assets/"), []byte("\""+conf.AdminBasePath+"/assets/"), -1) + source = bytes.Replace(source, []byte("PUBLIC-PATH-VARIABLE"), []byte(conf.AdminBasePath), -1) + return source +} +type transport struct { + http.RoundTripper + conf *config.Type +} + +func (t *transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + resp, err = t.RoundTripper.RoundTrip(req) + if err != nil { + return nil, err + } + //modify html only + if !strings.Contains(resp.Header.Get("Content-Type"), "text/html") { + return resp, err + } + var reader io.ReadCloser + switch resp.Header.Get("Content-Encoding") { + case "gzip": + reader, err = gzip.NewReader(resp.Body) + defer func() { + if tmpErr := reader.Close(); err == nil && tmpErr != nil { + err = tmpErr + } + }() + default: + reader = resp.Body + } + delete(resp.Header, "Content-Encoding") + b, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + err = resp.Body.Close() + if err != nil { + return nil, err + } + b = bytesTryReplaceIndex(b, t.conf) + body := io.NopCloser(bytes.NewReader(b)) + resp.Body = body + resp.ContentLength = int64(len(b)) + resp.Header.Set("Content-Length", strconv.Itoa(len(b))) + return resp, nil }