turash/bugulma/backend/internal/repository/subscription_repository.go

266 lines
7.9 KiB
Go

package repository
import (
"context"
"errors"
"time"
"bugulma/backend/internal/domain"
"gorm.io/gorm"
)
// SubscriptionRepository implements domain.SubscriptionRepository with GORM
type SubscriptionRepository struct {
*BaseRepository[domain.Subscription]
db *gorm.DB
}
// NewSubscriptionRepository creates a new GORM-based subscription repository
func NewSubscriptionRepository(db *gorm.DB) domain.SubscriptionRepository {
return &SubscriptionRepository{
BaseRepository: NewBaseRepository[domain.Subscription](db),
db: db,
}
}
// GetByUserID retrieves a subscription by user ID
func (r *SubscriptionRepository) GetByUserID(ctx context.Context, userID string) (*domain.Subscription, error) {
var subscription domain.Subscription
result := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&subscription)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, result.Error
}
return &subscription, nil
}
// GetActiveByUserID retrieves an active subscription by user ID
func (r *SubscriptionRepository) GetActiveByUserID(ctx context.Context, userID string) (*domain.Subscription, error) {
var subscription domain.Subscription
result := r.db.WithContext(ctx).
Where("user_id = ? AND status IN ?", userID, []domain.SubscriptionStatus{
domain.SubscriptionStatusActive,
domain.SubscriptionStatusTrialing,
}).
First(&subscription)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, result.Error
}
return &subscription, nil
}
// UpdateStatus updates the status of a subscription
func (r *SubscriptionRepository) UpdateStatus(ctx context.Context, id string, status domain.SubscriptionStatus) error {
result := r.db.WithContext(ctx).
Model(&domain.Subscription{}).
Where("id = ?", id).
Update("status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// PaymentMethodRepository implements domain.PaymentMethodRepository with GORM
type PaymentMethodRepository struct {
*BaseRepository[domain.PaymentMethod]
db *gorm.DB
}
// NewPaymentMethodRepository creates a new GORM-based payment method repository
func NewPaymentMethodRepository(db *gorm.DB) domain.PaymentMethodRepository {
return &PaymentMethodRepository{
BaseRepository: NewBaseRepository[domain.PaymentMethod](db),
db: db,
}
}
// GetByUserID retrieves all payment methods for a user
func (r *PaymentMethodRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.PaymentMethod, error) {
var paymentMethods []*domain.PaymentMethod
result := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("is_default DESC, created_at DESC").
Find(&paymentMethods)
if result.Error != nil {
return nil, result.Error
}
return paymentMethods, nil
}
// GetDefaultByUserID retrieves the default payment method for a user
func (r *PaymentMethodRepository) GetDefaultByUserID(ctx context.Context, userID string) (*domain.PaymentMethod, error) {
var paymentMethod domain.PaymentMethod
result := r.db.WithContext(ctx).
Where("user_id = ? AND is_default = ?", userID, true).
First(&paymentMethod)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, result.Error
}
return &paymentMethod, nil
}
// SetDefault sets a payment method as default for a user
func (r *PaymentMethodRepository) SetDefault(ctx context.Context, userID string, paymentMethodID string) error {
// Start transaction
tx := r.db.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Unset all defaults for user
if err := tx.Model(&domain.PaymentMethod{}).
Where("user_id = ?", userID).
Update("is_default", false).Error; err != nil {
tx.Rollback()
return err
}
// Set new default
if err := tx.Model(&domain.PaymentMethod{}).
Where("id = ? AND user_id = ?", paymentMethodID, userID).
Update("is_default", true).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// InvoiceRepository implements domain.InvoiceRepository with GORM
type InvoiceRepository struct {
*BaseRepository[domain.Invoice]
db *gorm.DB
}
// NewInvoiceRepository creates a new GORM-based invoice repository
func NewInvoiceRepository(db *gorm.DB) domain.InvoiceRepository {
return &InvoiceRepository{
BaseRepository: NewBaseRepository[domain.Invoice](db),
db: db,
}
}
// GetByUserID retrieves invoices for a user with pagination
func (r *InvoiceRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*domain.Invoice, int64, error) {
var invoices []*domain.Invoice
var total int64
// Get total count
if err := r.db.WithContext(ctx).
Model(&domain.Invoice{}).
Where("user_id = ?", userID).
Count(&total).Error; err != nil {
return nil, 0, err
}
// Get paginated results
result := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&invoices)
if result.Error != nil {
return nil, 0, result.Error
}
return invoices, total, nil
}
// GetBySubscriptionID retrieves all invoices for a subscription
func (r *InvoiceRepository) GetBySubscriptionID(ctx context.Context, subscriptionID string) ([]*domain.Invoice, error) {
var invoices []*domain.Invoice
result := r.db.WithContext(ctx).
Where("subscription_id = ?", subscriptionID).
Order("created_at DESC").
Find(&invoices)
if result.Error != nil {
return nil, result.Error
}
return invoices, nil
}
// UsageTrackingRepository implements domain.UsageTrackingRepository with GORM
type UsageTrackingRepository struct {
*BaseRepository[domain.UsageTracking]
db *gorm.DB
}
// NewUsageTrackingRepository creates a new GORM-based usage tracking repository
func NewUsageTrackingRepository(db *gorm.DB) domain.UsageTrackingRepository {
return &UsageTrackingRepository{
BaseRepository: NewBaseRepository[domain.UsageTracking](db),
db: db,
}
}
// GetByUserIDAndType retrieves usage tracking for a user and limit type in a specific period
func (r *UsageTrackingRepository) GetByUserIDAndType(ctx context.Context, userID string, limitType domain.UsageLimitType, periodStart time.Time) (*domain.UsageTracking, error) {
var usage domain.UsageTracking
result := r.db.WithContext(ctx).
Where("user_id = ? AND limit_type = ? AND period_start = ?", userID, limitType, periodStart).
First(&usage)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, result.Error
}
return &usage, nil
}
// UpdateUsage updates or creates usage tracking
func (r *UsageTrackingRepository) UpdateUsage(ctx context.Context, userID string, limitType domain.UsageLimitType, periodStart time.Time, amount int64) error {
// Try to get existing record
usage, err := r.GetByUserIDAndType(ctx, userID, limitType, periodStart)
if err != nil && !errors.Is(err, ErrNotFound) {
return err
}
if usage != nil {
// Update existing
usage.CurrentUsage = amount
return r.db.WithContext(ctx).Save(usage).Error
}
// Create new
periodEnd := periodStart.AddDate(0, 1, 0) // Default to monthly period
if limitType == domain.UsageLimitTypeAPICalls {
// API calls reset monthly
periodEnd = periodStart.AddDate(0, 1, 0)
}
newUsage := &domain.UsageTracking{
UserID: userID,
LimitType: limitType,
CurrentUsage: amount,
PeriodStart: periodStart,
PeriodEnd: periodEnd,
}
return r.db.WithContext(ctx).Create(newUsage).Error
}
// GetCurrentPeriodUsage retrieves current period usage
func (r *UsageTrackingRepository) GetCurrentPeriodUsage(ctx context.Context, userID string, limitType domain.UsageLimitType) (*domain.UsageTracking, error) {
// Get current month start
now := time.Now()
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
return r.GetByUserIDAndType(ctx, userID, limitType, periodStart)
}