mirror of
https://github.com/SamyRai/turash.git
synced 2025-12-26 23:01:33 +00:00
115 lines
3.8 KiB
Go
115 lines
3.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"bugulma/backend/internal/service"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type inMemorySettingsRepo struct {
|
|
data map[string]map[string]any
|
|
}
|
|
|
|
func newInMemorySettingsRepo() *inMemorySettingsRepo {
|
|
return &inMemorySettingsRepo{data: map[string]map[string]any{}}
|
|
}
|
|
func (r *inMemorySettingsRepo) Get(_ context.Context, key string) (map[string]any, error) {
|
|
return r.data[key], nil
|
|
}
|
|
func (r *inMemorySettingsRepo) Set(_ context.Context, key string, value map[string]any) error {
|
|
r.data[key] = value
|
|
return nil
|
|
}
|
|
|
|
func TestMaintenanceMiddleware_BlocksAndAllows(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
repo := newInMemorySettingsRepo()
|
|
svc := service.NewSettingsService(repo)
|
|
|
|
// Ensure maintenance disabled by default
|
|
r := gin.New()
|
|
r.Use(MaintenanceMiddleware(svc, []string{"/api/v1/admin", "/health"}))
|
|
r.GET("/api/v1/resource", func(c *gin.Context) {
|
|
c.Header("X-Client-IP", c.ClientIP())
|
|
c.JSON(200, gin.H{"ok": true})
|
|
})
|
|
r.GET("/api/v1/admin/settings", func(c *gin.Context) { c.JSON(200, gin.H{"admin": true}) })
|
|
|
|
// Disabled: request should succeed and header indicate false
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/resource", nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 when disabled: got %d", w.Code)
|
|
}
|
|
if got := w.Header().Get("X-Maintenance"); got != "false" {
|
|
t.Fatalf("expected X-Maintenance=false, got %s", got)
|
|
}
|
|
|
|
// Enable maintenance
|
|
if err := svc.SetMaintenance(context.Background(), &service.MaintenanceSetting{Enabled: true, Message: "maintenance in progress"}); err != nil {
|
|
t.Fatalf("set maintenance: %v", err)
|
|
}
|
|
|
|
// Non-whitelisted path should be blocked
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/api/v1/resource", nil)
|
|
req.Header.Set("X-Real-IP", "127.0.0.1")
|
|
req.Header.Set("X-Forwarded-For", "127.0.0.1")
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected 503 when enabled: got %d", w.Code)
|
|
}
|
|
if got := w.Header().Get("X-Maintenance"); got != "true" {
|
|
t.Fatalf("expected X-Maintenance=true, got %s", got)
|
|
}
|
|
var body map[string]any
|
|
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("unmarshal: %v", err)
|
|
}
|
|
if body["maintenance"] != true {
|
|
t.Fatalf("expected maintenance:true body, got %#v", body)
|
|
}
|
|
|
|
// Whitelisted path should be allowed and header present
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 for whitelisted path, got %d", w.Code)
|
|
}
|
|
if got := w.Header().Get("X-Maintenance"); got != "true" {
|
|
t.Fatalf("expected X-Maintenance=true for whitelisted, got %s", got)
|
|
}
|
|
|
|
// Allowed IP should bypass maintenance
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/api/v1/resource", nil)
|
|
// Set maintenance with allowed IPs to include loopback
|
|
if err := svc.SetMaintenance(context.Background(), &service.MaintenanceSetting{Enabled: true, Message: "maintenance", AllowedIPs: []string{"127.0.0.1"}}); err != nil {
|
|
t.Fatalf("set maintenance with allowed ips: %v", err)
|
|
}
|
|
// Ensure the settings service returned the allowed IPs (cache updated)
|
|
m, err := svc.GetMaintenance(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("get maintenance: %v", err)
|
|
}
|
|
if len(m.AllowedIPs) == 0 || m.AllowedIPs[0] != "127.0.0.1" {
|
|
t.Fatalf("expected allowed ip to be present, got %#v", m.AllowedIPs)
|
|
}
|
|
req.Header.Set("X-Real-IP", "127.0.0.1")
|
|
req.Header.Set("X-Forwarded-For", "127.0.0.1")
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 for allowed ip, got %d", w.Code)
|
|
}
|
|
}
|