Transformer Architecture in Deep Learning
The Transformer architecture, introduced in the paper "Attention Is All You Need," has revolutionized natural language processing and beyond. This guide explores the architecture, its components, and applications.
Architecture Overview
Core Components
Encoder-Decoder Structure
- Multiple encoder layers
- Multiple decoder layers
- Self-attention mechanisms
- Feed-forward networks
Attention Mechanisms
- Self-attention
- Multi-head attention
- Cross-attention
- Scaled dot-product attention
Implementation Details
Self-Attention Mechanism
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
self.queries = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.values = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, query, key, value, mask=None):
N = query.shape[0]
value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]
# Split into heads
queries = self.queries(query).reshape(N, query_len, self.heads, self.head_dim)
keys = self.keys(key).reshape(N, key_len, self.heads, self.head_dim)
values = self.values(value).reshape(N, value_len, self.heads, self.head_dim)
# Scaled dot-product attention
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
out = out.reshape(N, query_len, self.embed_size)
return self.fc_out(out)
Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super().__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask=None):
attention = self.attention(query, key, value, mask)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
Key Concepts
Positional Encoding
- Fixed Encodings
def get_positional_encoding(max_seq_len, embed_size):
pos = torch.arange(max_seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2) * -(math.log(10000.0) / embed_size))
pe = torch.zeros(max_seq_len, embed_size)
pe[:, 0::2] = torch.sin(pos * div_term)
pe[:, 1::2] = torch.cos(pos * div_term)
return pe
- Learned Positional Embeddings
- Trainable parameters
- Adaptive to sequence length
- Task-specific learning
Multi-Head Attention
Purpose
- Parallel attention computation
- Different representation subspaces
- Enhanced feature capture
Implementation
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = F.softmax(scores, dim=-1)
out = torch.matmul(attention, v)
out = out.transpose(1, 2).contiguous()
out = out.reshape(batch_size, seq_len, self.d_model)
return self.proj(out)
Applications
Natural Language Processing
Machine Translation
- Encoder-decoder architecture
- Language understanding
- Generation quality
Text Generation
- Autoregressive generation
- Beam search
- Temperature sampling
Vision Transformers
Image Processing
- Patch embeddings
- Position encodings
- Attention patterns
Vision Tasks
class VisionTransformer(nn.Module):
def __init__(self, image_size, patch_size, num_classes, dim):
super().__init__()
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.patch_embed = nn.Linear(patch_dim, dim)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = TransformerEncoder(dim, depth=12)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=p, p2=p)
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.transformer(x)
return self.mlp_head(x[:, 0])
Training Considerations
Optimization
Learning Rate
- Warm-up schedule
- Decay strategies
- Adaptive methods
Regularization
- Dropout
- Layer normalization
- Weight decay
Memory Management
Efficient Implementation
- Gradient checkpointing
- Mixed precision training
- Model parallelism
Resource Optimization
# Gradient checkpointing example
from torch.utils.checkpoint import checkpoint
def forward_with_checkpoint(self, x):
return checkpoint(self.transformer_block, x)
Advanced Topics
Variants and Improvements
Efficient Attention
- Linear attention
- Sparse attention
- Local attention
Architecture Modifications
- Reformer
- Performer
- Linformer
Scaling Considerations
Model Size
- Parameter efficiency
- Computational complexity
- Memory requirements
Training Strategies
- Distributed training
- Pipeline parallelism
- Zero redundancy optimizer
Best Practices
Implementation Tips
Code Organization
- Modular design
- Clear interfaces
- Reusable components
Performance Optimization
- Batch processing
- Caching mechanisms
- Memory management
Common Pitfalls
Training Issues
- Vanishing gradients
- Attention collapse
- Overfitting
Implementation Challenges
- Memory constraints
- Numerical stability
- Debugging complexity
Future Directions
Research Areas
Architecture Improvements
- Efficiency enhancements
- Scale optimization
- Task adaptation
Applications
- Cross-modal learning
- Domain adaptation
- Few-shot learning
Resources
Learning Materials
Papers
- "Attention Is All You Need"
- Key transformer variants
- Application studies
Implementations
- PyTorch
- TensorFlow
- JAX
Remember that transformers are a rapidly evolving architecture with new variants and applications emerging regularly. Stay updated with the latest research while maintaining a solid understanding of the fundamentals.