| | |
| | |
| |
|
| | import argparse |
| | import time |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional, Tuple |
| | import glob |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer |
| | import torch.utils.checkpoint as cp |
| | import os |
| |
|
| | |
| | |
| | |
| | try: |
| | from mamba_ssm import Mamba |
| | from mamba_ssm.utils.generation import InferenceParams |
| | _HAS_MAMBA = True |
| | except ImportError: |
| | _HAS_MAMBA = False |
| | InferenceParams = None |
| | print("=" * 80) |
| | print("[WARNING] mamba-ssm not installed. Mamba layers will not function.") |
| | print("Install with: pip install mamba-ssm") |
| | print("=" * 80) |
| |
|
| | class Mamba(nn.Module): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__() |
| | print("ERROR: Mamba placeholder. mamba-ssm not installed.") |
| | def forward(self, x, *args, **kwargs): |
| | print("ERROR: mamba-ssm not installed. Cannot run MambaBlock.") |
| | return x |
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class AdaptiveRiverConfig: |
| | vocab_size: int = 50257 |
| | d_model: int = 1024 |
| | n_layers: int = 24 |
| | d_ff: int = 4096 |
| | dropout: float = 0.0 |
| | rope_theta: float = 10000.0 |
| | rotary_pct: float = 1.0 |
| | layer_norm_eps: float = 1e-5 |
| | rope_scaling_type: str | None = None |
| | rope_scaling_factor: float = 1.0 |
| | experts_per_layer: int = 4 |
| | top_k_ffn: int = 1 |
| | moe_dropout: float = 0.0 |
| | attn_n_experts: int = 6 |
| | attn_top_k: int = 6 |
| | attn_n_orig_heads: int = 16 |
| | mamba_d_state: int = 16 |
| | mamba_d_conv: int = 4 |
| | mamba_expand: int = 2 |
| | entropy_weight: float = 1e-4 |
| | head_entropy_weight: float = 1e-4 |
| | default_budget_ratio: float = 1.0 |
| | init_std: float = 0.02 |
| | tie_word_embeddings: bool = False |
| | load_balance_weight: float = 0.01 |
| | router_z_weight: float = 0.001 |
| | gate_temperature: float = 0.7 |
| | checkpoint_attn_thresh: float = 0.35 |
| | checkpoint_ffn_thresh: float = 0.35 |
| | soak_dtype: str = "fp32" |
| |
|
| | def _init_weights(module: nn.Module, std: float): |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, mean=0.0, std=std) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| |
|
| | def topk_mask_ste(scores: torch.Tensor, k: int) -> torch.Tensor: |
| | s = scores.float() |
| | if k >= s.size(-1): |
| | return torch.ones_like(s) |
| | topk = torch.topk(s, k=k, dim=-1).indices |
| | one_hot = torch.zeros_like(s) |
| | one_hot.scatter_(dim=-1, index=topk, value=1.0) |
| | probs = F.softmax(s, dim=-1) |
| | return one_hot + probs - probs.detach() |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, dim, base=10000.0, scaling_type: str | None = None, scaling_factor: float = 1.0): |
| | super().__init__() |
| | self.dim = dim |
| | self.base = float(base) |
| | self.scaling_type = scaling_type |
| | self.scaling_factor = float(scaling_factor) |
| | base = self._effective_base() |
| | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim)) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| | self._cos_sin_cache = None |
| | self._cos_sin_cache_device = None |
| | self._cos_sin_cache_dtype = None |
| | self._cos_sin_max_seq_len = -1 |
| | def _effective_base(self) -> float: |
| | if not self.scaling_type or self.scaling_factor == 1.0: |
| | return self.base |
| | if self.scaling_type in ("ntk", "linear", "yarn"): |
| | return self.base * self.scaling_factor |
| | return self.base |
| | def _get_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): |
| | if (seq_len > self._cos_sin_max_seq_len or self._cos_sin_cache is None |
| | or self._cos_sin_cache_device != device or self._cos_sin_cache_dtype != dtype): |
| | self._cos_sin_max_seq_len = max(seq_len, 2048) |
| | t = torch.arange(self._cos_sin_max_seq_len, device=device, dtype=self.inv_freq.dtype) |
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos().to(dtype) |
| | sin = emb.sin().to(dtype) |
| | self._cos_sin_cache = (cos, sin) |
| | self._cos_sin_cache_device = device |
| | self._cos_sin_cache_dtype = dtype |
| | return self._cos_sin_cache |
| | def forward(self, x, seq_len: int, offset: int | torch.Tensor = 0): |
| | device, dtype = x.device, x.dtype |
| | cos, sin = self._get_cos_sin_cache(seq_len + int(offset), device, dtype) |
| | if isinstance(offset, torch.Tensor): |
| | if offset.numel() > 1: |
| | t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype).float() |
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos_val = emb.cos()[None, None, :, :].to(dtype) |
| | sin_val = emb.sin()[None, None, :, :].to(dtype) |
| | return cos_val, sin_val |
| | else: |
| | offset = int(offset.item()) |
| | cos = cos[offset:offset+seq_len].unsqueeze(0).unsqueeze(0) |
| | sin = sin[offset:offset+seq_len].unsqueeze(0).unsqueeze(0) |
| | return cos, sin |
| |
|
| | def apply_rotary(x, cos, sin): |
| | x1, x2 = x[..., ::2], x[..., 1::2] |
| | x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2) |
| | return x * cos + x_rot * sin |
| |
|
| | class PTLayerNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-5): |
| | super().__init__() |
| | self.ln = nn.LayerNorm(hidden_size, eps=eps) |
| | def forward(self, x): |
| | return self.ln(x) |
| |
|
| | class GlobalSDPAHead(nn.Module): |
| | def __init__(self, d_model, head_dim, dropout, rope_theta, rotary_pct, cfg): |
| | super().__init__() |
| | self.q_proj = nn.Linear(d_model, head_dim, bias=False) |
| | self.k_proj = nn.Linear(d_model, head_dim, bias=False) |
| | self.v_proj = nn.Linear(d_model, head_dim, bias=False) |
| | self.rotary_dim = int(head_dim * rotary_pct) |
| | self.dropout_p = dropout |
| | self.rope = None |
| | if self.rotary_dim > 0: |
| | self.rope = RotaryEmbedding( |
| | self.rotary_dim, base=rope_theta, |
| | scaling_type=cfg.rope_scaling_type, |
| | scaling_factor=cfg.rope_scaling_factor, |
| | ) |
| | def forward(self, x, position_offset): |
| | if isinstance(position_offset, torch.Tensor): |
| | position_offset = int(position_offset.view(-1)[0].item()) |
| | else: |
| | position_offset = int(position_offset) |
| | B, T, C = x.shape |
| | q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
| | if self.rotary_dim > 0: |
| | cos, sin = self.rope(q, seq_len=T, offset=position_offset) |
| | cos = cos.squeeze(1); sin = sin.squeeze(1) |
| | q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin) |
| | k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin) |
| | q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1) |
| | k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1) |
| | q, k, v = [t.unsqueeze(1) for t in (q, k, v)] |
| | dropout_p = self.dropout_p if self.training else 0.0 |
| | out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p) |
| | return out.squeeze(1) |
| |
|
| | class AttentionMoERouter(nn.Module): |
| | def __init__(self, d_model, num_experts, top_k): |
| | super().__init__() |
| | self.top_k = top_k |
| | self.num_experts = num_experts |
| | self.gate_proj = nn.Linear(d_model, num_experts, bias=False) |
| | nn.init.normal_(self.gate_proj.weight, mean=0.0, std=0.01) |
| | def forward(self, x, budget_ratio, temperature): |
| | seq_embed = x.mean(dim=1) |
| | logits = self.gate_proj(seq_embed) / max(1e-6, float(temperature)) |
| | logits = logits.clamp(min=-10.0, max=10.0) |
| | k_target = max(1, int(round(self.top_k * (0.25 + 0.75 * budget_ratio)))) |
| | k_target = min(k_target, logits.size(-1)) |
| | vals, idx = torch.topk(logits, k_target, dim=-1) |
| | weights = F.softmax(vals.to(torch.float32), dim=-1).to(x.dtype) |
| | mask = torch.zeros_like(logits, dtype=torch.bool) |
| | mask.scatter_(1, idx, True) |
| | with torch.no_grad(): |
| | p = F.softmax(logits, dim=-1) |
| | entropy = -(p * (p.clamp_min(1e-12)).log()).sum(dim=-1).mean() |
| | return mask, weights, idx, entropy, logits |
| |
|
| | class MoEAttention(nn.Module): |
| | def __init__(self, cfg: AdaptiveRiverConfig): |
| | super().__init__() |
| | self.d_model = cfg.d_model |
| | self.n_experts = cfg.attn_n_experts |
| | self.cfg = cfg |
| | self.head_dim = cfg.d_model // cfg.attn_n_orig_heads |
| | self.rotary_dim = int(self.head_dim * cfg.rotary_pct) |
| | self.router = AttentionMoERouter(cfg.d_model, cfg.attn_n_experts, cfg.attn_top_k) |
| | self.q_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) |
| | self.k_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) |
| | self.v_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) |
| | self.rope = None |
| | if self.rotary_dim > 0: |
| | self.rope = RotaryEmbedding( |
| | self.rotary_dim, base=cfg.rope_theta, |
| | scaling_type=cfg.rope_scaling_type, |
| | scaling_factor=cfg.rope_scaling_factor, |
| | ) |
| | self.o_proj = nn.Linear(cfg.attn_n_experts * self.head_dim, cfg.d_model, bias=False) |
| | def forward(self, x, position_offset, budget_ratio, temperature): |
| | B, T, C = x.shape |
| | E, H = self.n_experts, self.head_dim |
| | sel_mask, gate_w, gate_idx, entropy, gate_logits = self.router(x, budget_ratio, temperature) |
| | q = self.q_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) |
| | k = self.k_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) |
| | v = self.v_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) |
| | if self.rope: |
| | if isinstance(position_offset, torch.Tensor): |
| | position_offset = int(position_offset.view(-1)[0].item()) |
| | else: |
| | position_offset = int(position_offset) |
| | cos, sin = self.rope(q, seq_len=T, offset=position_offset) |
| | cos = cos.squeeze(1); sin = sin.squeeze(1) |
| | q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin) |
| | k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin) |
| | q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1) |
| | k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1) |
| | q_b = q.reshape(B * E, T, H) |
| | k_b = k.reshape(B * E, T, H) |
| | v_b = v.reshape(B * E, T, H) |
| | dropout_p = self.cfg.dropout if self.training else 0.0 |
| | out_b = F.scaled_dot_product_attention(q_b, k_b, v_b, is_causal=True, dropout_p=dropout_p) |
| | out = out_b.view(B, E, T, H).permute(0, 2, 1, 3) |
| | W = torch.zeros(B, E, device=x.device, dtype=out.dtype) |
| | W.scatter_(1, gate_idx, gate_w.to(out.dtype)) |
| | weighted_out = torch.einsum('b t e h, b e -> b t e h', out, W) |
| | y = weighted_out.reshape(B, T, E * H).to(self.o_proj.weight.dtype) |
| | y = self.o_proj(y) |
| | with torch.no_grad(): |
| | usage = sel_mask.float().mean(dim=0) |
| | expected = sel_mask.float().sum(dim=-1).mean() |
| | den = torch.clamp(expected, min=1e-6) |
| | usage_norm = usage / den |
| | uniform = 1.0 / self.n_experts |
| | attn_lb = ((usage_norm - uniform) ** 2).sum() * self.n_experts / self.n_experts |
| | attn_rz = (gate_logits ** 2).mean() |
| | head_keep = sel_mask.float().mean() |
| | return y, { |
| | "head_entropy": entropy, |
| | "head_keep_frac": head_keep, |
| | "attn_load_balance_loss": attn_lb, |
| | "attn_router_z_loss": attn_rz, |
| | } |
| |
|
| | class ExpertFFN(nn.Module): |
| | def __init__(self, d_model: int, d_ff: int, dropout: float): |
| | super().__init__() |
| | self.w1 = nn.Linear(d_model, d_ff, bias=False) |
| | self.w2 = nn.Linear(d_ff, d_model, bias=False) |
| | self.dropout_p = dropout |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.w1(x) |
| | x = F.gelu(x, approximate="tanh") |
| | x = F.dropout(x, p=self.dropout_p, training=self.training) |
| | x = self.w2(x) |
| | return x |
| |
|
| | class MoEFFN(nn.Module): |
| | def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int, dropout: float, cfg: AdaptiveRiverConfig): |
| | super().__init__() |
| | self.n_experts = n_experts |
| | self.base_top_k = top_k |
| | self.cfg = cfg |
| | self.router = nn.Linear(d_model, n_experts, bias=False) |
| | self.w1_stacked = nn.Parameter(torch.empty(n_experts, d_ff, d_model)) |
| | self.w2_stacked = nn.Parameter(torch.empty(n_experts, d_model, d_ff)) |
| | std = cfg.init_std |
| | nn.init.normal_(self.router.weight, mean=0.0, std=std) |
| | nn.init.normal_(self.w1_stacked, mean=0.0, std=std) |
| | nn.init.normal_(self.w2_stacked, mean=0.0, std=std) |
| | def forward(self, x: torch.Tensor, budget_ratio: float): |
| | B, T, C = x.shape |
| | N = B * T |
| | X = x.reshape(N, C) |
| | k_target = max(1, int(round(self.base_top_k * (0.5 + budget_ratio / 2.0)))) |
| | k_target = min(k_target, self.n_experts) |
| | scores = self.router(X).to(torch.float32).clamp(min=-10.0, max=10.0) |
| | probs = F.softmax(scores, dim=-1).to(X.dtype) |
| | mask = topk_mask_ste(scores, k=k_target).to(X.dtype) |
| | gate = (mask * probs) |
| | gate = gate / gate.sum(dim=-1, keepdim=True).clamp_min(1e-6) |
| | x_ff = torch.einsum('n c, e d c -> n e d', X, self.w1_stacked) |
| | x_act = F.gelu(x_ff, approximate="tanh") |
| | y_experts = torch.einsum('n e d, e c d -> n e c', x_act, self.w2_stacked) |
| | y = torch.einsum('n e, n e c -> n c', gate, y_experts).view(B, T, C).to(x.dtype) |
| | with torch.no_grad(): |
| | entropy = (-probs * probs.clamp_min(1e-12).log()).sum(dim=-1).mean() |
| | router_z = (scores ** 2).mean().clamp(max=10.0) |
| | frac = mask.mean(dim=0) |
| | uniform = 1.0 / self.n_experts |
| | lb = ((frac - uniform) ** 2).sum() * self.n_experts / self.n_experts |
| | return y, { |
| | "router_entropy": entropy, |
| | "ffn_expert_usage": frac.detach(), |
| | "ffn_load_balance_loss": lb, |
| | "ffn_router_z_loss": router_z, |
| | } |
| |
|
| | class MambaBlock(nn.Module): |
| | def __init__(self, cfg: AdaptiveRiverConfig, enhanced: bool = False, layer_idx: int | None = None): |
| | super().__init__() |
| | if not _HAS_MAMBA: |
| | print(f"MambaBlock Layer {layer_idx} disabled: mamba-ssm not installed.") |
| | self.mamba = None |
| | return |
| | self.cfg = cfg |
| | self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
| | self.mamba = Mamba( |
| | d_model=cfg.d_model, |
| | d_state=cfg.mamba_d_state, |
| | d_conv=cfg.mamba_d_conv, |
| | expand=cfg.mamba_expand * (2 if enhanced else 1), |
| | layer_idx=layer_idx, |
| | ) |
| | self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
| | self.ffn = nn.Sequential( |
| | nn.Linear(cfg.d_model, cfg.d_ff * (2 if enhanced else 1), bias=False), |
| | nn.GELU(approximate="tanh"), |
| | nn.Linear(cfg.d_ff * (2 if enhanced else 1), cfg.d_model, bias=False), |
| | ) |
| | def forward( |
| | self, |
| | x, |
| | attn_mask=None, |
| | position_offset: int | torch.Tensor = 0, |
| | past_kv=None, |
| | budget_ratio: float = 1.0, |
| | use_cache: bool = False, |
| | mamba_state: Optional[InferenceParams] = None, |
| | ): |
| | if not _HAS_MAMBA or self.mamba is None: |
| | stats = {"head_entropy": torch.tensor(0.0, device=x.device), |
| | "head_keep_frac": torch.tensor(1.0, device=x.device), |
| | "mamba_out_l2": torch.tensor(0.0, device=x.device)} |
| | return x, stats, (None, None) |
| | h = self.ln1(x) |
| | x_m = self.mamba(h) |
| | m_out_l2 = x_m.float().pow(2).mean() |
| | x = x + x_m |
| | h2 = self.ln2(x) |
| | x = x + self.ffn(h2) |
| | stats = { |
| | "head_entropy": torch.tensor(0.0, device=x.device), |
| | "head_keep_frac": torch.tensor(1.0, device=x.device), |
| | "mamba_out_l2": m_out_l2.detach(), |
| | } |
| | return x, stats, (None, None) |
| |
|
| | class RoutedBlock(nn.Module): |
| | def __init__(self, cfg: AdaptiveRiverConfig): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
| | self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
| | self.attn = MoEAttention(cfg) |
| | self.ffn = MoEFFN(cfg.d_model, cfg.d_ff, cfg.experts_per_layer, cfg.top_k_ffn, cfg.moe_dropout, cfg) |
| | def _attn_forward(self, h: torch.Tensor, position_offset: int, budget_ratio: float): |
| | if isinstance(position_offset, torch.Tensor): |
| | position_offset = int(position_offset.view(-1)[0].item()) |
| | else: |
| | position_offset = int(position_offset) |
| | return self.attn(h, position_offset, budget_ratio, self.cfg.gate_temperature) |
| | def forward( |
| | self, |
| | x, |
| | attn_mask=None, |
| | position_offset: int | torch.Tensor = 0, |
| | past_kv=None, |
| | budget_ratio: float = 1.0, |
| | use_cache: bool = False, |
| | mamba_state: Optional[InferenceParams] = None, |
| | ): |
| | h = self.ln1(x) |
| | attn_out, attn_stats = self._attn_forward(h, position_offset, budget_ratio) |
| | x = x + attn_out |
| | h2 = self.ln2(x) |
| | ffn_out, moe_stats = self.ffn(h2, budget_ratio=budget_ratio) |
| | x = x + ffn_out |
| | stats = {**attn_stats, **moe_stats} |
| | return x, stats, (None, None) |
| |
|
| | class AdaptiveRiverLM(nn.Module): |
| | def __init__(self, cfg: AdaptiveRiverConfig): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| | self.blocks = nn.ModuleList() |
| | mamba_layer_counter = 0 |
| | for i in range(cfg.n_layers): |
| | if i < 2: |
| | print(f"[model] Layer {i}: Mamba") |
| | self.blocks.append(MambaBlock(cfg, enhanced=False, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1 |
| | elif i >= (cfg.n_layers - 2): |
| | print(f"[model] Layer {i}: Mamba (enhanced)") |
| | self.blocks.append(MambaBlock(cfg, enhanced=True, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1 |
| | else: |
| | if i == 2: |
| | print(f"[model] Layers {i}-{cfg.n_layers-3}: MoE Attention + MoE FFN") |
| | self.blocks.append(RoutedBlock(cfg)) |
| | self.ln_f = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
| | self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| | if cfg.tie_word_embeddings: |
| | self.lm_head.weight = self.embed.weight |
| | self.apply(lambda m: _init_weights(m, cfg.init_std) if isinstance(m, nn.Linear) else None) |
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | budget_ratio: Optional[float] = None, |
| | mamba_states: Optional[List] = None, |
| | past_kvs: Optional[List] = None, |
| | position_offset: int | torch.Tensor = 0, |
| | return_expert_stats: bool = False, |
| | use_cache: bool = False, |
| | ): |
| | x = self.embed(input_ids) |
| | b = float(self.cfg.default_budget_ratio if budget_ratio is None else budget_ratio) |
| | all_stats: Dict[str, List[torch.Tensor]] = {} |
| | for block in self.blocks: |
| | x, stats, _ = block( |
| | x, |
| | position_offset=position_offset, |
| | past_kv=None, |
| | budget_ratio=b, |
| | use_cache=False, |
| | mamba_state=None, |
| | ) |
| | for k, v in stats.items(): |
| | all_stats.setdefault(k, []).append(torch.as_tensor(v.detach() if isinstance(v, torch.Tensor) else v)) |
| | _ = {k: torch.stack(v).mean() for k, v in all_stats.items() if len(v) > 0} |
| | x = self.ln_f(x) |
| | logits = self.lm_head(x) |
| | return logits, _ |
| |
|
| | def estimate_1b_config() -> AdaptiveRiverConfig: |
| | return AdaptiveRiverConfig( |
| | vocab_size=50257, |
| | d_model=1024, |
| | n_layers=24, |
| | d_ff=4096, |
| | experts_per_layer=4, |
| | top_k_ffn=1, |
| | default_budget_ratio=1.0, |
| | attn_n_experts=6, |
| | attn_top_k=6, |
| | attn_n_orig_heads=16, |
| | mamba_d_state=16, |
| | mamba_d_conv=4, |
| | mamba_expand=2, |
| | gate_temperature=0.7, |
| | head_entropy_weight=1e-4, |
| | checkpoint_attn_thresh=0.35, |
| | checkpoint_ffn_thresh=0.35, |
| | load_balance_weight=0.01, |
| | router_z_weight=0.001, |
| | tie_word_embeddings=False, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | class FastInferenceTester: |
| | def __init__(self, model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.device = device |
| | self.im_start_id = im_start_id |
| | self.im_end_id = im_end_id |
| | self.eos_id = eos_id |
| | self.pad_id = pad_id |
| |
|
| | self.model.eval() |
| | torch.set_grad_enabled(False) |
| | print("Using model's native precision") |
| |
|
| | if hasattr(torch, 'compile') and _HAS_MAMBA: |
| | print("Skipping torch.compile due to mamba-ssm kernels.") |
| | else: |
| | try: |
| | print("Compiling model with torch.compile...") |
| | self.model = torch.compile(self.model, mode="reduce-overhead") |
| | print("Model compiled successfully") |
| | except Exception as e: |
| | print(f"Could not compile model: {e}") |
| | print("Running without compilation") |
| |
|
| | def _format_to_training_chat(self, prompt: str) -> torch.Tensor: |
| | messages = [{"role": "user", "content": prompt}] |
| | formatted = self.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | input_ids = self.tokenizer.encode( |
| | formatted, add_special_tokens=False, return_tensors="pt" |
| | ).to(self.device) |
| | return input_ids |
| |
|
| | def _postprocess_like_training(self, text: str) -> str: |
| | if "<|im_start|>assistant" in text: |
| | return text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() |
| | if "assistant\n" in text: |
| | return text.split("assistant\n")[-1].split("<|im_end|>")[0].strip() |
| | return text.split("<|im_end|>")[0].strip() |
| |
|
| | def _reset_mamba_states(self): |
| | if not _HAS_MAMBA: |
| | return |
| | for block in self.model.blocks: |
| | if isinstance(block, MambaBlock) and hasattr(block, "mamba"): |
| | for attr in ("inference_params", "conv_state", "ssm_state"): |
| | if hasattr(block.mamba, attr): |
| | setattr(block.mamba, attr, None) |
| |
|
| | def generate_once( |
| | self, |
| | prompt: str, |
| | max_tokens: int = 2000, |
| | temperature: float = 0.8, |
| | top_p: float = 1.0, |
| | top_k: int = 0, |
| | budget_ratio: float = 1.0, |
| | show_tokens: bool = False, |
| | min_new_tokens: int = 3, |
| | ) -> Dict: |
| | self._reset_mamba_states() |
| |
|
| | print(f"\n{'='*80}") |
| | print("FAST GENERATION (no cache)") |
| | print(f"{'='*80}") |
| | print(f"Prompt: {prompt}") |
| | print("─" * 80) |
| |
|
| | input_ids = self._format_to_training_chat(prompt) |
| |
|
| | generated_tokens: List[int] = [] |
| | token_times: List[float] = [] |
| | stop_ids = set(t for t in [self.im_end_id, self.eos_id] if t is not None) |
| | ban_initial_ids = set(t for t in [self.im_end_id, self.eos_id, self.im_start_id, self.pad_id] if t is not None) |
| |
|
| | start_time = time.time() |
| |
|
| | with torch.inference_mode(): |
| | |
| | logits, _ = self.model( |
| | input_ids, |
| | budget_ratio=budget_ratio, |
| | position_offset=0, |
| | use_cache=False |
| | ) |
| | next_token_logits = logits[:, -1, :] |
| | vocab_size = next_token_logits.size(-1) |
| |
|
| | print("Generating...", end=" ", flush=True) |
| | is_cuda = torch.cuda.is_available() |
| | buffer = [] |
| |
|
| | for _ in range(max_tokens): |
| | if is_cuda: |
| | torch.cuda.synchronize() |
| | t0 = time.time() |
| |
|
| | |
| | logits_for_sampling = next_token_logits.squeeze(0).clone() / max(1e-6, temperature) |
| | vocab_size = logits_for_sampling.size(0) |
| |
|
| | |
| | if len(generated_tokens) < min_new_tokens and min_new_tokens > 0: |
| | for tid in ban_initial_ids: |
| | if tid is not None and 0 <= tid < vocab_size: |
| | logits_for_sampling[tid] = float("-inf") |
| |
|
| | |
| | if top_k and top_k > 0: |
| | kth = torch.topk(logits_for_sampling, top_k)[0][-1] |
| | logits_for_sampling[logits_for_sampling < kth] = float("-inf") |
| |
|
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits_for_sampling, descending=True) |
| | cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
| | sorted_indices_to_remove[0] = False |
| | remove_idx = sorted_indices[sorted_indices_to_remove] |
| | logits_for_sampling[remove_idx] = float("-inf") |
| |
|
| | |
| | probs = F.softmax(logits_for_sampling, dim=-1) |
| | next_token_id = torch.multinomial(probs, num_samples=1).item() |
| |
|
| | generated_tokens.append(next_token_id) |
| |
|
| | |
| | if show_tokens: |
| | tok_text = self.tokenizer.decode([next_token_id], skip_special_tokens=False) |
| | buffer.append(tok_text) |
| | if len(buffer) >= 16: |
| | print("".join(buffer), end="", flush=True) |
| | buffer.clear() |
| |
|
| | |
| | if (next_token_id in stop_ids) and (len(generated_tokens) >= max(1, min_new_tokens)): |
| | if buffer: |
| | print("".join(buffer), end="", flush=True) |
| | buffer.clear() |
| | if show_tokens: |
| | print(" [EOT]", flush=True) |
| | break |
| |
|
| | |
| | input_ids = torch.cat( |
| | [input_ids, torch.tensor([[next_token_id]], device=self.device)], |
| | dim=1 |
| | ) |
| | logits, _ = self.model( |
| | input_ids, |
| | budget_ratio=budget_ratio, |
| | position_offset=0, |
| | use_cache=False |
| | ) |
| | next_token_logits = logits[:, -1, :] |
| |
|
| | if is_cuda: |
| | torch.cuda.synchronize() |
| | token_times.append(time.time() - t0) |
| |
|
| | |
| | if buffer: |
| | print("".join(buffer), end="", flush=True) |
| | buffer.clear() |
| |
|
| |
|
| |
|
| | total_time = time.time() - start_time |
| | text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False) |
| | text = self._postprocess_like_training(text) |
| |
|
| | if show_tokens and (not generated_tokens or (generated_tokens[-1] not in stop_ids)): |
| | print() |
| |
|
| | num_gen = len(generated_tokens) |
| | if num_gen == 0: |
| | print("\nNo tokens generated.") |
| | return {'output': '', 'tokens_per_sec': 0, 'decode_tps': 0, 'total_time': total_time, 'num_tokens': 0} |
| |
|
| | decode_time = sum(token_times) |
| | toks_per_sec = num_gen / total_time if total_time > 0 else 0 |
| | decode_tps = num_gen / decode_time if decode_time > 0 else 0 |
| |
|
| | print("\n" + "─" * 80) |
| | print("STATISTICS") |
| | print("─" * 80) |
| | print(f"Tokens: {num_gen}") |
| | print(f"Total time: {total_time:.2f}s") |
| | print(f"Overall speed: {toks_per_sec:.1f} tok/s (includes prompt)") |
| | print(f"Decode speed: {decode_tps:.1f} tok/s (generation only)") |
| | print(f"Time/token: {(decode_time/num_gen)*1000:.1f}ms") |
| | print("─" * 80) |
| | print(f"Output: {text[:100]}{'...' if len(text) > 100 else ''}") |
| | print("=" * 80 + "\n") |
| |
|
| | self._reset_mamba_states() |
| |
|
| | return { |
| | 'output': text, |
| | 'tokens_per_sec': toks_per_sec, |
| | 'decode_tps': decode_tps, |
| | 'total_time': total_time, |
| | 'num_tokens': num_gen, |
| | } |
| |
|
| | def interactive_mode(self): |
| | print("\n" + "=" * 80) |
| | print("INTERACTIVE MODE (no cache, stateless)") |
| | print("Type 'quit' or your prompt") |
| | print("=" * 80 + "\n") |
| | while True: |
| | try: |
| | prompt = input("\nYou: ") |
| | except (EOFError, KeyboardInterrupt): |
| | print("\nBye.") |
| | break |
| | if prompt.lower() in ["quit", "exit", "q"]: |
| | break |
| | if not prompt.strip(): |
| | continue |
| | print("\nAssistant: ", end="", flush=True) |
| | self.generate_once(prompt, max_tokens=2000, temperature=0.8, show_tokens=True) |
| |
|
| | def _cast_layernorm_fp32(module: nn.Module): |
| | for m in module.modules(): |
| | if isinstance(m, nn.LayerNorm): |
| | m.float() |
| |
|
| | def load_model_and_tokenizer(model_dir: str): |
| | """ |
| | Load AdaptiveRiverLM model and tokenizer from a folder layout like: |
| | |
| | model_dir/ |
| | checkpoint.pt (or any .pt file) |
| | tokenizer/ |
| | tokenizer.json |
| | special_tokens_map.json |
| | ... |
| | |
| | Automatically finds the .pt file if not explicitly named. |
| | """ |
| | print(f"Searching for model checkpoint in: {model_dir}") |
| | ckpts = glob.glob(os.path.join(model_dir, "*.pt")) |
| | if not ckpts: |
| | raise FileNotFoundError(f"No .pt checkpoint found in {model_dir}") |
| | if len(ckpts) > 1: |
| | print(f"[Warning] Multiple .pt files found, using: {ckpts[0]}") |
| | checkpoint_path = ckpts[0] |
| |
|
| | tokenizer_path = os.path.join(model_dir, "tokenizer") |
| | if not os.path.isdir(tokenizer_path): |
| | raise FileNotFoundError(f"Missing tokenizer directory: {tokenizer_path}") |
| |
|
| | print(f"Loading tokenizer from: {tokenizer_path}") |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True) |
| | if tokenizer.pad_token is None: |
| | print("Tokenizer missing pad_token. Assigning eos_token as pad_token.") |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| |
|
| | print("Building model (AdaptiveRiverLM)...") |
| | cfg = estimate_1b_config() |
| | cfg.vocab_size = len(tokenizer) |
| | cfg.tie_word_embeddings = False |
| |
|
| | model = AdaptiveRiverLM(cfg) |
| |
|
| | print(f"Loading checkpoint: {checkpoint_path}") |
| | state = torch.load(checkpoint_path, map_location="cpu") |
| | model_state_dict = model.state_dict() |
| | converted_state = {} |
| |
|
| | for k, param in model_state_dict.items(): |
| | if k in state and state[k].shape == param.shape: |
| | converted_state[k] = state[k] |
| |
|
| | print("Loading weights...") |
| | load_result = model.load_state_dict(converted_state, strict=False) |
| |
|
| | if load_result.missing_keys: |
| | print("\n--- Missing Keys ---") |
| | for k in load_result.missing_keys: |
| | print(" ", k) |
| | if load_result.unexpected_keys: |
| | print("\n--- Unexpected Keys ---") |
| | for k in load_result.unexpected_keys: |
| | print(" ", k) |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = model.to(device) |
| |
|
| | if device == "cuda" and torch.cuda.is_bf16_supported(): |
| | _cast_layernorm_fp32(model) |
| | model = model.to(torch.bfloat16) |
| | else: |
| | model = model.to(torch.float32) |
| |
|
| | model.eval() |
| | print(f"Model and tokenizer loaded successfully from {model_dir} on {device}") |
| | return model, tokenizer, device |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Stateless inference for AdaptiveRiverLM (no KV cache), proper EOT handling") |
| | parser.add_argument("--model_dir", type=str, required=True, help="Path to model folder (with checkpoint.pt and tokenizer/)") |
| | parser.add_argument("--prompt", type=str, default="Hello, my name is") |
| | parser.add_argument("--max_tokens", type=int, default=2000) |
| | parser.add_argument("--temperature", type=float, default=0.8) |
| | parser.add_argument("--top_p", type=float, default=1.0) |
| | parser.add_argument("--top_k", type=int, default=0) |
| | parser.add_argument("--min_new_tokens", type=int, default=3) |
| | parser.add_argument("--interactive", action="store_true", help="Interactive mode (stateless)") |
| | args = parser.parse_args() |
| |
|
| | model, tokenizer, device = load_model_and_tokenizer(args.model_dir) |
| |
|
| | |
| | im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| | im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| | eos_id = tokenizer.eos_token_id |
| | pad_id = tokenizer.pad_token_id |
| |
|
| | stop_ids = set(t for t in [im_end_id, eos_id] if t is not None) |
| | ban_initial_ids = set(t for t in [im_end_id, eos_id, im_start_id, pad_id] if t is not None) |
| |
|
| |
|
| | tester = FastInferenceTester(model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id) |
| |
|
| | if args.interactive: |
| | tester.interactive_mode() |
| | else: |
| | tester.generate_once( |
| | args.prompt, |
| | max_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | top_k=args.top_k, |
| | show_tokens=True, |
| | min_new_tokens=args.min_new_tokens, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|