SAE Γ— RL β€” Qwen2.5-0.5B-Instruct (strict-reward chain)

TopK sparse autoencoders trained on residual-stream activations from Qwen/Qwen2.5-0.5B-Instruct as it is PPO-fine-tuned on GSM8k under a strict binary reward (final answer must exactly match the gold). Companion to a flexible-reward chain (looser numeric match) trained with the same KL budget.

The point of releasing both chains is to study how the residual-stream feature space drifts during RL fine-tuning β€” and whether that drift depends on what the reward measures or on how much the policy is allowed to move per step (KL budget).

Repo contents

folder what files
saes_k64/ k=64, expansion Γ—8, layers {6, 12, 18, 23} 24
saes_l23_k256/ k=256, expansion Γ—8, layer 23 only 6

Six stages per layer: instruct_base, ppo_step{10,30,50,80,116}. File naming: sae_{stage}_layer{N}.pt.

The k=256 L23 chain exists because k=64 was capacity-limited at L23 (NMSE ~0.29 vs ~0.15 elsewhere). For L23 results, prefer saes_l23_k256/.

Training setup

  • Activations: 100k GSM8k prompt+response tokens per stage, collected from the merged PPO+LoRA checkpoint at that stage.
  • Architecture: TopK SAE, expansion Γ—8 (n_features = 7168).
  • Schedule: 20 epochs, lr 1e-4, batch 512.
  • Dead-feature resampling: activation < 1e-4 over a 10-epoch interval β†’ resampled.
  • Warm-start chain: each stage's SAE is initialised from the previous stage's weights (instruct_base β†’ step10 β†’ … β†’ step116).
  • Split: 80/20 train/val per stage, seed = 0.
  • PPO: KL coef 0.005 (matched to the flexible chain), strict binary reward.

Headline result

Per-step reconstruction-quality drift (Ξ”NMSE/step) at L18 and L23 is within ~3% of the flexible-reward chain run at the same KL coefficient β€” evidence that drift rate is set by the per-step optimisation budget rather than by what the reward measures. Full numbers and the matching flexible chain in the paper (in prep, EMNLP 2026 target).

Loading

import torch
sd = torch.load("saes_k64/sae_ppo_step116_layer18.pt", map_location="cpu")
# state_dict for the TopK SAE (encoder W/b, decoder W/b, k, etc.)

The SAE module definition lives in the companion code repo.

Limitations

  • One base model, one task (GSM8k). No claim to generalisation.
  • Two reward shapes only (strict + flexible). A KL-coefficient sweep and a randomised-reward control are the natural follow-ups.
  • TopK only; no JumpReLU comparison.

Citation

Paper in preparation. Please cite this repository in the meantime.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for OhhMoo/sae-rl-qwen05b-strict

Finetuned
(730)
this model

Dataset used to train OhhMoo/sae-rl-qwen05b-strict