package repository import ( "context" "errors" "time" "bugulma/backend/internal/domain" "gorm.io/gorm" ) // SubscriptionRepository implements domain.SubscriptionRepository with GORM type SubscriptionRepository struct { *BaseRepository[domain.Subscription] db *gorm.DB } // NewSubscriptionRepository creates a new GORM-based subscription repository func NewSubscriptionRepository(db *gorm.DB) domain.SubscriptionRepository { return &SubscriptionRepository{ BaseRepository: NewBaseRepository[domain.Subscription](db), db: db, } } // GetByUserID retrieves a subscription by user ID func (r *SubscriptionRepository) GetByUserID(ctx context.Context, userID string) (*domain.Subscription, error) { var subscription domain.Subscription result := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&subscription) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, result.Error } return &subscription, nil } // GetActiveByUserID retrieves an active subscription by user ID func (r *SubscriptionRepository) GetActiveByUserID(ctx context.Context, userID string) (*domain.Subscription, error) { var subscription domain.Subscription result := r.db.WithContext(ctx). Where("user_id = ? AND status IN ?", userID, []domain.SubscriptionStatus{ domain.SubscriptionStatusActive, domain.SubscriptionStatusTrialing, }). First(&subscription) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, result.Error } return &subscription, nil } // UpdateStatus updates the status of a subscription func (r *SubscriptionRepository) UpdateStatus(ctx context.Context, id string, status domain.SubscriptionStatus) error { result := r.db.WithContext(ctx). Model(&domain.Subscription{}). Where("id = ?", id). Update("status", status) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return ErrNotFound } return nil } // PaymentMethodRepository implements domain.PaymentMethodRepository with GORM type PaymentMethodRepository struct { *BaseRepository[domain.PaymentMethod] db *gorm.DB } // NewPaymentMethodRepository creates a new GORM-based payment method repository func NewPaymentMethodRepository(db *gorm.DB) domain.PaymentMethodRepository { return &PaymentMethodRepository{ BaseRepository: NewBaseRepository[domain.PaymentMethod](db), db: db, } } // GetByUserID retrieves all payment methods for a user func (r *PaymentMethodRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.PaymentMethod, error) { var paymentMethods []*domain.PaymentMethod result := r.db.WithContext(ctx). Where("user_id = ?", userID). Order("is_default DESC, created_at DESC"). Find(&paymentMethods) if result.Error != nil { return nil, result.Error } return paymentMethods, nil } // GetDefaultByUserID retrieves the default payment method for a user func (r *PaymentMethodRepository) GetDefaultByUserID(ctx context.Context, userID string) (*domain.PaymentMethod, error) { var paymentMethod domain.PaymentMethod result := r.db.WithContext(ctx). Where("user_id = ? AND is_default = ?", userID, true). First(&paymentMethod) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, result.Error } return &paymentMethod, nil } // SetDefault sets a payment method as default for a user func (r *PaymentMethodRepository) SetDefault(ctx context.Context, userID string, paymentMethodID string) error { // Start transaction tx := r.db.WithContext(ctx).Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() // Unset all defaults for user if err := tx.Model(&domain.PaymentMethod{}). Where("user_id = ?", userID). Update("is_default", false).Error; err != nil { tx.Rollback() return err } // Set new default if err := tx.Model(&domain.PaymentMethod{}). Where("id = ? AND user_id = ?", paymentMethodID, userID). Update("is_default", true).Error; err != nil { tx.Rollback() return err } return tx.Commit().Error } // InvoiceRepository implements domain.InvoiceRepository with GORM type InvoiceRepository struct { *BaseRepository[domain.Invoice] db *gorm.DB } // NewInvoiceRepository creates a new GORM-based invoice repository func NewInvoiceRepository(db *gorm.DB) domain.InvoiceRepository { return &InvoiceRepository{ BaseRepository: NewBaseRepository[domain.Invoice](db), db: db, } } // GetByUserID retrieves invoices for a user with pagination func (r *InvoiceRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*domain.Invoice, int64, error) { var invoices []*domain.Invoice var total int64 // Get total count if err := r.db.WithContext(ctx). Model(&domain.Invoice{}). Where("user_id = ?", userID). Count(&total).Error; err != nil { return nil, 0, err } // Get paginated results result := r.db.WithContext(ctx). Where("user_id = ?", userID). Order("created_at DESC"). Limit(limit). Offset(offset). Find(&invoices) if result.Error != nil { return nil, 0, result.Error } return invoices, total, nil } // GetBySubscriptionID retrieves all invoices for a subscription func (r *InvoiceRepository) GetBySubscriptionID(ctx context.Context, subscriptionID string) ([]*domain.Invoice, error) { var invoices []*domain.Invoice result := r.db.WithContext(ctx). Where("subscription_id = ?", subscriptionID). Order("created_at DESC"). Find(&invoices) if result.Error != nil { return nil, result.Error } return invoices, nil } // UsageTrackingRepository implements domain.UsageTrackingRepository with GORM type UsageTrackingRepository struct { *BaseRepository[domain.UsageTracking] db *gorm.DB } // NewUsageTrackingRepository creates a new GORM-based usage tracking repository func NewUsageTrackingRepository(db *gorm.DB) domain.UsageTrackingRepository { return &UsageTrackingRepository{ BaseRepository: NewBaseRepository[domain.UsageTracking](db), db: db, } } // GetByUserIDAndType retrieves usage tracking for a user and limit type in a specific period func (r *UsageTrackingRepository) GetByUserIDAndType(ctx context.Context, userID string, limitType domain.UsageLimitType, periodStart time.Time) (*domain.UsageTracking, error) { var usage domain.UsageTracking result := r.db.WithContext(ctx). Where("user_id = ? AND limit_type = ? AND period_start = ?", userID, limitType, periodStart). First(&usage) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, result.Error } return &usage, nil } // UpdateUsage updates or creates usage tracking func (r *UsageTrackingRepository) UpdateUsage(ctx context.Context, userID string, limitType domain.UsageLimitType, periodStart time.Time, amount int64) error { // Try to get existing record usage, err := r.GetByUserIDAndType(ctx, userID, limitType, periodStart) if err != nil && !errors.Is(err, ErrNotFound) { return err } if usage != nil { // Update existing usage.CurrentUsage = amount return r.db.WithContext(ctx).Save(usage).Error } // Create new periodEnd := periodStart.AddDate(0, 1, 0) // Default to monthly period if limitType == domain.UsageLimitTypeAPICalls { // API calls reset monthly periodEnd = periodStart.AddDate(0, 1, 0) } newUsage := &domain.UsageTracking{ UserID: userID, LimitType: limitType, CurrentUsage: amount, PeriodStart: periodStart, PeriodEnd: periodEnd, } return r.db.WithContext(ctx).Create(newUsage).Error } // GetCurrentPeriodUsage retrieves current period usage func (r *UsageTrackingRepository) GetCurrentPeriodUsage(ctx context.Context, userID string, limitType domain.UsageLimitType) (*domain.UsageTracking, error) { // Get current month start now := time.Now() periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) return r.GetByUserIDAndType(ctx, userID, limitType, periodStart) }