package middleware import ( "io" "log/slog" "net/http" "net/http/httptest" "path/filepath" "testing" "time" "warpbox.dev/backend/libs/services" ) func TestBansMiddlewareBlocksActiveBan(t *testing.T) { bans := newMiddlewareBanService(t) if _, err := bans.CreateManualBan("203.0.113.20", "test", "admin", time.Now().UTC().Add(time.Hour)); err != nil { t.Fatalf("CreateManualBan returned error: %v", err) } handler := Chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next handler should not be called") }), Bans(slog.New(slog.NewTextHandler(io.Discard, nil)), bans, nil)) request := httptest.NewRequest(http.MethodGet, "/", nil) request.RemoteAddr = "127.0.0.1:6070" request.Header.Set("X-Forwarded-For", "203.0.113.20") response := httptest.NewRecorder() handler.ServeHTTP(response, request) if response.Code != http.StatusForbidden || response.Body.String() != "forbidden\n" { t.Fatalf("blocked response = %d %q", response.Code, response.Body.String()) } } func TestBansMiddlewareAllowsNonBannedIP(t *testing.T) { bans := newMiddlewareBanService(t) called := false handler := Chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true _, _ = io.WriteString(w, "ok") }), Bans(slog.New(slog.NewTextHandler(io.Discard, nil)), bans, nil)) request := httptest.NewRequest(http.MethodGet, "/", nil) request.RemoteAddr = "203.0.113.21:6070" response := httptest.NewRecorder() handler.ServeHTTP(response, request) if !called || response.Code != http.StatusOK { t.Fatalf("allowed response = called %v code %d", called, response.Code) } } func TestBansMiddlewareAutoBansMaliciousPaths(t *testing.T) { bans := newMiddlewareBanService(t) settings, err := bans.Settings() if err != nil { t.Fatalf("Settings returned error: %v", err) } settings.AutoBanEnabled = true settings.MaliciousPathThreshold = 3 if err := bans.UpdateSettings(settings); err != nil { t.Fatalf("UpdateSettings returned error: %v", err) } handler := Chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) }), Bans(slog.New(slog.NewTextHandler(io.Discard, nil)), bans, nil)) for i := 0; i < 3; i++ { request := httptest.NewRequest(http.MethodGet, "/.env", nil) request.RemoteAddr = "203.0.113.22:6070" response := httptest.NewRecorder() handler.ServeHTTP(response, request) if i < 2 && response.Code == http.StatusForbidden { t.Fatalf("request %d blocked before threshold", i+1) } if i == 2 && response.Code != http.StatusForbidden { t.Fatalf("request 3 status = %d, want forbidden", response.Code) } } } func TestBansMiddlewareSkipsAutoBanWhenDisabled(t *testing.T) { bans := newMiddlewareBanService(t) handler := Chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) }), Bans(slog.New(slog.NewTextHandler(io.Discard, nil)), bans, nil)) for i := 0; i < 5; i++ { request := httptest.NewRequest(http.MethodGet, "/.env", nil) request.RemoteAddr = "203.0.113.23:6070" response := httptest.NewRecorder() handler.ServeHTTP(response, request) if response.Code == http.StatusForbidden { t.Fatalf("request %d was blocked while auto-ban disabled", i+1) } } if _, ok, err := bans.Match("203.0.113.23", time.Now().UTC()); err != nil || ok { t.Fatalf("disabled auto-ban Match = %v, %v", ok, err) } } func newMiddlewareBanService(t *testing.T) *services.BanService { t.Helper() root := t.TempDir() upload, err := services.NewUploadService(1024*1024, filepath.Join(root, "data"), "http://example.test", slog.Default()) if err != nil { t.Fatalf("NewUploadService returned error: %v", err) } t.Cleanup(func() { if err := upload.Close(); err != nil { t.Fatalf("Close returned error: %v", err) } }) bans, err := services.NewBanService(upload.DB()) if err != nil { t.Fatalf("NewBanService returned error: %v", err) } return bans }