89 lines
1.8 KiB
Go
89 lines
1.8 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
// RateLimiter implements a simple per-IP token bucket rate limiter for
|
||
|
|
// protecting auth endpoints from brute-force attacks.
|
||
|
|
type RateLimiter struct {
|
||
|
|
mu sync.Mutex
|
||
|
|
visitors map[string]*bucket
|
||
|
|
rate int
|
||
|
|
window time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
type bucket struct {
|
||
|
|
tokens int
|
||
|
|
lastReset time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewRateLimiter creates a rate limiter that allows `rate` requests per minute per IP.
|
||
|
|
func NewRateLimiter(rate int) *RateLimiter {
|
||
|
|
return &RateLimiter{
|
||
|
|
visitors: make(map[string]*bucket),
|
||
|
|
rate: rate,
|
||
|
|
window: time.Minute,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Limit wraps a handler with rate limiting based on the real client IP.
|
||
|
|
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
ip := RealIPFromContext(r.Context())
|
||
|
|
if ip == "" {
|
||
|
|
ip = r.RemoteAddr
|
||
|
|
}
|
||
|
|
|
||
|
|
if !rl.allow(ip) {
|
||
|
|
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rl *RateLimiter) allow(ip string) bool {
|
||
|
|
rl.mu.Lock()
|
||
|
|
defer rl.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
|
||
|
|
b, ok := rl.visitors[ip]
|
||
|
|
if !ok {
|
||
|
|
rl.visitors[ip] = &bucket{tokens: rl.rate - 1, lastReset: now}
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// Reset tokens if the window has passed.
|
||
|
|
if now.Sub(b.lastReset) >= rl.window {
|
||
|
|
b.tokens = rl.rate - 1
|
||
|
|
b.lastReset = now
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
if b.tokens <= 0 {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
b.tokens--
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// Cleanup removes stale entries. Call periodically from a goroutine.
|
||
|
|
func (rl *RateLimiter) Cleanup() {
|
||
|
|
rl.mu.Lock()
|
||
|
|
defer rl.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
for ip, b := range rl.visitors {
|
||
|
|
if now.Sub(b.lastReset) >= 2*rl.window {
|
||
|
|
delete(rl.visitors, ip)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|