Spaces:
Running
Running
Initial deploy of NanoChat-ClimbMix-D12 demo
Browse filesAdding model weights (step 971), Gradio UI (app.py), and Docker configuration for Hugging Face Space hosting. Includes nanochat engine and core dependencies.
- Dockerfile +27 -0
- LICENSE +21 -0
- app.py +36 -0
- meta_000971.json +38 -0
- model_000971.pt +3 -0
- nanochat/__init__.py +0 -0
- nanochat/__pycache__/__init__.cpython-310.pyc +0 -0
- nanochat/__pycache__/checkpoint_manager.cpython-310.pyc +0 -0
- nanochat/__pycache__/common.cpython-310.pyc +0 -0
- nanochat/__pycache__/core_eval.cpython-310.pyc +0 -0
- nanochat/__pycache__/dataloader.cpython-310.pyc +0 -0
- nanochat/__pycache__/dataset.cpython-310.pyc +0 -0
- nanochat/__pycache__/engine.cpython-310.pyc +0 -0
- nanochat/__pycache__/execution.cpython-310.pyc +0 -0
- nanochat/__pycache__/flash_attention.cpython-310.pyc +0 -0
- nanochat/__pycache__/gpt.cpython-310.pyc +0 -0
- nanochat/__pycache__/loss_eval.cpython-310.pyc +0 -0
- nanochat/__pycache__/optim.cpython-310.pyc +0 -0
- nanochat/__pycache__/report.cpython-310.pyc +0 -0
- nanochat/__pycache__/tokenizer.cpython-310.pyc +0 -0
- nanochat/checkpoint_manager.py +194 -0
- nanochat/common.py +278 -0
- nanochat/core_eval.py +262 -0
- nanochat/dataloader.py +166 -0
- nanochat/dataset.py +160 -0
- nanochat/engine.py +351 -0
- nanochat/execution.py +349 -0
- nanochat/flash_attention.py +187 -0
- nanochat/fp8.py +266 -0
- nanochat/gpt.py +465 -0
- nanochat/logo.svg +8 -0
- nanochat/loss_eval.py +65 -0
- nanochat/optim.py +533 -0
- nanochat/report.py +418 -0
- nanochat/tokenizer.py +406 -0
- nanochat/ui.html +566 -0
- pyproject.toml +74 -0
- requirements.txt +1 -0
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use slim base for smaller size on free CPU tier
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Install minimal build tools only if any C extensions are needed (nanochat is mostly pure torch/Python)
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential git curl \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Copy requirements first for better caching
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt \
|
| 14 |
+
&& pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 15 |
+
|
| 16 |
+
# Copy the rest of your project (nanochat/, weights, app.py, etc.)
|
| 17 |
+
COPY . .
|
| 18 |
+
|
| 19 |
+
# Clean up pip cache to save space
|
| 20 |
+
RUN pip cache purge
|
| 21 |
+
|
| 22 |
+
# Gradio defaults
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 25 |
+
|
| 26 |
+
# Run the app
|
| 27 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Andrej Karpathy
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
app.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from nanochat.engine import Engine
|
| 4 |
+
|
| 5 |
+
MODEL_PATH = "model_000971.pt"
|
| 6 |
+
META_PATH = "meta_000971.json"
|
| 7 |
+
|
| 8 |
+
print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
|
| 9 |
+
engine = Engine(model_path=MODEL_PATH, meta_path=META_PATH, device="cpu")
|
| 10 |
+
|
| 11 |
+
def chat_fn(message):
|
| 12 |
+
response = engine.generate(message, max_new_tokens=300, temperature=0.85) # higher temp = more fun/confident nonsense
|
| 13 |
+
return response
|
| 14 |
+
|
| 15 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
| 16 |
+
gr.Markdown("# 🧸 NanoChat-ClimbMix-D12 – The Confident Toddler LLM")
|
| 17 |
+
gr.Markdown("Inspired by Andrej Karpathy's nanochat. Currently in 'Preschool Phase': 100% confident spelling, 0% reliable facts. 😂")
|
| 18 |
+
gr.Markdown("**Coming soon:** D14 (Elementary), D16 (Middle School), D18 (High School), D20+ (University level) – less hallucinations, more wisdom!")
|
| 19 |
+
|
| 20 |
+
with gr.Accordion("⚠️ Hallucination Disclaimer", open=True):
|
| 21 |
+
gr.Markdown("This model boldly answers everything — even when wrong. Enjoy the comedy! Next versions will grow up fast.")
|
| 22 |
+
|
| 23 |
+
gr.ChatInterface(
|
| 24 |
+
fn=chat_fn,
|
| 25 |
+
examples=[
|
| 26 |
+
"Why is the sky blue?",
|
| 27 |
+
"How many planets are in the solar system?",
|
| 28 |
+
"Write Python code to say hello world",
|
| 29 |
+
"Explain photosynthesis in one sentence"
|
| 30 |
+
],
|
| 31 |
+
title="Chat with the Toddler",
|
| 32 |
+
description="Ask anything — it will reply with maximum confidence!"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
meta_000971.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"step": 971,
|
| 3 |
+
"val_bpb": 0.36456499504385426,
|
| 4 |
+
"model_config": {
|
| 5 |
+
"sequence_len": 2048,
|
| 6 |
+
"vocab_size": 32768,
|
| 7 |
+
"n_layer": 12,
|
| 8 |
+
"n_head": 6,
|
| 9 |
+
"n_kv_head": 6,
|
| 10 |
+
"n_embd": 768,
|
| 11 |
+
"window_pattern": "L"
|
| 12 |
+
},
|
| 13 |
+
"user_config": {
|
| 14 |
+
"run": "d12_climbmix_sft_v1",
|
| 15 |
+
"device_type": "",
|
| 16 |
+
"model_tag": "d12_climbmix_v1",
|
| 17 |
+
"model_step": 2205,
|
| 18 |
+
"load_optimizer": 1,
|
| 19 |
+
"num_iterations": -1,
|
| 20 |
+
"max_seq_len": null,
|
| 21 |
+
"device_batch_size": 2,
|
| 22 |
+
"total_batch_size": null,
|
| 23 |
+
"embedding_lr": null,
|
| 24 |
+
"unembedding_lr": null,
|
| 25 |
+
"matrix_lr": null,
|
| 26 |
+
"init_lr_frac": 0.8,
|
| 27 |
+
"warmup_ratio": 0.0,
|
| 28 |
+
"warmdown_ratio": 0.5,
|
| 29 |
+
"final_lr_frac": 0.0,
|
| 30 |
+
"eval_every": -1,
|
| 31 |
+
"eval_tokens": 20971520,
|
| 32 |
+
"chatcore_every": 200,
|
| 33 |
+
"chatcore_max_cat": -1,
|
| 34 |
+
"chatcore_max_sample": 24,
|
| 35 |
+
"mmlu_epochs": 3,
|
| 36 |
+
"gsm8k_epochs": 4
|
| 37 |
+
}
|
| 38 |
+
}
|
model_000971.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb7c387bae513ea816333e6bce1f442491dc3ab6535f0f62ed8032a82121ec8f
|
| 3 |
+
size 792760433
|
nanochat/__init__.py
ADDED
|
File without changes
|
nanochat/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (131 Bytes). View file
|
|
|
nanochat/__pycache__/checkpoint_manager.cpython-310.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
nanochat/__pycache__/common.cpython-310.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
nanochat/__pycache__/core_eval.cpython-310.pyc
ADDED
|
Binary file (8.45 kB). View file
|
|
|
nanochat/__pycache__/dataloader.cpython-310.pyc
ADDED
|
Binary file (5.36 kB). View file
|
|
|
nanochat/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
nanochat/__pycache__/engine.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
nanochat/__pycache__/execution.cpython-310.pyc
ADDED
|
Binary file (8.72 kB). View file
|
|
|
nanochat/__pycache__/flash_attention.cpython-310.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
nanochat/__pycache__/gpt.cpython-310.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
nanochat/__pycache__/loss_eval.cpython-310.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
nanochat/__pycache__/optim.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
nanochat/__pycache__/report.cpython-310.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
nanochat/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
nanochat/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for saving and loading model/optim/state checkpoints.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from nanochat.common import get_base_dir
|
| 12 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 13 |
+
from nanochat.tokenizer import get_tokenizer
|
| 14 |
+
from nanochat.common import setup_default_logging
|
| 15 |
+
|
| 16 |
+
# Set up logging
|
| 17 |
+
setup_default_logging()
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
def log0(message):
|
| 20 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
| 21 |
+
logger.info(message)
|
| 22 |
+
|
| 23 |
+
def _patch_missing_config_keys(model_config_kwargs):
|
| 24 |
+
"""Add default values for new config keys missing in old checkpoints."""
|
| 25 |
+
# Old models were trained with full context (no sliding window)
|
| 26 |
+
if "window_pattern" not in model_config_kwargs:
|
| 27 |
+
model_config_kwargs["window_pattern"] = "L"
|
| 28 |
+
log0(f"Patching missing window_pattern in model config to 'L'")
|
| 29 |
+
|
| 30 |
+
def _patch_missing_keys(model_data, model_config):
|
| 31 |
+
"""Add default values for new parameters that may be missing in old checkpoints."""
|
| 32 |
+
n_layer = model_config.n_layer
|
| 33 |
+
# resid_lambdas defaults to 1.0 (identity scaling)
|
| 34 |
+
if "resid_lambdas" not in model_data:
|
| 35 |
+
model_data["resid_lambdas"] = torch.ones(n_layer)
|
| 36 |
+
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
| 37 |
+
# x0_lambdas defaults to 0.0 (disabled)
|
| 38 |
+
if "x0_lambdas" not in model_data:
|
| 39 |
+
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
| 40 |
+
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
| 41 |
+
|
| 42 |
+
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
| 43 |
+
if rank == 0:
|
| 44 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 45 |
+
# Save the model state parameters
|
| 46 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 47 |
+
torch.save(model_data, model_path)
|
| 48 |
+
logger.info(f"Saved model parameters to: {model_path}")
|
| 49 |
+
# Save the metadata dict as json
|
| 50 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 51 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 52 |
+
json.dump(meta_data, f, indent=2)
|
| 53 |
+
logger.info(f"Saved metadata to: {meta_path}")
|
| 54 |
+
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
| 55 |
+
if optimizer_data is not None:
|
| 56 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 57 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 58 |
+
torch.save(optimizer_data, optimizer_path)
|
| 59 |
+
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
| 60 |
+
|
| 61 |
+
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
| 62 |
+
# Load the model state
|
| 63 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 64 |
+
model_data = torch.load(model_path, map_location=device)
|
| 65 |
+
# Load the optimizer state if requested
|
| 66 |
+
optimizer_data = None
|
| 67 |
+
if load_optimizer:
|
| 68 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 69 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 70 |
+
# Load the metadata
|
| 71 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 72 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 73 |
+
meta_data = json.load(f)
|
| 74 |
+
return model_data, optimizer_data, meta_data
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_model(checkpoint_dir, step, device, phase):
|
| 78 |
+
"""
|
| 79 |
+
A bunch of repetitive code to build a model from a given checkpoint.
|
| 80 |
+
Returns:
|
| 81 |
+
- base model - uncompiled, not wrapped in DDP
|
| 82 |
+
- tokenizer
|
| 83 |
+
- meta data saved during base model training
|
| 84 |
+
"""
|
| 85 |
+
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
| 86 |
+
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
| 87 |
+
if device.type in {"cpu", "mps"}:
|
| 88 |
+
# Convert bfloat16 tensors to float for CPU inference
|
| 89 |
+
model_data = {
|
| 90 |
+
k: v.float() if v.dtype == torch.bfloat16 else v
|
| 91 |
+
for k, v in model_data.items()
|
| 92 |
+
}
|
| 93 |
+
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
| 94 |
+
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
| 95 |
+
model_config_kwargs = meta_data["model_config"]
|
| 96 |
+
_patch_missing_config_keys(model_config_kwargs)
|
| 97 |
+
log0(f"Building model with config: {model_config_kwargs}")
|
| 98 |
+
model_config = GPTConfig(**model_config_kwargs)
|
| 99 |
+
_patch_missing_keys(model_data, model_config)
|
| 100 |
+
with torch.device("meta"):
|
| 101 |
+
model = GPT(model_config)
|
| 102 |
+
# Load the model state
|
| 103 |
+
model.to_empty(device=device)
|
| 104 |
+
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
| 105 |
+
model.load_state_dict(model_data, strict=True, assign=True)
|
| 106 |
+
# Put the model in the right training phase / mode
|
| 107 |
+
if phase == "eval":
|
| 108 |
+
model.eval()
|
| 109 |
+
else:
|
| 110 |
+
model.train()
|
| 111 |
+
# Load the Tokenizer
|
| 112 |
+
tokenizer = get_tokenizer()
|
| 113 |
+
# Sanity check: compatibility between model and tokenizer
|
| 114 |
+
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
| 115 |
+
return model, tokenizer, meta_data
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def find_largest_model(checkpoints_dir):
|
| 119 |
+
# attempt to guess the model tag: take the biggest model available
|
| 120 |
+
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
| 121 |
+
if not model_tags:
|
| 122 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
| 123 |
+
# 1) normally all model tags are of the form d<number>, try that first:
|
| 124 |
+
candidates = []
|
| 125 |
+
for model_tag in model_tags:
|
| 126 |
+
match = re.match(r"d(\d+)", model_tag)
|
| 127 |
+
if match:
|
| 128 |
+
model_depth = int(match.group(1))
|
| 129 |
+
candidates.append((model_depth, model_tag))
|
| 130 |
+
if candidates:
|
| 131 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 132 |
+
return candidates[0][1]
|
| 133 |
+
# 2) if that failed, take the most recently updated model:
|
| 134 |
+
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
| 135 |
+
return model_tags[0]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def find_last_step(checkpoint_dir):
|
| 139 |
+
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
| 140 |
+
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
| 141 |
+
if not checkpoint_files:
|
| 142 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 143 |
+
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
| 144 |
+
return last_step
|
| 145 |
+
|
| 146 |
+
# -----------------------------------------------------------------------------
|
| 147 |
+
# convenience functions that take into account nanochat's directory structure
|
| 148 |
+
|
| 149 |
+
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
| 150 |
+
if model_tag is None:
|
| 151 |
+
# guess the model tag by defaulting to the largest model
|
| 152 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 153 |
+
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
| 154 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 155 |
+
if step is None:
|
| 156 |
+
# guess the step by defaulting to the last step
|
| 157 |
+
step = find_last_step(checkpoint_dir)
|
| 158 |
+
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
| 159 |
+
# build the model
|
| 160 |
+
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
| 161 |
+
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
| 162 |
+
return model, tokenizer, meta_data
|
| 163 |
+
|
| 164 |
+
def load_model(source, *args, **kwargs):
|
| 165 |
+
model_dir = {
|
| 166 |
+
"base": "base_checkpoints",
|
| 167 |
+
"sft": "chatsft_checkpoints",
|
| 168 |
+
"rl": "chatrl_checkpoints",
|
| 169 |
+
}[source]
|
| 170 |
+
base_dir = get_base_dir()
|
| 171 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 172 |
+
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
| 173 |
+
|
| 174 |
+
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
| 175 |
+
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
| 176 |
+
model_dir = {
|
| 177 |
+
"base": "base_checkpoints",
|
| 178 |
+
"sft": "chatsft_checkpoints",
|
| 179 |
+
"rl": "chatrl_checkpoints",
|
| 180 |
+
}[source]
|
| 181 |
+
base_dir = get_base_dir()
|
| 182 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 183 |
+
if model_tag is None:
|
| 184 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 185 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 186 |
+
if step is None:
|
| 187 |
+
step = find_last_step(checkpoint_dir)
|
| 188 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 189 |
+
if not os.path.exists(optimizer_path):
|
| 190 |
+
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
| 191 |
+
return None
|
| 192 |
+
log0(f"Loading optimizer state from {optimizer_path}")
|
| 193 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 194 |
+
return optimizer_data
|
nanochat/common.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities for nanochat.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import urllib.request
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from filelock import FileLock
|
| 12 |
+
|
| 13 |
+
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
| 14 |
+
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
| 15 |
+
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
| 16 |
+
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
| 17 |
+
def _detect_compute_dtype():
|
| 18 |
+
env = os.environ.get("NANOCHAT_DTYPE")
|
| 19 |
+
if env is not None:
|
| 20 |
+
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
| 23 |
+
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
| 24 |
+
capability = torch.cuda.get_device_capability()
|
| 25 |
+
if capability >= (8, 0):
|
| 26 |
+
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
| 27 |
+
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
| 28 |
+
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
| 29 |
+
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
| 30 |
+
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
| 31 |
+
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
| 32 |
+
|
| 33 |
+
class ColoredFormatter(logging.Formatter):
|
| 34 |
+
"""Custom formatter that adds colors to log messages."""
|
| 35 |
+
# ANSI color codes
|
| 36 |
+
COLORS = {
|
| 37 |
+
'DEBUG': '\033[36m', # Cyan
|
| 38 |
+
'INFO': '\033[32m', # Green
|
| 39 |
+
'WARNING': '\033[33m', # Yellow
|
| 40 |
+
'ERROR': '\033[31m', # Red
|
| 41 |
+
'CRITICAL': '\033[35m', # Magenta
|
| 42 |
+
}
|
| 43 |
+
RESET = '\033[0m'
|
| 44 |
+
BOLD = '\033[1m'
|
| 45 |
+
def format(self, record):
|
| 46 |
+
# Add color to the level name
|
| 47 |
+
levelname = record.levelname
|
| 48 |
+
if levelname in self.COLORS:
|
| 49 |
+
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
| 50 |
+
# Format the message
|
| 51 |
+
message = super().format(record)
|
| 52 |
+
# Add color to specific parts of the message
|
| 53 |
+
if levelname == 'INFO':
|
| 54 |
+
# Highlight numbers and percentages
|
| 55 |
+
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
| 56 |
+
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
| 57 |
+
return message
|
| 58 |
+
|
| 59 |
+
def setup_default_logging():
|
| 60 |
+
handler = logging.StreamHandler()
|
| 61 |
+
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 62 |
+
logging.basicConfig(
|
| 63 |
+
level=logging.INFO,
|
| 64 |
+
handlers=[handler]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
setup_default_logging()
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
def get_base_dir():
|
| 71 |
+
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
| 72 |
+
if os.environ.get("NANOCHAT_BASE_DIR"):
|
| 73 |
+
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
|
| 74 |
+
else:
|
| 75 |
+
home_dir = os.path.expanduser("~")
|
| 76 |
+
cache_dir = os.path.join(home_dir, ".cache")
|
| 77 |
+
nanochat_dir = os.path.join(cache_dir, "nanochat")
|
| 78 |
+
os.makedirs(nanochat_dir, exist_ok=True)
|
| 79 |
+
return nanochat_dir
|
| 80 |
+
|
| 81 |
+
def download_file_with_lock(url, filename, postprocess_fn=None):
|
| 82 |
+
"""
|
| 83 |
+
Downloads a file from a URL to a local path in the base directory.
|
| 84 |
+
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
| 85 |
+
"""
|
| 86 |
+
base_dir = get_base_dir()
|
| 87 |
+
file_path = os.path.join(base_dir, filename)
|
| 88 |
+
lock_path = file_path + ".lock"
|
| 89 |
+
|
| 90 |
+
if os.path.exists(file_path):
|
| 91 |
+
return file_path
|
| 92 |
+
|
| 93 |
+
with FileLock(lock_path):
|
| 94 |
+
# Only a single rank can acquire this lock
|
| 95 |
+
# All other ranks block until it is released
|
| 96 |
+
|
| 97 |
+
# Recheck after acquiring lock
|
| 98 |
+
if os.path.exists(file_path):
|
| 99 |
+
return file_path
|
| 100 |
+
|
| 101 |
+
# Download the content as bytes
|
| 102 |
+
print(f"Downloading {url}...")
|
| 103 |
+
with urllib.request.urlopen(url) as response:
|
| 104 |
+
content = response.read() # bytes
|
| 105 |
+
|
| 106 |
+
# Write to local file
|
| 107 |
+
with open(file_path, 'wb') as f:
|
| 108 |
+
f.write(content)
|
| 109 |
+
print(f"Downloaded to {file_path}")
|
| 110 |
+
|
| 111 |
+
# Run the postprocess function if provided
|
| 112 |
+
if postprocess_fn is not None:
|
| 113 |
+
postprocess_fn(file_path)
|
| 114 |
+
|
| 115 |
+
return file_path
|
| 116 |
+
|
| 117 |
+
def print0(s="",**kwargs):
|
| 118 |
+
ddp_rank = int(os.environ.get('RANK', 0))
|
| 119 |
+
if ddp_rank == 0:
|
| 120 |
+
print(s, **kwargs)
|
| 121 |
+
|
| 122 |
+
def print_banner():
|
| 123 |
+
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
| 124 |
+
banner = """
|
| 125 |
+
█████ █████
|
| 126 |
+
░░███ ░░███
|
| 127 |
+
████████ ██████ ██��█████ ██████ ██████ ░███████ ██████ ███████
|
| 128 |
+
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
| 129 |
+
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
| 130 |
+
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
| 131 |
+
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
| 132 |
+
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
| 133 |
+
"""
|
| 134 |
+
print0(banner)
|
| 135 |
+
|
| 136 |
+
def is_ddp_requested() -> bool:
|
| 137 |
+
"""
|
| 138 |
+
True if launched by torchrun (env present), even before init.
|
| 139 |
+
Used to decide whether we *should* initialize a PG.
|
| 140 |
+
"""
|
| 141 |
+
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
| 142 |
+
|
| 143 |
+
def is_ddp_initialized() -> bool:
|
| 144 |
+
"""
|
| 145 |
+
True if torch.distributed is available and the process group is initialized.
|
| 146 |
+
Used at cleanup to avoid destroying a non-existent PG.
|
| 147 |
+
"""
|
| 148 |
+
return dist.is_available() and dist.is_initialized()
|
| 149 |
+
|
| 150 |
+
def get_dist_info():
|
| 151 |
+
if is_ddp_requested():
|
| 152 |
+
# We rely on torchrun's env to decide if we SHOULD init.
|
| 153 |
+
# (Initialization itself happens in compute init.)
|
| 154 |
+
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
| 155 |
+
ddp_rank = int(os.environ['RANK'])
|
| 156 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
| 157 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
| 158 |
+
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
| 159 |
+
else:
|
| 160 |
+
return False, 0, 0, 1
|
| 161 |
+
|
| 162 |
+
def autodetect_device_type():
|
| 163 |
+
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
| 164 |
+
if torch.cuda.is_available():
|
| 165 |
+
device_type = "cuda"
|
| 166 |
+
elif torch.backends.mps.is_available():
|
| 167 |
+
device_type = "mps"
|
| 168 |
+
else:
|
| 169 |
+
device_type = "cpu"
|
| 170 |
+
print0(f"Autodetected device type: {device_type}")
|
| 171 |
+
return device_type
|
| 172 |
+
|
| 173 |
+
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
| 174 |
+
"""Basic initialization that we keep doing over and over, so make common."""
|
| 175 |
+
|
| 176 |
+
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
| 177 |
+
if device_type == "cuda":
|
| 178 |
+
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
| 179 |
+
if device_type == "mps":
|
| 180 |
+
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
| 181 |
+
|
| 182 |
+
# Reproducibility
|
| 183 |
+
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
| 184 |
+
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
| 185 |
+
torch.manual_seed(42)
|
| 186 |
+
if device_type == "cuda":
|
| 187 |
+
torch.cuda.manual_seed(42)
|
| 188 |
+
# skipping full reproducibility for now, possibly investigate slowdown later
|
| 189 |
+
# torch.use_deterministic_algorithms(True)
|
| 190 |
+
|
| 191 |
+
# Precision
|
| 192 |
+
if device_type == "cuda":
|
| 193 |
+
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
|
| 194 |
+
|
| 195 |
+
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
| 196 |
+
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 197 |
+
if is_ddp_requested and device_type == "cuda":
|
| 198 |
+
device = torch.device("cuda", ddp_local_rank)
|
| 199 |
+
torch.cuda.set_device(device) # make "cuda" default to this device
|
| 200 |
+
dist.init_process_group(backend="nccl", device_id=device)
|
| 201 |
+
dist.barrier()
|
| 202 |
+
else:
|
| 203 |
+
device = torch.device(device_type) # mps|cpu
|
| 204 |
+
|
| 205 |
+
if ddp_rank == 0:
|
| 206 |
+
logger.info(f"Distributed world size: {ddp_world_size}")
|
| 207 |
+
|
| 208 |
+
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
| 209 |
+
|
| 210 |
+
def compute_cleanup():
|
| 211 |
+
"""Companion function to compute_init, to clean things up before script exit"""
|
| 212 |
+
if is_ddp_initialized():
|
| 213 |
+
dist.destroy_process_group()
|
| 214 |
+
|
| 215 |
+
class DummyWandb:
|
| 216 |
+
"""Useful if we wish to not use wandb but have all the same signatures"""
|
| 217 |
+
def __init__(self):
|
| 218 |
+
pass
|
| 219 |
+
def log(self, *args, **kwargs):
|
| 220 |
+
pass
|
| 221 |
+
def finish(self):
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
# hardcoded BF16 peak flops for various GPUs
|
| 225 |
+
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
| 226 |
+
# and PR: https://github.com/karpathy/nanochat/pull/147
|
| 227 |
+
def get_peak_flops(device_name: str) -> float:
|
| 228 |
+
name = device_name.lower()
|
| 229 |
+
|
| 230 |
+
# Table order matters: more specific patterns first.
|
| 231 |
+
_PEAK_FLOPS_TABLE = (
|
| 232 |
+
# NVIDIA Blackwell
|
| 233 |
+
(["gb200"], 2.5e15),
|
| 234 |
+
(["grace blackwell"], 2.5e15),
|
| 235 |
+
(["b200"], 2.25e15),
|
| 236 |
+
(["b100"], 1.8e15),
|
| 237 |
+
# NVIDIA Hopper
|
| 238 |
+
(["h200", "nvl"], 836e12),
|
| 239 |
+
(["h200", "pcie"], 836e12),
|
| 240 |
+
(["h200"], 989e12),
|
| 241 |
+
(["h100", "nvl"], 835e12),
|
| 242 |
+
(["h100", "pcie"], 756e12),
|
| 243 |
+
(["h100"], 989e12),
|
| 244 |
+
(["h800", "nvl"], 989e12),
|
| 245 |
+
(["h800"], 756e12),
|
| 246 |
+
# NVIDIA Ampere data center
|
| 247 |
+
(["a100"], 312e12),
|
| 248 |
+
(["a800"], 312e12),
|
| 249 |
+
(["a40"], 149.7e12),
|
| 250 |
+
(["a30"], 165e12),
|
| 251 |
+
# NVIDIA Ada data center
|
| 252 |
+
(["l40s"], 362e12),
|
| 253 |
+
(["l40-s"], 362e12),
|
| 254 |
+
(["l40 s"], 362e12),
|
| 255 |
+
(["l4"], 121e12),
|
| 256 |
+
# AMD CDNA accelerators
|
| 257 |
+
(["mi355"], 2.5e15),
|
| 258 |
+
(["mi325"], 1.3074e15),
|
| 259 |
+
(["mi300x"], 1.3074e15),
|
| 260 |
+
(["mi300a"], 980.6e12),
|
| 261 |
+
(["mi250x"], 383e12),
|
| 262 |
+
(["mi250"], 362.1e12),
|
| 263 |
+
# Consumer RTX
|
| 264 |
+
(["5090"], 209.5e12),
|
| 265 |
+
(["4090"], 165.2e12),
|
| 266 |
+
(["3090"], 71e12),
|
| 267 |
+
)
|
| 268 |
+
for patterns, flops in _PEAK_FLOPS_TABLE:
|
| 269 |
+
if all(p in name for p in patterns):
|
| 270 |
+
return flops
|
| 271 |
+
if "data center gpu max 1550" in name:
|
| 272 |
+
# Ponte Vecchio (PVC) - dynamic based on compute units
|
| 273 |
+
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
| 274 |
+
return 512 * max_comp_units * 1300 * 10**6
|
| 275 |
+
|
| 276 |
+
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
| 277 |
+
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
| 278 |
+
return float('inf')
|
nanochat/core_eval.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions for evaluating the CORE metric, as described in the DCLM paper.
|
| 3 |
+
https://arxiv.org/abs/2406.11794
|
| 4 |
+
|
| 5 |
+
TODOs:
|
| 6 |
+
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
|
| 7 |
+
"""
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from jinja2 import Template
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# Prompt rendering utilities
|
| 16 |
+
|
| 17 |
+
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
| 18 |
+
"""Render complete prompts for a multiple choice question"""
|
| 19 |
+
template_str = """
|
| 20 |
+
{%- for example in fewshot_examples -%}
|
| 21 |
+
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
|
| 22 |
+
|
| 23 |
+
{% endfor -%}
|
| 24 |
+
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
|
| 25 |
+
template = Template(template_str)
|
| 26 |
+
fewshot_examples = fewshot_examples or []
|
| 27 |
+
context = {
|
| 28 |
+
'fewshot_examples': fewshot_examples,
|
| 29 |
+
'continuation_delimiter': continuation_delimiter,
|
| 30 |
+
'item': item
|
| 31 |
+
}
|
| 32 |
+
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
|
| 33 |
+
return prompts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
|
| 37 |
+
"""Render complete prompts for a schema question"""
|
| 38 |
+
template_str = """
|
| 39 |
+
{%- for example in fewshot_examples -%}
|
| 40 |
+
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 41 |
+
|
| 42 |
+
{% endfor -%}
|
| 43 |
+
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
|
| 44 |
+
template = Template(template_str)
|
| 45 |
+
fewshot_examples = fewshot_examples or []
|
| 46 |
+
context = {
|
| 47 |
+
'fewshot_examples': fewshot_examples,
|
| 48 |
+
'continuation_delimiter': continuation_delimiter,
|
| 49 |
+
'item': item
|
| 50 |
+
}
|
| 51 |
+
prompts = [template.render(context=context_option, **context)
|
| 52 |
+
for context_option in item['context_options']]
|
| 53 |
+
return prompts
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
| 57 |
+
"""
|
| 58 |
+
Render complete prompt for a language modeling task.
|
| 59 |
+
Notice that we manually trim the context in the template,
|
| 60 |
+
which in some datasets seems to have trailing whitespace (which we don't want).
|
| 61 |
+
"""
|
| 62 |
+
template_str = """
|
| 63 |
+
{%- for example in fewshot_examples -%}
|
| 64 |
+
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 65 |
+
|
| 66 |
+
{% endfor -%}
|
| 67 |
+
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
|
| 68 |
+
template = Template(template_str)
|
| 69 |
+
fewshot_examples = fewshot_examples or []
|
| 70 |
+
context = {
|
| 71 |
+
'fewshot_examples': fewshot_examples,
|
| 72 |
+
'continuation_delimiter': continuation_delimiter,
|
| 73 |
+
'item': item
|
| 74 |
+
}
|
| 75 |
+
# Return two prompts: without and with the continuation
|
| 76 |
+
prompt_without = template.render(include_continuation=False, **context)
|
| 77 |
+
prompt_with = template.render(include_continuation=True, **context)
|
| 78 |
+
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
|
| 79 |
+
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
|
| 80 |
+
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
|
| 81 |
+
# to detect the final continuation. Tokenizers...
|
| 82 |
+
prompt_without = prompt_without.strip()
|
| 83 |
+
return [prompt_without, prompt_with]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def find_common_length(token_sequences, direction='left'):
|
| 87 |
+
"""
|
| 88 |
+
Find the length of the common prefix or suffix across token sequences
|
| 89 |
+
- direction: 'left' for prefix, 'right' for suffix
|
| 90 |
+
"""
|
| 91 |
+
min_len = min(len(seq) for seq in token_sequences)
|
| 92 |
+
indices = {
|
| 93 |
+
'left': range(min_len),
|
| 94 |
+
'right': range(-1, -min_len-1, -1)
|
| 95 |
+
}[direction]
|
| 96 |
+
# Find the first position where the token sequences differ
|
| 97 |
+
for i, idx in enumerate(indices):
|
| 98 |
+
token = token_sequences[0][idx]
|
| 99 |
+
if not all(seq[idx] == token for seq in token_sequences):
|
| 100 |
+
return i
|
| 101 |
+
return min_len
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def stack_sequences(tokens, pad_token_id):
|
| 105 |
+
"""Stack up a list of token sequences, pad to longest on the right"""
|
| 106 |
+
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
|
| 107 |
+
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
|
| 108 |
+
for i, x in enumerate(tokens):
|
| 109 |
+
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
|
| 110 |
+
return input_ids
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def batch_sequences_mc(tokenizer, prompts):
|
| 114 |
+
# In multiple choice, contexts are the same but the continuation is different (common prefix)
|
| 115 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 116 |
+
# figure out the start and end of each continuation
|
| 117 |
+
answer_start_idx = find_common_length(tokens, direction='left')
|
| 118 |
+
start_indices = [answer_start_idx] * len(prompts)
|
| 119 |
+
end_indices = [len(x) for x in tokens]
|
| 120 |
+
return tokens, start_indices, end_indices
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def batch_sequences_schema(tokenizer, prompts):
|
| 124 |
+
# In schema tasks, contexts vary but continuation is the same (common suffix)
|
| 125 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 126 |
+
# figure out the start and end of each context
|
| 127 |
+
suffix_length = find_common_length(tokens, direction='right')
|
| 128 |
+
end_indices = [len(x) for x in tokens]
|
| 129 |
+
start_indices = [ei - suffix_length for ei in end_indices]
|
| 130 |
+
return tokens, start_indices, end_indices
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def batch_sequences_lm(tokenizer, prompts):
|
| 134 |
+
# In LM tasks, we have two prompts: without and with continuation
|
| 135 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 136 |
+
tokens_without, tokens_with = tokens
|
| 137 |
+
start_idx, end_idx = len(tokens_without), len(tokens_with)
|
| 138 |
+
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
|
| 139 |
+
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
|
| 140 |
+
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
|
| 141 |
+
return [tokens_with], [start_idx], [end_idx]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@torch.no_grad()
|
| 145 |
+
def forward_model(model, input_ids):
|
| 146 |
+
"""
|
| 147 |
+
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
| 148 |
+
The last column of losses is set to nan because we don't have autoregressive targets there.
|
| 149 |
+
"""
|
| 150 |
+
batch_size, seq_len = input_ids.size()
|
| 151 |
+
outputs = model(input_ids)
|
| 152 |
+
# Roll the tensor to the left by one position to get the (autoregressive) target ids
|
| 153 |
+
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
| 154 |
+
# Calculate cross entropy at all positions
|
| 155 |
+
losses = torch.nn.functional.cross_entropy(
|
| 156 |
+
outputs.view(batch_size * seq_len, -1),
|
| 157 |
+
target_ids.view(batch_size * seq_len),
|
| 158 |
+
reduction='none'
|
| 159 |
+
).view(batch_size, seq_len)
|
| 160 |
+
# Set the last column to be nan because there is no autoregressive loss there
|
| 161 |
+
losses[:, -1] = float('nan')
|
| 162 |
+
# Get the argmax predictions at each position
|
| 163 |
+
predictions = outputs.argmax(dim=-1)
|
| 164 |
+
return losses, predictions
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
| 169 |
+
"""Evaluate a single example, return True if correct, False otherwise"""
|
| 170 |
+
item = data[idx]
|
| 171 |
+
task_type = task_meta['task_type']
|
| 172 |
+
num_fewshot = task_meta['num_fewshot']
|
| 173 |
+
continuation_delimiter = task_meta['continuation_delimiter']
|
| 174 |
+
|
| 175 |
+
# Sample few-shot examples (excluding current item)
|
| 176 |
+
fewshot_examples = []
|
| 177 |
+
if num_fewshot > 0:
|
| 178 |
+
rng = random.Random(1234 + idx)
|
| 179 |
+
available_indices = [i for i in range(len(data)) if i != idx]
|
| 180 |
+
fewshot_indices = rng.sample(available_indices, min(num_fewshot, len(available_indices)))
|
| 181 |
+
fewshot_examples = [data[i] for i in fewshot_indices]
|
| 182 |
+
|
| 183 |
+
# Render prompts and batch sequences based on task type
|
| 184 |
+
if task_type == 'multiple_choice':
|
| 185 |
+
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
|
| 186 |
+
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
|
| 187 |
+
elif task_type == 'schema':
|
| 188 |
+
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
|
| 189 |
+
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
|
| 190 |
+
elif task_type == 'language_modeling':
|
| 191 |
+
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
|
| 192 |
+
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 195 |
+
|
| 196 |
+
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
|
| 197 |
+
# In these cases, we have to truncate sequences to max length and adjust the indices
|
| 198 |
+
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
| 199 |
+
max_tokens = model.max_seq_len
|
| 200 |
+
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
| 201 |
+
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
| 202 |
+
if len(t) > max_tokens:
|
| 203 |
+
num_to_crop = len(t) - max_tokens
|
| 204 |
+
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
| 205 |
+
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
| 206 |
+
new_end_idxs.append(e - num_to_crop)
|
| 207 |
+
assert s - num_to_crop >= 0, "this should never happen right?"
|
| 208 |
+
assert e - num_to_crop >= 0, "this should never happen right?"
|
| 209 |
+
else:
|
| 210 |
+
new_tokens.append(t) # keep unchanged
|
| 211 |
+
new_start_idxs.append(s)
|
| 212 |
+
new_end_idxs.append(e)
|
| 213 |
+
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
| 214 |
+
|
| 215 |
+
# Stack up all the sequences into a batch
|
| 216 |
+
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
| 217 |
+
input_ids = stack_sequences(tokens, pad_token_id)
|
| 218 |
+
input_ids = input_ids.to(device)
|
| 219 |
+
|
| 220 |
+
# Forward the model, get the autoregressive loss and argmax prediction at each token
|
| 221 |
+
losses, predictions = forward_model(model, input_ids)
|
| 222 |
+
|
| 223 |
+
# See if the losses/predictions come out correctly
|
| 224 |
+
if task_type == 'language_modeling':
|
| 225 |
+
# language modeling task is currently always batch size 1
|
| 226 |
+
si = start_idxs[0]
|
| 227 |
+
ei = end_idxs[0]
|
| 228 |
+
# predictions[i] predict input_ids[i+1] autoregressively
|
| 229 |
+
predicted_tokens = predictions[0, si-1:ei-1]
|
| 230 |
+
actual_tokens = input_ids[0, si:ei]
|
| 231 |
+
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
| 232 |
+
elif task_type in ['multiple_choice', 'schema']:
|
| 233 |
+
# For MC/schema: find the option with lowest average loss
|
| 234 |
+
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
| 235 |
+
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
| 236 |
+
pred_idx = mean_losses.index(min(mean_losses))
|
| 237 |
+
is_correct = pred_idx == item['gold']
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 240 |
+
|
| 241 |
+
return is_correct
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def evaluate_task(model, tokenizer, data, device, task_meta):
|
| 245 |
+
"""
|
| 246 |
+
This function is responsible for evaluating one task across many examples.
|
| 247 |
+
It also handles dispatch to all processes if the script is run with torchrun.
|
| 248 |
+
"""
|
| 249 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 250 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 251 |
+
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
| 252 |
+
# stride the examples to each rank
|
| 253 |
+
for idx in range(rank, len(data), world_size):
|
| 254 |
+
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
|
| 255 |
+
correct[idx] = float(is_correct)
|
| 256 |
+
# sync results across all the processes if running distributed
|
| 257 |
+
if world_size > 1:
|
| 258 |
+
dist.barrier()
|
| 259 |
+
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
| 260 |
+
# compute the mean
|
| 261 |
+
mean_correct = correct.mean().item()
|
| 262 |
+
return mean_correct
|
nanochat/dataloader.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distributed dataloaders for pretraining.
|
| 3 |
+
|
| 4 |
+
BOS-aligned bestfit:
|
| 5 |
+
- Every row starts with BOS token
|
| 6 |
+
- Documents packed using best-fit algorithm to minimize cropping
|
| 7 |
+
- When no document fits remaining space, crops a document to fill exactly
|
| 8 |
+
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
| 9 |
+
|
| 10 |
+
Compared to the original tokenizing_distributed_data_loader:
|
| 11 |
+
BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
| 12 |
+
there are fewer "confusing" tokens in the train/val batches as every token can
|
| 13 |
+
now attend back to the BOS token and sees the full context of the document.
|
| 14 |
+
|
| 15 |
+
Fallback to the original if you have very limited data AND long documents:
|
| 16 |
+
https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import pyarrow.parquet as pq
|
| 21 |
+
|
| 22 |
+
from nanochat.common import get_dist_info
|
| 23 |
+
from nanochat.dataset import list_parquet_files
|
| 24 |
+
|
| 25 |
+
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
| 26 |
+
"""
|
| 27 |
+
Infinite iterator over document batches (list of text strings) from parquet files.
|
| 28 |
+
|
| 29 |
+
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
| 30 |
+
where text_batch is a list of document strings, indices track position for resumption,
|
| 31 |
+
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
| 32 |
+
"""
|
| 33 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 34 |
+
|
| 35 |
+
warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy
|
| 36 |
+
parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy)
|
| 37 |
+
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
| 38 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 39 |
+
|
| 40 |
+
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
| 41 |
+
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
| 42 |
+
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
| 43 |
+
first_pass = True
|
| 44 |
+
pq_idx = resume_pq_idx
|
| 45 |
+
epoch = resume_epoch
|
| 46 |
+
|
| 47 |
+
while True: # iterate infinitely (multi-epoch)
|
| 48 |
+
pq_idx = resume_pq_idx if first_pass else 0
|
| 49 |
+
while pq_idx < len(parquet_paths):
|
| 50 |
+
filepath = parquet_paths[pq_idx]
|
| 51 |
+
pf = pq.ParquetFile(filepath)
|
| 52 |
+
# Start from resume point if resuming on same file, otherwise from DDP rank
|
| 53 |
+
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
| 54 |
+
base_idx = resume_rg_idx // ddp_world_size
|
| 55 |
+
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
| 56 |
+
rg_idx = base_idx * ddp_world_size + ddp_rank
|
| 57 |
+
if rg_idx >= pf.num_row_groups:
|
| 58 |
+
pq_idx += 1
|
| 59 |
+
continue
|
| 60 |
+
resume_rg_idx = None # only do this once
|
| 61 |
+
else:
|
| 62 |
+
rg_idx = ddp_rank
|
| 63 |
+
while rg_idx < pf.num_row_groups:
|
| 64 |
+
rg = pf.read_row_group(rg_idx)
|
| 65 |
+
batch = rg.column('text').to_pylist()
|
| 66 |
+
for i in range(0, len(batch), tokenizer_batch_size):
|
| 67 |
+
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
| 68 |
+
rg_idx += ddp_world_size
|
| 69 |
+
pq_idx += 1
|
| 70 |
+
first_pass = False
|
| 71 |
+
epoch += 1
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
| 75 |
+
tokenizer, B, T, split,
|
| 76 |
+
tokenizer_threads=4, tokenizer_batch_size=128,
|
| 77 |
+
device="cuda", resume_state_dict=None,
|
| 78 |
+
buffer_size=1000
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
BOS-aligned dataloader with Best-Fit Cropping.
|
| 82 |
+
|
| 83 |
+
Reduces token waste compared to simple greedy cropping by searching a buffer
|
| 84 |
+
for documents that fit well, while maintaining 100% utilization (no padding).
|
| 85 |
+
|
| 86 |
+
Algorithm for each row:
|
| 87 |
+
1. From buffered docs, pick the LARGEST doc that fits entirely
|
| 88 |
+
2. Repeat until no doc fits
|
| 89 |
+
3. When nothing fits, crop a doc to fill remaining space exactly
|
| 90 |
+
|
| 91 |
+
Key properties:
|
| 92 |
+
- Every row starts with BOS
|
| 93 |
+
- 100% utilization (no padding, every token is trained on)
|
| 94 |
+
- Approximately 35% of all tokens are discarded due to cropping
|
| 95 |
+
"""
|
| 96 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 97 |
+
|
| 98 |
+
row_capacity = T + 1
|
| 99 |
+
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
| 100 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 101 |
+
doc_buffer = []
|
| 102 |
+
pq_idx, rg_idx, epoch = 0, 0, 1
|
| 103 |
+
|
| 104 |
+
def refill_buffer():
|
| 105 |
+
nonlocal pq_idx, rg_idx, epoch
|
| 106 |
+
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
| 107 |
+
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
| 108 |
+
for tokens in token_lists:
|
| 109 |
+
doc_buffer.append(tokens)
|
| 110 |
+
|
| 111 |
+
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
| 112 |
+
# This gives us contiguous views and a single HtoD transfer
|
| 113 |
+
use_cuda = device == "cuda"
|
| 114 |
+
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
|
| 115 |
+
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
| 116 |
+
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
| 117 |
+
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
| 118 |
+
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
| 119 |
+
inputs = gpu_buffer[:B * T].view(B, T)
|
| 120 |
+
targets = gpu_buffer[B * T:].view(B, T)
|
| 121 |
+
|
| 122 |
+
while True:
|
| 123 |
+
for row_idx in range(B):
|
| 124 |
+
pos = 0
|
| 125 |
+
while pos < row_capacity:
|
| 126 |
+
# Ensure buffer has documents
|
| 127 |
+
while len(doc_buffer) < buffer_size:
|
| 128 |
+
refill_buffer()
|
| 129 |
+
|
| 130 |
+
remaining = row_capacity - pos
|
| 131 |
+
|
| 132 |
+
# Find largest doc that fits entirely
|
| 133 |
+
best_idx = -1
|
| 134 |
+
best_len = 0
|
| 135 |
+
for i, doc in enumerate(doc_buffer):
|
| 136 |
+
doc_len = len(doc)
|
| 137 |
+
if doc_len <= remaining and doc_len > best_len:
|
| 138 |
+
best_idx = i
|
| 139 |
+
best_len = doc_len
|
| 140 |
+
|
| 141 |
+
if best_idx >= 0:
|
| 142 |
+
doc = doc_buffer.pop(best_idx)
|
| 143 |
+
doc_len = len(doc)
|
| 144 |
+
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
|
| 145 |
+
pos += doc_len
|
| 146 |
+
else:
|
| 147 |
+
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
| 148 |
+
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
| 149 |
+
doc = doc_buffer.pop(shortest_idx)
|
| 150 |
+
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
| 151 |
+
pos += remaining
|
| 152 |
+
|
| 153 |
+
# Copy to pinned CPU buffer, then single HtoD transfer
|
| 154 |
+
cpu_inputs.copy_(row_buffer[:, :-1])
|
| 155 |
+
cpu_targets.copy_(row_buffer[:, 1:])
|
| 156 |
+
|
| 157 |
+
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
| 158 |
+
|
| 159 |
+
# Single HtoD copy into persistent GPU buffer and yield
|
| 160 |
+
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
| 161 |
+
yield inputs, targets, state_dict
|
| 162 |
+
|
| 163 |
+
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
| 164 |
+
"""Helper that omits state_dict from yields."""
|
| 165 |
+
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
| 166 |
+
yield inputs, targets
|
nanochat/dataset.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The base/pretraining dataset is a set of parquet files.
|
| 3 |
+
This file contains utilities for:
|
| 4 |
+
- iterating over the parquet files and yielding documents from it
|
| 5 |
+
- download the files on demand if they are not on disk
|
| 6 |
+
|
| 7 |
+
For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
import requests
|
| 14 |
+
import pyarrow.parquet as pq
|
| 15 |
+
from multiprocessing import Pool
|
| 16 |
+
|
| 17 |
+
from nanochat.common import get_base_dir
|
| 18 |
+
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
# The specifics of the current pretraining dataset
|
| 21 |
+
|
| 22 |
+
# The URL on the internet where the data is hosted and downloaded from on demand
|
| 23 |
+
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
| 24 |
+
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
| 25 |
+
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
| 26 |
+
base_dir = get_base_dir()
|
| 27 |
+
DATA_DIR = os.path.join(base_dir, "base_data_climbmix")
|
| 28 |
+
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
# These functions are useful utilities to other modules, can/should be imported
|
| 31 |
+
|
| 32 |
+
def list_parquet_files(data_dir=None, warn_on_legacy=False):
|
| 33 |
+
""" Looks into a data dir and returns full paths to all parquet files. """
|
| 34 |
+
data_dir = DATA_DIR if data_dir is None else data_dir
|
| 35 |
+
|
| 36 |
+
# Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B
|
| 37 |
+
# This code will eventually be deleted.
|
| 38 |
+
if not os.path.exists(data_dir):
|
| 39 |
+
if warn_on_legacy:
|
| 40 |
+
print()
|
| 41 |
+
print("=" * 80)
|
| 42 |
+
print(" WARNING: DATASET UPGRADE REQUIRED")
|
| 43 |
+
print("=" * 80)
|
| 44 |
+
print()
|
| 45 |
+
print(f" Could not find: {data_dir}")
|
| 46 |
+
print()
|
| 47 |
+
print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.")
|
| 48 |
+
print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.")
|
| 49 |
+
print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:")
|
| 50 |
+
print()
|
| 51 |
+
print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired")
|
| 52 |
+
print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data")
|
| 53 |
+
print()
|
| 54 |
+
print(" For now, falling back to your old FinewebEdu-100B dataset...")
|
| 55 |
+
print("=" * 80)
|
| 56 |
+
print()
|
| 57 |
+
# attempt a fallback to the legacy data directory
|
| 58 |
+
data_dir = os.path.join(base_dir, "base_data")
|
| 59 |
+
|
| 60 |
+
parquet_files = sorted([
|
| 61 |
+
f for f in os.listdir(data_dir)
|
| 62 |
+
if f.endswith('.parquet') and not f.endswith('.tmp')
|
| 63 |
+
])
|
| 64 |
+
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
| 65 |
+
return parquet_paths
|
| 66 |
+
|
| 67 |
+
def parquets_iter_batched(split, start=0, step=1):
|
| 68 |
+
"""
|
| 69 |
+
Iterate through the dataset, in batches of underlying row_groups for efficiency.
|
| 70 |
+
- split can be "train" or "val". the last parquet file will be val.
|
| 71 |
+
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
|
| 72 |
+
"""
|
| 73 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 74 |
+
parquet_paths = list_parquet_files()
|
| 75 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 76 |
+
for filepath in parquet_paths:
|
| 77 |
+
pf = pq.ParquetFile(filepath)
|
| 78 |
+
for rg_idx in range(start, pf.num_row_groups, step):
|
| 79 |
+
rg = pf.read_row_group(rg_idx)
|
| 80 |
+
texts = rg.column('text').to_pylist()
|
| 81 |
+
yield texts
|
| 82 |
+
|
| 83 |
+
# -----------------------------------------------------------------------------
|
| 84 |
+
def download_single_file(index):
|
| 85 |
+
""" Downloads a single file index, with some backoff """
|
| 86 |
+
|
| 87 |
+
# Construct the local filepath for this file and skip if it already exists
|
| 88 |
+
filename = index_to_filename(index)
|
| 89 |
+
filepath = os.path.join(DATA_DIR, filename)
|
| 90 |
+
if os.path.exists(filepath):
|
| 91 |
+
print(f"Skipping {filepath} (already exists)")
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
# Construct the remote URL for this file
|
| 95 |
+
url = f"{BASE_URL}/{filename}"
|
| 96 |
+
print(f"Downloading {filename}...")
|
| 97 |
+
|
| 98 |
+
# Download with retries
|
| 99 |
+
max_attempts = 5
|
| 100 |
+
for attempt in range(1, max_attempts + 1):
|
| 101 |
+
try:
|
| 102 |
+
response = requests.get(url, stream=True, timeout=30)
|
| 103 |
+
response.raise_for_status()
|
| 104 |
+
# Write to temporary file first
|
| 105 |
+
temp_path = filepath + f".tmp"
|
| 106 |
+
with open(temp_path, 'wb') as f:
|
| 107 |
+
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
| 108 |
+
if chunk:
|
| 109 |
+
f.write(chunk)
|
| 110 |
+
# Move temp file to final location
|
| 111 |
+
os.rename(temp_path, filepath)
|
| 112 |
+
print(f"Successfully downloaded {filename}")
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
except (requests.RequestException, IOError) as e:
|
| 116 |
+
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
| 117 |
+
# Clean up any partial files
|
| 118 |
+
for path in [filepath + f".tmp", filepath]:
|
| 119 |
+
if os.path.exists(path):
|
| 120 |
+
try:
|
| 121 |
+
os.remove(path)
|
| 122 |
+
except:
|
| 123 |
+
pass
|
| 124 |
+
# Try a few times with exponential backoff: 2^attempt seconds
|
| 125 |
+
if attempt < max_attempts:
|
| 126 |
+
wait_time = 2 ** attempt
|
| 127 |
+
print(f"Waiting {wait_time} seconds before retry...")
|
| 128 |
+
time.sleep(wait_time)
|
| 129 |
+
else:
|
| 130 |
+
print(f"Failed to download {filename} after {max_attempts} attempts")
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
|
| 138 |
+
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable")
|
| 139 |
+
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
# Prepare the output directory
|
| 143 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 144 |
+
|
| 145 |
+
# The way this works is that the user specifies the number of train shards to download via the -n flag.
|
| 146 |
+
# In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
|
| 147 |
+
num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
|
| 148 |
+
ids_to_download = list(range(num_train_shards))
|
| 149 |
+
ids_to_download.append(MAX_SHARD) # always download the validation shard
|
| 150 |
+
|
| 151 |
+
# Download the shards
|
| 152 |
+
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
| 153 |
+
print(f"Target directory: {DATA_DIR}")
|
| 154 |
+
print()
|
| 155 |
+
with Pool(processes=args.num_workers) as pool:
|
| 156 |
+
results = pool.map(download_single_file, ids_to_download)
|
| 157 |
+
|
| 158 |
+
# Report results
|
| 159 |
+
successful = sum(1 for success in results if success)
|
| 160 |
+
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
|
nanochat/engine.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engine for efficient inference of our models.
|
| 3 |
+
|
| 4 |
+
Everything works around token sequences:
|
| 5 |
+
- The user can send token sequences to the engine
|
| 6 |
+
- The engine returns the next token
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- The engine knows nothing about tokenization, it's purely token id sequences.
|
| 10 |
+
|
| 11 |
+
The whole thing is made as efficient as possible.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import signal
|
| 17 |
+
import warnings
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from collections import deque
|
| 20 |
+
from nanochat.common import compute_init, autodetect_device_type
|
| 21 |
+
from nanochat.checkpoint_manager import load_model
|
| 22 |
+
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
# Calculator tool helpers
|
| 25 |
+
@contextmanager
|
| 26 |
+
def timeout(duration, formula):
|
| 27 |
+
def timeout_handler(signum, frame):
|
| 28 |
+
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
| 29 |
+
|
| 30 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 31 |
+
signal.alarm(duration)
|
| 32 |
+
yield
|
| 33 |
+
signal.alarm(0)
|
| 34 |
+
|
| 35 |
+
def eval_with_timeout(formula, max_time=3):
|
| 36 |
+
try:
|
| 37 |
+
with timeout(max_time, formula):
|
| 38 |
+
with warnings.catch_warnings():
|
| 39 |
+
warnings.simplefilter("ignore", SyntaxWarning)
|
| 40 |
+
return eval(formula, {"__builtins__": {}}, {})
|
| 41 |
+
except Exception as e:
|
| 42 |
+
signal.alarm(0)
|
| 43 |
+
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def use_calculator(expr):
|
| 47 |
+
"""
|
| 48 |
+
Evaluate a Python expression safely.
|
| 49 |
+
Supports both math expressions and string operations like .count()
|
| 50 |
+
"""
|
| 51 |
+
# Remove commas from numbers
|
| 52 |
+
expr = expr.replace(",", "")
|
| 53 |
+
|
| 54 |
+
# Check if it's a pure math expression (old behavior)
|
| 55 |
+
if all([x in "0123456789*+-/.() " for x in expr]):
|
| 56 |
+
if "**" in expr: # disallow power operator
|
| 57 |
+
return None
|
| 58 |
+
return eval_with_timeout(expr)
|
| 59 |
+
|
| 60 |
+
# Check if it's a string operation we support
|
| 61 |
+
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
| 62 |
+
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
| 63 |
+
if not all([x in allowed_chars for x in expr]):
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
# Disallow dangerous patterns
|
| 67 |
+
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
| 68 |
+
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
| 69 |
+
'getattr', 'setattr', 'delattr', 'hasattr']
|
| 70 |
+
expr_lower = expr.lower()
|
| 71 |
+
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
# Only allow .count() method for now (can expand later)
|
| 75 |
+
if '.count(' not in expr:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
# Evaluate with timeout
|
| 79 |
+
return eval_with_timeout(expr)
|
| 80 |
+
|
| 81 |
+
# -----------------------------------------------------------------------------
|
| 82 |
+
class KVCache:
|
| 83 |
+
"""
|
| 84 |
+
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
| 85 |
+
|
| 86 |
+
Key differences from FA2-style cache:
|
| 87 |
+
- Tensors are (B, T, H, D) not (B, H, T, D)
|
| 88 |
+
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
| 89 |
+
- Position tracked per batch element via cache_seqlens tensor
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
self.max_seq_len = seq_len
|
| 95 |
+
self.n_layers = num_layers
|
| 96 |
+
self.n_heads = num_heads
|
| 97 |
+
self.head_dim = head_dim
|
| 98 |
+
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
| 99 |
+
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 100 |
+
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 101 |
+
# Current sequence length per batch element (FA3 needs int32)
|
| 102 |
+
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
| 103 |
+
|
| 104 |
+
def reset(self):
|
| 105 |
+
"""Reset cache to empty state."""
|
| 106 |
+
self.cache_seqlens.zero_()
|
| 107 |
+
|
| 108 |
+
def get_pos(self):
|
| 109 |
+
"""Get current position (assumes all batch elements at same position)."""
|
| 110 |
+
return self.cache_seqlens[0].item()
|
| 111 |
+
|
| 112 |
+
def get_layer_cache(self, layer_idx):
|
| 113 |
+
"""Return (k_cache, v_cache) views for a specific layer."""
|
| 114 |
+
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
| 115 |
+
|
| 116 |
+
def advance(self, num_tokens):
|
| 117 |
+
"""Advance the cache position by num_tokens."""
|
| 118 |
+
self.cache_seqlens += num_tokens
|
| 119 |
+
|
| 120 |
+
def prefill(self, other):
|
| 121 |
+
"""
|
| 122 |
+
Copy cached KV from another cache into this one.
|
| 123 |
+
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
| 124 |
+
"""
|
| 125 |
+
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
| 126 |
+
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
| 127 |
+
assert self.max_seq_len >= other.max_seq_len
|
| 128 |
+
other_pos = other.get_pos()
|
| 129 |
+
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
| 130 |
+
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
| 131 |
+
self.cache_seqlens.fill_(other_pos)
|
| 132 |
+
|
| 133 |
+
# -----------------------------------------------------------------------------
|
| 134 |
+
@torch.inference_mode()
|
| 135 |
+
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
| 136 |
+
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
| 137 |
+
assert temperature >= 0.0, "temperature must be non-negative"
|
| 138 |
+
if temperature == 0.0:
|
| 139 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 140 |
+
if top_k is not None and top_k > 0:
|
| 141 |
+
k = min(top_k, logits.size(-1))
|
| 142 |
+
vals, idx = torch.topk(logits, k, dim=-1)
|
| 143 |
+
vals = vals / temperature
|
| 144 |
+
probs = F.softmax(vals, dim=-1)
|
| 145 |
+
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 146 |
+
return idx.gather(1, choice)
|
| 147 |
+
else:
|
| 148 |
+
logits = logits / temperature
|
| 149 |
+
probs = F.softmax(logits, dim=-1)
|
| 150 |
+
return torch.multinomial(probs, num_samples=1, generator=rng)
|
| 151 |
+
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
|
| 154 |
+
class RowState:
|
| 155 |
+
# Per-row state tracking during generation
|
| 156 |
+
def __init__(self, current_tokens=None):
|
| 157 |
+
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
| 158 |
+
self.forced_tokens = deque() # Queue of tokens to force inject
|
| 159 |
+
self.in_python_block = False # Whether we are inside a python block
|
| 160 |
+
self.python_expr_tokens = [] # Tokens of the current python expression
|
| 161 |
+
self.completed = False # Whether this row has completed generation
|
| 162 |
+
|
| 163 |
+
class Engine:
|
| 164 |
+
|
| 165 |
+
def __init__(self, model, tokenizer):
|
| 166 |
+
self.model = model
|
| 167 |
+
self.tokenizer = tokenizer # needed for tool use
|
| 168 |
+
|
| 169 |
+
@torch.inference_mode()
|
| 170 |
+
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
| 171 |
+
"""Same as generate, but does single prefill and then clones the KV cache."""
|
| 172 |
+
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
| 173 |
+
device = self.model.get_device()
|
| 174 |
+
# NOTE: setting the dtype here and in this way is an ugly hack.
|
| 175 |
+
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
| 176 |
+
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
| 177 |
+
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
| 178 |
+
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
| 179 |
+
# In particular, the KVCache should allocate its tensors lazily
|
| 180 |
+
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 181 |
+
rng = torch.Generator(device=device)
|
| 182 |
+
rng.manual_seed(seed)
|
| 183 |
+
|
| 184 |
+
# Get the special tokens we need to coordinate the tool use state machine
|
| 185 |
+
get_special = lambda s: self.tokenizer.encode_special(s)
|
| 186 |
+
python_start = get_special("<|python_start|>")
|
| 187 |
+
python_end = get_special("<|python_end|>")
|
| 188 |
+
output_start = get_special("<|output_start|>")
|
| 189 |
+
output_end = get_special("<|output_end|>")
|
| 190 |
+
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
| 191 |
+
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
| 192 |
+
|
| 193 |
+
# 1) Run a batch 1 prefill of the prompt tokens
|
| 194 |
+
m = self.model.config
|
| 195 |
+
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
| 196 |
+
kv_cache_prefill = KVCache(
|
| 197 |
+
batch_size=1,
|
| 198 |
+
seq_len=len(tokens),
|
| 199 |
+
device=device,
|
| 200 |
+
dtype=dtype,
|
| 201 |
+
**kv_model_kwargs,
|
| 202 |
+
)
|
| 203 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 204 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
| 205 |
+
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
| 206 |
+
|
| 207 |
+
# 2) Replicate the KV cache for each sample/row
|
| 208 |
+
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
| 209 |
+
kv_cache_decode = KVCache(
|
| 210 |
+
batch_size=num_samples,
|
| 211 |
+
seq_len=kv_length_hint,
|
| 212 |
+
device=device,
|
| 213 |
+
dtype=dtype,
|
| 214 |
+
**kv_model_kwargs,
|
| 215 |
+
)
|
| 216 |
+
kv_cache_decode.prefill(kv_cache_prefill)
|
| 217 |
+
del kv_cache_prefill # no need to keep this memory around
|
| 218 |
+
|
| 219 |
+
# 3) Initialize states for each sample
|
| 220 |
+
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
| 221 |
+
|
| 222 |
+
# 4) Main generation loop
|
| 223 |
+
num_generated = 0
|
| 224 |
+
while True:
|
| 225 |
+
# Stop condition: we've reached max tokens
|
| 226 |
+
if max_tokens is not None and num_generated >= max_tokens:
|
| 227 |
+
break
|
| 228 |
+
# Stop condition: all rows are completed
|
| 229 |
+
if all(state.completed for state in row_states):
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
# Sample the next token for each row
|
| 233 |
+
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
| 234 |
+
sampled_tokens = next_ids[:, 0].tolist()
|
| 235 |
+
|
| 236 |
+
# Process each row: choose the next token, update state, optional tool use
|
| 237 |
+
token_column = [] # contains the next token id along each row
|
| 238 |
+
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
| 239 |
+
for i, state in enumerate(row_states):
|
| 240 |
+
# Select the next token in this row
|
| 241 |
+
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
| 242 |
+
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
| 243 |
+
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
| 244 |
+
token_column.append(next_token)
|
| 245 |
+
# Update the state of this row to include the next token
|
| 246 |
+
state.current_tokens.append(next_token)
|
| 247 |
+
# On <|assistant_end|> or <|bos|>, mark the row as completed
|
| 248 |
+
if next_token == assistant_end or next_token == bos:
|
| 249 |
+
state.completed = True
|
| 250 |
+
# Handle tool logic
|
| 251 |
+
if next_token == python_start:
|
| 252 |
+
state.in_python_block = True
|
| 253 |
+
state.python_expr_tokens = []
|
| 254 |
+
elif next_token == python_end and state.in_python_block:
|
| 255 |
+
state.in_python_block = False
|
| 256 |
+
if state.python_expr_tokens:
|
| 257 |
+
expr = self.tokenizer.decode(state.python_expr_tokens)
|
| 258 |
+
result = use_calculator(expr)
|
| 259 |
+
if result is not None:
|
| 260 |
+
result_tokens = self.tokenizer.encode(str(result))
|
| 261 |
+
state.forced_tokens.append(output_start)
|
| 262 |
+
state.forced_tokens.extend(result_tokens)
|
| 263 |
+
state.forced_tokens.append(output_end)
|
| 264 |
+
state.python_expr_tokens = []
|
| 265 |
+
elif state.in_python_block:
|
| 266 |
+
state.python_expr_tokens.append(next_token)
|
| 267 |
+
|
| 268 |
+
# Yield the token column
|
| 269 |
+
yield token_column, token_masks
|
| 270 |
+
num_generated += 1
|
| 271 |
+
|
| 272 |
+
# Prepare logits for next iteration
|
| 273 |
+
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
| 274 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
| 275 |
+
|
| 276 |
+
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
| 277 |
+
"""
|
| 278 |
+
Non-streaming batch generation that just returns the final token sequences.
|
| 279 |
+
Returns a list of token sequences (list of lists of ints).
|
| 280 |
+
Terminal tokens (assistant_end, bos) are not included in the results.
|
| 281 |
+
"""
|
| 282 |
+
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
| 283 |
+
bos = self.tokenizer.get_bos_token_id()
|
| 284 |
+
results = [tokens.copy() for _ in range(num_samples)]
|
| 285 |
+
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
| 286 |
+
completed = [False] * num_samples
|
| 287 |
+
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
| 288 |
+
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
| 289 |
+
if not completed[i]:
|
| 290 |
+
if token == assistant_end or token == bos:
|
| 291 |
+
completed[i] = True
|
| 292 |
+
else:
|
| 293 |
+
results[i].append(token)
|
| 294 |
+
masks[i].append(mask)
|
| 295 |
+
# Stop if all rows are completed
|
| 296 |
+
if all(completed):
|
| 297 |
+
break
|
| 298 |
+
return results, masks
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
"""
|
| 303 |
+
Quick inline test to make sure that the naive/slow model.generate function
|
| 304 |
+
is equivalent to the faster Engine.generate function here.
|
| 305 |
+
"""
|
| 306 |
+
import time
|
| 307 |
+
# init compute
|
| 308 |
+
device_type = autodetect_device_type()
|
| 309 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 310 |
+
# load the model and tokenizer
|
| 311 |
+
model, tokenizer, meta = load_model("base", device, phase="eval")
|
| 312 |
+
bos_token_id = tokenizer.get_bos_token_id()
|
| 313 |
+
# common hyperparameters
|
| 314 |
+
kwargs = dict(max_tokens=64, temperature=0.0)
|
| 315 |
+
# set the starting prompt
|
| 316 |
+
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
| 317 |
+
# generate the reference sequence using the model.generate() function
|
| 318 |
+
generated_tokens = []
|
| 319 |
+
torch.cuda.synchronize()
|
| 320 |
+
t0 = time.time()
|
| 321 |
+
stream = model.generate(prompt_tokens, **kwargs)
|
| 322 |
+
for token in stream:
|
| 323 |
+
generated_tokens.append(token)
|
| 324 |
+
chunk = tokenizer.decode([token])
|
| 325 |
+
print(chunk, end="", flush=True)
|
| 326 |
+
print()
|
| 327 |
+
torch.cuda.synchronize()
|
| 328 |
+
t1 = time.time()
|
| 329 |
+
print(f"Reference time: {t1 - t0:.2f}s")
|
| 330 |
+
reference_ids = generated_tokens
|
| 331 |
+
# generate tokens with Engine
|
| 332 |
+
generated_tokens = []
|
| 333 |
+
engine = Engine(model, tokenizer)
|
| 334 |
+
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
| 335 |
+
torch.cuda.synchronize()
|
| 336 |
+
t0 = time.time()
|
| 337 |
+
for token_column, token_masks in stream:
|
| 338 |
+
token = token_column[0] # only print out the first row
|
| 339 |
+
generated_tokens.append(token)
|
| 340 |
+
chunk = tokenizer.decode([token])
|
| 341 |
+
print(chunk, end="", flush=True)
|
| 342 |
+
print()
|
| 343 |
+
torch.cuda.synchronize()
|
| 344 |
+
t1 = time.time()
|
| 345 |
+
print(f"Engine time: {t1 - t0:.2f}s")
|
| 346 |
+
# compare the two sequences
|
| 347 |
+
for i in range(len(reference_ids)):
|
| 348 |
+
if reference_ids[i] != generated_tokens[i]:
|
| 349 |
+
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
| 350 |
+
break
|
| 351 |
+
print(f"Match: {reference_ids == generated_tokens}")
|
nanochat/execution.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sandboxed execution utilities for running Python code that comes out of an LLM.
|
| 3 |
+
Adapted from OpenAI HumanEval code:
|
| 4 |
+
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
|
| 5 |
+
|
| 6 |
+
What is covered:
|
| 7 |
+
- Each execution runs in its own process (can be killed if it hangs or crashes)
|
| 8 |
+
- Execution is limited by a timeout to stop infinite loops
|
| 9 |
+
- Memory limits are enforced by default (256MB)
|
| 10 |
+
- stdout and stderr are captured and returned
|
| 11 |
+
- Code runs in a temporary directory that is deleted afterwards
|
| 12 |
+
- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
|
| 13 |
+
|
| 14 |
+
What is not covered:
|
| 15 |
+
- Not a true security sandbox
|
| 16 |
+
- Network access is not blocked (e.g. sockets could be opened)
|
| 17 |
+
- Python's dynamic features (e.g. ctypes) could bypass restrictions
|
| 18 |
+
- No kernel-level isolation (no seccomp, no containers, no virtualization)
|
| 19 |
+
|
| 20 |
+
Overall this sandbox is good for evaluation of generated code and protects against
|
| 21 |
+
accidental destructive behavior, but it is not safe against malicious adversarial code.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import contextlib
|
| 25 |
+
import faulthandler
|
| 26 |
+
import io
|
| 27 |
+
import multiprocessing
|
| 28 |
+
import os
|
| 29 |
+
import platform
|
| 30 |
+
import signal
|
| 31 |
+
import tempfile
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from typing import Optional
|
| 34 |
+
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ExecutionResult:
|
| 39 |
+
"""Result of executing Python code in a sandbox."""
|
| 40 |
+
success: bool
|
| 41 |
+
stdout: str
|
| 42 |
+
stderr: str
|
| 43 |
+
error: Optional[str] = None
|
| 44 |
+
timeout: bool = False
|
| 45 |
+
memory_exceeded: bool = False
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
parts = []
|
| 49 |
+
parts.append(f"ExecutionResult(success={self.success}")
|
| 50 |
+
if self.timeout:
|
| 51 |
+
parts.append(", timeout=True")
|
| 52 |
+
if self.memory_exceeded:
|
| 53 |
+
parts.append(", memory_exceeded=True")
|
| 54 |
+
if self.error:
|
| 55 |
+
parts.append(f", error={self.error!r}")
|
| 56 |
+
if self.stdout:
|
| 57 |
+
parts.append(f", stdout={self.stdout!r}")
|
| 58 |
+
if self.stderr:
|
| 59 |
+
parts.append(f", stderr={self.stderr!r}")
|
| 60 |
+
parts.append(")")
|
| 61 |
+
return "".join(parts)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@contextlib.contextmanager
|
| 65 |
+
def time_limit(seconds: float):
|
| 66 |
+
def signal_handler(signum, frame):
|
| 67 |
+
raise TimeoutException("Timed out!")
|
| 68 |
+
|
| 69 |
+
signal.setitimer(signal.ITIMER_REAL, seconds)
|
| 70 |
+
signal.signal(signal.SIGALRM, signal_handler)
|
| 71 |
+
try:
|
| 72 |
+
yield
|
| 73 |
+
finally:
|
| 74 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@contextlib.contextmanager
|
| 78 |
+
def capture_io():
|
| 79 |
+
"""Capture stdout and stderr, and disable stdin."""
|
| 80 |
+
stdout_capture = io.StringIO()
|
| 81 |
+
stderr_capture = io.StringIO()
|
| 82 |
+
stdin_block = WriteOnlyStringIO()
|
| 83 |
+
with contextlib.redirect_stdout(stdout_capture):
|
| 84 |
+
with contextlib.redirect_stderr(stderr_capture):
|
| 85 |
+
with redirect_stdin(stdin_block):
|
| 86 |
+
yield stdout_capture, stderr_capture
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@contextlib.contextmanager
|
| 90 |
+
def create_tempdir():
|
| 91 |
+
with tempfile.TemporaryDirectory() as dirname:
|
| 92 |
+
with chdir(dirname):
|
| 93 |
+
yield dirname
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TimeoutException(Exception):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class WriteOnlyStringIO(io.StringIO):
|
| 101 |
+
"""StringIO that throws an exception when it's read from"""
|
| 102 |
+
|
| 103 |
+
def read(self, *args, **kwargs):
|
| 104 |
+
raise IOError
|
| 105 |
+
|
| 106 |
+
def readline(self, *args, **kwargs):
|
| 107 |
+
raise IOError
|
| 108 |
+
|
| 109 |
+
def readlines(self, *args, **kwargs):
|
| 110 |
+
raise IOError
|
| 111 |
+
|
| 112 |
+
def readable(self, *args, **kwargs):
|
| 113 |
+
"""Returns True if the IO object can be read."""
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class redirect_stdin(contextlib._RedirectStream): # type: ignore
|
| 118 |
+
_stream = "stdin"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@contextlib.contextmanager
|
| 122 |
+
def chdir(root):
|
| 123 |
+
if root == ".":
|
| 124 |
+
yield
|
| 125 |
+
return
|
| 126 |
+
cwd = os.getcwd()
|
| 127 |
+
os.chdir(root)
|
| 128 |
+
try:
|
| 129 |
+
yield
|
| 130 |
+
finally:
|
| 131 |
+
os.chdir(cwd)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
| 135 |
+
"""
|
| 136 |
+
This disables various destructive functions and prevents the generated code
|
| 137 |
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
| 138 |
+
removing filesystem files, etc.)
|
| 139 |
+
|
| 140 |
+
WARNING
|
| 141 |
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
| 142 |
+
generated code, should not be blindly executed outside of one. See the
|
| 143 |
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
| 144 |
+
with caution.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
if platform.uname().system != "Darwin":
|
| 148 |
+
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
| 149 |
+
import resource
|
| 150 |
+
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
| 151 |
+
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
| 152 |
+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
| 153 |
+
|
| 154 |
+
faulthandler.disable()
|
| 155 |
+
|
| 156 |
+
import builtins
|
| 157 |
+
|
| 158 |
+
builtins.exit = None
|
| 159 |
+
builtins.quit = None
|
| 160 |
+
|
| 161 |
+
import os
|
| 162 |
+
|
| 163 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 164 |
+
|
| 165 |
+
os.kill = None
|
| 166 |
+
os.system = None
|
| 167 |
+
os.putenv = None
|
| 168 |
+
os.remove = None
|
| 169 |
+
os.removedirs = None
|
| 170 |
+
os.rmdir = None
|
| 171 |
+
os.fchdir = None
|
| 172 |
+
os.setuid = None
|
| 173 |
+
os.fork = None
|
| 174 |
+
os.forkpty = None
|
| 175 |
+
os.killpg = None
|
| 176 |
+
os.rename = None
|
| 177 |
+
os.renames = None
|
| 178 |
+
os.truncate = None
|
| 179 |
+
os.replace = None
|
| 180 |
+
os.unlink = None
|
| 181 |
+
os.fchmod = None
|
| 182 |
+
os.fchown = None
|
| 183 |
+
os.chmod = None
|
| 184 |
+
os.chown = None
|
| 185 |
+
os.chroot = None
|
| 186 |
+
os.fchdir = None
|
| 187 |
+
os.lchflags = None
|
| 188 |
+
os.lchmod = None
|
| 189 |
+
os.lchown = None
|
| 190 |
+
os.getcwd = None
|
| 191 |
+
os.chdir = None
|
| 192 |
+
|
| 193 |
+
import shutil
|
| 194 |
+
|
| 195 |
+
shutil.rmtree = None
|
| 196 |
+
shutil.move = None
|
| 197 |
+
shutil.chown = None
|
| 198 |
+
|
| 199 |
+
import subprocess
|
| 200 |
+
|
| 201 |
+
subprocess.Popen = None # type: ignore
|
| 202 |
+
|
| 203 |
+
__builtins__["help"] = None
|
| 204 |
+
|
| 205 |
+
import sys
|
| 206 |
+
|
| 207 |
+
sys.modules["ipdb"] = None
|
| 208 |
+
sys.modules["joblib"] = None
|
| 209 |
+
sys.modules["resource"] = None
|
| 210 |
+
sys.modules["psutil"] = None
|
| 211 |
+
sys.modules["tkinter"] = None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
|
| 215 |
+
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
|
| 216 |
+
with create_tempdir():
|
| 217 |
+
|
| 218 |
+
# These system calls are needed when cleaning up tempdir.
|
| 219 |
+
import os
|
| 220 |
+
import shutil
|
| 221 |
+
|
| 222 |
+
rmtree = shutil.rmtree
|
| 223 |
+
rmdir = os.rmdir
|
| 224 |
+
chdir = os.chdir
|
| 225 |
+
unlink = os.unlink
|
| 226 |
+
|
| 227 |
+
# Disable functionalities that can make destructive changes to the test.
|
| 228 |
+
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
| 229 |
+
|
| 230 |
+
# Default to failure
|
| 231 |
+
result_dict.update({
|
| 232 |
+
"success": False,
|
| 233 |
+
"stdout": "",
|
| 234 |
+
"stderr": "",
|
| 235 |
+
"timeout": False,
|
| 236 |
+
"memory_exceeded": False,
|
| 237 |
+
"error": None,
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
exec_globals = {}
|
| 242 |
+
with capture_io() as (stdout_capture, stderr_capture):
|
| 243 |
+
with time_limit(timeout):
|
| 244 |
+
# WARNING
|
| 245 |
+
# This program exists to execute untrusted model-generated code. Although
|
| 246 |
+
# it is highly unlikely that model-generated code will do something overtly
|
| 247 |
+
# malicious in response to this test suite, model-generated code may act
|
| 248 |
+
# destructively due to a lack of model capability or alignment.
|
| 249 |
+
# Users are strongly encouraged to sandbox this evaluation suite so that it
|
| 250 |
+
# does not perform destructive actions on their host or network. For more
|
| 251 |
+
# information on how OpenAI sandboxes its code, see the accompanying paper.
|
| 252 |
+
# Once you have read this disclaimer and taken appropriate precautions,
|
| 253 |
+
# uncomment the following line and proceed at your own risk:
|
| 254 |
+
exec(code, exec_globals)
|
| 255 |
+
|
| 256 |
+
result_dict.update({
|
| 257 |
+
"success": True,
|
| 258 |
+
"stdout": stdout_capture.getvalue(),
|
| 259 |
+
"stderr": stderr_capture.getvalue(),
|
| 260 |
+
})
|
| 261 |
+
|
| 262 |
+
except TimeoutException:
|
| 263 |
+
result_dict.update({
|
| 264 |
+
"timeout": True,
|
| 265 |
+
"error": "Execution timed out",
|
| 266 |
+
})
|
| 267 |
+
|
| 268 |
+
except MemoryError as e:
|
| 269 |
+
result_dict.update({
|
| 270 |
+
"memory_exceeded": True,
|
| 271 |
+
"error": f"Memory limit exceeded: {e}",
|
| 272 |
+
})
|
| 273 |
+
|
| 274 |
+
except BaseException as e:
|
| 275 |
+
result_dict.update({
|
| 276 |
+
"error": f"{type(e).__name__}: {e}",
|
| 277 |
+
})
|
| 278 |
+
|
| 279 |
+
# Needed for cleaning up.
|
| 280 |
+
shutil.rmtree = rmtree
|
| 281 |
+
os.rmdir = rmdir
|
| 282 |
+
os.chdir = chdir
|
| 283 |
+
os.unlink = unlink
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def execute_code(
|
| 287 |
+
code: str,
|
| 288 |
+
timeout: float = 5.0, # 5 seconds default
|
| 289 |
+
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
|
| 290 |
+
) -> ExecutionResult:
|
| 291 |
+
"""
|
| 292 |
+
Execute Python code in a sandboxed environment.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
code: Python code to execute as a string
|
| 296 |
+
timeout: Maximum execution time in seconds (default: 5.0)
|
| 297 |
+
maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
ExecutionResult with success status, stdout/stderr, and error information
|
| 301 |
+
|
| 302 |
+
Example:
|
| 303 |
+
>>> result = execute_code("print('hello world')")
|
| 304 |
+
>>> result.success
|
| 305 |
+
True
|
| 306 |
+
>>> result.stdout
|
| 307 |
+
'hello world\\n'
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
manager = multiprocessing.Manager()
|
| 311 |
+
result_dict = manager.dict()
|
| 312 |
+
|
| 313 |
+
p = multiprocessing.Process(
|
| 314 |
+
target=_unsafe_execute,
|
| 315 |
+
args=(code, timeout, maximum_memory_bytes, result_dict)
|
| 316 |
+
)
|
| 317 |
+
p.start()
|
| 318 |
+
p.join(timeout=timeout + 1)
|
| 319 |
+
|
| 320 |
+
if p.is_alive():
|
| 321 |
+
p.kill()
|
| 322 |
+
return ExecutionResult(
|
| 323 |
+
success=False,
|
| 324 |
+
stdout="",
|
| 325 |
+
stderr="",
|
| 326 |
+
error="Execution timed out (process killed)",
|
| 327 |
+
timeout=True,
|
| 328 |
+
memory_exceeded=False,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if not result_dict:
|
| 332 |
+
return ExecutionResult(
|
| 333 |
+
success=False,
|
| 334 |
+
stdout="",
|
| 335 |
+
stderr="",
|
| 336 |
+
error="Execution failed (no result returned)",
|
| 337 |
+
timeout=True,
|
| 338 |
+
memory_exceeded=False,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return ExecutionResult(
|
| 342 |
+
success=result_dict["success"],
|
| 343 |
+
stdout=result_dict["stdout"],
|
| 344 |
+
stderr=result_dict["stderr"],
|
| 345 |
+
error=result_dict["error"],
|
| 346 |
+
timeout=result_dict["timeout"],
|
| 347 |
+
memory_exceeded=result_dict["memory_exceeded"],
|
| 348 |
+
)
|
| 349 |
+
|
nanochat/flash_attention.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
| 3 |
+
|
| 4 |
+
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
| 5 |
+
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
| 6 |
+
|
| 7 |
+
Usage (drop-in replacement for FA3):
|
| 8 |
+
from nanochat.flash_attention import flash_attn
|
| 9 |
+
|
| 10 |
+
# Training (no KV cache)
|
| 11 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 12 |
+
|
| 13 |
+
# Inference (with KV cache)
|
| 14 |
+
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
| 15 |
+
"""
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# Detection: Try to load FA3 on Hopper+ GPUs
|
| 22 |
+
# =============================================================================
|
| 23 |
+
def _load_flash_attention_3():
|
| 24 |
+
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
| 25 |
+
if not torch.cuda.is_available():
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
major, _ = torch.cuda.get_device_capability()
|
| 29 |
+
# FA3 kernels are compiled for Hopper (sm90) only
|
| 30 |
+
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
| 31 |
+
if major != 9:
|
| 32 |
+
return None
|
| 33 |
+
import os
|
| 34 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 35 |
+
from kernels import get_kernel
|
| 36 |
+
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
| 37 |
+
except Exception:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_fa3 = _load_flash_attention_3()
|
| 42 |
+
HAS_FA3 = _fa3 is not None
|
| 43 |
+
|
| 44 |
+
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
| 45 |
+
_override_impl = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _resolve_use_fa3():
|
| 49 |
+
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
| 50 |
+
if _override_impl == 'fa3':
|
| 51 |
+
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
| 52 |
+
return True
|
| 53 |
+
if _override_impl == 'sdpa':
|
| 54 |
+
return False
|
| 55 |
+
if HAS_FA3:
|
| 56 |
+
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
| 57 |
+
from nanochat.common import COMPUTE_DTYPE
|
| 58 |
+
if COMPUTE_DTYPE == torch.bfloat16:
|
| 59 |
+
return True
|
| 60 |
+
return False
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
USE_FA3 = _resolve_use_fa3()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# =============================================================================
|
| 67 |
+
# SDPA helpers
|
| 68 |
+
# =============================================================================
|
| 69 |
+
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
| 70 |
+
"""
|
| 71 |
+
SDPA attention with sliding window support.
|
| 72 |
+
q, k, v are (B, H, T, D) format.
|
| 73 |
+
"""
|
| 74 |
+
Tq = q.size(2)
|
| 75 |
+
Tk = k.size(2)
|
| 76 |
+
window = window_size[0]
|
| 77 |
+
|
| 78 |
+
# Full context, same length
|
| 79 |
+
if (window < 0 or window >= Tq) and Tq == Tk:
|
| 80 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
| 81 |
+
|
| 82 |
+
# Single token generation
|
| 83 |
+
if Tq == 1:
|
| 84 |
+
if window >= 0 and window < Tk:
|
| 85 |
+
# window is "left" tokens we need to include (window + 1) keys total
|
| 86 |
+
start = max(0, Tk - (window + 1))
|
| 87 |
+
k = k[:, :, start:, :]
|
| 88 |
+
v = v[:, :, start:, :]
|
| 89 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
| 90 |
+
|
| 91 |
+
# Need explicit mask for sliding window/chunk inference
|
| 92 |
+
device = q.device
|
| 93 |
+
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
| 94 |
+
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
| 95 |
+
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
| 96 |
+
mask = col_idx <= row_idx
|
| 97 |
+
|
| 98 |
+
# sliding window (left)
|
| 99 |
+
if window >= 0 and window < Tk:
|
| 100 |
+
mask = mask & ((row_idx - col_idx) <= window)
|
| 101 |
+
|
| 102 |
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# Public API: Same interface as FA3
|
| 106 |
+
# =============================================================================
|
| 107 |
+
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
| 108 |
+
"""
|
| 109 |
+
Flash Attention for training (no KV cache).
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
q, k, v: Tensors of shape (B, T, H, D)
|
| 113 |
+
causal: Whether to use causal masking
|
| 114 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Output tensor of shape (B, T, H, D)
|
| 118 |
+
"""
|
| 119 |
+
if USE_FA3:
|
| 120 |
+
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
| 121 |
+
|
| 122 |
+
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
| 123 |
+
q = q.transpose(1, 2)
|
| 124 |
+
k = k.transpose(1, 2)
|
| 125 |
+
v = v.transpose(1, 2)
|
| 126 |
+
enable_gqa = q.size(1) != k.size(1)
|
| 127 |
+
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
| 128 |
+
return y.transpose(1, 2) # back to (B, T, H, D)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
| 132 |
+
causal=False, window_size=(-1, -1)):
|
| 133 |
+
"""
|
| 134 |
+
Flash Attention with KV cache for inference.
|
| 135 |
+
|
| 136 |
+
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
q: Queries, shape (B, T_new, H, D)
|
| 140 |
+
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
| 141 |
+
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
| 142 |
+
cache_seqlens: Current position in cache, shape (B,) int32
|
| 143 |
+
causal: Whether to use causal masking
|
| 144 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Output tensor of shape (B, T_new, H, D)
|
| 148 |
+
"""
|
| 149 |
+
if USE_FA3:
|
| 150 |
+
return _fa3.flash_attn_with_kvcache(
|
| 151 |
+
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
| 152 |
+
causal=causal, window_size=window_size
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# SDPA fallback: manually manage KV cache
|
| 156 |
+
B, T_new, H, D = q.shape
|
| 157 |
+
pos = cache_seqlens[0].item() # assume uniform position across batch
|
| 158 |
+
|
| 159 |
+
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
| 160 |
+
if k is not None and v is not None:
|
| 161 |
+
k_cache[:, pos:pos+T_new, :, :] = k
|
| 162 |
+
v_cache[:, pos:pos+T_new, :, :] = v
|
| 163 |
+
|
| 164 |
+
# Get full cache up to current position + new tokens
|
| 165 |
+
end_pos = pos + T_new
|
| 166 |
+
k_full = k_cache[:, :end_pos, :, :]
|
| 167 |
+
v_full = v_cache[:, :end_pos, :, :]
|
| 168 |
+
|
| 169 |
+
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
| 170 |
+
q_sdpa = q.transpose(1, 2)
|
| 171 |
+
k_sdpa = k_full.transpose(1, 2)
|
| 172 |
+
v_sdpa = v_full.transpose(1, 2)
|
| 173 |
+
|
| 174 |
+
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
| 175 |
+
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
| 176 |
+
|
| 177 |
+
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# =============================================================================
|
| 181 |
+
# Export: flash_attn module interface (drop-in replacement for FA3)
|
| 182 |
+
# =============================================================================
|
| 183 |
+
from types import SimpleNamespace
|
| 184 |
+
flash_attn = SimpleNamespace(
|
| 185 |
+
flash_attn_func=flash_attn_func,
|
| 186 |
+
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
| 187 |
+
)
|
nanochat/fp8.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
|
| 2 |
+
|
| 3 |
+
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
|
| 4 |
+
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
|
| 5 |
+
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
|
| 6 |
+
subclass dispatch tables, etc.)
|
| 7 |
+
|
| 8 |
+
How FP8 training works
|
| 9 |
+
======================
|
| 10 |
+
A standard Linear layer does one matmul in forward and two in backward:
|
| 11 |
+
forward: output = input @ weight.T
|
| 12 |
+
backward: grad_input = grad_output @ weight
|
| 13 |
+
grad_weight= grad_output.T @ input
|
| 14 |
+
|
| 15 |
+
FP8 training wraps each of these three matmuls with:
|
| 16 |
+
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
|
| 17 |
+
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
|
| 18 |
+
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
|
| 19 |
+
4. Dequantize: _scaled_mm handles this internally using the inverse scales
|
| 20 |
+
|
| 21 |
+
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
|
| 22 |
+
torchao is just orchestration around these primitives. We can call them directly.
|
| 23 |
+
|
| 24 |
+
FP8 dtype choice
|
| 25 |
+
================
|
| 26 |
+
There are two FP8 formats. We use both, following the standard convention:
|
| 27 |
+
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
|
| 28 |
+
Higher precision (more mantissa bits), used for input and weight.
|
| 29 |
+
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
|
| 30 |
+
Wider range (more exponent bits), used for gradients which can be large.
|
| 31 |
+
|
| 32 |
+
torch._scaled_mm layout requirements
|
| 33 |
+
=====================================
|
| 34 |
+
The cuBLAS FP8 kernel requires specific memory layouts:
|
| 35 |
+
- First argument (A): must be row-major (contiguous)
|
| 36 |
+
- Second argument (B): must be column-major (B.t().contiguous().t())
|
| 37 |
+
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
|
| 38 |
+
already column-major — no copy needed. Otherwise we use _to_col_major().
|
| 39 |
+
|
| 40 |
+
How this differs from torchao's approach
|
| 41 |
+
========================================
|
| 42 |
+
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
|
| 43 |
+
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
|
| 44 |
+
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
|
| 45 |
+
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
|
| 46 |
+
output = input @ weight.T
|
| 47 |
+
the @ operator dispatches to aten.mm, which gets intercepted and routed to
|
| 48 |
+
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
|
| 49 |
+
a handler for every tensor operation that might touch an FP8 tensor.
|
| 50 |
+
|
| 51 |
+
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
|
| 52 |
+
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
|
| 53 |
+
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
|
| 54 |
+
opaque node rather than trying to trace inside.
|
| 55 |
+
|
| 56 |
+
The trade-off is in how torch.compile sees the two approaches:
|
| 57 |
+
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
|
| 58 |
+
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
|
| 59 |
+
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
|
| 60 |
+
amax computation with the preceding layer's activation function).
|
| 61 |
+
- ours: compile sees a single opaque call. It can optimize everything around
|
| 62 |
+
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
|
| 63 |
+
|
| 64 |
+
Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical.
|
| 65 |
+
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
|
| 66 |
+
compared to the matmul. In practice this means our version is slightly faster
|
| 67 |
+
(less compilation overhead, no tensor subclass dispatch cost) but can produce
|
| 68 |
+
subtly different floating-point rounding paths under torch.compile, since Inductor
|
| 69 |
+
generates a different graph. Numerics are bitwise identical in eager mode.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
import torch
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
|
| 75 |
+
from nanochat.common import COMPUTE_DTYPE
|
| 76 |
+
|
| 77 |
+
# Avoid division by zero when computing scale from an all-zeros tensor
|
| 78 |
+
EPS = 1e-12
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def _to_fp8(x, fp8_dtype):
|
| 83 |
+
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
|
| 84 |
+
|
| 85 |
+
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
|
| 86 |
+
"rowwise" which computes a separate scale per row). Tensorwise is faster
|
| 87 |
+
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
|
| 88 |
+
|
| 89 |
+
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
|
| 90 |
+
"""
|
| 91 |
+
fp8_max = torch.finfo(fp8_dtype).max
|
| 92 |
+
# Compute the max absolute value across the entire tensor
|
| 93 |
+
amax = x.float().abs().max()
|
| 94 |
+
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
|
| 95 |
+
# ensure consistent numerics between torch.compile and eager mode.
|
| 96 |
+
# (torchao does the same upcast — without it, compile/eager can diverge)
|
| 97 |
+
scale = fp8_max / amax.double().clamp(min=EPS)
|
| 98 |
+
scale = scale.float()
|
| 99 |
+
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
|
| 100 |
+
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
|
| 101 |
+
x_scaled = x.float() * scale
|
| 102 |
+
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
|
| 103 |
+
x_fp8 = x_clamped.to(fp8_dtype)
|
| 104 |
+
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
|
| 105 |
+
# convert FP8 values back to the original range during the matmul)
|
| 106 |
+
inv_scale = scale.reciprocal()
|
| 107 |
+
return x_fp8, inv_scale
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _to_col_major(x):
|
| 111 |
+
"""Rearrange a 2D tensor's memory to column-major layout.
|
| 112 |
+
|
| 113 |
+
torch._scaled_mm requires its second operand in column-major layout.
|
| 114 |
+
The trick: transpose -> contiguous (forces a copy in transposed order)
|
| 115 |
+
-> transpose back. The result has the same logical shape but column-major
|
| 116 |
+
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
|
| 117 |
+
"""
|
| 118 |
+
return x.t().contiguous().t()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# allow_in_graph tells torch.compile to treat this as an opaque operation —
|
| 122 |
+
# dynamo won't try to decompose it into smaller ops. See the module docstring
|
| 123 |
+
# for how this differs from torchao's tensor subclass approach.
|
| 124 |
+
@torch._dynamo.allow_in_graph
|
| 125 |
+
class _Float8Matmul(torch.autograd.Function):
|
| 126 |
+
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
| 127 |
+
|
| 128 |
+
The forward quantizes input and weight to FP8 and saves
|
| 129 |
+
the quantized tensors + scales for backward.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def forward(ctx, input_2d, weight):
|
| 134 |
+
# Quantize both operands to e4m3 (higher precision format)
|
| 135 |
+
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
| 136 |
+
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
| 137 |
+
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
|
| 138 |
+
|
| 139 |
+
# output = input @ weight.T
|
| 140 |
+
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
| 141 |
+
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
|
| 142 |
+
# strides (1, K) = column-major (good for second arg, no copy needed!)
|
| 143 |
+
output = torch._scaled_mm(
|
| 144 |
+
input_fp8,
|
| 145 |
+
weight_fp8.t(),
|
| 146 |
+
scale_a=input_inv,
|
| 147 |
+
scale_b=weight_inv,
|
| 148 |
+
out_dtype=input_2d.dtype,
|
| 149 |
+
# use_fast_accum=True accumulates the dot products in lower precision.
|
| 150 |
+
# Slightly less accurate but measurably faster. Standard practice for
|
| 151 |
+
# the forward pass; we use False in backward for more precise gradients.
|
| 152 |
+
use_fast_accum=True,
|
| 153 |
+
)
|
| 154 |
+
return output
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def backward(ctx, grad_output):
|
| 158 |
+
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
|
| 159 |
+
|
| 160 |
+
# === GEMM 1: grad_input = grad_output @ weight ===
|
| 161 |
+
# Shapes: [B, N] @ [N, K] -> [B, K]
|
| 162 |
+
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
| 163 |
+
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
| 164 |
+
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
| 165 |
+
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
| 166 |
+
w_col = _to_col_major(w_fp8)
|
| 167 |
+
grad_input = torch._scaled_mm(
|
| 168 |
+
go_fp8,
|
| 169 |
+
w_col,
|
| 170 |
+
scale_a=go_inv,
|
| 171 |
+
scale_b=w_inv,
|
| 172 |
+
out_dtype=grad_output.dtype,
|
| 173 |
+
use_fast_accum=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
| 177 |
+
# Shapes: [N, B] @ [B, K] -> [N, K]
|
| 178 |
+
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
| 179 |
+
# Transposing gives column-major, but first arg needs row-major,
|
| 180 |
+
# so we must call .contiguous() to physically rearrange the memory.
|
| 181 |
+
go_T = go_fp8.t().contiguous() # [N, B] row-major
|
| 182 |
+
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
| 183 |
+
grad_weight = torch._scaled_mm(
|
| 184 |
+
go_T,
|
| 185 |
+
in_col,
|
| 186 |
+
scale_a=go_inv,
|
| 187 |
+
scale_b=in_inv,
|
| 188 |
+
out_dtype=grad_output.dtype,
|
| 189 |
+
use_fast_accum=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return grad_input, grad_weight
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Float8Linear(nn.Linear):
|
| 196 |
+
"""Drop-in nn.Linear replacement that does FP8 compute.
|
| 197 |
+
|
| 198 |
+
Weights and biases remain in their original precision (e.g. fp32/bf16).
|
| 199 |
+
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def forward(self, input):
|
| 203 |
+
# Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
|
| 204 |
+
# reduced precision input, and we no longer rely on autocast to do this.
|
| 205 |
+
input = input.to(COMPUTE_DTYPE)
|
| 206 |
+
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
| 207 |
+
orig_shape = input.shape
|
| 208 |
+
input_2d = input.reshape(-1, orig_shape[-1])
|
| 209 |
+
output = _Float8Matmul.apply(input_2d, self.weight)
|
| 210 |
+
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
| 211 |
+
if self.bias is not None:
|
| 212 |
+
output = output + self.bias.to(output.dtype)
|
| 213 |
+
return output
|
| 214 |
+
|
| 215 |
+
@classmethod
|
| 216 |
+
def from_float(cls, mod):
|
| 217 |
+
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
|
| 218 |
+
|
| 219 |
+
Uses meta device to avoid allocating a temporary weight tensor — we
|
| 220 |
+
create the module shell on meta (shapes/dtypes only, no memory), then
|
| 221 |
+
point .weight and .bias to the original module's parameters.
|
| 222 |
+
"""
|
| 223 |
+
with torch.device("meta"):
|
| 224 |
+
new_mod = cls(mod.in_features, mod.out_features, bias=False)
|
| 225 |
+
new_mod.weight = mod.weight
|
| 226 |
+
new_mod.bias = mod.bias
|
| 227 |
+
return new_mod
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class Float8LinearConfig:
|
| 231 |
+
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
|
| 232 |
+
|
| 233 |
+
@staticmethod
|
| 234 |
+
def from_recipe_name(recipe_name):
|
| 235 |
+
if recipe_name != "tensorwise":
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
|
| 238 |
+
f"Rowwise/axiswise recipes require the full torchao library."
|
| 239 |
+
)
|
| 240 |
+
return Float8LinearConfig()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
|
| 244 |
+
"""Replace nn.Linear layers with Float8Linear throughout a module.
|
| 245 |
+
|
| 246 |
+
Walks the module tree in post-order (children before parents) and swaps
|
| 247 |
+
each nn.Linear that passes the optional filter. The new Float8Linear shares
|
| 248 |
+
the original weight and bias tensors — no copies, no extra memory.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
module: Root module to convert.
|
| 252 |
+
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
|
| 253 |
+
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
|
| 254 |
+
are converted. Common use: skip layers with dims not divisible by 16
|
| 255 |
+
(hardware requirement for FP8 matmuls on H100).
|
| 256 |
+
"""
|
| 257 |
+
def _convert(mod, prefix=""):
|
| 258 |
+
for name, child in mod.named_children():
|
| 259 |
+
fqn = f"{prefix}.{name}" if prefix else name
|
| 260 |
+
_convert(child, fqn)
|
| 261 |
+
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
|
| 262 |
+
if module_filter_fn is None or module_filter_fn(child, fqn):
|
| 263 |
+
setattr(mod, name, Float8Linear.from_float(child))
|
| 264 |
+
|
| 265 |
+
_convert(module)
|
| 266 |
+
return module
|
nanochat/gpt.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT model (rewrite, a lot simpler)
|
| 3 |
+
Notable features:
|
| 4 |
+
- rotary embeddings (and no positional embeddings)
|
| 5 |
+
- QK norm
|
| 6 |
+
- untied weights for token embedding and lm_head
|
| 7 |
+
- relu^2 activation in MLP
|
| 8 |
+
- norm after token embedding
|
| 9 |
+
- no learnable params in rmsnorm
|
| 10 |
+
- no bias in linear layers
|
| 11 |
+
- Group-Query Attention (GQA) support for more efficient inference
|
| 12 |
+
- Flash Attention 3 integration
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from functools import partial
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
| 23 |
+
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
| 24 |
+
|
| 25 |
+
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
| 26 |
+
from nanochat.flash_attention import flash_attn
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class GPTConfig:
|
| 30 |
+
sequence_len: int = 2048
|
| 31 |
+
vocab_size: int = 32768
|
| 32 |
+
n_layer: int = 12
|
| 33 |
+
n_head: int = 6 # number of query heads
|
| 34 |
+
n_kv_head: int = 6 # number of key/value heads (GQA)
|
| 35 |
+
n_embd: int = 768
|
| 36 |
+
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
| 37 |
+
# Characters: L=long (full context), S=short (half context)
|
| 38 |
+
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
| 39 |
+
window_pattern: str = "SSSL"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def norm(x):
|
| 43 |
+
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
| 44 |
+
|
| 45 |
+
class Linear(nn.Linear):
|
| 46 |
+
"""nn.Linear that casts weights to match input dtype in forward.
|
| 47 |
+
Replaces autocast: master weights stay fp32 for optimizer precision,
|
| 48 |
+
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return F.linear(x, self.weight.to(dtype=x.dtype))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def has_ve(layer_idx, n_layer):
|
| 54 |
+
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
| 55 |
+
return layer_idx % 2 == (n_layer - 1) % 2
|
| 56 |
+
|
| 57 |
+
def apply_rotary_emb(x, cos, sin):
|
| 58 |
+
assert x.ndim == 4 # multihead attention
|
| 59 |
+
d = x.shape[3] // 2
|
| 60 |
+
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
| 61 |
+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
| 62 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 63 |
+
return torch.cat([y1, y2], 3)
|
| 64 |
+
|
| 65 |
+
class CausalSelfAttention(nn.Module):
|
| 66 |
+
def __init__(self, config, layer_idx):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.layer_idx = layer_idx
|
| 69 |
+
self.n_head = config.n_head
|
| 70 |
+
self.n_kv_head = config.n_kv_head
|
| 71 |
+
self.n_embd = config.n_embd
|
| 72 |
+
self.head_dim = self.n_embd // self.n_head
|
| 73 |
+
assert self.n_embd % self.n_head == 0
|
| 74 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
| 75 |
+
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
| 76 |
+
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 77 |
+
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 78 |
+
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
| 79 |
+
self.ve_gate_channels = 12
|
| 80 |
+
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
| 81 |
+
|
| 82 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 83 |
+
B, T, C = x.size()
|
| 84 |
+
|
| 85 |
+
# Project the input to get queries, keys, and values
|
| 86 |
+
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
| 87 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 88 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 89 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 90 |
+
|
| 91 |
+
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
| 92 |
+
if ve is not None:
|
| 93 |
+
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
| 94 |
+
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
|
| 95 |
+
v = v + gate.unsqueeze(-1) * ve
|
| 96 |
+
|
| 97 |
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 98 |
+
cos, sin = cos_sin
|
| 99 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
| 100 |
+
q, k = norm(q), norm(k) # QK norm
|
| 101 |
+
q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
|
| 102 |
+
k = k * 1.15
|
| 103 |
+
|
| 104 |
+
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
| 105 |
+
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
| 106 |
+
if kv_cache is None:
|
| 107 |
+
# Training: causal attention with optional sliding window
|
| 108 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 109 |
+
else:
|
| 110 |
+
# Inference: use flash_attn_with_kvcache which handles cache management
|
| 111 |
+
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
| 112 |
+
y = flash_attn.flash_attn_with_kvcache(
|
| 113 |
+
q, k_cache, v_cache,
|
| 114 |
+
k=k, v=v,
|
| 115 |
+
cache_seqlens=kv_cache.cache_seqlens,
|
| 116 |
+
causal=True,
|
| 117 |
+
window_size=window_size,
|
| 118 |
+
)
|
| 119 |
+
# Advance position after last layer processes
|
| 120 |
+
if self.layer_idx == kv_cache.n_layers - 1:
|
| 121 |
+
kv_cache.advance(T)
|
| 122 |
+
|
| 123 |
+
# Re-assemble the heads and project back to residual stream
|
| 124 |
+
y = y.contiguous().view(B, T, -1)
|
| 125 |
+
y = self.c_proj(y)
|
| 126 |
+
return y
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class MLP(nn.Module):
|
| 130 |
+
def __init__(self, config):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 133 |
+
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x = self.c_fc(x)
|
| 137 |
+
x = F.relu(x).square()
|
| 138 |
+
x = self.c_proj(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Block(nn.Module):
|
| 143 |
+
def __init__(self, config, layer_idx):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
| 146 |
+
self.mlp = MLP(config)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 149 |
+
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
| 150 |
+
x = x + self.mlp(norm(x))
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class GPT(nn.Module):
|
| 155 |
+
def __init__(self, config, pad_vocab_size_to=64):
|
| 156 |
+
"""
|
| 157 |
+
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
| 158 |
+
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
| 159 |
+
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
| 160 |
+
"""
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
# Compute per-layer window sizes for sliding window attention
|
| 164 |
+
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
| 165 |
+
self.window_sizes = self._compute_window_sizes(config)
|
| 166 |
+
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
| 167 |
+
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
| 168 |
+
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
| 169 |
+
if padded_vocab_size != config.vocab_size:
|
| 170 |
+
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
| 171 |
+
self.transformer = nn.ModuleDict({
|
| 172 |
+
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
| 173 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 174 |
+
})
|
| 175 |
+
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
| 176 |
+
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
| 177 |
+
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
| 178 |
+
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
| 179 |
+
# Separate parameters so they can have different optimizer treatment
|
| 180 |
+
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
| 181 |
+
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
| 182 |
+
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
| 183 |
+
head_dim = config.n_embd // config.n_head
|
| 184 |
+
kv_dim = config.n_kv_head * head_dim
|
| 185 |
+
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
| 186 |
+
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
| 187 |
+
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
| 188 |
+
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
| 189 |
+
# In the future we can dynamically grow the cache, for now it's fine.
|
| 190 |
+
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
| 191 |
+
head_dim = config.n_embd // config.n_head
|
| 192 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 193 |
+
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
| 194 |
+
self.register_buffer("sin", sin, persistent=False)
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def init_weights(self):
|
| 198 |
+
"""
|
| 199 |
+
Initialize the full model in this one function for maximum clarity.
|
| 200 |
+
|
| 201 |
+
wte (embedding): normal, std=1.0
|
| 202 |
+
lm_head: normal, std=0.001
|
| 203 |
+
for each block:
|
| 204 |
+
attn.c_q: uniform, std=1/sqrt(n_embd)
|
| 205 |
+
attn.c_k: uniform, std=1/sqrt(n_embd)
|
| 206 |
+
attn.c_v: uniform, std=1/sqrt(n_embd)
|
| 207 |
+
attn.c_proj: zeros
|
| 208 |
+
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
| 209 |
+
mlp.c_proj: zeros
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
# Embedding and unembedding
|
| 213 |
+
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
| 214 |
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
| 215 |
+
|
| 216 |
+
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
| 217 |
+
n_embd = self.config.n_embd
|
| 218 |
+
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
| 219 |
+
for block in self.transformer.h:
|
| 220 |
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
| 221 |
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
| 222 |
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
| 223 |
+
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
| 224 |
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
|
| 225 |
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
| 226 |
+
|
| 227 |
+
# Per-layer scalars
|
| 228 |
+
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
| 229 |
+
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
| 230 |
+
|
| 231 |
+
# Value embeddings (init like c_v: uniform with same std)
|
| 232 |
+
for ve in self.value_embeds.values():
|
| 233 |
+
torch.nn.init.uniform_(ve.weight, -s, s)
|
| 234 |
+
|
| 235 |
+
# Gate weights init with small positive values so gates start slightly above neutral
|
| 236 |
+
for block in self.transformer.h:
|
| 237 |
+
if block.attn.ve_gate is not None:
|
| 238 |
+
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
|
| 239 |
+
|
| 240 |
+
# Rotary embeddings
|
| 241 |
+
head_dim = self.config.n_embd // self.config.n_head
|
| 242 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 243 |
+
self.cos, self.sin = cos, sin
|
| 244 |
+
|
| 245 |
+
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
| 246 |
+
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
| 247 |
+
# because GradScaler cannot unscale fp16 gradients.
|
| 248 |
+
if COMPUTE_DTYPE != torch.float16:
|
| 249 |
+
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
| 250 |
+
for ve in self.value_embeds.values():
|
| 251 |
+
ve.to(dtype=COMPUTE_DTYPE)
|
| 252 |
+
|
| 253 |
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
| 254 |
+
# TODO: bump base theta more? e.g. 100K is more common more recently
|
| 255 |
+
# autodetect the device from model embeddings
|
| 256 |
+
if device is None:
|
| 257 |
+
device = self.transformer.wte.weight.device
|
| 258 |
+
# stride the channels
|
| 259 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 260 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
| 261 |
+
# stride the time steps
|
| 262 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
| 263 |
+
# calculate the rotation frequencies at each (time, channel) pair
|
| 264 |
+
freqs = torch.outer(t, inv_freq)
|
| 265 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 266 |
+
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
| 267 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
| 268 |
+
return cos, sin
|
| 269 |
+
|
| 270 |
+
def _compute_window_sizes(self, config):
|
| 271 |
+
"""
|
| 272 |
+
Compute per-layer window sizes for sliding window attention.
|
| 273 |
+
|
| 274 |
+
Returns list of (left, right) tuples for FA3's window_size parameter:
|
| 275 |
+
- left: how many tokens before current position to attend to (-1 = unlimited)
|
| 276 |
+
- right: how many tokens after current position to attend to (0 for causal)
|
| 277 |
+
|
| 278 |
+
Pattern string is tiled across layers. Final layer always gets L (full context).
|
| 279 |
+
Characters: L=long (full context), S=short (half context)
|
| 280 |
+
"""
|
| 281 |
+
pattern = config.window_pattern.upper()
|
| 282 |
+
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
| 283 |
+
# Map characters to window sizes
|
| 284 |
+
long_window = config.sequence_len
|
| 285 |
+
short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
| 286 |
+
char_to_window = {
|
| 287 |
+
"L": (long_window, 0),
|
| 288 |
+
"S": (short_window, 0),
|
| 289 |
+
}
|
| 290 |
+
# Tile pattern across layers
|
| 291 |
+
window_sizes = []
|
| 292 |
+
for layer_idx in range(config.n_layer):
|
| 293 |
+
char = pattern[layer_idx % len(pattern)]
|
| 294 |
+
window_sizes.append(char_to_window[char])
|
| 295 |
+
# Final layer always gets full context
|
| 296 |
+
window_sizes[-1] = (long_window, 0)
|
| 297 |
+
return window_sizes
|
| 298 |
+
|
| 299 |
+
def get_device(self):
|
| 300 |
+
return self.transformer.wte.weight.device
|
| 301 |
+
|
| 302 |
+
def estimate_flops(self):
|
| 303 |
+
"""
|
| 304 |
+
Return the estimated FLOPs per token for the model (forward + backward).
|
| 305 |
+
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
| 306 |
+
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
| 307 |
+
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
| 308 |
+
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
| 309 |
+
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
| 310 |
+
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
| 311 |
+
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
| 312 |
+
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
| 313 |
+
"""
|
| 314 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 315 |
+
# Exclude non-matmul params: embeddings and per-layer scalars
|
| 316 |
+
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
| 317 |
+
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
| 318 |
+
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
| 319 |
+
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
| 320 |
+
# Sum attention FLOPs per layer, accounting for sliding window
|
| 321 |
+
attn_flops = 0
|
| 322 |
+
for window_size in self.window_sizes:
|
| 323 |
+
window = window_size[0] # (left, right) tuple, we use left
|
| 324 |
+
effective_seq = t if window < 0 else min(window, t)
|
| 325 |
+
attn_flops += 12 * h * q * effective_seq
|
| 326 |
+
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
| 327 |
+
return num_flops_per_token
|
| 328 |
+
|
| 329 |
+
def num_scaling_params(self):
|
| 330 |
+
"""
|
| 331 |
+
Return detailed parameter counts for scaling law analysis.
|
| 332 |
+
Different papers use different conventions:
|
| 333 |
+
- Kaplan et al. excluded embedding parameters
|
| 334 |
+
- Chinchilla included all parameters
|
| 335 |
+
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
| 336 |
+
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
| 337 |
+
|
| 338 |
+
Returns a dict with counts for each parameter group, so downstream analysis
|
| 339 |
+
can experiment with which combination gives the cleanest scaling laws.
|
| 340 |
+
"""
|
| 341 |
+
# Count each group separately (mirrors the grouping in setup_optimizers)
|
| 342 |
+
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
| 343 |
+
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
| 344 |
+
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 345 |
+
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
| 346 |
+
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
| 347 |
+
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
| 348 |
+
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
| 349 |
+
return {
|
| 350 |
+
'wte': wte,
|
| 351 |
+
'value_embeds': value_embeds,
|
| 352 |
+
'lm_head': lm_head,
|
| 353 |
+
'transformer_matrices': transformer_matrices,
|
| 354 |
+
'scalars': scalars,
|
| 355 |
+
'total': total,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
| 359 |
+
model_dim = self.config.n_embd
|
| 360 |
+
ddp, rank, local_rank, world_size = get_dist_info()
|
| 361 |
+
|
| 362 |
+
# Separate out all parameters into groups
|
| 363 |
+
matrix_params = list(self.transformer.h.parameters())
|
| 364 |
+
value_embeds_params = list(self.value_embeds.parameters())
|
| 365 |
+
embedding_params = list(self.transformer.wte.parameters())
|
| 366 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 367 |
+
resid_params = [self.resid_lambdas]
|
| 368 |
+
x0_params = [self.x0_lambdas]
|
| 369 |
+
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
| 370 |
+
|
| 371 |
+
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
| 372 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 373 |
+
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 374 |
+
|
| 375 |
+
# Build param_groups with all required fields explicit
|
| 376 |
+
param_groups = [
|
| 377 |
+
# AdamW groups (embeddings, lm_head, scalars)
|
| 378 |
+
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
| 379 |
+
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
| 380 |
+
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
| 381 |
+
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
| 382 |
+
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
| 383 |
+
]
|
| 384 |
+
# Muon groups (matrix params, grouped by shape for stacking)
|
| 385 |
+
for shape in sorted({p.shape for p in matrix_params}):
|
| 386 |
+
group_params = [p for p in matrix_params if p.shape == shape]
|
| 387 |
+
param_groups.append(dict(
|
| 388 |
+
kind='muon', params=group_params, lr=matrix_lr,
|
| 389 |
+
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
| 390 |
+
))
|
| 391 |
+
|
| 392 |
+
Factory = DistMuonAdamW if ddp else MuonAdamW
|
| 393 |
+
optimizer = Factory(param_groups)
|
| 394 |
+
for group in optimizer.param_groups:
|
| 395 |
+
group["initial_lr"] = group["lr"]
|
| 396 |
+
return optimizer
|
| 397 |
+
|
| 398 |
+
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
| 399 |
+
B, T = idx.size()
|
| 400 |
+
|
| 401 |
+
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
| 402 |
+
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
| 403 |
+
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
| 404 |
+
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
| 405 |
+
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
| 406 |
+
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
| 407 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
| 408 |
+
|
| 409 |
+
# Forward the trunk of the Transformer
|
| 410 |
+
x = self.transformer.wte(idx) # embed current token
|
| 411 |
+
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
| 412 |
+
x = norm(x)
|
| 413 |
+
x0 = x # save initial normalized embedding for x0 residual
|
| 414 |
+
for i, block in enumerate(self.transformer.h):
|
| 415 |
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
| 416 |
+
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
| 417 |
+
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
| 418 |
+
x = norm(x)
|
| 419 |
+
|
| 420 |
+
# Forward the lm_head (compute logits)
|
| 421 |
+
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
| 422 |
+
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
| 423 |
+
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
| 424 |
+
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
| 425 |
+
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
| 426 |
+
|
| 427 |
+
if targets is not None:
|
| 428 |
+
# training: given the targets, compute and return the loss
|
| 429 |
+
# TODO experiment with chunked cross-entropy?
|
| 430 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
| 431 |
+
return loss
|
| 432 |
+
else:
|
| 433 |
+
# inference: just return the logits directly
|
| 434 |
+
return logits
|
| 435 |
+
|
| 436 |
+
@torch.inference_mode()
|
| 437 |
+
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
| 438 |
+
"""
|
| 439 |
+
Naive autoregressive streaming inference.
|
| 440 |
+
To make it super simple, let's assume:
|
| 441 |
+
- batch size is 1
|
| 442 |
+
- ids and the yielded tokens are simple Python lists and ints
|
| 443 |
+
"""
|
| 444 |
+
assert isinstance(tokens, list)
|
| 445 |
+
device = self.get_device()
|
| 446 |
+
rng = None
|
| 447 |
+
if temperature > 0:
|
| 448 |
+
rng = torch.Generator(device=device)
|
| 449 |
+
rng.manual_seed(seed)
|
| 450 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
| 451 |
+
for _ in range(max_tokens):
|
| 452 |
+
logits = self.forward(ids) # (B, T, vocab_size)
|
| 453 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
| 454 |
+
if top_k is not None and top_k > 0:
|
| 455 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 456 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 457 |
+
if temperature > 0:
|
| 458 |
+
logits = logits / temperature
|
| 459 |
+
probs = F.softmax(logits, dim=-1)
|
| 460 |
+
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 461 |
+
else:
|
| 462 |
+
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
| 463 |
+
ids = torch.cat((ids, next_ids), dim=1)
|
| 464 |
+
token = next_ids.item()
|
| 465 |
+
yield token
|
nanochat/logo.svg
ADDED
|
|
nanochat/loss_eval.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A number of functions that help with evaluating a base model.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def evaluate_bpb(model, batches, steps, token_bytes):
|
| 10 |
+
"""
|
| 11 |
+
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
| 12 |
+
which is a tokenization vocab size-independent metric, meaning you are still comparing
|
| 13 |
+
apples:apples if you change the vocab size. The way this works is that instead of just
|
| 14 |
+
calculating the average loss as usual, you calculate the sum loss, and independently
|
| 15 |
+
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
| 16 |
+
the number of bytes that the target tokens represent.
|
| 17 |
+
|
| 18 |
+
The added complexity is so that:
|
| 19 |
+
1) All "normal" tokens are normalized by the length of the token in bytes
|
| 20 |
+
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
|
| 21 |
+
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
|
| 22 |
+
|
| 23 |
+
In addition to evaluate_loss, we need the token_bytes tensor:
|
| 24 |
+
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
|
| 25 |
+
each token id, or 0 if the token is to not be counted (e.g. special tokens).
|
| 26 |
+
"""
|
| 27 |
+
# record the losses
|
| 28 |
+
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
| 29 |
+
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
| 30 |
+
batch_iter = iter(batches)
|
| 31 |
+
for _ in range(steps):
|
| 32 |
+
x, y = next(batch_iter)
|
| 33 |
+
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
| 34 |
+
loss2d = loss2d.view(-1) # flatten
|
| 35 |
+
y = y.view(-1) # flatten
|
| 36 |
+
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
| 37 |
+
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
| 38 |
+
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
| 39 |
+
valid = y >= 0
|
| 40 |
+
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
| 41 |
+
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
| 42 |
+
num_bytes2d = torch.where(
|
| 43 |
+
valid,
|
| 44 |
+
token_bytes[y_safe],
|
| 45 |
+
torch.zeros_like(y, dtype=token_bytes.dtype)
|
| 46 |
+
)
|
| 47 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 48 |
+
total_bytes += num_bytes2d.sum()
|
| 49 |
+
else:
|
| 50 |
+
# fast path: no ignored targets, safe to index directly
|
| 51 |
+
num_bytes2d = token_bytes[y]
|
| 52 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 53 |
+
total_bytes += num_bytes2d.sum()
|
| 54 |
+
# sum reduce across all ranks
|
| 55 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 56 |
+
if world_size > 1:
|
| 57 |
+
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
| 58 |
+
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
| 59 |
+
# move both to cpu, calculate bpb and return
|
| 60 |
+
total_nats = total_nats.item()
|
| 61 |
+
total_bytes = total_bytes.item()
|
| 62 |
+
if total_bytes == 0:
|
| 63 |
+
return float('inf')
|
| 64 |
+
bpb = total_nats / (math.log(2) * total_bytes)
|
| 65 |
+
return bpb
|
nanochat/optim.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
| 3 |
+
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
| 4 |
+
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
| 5 |
+
|
| 6 |
+
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
| 7 |
+
Further contributions from @karpathy and @chrisjmccormick.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
"""
|
| 16 |
+
Good old AdamW optimizer, fused kernel.
|
| 17 |
+
https://arxiv.org/abs/1711.05101
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 21 |
+
def adamw_step_fused(
|
| 22 |
+
p: Tensor, # (32768, 768) - parameter tensor
|
| 23 |
+
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
| 24 |
+
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
|
| 25 |
+
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
|
| 26 |
+
step_t: Tensor, # () - 0-D CPU tensor, step count
|
| 27 |
+
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
| 28 |
+
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
|
| 29 |
+
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
|
| 30 |
+
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
|
| 31 |
+
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
| 32 |
+
) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
| 35 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 36 |
+
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
| 37 |
+
"""
|
| 38 |
+
# Weight decay (decoupled, applied before the update)
|
| 39 |
+
p.mul_(1 - lr_t * wd_t)
|
| 40 |
+
# Update running averages (lerp_ is cleaner and fuses well)
|
| 41 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 42 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 43 |
+
# Bias corrections
|
| 44 |
+
bias1 = 1 - beta1_t ** step_t
|
| 45 |
+
bias2 = 1 - beta2_t ** step_t
|
| 46 |
+
# Compute update and apply
|
| 47 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 48 |
+
step_size = lr_t / bias1
|
| 49 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 50 |
+
|
| 51 |
+
# -----------------------------------------------------------------------------
|
| 52 |
+
"""
|
| 53 |
+
Muon optimizer adapted and simplified from modded-nanogpt.
|
| 54 |
+
https://github.com/KellerJordan/modded-nanogpt
|
| 55 |
+
|
| 56 |
+
Background:
|
| 57 |
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
| 58 |
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
| 59 |
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
| 60 |
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
| 61 |
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
| 62 |
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
| 63 |
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 64 |
+
|
| 65 |
+
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
| 66 |
+
Polar Express Sign Method for orthogonalization.
|
| 67 |
+
https://arxiv.org/pdf/2505.16932
|
| 68 |
+
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
| 69 |
+
|
| 70 |
+
NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
|
| 71 |
+
update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
|
| 72 |
+
https://arxiv.org/pdf/2510.05491
|
| 73 |
+
|
| 74 |
+
Some of the changes in nanochat implementation:
|
| 75 |
+
- Uses a simpler, more general approach to parameter grouping and stacking
|
| 76 |
+
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
| 77 |
+
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
| 81 |
+
# From https://arxiv.org/pdf/2505.16932
|
| 82 |
+
polar_express_coeffs = [
|
| 83 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 84 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 85 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 86 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 87 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 91 |
+
def muon_step_fused(
|
| 92 |
+
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
| 93 |
+
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
| 94 |
+
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
|
| 95 |
+
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
|
| 96 |
+
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
| 97 |
+
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
| 98 |
+
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
| 99 |
+
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
| 100 |
+
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
| 101 |
+
red_dim: int, # -1 or -2 - reduction dimension for variance
|
| 102 |
+
) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
| 105 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 106 |
+
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# Nesterov momentum
|
| 110 |
+
momentum = momentum_t.to(stacked_grads.dtype)
|
| 111 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 112 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 113 |
+
|
| 114 |
+
# Polar express
|
| 115 |
+
X = g.bfloat16()
|
| 116 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
| 117 |
+
if g.size(-2) > g.size(-1): # Tall matrix
|
| 118 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 119 |
+
A = X.mT @ X
|
| 120 |
+
B = b * A + c * (A @ A)
|
| 121 |
+
X = a * X + X @ B
|
| 122 |
+
else: # Wide matrix (original math)
|
| 123 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 124 |
+
A = X @ X.mT
|
| 125 |
+
B = b * A + c * (A @ A)
|
| 126 |
+
X = a * X + B @ X
|
| 127 |
+
g = X
|
| 128 |
+
|
| 129 |
+
# Variance reduction
|
| 130 |
+
beta2 = beta2_t.to(g.dtype)
|
| 131 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 132 |
+
red_dim_size = g.size(red_dim)
|
| 133 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 134 |
+
v_norm = v_norm_sq.sqrt()
|
| 135 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 136 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 137 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 138 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 139 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 140 |
+
g = g * final_scale.to(g.dtype)
|
| 141 |
+
|
| 142 |
+
# Cautious weight decay + parameter update
|
| 143 |
+
lr = lr_t.to(g.dtype)
|
| 144 |
+
wd = wd_t.to(g.dtype)
|
| 145 |
+
mask = (g * stacked_params) >= 0
|
| 146 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 147 |
+
|
| 148 |
+
# -----------------------------------------------------------------------------
|
| 149 |
+
# Single GPU version of the MuonAdamW optimizer.
|
| 150 |
+
# Used mostly for reference, debugging and testing.
|
| 151 |
+
|
| 152 |
+
class MuonAdamW(torch.optim.Optimizer):
|
| 153 |
+
"""
|
| 154 |
+
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
|
| 155 |
+
|
| 156 |
+
AdamW - Fused AdamW optimizer step.
|
| 157 |
+
|
| 158 |
+
Muon - MomentUm Orthogonalized by Newton-schulz
|
| 159 |
+
https://kellerjordan.github.io/posts/muon/
|
| 160 |
+
|
| 161 |
+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
| 162 |
+
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
| 163 |
+
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
| 164 |
+
the advantage that it can be stably run in bfloat16 on the GPU.
|
| 165 |
+
|
| 166 |
+
Some warnings:
|
| 167 |
+
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
| 168 |
+
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
| 169 |
+
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
| 170 |
+
|
| 171 |
+
Arguments:
|
| 172 |
+
param_groups: List of dicts, each containing:
|
| 173 |
+
- 'params': List of parameters
|
| 174 |
+
- 'kind': 'adamw' or 'muon'
|
| 175 |
+
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
| 176 |
+
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
| 177 |
+
"""
|
| 178 |
+
def __init__(self, param_groups: list[dict]):
|
| 179 |
+
super().__init__(param_groups, defaults={})
|
| 180 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 181 |
+
# AdamW tensors
|
| 182 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 183 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 184 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 185 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 186 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 187 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 188 |
+
# Muon tensors
|
| 189 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 190 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 191 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 192 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 193 |
+
|
| 194 |
+
def _step_adamw(self, group: dict) -> None:
|
| 195 |
+
"""
|
| 196 |
+
AdamW update for each param in the group individually.
|
| 197 |
+
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
| 198 |
+
"""
|
| 199 |
+
for p in group['params']:
|
| 200 |
+
if p.grad is None:
|
| 201 |
+
continue
|
| 202 |
+
grad = p.grad
|
| 203 |
+
state = self.state[p]
|
| 204 |
+
|
| 205 |
+
# State init
|
| 206 |
+
if not state:
|
| 207 |
+
state['step'] = 0
|
| 208 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 209 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 210 |
+
exp_avg = state['exp_avg']
|
| 211 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 212 |
+
state['step'] += 1
|
| 213 |
+
|
| 214 |
+
# Fill 0-D tensors with current values
|
| 215 |
+
self._adamw_step_t.fill_(state['step'])
|
| 216 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 217 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 218 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 219 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 220 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 221 |
+
|
| 222 |
+
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
| 223 |
+
adamw_step_fused(
|
| 224 |
+
p, grad, exp_avg, exp_avg_sq,
|
| 225 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 226 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _step_muon(self, group: dict) -> None:
|
| 230 |
+
"""
|
| 231 |
+
Muon update for all params in the group (stacked for efficiency).
|
| 232 |
+
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
| 233 |
+
"""
|
| 234 |
+
params: list[Tensor] = group['params']
|
| 235 |
+
if not params:
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
# Get or create group-level buffers (stored in first param's state for convenience)
|
| 239 |
+
p = params[0]
|
| 240 |
+
state = self.state[p]
|
| 241 |
+
num_params = len(params)
|
| 242 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 243 |
+
|
| 244 |
+
# Momentum for every individual parameter
|
| 245 |
+
if "momentum_buffer" not in state:
|
| 246 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 247 |
+
momentum_buffer = state["momentum_buffer"]
|
| 248 |
+
|
| 249 |
+
# Second momentum buffer is factored, either per-row or per-column
|
| 250 |
+
if "second_momentum_buffer" not in state:
|
| 251 |
+
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
| 252 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 253 |
+
second_momentum_buffer = state["second_momentum_buffer"]
|
| 254 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 255 |
+
|
| 256 |
+
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
| 257 |
+
stacked_grads = torch.stack([p.grad for p in params])
|
| 258 |
+
stacked_params = torch.stack(params)
|
| 259 |
+
|
| 260 |
+
# Fill all the 0-D tensors with current values
|
| 261 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 262 |
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 263 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 264 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 265 |
+
|
| 266 |
+
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
| 267 |
+
muon_step_fused(
|
| 268 |
+
stacked_grads,
|
| 269 |
+
stacked_params,
|
| 270 |
+
momentum_buffer,
|
| 271 |
+
second_momentum_buffer,
|
| 272 |
+
self._muon_momentum_t,
|
| 273 |
+
self._muon_lr_t,
|
| 274 |
+
self._muon_wd_t,
|
| 275 |
+
self._muon_beta2_t,
|
| 276 |
+
group["ns_steps"],
|
| 277 |
+
red_dim,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Copy back to original params
|
| 281 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 282 |
+
|
| 283 |
+
@torch.no_grad()
|
| 284 |
+
def step(self):
|
| 285 |
+
for group in self.param_groups:
|
| 286 |
+
if group['kind'] == 'adamw':
|
| 287 |
+
self._step_adamw(group)
|
| 288 |
+
elif group['kind'] == 'muon':
|
| 289 |
+
self._step_muon(group)
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 292 |
+
|
| 293 |
+
# -----------------------------------------------------------------------------
|
| 294 |
+
# Distributed version of the MuonAdamW optimizer.
|
| 295 |
+
# Used for training on multiple GPUs.
|
| 296 |
+
|
| 297 |
+
class DistMuonAdamW(torch.optim.Optimizer):
|
| 298 |
+
"""
|
| 299 |
+
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
| 300 |
+
|
| 301 |
+
See MuonAdamW for the algorithmic details of each optimizer. This class adds
|
| 302 |
+
distributed communication to enable multi-GPU training without PyTorch DDP.
|
| 303 |
+
|
| 304 |
+
Design Goals:
|
| 305 |
+
- Overlap communication with computation (async ops)
|
| 306 |
+
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
|
| 307 |
+
- Batch small tensors into single comm ops where possible
|
| 308 |
+
|
| 309 |
+
Communication Pattern (3-phase async):
|
| 310 |
+
We use a 3-phase structure to maximize overlap between communication and compute:
|
| 311 |
+
|
| 312 |
+
Phase 1: Launch all async reduce ops
|
| 313 |
+
- Kick off all reduce_scatter/all_reduce operations
|
| 314 |
+
- Don't wait - let them run in background while we continue
|
| 315 |
+
|
| 316 |
+
Phase 2: Wait for reduces, compute updates, launch gathers
|
| 317 |
+
- For each group: wait for its reduce, compute the update, launch gather
|
| 318 |
+
- By processing groups in order, earlier gathers run while later computes happen
|
| 319 |
+
|
| 320 |
+
Phase 3: Wait for gathers, copy back
|
| 321 |
+
- Wait for all gathers to complete
|
| 322 |
+
- Copy updated params back to original tensors (Muon only)
|
| 323 |
+
|
| 324 |
+
AdamW Communication (ZeRO-2 style):
|
| 325 |
+
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
|
| 326 |
+
Optimizer state is replicated but these params are tiny (scalars, biases).
|
| 327 |
+
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
|
| 328 |
+
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
|
| 329 |
+
exp_avg_sq) is sharded - each rank only stores state for its slice.
|
| 330 |
+
Requires param.shape[0] divisible by world_size.
|
| 331 |
+
|
| 332 |
+
Muon Communication (stacked + chunked):
|
| 333 |
+
- All params in a Muon group must have the same shape (caller's responsibility).
|
| 334 |
+
- Stack all K params into a single (K, *shape) tensor for efficient comm.
|
| 335 |
+
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
|
| 336 |
+
- reduce_scatter the stacked grads so each rank gets its chunk.
|
| 337 |
+
- Each rank computes Muon update only for params it owns.
|
| 338 |
+
- all_gather the updated params back to all ranks.
|
| 339 |
+
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
|
| 340 |
+
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
|
| 341 |
+
then ignore the padding when copying back.
|
| 342 |
+
|
| 343 |
+
Buffer Reuse:
|
| 344 |
+
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
|
| 345 |
+
same buffer as the output for all_gather (stacked_params). This saves memory
|
| 346 |
+
since we don't need both buffers simultaneously.
|
| 347 |
+
|
| 348 |
+
Arguments:
|
| 349 |
+
param_groups: List of dicts, each containing:
|
| 350 |
+
- 'params': List of parameters
|
| 351 |
+
- 'kind': 'adamw' or 'muon'
|
| 352 |
+
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
| 353 |
+
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, param_groups: list[dict]):
|
| 356 |
+
super().__init__(param_groups, defaults={})
|
| 357 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 358 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 359 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 360 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 361 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 362 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 363 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 364 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 365 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 366 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 367 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 368 |
+
|
| 369 |
+
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
| 370 |
+
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
| 371 |
+
param_infos = {}
|
| 372 |
+
for p in group['params']:
|
| 373 |
+
grad = p.grad
|
| 374 |
+
if p.numel() < 1024:
|
| 375 |
+
# Small params: all_reduce (no scatter/gather needed)
|
| 376 |
+
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 377 |
+
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
| 378 |
+
else:
|
| 379 |
+
# Large params: reduce_scatter
|
| 380 |
+
assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
|
| 381 |
+
rank_size = grad.shape[0] // world_size
|
| 382 |
+
grad_slice = torch.empty_like(grad[:rank_size])
|
| 383 |
+
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 384 |
+
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
|
| 385 |
+
return dict(param_infos=param_infos)
|
| 386 |
+
|
| 387 |
+
def _reduce_muon(self, group: dict, world_size: int) -> dict:
|
| 388 |
+
"""Launch async reduce op for Muon group. Returns info dict."""
|
| 389 |
+
params = group['params']
|
| 390 |
+
chunk_size = (len(params) + world_size - 1) // world_size
|
| 391 |
+
padded_num_params = chunk_size * world_size
|
| 392 |
+
p = params[0]
|
| 393 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 394 |
+
|
| 395 |
+
# Stack grads and zero-pad to padded_num_params
|
| 396 |
+
grad_stack = torch.stack([p.grad for p in params])
|
| 397 |
+
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
| 398 |
+
stacked_grads[:len(params)].copy_(grad_stack)
|
| 399 |
+
if len(params) < padded_num_params:
|
| 400 |
+
stacked_grads[len(params):].zero_()
|
| 401 |
+
|
| 402 |
+
# Reduce_scatter to get this rank's chunk
|
| 403 |
+
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 404 |
+
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 405 |
+
|
| 406 |
+
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
| 407 |
+
|
| 408 |
+
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
| 409 |
+
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
| 410 |
+
param_infos = info['param_infos']
|
| 411 |
+
for p in group['params']:
|
| 412 |
+
pinfo = param_infos[p]
|
| 413 |
+
pinfo['future'].wait()
|
| 414 |
+
grad_slice = pinfo['grad_slice']
|
| 415 |
+
state = self.state[p]
|
| 416 |
+
|
| 417 |
+
# For small params, operate on full param; for large, operate on slice
|
| 418 |
+
if pinfo['is_small']:
|
| 419 |
+
p_slice = p
|
| 420 |
+
else:
|
| 421 |
+
rank_size = p.shape[0] // world_size
|
| 422 |
+
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
| 423 |
+
|
| 424 |
+
# State init
|
| 425 |
+
if not state:
|
| 426 |
+
state['step'] = 0
|
| 427 |
+
state['exp_avg'] = torch.zeros_like(p_slice)
|
| 428 |
+
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
| 429 |
+
state['step'] += 1
|
| 430 |
+
|
| 431 |
+
# Fill 0-D tensors and run fused kernel
|
| 432 |
+
self._adamw_step_t.fill_(state['step'])
|
| 433 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 434 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 435 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 436 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 437 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 438 |
+
adamw_step_fused(
|
| 439 |
+
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
| 440 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 441 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Large params need all_gather
|
| 445 |
+
if not pinfo['is_small']:
|
| 446 |
+
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
| 447 |
+
gather_list.append(dict(future=future, params=None))
|
| 448 |
+
|
| 449 |
+
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
| 450 |
+
"""Wait for reduce, compute Muon updates, launch gather."""
|
| 451 |
+
info['future'].wait()
|
| 452 |
+
params = group['params']
|
| 453 |
+
chunk_size = info['chunk_size']
|
| 454 |
+
grad_chunk = info['grad_chunk']
|
| 455 |
+
p = params[0]
|
| 456 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 457 |
+
|
| 458 |
+
# How many params does this rank own?
|
| 459 |
+
start_idx = rank * chunk_size
|
| 460 |
+
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
| 461 |
+
|
| 462 |
+
# Get or create group-level state
|
| 463 |
+
state = self.state[p]
|
| 464 |
+
if "momentum_buffer" not in state:
|
| 465 |
+
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
| 466 |
+
if "second_momentum_buffer" not in state:
|
| 467 |
+
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
| 468 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 469 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 470 |
+
|
| 471 |
+
# Build output buffer for all_gather
|
| 472 |
+
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 473 |
+
|
| 474 |
+
if num_owned > 0:
|
| 475 |
+
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
| 476 |
+
stacked_owned = torch.stack(owned_params)
|
| 477 |
+
|
| 478 |
+
# Fill 0-D tensors and run fused kernel
|
| 479 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 480 |
+
self._muon_beta2_t.fill_(group["beta2"])
|
| 481 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 482 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 483 |
+
muon_step_fused(
|
| 484 |
+
grad_chunk[:num_owned], stacked_owned,
|
| 485 |
+
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
| 486 |
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
| 487 |
+
group["ns_steps"], red_dim,
|
| 488 |
+
)
|
| 489 |
+
updated_params[:num_owned].copy_(stacked_owned)
|
| 490 |
+
|
| 491 |
+
if num_owned < chunk_size:
|
| 492 |
+
updated_params[num_owned:].zero_()
|
| 493 |
+
|
| 494 |
+
# Reuse stacked_grads buffer for all_gather output
|
| 495 |
+
stacked_params = info["stacked_grads"]
|
| 496 |
+
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
| 497 |
+
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
| 498 |
+
|
| 499 |
+
def _finish_gathers(self, gather_list: list) -> None:
|
| 500 |
+
"""Wait for all gathers and copy Muon params back."""
|
| 501 |
+
for info in gather_list:
|
| 502 |
+
info["future"].wait()
|
| 503 |
+
if info["params"] is not None:
|
| 504 |
+
# Muon: copy from stacked buffer back to individual params
|
| 505 |
+
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
|
| 506 |
+
|
| 507 |
+
@torch.no_grad()
|
| 508 |
+
def step(self):
|
| 509 |
+
rank = dist.get_rank()
|
| 510 |
+
world_size = dist.get_world_size()
|
| 511 |
+
|
| 512 |
+
# Phase 1: launch all async reduce ops
|
| 513 |
+
reduce_infos: list[dict] = []
|
| 514 |
+
for group in self.param_groups:
|
| 515 |
+
if group['kind'] == 'adamw':
|
| 516 |
+
reduce_infos.append(self._reduce_adamw(group, world_size))
|
| 517 |
+
elif group['kind'] == 'muon':
|
| 518 |
+
reduce_infos.append(self._reduce_muon(group, world_size))
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 521 |
+
|
| 522 |
+
# Phase 2: wait for reduces, compute updates, launch gathers
|
| 523 |
+
gather_list: list[dict] = []
|
| 524 |
+
for group, info in zip(self.param_groups, reduce_infos):
|
| 525 |
+
if group['kind'] == 'adamw':
|
| 526 |
+
self._compute_adamw(group, info, gather_list, rank, world_size)
|
| 527 |
+
elif group['kind'] == 'muon':
|
| 528 |
+
self._compute_muon(group, info, gather_list, rank)
|
| 529 |
+
else:
|
| 530 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 531 |
+
|
| 532 |
+
# Phase 3: wait for gathers, copy back
|
| 533 |
+
self._finish_gathers(gather_list)
|
nanochat/report.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for generating training report cards. More messy code than usual, will fix.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import subprocess
|
| 9 |
+
import socket
|
| 10 |
+
import datetime
|
| 11 |
+
import platform
|
| 12 |
+
import psutil
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
def run_command(cmd):
|
| 16 |
+
"""Run a shell command and return output, or None if it fails."""
|
| 17 |
+
try:
|
| 18 |
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
| 19 |
+
# Return stdout if we got output (even if some files in xargs failed)
|
| 20 |
+
if result.stdout.strip():
|
| 21 |
+
return result.stdout.strip()
|
| 22 |
+
if result.returncode == 0:
|
| 23 |
+
return ""
|
| 24 |
+
return None
|
| 25 |
+
except:
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def get_git_info():
|
| 29 |
+
"""Get current git commit, branch, and dirty status."""
|
| 30 |
+
info = {}
|
| 31 |
+
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
|
| 32 |
+
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
|
| 33 |
+
|
| 34 |
+
# Check if repo is dirty (has uncommitted changes)
|
| 35 |
+
status = run_command("git status --porcelain")
|
| 36 |
+
info['dirty'] = bool(status) if status is not None else False
|
| 37 |
+
|
| 38 |
+
# Get commit message
|
| 39 |
+
info['message'] = run_command("git log -1 --pretty=%B") or ""
|
| 40 |
+
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
|
| 41 |
+
|
| 42 |
+
return info
|
| 43 |
+
|
| 44 |
+
def get_gpu_info():
|
| 45 |
+
"""Get GPU information."""
|
| 46 |
+
if not torch.cuda.is_available():
|
| 47 |
+
return {"available": False}
|
| 48 |
+
|
| 49 |
+
num_devices = torch.cuda.device_count()
|
| 50 |
+
info = {
|
| 51 |
+
"available": True,
|
| 52 |
+
"count": num_devices,
|
| 53 |
+
"names": [],
|
| 54 |
+
"memory_gb": []
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
for i in range(num_devices):
|
| 58 |
+
props = torch.cuda.get_device_properties(i)
|
| 59 |
+
info["names"].append(props.name)
|
| 60 |
+
info["memory_gb"].append(props.total_memory / (1024**3))
|
| 61 |
+
|
| 62 |
+
# Get CUDA version
|
| 63 |
+
info["cuda_version"] = torch.version.cuda or "unknown"
|
| 64 |
+
|
| 65 |
+
return info
|
| 66 |
+
|
| 67 |
+
def get_system_info():
|
| 68 |
+
"""Get system information."""
|
| 69 |
+
info = {}
|
| 70 |
+
|
| 71 |
+
# Basic system info
|
| 72 |
+
info['hostname'] = socket.gethostname()
|
| 73 |
+
info['platform'] = platform.system()
|
| 74 |
+
info['python_version'] = platform.python_version()
|
| 75 |
+
info['torch_version'] = torch.__version__
|
| 76 |
+
|
| 77 |
+
# CPU and memory
|
| 78 |
+
info['cpu_count'] = psutil.cpu_count(logical=False)
|
| 79 |
+
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
|
| 80 |
+
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
|
| 81 |
+
|
| 82 |
+
# User and environment
|
| 83 |
+
info['user'] = os.environ.get('USER', 'unknown')
|
| 84 |
+
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
|
| 85 |
+
info['working_dir'] = os.getcwd()
|
| 86 |
+
|
| 87 |
+
return info
|
| 88 |
+
|
| 89 |
+
def estimate_cost(gpu_info, runtime_hours=None):
|
| 90 |
+
"""Estimate training cost based on GPU type and runtime."""
|
| 91 |
+
|
| 92 |
+
# Rough pricing, from Lambda Cloud
|
| 93 |
+
default_rate = 2.0
|
| 94 |
+
gpu_hourly_rates = {
|
| 95 |
+
"H100": 3.00,
|
| 96 |
+
"A100": 1.79,
|
| 97 |
+
"V100": 0.55,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if not gpu_info.get("available"):
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Try to identify GPU type from name
|
| 104 |
+
hourly_rate = None
|
| 105 |
+
gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
|
| 106 |
+
for gpu_type, rate in gpu_hourly_rates.items():
|
| 107 |
+
if gpu_type in gpu_name:
|
| 108 |
+
hourly_rate = rate * gpu_info["count"]
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if hourly_rate is None:
|
| 112 |
+
hourly_rate = default_rate * gpu_info["count"] # Default estimate
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"hourly_rate": hourly_rate,
|
| 116 |
+
"gpu_type": gpu_name,
|
| 117 |
+
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def generate_header():
|
| 121 |
+
"""Generate the header for a training report."""
|
| 122 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 123 |
+
|
| 124 |
+
git_info = get_git_info()
|
| 125 |
+
gpu_info = get_gpu_info()
|
| 126 |
+
sys_info = get_system_info()
|
| 127 |
+
cost_info = estimate_cost(gpu_info)
|
| 128 |
+
|
| 129 |
+
header = f"""# nanochat training report
|
| 130 |
+
|
| 131 |
+
Generated: {timestamp}
|
| 132 |
+
|
| 133 |
+
## Environment
|
| 134 |
+
|
| 135 |
+
### Git Information
|
| 136 |
+
- Branch: {git_info['branch']}
|
| 137 |
+
- Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
|
| 138 |
+
- Message: {git_info['message']}
|
| 139 |
+
|
| 140 |
+
### Hardware
|
| 141 |
+
- Platform: {sys_info['platform']}
|
| 142 |
+
- CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
|
| 143 |
+
- Memory: {sys_info['memory_gb']:.1f} GB
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
if gpu_info.get("available"):
|
| 147 |
+
gpu_names = ", ".join(set(gpu_info["names"]))
|
| 148 |
+
total_vram = sum(gpu_info["memory_gb"])
|
| 149 |
+
header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
|
| 150 |
+
- GPU Memory: {total_vram:.1f} GB total
|
| 151 |
+
- CUDA Version: {gpu_info['cuda_version']}
|
| 152 |
+
"""
|
| 153 |
+
else:
|
| 154 |
+
header += "- GPUs: None available\n"
|
| 155 |
+
|
| 156 |
+
if cost_info and cost_info["hourly_rate"] > 0:
|
| 157 |
+
header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
|
| 158 |
+
|
| 159 |
+
header += f"""
|
| 160 |
+
### Software
|
| 161 |
+
- Python: {sys_info['python_version']}
|
| 162 |
+
- PyTorch: {sys_info['torch_version']}
|
| 163 |
+
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# bloat metrics: count lines/chars in git-tracked source files only
|
| 167 |
+
extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
|
| 168 |
+
git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
|
| 169 |
+
files_output = run_command(f"git ls-files -- {git_patterns}")
|
| 170 |
+
file_list = [f for f in (files_output or '').split('\n') if f]
|
| 171 |
+
num_files = len(file_list)
|
| 172 |
+
num_lines = 0
|
| 173 |
+
num_chars = 0
|
| 174 |
+
if num_files > 0:
|
| 175 |
+
wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
|
| 176 |
+
if wc_output:
|
| 177 |
+
total_line = wc_output.strip().split('\n')[-1]
|
| 178 |
+
parts = total_line.split()
|
| 179 |
+
if len(parts) >= 2:
|
| 180 |
+
num_lines = int(parts[0])
|
| 181 |
+
num_chars = int(parts[1])
|
| 182 |
+
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
| 183 |
+
|
| 184 |
+
# count dependencies via uv.lock
|
| 185 |
+
uv_lock_lines = 0
|
| 186 |
+
if os.path.exists('uv.lock'):
|
| 187 |
+
with open('uv.lock', 'r', encoding='utf-8') as f:
|
| 188 |
+
uv_lock_lines = len(f.readlines())
|
| 189 |
+
|
| 190 |
+
header += f"""
|
| 191 |
+
### Bloat
|
| 192 |
+
- Characters: {num_chars:,}
|
| 193 |
+
- Lines: {num_lines:,}
|
| 194 |
+
- Files: {num_files:,}
|
| 195 |
+
- Tokens (approx): {num_tokens:,}
|
| 196 |
+
- Dependencies (uv.lock lines): {uv_lock_lines:,}
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
return header
|
| 200 |
+
|
| 201 |
+
# -----------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def slugify(text):
|
| 204 |
+
"""Slugify a text string."""
|
| 205 |
+
return text.lower().replace(" ", "-")
|
| 206 |
+
|
| 207 |
+
# the expected files and their order
|
| 208 |
+
EXPECTED_FILES = [
|
| 209 |
+
"tokenizer-training.md",
|
| 210 |
+
"tokenizer-evaluation.md",
|
| 211 |
+
"base-model-training.md",
|
| 212 |
+
"base-model-loss.md",
|
| 213 |
+
"base-model-evaluation.md",
|
| 214 |
+
"chat-sft.md",
|
| 215 |
+
"chat-evaluation-sft.md",
|
| 216 |
+
"chat-rl.md",
|
| 217 |
+
"chat-evaluation-rl.md",
|
| 218 |
+
]
|
| 219 |
+
# the metrics we're currently interested in
|
| 220 |
+
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
|
| 221 |
+
|
| 222 |
+
def extract(section, keys):
|
| 223 |
+
"""simple def to extract a single key from a section"""
|
| 224 |
+
if not isinstance(keys, list):
|
| 225 |
+
keys = [keys] # convenience
|
| 226 |
+
out = {}
|
| 227 |
+
for line in section.split("\n"):
|
| 228 |
+
for key in keys:
|
| 229 |
+
if key in line:
|
| 230 |
+
out[key] = line.split(":")[1].strip()
|
| 231 |
+
return out
|
| 232 |
+
|
| 233 |
+
def extract_timestamp(content, prefix):
|
| 234 |
+
"""Extract timestamp from content with given prefix."""
|
| 235 |
+
for line in content.split('\n'):
|
| 236 |
+
if line.startswith(prefix):
|
| 237 |
+
time_str = line.split(":", 1)[1].strip()
|
| 238 |
+
try:
|
| 239 |
+
return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
| 240 |
+
except:
|
| 241 |
+
pass
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
class Report:
|
| 245 |
+
"""Maintains a bunch of logs, generates a final markdown report."""
|
| 246 |
+
|
| 247 |
+
def __init__(self, report_dir):
|
| 248 |
+
os.makedirs(report_dir, exist_ok=True)
|
| 249 |
+
self.report_dir = report_dir
|
| 250 |
+
|
| 251 |
+
def log(self, section, data):
|
| 252 |
+
"""Log a section of data to the report."""
|
| 253 |
+
slug = slugify(section)
|
| 254 |
+
file_name = f"{slug}.md"
|
| 255 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 256 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 257 |
+
f.write(f"## {section}\n")
|
| 258 |
+
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 259 |
+
for item in data:
|
| 260 |
+
if not item:
|
| 261 |
+
# skip falsy values like None or empty dict etc.
|
| 262 |
+
continue
|
| 263 |
+
if isinstance(item, str):
|
| 264 |
+
# directly write the string
|
| 265 |
+
f.write(item)
|
| 266 |
+
else:
|
| 267 |
+
# render a dict
|
| 268 |
+
for k, v in item.items():
|
| 269 |
+
if isinstance(v, float):
|
| 270 |
+
vstr = f"{v:.4f}"
|
| 271 |
+
elif isinstance(v, int) and v >= 10000:
|
| 272 |
+
vstr = f"{v:,.0f}"
|
| 273 |
+
else:
|
| 274 |
+
vstr = str(v)
|
| 275 |
+
f.write(f"- {k}: {vstr}\n")
|
| 276 |
+
f.write("\n")
|
| 277 |
+
return file_path
|
| 278 |
+
|
| 279 |
+
def generate(self):
|
| 280 |
+
"""Generate the final report."""
|
| 281 |
+
report_dir = self.report_dir
|
| 282 |
+
report_file = os.path.join(report_dir, "report.md")
|
| 283 |
+
print(f"Generating report to {report_file}")
|
| 284 |
+
final_metrics = {} # the most important final metrics we'll add as table at the end
|
| 285 |
+
start_time = None
|
| 286 |
+
end_time = None
|
| 287 |
+
with open(report_file, "w", encoding="utf-8") as out_file:
|
| 288 |
+
# write the header first
|
| 289 |
+
header_file = os.path.join(report_dir, "header.md")
|
| 290 |
+
if os.path.exists(header_file):
|
| 291 |
+
with open(header_file, "r", encoding="utf-8") as f:
|
| 292 |
+
header_content = f.read()
|
| 293 |
+
out_file.write(header_content)
|
| 294 |
+
start_time = extract_timestamp(header_content, "Run started:")
|
| 295 |
+
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
| 296 |
+
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
| 297 |
+
bloat_data = bloat_data.group(1) if bloat_data else ""
|
| 298 |
+
else:
|
| 299 |
+
start_time = None # will cause us to not write the total wall clock time
|
| 300 |
+
bloat_data = "[bloat data missing]"
|
| 301 |
+
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
| 302 |
+
# process all the individual sections
|
| 303 |
+
for file_name in EXPECTED_FILES:
|
| 304 |
+
section_file = os.path.join(report_dir, file_name)
|
| 305 |
+
if not os.path.exists(section_file):
|
| 306 |
+
print(f"Warning: {section_file} does not exist, skipping")
|
| 307 |
+
continue
|
| 308 |
+
with open(section_file, "r", encoding="utf-8") as in_file:
|
| 309 |
+
section = in_file.read()
|
| 310 |
+
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
| 311 |
+
if "rl" not in file_name:
|
| 312 |
+
# Skip RL sections for end_time calculation because RL is experimental
|
| 313 |
+
end_time = extract_timestamp(section, "timestamp:")
|
| 314 |
+
# extract the most important metrics from the sections
|
| 315 |
+
if file_name == "base-model-evaluation.md":
|
| 316 |
+
final_metrics["base"] = extract(section, "CORE")
|
| 317 |
+
if file_name == "chat-evaluation-sft.md":
|
| 318 |
+
final_metrics["sft"] = extract(section, chat_metrics)
|
| 319 |
+
if file_name == "chat-evaluation-rl.md":
|
| 320 |
+
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
| 321 |
+
# append this section of the report
|
| 322 |
+
out_file.write(section)
|
| 323 |
+
out_file.write("\n")
|
| 324 |
+
# add the final metrics table
|
| 325 |
+
out_file.write("## Summary\n\n")
|
| 326 |
+
# Copy over the bloat metrics from the header
|
| 327 |
+
out_file.write(bloat_data)
|
| 328 |
+
out_file.write("\n\n")
|
| 329 |
+
# Collect all unique metric names
|
| 330 |
+
all_metrics = set()
|
| 331 |
+
for stage_metrics in final_metrics.values():
|
| 332 |
+
all_metrics.update(stage_metrics.keys())
|
| 333 |
+
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
| 334 |
+
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
| 335 |
+
# Fixed column widths
|
| 336 |
+
stages = ["base", "sft", "rl"]
|
| 337 |
+
metric_width = 15
|
| 338 |
+
value_width = 8
|
| 339 |
+
# Write table header
|
| 340 |
+
header = f"| {'Metric'.ljust(metric_width)} |"
|
| 341 |
+
for stage in stages:
|
| 342 |
+
header += f" {stage.upper().ljust(value_width)} |"
|
| 343 |
+
out_file.write(header + "\n")
|
| 344 |
+
# Write separator
|
| 345 |
+
separator = f"|{'-' * (metric_width + 2)}|"
|
| 346 |
+
for stage in stages:
|
| 347 |
+
separator += f"{'-' * (value_width + 2)}|"
|
| 348 |
+
out_file.write(separator + "\n")
|
| 349 |
+
# Write table rows
|
| 350 |
+
for metric in all_metrics:
|
| 351 |
+
row = f"| {metric.ljust(metric_width)} |"
|
| 352 |
+
for stage in stages:
|
| 353 |
+
value = final_metrics.get(stage, {}).get(metric, "-")
|
| 354 |
+
row += f" {str(value).ljust(value_width)} |"
|
| 355 |
+
out_file.write(row + "\n")
|
| 356 |
+
out_file.write("\n")
|
| 357 |
+
# Calculate and write total wall clock time
|
| 358 |
+
if start_time and end_time:
|
| 359 |
+
duration = end_time - start_time
|
| 360 |
+
total_seconds = int(duration.total_seconds())
|
| 361 |
+
hours = total_seconds // 3600
|
| 362 |
+
minutes = (total_seconds % 3600) // 60
|
| 363 |
+
out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
|
| 364 |
+
else:
|
| 365 |
+
out_file.write("Total wall clock time: unknown\n")
|
| 366 |
+
# also cp the report.md file to current directory
|
| 367 |
+
print(f"Copying report.md to current directory for convenience")
|
| 368 |
+
shutil.copy(report_file, "report.md")
|
| 369 |
+
return report_file
|
| 370 |
+
|
| 371 |
+
def reset(self):
|
| 372 |
+
"""Reset the report."""
|
| 373 |
+
# Remove section files
|
| 374 |
+
for file_name in EXPECTED_FILES:
|
| 375 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 376 |
+
if os.path.exists(file_path):
|
| 377 |
+
os.remove(file_path)
|
| 378 |
+
# Remove report.md if it exists
|
| 379 |
+
report_file = os.path.join(self.report_dir, "report.md")
|
| 380 |
+
if os.path.exists(report_file):
|
| 381 |
+
os.remove(report_file)
|
| 382 |
+
# Generate and write the header section with start timestamp
|
| 383 |
+
header_file = os.path.join(self.report_dir, "header.md")
|
| 384 |
+
header = generate_header()
|
| 385 |
+
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 386 |
+
with open(header_file, "w", encoding="utf-8") as f:
|
| 387 |
+
f.write(header)
|
| 388 |
+
f.write(f"Run started: {start_time}\n\n---\n\n")
|
| 389 |
+
print(f"Reset report and wrote header to {header_file}")
|
| 390 |
+
|
| 391 |
+
# -----------------------------------------------------------------------------
|
| 392 |
+
# nanochat-specific convenience functions
|
| 393 |
+
|
| 394 |
+
class DummyReport:
|
| 395 |
+
def log(self, *args, **kwargs):
|
| 396 |
+
pass
|
| 397 |
+
def reset(self, *args, **kwargs):
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
def get_report():
|
| 401 |
+
# just for convenience, only rank 0 logs to report
|
| 402 |
+
from nanochat.common import get_base_dir, get_dist_info
|
| 403 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 404 |
+
if ddp_rank == 0:
|
| 405 |
+
report_dir = os.path.join(get_base_dir(), "report")
|
| 406 |
+
return Report(report_dir)
|
| 407 |
+
else:
|
| 408 |
+
return DummyReport()
|
| 409 |
+
|
| 410 |
+
if __name__ == "__main__":
|
| 411 |
+
import argparse
|
| 412 |
+
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
|
| 413 |
+
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
|
| 414 |
+
args = parser.parse_args()
|
| 415 |
+
if args.command == "generate":
|
| 416 |
+
get_report().generate()
|
| 417 |
+
elif args.command == "reset":
|
| 418 |
+
get_report().reset()
|
nanochat/tokenizer.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BPE Tokenizer in the style of GPT-4.
|
| 3 |
+
|
| 4 |
+
Two implementations are available:
|
| 5 |
+
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
|
| 6 |
+
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import copy
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
|
| 13 |
+
SPECIAL_TOKENS = [
|
| 14 |
+
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
| 15 |
+
"<|bos|>",
|
| 16 |
+
# tokens below are only used during finetuning to render Conversations into token ids
|
| 17 |
+
"<|user_start|>", # user messages
|
| 18 |
+
"<|user_end|>",
|
| 19 |
+
"<|assistant_start|>", # assistant messages
|
| 20 |
+
"<|assistant_end|>",
|
| 21 |
+
"<|python_start|>", # assistant invokes python REPL tool
|
| 22 |
+
"<|python_end|>",
|
| 23 |
+
"<|output_start|>", # python REPL outputs back to assistant
|
| 24 |
+
"<|output_end|>",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
| 28 |
+
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
| 29 |
+
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
| 30 |
+
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
|
| 34 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 35 |
+
from tokenizers import pre_tokenizers, decoders, Regex
|
| 36 |
+
from tokenizers.models import BPE
|
| 37 |
+
from tokenizers.trainers import BpeTrainer
|
| 38 |
+
|
| 39 |
+
class HuggingFaceTokenizer:
|
| 40 |
+
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, tokenizer):
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(cls, hf_path):
|
| 47 |
+
# init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
|
| 48 |
+
tokenizer = HFTokenizer.from_pretrained(hf_path)
|
| 49 |
+
return cls(tokenizer)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_directory(cls, tokenizer_dir):
|
| 53 |
+
# init from a local directory on disk (e.g. "out/tokenizer")
|
| 54 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 55 |
+
tokenizer = HFTokenizer.from_file(tokenizer_path)
|
| 56 |
+
return cls(tokenizer)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 60 |
+
# train from an iterator of text
|
| 61 |
+
# Configure the HuggingFace Tokenizer
|
| 62 |
+
tokenizer = HFTokenizer(BPE(
|
| 63 |
+
byte_fallback=True, # needed!
|
| 64 |
+
unk_token=None,
|
| 65 |
+
fuse_unk=False,
|
| 66 |
+
))
|
| 67 |
+
# Normalizer: None
|
| 68 |
+
tokenizer.normalizer = None
|
| 69 |
+
# Pre-tokenizer: GPT-4 style
|
| 70 |
+
# the regex pattern used by GPT-4 to split text into groups before BPE
|
| 71 |
+
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
|
| 72 |
+
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
|
| 73 |
+
# (but I haven't validated this! TODO)
|
| 74 |
+
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
| 75 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
| 76 |
+
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
| 77 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
| 78 |
+
])
|
| 79 |
+
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
| 80 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 81 |
+
# Post-processor: None
|
| 82 |
+
tokenizer.post_processor = None
|
| 83 |
+
# Trainer: BPE
|
| 84 |
+
trainer = BpeTrainer(
|
| 85 |
+
vocab_size=vocab_size,
|
| 86 |
+
show_progress=True,
|
| 87 |
+
min_frequency=0, # no minimum frequency
|
| 88 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 89 |
+
special_tokens=SPECIAL_TOKENS,
|
| 90 |
+
)
|
| 91 |
+
# Kick off the training
|
| 92 |
+
tokenizer.train_from_iterator(text_iterator, trainer)
|
| 93 |
+
return cls(tokenizer)
|
| 94 |
+
|
| 95 |
+
def get_vocab_size(self):
|
| 96 |
+
return self.tokenizer.get_vocab_size()
|
| 97 |
+
|
| 98 |
+
def get_special_tokens(self):
|
| 99 |
+
special_tokens_map = self.tokenizer.get_added_tokens_decoder()
|
| 100 |
+
special_tokens = [w.content for w in special_tokens_map.values()]
|
| 101 |
+
return special_tokens
|
| 102 |
+
|
| 103 |
+
def id_to_token(self, id):
|
| 104 |
+
return self.tokenizer.id_to_token(id)
|
| 105 |
+
|
| 106 |
+
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
| 107 |
+
# encode a single string
|
| 108 |
+
# prepend/append can be either a string of a special token or a token id directly.
|
| 109 |
+
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
| 110 |
+
assert isinstance(text, str)
|
| 111 |
+
ids = []
|
| 112 |
+
if prepend is not None:
|
| 113 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 114 |
+
ids.append(prepend_id)
|
| 115 |
+
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
|
| 116 |
+
if append is not None:
|
| 117 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 118 |
+
ids.append(append_id)
|
| 119 |
+
return ids
|
| 120 |
+
|
| 121 |
+
def encode_special(self, text):
|
| 122 |
+
# encode a single special token via exact match
|
| 123 |
+
return self.tokenizer.token_to_id(text)
|
| 124 |
+
|
| 125 |
+
def get_bos_token_id(self):
|
| 126 |
+
# Different HuggingFace models use different BOS tokens and there is little consistency
|
| 127 |
+
# 1) attempt to find a <|bos|> token
|
| 128 |
+
bos = self.encode_special("<|bos|>")
|
| 129 |
+
# 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
|
| 130 |
+
if bos is None:
|
| 131 |
+
bos = self.encode_special("<|endoftext|>")
|
| 132 |
+
# 3) if these fail, it's better to crash than to silently return None
|
| 133 |
+
assert bos is not None, "Failed to find BOS token in tokenizer"
|
| 134 |
+
return bos
|
| 135 |
+
|
| 136 |
+
def encode(self, text, *args, **kwargs):
|
| 137 |
+
if isinstance(text, str):
|
| 138 |
+
return self._encode_one(text, *args, **kwargs)
|
| 139 |
+
elif isinstance(text, list):
|
| 140 |
+
return [self._encode_one(t, *args, **kwargs) for t in text]
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 143 |
+
|
| 144 |
+
def __call__(self, *args, **kwargs):
|
| 145 |
+
return self.encode(*args, **kwargs)
|
| 146 |
+
|
| 147 |
+
def decode(self, ids):
|
| 148 |
+
return self.tokenizer.decode(ids, skip_special_tokens=False)
|
| 149 |
+
|
| 150 |
+
def save(self, tokenizer_dir):
|
| 151 |
+
# save the tokenizer to disk
|
| 152 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 153 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 154 |
+
self.tokenizer.save(tokenizer_path)
|
| 155 |
+
print(f"Saved tokenizer to {tokenizer_path}")
|
| 156 |
+
|
| 157 |
+
# -----------------------------------------------------------------------------
|
| 158 |
+
# Tokenizer based on rustbpe + tiktoken combo
|
| 159 |
+
import pickle
|
| 160 |
+
import rustbpe
|
| 161 |
+
import tiktoken
|
| 162 |
+
|
| 163 |
+
class RustBPETokenizer:
|
| 164 |
+
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, enc, bos_token):
|
| 167 |
+
self.enc = enc
|
| 168 |
+
self.bos_token_id = self.encode_special(bos_token)
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 172 |
+
# 1) train using rustbpe
|
| 173 |
+
tokenizer = rustbpe.Tokenizer()
|
| 174 |
+
# the special tokens are inserted later in __init__, we don't train them here
|
| 175 |
+
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
|
| 176 |
+
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
|
| 177 |
+
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 178 |
+
# 2) construct the associated tiktoken encoding for inference
|
| 179 |
+
pattern = tokenizer.get_pattern()
|
| 180 |
+
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
|
| 181 |
+
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
| 182 |
+
tokens_offset = len(mergeable_ranks)
|
| 183 |
+
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
| 184 |
+
enc = tiktoken.Encoding(
|
| 185 |
+
name="rustbpe",
|
| 186 |
+
pat_str=pattern,
|
| 187 |
+
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
|
| 188 |
+
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
|
| 189 |
+
)
|
| 190 |
+
return cls(enc, "<|bos|>")
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def from_directory(cls, tokenizer_dir):
|
| 194 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 195 |
+
with open(pickle_path, "rb") as f:
|
| 196 |
+
enc = pickle.load(f)
|
| 197 |
+
return cls(enc, "<|bos|>")
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def from_pretrained(cls, tiktoken_name):
|
| 201 |
+
# https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
|
| 202 |
+
enc = tiktoken.get_encoding(tiktoken_name)
|
| 203 |
+
# tiktoken calls the special document delimiter token "<|endoftext|>"
|
| 204 |
+
# yes this is confusing because this token is almost always PREPENDED to the beginning of the document
|
| 205 |
+
# it most often is used to signal the start of a new sequence to the LLM during inference etc.
|
| 206 |
+
# so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
|
| 207 |
+
return cls(enc, "<|endoftext|>")
|
| 208 |
+
|
| 209 |
+
def get_vocab_size(self):
|
| 210 |
+
return self.enc.n_vocab
|
| 211 |
+
|
| 212 |
+
def get_special_tokens(self):
|
| 213 |
+
return self.enc.special_tokens_set
|
| 214 |
+
|
| 215 |
+
def id_to_token(self, id):
|
| 216 |
+
return self.enc.decode([id])
|
| 217 |
+
|
| 218 |
+
@lru_cache(maxsize=32)
|
| 219 |
+
def encode_special(self, text):
|
| 220 |
+
return self.enc.encode_single_token(text)
|
| 221 |
+
|
| 222 |
+
def get_bos_token_id(self):
|
| 223 |
+
return self.bos_token_id
|
| 224 |
+
|
| 225 |
+
def encode(self, text, prepend=None, append=None, num_threads=8):
|
| 226 |
+
# text can be either a string or a list of strings
|
| 227 |
+
|
| 228 |
+
if prepend is not None:
|
| 229 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 230 |
+
if append is not None:
|
| 231 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 232 |
+
|
| 233 |
+
if isinstance(text, str):
|
| 234 |
+
ids = self.enc.encode_ordinary(text)
|
| 235 |
+
if prepend is not None:
|
| 236 |
+
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
| 237 |
+
if append is not None:
|
| 238 |
+
ids.append(append_id)
|
| 239 |
+
elif isinstance(text, list):
|
| 240 |
+
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
| 241 |
+
if prepend is not None:
|
| 242 |
+
for ids_row in ids:
|
| 243 |
+
ids_row.insert(0, prepend_id) # TODO: same
|
| 244 |
+
if append is not None:
|
| 245 |
+
for ids_row in ids:
|
| 246 |
+
ids_row.append(append_id)
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 249 |
+
|
| 250 |
+
return ids
|
| 251 |
+
|
| 252 |
+
def __call__(self, *args, **kwargs):
|
| 253 |
+
return self.encode(*args, **kwargs)
|
| 254 |
+
|
| 255 |
+
def decode(self, ids):
|
| 256 |
+
return self.enc.decode(ids)
|
| 257 |
+
|
| 258 |
+
def save(self, tokenizer_dir):
|
| 259 |
+
# save the encoding object to disk
|
| 260 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 261 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 262 |
+
with open(pickle_path, "wb") as f:
|
| 263 |
+
pickle.dump(self.enc, f)
|
| 264 |
+
print(f"Saved tokenizer encoding to {pickle_path}")
|
| 265 |
+
|
| 266 |
+
def render_conversation(self, conversation, max_tokens=2048):
|
| 267 |
+
"""
|
| 268 |
+
Tokenize a single Chat conversation (which we call a "doc" or "document" here).
|
| 269 |
+
Returns:
|
| 270 |
+
- ids: list[int] is a list of token ids of this rendered conversation
|
| 271 |
+
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
|
| 272 |
+
"""
|
| 273 |
+
# ids, masks that we will return and a helper function to help build them up.
|
| 274 |
+
ids, mask = [], []
|
| 275 |
+
def add_tokens(token_ids, mask_val):
|
| 276 |
+
if isinstance(token_ids, int):
|
| 277 |
+
token_ids = [token_ids]
|
| 278 |
+
ids.extend(token_ids)
|
| 279 |
+
mask.extend([mask_val] * len(token_ids))
|
| 280 |
+
|
| 281 |
+
# sometimes the first message is a system message...
|
| 282 |
+
# => just merge it with the second (user) message
|
| 283 |
+
if conversation["messages"][0]["role"] == "system":
|
| 284 |
+
# some conversation surgery is necessary here for now...
|
| 285 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 286 |
+
messages = conversation["messages"]
|
| 287 |
+
assert messages[1]["role"] == "user", "System message must be followed by a user message"
|
| 288 |
+
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
|
| 289 |
+
messages = messages[1:]
|
| 290 |
+
else:
|
| 291 |
+
messages = conversation["messages"]
|
| 292 |
+
assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
|
| 293 |
+
|
| 294 |
+
# fetch all the special tokens we need
|
| 295 |
+
bos = self.get_bos_token_id()
|
| 296 |
+
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
|
| 297 |
+
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
|
| 298 |
+
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
|
| 299 |
+
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
|
| 300 |
+
|
| 301 |
+
# now we can tokenize the conversation
|
| 302 |
+
add_tokens(bos, 0)
|
| 303 |
+
for i, message in enumerate(messages):
|
| 304 |
+
|
| 305 |
+
# some sanity checking here around assumptions, to prevent footguns
|
| 306 |
+
must_be_from = "user" if i % 2 == 0 else "assistant"
|
| 307 |
+
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
| 308 |
+
|
| 309 |
+
# content can be either a simple string or a list of parts (e.g. containing tool calls)
|
| 310 |
+
content = message["content"]
|
| 311 |
+
|
| 312 |
+
if message["role"] == "user":
|
| 313 |
+
assert isinstance(content, str), "User messages are simply expected to be strings"
|
| 314 |
+
value_ids = self.encode(content)
|
| 315 |
+
add_tokens(user_start, 0)
|
| 316 |
+
add_tokens(value_ids, 0)
|
| 317 |
+
add_tokens(user_end, 0)
|
| 318 |
+
elif message["role"] == "assistant":
|
| 319 |
+
add_tokens(assistant_start, 0)
|
| 320 |
+
if isinstance(content, str):
|
| 321 |
+
# simple string => simply add the tokens
|
| 322 |
+
value_ids = self.encode(content)
|
| 323 |
+
add_tokens(value_ids, 1)
|
| 324 |
+
elif isinstance(content, list):
|
| 325 |
+
for part in content:
|
| 326 |
+
value_ids = self.encode(part["text"])
|
| 327 |
+
if part["type"] == "text":
|
| 328 |
+
# string part => simply add the tokens
|
| 329 |
+
add_tokens(value_ids, 1)
|
| 330 |
+
elif part["type"] == "python":
|
| 331 |
+
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
|
| 332 |
+
add_tokens(python_start, 1)
|
| 333 |
+
add_tokens(value_ids, 1)
|
| 334 |
+
add_tokens(python_end, 1)
|
| 335 |
+
elif part["type"] == "python_output":
|
| 336 |
+
# python output => add the tokens inside <|output_start|> and <|output_end|>
|
| 337 |
+
# none of these tokens are supervised because the tokens come from Python at test time
|
| 338 |
+
add_tokens(output_start, 0)
|
| 339 |
+
add_tokens(value_ids, 0)
|
| 340 |
+
add_tokens(output_end, 0)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError(f"Unknown part type: {part['type']}")
|
| 343 |
+
else:
|
| 344 |
+
raise ValueError(f"Unknown content type: {type(content)}")
|
| 345 |
+
add_tokens(assistant_end, 1)
|
| 346 |
+
|
| 347 |
+
# truncate to max_tokens tokens MAX (helps prevent OOMs)
|
| 348 |
+
ids = ids[:max_tokens]
|
| 349 |
+
mask = mask[:max_tokens]
|
| 350 |
+
return ids, mask
|
| 351 |
+
|
| 352 |
+
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
| 353 |
+
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
| 354 |
+
RED = '\033[91m'
|
| 355 |
+
GREEN = '\033[92m'
|
| 356 |
+
RESET = '\033[0m'
|
| 357 |
+
GRAY = '\033[90m'
|
| 358 |
+
tokens = []
|
| 359 |
+
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
| 360 |
+
token_str = self.decode([token_id])
|
| 361 |
+
color = GREEN if mask_val == 1 else RED
|
| 362 |
+
tokens.append(f"{color}{token_str}{RESET}")
|
| 363 |
+
if with_token_id:
|
| 364 |
+
tokens.append(f"{GRAY}({token_id}){RESET}")
|
| 365 |
+
return '|'.join(tokens)
|
| 366 |
+
|
| 367 |
+
def render_for_completion(self, conversation):
|
| 368 |
+
"""
|
| 369 |
+
Used during Reinforcement Learning. In that setting, we want to
|
| 370 |
+
render the conversation priming the Assistant for a completion.
|
| 371 |
+
Unlike the Chat SFT case, we don't need to return the mask.
|
| 372 |
+
"""
|
| 373 |
+
# We have some surgery to do: we need to pop the last message (of the Assistant)
|
| 374 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 375 |
+
messages = conversation["messages"]
|
| 376 |
+
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
|
| 377 |
+
messages.pop() # remove the last message (of the Assistant) inplace
|
| 378 |
+
|
| 379 |
+
# Now tokenize the conversation
|
| 380 |
+
ids, mask = self.render_conversation(conversation)
|
| 381 |
+
|
| 382 |
+
# Finally, to prime the Assistant for a completion, append the Assistant start token
|
| 383 |
+
assistant_start = self.encode_special("<|assistant_start|>")
|
| 384 |
+
ids.append(assistant_start)
|
| 385 |
+
return ids
|
| 386 |
+
|
| 387 |
+
# -----------------------------------------------------------------------------
|
| 388 |
+
# nanochat-specific convenience functions
|
| 389 |
+
|
| 390 |
+
def get_tokenizer():
|
| 391 |
+
from nanochat.common import get_base_dir
|
| 392 |
+
base_dir = get_base_dir()
|
| 393 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 394 |
+
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
|
| 395 |
+
return RustBPETokenizer.from_directory(tokenizer_dir)
|
| 396 |
+
|
| 397 |
+
def get_token_bytes(device="cpu"):
|
| 398 |
+
import torch
|
| 399 |
+
from nanochat.common import get_base_dir
|
| 400 |
+
base_dir = get_base_dir()
|
| 401 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 402 |
+
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
| 403 |
+
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
| 404 |
+
with open(token_bytes_path, "rb") as f:
|
| 405 |
+
token_bytes = torch.load(f, map_location=device)
|
| 406 |
+
return token_bytes
|
nanochat/ui.html
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
| 6 |
+
<title>NanoChat</title>
|
| 7 |
+
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
color-scheme: light;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
* {
|
| 14 |
+
box-sizing: border-box;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
html, body{
|
| 18 |
+
height: 100%;
|
| 19 |
+
margin: 0;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
body {
|
| 23 |
+
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
| 24 |
+
background-color: #ffffff;
|
| 25 |
+
color: #111827;
|
| 26 |
+
min-height: 100dvh;
|
| 27 |
+
margin: 0;
|
| 28 |
+
display: flex;
|
| 29 |
+
flex-direction: column;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.header {
|
| 33 |
+
background-color: #ffffff;
|
| 34 |
+
padding: 1.25rem 1.5rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.header-left {
|
| 38 |
+
display: flex;
|
| 39 |
+
align-items: center;
|
| 40 |
+
gap: 0.75rem;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.header-logo {
|
| 44 |
+
height: 32px;
|
| 45 |
+
width: auto;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.header h1 {
|
| 49 |
+
font-size: 1.25rem;
|
| 50 |
+
font-weight: 600;
|
| 51 |
+
margin: 0;
|
| 52 |
+
color: #111827;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.new-conversation-btn {
|
| 56 |
+
width: 32px;
|
| 57 |
+
height: 32px;
|
| 58 |
+
padding: 0;
|
| 59 |
+
border: 1px solid #e5e7eb;
|
| 60 |
+
border-radius: 0.5rem;
|
| 61 |
+
background-color: #ffffff;
|
| 62 |
+
color: #6b7280;
|
| 63 |
+
cursor: pointer;
|
| 64 |
+
display: flex;
|
| 65 |
+
align-items: center;
|
| 66 |
+
justify-content: center;
|
| 67 |
+
transition: all 0.2s ease;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.new-conversation-btn:hover {
|
| 71 |
+
background-color: #f3f4f6;
|
| 72 |
+
border-color: #d1d5db;
|
| 73 |
+
color: #374151;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.chat-container {
|
| 77 |
+
flex: 1;
|
| 78 |
+
overflow-y: auto;
|
| 79 |
+
background-color: #ffffff;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.chat-wrapper {
|
| 83 |
+
max-width: 48rem;
|
| 84 |
+
margin: 0 auto;
|
| 85 |
+
padding: 2rem 1.5rem 3rem;
|
| 86 |
+
display: flex;
|
| 87 |
+
flex-direction: column;
|
| 88 |
+
gap: 0.75rem;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.message {
|
| 92 |
+
display: flex;
|
| 93 |
+
justify-content: flex-start;
|
| 94 |
+
margin-bottom: 0.5rem;
|
| 95 |
+
color: #0d0d0d;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.message.assistant {
|
| 99 |
+
justify-content: flex-start;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.message.user {
|
| 103 |
+
justify-content: flex-end;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.message-content {
|
| 107 |
+
white-space: pre-wrap;
|
| 108 |
+
line-height: 1.6;
|
| 109 |
+
max-width: 100%;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.message.assistant .message-content {
|
| 113 |
+
background: transparent;
|
| 114 |
+
border: none;
|
| 115 |
+
cursor: pointer;
|
| 116 |
+
border-radius: 0.5rem;
|
| 117 |
+
padding: 0.5rem;
|
| 118 |
+
margin-left: -0.5rem;
|
| 119 |
+
transition: background-color 0.2s ease;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.message.assistant .message-content:hover {
|
| 123 |
+
background-color: #f9fafb;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.message.user .message-content {
|
| 127 |
+
background-color: #f3f4f6;
|
| 128 |
+
border-radius: 1.25rem;
|
| 129 |
+
padding: 0.8rem 1rem;
|
| 130 |
+
max-width: 65%;
|
| 131 |
+
cursor: pointer;
|
| 132 |
+
transition: background-color 0.2s ease;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.message.user .message-content:hover {
|
| 136 |
+
background-color: #e5e7eb;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.message.console .message-content {
|
| 140 |
+
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
|
| 141 |
+
font-size: 0.875rem;
|
| 142 |
+
background-color: #fafafa;
|
| 143 |
+
padding: 0.75rem 1rem;
|
| 144 |
+
color: #374151;
|
| 145 |
+
max-width: 80%;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.input-container {
|
| 149 |
+
background-color: #ffffff;
|
| 150 |
+
padding: 1rem;
|
| 151 |
+
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.input-wrapper {
|
| 155 |
+
max-width: 48rem;
|
| 156 |
+
margin: 0 auto;
|
| 157 |
+
display: flex;
|
| 158 |
+
gap: 0.75rem;
|
| 159 |
+
align-items: flex-end;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.chat-input {
|
| 163 |
+
flex: 1;
|
| 164 |
+
padding: 0.8rem 1rem;
|
| 165 |
+
border: 1px solid #d1d5db;
|
| 166 |
+
border-radius: 0.75rem;
|
| 167 |
+
background-color: #ffffff;
|
| 168 |
+
color: #111827;
|
| 169 |
+
font-size: 1rem;
|
| 170 |
+
line-height: 1.5;
|
| 171 |
+
resize: none;
|
| 172 |
+
outline: none;
|
| 173 |
+
min-height: 54px;
|
| 174 |
+
max-height: 200px;
|
| 175 |
+
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.chat-input::placeholder {
|
| 179 |
+
color: #9ca3af;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.chat-input:focus {
|
| 183 |
+
border-color: #2563eb;
|
| 184 |
+
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.send-button {
|
| 188 |
+
flex-shrink: 0;
|
| 189 |
+
padding: 0;
|
| 190 |
+
width: 54px;
|
| 191 |
+
height: 54px;
|
| 192 |
+
border: 1px solid #111827;
|
| 193 |
+
border-radius: 0.75rem;
|
| 194 |
+
background-color: #111827;
|
| 195 |
+
color: #ffffff;
|
| 196 |
+
display: flex;
|
| 197 |
+
align-items: center;
|
| 198 |
+
justify-content: center;
|
| 199 |
+
cursor: pointer;
|
| 200 |
+
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.send-button:hover:not(:disabled) {
|
| 204 |
+
background-color: #2563eb;
|
| 205 |
+
border-color: #2563eb;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.send-button:disabled {
|
| 209 |
+
cursor: not-allowed;
|
| 210 |
+
border-color: #d1d5db;
|
| 211 |
+
background-color: #e5e7eb;
|
| 212 |
+
color: #9ca3af;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.typing-indicator {
|
| 216 |
+
display: inline-block;
|
| 217 |
+
color: #6b7280;
|
| 218 |
+
letter-spacing: 0.15em;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.typing-indicator::after {
|
| 222 |
+
content: '···';
|
| 223 |
+
animation: typing 1.4s infinite;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
@keyframes typing {
|
| 227 |
+
0%, 60%, 100% { opacity: 0.2; }
|
| 228 |
+
30% { opacity: 1; }
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
.error-message {
|
| 232 |
+
background-color: #fee2e2;
|
| 233 |
+
border: 1px solid #fecaca;
|
| 234 |
+
color: #b91c1c;
|
| 235 |
+
padding: 0.75rem 1rem;
|
| 236 |
+
border-radius: 0.75rem;
|
| 237 |
+
margin-top: 0.5rem;
|
| 238 |
+
}
|
| 239 |
+
</style>
|
| 240 |
+
</head>
|
| 241 |
+
<body>
|
| 242 |
+
<div class="header">
|
| 243 |
+
<div class="header-left">
|
| 244 |
+
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
|
| 245 |
+
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 246 |
+
<path d="M12 5v14"></path>
|
| 247 |
+
<path d="M5 12h14"></path>
|
| 248 |
+
</svg>
|
| 249 |
+
</button>
|
| 250 |
+
<h1>nanochat</h1>
|
| 251 |
+
</div>
|
| 252 |
+
</div>
|
| 253 |
+
|
| 254 |
+
<div class="chat-container" id="chatContainer">
|
| 255 |
+
<div class="chat-wrapper" id="chatWrapper">
|
| 256 |
+
<!-- Messages will be added here -->
|
| 257 |
+
</div>
|
| 258 |
+
</div>
|
| 259 |
+
|
| 260 |
+
<div class="input-container">
|
| 261 |
+
<div class="input-wrapper">
|
| 262 |
+
<textarea
|
| 263 |
+
id="chatInput"
|
| 264 |
+
class="chat-input"
|
| 265 |
+
placeholder="Ask anything"
|
| 266 |
+
rows="1"
|
| 267 |
+
onkeydown="handleKeyDown(event)"
|
| 268 |
+
></textarea>
|
| 269 |
+
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
|
| 270 |
+
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 271 |
+
<path d="M22 2L11 13"></path>
|
| 272 |
+
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
|
| 273 |
+
</svg>
|
| 274 |
+
</button>
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
|
| 278 |
+
<script>
|
| 279 |
+
const API_URL = '';
|
| 280 |
+
const chatContainer = document.getElementById('chatContainer');
|
| 281 |
+
const chatWrapper = document.getElementById('chatWrapper');
|
| 282 |
+
const chatInput = document.getElementById('chatInput');
|
| 283 |
+
const sendButton = document.getElementById('sendButton');
|
| 284 |
+
|
| 285 |
+
let messages = [];
|
| 286 |
+
let isGenerating = false;
|
| 287 |
+
let currentTemperature = 0.8;
|
| 288 |
+
let currentTopK = 50;
|
| 289 |
+
|
| 290 |
+
chatInput.addEventListener('input', function() {
|
| 291 |
+
this.style.height = 'auto';
|
| 292 |
+
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
|
| 293 |
+
sendButton.disabled = !this.value.trim() || isGenerating;
|
| 294 |
+
});
|
| 295 |
+
|
| 296 |
+
function handleKeyDown(event) {
|
| 297 |
+
if (event.key === 'Enter' && !event.shiftKey) {
|
| 298 |
+
event.preventDefault();
|
| 299 |
+
sendMessage();
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
document.addEventListener('keydown', function(event) {
|
| 304 |
+
// Ctrl+Shift+N for new conversation
|
| 305 |
+
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
|
| 306 |
+
event.preventDefault();
|
| 307 |
+
if (!isGenerating) {
|
| 308 |
+
newConversation();
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
});
|
| 312 |
+
|
| 313 |
+
function newConversation() {
|
| 314 |
+
messages = [];
|
| 315 |
+
chatWrapper.innerHTML = '';
|
| 316 |
+
chatInput.value = '';
|
| 317 |
+
chatInput.style.height = 'auto';
|
| 318 |
+
sendButton.disabled = false;
|
| 319 |
+
isGenerating = false;
|
| 320 |
+
chatInput.focus();
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
function addMessage(role, content, messageIndex = null) {
|
| 324 |
+
const messageDiv = document.createElement('div');
|
| 325 |
+
messageDiv.className = `message ${role}`;
|
| 326 |
+
|
| 327 |
+
const contentDiv = document.createElement('div');
|
| 328 |
+
contentDiv.className = 'message-content';
|
| 329 |
+
contentDiv.textContent = content;
|
| 330 |
+
|
| 331 |
+
// Add click handler for user messages to enable editing
|
| 332 |
+
if (role === 'user' && messageIndex !== null) {
|
| 333 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 334 |
+
contentDiv.setAttribute('title', 'Click to edit and restart from here');
|
| 335 |
+
contentDiv.addEventListener('click', function() {
|
| 336 |
+
if (!isGenerating) {
|
| 337 |
+
editMessage(messageIndex);
|
| 338 |
+
}
|
| 339 |
+
});
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Add click handler for assistant messages to enable regeneration
|
| 343 |
+
if (role === 'assistant' && messageIndex !== null) {
|
| 344 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 345 |
+
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
| 346 |
+
contentDiv.addEventListener('click', function() {
|
| 347 |
+
if (!isGenerating) {
|
| 348 |
+
regenerateMessage(messageIndex);
|
| 349 |
+
}
|
| 350 |
+
});
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
messageDiv.appendChild(contentDiv);
|
| 354 |
+
chatWrapper.appendChild(messageDiv);
|
| 355 |
+
|
| 356 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 357 |
+
return contentDiv;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
function editMessage(messageIndex) {
|
| 361 |
+
// Find the message in the messages array
|
| 362 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 363 |
+
|
| 364 |
+
const messageToEdit = messages[messageIndex];
|
| 365 |
+
if (messageToEdit.role !== 'user') return;
|
| 366 |
+
|
| 367 |
+
// Copy message content to input
|
| 368 |
+
chatInput.value = messageToEdit.content;
|
| 369 |
+
chatInput.style.height = 'auto';
|
| 370 |
+
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
|
| 371 |
+
|
| 372 |
+
// Remove this message and all subsequent messages from the array
|
| 373 |
+
messages = messages.slice(0, messageIndex);
|
| 374 |
+
|
| 375 |
+
// Remove message elements from DOM starting from messageIndex
|
| 376 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 377 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 378 |
+
allMessages[i].remove();
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
// Enable send button and focus input
|
| 382 |
+
sendButton.disabled = false;
|
| 383 |
+
chatInput.focus();
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
async function generateAssistantResponse() {
|
| 387 |
+
isGenerating = true;
|
| 388 |
+
sendButton.disabled = true;
|
| 389 |
+
|
| 390 |
+
const assistantContent = addMessage('assistant', '');
|
| 391 |
+
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
| 392 |
+
|
| 393 |
+
try {
|
| 394 |
+
const response = await fetch(`${API_URL}/chat/completions`, {
|
| 395 |
+
method: 'POST',
|
| 396 |
+
headers: {
|
| 397 |
+
'Content-Type': 'application/json',
|
| 398 |
+
},
|
| 399 |
+
body: JSON.stringify({
|
| 400 |
+
messages: messages,
|
| 401 |
+
temperature: currentTemperature,
|
| 402 |
+
top_k: currentTopK,
|
| 403 |
+
max_tokens: 512
|
| 404 |
+
}),
|
| 405 |
+
});
|
| 406 |
+
|
| 407 |
+
if (!response.ok) {
|
| 408 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
const reader = response.body.getReader();
|
| 412 |
+
const decoder = new TextDecoder();
|
| 413 |
+
let fullResponse = '';
|
| 414 |
+
assistantContent.textContent = '';
|
| 415 |
+
|
| 416 |
+
while (true) {
|
| 417 |
+
const { done, value } = await reader.read();
|
| 418 |
+
if (done) break;
|
| 419 |
+
|
| 420 |
+
const chunk = decoder.decode(value);
|
| 421 |
+
const lines = chunk.split('\n');
|
| 422 |
+
|
| 423 |
+
for (const line of lines) {
|
| 424 |
+
if (line.startsWith('data: ')) {
|
| 425 |
+
try {
|
| 426 |
+
const data = JSON.parse(line.slice(6));
|
| 427 |
+
if (data.token) {
|
| 428 |
+
fullResponse += data.token;
|
| 429 |
+
assistantContent.textContent = fullResponse;
|
| 430 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 431 |
+
}
|
| 432 |
+
} catch (e) {
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
const assistantMessageIndex = messages.length;
|
| 439 |
+
messages.push({ role: 'assistant', content: fullResponse });
|
| 440 |
+
|
| 441 |
+
// Add click handler to regenerate this assistant message
|
| 442 |
+
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
| 443 |
+
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
| 444 |
+
assistantContent.addEventListener('click', function() {
|
| 445 |
+
if (!isGenerating) {
|
| 446 |
+
regenerateMessage(assistantMessageIndex);
|
| 447 |
+
}
|
| 448 |
+
});
|
| 449 |
+
|
| 450 |
+
} catch (error) {
|
| 451 |
+
console.error('Error:', error);
|
| 452 |
+
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
| 453 |
+
} finally {
|
| 454 |
+
isGenerating = false;
|
| 455 |
+
sendButton.disabled = !chatInput.value.trim();
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
async function regenerateMessage(messageIndex) {
|
| 460 |
+
// Find the message in the messages array
|
| 461 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 462 |
+
|
| 463 |
+
const messageToRegenerate = messages[messageIndex];
|
| 464 |
+
if (messageToRegenerate.role !== 'assistant') return;
|
| 465 |
+
|
| 466 |
+
// Remove this message and all subsequent messages from the array
|
| 467 |
+
messages = messages.slice(0, messageIndex);
|
| 468 |
+
|
| 469 |
+
// Remove message elements from DOM starting from messageIndex
|
| 470 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 471 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 472 |
+
allMessages[i].remove();
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
// Regenerate the assistant response
|
| 476 |
+
await generateAssistantResponse();
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
function handleSlashCommand(command) {
|
| 480 |
+
const parts = command.trim().split(/\s+/);
|
| 481 |
+
const cmd = parts[0].toLowerCase();
|
| 482 |
+
const arg = parts[1];
|
| 483 |
+
|
| 484 |
+
if (cmd === '/temperature') {
|
| 485 |
+
if (arg === undefined) {
|
| 486 |
+
addMessage('console', `Current temperature: ${currentTemperature}`);
|
| 487 |
+
} else {
|
| 488 |
+
const temp = parseFloat(arg);
|
| 489 |
+
if (isNaN(temp) || temp < 0 || temp > 2) {
|
| 490 |
+
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
|
| 491 |
+
} else {
|
| 492 |
+
currentTemperature = temp;
|
| 493 |
+
addMessage('console', `Temperature set to ${currentTemperature}`);
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
return true;
|
| 497 |
+
} else if (cmd === '/topk') {
|
| 498 |
+
if (arg === undefined) {
|
| 499 |
+
addMessage('console', `Current top-k: ${currentTopK}`);
|
| 500 |
+
} else {
|
| 501 |
+
const topk = parseInt(arg);
|
| 502 |
+
if (isNaN(topk) || topk < 1 || topk > 200) {
|
| 503 |
+
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
|
| 504 |
+
} else {
|
| 505 |
+
currentTopK = topk;
|
| 506 |
+
addMessage('console', `Top-k set to ${currentTopK}`);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
return true;
|
| 510 |
+
} else if (cmd === '/clear') {
|
| 511 |
+
newConversation();
|
| 512 |
+
return true;
|
| 513 |
+
} else if (cmd === '/help') {
|
| 514 |
+
addMessage('console',
|
| 515 |
+
'Available commands:\n' +
|
| 516 |
+
'/temperature - Show current temperature\n' +
|
| 517 |
+
'/temperature <value> - Set temperature (0.0-2.0)\n' +
|
| 518 |
+
'/topk - Show current top-k\n' +
|
| 519 |
+
'/topk <value> - Set top-k (1-200)\n' +
|
| 520 |
+
'/clear - Clear conversation\n' +
|
| 521 |
+
'/help - Show this help message'
|
| 522 |
+
);
|
| 523 |
+
return true;
|
| 524 |
+
}
|
| 525 |
+
return false;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
async function sendMessage() {
|
| 529 |
+
const message = chatInput.value.trim();
|
| 530 |
+
if (!message || isGenerating) return;
|
| 531 |
+
|
| 532 |
+
// Handle slash commands
|
| 533 |
+
if (message.startsWith('/')) {
|
| 534 |
+
chatInput.value = '';
|
| 535 |
+
chatInput.style.height = 'auto';
|
| 536 |
+
handleSlashCommand(message);
|
| 537 |
+
return;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
chatInput.value = '';
|
| 541 |
+
chatInput.style.height = 'auto';
|
| 542 |
+
|
| 543 |
+
const userMessageIndex = messages.length;
|
| 544 |
+
messages.push({ role: 'user', content: message });
|
| 545 |
+
addMessage('user', message, userMessageIndex);
|
| 546 |
+
|
| 547 |
+
await generateAssistantResponse();
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
sendButton.disabled = false;
|
| 551 |
+
|
| 552 |
+
// Autofocus the chat input on page load
|
| 553 |
+
chatInput.focus();
|
| 554 |
+
|
| 555 |
+
fetch(`${API_URL}/health`)
|
| 556 |
+
.then(response => response.json())
|
| 557 |
+
.then(data => {
|
| 558 |
+
console.log('Engine status:', data);
|
| 559 |
+
})
|
| 560 |
+
.catch(error => {
|
| 561 |
+
console.error('Engine not available:', error);
|
| 562 |
+
chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
|
| 563 |
+
});
|
| 564 |
+
</script>
|
| 565 |
+
</body>
|
| 566 |
+
</html>
|
pyproject.toml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "nanochat"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "the minimal full-stack ChatGPT clone"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"datasets>=4.0.0",
|
| 9 |
+
"fastapi>=0.117.1",
|
| 10 |
+
"ipykernel>=7.1.0",
|
| 11 |
+
"kernels>=0.11.7",
|
| 12 |
+
"matplotlib>=3.10.8",
|
| 13 |
+
"psutil>=7.1.0",
|
| 14 |
+
"python-dotenv>=1.2.1",
|
| 15 |
+
"regex>=2025.9.1",
|
| 16 |
+
"rustbpe>=0.1.0",
|
| 17 |
+
"scipy>=1.15.3",
|
| 18 |
+
"setuptools>=80.9.0",
|
| 19 |
+
"tabulate>=0.9.0",
|
| 20 |
+
"tiktoken>=0.11.0",
|
| 21 |
+
"tokenizers>=0.22.0",
|
| 22 |
+
"torch==2.9.1",
|
| 23 |
+
"transformers>=4.57.3",
|
| 24 |
+
"uvicorn>=0.36.0",
|
| 25 |
+
"wandb>=0.21.3",
|
| 26 |
+
"zstandard>=0.25.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[dependency-groups]
|
| 30 |
+
dev = [
|
| 31 |
+
"pytest>=8.0.0",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[tool.pytest.ini_options]
|
| 35 |
+
markers = [
|
| 36 |
+
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
| 37 |
+
]
|
| 38 |
+
testpaths = ["tests"]
|
| 39 |
+
python_files = ["test_*.py"]
|
| 40 |
+
python_classes = ["Test*"]
|
| 41 |
+
python_functions = ["test_*"]
|
| 42 |
+
|
| 43 |
+
# target torch to cuda 12.8 or CPU
|
| 44 |
+
[tool.uv.sources]
|
| 45 |
+
torch = [
|
| 46 |
+
{ index = "pytorch-cpu", extra = "cpu" },
|
| 47 |
+
{ index = "pytorch-cu128", extra = "gpu" },
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
[[tool.uv.index]]
|
| 51 |
+
name = "pytorch-cpu"
|
| 52 |
+
url = "https://download.pytorch.org/whl/cpu"
|
| 53 |
+
explicit = true
|
| 54 |
+
|
| 55 |
+
[[tool.uv.index]]
|
| 56 |
+
name = "pytorch-cu128"
|
| 57 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 58 |
+
explicit = true
|
| 59 |
+
|
| 60 |
+
[project.optional-dependencies]
|
| 61 |
+
cpu = [
|
| 62 |
+
"torch==2.9.1",
|
| 63 |
+
]
|
| 64 |
+
gpu = [
|
| 65 |
+
"torch==2.9.1",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
[tool.uv]
|
| 69 |
+
conflicts = [
|
| 70 |
+
[
|
| 71 |
+
{ extra = "cpu" },
|
| 72 |
+
{ extra = "gpu" },
|
| 73 |
+
],
|
| 74 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
gradio
|