package httpserver_test import ( "context" "encoding/json" "errors" "io" "net" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/httpserver" brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log" ) // fakePinger implements httpserver.Pinger for /healthz tests. type fakePinger struct{ err error } func (f *fakePinger) Ping(context.Context) error { return f.err } func TestHealth_OK(t *testing.T) { s := &httpserver.Server{Log: brokerlog.Discard(), Store: &fakePinger{}} req := httptest.NewRequest(http.MethodGet, "/healthz", nil) w := httptest.NewRecorder() s.Handler().ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want %d", w.Code, http.StatusOK) } var resp map[string]string if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("body not JSON: %v", err) } for _, k := range []string{"status", "version", "git_revision", "build_date", "store"} { if resp[k] == "" { t.Errorf("healthz response missing field %q: %v", k, resp) } } if resp["status"] != "ok" { t.Errorf("status = %q, want ok", resp["status"]) } if resp["store"] != "ok" { t.Errorf("store = %q, want ok", resp["store"]) } } func TestHealth_DegradedOnStoreFailure(t *testing.T) { s := &httpserver.Server{Log: brokerlog.Discard(), Store: &fakePinger{err: errors.New("boom")}} w := httptest.NewRecorder() s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/healthz", nil)) if w.Code != http.StatusServiceUnavailable { t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) } if !strings.Contains(w.Body.String(), "degraded") { t.Errorf("body should mark status as degraded: %s", w.Body.String()) } if !strings.Contains(w.Body.String(), "boom") { t.Errorf("body should include underlying error: %s", w.Body.String()) } } func TestHealth_NoStoreConfigured(t *testing.T) { s := &httpserver.Server{Log: brokerlog.Discard()} w := httptest.NewRecorder() s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/healthz", nil)) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200 when store is unconfigured", w.Code) } if !strings.Contains(w.Body.String(), "not configured") { t.Errorf("expected 'not configured' marker: %s", w.Body.String()) } } func TestHandler_WrongMethodIsRejected(t *testing.T) { // Go 1.22+ mux: POST /healthz should not dispatch to the GET handler. s := &httpserver.Server{Log: brokerlog.Discard()} w := httptest.NewRecorder() s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/healthz", nil)) if w.Code != http.StatusMethodNotAllowed { t.Errorf("POST /healthz should return 405, got %d", w.Code) } } func TestHandler_ExtraHandlerReceivesOtherPaths(t *testing.T) { extra := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "extra:"+r.URL.Path) }) s := &httpserver.Server{Log: brokerlog.Discard(), ExtraHandler: extra} w := httptest.NewRecorder() s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/other", nil)) if !strings.Contains(w.Body.String(), "extra:/other") { t.Errorf("extra handler not invoked, got %q", w.Body.String()) } } func TestRun_ShutdownOnContextCancel(t *testing.T) { addr := freeAddr(t) s := &httpserver.Server{Addr: addr, Log: brokerlog.Discard(), Store: &fakePinger{}} ctx, cancel := context.WithCancel(t.Context()) runErr := make(chan error, 1) go func() { runErr <- s.Run(ctx) }() // Wait for the listener to be ready. waitReady(t, addr, 2*time.Second) // Sanity: /healthz works while running. resp, err := http.Get("http://" + addr + "/healthz") if err != nil { t.Fatalf("pre-shutdown GET: %v", err) } _ = resp.Body.Close() // Cancel the context to simulate SIGTERM. start := time.Now() cancel() select { case err := <-runErr: if err != nil { t.Errorf("Run returned error: %v", err) } if elapsed := time.Since(start); elapsed > 2*time.Second { t.Errorf("Run took %s to shut down, want < 2s", elapsed) } case <-time.After(3 * time.Second): t.Fatal("Run did not return within 3s of cancel") } } func TestRun_ShutdownTimeout_ForciblyClosesSlowRequests(t *testing.T) { addr := freeAddr(t) // An extra handler that blocks until the test releases it — simulates // a slow in-flight request that outlives the shutdown deadline. release := make(chan struct{}) startedServing := make(chan struct{}) var once sync.Once extra := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { once.Do(func() { close(startedServing) }) select { case <-release: case <-r.Context().Done(): } }) s := &httpserver.Server{ Addr: addr, Log: brokerlog.Discard(), ExtraHandler: extra, ShutdownTimeout: 200 * time.Millisecond, } ctx, cancel := context.WithCancel(t.Context()) runErr := make(chan error, 1) go func() { runErr <- s.Run(ctx) }() waitReady(t, addr, 2*time.Second) // Fire a request that will hang in the handler. reqDone := make(chan error, 1) go func() { resp, err := http.Get("http://" + addr + "/slow") if resp != nil { _ = resp.Body.Close() } reqDone <- err }() <-startedServing // handler is blocking cancel() // shutdown fires; handler won't return voluntarily select { case err := <-runErr: // Shutdown should complete (with an error reporting deadline breach) // within about the timeout + small wiggle. if err == nil { t.Logf("Run returned nil — http.Server closed the conn via ctx.Done() cascade; acceptable") } else if !strings.Contains(err.Error(), "shutdown") { t.Errorf("unexpected Run error: %v", err) } case <-time.After(2 * time.Second): t.Fatal("Run did not return within 2s of cancel") } // The hanging request's conn should have been forcibly closed. select { case err := <-reqDone: if err == nil { t.Error("slow request should have been terminated by shutdown") } case <-time.After(2 * time.Second): t.Fatal("slow request did not terminate after shutdown") } // Drain the handler so no goroutine leaks past the test. close(release) } func TestRun_MissingFieldsErr(t *testing.T) { cases := []struct { name string server *httpserver.Server want string }{ {"no_log", &httpserver.Server{Addr: ":0"}, "Log"}, {"no_addr", &httpserver.Server{Log: brokerlog.Discard()}, "Addr"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := tc.server.Run(t.Context()) if err == nil || !strings.Contains(err.Error(), tc.want) { t.Errorf("want error containing %q, got %v", tc.want, err) } }) } } // freeAddr returns a loopback "host:port" with a port chosen by the kernel. // The listener is closed immediately, so there is a tiny race window before // the caller rebinds — acceptable for loopback test use. func freeAddr(t *testing.T) string { t.Helper() l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } addr := l.Addr().String() _ = l.Close() return addr } // waitReady polls the target address until a TCP dial succeeds or the // deadline expires. func waitReady(t *testing.T, addr string, within time.Duration) { t.Helper() deadline := time.Now().Add(within) for time.Now().Before(deadline) { c, err := net.DialTimeout("tcp", addr, 50*time.Millisecond) if err == nil { _ = c.Close() return } time.Sleep(10 * time.Millisecond) } t.Fatalf("server not reachable at %s within %s", addr, within) }