feat(security): add trusted proxies and abuse event cleanup
All checks were successful
Build and Publish Docker Image / deploy (push) Successful in 1m38s
All checks were successful
Build and Publish Docker Image / deploy (push) Successful in 1m38s
- Add `WARPBOX_TRUSTED_PROXIES` configuration to restrict accepted forwarded client IP headers to specific proxy IPs/CIDRs, securing client IP resolution. - Integrate `BanService` into the background cleanup job to automatically purge expired abuse and ban evidence events. - Update documentation with reverse proxy security guidelines and a production systemd deployment guide.
This commit is contained in:
545
backend/libs/services/bans.go
Normal file
545
backend/libs/services/bans.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
var (
|
||||
bansBucket = []byte("bans")
|
||||
abuseEventsBucket = []byte("abuse_events")
|
||||
banRulesBucket = []byte("ban_rules")
|
||||
banSettingsBucket = []byte("ban_settings")
|
||||
banSettingsKey = []byte("settings")
|
||||
defaultBanRulesSeed = []byte("default_rules_seeded")
|
||||
)
|
||||
|
||||
const (
|
||||
BanSourceManual = "manual"
|
||||
BanSourceAuto = "auto"
|
||||
|
||||
AbuseKindMaliciousPath = "malicious_path"
|
||||
AbuseKindAdminLogin = "admin_login_failure"
|
||||
AbuseKindUserLogin = "user_login_failure"
|
||||
)
|
||||
|
||||
var defaultMaliciousPathRules = []string{
|
||||
"/wp-admin",
|
||||
"/.env",
|
||||
"/.git/config",
|
||||
"/phpmyadmin",
|
||||
"/wp-login.php",
|
||||
"/xmlrpc.php",
|
||||
"/config.php",
|
||||
"/vendor/phpunit",
|
||||
".env",
|
||||
"backup",
|
||||
"dump.sql",
|
||||
}
|
||||
|
||||
var ErrBanNotFound = errors.New("ban not found")
|
||||
|
||||
type BanService struct {
|
||||
db *bbolt.DB
|
||||
}
|
||||
|
||||
type BanSettings struct {
|
||||
AutoBanEnabled bool `json:"autoBanEnabled"`
|
||||
AutoBanDurationHours int `json:"autoBanDurationHours"`
|
||||
MaliciousPathThreshold int `json:"maliciousPathThreshold"`
|
||||
AdminLoginFailureThreshold int `json:"adminLoginFailureThreshold"`
|
||||
UserLoginFailureThreshold int `json:"userLoginFailureThreshold"`
|
||||
AbuseWindowHours int `json:"abuseWindowHours"`
|
||||
}
|
||||
|
||||
type BanRecord struct {
|
||||
ID string `json:"id"`
|
||||
Target string `json:"target"`
|
||||
Normalized string `json:"normalized"`
|
||||
Reason string `json:"reason"`
|
||||
Source string `json:"source"`
|
||||
CreatedBy string `json:"createdBy,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
UnbannedAt *time.Time `json:"unbannedAt,omitempty"`
|
||||
LastMatchedAt *time.Time `json:"lastMatchedAt,omitempty"`
|
||||
}
|
||||
|
||||
type BanRule struct {
|
||||
ID string `json:"id"`
|
||||
Pattern string `json:"pattern"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type AbuseEvent struct {
|
||||
Key string `json:"key"`
|
||||
IP string `json:"ip"`
|
||||
Kind string `json:"kind"`
|
||||
Count int `json:"count"`
|
||||
FirstSeen time.Time `json:"firstSeen"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type MatchedBan struct {
|
||||
Ban BanRecord
|
||||
IP string
|
||||
}
|
||||
|
||||
type AbuseResult struct {
|
||||
Event AbuseEvent
|
||||
Ban BanRecord
|
||||
Triggered bool
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
func NewBanService(db *bbolt.DB) (*BanService, error) {
|
||||
service := &BanService{db: db}
|
||||
err := db.Update(func(tx *bbolt.Tx) error {
|
||||
for _, bucket := range [][]byte{bansBucket, abuseEventsBucket, banRulesBucket, banSettingsBucket} {
|
||||
if _, err := tx.CreateBucketIfNotExists(bucket); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if tx.Bucket(banSettingsBucket).Get(banSettingsKey) == nil {
|
||||
data, err := json.Marshal(DefaultBanSettings())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Bucket(banSettingsBucket).Put(banSettingsKey, data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
rules := tx.Bucket(banRulesBucket)
|
||||
if rules.Get(defaultBanRulesSeed) == nil {
|
||||
now := time.Now().UTC()
|
||||
for _, pattern := range defaultMaliciousPathRules {
|
||||
rule := BanRule{ID: randomID(10), Pattern: pattern, Enabled: true, CreatedAt: now, UpdatedAt: now}
|
||||
data, err := json.Marshal(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rules.Put([]byte(rule.ID), data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := rules.Put(defaultBanRulesSeed, []byte("1")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return service, err
|
||||
}
|
||||
|
||||
func DefaultBanSettings() BanSettings {
|
||||
return BanSettings{
|
||||
AutoBanEnabled: false,
|
||||
AutoBanDurationHours: 24,
|
||||
MaliciousPathThreshold: 3,
|
||||
AdminLoginFailureThreshold: 10,
|
||||
UserLoginFailureThreshold: 30,
|
||||
AbuseWindowHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BanService) Settings() (BanSettings, error) {
|
||||
settings := DefaultBanSettings()
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
data := tx.Bucket(banSettingsBucket).Get(banSettingsKey)
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return err
|
||||
}
|
||||
settings = withBanSettingDefaults(settings)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return BanSettings{}, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (s *BanService) UpdateSettings(settings BanSettings) error {
|
||||
settings = withBanSettingDefaults(settings)
|
||||
if settings.AutoBanDurationHours <= 0 || settings.MaliciousPathThreshold <= 0 ||
|
||||
settings.AdminLoginFailureThreshold <= 0 || settings.UserLoginFailureThreshold <= 0 ||
|
||||
settings.AbuseWindowHours <= 0 {
|
||||
return fmt.Errorf("ban settings must be positive")
|
||||
}
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(banSettingsBucket).Put(banSettingsKey, data)
|
||||
})
|
||||
}
|
||||
|
||||
func withBanSettingDefaults(settings BanSettings) BanSettings {
|
||||
defaults := DefaultBanSettings()
|
||||
if settings.AutoBanDurationHours <= 0 {
|
||||
settings.AutoBanDurationHours = defaults.AutoBanDurationHours
|
||||
}
|
||||
if settings.MaliciousPathThreshold <= 0 {
|
||||
settings.MaliciousPathThreshold = defaults.MaliciousPathThreshold
|
||||
}
|
||||
if settings.AdminLoginFailureThreshold <= 0 {
|
||||
settings.AdminLoginFailureThreshold = defaults.AdminLoginFailureThreshold
|
||||
}
|
||||
if settings.UserLoginFailureThreshold <= 0 {
|
||||
settings.UserLoginFailureThreshold = defaults.UserLoginFailureThreshold
|
||||
}
|
||||
if settings.AbuseWindowHours <= 0 {
|
||||
settings.AbuseWindowHours = defaults.AbuseWindowHours
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func (s *BanService) CreateManualBan(target, reason, createdBy string, expiresAt time.Time) (BanRecord, error) {
|
||||
return s.createBan(target, reason, BanSourceManual, createdBy, expiresAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *BanService) createBan(target, reason, source, createdBy string, expiresAt, now time.Time) (BanRecord, error) {
|
||||
normalized, err := NormalizeBanTarget(target)
|
||||
if err != nil {
|
||||
return BanRecord{}, err
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if reason == "" {
|
||||
return BanRecord{}, fmt.Errorf("ban reason is required")
|
||||
}
|
||||
if !expiresAt.After(now) {
|
||||
return BanRecord{}, fmt.Errorf("ban expiration must be in the future")
|
||||
}
|
||||
record := BanRecord{
|
||||
ID: randomID(12),
|
||||
Target: strings.TrimSpace(target),
|
||||
Normalized: normalized,
|
||||
Reason: reason,
|
||||
Source: source,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ExpiresAt: expiresAt.UTC(),
|
||||
}
|
||||
data, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return BanRecord{}, err
|
||||
}
|
||||
err = s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(bansBucket).Put([]byte(record.ID), data)
|
||||
})
|
||||
return record, err
|
||||
}
|
||||
|
||||
func NormalizeBanTarget(target string) (string, error) {
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" {
|
||||
return "", fmt.Errorf("ban target is required")
|
||||
}
|
||||
if strings.Contains(target, "/") {
|
||||
_, network, err := net.ParseCIDR(target)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid CIDR target")
|
||||
}
|
||||
return network.String(), nil
|
||||
}
|
||||
ip := net.ParseIP(target)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("invalid IP target")
|
||||
}
|
||||
return ip.String(), nil
|
||||
}
|
||||
|
||||
func (s *BanService) ListBans() ([]BanRecord, error) {
|
||||
records := []BanRecord{}
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(bansBucket).ForEach(func(_, value []byte) error {
|
||||
var record BanRecord
|
||||
if err := json.Unmarshal(value, &record); err != nil {
|
||||
return err
|
||||
}
|
||||
records = append(records, record)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].CreatedAt.After(records[j].CreatedAt)
|
||||
})
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (s *BanService) Unban(id string, now time.Time) error {
|
||||
id = strings.TrimSpace(id)
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(bansBucket)
|
||||
data := bucket.Get([]byte(id))
|
||||
if data == nil {
|
||||
return ErrBanNotFound
|
||||
}
|
||||
var record BanRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
return err
|
||||
}
|
||||
now = now.UTC()
|
||||
record.UnbannedAt = &now
|
||||
record.UpdatedAt = now
|
||||
next, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put([]byte(id), next)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BanService) Match(ip string, now time.Time) (MatchedBan, bool, error) {
|
||||
parsed := net.ParseIP(strings.TrimSpace(ip))
|
||||
if parsed == nil {
|
||||
return MatchedBan{}, false, nil
|
||||
}
|
||||
now = now.UTC()
|
||||
var matched BanRecord
|
||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(bansBucket)
|
||||
return bucket.ForEach(func(key, value []byte) error {
|
||||
if matched.ID != "" {
|
||||
return nil
|
||||
}
|
||||
var record BanRecord
|
||||
if err := json.Unmarshal(value, &record); err != nil {
|
||||
return err
|
||||
}
|
||||
if !record.Active(now) || !banTargetMatches(record.Normalized, parsed) {
|
||||
return nil
|
||||
}
|
||||
record.LastMatchedAt = &now
|
||||
record.UpdatedAt = now
|
||||
next, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bucket.Put(key, next); err != nil {
|
||||
return err
|
||||
}
|
||||
matched = record
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return MatchedBan{Ban: matched, IP: ip}, matched.ID != "", err
|
||||
}
|
||||
|
||||
func (r BanRecord) Active(now time.Time) bool {
|
||||
return r.UnbannedAt == nil && r.ExpiresAt.After(now.UTC())
|
||||
}
|
||||
|
||||
func (r BanRecord) Status(now time.Time) string {
|
||||
switch {
|
||||
case r.UnbannedAt != nil:
|
||||
return "unbanned"
|
||||
case !r.ExpiresAt.After(now.UTC()):
|
||||
return "expired"
|
||||
default:
|
||||
return "active"
|
||||
}
|
||||
}
|
||||
|
||||
func banTargetMatches(target string, ip net.IP) bool {
|
||||
if strings.Contains(target, "/") {
|
||||
if _, network, err := net.ParseCIDR(target); err == nil {
|
||||
return network.Contains(ip)
|
||||
}
|
||||
return false
|
||||
}
|
||||
targetIP := net.ParseIP(target)
|
||||
return targetIP != nil && targetIP.Equal(ip)
|
||||
}
|
||||
|
||||
func (s *BanService) ListRules() ([]BanRule, error) {
|
||||
rules := []BanRule{}
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(banRulesBucket).ForEach(func(key, value []byte) error {
|
||||
if string(key) == string(defaultBanRulesSeed) {
|
||||
return nil
|
||||
}
|
||||
var rule BanRule
|
||||
if err := json.Unmarshal(value, &rule); err != nil {
|
||||
return err
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return strings.ToLower(rules[i].Pattern) < strings.ToLower(rules[j].Pattern)
|
||||
})
|
||||
return rules, err
|
||||
}
|
||||
|
||||
func (s *BanService) SaveRules(patterns []string, now time.Time) error {
|
||||
now = now.UTC()
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(banRulesBucket)
|
||||
deleteKeys := [][]byte{}
|
||||
if err := bucket.ForEach(func(key, _ []byte) error {
|
||||
if string(key) == string(defaultBanRulesSeed) {
|
||||
return nil
|
||||
}
|
||||
deleteKeys = append(deleteKeys, append([]byte(nil), key...))
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range deleteKeys {
|
||||
if err := bucket.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
seen := map[string]bool{}
|
||||
for _, pattern := range patterns {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
if pattern == "" || seen[strings.ToLower(pattern)] {
|
||||
continue
|
||||
}
|
||||
seen[strings.ToLower(pattern)] = true
|
||||
rule := BanRule{ID: randomID(10), Pattern: pattern, Enabled: true, CreatedAt: now, UpdatedAt: now}
|
||||
data, err := json.Marshal(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bucket.Put([]byte(rule.ID), data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BanService) DeleteRule(id string) error {
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(banRulesBucket).Delete([]byte(strings.TrimSpace(id)))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BanService) MaliciousPattern(path string) (string, error) {
|
||||
if shouldSkipMaliciousPath(path) {
|
||||
return "", nil
|
||||
}
|
||||
rules, err := s.ListRules()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
lowerPath := strings.ToLower(path)
|
||||
for _, rule := range rules {
|
||||
if rule.Enabled && strings.Contains(lowerPath, strings.ToLower(rule.Pattern)) {
|
||||
return rule.Pattern, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func shouldSkipMaliciousPath(path string) bool {
|
||||
return path == "/health" || path == "/healthz" || path == "/api/v1/health" || strings.HasPrefix(path, "/static/")
|
||||
}
|
||||
|
||||
func (s *BanService) RecordAbuse(ip, kind, detail string, threshold int, now time.Time) (AbuseResult, error) {
|
||||
settings, err := s.Settings()
|
||||
if err != nil {
|
||||
return AbuseResult{}, err
|
||||
}
|
||||
if !settings.AutoBanEnabled {
|
||||
return AbuseResult{Enabled: false}, nil
|
||||
}
|
||||
if threshold <= 0 {
|
||||
return AbuseResult{Enabled: true}, nil
|
||||
}
|
||||
now = now.UTC()
|
||||
window := time.Duration(settings.AbuseWindowHours) * time.Hour
|
||||
key := abuseKey(ip, kind)
|
||||
var event AbuseEvent
|
||||
var triggered bool
|
||||
var ban BanRecord
|
||||
err = s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(abuseEventsBucket)
|
||||
data := bucket.Get([]byte(key))
|
||||
if data != nil {
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if data == nil || now.Sub(event.FirstSeen) > window {
|
||||
event = AbuseEvent{Key: key, IP: ip, Kind: kind, FirstSeen: now}
|
||||
}
|
||||
event.Count++
|
||||
event.LastSeen = now
|
||||
event.Detail = detail
|
||||
next, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bucket.Put([]byte(key), next); err != nil {
|
||||
return err
|
||||
}
|
||||
triggered = event.Count >= threshold
|
||||
return nil
|
||||
})
|
||||
if err != nil || !triggered {
|
||||
return AbuseResult{Event: event, Triggered: false, Enabled: true}, err
|
||||
}
|
||||
reason := fmt.Sprintf("%s threshold reached: %s", strings.ReplaceAll(kind, "_", " "), detail)
|
||||
ban, err = s.createBan(ip, reason, BanSourceAuto, "", now.Add(time.Duration(settings.AutoBanDurationHours)*time.Hour), now)
|
||||
if err != nil {
|
||||
return AbuseResult{}, err
|
||||
}
|
||||
return AbuseResult{Event: event, Ban: ban, Triggered: true, Enabled: true}, nil
|
||||
}
|
||||
|
||||
func (s *BanService) CleanupAbuseEvents(now time.Time) (int, error) {
|
||||
settings, err := s.Settings()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cutoff := now.UTC().Add(-time.Duration(settings.AbuseWindowHours) * time.Hour)
|
||||
cleaned := 0
|
||||
err = s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(abuseEventsBucket)
|
||||
deleteKeys := [][]byte{}
|
||||
if err := bucket.ForEach(func(key, value []byte) error {
|
||||
var event AbuseEvent
|
||||
if err := json.Unmarshal(value, &event); err != nil {
|
||||
deleteKeys = append(deleteKeys, append([]byte(nil), key...))
|
||||
return nil
|
||||
}
|
||||
if event.LastSeen.Before(cutoff) {
|
||||
deleteKeys = append(deleteKeys, append([]byte(nil), key...))
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range deleteKeys {
|
||||
if err := bucket.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
cleaned++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return cleaned, err
|
||||
}
|
||||
|
||||
func abuseKey(ip, kind string) string {
|
||||
return kind + ":" + strings.TrimSpace(ip)
|
||||
}
|
||||
117
backend/libs/services/bans_test.go
Normal file
117
backend/libs/services/bans_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBanServiceMatchesIPAndCIDR(t *testing.T) {
|
||||
bans := newTestBanService(t)
|
||||
now := time.Date(2026, 5, 31, 12, 0, 0, 0, time.UTC)
|
||||
ipBan, err := bans.createBan("203.0.113.5", "single IP", BanSourceManual, "test", now.Add(time.Hour), now)
|
||||
if err != nil {
|
||||
t.Fatalf("createBan IP returned error: %v", err)
|
||||
}
|
||||
cidrBan, err := bans.createBan("198.51.100.0/24", "CIDR", BanSourceManual, "test", now.Add(time.Hour), now)
|
||||
if err != nil {
|
||||
t.Fatalf("createBan CIDR returned error: %v", err)
|
||||
}
|
||||
|
||||
if matched, ok, err := bans.Match("203.0.113.5", now); err != nil || !ok || matched.Ban.ID != ipBan.ID {
|
||||
t.Fatalf("Match IP = %+v, %v, %v", matched, ok, err)
|
||||
}
|
||||
if matched, ok, err := bans.Match("198.51.100.42", now); err != nil || !ok || matched.Ban.ID != cidrBan.ID {
|
||||
t.Fatalf("Match CIDR = %+v, %v, %v", matched, ok, err)
|
||||
}
|
||||
if _, ok, err := bans.Match("192.0.2.1", now); err != nil || ok {
|
||||
t.Fatalf("Match unrelated = %v, %v", ok, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanServiceIgnoresExpiredAndUnbanned(t *testing.T) {
|
||||
bans := newTestBanService(t)
|
||||
now := time.Date(2026, 5, 31, 12, 0, 0, 0, time.UTC)
|
||||
expired, err := bans.createBan("203.0.113.6", "expired", BanSourceManual, "test", now.Add(time.Hour), now)
|
||||
if err != nil {
|
||||
t.Fatalf("createBan expired returned error: %v", err)
|
||||
}
|
||||
if _, ok, err := bans.Match("203.0.113.6", now.Add(2*time.Hour)); err != nil || ok {
|
||||
t.Fatalf("expired Match = %v, %v", ok, err)
|
||||
}
|
||||
active, err := bans.createBan("203.0.113.7", "active", BanSourceManual, "test", now.Add(time.Hour), now)
|
||||
if err != nil {
|
||||
t.Fatalf("createBan active returned error: %v", err)
|
||||
}
|
||||
if err := bans.Unban(active.ID, now.Add(time.Minute)); err != nil {
|
||||
t.Fatalf("Unban returned error: %v", err)
|
||||
}
|
||||
if _, ok, err := bans.Match("203.0.113.7", now.Add(2*time.Minute)); err != nil || ok {
|
||||
t.Fatalf("unbanned Match = %v, %v", ok, err)
|
||||
}
|
||||
if expired.Status(now.Add(2*time.Hour)) != "expired" {
|
||||
t.Fatalf("expired status = %q", expired.Status(now.Add(2*time.Hour)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanServiceAutoBanThresholdsAndDisabled(t *testing.T) {
|
||||
bans := newTestBanService(t)
|
||||
now := time.Date(2026, 5, 31, 12, 0, 0, 0, time.UTC)
|
||||
if result, err := bans.RecordAbuse("203.0.113.8", AbuseKindMaliciousPath, "/.env", 3, now); err != nil || result.Enabled {
|
||||
t.Fatalf("disabled RecordAbuse = %+v, %v", result, err)
|
||||
}
|
||||
settings, err := bans.Settings()
|
||||
if err != nil {
|
||||
t.Fatalf("Settings returned error: %v", err)
|
||||
}
|
||||
settings.AutoBanEnabled = true
|
||||
if err := bans.UpdateSettings(settings); err != nil {
|
||||
t.Fatalf("UpdateSettings returned error: %v", err)
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
result, err := bans.RecordAbuse("203.0.113.8", AbuseKindMaliciousPath, "/.env", 3, now.Add(time.Duration(i)*time.Minute))
|
||||
if err != nil || result.Triggered {
|
||||
t.Fatalf("RecordAbuse before threshold = %+v, %v", result, err)
|
||||
}
|
||||
}
|
||||
result, err := bans.RecordAbuse("203.0.113.8", AbuseKindMaliciousPath, "/.env", 3, now.Add(3*time.Minute))
|
||||
if err != nil || !result.Triggered || result.Ban.ID == "" {
|
||||
t.Fatalf("RecordAbuse threshold = %+v, %v", result, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanServiceMaliciousPathRules(t *testing.T) {
|
||||
bans := newTestBanService(t)
|
||||
if pattern, err := bans.MaliciousPattern("/foo/.ENV"); err != nil || pattern == "" {
|
||||
t.Fatalf("MaliciousPattern .env = %q, %v", pattern, err)
|
||||
}
|
||||
if pattern, err := bans.MaliciousPattern("/static/.env"); err != nil || pattern != "" {
|
||||
t.Fatalf("MaliciousPattern static = %q, %v", pattern, err)
|
||||
}
|
||||
if err := bans.SaveRules([]string{"/custom-probe"}, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("SaveRules returned error: %v", err)
|
||||
}
|
||||
if pattern, err := bans.MaliciousPattern("/x/CUSTOM-probe"); err != nil || pattern != "/custom-probe" {
|
||||
t.Fatalf("MaliciousPattern custom = %q, %v", pattern, err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestBanService(t *testing.T) *BanService {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
upload, err := NewUploadService(1024*1024, filepath.Join(root, "data"), "http://example.test", slog.Default())
|
||||
if err != nil {
|
||||
t.Fatalf("NewUploadService returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := upload.Close(); err != nil {
|
||||
t.Fatalf("Close returned error: %v", err)
|
||||
}
|
||||
})
|
||||
bans, err := NewBanService(upload.DB())
|
||||
if err != nil {
|
||||
t.Fatalf("NewBanService returned error: %v", err)
|
||||
}
|
||||
return bans
|
||||
}
|
||||
75
backend/libs/services/proxy.go
Normal file
75
backend/libs/services/proxy.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type clientIPContextKey struct{}
|
||||
|
||||
func WithClientIP(r *http.Request, ip string) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), clientIPContextKey{}, ip))
|
||||
}
|
||||
|
||||
func ClientIPFromContext(r *http.Request) (string, bool) {
|
||||
ip, ok := r.Context().Value(clientIPContextKey{}).(string)
|
||||
return ip, ok && ip != ""
|
||||
}
|
||||
|
||||
// ClientIP resolves the effective client IP. When trustedProxies is empty,
|
||||
// forwarded headers are trusted for easy reverse-proxy/container defaults.
|
||||
func ClientIP(remoteAddr, forwardedFor, realIP string, trustedProxies []string) string {
|
||||
remoteIP := remoteIPOnly(remoteAddr)
|
||||
if len(trustedProxies) == 0 || remoteTrusted(remoteIP, trustedProxies) {
|
||||
if ip := firstForwardedIP(forwardedFor); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := strings.TrimSpace(realIP); ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
func remoteIPOnly(remoteAddr string) string {
|
||||
host := strings.TrimSpace(remoteAddr)
|
||||
if splitHost, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||||
host = splitHost
|
||||
}
|
||||
return strings.Trim(host, "[]")
|
||||
}
|
||||
|
||||
func firstForwardedIP(forwardedFor string) string {
|
||||
for _, part := range strings.Split(forwardedFor, ",") {
|
||||
ip := strings.TrimSpace(part)
|
||||
if ip != "" {
|
||||
return strings.Trim(ip, "[]")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func remoteTrusted(remoteIP string, trustedProxies []string) bool {
|
||||
parsed := net.ParseIP(remoteIP)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
for _, trusted := range trustedProxies {
|
||||
trusted = strings.TrimSpace(trusted)
|
||||
if trusted == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(trusted, "/") {
|
||||
if _, network, err := net.ParseCIDR(trusted); err == nil && network.Contains(parsed) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(trusted); ip != nil && ip.Equal(parsed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
29
backend/libs/services/proxy_test.go
Normal file
29
backend/libs/services/proxy_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package services
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClientIPTrustsForwardedHeadersByDefault(t *testing.T) {
|
||||
ip := ClientIP("127.0.0.1:6070", "203.0.113.10, 10.0.0.2", "198.51.100.2", nil)
|
||||
if ip != "203.0.113.10" {
|
||||
t.Fatalf("ClientIP = %q, want forwarded IP", ip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIPUsesTrustedProxyCIDRs(t *testing.T) {
|
||||
trusted := []string{"127.0.0.1", "172.16.0.0/12"}
|
||||
ip := ClientIP("172.20.0.4:6070", "203.0.113.11", "", trusted)
|
||||
if ip != "203.0.113.11" {
|
||||
t.Fatalf("trusted ClientIP = %q", ip)
|
||||
}
|
||||
spoofed := ClientIP("198.51.100.20:6070", "203.0.113.12", "203.0.113.13", trusted)
|
||||
if spoofed != "198.51.100.20" {
|
||||
t.Fatalf("untrusted ClientIP = %q, want remote addr", spoofed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIPFallsBackToRealIP(t *testing.T) {
|
||||
ip := ClientIP("127.0.0.1:6070", "", "203.0.113.14", nil)
|
||||
if ip != "203.0.113.14" {
|
||||
t.Fatalf("ClientIP = %q, want real IP", ip)
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -484,19 +483,3 @@ func normalizeBackendID(id string) string {
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func ClientIP(remoteAddr, forwardedFor string) string {
|
||||
if forwardedFor != "" {
|
||||
parts := strings.Split(forwardedFor, ",")
|
||||
if ip := strings.TrimSpace(parts[0]); ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
host := remoteAddr
|
||||
if strings.Contains(remoteAddr, ":") {
|
||||
if splitHost, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||||
host = splitHost
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user