package db import ( "tercul/internal/observability" "time" "gorm.io/gorm" ) const ( startTime = "start_time" ) type PrometheusPlugin struct { Metrics *observability.Metrics } func (p *PrometheusPlugin) Name() string { return "PrometheusPlugin" } func (p *PrometheusPlugin) Initialize(db *gorm.DB) error { // Before callbacks if err := db.Callback().Create().Before("gorm:create").Register("prometheus:before_create", p.before); err != nil { return err } if err := db.Callback().Query().Before("gorm:query").Register("prometheus:before_query", p.before); err != nil { return err } if err := db.Callback().Update().Before("gorm:update").Register("prometheus:before_update", p.before); err != nil { return err } if err := db.Callback().Delete().Before("gorm:delete").Register("prometheus:before_delete", p.before); err != nil { return err } if err := db.Callback().Row().Before("gorm:row").Register("prometheus:before_row", p.before); err != nil { return err } if err := db.Callback().Raw().Before("gorm:raw").Register("prometheus:before_raw", p.before); err != nil { return err } // After callbacks if err := db.Callback().Create().After("gorm:create").Register("prometheus:after_create", p.after); err != nil { return err } if err := db.Callback().Query().After("gorm:query").Register("prometheus:after_query", p.after); err != nil { return err } if err := db.Callback().Update().After("gorm:update").Register("prometheus:after_update", p.after); err != nil { return err } if err := db.Callback().Delete().After("gorm:delete").Register("prometheus:after_delete", p.after); err != nil { return err } if err := db.Callback().Row().After("gorm:row").Register("prometheus:after_row", p.after); err != nil { return err } if err := db.Callback().Raw().After("gorm:raw").Register("prometheus:after_raw", p.after); err != nil { return err } return nil } func (p *PrometheusPlugin) before(db *gorm.DB) { db.Set(startTime, time.Now()) } func (p *PrometheusPlugin) after(db *gorm.DB) { _ts, ok := db.Get(startTime) if !ok { return } ts, ok := _ts.(time.Time) if !ok { return } operation := db.Statement.SQL.String() if len(operation) > 50 { // Truncate long queries operation = operation[:50] } status := "success" if db.Error != nil { status = "error" } duration := time.Since(ts).Seconds() p.Metrics.DBQueryDuration.WithLabelValues(operation, status).Observe(duration) p.Metrics.DBQueriesTotal.WithLabelValues(operation, status).Inc() } func NewPrometheusPlugin(metrics *observability.Metrics) *PrometheusPlugin { return &PrometheusPlugin{Metrics: metrics} }