package middleware import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "bugulma/backend/internal/service" "github.com/gin-gonic/gin" ) type inMemorySettingsRepo struct { data map[string]map[string]any } func newInMemorySettingsRepo() *inMemorySettingsRepo { return &inMemorySettingsRepo{data: map[string]map[string]any{}} } func (r *inMemorySettingsRepo) Get(_ context.Context, key string) (map[string]any, error) { return r.data[key], nil } func (r *inMemorySettingsRepo) Set(_ context.Context, key string, value map[string]any) error { r.data[key] = value return nil } func TestMaintenanceMiddleware_BlocksAndAllows(t *testing.T) { gin.SetMode(gin.TestMode) repo := newInMemorySettingsRepo() svc := service.NewSettingsService(repo) // Ensure maintenance disabled by default r := gin.New() r.Use(MaintenanceMiddleware(svc, []string{"/api/v1/admin", "/health"})) r.GET("/api/v1/resource", func(c *gin.Context) { c.Header("X-Client-IP", c.ClientIP()) c.JSON(200, gin.H{"ok": true}) }) r.GET("/api/v1/admin/settings", func(c *gin.Context) { c.JSON(200, gin.H{"admin": true}) }) // Disabled: request should succeed and header indicate false w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/v1/resource", nil) req.RemoteAddr = "127.0.0.1:12345" r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 when disabled: got %d", w.Code) } if got := w.Header().Get("X-Maintenance"); got != "false" { t.Fatalf("expected X-Maintenance=false, got %s", got) } // Enable maintenance if err := svc.SetMaintenance(context.Background(), &service.MaintenanceSetting{Enabled: true, Message: "maintenance in progress"}); err != nil { t.Fatalf("set maintenance: %v", err) } // Non-whitelisted path should be blocked w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/resource", nil) req.Header.Set("X-Real-IP", "127.0.0.1") req.Header.Set("X-Forwarded-For", "127.0.0.1") r.ServeHTTP(w, req) if w.Code != http.StatusServiceUnavailable { t.Fatalf("expected 503 when enabled: got %d", w.Code) } if got := w.Header().Get("X-Maintenance"); got != "true" { t.Fatalf("expected X-Maintenance=true, got %s", got) } var body map[string]any if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { t.Fatalf("unmarshal: %v", err) } if body["maintenance"] != true { t.Fatalf("expected maintenance:true body, got %#v", body) } // Whitelisted path should be allowed and header present w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 for whitelisted path, got %d", w.Code) } if got := w.Header().Get("X-Maintenance"); got != "true" { t.Fatalf("expected X-Maintenance=true for whitelisted, got %s", got) } // Allowed IP should bypass maintenance w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/resource", nil) // Set maintenance with allowed IPs to include loopback if err := svc.SetMaintenance(context.Background(), &service.MaintenanceSetting{Enabled: true, Message: "maintenance", AllowedIPs: []string{"127.0.0.1"}}); err != nil { t.Fatalf("set maintenance with allowed ips: %v", err) } // Ensure the settings service returned the allowed IPs (cache updated) m, err := svc.GetMaintenance(context.Background()) if err != nil { t.Fatalf("get maintenance: %v", err) } if len(m.AllowedIPs) == 0 || m.AllowedIPs[0] != "127.0.0.1" { t.Fatalf("expected allowed ip to be present, got %#v", m.AllowedIPs) } req.Header.Set("X-Real-IP", "127.0.0.1") req.Header.Set("X-Forwarded-For", "127.0.0.1") req.RemoteAddr = "127.0.0.1:12345" r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 for allowed ip, got %d", w.Code) } }