// SPDX-License-Identifier: AGPL-3.0-or-later package middleware import ( "net" "net/http" "net/http/httptest" "testing" ) func TestRealIPFromTrustedProxy(t *testing.T) { _, tailscale, _ := net.ParseCIDR("100.64.0.0/10") trusted := []*net.IPNet{tailscale} handler := RealIP(trusted)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := RealIPFromContext(r.Context()) w.Write([]byte(ip)) })) // Request from trusted proxy with X-Forwarded-For. req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = "100.64.1.1:12345" req.Header.Set("X-Forwarded-For", "203.0.113.50, 100.64.1.1") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Body.String() != "203.0.113.50" { t.Errorf("real IP = %q, want %q", rr.Body.String(), "203.0.113.50") } } func TestRealIPFromUntrustedProxy(t *testing.T) { _, localhost, _ := net.ParseCIDR("127.0.0.1/32") trusted := []*net.IPNet{localhost} handler := RealIP(trusted)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := RealIPFromContext(r.Context()) w.Write([]byte(ip)) })) // Request from untrusted IP — X-Forwarded-For should be ignored. req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = "192.168.1.100:12345" req.Header.Set("X-Forwarded-For", "spoofed-ip") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Body.String() != "192.168.1.100" { t.Errorf("real IP = %q, want %q (should ignore XFF from untrusted)", rr.Body.String(), "192.168.1.100") } } func TestBasePathStripping(t *testing.T) { handler := BasePath("/faves")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.Path)) })) tests := []struct { path string want string }{ {"/faves/", "/"}, {"/faves/login", "/login"}, {"/faves/u/test", "/u/test"}, {"/other", "/other"}, // No prefix match — passed through unchanged. } for _, tt := range tests { req := httptest.NewRequest("GET", tt.path, nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Body.String() != tt.want { t.Errorf("BasePath(%q) = %q, want %q", tt.path, rr.Body.String(), tt.want) } } } func TestBasePathEmpty(t *testing.T) { // Empty base path should be a no-op. inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.Path)) }) handler := BasePath("")(inner) req := httptest.NewRequest("GET", "/login", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Body.String() != "/login" { t.Errorf("empty base path: got %q, want /login", rr.Body.String()) } } func TestRateLimiter(t *testing.T) { rl := NewRateLimiter(3) handler := rl.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 3; i++ { req := httptest.NewRequest("POST", "/login", nil) // RealIP middleware hasn't run, so RealIPFromContext returns "". // The rate limiter falls back to RemoteAddr. req.RemoteAddr = "192.168.1.1:1234" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("request %d: got %d, want 200", i+1, rr.Code) } } // 4th request should be rate-limited. req := httptest.NewRequest("POST", "/login", nil) req.RemoteAddr = "192.168.1.1:1234" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("rate-limited request: got %d, want 429", rr.Code) } // Different IP should not be rate-limited. req = httptest.NewRequest("POST", "/login", nil) req.RemoteAddr = "10.0.0.1:1234" rr = httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("different IP: got %d, want 200", rr.Code) } } func TestRecovery(t *testing.T) { handler := Recovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("test panic") })) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusInternalServerError { t.Errorf("panic recovery: got %d, want 500", rr.Code) } } func TestSecurityHeaders(t *testing.T) { handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) headers := map[string]string{ "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "Referrer-Policy": "strict-origin-when-cross-origin", } for key, want := range headers { got := rr.Header().Get(key) if got != want { t.Errorf("%s = %q, want %q", key, got, want) } } csp := rr.Header().Get("Content-Security-Policy") if csp == "" { t.Error("Content-Security-Policy header missing") } } func TestRequireLoginRedirects(t *testing.T) { handler := RequireLogin("/faves")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // No user in context — should redirect. req := httptest.NewRequest("GET", "/settings", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusSeeOther { t.Errorf("no user: got %d, want 303", rr.Code) } loc := rr.Header().Get("Location") if loc != "/faves/login" { t.Errorf("redirect location = %q, want /faves/login", loc) } } func TestMustResetPasswordGuard(t *testing.T) { handler := MustResetPasswordGuard("")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // No user — should pass through. req := httptest.NewRequest("GET", "/faves", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("no user: got %d, want 200", rr.Code) } // Static paths should always pass through even with must-reset user. req = httptest.NewRequest("GET", "/static/css/style.css", nil) rr = httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("static path: got %d, want 200", rr.Code) } }