76 lines
1.8 KiB
Go
76 lines
1.8 KiB
Go
|
|
package services
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
type clientIPContextKey struct{}
|
||
|
|
|
||
|
|
func WithClientIP(r *http.Request, ip string) *http.Request {
|
||
|
|
return r.WithContext(context.WithValue(r.Context(), clientIPContextKey{}, ip))
|
||
|
|
}
|
||
|
|
|
||
|
|
func ClientIPFromContext(r *http.Request) (string, bool) {
|
||
|
|
ip, ok := r.Context().Value(clientIPContextKey{}).(string)
|
||
|
|
return ip, ok && ip != ""
|
||
|
|
}
|
||
|
|
|
||
|
|
// ClientIP resolves the effective client IP. When trustedProxies is empty,
|
||
|
|
// forwarded headers are trusted for easy reverse-proxy/container defaults.
|
||
|
|
func ClientIP(remoteAddr, forwardedFor, realIP string, trustedProxies []string) string {
|
||
|
|
remoteIP := remoteIPOnly(remoteAddr)
|
||
|
|
if len(trustedProxies) == 0 || remoteTrusted(remoteIP, trustedProxies) {
|
||
|
|
if ip := firstForwardedIP(forwardedFor); ip != "" {
|
||
|
|
return ip
|
||
|
|
}
|
||
|
|
if ip := strings.TrimSpace(realIP); ip != "" {
|
||
|
|
return ip
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return remoteIP
|
||
|
|
}
|
||
|
|
|
||
|
|
func remoteIPOnly(remoteAddr string) string {
|
||
|
|
host := strings.TrimSpace(remoteAddr)
|
||
|
|
if splitHost, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||
|
|
host = splitHost
|
||
|
|
}
|
||
|
|
return strings.Trim(host, "[]")
|
||
|
|
}
|
||
|
|
|
||
|
|
func firstForwardedIP(forwardedFor string) string {
|
||
|
|
for _, part := range strings.Split(forwardedFor, ",") {
|
||
|
|
ip := strings.TrimSpace(part)
|
||
|
|
if ip != "" {
|
||
|
|
return strings.Trim(ip, "[]")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func remoteTrusted(remoteIP string, trustedProxies []string) bool {
|
||
|
|
parsed := net.ParseIP(remoteIP)
|
||
|
|
if parsed == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
for _, trusted := range trustedProxies {
|
||
|
|
trusted = strings.TrimSpace(trusted)
|
||
|
|
if trusted == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if strings.Contains(trusted, "/") {
|
||
|
|
if _, network, err := net.ParseCIDR(trusted); err == nil && network.Contains(parsed) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if ip := net.ParseIP(trusted); ip != nil && ip.Equal(parsed) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|