In Part 5, I assembled the complete GPT medium model and validated its architecture with forward passes and text generation. In Part 6, I moved into the crucial stage of pretraining. I set out to understand the basics of pretraining by building a complete, reproducible pipeline around a GPT‑2 style model I call SydsGPT. In this part, I pretrained the model on 20 books from Project Gutenberg and used it to generate grammatically correct text. The focus was not scale, but clarity, control, and the groundwork for a private assistant that can operate without privacy concerns.

Why small models? Because I cannot compete with AI labs that train giant models on thousands of GPUs. My aim is to train a compact language model around 200M parameters on approximately 3B tokens, fine‑tune it for domain‑specific tasks, and extend it with tool calling for web search and RAG to interact with private data. This journey is about exploring the art of the possible with small models.
Configuring and instantiating a GPT‑2 345M style model
I started by defining SydsGPT with a GPT‑2 345M like configuration. The setup includes vocabulary size, context length, embedding dimension, number of heads and layers, dropout, and whether to include QKV biases. I fixed a manual seed for reproducibility, moved the model to the available device (cuda), and switched it to eval mode for deterministic behavior.
import torch
from model.SydsGPT import SydsGPT
SYDSGPT_CONFIG_345M = {
"vocab_size" : 50257,
"context_length" : 512,
"embedding_dim" : 1024,
"num_heads" : 16,
"num_layers" : 24,
"dropout" : 0.1,
"qkv_bias" : False
}
torch.manual_seed(246)
model = SydsGPT(SYDSGPT_CONFIG_345M)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
- Inputs: Integer token IDs shaped (batch size, sequence length)
- Outputs: Logits shaped (batch size, sequence length, vocab size)
- Device tips: Use GPU if available; consider bfloat16/float16 for inference where supported
- Common pitfalls: Ensure
model/andmodules/are importable,num_headsdividesembedding_dim, do not callmodel.train()during inference
Minimal text generation with GPT‑2 BPE
To verify the model and tokenizer integration, I built a thin encode/decode wrapper with tiktoken and a simple generation loop. This confirmed end‑to‑end functionality
from modules.GenerateSimple import generate_simple
import tiktoken
def text_to_tokens(text, tokenizer):
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
return tokens
def tokens_to_text(tokens, tokenizer):
text = tokenizer.decode(tokens.squeeze(0).tolist())
return text
tokenizer = tiktoken.get_encoding("gpt2")
Then I ran a small sample:
input_text = "Once upon a time"
input_tokens = text_to_tokens(input_text, tokenizer)
output_tokens = generate_simple(model, input_tokens, 100, SYDSGPT_CONFIG_345M['context_length'])
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text: {output_text}")
Outcome: The model generated mostly gibberish continuations from a seed prompt.
Batch inference and greedy selection
Next I tested batched inference on two prompts. I computed logits, converted them to probabilities, and selected the most likely tokens via greedy argmax to get a quick qualitative signal.
example_input_text_1 = "The quick brown fox"
example_target_text_1 = " quick brown fox jumps"
example_input_text_2 = "In a galaxy far"
example_target_text_2 = " a galaxy far away"
input_tokens_1 = text_to_tokens(example_input_text_1, tokenizer)
input_tokens_2 = text_to_tokens(example_input_text_2, tokenizer)
target_tokens_1 = text_to_tokens(example_target_text_1, tokenizer)
target_tokens_2 = text_to_tokens(example_target_text_2, tokenizer)
batch_input_tokens = torch.cat([input_tokens_1, input_tokens_2], dim=0)
batch_target_tokens = torch.cat([target_tokens_1, target_tokens_2], dim=0)
print(f"Batch Input Tokens Shape: {batch_input_tokens.shape}")
print(f"Batch Input Tokens: {batch_input_tokens}")
print(f"Batch Target Tokens Shape: {batch_target_tokens.shape}")
print(f"Batch Target Tokens: {batch_target_tokens}")
with torch.no_grad():
logits = model(batch_input_tokens)
probs = torch.softmax(logits, dim = -1)
generated_tokens = torch.argmax(probs, dim = -1, keepdim = True)
print(f"Generated Tokens: \n{generated_tokens}")
print(f"Target text for example 1: {example_target_text_1}")
print(f"Generated text for example 1: {tokens_to_text(generated_tokens[0].flatten(), tokenizer)}")
print(f"Target text for example 2: {example_target_text_2}")
print(f"Generated text for example 2: {tokens_to_text(generated_tokens[1].flatten(), tokenizer)}")
Goal: Validate shapes, decoding, and baseline behavior under greedy prediction
Selecting target probabilities and estimating a simple loss
I extracted probabilities for target tokens using advanced indexing, converted them to log‑probs, and averaged a negative log‑probability as a proxy for loss. This was a hands‑on way to inspect model confidence across positions
batch_index = 0
target_probs_1 = probs[batch_index, [0,1,2,3], batch_target_tokens[batch_index]]
print(f"Target probabilities for example 1: {target_probs_1}")
batch_index = 1
target_probs_2 = probs[batch_index, [0,1,2,3], batch_target_tokens[batch_index]]
print(f"Target probabilities for example 2: {target_probs_2}")
log_probs = torch.log(torch.cat((target_probs_1, target_probs_2)))
print(f"Log probabilities: {log_probs}")
mean_log_probs = torch.mean(log_probs)
print(f"Mean log probability: {mean_log_probs}")
negative_mean_log_probs = mean_log_probs * -1
print(f"Negative mean log probability (loss): {negative_mean_log_probs}")
Note: Broadcasting across differing index shapes yields a grid of probabilities, useful for exploratory inspection rather than per‑step alignment
Inspecting logits and computing cross‑entropy
I compared the manual loss proxy with the standard cross‑entropy by flattening logits and targets into the expected shapes. This confirmed consistency
print(f"Logits shape: {logits.shape}")
print(f"Logits: {logits}")
print(f"Targets shape: {batch_target_tokens.shape}")
print(f"Targets: {batch_target_tokens}")
flat_logits = logits.flatten(0, 1)
flat_targets = batch_target_tokens.flatten()
print(f"Flattened Logits shape: {flat_logits.shape}")
print(f"Flattened Targets shape: {flat_targets.shape}")
loss_fn = torch.nn.functional.cross_entropy
loss = loss_fn(flat_logits, flat_targets)
print(f"Cross-entropy loss: {loss}")
Why flatten: Loss functions expect (N, C) logits and (N,) targets; treating every time step as a classification example is standard for language modelin
Loading the corpus and estimating tokens
I read the combined raw corpus from disk and reported character and token counts to estimate the training budget under GPT‑2 BPE.
data_file_path = "data/all_books.txt"
with open(data_file_path, 'r', encoding = 'utf-8') as books:
text_data = books.read()
print(f"Total Characters: {len(text_data)}")
print(f"Total Tokens after encoding: {len(tokenizer.encode(text_data))}")
Reminder: GPT‑2 BPE uses a byte‑level vocabulary of 50,257 where token count is not the same as word count.
Total Characters: 19849702. Total Tokens after encoding: 5611150
Building training and validation DataLoaders
I split the text into train and validation subsets and constructed DataLoaders that yield (x, y) batches for next‑token prediction with a fixed context window
training_ratio = 0.9
training_size = int(training_ratio * len(text_data))
training_dataset = text_data[:training_size]
validation_dataset = text_data[training_size:]
from modules.DataLoader import create_dataloader
training_dataloader = create_dataloader(
training_dataset,
max_length = SYDSGPT_CONFIG_345M['context_length'],
step_size = SYDSGPT_CONFIG_345M['context_length'],
batch_size = 8,
shuffle = True,
drop_last = True,
num_workers = 0,
)
validation_dataloader = create_dataloader(
validation_dataset,
max_length = SYDSGPT_CONFIG_345M['context_length'],
step_size = SYDSGPT_CONFIG_345M['context_length'],
batch_size = 8,
shuffle = True,
drop_last = True,
num_workers = 0,
)
print(f"Number of training batches: {len(training_dataloader)}")
print("Training loader:")
for x, y in training_dataloader:
print(x.shape, y.shape)
break
print(f"Number of validation batches: {len(validation_dataloader)}")
print("Validation loader:")
for x, y in validation_dataloader:
print(x.shape, y.shape)
break
- Key parameters: Context length defines window size; stride equals context length to avoid overlap in this configuration
- Batch shapes: x and y are LongTensors of shape (batch size, max length)
- Tip: Adjust batch size or max length based on memory constraints
Utilities for computing loss
I wrote two small helpers: one for per‑batch loss and one for averaging loss over a loader. These make evaluation and training loops concise and consistent.
def calc_batch_loss(input_batch, target_batch, model, device):
input_batch = input_batch.to(device)
target_batch = target_batch.to(device)
logits = model(input_batch)
loss = torch.nn.functional.cross_entropy(logits.flatten(0,1), target_batch.flatten())
return loss
def calc_loader_loss(data_loader, model, device, num_batches = None):
total_loss = 0
if len(data_loader) == 0:
return float('nan')
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for batch_index, (input_batch, target_batch) in enumerate(data_loader):
if batch_index >= num_batches:
break
else:
batch_loss = calc_batch_loss(input_batch, target_batch, model, device)
total_loss += batch_loss.item()
return total_loss / num_batches
Averaging semantics: Each batch loss is a mean per token; the loader loss averages those batch means equally
Baseline training and validation losses
Before training, I computed baseline losses without autograd to sanity check the pipeline.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model.to(device)
with torch.no_grad():
training_loss = calc_loader_loss(training_dataloader, model, device)
validation_loss = calc_loader_loss(validation_dataloader, model, device)
print(f"Initial Training Loss: {training_loss}")
print(f"Initial Validation Loss: {validation_loss}")
Expectation: With random initialization and a vocab of 50,257, an initial loss near ln(50257) ≈ 10.82 is typical
Sample generation helper during training
To monitor qualitative progress, I added a utility that generates text from a start context at the end of each epoch.
def generate_sample_text(model, tokenizer, device, start_context):
model.eval()
context_size = SYDSGPT_CONFIG_345M['context_length']
input_tokens = text_to_tokens(start_context, tokenizer).to(device)
with torch.no_grad():
generated_tokens = generate_simple(model, input_tokens, 100, context_size)
generated_text = tokens_to_text(generated_tokens, tokenizer)
print(f"Generated Text: {generated_text}".replace("\n", " "))
model.train()
Training loop with periodic evaluation and checkpointing
I trained the model with AdamW, tracked tokens processed, evaluated losses periodically, saved autosave checkpoints, and printed sample generations after each epoch.
def train_model_v1(model, training_dataloader, validation_dataloader, optimizer, device, num_epochs, evaluation_frequency, evaluation_iterations, start_context, tokenizer, checkpoint_interval = 500):
training_losses, validation_losses, total_tokens_processed = [], [], []
tokens_processed = 0
global_step = -1
for epoch in range(num_epochs):
model.train()
for input_batch, target_batch in training_dataloader:
optimizer.zero_grad()
loss = calc_batch_loss(input_batch, target_batch, model, device)
loss.backward()
optimizer.step()
tokens_processed += input_batch.numel()
global_step += 1
total_tokens_processed.append(tokens_processed)
print(f"Epoch {epoch+1}, Step {global_step}: Tokens Processed = {tokens_processed}")
if global_step % evaluation_frequency == 0:
model.eval()
with torch.no_grad():
training_loss = calc_loader_loss(training_dataloader, model, device, evaluation_iterations)
validation_loss = calc_loader_loss(validation_dataloader, model, device, evaluation_iterations)
training_losses.append(training_loss)
validation_losses.append(validation_loss)
print(f"Epoch {epoch+1}, Step {global_step}: Training Loss = {training_loss}, Validation Loss = {validation_loss}, Tokens Processed = {tokens_processed}")
model.train()
if global_step % checkpoint_interval == 0:
torch.save({"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, "autosave_sydsgpt_345m_trained_model_optimizer.pth")
generate_sample_text(model, tokenizer, device, start_context)
return training_losses, validation_losses, total_tokens_processed
I ran the training as follows, initially for 5 epochs. It took me approx. 11 hours per epoch of training on my 3080 Ti GPU:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(246)
model = SydsGPT(SYDSGPT_CONFIG_345M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0002, weight_decay = 0.05)
num_epochs = 5
training_losses, validation_losses, total_tokens_processed = train_model_v1(
model,
training_dataloader,
validation_dataloader,
optimizer,
device,
num_epochs,
evaluation_frequency = 100,
evaluation_iterations = 2,
start_context = "Once upon a time",
tokenizer = tokenizer
)
Checkpointing: Saved a final checkpoint after training
torch.save({"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, "sydsgpt_345m_trained_model_optimizer.pth")
Restoring from checkpoint and continuing training
I restored the model and optimizer states from checkpoint, relocated optimizer tensors to the correct device, and generated a sample to verify the restore. Then I continued training for a couple more epochs.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = SydsGPT(SYDSGPT_CONFIG_345M)
checkpoint = torch.load("sydsgpt_345m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0002, weight_decay = 0.05)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
model.to(device)
generate_sample_text(model, tokenizer, device, "once upon a time")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = SydsGPT(SYDSGPT_CONFIG_345M)
checkpoint = torch.load("sydsgpt_345m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0002, weight_decay = 0.05)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
model.to(device)
num_epochs = 2
training_losses, validation_losses, total_tokens_processed = train_model_v1(
model,
training_dataloader,
validation_dataloader,
optimizer,
device,
num_epochs,
evaluation_frequency = 100,
evaluation_iterations = 2,
start_context = "Once upon a time",
tokenizer = tokenizer
)
I also generated a longer sample to inspect coherence:
model.eval()
output_tokens = generate_simple(model, text_to_tokens("once upon a time", tokenizer).to(device), 200, SYDSGPT_CONFIG_345M['context_length'])
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text: {output_text}")
Sample excerpt: The output included multi‑sentence, grammatically correct text with recurring narrative structures and dialogue markers
Output Text: once upon a time before she was so busy, that I felt quite sure that I felt quite sure that I was not quite sure that I felt it was. “You are a child,” said I, “that you are a beautiful woman, and you are a beautiful woman.” “Yes,” said Ada, “that there is nothing in it.” “That is,” said my guardian, “that there is nothing else that makes it so.” “That is not,” said my guardian, “that there is such a time as you are.” “You are not to be removed,” said my guardian, “that there is no considerable answer.” “You are not to be always happy,” said my guardian, “that there is something of the kind
Exploring sampling strategies: probabilities, temperature, and top‑k
To understand sampling dynamics, I built a small illustrative example using a toy vocabulary and logits. I compared greedy selection with multinomial sampling and examined how temperature and top‑k filtering shape token distributions
example_vocab = {
"once" : 0,
"upon" : 1,
"a" : 2,
"time" : 3,
"before" : 4,
"she" : 5,
"lived" : 6,
"happily" : 7,
"ever" : 8,
"after" : 9
}
inverse_example_vocab = {v: k for k, v in example_vocab.items()}
example_next_token_logits = torch.tensor([1.35, 1.86, 1.53, 0.17, 3.63, -1.82, -2.17, -3.90, -4.85, -5.38])
example_next_token_probs = torch.softmax(example_next_token_logits, dim = 0)
example_greedy_next_token = torch.argmax(example_next_token_probs).item()
print(f"Greedy Next Token: {inverse_example_vocab[example_greedy_next_token]}")
torch.manual_seed(246)
example_random_next_token = torch.multinomial(example_next_token_probs, num_samples = 1).item()
print(f"Random Next Token: {inverse_example_vocab[example_random_next_token]}")
def get_sampled_tokens(probs):
sampled_token = [torch.multinomial(probs, num_samples = 1).item() for i in range(1000)]
sampled_tokens = torch.bincount(torch.tensor(sampled_token))
for i, frequency in enumerate(sampled_tokens):
print(f"Token: {inverse_example_vocab[i]}: {frequency.item()} times")
get_sampled_tokens(example_next_token_probs)
Temperature scaling:
def softmax_with_temperature(logits, temperature):
scaled_logits = logits / temperature
probs = torch.softmax(scaled_logits, dim = 0)
return probs
temperatures = [0.1, 0.5, 1.0, 2.0]
for temp in temperatures:
temperature_scaled_probs = softmax_with_temperature(example_next_token_logits, temp)
print(f"\n Temperature: {temp}")
get_sampled_tokens(temperature_scaled_probs)
Top‑k filtering
top_k = 4
top_k_logits, top_k_indices = torch.topk(example_next_token_logits, top_k)
print(f"Top-{top_k} Indices: {top_k_indices}")
print(f"Top-{top_k} Logits: {top_k_logits}")
new_logits = torch.where(
condition = example_next_token_logits < top_k_logits[-1],
input = torch.tensor(float('-inf')),
other = example_next_token_logits
)
print(f"New Logits after Top-{top_k} filtering: {new_logits}")
top_k_probs = torch.softmax(new_logits, dim = 0)
print(f"Top-{top_k} Probabilities: {top_k_probs}")
get_sampled_tokens(top_k_probs)
Insight: Lower temperature sharpens distributions and favors high‑probability tokens; top‑k truncates the distribution to the k most likely tokens for more controlled sampling
A configurable generation function with temperature and top‑k
I implemented a general generation helper that supports temperature scaling, top‑k filtering, context truncation, and optional EOS termination
def generate(model, input_tokens, max_new_tokens, context_size, temperature = 1.0, top_k = None, eos_id = None):
for _ in range(max_new_tokens):
input_context = input_tokens[:, -context_size:]
with torch.no_grad():
logits = model(input_context)
logits = logits[:, -1, :]
if top_k is not None:
top_k_logits, _ = torch.topk(logits, top_k)
min_top_k_logit = top_k_logits[:, -1]
logits = torch.where(logits < min_top_k_logit, torch.tensor(float('-inf')).to(logits.device), logits)
if temperature > 0.0:
logits = logits / temperature
probs = torch.softmax(logits, dim = -1)
next_token = torch.multinomial(probs, num_samples = 1)
else:
next_token = torch.argmax(logits, dim = -1, keepdim = True)
if next_token == eos_id:
break
input_tokens = torch.cat((input_tokens, next_token), dim = 1)
return input_tokens
I used it to generate longer outputs with controlled sampling:
torch.manual_seed(246)
input_text = "once upon a time"
input_tokens = text_to_tokens(input_text, tokenizer).to(device)
output_tokens = generate(model, input_tokens, 200, SYDSGPT_CONFIG_345M['context_length'], temperature = 1.5, top_k = 40)
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text:\n {output_text}")
Observation: With temperature and top‑k tuned, the outputs remained grammatical while introducing variability and stylistic detail
Output Text: once upon a time after his arrival. But again had that time too already settled in the possibility of talking in the existing histories given personal opportunity of reproachfulness? Hath it not not been brought together simply that one who had not always tried? And the most wonderful man, in a sort of unbension with which he had been capable of using the money by a man who had done so intimatelyision and must not think about as a politician be better in his physical conversation, for whose knowledge there must give a reference to the facts (a lady, especially on purpose, placed upright in their hands) of the unhappy man. The victim might receive her reason to be as much as a hypocrite as possible, but of having supposed she to do as, as it came upon him, as a mode of their being a woman. The latter part of his respect took place to him as much as much as possible to give it him
What I learned in Part 6
- Pretraining basics: I built and validated a complete language modeling pipeline including data loading, tokenization, batching, loss computation, training, evaluation, sampling, and checkpointing.
- Grammatically correct generation: After training on 20 books from Project Gutenberg, SydsGPT produced coherent, grammatical text with clear sentence structure and narrative elements.
- Reproducibility: Fixed seeds, consistent device handling, and periodic checkpoints made the process repeatable and auditable.
- Sampling behavior: Temperature and top‑k are powerful controls over style, diversity, and determinism during generation.
Try It Yourself
The full notebook with all the steps, from preparing the corpus, data loaders, loss computation, pretraining loop, text sampling and generation, is available here:
SydsGPT pretraining Repository
Clone the repo, open the Jupyter notebook, and step through the code.
Build It Yourself
If you want to try building it yourself, you can find the complete code with detailed explanations of each block in the source code section at the end of this post. All the best!
What comes next
Part 7 will focus on training optimization techniques to push efficiency on limited hardware.
- Mixed precision training: Reduce memory footprint and increase throughput by using bfloat16/float16 safely during training.
- Gradient accumulation: Simulate larger batch sizes without exceeding memory limits.
- Flash attention: Optimize attention computation for speed and memory efficiency.
- KV cache: Speed up autoregressive generation by caching key/value tensors across steps.
The roadmap remains clear. I will train a small language model around 200M parameters on a diverse corpus of approximately 3B tokens, then fine‑tune it on domain‑specific data. I will add tool calling for web search and build a RAG pipeline to interact with private data. The aim is a private assistant that respects privacy and delivers practical value, proving that small models can go far when engineered with care
Source Code
Configure and instantiate SydsGPT (345M-style)¶
This cell prepares a ready-to-use SydsGPT model with a GPT-2–345M–like configuration.
It imports PyTorch and the model class, defines a configuration dictionary, sets a random seed for reproducibility, instantiates the model, and switches it to evaluation mode.
What the code does¶
import torch: Loads PyTorch for tensors, RNG control, and device utilities.from model.SydsGPT import SydsGPT: Imports the model class from the localmodelpackage.SYDSGPT_CONFIG_345M = {...}: Defines key hyperparameters:vocab_size(int): Size of the tokenizer vocabulary the model predicts over.context_length(int): Maximum sequence length (number of time steps) the model attends to.embedding_dim(int): Hidden size / channel dimension of token and position embeddings and transformer layers.num_heads(int): Number of attention heads per transformer block. Should divideembedding_dimevenly.num_layers(int): Number of stacked transformer blocks (depth).dropout(float): Dropout probability used in training; disabled automatically in eval mode.qkv_bias(bool): Whether to include bias terms in Q/K/V projection layers.
torch.manual_seed(246): Sets the RNG seed for reproducibility (affects random init, sampling, etc.).model = SydsGPT(SYDSGPT_CONFIG_345M): Builds the model using the provided config.model.eval(): Puts the model in inference mode (disables dropout; batch-norm-like training behavior is not used in this architecture).
Expected inputs and outputs¶
Input: integer token IDs shaped
(batch_size, seq_len)with0 <= token_id < vocab_size.Output: logits shaped
(batch_size, seq_len, vocab_size)ready for softmax or sampling.
Device and performance tips¶
To use GPU if available:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) x = x.to(device)
For memory-sensitive scenarios, consider bfloat16/float16 inference (on supported hardware):
model = model.to(device).to(dtype=torch.bfloat16) x = x.to(device) with torch.autocast(device_type=device.type, dtype=torch.bfloat16): logits = model(x)
Common pitfalls¶
ModuleNotFoundError: Ensure you run the notebook from the repo root (where this notebook lives) and thatmodel/andmodules/contain__init__.py.num_headsmust evenly divideembedding_dim.Using
model.train()during inference will enable dropout and change outputs; keepmodel.eval()for deterministic behavior (given a fixed seed and inputs).
import torch
from model.SydsGPT import SydsGPT
SYDSGPT_CONFIG_345M = {
"vocab_size" : 50257,
"context_length" : 512,
"embedding_dim" : 1024,
"num_heads" : 16,
"num_layers" : 24,
"dropout" : 0.1,
"qkv_bias" : False
}
torch.manual_seed(246)
model = SydsGPT(SYDSGPT_CONFIG_345M)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
SydsGPT(
(token_embedding): Embedding(50257, 1024)
(position_embedding): Embedding(512, 1024)
(drop_embedding): Dropout(p=0.1, inplace=False)
(transformer_blocks): Sequential(
(0): TransformerBlock(
(attention): MultiHeadAttention(
(weight_query): Linear(in_features=1024, out_features=1024, bias=False)
(weight_key): Linear(in_features=1024, out_features=1024, bias=False)
(weight_value): Linear(in_features=1024, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(output_projection): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm1): LayerNorm()
(feed_forward): FeedForward(
(layers): Sequential(
(0): Linear(in_features=1024, out_features=4096, bias=True)
(1): GELU()
(2): Linear(in_features=4096, out_features=1024, bias=True)
)
)
(layer_norm2): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
...
(23): TransformerBlock(
(attention): MultiHeadAttention(
(weight_query): Linear(in_features=1024, out_features=1024, bias=False)
(weight_key): Linear(in_features=1024, out_features=1024, bias=False)
(weight_value): Linear(in_features=1024, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(output_projection): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm1): LayerNorm()
(feed_forward): FeedForward(
(layers): Sequential(
(0): Linear(in_features=1024, out_features=4096, bias=True)
(1): GELU()
(2): Linear(in_features=4096, out_features=1024, bias=True)
)
)
(layer_norm2): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(final_layer_norm): LayerNorm()
(output_projection): Linear(in_features=1024, out_features=50257, bias=False)
)
Simple text generation with tiktoken¶
This cell demonstrates a minimal text-generation loop using the model and GPT‑2 BPE tokenization.
It converts input text to token IDs, generates a few new tokens with generate_simple, then decodes the tokens back to text.
What the code does¶
from modules.GenerateSimple import generate_simple: Imports a helper that appends new tokens to the context.import tiktoken: Loads the GPT‑2 tokenizer implementation used for encoding/decoding.text_to_tokens(text, tokenizer): Encodes a string into token IDs and wraps them into a batch of size 1 with shape(1, seq_len).tokens_to_text(tokens, tokenizer): Decodes a tensor of token IDs back to a string.tokenizer = tiktoken.get_encoding("gpt2"): Uses the GPT‑2 BPE vocabulary (vocab size 50257).input_text = "Once upon a time": Seed prompt for generation.input_tokens = text_to_tokens(...): Converts the prompt to(1, seq_len)tensor.output_tokens = generate_simple(model, input_tokens, 10, context_length): Generates 10 new tokens (total length increases by up to 10, subject to context length).print(...): Shows the decoded text after generation.
Inputs and outputs¶
Input text: a Python string prompt.
Encoded input:
input_tokenswith dtypetorch.longand shape(1, T)where0 <= token_id < vocab_size.Output tokens: a tensor typically shaped
(1, T + N)whereNismax_new_tokens(here 10), but may be shorter if the function stops early (e.g., on end‑of‑text).Decoded output: a Python string built from
output_tokens.
Tokenization notes¶
This example allows the special token
<|endoftext|>during encoding:tokenizer.encode(text, allowed_special={"<|endoftext|>"})
Decoding uses
tokenizer.decode(...)to reconstruct readable text.
from modules.GenerateSimple import generate_simple
import tiktoken
def text_to_tokens(text, tokenizer):
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
return tokens
def tokens_to_text(tokens, tokenizer):
text = tokenizer.decode(tokens.squeeze(0).tolist())
return text
tokenizer = tiktoken.get_encoding("gpt2")
input_text = "Once upon a time"
input_tokens = text_to_tokens(input_text, tokenizer)
output_tokens = generate_simple(model, input_tokens, 100, SYDSGPT_CONFIG_345M['context_length'])
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text: {output_text}")
Output Text: Once upon a time,
Batch inference and greedy token selection (baseline)¶
This cell runs a batched forward pass on two prompts, converts logits to probabilities, and selects the most likely token at each position via greedy argmax. It then decodes these per‑position predictions to text for a quick qualitative check.
What the code does¶
Defines two short input prompts and two short “target” snippets for comparison.
Encodes strings to token IDs with
text_to_tokensand batches them usingtorch.catto form tensors of shape(batch_size=2, seq_len).Computes model logits with
model(batch_input_tokens)→ shape(B, T, V)whereV = vocab_size.Converts logits to probabilities with
softmaxand picks the most likely token id at each position usingargmaxover the vocab dimension.Decodes the resulting token IDs back to text with
tokens_to_textfor a quick sanity check.
Shapes and dtypes¶
batch_input_tokens:(B, T)of dtypetorch.long.logits:(B, T, V)of dtypetorch.float32(or your model’s dtype).probs:(B, T, V)after softmax.generated_tokens:(B, T, 1)afterargmax(..., keepdim=True).
example_input_text_1 = "The quick brown fox"
example_target_text_1 = " quick brown fox jumps"
example_input_text_2 = "In a galaxy far"
example_target_text_2 = " a galaxy far away"
input_tokens_1 = text_to_tokens(example_input_text_1, tokenizer)
input_tokens_2 = text_to_tokens(example_input_text_2, tokenizer)
target_tokens_1 = text_to_tokens(example_target_text_1, tokenizer)
target_tokens_2 = text_to_tokens(example_target_text_2, tokenizer)
batch_input_tokens = torch.cat([input_tokens_1, input_tokens_2], dim=0)
batch_target_tokens = torch.cat([target_tokens_1, target_tokens_2], dim=0)
print(f"Batch Input Tokens Shape: {batch_input_tokens.shape}")
print(f"Batch Input Tokens: {batch_input_tokens}")
print(f"Batch Target Tokens Shape: {batch_target_tokens.shape}")
print(f"Batch Target Tokens: {batch_target_tokens}")
with torch.no_grad():
logits = model(batch_input_tokens)
probs = torch.softmax(logits, dim = -1)
generated_tokens = torch.argmax(probs, dim = -1, keepdim = True)
print(f"Generated Tokens: \n{generated_tokens}")
print(f"Target text for example 1: {example_target_text_1}")
print(f"Generated text for example 1: {tokens_to_text(generated_tokens[0].flatten(), tokenizer)}")
print(f"Target text for example 2: {example_target_text_2}")
print(f"Generated text for example 2: {tokens_to_text(generated_tokens[1].flatten(), tokenizer)}")
Batch Input Tokens Shape: torch.Size([2, 4])
Batch Input Tokens: tensor([[ 464, 2068, 7586, 21831],
[ 818, 257, 16161, 1290]])
Batch Target Tokens Shape: torch.Size([2, 4])
Batch Target Tokens: tensor([[ 2068, 7586, 21831, 18045],
[ 257, 16161, 1290, 1497]])
Generated Tokens:
tensor([[[34912],
[11918],
[50106],
[22915]],
[[47903],
[31311],
[27997],
[ 8618]]])
Target text for example 1: quick brown fox jumps
Generated text for example 1: Trace chaoscatchingoutput
Target text for example 2: a galaxy far away
Generated text for example 2: ּ Debbie fats gained
Selecting target probabilities and estimating a simple loss¶
This cell extracts probabilities assigned to the target tokens, converts them to log‑probs, and computes an averaged negative log‑probability (a proxy for loss). It demonstrates indexing into a (batch, time, vocab) probability tensor and aggregates results across two examples.
Tensors and shapes¶
probs: shape(B, T, V)— probabilities over the vocabulary at every time step for each example.batch_target_tokens: shape(B, T)— integer token IDs for the target sequence per example.batch_index: scalarintselecting which example (0 or 1 here).
What each line does¶
batch_index = 0(and later1): chooses which example in the batch to analyze.target_probs_1 = probs[batch_index, [0,1,2,3], batch_target_tokens[batch_index]]:Narrows to the chosen example →
probs[batch_index]has shape(T, V).Uses advanced indexing with two indexers:
Time indices
[0,1,2,3]select 4 positions from the time dimension.batch_target_tokens[batch_index]is a length‑Tvector of token IDs and indexes the vocab dimension.
Because these two indexers have different shapes
(4,)and(T,), PyTorch broadcasts them, producing a matrix of shape(4, T). Each row corresponds to a chosen time step; each column corresponds to a target token at some position in the sequence. This yields a grid of probabilities, not a 1‑to‑1 per‑step alignment.
Repeat for
batch_index = 1to gettarget_probs_2with shape(4, T)as well.log_probs = torch.log(torch.cat((target_probs_1, target_probs_2))): concatenates along the first dimension → shape(8, T)and applies natural log to get log‑probs.mean_log_probs = torch.mean(log_probs): averages all selected log‑probs.negative_mean_log_probs = mean_log_probs * -1: converts to a positive quantity akin to an average negative log‑probability.
batch_index = 0
target_probs_1 = probs[batch_index, [0,1,2,3], batch_target_tokens[batch_index]]
print(f"Target probabilities for example 1: {target_probs_1}")
batch_index = 1
target_probs_2 = probs[batch_index, [0,1,2,3], batch_target_tokens[batch_index]]
print(f"Target probabilities for example 2: {target_probs_2}")
log_probs = torch.log(torch.cat((target_probs_1, target_probs_2)))
print(f"Log probabilities: {log_probs}")
mean_log_probs = torch.mean(log_probs)
print(f"Mean log probability: {mean_log_probs}")
negative_mean_log_probs = mean_log_probs * -1
print(f"Negative mean log probability (loss): {negative_mean_log_probs}")
Target probabilities for example 1: tensor([3.4469e-05, 1.3686e-05, 6.9494e-06, 1.3889e-05])
Target probabilities for example 2: tensor([4.0864e-05, 1.1410e-05, 9.1020e-06, 8.7344e-06])
Log probabilities: tensor([-10.2754, -11.1992, -11.8768, -11.1844, -10.1053, -11.3810, -11.6070,
-11.6482])
Mean log probability: -11.15966510772705
Negative mean log probability (loss): 11.15966510772705
Inspecting logits/targets and computing cross‑entropy loss¶
This cell prints tensor shapes/values, flattens logits and targets to a 2‑D/1‑D form, and computes a categorical cross‑entropy loss. It’s a quick way to sanity‑check model outputs against targets and to illustrate how classification loss is applied to sequence models.
What the code does¶
Prints
logits.shapeandbatch_target_tokens.shapeto verify expected dimensions:logits:(B, T, V)— unnormalized scores over the vocabulary for each batch/time position.batch_target_tokens:(B, T)— integer token IDs (class labels) for each batch/time position.
Flattens tensors for loss computation:
flat_logits = logits.flatten(0, 1)→ shape((B*T), V).flat_targets = batch_target_tokens.flatten()→ shape((B*T),).
Computes cross‑entropy using
torch.nn.functional.cross_entropyon raw logits and integer class targets.
Why flatten?¶
Most loss functions in PyTorch expect 2‑D logits (N, C) and 1‑D targets (N,). Flattening the batch and time dimensions treats every time step in the batch as an independent classification example, which is equivalent to computing the mean loss over all (b, t) positions.
Dtype, device, and numerical stability¶
logitsshould be floating point (e.g.,float32), andtargetsmust be integer type (torch.long).Ensure
logitsandtargetsreside on the same device (CPU/GPU) to avoid runtime errors.cross_entropyis numerically stable; you rarely need to clamp or add epsilons yourself.
Conclusion¶
- We get the same Cross Entropy loss as our manual calculation in the previous step.
print(f"Logits shape: {logits.shape}")
print(f"Logits: {logits}")
print(f"Targets shape: {batch_target_tokens.shape}")
print(f"Targets: {batch_target_tokens}")
flat_logits = logits.flatten(0, 1)
flat_targets = batch_target_tokens.flatten()
print(f"Flattened Logits shape: {flat_logits.shape}")
print(f"Flattened Targets shape: {flat_targets.shape}")
loss_fn = torch.nn.functional.cross_entropy
loss = loss_fn(flat_logits, flat_targets)
print(f"Cross-entropy loss: {loss}")
Logits shape: torch.Size([2, 4, 50257])
Logits: tensor([[[-0.0807, 0.1419, -0.0128, ..., 0.0488, -1.1006, -0.5177],
[ 0.2431, 0.1199, 0.4347, ..., -1.4129, -0.4291, -0.3951],
[-0.4904, -0.1851, -0.0027, ..., -1.0833, 1.1487, -0.5754],
[-0.0147, 0.2563, 1.2010, ..., -0.4941, -0.5542, -1.3598]],
[[ 0.0052, 0.1310, 0.0080, ..., -0.1653, -0.8215, -0.9171],
[-0.2092, -0.1521, 0.5149, ..., -0.5578, 0.0754, -1.4415],
[-0.2539, -0.2681, 0.4995, ..., -0.6771, 0.0557, -0.8401],
[-0.3334, -0.0690, 1.2449, ..., 0.1689, -0.3848, -0.5397]]])
Targets shape: torch.Size([2, 4])
Targets: tensor([[ 2068, 7586, 21831, 18045],
[ 257, 16161, 1290, 1497]])
Flattened Logits shape: torch.Size([8, 50257])
Flattened Targets shape: torch.Size([8])
Cross-entropy loss: 11.15966510772705
Load raw corpus and estimate token count¶
This cell reads a text corpus from disk and reports:
Total number of characters in the file (useful for sanity checks and throughput planning).
Total number of tokens after GPT‑2 BPE encoding using
tiktoken(this approximates the training token budget).
What the code does¶
Opens
data/all_books.txtwith UTF‑8 encoding and loads it into memory as a single string.Uses the previously created
tokenizer = tiktoken.get_encoding("gpt2")to encode the entire string to token IDs.Prints character and token counts.
Prerequisites¶
The file
data/all_books.txtshould exist relative to this notebook’s working directory (the repository root).A
tokenizermust be available in the notebook scope (earlier cells settokenizer = tiktoken.get_encoding("gpt2")).
Notes on GPT‑2 BPE tokenization¶
GPT‑2 uses a byte‑level BPE with a vocabulary size of 50,257; token count ≠ word count.
Token counts depend on punctuation, whitespace, and casing; the same characters can yield different tokenizations if the text changes slightly.
data_file_path = "data/all_books.txt"
with open(data_file_path, 'r', encoding = 'utf-8') as books:
text_data = books.read()
print(f"Total Characters: {len(text_data)}")
print(f"Total Tokens after encoding: {len(tokenizer.encode(text_data))}")
Total Characters: 19849702 Total Tokens after encoding: 5611150 Total Tokens after encoding: 5611150
Create training/validation DataLoaders from raw text¶
This cell splits the raw corpus into train/validation subsets and builds batched (input, target) sequences for next‑token prediction using a fixed context window.
What the code does¶
Splits
text_datawithtraining_ratio = 0.9:training_dataset= first 90% of charactersvalidation_dataset= remaining 10%
Uses
modules.DataLoader.create_dataloader(...)to produce iterable loaders that yield(x, y)batches.Reports the number of batches and prints the shape of one batch from each loader for a quick sanity check.
Key parameters¶
max_length = SYDSGPT_CONFIG_345M['context_length']:- Sequence length per sample (context window size).
step_size = SYDSGPT_CONFIG_345M['context_length'] // 2:- Half‑overlapping windows (stride is half the window). This increases dataset size vs non‑overlapping windows and improves sample diversity at the cost of more compute.
batch_size = 64:- Number of sequences per batch. Tune based on GPU/CPU memory; larger batches increase throughput but require more memory.
shuffle = True:- Randomizes sample order to improve training stability.
drop_last = True:- Drops incomplete final batch to keep shapes consistent.
num_workers = 0:- Data loading in the main process. Increase (e.g., 2–8) to parallelize preprocessing if your implementation supports it.
Expected batch shapes and semantics¶
x: LongTensor(batch_size, max_length)— input token IDs.y: LongTensor(batch_size, max_length)— target token IDs.Language modeling convention: logits at time
tare trained to predictyat timet+1(a one‑token shift). Yourcreate_dataloadershould either prepareyaccordingly or you can apply the shift when computing loss.
Assumptions (API contract)¶
create_dataloader(text, max_length, step_size, ...)tokenizestext(using the same tokenizer as earlier), slices into windows of lengthmax_lengthwith stridestep_size, and returns batches of(x, y)suitable for a next‑token objective.If your implementation expects pre‑tokenized data or returns different shapes, adjust parameters or downstream code accordingly.
Tips and troubleshooting¶
Throughput vs redundancy:
- Smaller
step_size→ more overlapping windows → more training samples but higher redundancy. Start withmax_length//2and tune.
- Smaller
Memory pressure:
- Reduce
batch_sizeif you hit OOM; alternatively, keepbatch_sizeand reducemax_length.
- Reduce
training_ratio = 0.9
training_size = int(training_ratio * len(text_data))
training_dataset = text_data[:training_size]
validation_dataset = text_data[training_size:]
from modules.DataLoader import create_dataloader
training_dataloader = create_dataloader(
training_dataset,
max_length = SYDSGPT_CONFIG_345M['context_length'],
step_size = SYDSGPT_CONFIG_345M['context_length'],
batch_size = 8,
shuffle = True,
drop_last = True,
num_workers = 0,
)
validation_dataloader = create_dataloader(
validation_dataset,
max_length = SYDSGPT_CONFIG_345M['context_length'],
step_size = SYDSGPT_CONFIG_345M['context_length'],
batch_size = 8,
shuffle = True,
drop_last = True,
num_workers = 0,
)
print(f"Number of training batches: {len(training_dataloader)}")
print("Training loader:")
for x, y in training_dataloader:
print(x.shape, y.shape)
break
print(f"Number of validation batches: {len(validation_dataloader)}")
print("Validation loader:")
for x, y in validation_dataloader:
print(x.shape, y.shape)
break
Number of training batches: 1249 Training loader: torch.Size([8, 512]) torch.Size([8, 512]) Number of validation batches: 120 Validation loader: torch.Size([8, 512]) torch.Size([8, 512])
Mini utility: per-batch loss computation¶
This helper computes the cross‑entropy loss for a single batch. It moves tensors to the desired device, runs the model forward pass, flattens logits/targets into the expected shapes, and returns a scalar loss suitable for backward().
Function signature¶
calc_batch_loss(input_batch, target_batch, model, device) -> torch.Tensor- Returns a scalar 0‑D tensor (the mean cross‑entropy over all tokens in the batch).
Inputs and shapes¶
input_batch: LongTensor(B, T)— token IDs fed into the model.target_batch: LongTensor(B, T)— token IDs used as labels.model: ann.Modulewithforward(input_batch) -> logitsof shape(B, T, V).device: target device (e.g.,torch.device("cuda")or"cpu").
What happens inside¶
Moves
input_batchandtarget_batchtodevicefor consistency.Computes
logits = model(input_batch)with shape(B, T, V).Flattens for loss:
logits.flatten(0, 1)→((B*T), V)andtarget_batch.flatten()→((B*T),).Applies
torch.nn.functional.cross_entropyto raw logits and integer targets (no softmax needed).Returns the mean loss across all
(B*T)token positions as a scalar.
def calc_batch_loss(input_batch, target_batch, model, device):
input_batch = input_batch.to(device)
target_batch = target_batch.to(device)
logits = model(input_batch)
loss = torch.nn.functional.cross_entropy(logits.flatten(0,1), target_batch.flatten())
return loss
Compute average loss over a DataLoader (validation helper)¶
This utility iterates over a DataLoader, computes the per‑batch cross‑entropy via calc_batch_loss, and returns the mean of batch losses. It’s intended for quick validation/evaluation, not training.
Function signature¶
calc_loader_loss(data_loader, model, device, num_batches=None) -> float- Returns a Python float: the average of
num_batchesbatch losses.
- Returns a Python float: the average of
Parameters¶
data_loader: Iterable yielding(input_batch, target_batch)tensors, typically shaped(B, T)each.model: The language model (nn.Module). For evaluation, callmodel.eval()beforehand.device:torch.deviceor string identifying the device (e.g.,"cuda","cpu").num_batches(optional): If provided, caps the number of batches processed (useful for fast estimates). Defaults to the full length of the loader when available.
What the code does¶
Initializes
total_loss = 0.Handles an empty loader by returning
nanearly.Determines how many batches to process:
If
num_batches is None, useslen(data_loader).Else takes
min(num_batches, len(data_loader)).
Iterates over the loader and calls
calc_batch_loss(...)for each batch until the cap is reached.Sums
batch_loss.item()and returnstotal_loss / num_batches.
Averaging semantics and weighting¶
calc_batch_lossusescross_entropy(..., reduction='mean')on flattened(B*T)elements, so each batch loss is the mean per token within that batch.This helper averages those batch means equally across batches. If all batches are the same size (common when
drop_last=True), this equals the dataset mean.
Performance tips¶
Wrap calls in
torch.no_grad()ortorch.inference_mode()to disable autograd and save memory/time during validation.Increase
num_workersin your dataloader (if supported) to accelerate data prep.Ensure batches and model are on the same device to avoid implicit transfers.
def calc_loader_loss(data_loader, model, device, num_batches = None):
total_loss = 0
if len(data_loader) == 0:
return float('nan')
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for batch_index, (input_batch, target_batch) in enumerate(data_loader):
if batch_index >= num_batches:
break
else:
batch_loss = calc_batch_loss(input_batch, target_batch, model, device)
total_loss += batch_loss.item()
return total_loss / num_batches
Evaluate initial training/validation loss (no grad)¶
This cell runs a quick, read-only evaluation to get baseline losses on the training and validation loaders. It’s useful for sanity checking your pipeline before training (e.g., confirming shapes/devices, dataloader yield, and a reasonable initial loss).
What the code does¶
- Picks a device:
cudaif available, elsecpu. - Prints the chosen device for visibility.
- Moves the model to that device:
model.to(device). - Disables autograd with
torch.no_grad()and computes:training_loss = calc_loader_loss(training_dataloader, model, device)validation_loss = calc_loader_loss(validation_dataloader, model, device)
- Prints both losses.
Interpreting the numbers¶
- The loss is the mean cross‑entropy over all tokens seen by the helper.
- With random initialization and a vocab of 50,257, a loss near ln(50257) ≈ 10.82 is expected. Lower is better; after training, you should see this decrease.
Example options:
- Fast estimate:
calc_loader_loss(training_dataloader, model, device, num_batches=50) - Exact:
calc_loader_loss(training_dataloader, model, device, num_batches=num_train_batches)wherenum_train_batches = count_batches_streaming(...).
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model.to(device)
with torch.no_grad():
training_loss = calc_loader_loss(training_dataloader, model, device)
validation_loss = calc_loader_loss(validation_dataloader, model, device)
print(f"Initial Training Loss: {training_loss}")
print(f"Initial Validation Loss: {validation_loss}")
Using device: cuda Initial Training Loss: 10.957578182220459 Initial Validation Loss: 10.96420328957694 Initial Training Loss: 10.957578182220459 Initial Validation Loss: 10.96420328957694
Helper: generate_sample_text (quick qualitative checkpoint)¶
This helper function produces a short sample completion from the current model state after (or during) training. It lets you visually inspect language quality and drift without exporting or writing a separate script.
Function Signature¶
def generate_sample_text(model, tokenizer, device, start_context):
...
Parameters¶
- model (
nn.Module): Your SydsGPT instance (already moved todevice). Should be in training or eval mode depending on outer loop. - tokenizer (
tiktoken.Encoding): GPT‑2 BPE tokenizer used for encoding the prompt and decoding output tokens. - device (
torch.device): Target compute device (e.g.,cudaorcpu). Ensures input tokens match model placement. - start_context (
str): The initial prompt (seed text) to condition generation. Can be a sentence fragment, instruction, or domain phrase.
Returns¶
- No explicit return; prints a one‑line generated string (newlines collapsed to spaces). You could refactor to return the string if needed.
Step‑by‑Step Workflow¶
- Switches the model to
eval()to:- Disable dropout.
- Avoid gradient tracking (paired with an explicit
torch.no_grad()block).
- Determines
context_sizefrom the model config (SYDSGPT_CONFIG_345M['context_length']). This sets the max total token length (prompt + generated) the function will attempt. - Encodes
start_contextinto token IDs and wraps into a batch of size 1; moves tensor todevice. - Calls
generate_simple(model, input_tokens, 100, context_size):- Greedy extension up to
max_new_tokens=100or until context limit reached. generate_simpleinternally ensures input tokens are on the same device as model (device‑safety fix applied earlier).
- Greedy extension up to
- Decodes resulting tokens back to text via the tokenizer.
- Prints the generated text with newlines replaced by spaces for cleaner logging.
- Restores
model.train()so the outer training loop can continue collecting gradients.
Usage Patterns¶
Typical placement:
- End of each epoch (qualitative progress checkpoint).
- After major learning rate changes.
- Before and after loading a checkpoint to verify restoration.
Example invocation inside a loop:
generate_sample_text(model, tokenizer, device, start_context="Once upon a time")
Design Considerations¶
- Non‑Determinism: If you include dropout or sample stochastically (e.g., top‑k, temperature), results change across calls. Here, greedy decoding +
eval()yields deterministic output given fixed weights. - Prompt Length vs. Context: If
start_contextis already nearcontext_length, generation will be short or zero. Consider truncating long prompts. - Performance: 100 tokens is modest; for faster feedback on large models reduce to 32 or 64.
- Mode Switching: Ensures the model returns to training mode automatically—prevents accidental dropout disablement mid‑training.
Summary¶
generate_sample_text is a lightweight, deterministic snapshot tool: it briefly transitions the model to evaluation, performs greedy decoding for a fixed number of tokens, prints the result, and returns the model to training mode—giving you rapid human feedback on training progress without disrupting optimization.
def generate_sample_text(model, tokenizer, device, start_context):
model.eval()
context_size = SYDSGPT_CONFIG_345M['context_length']
input_tokens = text_to_tokens(start_context, tokenizer).to(device)
with torch.no_grad():
generated_tokens = generate_simple(model, input_tokens, 100, context_size)
generated_text = tokens_to_text(generated_tokens, tokenizer)
print(f"Generated Text: {generated_text}".replace("\n", " "))
model.train()
Training Loop: train_model_v1 (core optimization engine)¶
This section documents the training function responsible for iterating over data, computing gradients, updating model weights, periodically evaluating progress, checkpointing state, and logging token throughput.
Function Signature¶
def train_model_v1(model,
training_dataloader,
validation_dataloader,
optimizer,
device,
num_epochs,
evaluation_frequency,
evaluation_iterations,
start_context,
tokenizer,
checkpoint_interval=500):
...
Arguments¶
| Name | Type | Description |
|---|---|---|
model |
nn.Module |
SydsGPT model to optimize. Must already reside on device. |
training_dataloader |
iterable of (input_batch, target_batch) |
Yields training batches (LongTensors of shape (B, T)). |
validation_dataloader |
iterable of (input_batch, target_batch) |
Used for periodic evaluation (subset or full depending on evaluation_iterations). |
optimizer |
torch.optim.Optimizer |
Optimizer instance (e.g., AdamW) over model parameters. |
device |
torch.device / str |
Compute target (GPU preferred if available). |
num_epochs |
int |
Number of full passes over training_dataloader. |
evaluation_frequency |
int |
Evaluate every N optimization steps (NOT epochs). |
evaluation_iterations |
int |
Number of batches to sample from each loader during evaluation (fast approximation). |
start_context |
str |
Prompt for qualitative sample via generate_sample_text after each epoch. |
tokenizer |
tokenizer object | Used only for the sampling helper (not directly in training loop math). |
checkpoint_interval |
int (default 500) |
Save a checkpoint every N steps (model + optimizer state). |
Returns¶
A tuple (training_losses, validation_losses, total_tokens_processed):
training_losses: List of sampled training loss values (one per evaluation event).validation_losses: List of sampled validation loss values aligned withtraining_lossesindices.total_tokens_processed: Cumulative token counts (one entry per training step logged) enabling plotting loss vs. tokens.
Internal State Variables¶
tokens_processed: Running count of all token positions used in gradient steps across steps and epochs. Incremented byinput_batch.numel()(which =batch_size * sequence_length).global_step: Counts optimization steps across epochs (starts at-1, incremented at each batch). Used to trigger evaluations and checkpoints. Starting at -1 ensures after first increment the first batch is step 0 (aligns with modulo logic for early evaluation/checkpoint if desired).
Step-by-Step Flow¶
- Epoch Loop: Repeats
num_epochstimes. - Set
model.train(): Ensures dropout (if any) and gradient accumulation are active. - Batch Loop (Training):
- Zero gradients:
optimizer.zero_grad()(baseline approach; could switch to gradient accumulation strategy if needed). - Forward + Loss: Uses
calc_batch_losswhich (a) moves tensors to device, (b) computes logits, (c) applies cross-entropy across flattened(B*T)tokens. - Backprop:
loss.backward()computes gradients. - Optimizer Step:
optimizer.step()applies parameter updates. - Update token counter:
tokens_processed += input_batch.numel(). - Increment global step.
- Append the new cumulative token count to
total_tokens_processed. - Progress print: provides epoch, step, and cumulative tokens.
- Zero gradients:
- Conditional Evaluation (
if global_step % evaluation_frequency == 0):- Switch to
model.eval(). - Wrap in
torch.no_grad()to disable gradient tracking. - Compute approximate training loss:
calc_loader_loss(training_dataloader, ..., num_batches=evaluation_iterations). - Compute approximate validation loss similarly.
- Append both losses to respective lists and print a concise summary.
- Return to
model.train().
- Switch to
- Conditional Checkpoint (
if global_step % checkpoint_interval == 0):- Save a dictionary with both model and optimizer state dicts to
autosave_sydsgpt_345m_trained_model_optimizer.pth.
- Save a dictionary with both model and optimizer state dicts to
- End of Epoch Qualitative Sample: Calls
generate_sample_textusingstart_contextto qualitatively track improvements. - Loop Continue or Exit: Repeats until all epochs complete; returns logged metrics.
Evaluation Strategy Rationale¶
- Using a small
evaluation_iterationsdrastically reduces overhead, enabling frequent snapshots (e.g., every 100 steps) without stalling training. - For a precise validation curve later, run a dedicated full evaluation pass using the batch counting helpers from the streaming dataloader (if integrated) or iterate fully with
drop_last=False.
Checkpointing Notes¶
- Each autosave overwrites the same filename, minimizing disk usage.
- For resilience, consider timestamped or step‑indexed filenames, e.g.:
autosave_step_{global_step}.pth. - Include
tokens_processedandglobal_stepin future checkpoint metadata to support exact resumption.
Logging & Monitoring¶
- Current approach: simple
printstatements. - Optional enhancements:
- Use
tqdmprogress bars (wrap dataloader) for ETA visibility. - Log scalars (loss, tokens) to TensorBoard or a tracking service.
- Store (step, training_loss, validation_loss, tokens_processed) as rows in a CSV for later analysis.
- Use
Performance Tips¶
- Prefer larger batch sizes if GPU memory permits—improves arithmetic intensity.
- If dataloader becomes bottleneck: increase
num_workers, enablepin_memory=True, pre‑tokenize (if not already streaming efficiently). - Track
tokens_processedvs. wall clock time to estimate throughput (tokens/sec) for capacity planning.
Usage Example¶
training_losses, validation_losses, token_counts = train_model_v1(
model, training_dataloader, validation_dataloader, optimizer, device,
num_epochs=5, evaluation_frequency=100, evaluation_iterations=2,
start_context="Once upon a time", tokenizer=tokenizer, checkpoint_interval=500)
Summary¶
train_model_v1 provides a clear, modular training scaffold: it tracks tokens, performs scheduled lightweight evaluations, produces qualitative samples, and saves recoverable checkpoints. This makes it a solid foundation for iterative experimentation and scaling with minimal friction.
def train_model_v1(model, training_dataloader, validation_dataloader, optimizer, device, num_epochs, evaluation_frequency, evaluation_iterations, start_context, tokenizer, checkpoint_interval = 500):
training_losses, validation_losses, total_tokens_processed = [], [], []
tokens_processed = 0
global_step = -1
for epoch in range(num_epochs):
model.train()
for input_batch, target_batch in training_dataloader:
optimizer.zero_grad()
loss = calc_batch_loss(input_batch, target_batch, model, device)
loss.backward()
optimizer.step()
tokens_processed += input_batch.numel()
global_step += 1
total_tokens_processed.append(tokens_processed)
print(f"Epoch {epoch+1}, Step {global_step}: Tokens Processed = {tokens_processed}")
if global_step % evaluation_frequency == 0:
model.eval()
with torch.no_grad():
training_loss = calc_loader_loss(training_dataloader, model, device, evaluation_iterations)
validation_loss = calc_loader_loss(validation_dataloader, model, device, evaluation_iterations)
training_losses.append(training_loss)
validation_losses.append(validation_loss)
print(f"Epoch {epoch+1}, Step {global_step}: Training Loss = {training_loss}, Validation Loss = {validation_loss}, Tokens Processed = {tokens_processed}")
model.train()
if global_step % checkpoint_interval == 0:
torch.save({"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, "autosave_sydsgpt_345m_trained_model_optimizer.pth")
generate_sample_text(model, tokenizer, device, start_context)
return training_losses, validation_losses, total_tokens_processed
Initial training run (from scratch)¶
This cell wires everything together to kick off a multi‑epoch training run using the previously defined train_model_v1 loop. It selects a device, seeds RNG for reproducibility, instantiates a fresh model, configures an optimizer, and starts training with periodic evaluations and qualitative sampling.
What the code does¶
- Chooses a compute device (
cudaif available, elsecpu) and prints it. - Sets a fixed torch seed (
246) so initialization and early steps are repeatable. - Creates a new
SydsGPTmodel with the 345M‑style config and moves it to the selected device. - Constructs an
AdamWoptimizer over model parameters with:- Learning rate:
2e-4 - Weight decay:
0.05
- Learning rate:
- Sets
num_epochs = 5and callstrain_model_v1(...)with:evaluation_frequency=100: run a quick loss snapshot every 100 optimization steps.evaluation_iterations=2: evaluate only 2 mini‑batches per snapshot to keep it fast.start_context="Once upon a time": prompt for qualitative samples at epoch ends.
- Captures outputs:
training_losses: sampled training losses (one per evaluation event)validation_losses: sampled validation losses aligned with training snapshotstotal_tokens_processed: cumulative token count across steps (for plotting)
Inputs and assumptions¶
training_dataloaderandvalidation_dataloaderyield(input_batch, target_batch)with dtypetorch.longand shape(B, T)each.- The model’s vocab size and
context_lengthmatch the tokenizer/loader. - The training loop handles device placement for batches and uses
cross_entropyon flattened logits/targets.
Tuning knobs (quick guidance)¶
- Batch size vs. context length: If you see CUDA OOM, reduce batch size first; then consider lowering
context_length. - Learning rate:
2e-4is a reasonable starting point for AdamW; try 1e‑4 to 3e‑4 depending on stability. - Weight decay:
0.05encourages generalization; adjust if you notice under/over‑regularization. - Evaluation cadence: Increase
evaluation_iterationsonce runs are stable to get more precise estimates (at a compute cost).
Expected outputs¶
- Console logs for steps, tokens processed, and periodic training/validation losses.
- A short generated sample at the end of each epoch to qualitatively gauge progress.
- On completion, you can optionally save a final checkpoint in the next cell (separate from autosaves inside the loop).
Troubleshooting¶
- Device mismatch: Ensure
model.to(device)ran before training; the loop moves batches todeviceinternally. - Empty/short dataloader: If your dataset is tiny,
evaluation_frequencymay fire too often; increase it or reduce epochs. - Slow evaluations: Lower
evaluation_iterationsor run snapshots less frequently.
With this cell, you should see both quantitative (loss) and qualitative (sample text) signals to verify your pipeline is training correctly.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(246)
model = SydsGPT(SYDSGPT_CONFIG_345M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0002, weight_decay = 0.05)
num_epochs = 5
training_losses, validation_losses, total_tokens_processed = train_model_v1(
model,
training_dataloader,
validation_dataloader,
optimizer,
device,
num_epochs,
evaluation_frequency = 100,
evaluation_iterations = 2,
start_context = "Once upon a time",
tokenizer = tokenizer
)
Final checkpoint save (explicit milestone)¶
This cell performs an explicit, manual save of the model and optimizer state at the current training milestone. Although the training loop has already produced periodic autosaves, this dedicated checkpoint is useful for tagging a known-good completion of an initial training phase (e.g., after all planned epochs).
What gets saved¶
The file sydsgpt_345m_trained_model_optimizer.pth contains a Python dictionary with two keys:
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}
model_state_dict: Parameter tensors (weights, biases, embeddings, layer norms, attention projections, etc.).optimizer_state_dict: Internal optimizer buffers (e.g., AdamW moment estimatesexp_avg,exp_avg_sq, learning rate state, step counters).
Including the optimizer state allows seamless continuation of training without losing momentum statistics or adaptive learning rate context. This preserves optimization dynamics and avoids a transient spike in loss that can occur when restarting with a freshly initialized optimizer.
Why take a manual checkpoint here¶
- Marks the end of a planned training segment (e.g., initial 5 epochs) distinctly from autosaves that may overwrite.
- Ensures you have a stable artifact before experimenting with new hyperparameters, architectures, or data.
- Provides a rollback point if subsequent fine‑tuning or continued training degrades performance.
Reloading later¶
To resume training or run inference:
checkpoint = torch.load("sydsgpt_345m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Move optimizer state tensors to device (see later reload cell pattern)
After loading, move any optimizer tensors to the target device so continued training doesn’t trigger device mismatch errors.
Differences vs. autosave¶
| Aspect | Autosave file | Manual final checkpoint |
|---|---|---|
| Trigger | Time/step interval (checkpoint_interval) |
Explicit user intent (end of phase) |
| Overwrite behavior | Typically overwritten each interval | Usually kept / versioned |
| Semantic meaning | Latest state at periodic step | Milestone and reference artifact |
Consider versioning milestone checkpoints (e.g., append _epoch5.pth, _tokens{count}.pth) if you plan multiple phases.
Best practices¶
- Keep at least two historical milestone checkpoints in case of accidental corruption.
- Record metadata externally (JSON/YAML): total tokens processed, epoch count, validation loss at save time.
- For inference-only deployment, you can discard
optimizer_state_dictto reduce file size (just savemodel_state_dict). - Compress large checkpoints when archiving (
torch.savealready uses pickle; for storage, considerziportar.gz).
Common pitfalls¶
| Pitfall | Symptom | Mitigation |
|---|---|---|
| Forgetting optimizer state | Restarted training shows learning instability | Always save optimizer_state_dict during training checkpoints |
| Device mismatch on reload | Runtime error about CPU vs CUDA tensors | After optimizer.load_state_dict, move state tensors to device |
| Overwriting milestone unintentionally | Loss of earlier good snapshot | Use distinct filenames per milestone |
| Large disk usage | Many multi‑GB files | Prune intermediate autosaves; retain only curated milestones |
Minimal variant for inference¶
If you only need the model weights:
torch.save({"model_state_dict": model.state_dict()}, "sydsgpt_345m_model_only.pth")
This is smaller and faster to load for pure generation tasks.
Summary¶
This cell captures a durable, milestone checkpoint bundling both model parameters and optimizer state—essential for faithful resumption and reproducible experimentation beyond the initial training phase.
torch.save({"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, "sydsgpt_345m_trained_model_optimizer.pth")
Reloading a saved checkpoint (model + optimizer)¶
This markdown explains the checkpoint reload code directly below. Its purpose is to restore both model weights and optimizer state so you can seamlessly continue training or run inference after a previous session ended.
What the code does (line by line)¶
- Picks a device:
cudaif available elsecpu. - Prints the chosen device for visibility.
- Instantiates a fresh
SydsGPTmodel with the same config used during training. torch.load(..., map_location=device)loads the serialized checkpoint dict from disk and ensures all tensors are mapped onto the selected device (or CPU fallback if no GPU).model.load_state_dict(checkpoint['model_state_dict'])populates the newly created model with the trained parameters.- Creates a new
AdamWoptimizer with identical hyperparameters (LR, weight decay) used earlier. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])restores optimizer internal buffers (moment estimates, step counters) for continuity.- Iterates through every tensor in the optimizer state dict and forces it onto
device(sometimesmap_location+load_state_dictleave optimizer state tensors on CPU while the model is on GPU—this avoids device mismatch errors during the nextoptimizer.step()). - Moves the model itself to device (
model.to(device)).
Why instantiate a new model before loading¶
You need an object with the correct architecture to receive parameters. The state_dict only contains raw tensors keyed by module names; without constructing the model first, there’s nowhere to load them.
Optimizer state importance¶
Restoring optimizer buffers keeps training dynamics smooth:
- AdamW uses first (
exp_avg) and second (exp_avg_sq) moment estimates to adapt per‑parameter learning rates. - Omitting them causes a brief instability phase while moments re‑warm.
- Preserving the
stepcounter ensures scheduler logic (if later added) resumes correctly.
Device handling details¶
map_location=deviceensures checkpoint tensors don’t try to allocate on an unavailable GPU.- Always move optimizer state tensors explicitly after
load_state_dict; some PyTorch versions leave them on the original device regardless ofmap_location. - If later using multiple GPUs or
DistributedDataParallel, load on CPU first, then wrap/replicate.
Minimal inference-only variant¶
If you only need text generation (no further training):
model = SydsGPT(SYDSGPT_CONFIG_345M)
checkpoint = torch.load("sydsgpt_345m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
# Skip optimizer entirely
This loads faster and uses less memory.
Summary¶
This reload cell reconstructs the training state (model + optimizer) on the chosen device, enabling seamless continuation or evaluation. Proper device mapping and optimizer state preservation prevent subtle training regressions and runtime errors.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = SydsGPT(SYDSGPT_CONFIG_345M)
checkpoint = torch.load("sydsgpt_345m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0002, weight_decay = 0.05)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
model.to(device)
Using device: cuda
SydsGPT(
(token_embedding): Embedding(50257, 1024)
(position_embedding): Embedding(512, 1024)
(drop_embedding): Dropout(p=0.1, inplace=False)
(transformer_blocks): Sequential(
(0): TransformerBlock(
(attention): MultiHeadAttention(
(weight_query): Linear(in_features=1024, out_features=1024, bias=False)
(weight_key): Linear(in_features=1024, out_features=1024, bias=False)
(weight_value): Linear(in_features=1024, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(output_projection): Linear(in_features=1024, out_features=1024, bias=True)
)
...
(final_layer_norm): LayerNorm()
(output_projection): Linear(in_features=1024, out_features=50257, bias=False)
)
Post‑reload qualitative sanity check (generate_sample_text call)¶
This cell performs a quick qualitative verification immediately after restoring the model and optimizer from a checkpoint. It doesn’t train or modify weights; it simply generates a short continuation from a prompt to confirm the load succeeded and the model produces coherent text.
Purpose¶
- Validate that
model.load_state_dict(...)correctly restored parameters. - Confirm the model is on the intended
device(no silent CPU/GPU mismatch). - Provide a human‑readable signal (sample text) before investing time in continued training.
What happens¶
- Calls
generate_sample_text(model, tokenizer, device, "once upon a time"). - Inside the helper:
- Switches to
eval()(disables dropout) and wraps generation inno_grad(). - Encodes the prompt, performs greedy extension (up to configured max new tokens in the helper), decodes output.
- Prints the generated string (newlines flattened) and returns the model to
train()mode.
- Switches to
Why a qualitative sample here¶
- Faster than computing a full validation loss (no dataloader iteration).
- Immediately surfaces obvious load issues (garbled or purely random tokens vs. plausible language).
- Lets you compare style with samples produced before saving the checkpoint.
Interpreting output¶
| Observation | Likely Meaning | Action |
|---|---|---|
| Fluent continuation resembling earlier runs | Successful checkpoint restore | Proceed to continued training / evaluation |
| Completely random / high entropy gibberish | Wrong weights file or failed load | Recheck filename, config mismatch, or state_dict keys |
| Runtime error about device mismatch | Some tensors still on CPU | Ensure optimizer state + model moved with .to(device) earlier |
| Identical text every run (expected here) | Deterministic greedy decoding | Introduce temperature/top‑k if diversity needed |
Customizing¶
- Change the prompt:
"In a distant future","Chapter 1:", domain‑specific phrase. - Shorter test: modify helper to use fewer new tokens (e.g., 40) for speed.
- Add timing: wrap call with
time.perf_counter()to gauge generation latency.
Summary¶
A lightweight, single‑call checkpoint verification: generate a deterministic sample to ensure the restored model behaves plausibly before resuming expensive training.
generate_sample_text(model, tokenizer, device, "once upon a time")
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[14], line 1 ----> 1 generate_sample_text(model, tokenizer, device, "once upon a time") Cell In[7], line 6, in generate_sample_text(model, tokenizer, device, start_context) 4 input_tokens = text_to_tokens(start_context, tokenizer).to(device) 5 with torch.no_grad(): ----> 6 generated_tokens = generate_simple(model, input_tokens, 100, context_size) 7 generated_text = tokens_to_text(generated_tokens, tokenizer) 8 print(f"Generated Text: {generated_text}".replace("\n", " ")) File e:\Code\SydsGPT-Pretraining\modules\GenerateSimple.py:7, in generate_simple(model, input_ids, max_length, context_size) 5 input_ids_crop = input_ids[:, -context_size:] 6 with torch.no_grad(): ----> 7 logits = model(input_ids_crop) 8 next_token_logits = logits[:, -1, :] 9 next_token_probs = torch.softmax(next_token_logits, dim = -1) File e:\Code\SydsGPT-Pretraining\.venv\Lib\site-packages\torch\nn\modules\module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File e:\Code\SydsGPT-Pretraining\.venv\Lib\site-packages\torch\nn\modules\module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File e:\Code\SydsGPT-Pretraining\model\SydsGPT.py:18, in SydsGPT.forward(self, input) 16 def forward(self, input): 17 batch_size, seq_length = input.shape ---> 18 token_embeddings = self.token_embedding(input) 19 position_embeddings = self.position_embedding(torch.arange(seq_length, device=input.device)) 20 x = token_embeddings + position_embeddings File e:\Code\SydsGPT-Pretraining\.venv\Lib\site-packages\torch\nn\modules\module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File e:\Code\SydsGPT-Pretraining\.venv\Lib\site-packages\torch\nn\modules\module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File e:\Code\SydsGPT-Pretraining\.venv\Lib\site-packages\torch\nn\modules\sparse.py:190, in Embedding.forward(self, input) 189 def forward(self, input: Tensor) -> Tensor: --> 190 return F.embedding( 191 input, 192 self.weight, 193 self.padding_idx, 194 self.max_norm, 195 self.norm_type, 196 self.scale_grad_by_freq, 197 self.sparse, 198 ) File e:\Code\SydsGPT-Pretraining\.venv\Lib\site-packages\torch\nn\functional.py:2551, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) 2545 # Note [embedding_renorm set_grad_enabled] 2546 # XXX: equivalent to 2547 # with torch.no_grad(): 2548 # torch.embedding_renorm_ 2549 # remove once script supports set_grad_enabled 2550 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) -> 2551 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
Continue training from a saved checkpoint (resume run)¶
This cell resumes optimization from a previously saved checkpoint. It reconstructs the model and optimizer states on the chosen device, then runs a shorter training segment with periodic evaluation and autosaves.
What the code does¶
- Selects a device (
cudaif available, elsecpu) and prints it. - Instantiates a fresh
SydsGPTwith the same configuration used earlier. - Loads the checkpoint dictionary from disk with
map_location=deviceso tensors land on the intended device even without a GPU. - Restores weights:
model.load_state_dict(checkpoint["model_state_dict"]). - Creates an
AdamWoptimizer and restores its state withoptimizer.load_state_dict(...)(includes moment buffers and step counters). - Ensures all optimizer state tensors live on
device(prevents CPU/GPU mismatch duringoptimizer.step()). - Moves the
modeltodeviceand kicks off training fornum_epochs = 2viatrain_model_v1with quick evaluations.
Resume semantics and counters¶
- Optimizer continuity: AdamW internal buffers (
exp_avg,exp_avg_sq) andstate['step']are restored, preserving learning dynamics. - Training loop counters:
train_model_v1initializes a freshglobal_step = -1andtokens_processed = 0each time you call it. That means:- Evaluation/checkpoint cadence restarts from step 0 for this resumed segment.
- If you need a global step across sessions, extend the function to accept and persist
global_step/tokens_processed.
Evaluation and checkpoint cadence¶
evaluation_frequency = 100: samples losses every 100 updates using onlyevaluation_iterations = 2mini‑batches for speed.- Autosave: every
checkpoint_intervalsteps (default set insidetrain_model_v1), an autosave writes toautosave_sydsgpt_345m_trained_model_optimizer.pth.- Note: This filename may overwrite a prior autosave. Use timestamped or step‑indexed names if you want a trail of snapshots.
Tuning knobs¶
num_epochs: increase for longer continuation runs.lrandweight_decay: keep consistent with the original run for stability; adjust cautiously if loss plateaus.evaluation_iterations: raise to improve estimate accuracy (at the cost of extra compute).- Consider enabling mixed precision (AMP) in
train_model_v1for speed on supported GPUs.
Assumptions and safety checks¶
- Model config matches the checkpoint (same
embedding_dim,num_layers,num_heads,vocab_size,context_length). - Checkpoint path
sydsgpt_345m_trained_model_optimizer.pthexists in the working directory. - Tokenizer and dataloaders are unchanged from the original training setup.
Quick verification¶
- Watch early training log lines for decreasing validation loss compared to just‑restored values.
- Generate a short sample at each epoch end (already handled via
generate_sample_textinsidetrain_model_v1).
Outputs¶
training_losses,validation_losses: sampled losses at the chosen evaluation cadence.total_tokens_processed: cumulative tokens (within this resumed segment) for plotting progress.
Summary¶
This cell cleanly restores model and optimizer state, then continues training with lightweight periodic evaluations and autosaves—ideal for incremental improvements without losing prior optimization momentum.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = SydsGPT(SYDSGPT_CONFIG_345M)
checkpoint = torch.load("sydsgpt_345m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0002, weight_decay = 0.05)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
model.to(device)
num_epochs = 2
training_losses, validation_losses, total_tokens_processed = train_model_v1(
model,
training_dataloader,
validation_dataloader,
optimizer,
device,
num_epochs,
evaluation_frequency = 100,
evaluation_iterations = 2,
start_context = "Once upon a time",
tokenizer = tokenizer
)
Using device: cuda Epoch 1, Step 0: Tokens Processed = 4096 Epoch 1, Step 0: Tokens Processed = 4096 Epoch 1, Step 0: Training Loss = 2.4921298027038574, Validation Loss = 5.017168760299683, Tokens Processed = 4096 Epoch 1, Step 0: Training Loss = 2.4921298027038574, Validation Loss = 5.017168760299683, Tokens Processed = 4096 Epoch 1, Step 1: Tokens Processed = 8192 ... Epoch 1, Step 1248: Tokens Processed = 5115904 Generated Text: Once upon a time, and a tax of apprenticeship, a tax of three or four shillings and a-half, imposed upon the wages of the land tax, which, in the time of time, is said to be the most unequal. The tythe, or of the most other taxes, is not, in the highest degree, of the most Generated Text: Once upon a time, and a tax of apprenticeship, a tax of three or four shillings and a-half, imposed upon the wages of the land tax, which, in the time of time, is said to be the most unequal. The tythe, or of the most other taxes, is not, in the highest degree, of the most
Manual generation cell (direct greedy sampling after training)¶
This cell performs a one‑off text generation directly with the trained model—outside the wrapped helper—to give a longer qualitative sample (200 new tokens) from the prompt “once upon a time”.
What the code does (step by step)¶
model.eval()switches the model to evaluation mode:- Disables dropout.
- Ensures deterministic output given fixed weights and input.
generate_simple(...)is called with:- The current
model. - A tokenized prompt (conversion handled by
text_to_tokens). max_new_tokens=200for a longer continuation.- The configured
context_length(prevents exceeding the model’s positional window).
- The current
- The returned tensor of token IDs is decoded back to human‑readable text via
tokens_to_text. - The final string is printed for inspection.
Purpose¶
- Obtain a longer, standalone sample than the shorter (100-token) helper output.
- Compare stylistic/coherence evolution after training or fine‑tuning phases.
- Serve as a quick qualitative regression test if later modifications are made to generation logic.
Inputs and outputs¶
| Item | Description |
|---|---|
| Prompt | “once upon a time” (lowercase variant) |
max_new_tokens |
200 (total output length ≈ prompt length + 200, capped by context_length) |
| Output tokens | 1D sequence of token IDs after greedy extension |
| Printed text | Decoded English continuation for human evaluation |
Greedy decoding characteristics¶
- Always selects argmax token; produces deterministic, lower‑entropy completions.
- May become repetitive over very long stretches (lack of sampling diversity). For more natural variation consider top‑k, nucleus (top‑p), or temperature sampling.
Adjustments you can make¶
| Change | Effect |
|---|---|
Increase max_new_tokens |
Longer stories / more context, higher runtime, more memory usage |
Shorter max_new_tokens (e.g. 64) |
Faster iteration / quick spot checks |
| Different prompt casing / content | Alters stylistic bias of continuation |
After running¶
You can proceed to additional evaluation, start continuation training, or experiment with alternative decoding strategies for more creative outputs.
Summary¶
A direct, transparent greedy generation producing a longer sample from a fixed prompt—ideal for quick qualitative assessment of model fluency and coherence after training.
model.eval()
output_tokens = generate_simple(model, text_to_tokens("once upon a time", tokenizer).to(device), 200, SYDSGPT_CONFIG_345M['context_length'])
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text: {output_text}")
Output Text: once upon a time before she was so busy, that I felt quite sure that I felt quite sure that I was not quite sure that I felt it was. “You are a child,” said I, “that you are a beautiful woman, and you are a beautiful woman.” “Yes,” said Ada, “that there is nothing in it.” “That is,” said my guardian, “that there is nothing else that makes it so.” “That is not,” said my guardian, “that there is such a time as you are.” “You are not to be removed,” said my guardian, “that there is no considerable answer.” “You are not to be always happy,” said my guardian, “that there is something of the kind
Token sampling demo: greedy vs. multinomial and empirical frequency¶
This cell illustrates fundamental next‑token selection strategies on a tiny handcrafted vocabulary. It shows how raw logits become probabilities, how greedy decoding picks the argmax, how stochastic sampling (torch.multinomial) introduces diversity, and how repeated sampling approximates the underlying probability distribution.
What the code builds¶
example_vocab: A small mapping from token strings to integer IDs (10 tokens). Acts like a miniature language model vocabulary.inverse_example_vocab: Reverse lookup so we can convert sampled IDs back to readable tokens.example_next_token_logits: A 1‑D tensor of unnormalized scores (logits) for each vocabulary entry.example_next_token_probs = torch.softmax(logits, dim=0): Converts logits to a valid probability distribution (non‑negative, sums to 1).example_greedy_next_token = torch.argmax(...): Deterministic choice of the highest‑probability token.torch.multinomial(example_next_token_probs, num_samples=1): Draws one token according to the categorical distribution—probabilistic sampling.get_sampled_tokens(...): Repeats multinomial sampling 1000 times to empirically estimate how often each token is chosen; prints frequencies.
Why this matters¶
Language models typically output logits over a large vocabulary each step. Generation quality depends heavily on how you transform these logits into a selected next token:
- Greedy decoding maximizes immediate probability but can produce repetitive or bland sequences.
- Stochastic sampling (optionally with temperature/top‑k/top‑p) encourages diversity and can yield more creative or natural continuations.
- Empirical sampling frequency (many draws) converges toward the theoretical probability distribution—helpful for intuition.
Key concepts demonstrated¶
| Concept | Shown By | Notes |
|---|---|---|
| Logits vs. probabilities | softmax call |
Softmax rescales differences; large positive logits dominate. |
| Greedy selection | torch.argmax |
Always same token given identical logits. |
| Random sampling | torch.multinomial |
Draw proportional to probability mass; needs a proper distribution (no negative values, sums to 1). |
| Empirical distribution | 1000 repeats + bincount |
Frequencies approximate probabilities; more samples → tighter convergence (Law of Large Numbers). |
Interpreting results¶
- The printed “Greedy Next Token” is the single most likely token (highest probability after softmax). In this example it should correspond to the largest logit (here
beforewith logit 3.63). - “Random Next Token” might or might not match the greedy token—depends on one draw from the distribution.
- Frequency table: Tokens with higher true probabilities appear more often; rare tokens may still show up occasionally, demonstrating stochastic exploration.
Quick experiment ideas¶
Try changing one logit dramatically (e.g., raise after from -5.38 to 2.5) and re-run: watch probability mass shift and empirical frequencies respond proportionally.
Summary¶
This micro‑example demystifies next‑token selection: logits → probabilities → deterministic (greedy) vs. stochastic (multinomial) choice, with repeated sampling revealing the underlying distribution. It’s a conceptual foundation for more advanced decoding strategies (top‑k, top‑p, temperature) used in full language model generation.
example_vocab = {
"once" : 0,
"upon" : 1,
"a" : 2,
"time" : 3,
"before" : 4,
"she" : 5,
"lived" : 6,
"happily" : 7,
"ever" : 8,
"after" : 9
}
inverse_example_vocab = {v: k for k, v in example_vocab.items()}
example_next_token_logits = torch.tensor([1.35, 1.86, 1.53, 0.17, 3.63, -1.82, -2.17, -3.90, -4.85, -5.38])
example_next_token_probs = torch.softmax(example_next_token_logits, dim = 0)
example_greedy_next_token = torch.argmax(example_next_token_probs).item()
print(f"Greedy Next Token: {inverse_example_vocab[example_greedy_next_token]}")
torch.manual_seed(246)
example_random_next_token = torch.multinomial(example_next_token_probs, num_samples = 1).item()
print(f"Random Next Token: {inverse_example_vocab[example_random_next_token]}")
def get_sampled_tokens(probs):
sampled_token = [torch.multinomial(probs, num_samples = 1).item() for i in range(1000)]
sampled_tokens = torch.bincount(torch.tensor(sampled_token))
for i, frequency in enumerate(sampled_tokens):
print(f"Token: {inverse_example_vocab[i]}: {frequency.item()} times")
get_sampled_tokens(example_next_token_probs)
Greedy Next Token: before Random Next Token: once Token: once: 68 times Token: upon: 103 times Token: a: 83 times Token: time: 16 times Token: before: 723 times Token: she: 4 times Token: lived: 3 times
Temperature scaling demo: controlling randomness in sampling¶
This cell explores how the softmax temperature T reshapes a probability distribution and affects stochastic next‑token sampling. Using the same logits as the previous mini‑vocabulary example, it prints empirical frequencies at multiple temperatures to show how generation becomes more/less diverse.
What the code does¶
- Defines
softmax_with_temperature(logits, temperature):- Scales logits by 1/T and applies softmax.
- For a vector of logits z, probabilities are:
p_i(T) = exp(z_i / T) / Σ_j exp(z_j / T)
- Iterates over
temperatures = [0.1, 0.5, 1.0, 2.0]:- Computes
temperature_scaled_probsfor each T. - Calls
get_sampled_tokens(...)(from the previous cell) to draw 1000 samples and print a frequency table for each T.
- Computes
How temperature changes behavior¶
- T < 1.0 (e.g., 0.5, 0.1): sharpens the distribution
- Increases contrast between high and low‑probability tokens.
- Sampling becomes more deterministic; top tokens dominate frequency counts.
- T = 1.0: baseline distribution (no scaling)
- Frequencies reflect the original softmax over
example_next_token_logits.
- Frequencies reflect the original softmax over
- T > 1.0 (e.g., 2.0): flattens the distribution
- Reduces differences between tokens.
- Increases diversity; lower‑probability tokens appear more often.
Interpreting the output¶
- For each T, you’ll see counts for each token over 1000 draws.
- As T decreases, the most likely token (greedy token) should dominate the histogram.
- As T increases, frequencies spread toward a more uniform distribution.
- Because sampling is stochastic, exact numbers vary between runs; larger sample sizes (e.g., 10k) reduce variance and better reveal trends.
Practical guidance for generation¶
- Start with T in the 0.7–1.2 range; adjust based on desired creativity vs. factuality.
- Keep T consistent across steps within a single generation unless experimenting with annealing strategies.
Summary¶
Temperature scaling is a simple, powerful knob for controlling randomness in language model sampling: low T → precise and repetitive; high T → diverse and creative. Use it alongside top‑k/top‑p for fine‑grained control over generation quality.
def softmax_with_temperature(logits, temperature):
scaled_logits = logits / temperature
probs = torch.softmax(scaled_logits, dim = 0)
return probs
temperatures = [0.1, 0.5, 1.0, 2.0]
for temp in temperatures:
temperature_scaled_probs = softmax_with_temperature(example_next_token_logits, temp)
print(f"\n Temperature: {temp}")
get_sampled_tokens(temperature_scaled_probs)
Temperature: 0.1 Token: once: 0 times Token: upon: 0 times Token: a: 0 times Token: time: 0 times Token: before: 1000 times Temperature: 0.5 Token: once: 15 times Token: upon: 34 times Token: a: 18 times Token: time: 3 times Token: before: 930 times Temperature: 1.0 Token: once: 75 times Token: upon: 123 times Token: a: 82 times Token: time: 24 times Token: before: 690 times Token: she: 4 times Token: lived: 1 times Token: happily: 1 times Temperature: 2.0 Token: once: 113 times Token: upon: 159 times Token: a: 131 times Token: time: 78 times Token: before: 447 times Token: she: 25 times Token: lived: 22 times Token: happily: 13 times Token: ever: 8 times Token: after: 4 times
Top-k filtering demo: restricting the candidate set before sampling¶
This cell demonstrates a classic decoding refinement: top-k sampling. Instead of sampling from the full vocabulary distribution, we keep only the k highest‑logit tokens, mask the rest to negative infinity (so their post‑softmax probability becomes zero), then sample within that reduced set. This balances diversity and relevance.
What the code does (step by step)¶
top_k = 4: Choose how many highest‑scoring tokens to retain.torch.topk(example_next_token_logits, top_k)returns:top_k_logits: The 4 largest logits in descending order.top_k_indices: Their original token indices.
torch.where(example_next_token_logits < top_k_logits[-1], -inf, example_next_token_logits):- Finds the cutoff logit (the smallest logit among the top-k set:
top_k_logits[-1]). - Replaces any logit below that cutoff with
-inf, effectively zeroing its probability after softmax.
- Finds the cutoff logit (the smallest logit among the top-k set:
top_k_probs = torch.softmax(new_logits, dim=0): Computes normalized probabilities over only the surviving top-k logits (others become exactly 0 probability).get_sampled_tokens(top_k_probs): Samples repeatedly (1000 draws) and prints frequencies among the retained tokens.
Why top-k filtering¶
| Motivation | Benefit |
|---|---|
| Remove low-probability tail | Reduces chance of bizarre / out-of-context tokens |
| Retain diversity among strong candidates | Allows exploration beyond pure greedy argmax |
| Computational simplicity | Easy to implement; single topk + masking step |
Compared to pure greedy decoding, top-k can produce more varied yet still on-topic continuations. Compared to temperature-only scaling, it imposes a hard boundary on candidate tokens, preventing very low-probability choices even at higher temperatures.
Alternative masking approaches¶
- Direct index filtering: Gather only top-k indices and sample from that subset array; equivalent probability outcome.
- Top-p (nucleus) sampling: Instead of a fixed k, pick the smallest set of tokens whose cumulative probability ≥ p (e.g., 0.9). Adapts dynamically to distribution shape.
- Temperature + top-k: Sharpen or flatten within the retained set for fine-grained control.
Numerical details¶
- Using
-inf(negative infinity) ensuresexp(-inf) = 0in softmax, producing exact zeros without manual renormalization. - If implementing with large tensors, broadcasting and in-place operations can reduce memory pressure.
- Always compute softmax after masking; masking post-softmax requires renormalization manually.
Practical tips¶
- Typical values: k=20…50 for medium vocabularies (LLM generation often uses k≈40).
- Smaller k increases focus but risks repetition.
Summary¶
Top-k filtering discards the tail of the probability distribution, limiting sampling to the k most promising tokens. It offers a simple, deterministic way to balance diversity and coherence, and serves as a building block for more advanced decoding strategies used in modern language model generation.
top_k = 4
top_k_logits, top_k_indices = torch.topk(example_next_token_logits, top_k)
print(f"Top-{top_k} Indices: {top_k_indices}")
print(f"Top-{top_k} Logits: {top_k_logits}")
new_logits = torch.where(
condition = example_next_token_logits < top_k_logits[-1],
input = torch.tensor(float('-inf')),
other = example_next_token_logits
)
print(f"New Logits after Top-{top_k} filtering: {new_logits}")
top_k_probs = torch.softmax(new_logits, dim = 0)
print(f"Top-{top_k} Probabilities: {top_k_probs}")
get_sampled_tokens(top_k_probs)
Top-4 Indices: tensor([4, 1, 2, 0])
Top-4 Logits: tensor([3.6300, 1.8600, 1.5300, 1.3500])
New Logits after Top-4 filtering: tensor([1.3500, 1.8600, 1.5300, -inf, 3.6300, -inf, -inf, -inf, -inf,
-inf])
Top-4 Probabilities: tensor([0.0733, 0.1221, 0.0878, 0.0000, 0.7168, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000])
Token: once: 62 times
Token: upon: 120 times
Token: a: 92 times
Token: time: 0 times
Token: before: 726 times
Advanced generate function: temperature, top‑k, and EOS handling¶
This cell defines a more flexible text generation routine that supports temperature scaling, optional top‑k filtering, greedy fall‑back, and early stopping on an end‑of‑sequence (EOS) token. It also enforces the model’s context_size window by truncating the input to the most recent tokens at each step.
Function signature¶
def generate(model, input_tokens, max_new_tokens, context_size,
temperature=1.0, top_k=None, eos_id=None):
...
Parameters¶
model(nn.Module): Autoregressive language model producing logits over the vocab.input_tokens(LongTensor, shape(1, T0)): Seed prompt; function appends new tokens to this tensor in place.max_new_tokens(int): Maximum number of tokens to generate.context_size(int): Maximum sequence length the model attends to; older tokens are truncated beyond this window.temperature(float, default1.0):> 0: Scale logits by1/temperatureand sample via multinomial.== 0: Use greedy argmax (deterministic).
top_k(int | None): If set, keep only the top‑k logits per step (mask others to-inf), then apply softmax.eos_id(int | None): If set, stop generation early when this token is selected.
Returns¶
LongTensorof shape(1, T0 + N)where0 ≤ N ≤ max_new_tokens(may stop early oneos_id).
Step‑by‑step flow¶
- Loop
max_new_tokenstimes. - Slice the active context:
input_context = input_tokens[:, -context_size:]to obey the model’s positional limit. - Forward pass with
torch.no_grad(); extract last‑step logits:logits = logits[:, -1, :]. - Optional top‑k filtering:
top_k_logits, _ = torch.topk(logits, top_k)→ get per‑batch cutoffs.min_top_k_logit = top_k_logits[:, -1]→ the k‑th largest logit per example.- Mask all logits below the cutoff to
-infusingtorch.where(on the same device/dtype aslogits).
- Temperature / decoding mode:
- If
temperature > 0:- Scale logits by
1/temperatureand computeprobs = softmax(logits, dim=-1). - Sample
next_token = torch.multinomial(probs, num_samples=1).
- Scale logits by
- Else (temperature == 0):
- Greedy:
next_token = torch.argmax(logits, dim=-1, keepdim=True).
- Greedy:
- If
- Early stop: If
eos_idis provided and the sampled token equalseos_id,break. - Append:
input_tokens = torch.cat((input_tokens, next_token), dim=1). - Return the extended token sequence.
Design choices and rationale¶
- Context truncation ensures compute and memory scale with
context_size, not total generated length. - Top‑k pruning removes the low‑probability tail for safer sampling, especially at higher temperatures.
- Temperature provides a single, intuitive knob for diversity: lower → more deterministic; higher → more varied.
- Greedy fallback via
temperature == 0keeps the API simple without a separate mode switch. eos_idallows clean termination when the model emits a special end token.
Usage examples¶
- Pure greedy (deterministic):
generate(model, x, 128, context_size, temperature=0.0) - Temperature‑only sampling:
generate(model, x, 128, context_size, temperature=0.8) - Top‑k sampling:
generate(model, x, 128, context_size, temperature=0.8, top_k=40) - Early stop on EOS:
generate(model, x, 256, context_size, temperature=0.7, eos_id=tokenizer.eot_token)
Summary¶
This generate function brings together practical decoding controls—context management, temperature, top‑k, and EOS—into a compact loop suitable for qualitative sampling and quick experiments. It’s a solid baseline to plug into training checkpoints and prompt‑driven evaluations.
def generate(model, input_tokens, max_new_tokens, context_size, temperature = 1.0, top_k = None, eos_id = None):
for _ in range(max_new_tokens):
input_context = input_tokens[:, -context_size:]
with torch.no_grad():
logits = model(input_context)
logits = logits[:, -1, :]
if top_k is not None:
top_k_logits, _ = torch.topk(logits, top_k)
min_top_k_logit = top_k_logits[:, -1]
logits = torch.where(logits < min_top_k_logit, torch.tensor(float('-inf')).to(logits.device), logits)
if temperature > 0.0:
logits = logits / temperature
probs = torch.softmax(logits, dim = -1)
next_token = torch.multinomial(probs, num_samples = 1)
else:
next_token = torch.argmax(logits, dim = -1, keepdim = True)
if next_token == eos_id:
break
input_tokens = torch.cat((input_tokens, next_token), dim = 1)
return input_tokens
Example: temperature + top‑k generation (invocation)¶
This cell runs the advanced generate function on a natural prompt using both temperature scaling and top‑k filtering to produce a diverse yet controlled continuation.
What the code does¶
- Sets a fixed RNG seed (
torch.manual_seed(246)) to make sampling reproducible. - Defines a human‑readable prompt:
"Once upon a time there was a". - Encodes it to token IDs and moves them to
device. - Calls
generate(...)with:max_new_tokens=200: up to 200 tokens of continuation.context_size=SYDSGPT_CONFIG_345M['context_length']: enforces model’s max window.temperature=1.5: flattens probabilities to encourage variety.top_k=30: restricts sampling to the 30 most likely tokens each step.
- Decodes tokens back to text and prints the result.
Why combine temperature and top‑k¶
- Temperature > 1.0 increases diversity, avoiding overly‑confident loops.
- Top‑k caps the candidate set, preventing extremely unlikely tokens from appearing even when temperature is high.
- Together they provide a practical balance: varied but not nonsensical.
Tuning tips¶
- If output feels chaotic: lower
temperature(e.g., 0.8–1.0) or reducetop_k. - If output is dull/repetitive: raise
temperature(1.2–1.8) or increasetop_k(e.g., 50). - Keep the prompt specific to steer the model; generic prompts amplify variance.
After running¶
Skim the output for coherence, repetition, and topic adherence. Adjust temperature and top_k to your preferences, then reuse this pattern for different prompts or integrate it into a qualitative evaluation loop across checkpoints.
torch.manual_seed(246)
input_text = "once upon a time"
input_tokens = text_to_tokens(input_text, tokenizer).to(device)
output_tokens = generate(model, input_tokens, 200, SYDSGPT_CONFIG_345M['context_length'], temperature = 1.5, top_k = 40)
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text:\n {output_text}")
Output Text: once upon a time after his arrival. But again had that time too already settled in the possibility of talking in the existing histories given personal opportunity of reproachfulness? Hath it not not been brought together simply that one who had not always tried? And the most wonderful man, in a sort of unbension with which he had been capable of using the money by a man who had done so intimatelyision and must not think about as a politician be better in his physical conversation, for whose knowledge there must give a reference to the facts (a lady, especially on purpose, placed upright in their hands) of the unhappy man. The victim might receive her reason to be as much as a hypocrite as possible, but of having supposed she to do as, as it came upon him, as a mode of their being a woman. The latter part of his respect took place to him as much as much as possible to give it him

Leave a Reply