diff --git a/backend/libs/handlers/accounts_test.go b/backend/libs/handlers/accounts_test.go index 823c93e..fa74146 100644 --- a/backend/libs/handlers/accounts_test.go +++ b/backend/libs/handlers/accounts_test.go @@ -414,6 +414,80 @@ func TestLayeredUploadLimits(t *testing.T) { } } +func TestBatchedUploadAppendBypassesDailyBoxCreationCap(t *testing.T) { + app, cleanup := newTestApp(t) + defer cleanup() + policy := testPolicy(t, app) + policy.AnonymousDailyBoxes = 1 + policy.AnonymousActiveBoxes = 10 + if err := app.settingsService.UpdateUploadPolicy(policy); err != nil { + t.Fatalf("UpdateUploadPolicy returned error: %v", err) + } + + first := multipartUploadRequest(t, "/api/v1/upload", "file", "first.txt", "hello") + first.Header.Set("Accept", "application/json") + first.Header.Set(uploadBatchHeader, "sharex-test") + firstResponse := httptest.NewRecorder() + app.Upload(firstResponse, first) + if firstResponse.Code != http.StatusCreated { + t.Fatalf("first batched status = %d, body = %s", firstResponse.Code, firstResponse.Body.String()) + } + + second := multipartUploadRequest(t, "/api/v1/upload", "file", "second.txt", "hello") + second.Header.Set("Accept", "application/json") + second.Header.Set(uploadBatchHeader, "sharex-test") + secondResponse := httptest.NewRecorder() + app.Upload(secondResponse, second) + if secondResponse.Code != http.StatusCreated { + t.Fatalf("second batched status = %d, body = %s", secondResponse.Code, secondResponse.Body.String()) + } + + third := multipartUploadRequest(t, "/api/v1/upload", "file", "third.txt", "hello") + third.Header.Set("Accept", "application/json") + thirdResponse := httptest.NewRecorder() + app.Upload(thirdResponse, third) + if thirdResponse.Code != http.StatusTooManyRequests { + t.Fatalf("non-batched status = %d, body = %s", thirdResponse.Code, thirdResponse.Body.String()) + } +} + +func TestBatchedUploadAppendBypassesActiveBoxCreationCap(t *testing.T) { + app, cleanup := newTestApp(t) + defer cleanup() + policy := testPolicy(t, app) + policy.AnonymousDailyBoxes = 10 + policy.AnonymousActiveBoxes = 1 + if err := app.settingsService.UpdateUploadPolicy(policy); err != nil { + t.Fatalf("UpdateUploadPolicy returned error: %v", err) + } + + first := multipartUploadRequest(t, "/api/v1/upload", "file", "first.txt", "hello") + first.Header.Set("Accept", "application/json") + first.Header.Set(uploadBatchHeader, "active-cap") + firstResponse := httptest.NewRecorder() + app.Upload(firstResponse, first) + if firstResponse.Code != http.StatusCreated { + t.Fatalf("first batched status = %d, body = %s", firstResponse.Code, firstResponse.Body.String()) + } + + second := multipartUploadRequest(t, "/api/v1/upload", "file", "second.txt", "hello") + second.Header.Set("Accept", "application/json") + second.Header.Set(uploadBatchHeader, "active-cap") + secondResponse := httptest.NewRecorder() + app.Upload(secondResponse, second) + if secondResponse.Code != http.StatusCreated { + t.Fatalf("second batched status = %d, body = %s", secondResponse.Code, secondResponse.Body.String()) + } + + third := multipartUploadRequest(t, "/api/v1/upload", "file", "third.txt", "hello") + third.Header.Set("Accept", "application/json") + thirdResponse := httptest.NewRecorder() + app.Upload(thirdResponse, third) + if thirdResponse.Code != http.StatusTooManyRequests { + t.Fatalf("non-batched status = %d, body = %s", thirdResponse.Code, thirdResponse.Body.String()) + } +} + func TestUserPolicyOverrideChangesUploadEnforcement(t *testing.T) { app, cleanup := newTestApp(t) defer cleanup() diff --git a/backend/libs/handlers/upload.go b/backend/libs/handlers/upload.go index 27756b0..cff3035 100644 --- a/backend/libs/handlers/upload.go +++ b/backend/libs/handlers/upload.go @@ -121,7 +121,12 @@ func (a *App) Upload(w http.ResponseWriter, r *http.Request) { CreatorIP: uploadClientIP(r), StorageBackendID: effectivePolicy.StorageBackendID, } - result, boxesAdded, err := a.createOrAppendBox(r, user, loggedIn, files, opts) + result, boxesAdded, status, policyMessage, err := a.createOrAppendBox(r, user, loggedIn, effectivePolicy, files, opts, !isAdminUpload) + if policyMessage != "" { + a.logger.Warn("upload rejected by policy", "source", "quota", "severity", "warn", "code", status, "ip", uploadClientIP(r), "user_id", user.ID, "message", policyMessage, "bytes", totalBytes, "files", len(files)) + helpers.WriteJSONError(w, status, policyMessage) + return + } if err != nil { a.logger.Warn("upload failed", "source", "user-upload", "severity", "warn", "code", 4001, "ip", uploadClientIP(r), "user_id", user.ID, "error", err.Error()) helpers.WriteJSONError(w, http.StatusBadRequest, err.Error()) @@ -154,14 +159,19 @@ func (a *App) Upload(w http.ResponseWriter, r *http.Request) { // uploadGroupWindow are folded into one box. Without the header the behaviour is // identical to creating a fresh box every time. Returns the result and how many // boxes were created (1 for a new box, 0 for an append) for usage accounting. -func (a *App) createOrAppendBox(r *http.Request, user services.User, loggedIn bool, files []*multipart.FileHeader, opts services.UploadOptions) (services.UploadResult, int, error) { +func (a *App) createOrAppendBox(r *http.Request, user services.User, loggedIn bool, policy services.EffectiveUploadPolicy, files []*multipart.FileHeader, opts services.UploadOptions, enforceBoxLimits bool) (services.UploadResult, int, int, string, error) { batch := strings.TrimSpace(r.Header.Get(uploadBatchHeader)) if batch == "" { + if enforceBoxLimits { + if status, message := a.checkBoxCreationPolicy(r, user, loggedIn, policy); message != "" { + return services.UploadResult{}, 0, status, message, nil + } + } result, err := a.uploadService.CreateBox(files, opts) if err != nil { - return services.UploadResult{}, 0, err + return services.UploadResult{}, 0, 0, "", err } - return result, 1, nil + return result, 1, 0, "", nil } // Group key is scoped to the uploader so batches never cross accounts/IPs. @@ -184,20 +194,25 @@ func (a *App) createOrAppendBox(r *http.Request, user services.User, loggedIn bo result.ManageURL = entry.manageURL result.DeleteURL = entry.deleteURL entry.at = time.Now() - return result, 0, nil + return result, 0, 0, "", nil } } } + if enforceBoxLimits { + if status, message := a.checkBoxCreationPolicy(r, user, loggedIn, policy); message != "" { + return services.UploadResult{}, 0, status, message, nil + } + } result, err := a.uploadService.CreateBox(files, opts) if err != nil { - return services.UploadResult{}, 0, err + return services.UploadResult{}, 0, 0, "", err } entry.boxID = result.BoxID entry.manageURL = result.ManageURL entry.deleteURL = result.DeleteURL entry.at = time.Now() - return result, 1, nil + return result, 1, 0, "", nil } // batchBoxMatches guards that a batched append only ever touches a box owned by @@ -230,16 +245,6 @@ func (a *App) checkUploadPolicy(r *http.Request, user services.User, loggedIn bo if policy.DailyUploadMB > 0 && usage.UploadedBytes+totalBytes > services.MegabytesToBytes(policy.DailyUploadMB) { return http.StatusTooManyRequests, "anonymous daily upload limit reached" } - if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes { - return http.StatusTooManyRequests, "anonymous daily box limit reached" - } - activeBoxes, err := a.uploadService.ActiveBoxCountForIP(uploadClientIP(r)) - if err != nil { - return http.StatusInternalServerError, "active box limit could not be checked" - } - if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes { - return http.StatusTooManyRequests, "anonymous active box limit reached" - } if status, message := a.checkStorageBackendCapacity(policy.StorageBackendID, settings, totalBytes); message != "" { return status, message } @@ -253,16 +258,6 @@ func (a *App) checkUploadPolicy(r *http.Request, user services.User, loggedIn bo if policy.DailyUploadMB > 0 && usage.UploadedBytes+totalBytes > services.MegabytesToBytes(policy.DailyUploadMB) { return http.StatusTooManyRequests, "daily upload limit reached" } - if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes { - return http.StatusTooManyRequests, "daily box limit reached" - } - activeBoxes, err := a.uploadService.ActiveBoxCountForUser(user.ID) - if err != nil { - return http.StatusInternalServerError, "active box limit could not be checked" - } - if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes { - return http.StatusTooManyRequests, "active box limit reached" - } activeStorage, err := a.uploadService.UserActiveStorageUsed(user.ID) if err != nil { return http.StatusInternalServerError, "storage quota could not be checked" @@ -276,6 +271,42 @@ func (a *App) checkUploadPolicy(r *http.Request, user services.User, loggedIn bo return 0, "" } +func (a *App) checkBoxCreationPolicy(r *http.Request, user services.User, loggedIn bool, policy services.EffectiveUploadPolicy) (int, string) { + now := time.Now().UTC() + if !loggedIn { + usage, err := a.settingsService.UsageForIP(uploadClientIP(r), now) + if err != nil { + return http.StatusInternalServerError, "upload usage could not be checked" + } + if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes { + return http.StatusTooManyRequests, "anonymous daily box limit reached" + } + activeBoxes, err := a.uploadService.ActiveBoxCountForIP(uploadClientIP(r)) + if err != nil { + return http.StatusInternalServerError, "active box limit could not be checked" + } + if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes { + return http.StatusTooManyRequests, "anonymous active box limit reached" + } + return 0, "" + } + usage, err := a.settingsService.UsageForUser(user.ID, now) + if err != nil { + return http.StatusInternalServerError, "upload usage could not be checked" + } + if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes { + return http.StatusTooManyRequests, "daily box limit reached" + } + activeBoxes, err := a.uploadService.ActiveBoxCountForUser(user.ID) + if err != nil { + return http.StatusInternalServerError, "active box limit could not be checked" + } + if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes { + return http.StatusTooManyRequests, "active box limit reached" + } + return 0, "" +} + func (a *App) recordUploadUsage(r *http.Request, user services.User, loggedIn bool, totalBytes int64, boxes int) error { now := time.Now().UTC() if loggedIn { diff --git a/backend/libs/handlers/upload_group.go b/backend/libs/handlers/upload_group.go index 26f33da..3c0e656 100644 --- a/backend/libs/handlers/upload_group.go +++ b/backend/libs/handlers/upload_group.go @@ -48,7 +48,7 @@ func (g *uploadGrouper) entryFor(key string) *uploadGroupEntry { g.pruneLocked(time.Now()) entry, ok := g.entries[key] if !ok { - entry = &uploadGroupEntry{} + entry = &uploadGroupEntry{at: time.Now()} g.entries[key] = entry } return entry @@ -56,8 +56,8 @@ func (g *uploadGrouper) entryFor(key string) *uploadGroupEntry { // pruneLocked drops entries whose last use is well past the grouping window so // the map stays bounded to recently-active keys. Callers must hold g.mu. Entries -// currently in use, or freshly created but not yet used (zero timestamp), are -// kept to avoid removing one a request is about to populate. +// currently in use are kept to avoid removing one a request is about to +// populate. func (g *uploadGrouper) pruneLocked(now time.Time) { if now.Sub(g.lastPrune) < uploadGroupPruneInterval { return @@ -67,7 +67,7 @@ func (g *uploadGrouper) pruneLocked(now time.Time) { if !entry.mu.TryLock() { continue } - stale := !entry.at.IsZero() && now.Sub(entry.at) > 2*uploadGroupWindow + stale := now.Sub(entry.at) > 2*uploadGroupWindow entry.mu.Unlock() if stale { delete(g.entries, key) diff --git a/backend/libs/handlers/upload_group_test.go b/backend/libs/handlers/upload_group_test.go new file mode 100644 index 0000000..0f527c7 --- /dev/null +++ b/backend/libs/handlers/upload_group_test.go @@ -0,0 +1,24 @@ +package handlers + +import ( + "testing" + "time" +) + +func TestUploadGroupPrunesFailedEntries(t *testing.T) { + g := newUploadGrouper() + entry := g.entryFor("ip:203.0.113.1|failed") + entry.mu.Lock() + entry.at = time.Now().Add(-3 * uploadGroupWindow) + entry.mu.Unlock() + g.lastPrune = time.Now().Add(-uploadGroupPruneInterval) + + _ = g.entryFor("ip:203.0.113.1|next") + + if _, ok := g.entries["ip:203.0.113.1|failed"]; ok { + t.Fatalf("stale failed entry was not pruned") + } + if _, ok := g.entries["ip:203.0.113.1|next"]; !ok { + t.Fatalf("new entry was not created") + } +} diff --git a/backend/libs/middleware/bans.go b/backend/libs/middleware/bans.go index 4360f1f..4869bfc 100644 --- a/backend/libs/middleware/bans.go +++ b/backend/libs/middleware/bans.go @@ -26,10 +26,20 @@ func Bans(logger *slog.Logger, bans *services.BanService, trustedProxies []strin return } + settings, err := bans.Settings() + if err != nil { + logger.Error("ban settings load failed", "source", "ban", "severity", "error", "code", 5004, "ip", ip, "error", err.Error()) + next.ServeHTTP(w, r) + return + } + if !settings.AutoBanEnabled { + next.ServeHTTP(w, r) + return + } if pattern, err := bans.MaliciousPattern(r.URL.Path); err != nil { logger.Error("malicious path check failed", "source", "ban", "severity", "error", "code", 5002, "ip", ip, "error", err.Error()) } else if pattern != "" { - if result, err := bans.RecordAbuse(ip, services.AbuseKindMaliciousPath, r.URL.Path, banThreshold(bans, services.AbuseKindMaliciousPath), now); err != nil { + if result, err := bans.RecordAbuse(ip, services.AbuseKindMaliciousPath, r.URL.Path, settings.MaliciousPathThreshold, now); err != nil { logger.Error("malicious path event failed", "source", "ban", "severity", "error", "code", 5003, "ip", ip, "path", r.URL.Path, "error", err.Error()) } else if result.Enabled { logger.Warn("malicious path requested", "source", "ban", "severity", "warn", "code", 4302, "ip", ip, "path", r.URL.Path, "pattern", pattern, "count", result.Event.Count) @@ -48,18 +58,3 @@ func Bans(logger *slog.Logger, bans *services.BanService, trustedProxies []strin }) } } - -func banThreshold(bans *services.BanService, kind string) int { - settings, err := bans.Settings() - if err != nil { - return 0 - } - switch kind { - case services.AbuseKindAdminLogin: - return settings.AdminLoginFailureThreshold - case services.AbuseKindUserLogin: - return settings.UserLoginFailureThreshold - default: - return settings.MaliciousPathThreshold - } -} diff --git a/backend/libs/middleware/bans_test.go b/backend/libs/middleware/bans_test.go index 681fee9..b6dcb61 100644 --- a/backend/libs/middleware/bans_test.go +++ b/backend/libs/middleware/bans_test.go @@ -79,6 +79,26 @@ func TestBansMiddlewareAutoBansMaliciousPaths(t *testing.T) { } } +func TestBansMiddlewareSkipsAutoBanWhenDisabled(t *testing.T) { + bans := newMiddlewareBanService(t) + handler := Chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }), Bans(slog.New(slog.NewTextHandler(io.Discard, nil)), bans, nil)) + + for i := 0; i < 5; i++ { + request := httptest.NewRequest(http.MethodGet, "/.env", nil) + request.RemoteAddr = "203.0.113.23:6070" + response := httptest.NewRecorder() + handler.ServeHTTP(response, request) + if response.Code == http.StatusForbidden { + t.Fatalf("request %d was blocked while auto-ban disabled", i+1) + } + } + if _, ok, err := bans.Match("203.0.113.23", time.Now().UTC()); err != nil || ok { + t.Fatalf("disabled auto-ban Match = %v, %v", ok, err) + } +} + func newMiddlewareBanService(t *testing.T) *services.BanService { t.Helper() root := t.TempDir() diff --git a/backend/libs/services/bans.go b/backend/libs/services/bans.go index 8f646f8..d5fc7d9 100644 --- a/backend/libs/services/bans.go +++ b/backend/libs/services/bans.go @@ -519,6 +519,11 @@ func (s *BanService) RecordAbuse(ip, kind, detail string, threshold int, now tim if err != nil || !triggered { return AbuseResult{Event: event, Triggered: false, Enabled: true}, err } + if matched, ok, err := s.Match(ip, now); err != nil { + return AbuseResult{}, err + } else if ok { + return AbuseResult{Event: event, Ban: matched.Ban, Triggered: true, Enabled: true}, nil + } reason := fmt.Sprintf("%s threshold reached: %s", strings.ReplaceAll(kind, "_", " "), detail) ban, err = s.createBan(ip, reason, BanSourceAuto, "", now.Add(time.Duration(settings.AutoBanDurationHours)*time.Hour), now) if err != nil { diff --git a/backend/libs/services/bans_test.go b/backend/libs/services/bans_test.go index 6358dac..f558b96 100644 --- a/backend/libs/services/bans_test.go +++ b/backend/libs/services/bans_test.go @@ -79,6 +79,17 @@ func TestBanServiceAutoBanThresholdsAndDisabled(t *testing.T) { if err != nil || !result.Triggered || result.Ban.ID == "" { t.Fatalf("RecordAbuse threshold = %+v, %v", result, err) } + again, err := bans.RecordAbuse("203.0.113.8", AbuseKindMaliciousPath, "/.env", 3, now.Add(4*time.Minute)) + if err != nil || !again.Triggered || again.Ban.ID != result.Ban.ID { + t.Fatalf("RecordAbuse duplicate = %+v, %v", again, err) + } + records, err := bans.ListBans() + if err != nil { + t.Fatalf("ListBans returned error: %v", err) + } + if len(records) != 1 { + t.Fatalf("ban count = %d, want 1", len(records)) + } } func TestBanServiceMaliciousPathRules(t *testing.T) {