211 lines
5.9 KiB
Go
211 lines
5.9 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestRealIPFromTrustedProxy(t *testing.T) {
|
||
|
|
_, tailscale, _ := net.ParseCIDR("100.64.0.0/10")
|
||
|
|
trusted := []*net.IPNet{tailscale}
|
||
|
|
|
||
|
|
handler := RealIP(trusted)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
ip := RealIPFromContext(r.Context())
|
||
|
|
w.Write([]byte(ip))
|
||
|
|
}))
|
||
|
|
|
||
|
|
// Request from trusted proxy with X-Forwarded-For.
|
||
|
|
req := httptest.NewRequest("GET", "/", nil)
|
||
|
|
req.RemoteAddr = "100.64.1.1:12345"
|
||
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.50, 100.64.1.1")
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Body.String() != "203.0.113.50" {
|
||
|
|
t.Errorf("real IP = %q, want %q", rr.Body.String(), "203.0.113.50")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRealIPFromUntrustedProxy(t *testing.T) {
|
||
|
|
_, localhost, _ := net.ParseCIDR("127.0.0.1/32")
|
||
|
|
trusted := []*net.IPNet{localhost}
|
||
|
|
|
||
|
|
handler := RealIP(trusted)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
ip := RealIPFromContext(r.Context())
|
||
|
|
w.Write([]byte(ip))
|
||
|
|
}))
|
||
|
|
|
||
|
|
// Request from untrusted IP — X-Forwarded-For should be ignored.
|
||
|
|
req := httptest.NewRequest("GET", "/", nil)
|
||
|
|
req.RemoteAddr = "192.168.1.100:12345"
|
||
|
|
req.Header.Set("X-Forwarded-For", "spoofed-ip")
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Body.String() != "192.168.1.100" {
|
||
|
|
t.Errorf("real IP = %q, want %q (should ignore XFF from untrusted)", rr.Body.String(), "192.168.1.100")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBasePathStripping(t *testing.T) {
|
||
|
|
handler := BasePath("/faves")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.Write([]byte(r.URL.Path))
|
||
|
|
}))
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
path string
|
||
|
|
want string
|
||
|
|
}{
|
||
|
|
{"/faves/", "/"},
|
||
|
|
{"/faves/login", "/login"},
|
||
|
|
{"/faves/u/test", "/u/test"},
|
||
|
|
{"/other", "/other"}, // No prefix match — passed through unchanged.
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
req := httptest.NewRequest("GET", tt.path, nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Body.String() != tt.want {
|
||
|
|
t.Errorf("BasePath(%q) = %q, want %q", tt.path, rr.Body.String(), tt.want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBasePathEmpty(t *testing.T) {
|
||
|
|
// Empty base path should be a no-op.
|
||
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.Write([]byte(r.URL.Path))
|
||
|
|
})
|
||
|
|
handler := BasePath("")(inner)
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "/login", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Body.String() != "/login" {
|
||
|
|
t.Errorf("empty base path: got %q, want /login", rr.Body.String())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRateLimiter(t *testing.T) {
|
||
|
|
rl := NewRateLimiter(3)
|
||
|
|
handler := rl.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
for i := 0; i < 3; i++ {
|
||
|
|
req := httptest.NewRequest("POST", "/login", nil)
|
||
|
|
// RealIP middleware hasn't run, so RealIPFromContext returns "".
|
||
|
|
// The rate limiter falls back to RemoteAddr.
|
||
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("request %d: got %d, want 200", i+1, rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 4th request should be rate-limited.
|
||
|
|
req := httptest.NewRequest("POST", "/login", nil)
|
||
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Code != http.StatusTooManyRequests {
|
||
|
|
t.Errorf("rate-limited request: got %d, want 429", rr.Code)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Different IP should not be rate-limited.
|
||
|
|
req = httptest.NewRequest("POST", "/login", nil)
|
||
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
||
|
|
rr = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("different IP: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRecovery(t *testing.T) {
|
||
|
|
handler := Recovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
panic("test panic")
|
||
|
|
}))
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "/", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusInternalServerError {
|
||
|
|
t.Errorf("panic recovery: got %d, want 500", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSecurityHeaders(t *testing.T) {
|
||
|
|
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "/", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
headers := map[string]string{
|
||
|
|
"X-Content-Type-Options": "nosniff",
|
||
|
|
"X-Frame-Options": "DENY",
|
||
|
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||
|
|
}
|
||
|
|
for key, want := range headers {
|
||
|
|
got := rr.Header().Get(key)
|
||
|
|
if got != want {
|
||
|
|
t.Errorf("%s = %q, want %q", key, got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
csp := rr.Header().Get("Content-Security-Policy")
|
||
|
|
if csp == "" {
|
||
|
|
t.Error("Content-Security-Policy header missing")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireLoginRedirects(t *testing.T) {
|
||
|
|
handler := RequireLogin("/faves")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
// No user in context — should redirect.
|
||
|
|
req := httptest.NewRequest("GET", "/settings", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusSeeOther {
|
||
|
|
t.Errorf("no user: got %d, want 303", rr.Code)
|
||
|
|
}
|
||
|
|
loc := rr.Header().Get("Location")
|
||
|
|
if loc != "/faves/login" {
|
||
|
|
t.Errorf("redirect location = %q, want /faves/login", loc)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestMustResetPasswordGuard(t *testing.T) {
|
||
|
|
handler := MustResetPasswordGuard("")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
// No user — should pass through.
|
||
|
|
req := httptest.NewRequest("GET", "/faves", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("no user: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Static paths should always pass through even with must-reset user.
|
||
|
|
req = httptest.NewRequest("GET", "/static/css/style.css", nil)
|
||
|
|
rr = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("static path: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
}
|