feat(security): Implemented more security information
This commit is contained in:
@@ -1,10 +1,16 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@@ -26,8 +32,14 @@ type Guard struct {
|
||||
scanAttempts map[string][]time.Time
|
||||
uploadEvents map[string][]uploadEvent
|
||||
bannedUntil map[string]time.Time
|
||||
ipWhitelist map[string]bool
|
||||
adminWhitelist map[string]bool
|
||||
ipWhitelist []ipMatcher
|
||||
adminWhitelist []ipMatcher
|
||||
banDB *badger.DB
|
||||
}
|
||||
|
||||
type ipMatcher struct {
|
||||
exact net.IP
|
||||
prefix *net.IPNet
|
||||
}
|
||||
|
||||
type uploadEvent struct {
|
||||
@@ -40,34 +52,90 @@ type BanEntry struct {
|
||||
Until time.Time `json:"until"`
|
||||
}
|
||||
|
||||
const banKeyPrefix = "ban:"
|
||||
|
||||
func NewGuard() *Guard {
|
||||
return &Guard{
|
||||
failedLogins: map[string][]time.Time{},
|
||||
scanAttempts: map[string][]time.Time{},
|
||||
uploadEvents: map[string][]uploadEvent{},
|
||||
bannedUntil: map[string]time.Time{},
|
||||
ipWhitelist: map[string]bool{},
|
||||
adminWhitelist: map[string]bool{},
|
||||
ipWhitelist: []ipMatcher{},
|
||||
adminWhitelist: []ipMatcher{},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) Reload(cfg Config) {
|
||||
func (g *Guard) Close() error {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.ipWhitelist = parseList(cfg.IPWhitelist)
|
||||
g.adminWhitelist = parseList(cfg.AdminIPWhitelist)
|
||||
if g.banDB == nil {
|
||||
return nil
|
||||
}
|
||||
err := g.banDB.Close()
|
||||
g.banDB = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *Guard) EnableBanPersistence(path string) error {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if g.banDB != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := badger.DefaultOptions(path)
|
||||
opts.Logger = nil
|
||||
db, err := badger.Open(opts)
|
||||
if err != nil {
|
||||
// Corruption-safe fallback: quarantine badger files and start fresh.
|
||||
_ = os.Rename(path, path+".corrupt."+time.Now().UTC().Format("20060102T150405"))
|
||||
db, err = badger.Open(opts)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.banDB = db
|
||||
|
||||
if err := g.loadBansLocked(); err != nil {
|
||||
_ = g.banDB.Close()
|
||||
g.banDB = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Guard) Reload(cfg Config) error {
|
||||
ipWhitelist, err := ParseIPMatchers(cfg.IPWhitelist, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ip whitelist: %w", err)
|
||||
}
|
||||
adminWhitelist, err := ParseIPMatchers(cfg.AdminIPWhitelist, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("admin ip whitelist: %w", err)
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.ipWhitelist = ipWhitelist
|
||||
g.adminWhitelist = adminWhitelist
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Guard) IsWhitelisted(ip string) bool {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
return g.ipWhitelist[ip]
|
||||
return matchIP(g.ipWhitelist, ip)
|
||||
}
|
||||
|
||||
func (g *Guard) IsAdminWhitelisted(ip string) bool {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
return g.adminWhitelist[ip] || g.ipWhitelist[ip]
|
||||
return matchIP(g.adminWhitelist, ip) || matchIP(g.ipWhitelist, ip)
|
||||
}
|
||||
|
||||
func (g *Guard) IsBanned(ip string) bool {
|
||||
@@ -79,6 +147,7 @@ func (g *Guard) IsBanned(ip string) bool {
|
||||
}
|
||||
if time.Now().UTC().After(until) {
|
||||
delete(g.bannedUntil, ip)
|
||||
g.deleteBanLocked(ip)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -90,7 +159,9 @@ func (g *Guard) Ban(ip string, seconds int64) {
|
||||
}
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.bannedUntil[ip] = time.Now().UTC().Add(time.Duration(seconds) * time.Second)
|
||||
until := time.Now().UTC().Add(time.Duration(seconds) * time.Second)
|
||||
g.bannedUntil[ip] = until
|
||||
g.saveBanLocked(ip, until)
|
||||
}
|
||||
|
||||
func (g *Guard) BanUntil(ip string, until time.Time) {
|
||||
@@ -99,7 +170,9 @@ func (g *Guard) BanUntil(ip string, until time.Time) {
|
||||
}
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.bannedUntil[ip] = until.UTC()
|
||||
until = until.UTC()
|
||||
g.bannedUntil[ip] = until
|
||||
g.saveBanLocked(ip, until)
|
||||
}
|
||||
|
||||
func (g *Guard) Unban(ip string) {
|
||||
@@ -109,6 +182,7 @@ func (g *Guard) Unban(ip string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
delete(g.bannedUntil, ip)
|
||||
g.deleteBanLocked(ip)
|
||||
}
|
||||
|
||||
func (g *Guard) BanList() []BanEntry {
|
||||
@@ -119,6 +193,7 @@ func (g *Guard) BanList() []BanEntry {
|
||||
for ip, until := range g.bannedUntil {
|
||||
if now.After(until) {
|
||||
delete(g.bannedUntil, ip)
|
||||
g.deleteBanLocked(ip)
|
||||
continue
|
||||
}
|
||||
out = append(out, BanEntry{IP: ip, Until: until})
|
||||
@@ -141,7 +216,9 @@ func (g *Guard) RegisterFailedLogin(ip string, windowSeconds int64, maxAttempts
|
||||
attempts = append(attempts, now)
|
||||
g.failedLogins[ip] = attempts
|
||||
if len(attempts) >= maxAttempts {
|
||||
g.bannedUntil[ip] = now.Add(time.Duration(banSeconds) * time.Second)
|
||||
until := now.Add(time.Duration(banSeconds) * time.Second)
|
||||
g.bannedUntil[ip] = until
|
||||
g.saveBanLocked(ip, until)
|
||||
return true, len(attempts)
|
||||
}
|
||||
return false, len(attempts)
|
||||
@@ -159,7 +236,9 @@ func (g *Guard) RegisterScanAttempt(ip string, windowSeconds int64, maxAttempts
|
||||
attempts = append(attempts, now)
|
||||
g.scanAttempts[ip] = attempts
|
||||
if len(attempts) >= maxAttempts {
|
||||
g.bannedUntil[ip] = now.Add(time.Duration(banSeconds) * time.Second)
|
||||
until := now.Add(time.Duration(banSeconds) * time.Second)
|
||||
g.bannedUntil[ip] = until
|
||||
g.saveBanLocked(ip, until)
|
||||
return true, len(attempts)
|
||||
}
|
||||
return false, len(attempts)
|
||||
@@ -195,6 +274,49 @@ func (g *Guard) AllowUpload(ip string, size int64, windowSeconds int64, maxReque
|
||||
return true, nextCount, nextBytes
|
||||
}
|
||||
|
||||
func ParseIPMatchers(raw string, allowCIDR bool) ([]ipMatcher, error) {
|
||||
entries := []ipMatcher{}
|
||||
for _, chunk := range strings.Split(raw, ",") {
|
||||
value := strings.TrimSpace(chunk)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(value, "/") {
|
||||
if !allowCIDR {
|
||||
return nil, fmt.Errorf("%q must be a CIDR", value)
|
||||
}
|
||||
_, network, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %q", value)
|
||||
}
|
||||
entries = append(entries, ipMatcher{prefix: network})
|
||||
continue
|
||||
}
|
||||
parsed := net.ParseIP(value)
|
||||
if parsed == nil {
|
||||
return nil, fmt.Errorf("invalid IP %q", value)
|
||||
}
|
||||
entries = append(entries, ipMatcher{exact: parsed})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func ParseCIDRList(raw string) ([]net.IPNet, error) {
|
||||
entries := []net.IPNet{}
|
||||
for _, chunk := range strings.Split(raw, ",") {
|
||||
value := strings.TrimSpace(chunk)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, network, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %q", value)
|
||||
}
|
||||
entries = append(entries, *network)
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func pruneTimes(values []time.Time, cutoff time.Time) []time.Time {
|
||||
kept := make([]time.Time, 0, len(values))
|
||||
for _, value := range values {
|
||||
@@ -205,13 +327,100 @@ func pruneTimes(values []time.Time, cutoff time.Time) []time.Time {
|
||||
return kept
|
||||
}
|
||||
|
||||
func parseList(raw string) map[string]bool {
|
||||
out := map[string]bool{}
|
||||
for _, chunk := range strings.Split(raw, ",") {
|
||||
value := strings.TrimSpace(chunk)
|
||||
if value != "" {
|
||||
out[value] = true
|
||||
func matchIP(rules []ipMatcher, value string) bool {
|
||||
ip := net.ParseIP(strings.TrimSpace(value))
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if rule.exact != nil && rule.exact.Equal(ip) {
|
||||
return true
|
||||
}
|
||||
if rule.prefix != nil && rule.prefix.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return out
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *Guard) saveBanLocked(ip string, until time.Time) {
|
||||
if g.banDB == nil || ip == "" || until.IsZero() {
|
||||
return
|
||||
}
|
||||
seconds := int64(time.Until(until).Seconds())
|
||||
if seconds <= 0 {
|
||||
_ = g.banDB.Update(func(txn *badger.Txn) error {
|
||||
return txn.Delete([]byte(banKeyPrefix + ip))
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
value := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(value, uint64(until.Unix()))
|
||||
_ = g.banDB.Update(func(txn *badger.Txn) error {
|
||||
entry := badger.NewEntry([]byte(banKeyPrefix+ip), value).WithTTL(time.Duration(seconds) * time.Second)
|
||||
return txn.SetEntry(entry)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *Guard) deleteBanLocked(ip string) {
|
||||
if g.banDB == nil || ip == "" {
|
||||
return
|
||||
}
|
||||
_ = g.banDB.Update(func(txn *badger.Txn) error {
|
||||
return txn.Delete([]byte(banKeyPrefix + ip))
|
||||
})
|
||||
}
|
||||
|
||||
func (g *Guard) loadBansLocked() error {
|
||||
if g.banDB == nil {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
loaded := map[string]time.Time{}
|
||||
expired := [][]byte{}
|
||||
|
||||
err := g.banDB.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
for it.Seek([]byte(banKeyPrefix)); it.ValidForPrefix([]byte(banKeyPrefix)); it.Next() {
|
||||
item := it.Item()
|
||||
key := string(item.Key())
|
||||
ip := strings.TrimPrefix(key, banKeyPrefix)
|
||||
err := item.Value(func(val []byte) error {
|
||||
if len(val) != 8 {
|
||||
expired = append(expired, append([]byte(nil), item.Key()...))
|
||||
return nil
|
||||
}
|
||||
unix := int64(binary.BigEndian.Uint64(val))
|
||||
until := time.Unix(unix, 0).UTC()
|
||||
if now.After(until) {
|
||||
expired = append(expired, append([]byte(nil), item.Key()...))
|
||||
return nil
|
||||
}
|
||||
loaded[ip] = until
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.bannedUntil = loaded
|
||||
if len(expired) == 0 {
|
||||
return nil
|
||||
}
|
||||
return g.banDB.Update(func(txn *badger.Txn) error {
|
||||
for _, key := range expired {
|
||||
if err := txn.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user