253 lines
7.3 KiB
Go
253 lines
7.3 KiB
Go
|
|
package httpserver_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/httpserver"
|
||
|
|
brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log"
|
||
|
|
)
|
||
|
|
|
||
|
|
// fakePinger implements httpserver.Pinger for /healthz tests.
|
||
|
|
type fakePinger struct{ err error }
|
||
|
|
|
||
|
|
func (f *fakePinger) Ping(context.Context) error { return f.err }
|
||
|
|
|
||
|
|
func TestHealth_OK(t *testing.T) {
|
||
|
|
s := &httpserver.Server{Log: brokerlog.Discard(), Store: &fakePinger{}}
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
s.Handler().ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||
|
|
}
|
||
|
|
var resp map[string]string
|
||
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||
|
|
t.Fatalf("body not JSON: %v", err)
|
||
|
|
}
|
||
|
|
for _, k := range []string{"status", "version", "git_revision", "build_date", "store"} {
|
||
|
|
if resp[k] == "" {
|
||
|
|
t.Errorf("healthz response missing field %q: %v", k, resp)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if resp["status"] != "ok" {
|
||
|
|
t.Errorf("status = %q, want ok", resp["status"])
|
||
|
|
}
|
||
|
|
if resp["store"] != "ok" {
|
||
|
|
t.Errorf("store = %q, want ok", resp["store"])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHealth_DegradedOnStoreFailure(t *testing.T) {
|
||
|
|
s := &httpserver.Server{Log: brokerlog.Discard(), Store: &fakePinger{err: errors.New("boom")}}
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/healthz", nil))
|
||
|
|
|
||
|
|
if w.Code != http.StatusServiceUnavailable {
|
||
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||
|
|
}
|
||
|
|
if !strings.Contains(w.Body.String(), "degraded") {
|
||
|
|
t.Errorf("body should mark status as degraded: %s", w.Body.String())
|
||
|
|
}
|
||
|
|
if !strings.Contains(w.Body.String(), "boom") {
|
||
|
|
t.Errorf("body should include underlying error: %s", w.Body.String())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHealth_NoStoreConfigured(t *testing.T) {
|
||
|
|
s := &httpserver.Server{Log: brokerlog.Discard()}
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/healthz", nil))
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status = %d, want 200 when store is unconfigured", w.Code)
|
||
|
|
}
|
||
|
|
if !strings.Contains(w.Body.String(), "not configured") {
|
||
|
|
t.Errorf("expected 'not configured' marker: %s", w.Body.String())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_WrongMethodIsRejected(t *testing.T) {
|
||
|
|
// Go 1.22+ mux: POST /healthz should not dispatch to the GET handler.
|
||
|
|
s := &httpserver.Server{Log: brokerlog.Discard()}
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/healthz", nil))
|
||
|
|
if w.Code != http.StatusMethodNotAllowed {
|
||
|
|
t.Errorf("POST /healthz should return 405, got %d", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_ExtraHandlerReceivesOtherPaths(t *testing.T) {
|
||
|
|
extra := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
_, _ = io.WriteString(w, "extra:"+r.URL.Path)
|
||
|
|
})
|
||
|
|
s := &httpserver.Server{Log: brokerlog.Discard(), ExtraHandler: extra}
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
s.Handler().ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/other", nil))
|
||
|
|
if !strings.Contains(w.Body.String(), "extra:/other") {
|
||
|
|
t.Errorf("extra handler not invoked, got %q", w.Body.String())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRun_ShutdownOnContextCancel(t *testing.T) {
|
||
|
|
addr := freeAddr(t)
|
||
|
|
s := &httpserver.Server{Addr: addr, Log: brokerlog.Discard(), Store: &fakePinger{}}
|
||
|
|
|
||
|
|
ctx, cancel := context.WithCancel(t.Context())
|
||
|
|
runErr := make(chan error, 1)
|
||
|
|
go func() { runErr <- s.Run(ctx) }()
|
||
|
|
|
||
|
|
// Wait for the listener to be ready.
|
||
|
|
waitReady(t, addr, 2*time.Second)
|
||
|
|
|
||
|
|
// Sanity: /healthz works while running.
|
||
|
|
resp, err := http.Get("http://" + addr + "/healthz")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("pre-shutdown GET: %v", err)
|
||
|
|
}
|
||
|
|
_ = resp.Body.Close()
|
||
|
|
|
||
|
|
// Cancel the context to simulate SIGTERM.
|
||
|
|
start := time.Now()
|
||
|
|
cancel()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-runErr:
|
||
|
|
if err != nil {
|
||
|
|
t.Errorf("Run returned error: %v", err)
|
||
|
|
}
|
||
|
|
if elapsed := time.Since(start); elapsed > 2*time.Second {
|
||
|
|
t.Errorf("Run took %s to shut down, want < 2s", elapsed)
|
||
|
|
}
|
||
|
|
case <-time.After(3 * time.Second):
|
||
|
|
t.Fatal("Run did not return within 3s of cancel")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRun_ShutdownTimeout_ForciblyClosesSlowRequests(t *testing.T) {
|
||
|
|
addr := freeAddr(t)
|
||
|
|
|
||
|
|
// An extra handler that blocks until the test releases it — simulates
|
||
|
|
// a slow in-flight request that outlives the shutdown deadline.
|
||
|
|
release := make(chan struct{})
|
||
|
|
startedServing := make(chan struct{})
|
||
|
|
var once sync.Once
|
||
|
|
extra := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
once.Do(func() { close(startedServing) })
|
||
|
|
select {
|
||
|
|
case <-release:
|
||
|
|
case <-r.Context().Done():
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
s := &httpserver.Server{
|
||
|
|
Addr: addr,
|
||
|
|
Log: brokerlog.Discard(),
|
||
|
|
ExtraHandler: extra,
|
||
|
|
ShutdownTimeout: 200 * time.Millisecond,
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx, cancel := context.WithCancel(t.Context())
|
||
|
|
runErr := make(chan error, 1)
|
||
|
|
go func() { runErr <- s.Run(ctx) }()
|
||
|
|
waitReady(t, addr, 2*time.Second)
|
||
|
|
|
||
|
|
// Fire a request that will hang in the handler.
|
||
|
|
reqDone := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
resp, err := http.Get("http://" + addr + "/slow")
|
||
|
|
if resp != nil {
|
||
|
|
_ = resp.Body.Close()
|
||
|
|
}
|
||
|
|
reqDone <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
<-startedServing // handler is blocking
|
||
|
|
cancel() // shutdown fires; handler won't return voluntarily
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-runErr:
|
||
|
|
// Shutdown should complete (with an error reporting deadline breach)
|
||
|
|
// within about the timeout + small wiggle.
|
||
|
|
if err == nil {
|
||
|
|
t.Logf("Run returned nil — http.Server closed the conn via ctx.Done() cascade; acceptable")
|
||
|
|
} else if !strings.Contains(err.Error(), "shutdown") {
|
||
|
|
t.Errorf("unexpected Run error: %v", err)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("Run did not return within 2s of cancel")
|
||
|
|
}
|
||
|
|
|
||
|
|
// The hanging request's conn should have been forcibly closed.
|
||
|
|
select {
|
||
|
|
case err := <-reqDone:
|
||
|
|
if err == nil {
|
||
|
|
t.Error("slow request should have been terminated by shutdown")
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("slow request did not terminate after shutdown")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Drain the handler so no goroutine leaks past the test.
|
||
|
|
close(release)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRun_MissingFieldsErr(t *testing.T) {
|
||
|
|
cases := []struct {
|
||
|
|
name string
|
||
|
|
server *httpserver.Server
|
||
|
|
want string
|
||
|
|
}{
|
||
|
|
{"no_log", &httpserver.Server{Addr: ":0"}, "Log"},
|
||
|
|
{"no_addr", &httpserver.Server{Log: brokerlog.Discard()}, "Addr"},
|
||
|
|
}
|
||
|
|
for _, tc := range cases {
|
||
|
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
|
err := tc.server.Run(t.Context())
|
||
|
|
if err == nil || !strings.Contains(err.Error(), tc.want) {
|
||
|
|
t.Errorf("want error containing %q, got %v", tc.want, err)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// freeAddr returns a loopback "host:port" with a port chosen by the kernel.
|
||
|
|
// The listener is closed immediately, so there is a tiny race window before
|
||
|
|
// the caller rebinds — acceptable for loopback test use.
|
||
|
|
func freeAddr(t *testing.T) string {
|
||
|
|
t.Helper()
|
||
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("listen: %v", err)
|
||
|
|
}
|
||
|
|
addr := l.Addr().String()
|
||
|
|
_ = l.Close()
|
||
|
|
return addr
|
||
|
|
}
|
||
|
|
|
||
|
|
// waitReady polls the target address until a TCP dial succeeds or the
|
||
|
|
// deadline expires.
|
||
|
|
func waitReady(t *testing.T, addr string, within time.Duration) {
|
||
|
|
t.Helper()
|
||
|
|
deadline := time.Now().Add(within)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
c, err := net.DialTimeout("tcp", addr, 50*time.Millisecond)
|
||
|
|
if err == nil {
|
||
|
|
_ = c.Close()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
time.Sleep(10 * time.Millisecond)
|
||
|
|
}
|
||
|
|
t.Fatalf("server not reachable at %s within %s", addr, within)
|
||
|
|
}
|
||
|
|
|