How to Build an LLM from Scratch with PyTorch
Practical tutorial: It discusses an interesting technique that could influence how developers interact with large language models.
How to Build an LLM from Scratch with PyTorch
Table of Contents
- How to Build an LLM from Scratch with PyTorch
- Create a virtual environment
- Install core dependencies
- For tokenization and data processing
- Initialize and train the model
📺 Watch: Neural Networks Explained
Video by 3Blue1Brown
Large language models have transformed how we interact with technology, but understanding their inner workings remains a challenge for many developers. According to Wikipedia, a large language model (LLM) is a neural network trained on a vast amount of text for natural language processing tasks, especially language generation. While most practitioners use pre-trained models through APIs, building one from scratch provides invaluable insight into the architecture, training dynamics, and limitations of these systems.
In this tutorial, we'll implement a ChatGPT [5]-like LLM in PyTorch from scratch, step by step, as described in the popular GitHub repository "LLMs-from-scratch" which has garnered 87,799 stars and 13,374 forks as of May 2026. This repository, written in Jupyter Notebook, provides a comprehensive guide to understanding the fundamental components of modern LLMs.
Understanding the Architecture: From Attention to Generation
Before diving into code, we need to understand the core components that make LLMs work. The transformer architecture, introduced in 2017, revolutionized natural language processing by replacing recurrent neural networks with attention mechanisms. Our implementation will focus on the decoder-only architecture used by GPT models, which consists of:
- Token Embedding Layer: Converts discrete tokens into continuous vector representations
- Positional Encoding: Adds information about token positions in the sequence
- Multi-Head Self-Attention: Allows the model to weigh the importance of different tokens
- Feed-Forward Networks: Processes attention outputs through non-linear transformations
- Layer Normalization: Stabilizes training by normalizing activations
- Output Projection: Maps the final hidden states to vocabulary probabilities
The key insight is that LLMs learn to predict the next token in a sequence, which enables them to generate coherent text through autoregressive decoding. This process, while computationally expensive, allows the model to capture complex patterns in language.
Prerequisites and Environment Setup
We'll need a modern Python environment with PyTorch [6] and supporting libraries. Let's set up our development environment:
# Create a virtual environment
python -m venv llm_env
source llm_env/bin/activate # On Windows: llm_env\Scripts\activate
# Install core dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install numpy tqdm matplotlib wandb
# For tokenization and data processing
pip install tiktoken datasets
Our implementation will use the following key components:
- PyTorch 2.x: For tensor operations and automatic differentiation
- tiktoken: OpenAI [8]'s fast BPE tokenizer
- datasets: Hugging Face's dataset library for training data
- wandb: Optional experiment tracking
Core Implementation: Building the Transformer Block
Let's start with the fundamental building block of our LLM: the transformer decoder layer. This implementation follows the architecture described in the "LLMs-from-scratch" repository, which has become a gold standard for understanding LLM internals.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
Implements multi-head scaled dot-product attention.
Args:
d_model: Model dimension
n_heads: Number of attention heads
dropout: Dropout probability
"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # Dimension per head
# Linear projections for Q, K, V
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# Project and reshape for multi-head attention
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention to values
context = torch.matmul(attention_weights, V)
# Reshape back to original dimensions
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
return self.w_o(context)
class FeedForward(nn.Module):
"""
Position-wise feed-forward network with GELU activation.
"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerDecoderLayer(nn.Module):
"""
Single transformer decoder layer with pre-layer normalization.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Pre-layer normalization (more stable than post-norm)
attn_output = self.self_attention(self.norm1(x), mask)
x = x + self.dropout(attn_output)
ff_output = self.feed_forward(self.norm2(x))
x = x + self.dropout(ff_output)
return x
This implementation uses pre-layer normalization, which has been shown to provide more stable training compared to the original post-norm architecture. The MultiHeadAttention class implements the core attention mechanism with proper masking for causal language modeling.
Building the Complete Language Model
Now let's assemble the full model with token embeddings, positional encodings, and the output projection layer:
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding as described in "Attention Is All You Need".
"""
def __init__(self, d_model: int, max_seq_len: int = 2048):
super().__init__()
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: (1, max_seq_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, :x.size(1), :]
class GPTModel(nn.Module):
"""
Decoder-only transformer model for language modeling.
Args:
vocab_size: Size of vocabulary
d_model: Model dimension
n_layers: Number of transformer layers
n_heads: Number of attention heads
d_ff: Feed-forward dimension
max_seq_len: Maximum sequence length
dropout: Dropout probability
"""
def __init__(self, vocab_size: int, d_model: int = 768, n_layers: int = 12,
n_heads: int = 12, d_ff: int = 3072, max_seq_len: int = 2048,
dropout: float = 0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList([
TransformerDecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, vocab_size, bias=False)
# Tie weights between embedding and output projection
self.token_embedding.weight = self.output_projection.weight
self._init_weights()
def _init_weights(self):
"""Initialize weights using small normal distribution."""
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Token embeddings + positional encoding
x = self.token_embedding(x)
x = self.positional_encoding(x)
x = self.dropout(x)
# Pass through transformer layers
for layer in self.layers:
x = layer(x, mask)
# Final normalization and output projection
x = self.final_norm(x)
logits = self.output_projection(x)
return logits
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
"""
Autoregressive text generation.
Args:
input_ids: Starting token IDs (batch_size, seq_len)
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature (higher = more random)
top_k: If set, only sample from top-k tokens
"""
self.eval()
for _ in range(max_new_tokens):
# Crop to max sequence length
if input_ids.size(1) > self.positional_encoding.pe.size(1):
input_ids = input_ids[:, -self.positional_encoding.pe.size(1):]
# Create causal mask
seq_len = input_ids.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
mask = mask.to(input_ids.device)
# Forward pass
with torch.no_grad():
logits = self.forward(input_ids, mask)
# Get logits for the last token
logits = logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k is not None:
top_k_values, _ = torch.topk(logits, top_k, dim=-1)
min_top_k = top_k_values[:, -1].unsqueeze(-1)
logits[logits < min_top_k] = float('-inf')
# Sample from the distribution
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
The GPTModel class implements weight tying between the embedding layer and output projection, which reduces the number of parameters and improves training efficiency. The generate method implements autoregressive decoding with temperature scaling and top-k sampling for controlled text generation.
Training the Model: Data Preparation and Training Loop
To train our model, we need to prepare a dataset and implement an efficient training loop. Let's use the Tiny Shakespeare dataset for demonstration:
import tiktoken
from datasets import load_dataset
def prepare_data(batch_size: int = 8, block_size: int = 256):
"""
Load and tokenize the Tiny Shakespeare dataset.
"""
# Load dataset
dataset = load_dataset("tiny_shakespeare", split="train")
# Initialize tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
# Tokenize the entire dataset
text = "\n\n".join(dataset["text"])
tokens = tokenizer.encode(text)
# Convert to tensor
data = torch.tensor(tokens, dtype=torch.long)
# Create training and validation splits
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
return train_data, val_data, tokenizer
def get_batch(data: torch.Tensor, batch_size: int, block_size: int):
"""
Get a random batch of sequences from the dataset.
"""
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x, y
def train_model(model: GPTModel, train_data: torch.Tensor, val_data: torch.Tensor,
num_epochs: int = 10, batch_size: int = 8, block_size: int = 256,
learning_rate: float = 3e-4):
"""
Training loop with learning rate scheduling and gradient clipping.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for step in range(len(train_data) // (batch_size * block_size)):
x, y = get_batch(train_data, batch_size, block_size)
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1)
)
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
if step % 100 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
scheduler.step()
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for _ in range(100):
x, y = get_batch(val_data, batch_size, block_size)
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1)
)
val_loss += loss.item()
print(f"Epoch {epoch}, Train Loss: {total_loss/100:.4f}, Val Loss: {val_loss/100:.4f}")
# Initialize and train the model
vocab_size = 50257 # GPT-2 vocabulary size
model = GPTModel(
vocab_size=vocab_size,
d_model=768,
n_layers=12,
n_heads=12,
d_ff=3072,
max_seq_len=2048,
dropout=0.1
)
train_data, val_data, tokenizer = prepare_data()
train_model(model, train_data, val_data)
This training loop implements several best practices:
- Gradient clipping to prevent exploding gradients
- Cosine annealing learning rate schedule for better convergence
- Weight decay for regularization
- Efficient batching with random sequence sampling
Edge Cases and Production Considerations
When deploying LLMs in production, several edge cases require careful handling:
1. Memory Management
Large models can easily exceed GPU memory. Implement gradient checkpointing and mixed precision training:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast():
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
2. Tokenization Edge Cases
Handle out-of-vocabulary tokens and special characters:
def safe_tokenize(text: str, tokenizer, max_length: int = 2048):
"""Tokenize with proper handling of edge cases."""
if not text or text.isspace():
return torch.tensor([], dtype=torch.long)
# Handle very long sequences
tokens = tokenizer.encode(text)[:max_length]
# Ensure minimum length for attention
if len(tokens) < 2:
tokens = [tokenizer.eot_token] * 2
return torch.tensor(tokens, dtype=torch.long)
3. Inference Optimization
For production inference, implement KV caching to avoid recomputing attention for previously generated tokens:
class CachedGPTModel(GPTModel):
"""GPT model with KV cache for efficient inference."""
def generate_with_cache(self, input_ids, max_new_tokens=100):
cache = {}
for _ in range(max_new_tokens):
logits, cache = self.forward_with_cache(input_ids, cache)
next_token = torch.argmax(logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
return input_ids
Evaluation and Benchmarking
To evaluate our model's performance, we can use perplexity as a metric. The "DiscoverPhysics" benchmark, published on arXiv on May 25, 2026, demonstrates how LLMs can be evaluated for scientific thinking capabilities. While our small model won't match state-of-the-art performance, we can measure its learning progress:
def calculate_perplexity(model, data, batch_size=8, block_size=256):
"""Calculate perplexity on a dataset."""
model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for i in range(0, len(data) - block_size, block_size):
x = data[i:i+block_size].unsqueeze(0)
y = data[i+1:i+block_size+1].unsqueeze(0)
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
reduction='sum'
)
total_loss += loss.item()
total_tokens += block_size
return math.exp(total_loss / total_tokens)
Conclusion and What's Next
Building an LLM from scratch provides deep insight into the mechanics of modern language models. Our implementation, following the architecture from the "LLMs-from-scratch" repository, demonstrates the core components: multi-head attention, feed-forward networks, and autoregressive generation.
The field of LLM development continues to evolve rapidly. Recent work on knowledge distillation, as collected in the "Awesome-Knowledge-Distillation-of-LLMs" repository (1,264 stars, 71 forks), explores techniques for compressing large models into smaller, more efficient versions. Additionally, understanding model vulnerabilities through resources like the "jailbreak_llms" dataset (3,596 stars, 320 forks) is crucial for building robust systems.
What's Next:
- Scale up: Train on larger datasets with distributed training
- Fine-tune: Adapt the model for specific tasks using instruction tuning
- Optimize: Implement Flash Attention and other memory-efficient techniques
- Deploy: Use ONNX or TensorRT for production inference
- Evaluate: Test against benchmarks like DiscoverPhysics for scientific reasoning
Remember that LLMs, as Wikipedia notes, can produce unreliable output if trained on biased or inaccurate data. Always validate your model's outputs and implement appropriate safeguards for production use.
The complete code for this tutorial is available in the LLMs-from-scratch repository, which provides additional notebooks and exercises for deepening your understanding of transformer architectures.
References
Was this article helpful?
Let us know to improve our AI generation.
Related Articles
How to Analyze Security Logs with DeepSeek Locally
Practical tutorial: Analyze security logs with DeepSeek locally
How to Build a Multimodal App with Gemini 2.0 Vision API
Practical tutorial: Build a multimodal app with Gemini 2.0 Vision API
How to Build an AI Research Assistant with Perplexity API
Practical tutorial: Create an AI research assistant with Perplexity API