Skip to content

Usage Guide

This guide covers how to use SAEs for inference and analysis. For training SAEs, see Training SAEs.

Loading SAEs

From Pretrained (Hugging Face)

Load SAEs from the SAELens registry or any Hugging Face repository with the saelens tag.

from sae_lens import SAE

# Load from SAELens registry
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)

# Load from any Hugging Face repo with saelens tag
sae = SAE.from_pretrained(
    release="your-username/your-sae-repo",
    sae_id="path/to/sae",
    device="cuda"
)

See Pretrained SAEs for a full list of available SAEs.

From Disk

Load SAEs that you've trained yourself or downloaded manually.

from sae_lens import SAE

sae = SAE.load_from_disk(
    path="/path/to/your/sae",
    device="cuda"
)

Running SAEs Directly

The SAE class provides three main methods for inference: encode(), decode(), and forward().

Encode

Convert activations to sparse feature representations.

import torch
from sae_lens import SAE

sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)

# activations shape: (batch, seq_len, d_model)
# Gemma 2 2B has d_model=2304
activations = torch.randn(1, 128, 2304, device="cuda")

# feature_acts shape: (batch, seq_len, d_sae)
feature_acts = sae.encode(activations)

# Check which features are active
active_features = (feature_acts > 0).sum(dim=-1)
print(f"Average L0: {active_features.float().mean().item()}")

Decode

Convert sparse feature representations back to activation space.

# Reconstruct activations from features
reconstructed = sae.decode(feature_acts)

# Compute reconstruction error
mse = (activations - reconstructed).pow(2).mean()
print(f"Reconstruction MSE: {mse.item()}")

Forward

Run the full SAE pipeline (encode + decode) in one call.

# Equivalent to sae.decode(sae.encode(activations))
reconstructed = sae.forward(activations)

# Or simply call the SAE directly
reconstructed = sae(activations)

Using HookedSAETransformer

HookedSAETransformer extends TransformerLens's HookedTransformer to seamlessly integrate SAEs into the model's forward pass.

Warning

When using HookedSAETransformer or HookedTransformer, you should probably use from_pretrained_no_processing to load the model, not from_pretrained. Most SAEs are trained on raw LLM activations, and the default processing in from_pretrained will apply post-processing to the activations, and may break your SAE.

Setup

from sae_lens import SAE, HookedSAETransformer

# Load model
model = HookedSAETransformer.from_pretrained_no_processing("gemma-2-2b", device="cuda")

# Load SAE
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)

Run with SAEs (Temporary)

Run a forward pass with SAEs attached temporarily. SAEs are removed after the forward pass.

tokens = model.to_tokens("Hello, world!")

# Run with SAE - SAE is removed after this call
logits = model.run_with_saes(tokens, saes=[sae])

Run with Cache and SAEs

Cache activations including SAE feature activations.

logits, cache = model.run_with_cache_with_saes(tokens, saes=[sae])

# Access SAE feature activations
sae_acts = cache["blocks.12.hook_resid_post.hook_sae_acts_post"]
print(f"SAE activations shape: {sae_acts.shape}")

Run with Hooks and SAEs

Intervene on SAE activations during the forward pass.

from functools import partial

def ablate_feature(sae_acts, hook, feature_id):
    sae_acts[:, :, feature_id] = 0.0
    return sae_acts

# Ablate feature 1000 during forward pass
logits = model.run_with_hooks_with_saes(
    tokens,
    saes=[sae],
    fwd_hooks=[
        ("blocks.12.hook_resid_post.hook_sae_acts_post",
         partial(ablate_feature, feature_id=1000))
    ]
)

Add SAEs (Persistent)

Permanently attach SAEs to the model until explicitly removed.

# Add SAE permanently
model.add_sae(sae)

# Now standard forward passes include the SAE
logits = model(tokens)
logits, cache = model.run_with_cache(tokens)

# Remove all attached SAEs
model.reset_saes()

# Or remove specific SAEs
model.reset_saes(act_names=["blocks.12.hook_resid_post"])

Using Error Terms

Include error terms to preserve original model behavior while accessing SAE features.

sae.use_error_term = True
model.add_sae(sae)

# Output is now: SAE(x) + error_term = x (original activation)
logits = model(tokens)

# You can intervene on the error term
logits = model.run_with_hooks(
    tokens,
    fwd_hooks=[
        ("blocks.12.hook_resid_post.hook_sae_error",
         lambda act, hook: torch.zeros_like(act))
    ]
)

Using SAETransformerBridge (Beta)

For models not natively supported by HookedTransformer (such as Gemma 3), use SAETransformerBridge. This wraps TransformerLens v3's TransformerBridge, which provides hook points for HuggingFace models without the overhead of weight processing.

Beta Feature

SAETransformerBridge requires TransformerLens v3, which is currently in beta. Install it with pip install transformer-lens>=3.0.0b0. The API may change in future versions.

Setup

from sae_lens import SAE
from sae_lens.analysis.sae_transformer_bridge import SAETransformerBridge

# Load model using TransformerBridge
model = SAETransformerBridge.boot_transformers("google/gemma-3-4b-it", device="cuda")

# Load SAE (Gemma Scope 2 SAEs work with Gemma 3 models)
sae = SAE.from_pretrained(
    release="gemma-scope-2-4b-it-res",
    sae_id="layer_17_width_16k_l0_medium",
    device="cuda"
)

Run with SAEs

The API mirrors HookedSAETransformer:

# Add SAE permanently
model.add_sae(sae)
logits = model("Hello, world!")
model.reset_saes()

# Or use context manager for temporary attachment
with model.saes(saes=[sae]):
    logits = model("Hello, world!")

# Run with SAEs (temporary, removed after forward pass)
logits = model.run_with_saes("Hello, world!", saes=[sae])

# Run with cache to access SAE activations
logits, cache = model.run_with_cache_with_saes("Hello, world!", saes=[sae])

Supported Models

SAETransformerBridge supports any model that TransformerBridge supports, including:

  • Gemma 3 (all sizes)
  • Other HuggingFace models not natively supported by HookedTransformer

For models supported by both HookedTransformer and TransformerBridge (like GPT-2, Gemma 2), prefer HookedSAETransformer as it has more mature support.

Using SAEs Without TransformerLens

SAEs from SAELens are standard PyTorch modules and can be used with any model or framework. The key is extracting activations from your model and passing them to the SAE's encode(), decode(), or forward() methods. Also note that the names of hook points will be different between TransformerLens and Hugging Face / NNsight.

Pure PyTorch with Hugging Face Transformers

Use standard PyTorch hooks to extract activations from Hugging Face models.

import torch
from transformers import AutoModel, AutoTokenizer
from sae_lens import SAE

# Load Hugging Face model
model = AutoModel.from_pretrained("google/gemma-2-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model.eval()

# Load SAE (trained on Gemma 2 2B residual stream at layer 12)
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cpu"
)

# Storage for activations
activations = {}

def hook_fn(module, input, output):
    # Gemma transformer blocks output a tuple; hidden states are first
    hidden_states = output[0] if isinstance(output, tuple) else output
    activations["layer_12"] = hidden_states.detach()

# Register hook on layer 12
handle = model.layers[12].register_forward_hook(hook_fn)

# Run forward pass
inputs = tokenizer("Hello, world!", return_tensors="pt")
with torch.no_grad():
    model(**inputs)

# Remove hook
handle.remove()

# Use SAE on extracted activations
layer_12_acts = activations["layer_12"]
feature_acts = sae.encode(layer_12_acts)
reconstructed = sae.decode(feature_acts)

print(f"Input shape: {layer_12_acts.shape}")
print(f"Feature activations shape: {feature_acts.shape}")
print(f"Active features per token: {(feature_acts > 0).sum(dim=-1)}")

Full Example: Analyzing Features with Hugging Face

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model.eval()

# Load SAE
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device=device
)

def get_sae_features(text, layer=12):
    """Extract SAE features for a given text."""
    activations = {}

    def hook_fn(module, input, output):
        hidden_states = output[0] if isinstance(output, tuple) else output
        activations["hidden"] = hidden_states.detach()

    handle = model.model.layers[layer].register_forward_hook(hook_fn)

    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        model(**inputs)

    handle.remove()

    feature_acts = sae.encode(activations["hidden"])
    return feature_acts, inputs["input_ids"]

# Analyze a prompt
text = "The capital of France is"
features, tokens = get_sae_features(text)

# Find top active features at the last token
last_token_features = features[0, -1, :]
top_features = torch.topk(last_token_features, k=10)

print(f"Top 10 active features at last token:")
for idx, (feat_idx, value) in enumerate(zip(top_features.indices, top_features.values)):
    print(f"  Feature {feat_idx.item()}: {value.item():.4f}")

Using SAEs with NNsight

nnsight provides a clean interface for model interventions. SAEs integrate naturally with nnsight's tracing API.

import torch
from nnsight import LanguageModel
from sae_lens import SAE

# Load model with nnsight
model = LanguageModel("google/gemma-2-2b", device_map="auto")

# Load SAE
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)

prompt = "The Eiffel Tower is located in"

# Extract activations and compute SAE features
with model.trace(prompt):
    # Access hidden states at layer 12
    hidden_states = model.model.layers[12].output[0]

    # Save the hidden states
    hidden_states_saved = hidden_states.save()

# Get SAE features outside the trace
with torch.no_grad():
    features = sae.encode(hidden_states_saved)

print(f"Feature activations shape: {features.shape}")
print(f"Average L0: {(features[:, 1:, :] > 0).sum(dim=-1).float().mean().item():.1f}")

Intervening on SAE Features with NNsight

import torch
from nnsight import LanguageModel
from sae_lens import SAE

model = LanguageModel("google/gemma-2-2b", device_map="auto")

sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)

prompt = "The Eiffel Tower is located in"

def ablate_top_features(hidden_states, sae, k=10):
    """Ablate the top-k active features and return modified activations."""
    features = sae.encode(hidden_states)

    # Find and ablate top-k features at each position
    for pos in range(features.shape[1]):
        top_k = torch.topk(features[0, pos], k=k)
        features[0, pos, top_k.indices] = 0.0

    # Reconstruct with ablated features
    return sae.decode(features)

# Run with intervention
with model.trace(prompt) as tracer:
    # Get hidden states
    hidden_states = model.model.layers[12].output[0]

    # Modify using SAE
    modified = ablate_top_features(hidden_states, sae, k=10)

    # Replace the hidden states
    model.model.layers[12].output[0][:] = modified

    # Get output logits
    logits = model.lm_head.output.save()

print(f"Output shape: {logits.shape}")

Key SAE Attributes

After loading an SAE, you can access useful configuration and metadata:

sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)

# Model dimensions
print(f"Input dimension (d_in): {sae.cfg.d_in}")
print(f"SAE dimension (d_sae): {sae.cfg.d_sae}")
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in}")

# Metadata about the SAE
print(f"Hook name: {sae.cfg.metadata.hook_name}")
print(f"Model name: {sae.cfg.metadata.model_name}")
print(f"Context size: {sae.cfg.metadata.context_size}")

# Hugging Face / NNsight Hook Name (if present)
print(f"Hook name: {sae.cfg.metadata.hf_hook_name}")

# Weights
print(f"Encoder weights shape: {sae.W_enc.shape}")  # (d_in, d_sae)
print(f"Decoder weights shape: {sae.W_dec.shape}")  # (d_sae, d_in)
print(f"Decoder bias shape: {sae.b_dec.shape}")     # (d_in,)