In this part, I scaled a full pretraining pipeline: a ~10B-token corpus, pre-tokenization and chunking for streaming, a Flash Attention replacement inside the GPT blocks, training-loop features (warmup, cosine decay, gradient accumulation), torch.compile for runtime speedups, and GaloreAdamW as the optimizer. I then ran a long single‑GPU pretraining run (~12B tokens over ~11 days on an NVIDIA 3080 Ti). This post documents the full process, explains the design choices, and shows the exact code so readers can reproduce and adapt each step.

Overview

What this part covers

  • Dataset assembly and preprocessing: combining multiple corpora, pre-tokenization with tiktoken, and chunking into fixed-length shards stored as Parquet for streaming.
  • Model changes: replacing standard attention with a Flash Attention style implementation using torch.nn.functional.scaled_dot_product_attention, and wiring that into a new transformer block and model class.
  • Training loop improvements: LR warmup, cosine decay, gradient accumulation, gradient clipping, periodic evaluation, and rotating checkpoints.
  • Performance engineering: torch.compile usage and runtime flags, mixed-precision considerations, and optimizer selection (GaLoreAdamW).
  • Run summary and practical lessons from training ~12B tokens on a single 3080 Ti.

Below I walk through each stage and include the notebook code so you can see exactly what was done.

Dataset assembly and tokenization

Goal: build a large, mixed corpus and convert it into tokenized, fixed-length chunks that can be streamed efficiently during training.

Key ideas

  • Keep raw text columns minimal to save space.
  • Pre-tokenize with tiktoken (GPT-2 encoding) to get deterministic token counts.
  • Stream token lists into a buffer and emit fixed-size chunks (here CHUNK_SIZE = 2048) into Parquet shards for efficient, memory-mapped reads.

Code: dataset loading, trimming, concatenation, and saving

fineweb_dataset = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train")
wikipedia_dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train")
arxiv_dataset = load_dataset("timaeus/pile-arxiv", split="train")

fineweb_dataset = fineweb_dataset.remove_columns([col for col in fineweb_dataset.column_names if col != 'text'])
wikipedia_dataset = wikipedia_dataset.remove_columns([col for col in wikipedia_dataset.column_names if col != 'text'])
arxiv_dataset = arxiv_dataset.remove_columns([col for col in arxiv_dataset.column_names if col != 'text'])

fineweb_dataset.save_to_disk('data/fineweb_dataset')
wikipedia_dataset.save_to_disk('data/wikipedia_dataset')
arxiv_dataset.save_to_disk('data/arxiv_dataset')

fineweb_dataset = load_from_disk('data/fineweb_dataset')
wikipedia_dataset = load_from_disk('data/wikipedia_dataset')
arxiv_dataset = load_from_disk('data/arxiv_dataset')

fineweb_dataset = fineweb_dataset.shuffle(seed=42)
wikipedia_dataset = wikipedia_dataset.shuffle(seed=42)
arxiv_dataset = arxiv_dataset.shuffle(seed=42)
# trim 50% of the dataset
fineweb_dataset = fineweb_dataset.select(range(len(fineweb_dataset)//2))
wikipedia_dataset = wikipedia_dataset.select(range(len(wikipedia_dataset)//2))
arxiv_dataset = arxiv_dataset.select(range(len(arxiv_dataset)//2))

#Concatenate the datasets
combined_dataset = concatenate_datasets([fineweb_dataset, wikipedia_dataset, arxiv_dataset])


combined_dataset.save_to_disk('data/combined_dataset')

Code: tokenization with tiktoken and saving tokenized dataset

import tiktoken
enc = tiktoken.get_encoding("gpt2")

from typing import Optional, Tuple
from datasets import Dataset

# Tokenize a HF Dataset and save to disk. Returns (tokenized_dataset, total_tokens)
def tokenize_and_save(
    dataset: Dataset,
    out_dir: str,
    text_column: str = 'text',
    keep_text: bool = False,
    batch_size: int = 1000,
    num_proc: Optional[int] = None,
) -> Tuple[Dataset, int]:
    """
    - Tokenizes each row's text using the global `enc` (tiktoken GPT-2).
    - Adds 'input_ids' (List[int]) and 'length' (int) columns.
    - Optionally removes the original 'text' column to save space.
    - Saves the resulting dataset to `out_dir`.
    Returns the tokenized dataset and the total token count.
    """

    def tok_batch(batch):
        texts = batch[text_column]
        input_ids = [enc.encode(t, allowed_special={'<|endoftext|>'}) for t in texts]
        lengths = [len(ids) for ids in input_ids]
        return {'input_ids': input_ids, 'length': lengths}

    remove_cols = None if keep_text else [text_column]

    tokenized = dataset.map(
        tok_batch,
        batched=True,
        batch_size=batch_size,
        num_proc=num_proc,
        remove_columns=remove_cols,
        desc=f"Tokenizing -> {out_dir}",
    )

    # Compute total tokens efficiently by summing the 'length' column
    total_tokens = int(sum(tokenized['length']))

    # Persist to disk
    tokenized.save_to_disk(out_dir)

    return tokenized, total_tokens

dataset_path = 'data/combined_dataset'
tokenized_dataset_path = 'data/combined_tokenized_dataset'

# Load datasets from disk
dataset = load_from_disk(dataset_path)

dataset_tokenized, dataset_token_count = tokenize_and_save(dataset, tokenized_dataset_path, keep_text=False, batch_size=1000, num_proc=None)


print('Tokenized dataset sizes (rows):', {
    'combined_rows': len(dataset_tokenized),
})
print('Per-dataset token counts:', {
    'combined_tokens': dataset_token_count,
})

Why this matters

  • Pre-tokenization gives you an exact token count and lets you reason about how many chunks and epochs you can run.
  • Saving tokenized rows to disk avoids repeated tokenization during experiments and makes preprocessing reproducible.

Chunking into fixed-length shards for streaming

Goal: convert variable-length token lists into fixed-length chunks (2048 tokens) and write them into Parquet shards for efficient streaming and reproducible sampling.

Design choices

  • Buffering: accumulate tokens across rows until you can emit a full chunk.
  • Shard sizing: choose a shard size (SHARD_SIZE_CHUNKS) that balances file count and I/O throughput.
  • Train/val split: random assignment at chunk emission time to get an approximate 80/20 split.
  • Parquet: memory-mapped reads via HF Dataset.from_parquet avoid loading everything into RAM.

Code: chunking pipeline and DataLoader wrappers

import os
import random
from pathlib import Path
from glob import glob
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from torch.utils.data import Dataset as TorchDataset, DataLoader
from datasets import load_from_disk, Dataset

# Config
SOURCE_PATH = 'data/combined_tokenized_dataset'   # variable-length input_ids
OUT_TRAIN_DIR = Path('data/combined_chunks_train_parquet')
OUT_VAL_DIR = Path('data/combined_chunks_val_parquet')
CHUNK_SIZE = 2048
TRAIN_PROB = 0.8          # approximate 80/20 split at chunk level
SHARD_SIZE_CHUNKS = 25000  # number of chunks per parquet shard (tune for memory/disk throughput)
BATCH_SIZE = 2
SEED = 42

random.seed(SEED)
OUT_TRAIN_DIR.mkdir(parents=True, exist_ok=True)
OUT_VAL_DIR.mkdir(parents=True, exist_ok=True)

# Optional: clean old shards (only .parquet files)
for f in list(OUT_TRAIN_DIR.glob('*.parquet')) + list(OUT_VAL_DIR.glob('*.parquet')):
    try:
        f.unlink()
    except Exception:
        pass

# Helper to write a shard of chunks to Parquet
# chunks: List[List[int]] (all must be CHUNK_SIZE long)
def write_parquet_shard(chunks, out_dir: Path, shard_idx: int):
    if not chunks:
        return
    array = pa.array(chunks, type=pa.list_(pa.int32()))
    table = pa.table({'input_ids': array})
    pq.write_table(table, out_dir / f'part-{shard_idx:05d}.parquet')

# Stream over dataset and produce fixed-size chunks
buf = []  # token buffer
train_batch, val_batch = [], []
train_shard, val_shard = 0, 0
train_count, val_count = 0, 0

src = load_from_disk(SOURCE_PATH)
print('Streaming rows:', len(src))

for row in src:
    toks = row['input_ids']
    if not toks:
        continue
    buf.extend(toks)
    while len(buf) >= CHUNK_SIZE:
        chunk = buf[:CHUNK_SIZE]
        del buf[:CHUNK_SIZE]
        if random.random() < TRAIN_PROB:
            train_batch.append(chunk)
            train_count += 1
            if len(train_batch) >= SHARD_SIZE_CHUNKS:
                write_parquet_shard(train_batch, OUT_TRAIN_DIR, train_shard)
                train_shard += 1
                train_batch = []
        else:
            val_batch.append(chunk)
            val_count += 1
            if len(val_batch) >= SHARD_SIZE_CHUNKS:
                write_parquet_shard(val_batch, OUT_VAL_DIR, val_shard)
                val_shard += 1
                val_batch = []

# Flush leftovers
write_parquet_shard(train_batch, OUT_TRAIN_DIR, train_shard)
write_parquet_shard(val_batch, OUT_VAL_DIR, val_shard)

print({'train_chunks_written': train_count, 'val_chunks_written': val_count, 'leftover_tokens_dropped': len(buf)})

# Build HF Datasets from Parquet shards (memory-mapped; avoids loading everything at once)
train_parquet_files = sorted(glob(str(OUT_TRAIN_DIR / '*.parquet')))
val_parquet_files = sorted(glob(str(OUT_VAL_DIR / '*.parquet')))

train_hfds = Dataset.from_parquet(train_parquet_files)
val_hfds = Dataset.from_parquet(val_parquet_files)

print({'train_rows': len(train_hfds), 'val_rows': len(val_hfds)})

# Torch wrappers and DataLoaders (no attention_mask), with external shift in collate
class FixedLenHFDataset(TorchDataset):
    def __init__(self, hf_ds: Dataset):
        self.ds = hf_ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        ids = self.ds[idx]['input_ids']
        return torch.tensor(ids, dtype=torch.long)

def collate_shift(batch):
    x = torch.stack(batch)      # (B, CHUNK_SIZE)
    y = x.clone()
    y[:, :-1] = x[:, 1:]
    y[:, -1] = -100
    return {'input_ids': x, 'targets': y}

train_fixed = FixedLenHFDataset(train_hfds)
val_fixed = FixedLenHFDataset(val_hfds)

training_dataloader = DataLoader(train_fixed, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_shift)
validation_dataloader = DataLoader(val_fixed, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_shift)

Notes on the collate function

  • The collate_shift function prepares input_ids and targets by shifting tokens left for next-token prediction and using -100 as the ignore index for the final token. This keeps the loss computation simple and efficient.

Model architecture and Flash Attention

Goal: reduce attention memory pressure and improve throughput by using torch.nn.functional.scaled_dot_product_attention (Flash Attention style) while preserving causal masking.

Key points

  • The FlashAttention class computes Q/K/V via a single linear, reshapes into heads, and calls scaled_dot_product_attention with is_causal=True.
  • The transformer block (TransformerBlockv2) uses LayerNorm, the FlashAttention module, and a FeedForward module.
  • The top-level model SydsGPTv2 wires token and position embeddings, a stack of TransformerBlockv2, final layer norm, and an output projection.

Code: FlashAttention, TransformerBlockv2, and SydsGPTv2

class FlashAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embedding_dim % num_heads == 0, "embedding_dim must be divisible by num_heads"
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        self.dropout = dropout

        self.qkv = nn.Linear(embedding_dim, 3 * embedding_dim)
        self.out_proj = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x):
        batch_size, seq_length, _ = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(batch_size, seq_length, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        queries, keys, values = qkv
        dropout = 0.0 if not self.training else self.dropout
        context_vectors = torch.nn.functional.scaled_dot_product_attention(queries, keys, values, attn_mask = None, dropout_p = dropout, is_causal = True)
        context_vectors = context_vectors.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embedding_dim)
        context_vectors = self.out_proj(context_vectors)
        return context_vectors
from modules.LayerNorm import LayerNorm
from modules.FeedForward import FeedForward
import torch.nn as nn

class TransformerBlockv2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = FlashAttention(
            embedding_dim = config["embedding_dim"],
            num_heads = config["num_heads"],
            dropout = config["dropout"],
        )
        self.layer_norm1 = LayerNorm(config["embedding_dim"])
        self.feed_forward = FeedForward(config)
        self.layer_norm2 = LayerNorm(config["embedding_dim"])
        self.dropout = nn.Dropout(config["dropout"])

    def forward(self, x):
        shortcut = x
        x = self.layer_norm1(x)
        x = self.attention(x)
        x = self.dropout(x)
        x = x + shortcut
        shortcut = x
        x = self.layer_norm2(x)
        x = self.feed_forward(x)
        x = self.dropout(x)
        x = x + shortcut
        return x
import torch
import torch.nn as nn
from modules.TransformerBlock import TransformerBlock
from modules.LayerNorm import LayerNorm

class SydsGPTv2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding = nn.Embedding(config["vocab_size"], config["embedding_dim"])
        self.position_embedding = nn.Embedding(config["context_length"], config["embedding_dim"])
        self.drop_embedding = nn.Dropout(config["dropout"])
        self.transformer_blocks = nn.Sequential(*[TransformerBlockv2(config) for _ in range(config["num_layers"])])
        self.final_layer_norm = LayerNorm(config["embedding_dim"])
        self.output_projection = nn.Linear(config["embedding_dim"], config["vocab_size"], bias = False)
    
    def forward(self, input):
        batch_size, seq_length = input.shape
        token_embeddings = self.token_embedding(input)
        position_embeddings = self.position_embedding(torch.arange(seq_length, device=input.device))
        x = token_embeddings + position_embeddings
        x = self.drop_embedding(x)
        x = self.transformer_blocks(x)
        x = self.final_layer_norm(x)
        logits = self.output_projection(x)
        return logits

Practical validation

  • Compare logits on a small batch between the FlashAttention model and a baseline to ensure numerical parity within tolerance.
  • Confirm is_causal=True to preserve autoregressive behavior.
  • Watch dtype: scaled_dot_product_attention supports mixed precision; ensure your autocast and torch.set_float32_matmul_precision settings align with your hardware.

Compilation and runtime flags

Goal: reduce Python overhead and fuse kernels where possible using torch.compile, and enable safe TF32/precision knobs on Ampere+ GPUs.

Code: performance flags and torch.compile

# Compile model with torch.compile and set performance flags
import torch
import contextlib

# Optional performance knobs (safe on Ampere+ GPUs; harmless on CPU)
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
except Exception:
    pass

# Prefer higher-precision matmul kernels if available on your hardware
with contextlib.suppress(Exception):
    torch.set_float32_matmul_precision('high')  # 'high' or 'medium'

# Choose a compile configuration
compile_backend = 'inductor'          # default backend
compile_mode = 'default'              # try 'reduce-overhead' or 'max-autotune' later
dynamic_shapes = False                # set True if you plan to change batch size frequently
compile_ok = False


try:
    model = torch.compile(model, backend=compile_backend, mode=compile_mode, dynamic=dynamic_shapes)
    compile_ok = True
    print(f"Model compiled with torch.compile (backend={compile_backend}, mode={compile_mode}, dynamic={dynamic_shapes})")
    print("Note: First iteration includes compile time; subsequent steps are faster.")
except Exception as e:
    print("torch.compile failed; falling back to eager. Error:\n", e)

Notes

  • The first iteration after torch.compile includes compilation overhead; measure steady-state throughput after warmup.
  • torch.backends.cudnn.benchmark = True helps when input sizes are stable.
  • torch.set_float32_matmul_precision('high') can improve matmul performance on supported hardware.

Training loop: warmup, cosine decay, gradient accumulation, and checkpoints

Goals

  • Stabilize early training with LR warmup.
  • Use cosine decay to anneal LR smoothly across the full training horizon.
  • Use gradient accumulation to simulate large effective batch sizes on a single GPU.
  • Rotate checkpoints to limit disk usage while keeping recent history.

Hyperparameters used in the run

  • initial_lr = 1e-6, peak_lr = 1e-4, min_lr = 0.1 * peak_lr.
  • Warmup set to ~2% of steps per epoch.
  • grad_accum_steps = 64 to scale effective batch size.
  • checkpoint_interval = 10000 steps (rotating saves).

Code: training function v2 (basic warmup + cosine decay)

import math
import os
from modules.Loss import calc_batch_loss
from modules.Generate import generate_sample_text

def train_model_v2(model, training_dataloader, validation_dataloader, optimizer, device,
                   num_epochs, evaluation_frequency, start_context,
                   tokenizer, checkpoint_interval, total_steps_per_epoch, warmup_steps, initial_lr, peak_lr, min_lr):
    
    training_losses, validation_losses, total_tokens_processed, learning_rates = [], [], [], []
    total_tokens_processed, global_step = 0, -1
    total_training_steps = num_epochs * total_steps_per_epoch
    lr_increment = (peak_lr - initial_lr) / warmup_steps

    for epoch in range(num_epochs):
        model.train()
        for batch in training_dataloader:
            optimizer.zero_grad()
            global_step += 1
            if global_step < warmup_steps:
                lr = initial_lr + global_step * lr_increment
            else:
                progress = (global_step - warmup_steps) / (total_training_steps - warmup_steps)
                lr = min_lr + 0.5 * (peak_lr - min_lr) * (1 + math.cos(math.pi * progress))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
                learning_rates.append(lr)
                loss = calc_batch_loss(batch['input_ids'], batch['targets'], model, device)
                loss.backward()

                if global_step >= warmup_steps:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0)
                
                optimizer.step()
                training_losses.append(loss.item())
                total_tokens_processed += (batch['input_ids'] != -100).sum().item()
                print(f"Epoch {epoch + 1}, Step {global_step}: Tokens Processed = {total_tokens_processed}, Training Loss = {loss.item()}")

            if global_step >= evaluation_frequency and global_step % evaluation_frequency == 0:
                model.eval()
                val_batch = next(iter(validation_dataloader))
                with torch.no_grad():
                    val_loss = calc_batch_loss(val_batch['input_ids'], val_batch['targets'], model, device)
                validation_losses.append(val_loss.item())
                print(f"--- Evaluation at Epoch {epoch + 1}, Step {global_step}: Validation Loss = {val_loss.item()} ---")
                generate_sample_text(model, tokenizer, device, start_context)
                model.train()
            
            if global_step >= checkpoint_interval and global_step % checkpoint_interval == 0:
                base_ckpt = "autosave_ckpt1_sydsgpt_v2_164m_trained_model_optimizer.pth"
                prev1_ckpt = "autosave_ckpt1_prev1_sydsgpt_v2_164m_trained_model_optimizer.pth"

                try:
                    if os.path.exists(prev1_ckpt):
                        os.remove(prev1_ckpt)
                except Exception:
                    pass

                try:
                    if os.path.exists(base_ckpt):
                        os.replace(base_ckpt, prev1_ckpt)
                except Exception:
                    pass

                torch.save({"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, base_ckpt)
                print(f"Checkpoint saved (rotating): {base_ckpt} | prev1 -> {prev1_ckpt}")


    return training_losses, validation_losses, total_tokens_processed, learning_rates

Code: optimizer instantiation

from galore_torch import GaLoreAdamW
optimizer = GaLoreAdamW(model.parameters(), weight_decay=0.05)

Code: training function v3 (warmup + cosine + gradient accumulation + rotated checkpoints)

import math
import os
from modules.Loss import calc_batch_loss
from modules.Generate import generate_sample_text

def train_model_v3(model, training_dataloader, validation_dataloader, optimizer, device,
                   num_epochs, evaluation_frequency, start_context,
                   tokenizer, checkpoint_interval, total_steps_per_epoch, warmup_steps, initial_lr, peak_lr, min_lr,
                   grad_accum_steps: int = 1):
    """
    Train with cosine decay + warmup and optional gradient accumulation.

    Notes:
    - LR/warmup here are updated per batch (DataLoader iteration). If you want warmup
      in optimizer steps, compute warmup_steps accordingly (divide by grad_accum_steps).
    - loss is scaled by 1/grad_accum_steps before backward to keep gradients invariant.
    """
    training_losses, validation_losses, total_tokens_processed, learning_rates = [], [], [], []
    total_tokens_processed, global_step = 0, -1
    total_training_steps = num_epochs * total_steps_per_epoch
    lr_increment = (peak_lr - initial_lr) / max(1, warmup_steps)
    accum_counter = 0
    
    optimizer.zero_grad(set_to_none=True)
    
    for epoch in range(num_epochs):
        model.train()
        for batch in training_dataloader:
            global_step += 1
            # Learning rate schedule per batch step
            if global_step < warmup_steps:
                lr = initial_lr + global_step * lr_increment
            else:
                progress = (global_step - warmup_steps) / max(1, (total_training_steps - warmup_steps))
                lr = min_lr + 0.5 * (peak_lr - min_lr) * (1 + math.cos(math.pi * progress))
            for pg in optimizer.param_groups:
                pg['lr'] = lr
            learning_rates.append(lr)
            
            # Forward + backward (accumulated)
            loss = calc_batch_loss(batch['input_ids'], batch['targets'], model, device)
            training_losses.append(loss.item())  # log unscaled loss
            (loss / max(1, grad_accum_steps)).backward()
            accum_counter += 1
            
            # Token accounting (per batch)
            total_tokens_processed += (batch['input_ids'] != -100).sum().item()
            
            did_optimizer_step = False
            if accum_counter % max(1, grad_accum_steps) == 0:
                if global_step >= warmup_steps:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                did_optimizer_step = True
            
            print(f"Epoch {epoch + 1}, Step {global_step} ({'opt-step' if did_optimizer_step else 'accumulating'}): Tokens Processed = {total_tokens_processed}, Training Loss = {loss.item():.4f}, LR = {lr:.2e}")
            
            # Periodic evaluation
            if global_step >= evaluation_frequency and global_step % evaluation_frequency == 0:
                model.eval()
                try:
                    val_batch = next(iter(validation_dataloader))
                    with torch.no_grad():
                        val_loss = calc_batch_loss(val_batch['input_ids'], val_batch['targets'], model, device)
                    validation_losses.append(val_loss.item())
                    print(f"--- Evaluation at Epoch {epoch + 1}, Step {global_step}: Validation Loss = {val_loss.item():.4f} ---")
                    generate_sample_text(model, tokenizer, device, start_context)
                except StopIteration:
                    print("Validation loader empty; skipping eval.")
                finally:
                    model.train()
            
            # Checkpoint rotation (keep last 2)
            if global_step >= checkpoint_interval and global_step % checkpoint_interval == 0:
                base_ckpt = "autosave_ckpt1_sydsgpt_v2_164m_trained_model_optimizer.pth"
                prev1_ckpt = "autosave_ckpt1_prev1_sydsgpt_v2_164m_trained_model_optimizer.pth"
                prev2_ckpt = "autosave_ckpt1_prev2_sydsgpt_v2_164m_trained_model_optimizer.pth"
                try:
                    if os.path.exists(prev2_ckpt):
                        os.remove(prev2_ckpt)
                except Exception:
                    pass
                try:
                    if os.path.exists(prev1_ckpt):
                        os.replace(prev1_ckpt, prev2_ckpt)
                except Exception:
                    pass
                try:
                    if os.path.exists(base_ckpt):
                        os.replace(base_ckpt, prev1_ckpt)
                except Exception:
                    pass
                torch.save({"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, base_ckpt)
                print(f"Checkpoint saved (rotating): {base_ckpt} | prev1 -> {prev1_ckpt} | prev2 -> {prev2_ckpt}")
        
        # Flush leftover grads at epoch end (if any)
        if accum_counter % max(1, grad_accum_steps) != 0:
            if global_step >= warmup_steps:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            print("Flushed leftover accumulated gradients at epoch end.")
    
    return training_losses, validation_losses, total_tokens_processed, learning_rates

Practical tips

  • Loss scaling: dividing the loss by grad_accum_steps before .backward() keeps gradient magnitudes consistent with larger batch training.
  • Gradient clipping: apply only at optimizer step time to avoid clipping partial gradients repeatedly.
  • Checkpoint rotation: keeps disk usage bounded while preserving recent history for recovery.

Running the experiment and saving final weights

  • Iinstantiate SydsGPTv2 with a 164M-parameter configuration and compile it if possible.
  • Used GaLoreAdamW with weight_decay=0.05.
  • Ran train_model_v3 with grad_accum_steps = 64 and saved the final model as "sydsgpt_v2_164m_trained_model-11.8B.pth".

Code: training invocation and final save

from galore_torch import GaLoreAdamW
optimizer = GaLoreAdamW(model.parameters(), weight_decay=0.05)


num_epochs = 2
training_losses, validation_losses, total_tokens_processed, learning_rates = train_model_v2(
    model,
    training_dataloader,
    validation_dataloader,
    optimizer,
    device,
    num_epochs,
    evaluation_frequency = 10000,
    start_context = "Once upon a time",
    tokenizer = enc,
    checkpoint_interval = 10000,
    total_steps_per_epoch = total_steps_per_epoch,
    warmup_steps = warmup_steps,
    initial_lr = initial_lr,
    peak_lr = peak_lr,
    min_lr = min_lr
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SydsGPTv2(SYDSGPT_CONFIG_V2_164M)
checkpoint = torch.load("autosave_ckpt1_sydsgpt_v2_164m_trained_model_optimizer.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = GaLoreAdamW(model.parameters(), weight_decay=0.05)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)
model.to(device)
num_epochs = 1
grad_accum_steps = 64  # effective batch = BATCH_SIZE * grad_accum_steps
training_losses, validation_losses, total_tokens_processed, learning_rates = train_model_v3(
    model,
    training_dataloader,
    validation_dataloader,
    optimizer,
    device,
    num_epochs,
    evaluation_frequency = 10000,
    start_context = "Once upon a time",
    tokenizer = enc,
    checkpoint_interval = 10000,
    total_steps_per_epoch = total_steps_per_epoch,
    warmup_steps = warmup_steps,
    initial_lr = initial_lr,
    peak_lr = peak_lr,
    min_lr = min_lr,
    grad_accum_steps = grad_accum_steps
)

torch.save(model.state_dict(), "sydsgpt_v2_164m_trained_model-11.8B.pth")

Notes

  • The notebook shows both and being used; is the final training function with gradient accumulation.
  • with yields an effective batch size of 128, which is a practical way to approximate larger-batch training on a single GPU.

Loading and generation

Code: loading the final checkpoint and generating text

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SydsGPTv2(SYDSGPT_CONFIG_V2_164M)
model.load_state_dict(torch.load("sydsgpt_v2_164m_trained_model-11.8B.pth", map_location=device))
model.to(device)

from modules.Generate import generate, text_to_tokens, tokens_to_text

import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")


input_text = "A deep neural network is a type of artificial neural network with multiple layers between the input and output layers, which allows it to learn hierarchical patterns in data."
input_tokens = text_to_tokens(input_text, tokenizer).to(device)
output_tokens = generate(model, input_tokens, 1000, SYDSGPT_CONFIG_V2_164M['context_length'], temperature = 1.5, top_k = 40)
output_text = tokens_to_text(output_tokens, tokenizer)
print(f"Output Text:\n {output_text}")

What to watch during generation

  • Temperature and top-k: higher temperature and top-k produce more diverse outputs but can also increase incoherence.
  • Context length: ensure the input fits within context_length or is truncated appropriately.
  • Token-to-text mapping: use the same tiktoken encoder used during training to avoid tokenization mismatches.

Observations from the run

Throughput and runtime

  • Running ~12B tokens on a single 3080 Ti required careful memory management: Flash Attention, gradient accumulation, and mixed-precision-friendly flags were essential.
  • torch.compile can reduce Python overhead and improve steady-state throughput, but the first iteration includes compilation time. Measure both compile time and steady-state tokens/sec.
  • Parquet shards and memory-mapped HF datasets kept RAM usage low and allowed streaming large corpora without loading everything into memory.

Stability

  • LR warmup prevented early divergence. A small initial_lr and a short warmup window (2% of steps per epoch) stabilized the first phase.
  • Cosine decay provided a smooth annealing schedule across the full run.
  • Gradient clipping applied at optimizer step time helped avoid gradient explosions after warmup.

Practical trade-offs

  • Shard size: larger shards reduce file count but increase I/O per read; tune SHARD_SIZE_CHUNKS to your disk and training pattern.
  • Batch size vs. accumulation: accumulation increases effective batch size but increases wall-clock time per optimizer step; choose grad_accum_steps to balance memory and throughput.
  • Checkpoint cadence: frequent checkpoints increase disk usage and I/O; rotating saves keep recent history while bounding storage.

Lessons learned and recommendations

Data

  • Pre-tokenize and persist tokenized rows to avoid repeated tokenization and to get accurate token counts for planning.
  • Use deterministic sharding and manifest files for reproducibility.

Model

  • Flash Attention (or scaled_dot_product_attention) is a practical way to reduce memory pressure and increase throughput on consumer GPUs. Validate numerical parity with a baseline.

Training

  • Warmup + cosine decay is a robust schedule for long runs.
  • Gradient accumulation is essential for single-GPU large-scale pretraining. Ensure correct loss scaling and clipping semantics.
  • Use rotating checkpoints to limit disk usage while keeping recoverability.

Performance

  • torch.compile can help but measure compile overhead vs. steady-state gains.
  • Enable TF32 and set_float32_matmul_precision on Ampere+ GPUs for faster matmuls when acceptable.

Final thoughts

This part of the series demonstrates how careful engineering across the data pipeline, attention kernel, training loop, and runtime configuration makes large-scale pretraining feasible even on constrained hardware. The code provided is a practical, reproducible blueprint: tokenize once, shard into fixed-length chunks, stream shards with memory-mapped HF datasets, replace attention with a Flash Attention style kernel, compile the model when possible, and run a disciplined training loop with warmup, cosine decay, gradient accumulation, and rotating checkpoints.

Try It Yourself

The full notebook with all the steps, from preparing the corpus, data loaders, loss computation, pretraining loop, text sampling and generation, is available here:

SydsGPT pretraining on Large corpus Repository

Clone the repo, open the Jupyter notebook, and step through the code.

Build It Yourself

If you want to try building it yourself, you can find the complete code with detailed explanations of each block in the source code section at the end of this post. All the best!

What comes next

Part 8 will focus on fine-tuning. Specifically on instruction fine‑tuning and alignment: curate and clean an instruction‑style dataset (paired prompts and high‑quality responses), normalize formatting and tokenization to match the pretraining pipeline, and split into train/validation shards for reproducible experiments. Experiment with lightweight adaptation methods first (LoRA/PEFT or adapters) to get fast iteration on learning rates, weight decay, and few‑epoch schedules before committing to full‑model fine‑tuning

Later, I will add tool calling for web search and build a RAG pipeline to interact with private data. The aim is a private assistant that respects privacy and delivers practical value, proving that small models can go far when engineered with care

Source Code

pretraining-largecorpus

Leave a Reply

Your email address will not be published. Required fields are marked *