package auth import ( "context" "net/http" "strings" "tercul/internal/observability" "tercul/internal/platform/log" ) // ContextKey is a type for context keys type ContextKey string const ( // UserContextKey is the key for user in context UserContextKey ContextKey = "user" // ClaimsContextKey is the key for claims in context ClaimsContextKey ContextKey = "claims" ) // AuthMiddleware creates middleware for JWT authentication func AuthMiddleware(jwtManager *JWTManager) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := log.FromContext(r.Context()) // Skip authentication for certain paths if shouldSkipAuth(r.URL.Path) { next.ServeHTTP(w, r) return } // Extract token from Authorization header authHeader := r.Header.Get("Authorization") tokenString, err := jwtManager.ExtractTokenFromHeader(authHeader) if err != nil { logger.Warn("Authentication failed - missing or invalid token") http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // Validate token claims, err := jwtManager.ValidateToken(tokenString) if err != nil { logger.Warn("Authentication failed - invalid token") http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // Add claims to context ctx := context.WithValue(r.Context(), ClaimsContextKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RoleMiddleware creates middleware for role-based authorization func RoleMiddleware(jwtManager *JWTManager, requiredRole string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := log.FromContext(r.Context()) claims, ok := r.Context().Value(ClaimsContextKey).(*Claims) if !ok { logger.Warn("Authorization failed - no claims in context") http.Error(w, "Forbidden", http.StatusForbidden) return } if err := jwtManager.RequireRole(claims.Role, requiredRole); err != nil { logger.With("user_role", claims.Role).With("required_role", requiredRole).Warn("Authorization failed - insufficient role") http.Error(w, "Forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } } // GraphQLAuthMiddleware creates middleware specifically for GraphQL requests func GraphQLAuthMiddleware(jwtManager *JWTManager) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := log.FromContext(r.Context()) authHeader := r.Header.Get("Authorization") if authHeader == "" { next.ServeHTTP(w, r) return } tokenString, err := jwtManager.ExtractTokenFromHeader(authHeader) if err != nil { logger.Error(err, "GraphQL authentication failed - could not extract token") next.ServeHTTP(w, r) // Proceed without auth return } claims, err := jwtManager.ValidateToken(tokenString) if err != nil { logger.Error(err, "GraphQL authentication failed - invalid token") next.ServeHTTP(w, r) // Proceed without auth return } // Add claims and enriched logger to context for authenticated requests ctx := context.WithValue(r.Context(), ClaimsContextKey, claims) enrichedLogger := logger.With("user_id", claims.UserID) ctx = context.WithValue(ctx, observability.LoggerContextKey, enrichedLogger) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // GetClaimsFromContext extracts claims from context func GetClaimsFromContext(ctx context.Context) (*Claims, bool) { claims, ok := ctx.Value(ClaimsContextKey).(*Claims) return claims, ok } // GetUserIDFromContext extracts user ID from context func GetUserIDFromContext(ctx context.Context) (uint, bool) { claims, ok := GetClaimsFromContext(ctx) if !ok { return 0, false } return claims.UserID, true } // IsAuthenticated checks if the request is authenticated func IsAuthenticated(ctx context.Context) bool { _, ok := GetClaimsFromContext(ctx) return ok } // RequireAuth ensures the request is authenticated func RequireAuth(ctx context.Context) (*Claims, error) { claims, ok := GetClaimsFromContext(ctx) if !ok { return nil, ErrMissingToken } return claims, nil } // RequireRole ensures the user has the required role func RequireRole(ctx context.Context, jwtManager *JWTManager, requiredRole string) (*Claims, error) { claims, err := RequireAuth(ctx) if err != nil { return nil, err } if err := jwtManager.RequireRole(claims.Role, requiredRole); err != nil { return nil, err } return claims, nil } // shouldSkipAuth determines if authentication should be skipped for a path func shouldSkipAuth(path string) bool { skipPaths := []string{ "/", "/query", "/health", "/metrics", "/favicon.ico", } for _, skipPath := range skipPaths { if path == skipPath { return true } } // Skip static files if strings.HasPrefix(path, "/static/") { return true } return false } // ContextWithUserID adds a user ID to the context for testing purposes. func ContextWithUserID(ctx context.Context, userID uint) context.Context { claims := &Claims{UserID: userID} return context.WithValue(ctx, ClaimsContextKey, claims) } // ContextWithAdminUser adds an admin user to the context for testing purposes. func ContextWithAdminUser(ctx context.Context, userID uint) context.Context { claims := &Claims{ UserID: userID, Role: "admin", } return context.WithValue(ctx, ClaimsContextKey, claims) }