package services import ( "encoding/json" "errors" "fmt" "net" "sort" "strings" "time" "go.etcd.io/bbolt" ) var ( bansBucket = []byte("bans") abuseEventsBucket = []byte("abuse_events") banRulesBucket = []byte("ban_rules") banSettingsBucket = []byte("ban_settings") banSettingsKey = []byte("settings") defaultBanRulesSeed = []byte("default_rules_seeded") ) const ( BanSourceManual = "manual" BanSourceAuto = "auto" AbuseKindMaliciousPath = "malicious_path" AbuseKindAdminLogin = "admin_login_failure" AbuseKindUserLogin = "user_login_failure" ) var defaultMaliciousPathRules = []string{ "/wp-admin", "/.env", "/.git/config", "/phpmyadmin", "/wp-login.php", "/xmlrpc.php", "/config.php", "/vendor/phpunit", ".env", "backup", "dump.sql", } var ErrBanNotFound = errors.New("ban not found") type BanService struct { db *bbolt.DB } type BanSettings struct { AutoBanEnabled bool `json:"autoBanEnabled"` AutoBanDurationHours int `json:"autoBanDurationHours"` MaliciousPathThreshold int `json:"maliciousPathThreshold"` AdminLoginFailureThreshold int `json:"adminLoginFailureThreshold"` UserLoginFailureThreshold int `json:"userLoginFailureThreshold"` AbuseWindowHours int `json:"abuseWindowHours"` } type BanRecord struct { ID string `json:"id"` Target string `json:"target"` Normalized string `json:"normalized"` Reason string `json:"reason"` Source string `json:"source"` CreatedBy string `json:"createdBy,omitempty"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` ExpiresAt time.Time `json:"expiresAt"` UnbannedAt *time.Time `json:"unbannedAt,omitempty"` LastMatchedAt *time.Time `json:"lastMatchedAt,omitempty"` } type BanRule struct { ID string `json:"id"` Pattern string `json:"pattern"` Enabled bool `json:"enabled"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } type AbuseEvent struct { Key string `json:"key"` IP string `json:"ip"` Kind string `json:"kind"` Count int `json:"count"` FirstSeen time.Time `json:"firstSeen"` LastSeen time.Time `json:"lastSeen"` Detail string `json:"detail,omitempty"` } type MatchedBan struct { Ban BanRecord IP string } type AbuseResult struct { Event AbuseEvent Ban BanRecord Triggered bool Enabled bool } func NewBanService(db *bbolt.DB) (*BanService, error) { service := &BanService{db: db} err := db.Update(func(tx *bbolt.Tx) error { for _, bucket := range [][]byte{bansBucket, abuseEventsBucket, banRulesBucket, banSettingsBucket} { if _, err := tx.CreateBucketIfNotExists(bucket); err != nil { return err } } if tx.Bucket(banSettingsBucket).Get(banSettingsKey) == nil { data, err := json.Marshal(DefaultBanSettings()) if err != nil { return err } if err := tx.Bucket(banSettingsBucket).Put(banSettingsKey, data); err != nil { return err } } rules := tx.Bucket(banRulesBucket) if rules.Get(defaultBanRulesSeed) == nil { now := time.Now().UTC() for _, pattern := range defaultMaliciousPathRules { rule := BanRule{ID: randomID(10), Pattern: pattern, Enabled: true, CreatedAt: now, UpdatedAt: now} data, err := json.Marshal(rule) if err != nil { return err } if err := rules.Put([]byte(rule.ID), data); err != nil { return err } } if err := rules.Put(defaultBanRulesSeed, []byte("1")); err != nil { return err } } return nil }) return service, err } func DefaultBanSettings() BanSettings { return BanSettings{ AutoBanEnabled: false, AutoBanDurationHours: 24, MaliciousPathThreshold: 3, AdminLoginFailureThreshold: 10, UserLoginFailureThreshold: 30, AbuseWindowHours: 24, } } func (s *BanService) Settings() (BanSettings, error) { settings := DefaultBanSettings() err := s.db.View(func(tx *bbolt.Tx) error { data := tx.Bucket(banSettingsBucket).Get(banSettingsKey) if data == nil { return nil } if err := json.Unmarshal(data, &settings); err != nil { return err } settings = withBanSettingDefaults(settings) return nil }) if err != nil { return BanSettings{}, err } return settings, nil } func (s *BanService) UpdateSettings(settings BanSettings) error { settings = withBanSettingDefaults(settings) if settings.AutoBanDurationHours <= 0 || settings.MaliciousPathThreshold <= 0 || settings.AdminLoginFailureThreshold <= 0 || settings.UserLoginFailureThreshold <= 0 || settings.AbuseWindowHours <= 0 { return fmt.Errorf("ban settings must be positive") } data, err := json.Marshal(settings) if err != nil { return err } return s.db.Update(func(tx *bbolt.Tx) error { return tx.Bucket(banSettingsBucket).Put(banSettingsKey, data) }) } func withBanSettingDefaults(settings BanSettings) BanSettings { defaults := DefaultBanSettings() if settings.AutoBanDurationHours <= 0 { settings.AutoBanDurationHours = defaults.AutoBanDurationHours } if settings.MaliciousPathThreshold <= 0 { settings.MaliciousPathThreshold = defaults.MaliciousPathThreshold } if settings.AdminLoginFailureThreshold <= 0 { settings.AdminLoginFailureThreshold = defaults.AdminLoginFailureThreshold } if settings.UserLoginFailureThreshold <= 0 { settings.UserLoginFailureThreshold = defaults.UserLoginFailureThreshold } if settings.AbuseWindowHours <= 0 { settings.AbuseWindowHours = defaults.AbuseWindowHours } return settings } func (s *BanService) CreateManualBan(target, reason, createdBy string, expiresAt time.Time) (BanRecord, error) { return s.createBan(target, reason, BanSourceManual, createdBy, expiresAt, time.Now().UTC()) } func (s *BanService) createBan(target, reason, source, createdBy string, expiresAt, now time.Time) (BanRecord, error) { normalized, err := NormalizeBanTarget(target) if err != nil { return BanRecord{}, err } reason = strings.TrimSpace(reason) if reason == "" { return BanRecord{}, fmt.Errorf("ban reason is required") } if !expiresAt.After(now) { return BanRecord{}, fmt.Errorf("ban expiration must be in the future") } record := BanRecord{ ID: randomID(12), Target: strings.TrimSpace(target), Normalized: normalized, Reason: reason, Source: source, CreatedBy: createdBy, CreatedAt: now, UpdatedAt: now, ExpiresAt: expiresAt.UTC(), } data, err := json.Marshal(record) if err != nil { return BanRecord{}, err } err = s.db.Update(func(tx *bbolt.Tx) error { return tx.Bucket(bansBucket).Put([]byte(record.ID), data) }) return record, err } func NormalizeBanTarget(target string) (string, error) { target = strings.TrimSpace(target) if target == "" { return "", fmt.Errorf("ban target is required") } if strings.Contains(target, "/") { _, network, err := net.ParseCIDR(target) if err != nil { return "", fmt.Errorf("invalid CIDR target") } return network.String(), nil } ip := net.ParseIP(target) if ip == nil { return "", fmt.Errorf("invalid IP target") } return ip.String(), nil } func (s *BanService) ListBans() ([]BanRecord, error) { records := []BanRecord{} err := s.db.View(func(tx *bbolt.Tx) error { return tx.Bucket(bansBucket).ForEach(func(_, value []byte) error { var record BanRecord if err := json.Unmarshal(value, &record); err != nil { return err } records = append(records, record) return nil }) }) sort.Slice(records, func(i, j int) bool { return records[i].CreatedAt.After(records[j].CreatedAt) }) return records, err } func (s *BanService) Unban(id string, now time.Time) error { id = strings.TrimSpace(id) return s.db.Update(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bansBucket) data := bucket.Get([]byte(id)) if data == nil { return ErrBanNotFound } var record BanRecord if err := json.Unmarshal(data, &record); err != nil { return err } now = now.UTC() record.UnbannedAt = &now record.UpdatedAt = now next, err := json.Marshal(record) if err != nil { return err } return bucket.Put([]byte(id), next) }) } func (s *BanService) Match(ip string, now time.Time) (MatchedBan, bool, error) { parsed := net.ParseIP(strings.TrimSpace(ip)) if parsed == nil { return MatchedBan{}, false, nil } now = now.UTC() var matched BanRecord var matchedKey []byte // Read-only scan first: the common case (no match) only takes a concurrent // read transaction, instead of grabbing the single bbolt write lock on every // request that flows through the ban middleware. err := s.db.View(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bansBucket) return bucket.ForEach(func(key, value []byte) error { if matched.ID != "" { return nil } var record BanRecord if err := json.Unmarshal(value, &record); err != nil { return err } if !record.Active(now) || !banTargetMatches(record.Normalized, parsed) { return nil } matched = record matchedKey = append([]byte(nil), key...) // key bytes are only valid within the txn return nil }) }) if err != nil || matched.ID == "" { return MatchedBan{Ban: matched, IP: ip}, matched.ID != "", err } // On a hit, record the match time in a short write transaction. matched.LastMatchedAt = &now matched.UpdatedAt = now _ = s.db.Update(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bansBucket) data := bucket.Get(matchedKey) if data == nil { return nil } var record BanRecord if err := json.Unmarshal(data, &record); err != nil { return nil } record.LastMatchedAt = &now record.UpdatedAt = now next, err := json.Marshal(record) if err != nil { return nil } return bucket.Put(matchedKey, next) }) return MatchedBan{Ban: matched, IP: ip}, true, nil } func (r BanRecord) Active(now time.Time) bool { return r.UnbannedAt == nil && r.ExpiresAt.After(now.UTC()) } func (r BanRecord) Status(now time.Time) string { switch { case r.UnbannedAt != nil: return "unbanned" case !r.ExpiresAt.After(now.UTC()): return "expired" default: return "active" } } func banTargetMatches(target string, ip net.IP) bool { if strings.Contains(target, "/") { if _, network, err := net.ParseCIDR(target); err == nil { return network.Contains(ip) } return false } targetIP := net.ParseIP(target) return targetIP != nil && targetIP.Equal(ip) } func (s *BanService) ListRules() ([]BanRule, error) { rules := []BanRule{} err := s.db.View(func(tx *bbolt.Tx) error { return tx.Bucket(banRulesBucket).ForEach(func(key, value []byte) error { if string(key) == string(defaultBanRulesSeed) { return nil } var rule BanRule if err := json.Unmarshal(value, &rule); err != nil { return err } rules = append(rules, rule) return nil }) }) sort.Slice(rules, func(i, j int) bool { return strings.ToLower(rules[i].Pattern) < strings.ToLower(rules[j].Pattern) }) return rules, err } func (s *BanService) SaveRules(patterns []string, now time.Time) error { now = now.UTC() return s.db.Update(func(tx *bbolt.Tx) error { bucket := tx.Bucket(banRulesBucket) deleteKeys := [][]byte{} if err := bucket.ForEach(func(key, _ []byte) error { if string(key) == string(defaultBanRulesSeed) { return nil } deleteKeys = append(deleteKeys, append([]byte(nil), key...)) return nil }); err != nil { return err } for _, key := range deleteKeys { if err := bucket.Delete(key); err != nil { return err } } seen := map[string]bool{} for _, pattern := range patterns { pattern = strings.TrimSpace(pattern) if pattern == "" || seen[strings.ToLower(pattern)] { continue } seen[strings.ToLower(pattern)] = true rule := BanRule{ID: randomID(10), Pattern: pattern, Enabled: true, CreatedAt: now, UpdatedAt: now} data, err := json.Marshal(rule) if err != nil { return err } if err := bucket.Put([]byte(rule.ID), data); err != nil { return err } } return nil }) } func (s *BanService) DeleteRule(id string) error { return s.db.Update(func(tx *bbolt.Tx) error { return tx.Bucket(banRulesBucket).Delete([]byte(strings.TrimSpace(id))) }) } func (s *BanService) MaliciousPattern(path string) (string, error) { if shouldSkipMaliciousPath(path) { return "", nil } rules, err := s.ListRules() if err != nil { return "", err } lowerPath := strings.ToLower(path) for _, rule := range rules { if rule.Enabled && strings.Contains(lowerPath, strings.ToLower(rule.Pattern)) { return rule.Pattern, nil } } return "", nil } func shouldSkipMaliciousPath(path string) bool { return path == "/health" || strings.HasPrefix(path, "/static/") } func (s *BanService) RecordAbuse(ip, kind, detail string, threshold int, now time.Time) (AbuseResult, error) { settings, err := s.Settings() if err != nil { return AbuseResult{}, err } if !settings.AutoBanEnabled { return AbuseResult{Enabled: false}, nil } if threshold <= 0 { return AbuseResult{Enabled: true}, nil } now = now.UTC() window := time.Duration(settings.AbuseWindowHours) * time.Hour key := abuseKey(ip, kind) var event AbuseEvent var triggered bool var ban BanRecord err = s.db.Update(func(tx *bbolt.Tx) error { bucket := tx.Bucket(abuseEventsBucket) data := bucket.Get([]byte(key)) if data != nil { if err := json.Unmarshal(data, &event); err != nil { return err } } if data == nil || now.Sub(event.FirstSeen) > window { event = AbuseEvent{Key: key, IP: ip, Kind: kind, FirstSeen: now} } event.Count++ event.LastSeen = now event.Detail = detail next, err := json.Marshal(event) if err != nil { return err } if err := bucket.Put([]byte(key), next); err != nil { return err } triggered = event.Count >= threshold return nil }) if err != nil || !triggered { return AbuseResult{Event: event, Triggered: false, Enabled: true}, err } if matched, ok, err := s.Match(ip, now); err != nil { return AbuseResult{}, err } else if ok { return AbuseResult{Event: event, Ban: matched.Ban, Triggered: true, Enabled: true}, nil } reason := fmt.Sprintf("%s threshold reached: %s", strings.ReplaceAll(kind, "_", " "), detail) ban, err = s.createBan(ip, reason, BanSourceAuto, "", now.Add(time.Duration(settings.AutoBanDurationHours)*time.Hour), now) if err != nil { return AbuseResult{}, err } return AbuseResult{Event: event, Ban: ban, Triggered: true, Enabled: true}, nil } func (s *BanService) CleanupAbuseEvents(now time.Time) (int, error) { settings, err := s.Settings() if err != nil { return 0, err } cutoff := now.UTC().Add(-time.Duration(settings.AbuseWindowHours) * time.Hour) cleaned := 0 err = s.db.Update(func(tx *bbolt.Tx) error { bucket := tx.Bucket(abuseEventsBucket) deleteKeys := [][]byte{} if err := bucket.ForEach(func(key, value []byte) error { var event AbuseEvent if err := json.Unmarshal(value, &event); err != nil { deleteKeys = append(deleteKeys, append([]byte(nil), key...)) return nil } if event.LastSeen.Before(cutoff) { deleteKeys = append(deleteKeys, append([]byte(nil), key...)) } return nil }); err != nil { return err } for _, key := range deleteKeys { if err := bucket.Delete(key); err != nil { return err } cleaned++ } return nil }) return cleaned, err } func abuseKey(ip, kind string) string { return kind + ":" + strings.TrimSpace(ip) }