package http import ( "fmt" "net/http" "sync" "tercul/internal/platform/config" "tercul/internal/platform/log" "time" ) // Canonical token bucket implementation for strict burst/rate enforcement // Each client has a bucket with up to 'capacity' tokens, refilled at 'rate' tokens/sec // On each request, refill tokens based on elapsed time, allow only if tokens >= 1 type RateLimiter struct { tokens map[string]float64 // tokens per client lastRefill map[string]time.Time // last refill time per client rate float64 // tokens per second capacity float64 // maximum tokens mu sync.Mutex // mutex for concurrent access } // NewRateLimiter creates a new rate limiter func NewRateLimiter(cfg *config.Config) *RateLimiter { rate := cfg.RateLimit if rate <= 0 { rate = 10 // default rate: 10 requests per second } capacity := cfg.RateLimitBurst if capacity <= 0 { capacity = 100 // default capacity: 100 tokens } return &RateLimiter{ tokens: make(map[string]float64), lastRefill: make(map[string]time.Time), rate: float64(rate), capacity: float64(capacity), } } // Allow checks if a request is allowed based on the client's IP func (rl *RateLimiter) Allow(clientIP string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() // Initialize bucket for new client if _, exists := rl.tokens[clientIP]; !exists { rl.tokens[clientIP] = rl.capacity rl.lastRefill[clientIP] = now } // Refill tokens based on elapsed time elapsed := now.Sub(rl.lastRefill[clientIP]).Seconds() refill := elapsed * rl.rate if refill > 0 { rl.tokens[clientIP] = minF(rl.capacity, rl.tokens[clientIP]+refill) rl.lastRefill[clientIP] = now } if rl.tokens[clientIP] >= 1 { rl.tokens[clientIP] -= 1 return true } return false } // minF returns the minimum of two float64s func minF(a, b float64) float64 { if a < b { return a } return b } // RateLimitMiddleware creates a middleware that applies rate limiting func RateLimitMiddleware(cfg *config.Config) func(http.Handler) http.Handler { rateLimiter := NewRateLimiter(cfg) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Use X-Client-ID header for client identification in tests clientID := r.Header.Get("X-Client-ID") if clientID == "" { clientID = r.RemoteAddr } // Check if request is allowed if !rateLimiter.Allow(clientID) { log.FromContext(r.Context()). With("clientID", clientID). Warn("Rate limit exceeded") w.WriteHeader(http.StatusTooManyRequests) if _, err := w.Write([]byte("Rate limit exceeded. Please try again later.")); err != nil { // We can't write the body, but the header has been sent. // Log the error for observability. log.FromContext(r.Context()).Error(err, fmt.Sprintf("Failed to write rate limit response body for clientID %s", clientID)) } return } // Continue to the next handler next.ServeHTTP(w, r) }) } }