From 2c859149190dd0e047ac48b0040d26d30151c4e6 Mon Sep 17 00:00:00 2001 From: yinheli Date: Sat, 14 Sep 2024 23:04:15 +0800 Subject: [PATCH] fix: static server in sub app with mount (#3104) --- router.go | 7 +++++++ router_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/router.go b/router.go index 4afa741537..5cbecbceab 100644 --- a/router.go +++ b/router.go @@ -5,6 +5,7 @@ package fiber import ( + "bytes" "fmt" "html" "sort" @@ -357,6 +358,12 @@ func (app *App) registerStatic(prefix, root string, config ...Static) { IndexNames: []string{"index.html"}, PathRewrite: func(fctx *fasthttp.RequestCtx) []byte { path := fctx.Path() + mountPath := app.MountPath() + if n := len(mountPath); n > 0 { + if bytes.Equal(path[:n], utils.UnsafeBytes(mountPath)) { + path = path[n:] + } + } if len(path) >= prefixLen { if isStar && app.getString(path[0:prefixLen]) == prefix { path = append(path[0:0], '/') diff --git a/router_test.go b/router_test.go index 6a43db5937..adb66f3c03 100644 --- a/router_test.go +++ b/router_test.go @@ -471,6 +471,37 @@ func Test_Route_Static_HasPrefix(t *testing.T) { body, err = io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + + app = New() + app.Static("/css", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/css/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) +} + +func Test_Route_Static_SubApp(t *testing.T) { + t.Parallel() + + dir := "./.github/testdata/fs/css" + app := New() + + subApp := New() + subApp.Static("/css", dir) + + app.Mount("/sub", subApp) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/sub/css/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) } func Test_Router_NotFound(t *testing.T) {