Extending SAELens
This guide explains how to extend SAELens to work with new model types, custom activation sources, or non-standard training workflows. While the LanguageModelSAETrainingRunner provides a batteries-included experience for TransformerLens models, you may want more control over the training process—for example, to train SAEs on vision models, sequence classifiers, or custom architectures.
Overview
SAELens separates the SAE training logic (handled by SAETrainer) from the activation generation (handled by data providers). This separation makes it straightforward to extend SAELens to new settings:
SAETrainer: The core training loop that handles optimization, logging, checkpointing, and evaluationDataProvider: An iterator that yields batches of activations for trainingEvaluator: An optional callable that runs custom evaluations during trainingmixing_buffer: A utility for shuffling activations when your source produces them in sequence order
Using SAETrainer Directly
The SAETrainer class is the core of SAELens training. You can use it directly when you want full control over the training process. Here's what you need:
from sae_lens.training.sae_trainer import SAETrainer
from sae_lens.config import SAETrainerConfig, LoggingConfig
from sae_lens.saes.sae import TrainingSAE, TrainingSAEConfig
from sae_lens import StandardTrainingSAEConfig
from collections.abc import Iterator
import torch
# 1. Create a TrainingSAE with your desired architecture
sae_cfg = StandardTrainingSAEConfig(
d_in=768,
d_sae=768 * 8,
l1_coefficient=5.0,
)
sae = TrainingSAE.from_dict(sae_cfg.to_dict())
# 2. Create a data provider (any iterator that yields activation tensors)
def my_activation_generator() -> Iterator[torch.Tensor]:
while True:
# Your logic to generate activations
yield torch.randn(4096, 768) # (batch_size, d_in)
data_provider = my_activation_generator()
# 3. Create the trainer config
trainer_cfg = SAETrainerConfig(
total_training_samples=1_000_000,
train_batch_size_samples=4096,
device="cuda",
lr=3e-4,
lr_end=3e-5,
lr_scheduler_name="constant",
lr_warm_up_steps=1000,
adam_beta1=0.9,
adam_beta2=0.999,
lr_decay_steps=0,
n_restart_cycles=1,
autocast=True,
dead_feature_window=1000,
feature_sampling_window=2000,
n_checkpoints=0,
checkpoint_path=None,
save_final_checkpoint=False,
logger=LoggingConfig(log_to_wandb=False),
)
# 4. Create and run the trainer
trainer = SAETrainer(
cfg=trainer_cfg,
sae=sae,
data_provider=data_provider,
evaluator=None, # Optional: add custom evaluation
)
trained_sae = trainer.fit()
# Save the trained SAE
trained_sae.save_inference_model("path/to/sae")
SAETrainerConfig
The SAETrainerConfig contains all training hyperparameters. Key fields include:
| Field | Description |
|---|---|
total_training_samples |
Total number of activation samples to train on |
train_batch_size_samples |
Batch size (number of activation vectors per step) |
device |
Device to train on ("cuda", "cpu", etc.) |
lr, lr_end |
Learning rate and final learning rate for schedulers |
lr_scheduler_name |
Scheduler type: "constant", "cosineannealing", "cosineannealingwarmrestarts" |
lr_warm_up_steps |
Number of warmup steps for learning rate |
autocast |
Whether to use mixed precision training |
dead_feature_window |
Window for detecting dead features |
n_checkpoints |
Number of checkpoints to save during training |
logger |
LoggingConfig for W&B logging |
DataProvider
A DataProvider is simply an iterator that yields torch.Tensor batches:
Each tensor should have shape (batch_size, d_in) where d_in matches your SAE's input dimension. The trainer will call next(data_provider) to get each batch.
Creating Custom Runners
For more complex scenarios, you may want to create a custom runner class that encapsulates your model, activation generation, and training logic. Here's a template:
from dataclasses import dataclass
from collections.abc import Iterator
from typing import Any
import torch
from sae_lens.training.sae_trainer import SAETrainer, Evaluator
from sae_lens.config import SAETrainerConfig, LoggingConfig
from sae_lens.saes.sae import TrainingSAE, TrainingSAEConfig
from sae_lens.training.activation_scaler import ActivationScaler
from sae_lens.training.types import DataProvider
@dataclass
class MyCustomRunnerConfig:
"""Configuration for your custom runner."""
# Model settings
model_name: str
hook_layer: int
# Training settings
training_samples: int
batch_size: int
lr: float
device: str = "cuda"
# SAE settings
d_in: int = 768
d_sae: int = 768 * 8
# Output
output_path: str | None = None
class MyCustomRunner:
"""Custom runner for training SAEs on your model type."""
def __init__(self, cfg: MyCustomRunnerConfig):
self.cfg = cfg
# Load your model
self.model = self._load_model()
# Create the SAE
self.sae = self._create_sae()
def _load_model(self) -> torch.nn.Module:
"""Load your model here."""
# Example: load a custom model
from transformers import AutoModel
model = AutoModel.from_pretrained(self.cfg.model_name)
return model.to(self.cfg.device)
def _create_sae(self) -> TrainingSAE[Any]:
"""Create the SAE to train."""
from sae_lens import StandardTrainingSAEConfig
sae_cfg = StandardTrainingSAEConfig(
d_in=self.cfg.d_in,
d_sae=self.cfg.d_sae,
)
sae = TrainingSAE.from_dict(sae_cfg.to_dict())
return sae.to(self.cfg.device)
def _get_activations(self, inputs: Any) -> torch.Tensor:
"""Extract activations from your model."""
activations = []
def hook_fn(module: Any, input: Any, output: Any) -> None:
activations.append(output.detach())
# Register hook on desired layer
target_layer = self._get_target_layer()
handle = target_layer.register_forward_hook(hook_fn)
with torch.no_grad():
self.model(inputs)
handle.remove()
return activations[0]
def _get_target_layer(self) -> torch.nn.Module:
"""Get the layer to hook for activations."""
# Implement based on your model architecture
raise NotImplementedError
def _create_data_provider(self) -> DataProvider:
"""Create an iterator that yields activation batches."""
def activation_generator() -> Iterator[torch.Tensor]:
# Your data loading logic here
while True:
# Get a batch of inputs for your model
inputs = self._get_next_batch()
# Extract activations
activations = self._get_activations(inputs)
# Flatten to (batch_size * seq_len, d_in) if needed
if activations.dim() == 3:
activations = activations.view(-1, activations.size(-1))
yield activations
return activation_generator()
def _get_next_batch(self) -> Any:
"""Get the next batch of inputs for your model."""
raise NotImplementedError
def _create_evaluator(self) -> Evaluator[TrainingSAE[Any]] | None:
"""Optionally create an evaluator for periodic evaluation."""
return None
def run(self) -> TrainingSAE[Any]:
"""Run training and return the trained SAE."""
trainer_cfg = SAETrainerConfig(
total_training_samples=self.cfg.training_samples,
train_batch_size_samples=self.cfg.batch_size,
device=self.cfg.device,
lr=self.cfg.lr,
lr_end=self.cfg.lr / 10,
lr_scheduler_name="constant",
lr_warm_up_steps=1000,
adam_beta1=0.9,
adam_beta2=0.999,
lr_decay_steps=0,
n_restart_cycles=1,
autocast=True,
dead_feature_window=1000,
feature_sampling_window=2000,
n_checkpoints=0,
checkpoint_path=None,
save_final_checkpoint=False,
logger=LoggingConfig(log_to_wandb=True),
)
trainer = SAETrainer(
cfg=trainer_cfg,
sae=self.sae,
data_provider=self._create_data_provider(),
evaluator=self._create_evaluator(),
)
trained_sae = trainer.fit()
# Save the trained SAE
if self.cfg.output_path is not None:
trained_sae.save_inference_model(self.cfg.output_path)
return trained_sae
Example: GPT2 Sequence Classifier
Here's a complete example showing how to train an SAE on a GPT2ForSequenceClassification model. This demonstrates training on a non-standard model type that isn't directly supported by the built-in LanguageModelSAETrainingRunner:
from dataclasses import dataclass
from collections.abc import Iterator
from typing import Any
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2ForSequenceClassification, GPT2Tokenizer
from sae_lens.training.sae_trainer import SAETrainer, Evaluator
from sae_lens.training.mixing_buffer import mixing_buffer
from sae_lens.config import SAETrainerConfig, LoggingConfig
from sae_lens.saes.sae import TrainingSAE
from sae_lens import StandardTrainingSAEConfig
from sae_lens.training.activation_scaler import ActivationScaler
from sae_lens.training.types import DataProvider
@dataclass
class ClassifierSAERunnerConfig:
"""Config for training SAEs on sequence classifiers."""
model_name: str = "gpt2"
dataset_path: str = "imdb"
hook_layer: int = 6 # Which transformer block to hook
# Training
training_tokens: int = 1_000_000
batch_size_prompts: int = 8
train_batch_size: int = 4096
context_size: int = 256
lr: float = 3e-4
device: str = "cuda"
# SAE
expansion_factor: int = 8
l1_coefficient: float = 5.0
# Buffer for activation shuffling
n_batches_in_buffer: int = 16
# Output
output_path: str | None = None
class ClassifierSAERunner:
"""Runner for training SAEs on GPT2ForSequenceClassification."""
def __init__(self, cfg: ClassifierSAERunnerConfig):
self.cfg = cfg
# Load tokenizer and model
self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = GPT2ForSequenceClassification.from_pretrained(
cfg.model_name,
num_labels=2, # Binary classification
).to(cfg.device)
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.model.eval()
# Get hidden size from model config
self.d_in = self.model.config.hidden_size
# Load dataset
self.dataset = load_dataset(cfg.dataset_path, split="train")
# Create SAE
sae_cfg = StandardTrainingSAEConfig(
d_in=self.d_in,
d_sae=self.d_in * cfg.expansion_factor,
l1_coefficient=cfg.l1_coefficient,
)
self.sae = TrainingSAE.from_dict(sae_cfg.to_dict()).to(cfg.device)
def _get_activations(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Extract activations from the target layer."""
activations: list[torch.Tensor] = []
def hook_fn(module: Any, input: Any, output: Any) -> None:
# output is (hidden_states, ...) tuple
hidden = output[0] if isinstance(output, tuple) else output
activations.append(hidden.detach())
# Hook into the transformer block's output
target_block = self.model.transformer.h[self.cfg.hook_layer]
handle = target_block.register_forward_hook(hook_fn)
try:
with torch.no_grad():
self.model(input_ids)
finally:
handle.remove()
return activations[0] # (batch, seq_len, hidden_size)
def _iterate_raw_activations(self) -> Iterator[torch.Tensor]:
"""Iterate over batches of activations from the model."""
dataloader = DataLoader(
self.dataset, # type: ignore
batch_size=self.cfg.batch_size_prompts,
shuffle=True,
)
while True:
for batch in dataloader:
# Tokenize
encoded = self.tokenizer(
batch["text"],
padding="max_length",
truncation=True,
max_length=self.cfg.context_size,
return_tensors="pt",
)
input_ids = encoded["input_ids"].to(self.cfg.device)
# Get activations: (batch, seq_len, hidden_size)
activations = self._get_activations(input_ids)
# Flatten to (batch * seq_len, hidden_size)
flat_activations = activations.view(-1, self.d_in)
yield flat_activations
def _create_data_provider(self) -> DataProvider:
"""Create a data provider with activation shuffling."""
buffer_size = (
self.cfg.n_batches_in_buffer
* self.cfg.batch_size_prompts
* self.cfg.context_size
)
return mixing_buffer(
buffer_size=buffer_size,
batch_size=self.cfg.train_batch_size,
activations_loader=self._iterate_raw_activations(),
)
def _create_evaluator(
self,
) -> Evaluator[TrainingSAE[Any]] | None:
"""Create an optional evaluator."""
def simple_evaluator(
sae: TrainingSAE[Any],
data_provider: DataProvider,
activation_scaler: ActivationScaler,
) -> dict[str, Any]:
"""Simple evaluation: compute reconstruction error on a batch."""
sae.eval()
batch = next(data_provider).to(sae.device)
with torch.no_grad():
sae_out = sae.decode(sae.encode(batch))
mse = (batch - sae_out).pow(2).mean().item()
sae.train()
return {"eval/mse": mse}
return simple_evaluator
def run(self) -> TrainingSAE[Any]:
"""Run training."""
trainer_cfg = SAETrainerConfig(
total_training_samples=self.cfg.training_tokens,
train_batch_size_samples=self.cfg.train_batch_size,
device=self.cfg.device,
lr=self.cfg.lr,
lr_end=self.cfg.lr / 10,
lr_scheduler_name="constant",
lr_warm_up_steps=1000,
adam_beta1=0.9,
adam_beta2=0.999,
lr_decay_steps=0,
n_restart_cycles=1,
autocast=True,
dead_feature_window=1000,
feature_sampling_window=2000,
n_checkpoints=0,
checkpoint_path=None,
save_final_checkpoint=False,
logger=LoggingConfig(
log_to_wandb=True,
wandb_project="classifier-sae",
),
)
trainer = SAETrainer(
cfg=trainer_cfg,
sae=self.sae,
data_provider=self._create_data_provider(),
evaluator=self._create_evaluator(),
)
trained_sae = trainer.fit()
# Save the trained SAE
if self.cfg.output_path is not None:
trained_sae.save_inference_model(self.cfg.output_path)
return trained_sae
# Usage
if __name__ == "__main__":
cfg = ClassifierSAERunnerConfig(
model_name="gpt2",
dataset_path="imdb",
hook_layer=6,
training_tokens=500_000,
device="cuda" if torch.cuda.is_available() else "cpu",
output_path="classifier_sae",
)
runner = ClassifierSAERunner(cfg)
runner.run()
Activation Shuffling with mixing_buffer
When collecting activations sequentially (e.g., processing documents one at a time), consecutive activations are highly correlated. This can hurt training. The mixing_buffer utility helps by shuffling activations:
from sae_lens.training.mixing_buffer import mixing_buffer
from collections.abc import Iterator
import torch
def my_sequential_activations() -> Iterator[torch.Tensor]:
"""Yields activations in document order (correlated)."""
while True:
# Process a document and yield its activations
yield torch.randn(1024, 768) # (tokens_in_doc, d_in)
# Wrap with mixing_buffer to shuffle
shuffled_provider = mixing_buffer(
buffer_size=100_000, # Total activations to buffer
batch_size=4096, # Output batch size
activations_loader=my_sequential_activations(),
)
# Now batches are shuffled!
for batch in shuffled_provider:
print(batch.shape) # (4096, 768)
break
The mixing buffer:
- Accumulates activations until reaching
buffer_size - Randomly shuffles the buffer
- Yields half as batches while keeping the other half
- Refills with new activations and repeats
This ensures each batch contains activations from many different contexts rather than consecutive tokens from the same document.
Custom Evaluators
Evaluators are called periodically during training to compute metrics. They receive the SAE, data provider, and activation scaler:
from typing import Any
from sae_lens.saes.sae import TrainingSAE
from sae_lens.training.activation_scaler import ActivationScaler
from sae_lens.training.types import DataProvider
def my_evaluator(
sae: TrainingSAE[Any],
data_provider: DataProvider,
activation_scaler: ActivationScaler,
) -> dict[str, Any]:
"""Custom evaluation function."""
sae.eval()
# Collect some evaluation batches
eval_batches = [next(data_provider) for _ in range(10)]
metrics: dict[str, float] = {}
with torch.no_grad():
for batch in eval_batches:
batch = batch.to(sae.device)
scaled_batch = activation_scaler(batch)
# Forward pass
encoded = sae.encode(scaled_batch)
decoded = sae.decode(encoded)
# Compute metrics
mse = (scaled_batch - decoded).pow(2).mean()
l0 = (encoded != 0).float().sum(dim=-1).mean()
metrics["mse"] = metrics.get("mse", 0) + mse.item()
metrics["l0"] = metrics.get("l0", 0) + l0.item()
# Average
n_batches = len(eval_batches)
metrics = {f"eval/{k}": v / n_batches for k, v in metrics.items()}
sae.train()
return metrics
The returned dictionary is logged to W&B if logging is enabled.
Example: Vision Model SAE
Here's a sketch for training SAEs on vision model activations:
from dataclasses import dataclass
from collections.abc import Iterator
from typing import Any
import torch
from torchvision import transforms
from torchvision.datasets import ImageNet
from torch.utils.data import DataLoader
from sae_lens.training.sae_trainer import SAETrainer
from sae_lens.training.mixing_buffer import mixing_buffer
from sae_lens.config import SAETrainerConfig, LoggingConfig
from sae_lens.saes.sae import TrainingSAE
from sae_lens import StandardTrainingSAEConfig
@dataclass
class VisionSAEConfig:
model_name: str = "vit_base_patch16_224"
layer_name: str = "blocks.6" # Target layer
dataset_path: str = "/path/to/imagenet"
batch_size: int = 32
training_tokens: int = 1_000_000
device: str = "cuda"
output_path: str | None = None
class VisionSAERunner:
def __init__(self, cfg: VisionSAEConfig):
self.cfg = cfg
# Load vision model (using timm for example)
import timm
self.model = timm.create_model(
cfg.model_name,
pretrained=True,
).to(cfg.device)
self.model.eval()
# Get the hidden dimension
# This depends on your model - ViT base has d=768
self.d_in = 768
# Create SAE
sae_cfg = StandardTrainingSAEConfig(
d_in=self.d_in,
d_sae=self.d_in * 8,
)
self.sae = TrainingSAE.from_dict(sae_cfg.to_dict()).to(cfg.device)
# Setup data loading
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
self.dataset = ImageNet(cfg.dataset_path, split="train", transform=transform)
self.dataloader = DataLoader(
self.dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=4,
)
def _get_layer(self, name: str) -> torch.nn.Module:
"""Get a layer by name like 'blocks.6'."""
parts = name.split(".")
module = self.model
for part in parts:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
return module
def _iterate_activations(self) -> Iterator[torch.Tensor]:
"""Extract activations from vision model."""
while True:
for images, _ in self.dataloader:
images = images.to(self.cfg.device)
activations: list[torch.Tensor] = []
def hook_fn(m: Any, i: Any, o: Any) -> None:
activations.append(o.detach())
layer = self._get_layer(self.cfg.layer_name)
handle = layer.register_forward_hook(hook_fn)
try:
with torch.no_grad():
self.model(images)
finally:
handle.remove()
# ViT outputs: (batch, num_patches + 1, hidden_dim)
# Flatten patches: (batch * num_patches, hidden_dim)
act = activations[0]
flat_act = act.view(-1, act.size(-1))
yield flat_act
def run(self) -> TrainingSAE[Any]:
# Use mixing buffer for shuffling patch activations
data_provider = mixing_buffer(
buffer_size=50_000,
batch_size=4096,
activations_loader=self._iterate_activations(),
)
trainer_cfg = SAETrainerConfig(
total_training_samples=self.cfg.training_tokens,
train_batch_size_samples=4096,
device=self.cfg.device,
lr=3e-4,
lr_end=3e-5,
lr_scheduler_name="constant",
lr_warm_up_steps=1000,
adam_beta1=0.9,
adam_beta2=0.999,
lr_decay_steps=0,
n_restart_cycles=1,
autocast=True,
dead_feature_window=1000,
feature_sampling_window=2000,
n_checkpoints=0,
checkpoint_path=None,
save_final_checkpoint=False,
logger=LoggingConfig(log_to_wandb=True),
)
trainer = SAETrainer(
cfg=trainer_cfg,
sae=self.sae,
data_provider=data_provider,
)
trained_sae = trainer.fit()
# Save the trained SAE
if self.cfg.output_path is not None:
trained_sae.save_inference_model(self.cfg.output_path)
return trained_sae
Tips for Custom Runners
-
Use
mixing_bufferfor any sequential data source to improve training stability -
Match batch sizes carefully: The
train_batch_size_samplesinSAETrainerConfigshould match the batch size yielded by your data provider -
Handle activation scaling: If using
normalize_activations="expected_average_only_in"in your SAE config, the trainer will automatically estimate and apply scaling -
Monitor dead features: The trainer tracks feature activation frequency and logs dead feature counts to W&B
-
For custom SAE architectures: See the Custom SAEs guide for creating new SAE types
Summary
| Component | Purpose | Required? |
|---|---|---|
SAETrainer |
Core training loop | Yes |
SAETrainerConfig |
Training hyperparameters | Yes |
TrainingSAE |
The SAE to train | Yes |
DataProvider |
Iterator yielding (batch, d_in) tensors |
Yes |
Evaluator |
Periodic evaluation callback | No |
mixing_buffer |
Shuffle activations | Recommended |
save_checkpoint_fn |
Custom checkpoint logic | No |
For most custom scenarios, you'll create a runner class that:
- Loads your model
- Creates a
TrainingSAEwith your desired architecture - Implements activation extraction with hooks
- Optionally uses
mixing_bufferto shuffle activations - Creates an
SAETrainerand calls.fit()to train the SAE - Saves the trained inference SAE