tercul-backend/internal/platform/http/rate_limiter.go
google-labs-jules[bot] 6d40b4c686 feat: add security middleware, graphql apq, and improved linting
- Add RateLimit, RequestValidation, and CORS middleware.
- Configure middleware chain in API server.
- Implement Redis cache for GraphQL Automatic Persisted Queries.
- Add .golangci.yml and fix linting issues (shadowing, timeouts).
2025-11-30 21:17:43 +00:00

111 lines
3.0 KiB
Go

// Package http provides HTTP middleware and utilities.
package http
import (
"fmt"
"net/http"
"sync"
"time"
"tercul/internal/platform/config"
"tercul/internal/platform/log"
)
// Canonical token bucket implementation for strict burst/rate enforcement
// Each client has a bucket with up to 'capacity' tokens, refilled at 'rate' tokens/sec
// On each request, refill tokens based on elapsed time, allow only if tokens >= 1
type RateLimiter struct {
tokens map[string]float64 // tokens per client
lastRefill map[string]time.Time // last refill time per client
rate float64 // tokens per second
capacity float64 // maximum tokens
mu sync.Mutex // mutex for concurrent access
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(cfg *config.Config) *RateLimiter {
rate := cfg.RateLimit
if rate <= 0 {
rate = 10 // default rate: 10 requests per second
}
capacity := cfg.RateLimitBurst
if capacity <= 0 {
capacity = 100 // default capacity: 100 tokens
}
return &RateLimiter{
tokens: make(map[string]float64),
lastRefill: make(map[string]time.Time),
rate: float64(rate),
capacity: float64(capacity),
}
}
// Allow checks if a request is allowed based on the client's IP
func (rl *RateLimiter) Allow(clientIP string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
// Initialize bucket for new client
if _, exists := rl.tokens[clientIP]; !exists {
rl.tokens[clientIP] = rl.capacity
rl.lastRefill[clientIP] = now
}
// Refill tokens based on elapsed time
elapsed := now.Sub(rl.lastRefill[clientIP]).Seconds()
refill := elapsed * rl.rate
if refill > 0 {
rl.tokens[clientIP] = minF(rl.capacity, rl.tokens[clientIP]+refill)
rl.lastRefill[clientIP] = now
}
if rl.tokens[clientIP] >= 1 {
rl.tokens[clientIP] -= 1
return true
}
return false
}
// minF returns the minimum of two float64s
func minF(a, b float64) float64 {
if a < b {
return a
}
return b
}
// RateLimitMiddleware creates a middleware that applies rate limiting
func RateLimitMiddleware(cfg *config.Config) func(http.Handler) http.Handler {
rateLimiter := NewRateLimiter(cfg)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use X-Client-ID header for client identification in tests
clientID := r.Header.Get("X-Client-ID")
if clientID == "" {
clientID = r.RemoteAddr
}
// Check if request is allowed
if !rateLimiter.Allow(clientID) {
log.FromContext(r.Context()).
With("clientID", clientID).
Warn("Rate limit exceeded")
w.WriteHeader(http.StatusTooManyRequests)
if _, err := w.Write([]byte("Rate limit exceeded. Please try again later.")); err != nil {
// We can't write the body, but the header has been sent.
// Log the error for observability.
log.FromContext(r.Context()).Error(err, fmt.Sprintf("Failed to write rate limit response body for clientID %s", clientID))
}
return
}
// Continue to the next handler
next.ServeHTTP(w, r)
})
}
}