package security import ( "encoding/binary" "fmt" "net" "os" "sort" "strings" "sync" "time" "github.com/dgraph-io/badger/v4" ) type Config struct { IPWhitelist string AdminIPWhitelist string LoginWindowSeconds int64 LoginMaxAttempts int BanSeconds int64 ScanWindowSeconds int64 ScanMaxAttempts int UploadWindowSeconds int64 UploadMaxRequests int UploadMaxBytes int64 } type Guard struct { mu sync.Mutex failedLogins map[string][]time.Time scanAttempts map[string][]time.Time uploadEvents map[string][]uploadEvent bannedUntil map[string]time.Time ipWhitelist []ipMatcher adminWhitelist []ipMatcher banDB *badger.DB } type ipMatcher struct { exact net.IP prefix *net.IPNet } type uploadEvent struct { at time.Time bytes int64 } type BanEntry struct { IP string `json:"ip"` Until time.Time `json:"until"` } const banKeyPrefix = "ban:" func NewGuard() *Guard { return &Guard{ failedLogins: map[string][]time.Time{}, scanAttempts: map[string][]time.Time{}, uploadEvents: map[string][]uploadEvent{}, bannedUntil: map[string]time.Time{}, ipWhitelist: []ipMatcher{}, adminWhitelist: []ipMatcher{}, } } func (g *Guard) Close() error { g.mu.Lock() defer g.mu.Unlock() if g.banDB == nil { return nil } err := g.banDB.Close() g.banDB = nil return err } func (g *Guard) EnableBanPersistence(path string) error { if strings.TrimSpace(path) == "" { return nil } g.mu.Lock() defer g.mu.Unlock() if g.banDB != nil { return nil } opts := badger.DefaultOptions(path) opts.Logger = nil db, err := badger.Open(opts) if err != nil { // Corruption-safe fallback: quarantine badger files and start fresh. _ = os.Rename(path, path+".corrupt."+time.Now().UTC().Format("20060102T150405")) db, err = badger.Open(opts) } if err != nil { return err } g.banDB = db if err := g.loadBansLocked(); err != nil { _ = g.banDB.Close() g.banDB = nil return err } return nil } func (g *Guard) Reload(cfg Config) error { ipWhitelist, err := ParseIPMatchers(cfg.IPWhitelist, true) if err != nil { return fmt.Errorf("ip whitelist: %w", err) } adminWhitelist, err := ParseIPMatchers(cfg.AdminIPWhitelist, true) if err != nil { return fmt.Errorf("admin ip whitelist: %w", err) } g.mu.Lock() defer g.mu.Unlock() g.ipWhitelist = ipWhitelist g.adminWhitelist = adminWhitelist return nil } func (g *Guard) IsWhitelisted(ip string) bool { g.mu.Lock() defer g.mu.Unlock() return matchIP(g.ipWhitelist, ip) } func (g *Guard) IsAdminWhitelisted(ip string) bool { g.mu.Lock() defer g.mu.Unlock() return matchIP(g.adminWhitelist, ip) || matchIP(g.ipWhitelist, ip) } func (g *Guard) IsBanned(ip string) bool { g.mu.Lock() defer g.mu.Unlock() until, ok := g.bannedUntil[ip] if !ok { return false } if time.Now().UTC().After(until) { delete(g.bannedUntil, ip) g.deleteBanLocked(ip) return false } return true } func (g *Guard) Ban(ip string, seconds int64) { if seconds <= 0 || ip == "" { return } g.mu.Lock() defer g.mu.Unlock() until := time.Now().UTC().Add(time.Duration(seconds) * time.Second) g.bannedUntil[ip] = until g.saveBanLocked(ip, until) } func (g *Guard) BanUntil(ip string, until time.Time) { if ip == "" || until.IsZero() { return } g.mu.Lock() defer g.mu.Unlock() until = until.UTC() g.bannedUntil[ip] = until g.saveBanLocked(ip, until) } func (g *Guard) Unban(ip string) { if ip == "" { return } g.mu.Lock() defer g.mu.Unlock() delete(g.bannedUntil, ip) g.deleteBanLocked(ip) } func (g *Guard) BanList() []BanEntry { g.mu.Lock() defer g.mu.Unlock() now := time.Now().UTC() out := make([]BanEntry, 0, len(g.bannedUntil)) for ip, until := range g.bannedUntil { if now.After(until) { delete(g.bannedUntil, ip) g.deleteBanLocked(ip) continue } out = append(out, BanEntry{IP: ip, Until: until}) } sort.Slice(out, func(i, j int) bool { return out[i].Until.Before(out[j].Until) }) return out } func (g *Guard) RegisterFailedLogin(ip string, windowSeconds int64, maxAttempts int, banSeconds int64) (bool, int) { if ip == "" || maxAttempts <= 0 || windowSeconds <= 0 { return false, 0 } g.mu.Lock() defer g.mu.Unlock() now := time.Now().UTC() cutoff := now.Add(-time.Duration(windowSeconds) * time.Second) attempts := pruneTimes(g.failedLogins[ip], cutoff) attempts = append(attempts, now) g.failedLogins[ip] = attempts if len(attempts) >= maxAttempts { until := now.Add(time.Duration(banSeconds) * time.Second) g.bannedUntil[ip] = until g.saveBanLocked(ip, until) return true, len(attempts) } return false, len(attempts) } func (g *Guard) RegisterScanAttempt(ip string, windowSeconds int64, maxAttempts int, banSeconds int64) (bool, int) { if ip == "" || maxAttempts <= 0 || windowSeconds <= 0 { return false, 0 } g.mu.Lock() defer g.mu.Unlock() now := time.Now().UTC() cutoff := now.Add(-time.Duration(windowSeconds) * time.Second) attempts := pruneTimes(g.scanAttempts[ip], cutoff) attempts = append(attempts, now) g.scanAttempts[ip] = attempts if len(attempts) >= maxAttempts { until := now.Add(time.Duration(banSeconds) * time.Second) g.bannedUntil[ip] = until g.saveBanLocked(ip, until) return true, len(attempts) } return false, len(attempts) } func (g *Guard) AllowUpload(ip string, size int64, windowSeconds int64, maxRequests int, maxBytes int64) (bool, int, int64) { if ip == "" || windowSeconds <= 0 || maxRequests <= 0 { return true, 0, 0 } g.mu.Lock() defer g.mu.Unlock() now := time.Now().UTC() cutoff := now.Add(-time.Duration(windowSeconds) * time.Second) events := g.uploadEvents[ip] kept := make([]uploadEvent, 0, len(events)+1) totalBytes := int64(0) for _, event := range events { if event.at.After(cutoff) { kept = append(kept, event) totalBytes += event.bytes } } nextCount := len(kept) + 1 nextBytes := totalBytes + size if nextCount > maxRequests { return false, nextCount, nextBytes } if maxBytes > 0 && nextBytes > maxBytes { return false, nextCount, nextBytes } kept = append(kept, uploadEvent{at: now, bytes: size}) g.uploadEvents[ip] = kept return true, nextCount, nextBytes } func ParseIPMatchers(raw string, allowCIDR bool) ([]ipMatcher, error) { entries := []ipMatcher{} for _, chunk := range strings.Split(raw, ",") { value := strings.TrimSpace(chunk) if value == "" { continue } if strings.Contains(value, "/") { if !allowCIDR { return nil, fmt.Errorf("%q must be a CIDR", value) } _, network, err := net.ParseCIDR(value) if err != nil { return nil, fmt.Errorf("invalid CIDR %q", value) } entries = append(entries, ipMatcher{prefix: network}) continue } parsed := net.ParseIP(value) if parsed == nil { return nil, fmt.Errorf("invalid IP %q", value) } entries = append(entries, ipMatcher{exact: parsed}) } return entries, nil } func ParseCIDRList(raw string) ([]net.IPNet, error) { entries := []net.IPNet{} for _, chunk := range strings.Split(raw, ",") { value := strings.TrimSpace(chunk) if value == "" { continue } _, network, err := net.ParseCIDR(value) if err != nil { return nil, fmt.Errorf("invalid CIDR %q", value) } entries = append(entries, *network) } return entries, nil } func pruneTimes(values []time.Time, cutoff time.Time) []time.Time { kept := make([]time.Time, 0, len(values)) for _, value := range values { if value.After(cutoff) { kept = append(kept, value) } } return kept } func matchIP(rules []ipMatcher, value string) bool { ip := net.ParseIP(strings.TrimSpace(value)) if ip == nil { return false } for _, rule := range rules { if rule.exact != nil && rule.exact.Equal(ip) { return true } if rule.prefix != nil && rule.prefix.Contains(ip) { return true } } return false } func (g *Guard) saveBanLocked(ip string, until time.Time) { if g.banDB == nil || ip == "" || until.IsZero() { return } seconds := int64(time.Until(until).Seconds()) if seconds <= 0 { _ = g.banDB.Update(func(txn *badger.Txn) error { return txn.Delete([]byte(banKeyPrefix + ip)) }) return } value := make([]byte, 8) binary.BigEndian.PutUint64(value, uint64(until.Unix())) _ = g.banDB.Update(func(txn *badger.Txn) error { entry := badger.NewEntry([]byte(banKeyPrefix+ip), value).WithTTL(time.Duration(seconds) * time.Second) return txn.SetEntry(entry) }) } func (g *Guard) deleteBanLocked(ip string) { if g.banDB == nil || ip == "" { return } _ = g.banDB.Update(func(txn *badger.Txn) error { return txn.Delete([]byte(banKeyPrefix + ip)) }) } func (g *Guard) loadBansLocked() error { if g.banDB == nil { return nil } now := time.Now().UTC() loaded := map[string]time.Time{} expired := [][]byte{} err := g.banDB.View(func(txn *badger.Txn) error { it := txn.NewIterator(badger.DefaultIteratorOptions) defer it.Close() for it.Seek([]byte(banKeyPrefix)); it.ValidForPrefix([]byte(banKeyPrefix)); it.Next() { item := it.Item() key := string(item.Key()) ip := strings.TrimPrefix(key, banKeyPrefix) err := item.Value(func(val []byte) error { if len(val) != 8 { expired = append(expired, append([]byte(nil), item.Key()...)) return nil } unix := int64(binary.BigEndian.Uint64(val)) until := time.Unix(unix, 0).UTC() if now.After(until) { expired = append(expired, append([]byte(nil), item.Key()...)) return nil } loaded[ip] = until return nil }) if err != nil { return err } } return nil }) if err != nil { return err } g.bannedUntil = loaded if len(expired) == 0 { return nil } return g.banDB.Update(func(txn *badger.Txn) error { for _, key := range expired { if err := txn.Delete(key); err != nil { return err } } return nil }) }