fix(handlers): bypass box creation limits for batched uploads
All checks were successful
Build and Publish Docker Image / deploy (push) Successful in 1m42s
All checks were successful
Build and Publish Docker Image / deploy (push) Successful in 1m42s
Update `createOrAppendBox` to accept the upload policy and admin status, allowing policy enforcement to be handled during the box creation/append decision process. This ensures that appending files to an existing batch does not incorrectly trigger daily or active box creation limits, as no new box is being created. Also, add unit tests to verify that batched uploads successfully bypass both daily and active box creation caps.
This commit is contained in:
@@ -414,6 +414,80 @@ func TestLayeredUploadLimits(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBatchedUploadAppendBypassesDailyBoxCreationCap(t *testing.T) {
|
||||||
|
app, cleanup := newTestApp(t)
|
||||||
|
defer cleanup()
|
||||||
|
policy := testPolicy(t, app)
|
||||||
|
policy.AnonymousDailyBoxes = 1
|
||||||
|
policy.AnonymousActiveBoxes = 10
|
||||||
|
if err := app.settingsService.UpdateUploadPolicy(policy); err != nil {
|
||||||
|
t.Fatalf("UpdateUploadPolicy returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
first := multipartUploadRequest(t, "/api/v1/upload", "file", "first.txt", "hello")
|
||||||
|
first.Header.Set("Accept", "application/json")
|
||||||
|
first.Header.Set(uploadBatchHeader, "sharex-test")
|
||||||
|
firstResponse := httptest.NewRecorder()
|
||||||
|
app.Upload(firstResponse, first)
|
||||||
|
if firstResponse.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("first batched status = %d, body = %s", firstResponse.Code, firstResponse.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
second := multipartUploadRequest(t, "/api/v1/upload", "file", "second.txt", "hello")
|
||||||
|
second.Header.Set("Accept", "application/json")
|
||||||
|
second.Header.Set(uploadBatchHeader, "sharex-test")
|
||||||
|
secondResponse := httptest.NewRecorder()
|
||||||
|
app.Upload(secondResponse, second)
|
||||||
|
if secondResponse.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("second batched status = %d, body = %s", secondResponse.Code, secondResponse.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
third := multipartUploadRequest(t, "/api/v1/upload", "file", "third.txt", "hello")
|
||||||
|
third.Header.Set("Accept", "application/json")
|
||||||
|
thirdResponse := httptest.NewRecorder()
|
||||||
|
app.Upload(thirdResponse, third)
|
||||||
|
if thirdResponse.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("non-batched status = %d, body = %s", thirdResponse.Code, thirdResponse.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchedUploadAppendBypassesActiveBoxCreationCap(t *testing.T) {
|
||||||
|
app, cleanup := newTestApp(t)
|
||||||
|
defer cleanup()
|
||||||
|
policy := testPolicy(t, app)
|
||||||
|
policy.AnonymousDailyBoxes = 10
|
||||||
|
policy.AnonymousActiveBoxes = 1
|
||||||
|
if err := app.settingsService.UpdateUploadPolicy(policy); err != nil {
|
||||||
|
t.Fatalf("UpdateUploadPolicy returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
first := multipartUploadRequest(t, "/api/v1/upload", "file", "first.txt", "hello")
|
||||||
|
first.Header.Set("Accept", "application/json")
|
||||||
|
first.Header.Set(uploadBatchHeader, "active-cap")
|
||||||
|
firstResponse := httptest.NewRecorder()
|
||||||
|
app.Upload(firstResponse, first)
|
||||||
|
if firstResponse.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("first batched status = %d, body = %s", firstResponse.Code, firstResponse.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
second := multipartUploadRequest(t, "/api/v1/upload", "file", "second.txt", "hello")
|
||||||
|
second.Header.Set("Accept", "application/json")
|
||||||
|
second.Header.Set(uploadBatchHeader, "active-cap")
|
||||||
|
secondResponse := httptest.NewRecorder()
|
||||||
|
app.Upload(secondResponse, second)
|
||||||
|
if secondResponse.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("second batched status = %d, body = %s", secondResponse.Code, secondResponse.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
third := multipartUploadRequest(t, "/api/v1/upload", "file", "third.txt", "hello")
|
||||||
|
third.Header.Set("Accept", "application/json")
|
||||||
|
thirdResponse := httptest.NewRecorder()
|
||||||
|
app.Upload(thirdResponse, third)
|
||||||
|
if thirdResponse.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("non-batched status = %d, body = %s", thirdResponse.Code, thirdResponse.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUserPolicyOverrideChangesUploadEnforcement(t *testing.T) {
|
func TestUserPolicyOverrideChangesUploadEnforcement(t *testing.T) {
|
||||||
app, cleanup := newTestApp(t)
|
app, cleanup := newTestApp(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|||||||
@@ -121,7 +121,12 @@ func (a *App) Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
CreatorIP: uploadClientIP(r),
|
CreatorIP: uploadClientIP(r),
|
||||||
StorageBackendID: effectivePolicy.StorageBackendID,
|
StorageBackendID: effectivePolicy.StorageBackendID,
|
||||||
}
|
}
|
||||||
result, boxesAdded, err := a.createOrAppendBox(r, user, loggedIn, files, opts)
|
result, boxesAdded, status, policyMessage, err := a.createOrAppendBox(r, user, loggedIn, effectivePolicy, files, opts, !isAdminUpload)
|
||||||
|
if policyMessage != "" {
|
||||||
|
a.logger.Warn("upload rejected by policy", "source", "quota", "severity", "warn", "code", status, "ip", uploadClientIP(r), "user_id", user.ID, "message", policyMessage, "bytes", totalBytes, "files", len(files))
|
||||||
|
helpers.WriteJSONError(w, status, policyMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Warn("upload failed", "source", "user-upload", "severity", "warn", "code", 4001, "ip", uploadClientIP(r), "user_id", user.ID, "error", err.Error())
|
a.logger.Warn("upload failed", "source", "user-upload", "severity", "warn", "code", 4001, "ip", uploadClientIP(r), "user_id", user.ID, "error", err.Error())
|
||||||
helpers.WriteJSONError(w, http.StatusBadRequest, err.Error())
|
helpers.WriteJSONError(w, http.StatusBadRequest, err.Error())
|
||||||
@@ -154,14 +159,19 @@ func (a *App) Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// uploadGroupWindow are folded into one box. Without the header the behaviour is
|
// uploadGroupWindow are folded into one box. Without the header the behaviour is
|
||||||
// identical to creating a fresh box every time. Returns the result and how many
|
// identical to creating a fresh box every time. Returns the result and how many
|
||||||
// boxes were created (1 for a new box, 0 for an append) for usage accounting.
|
// boxes were created (1 for a new box, 0 for an append) for usage accounting.
|
||||||
func (a *App) createOrAppendBox(r *http.Request, user services.User, loggedIn bool, files []*multipart.FileHeader, opts services.UploadOptions) (services.UploadResult, int, error) {
|
func (a *App) createOrAppendBox(r *http.Request, user services.User, loggedIn bool, policy services.EffectiveUploadPolicy, files []*multipart.FileHeader, opts services.UploadOptions, enforceBoxLimits bool) (services.UploadResult, int, int, string, error) {
|
||||||
batch := strings.TrimSpace(r.Header.Get(uploadBatchHeader))
|
batch := strings.TrimSpace(r.Header.Get(uploadBatchHeader))
|
||||||
if batch == "" {
|
if batch == "" {
|
||||||
|
if enforceBoxLimits {
|
||||||
|
if status, message := a.checkBoxCreationPolicy(r, user, loggedIn, policy); message != "" {
|
||||||
|
return services.UploadResult{}, 0, status, message, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
result, err := a.uploadService.CreateBox(files, opts)
|
result, err := a.uploadService.CreateBox(files, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return services.UploadResult{}, 0, err
|
return services.UploadResult{}, 0, 0, "", err
|
||||||
}
|
}
|
||||||
return result, 1, nil
|
return result, 1, 0, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group key is scoped to the uploader so batches never cross accounts/IPs.
|
// Group key is scoped to the uploader so batches never cross accounts/IPs.
|
||||||
@@ -184,20 +194,25 @@ func (a *App) createOrAppendBox(r *http.Request, user services.User, loggedIn bo
|
|||||||
result.ManageURL = entry.manageURL
|
result.ManageURL = entry.manageURL
|
||||||
result.DeleteURL = entry.deleteURL
|
result.DeleteURL = entry.deleteURL
|
||||||
entry.at = time.Now()
|
entry.at = time.Now()
|
||||||
return result, 0, nil
|
return result, 0, 0, "", nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if enforceBoxLimits {
|
||||||
|
if status, message := a.checkBoxCreationPolicy(r, user, loggedIn, policy); message != "" {
|
||||||
|
return services.UploadResult{}, 0, status, message, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
result, err := a.uploadService.CreateBox(files, opts)
|
result, err := a.uploadService.CreateBox(files, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return services.UploadResult{}, 0, err
|
return services.UploadResult{}, 0, 0, "", err
|
||||||
}
|
}
|
||||||
entry.boxID = result.BoxID
|
entry.boxID = result.BoxID
|
||||||
entry.manageURL = result.ManageURL
|
entry.manageURL = result.ManageURL
|
||||||
entry.deleteURL = result.DeleteURL
|
entry.deleteURL = result.DeleteURL
|
||||||
entry.at = time.Now()
|
entry.at = time.Now()
|
||||||
return result, 1, nil
|
return result, 1, 0, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// batchBoxMatches guards that a batched append only ever touches a box owned by
|
// batchBoxMatches guards that a batched append only ever touches a box owned by
|
||||||
@@ -230,16 +245,6 @@ func (a *App) checkUploadPolicy(r *http.Request, user services.User, loggedIn bo
|
|||||||
if policy.DailyUploadMB > 0 && usage.UploadedBytes+totalBytes > services.MegabytesToBytes(policy.DailyUploadMB) {
|
if policy.DailyUploadMB > 0 && usage.UploadedBytes+totalBytes > services.MegabytesToBytes(policy.DailyUploadMB) {
|
||||||
return http.StatusTooManyRequests, "anonymous daily upload limit reached"
|
return http.StatusTooManyRequests, "anonymous daily upload limit reached"
|
||||||
}
|
}
|
||||||
if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes {
|
|
||||||
return http.StatusTooManyRequests, "anonymous daily box limit reached"
|
|
||||||
}
|
|
||||||
activeBoxes, err := a.uploadService.ActiveBoxCountForIP(uploadClientIP(r))
|
|
||||||
if err != nil {
|
|
||||||
return http.StatusInternalServerError, "active box limit could not be checked"
|
|
||||||
}
|
|
||||||
if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes {
|
|
||||||
return http.StatusTooManyRequests, "anonymous active box limit reached"
|
|
||||||
}
|
|
||||||
if status, message := a.checkStorageBackendCapacity(policy.StorageBackendID, settings, totalBytes); message != "" {
|
if status, message := a.checkStorageBackendCapacity(policy.StorageBackendID, settings, totalBytes); message != "" {
|
||||||
return status, message
|
return status, message
|
||||||
}
|
}
|
||||||
@@ -253,16 +258,6 @@ func (a *App) checkUploadPolicy(r *http.Request, user services.User, loggedIn bo
|
|||||||
if policy.DailyUploadMB > 0 && usage.UploadedBytes+totalBytes > services.MegabytesToBytes(policy.DailyUploadMB) {
|
if policy.DailyUploadMB > 0 && usage.UploadedBytes+totalBytes > services.MegabytesToBytes(policy.DailyUploadMB) {
|
||||||
return http.StatusTooManyRequests, "daily upload limit reached"
|
return http.StatusTooManyRequests, "daily upload limit reached"
|
||||||
}
|
}
|
||||||
if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes {
|
|
||||||
return http.StatusTooManyRequests, "daily box limit reached"
|
|
||||||
}
|
|
||||||
activeBoxes, err := a.uploadService.ActiveBoxCountForUser(user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return http.StatusInternalServerError, "active box limit could not be checked"
|
|
||||||
}
|
|
||||||
if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes {
|
|
||||||
return http.StatusTooManyRequests, "active box limit reached"
|
|
||||||
}
|
|
||||||
activeStorage, err := a.uploadService.UserActiveStorageUsed(user.ID)
|
activeStorage, err := a.uploadService.UserActiveStorageUsed(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusInternalServerError, "storage quota could not be checked"
|
return http.StatusInternalServerError, "storage quota could not be checked"
|
||||||
@@ -276,6 +271,42 @@ func (a *App) checkUploadPolicy(r *http.Request, user services.User, loggedIn bo
|
|||||||
return 0, ""
|
return 0, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) checkBoxCreationPolicy(r *http.Request, user services.User, loggedIn bool, policy services.EffectiveUploadPolicy) (int, string) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if !loggedIn {
|
||||||
|
usage, err := a.settingsService.UsageForIP(uploadClientIP(r), now)
|
||||||
|
if err != nil {
|
||||||
|
return http.StatusInternalServerError, "upload usage could not be checked"
|
||||||
|
}
|
||||||
|
if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes {
|
||||||
|
return http.StatusTooManyRequests, "anonymous daily box limit reached"
|
||||||
|
}
|
||||||
|
activeBoxes, err := a.uploadService.ActiveBoxCountForIP(uploadClientIP(r))
|
||||||
|
if err != nil {
|
||||||
|
return http.StatusInternalServerError, "active box limit could not be checked"
|
||||||
|
}
|
||||||
|
if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes {
|
||||||
|
return http.StatusTooManyRequests, "anonymous active box limit reached"
|
||||||
|
}
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
usage, err := a.settingsService.UsageForUser(user.ID, now)
|
||||||
|
if err != nil {
|
||||||
|
return http.StatusInternalServerError, "upload usage could not be checked"
|
||||||
|
}
|
||||||
|
if policy.DailyBoxes > 0 && usage.UploadedBoxes+1 > policy.DailyBoxes {
|
||||||
|
return http.StatusTooManyRequests, "daily box limit reached"
|
||||||
|
}
|
||||||
|
activeBoxes, err := a.uploadService.ActiveBoxCountForUser(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return http.StatusInternalServerError, "active box limit could not be checked"
|
||||||
|
}
|
||||||
|
if policy.ActiveBoxes > 0 && activeBoxes+1 > policy.ActiveBoxes {
|
||||||
|
return http.StatusTooManyRequests, "active box limit reached"
|
||||||
|
}
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) recordUploadUsage(r *http.Request, user services.User, loggedIn bool, totalBytes int64, boxes int) error {
|
func (a *App) recordUploadUsage(r *http.Request, user services.User, loggedIn bool, totalBytes int64, boxes int) error {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
if loggedIn {
|
if loggedIn {
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (g *uploadGrouper) entryFor(key string) *uploadGroupEntry {
|
|||||||
g.pruneLocked(time.Now())
|
g.pruneLocked(time.Now())
|
||||||
entry, ok := g.entries[key]
|
entry, ok := g.entries[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
entry = &uploadGroupEntry{}
|
entry = &uploadGroupEntry{at: time.Now()}
|
||||||
g.entries[key] = entry
|
g.entries[key] = entry
|
||||||
}
|
}
|
||||||
return entry
|
return entry
|
||||||
@@ -56,8 +56,8 @@ func (g *uploadGrouper) entryFor(key string) *uploadGroupEntry {
|
|||||||
|
|
||||||
// pruneLocked drops entries whose last use is well past the grouping window so
|
// pruneLocked drops entries whose last use is well past the grouping window so
|
||||||
// the map stays bounded to recently-active keys. Callers must hold g.mu. Entries
|
// the map stays bounded to recently-active keys. Callers must hold g.mu. Entries
|
||||||
// currently in use, or freshly created but not yet used (zero timestamp), are
|
// currently in use are kept to avoid removing one a request is about to
|
||||||
// kept to avoid removing one a request is about to populate.
|
// populate.
|
||||||
func (g *uploadGrouper) pruneLocked(now time.Time) {
|
func (g *uploadGrouper) pruneLocked(now time.Time) {
|
||||||
if now.Sub(g.lastPrune) < uploadGroupPruneInterval {
|
if now.Sub(g.lastPrune) < uploadGroupPruneInterval {
|
||||||
return
|
return
|
||||||
@@ -67,7 +67,7 @@ func (g *uploadGrouper) pruneLocked(now time.Time) {
|
|||||||
if !entry.mu.TryLock() {
|
if !entry.mu.TryLock() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stale := !entry.at.IsZero() && now.Sub(entry.at) > 2*uploadGroupWindow
|
stale := now.Sub(entry.at) > 2*uploadGroupWindow
|
||||||
entry.mu.Unlock()
|
entry.mu.Unlock()
|
||||||
if stale {
|
if stale {
|
||||||
delete(g.entries, key)
|
delete(g.entries, key)
|
||||||
|
|||||||
24
backend/libs/handlers/upload_group_test.go
Normal file
24
backend/libs/handlers/upload_group_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUploadGroupPrunesFailedEntries(t *testing.T) {
|
||||||
|
g := newUploadGrouper()
|
||||||
|
entry := g.entryFor("ip:203.0.113.1|failed")
|
||||||
|
entry.mu.Lock()
|
||||||
|
entry.at = time.Now().Add(-3 * uploadGroupWindow)
|
||||||
|
entry.mu.Unlock()
|
||||||
|
g.lastPrune = time.Now().Add(-uploadGroupPruneInterval)
|
||||||
|
|
||||||
|
_ = g.entryFor("ip:203.0.113.1|next")
|
||||||
|
|
||||||
|
if _, ok := g.entries["ip:203.0.113.1|failed"]; ok {
|
||||||
|
t.Fatalf("stale failed entry was not pruned")
|
||||||
|
}
|
||||||
|
if _, ok := g.entries["ip:203.0.113.1|next"]; !ok {
|
||||||
|
t.Fatalf("new entry was not created")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,10 +26,20 @@ func Bans(logger *slog.Logger, bans *services.BanService, trustedProxies []strin
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settings, err := bans.Settings()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("ban settings load failed", "source", "ban", "severity", "error", "code", 5004, "ip", ip, "error", err.Error())
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !settings.AutoBanEnabled {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
if pattern, err := bans.MaliciousPattern(r.URL.Path); err != nil {
|
if pattern, err := bans.MaliciousPattern(r.URL.Path); err != nil {
|
||||||
logger.Error("malicious path check failed", "source", "ban", "severity", "error", "code", 5002, "ip", ip, "error", err.Error())
|
logger.Error("malicious path check failed", "source", "ban", "severity", "error", "code", 5002, "ip", ip, "error", err.Error())
|
||||||
} else if pattern != "" {
|
} else if pattern != "" {
|
||||||
if result, err := bans.RecordAbuse(ip, services.AbuseKindMaliciousPath, r.URL.Path, banThreshold(bans, services.AbuseKindMaliciousPath), now); err != nil {
|
if result, err := bans.RecordAbuse(ip, services.AbuseKindMaliciousPath, r.URL.Path, settings.MaliciousPathThreshold, now); err != nil {
|
||||||
logger.Error("malicious path event failed", "source", "ban", "severity", "error", "code", 5003, "ip", ip, "path", r.URL.Path, "error", err.Error())
|
logger.Error("malicious path event failed", "source", "ban", "severity", "error", "code", 5003, "ip", ip, "path", r.URL.Path, "error", err.Error())
|
||||||
} else if result.Enabled {
|
} else if result.Enabled {
|
||||||
logger.Warn("malicious path requested", "source", "ban", "severity", "warn", "code", 4302, "ip", ip, "path", r.URL.Path, "pattern", pattern, "count", result.Event.Count)
|
logger.Warn("malicious path requested", "source", "ban", "severity", "warn", "code", 4302, "ip", ip, "path", r.URL.Path, "pattern", pattern, "count", result.Event.Count)
|
||||||
@@ -48,18 +58,3 @@ func Bans(logger *slog.Logger, bans *services.BanService, trustedProxies []strin
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func banThreshold(bans *services.BanService, kind string) int {
|
|
||||||
settings, err := bans.Settings()
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
switch kind {
|
|
||||||
case services.AbuseKindAdminLogin:
|
|
||||||
return settings.AdminLoginFailureThreshold
|
|
||||||
case services.AbuseKindUserLogin:
|
|
||||||
return settings.UserLoginFailureThreshold
|
|
||||||
default:
|
|
||||||
return settings.MaliciousPathThreshold
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -79,6 +79,26 @@ func TestBansMiddlewareAutoBansMaliciousPaths(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBansMiddlewareSkipsAutoBanWhenDisabled(t *testing.T) {
|
||||||
|
bans := newMiddlewareBanService(t)
|
||||||
|
handler := Chain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}), Bans(slog.New(slog.NewTextHandler(io.Discard, nil)), bans, nil))
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "/.env", nil)
|
||||||
|
request.RemoteAddr = "203.0.113.23:6070"
|
||||||
|
response := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(response, request)
|
||||||
|
if response.Code == http.StatusForbidden {
|
||||||
|
t.Fatalf("request %d was blocked while auto-ban disabled", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok, err := bans.Match("203.0.113.23", time.Now().UTC()); err != nil || ok {
|
||||||
|
t.Fatalf("disabled auto-ban Match = %v, %v", ok, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newMiddlewareBanService(t *testing.T) *services.BanService {
|
func newMiddlewareBanService(t *testing.T) *services.BanService {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
|
|||||||
@@ -519,6 +519,11 @@ func (s *BanService) RecordAbuse(ip, kind, detail string, threshold int, now tim
|
|||||||
if err != nil || !triggered {
|
if err != nil || !triggered {
|
||||||
return AbuseResult{Event: event, Triggered: false, Enabled: true}, err
|
return AbuseResult{Event: event, Triggered: false, Enabled: true}, err
|
||||||
}
|
}
|
||||||
|
if matched, ok, err := s.Match(ip, now); err != nil {
|
||||||
|
return AbuseResult{}, err
|
||||||
|
} else if ok {
|
||||||
|
return AbuseResult{Event: event, Ban: matched.Ban, Triggered: true, Enabled: true}, nil
|
||||||
|
}
|
||||||
reason := fmt.Sprintf("%s threshold reached: %s", strings.ReplaceAll(kind, "_", " "), detail)
|
reason := fmt.Sprintf("%s threshold reached: %s", strings.ReplaceAll(kind, "_", " "), detail)
|
||||||
ban, err = s.createBan(ip, reason, BanSourceAuto, "", now.Add(time.Duration(settings.AutoBanDurationHours)*time.Hour), now)
|
ban, err = s.createBan(ip, reason, BanSourceAuto, "", now.Add(time.Duration(settings.AutoBanDurationHours)*time.Hour), now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -79,6 +79,17 @@ func TestBanServiceAutoBanThresholdsAndDisabled(t *testing.T) {
|
|||||||
if err != nil || !result.Triggered || result.Ban.ID == "" {
|
if err != nil || !result.Triggered || result.Ban.ID == "" {
|
||||||
t.Fatalf("RecordAbuse threshold = %+v, %v", result, err)
|
t.Fatalf("RecordAbuse threshold = %+v, %v", result, err)
|
||||||
}
|
}
|
||||||
|
again, err := bans.RecordAbuse("203.0.113.8", AbuseKindMaliciousPath, "/.env", 3, now.Add(4*time.Minute))
|
||||||
|
if err != nil || !again.Triggered || again.Ban.ID != result.Ban.ID {
|
||||||
|
t.Fatalf("RecordAbuse duplicate = %+v, %v", again, err)
|
||||||
|
}
|
||||||
|
records, err := bans.ListBans()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListBans returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(records) != 1 {
|
||||||
|
t.Fatalf("ban count = %d, want 1", len(records))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBanServiceMaliciousPathRules(t *testing.T) {
|
func TestBanServiceMaliciousPathRules(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user