1. ai
  2. /nlp
  3. /transformers

Transformer Architecture in Deep Learning

The Transformer architecture, introduced in the paper "Attention Is All You Need," has revolutionized natural language processing and beyond. This guide explores the architecture, its components, and applications.

Architecture Overview

Core Components

  1. Encoder-Decoder Structure

    • Multiple encoder layers
    • Multiple decoder layers
    • Self-attention mechanisms
    • Feed-forward networks
  2. Attention Mechanisms

    • Self-attention
    • Multi-head attention
    • Cross-attention
    • Scaled dot-product attention

Implementation Details

Self-Attention Mechanism

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        self.queries = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.values = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, query, key, value, mask=None):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]
        
        # Split into heads
        queries = self.queries(query).reshape(N, query_len, self.heads, self.head_dim)
        keys = self.keys(key).reshape(N, key_len, self.heads, self.head_dim)
        values = self.values(value).reshape(N, value_len, self.heads, self.head_dim)
        
        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        
        out = out.reshape(N, query_len, self.embed_size)
        return self.fc_out(out)

Transformer Block

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super().__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask=None):
        attention = self.attention(query, key, value, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

Key Concepts

Positional Encoding

  1. Fixed Encodings
def get_positional_encoding(max_seq_len, embed_size):
    pos = torch.arange(max_seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_size, 2) * -(math.log(10000.0) / embed_size))
    pe = torch.zeros(max_seq_len, embed_size)
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    return pe
  1. Learned Positional Embeddings
    • Trainable parameters
    • Adaptive to sequence length
    • Task-specific learning

Multi-Head Attention

  1. Purpose

    • Parallel attention computation
    • Different representation subspaces
    • Enhanced feature capture
  2. Implementation

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = F.softmax(scores, dim=-1)
        
        out = torch.matmul(attention, v)
        out = out.transpose(1, 2).contiguous()
        out = out.reshape(batch_size, seq_len, self.d_model)
        return self.proj(out)

Applications

Natural Language Processing

  1. Machine Translation

    • Encoder-decoder architecture
    • Language understanding
    • Generation quality
  2. Text Generation

    • Autoregressive generation
    • Beam search
    • Temperature sampling

Vision Transformers

  1. Image Processing

    • Patch embeddings
    • Position encodings
    • Attention patterns
  2. Vision Tasks

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2
        
        self.patch_embed = nn.Linear(patch_dim, dim)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        self.transformer = TransformerEncoder(dim, depth=12)
        self.mlp_head = nn.Linear(dim, num_classes)
        
    def forward(self, img):
        p = self.patch_size
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 
                     p1=p, p2=p)
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.transformer(x)
        return self.mlp_head(x[:, 0])

Training Considerations

Optimization

  1. Learning Rate

    • Warm-up schedule
    • Decay strategies
    • Adaptive methods
  2. Regularization

    • Dropout
    • Layer normalization
    • Weight decay

Memory Management

  1. Efficient Implementation

    • Gradient checkpointing
    • Mixed precision training
    • Model parallelism
  2. Resource Optimization

# Gradient checkpointing example
from torch.utils.checkpoint import checkpoint

def forward_with_checkpoint(self, x):
    return checkpoint(self.transformer_block, x)

Advanced Topics

Variants and Improvements

  1. Efficient Attention

    • Linear attention
    • Sparse attention
    • Local attention
  2. Architecture Modifications

    • Reformer
    • Performer
    • Linformer

Scaling Considerations

  1. Model Size

    • Parameter efficiency
    • Computational complexity
    • Memory requirements
  2. Training Strategies

    • Distributed training
    • Pipeline parallelism
    • Zero redundancy optimizer

Best Practices

Implementation Tips

  1. Code Organization

    • Modular design
    • Clear interfaces
    • Reusable components
  2. Performance Optimization

    • Batch processing
    • Caching mechanisms
    • Memory management

Common Pitfalls

  1. Training Issues

    • Vanishing gradients
    • Attention collapse
    • Overfitting
  2. Implementation Challenges

    • Memory constraints
    • Numerical stability
    • Debugging complexity

Future Directions

Research Areas

  1. Architecture Improvements

    • Efficiency enhancements
    • Scale optimization
    • Task adaptation
  2. Applications

    • Cross-modal learning
    • Domain adaptation
    • Few-shot learning

Resources

Learning Materials

  1. Papers

    • "Attention Is All You Need"
    • Key transformer variants
    • Application studies
  2. Implementations

    • PyTorch
    • TensorFlow
    • JAX

Remember that transformers are a rapidly evolving architecture with new variants and applications emerging regularly. Stay updated with the latest research while maintaining a solid understanding of the fundamentals.