1. go
  2. /database
  3. /transactions

Working with Database Transactions in Go

Database transactions ensure data consistency and integrity in your applications. This guide covers how to work with transactions effectively in Go.

Basic Transaction Handling

SQL Transactions

package main

import (
    "context"
    "database/sql"
)

func CreateUserWithProfile(db *sql.DB, user User, profile Profile) error {
    tx, err := db.Begin()
    if err != nil {
        return err
    }
    defer tx.Rollback() // Rollback if we return with error
    
    // Insert user
    userQuery := `
        INSERT INTO users (name, email, created_at)
        VALUES ($1, $2, NOW())
        RETURNING id
    `
    var userID int64
    err = tx.QueryRow(userQuery, user.Name, user.Email).Scan(&userID)
    if err != nil {
        return err
    }
    
    // Insert profile
    profileQuery := `
        INSERT INTO profiles (user_id, bio, avatar_url)
        VALUES ($1, $2, $3)
    `
    _, err = tx.Exec(profileQuery, userID, profile.Bio, profile.AvatarURL)
    if err != nil {
        return err
    }
    
    return tx.Commit()
}

GORM Transactions

func CreateOrderWithItems(db *gorm.DB, order Order, items []OrderItem) error {
    return db.Transaction(func(tx *gorm.DB) error {
        // Create order
        if err := tx.Create(&order).Error; err != nil {
            return err
        }
        
        // Create order items
        for i := range items {
            items[i].OrderID = order.ID
            if err := tx.Create(&items[i]).Error; err != nil {
                return err
            }
        }
        
        // Update inventory
        for _, item := range items {
            result := tx.Model(&Product{}).
                Where("id = ? AND stock >= ?", item.ProductID, item.Quantity).
                UpdateColumn("stock", gorm.Expr("stock - ?", item.Quantity))
            
            if result.RowsAffected == 0 {
                return fmt.Errorf("insufficient stock for product %d", item.ProductID)
            }
        }
        
        return nil
    })
}

Transaction Patterns

1. Repository Pattern with Transactions

type TransactionManager interface {
    WithTransaction(fn func(*sql.Tx) error) error
}

type SQLTransactionManager struct {
    db *sql.DB
}

func (tm *SQLTransactionManager) WithTransaction(fn func(*sql.Tx) error) error {
    tx, err := tm.db.Begin()
    if err != nil {
        return err
    }
    defer tx.Rollback()
    
    if err := fn(tx); err != nil {
        return err
    }
    
    return tx.Commit()
}

type UserRepository struct {
    tm TransactionManager
}

func (r *UserRepository) CreateUserWithProfile(user User, profile Profile) error {
    return r.tm.WithTransaction(func(tx *sql.Tx) error {
        // Create user
        userID, err := r.createUser(tx, user)
        if err != nil {
            return err
        }
        
        // Create profile
        profile.UserID = userID
        return r.createProfile(tx, profile)
    })
}

2. Nested Transactions

type TransactionOptions struct {
    Isolation sql.IsolationLevel
    ReadOnly  bool
}

func (tm *SQLTransactionManager) NestedTransaction(
    parent *sql.Tx,
    opts *TransactionOptions,
    fn func(*sql.Tx) error,
) error {
    var tx *sql.Tx
    var err error
    
    if parent != nil {
        // Create savepoint
        savepoint := fmt.Sprintf("sp_%d", time.Now().UnixNano())
        if _, err := parent.Exec("SAVEPOINT " + savepoint); err != nil {
            return err
        }
        tx = parent
        
        defer func() {
            if err != nil {
                parent.Exec("ROLLBACK TO SAVEPOINT " + savepoint)
            }
        }()
    } else {
        // Start new transaction
        tx, err = tm.db.BeginTx(context.Background(), &sql.TxOptions{
            Isolation: opts.Isolation,
            ReadOnly:  opts.ReadOnly,
        })
        if err != nil {
            return err
        }
        defer tx.Rollback()
    }
    
    if err := fn(tx); err != nil {
        return err
    }
    
    if parent == nil {
        return tx.Commit()
    }
    return nil
}

3. Retry Logic

type RetryOptions struct {
    MaxAttempts int
    Delay       time.Duration
    Factor      float64
}

func WithRetry(opts RetryOptions, fn func() error) error {
    var err error
    attempt := 0
    delay := opts.Delay
    
    for attempt < opts.MaxAttempts {
        err = fn()
        if err == nil {
            return nil
        }
        
        // Check if error is retryable
        if !isRetryableError(err) {
            return err
        }
        
        attempt++
        if attempt == opts.MaxAttempts {
            break
        }
        
        time.Sleep(delay)
        delay = time.Duration(float64(delay) * opts.Factor)
    }
    
    return fmt.Errorf("max retry attempts reached: %v", err)
}

func isRetryableError(err error) bool {
    if err == nil {
        return false
    }
    
    // Check for deadlock
    if strings.Contains(err.Error(), "deadlock") {
        return true
    }
    
    // Check for serialization failure
    if strings.Contains(err.Error(), "serialization") {
        return true
    }
    
    // Check for connection issues
    if strings.Contains(err.Error(), "connection") {
        return true
    }
    
    return false
}

Best Practices

1. Transaction Scope

type TransactionScope struct {
    db        *sql.DB
    tx        *sql.Tx
    committed bool
}

func NewTransactionScope(db *sql.DB) (*TransactionScope, error) {
    tx, err := db.Begin()
    if err != nil {
        return nil, err
    }
    
    return &TransactionScope{
        db: db,
        tx: tx,
    }, nil
}

func (ts *TransactionScope) Complete() error {
    if ts.tx == nil {
        return nil
    }
    
    if ts.committed {
        return nil
    }
    
    return ts.tx.Rollback()
}

func (ts *TransactionScope) Commit() error {
    if ts.tx == nil {
        return nil
    }
    
    err := ts.tx.Commit()
    if err == nil {
        ts.committed = true
    }
    return err
}

2. Transaction Context

type txKey struct{}

func WithTx(ctx context.Context, tx *sql.Tx) context.Context {
    return context.WithValue(ctx, txKey{}, tx)
}

func GetTx(ctx context.Context) (*sql.Tx, bool) {
    tx, ok := ctx.Value(txKey{}).(*sql.Tx)
    return tx, ok
}

type Repository struct {
    db *sql.DB
}

func (r *Repository) execWithTx(ctx context.Context, fn func(*sql.Tx) error) error {
    if tx, ok := GetTx(ctx); ok {
        return fn(tx)
    }
    
    tx, err := r.db.Begin()
    if err != nil {
        return err
    }
    defer tx.Rollback()
    
    if err := fn(tx); err != nil {
        return err
    }
    
    return tx.Commit()
}

3. Isolation Levels

type TransactionConfig struct {
    Isolation sql.IsolationLevel
    ReadOnly  bool
    Timeout   time.Duration
}

func BeginTxWithConfig(db *sql.DB, config TransactionConfig) (*sql.Tx, error) {
    ctx := context.Background()
    if config.Timeout > 0 {
        var cancel context.CancelFunc
        ctx, cancel = context.WithTimeout(ctx, config.Timeout)
        defer cancel()
    }
    
    return db.BeginTx(ctx, &sql.TxOptions{
        Isolation: config.Isolation,
        ReadOnly:  config.ReadOnly,
    })
}

// Usage example
func GetUserBalance(db *sql.DB, userID int64) (decimal.Decimal, error) {
    tx, err := BeginTxWithConfig(db, TransactionConfig{
        Isolation: sql.LevelRepeatableRead,
        ReadOnly:  true,
        Timeout:   5 * time.Second,
    })
    if err != nil {
        return decimal.Zero, err
    }
    defer tx.Rollback()
    
    var balance decimal.Decimal
    err = tx.QueryRow("SELECT balance FROM accounts WHERE user_id = $1", userID).
        Scan(&balance)
    if err != nil {
        return decimal.Zero, err
    }
    
    return balance, tx.Commit()
}

Common Patterns

1. Unit of Work

type UnitOfWork struct {
    db          *sql.DB
    tx          *sql.Tx
    userRepo    *UserRepository
    orderRepo   *OrderRepository
    productRepo *ProductRepository
}

func NewUnitOfWork(db *sql.DB) *UnitOfWork {
    return &UnitOfWork{db: db}
}

func (uow *UnitOfWork) Begin() error {
    tx, err := uow.db.Begin()
    if err != nil {
        return err
    }
    
    uow.tx = tx
    uow.userRepo = NewUserRepository(tx)
    uow.orderRepo = NewOrderRepository(tx)
    uow.productRepo = NewProductRepository(tx)
    return nil
}

func (uow *UnitOfWork) Commit() error {
    if uow.tx == nil {
        return nil
    }
    return uow.tx.Commit()
}

func (uow *UnitOfWork) Rollback() error {
    if uow.tx == nil {
        return nil
    }
    return uow.tx.Rollback()
}

func (uow *UnitOfWork) Users() *UserRepository {
    return uow.userRepo
}

func (uow *UnitOfWork) Orders() *OrderRepository {
    return uow.orderRepo
}

func (uow *UnitOfWork) Products() *ProductRepository {
    return uow.productRepo
}

2. Event Sourcing

type Event struct {
    ID        string
    Type      string
    Data      []byte
    Timestamp time.Time
    Version   int
}

type EventStore struct {
    db *sql.DB
}

func (es *EventStore) SaveEvents(aggregateID string, events []Event) error {
    return WithRetry(RetryOptions{
        MaxAttempts: 3,
        Delay:      100 * time.Millisecond,
        Factor:     2,
    }, func() error {
        tx, err := es.db.Begin()
        if err != nil {
            return err
        }
        defer tx.Rollback()
        
        // Lock aggregate
        var version int
        err = tx.QueryRow(
            "SELECT version FROM aggregates WHERE id = $1 FOR UPDATE",
            aggregateID,
        ).Scan(&version)
        if err != nil && err != sql.ErrNoRows {
            return err
        }
        
        // Insert events
        stmt, err := tx.Prepare(`
            INSERT INTO events (aggregate_id, type, data, timestamp, version)
            VALUES ($1, $2, $3, $4, $5)
        `)
        if err != nil {
            return err
        }
        
        for i, event := range events {
            event.Version = version + i + 1
            _, err = stmt.Exec(
                aggregateID,
                event.Type,
                event.Data,
                event.Timestamp,
                event.Version,
            )
            if err != nil {
                return err
            }
        }
        
        // Update aggregate version
        _, err = tx.Exec(
            `INSERT INTO aggregates (id, version)
             VALUES ($1, $2)
             ON CONFLICT (id) DO UPDATE SET version = $2`,
            aggregateID,
            version+len(events),
        )
        if err != nil {
            return err
        }
        
        return tx.Commit()
    })
}

Next Steps