Really-Amazing commited on
Commit
1702c1d
·
verified ·
1 Parent(s): d2143a5

Initial deploy of NanoChat-ClimbMix-D12 demo

Browse files

Adding model weights (step 971), Gradio UI (app.py), and Docker configuration for Hugging Face Space hosting. Includes nanochat engine and core dependencies.

Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use slim base for smaller size on free CPU tier
2
+ FROM python:3.10-slim
3
+
4
+ # Install minimal build tools only if any C extensions are needed (nanochat is mostly pure torch/Python)
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential git curl \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ # Copy requirements first for better caching
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt \
14
+ && pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
15
+
16
+ # Copy the rest of your project (nanochat/, weights, app.py, etc.)
17
+ COPY . .
18
+
19
+ # Clean up pip cache to save space
20
+ RUN pip cache purge
21
+
22
+ # Gradio defaults
23
+ EXPOSE 7860
24
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
25
+
26
+ # Run the app
27
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Andrej Karpathy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from nanochat.engine import Engine
4
+
5
+ MODEL_PATH = "model_000971.pt"
6
+ META_PATH = "meta_000971.json"
7
+
8
+ print("Waking up the toddler (NanoChat-ClimbMix-D12)...")
9
+ engine = Engine(model_path=MODEL_PATH, meta_path=META_PATH, device="cpu")
10
+
11
+ def chat_fn(message):
12
+ response = engine.generate(message, max_new_tokens=300, temperature=0.85) # higher temp = more fun/confident nonsense
13
+ return response
14
+
15
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
16
+ gr.Markdown("# 🧸 NanoChat-ClimbMix-D12 – The Confident Toddler LLM")
17
+ gr.Markdown("Inspired by Andrej Karpathy's nanochat. Currently in 'Preschool Phase': 100% confident spelling, 0% reliable facts. 😂")
18
+ gr.Markdown("**Coming soon:** D14 (Elementary), D16 (Middle School), D18 (High School), D20+ (University level) – less hallucinations, more wisdom!")
19
+
20
+ with gr.Accordion("⚠️ Hallucination Disclaimer", open=True):
21
+ gr.Markdown("This model boldly answers everything — even when wrong. Enjoy the comedy! Next versions will grow up fast.")
22
+
23
+ gr.ChatInterface(
24
+ fn=chat_fn,
25
+ examples=[
26
+ "Why is the sky blue?",
27
+ "How many planets are in the solar system?",
28
+ "Write Python code to say hello world",
29
+ "Explain photosynthesis in one sentence"
30
+ ],
31
+ title="Chat with the Toddler",
32
+ description="Ask anything — it will reply with maximum confidence!"
33
+ )
34
+
35
+ if __name__ == "__main__":
36
+ demo.launch(server_name="0.0.0.0", server_port=7860)
meta_000971.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "step": 971,
3
+ "val_bpb": 0.36456499504385426,
4
+ "model_config": {
5
+ "sequence_len": 2048,
6
+ "vocab_size": 32768,
7
+ "n_layer": 12,
8
+ "n_head": 6,
9
+ "n_kv_head": 6,
10
+ "n_embd": 768,
11
+ "window_pattern": "L"
12
+ },
13
+ "user_config": {
14
+ "run": "d12_climbmix_sft_v1",
15
+ "device_type": "",
16
+ "model_tag": "d12_climbmix_v1",
17
+ "model_step": 2205,
18
+ "load_optimizer": 1,
19
+ "num_iterations": -1,
20
+ "max_seq_len": null,
21
+ "device_batch_size": 2,
22
+ "total_batch_size": null,
23
+ "embedding_lr": null,
24
+ "unembedding_lr": null,
25
+ "matrix_lr": null,
26
+ "init_lr_frac": 0.8,
27
+ "warmup_ratio": 0.0,
28
+ "warmdown_ratio": 0.5,
29
+ "final_lr_frac": 0.0,
30
+ "eval_every": -1,
31
+ "eval_tokens": 20971520,
32
+ "chatcore_every": 200,
33
+ "chatcore_max_cat": -1,
34
+ "chatcore_max_sample": 24,
35
+ "mmlu_epochs": 3,
36
+ "gsm8k_epochs": 4
37
+ }
38
+ }
model_000971.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb7c387bae513ea816333e6bce1f442491dc3ab6535f0f62ed8032a82121ec8f
3
+ size 792760433
nanochat/__init__.py ADDED
File without changes
nanochat/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (131 Bytes). View file
 
nanochat/__pycache__/checkpoint_manager.cpython-310.pyc ADDED
Binary file (6.94 kB). View file
 
nanochat/__pycache__/common.cpython-310.pyc ADDED
Binary file (9.6 kB). View file
 
nanochat/__pycache__/core_eval.cpython-310.pyc ADDED
Binary file (8.45 kB). View file
 
nanochat/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
nanochat/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
nanochat/__pycache__/engine.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
nanochat/__pycache__/execution.cpython-310.pyc ADDED
Binary file (8.72 kB). View file
 
nanochat/__pycache__/flash_attention.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
nanochat/__pycache__/gpt.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
nanochat/__pycache__/loss_eval.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
nanochat/__pycache__/optim.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
nanochat/__pycache__/report.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
nanochat/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def _patch_missing_config_keys(model_config_kwargs):
24
+ """Add default values for new config keys missing in old checkpoints."""
25
+ # Old models were trained with full context (no sliding window)
26
+ if "window_pattern" not in model_config_kwargs:
27
+ model_config_kwargs["window_pattern"] = "L"
28
+ log0(f"Patching missing window_pattern in model config to 'L'")
29
+
30
+ def _patch_missing_keys(model_data, model_config):
31
+ """Add default values for new parameters that may be missing in old checkpoints."""
32
+ n_layer = model_config.n_layer
33
+ # resid_lambdas defaults to 1.0 (identity scaling)
34
+ if "resid_lambdas" not in model_data:
35
+ model_data["resid_lambdas"] = torch.ones(n_layer)
36
+ log0(f"Patching missing resid_lambdas in model data to 1.0")
37
+ # x0_lambdas defaults to 0.0 (disabled)
38
+ if "x0_lambdas" not in model_data:
39
+ model_data["x0_lambdas"] = torch.zeros(n_layer)
40
+ log0(f"Patching missing x0_lambdas in model data to 0.0")
41
+
42
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
43
+ if rank == 0:
44
+ os.makedirs(checkpoint_dir, exist_ok=True)
45
+ # Save the model state parameters
46
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
47
+ torch.save(model_data, model_path)
48
+ logger.info(f"Saved model parameters to: {model_path}")
49
+ # Save the metadata dict as json
50
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
51
+ with open(meta_path, "w", encoding="utf-8") as f:
52
+ json.dump(meta_data, f, indent=2)
53
+ logger.info(f"Saved metadata to: {meta_path}")
54
+ # Note that optimizer state is sharded across ranks, so each rank must save its own.
55
+ if optimizer_data is not None:
56
+ os.makedirs(checkpoint_dir, exist_ok=True)
57
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
58
+ torch.save(optimizer_data, optimizer_path)
59
+ logger.info(f"Saved optimizer state to: {optimizer_path}")
60
+
61
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
62
+ # Load the model state
63
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
64
+ model_data = torch.load(model_path, map_location=device)
65
+ # Load the optimizer state if requested
66
+ optimizer_data = None
67
+ if load_optimizer:
68
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
69
+ optimizer_data = torch.load(optimizer_path, map_location=device)
70
+ # Load the metadata
71
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
72
+ with open(meta_path, "r", encoding="utf-8") as f:
73
+ meta_data = json.load(f)
74
+ return model_data, optimizer_data, meta_data
75
+
76
+
77
+ def build_model(checkpoint_dir, step, device, phase):
78
+ """
79
+ A bunch of repetitive code to build a model from a given checkpoint.
80
+ Returns:
81
+ - base model - uncompiled, not wrapped in DDP
82
+ - tokenizer
83
+ - meta data saved during base model training
84
+ """
85
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
86
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
87
+ if device.type in {"cpu", "mps"}:
88
+ # Convert bfloat16 tensors to float for CPU inference
89
+ model_data = {
90
+ k: v.float() if v.dtype == torch.bfloat16 else v
91
+ for k, v in model_data.items()
92
+ }
93
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
94
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
95
+ model_config_kwargs = meta_data["model_config"]
96
+ _patch_missing_config_keys(model_config_kwargs)
97
+ log0(f"Building model with config: {model_config_kwargs}")
98
+ model_config = GPTConfig(**model_config_kwargs)
99
+ _patch_missing_keys(model_data, model_config)
100
+ with torch.device("meta"):
101
+ model = GPT(model_config)
102
+ # Load the model state
103
+ model.to_empty(device=device)
104
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
105
+ model.load_state_dict(model_data, strict=True, assign=True)
106
+ # Put the model in the right training phase / mode
107
+ if phase == "eval":
108
+ model.eval()
109
+ else:
110
+ model.train()
111
+ # Load the Tokenizer
112
+ tokenizer = get_tokenizer()
113
+ # Sanity check: compatibility between model and tokenizer
114
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
115
+ return model, tokenizer, meta_data
116
+
117
+
118
+ def find_largest_model(checkpoints_dir):
119
+ # attempt to guess the model tag: take the biggest model available
120
+ model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
121
+ if not model_tags:
122
+ raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
123
+ # 1) normally all model tags are of the form d<number>, try that first:
124
+ candidates = []
125
+ for model_tag in model_tags:
126
+ match = re.match(r"d(\d+)", model_tag)
127
+ if match:
128
+ model_depth = int(match.group(1))
129
+ candidates.append((model_depth, model_tag))
130
+ if candidates:
131
+ candidates.sort(key=lambda x: x[0], reverse=True)
132
+ return candidates[0][1]
133
+ # 2) if that failed, take the most recently updated model:
134
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
135
+ return model_tags[0]
136
+
137
+
138
+ def find_last_step(checkpoint_dir):
139
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
140
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
141
+ if not checkpoint_files:
142
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
143
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
144
+ return last_step
145
+
146
+ # -----------------------------------------------------------------------------
147
+ # convenience functions that take into account nanochat's directory structure
148
+
149
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
150
+ if model_tag is None:
151
+ # guess the model tag by defaulting to the largest model
152
+ model_tag = find_largest_model(checkpoints_dir)
153
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
154
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
155
+ if step is None:
156
+ # guess the step by defaulting to the last step
157
+ step = find_last_step(checkpoint_dir)
158
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
159
+ # build the model
160
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
161
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
162
+ return model, tokenizer, meta_data
163
+
164
+ def load_model(source, *args, **kwargs):
165
+ model_dir = {
166
+ "base": "base_checkpoints",
167
+ "sft": "chatsft_checkpoints",
168
+ "rl": "chatrl_checkpoints",
169
+ }[source]
170
+ base_dir = get_base_dir()
171
+ checkpoints_dir = os.path.join(base_dir, model_dir)
172
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
173
+
174
+ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
175
+ """Load just the optimizer shard for a given rank, without re-loading the model."""
176
+ model_dir = {
177
+ "base": "base_checkpoints",
178
+ "sft": "chatsft_checkpoints",
179
+ "rl": "chatrl_checkpoints",
180
+ }[source]
181
+ base_dir = get_base_dir()
182
+ checkpoints_dir = os.path.join(base_dir, model_dir)
183
+ if model_tag is None:
184
+ model_tag = find_largest_model(checkpoints_dir)
185
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
186
+ if step is None:
187
+ step = find_last_step(checkpoint_dir)
188
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
189
+ if not os.path.exists(optimizer_path):
190
+ log0(f"Optimizer checkpoint not found: {optimizer_path}")
191
+ return None
192
+ log0(f"Loading optimizer state from {optimizer_path}")
193
+ optimizer_data = torch.load(optimizer_path, map_location=device)
194
+ return optimizer_data
nanochat/common.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import urllib.request
9
+ import torch
10
+ import torch.distributed as dist
11
+ from filelock import FileLock
12
+
13
+ # The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
14
+ # Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
15
+ # Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
16
+ _DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
17
+ def _detect_compute_dtype():
18
+ env = os.environ.get("NANOCHAT_DTYPE")
19
+ if env is not None:
20
+ return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
21
+ if torch.cuda.is_available():
22
+ # bf16 requires SM 80+ (Ampere: A100, A10, etc.)
23
+ # Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
24
+ capability = torch.cuda.get_device_capability()
25
+ if capability >= (8, 0):
26
+ return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
27
+ # fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
28
+ # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
29
+ return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
30
+ return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
31
+ COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
32
+
33
+ class ColoredFormatter(logging.Formatter):
34
+ """Custom formatter that adds colors to log messages."""
35
+ # ANSI color codes
36
+ COLORS = {
37
+ 'DEBUG': '\033[36m', # Cyan
38
+ 'INFO': '\033[32m', # Green
39
+ 'WARNING': '\033[33m', # Yellow
40
+ 'ERROR': '\033[31m', # Red
41
+ 'CRITICAL': '\033[35m', # Magenta
42
+ }
43
+ RESET = '\033[0m'
44
+ BOLD = '\033[1m'
45
+ def format(self, record):
46
+ # Add color to the level name
47
+ levelname = record.levelname
48
+ if levelname in self.COLORS:
49
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
50
+ # Format the message
51
+ message = super().format(record)
52
+ # Add color to specific parts of the message
53
+ if levelname == 'INFO':
54
+ # Highlight numbers and percentages
55
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
56
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
57
+ return message
58
+
59
+ def setup_default_logging():
60
+ handler = logging.StreamHandler()
61
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
62
+ logging.basicConfig(
63
+ level=logging.INFO,
64
+ handlers=[handler]
65
+ )
66
+
67
+ setup_default_logging()
68
+ logger = logging.getLogger(__name__)
69
+
70
+ def get_base_dir():
71
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
72
+ if os.environ.get("NANOCHAT_BASE_DIR"):
73
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
74
+ else:
75
+ home_dir = os.path.expanduser("~")
76
+ cache_dir = os.path.join(home_dir, ".cache")
77
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
78
+ os.makedirs(nanochat_dir, exist_ok=True)
79
+ return nanochat_dir
80
+
81
+ def download_file_with_lock(url, filename, postprocess_fn=None):
82
+ """
83
+ Downloads a file from a URL to a local path in the base directory.
84
+ Uses a lock file to prevent concurrent downloads among multiple ranks.
85
+ """
86
+ base_dir = get_base_dir()
87
+ file_path = os.path.join(base_dir, filename)
88
+ lock_path = file_path + ".lock"
89
+
90
+ if os.path.exists(file_path):
91
+ return file_path
92
+
93
+ with FileLock(lock_path):
94
+ # Only a single rank can acquire this lock
95
+ # All other ranks block until it is released
96
+
97
+ # Recheck after acquiring lock
98
+ if os.path.exists(file_path):
99
+ return file_path
100
+
101
+ # Download the content as bytes
102
+ print(f"Downloading {url}...")
103
+ with urllib.request.urlopen(url) as response:
104
+ content = response.read() # bytes
105
+
106
+ # Write to local file
107
+ with open(file_path, 'wb') as f:
108
+ f.write(content)
109
+ print(f"Downloaded to {file_path}")
110
+
111
+ # Run the postprocess function if provided
112
+ if postprocess_fn is not None:
113
+ postprocess_fn(file_path)
114
+
115
+ return file_path
116
+
117
+ def print0(s="",**kwargs):
118
+ ddp_rank = int(os.environ.get('RANK', 0))
119
+ if ddp_rank == 0:
120
+ print(s, **kwargs)
121
+
122
+ def print_banner():
123
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
124
+ banner = """
125
+ █████ █████
126
+ ░░███ ░░███
127
+ ████████ ██████ ██��█████ ██████ ██████ ░███████ ██████ ███████
128
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
129
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
130
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
131
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
132
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
133
+ """
134
+ print0(banner)
135
+
136
+ def is_ddp_requested() -> bool:
137
+ """
138
+ True if launched by torchrun (env present), even before init.
139
+ Used to decide whether we *should* initialize a PG.
140
+ """
141
+ return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
142
+
143
+ def is_ddp_initialized() -> bool:
144
+ """
145
+ True if torch.distributed is available and the process group is initialized.
146
+ Used at cleanup to avoid destroying a non-existent PG.
147
+ """
148
+ return dist.is_available() and dist.is_initialized()
149
+
150
+ def get_dist_info():
151
+ if is_ddp_requested():
152
+ # We rely on torchrun's env to decide if we SHOULD init.
153
+ # (Initialization itself happens in compute init.)
154
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
155
+ ddp_rank = int(os.environ['RANK'])
156
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
157
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
158
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
159
+ else:
160
+ return False, 0, 0, 1
161
+
162
+ def autodetect_device_type():
163
+ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
164
+ if torch.cuda.is_available():
165
+ device_type = "cuda"
166
+ elif torch.backends.mps.is_available():
167
+ device_type = "mps"
168
+ else:
169
+ device_type = "cpu"
170
+ print0(f"Autodetected device type: {device_type}")
171
+ return device_type
172
+
173
+ def compute_init(device_type="cuda"): # cuda|cpu|mps
174
+ """Basic initialization that we keep doing over and over, so make common."""
175
+
176
+ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
177
+ if device_type == "cuda":
178
+ assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
179
+ if device_type == "mps":
180
+ assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
181
+
182
+ # Reproducibility
183
+ # Note that we set the global seeds here, but most of the code uses explicit rng objects.
184
+ # The only place where global rng might be used is nn.Module initialization of the model weights.
185
+ torch.manual_seed(42)
186
+ if device_type == "cuda":
187
+ torch.cuda.manual_seed(42)
188
+ # skipping full reproducibility for now, possibly investigate slowdown later
189
+ # torch.use_deterministic_algorithms(True)
190
+
191
+ # Precision
192
+ if device_type == "cuda":
193
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
194
+
195
+ # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
196
+ is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
197
+ if is_ddp_requested and device_type == "cuda":
198
+ device = torch.device("cuda", ddp_local_rank)
199
+ torch.cuda.set_device(device) # make "cuda" default to this device
200
+ dist.init_process_group(backend="nccl", device_id=device)
201
+ dist.barrier()
202
+ else:
203
+ device = torch.device(device_type) # mps|cpu
204
+
205
+ if ddp_rank == 0:
206
+ logger.info(f"Distributed world size: {ddp_world_size}")
207
+
208
+ return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
209
+
210
+ def compute_cleanup():
211
+ """Companion function to compute_init, to clean things up before script exit"""
212
+ if is_ddp_initialized():
213
+ dist.destroy_process_group()
214
+
215
+ class DummyWandb:
216
+ """Useful if we wish to not use wandb but have all the same signatures"""
217
+ def __init__(self):
218
+ pass
219
+ def log(self, *args, **kwargs):
220
+ pass
221
+ def finish(self):
222
+ pass
223
+
224
+ # hardcoded BF16 peak flops for various GPUs
225
+ # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
226
+ # and PR: https://github.com/karpathy/nanochat/pull/147
227
+ def get_peak_flops(device_name: str) -> float:
228
+ name = device_name.lower()
229
+
230
+ # Table order matters: more specific patterns first.
231
+ _PEAK_FLOPS_TABLE = (
232
+ # NVIDIA Blackwell
233
+ (["gb200"], 2.5e15),
234
+ (["grace blackwell"], 2.5e15),
235
+ (["b200"], 2.25e15),
236
+ (["b100"], 1.8e15),
237
+ # NVIDIA Hopper
238
+ (["h200", "nvl"], 836e12),
239
+ (["h200", "pcie"], 836e12),
240
+ (["h200"], 989e12),
241
+ (["h100", "nvl"], 835e12),
242
+ (["h100", "pcie"], 756e12),
243
+ (["h100"], 989e12),
244
+ (["h800", "nvl"], 989e12),
245
+ (["h800"], 756e12),
246
+ # NVIDIA Ampere data center
247
+ (["a100"], 312e12),
248
+ (["a800"], 312e12),
249
+ (["a40"], 149.7e12),
250
+ (["a30"], 165e12),
251
+ # NVIDIA Ada data center
252
+ (["l40s"], 362e12),
253
+ (["l40-s"], 362e12),
254
+ (["l40 s"], 362e12),
255
+ (["l4"], 121e12),
256
+ # AMD CDNA accelerators
257
+ (["mi355"], 2.5e15),
258
+ (["mi325"], 1.3074e15),
259
+ (["mi300x"], 1.3074e15),
260
+ (["mi300a"], 980.6e12),
261
+ (["mi250x"], 383e12),
262
+ (["mi250"], 362.1e12),
263
+ # Consumer RTX
264
+ (["5090"], 209.5e12),
265
+ (["4090"], 165.2e12),
266
+ (["3090"], 71e12),
267
+ )
268
+ for patterns, flops in _PEAK_FLOPS_TABLE:
269
+ if all(p in name for p in patterns):
270
+ return flops
271
+ if "data center gpu max 1550" in name:
272
+ # Ponte Vecchio (PVC) - dynamic based on compute units
273
+ max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
274
+ return 512 * max_comp_units * 1300 * 10**6
275
+
276
+ # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
277
+ logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
278
+ return float('inf')
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, min(num_fewshot, len(available_indices)))
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distributed dataloaders for pretraining.
3
+
4
+ BOS-aligned bestfit:
5
+ - Every row starts with BOS token
6
+ - Documents packed using best-fit algorithm to minimize cropping
7
+ - When no document fits remaining space, crops a document to fill exactly
8
+ - 100% utilization (no padding), ~35% tokens cropped at T=2048
9
+
10
+ Compared to the original tokenizing_distributed_data_loader:
11
+ BOS-aligned loses ~35% of tokens to cropping, but ensures that
12
+ there are fewer "confusing" tokens in the train/val batches as every token can
13
+ now attend back to the BOS token and sees the full context of the document.
14
+
15
+ Fallback to the original if you have very limited data AND long documents:
16
+ https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
17
+ """
18
+
19
+ import torch
20
+ import pyarrow.parquet as pq
21
+
22
+ from nanochat.common import get_dist_info
23
+ from nanochat.dataset import list_parquet_files
24
+
25
+ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
26
+ """
27
+ Infinite iterator over document batches (list of text strings) from parquet files.
28
+
29
+ Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
30
+ where text_batch is a list of document strings, indices track position for resumption,
31
+ and epoch counts how many times we've cycled through the dataset (starts at 1).
32
+ """
33
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
34
+
35
+ warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy
36
+ parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy)
37
+ assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
38
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
39
+
40
+ resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
41
+ resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
42
+ resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
43
+ first_pass = True
44
+ pq_idx = resume_pq_idx
45
+ epoch = resume_epoch
46
+
47
+ while True: # iterate infinitely (multi-epoch)
48
+ pq_idx = resume_pq_idx if first_pass else 0
49
+ while pq_idx < len(parquet_paths):
50
+ filepath = parquet_paths[pq_idx]
51
+ pf = pq.ParquetFile(filepath)
52
+ # Start from resume point if resuming on same file, otherwise from DDP rank
53
+ if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
54
+ base_idx = resume_rg_idx // ddp_world_size
55
+ base_idx += 1 # advance by 1 so we don't repeat data after resuming
56
+ rg_idx = base_idx * ddp_world_size + ddp_rank
57
+ if rg_idx >= pf.num_row_groups:
58
+ pq_idx += 1
59
+ continue
60
+ resume_rg_idx = None # only do this once
61
+ else:
62
+ rg_idx = ddp_rank
63
+ while rg_idx < pf.num_row_groups:
64
+ rg = pf.read_row_group(rg_idx)
65
+ batch = rg.column('text').to_pylist()
66
+ for i in range(0, len(batch), tokenizer_batch_size):
67
+ yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
68
+ rg_idx += ddp_world_size
69
+ pq_idx += 1
70
+ first_pass = False
71
+ epoch += 1
72
+
73
+
74
+ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
75
+ tokenizer, B, T, split,
76
+ tokenizer_threads=4, tokenizer_batch_size=128,
77
+ device="cuda", resume_state_dict=None,
78
+ buffer_size=1000
79
+ ):
80
+ """
81
+ BOS-aligned dataloader with Best-Fit Cropping.
82
+
83
+ Reduces token waste compared to simple greedy cropping by searching a buffer
84
+ for documents that fit well, while maintaining 100% utilization (no padding).
85
+
86
+ Algorithm for each row:
87
+ 1. From buffered docs, pick the LARGEST doc that fits entirely
88
+ 2. Repeat until no doc fits
89
+ 3. When nothing fits, crop a doc to fill remaining space exactly
90
+
91
+ Key properties:
92
+ - Every row starts with BOS
93
+ - 100% utilization (no padding, every token is trained on)
94
+ - Approximately 35% of all tokens are discarded due to cropping
95
+ """
96
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
97
+
98
+ row_capacity = T + 1
99
+ batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
100
+ bos_token = tokenizer.get_bos_token_id()
101
+ doc_buffer = []
102
+ pq_idx, rg_idx, epoch = 0, 0, 1
103
+
104
+ def refill_buffer():
105
+ nonlocal pq_idx, rg_idx, epoch
106
+ doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
107
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
108
+ for tokens in token_lists:
109
+ doc_buffer.append(tokens)
110
+
111
+ # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
112
+ # This gives us contiguous views and a single HtoD transfer
113
+ use_cuda = device == "cuda"
114
+ row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
115
+ cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
116
+ gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
117
+ cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
118
+ cpu_targets = cpu_buffer[B * T:].view(B, T)
119
+ inputs = gpu_buffer[:B * T].view(B, T)
120
+ targets = gpu_buffer[B * T:].view(B, T)
121
+
122
+ while True:
123
+ for row_idx in range(B):
124
+ pos = 0
125
+ while pos < row_capacity:
126
+ # Ensure buffer has documents
127
+ while len(doc_buffer) < buffer_size:
128
+ refill_buffer()
129
+
130
+ remaining = row_capacity - pos
131
+
132
+ # Find largest doc that fits entirely
133
+ best_idx = -1
134
+ best_len = 0
135
+ for i, doc in enumerate(doc_buffer):
136
+ doc_len = len(doc)
137
+ if doc_len <= remaining and doc_len > best_len:
138
+ best_idx = i
139
+ best_len = doc_len
140
+
141
+ if best_idx >= 0:
142
+ doc = doc_buffer.pop(best_idx)
143
+ doc_len = len(doc)
144
+ row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
145
+ pos += doc_len
146
+ else:
147
+ # No doc fits - crop shortest in buffer to fill remaining and minimize waste
148
+ shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
149
+ doc = doc_buffer.pop(shortest_idx)
150
+ row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
151
+ pos += remaining
152
+
153
+ # Copy to pinned CPU buffer, then single HtoD transfer
154
+ cpu_inputs.copy_(row_buffer[:, :-1])
155
+ cpu_targets.copy_(row_buffer[:, 1:])
156
+
157
+ state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
158
+
159
+ # Single HtoD copy into persistent GPU buffer and yield
160
+ gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
161
+ yield inputs, targets, state_dict
162
+
163
+ def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
164
+ """Helper that omits state_dict from yields."""
165
+ for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
166
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
24
+ MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data_climbmix")
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # These functions are useful utilities to other modules, can/should be imported
31
+
32
+ def list_parquet_files(data_dir=None, warn_on_legacy=False):
33
+ """ Looks into a data dir and returns full paths to all parquet files. """
34
+ data_dir = DATA_DIR if data_dir is None else data_dir
35
+
36
+ # Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B
37
+ # This code will eventually be deleted.
38
+ if not os.path.exists(data_dir):
39
+ if warn_on_legacy:
40
+ print()
41
+ print("=" * 80)
42
+ print(" WARNING: DATASET UPGRADE REQUIRED")
43
+ print("=" * 80)
44
+ print()
45
+ print(f" Could not find: {data_dir}")
46
+ print()
47
+ print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.")
48
+ print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.")
49
+ print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:")
50
+ print()
51
+ print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired")
52
+ print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data")
53
+ print()
54
+ print(" For now, falling back to your old FinewebEdu-100B dataset...")
55
+ print("=" * 80)
56
+ print()
57
+ # attempt a fallback to the legacy data directory
58
+ data_dir = os.path.join(base_dir, "base_data")
59
+
60
+ parquet_files = sorted([
61
+ f for f in os.listdir(data_dir)
62
+ if f.endswith('.parquet') and not f.endswith('.tmp')
63
+ ])
64
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
65
+ return parquet_paths
66
+
67
+ def parquets_iter_batched(split, start=0, step=1):
68
+ """
69
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
70
+ - split can be "train" or "val". the last parquet file will be val.
71
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
72
+ """
73
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
74
+ parquet_paths = list_parquet_files()
75
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
76
+ for filepath in parquet_paths:
77
+ pf = pq.ParquetFile(filepath)
78
+ for rg_idx in range(start, pf.num_row_groups, step):
79
+ rg = pf.read_row_group(rg_idx)
80
+ texts = rg.column('text').to_pylist()
81
+ yield texts
82
+
83
+ # -----------------------------------------------------------------------------
84
+ def download_single_file(index):
85
+ """ Downloads a single file index, with some backoff """
86
+
87
+ # Construct the local filepath for this file and skip if it already exists
88
+ filename = index_to_filename(index)
89
+ filepath = os.path.join(DATA_DIR, filename)
90
+ if os.path.exists(filepath):
91
+ print(f"Skipping {filepath} (already exists)")
92
+ return True
93
+
94
+ # Construct the remote URL for this file
95
+ url = f"{BASE_URL}/{filename}"
96
+ print(f"Downloading {filename}...")
97
+
98
+ # Download with retries
99
+ max_attempts = 5
100
+ for attempt in range(1, max_attempts + 1):
101
+ try:
102
+ response = requests.get(url, stream=True, timeout=30)
103
+ response.raise_for_status()
104
+ # Write to temporary file first
105
+ temp_path = filepath + f".tmp"
106
+ with open(temp_path, 'wb') as f:
107
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
108
+ if chunk:
109
+ f.write(chunk)
110
+ # Move temp file to final location
111
+ os.rename(temp_path, filepath)
112
+ print(f"Successfully downloaded {filename}")
113
+ return True
114
+
115
+ except (requests.RequestException, IOError) as e:
116
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
117
+ # Clean up any partial files
118
+ for path in [filepath + f".tmp", filepath]:
119
+ if os.path.exists(path):
120
+ try:
121
+ os.remove(path)
122
+ except:
123
+ pass
124
+ # Try a few times with exponential backoff: 2^attempt seconds
125
+ if attempt < max_attempts:
126
+ wait_time = 2 ** attempt
127
+ print(f"Waiting {wait_time} seconds before retry...")
128
+ time.sleep(wait_time)
129
+ else:
130
+ print(f"Failed to download {filename} after {max_attempts} attempts")
131
+ return False
132
+
133
+ return False
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
138
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable")
139
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
140
+ args = parser.parse_args()
141
+
142
+ # Prepare the output directory
143
+ os.makedirs(DATA_DIR, exist_ok=True)
144
+
145
+ # The way this works is that the user specifies the number of train shards to download via the -n flag.
146
+ # In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
147
+ num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
148
+ ids_to_download = list(range(num_train_shards))
149
+ ids_to_download.append(MAX_SHARD) # always download the validation shard
150
+
151
+ # Download the shards
152
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
153
+ print(f"Target directory: {DATA_DIR}")
154
+ print()
155
+ with Pool(processes=args.num_workers) as pool:
156
+ results = pool.map(download_single_file, ids_to_download)
157
+
158
+ # Report results
159
+ successful = sum(1 for success in results if success)
160
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/engine.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init, autodetect_device_type
21
+ from nanochat.checkpoint_manager import load_model
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Calculator tool helpers
25
+ @contextmanager
26
+ def timeout(duration, formula):
27
+ def timeout_handler(signum, frame):
28
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
29
+
30
+ signal.signal(signal.SIGALRM, timeout_handler)
31
+ signal.alarm(duration)
32
+ yield
33
+ signal.alarm(0)
34
+
35
+ def eval_with_timeout(formula, max_time=3):
36
+ try:
37
+ with timeout(max_time, formula):
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter("ignore", SyntaxWarning)
40
+ return eval(formula, {"__builtins__": {}}, {})
41
+ except Exception as e:
42
+ signal.alarm(0)
43
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
44
+ return None
45
+
46
+ def use_calculator(expr):
47
+ """
48
+ Evaluate a Python expression safely.
49
+ Supports both math expressions and string operations like .count()
50
+ """
51
+ # Remove commas from numbers
52
+ expr = expr.replace(",", "")
53
+
54
+ # Check if it's a pure math expression (old behavior)
55
+ if all([x in "0123456789*+-/.() " for x in expr]):
56
+ if "**" in expr: # disallow power operator
57
+ return None
58
+ return eval_with_timeout(expr)
59
+
60
+ # Check if it's a string operation we support
61
+ # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
62
+ allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
63
+ if not all([x in allowed_chars for x in expr]):
64
+ return None
65
+
66
+ # Disallow dangerous patterns
67
+ dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
68
+ 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
69
+ 'getattr', 'setattr', 'delattr', 'hasattr']
70
+ expr_lower = expr.lower()
71
+ if any(pattern in expr_lower for pattern in dangerous_patterns):
72
+ return None
73
+
74
+ # Only allow .count() method for now (can expand later)
75
+ if '.count(' not in expr:
76
+ return None
77
+
78
+ # Evaluate with timeout
79
+ return eval_with_timeout(expr)
80
+
81
+ # -----------------------------------------------------------------------------
82
+ class KVCache:
83
+ """
84
+ KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
85
+
86
+ Key differences from FA2-style cache:
87
+ - Tensors are (B, T, H, D) not (B, H, T, D)
88
+ - FA3 updates the cache in-place during flash_attn_with_kvcache
89
+ - Position tracked per batch element via cache_seqlens tensor
90
+ """
91
+
92
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
93
+ self.batch_size = batch_size
94
+ self.max_seq_len = seq_len
95
+ self.n_layers = num_layers
96
+ self.n_heads = num_heads
97
+ self.head_dim = head_dim
98
+ # Pre-allocate cache tensors: (n_layers, B, T, H, D)
99
+ self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
100
+ self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
101
+ # Current sequence length per batch element (FA3 needs int32)
102
+ self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
103
+
104
+ def reset(self):
105
+ """Reset cache to empty state."""
106
+ self.cache_seqlens.zero_()
107
+
108
+ def get_pos(self):
109
+ """Get current position (assumes all batch elements at same position)."""
110
+ return self.cache_seqlens[0].item()
111
+
112
+ def get_layer_cache(self, layer_idx):
113
+ """Return (k_cache, v_cache) views for a specific layer."""
114
+ return self.k_cache[layer_idx], self.v_cache[layer_idx]
115
+
116
+ def advance(self, num_tokens):
117
+ """Advance the cache position by num_tokens."""
118
+ self.cache_seqlens += num_tokens
119
+
120
+ def prefill(self, other):
121
+ """
122
+ Copy cached KV from another cache into this one.
123
+ Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
124
+ """
125
+ assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
126
+ assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
127
+ assert self.max_seq_len >= other.max_seq_len
128
+ other_pos = other.get_pos()
129
+ self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
130
+ self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
131
+ self.cache_seqlens.fill_(other_pos)
132
+
133
+ # -----------------------------------------------------------------------------
134
+ @torch.inference_mode()
135
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
136
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
137
+ assert temperature >= 0.0, "temperature must be non-negative"
138
+ if temperature == 0.0:
139
+ return torch.argmax(logits, dim=-1, keepdim=True)
140
+ if top_k is not None and top_k > 0:
141
+ k = min(top_k, logits.size(-1))
142
+ vals, idx = torch.topk(logits, k, dim=-1)
143
+ vals = vals / temperature
144
+ probs = F.softmax(vals, dim=-1)
145
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
146
+ return idx.gather(1, choice)
147
+ else:
148
+ logits = logits / temperature
149
+ probs = F.softmax(logits, dim=-1)
150
+ return torch.multinomial(probs, num_samples=1, generator=rng)
151
+
152
+ # -----------------------------------------------------------------------------
153
+
154
+ class RowState:
155
+ # Per-row state tracking during generation
156
+ def __init__(self, current_tokens=None):
157
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
158
+ self.forced_tokens = deque() # Queue of tokens to force inject
159
+ self.in_python_block = False # Whether we are inside a python block
160
+ self.python_expr_tokens = [] # Tokens of the current python expression
161
+ self.completed = False # Whether this row has completed generation
162
+
163
+ class Engine:
164
+
165
+ def __init__(self, model, tokenizer):
166
+ self.model = model
167
+ self.tokenizer = tokenizer # needed for tool use
168
+
169
+ @torch.inference_mode()
170
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
171
+ """Same as generate, but does single prefill and then clones the KV cache."""
172
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
173
+ device = self.model.get_device()
174
+ # NOTE: setting the dtype here and in this way is an ugly hack.
175
+ # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
176
+ # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
177
+ # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
178
+ # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
179
+ # In particular, the KVCache should allocate its tensors lazily
180
+ dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
181
+ rng = torch.Generator(device=device)
182
+ rng.manual_seed(seed)
183
+
184
+ # Get the special tokens we need to coordinate the tool use state machine
185
+ get_special = lambda s: self.tokenizer.encode_special(s)
186
+ python_start = get_special("<|python_start|>")
187
+ python_end = get_special("<|python_end|>")
188
+ output_start = get_special("<|output_start|>")
189
+ output_end = get_special("<|output_end|>")
190
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
191
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
192
+
193
+ # 1) Run a batch 1 prefill of the prompt tokens
194
+ m = self.model.config
195
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
196
+ kv_cache_prefill = KVCache(
197
+ batch_size=1,
198
+ seq_len=len(tokens),
199
+ device=device,
200
+ dtype=dtype,
201
+ **kv_model_kwargs,
202
+ )
203
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
204
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
205
+ logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
206
+
207
+ # 2) Replicate the KV cache for each sample/row
208
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
209
+ kv_cache_decode = KVCache(
210
+ batch_size=num_samples,
211
+ seq_len=kv_length_hint,
212
+ device=device,
213
+ dtype=dtype,
214
+ **kv_model_kwargs,
215
+ )
216
+ kv_cache_decode.prefill(kv_cache_prefill)
217
+ del kv_cache_prefill # no need to keep this memory around
218
+
219
+ # 3) Initialize states for each sample
220
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
221
+
222
+ # 4) Main generation loop
223
+ num_generated = 0
224
+ while True:
225
+ # Stop condition: we've reached max tokens
226
+ if max_tokens is not None and num_generated >= max_tokens:
227
+ break
228
+ # Stop condition: all rows are completed
229
+ if all(state.completed for state in row_states):
230
+ break
231
+
232
+ # Sample the next token for each row
233
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
234
+ sampled_tokens = next_ids[:, 0].tolist()
235
+
236
+ # Process each row: choose the next token, update state, optional tool use
237
+ token_column = [] # contains the next token id along each row
238
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
239
+ for i, state in enumerate(row_states):
240
+ # Select the next token in this row
241
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
242
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
243
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
244
+ token_column.append(next_token)
245
+ # Update the state of this row to include the next token
246
+ state.current_tokens.append(next_token)
247
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
248
+ if next_token == assistant_end or next_token == bos:
249
+ state.completed = True
250
+ # Handle tool logic
251
+ if next_token == python_start:
252
+ state.in_python_block = True
253
+ state.python_expr_tokens = []
254
+ elif next_token == python_end and state.in_python_block:
255
+ state.in_python_block = False
256
+ if state.python_expr_tokens:
257
+ expr = self.tokenizer.decode(state.python_expr_tokens)
258
+ result = use_calculator(expr)
259
+ if result is not None:
260
+ result_tokens = self.tokenizer.encode(str(result))
261
+ state.forced_tokens.append(output_start)
262
+ state.forced_tokens.extend(result_tokens)
263
+ state.forced_tokens.append(output_end)
264
+ state.python_expr_tokens = []
265
+ elif state.in_python_block:
266
+ state.python_expr_tokens.append(next_token)
267
+
268
+ # Yield the token column
269
+ yield token_column, token_masks
270
+ num_generated += 1
271
+
272
+ # Prepare logits for next iteration
273
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
274
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
275
+
276
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
277
+ """
278
+ Non-streaming batch generation that just returns the final token sequences.
279
+ Returns a list of token sequences (list of lists of ints).
280
+ Terminal tokens (assistant_end, bos) are not included in the results.
281
+ """
282
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
283
+ bos = self.tokenizer.get_bos_token_id()
284
+ results = [tokens.copy() for _ in range(num_samples)]
285
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
286
+ completed = [False] * num_samples
287
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
288
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
289
+ if not completed[i]:
290
+ if token == assistant_end or token == bos:
291
+ completed[i] = True
292
+ else:
293
+ results[i].append(token)
294
+ masks[i].append(mask)
295
+ # Stop if all rows are completed
296
+ if all(completed):
297
+ break
298
+ return results, masks
299
+
300
+
301
+ if __name__ == "__main__":
302
+ """
303
+ Quick inline test to make sure that the naive/slow model.generate function
304
+ is equivalent to the faster Engine.generate function here.
305
+ """
306
+ import time
307
+ # init compute
308
+ device_type = autodetect_device_type()
309
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
310
+ # load the model and tokenizer
311
+ model, tokenizer, meta = load_model("base", device, phase="eval")
312
+ bos_token_id = tokenizer.get_bos_token_id()
313
+ # common hyperparameters
314
+ kwargs = dict(max_tokens=64, temperature=0.0)
315
+ # set the starting prompt
316
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
317
+ # generate the reference sequence using the model.generate() function
318
+ generated_tokens = []
319
+ torch.cuda.synchronize()
320
+ t0 = time.time()
321
+ stream = model.generate(prompt_tokens, **kwargs)
322
+ for token in stream:
323
+ generated_tokens.append(token)
324
+ chunk = tokenizer.decode([token])
325
+ print(chunk, end="", flush=True)
326
+ print()
327
+ torch.cuda.synchronize()
328
+ t1 = time.time()
329
+ print(f"Reference time: {t1 - t0:.2f}s")
330
+ reference_ids = generated_tokens
331
+ # generate tokens with Engine
332
+ generated_tokens = []
333
+ engine = Engine(model, tokenizer)
334
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
335
+ torch.cuda.synchronize()
336
+ t0 = time.time()
337
+ for token_column, token_masks in stream:
338
+ token = token_column[0] # only print out the first row
339
+ generated_tokens.append(token)
340
+ chunk = tokenizer.decode([token])
341
+ print(chunk, end="", flush=True)
342
+ print()
343
+ torch.cuda.synchronize()
344
+ t1 = time.time()
345
+ print(f"Engine time: {t1 - t0:.2f}s")
346
+ # compare the two sequences
347
+ for i in range(len(reference_ids)):
348
+ if reference_ids[i] != generated_tokens[i]:
349
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
350
+ break
351
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ finally:
131
+ os.chdir(cwd)
132
+
133
+
134
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
135
+ """
136
+ This disables various destructive functions and prevents the generated code
137
+ from interfering with the test (e.g. fork bomb, killing other processes,
138
+ removing filesystem files, etc.)
139
+
140
+ WARNING
141
+ This function is NOT a security sandbox. Untrusted code, including, model-
142
+ generated code, should not be blindly executed outside of one. See the
143
+ Codex paper for more information about OpenAI's code sandbox, and proceed
144
+ with caution.
145
+ """
146
+
147
+ if platform.uname().system != "Darwin":
148
+ # These resource limit calls seem to fail on macOS (Darwin), skip?
149
+ import resource
150
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
151
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
152
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
153
+
154
+ faulthandler.disable()
155
+
156
+ import builtins
157
+
158
+ builtins.exit = None
159
+ builtins.quit = None
160
+
161
+ import os
162
+
163
+ os.environ["OMP_NUM_THREADS"] = "1"
164
+
165
+ os.kill = None
166
+ os.system = None
167
+ os.putenv = None
168
+ os.remove = None
169
+ os.removedirs = None
170
+ os.rmdir = None
171
+ os.fchdir = None
172
+ os.setuid = None
173
+ os.fork = None
174
+ os.forkpty = None
175
+ os.killpg = None
176
+ os.rename = None
177
+ os.renames = None
178
+ os.truncate = None
179
+ os.replace = None
180
+ os.unlink = None
181
+ os.fchmod = None
182
+ os.fchown = None
183
+ os.chmod = None
184
+ os.chown = None
185
+ os.chroot = None
186
+ os.fchdir = None
187
+ os.lchflags = None
188
+ os.lchmod = None
189
+ os.lchown = None
190
+ os.getcwd = None
191
+ os.chdir = None
192
+
193
+ import shutil
194
+
195
+ shutil.rmtree = None
196
+ shutil.move = None
197
+ shutil.chown = None
198
+
199
+ import subprocess
200
+
201
+ subprocess.Popen = None # type: ignore
202
+
203
+ __builtins__["help"] = None
204
+
205
+ import sys
206
+
207
+ sys.modules["ipdb"] = None
208
+ sys.modules["joblib"] = None
209
+ sys.modules["resource"] = None
210
+ sys.modules["psutil"] = None
211
+ sys.modules["tkinter"] = None
212
+
213
+
214
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
215
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
216
+ with create_tempdir():
217
+
218
+ # These system calls are needed when cleaning up tempdir.
219
+ import os
220
+ import shutil
221
+
222
+ rmtree = shutil.rmtree
223
+ rmdir = os.rmdir
224
+ chdir = os.chdir
225
+ unlink = os.unlink
226
+
227
+ # Disable functionalities that can make destructive changes to the test.
228
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
229
+
230
+ # Default to failure
231
+ result_dict.update({
232
+ "success": False,
233
+ "stdout": "",
234
+ "stderr": "",
235
+ "timeout": False,
236
+ "memory_exceeded": False,
237
+ "error": None,
238
+ })
239
+
240
+ try:
241
+ exec_globals = {}
242
+ with capture_io() as (stdout_capture, stderr_capture):
243
+ with time_limit(timeout):
244
+ # WARNING
245
+ # This program exists to execute untrusted model-generated code. Although
246
+ # it is highly unlikely that model-generated code will do something overtly
247
+ # malicious in response to this test suite, model-generated code may act
248
+ # destructively due to a lack of model capability or alignment.
249
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
250
+ # does not perform destructive actions on their host or network. For more
251
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
252
+ # Once you have read this disclaimer and taken appropriate precautions,
253
+ # uncomment the following line and proceed at your own risk:
254
+ exec(code, exec_globals)
255
+
256
+ result_dict.update({
257
+ "success": True,
258
+ "stdout": stdout_capture.getvalue(),
259
+ "stderr": stderr_capture.getvalue(),
260
+ })
261
+
262
+ except TimeoutException:
263
+ result_dict.update({
264
+ "timeout": True,
265
+ "error": "Execution timed out",
266
+ })
267
+
268
+ except MemoryError as e:
269
+ result_dict.update({
270
+ "memory_exceeded": True,
271
+ "error": f"Memory limit exceeded: {e}",
272
+ })
273
+
274
+ except BaseException as e:
275
+ result_dict.update({
276
+ "error": f"{type(e).__name__}: {e}",
277
+ })
278
+
279
+ # Needed for cleaning up.
280
+ shutil.rmtree = rmtree
281
+ os.rmdir = rmdir
282
+ os.chdir = chdir
283
+ os.unlink = unlink
284
+
285
+
286
+ def execute_code(
287
+ code: str,
288
+ timeout: float = 5.0, # 5 seconds default
289
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
290
+ ) -> ExecutionResult:
291
+ """
292
+ Execute Python code in a sandboxed environment.
293
+
294
+ Args:
295
+ code: Python code to execute as a string
296
+ timeout: Maximum execution time in seconds (default: 5.0)
297
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
298
+
299
+ Returns:
300
+ ExecutionResult with success status, stdout/stderr, and error information
301
+
302
+ Example:
303
+ >>> result = execute_code("print('hello world')")
304
+ >>> result.success
305
+ True
306
+ >>> result.stdout
307
+ 'hello world\\n'
308
+ """
309
+
310
+ manager = multiprocessing.Manager()
311
+ result_dict = manager.dict()
312
+
313
+ p = multiprocessing.Process(
314
+ target=_unsafe_execute,
315
+ args=(code, timeout, maximum_memory_bytes, result_dict)
316
+ )
317
+ p.start()
318
+ p.join(timeout=timeout + 1)
319
+
320
+ if p.is_alive():
321
+ p.kill()
322
+ return ExecutionResult(
323
+ success=False,
324
+ stdout="",
325
+ stderr="",
326
+ error="Execution timed out (process killed)",
327
+ timeout=True,
328
+ memory_exceeded=False,
329
+ )
330
+
331
+ if not result_dict:
332
+ return ExecutionResult(
333
+ success=False,
334
+ stdout="",
335
+ stderr="",
336
+ error="Execution failed (no result returned)",
337
+ timeout=True,
338
+ memory_exceeded=False,
339
+ )
340
+
341
+ return ExecutionResult(
342
+ success=result_dict["success"],
343
+ stdout=result_dict["stdout"],
344
+ stderr=result_dict["stderr"],
345
+ error=result_dict["error"],
346
+ timeout=result_dict["timeout"],
347
+ memory_exceeded=result_dict["memory_exceeded"],
348
+ )
349
+
nanochat/flash_attention.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Flash Attention interface with automatic FA3/SDPA switching.
3
+
4
+ Exports `flash_attn` module that matches the FA3 API exactly, but falls back
5
+ to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
6
+
7
+ Usage (drop-in replacement for FA3):
8
+ from nanochat.flash_attention import flash_attn
9
+
10
+ # Training (no KV cache)
11
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
12
+
13
+ # Inference (with KV cache)
14
+ y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
15
+ """
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # =============================================================================
21
+ # Detection: Try to load FA3 on Hopper+ GPUs
22
+ # =============================================================================
23
+ def _load_flash_attention_3():
24
+ """Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
25
+ if not torch.cuda.is_available():
26
+ return None
27
+ try:
28
+ major, _ = torch.cuda.get_device_capability()
29
+ # FA3 kernels are compiled for Hopper (sm90) only
30
+ # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
31
+ if major != 9:
32
+ return None
33
+ import os
34
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
35
+ from kernels import get_kernel
36
+ return get_kernel('varunneal/flash-attention-3').flash_attn_interface
37
+ except Exception:
38
+ return None
39
+
40
+
41
+ _fa3 = _load_flash_attention_3()
42
+ HAS_FA3 = _fa3 is not None
43
+
44
+ # Override for testing: set to 'fa3', 'sdpa', or None (auto)
45
+ _override_impl = None
46
+
47
+
48
+ def _resolve_use_fa3():
49
+ """Decide once whether to use FA3, based on availability, override, and dtype."""
50
+ if _override_impl == 'fa3':
51
+ assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
52
+ return True
53
+ if _override_impl == 'sdpa':
54
+ return False
55
+ if HAS_FA3:
56
+ # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
57
+ from nanochat.common import COMPUTE_DTYPE
58
+ if COMPUTE_DTYPE == torch.bfloat16:
59
+ return True
60
+ return False
61
+ return False
62
+
63
+ USE_FA3 = _resolve_use_fa3()
64
+
65
+
66
+ # =============================================================================
67
+ # SDPA helpers
68
+ # =============================================================================
69
+ def _sdpa_attention(q, k, v, window_size, enable_gqa):
70
+ """
71
+ SDPA attention with sliding window support.
72
+ q, k, v are (B, H, T, D) format.
73
+ """
74
+ Tq = q.size(2)
75
+ Tk = k.size(2)
76
+ window = window_size[0]
77
+
78
+ # Full context, same length
79
+ if (window < 0 or window >= Tq) and Tq == Tk:
80
+ return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
81
+
82
+ # Single token generation
83
+ if Tq == 1:
84
+ if window >= 0 and window < Tk:
85
+ # window is "left" tokens we need to include (window + 1) keys total
86
+ start = max(0, Tk - (window + 1))
87
+ k = k[:, :, start:, :]
88
+ v = v[:, :, start:, :]
89
+ return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
90
+
91
+ # Need explicit mask for sliding window/chunk inference
92
+ device = q.device
93
+ # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
94
+ row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
95
+ col_idx = torch.arange(Tk, device=device).unsqueeze(0)
96
+ mask = col_idx <= row_idx
97
+
98
+ # sliding window (left)
99
+ if window >= 0 and window < Tk:
100
+ mask = mask & ((row_idx - col_idx) <= window)
101
+
102
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
103
+
104
+ # =============================================================================
105
+ # Public API: Same interface as FA3
106
+ # =============================================================================
107
+ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
108
+ """
109
+ Flash Attention for training (no KV cache).
110
+
111
+ Args:
112
+ q, k, v: Tensors of shape (B, T, H, D)
113
+ causal: Whether to use causal masking
114
+ window_size: (left, right) sliding window. -1 means unlimited.
115
+
116
+ Returns:
117
+ Output tensor of shape (B, T, H, D)
118
+ """
119
+ if USE_FA3:
120
+ return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
121
+
122
+ # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
123
+ q = q.transpose(1, 2)
124
+ k = k.transpose(1, 2)
125
+ v = v.transpose(1, 2)
126
+ enable_gqa = q.size(1) != k.size(1)
127
+ y = _sdpa_attention(q, k, v, window_size, enable_gqa)
128
+ return y.transpose(1, 2) # back to (B, T, H, D)
129
+
130
+
131
+ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
132
+ causal=False, window_size=(-1, -1)):
133
+ """
134
+ Flash Attention with KV cache for inference.
135
+
136
+ FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
137
+
138
+ Args:
139
+ q: Queries, shape (B, T_new, H, D)
140
+ k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
141
+ k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
142
+ cache_seqlens: Current position in cache, shape (B,) int32
143
+ causal: Whether to use causal masking
144
+ window_size: (left, right) sliding window. -1 means unlimited.
145
+
146
+ Returns:
147
+ Output tensor of shape (B, T_new, H, D)
148
+ """
149
+ if USE_FA3:
150
+ return _fa3.flash_attn_with_kvcache(
151
+ q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
152
+ causal=causal, window_size=window_size
153
+ )
154
+
155
+ # SDPA fallback: manually manage KV cache
156
+ B, T_new, H, D = q.shape
157
+ pos = cache_seqlens[0].item() # assume uniform position across batch
158
+
159
+ # Insert new k, v into cache (in-place, matching FA3 behavior)
160
+ if k is not None and v is not None:
161
+ k_cache[:, pos:pos+T_new, :, :] = k
162
+ v_cache[:, pos:pos+T_new, :, :] = v
163
+
164
+ # Get full cache up to current position + new tokens
165
+ end_pos = pos + T_new
166
+ k_full = k_cache[:, :end_pos, :, :]
167
+ v_full = v_cache[:, :end_pos, :, :]
168
+
169
+ # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
170
+ q_sdpa = q.transpose(1, 2)
171
+ k_sdpa = k_full.transpose(1, 2)
172
+ v_sdpa = v_full.transpose(1, 2)
173
+
174
+ enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
175
+ y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
176
+
177
+ return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
178
+
179
+
180
+ # =============================================================================
181
+ # Export: flash_attn module interface (drop-in replacement for FA3)
182
+ # =============================================================================
183
+ from types import SimpleNamespace
184
+ flash_attn = SimpleNamespace(
185
+ flash_attn_func=flash_attn_func,
186
+ flash_attn_with_kvcache=flash_attn_with_kvcache,
187
+ )
nanochat/fp8.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
2
+
3
+ Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
4
+ We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
5
+ generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
6
+ subclass dispatch tables, etc.)
7
+
8
+ How FP8 training works
9
+ ======================
10
+ A standard Linear layer does one matmul in forward and two in backward:
11
+ forward: output = input @ weight.T
12
+ backward: grad_input = grad_output @ weight
13
+ grad_weight= grad_output.T @ input
14
+
15
+ FP8 training wraps each of these three matmuls with:
16
+ 1. Compute scale = FP8_MAX / max(|tensor|) for each operand
17
+ 2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
18
+ 3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
19
+ 4. Dequantize: _scaled_mm handles this internally using the inverse scales
20
+
21
+ The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
22
+ torchao is just orchestration around these primitives. We can call them directly.
23
+
24
+ FP8 dtype choice
25
+ ================
26
+ There are two FP8 formats. We use both, following the standard convention:
27
+ - float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
28
+ Higher precision (more mantissa bits), used for input and weight.
29
+ - float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
30
+ Wider range (more exponent bits), used for gradients which can be large.
31
+
32
+ torch._scaled_mm layout requirements
33
+ =====================================
34
+ The cuBLAS FP8 kernel requires specific memory layouts:
35
+ - First argument (A): must be row-major (contiguous)
36
+ - Second argument (B): must be column-major (B.t().contiguous().t())
37
+ If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
38
+ already column-major — no copy needed. Otherwise we use _to_col_major().
39
+
40
+ How this differs from torchao's approach
41
+ ========================================
42
+ torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
43
+ of torch.Tensor that bundles FP8 data + scale + metadata. It implements
44
+ __torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
45
+ reshape, clone, ...) and handles it in FP8-aware fashion. When you call
46
+ output = input @ weight.T
47
+ the @ operator dispatches to aten.mm, which gets intercepted and routed to
48
+ torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
49
+ a handler for every tensor operation that might touch an FP8 tensor.
50
+
51
+ We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
52
+ full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
53
+ full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
54
+ opaque node rather than trying to trace inside.
55
+
56
+ The trade-off is in how torch.compile sees the two approaches:
57
+ - torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
58
+ sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
59
+ nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
60
+ amax computation with the preceding layer's activation function).
61
+ - ours: compile sees a single opaque call. It can optimize everything around
62
+ the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
63
+
64
+ Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical.
65
+ The difference is only in the "glue" ops (amax, scale, cast) which are tiny
66
+ compared to the matmul. In practice this means our version is slightly faster
67
+ (less compilation overhead, no tensor subclass dispatch cost) but can produce
68
+ subtly different floating-point rounding paths under torch.compile, since Inductor
69
+ generates a different graph. Numerics are bitwise identical in eager mode.
70
+ """
71
+
72
+ import torch
73
+ import torch.nn as nn
74
+
75
+ from nanochat.common import COMPUTE_DTYPE
76
+
77
+ # Avoid division by zero when computing scale from an all-zeros tensor
78
+ EPS = 1e-12
79
+
80
+
81
+ @torch.no_grad()
82
+ def _to_fp8(x, fp8_dtype):
83
+ """Dynamically quantize a tensor to FP8 using tensorwise scaling.
84
+
85
+ "Tensorwise" means one scalar scale for the entire tensor (as opposed to
86
+ "rowwise" which computes a separate scale per row). Tensorwise is faster
87
+ because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
88
+
89
+ Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
90
+ """
91
+ fp8_max = torch.finfo(fp8_dtype).max
92
+ # Compute the max absolute value across the entire tensor
93
+ amax = x.float().abs().max()
94
+ # Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
95
+ # ensure consistent numerics between torch.compile and eager mode.
96
+ # (torchao does the same upcast — without it, compile/eager can diverge)
97
+ scale = fp8_max / amax.double().clamp(min=EPS)
98
+ scale = scale.float()
99
+ # Quantize: scale into FP8 range, saturate (clamp prevents overflow when
100
+ # casting — PyTorch's default is to wrap, not saturate), then cast to FP8
101
+ x_scaled = x.float() * scale
102
+ x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
103
+ x_fp8 = x_clamped.to(fp8_dtype)
104
+ # _scaled_mm expects the *inverse* of our scale (it multiplies by this to
105
+ # convert FP8 values back to the original range during the matmul)
106
+ inv_scale = scale.reciprocal()
107
+ return x_fp8, inv_scale
108
+
109
+
110
+ def _to_col_major(x):
111
+ """Rearrange a 2D tensor's memory to column-major layout.
112
+
113
+ torch._scaled_mm requires its second operand in column-major layout.
114
+ The trick: transpose -> contiguous (forces a copy in transposed order)
115
+ -> transpose back. The result has the same logical shape but column-major
116
+ strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
117
+ """
118
+ return x.t().contiguous().t()
119
+
120
+
121
+ # allow_in_graph tells torch.compile to treat this as an opaque operation —
122
+ # dynamo won't try to decompose it into smaller ops. See the module docstring
123
+ # for how this differs from torchao's tensor subclass approach.
124
+ @torch._dynamo.allow_in_graph
125
+ class _Float8Matmul(torch.autograd.Function):
126
+ """Custom autograd for the three FP8 GEMMs of a Linear layer.
127
+
128
+ The forward quantizes input and weight to FP8 and saves
129
+ the quantized tensors + scales for backward.
130
+ """
131
+
132
+ @staticmethod
133
+ def forward(ctx, input_2d, weight):
134
+ # Quantize both operands to e4m3 (higher precision format)
135
+ input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
136
+ weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
137
+ ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
138
+
139
+ # output = input @ weight.T
140
+ # input_fp8 is [B, K] contiguous = row-major (good for first arg)
141
+ # weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
142
+ # strides (1, K) = column-major (good for second arg, no copy needed!)
143
+ output = torch._scaled_mm(
144
+ input_fp8,
145
+ weight_fp8.t(),
146
+ scale_a=input_inv,
147
+ scale_b=weight_inv,
148
+ out_dtype=input_2d.dtype,
149
+ # use_fast_accum=True accumulates the dot products in lower precision.
150
+ # Slightly less accurate but measurably faster. Standard practice for
151
+ # the forward pass; we use False in backward for more precise gradients.
152
+ use_fast_accum=True,
153
+ )
154
+ return output
155
+
156
+ @staticmethod
157
+ def backward(ctx, grad_output):
158
+ in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
159
+
160
+ # === GEMM 1: grad_input = grad_output @ weight ===
161
+ # Shapes: [B, N] @ [N, K] -> [B, K]
162
+ # Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
163
+ go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
164
+ # go_fp8 is [B, N] contiguous = row-major, good for first arg
165
+ # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
166
+ w_col = _to_col_major(w_fp8)
167
+ grad_input = torch._scaled_mm(
168
+ go_fp8,
169
+ w_col,
170
+ scale_a=go_inv,
171
+ scale_b=w_inv,
172
+ out_dtype=grad_output.dtype,
173
+ use_fast_accum=False,
174
+ )
175
+
176
+ # === GEMM 2: grad_weight = grad_output.T @ input ===
177
+ # Shapes: [N, B] @ [B, K] -> [N, K]
178
+ # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
179
+ # Transposing gives column-major, but first arg needs row-major,
180
+ # so we must call .contiguous() to physically rearrange the memory.
181
+ go_T = go_fp8.t().contiguous() # [N, B] row-major
182
+ in_col = _to_col_major(in_fp8) # [B, K] column-major
183
+ grad_weight = torch._scaled_mm(
184
+ go_T,
185
+ in_col,
186
+ scale_a=go_inv,
187
+ scale_b=in_inv,
188
+ out_dtype=grad_output.dtype,
189
+ use_fast_accum=False,
190
+ )
191
+
192
+ return grad_input, grad_weight
193
+
194
+
195
+ class Float8Linear(nn.Linear):
196
+ """Drop-in nn.Linear replacement that does FP8 compute.
197
+
198
+ Weights and biases remain in their original precision (e.g. fp32/bf16).
199
+ Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
200
+ """
201
+
202
+ def forward(self, input):
203
+ # Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
204
+ # reduced precision input, and we no longer rely on autocast to do this.
205
+ input = input.to(COMPUTE_DTYPE)
206
+ # _scaled_mm only works on 2D tensors, so flatten batch dimensions
207
+ orig_shape = input.shape
208
+ input_2d = input.reshape(-1, orig_shape[-1])
209
+ output = _Float8Matmul.apply(input_2d, self.weight)
210
+ output = output.reshape(*orig_shape[:-1], output.shape[-1])
211
+ if self.bias is not None:
212
+ output = output + self.bias.to(output.dtype)
213
+ return output
214
+
215
+ @classmethod
216
+ def from_float(cls, mod):
217
+ """Create Float8Linear from nn.Linear, sharing the same weight and bias.
218
+
219
+ Uses meta device to avoid allocating a temporary weight tensor — we
220
+ create the module shell on meta (shapes/dtypes only, no memory), then
221
+ point .weight and .bias to the original module's parameters.
222
+ """
223
+ with torch.device("meta"):
224
+ new_mod = cls(mod.in_features, mod.out_features, bias=False)
225
+ new_mod.weight = mod.weight
226
+ new_mod.bias = mod.bias
227
+ return new_mod
228
+
229
+
230
+ class Float8LinearConfig:
231
+ """Minimal config matching torchao's API. Only tensorwise recipe is supported."""
232
+
233
+ @staticmethod
234
+ def from_recipe_name(recipe_name):
235
+ if recipe_name != "tensorwise":
236
+ raise ValueError(
237
+ f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
238
+ f"Rowwise/axiswise recipes require the full torchao library."
239
+ )
240
+ return Float8LinearConfig()
241
+
242
+
243
+ def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
244
+ """Replace nn.Linear layers with Float8Linear throughout a module.
245
+
246
+ Walks the module tree in post-order (children before parents) and swaps
247
+ each nn.Linear that passes the optional filter. The new Float8Linear shares
248
+ the original weight and bias tensors — no copies, no extra memory.
249
+
250
+ Args:
251
+ module: Root module to convert.
252
+ config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
253
+ module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
254
+ are converted. Common use: skip layers with dims not divisible by 16
255
+ (hardware requirement for FP8 matmuls on H100).
256
+ """
257
+ def _convert(mod, prefix=""):
258
+ for name, child in mod.named_children():
259
+ fqn = f"{prefix}.{name}" if prefix else name
260
+ _convert(child, fqn)
261
+ if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
262
+ if module_filter_fn is None or module_filter_fn(child, fqn):
263
+ setattr(mod, name, Float8Linear.from_float(child))
264
+
265
+ _convert(module)
266
+ return module
nanochat/gpt.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Group-Query Attention (GQA) support for more efficient inference
12
+ - Flash Attention 3 integration
13
+ """
14
+
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
23
+ from nanochat.optim import MuonAdamW, DistMuonAdamW
24
+
25
+ # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
26
+ from nanochat.flash_attention import flash_attn
27
+
28
+ @dataclass
29
+ class GPTConfig:
30
+ sequence_len: int = 2048
31
+ vocab_size: int = 32768
32
+ n_layer: int = 12
33
+ n_head: int = 6 # number of query heads
34
+ n_kv_head: int = 6 # number of key/value heads (GQA)
35
+ n_embd: int = 768
36
+ # Sliding window attention pattern string, tiled across layers. Final layer always L.
37
+ # Characters: L=long (full context), S=short (half context)
38
+ # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
39
+ window_pattern: str = "SSSL"
40
+
41
+
42
+ def norm(x):
43
+ return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
44
+
45
+ class Linear(nn.Linear):
46
+ """nn.Linear that casts weights to match input dtype in forward.
47
+ Replaces autocast: master weights stay fp32 for optimizer precision,
48
+ but matmuls run in the activation dtype (typically bf16 from embeddings)."""
49
+ def forward(self, x):
50
+ return F.linear(x, self.weight.to(dtype=x.dtype))
51
+
52
+
53
+ def has_ve(layer_idx, n_layer):
54
+ """Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
55
+ return layer_idx % 2 == (n_layer - 1) % 2
56
+
57
+ def apply_rotary_emb(x, cos, sin):
58
+ assert x.ndim == 4 # multihead attention
59
+ d = x.shape[3] // 2
60
+ x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
61
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
62
+ y2 = x1 * (-sin) + x2 * cos
63
+ return torch.cat([y1, y2], 3)
64
+
65
+ class CausalSelfAttention(nn.Module):
66
+ def __init__(self, config, layer_idx):
67
+ super().__init__()
68
+ self.layer_idx = layer_idx
69
+ self.n_head = config.n_head
70
+ self.n_kv_head = config.n_kv_head
71
+ self.n_embd = config.n_embd
72
+ self.head_dim = self.n_embd // self.n_head
73
+ assert self.n_embd % self.n_head == 0
74
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
75
+ self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
76
+ self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
77
+ self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
78
+ self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
79
+ self.ve_gate_channels = 12
80
+ self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
81
+
82
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
83
+ B, T, C = x.size()
84
+
85
+ # Project the input to get queries, keys, and values
86
+ # Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
87
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
88
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
89
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
90
+
91
+ # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
92
+ if ve is not None:
93
+ ve = ve.view(B, T, self.n_kv_head, self.head_dim)
94
+ gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
95
+ v = v + gate.unsqueeze(-1) * ve
96
+
97
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
98
+ cos, sin = cos_sin
99
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
100
+ q, k = norm(q), norm(k) # QK norm
101
+ q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
102
+ k = k * 1.15
103
+
104
+ # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
105
+ # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
106
+ if kv_cache is None:
107
+ # Training: causal attention with optional sliding window
108
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
109
+ else:
110
+ # Inference: use flash_attn_with_kvcache which handles cache management
111
+ k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
112
+ y = flash_attn.flash_attn_with_kvcache(
113
+ q, k_cache, v_cache,
114
+ k=k, v=v,
115
+ cache_seqlens=kv_cache.cache_seqlens,
116
+ causal=True,
117
+ window_size=window_size,
118
+ )
119
+ # Advance position after last layer processes
120
+ if self.layer_idx == kv_cache.n_layers - 1:
121
+ kv_cache.advance(T)
122
+
123
+ # Re-assemble the heads and project back to residual stream
124
+ y = y.contiguous().view(B, T, -1)
125
+ y = self.c_proj(y)
126
+ return y
127
+
128
+
129
+ class MLP(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
133
+ self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
134
+
135
+ def forward(self, x):
136
+ x = self.c_fc(x)
137
+ x = F.relu(x).square()
138
+ x = self.c_proj(x)
139
+ return x
140
+
141
+
142
+ class Block(nn.Module):
143
+ def __init__(self, config, layer_idx):
144
+ super().__init__()
145
+ self.attn = CausalSelfAttention(config, layer_idx)
146
+ self.mlp = MLP(config)
147
+
148
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
149
+ x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
150
+ x = x + self.mlp(norm(x))
151
+ return x
152
+
153
+
154
+ class GPT(nn.Module):
155
+ def __init__(self, config, pad_vocab_size_to=64):
156
+ """
157
+ NOTE a major footgun: this __init__ function runs in meta device context (!!)
158
+ Therefore, any calculations inside here are shapes and dtypes only, no actual data.
159
+ => We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
160
+ """
161
+ super().__init__()
162
+ self.config = config
163
+ # Compute per-layer window sizes for sliding window attention
164
+ # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
165
+ self.window_sizes = self._compute_window_sizes(config)
166
+ # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
167
+ # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
168
+ padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
169
+ if padded_vocab_size != config.vocab_size:
170
+ print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
171
+ self.transformer = nn.ModuleDict({
172
+ "wte": nn.Embedding(padded_vocab_size, config.n_embd),
173
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
174
+ })
175
+ self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
176
+ # Per-layer learnable scalars (inspired by modded-nanogpt)
177
+ # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
178
+ # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
179
+ # Separate parameters so they can have different optimizer treatment
180
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
181
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
182
+ # Value embeddings (ResFormer-style): alternating layers, last layer always included
183
+ head_dim = config.n_embd // config.n_head
184
+ kv_dim = config.n_kv_head * head_dim
185
+ self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
186
+ # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
187
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
188
+ # so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
189
+ # In the future we can dynamically grow the cache, for now it's fine.
190
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
191
+ head_dim = config.n_embd // config.n_head
192
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
193
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
194
+ self.register_buffer("sin", sin, persistent=False)
195
+
196
+ @torch.no_grad()
197
+ def init_weights(self):
198
+ """
199
+ Initialize the full model in this one function for maximum clarity.
200
+
201
+ wte (embedding): normal, std=1.0
202
+ lm_head: normal, std=0.001
203
+ for each block:
204
+ attn.c_q: uniform, std=1/sqrt(n_embd)
205
+ attn.c_k: uniform, std=1/sqrt(n_embd)
206
+ attn.c_v: uniform, std=1/sqrt(n_embd)
207
+ attn.c_proj: zeros
208
+ mlp.c_fc: uniform, std=1/sqrt(n_embd)
209
+ mlp.c_proj: zeros
210
+ """
211
+
212
+ # Embedding and unembedding
213
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
214
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
215
+
216
+ # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
217
+ n_embd = self.config.n_embd
218
+ s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
219
+ for block in self.transformer.h:
220
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
221
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
222
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
223
+ torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
224
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
225
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
226
+
227
+ # Per-layer scalars
228
+ self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
229
+ self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
230
+
231
+ # Value embeddings (init like c_v: uniform with same std)
232
+ for ve in self.value_embeds.values():
233
+ torch.nn.init.uniform_(ve.weight, -s, s)
234
+
235
+ # Gate weights init with small positive values so gates start slightly above neutral
236
+ for block in self.transformer.h:
237
+ if block.attn.ve_gate is not None:
238
+ torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
239
+
240
+ # Rotary embeddings
241
+ head_dim = self.config.n_embd // self.config.n_head
242
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
243
+ self.cos, self.sin = cos, sin
244
+
245
+ # Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
246
+ # embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
247
+ # because GradScaler cannot unscale fp16 gradients.
248
+ if COMPUTE_DTYPE != torch.float16:
249
+ self.transformer.wte.to(dtype=COMPUTE_DTYPE)
250
+ for ve in self.value_embeds.values():
251
+ ve.to(dtype=COMPUTE_DTYPE)
252
+
253
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
254
+ # TODO: bump base theta more? e.g. 100K is more common more recently
255
+ # autodetect the device from model embeddings
256
+ if device is None:
257
+ device = self.transformer.wte.weight.device
258
+ # stride the channels
259
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
260
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
261
+ # stride the time steps
262
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
263
+ # calculate the rotation frequencies at each (time, channel) pair
264
+ freqs = torch.outer(t, inv_freq)
265
+ cos, sin = freqs.cos(), freqs.sin()
266
+ cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
267
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
268
+ return cos, sin
269
+
270
+ def _compute_window_sizes(self, config):
271
+ """
272
+ Compute per-layer window sizes for sliding window attention.
273
+
274
+ Returns list of (left, right) tuples for FA3's window_size parameter:
275
+ - left: how many tokens before current position to attend to (-1 = unlimited)
276
+ - right: how many tokens after current position to attend to (0 for causal)
277
+
278
+ Pattern string is tiled across layers. Final layer always gets L (full context).
279
+ Characters: L=long (full context), S=short (half context)
280
+ """
281
+ pattern = config.window_pattern.upper()
282
+ assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
283
+ # Map characters to window sizes
284
+ long_window = config.sequence_len
285
+ short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
286
+ char_to_window = {
287
+ "L": (long_window, 0),
288
+ "S": (short_window, 0),
289
+ }
290
+ # Tile pattern across layers
291
+ window_sizes = []
292
+ for layer_idx in range(config.n_layer):
293
+ char = pattern[layer_idx % len(pattern)]
294
+ window_sizes.append(char_to_window[char])
295
+ # Final layer always gets full context
296
+ window_sizes[-1] = (long_window, 0)
297
+ return window_sizes
298
+
299
+ def get_device(self):
300
+ return self.transformer.wte.weight.device
301
+
302
+ def estimate_flops(self):
303
+ """
304
+ Return the estimated FLOPs per token for the model (forward + backward).
305
+ Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
306
+ Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
307
+ On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
308
+ With sliding windows, effective_seq_len varies per layer (capped by window size).
309
+ Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
310
+ This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
311
+ - Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
312
+ - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
313
+ """
314
+ nparams = sum(p.numel() for p in self.parameters())
315
+ # Exclude non-matmul params: embeddings and per-layer scalars
316
+ value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
317
+ nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
318
+ self.resid_lambdas.numel() + self.x0_lambdas.numel())
319
+ h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
320
+ # Sum attention FLOPs per layer, accounting for sliding window
321
+ attn_flops = 0
322
+ for window_size in self.window_sizes:
323
+ window = window_size[0] # (left, right) tuple, we use left
324
+ effective_seq = t if window < 0 else min(window, t)
325
+ attn_flops += 12 * h * q * effective_seq
326
+ num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
327
+ return num_flops_per_token
328
+
329
+ def num_scaling_params(self):
330
+ """
331
+ Return detailed parameter counts for scaling law analysis.
332
+ Different papers use different conventions:
333
+ - Kaplan et al. excluded embedding parameters
334
+ - Chinchilla included all parameters
335
+ Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
336
+ Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
337
+
338
+ Returns a dict with counts for each parameter group, so downstream analysis
339
+ can experiment with which combination gives the cleanest scaling laws.
340
+ """
341
+ # Count each group separately (mirrors the grouping in setup_optimizers)
342
+ wte = sum(p.numel() for p in self.transformer.wte.parameters())
343
+ value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
344
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
345
+ transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
346
+ scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
347
+ total = wte + value_embeds + lm_head + transformer_matrices + scalars
348
+ assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
349
+ return {
350
+ 'wte': wte,
351
+ 'value_embeds': value_embeds,
352
+ 'lm_head': lm_head,
353
+ 'transformer_matrices': transformer_matrices,
354
+ 'scalars': scalars,
355
+ 'total': total,
356
+ }
357
+
358
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
359
+ model_dim = self.config.n_embd
360
+ ddp, rank, local_rank, world_size = get_dist_info()
361
+
362
+ # Separate out all parameters into groups
363
+ matrix_params = list(self.transformer.h.parameters())
364
+ value_embeds_params = list(self.value_embeds.parameters())
365
+ embedding_params = list(self.transformer.wte.parameters())
366
+ lm_head_params = list(self.lm_head.parameters())
367
+ resid_params = [self.resid_lambdas]
368
+ x0_params = [self.x0_lambdas]
369
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
370
+
371
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
372
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
373
+ print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
374
+
375
+ # Build param_groups with all required fields explicit
376
+ param_groups = [
377
+ # AdamW groups (embeddings, lm_head, scalars)
378
+ dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
379
+ dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
380
+ dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
381
+ dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
382
+ dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
383
+ ]
384
+ # Muon groups (matrix params, grouped by shape for stacking)
385
+ for shape in sorted({p.shape for p in matrix_params}):
386
+ group_params = [p for p in matrix_params if p.shape == shape]
387
+ param_groups.append(dict(
388
+ kind='muon', params=group_params, lr=matrix_lr,
389
+ momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
390
+ ))
391
+
392
+ Factory = DistMuonAdamW if ddp else MuonAdamW
393
+ optimizer = Factory(param_groups)
394
+ for group in optimizer.param_groups:
395
+ group["initial_lr"] = group["lr"]
396
+ return optimizer
397
+
398
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
399
+ B, T = idx.size()
400
+
401
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
402
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
403
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
404
+ assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
405
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
406
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
407
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
408
+
409
+ # Forward the trunk of the Transformer
410
+ x = self.transformer.wte(idx) # embed current token
411
+ x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
412
+ x = norm(x)
413
+ x0 = x # save initial normalized embedding for x0 residual
414
+ for i, block in enumerate(self.transformer.h):
415
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
416
+ ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
417
+ x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
418
+ x = norm(x)
419
+
420
+ # Forward the lm_head (compute logits)
421
+ softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
422
+ logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
423
+ logits = logits[..., :self.config.vocab_size] # slice to remove padding
424
+ logits = logits.float() # switch to fp32 for logit softcap and loss computation
425
+ logits = softcap * torch.tanh(logits / softcap) # squash the logits
426
+
427
+ if targets is not None:
428
+ # training: given the targets, compute and return the loss
429
+ # TODO experiment with chunked cross-entropy?
430
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
431
+ return loss
432
+ else:
433
+ # inference: just return the logits directly
434
+ return logits
435
+
436
+ @torch.inference_mode()
437
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
438
+ """
439
+ Naive autoregressive streaming inference.
440
+ To make it super simple, let's assume:
441
+ - batch size is 1
442
+ - ids and the yielded tokens are simple Python lists and ints
443
+ """
444
+ assert isinstance(tokens, list)
445
+ device = self.get_device()
446
+ rng = None
447
+ if temperature > 0:
448
+ rng = torch.Generator(device=device)
449
+ rng.manual_seed(seed)
450
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
451
+ for _ in range(max_tokens):
452
+ logits = self.forward(ids) # (B, T, vocab_size)
453
+ logits = logits[:, -1, :] # (B, vocab_size)
454
+ if top_k is not None and top_k > 0:
455
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
456
+ logits[logits < v[:, [-1]]] = -float('Inf')
457
+ if temperature > 0:
458
+ logits = logits / temperature
459
+ probs = F.softmax(logits, dim=-1)
460
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
461
+ else:
462
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
463
+ ids = torch.cat((ids, next_ids), dim=1)
464
+ token = next_ids.item()
465
+ yield token
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-independent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and independently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ if total_bytes == 0:
63
+ return float('inf')
64
+ bpb = total_nats / (math.log(2) * total_bytes)
65
+ return bpb
nanochat/optim.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A nice and efficient mixed AdamW/Muon Combined Optimizer.
3
+ Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
4
+ Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
5
+
6
+ Addapted from: https://github.com/KellerJordan/modded-nanogpt
7
+ Further contributions from @karpathy and @chrisjmccormick.
8
+ """
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import Tensor
13
+
14
+ # -----------------------------------------------------------------------------
15
+ """
16
+ Good old AdamW optimizer, fused kernel.
17
+ https://arxiv.org/abs/1711.05101
18
+ """
19
+
20
+ @torch.compile(dynamic=False, fullgraph=True)
21
+ def adamw_step_fused(
22
+ p: Tensor, # (32768, 768) - parameter tensor
23
+ grad: Tensor, # (32768, 768) - gradient, same shape as p
24
+ exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
25
+ exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
26
+ step_t: Tensor, # () - 0-D CPU tensor, step count
27
+ lr_t: Tensor, # () - 0-D CPU tensor, learning rate
28
+ beta1_t: Tensor, # () - 0-D CPU tensor, beta1
29
+ beta2_t: Tensor, # () - 0-D CPU tensor, beta2
30
+ eps_t: Tensor, # () - 0-D CPU tensor, epsilon
31
+ wd_t: Tensor, # () - 0-D CPU tensor, weight decay
32
+ ) -> None:
33
+ """
34
+ Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
35
+ All in one compiled graph to eliminate Python overhead between ops.
36
+ The 0-D CPU tensors avoid recompilation when hyperparameter values change.
37
+ """
38
+ # Weight decay (decoupled, applied before the update)
39
+ p.mul_(1 - lr_t * wd_t)
40
+ # Update running averages (lerp_ is cleaner and fuses well)
41
+ exp_avg.lerp_(grad, 1 - beta1_t)
42
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
43
+ # Bias corrections
44
+ bias1 = 1 - beta1_t ** step_t
45
+ bias2 = 1 - beta2_t ** step_t
46
+ # Compute update and apply
47
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
48
+ step_size = lr_t / bias1
49
+ p.add_(exp_avg / denom, alpha=-step_size)
50
+
51
+ # -----------------------------------------------------------------------------
52
+ """
53
+ Muon optimizer adapted and simplified from modded-nanogpt.
54
+ https://github.com/KellerJordan/modded-nanogpt
55
+
56
+ Background:
57
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
58
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
59
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
60
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
61
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
62
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
63
+ performance at all relative to UV^T, where USV^T = G is the SVD.
64
+
65
+ Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
66
+ Polar Express Sign Method for orthogonalization.
67
+ https://arxiv.org/pdf/2505.16932
68
+ by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
69
+
70
+ NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
71
+ update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
72
+ https://arxiv.org/pdf/2510.05491
73
+
74
+ Some of the changes in nanochat implementation:
75
+ - Uses a simpler, more general approach to parameter grouping and stacking
76
+ - Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
77
+ - Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
78
+ """
79
+
80
+ # Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
81
+ # From https://arxiv.org/pdf/2505.16932
82
+ polar_express_coeffs = [
83
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
84
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
85
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
86
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
87
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
88
+ ]
89
+
90
+ @torch.compile(dynamic=False, fullgraph=True)
91
+ def muon_step_fused(
92
+ stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
93
+ stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
94
+ momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
95
+ second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
96
+ momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
97
+ lr_t: Tensor, # () - 0-D CPU tensor, learning rate
98
+ wd_t: Tensor, # () - 0-D CPU tensor, weight decay
99
+ beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
100
+ ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
101
+ red_dim: int, # -1 or -2 - reduction dimension for variance
102
+ ) -> None:
103
+ """
104
+ Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
105
+ All in one compiled graph to eliminate Python overhead between ops.
106
+ Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
107
+ """
108
+
109
+ # Nesterov momentum
110
+ momentum = momentum_t.to(stacked_grads.dtype)
111
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
112
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
113
+
114
+ # Polar express
115
+ X = g.bfloat16()
116
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
117
+ if g.size(-2) > g.size(-1): # Tall matrix
118
+ for a, b, c in polar_express_coeffs[:ns_steps]:
119
+ A = X.mT @ X
120
+ B = b * A + c * (A @ A)
121
+ X = a * X + X @ B
122
+ else: # Wide matrix (original math)
123
+ for a, b, c in polar_express_coeffs[:ns_steps]:
124
+ A = X @ X.mT
125
+ B = b * A + c * (A @ A)
126
+ X = a * X + B @ X
127
+ g = X
128
+
129
+ # Variance reduction
130
+ beta2 = beta2_t.to(g.dtype)
131
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
132
+ red_dim_size = g.size(red_dim)
133
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
134
+ v_norm = v_norm_sq.sqrt()
135
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
136
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
137
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
138
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
139
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
140
+ g = g * final_scale.to(g.dtype)
141
+
142
+ # Cautious weight decay + parameter update
143
+ lr = lr_t.to(g.dtype)
144
+ wd = wd_t.to(g.dtype)
145
+ mask = (g * stacked_params) >= 0
146
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
147
+
148
+ # -----------------------------------------------------------------------------
149
+ # Single GPU version of the MuonAdamW optimizer.
150
+ # Used mostly for reference, debugging and testing.
151
+
152
+ class MuonAdamW(torch.optim.Optimizer):
153
+ """
154
+ Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
155
+
156
+ AdamW - Fused AdamW optimizer step.
157
+
158
+ Muon - MomentUm Orthogonalized by Newton-schulz
159
+ https://kellerjordan.github.io/posts/muon/
160
+
161
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
162
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
163
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
164
+ the advantage that it can be stably run in bfloat16 on the GPU.
165
+
166
+ Some warnings:
167
+ - The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
168
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
169
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
170
+
171
+ Arguments:
172
+ param_groups: List of dicts, each containing:
173
+ - 'params': List of parameters
174
+ - 'kind': 'adamw' or 'muon'
175
+ - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
176
+ - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
177
+ """
178
+ def __init__(self, param_groups: list[dict]):
179
+ super().__init__(param_groups, defaults={})
180
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
181
+ # AdamW tensors
182
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
183
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
184
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
185
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
186
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
187
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
188
+ # Muon tensors
189
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
190
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
191
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
192
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
193
+
194
+ def _step_adamw(self, group: dict) -> None:
195
+ """
196
+ AdamW update for each param in the group individually.
197
+ Lazy init the state, fill in all 0-D tensors, call the fused kernel.
198
+ """
199
+ for p in group['params']:
200
+ if p.grad is None:
201
+ continue
202
+ grad = p.grad
203
+ state = self.state[p]
204
+
205
+ # State init
206
+ if not state:
207
+ state['step'] = 0
208
+ state['exp_avg'] = torch.zeros_like(p)
209
+ state['exp_avg_sq'] = torch.zeros_like(p)
210
+ exp_avg = state['exp_avg']
211
+ exp_avg_sq = state['exp_avg_sq']
212
+ state['step'] += 1
213
+
214
+ # Fill 0-D tensors with current values
215
+ self._adamw_step_t.fill_(state['step'])
216
+ self._adamw_lr_t.fill_(group['lr'])
217
+ self._adamw_beta1_t.fill_(group['betas'][0])
218
+ self._adamw_beta2_t.fill_(group['betas'][1])
219
+ self._adamw_eps_t.fill_(group['eps'])
220
+ self._adamw_wd_t.fill_(group['weight_decay'])
221
+
222
+ # Fused update: weight_decay -> momentum -> bias_correction -> param_update
223
+ adamw_step_fused(
224
+ p, grad, exp_avg, exp_avg_sq,
225
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
226
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
227
+ )
228
+
229
+ def _step_muon(self, group: dict) -> None:
230
+ """
231
+ Muon update for all params in the group (stacked for efficiency).
232
+ Lazy init the state, fill in all 0-D tensors, call the fused kernel.
233
+ """
234
+ params: list[Tensor] = group['params']
235
+ if not params:
236
+ return
237
+
238
+ # Get or create group-level buffers (stored in first param's state for convenience)
239
+ p = params[0]
240
+ state = self.state[p]
241
+ num_params = len(params)
242
+ shape, device, dtype = p.shape, p.device, p.dtype
243
+
244
+ # Momentum for every individual parameter
245
+ if "momentum_buffer" not in state:
246
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
247
+ momentum_buffer = state["momentum_buffer"]
248
+
249
+ # Second momentum buffer is factored, either per-row or per-column
250
+ if "second_momentum_buffer" not in state:
251
+ state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
252
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
253
+ second_momentum_buffer = state["second_momentum_buffer"]
254
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
255
+
256
+ # Stack grads and params (NOTE: this assumes all params have the same shape)
257
+ stacked_grads = torch.stack([p.grad for p in params])
258
+ stacked_params = torch.stack(params)
259
+
260
+ # Fill all the 0-D tensors with current values
261
+ self._muon_momentum_t.fill_(group["momentum"])
262
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
263
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
264
+ self._muon_wd_t.fill_(group["weight_decay"])
265
+
266
+ # Single fused kernel: momentum -> polar_express -> variance_reduction -> update
267
+ muon_step_fused(
268
+ stacked_grads,
269
+ stacked_params,
270
+ momentum_buffer,
271
+ second_momentum_buffer,
272
+ self._muon_momentum_t,
273
+ self._muon_lr_t,
274
+ self._muon_wd_t,
275
+ self._muon_beta2_t,
276
+ group["ns_steps"],
277
+ red_dim,
278
+ )
279
+
280
+ # Copy back to original params
281
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
282
+
283
+ @torch.no_grad()
284
+ def step(self):
285
+ for group in self.param_groups:
286
+ if group['kind'] == 'adamw':
287
+ self._step_adamw(group)
288
+ elif group['kind'] == 'muon':
289
+ self._step_muon(group)
290
+ else:
291
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
292
+
293
+ # -----------------------------------------------------------------------------
294
+ # Distributed version of the MuonAdamW optimizer.
295
+ # Used for training on multiple GPUs.
296
+
297
+ class DistMuonAdamW(torch.optim.Optimizer):
298
+ """
299
+ Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
300
+
301
+ See MuonAdamW for the algorithmic details of each optimizer. This class adds
302
+ distributed communication to enable multi-GPU training without PyTorch DDP.
303
+
304
+ Design Goals:
305
+ - Overlap communication with computation (async ops)
306
+ - Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
307
+ - Batch small tensors into single comm ops where possible
308
+
309
+ Communication Pattern (3-phase async):
310
+ We use a 3-phase structure to maximize overlap between communication and compute:
311
+
312
+ Phase 1: Launch all async reduce ops
313
+ - Kick off all reduce_scatter/all_reduce operations
314
+ - Don't wait - let them run in background while we continue
315
+
316
+ Phase 2: Wait for reduces, compute updates, launch gathers
317
+ - For each group: wait for its reduce, compute the update, launch gather
318
+ - By processing groups in order, earlier gathers run while later computes happen
319
+
320
+ Phase 3: Wait for gathers, copy back
321
+ - Wait for all gathers to complete
322
+ - Copy updated params back to original tensors (Muon only)
323
+
324
+ AdamW Communication (ZeRO-2 style):
325
+ - Small params (<1024 elements): all_reduce gradients, update full param on each rank.
326
+ Optimizer state is replicated but these params are tiny (scalars, biases).
327
+ - Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
328
+ only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
329
+ exp_avg_sq) is sharded - each rank only stores state for its slice.
330
+ Requires param.shape[0] divisible by world_size.
331
+
332
+ Muon Communication (stacked + chunked):
333
+ - All params in a Muon group must have the same shape (caller's responsibility).
334
+ - Stack all K params into a single (K, *shape) tensor for efficient comm.
335
+ - Divide K params across N ranks: each rank "owns" ceil(K/N) params.
336
+ - reduce_scatter the stacked grads so each rank gets its chunk.
337
+ - Each rank computes Muon update only for params it owns.
338
+ - all_gather the updated params back to all ranks.
339
+ - Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
340
+ - Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
341
+ then ignore the padding when copying back.
342
+
343
+ Buffer Reuse:
344
+ - For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
345
+ same buffer as the output for all_gather (stacked_params). This saves memory
346
+ since we don't need both buffers simultaneously.
347
+
348
+ Arguments:
349
+ param_groups: List of dicts, each containing:
350
+ - 'params': List of parameters
351
+ - 'kind': 'adamw' or 'muon'
352
+ - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
353
+ - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
354
+ """
355
+ def __init__(self, param_groups: list[dict]):
356
+ super().__init__(param_groups, defaults={})
357
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
358
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
359
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
360
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
361
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
362
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
363
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
364
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
365
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
366
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
367
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
368
+
369
+ def _reduce_adamw(self, group: dict, world_size: int) -> dict:
370
+ """Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
371
+ param_infos = {}
372
+ for p in group['params']:
373
+ grad = p.grad
374
+ if p.numel() < 1024:
375
+ # Small params: all_reduce (no scatter/gather needed)
376
+ future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
377
+ param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
378
+ else:
379
+ # Large params: reduce_scatter
380
+ assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
381
+ rank_size = grad.shape[0] // world_size
382
+ grad_slice = torch.empty_like(grad[:rank_size])
383
+ future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
384
+ param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
385
+ return dict(param_infos=param_infos)
386
+
387
+ def _reduce_muon(self, group: dict, world_size: int) -> dict:
388
+ """Launch async reduce op for Muon group. Returns info dict."""
389
+ params = group['params']
390
+ chunk_size = (len(params) + world_size - 1) // world_size
391
+ padded_num_params = chunk_size * world_size
392
+ p = params[0]
393
+ shape, device, dtype = p.shape, p.device, p.dtype
394
+
395
+ # Stack grads and zero-pad to padded_num_params
396
+ grad_stack = torch.stack([p.grad for p in params])
397
+ stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
398
+ stacked_grads[:len(params)].copy_(grad_stack)
399
+ if len(params) < padded_num_params:
400
+ stacked_grads[len(params):].zero_()
401
+
402
+ # Reduce_scatter to get this rank's chunk
403
+ grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
404
+ future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
405
+
406
+ return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
407
+
408
+ def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
409
+ """Wait for reduce, compute AdamW updates, launch gathers for large params."""
410
+ param_infos = info['param_infos']
411
+ for p in group['params']:
412
+ pinfo = param_infos[p]
413
+ pinfo['future'].wait()
414
+ grad_slice = pinfo['grad_slice']
415
+ state = self.state[p]
416
+
417
+ # For small params, operate on full param; for large, operate on slice
418
+ if pinfo['is_small']:
419
+ p_slice = p
420
+ else:
421
+ rank_size = p.shape[0] // world_size
422
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
423
+
424
+ # State init
425
+ if not state:
426
+ state['step'] = 0
427
+ state['exp_avg'] = torch.zeros_like(p_slice)
428
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
429
+ state['step'] += 1
430
+
431
+ # Fill 0-D tensors and run fused kernel
432
+ self._adamw_step_t.fill_(state['step'])
433
+ self._adamw_lr_t.fill_(group['lr'])
434
+ self._adamw_beta1_t.fill_(group['betas'][0])
435
+ self._adamw_beta2_t.fill_(group['betas'][1])
436
+ self._adamw_eps_t.fill_(group['eps'])
437
+ self._adamw_wd_t.fill_(group['weight_decay'])
438
+ adamw_step_fused(
439
+ p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
440
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
441
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
442
+ )
443
+
444
+ # Large params need all_gather
445
+ if not pinfo['is_small']:
446
+ future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
447
+ gather_list.append(dict(future=future, params=None))
448
+
449
+ def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
450
+ """Wait for reduce, compute Muon updates, launch gather."""
451
+ info['future'].wait()
452
+ params = group['params']
453
+ chunk_size = info['chunk_size']
454
+ grad_chunk = info['grad_chunk']
455
+ p = params[0]
456
+ shape, device, dtype = p.shape, p.device, p.dtype
457
+
458
+ # How many params does this rank own?
459
+ start_idx = rank * chunk_size
460
+ num_owned = min(chunk_size, max(0, len(params) - start_idx))
461
+
462
+ # Get or create group-level state
463
+ state = self.state[p]
464
+ if "momentum_buffer" not in state:
465
+ state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
466
+ if "second_momentum_buffer" not in state:
467
+ state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
468
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
469
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
470
+
471
+ # Build output buffer for all_gather
472
+ updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
473
+
474
+ if num_owned > 0:
475
+ owned_params = [params[start_idx + i] for i in range(num_owned)]
476
+ stacked_owned = torch.stack(owned_params)
477
+
478
+ # Fill 0-D tensors and run fused kernel
479
+ self._muon_momentum_t.fill_(group["momentum"])
480
+ self._muon_beta2_t.fill_(group["beta2"])
481
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
482
+ self._muon_wd_t.fill_(group["weight_decay"])
483
+ muon_step_fused(
484
+ grad_chunk[:num_owned], stacked_owned,
485
+ state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
486
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
487
+ group["ns_steps"], red_dim,
488
+ )
489
+ updated_params[:num_owned].copy_(stacked_owned)
490
+
491
+ if num_owned < chunk_size:
492
+ updated_params[num_owned:].zero_()
493
+
494
+ # Reuse stacked_grads buffer for all_gather output
495
+ stacked_params = info["stacked_grads"]
496
+ future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
497
+ gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
498
+
499
+ def _finish_gathers(self, gather_list: list) -> None:
500
+ """Wait for all gathers and copy Muon params back."""
501
+ for info in gather_list:
502
+ info["future"].wait()
503
+ if info["params"] is not None:
504
+ # Muon: copy from stacked buffer back to individual params
505
+ torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
506
+
507
+ @torch.no_grad()
508
+ def step(self):
509
+ rank = dist.get_rank()
510
+ world_size = dist.get_world_size()
511
+
512
+ # Phase 1: launch all async reduce ops
513
+ reduce_infos: list[dict] = []
514
+ for group in self.param_groups:
515
+ if group['kind'] == 'adamw':
516
+ reduce_infos.append(self._reduce_adamw(group, world_size))
517
+ elif group['kind'] == 'muon':
518
+ reduce_infos.append(self._reduce_muon(group, world_size))
519
+ else:
520
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
521
+
522
+ # Phase 2: wait for reduces, compute updates, launch gathers
523
+ gather_list: list[dict] = []
524
+ for group, info in zip(self.param_groups, reduce_infos):
525
+ if group['kind'] == 'adamw':
526
+ self._compute_adamw(group, info, gather_list, rank, world_size)
527
+ elif group['kind'] == 'muon':
528
+ self._compute_muon(group, info, gather_list, rank)
529
+ else:
530
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
531
+
532
+ # Phase 3: wait for gathers, copy back
533
+ self._finish_gathers(gather_list)
nanochat/report.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ # Return stdout if we got output (even if some files in xargs failed)
20
+ if result.stdout.strip():
21
+ return result.stdout.strip()
22
+ if result.returncode == 0:
23
+ return ""
24
+ return None
25
+ except:
26
+ return None
27
+
28
+ def get_git_info():
29
+ """Get current git commit, branch, and dirty status."""
30
+ info = {}
31
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
32
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
33
+
34
+ # Check if repo is dirty (has uncommitted changes)
35
+ status = run_command("git status --porcelain")
36
+ info['dirty'] = bool(status) if status is not None else False
37
+
38
+ # Get commit message
39
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
40
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
41
+
42
+ return info
43
+
44
+ def get_gpu_info():
45
+ """Get GPU information."""
46
+ if not torch.cuda.is_available():
47
+ return {"available": False}
48
+
49
+ num_devices = torch.cuda.device_count()
50
+ info = {
51
+ "available": True,
52
+ "count": num_devices,
53
+ "names": [],
54
+ "memory_gb": []
55
+ }
56
+
57
+ for i in range(num_devices):
58
+ props = torch.cuda.get_device_properties(i)
59
+ info["names"].append(props.name)
60
+ info["memory_gb"].append(props.total_memory / (1024**3))
61
+
62
+ # Get CUDA version
63
+ info["cuda_version"] = torch.version.cuda or "unknown"
64
+
65
+ return info
66
+
67
+ def get_system_info():
68
+ """Get system information."""
69
+ info = {}
70
+
71
+ # Basic system info
72
+ info['hostname'] = socket.gethostname()
73
+ info['platform'] = platform.system()
74
+ info['python_version'] = platform.python_version()
75
+ info['torch_version'] = torch.__version__
76
+
77
+ # CPU and memory
78
+ info['cpu_count'] = psutil.cpu_count(logical=False)
79
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
80
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
81
+
82
+ # User and environment
83
+ info['user'] = os.environ.get('USER', 'unknown')
84
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
85
+ info['working_dir'] = os.getcwd()
86
+
87
+ return info
88
+
89
+ def estimate_cost(gpu_info, runtime_hours=None):
90
+ """Estimate training cost based on GPU type and runtime."""
91
+
92
+ # Rough pricing, from Lambda Cloud
93
+ default_rate = 2.0
94
+ gpu_hourly_rates = {
95
+ "H100": 3.00,
96
+ "A100": 1.79,
97
+ "V100": 0.55,
98
+ }
99
+
100
+ if not gpu_info.get("available"):
101
+ return None
102
+
103
+ # Try to identify GPU type from name
104
+ hourly_rate = None
105
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
106
+ for gpu_type, rate in gpu_hourly_rates.items():
107
+ if gpu_type in gpu_name:
108
+ hourly_rate = rate * gpu_info["count"]
109
+ break
110
+
111
+ if hourly_rate is None:
112
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
113
+
114
+ return {
115
+ "hourly_rate": hourly_rate,
116
+ "gpu_type": gpu_name,
117
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
118
+ }
119
+
120
+ def generate_header():
121
+ """Generate the header for a training report."""
122
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
123
+
124
+ git_info = get_git_info()
125
+ gpu_info = get_gpu_info()
126
+ sys_info = get_system_info()
127
+ cost_info = estimate_cost(gpu_info)
128
+
129
+ header = f"""# nanochat training report
130
+
131
+ Generated: {timestamp}
132
+
133
+ ## Environment
134
+
135
+ ### Git Information
136
+ - Branch: {git_info['branch']}
137
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
138
+ - Message: {git_info['message']}
139
+
140
+ ### Hardware
141
+ - Platform: {sys_info['platform']}
142
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
143
+ - Memory: {sys_info['memory_gb']:.1f} GB
144
+ """
145
+
146
+ if gpu_info.get("available"):
147
+ gpu_names = ", ".join(set(gpu_info["names"]))
148
+ total_vram = sum(gpu_info["memory_gb"])
149
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
150
+ - GPU Memory: {total_vram:.1f} GB total
151
+ - CUDA Version: {gpu_info['cuda_version']}
152
+ """
153
+ else:
154
+ header += "- GPUs: None available\n"
155
+
156
+ if cost_info and cost_info["hourly_rate"] > 0:
157
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
158
+
159
+ header += f"""
160
+ ### Software
161
+ - Python: {sys_info['python_version']}
162
+ - PyTorch: {sys_info['torch_version']}
163
+
164
+ """
165
+
166
+ # bloat metrics: count lines/chars in git-tracked source files only
167
+ extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
168
+ git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
169
+ files_output = run_command(f"git ls-files -- {git_patterns}")
170
+ file_list = [f for f in (files_output or '').split('\n') if f]
171
+ num_files = len(file_list)
172
+ num_lines = 0
173
+ num_chars = 0
174
+ if num_files > 0:
175
+ wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
176
+ if wc_output:
177
+ total_line = wc_output.strip().split('\n')[-1]
178
+ parts = total_line.split()
179
+ if len(parts) >= 2:
180
+ num_lines = int(parts[0])
181
+ num_chars = int(parts[1])
182
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
183
+
184
+ # count dependencies via uv.lock
185
+ uv_lock_lines = 0
186
+ if os.path.exists('uv.lock'):
187
+ with open('uv.lock', 'r', encoding='utf-8') as f:
188
+ uv_lock_lines = len(f.readlines())
189
+
190
+ header += f"""
191
+ ### Bloat
192
+ - Characters: {num_chars:,}
193
+ - Lines: {num_lines:,}
194
+ - Files: {num_files:,}
195
+ - Tokens (approx): {num_tokens:,}
196
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
197
+
198
+ """
199
+ return header
200
+
201
+ # -----------------------------------------------------------------------------
202
+
203
+ def slugify(text):
204
+ """Slugify a text string."""
205
+ return text.lower().replace(" ", "-")
206
+
207
+ # the expected files and their order
208
+ EXPECTED_FILES = [
209
+ "tokenizer-training.md",
210
+ "tokenizer-evaluation.md",
211
+ "base-model-training.md",
212
+ "base-model-loss.md",
213
+ "base-model-evaluation.md",
214
+ "chat-sft.md",
215
+ "chat-evaluation-sft.md",
216
+ "chat-rl.md",
217
+ "chat-evaluation-rl.md",
218
+ ]
219
+ # the metrics we're currently interested in
220
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
221
+
222
+ def extract(section, keys):
223
+ """simple def to extract a single key from a section"""
224
+ if not isinstance(keys, list):
225
+ keys = [keys] # convenience
226
+ out = {}
227
+ for line in section.split("\n"):
228
+ for key in keys:
229
+ if key in line:
230
+ out[key] = line.split(":")[1].strip()
231
+ return out
232
+
233
+ def extract_timestamp(content, prefix):
234
+ """Extract timestamp from content with given prefix."""
235
+ for line in content.split('\n'):
236
+ if line.startswith(prefix):
237
+ time_str = line.split(":", 1)[1].strip()
238
+ try:
239
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
240
+ except:
241
+ pass
242
+ return None
243
+
244
+ class Report:
245
+ """Maintains a bunch of logs, generates a final markdown report."""
246
+
247
+ def __init__(self, report_dir):
248
+ os.makedirs(report_dir, exist_ok=True)
249
+ self.report_dir = report_dir
250
+
251
+ def log(self, section, data):
252
+ """Log a section of data to the report."""
253
+ slug = slugify(section)
254
+ file_name = f"{slug}.md"
255
+ file_path = os.path.join(self.report_dir, file_name)
256
+ with open(file_path, "w", encoding="utf-8") as f:
257
+ f.write(f"## {section}\n")
258
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
259
+ for item in data:
260
+ if not item:
261
+ # skip falsy values like None or empty dict etc.
262
+ continue
263
+ if isinstance(item, str):
264
+ # directly write the string
265
+ f.write(item)
266
+ else:
267
+ # render a dict
268
+ for k, v in item.items():
269
+ if isinstance(v, float):
270
+ vstr = f"{v:.4f}"
271
+ elif isinstance(v, int) and v >= 10000:
272
+ vstr = f"{v:,.0f}"
273
+ else:
274
+ vstr = str(v)
275
+ f.write(f"- {k}: {vstr}\n")
276
+ f.write("\n")
277
+ return file_path
278
+
279
+ def generate(self):
280
+ """Generate the final report."""
281
+ report_dir = self.report_dir
282
+ report_file = os.path.join(report_dir, "report.md")
283
+ print(f"Generating report to {report_file}")
284
+ final_metrics = {} # the most important final metrics we'll add as table at the end
285
+ start_time = None
286
+ end_time = None
287
+ with open(report_file, "w", encoding="utf-8") as out_file:
288
+ # write the header first
289
+ header_file = os.path.join(report_dir, "header.md")
290
+ if os.path.exists(header_file):
291
+ with open(header_file, "r", encoding="utf-8") as f:
292
+ header_content = f.read()
293
+ out_file.write(header_content)
294
+ start_time = extract_timestamp(header_content, "Run started:")
295
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
296
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
297
+ bloat_data = bloat_data.group(1) if bloat_data else ""
298
+ else:
299
+ start_time = None # will cause us to not write the total wall clock time
300
+ bloat_data = "[bloat data missing]"
301
+ print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
302
+ # process all the individual sections
303
+ for file_name in EXPECTED_FILES:
304
+ section_file = os.path.join(report_dir, file_name)
305
+ if not os.path.exists(section_file):
306
+ print(f"Warning: {section_file} does not exist, skipping")
307
+ continue
308
+ with open(section_file, "r", encoding="utf-8") as in_file:
309
+ section = in_file.read()
310
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
311
+ if "rl" not in file_name:
312
+ # Skip RL sections for end_time calculation because RL is experimental
313
+ end_time = extract_timestamp(section, "timestamp:")
314
+ # extract the most important metrics from the sections
315
+ if file_name == "base-model-evaluation.md":
316
+ final_metrics["base"] = extract(section, "CORE")
317
+ if file_name == "chat-evaluation-sft.md":
318
+ final_metrics["sft"] = extract(section, chat_metrics)
319
+ if file_name == "chat-evaluation-rl.md":
320
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
321
+ # append this section of the report
322
+ out_file.write(section)
323
+ out_file.write("\n")
324
+ # add the final metrics table
325
+ out_file.write("## Summary\n\n")
326
+ # Copy over the bloat metrics from the header
327
+ out_file.write(bloat_data)
328
+ out_file.write("\n\n")
329
+ # Collect all unique metric names
330
+ all_metrics = set()
331
+ for stage_metrics in final_metrics.values():
332
+ all_metrics.update(stage_metrics.keys())
333
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
334
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
335
+ # Fixed column widths
336
+ stages = ["base", "sft", "rl"]
337
+ metric_width = 15
338
+ value_width = 8
339
+ # Write table header
340
+ header = f"| {'Metric'.ljust(metric_width)} |"
341
+ for stage in stages:
342
+ header += f" {stage.upper().ljust(value_width)} |"
343
+ out_file.write(header + "\n")
344
+ # Write separator
345
+ separator = f"|{'-' * (metric_width + 2)}|"
346
+ for stage in stages:
347
+ separator += f"{'-' * (value_width + 2)}|"
348
+ out_file.write(separator + "\n")
349
+ # Write table rows
350
+ for metric in all_metrics:
351
+ row = f"| {metric.ljust(metric_width)} |"
352
+ for stage in stages:
353
+ value = final_metrics.get(stage, {}).get(metric, "-")
354
+ row += f" {str(value).ljust(value_width)} |"
355
+ out_file.write(row + "\n")
356
+ out_file.write("\n")
357
+ # Calculate and write total wall clock time
358
+ if start_time and end_time:
359
+ duration = end_time - start_time
360
+ total_seconds = int(duration.total_seconds())
361
+ hours = total_seconds // 3600
362
+ minutes = (total_seconds % 3600) // 60
363
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
364
+ else:
365
+ out_file.write("Total wall clock time: unknown\n")
366
+ # also cp the report.md file to current directory
367
+ print(f"Copying report.md to current directory for convenience")
368
+ shutil.copy(report_file, "report.md")
369
+ return report_file
370
+
371
+ def reset(self):
372
+ """Reset the report."""
373
+ # Remove section files
374
+ for file_name in EXPECTED_FILES:
375
+ file_path = os.path.join(self.report_dir, file_name)
376
+ if os.path.exists(file_path):
377
+ os.remove(file_path)
378
+ # Remove report.md if it exists
379
+ report_file = os.path.join(self.report_dir, "report.md")
380
+ if os.path.exists(report_file):
381
+ os.remove(report_file)
382
+ # Generate and write the header section with start timestamp
383
+ header_file = os.path.join(self.report_dir, "header.md")
384
+ header = generate_header()
385
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
386
+ with open(header_file, "w", encoding="utf-8") as f:
387
+ f.write(header)
388
+ f.write(f"Run started: {start_time}\n\n---\n\n")
389
+ print(f"Reset report and wrote header to {header_file}")
390
+
391
+ # -----------------------------------------------------------------------------
392
+ # nanochat-specific convenience functions
393
+
394
+ class DummyReport:
395
+ def log(self, *args, **kwargs):
396
+ pass
397
+ def reset(self, *args, **kwargs):
398
+ pass
399
+
400
+ def get_report():
401
+ # just for convenience, only rank 0 logs to report
402
+ from nanochat.common import get_base_dir, get_dist_info
403
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
404
+ if ddp_rank == 0:
405
+ report_dir = os.path.join(get_base_dir(), "report")
406
+ return Report(report_dir)
407
+ else:
408
+ return DummyReport()
409
+
410
+ if __name__ == "__main__":
411
+ import argparse
412
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
413
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
414
+ args = parser.parse_args()
415
+ if args.command == "generate":
416
+ get_report().generate()
417
+ elif args.command == "reset":
418
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None, num_threads=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ # num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
110
+ assert isinstance(text, str)
111
+ ids = []
112
+ if prepend is not None:
113
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
114
+ ids.append(prepend_id)
115
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
116
+ if append is not None:
117
+ append_id = append if isinstance(append, int) else self.encode_special(append)
118
+ ids.append(append_id)
119
+ return ids
120
+
121
+ def encode_special(self, text):
122
+ # encode a single special token via exact match
123
+ return self.tokenizer.token_to_id(text)
124
+
125
+ def get_bos_token_id(self):
126
+ # Different HuggingFace models use different BOS tokens and there is little consistency
127
+ # 1) attempt to find a <|bos|> token
128
+ bos = self.encode_special("<|bos|>")
129
+ # 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
130
+ if bos is None:
131
+ bos = self.encode_special("<|endoftext|>")
132
+ # 3) if these fail, it's better to crash than to silently return None
133
+ assert bos is not None, "Failed to find BOS token in tokenizer"
134
+ return bos
135
+
136
+ def encode(self, text, *args, **kwargs):
137
+ if isinstance(text, str):
138
+ return self._encode_one(text, *args, **kwargs)
139
+ elif isinstance(text, list):
140
+ return [self._encode_one(t, *args, **kwargs) for t in text]
141
+ else:
142
+ raise ValueError(f"Invalid input type: {type(text)}")
143
+
144
+ def __call__(self, *args, **kwargs):
145
+ return self.encode(*args, **kwargs)
146
+
147
+ def decode(self, ids):
148
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
149
+
150
+ def save(self, tokenizer_dir):
151
+ # save the tokenizer to disk
152
+ os.makedirs(tokenizer_dir, exist_ok=True)
153
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
154
+ self.tokenizer.save(tokenizer_path)
155
+ print(f"Saved tokenizer to {tokenizer_path}")
156
+
157
+ # -----------------------------------------------------------------------------
158
+ # Tokenizer based on rustbpe + tiktoken combo
159
+ import pickle
160
+ import rustbpe
161
+ import tiktoken
162
+
163
+ class RustBPETokenizer:
164
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
165
+
166
+ def __init__(self, enc, bos_token):
167
+ self.enc = enc
168
+ self.bos_token_id = self.encode_special(bos_token)
169
+
170
+ @classmethod
171
+ def train_from_iterator(cls, text_iterator, vocab_size):
172
+ # 1) train using rustbpe
173
+ tokenizer = rustbpe.Tokenizer()
174
+ # the special tokens are inserted later in __init__, we don't train them here
175
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
176
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
177
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
178
+ # 2) construct the associated tiktoken encoding for inference
179
+ pattern = tokenizer.get_pattern()
180
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
181
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
182
+ tokens_offset = len(mergeable_ranks)
183
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
184
+ enc = tiktoken.Encoding(
185
+ name="rustbpe",
186
+ pat_str=pattern,
187
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
188
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
189
+ )
190
+ return cls(enc, "<|bos|>")
191
+
192
+ @classmethod
193
+ def from_directory(cls, tokenizer_dir):
194
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
195
+ with open(pickle_path, "rb") as f:
196
+ enc = pickle.load(f)
197
+ return cls(enc, "<|bos|>")
198
+
199
+ @classmethod
200
+ def from_pretrained(cls, tiktoken_name):
201
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
202
+ enc = tiktoken.get_encoding(tiktoken_name)
203
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
204
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
205
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
206
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
207
+ return cls(enc, "<|endoftext|>")
208
+
209
+ def get_vocab_size(self):
210
+ return self.enc.n_vocab
211
+
212
+ def get_special_tokens(self):
213
+ return self.enc.special_tokens_set
214
+
215
+ def id_to_token(self, id):
216
+ return self.enc.decode([id])
217
+
218
+ @lru_cache(maxsize=32)
219
+ def encode_special(self, text):
220
+ return self.enc.encode_single_token(text)
221
+
222
+ def get_bos_token_id(self):
223
+ return self.bos_token_id
224
+
225
+ def encode(self, text, prepend=None, append=None, num_threads=8):
226
+ # text can be either a string or a list of strings
227
+
228
+ if prepend is not None:
229
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
230
+ if append is not None:
231
+ append_id = append if isinstance(append, int) else self.encode_special(append)
232
+
233
+ if isinstance(text, str):
234
+ ids = self.enc.encode_ordinary(text)
235
+ if prepend is not None:
236
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
237
+ if append is not None:
238
+ ids.append(append_id)
239
+ elif isinstance(text, list):
240
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
241
+ if prepend is not None:
242
+ for ids_row in ids:
243
+ ids_row.insert(0, prepend_id) # TODO: same
244
+ if append is not None:
245
+ for ids_row in ids:
246
+ ids_row.append(append_id)
247
+ else:
248
+ raise ValueError(f"Invalid input type: {type(text)}")
249
+
250
+ return ids
251
+
252
+ def __call__(self, *args, **kwargs):
253
+ return self.encode(*args, **kwargs)
254
+
255
+ def decode(self, ids):
256
+ return self.enc.decode(ids)
257
+
258
+ def save(self, tokenizer_dir):
259
+ # save the encoding object to disk
260
+ os.makedirs(tokenizer_dir, exist_ok=True)
261
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
262
+ with open(pickle_path, "wb") as f:
263
+ pickle.dump(self.enc, f)
264
+ print(f"Saved tokenizer encoding to {pickle_path}")
265
+
266
+ def render_conversation(self, conversation, max_tokens=2048):
267
+ """
268
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
269
+ Returns:
270
+ - ids: list[int] is a list of token ids of this rendered conversation
271
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
272
+ """
273
+ # ids, masks that we will return and a helper function to help build them up.
274
+ ids, mask = [], []
275
+ def add_tokens(token_ids, mask_val):
276
+ if isinstance(token_ids, int):
277
+ token_ids = [token_ids]
278
+ ids.extend(token_ids)
279
+ mask.extend([mask_val] * len(token_ids))
280
+
281
+ # sometimes the first message is a system message...
282
+ # => just merge it with the second (user) message
283
+ if conversation["messages"][0]["role"] == "system":
284
+ # some conversation surgery is necessary here for now...
285
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
286
+ messages = conversation["messages"]
287
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
288
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
289
+ messages = messages[1:]
290
+ else:
291
+ messages = conversation["messages"]
292
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
293
+
294
+ # fetch all the special tokens we need
295
+ bos = self.get_bos_token_id()
296
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
297
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
298
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
299
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
300
+
301
+ # now we can tokenize the conversation
302
+ add_tokens(bos, 0)
303
+ for i, message in enumerate(messages):
304
+
305
+ # some sanity checking here around assumptions, to prevent footguns
306
+ must_be_from = "user" if i % 2 == 0 else "assistant"
307
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
308
+
309
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
310
+ content = message["content"]
311
+
312
+ if message["role"] == "user":
313
+ assert isinstance(content, str), "User messages are simply expected to be strings"
314
+ value_ids = self.encode(content)
315
+ add_tokens(user_start, 0)
316
+ add_tokens(value_ids, 0)
317
+ add_tokens(user_end, 0)
318
+ elif message["role"] == "assistant":
319
+ add_tokens(assistant_start, 0)
320
+ if isinstance(content, str):
321
+ # simple string => simply add the tokens
322
+ value_ids = self.encode(content)
323
+ add_tokens(value_ids, 1)
324
+ elif isinstance(content, list):
325
+ for part in content:
326
+ value_ids = self.encode(part["text"])
327
+ if part["type"] == "text":
328
+ # string part => simply add the tokens
329
+ add_tokens(value_ids, 1)
330
+ elif part["type"] == "python":
331
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
332
+ add_tokens(python_start, 1)
333
+ add_tokens(value_ids, 1)
334
+ add_tokens(python_end, 1)
335
+ elif part["type"] == "python_output":
336
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
337
+ # none of these tokens are supervised because the tokens come from Python at test time
338
+ add_tokens(output_start, 0)
339
+ add_tokens(value_ids, 0)
340
+ add_tokens(output_end, 0)
341
+ else:
342
+ raise ValueError(f"Unknown part type: {part['type']}")
343
+ else:
344
+ raise ValueError(f"Unknown content type: {type(content)}")
345
+ add_tokens(assistant_end, 1)
346
+
347
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
348
+ ids = ids[:max_tokens]
349
+ mask = mask[:max_tokens]
350
+ return ids, mask
351
+
352
+ def visualize_tokenization(self, ids, mask, with_token_id=False):
353
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
354
+ RED = '\033[91m'
355
+ GREEN = '\033[92m'
356
+ RESET = '\033[0m'
357
+ GRAY = '\033[90m'
358
+ tokens = []
359
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
360
+ token_str = self.decode([token_id])
361
+ color = GREEN if mask_val == 1 else RED
362
+ tokens.append(f"{color}{token_str}{RESET}")
363
+ if with_token_id:
364
+ tokens.append(f"{GRAY}({token_id}){RESET}")
365
+ return '|'.join(tokens)
366
+
367
+ def render_for_completion(self, conversation):
368
+ """
369
+ Used during Reinforcement Learning. In that setting, we want to
370
+ render the conversation priming the Assistant for a completion.
371
+ Unlike the Chat SFT case, we don't need to return the mask.
372
+ """
373
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
374
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
375
+ messages = conversation["messages"]
376
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
377
+ messages.pop() # remove the last message (of the Assistant) inplace
378
+
379
+ # Now tokenize the conversation
380
+ ids, mask = self.render_conversation(conversation)
381
+
382
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
383
+ assistant_start = self.encode_special("<|assistant_start|>")
384
+ ids.append(assistant_start)
385
+ return ids
386
+
387
+ # -----------------------------------------------------------------------------
388
+ # nanochat-specific convenience functions
389
+
390
+ def get_tokenizer():
391
+ from nanochat.common import get_base_dir
392
+ base_dir = get_base_dir()
393
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
394
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
395
+ return RustBPETokenizer.from_directory(tokenizer_dir)
396
+
397
+ def get_token_bytes(device="cpu"):
398
+ import torch
399
+ from nanochat.common import get_base_dir
400
+ base_dir = get_base_dir()
401
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
402
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
403
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
404
+ with open(token_bytes_path, "rb") as f:
405
+ token_bytes = torch.load(f, map_location=device)
406
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ html, body{
18
+ height: 100%;
19
+ margin: 0;
20
+ }
21
+
22
+ body {
23
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
24
+ background-color: #ffffff;
25
+ color: #111827;
26
+ min-height: 100dvh;
27
+ margin: 0;
28
+ display: flex;
29
+ flex-direction: column;
30
+ }
31
+
32
+ .header {
33
+ background-color: #ffffff;
34
+ padding: 1.25rem 1.5rem;
35
+ }
36
+
37
+ .header-left {
38
+ display: flex;
39
+ align-items: center;
40
+ gap: 0.75rem;
41
+ }
42
+
43
+ .header-logo {
44
+ height: 32px;
45
+ width: auto;
46
+ }
47
+
48
+ .header h1 {
49
+ font-size: 1.25rem;
50
+ font-weight: 600;
51
+ margin: 0;
52
+ color: #111827;
53
+ }
54
+
55
+ .new-conversation-btn {
56
+ width: 32px;
57
+ height: 32px;
58
+ padding: 0;
59
+ border: 1px solid #e5e7eb;
60
+ border-radius: 0.5rem;
61
+ background-color: #ffffff;
62
+ color: #6b7280;
63
+ cursor: pointer;
64
+ display: flex;
65
+ align-items: center;
66
+ justify-content: center;
67
+ transition: all 0.2s ease;
68
+ }
69
+
70
+ .new-conversation-btn:hover {
71
+ background-color: #f3f4f6;
72
+ border-color: #d1d5db;
73
+ color: #374151;
74
+ }
75
+
76
+ .chat-container {
77
+ flex: 1;
78
+ overflow-y: auto;
79
+ background-color: #ffffff;
80
+ }
81
+
82
+ .chat-wrapper {
83
+ max-width: 48rem;
84
+ margin: 0 auto;
85
+ padding: 2rem 1.5rem 3rem;
86
+ display: flex;
87
+ flex-direction: column;
88
+ gap: 0.75rem;
89
+ }
90
+
91
+ .message {
92
+ display: flex;
93
+ justify-content: flex-start;
94
+ margin-bottom: 0.5rem;
95
+ color: #0d0d0d;
96
+ }
97
+
98
+ .message.assistant {
99
+ justify-content: flex-start;
100
+ }
101
+
102
+ .message.user {
103
+ justify-content: flex-end;
104
+ }
105
+
106
+ .message-content {
107
+ white-space: pre-wrap;
108
+ line-height: 1.6;
109
+ max-width: 100%;
110
+ }
111
+
112
+ .message.assistant .message-content {
113
+ background: transparent;
114
+ border: none;
115
+ cursor: pointer;
116
+ border-radius: 0.5rem;
117
+ padding: 0.5rem;
118
+ margin-left: -0.5rem;
119
+ transition: background-color 0.2s ease;
120
+ }
121
+
122
+ .message.assistant .message-content:hover {
123
+ background-color: #f9fafb;
124
+ }
125
+
126
+ .message.user .message-content {
127
+ background-color: #f3f4f6;
128
+ border-radius: 1.25rem;
129
+ padding: 0.8rem 1rem;
130
+ max-width: 65%;
131
+ cursor: pointer;
132
+ transition: background-color 0.2s ease;
133
+ }
134
+
135
+ .message.user .message-content:hover {
136
+ background-color: #e5e7eb;
137
+ }
138
+
139
+ .message.console .message-content {
140
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
141
+ font-size: 0.875rem;
142
+ background-color: #fafafa;
143
+ padding: 0.75rem 1rem;
144
+ color: #374151;
145
+ max-width: 80%;
146
+ }
147
+
148
+ .input-container {
149
+ background-color: #ffffff;
150
+ padding: 1rem;
151
+ padding-bottom: calc(1rem + env(safe-area-inset-bottom))
152
+ }
153
+
154
+ .input-wrapper {
155
+ max-width: 48rem;
156
+ margin: 0 auto;
157
+ display: flex;
158
+ gap: 0.75rem;
159
+ align-items: flex-end;
160
+ }
161
+
162
+ .chat-input {
163
+ flex: 1;
164
+ padding: 0.8rem 1rem;
165
+ border: 1px solid #d1d5db;
166
+ border-radius: 0.75rem;
167
+ background-color: #ffffff;
168
+ color: #111827;
169
+ font-size: 1rem;
170
+ line-height: 1.5;
171
+ resize: none;
172
+ outline: none;
173
+ min-height: 54px;
174
+ max-height: 200px;
175
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
176
+ }
177
+
178
+ .chat-input::placeholder {
179
+ color: #9ca3af;
180
+ }
181
+
182
+ .chat-input:focus {
183
+ border-color: #2563eb;
184
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
185
+ }
186
+
187
+ .send-button {
188
+ flex-shrink: 0;
189
+ padding: 0;
190
+ width: 54px;
191
+ height: 54px;
192
+ border: 1px solid #111827;
193
+ border-radius: 0.75rem;
194
+ background-color: #111827;
195
+ color: #ffffff;
196
+ display: flex;
197
+ align-items: center;
198
+ justify-content: center;
199
+ cursor: pointer;
200
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
201
+ }
202
+
203
+ .send-button:hover:not(:disabled) {
204
+ background-color: #2563eb;
205
+ border-color: #2563eb;
206
+ }
207
+
208
+ .send-button:disabled {
209
+ cursor: not-allowed;
210
+ border-color: #d1d5db;
211
+ background-color: #e5e7eb;
212
+ color: #9ca3af;
213
+ }
214
+
215
+ .typing-indicator {
216
+ display: inline-block;
217
+ color: #6b7280;
218
+ letter-spacing: 0.15em;
219
+ }
220
+
221
+ .typing-indicator::after {
222
+ content: '···';
223
+ animation: typing 1.4s infinite;
224
+ }
225
+
226
+ @keyframes typing {
227
+ 0%, 60%, 100% { opacity: 0.2; }
228
+ 30% { opacity: 1; }
229
+ }
230
+
231
+ .error-message {
232
+ background-color: #fee2e2;
233
+ border: 1px solid #fecaca;
234
+ color: #b91c1c;
235
+ padding: 0.75rem 1rem;
236
+ border-radius: 0.75rem;
237
+ margin-top: 0.5rem;
238
+ }
239
+ </style>
240
+ </head>
241
+ <body>
242
+ <div class="header">
243
+ <div class="header-left">
244
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
245
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
246
+ <path d="M12 5v14"></path>
247
+ <path d="M5 12h14"></path>
248
+ </svg>
249
+ </button>
250
+ <h1>nanochat</h1>
251
+ </div>
252
+ </div>
253
+
254
+ <div class="chat-container" id="chatContainer">
255
+ <div class="chat-wrapper" id="chatWrapper">
256
+ <!-- Messages will be added here -->
257
+ </div>
258
+ </div>
259
+
260
+ <div class="input-container">
261
+ <div class="input-wrapper">
262
+ <textarea
263
+ id="chatInput"
264
+ class="chat-input"
265
+ placeholder="Ask anything"
266
+ rows="1"
267
+ onkeydown="handleKeyDown(event)"
268
+ ></textarea>
269
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
270
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
271
+ <path d="M22 2L11 13"></path>
272
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
273
+ </svg>
274
+ </button>
275
+ </div>
276
+ </div>
277
+
278
+ <script>
279
+ const API_URL = '';
280
+ const chatContainer = document.getElementById('chatContainer');
281
+ const chatWrapper = document.getElementById('chatWrapper');
282
+ const chatInput = document.getElementById('chatInput');
283
+ const sendButton = document.getElementById('sendButton');
284
+
285
+ let messages = [];
286
+ let isGenerating = false;
287
+ let currentTemperature = 0.8;
288
+ let currentTopK = 50;
289
+
290
+ chatInput.addEventListener('input', function() {
291
+ this.style.height = 'auto';
292
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
293
+ sendButton.disabled = !this.value.trim() || isGenerating;
294
+ });
295
+
296
+ function handleKeyDown(event) {
297
+ if (event.key === 'Enter' && !event.shiftKey) {
298
+ event.preventDefault();
299
+ sendMessage();
300
+ }
301
+ }
302
+
303
+ document.addEventListener('keydown', function(event) {
304
+ // Ctrl+Shift+N for new conversation
305
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
306
+ event.preventDefault();
307
+ if (!isGenerating) {
308
+ newConversation();
309
+ }
310
+ }
311
+ });
312
+
313
+ function newConversation() {
314
+ messages = [];
315
+ chatWrapper.innerHTML = '';
316
+ chatInput.value = '';
317
+ chatInput.style.height = 'auto';
318
+ sendButton.disabled = false;
319
+ isGenerating = false;
320
+ chatInput.focus();
321
+ }
322
+
323
+ function addMessage(role, content, messageIndex = null) {
324
+ const messageDiv = document.createElement('div');
325
+ messageDiv.className = `message ${role}`;
326
+
327
+ const contentDiv = document.createElement('div');
328
+ contentDiv.className = 'message-content';
329
+ contentDiv.textContent = content;
330
+
331
+ // Add click handler for user messages to enable editing
332
+ if (role === 'user' && messageIndex !== null) {
333
+ contentDiv.setAttribute('data-message-index', messageIndex);
334
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
335
+ contentDiv.addEventListener('click', function() {
336
+ if (!isGenerating) {
337
+ editMessage(messageIndex);
338
+ }
339
+ });
340
+ }
341
+
342
+ // Add click handler for assistant messages to enable regeneration
343
+ if (role === 'assistant' && messageIndex !== null) {
344
+ contentDiv.setAttribute('data-message-index', messageIndex);
345
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
346
+ contentDiv.addEventListener('click', function() {
347
+ if (!isGenerating) {
348
+ regenerateMessage(messageIndex);
349
+ }
350
+ });
351
+ }
352
+
353
+ messageDiv.appendChild(contentDiv);
354
+ chatWrapper.appendChild(messageDiv);
355
+
356
+ chatContainer.scrollTop = chatContainer.scrollHeight;
357
+ return contentDiv;
358
+ }
359
+
360
+ function editMessage(messageIndex) {
361
+ // Find the message in the messages array
362
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
363
+
364
+ const messageToEdit = messages[messageIndex];
365
+ if (messageToEdit.role !== 'user') return;
366
+
367
+ // Copy message content to input
368
+ chatInput.value = messageToEdit.content;
369
+ chatInput.style.height = 'auto';
370
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
371
+
372
+ // Remove this message and all subsequent messages from the array
373
+ messages = messages.slice(0, messageIndex);
374
+
375
+ // Remove message elements from DOM starting from messageIndex
376
+ const allMessages = chatWrapper.querySelectorAll('.message');
377
+ for (let i = messageIndex; i < allMessages.length; i++) {
378
+ allMessages[i].remove();
379
+ }
380
+
381
+ // Enable send button and focus input
382
+ sendButton.disabled = false;
383
+ chatInput.focus();
384
+ }
385
+
386
+ async function generateAssistantResponse() {
387
+ isGenerating = true;
388
+ sendButton.disabled = true;
389
+
390
+ const assistantContent = addMessage('assistant', '');
391
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
392
+
393
+ try {
394
+ const response = await fetch(`${API_URL}/chat/completions`, {
395
+ method: 'POST',
396
+ headers: {
397
+ 'Content-Type': 'application/json',
398
+ },
399
+ body: JSON.stringify({
400
+ messages: messages,
401
+ temperature: currentTemperature,
402
+ top_k: currentTopK,
403
+ max_tokens: 512
404
+ }),
405
+ });
406
+
407
+ if (!response.ok) {
408
+ throw new Error(`HTTP error! status: ${response.status}`);
409
+ }
410
+
411
+ const reader = response.body.getReader();
412
+ const decoder = new TextDecoder();
413
+ let fullResponse = '';
414
+ assistantContent.textContent = '';
415
+
416
+ while (true) {
417
+ const { done, value } = await reader.read();
418
+ if (done) break;
419
+
420
+ const chunk = decoder.decode(value);
421
+ const lines = chunk.split('\n');
422
+
423
+ for (const line of lines) {
424
+ if (line.startsWith('data: ')) {
425
+ try {
426
+ const data = JSON.parse(line.slice(6));
427
+ if (data.token) {
428
+ fullResponse += data.token;
429
+ assistantContent.textContent = fullResponse;
430
+ chatContainer.scrollTop = chatContainer.scrollHeight;
431
+ }
432
+ } catch (e) {
433
+ }
434
+ }
435
+ }
436
+ }
437
+
438
+ const assistantMessageIndex = messages.length;
439
+ messages.push({ role: 'assistant', content: fullResponse });
440
+
441
+ // Add click handler to regenerate this assistant message
442
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
443
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
444
+ assistantContent.addEventListener('click', function() {
445
+ if (!isGenerating) {
446
+ regenerateMessage(assistantMessageIndex);
447
+ }
448
+ });
449
+
450
+ } catch (error) {
451
+ console.error('Error:', error);
452
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
453
+ } finally {
454
+ isGenerating = false;
455
+ sendButton.disabled = !chatInput.value.trim();
456
+ }
457
+ }
458
+
459
+ async function regenerateMessage(messageIndex) {
460
+ // Find the message in the messages array
461
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
462
+
463
+ const messageToRegenerate = messages[messageIndex];
464
+ if (messageToRegenerate.role !== 'assistant') return;
465
+
466
+ // Remove this message and all subsequent messages from the array
467
+ messages = messages.slice(0, messageIndex);
468
+
469
+ // Remove message elements from DOM starting from messageIndex
470
+ const allMessages = chatWrapper.querySelectorAll('.message');
471
+ for (let i = messageIndex; i < allMessages.length; i++) {
472
+ allMessages[i].remove();
473
+ }
474
+
475
+ // Regenerate the assistant response
476
+ await generateAssistantResponse();
477
+ }
478
+
479
+ function handleSlashCommand(command) {
480
+ const parts = command.trim().split(/\s+/);
481
+ const cmd = parts[0].toLowerCase();
482
+ const arg = parts[1];
483
+
484
+ if (cmd === '/temperature') {
485
+ if (arg === undefined) {
486
+ addMessage('console', `Current temperature: ${currentTemperature}`);
487
+ } else {
488
+ const temp = parseFloat(arg);
489
+ if (isNaN(temp) || temp < 0 || temp > 2) {
490
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
491
+ } else {
492
+ currentTemperature = temp;
493
+ addMessage('console', `Temperature set to ${currentTemperature}`);
494
+ }
495
+ }
496
+ return true;
497
+ } else if (cmd === '/topk') {
498
+ if (arg === undefined) {
499
+ addMessage('console', `Current top-k: ${currentTopK}`);
500
+ } else {
501
+ const topk = parseInt(arg);
502
+ if (isNaN(topk) || topk < 1 || topk > 200) {
503
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
504
+ } else {
505
+ currentTopK = topk;
506
+ addMessage('console', `Top-k set to ${currentTopK}`);
507
+ }
508
+ }
509
+ return true;
510
+ } else if (cmd === '/clear') {
511
+ newConversation();
512
+ return true;
513
+ } else if (cmd === '/help') {
514
+ addMessage('console',
515
+ 'Available commands:\n' +
516
+ '/temperature - Show current temperature\n' +
517
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
518
+ '/topk - Show current top-k\n' +
519
+ '/topk <value> - Set top-k (1-200)\n' +
520
+ '/clear - Clear conversation\n' +
521
+ '/help - Show this help message'
522
+ );
523
+ return true;
524
+ }
525
+ return false;
526
+ }
527
+
528
+ async function sendMessage() {
529
+ const message = chatInput.value.trim();
530
+ if (!message || isGenerating) return;
531
+
532
+ // Handle slash commands
533
+ if (message.startsWith('/')) {
534
+ chatInput.value = '';
535
+ chatInput.style.height = 'auto';
536
+ handleSlashCommand(message);
537
+ return;
538
+ }
539
+
540
+ chatInput.value = '';
541
+ chatInput.style.height = 'auto';
542
+
543
+ const userMessageIndex = messages.length;
544
+ messages.push({ role: 'user', content: message });
545
+ addMessage('user', message, userMessageIndex);
546
+
547
+ await generateAssistantResponse();
548
+ }
549
+
550
+ sendButton.disabled = false;
551
+
552
+ // Autofocus the chat input on page load
553
+ chatInput.focus();
554
+
555
+ fetch(`${API_URL}/health`)
556
+ .then(response => response.json())
557
+ .then(data => {
558
+ console.log('Engine status:', data);
559
+ })
560
+ .catch(error => {
561
+ console.error('Engine not available:', error);
562
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
563
+ });
564
+ </script>
565
+ </body>
566
+ </html>
pyproject.toml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "nanochat"
3
+ version = "0.1.0"
4
+ description = "the minimal full-stack ChatGPT clone"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "datasets>=4.0.0",
9
+ "fastapi>=0.117.1",
10
+ "ipykernel>=7.1.0",
11
+ "kernels>=0.11.7",
12
+ "matplotlib>=3.10.8",
13
+ "psutil>=7.1.0",
14
+ "python-dotenv>=1.2.1",
15
+ "regex>=2025.9.1",
16
+ "rustbpe>=0.1.0",
17
+ "scipy>=1.15.3",
18
+ "setuptools>=80.9.0",
19
+ "tabulate>=0.9.0",
20
+ "tiktoken>=0.11.0",
21
+ "tokenizers>=0.22.0",
22
+ "torch==2.9.1",
23
+ "transformers>=4.57.3",
24
+ "uvicorn>=0.36.0",
25
+ "wandb>=0.21.3",
26
+ "zstandard>=0.25.0",
27
+ ]
28
+
29
+ [dependency-groups]
30
+ dev = [
31
+ "pytest>=8.0.0",
32
+ ]
33
+
34
+ [tool.pytest.ini_options]
35
+ markers = [
36
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
37
+ ]
38
+ testpaths = ["tests"]
39
+ python_files = ["test_*.py"]
40
+ python_classes = ["Test*"]
41
+ python_functions = ["test_*"]
42
+
43
+ # target torch to cuda 12.8 or CPU
44
+ [tool.uv.sources]
45
+ torch = [
46
+ { index = "pytorch-cpu", extra = "cpu" },
47
+ { index = "pytorch-cu128", extra = "gpu" },
48
+ ]
49
+
50
+ [[tool.uv.index]]
51
+ name = "pytorch-cpu"
52
+ url = "https://download.pytorch.org/whl/cpu"
53
+ explicit = true
54
+
55
+ [[tool.uv.index]]
56
+ name = "pytorch-cu128"
57
+ url = "https://download.pytorch.org/whl/cu128"
58
+ explicit = true
59
+
60
+ [project.optional-dependencies]
61
+ cpu = [
62
+ "torch==2.9.1",
63
+ ]
64
+ gpu = [
65
+ "torch==2.9.1",
66
+ ]
67
+
68
+ [tool.uv]
69
+ conflicts = [
70
+ [
71
+ { extra = "cpu" },
72
+ { extra = "gpu" },
73
+ ],
74
+ ]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio