68 lines
1.7 KiB
Go
68 lines
1.7 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
// RealIP extracts the real client IP from X-Forwarded-For, but only if the
|
||
|
|
// direct connection comes from a trusted proxy. This is essential when Caddy
|
||
|
|
// runs on a different machine (e.g. connected via WireGuard/Tailscale).
|
||
|
|
func RealIP(trustedProxies []*net.IPNet) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
ip := extractRealIP(r, trustedProxies)
|
||
|
|
ctx := context.WithValue(r.Context(), realIPKey, ip)
|
||
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func extractRealIP(r *http.Request, trusted []*net.IPNet) string {
|
||
|
|
// Get the direct connection IP.
|
||
|
|
directIP, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||
|
|
if directIP == "" {
|
||
|
|
directIP = r.RemoteAddr
|
||
|
|
}
|
||
|
|
|
||
|
|
// Only trust X-Forwarded-For if the direct connection is from a trusted proxy.
|
||
|
|
if !isTrusted(directIP, trusted) {
|
||
|
|
return directIP
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse X-Forwarded-For: client, proxy1, proxy2
|
||
|
|
// The rightmost non-trusted IP is the real client.
|
||
|
|
xff := r.Header.Get("X-Forwarded-For")
|
||
|
|
if xff == "" {
|
||
|
|
return directIP
|
||
|
|
}
|
||
|
|
|
||
|
|
ips := strings.Split(xff, ",")
|
||
|
|
// Walk from right to left, finding the first non-trusted IP.
|
||
|
|
for i := len(ips) - 1; i >= 0; i-- {
|
||
|
|
ip := strings.TrimSpace(ips[i])
|
||
|
|
if !isTrusted(ip, trusted) {
|
||
|
|
return ip
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// All IPs in the chain are trusted; use the leftmost.
|
||
|
|
return strings.TrimSpace(ips[0])
|
||
|
|
}
|
||
|
|
|
||
|
|
func isTrusted(ipStr string, nets []*net.IPNet) bool {
|
||
|
|
ip := net.ParseIP(ipStr)
|
||
|
|
if ip == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
for _, n := range nets {
|
||
|
|
if n.Contains(ip) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|