--- library_name: transformers tags: - interpretability - sparse-autoencoder - sae - mechanistic-interpretability - dictionary-learning - gemma base_model: google/gemma-3-27b-it --- # Sparse Autoencoders for Gemma-3-27b-it This repository contains **9 Sparse Autoencoders (SAEs)** trained on [google/gemma-3-27b-it](https://huggingface.co/google/gemma-3-27b-it) using the BatchTopK architecture. ## Architecture: BatchTopK SAE These SAEs use the **BatchTopK** architecture, which enforces sparsity by: 1. Computing feature activations: `z = Wx + b` (encoder) 2. Selecting top-k features **across the batch** (not per-sample) 3. Reconstructing: `x̂ = W'z_topk + b_dec` (decoder) This approach tends to produce more interpretable features than ReLU-based SAEs and has better training dynamics. ## Repository Structure ``` layer_45/ dict_16k_k80/ # 16,384 features, k=80 ae.pt # SAE weights config.json # Training configuration feature_labels.json # Natural language feature descriptions dict_16k_k160/ # 16,384 features, k=160 dict_65k_k80/ # 65,536 features, k=80 dict_65k_k160/ # 65,536 features, k=160 layer_47/ (same structure) layer_45_mlp/ (same structure) ``` ## Available SAEs | Layer | Dict Size | k | Activation Dim | Parameters | Sparsity | |-------|-----------|---|----------------|------------|----------| | 45 | 16,384 | 80 | 5,376 | 176,182,528 | 0.49% | | 45 | 16,384 | 160 | 5,376 | 176,182,528 | 0.98% | | 45 | 65,536 | 80 | 5,376 | 704,713,984 | 0.12% | | 45 | 65,536 | 160 | 5,376 | 704,713,984 | 0.24% | | 47 | 16,384 | 80 | 5,376 | 176,182,528 | 0.49% | | 47 | 16,384 | 160 | 5,376 | 176,182,528 | 0.98% | | 47 | 65,536 | 80 | 5,376 | 704,713,984 | 0.12% | | 47 | 65,536 | 160 | 5,376 | 704,713,984 | 0.24% | **Total Parameters**: 3,523,586,048 ## Model Details ### Training Details **Base Model**: [google/gemma-3-27b-it](https://huggingface.co/google/gemma-3-27b-it) **Hook Point**: `residual_stream` (post-layer activations) **Dataset**: FineWeb (HuggingFaceFW/fineweb) **Training Hyperparameters**: - Optimizer: Adam - Learning rate: 5e-5 - Warmup steps: 1,000 - Training steps: ~244,140 - Context length: 2,048 tokens - Batch size: 2,048 activations - Decay start: 195,312 steps **BatchTopK Parameters**: - Auxiliary loss coefficient (α): 0.03125 - Threshold decay (β): 0.999 - Threshold start step: 1,000 **Sparsity Levels**: - **k=80**: Higher sparsity, more selective features - **k=160**: Lower sparsity, more features active per sample **Dictionary Sizes**: - **16,384**: Compact, efficient, good for resource-constrained analysis - **65,536**: Comprehensive, captures more fine-grained patterns ### Feature Labels This repository includes **natural language descriptions** for all features, generated using LLM-as-a-judge (GPT-4) on maximum activating examples. Each feature has: - **Title**: Short description of what the feature detects - **Description**: Detailed explanation with examples - **Examples**: Token sequences that maximally activate the feature ## Usage ### Installation ```bash pip install torch transformers huggingface_hub ``` ### Loading an SAE ```python import torch from huggingface_hub import hf_hub_download # Download specific SAE ae_path = hf_hub_download( repo_id="uzaymacar/gemma-3-27b-saes", filename="layer_45/dict_16k_k80/ae.pt", subfolder=None, ) config_path = hf_hub_download( repo_id="uzaymacar/gemma-3-27b-saes", filename="layer_45/dict_16k_k80/config.json", ) # Load SAE ae_data = torch.load(ae_path, map_location='cpu') with open(config_path, 'r') as f: config = json.load(f) print(f"Loaded SAE with {config['trainer']['dict_size']} features") print(f"Activation dimension: {config['trainer']['activation_dim']}") print(f"Top-k: {config['trainer']['k']}") # SAE weights encoder_weight = ae_data['encoder.weight'] # [dict_size, activation_dim] encoder_bias = ae_data['encoder.bias'] # [dict_size] decoder_weight = ae_data['decoder.weight'] # [activation_dim, dict_size] decoder_bias = ae_data['b_dec'] # [activation_dim] threshold = ae_data['threshold'] # Learned threshold ``` ### Using the SAE for Analysis ```python from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn.functional as F # Load base model model_name = "google/gemma-3-27b-it" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map='auto' ) tokenizer = AutoTokenizer.from_pretrained(model_name) # Get activations from layer 45 text = "The capital of France is Paris" inputs = tokenizer(text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) layer_45_acts = outputs.hidden_states[45] # [batch, seq, activation_dim] # Encode with SAE acts_flat = layer_45_acts.reshape(-1, layer_45_acts.shape[-1]) # [batch*seq, dim] # Encoder: z = Wx + b z = F.linear(acts_flat, encoder_weight, encoder_bias) # [batch*seq, dict_size] # Top-k selection (per sample, not batch) top_k = config['trainer']['k'] top_values, top_indices = torch.topk(z, k=top_k, dim=-1) # Create sparse representation z_topk = torch.zeros_like(z) z_topk.scatter_(-1, top_indices, top_values) # Decode: x̂ = W'z + b reconstructed = F.linear(z_topk, decoder_weight.t(), decoder_bias) # Compute reconstruction loss mse_loss = F.mse_loss(reconstructed, acts_flat) print(f"Reconstruction MSE: {mse_loss.item():.6f}") # Find active features active_features = top_indices[0, 0] # First token's active features print(f"Active features for first token: {active_features.tolist()}") ``` ### Loading Feature Labels ```python import json from huggingface_hub import hf_hub_download # Download feature labels labels_path = hf_hub_download( repo_id="uzaymacar/gemma-3-27b-saes", filename="layer_45/dict_16k_k80/feature_labels.json", ) with open(labels_path, 'r') as f: labels = json.load(f) # Examine a specific feature feature_id = 1234 if str(feature_id) in labels: label = labels[str(feature_id)] print(f"Feature {feature_id}:") print(f" Title: {label.get('title', 'N/A')}") print(f" Description: {label.get('description', 'N/A')}") ``` ## Citation If you use these SAEs in your research, please cite: ```bibtex @software{gemma3_27b_saes, author = {Macar, Uzay}, title = {Sparse Autoencoders for Gemma-3-27b-it}, year = {2024}, url = {https://huggingface.co/uzaymacar/gemma-3-27b-saes} } ``` **SAE Training Framework:** ```bibtex @software{dictionary_learning, author = {Marks, Samuel and others}, title = {Dictionary Learning for Mechanistic Interpretability}, year = {2024}, url = {https://github.com/saprmarks/dictionary_learning} } ``` **BatchTopK Architecture:** ```bibtex @article{gao2024batchTopK, title={Scaling and evaluating sparse autoencoders}, author={Gao, Leo and others}, journal={arXiv preprint arXiv:2406.04093}, year={2024} } ``` ## License These SAEs are released under the same license as the base model (google/gemma-3-27b-it). ## Acknowledgments - Trained using [dictionary_learning](https://github.com/saprmarks/dictionary_learning) - Base model: [google/gemma-3-27b-it](https://huggingface.co/google/gemma-3-27b-it) - Training data: [FineWeb](https://huggingface.co/datasets/HuggingFaceFW/fineweb) ## Contact For questions or issues, please contact me at [uzaymacar@gmail.com](mailto:uzaymacar@gmail.com) --- **Note**: These SAEs are research artifacts. While they provide valuable insights into model representations, they should be used as one tool among many for interpretability research.