mirror of
https://github.com/SamyRai/tercul-backend.git
synced 2025-12-27 05:11:34 +00:00
feat: add security middleware, graphql apq, and improved linting
- Add RateLimit, RequestValidation, and CORS middleware. - Configure middleware chain in API server. - Implement Redis cache for GraphQL Automatic Persisted Queries. - Add .golangci.yml and fix linting issues (shadowing, timeouts).
This commit is contained in:
parent
be97b587b2
commit
6d40b4c686
59
.golangci.yml
Normal file
59
.golangci.yml
Normal file
@ -0,0 +1,59 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
|
||||
linters-settings:
|
||||
govet:
|
||||
check-shadowing: true
|
||||
gocyclo:
|
||||
min-complexity: 15
|
||||
goconst:
|
||||
min-len: 2
|
||||
min-occurrences: 3
|
||||
misspell:
|
||||
locale: US
|
||||
lll:
|
||||
line-length: 140
|
||||
goimports:
|
||||
local-prefixes: tercul
|
||||
gocritic:
|
||||
enabled-tags:
|
||||
- diagnostic
|
||||
- performance
|
||||
- style
|
||||
disabled-checks:
|
||||
- wrapperFunc
|
||||
- ifElseChain
|
||||
- octalLiteral
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- bodyclose
|
||||
- errcheck
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- gofmt
|
||||
- goimports
|
||||
- gosec
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- lll
|
||||
- misspell
|
||||
- nakedret
|
||||
- noctx
|
||||
- nolintlint
|
||||
- staticcheck
|
||||
- stylecheck
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- whitespace
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
@ -8,6 +8,8 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"tercul/internal/adapters/graphql"
|
||||
"tercul/internal/app"
|
||||
"tercul/internal/app/analytics"
|
||||
@ -31,12 +33,13 @@ import (
|
||||
"tercul/internal/jobs/linguistics"
|
||||
"tercul/internal/observability"
|
||||
platform_auth "tercul/internal/platform/auth"
|
||||
"tercul/internal/platform/cache"
|
||||
"tercul/internal/platform/config"
|
||||
"tercul/internal/platform/db"
|
||||
app_log "tercul/internal/platform/log"
|
||||
"tercul/internal/platform/search"
|
||||
"time"
|
||||
|
||||
gql "github.com/99designs/gqlgen/graphql"
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
||||
@ -75,13 +78,13 @@ func main() {
|
||||
obsLogger := observability.NewLogger("tercul-api", cfg.Environment)
|
||||
|
||||
// Initialize OpenTelemetry Tracer Provider
|
||||
tp, err := observability.TracerProvider("tercul-api", cfg.Environment)
|
||||
if err != nil {
|
||||
app_log.Fatal(err, "Failed to initialize OpenTelemetry tracer")
|
||||
tp, traceErr := observability.TracerProvider("tercul-api", cfg.Environment)
|
||||
if traceErr != nil {
|
||||
app_log.Fatal(traceErr, "Failed to initialize OpenTelemetry tracer")
|
||||
}
|
||||
defer func() {
|
||||
if err := tp.Shutdown(context.Background()); err != nil {
|
||||
app_log.Error(err, "Error shutting down tracer provider")
|
||||
if shutdownErr := tp.Shutdown(context.Background()); shutdownErr != nil {
|
||||
app_log.Error(shutdownErr, "Error shutting down tracer provider")
|
||||
}
|
||||
}()
|
||||
|
||||
@ -92,18 +95,18 @@ func main() {
|
||||
app_log.Info(fmt.Sprintf("Starting Tercul application in %s environment, version 1.0.0", cfg.Environment))
|
||||
|
||||
// Initialize database connection
|
||||
database, err := db.InitDB(cfg, metrics)
|
||||
if err != nil {
|
||||
app_log.Fatal(err, "Failed to initialize database")
|
||||
database, dbErr := db.InitDB(cfg, metrics)
|
||||
if dbErr != nil {
|
||||
app_log.Fatal(dbErr, "Failed to initialize database")
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(database); err != nil {
|
||||
app_log.Error(err, "Error closing database")
|
||||
if closeErr := db.Close(database); closeErr != nil {
|
||||
app_log.Error(closeErr, "Error closing database")
|
||||
}
|
||||
}()
|
||||
|
||||
if err := runMigrations(database, cfg.MigrationPath); err != nil {
|
||||
app_log.Fatal(err, "Failed to apply database migrations")
|
||||
if migErr := runMigrations(database, cfg.MigrationPath); migErr != nil {
|
||||
app_log.Fatal(migErr, "Failed to apply database migrations")
|
||||
}
|
||||
|
||||
// Initialize Weaviate client
|
||||
@ -111,9 +114,9 @@ func main() {
|
||||
Host: cfg.WeaviateHost,
|
||||
Scheme: cfg.WeaviateScheme,
|
||||
}
|
||||
weaviateClient, err := weaviate.NewClient(weaviateCfg)
|
||||
if err != nil {
|
||||
app_log.Fatal(err, "Failed to create weaviate client")
|
||||
weaviateClient, wErr := weaviate.NewClient(weaviateCfg)
|
||||
if wErr != nil {
|
||||
app_log.Fatal(wErr, "Failed to create weaviate client")
|
||||
}
|
||||
|
||||
// Create search client
|
||||
@ -124,9 +127,9 @@ func main() {
|
||||
|
||||
// Create linguistics dependencies
|
||||
analysisRepo := linguistics.NewGORMAnalysisRepository(database)
|
||||
sentimentProvider, err := linguistics.NewGoVADERSentimentProvider()
|
||||
if err != nil {
|
||||
app_log.Fatal(err, "Failed to create sentiment provider")
|
||||
sentimentProvider, sErr := linguistics.NewGoVADERSentimentProvider()
|
||||
if sErr != nil {
|
||||
app_log.Fatal(sErr, "Failed to create sentiment provider")
|
||||
}
|
||||
|
||||
// Create platform components
|
||||
@ -178,13 +181,24 @@ func main() {
|
||||
App: application,
|
||||
}
|
||||
|
||||
// Initialize Redis Cache for APQ
|
||||
redisCache, cacheErr := cache.NewDefaultRedisCache(cfg)
|
||||
var queryCache gql.Cache[string]
|
||||
if cacheErr != nil {
|
||||
app_log.Warn("Redis cache initialization failed, APQ disabled: " + cacheErr.Error())
|
||||
} else {
|
||||
queryCache = &cache.GraphQLCacheAdapter{RedisCache: redisCache}
|
||||
app_log.Info("Redis cache initialized for APQ")
|
||||
}
|
||||
|
||||
// Create the consolidated API server with all routes.
|
||||
apiHandler := NewAPIServer(resolver, jwtManager, metrics, obsLogger, reg)
|
||||
apiHandler := NewAPIServer(cfg, resolver, queryCache, jwtManager, metrics, obsLogger, reg)
|
||||
|
||||
// Create the main HTTP server.
|
||||
mainServer := &http.Server{
|
||||
Addr: cfg.ServerPort,
|
||||
Handler: apiHandler,
|
||||
Addr: cfg.ServerPort,
|
||||
Handler: apiHandler,
|
||||
ReadHeaderTimeout: 5 * time.Second, // Gosec: Prevent Slowloris attack
|
||||
}
|
||||
app_log.Info(fmt.Sprintf("API server listening on port %s", cfg.ServerPort))
|
||||
|
||||
@ -205,8 +219,8 @@ func main() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mainServer.Shutdown(ctx); err != nil {
|
||||
app_log.Error(err, "Server forced to shutdown")
|
||||
if shutdownErr := mainServer.Shutdown(ctx); shutdownErr != nil {
|
||||
app_log.Error(shutdownErr, "Server forced to shutdown")
|
||||
}
|
||||
|
||||
app_log.Info("Server shut down successfully")
|
||||
|
||||
@ -2,11 +2,16 @@ package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"tercul/internal/adapters/graphql"
|
||||
"tercul/internal/observability"
|
||||
"tercul/internal/platform/auth"
|
||||
"tercul/internal/platform/config"
|
||||
platform_http "tercul/internal/platform/http"
|
||||
|
||||
gql "github.com/99designs/gqlgen/graphql"
|
||||
"github.com/99designs/gqlgen/graphql/handler"
|
||||
"github.com/99designs/gqlgen/graphql/handler/extension"
|
||||
"github.com/99designs/gqlgen/graphql/playground"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
@ -14,7 +19,9 @@ import (
|
||||
// NewAPIServer creates a new http.ServeMux and configures it with all the API routes,
|
||||
// including the GraphQL endpoint, GraphQL Playground, and Prometheus metrics.
|
||||
func NewAPIServer(
|
||||
cfg *config.Config,
|
||||
resolver *graphql.Resolver,
|
||||
queryCache gql.Cache[string],
|
||||
jwtManager *auth.JWTManager,
|
||||
metrics *observability.Metrics,
|
||||
logger *observability.Logger,
|
||||
@ -26,10 +33,18 @@ func NewAPIServer(
|
||||
|
||||
// Create the core GraphQL handler
|
||||
graphqlHandler := handler.New(graphql.NewExecutableSchema(c))
|
||||
|
||||
// Enable Automatic Persisted Queries (APQ) if cache is provided
|
||||
if queryCache != nil {
|
||||
graphqlHandler.Use(extension.AutomaticPersistedQuery{
|
||||
Cache: queryCache,
|
||||
})
|
||||
}
|
||||
|
||||
graphqlHandler.SetErrorPresenter(graphql.NewErrorPresenter())
|
||||
|
||||
// Create the middleware chain for the GraphQL endpoint.
|
||||
// Middlewares are applied from bottom to top.
|
||||
// Middlewares are applied from bottom to top (last applied is first executed).
|
||||
var chain http.Handler
|
||||
chain = graphqlHandler
|
||||
chain = metrics.PrometheusMiddleware(chain)
|
||||
@ -38,6 +53,14 @@ func NewAPIServer(
|
||||
chain = observability.TracingMiddleware(chain)
|
||||
chain = observability.RequestIDMiddleware(chain)
|
||||
|
||||
// Security and Validation Middlewares
|
||||
chain = platform_http.RequestValidationMiddleware(chain)
|
||||
chain = platform_http.RateLimitMiddleware(cfg)(chain)
|
||||
|
||||
// CORS should be the outermost to handle preflight OPTIONS requests
|
||||
// TODO: Make allowed origins configurable
|
||||
chain = platform_http.CORSMiddleware([]string{"*"})(chain)
|
||||
|
||||
// Create a new ServeMux and register all handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/query", chain)
|
||||
|
||||
4
go.mod
4
go.mod
@ -6,6 +6,7 @@ require (
|
||||
github.com/99designs/gqlgen v0.17.72
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2
|
||||
github.com/blevesearch/bleve/v2 v2.5.5
|
||||
github.com/go-openapi/strfmt v0.25.0
|
||||
github.com/go-playground/validator/v10 v10.28.0
|
||||
github.com/go-redis/redismock/v9 v9.2.0
|
||||
@ -19,6 +20,7 @@ require (
|
||||
github.com/prometheus/client_golang v1.20.5
|
||||
github.com/redis/go-redis/v9 v9.8.0
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go v0.40.0
|
||||
@ -48,7 +50,6 @@ require (
|
||||
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bits-and-blooms/bitset v1.22.0 // indirect
|
||||
github.com/blevesearch/bleve/v2 v2.5.5 // indirect
|
||||
github.com/blevesearch/bleve_index_api v1.2.11 // indirect
|
||||
github.com/blevesearch/geo v0.2.4 // indirect
|
||||
github.com/blevesearch/go-faiss v1.0.26 // indirect
|
||||
@ -172,7 +173,6 @@ require (
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/cobra v1.10.1 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stretchr/objx v0.5.3 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
|
||||
4
internal/platform/cache/cache.go
vendored
4
internal/platform/cache/cache.go
vendored
@ -54,7 +54,7 @@ type KeyGenerator interface {
|
||||
ListKey(entityType string, page, pageSize int) string
|
||||
|
||||
// QueryKey generates a key for a custom query
|
||||
QueryKey(entityType string, queryName string, params ...interface{}) string
|
||||
QueryKey(entityType, queryName string, params ...interface{}) string
|
||||
}
|
||||
|
||||
// DefaultKeyGenerator implements the KeyGenerator interface
|
||||
@ -83,7 +83,7 @@ func (g *DefaultKeyGenerator) ListKey(entityType string, page, pageSize int) str
|
||||
}
|
||||
|
||||
// QueryKey generates a key for a custom query
|
||||
func (g *DefaultKeyGenerator) QueryKey(entityType string, queryName string, params ...interface{}) string {
|
||||
func (g *DefaultKeyGenerator) QueryKey(entityType, queryName string, params ...interface{}) string {
|
||||
key := g.Prefix + entityType + ":" + queryName
|
||||
for _, param := range params {
|
||||
key += ":" + fmt.Sprintf("%v", param)
|
||||
|
||||
110
internal/platform/cache/cache_test.go
vendored
110
internal/platform/cache/cache_test.go
vendored
@ -1,68 +1,68 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultKeyGenerator_DefaultPrefix(t *testing.T) {
|
||||
g := NewDefaultKeyGenerator("")
|
||||
require.NotNil(t, g)
|
||||
// Table-driven tests for key generation
|
||||
tests := []struct {
|
||||
name string
|
||||
entity string
|
||||
id uint
|
||||
page int
|
||||
pageSize int
|
||||
queryName string
|
||||
params []interface{}
|
||||
wantEntity string
|
||||
wantList string
|
||||
wantQuery string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
entity: "user",
|
||||
id: 42,
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
queryName: "byEmail",
|
||||
params: []interface{}{"foo@bar.com"},
|
||||
wantEntity: "tercul:user:id:42",
|
||||
wantList: "tercul:user:list:1:20",
|
||||
wantQuery: "tercul:user:byEmail:foo@bar.com",
|
||||
},
|
||||
{
|
||||
name: "different entity and multiple params",
|
||||
entity: "work",
|
||||
id: 7,
|
||||
page: 3,
|
||||
pageSize: 15,
|
||||
queryName: "search",
|
||||
params: []interface{}{"abc", 2020, true},
|
||||
wantEntity: "tercul:work:id:7",
|
||||
wantList: "tercul:work:list:3:15",
|
||||
wantQuery: "tercul:work:search:abc:2020:true",
|
||||
},
|
||||
}
|
||||
g := NewDefaultKeyGenerator("")
|
||||
require.NotNil(t, g)
|
||||
// Table-driven tests for key generation
|
||||
tests := []struct {
|
||||
name string
|
||||
entity string
|
||||
id uint
|
||||
page int
|
||||
pageSize int
|
||||
queryName string
|
||||
params []interface{}
|
||||
wantEntity string
|
||||
wantList string
|
||||
wantQuery string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
entity: "user",
|
||||
id: 42,
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
queryName: "byEmail",
|
||||
params: []interface{}{"foo@bar.com"},
|
||||
wantEntity: "tercul:user:id:42",
|
||||
wantList: "tercul:user:list:1:20",
|
||||
wantQuery: "tercul:user:byEmail:foo@bar.com",
|
||||
},
|
||||
{
|
||||
name: "different entity and multiple params",
|
||||
entity: "work",
|
||||
id: 7,
|
||||
page: 3,
|
||||
pageSize: 15,
|
||||
queryName: "search",
|
||||
params: []interface{}{"abc", 2020, true},
|
||||
wantEntity: "tercul:work:id:7",
|
||||
wantList: "tercul:work:list:3:15",
|
||||
wantQuery: "tercul:work:search:abc:2020:true",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.wantEntity, g.EntityKey(tt.entity, tt.id))
|
||||
assert.Equal(t, tt.wantList, g.ListKey(tt.entity, tt.page, tt.pageSize))
|
||||
assert.Equal(t, tt.wantQuery, g.QueryKey(tt.entity, tt.queryName, tt.params...))
|
||||
})
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.wantEntity, g.EntityKey(tt.entity, tt.id))
|
||||
assert.Equal(t, tt.wantList, g.ListKey(tt.entity, tt.page, tt.pageSize))
|
||||
assert.Equal(t, tt.wantQuery, g.QueryKey(tt.entity, tt.queryName, tt.params...))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultKeyGenerator_CustomPrefix(t *testing.T) {
|
||||
g := NewDefaultKeyGenerator("mypfx:")
|
||||
require.NotNil(t, g)
|
||||
g := NewDefaultKeyGenerator("mypfx:")
|
||||
require.NotNil(t, g)
|
||||
|
||||
assert.Equal(t, "mypfx:book:id:1", g.EntityKey("book", 1))
|
||||
assert.Equal(t, "mypfx:book:list:2:10", g.ListKey("book", 2, 10))
|
||||
assert.Equal(t, "mypfx:book:find:tag:99", g.QueryKey("book", "find", "tag", 99))
|
||||
assert.Equal(t, "mypfx:book:id:1", g.EntityKey("book", 1))
|
||||
assert.Equal(t, "mypfx:book:list:2:10", g.ListKey("book", 2, 10))
|
||||
assert.Equal(t, "mypfx:book:find:tag:99", g.QueryKey("book", "find", "tag", 99))
|
||||
}
|
||||
|
||||
29
internal/platform/cache/graphql_adapter.go
vendored
Normal file
29
internal/platform/cache/graphql_adapter.go
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
// Package cache provides cache implementations and adapters.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GraphQLCacheAdapter adapts the RedisCache to the graphql.Cache[string] interface
|
||||
type GraphQLCacheAdapter struct {
|
||||
RedisCache *RedisCache
|
||||
}
|
||||
|
||||
// Get looks up a key in the cache
|
||||
func (a *GraphQLCacheAdapter) Get(ctx context.Context, key string) (string, bool) {
|
||||
// gqlgen APQ stores strings.
|
||||
var s string
|
||||
err := a.RedisCache.Get(ctx, key, &s)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// Add adds a key to the cache
|
||||
func (a *GraphQLCacheAdapter) Add(ctx context.Context, key, value string) {
|
||||
// Use default TTL of 24 hours for APQ. The interface does not provide TTL.
|
||||
_ = a.RedisCache.Set(ctx, key, value, 24*time.Hour)
|
||||
}
|
||||
11
internal/platform/cache/redis_cache.go
vendored
11
internal/platform/cache/redis_cache.go
vendored
@ -5,9 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"tercul/internal/platform/config"
|
||||
"tercul/internal/platform/log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@ -171,7 +172,13 @@ func (c *RedisCache) GetList(ctx context.Context, entityType string, page, pageS
|
||||
}
|
||||
|
||||
// SetList stores a list of entities in the cache
|
||||
func (c *RedisCache) SetList(ctx context.Context, entityType string, page, pageSize int, value interface{}, expiration time.Duration) error {
|
||||
func (c *RedisCache) SetList(
|
||||
ctx context.Context,
|
||||
entityType string,
|
||||
page, pageSize int,
|
||||
value interface{},
|
||||
expiration time.Duration,
|
||||
) error {
|
||||
key := c.keyGenerator.ListKey(entityType, page, pageSize)
|
||||
return c.Set(ctx, key, value, expiration)
|
||||
}
|
||||
|
||||
3
internal/platform/cache/redis_cache_test.go
vendored
3
internal/platform/cache/redis_cache_test.go
vendored
@ -6,9 +6,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tercul/internal/platform/cache"
|
||||
"tercul/internal/platform/config"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redismock/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
86
internal/platform/http/middleware.go
Normal file
86
internal/platform/http/middleware.go
Normal file
@ -0,0 +1,86 @@
|
||||
// Package http provides HTTP middleware and utilities.
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CORSMiddleware handles Cross-Origin Resource Sharing
|
||||
func CORSMiddleware(allowedOrigins []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
allowed := false
|
||||
|
||||
// If no allowed origins configured, allow all (development mode usually)
|
||||
if len(allowedOrigins) == 0 {
|
||||
allowed = true
|
||||
} else {
|
||||
for _, o := range allowedOrigins {
|
||||
if o == "*" || o == origin {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Safe default if we want to allow everything
|
||||
if allowed {
|
||||
// If origin is present, use it, otherwise *
|
||||
if origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers",
|
||||
"Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Client-ID, X-API-Key")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequestValidationMiddleware performs basic request validation
|
||||
func RequestValidationMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check Content-Type for POST requests to /query
|
||||
if r.Method == "POST" && r.URL.Path == "/query" {
|
||||
ct := r.Header.Get("Content-Type")
|
||||
// GraphQL clients might send application/json; charset=utf-8
|
||||
if !strings.Contains(ct, "application/json") {
|
||||
// Some clients might send no content type or something else?
|
||||
// Strictly enforcing application/json is good for security.
|
||||
// But we should be careful not to break existing clients if they are sloppy.
|
||||
// For now, let's enforce it as requested.
|
||||
http.Error(w, "Content-Type must be application/json", http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// APIKeyMiddleware checks for X-API-Key header
|
||||
// This is a placeholder for future external integrations.
|
||||
// It allows requests with a valid API key to bypass other auth or strictly enforce it.
|
||||
// Currently it is a pass-through as we don't have defined API keys in config yet.
|
||||
func APIKeyMiddleware(validAPIKeys []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// If we had keys, we would check them here.
|
||||
// apiKey := r.Header.Get("X-API-Key")
|
||||
// validate(apiKey)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,12 +1,14 @@
|
||||
// Package http provides HTTP middleware and utilities.
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tercul/internal/platform/config"
|
||||
"tercul/internal/platform/log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Canonical token bucket implementation for strict burst/rate enforcement
|
||||
|
||||
@ -3,11 +3,12 @@ package http_test
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"tercul/internal/platform/config"
|
||||
platformhttp "tercul/internal/platform/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tercul/internal/platform/config"
|
||||
platformhttp "tercul/internal/platform/http"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
@ -95,7 +96,7 @@ func (s *RateLimiterSuite) TestRateLimiterMiddleware() {
|
||||
staticID := "test-client-id"
|
||||
// Test that the first 3 requests are allowed (burst)
|
||||
for i := 0; i < 3; i++ {
|
||||
req, _ := http.NewRequest("GET", server.URL, nil)
|
||||
req, _ := http.NewRequest("GET", server.URL, http.NoBody)
|
||||
req.Header.Set("X-Client-ID", staticID)
|
||||
resp, err := client.Do(req)
|
||||
s.Require().NoError(err)
|
||||
@ -104,7 +105,7 @@ func (s *RateLimiterSuite) TestRateLimiterMiddleware() {
|
||||
}
|
||||
|
||||
// Test that the 4th request is not allowed (burst exceeded)
|
||||
req, _ := http.NewRequest("GET", server.URL, nil)
|
||||
req, _ := http.NewRequest("GET", server.URL, http.NoBody)
|
||||
req.Header.Set("X-Client-ID", staticID)
|
||||
resp, err := client.Do(req)
|
||||
s.Require().NoError(err)
|
||||
@ -116,7 +117,7 @@ func (s *RateLimiterSuite) TestRateLimiterMiddleware() {
|
||||
|
||||
// Test that the next 2 requests are allowed (rate)
|
||||
for i := 0; i < 2; i++ {
|
||||
req, _ := http.NewRequest("GET", server.URL, nil)
|
||||
req, _ := http.NewRequest("GET", server.URL, http.NoBody)
|
||||
req.Header.Set("X-Client-ID", staticID)
|
||||
resp, err := client.Do(req)
|
||||
s.Require().NoError(err)
|
||||
@ -125,7 +126,7 @@ func (s *RateLimiterSuite) TestRateLimiterMiddleware() {
|
||||
}
|
||||
|
||||
// Test that the 3rd request after wait is not allowed (rate exceeded)
|
||||
req, _ = http.NewRequest("GET", server.URL, nil)
|
||||
req, _ = http.NewRequest("GET", server.URL, http.NoBody)
|
||||
req.Header.Set("X-Client-ID", staticID)
|
||||
resp, err = client.Do(req)
|
||||
s.Require().NoError(err)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user