package middleware_test import ( "net/http" "net/http/httptest" "testing" "time" "tercul/config" "tercul/middleware" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) // RateLimiterSuite is a test suite for the RateLimiter type RateLimiterSuite struct { suite.Suite } // TestRateLimiter tests the RateLimiter func (s *RateLimiterSuite) TestRateLimiter() { // Create a new rate limiter with 2 requests per second and a burst of 3 limiter := middleware.NewRateLimiter(2, 3) // Test that the first 3 requests are allowed (burst) for i := 0; i < 3; i++ { allowed := limiter.Allow("test-client") s.True(allowed, "Request %d should be allowed (burst)", i+1) } // Test that the 4th request is not allowed (burst exceeded) allowed := limiter.Allow("test-client") s.False(allowed, "Request 4 should not be allowed (burst exceeded)") // Wait for 1 second to allow the rate limiter to refill time.Sleep(1 * time.Second) // Test that the next 2 requests are allowed (rate) for i := 0; i < 2; i++ { allowed := limiter.Allow("test-client") s.True(allowed, "Request %d after wait should be allowed (rate)", i+1) } // Test that the 3rd request after wait is not allowed (rate exceeded) allowed = limiter.Allow("test-client") s.False(allowed, "Request 3 after wait should not be allowed (rate exceeded)") } // TestRateLimiterMultipleClients tests the RateLimiter with multiple clients func (s *RateLimiterSuite) TestRateLimiterMultipleClients() { // Create a new rate limiter with 2 requests per second and a burst of 3 limiter := middleware.NewRateLimiter(2, 3) // Test that the first 3 requests for client1 are allowed (burst) for i := 0; i < 3; i++ { allowed := limiter.Allow("client1") s.True(allowed, "Request %d for client1 should be allowed (burst)", i+1) } // Test that the first 3 requests for client2 are allowed (burst) for i := 0; i < 3; i++ { allowed := limiter.Allow("client2") s.True(allowed, "Request %d for client2 should be allowed (burst)", i+1) } // Test that the 4th request for client1 is not allowed (burst exceeded) allowed := limiter.Allow("client1") s.False(allowed, "Request 4 for client1 should not be allowed (burst exceeded)") // Test that the 4th request for client2 is not allowed (burst exceeded) allowed = limiter.Allow("client2") s.False(allowed, "Request 4 for client2 should not be allowed (burst exceeded)") } // TestRateLimiterMiddleware tests the RateLimiterMiddleware func (s *RateLimiterSuite) TestRateLimiterMiddleware() { // Set config to match test expectations config.Cfg.RateLimit = 2 config.Cfg.RateLimitBurst = 3 // Create a test handler that always returns 200 OK testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Create a rate limiter middleware with 2 requests per second and a burst of 3 middleware := middleware.RateLimitMiddleware(testHandler) // Create a test server server := httptest.NewServer(middleware) defer server.Close() // Create a test client client := server.Client() // Use a static client IP for all requests staticID := "test-client-id" // Test that the first 3 requests are allowed (burst) for i := 0; i < 3; i++ { req, _ := http.NewRequest("GET", server.URL, nil) req.Header.Set("X-Client-ID", staticID) resp, err := client.Do(req) s.Require().NoError(err) s.Equal(http.StatusOK, resp.StatusCode, "Request %d should be allowed (burst)", i+1) resp.Body.Close() } // Test that the 4th request is not allowed (burst exceeded) req, _ := http.NewRequest("GET", server.URL, nil) req.Header.Set("X-Client-ID", staticID) resp, err := client.Do(req) s.Require().NoError(err) s.Equal(http.StatusTooManyRequests, resp.StatusCode, "Request 4 should not be allowed (burst exceeded)") resp.Body.Close() // Wait for 1.1 seconds to allow the rate limiter to refill (ensure >1 token) time.Sleep(1100 * time.Millisecond) // Test that the next 2 requests are allowed (rate) for i := 0; i < 2; i++ { req, _ := http.NewRequest("GET", server.URL, nil) req.Header.Set("X-Client-ID", staticID) resp, err := client.Do(req) s.Require().NoError(err) s.Equal(http.StatusOK, resp.StatusCode, "Request %d after wait should be allowed (rate)", i+1) resp.Body.Close() } // Test that the 3rd request after wait is not allowed (rate exceeded) req, _ = http.NewRequest("GET", server.URL, nil) req.Header.Set("X-Client-ID", staticID) resp, err = client.Do(req) s.Require().NoError(err) s.Equal(http.StatusTooManyRequests, resp.StatusCode, "Request 3 after wait should not be allowed (rate exceeded)") resp.Body.Close() } // TestRateLimiterSuite runs the test suite func TestRateLimiterSuite(t *testing.T) { suite.Run(t, new(RateLimiterSuite)) } // TestNewRateLimiter tests the NewRateLimiter function func TestNewRateLimiter(t *testing.T) { // Test with valid parameters limiter := middleware.NewRateLimiter(10, 20) assert.NotNil(t, limiter, "NewRateLimiter should return a non-nil limiter") // Test with zero rate (should use default) limiter = middleware.NewRateLimiter(0, 20) assert.NotNil(t, limiter, "NewRateLimiter should return a non-nil limiter with default rate") // Test with zero capacity (should use default) limiter = middleware.NewRateLimiter(10, 0) assert.NotNil(t, limiter, "NewRateLimiter should return a non-nil limiter with default capacity") // Test with negative rate (should use default) limiter = middleware.NewRateLimiter(-10, 20) assert.NotNil(t, limiter, "NewRateLimiter should return a non-nil limiter with default rate") // Test with negative capacity (should use default) limiter = middleware.NewRateLimiter(10, -20) assert.NotNil(t, limiter, "NewRateLimiter should return a non-nil limiter with default capacity") }