Skip to content

API

ActivationsStore

Class for streaming tokens and generating and storing activations while training SAEs.

Source code in sae_lens/training/activations_store.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
class ActivationsStore:
    """
    Class for streaming tokens and generating and storing activations
    while training SAEs.
    """

    model: HookedRootModule
    dataset: HfDataset
    cached_activations_path: str | None
    cached_activation_dataset: Dataset | None = None
    tokens_column: Literal["tokens", "input_ids", "text", "problem"]
    hook_name: str
    hook_head_index: int | None
    _dataloader: Iterator[Any] | None = None
    exclude_special_tokens: torch.Tensor | None = None
    device: torch.device

    @classmethod
    def from_cache_activations(
        cls,
        model: HookedRootModule,
        cfg: CacheActivationsRunnerConfig,
    ) -> ActivationsStore:
        """
        Public api to create an ActivationsStore from a cached activations dataset.
        """
        return cls(
            cached_activations_path=cfg.new_cached_activations_path,
            dtype=cfg.dtype,
            hook_name=cfg.hook_name,
            context_size=cfg.context_size,
            d_in=cfg.d_in,
            n_batches_in_buffer=cfg.n_batches_in_buffer,
            total_training_tokens=cfg.training_tokens,
            store_batch_size_prompts=cfg.model_batch_size,  # get_buffer
            train_batch_size_tokens=cfg.model_batch_size,  # dataloader
            seqpos_slice=(None,),
            device=torch.device(cfg.device),  # since we're sending these to SAE
            # NOOP
            prepend_bos=False,
            hook_head_index=None,
            dataset=cfg.dataset_path,
            streaming=False,
            model=model,
            normalize_activations="none",
            model_kwargs=None,
            autocast_lm=False,
            dataset_trust_remote_code=None,
            exclude_special_tokens=None,
        )

    @classmethod
    def from_config(
        cls,
        model: HookedRootModule,
        cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
        | CacheActivationsRunnerConfig,
        override_dataset: HfDataset | None = None,
    ) -> ActivationsStore:
        if isinstance(cfg, CacheActivationsRunnerConfig):
            return cls.from_cache_activations(model, cfg)

        cached_activations_path = cfg.cached_activations_path
        # set cached_activations_path to None if we're not using cached activations
        if (
            isinstance(cfg, LanguageModelSAERunnerConfig)
            and not cfg.use_cached_activations
        ):
            cached_activations_path = None

        if override_dataset is None and cfg.dataset_path == "":
            raise ValueError(
                "You must either pass in a dataset or specify a dataset_path in your configutation."
            )

        device = torch.device(cfg.act_store_device)
        exclude_special_tokens = cfg.exclude_special_tokens
        if exclude_special_tokens is False:
            exclude_special_tokens = None
        if exclude_special_tokens is True:
            exclude_special_tokens = get_special_token_ids(model.tokenizer)  # type: ignore
        if exclude_special_tokens is not None:
            exclude_special_tokens = torch.tensor(
                exclude_special_tokens, dtype=torch.long, device=device
            )
        return cls(
            model=model,
            dataset=override_dataset or cfg.dataset_path,
            streaming=cfg.streaming,
            hook_name=cfg.hook_name,
            hook_head_index=cfg.hook_head_index,
            context_size=cfg.context_size,
            d_in=cfg.d_in
            if isinstance(cfg, CacheActivationsRunnerConfig)
            else cfg.sae.d_in,
            n_batches_in_buffer=cfg.n_batches_in_buffer,
            total_training_tokens=cfg.training_tokens,
            store_batch_size_prompts=cfg.store_batch_size_prompts,
            train_batch_size_tokens=cfg.train_batch_size_tokens,
            prepend_bos=cfg.prepend_bos,
            normalize_activations=cfg.sae.normalize_activations,
            device=device,
            dtype=cfg.dtype,
            cached_activations_path=cached_activations_path,
            model_kwargs=cfg.model_kwargs,
            autocast_lm=cfg.autocast_lm,
            dataset_trust_remote_code=cfg.dataset_trust_remote_code,
            seqpos_slice=cfg.seqpos_slice,
            exclude_special_tokens=exclude_special_tokens,
            disable_concat_sequences=cfg.disable_concat_sequences,
            sequence_separator_token=cfg.sequence_separator_token,
            activations_mixing_fraction=cfg.activations_mixing_fraction,
        )

    @classmethod
    def from_sae(
        cls,
        model: HookedRootModule,
        sae: SAE[T_SAE_CONFIG],
        dataset: HfDataset | str,
        dataset_trust_remote_code: bool = False,
        context_size: int | None = None,
        streaming: bool = True,
        store_batch_size_prompts: int = 8,
        n_batches_in_buffer: int = 8,
        train_batch_size_tokens: int = 4096,
        total_tokens: int = 10**9,
        device: str = "cpu",
        disable_concat_sequences: bool = False,
        sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
    ) -> ActivationsStore:
        if context_size is None:
            context_size = sae.cfg.metadata.context_size
        if sae.cfg.metadata.hook_name is None:
            raise ValueError("hook_name is required")
        if context_size is None:
            raise ValueError("context_size is required")
        if sae.cfg.metadata.prepend_bos is None:
            raise ValueError("prepend_bos is required")
        return cls(
            model=model,
            dataset=dataset,
            d_in=sae.cfg.d_in,
            hook_name=sae.cfg.metadata.hook_name,
            hook_head_index=sae.cfg.metadata.hook_head_index,
            context_size=context_size,
            prepend_bos=sae.cfg.metadata.prepend_bos,
            streaming=streaming,
            store_batch_size_prompts=store_batch_size_prompts,
            train_batch_size_tokens=train_batch_size_tokens,
            n_batches_in_buffer=n_batches_in_buffer,
            total_training_tokens=total_tokens,
            normalize_activations=sae.cfg.normalize_activations,
            dataset_trust_remote_code=dataset_trust_remote_code,
            dtype=sae.cfg.dtype,
            device=torch.device(device),
            seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
            disable_concat_sequences=disable_concat_sequences,
            sequence_separator_token=sequence_separator_token,
        )

    def __init__(
        self,
        model: HookedRootModule,
        dataset: HfDataset | str,
        streaming: bool,
        hook_name: str,
        hook_head_index: int | None,
        context_size: int,
        d_in: int,
        n_batches_in_buffer: int,
        total_training_tokens: int,
        store_batch_size_prompts: int,
        train_batch_size_tokens: int,
        prepend_bos: bool,
        normalize_activations: str,
        device: torch.device,
        dtype: str,
        cached_activations_path: str | None = None,
        model_kwargs: dict[str, Any] | None = None,
        autocast_lm: bool = False,
        dataset_trust_remote_code: bool | None = None,
        seqpos_slice: tuple[int | None, ...] = (None,),
        exclude_special_tokens: torch.Tensor | None = None,
        disable_concat_sequences: bool = False,
        sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
        activations_mixing_fraction: float = 0.5,
    ):
        self.model = model
        if model_kwargs is None:
            model_kwargs = {}
        self.model_kwargs = model_kwargs
        self.dataset = (
            load_dataset(
                dataset,
                split="train",
                streaming=streaming,  # type: ignore
                trust_remote_code=dataset_trust_remote_code,  # type: ignore
            )
            if isinstance(dataset, str)
            else dataset
        )

        if isinstance(dataset, (Dataset, DatasetDict)):
            self.dataset = cast(Dataset | DatasetDict, self.dataset)
            n_samples = len(self.dataset)

            if n_samples < total_training_tokens:
                warnings.warn(
                    f"The training dataset contains fewer samples ({n_samples}) than the number of samples required by your training configuration ({total_training_tokens}). This will result in multiple training epochs and some samples being used more than once."
                )

        self.hook_name = hook_name
        self.hook_head_index = hook_head_index
        self.context_size = context_size
        self.d_in = d_in
        self.n_batches_in_buffer = n_batches_in_buffer
        self.total_training_tokens = total_training_tokens
        self.store_batch_size_prompts = store_batch_size_prompts
        self.train_batch_size_tokens = train_batch_size_tokens
        self.prepend_bos = prepend_bos
        self.normalize_activations = normalize_activations
        self.device = torch.device(device)
        self.dtype = str_to_dtype(dtype)
        self.cached_activations_path = cached_activations_path
        self.autocast_lm = autocast_lm
        self.seqpos_slice = seqpos_slice
        self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
        self.exclude_special_tokens = exclude_special_tokens
        self.disable_concat_sequences = disable_concat_sequences
        self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
            sequence_separator_token
        )
        self.activations_mixing_fraction = activations_mixing_fraction

        self.n_dataset_processed = 0

        # Check if dataset is tokenized
        dataset_sample = next(iter(self.dataset))

        # check if it's tokenized
        if "tokens" in dataset_sample:
            self.is_dataset_tokenized = True
            self.tokens_column = "tokens"
        elif "input_ids" in dataset_sample:
            self.is_dataset_tokenized = True
            self.tokens_column = "input_ids"
        elif "text" in dataset_sample:
            self.is_dataset_tokenized = False
            self.tokens_column = "text"
        elif "problem" in dataset_sample:
            self.is_dataset_tokenized = False
            self.tokens_column = "problem"
        else:
            raise ValueError(
                "Dataset must have a 'tokens', 'input_ids', 'text', or 'problem' column."
            )
        if self.is_dataset_tokenized:
            ds_context_size = len(dataset_sample[self.tokens_column])  # type: ignore
            if ds_context_size < self.context_size:
                raise ValueError(
                    f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}.
                    The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}."""
                )
            if self.context_size != ds_context_size:
                warnings.warn(
                    f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""",
                    RuntimeWarning,
                )
            # TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement
            if hasattr(self.dataset, "set_format"):
                self.dataset.set_format(type="torch", columns=[self.tokens_column])  # type: ignore

            if (
                isinstance(dataset, str)
                and hasattr(model, "tokenizer")
                and model.tokenizer is not None
            ):
                validate_pretokenized_dataset_tokenizer(
                    dataset_path=dataset,
                    model_tokenizer=model.tokenizer,  # type: ignore
                )
        else:
            warnings.warn(
                "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
            )

        self.iterable_sequences = self._iterate_tokenized_sequences()

        self.cached_activation_dataset = self.load_cached_activation_dataset()

        # TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

    def _iterate_raw_dataset(
        self,
    ) -> Generator[torch.Tensor | list[int] | str, None, None]:
        """
        Helper to iterate over the dataset while incrementing n_dataset_processed
        """
        for row in self.dataset:
            # typing datasets is difficult
            yield row[self.tokens_column]  # type: ignore
            self.n_dataset_processed += 1

    def _iterate_raw_dataset_tokens(self) -> Generator[torch.Tensor, None, None]:
        """
        Helper to create an iterator which tokenizes raw text from the dataset on the fly
        """
        for row in self._iterate_raw_dataset():
            tokens = (
                self.model.to_tokens(
                    row,
                    truncate=False,
                    move_to_device=False,  # we move to device below
                    prepend_bos=False,
                )  # type: ignore
                .squeeze(0)
                .to(self.device)
            )
            if len(tokens.shape) != 1:
                raise ValueError(f"tokens.shape should be 1D but was {tokens.shape}")
            yield tokens

    def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]:
        """
        Generator which iterates over full sequence of context_size tokens
        """
        # If the datset is pretokenized, we will slice the dataset to the length of the context window if needed. Otherwise, no further processing is needed.
        # We assume that all necessary BOS/EOS/SEP tokens have been added during pretokenization.
        if self.is_dataset_tokenized:
            for row in self._iterate_raw_dataset():
                yield torch.tensor(
                    row[
                        : self.context_size
                    ],  # If self.context_size = None, this line simply returns the whole row
                    dtype=torch.long,
                    device=self.device,
                    requires_grad=False,
                )
        # If the dataset isn't tokenized, we'll tokenize, concat, and batch on the fly
        else:
            tokenizer = getattr(self.model, "tokenizer", None)
            bos_token_id = None if tokenizer is None else tokenizer.bos_token_id

            yield from concat_and_batch_sequences(
                tokens_iterator=self._iterate_raw_dataset_tokens(),
                context_size=self.context_size,
                begin_batch_token_id=(bos_token_id if self.prepend_bos else None),
                begin_sequence_token_id=None,
                sequence_separator_token_id=get_special_token_from_cfg(
                    self.sequence_separator_token, tokenizer
                )
                if tokenizer is not None
                else None,
                disable_concat_sequences=self.disable_concat_sequences,
            )

    def load_cached_activation_dataset(self) -> Dataset | None:
        """
        Load the cached activation dataset from disk.

        - If cached_activations_path is set, returns Huggingface Dataset else None
        - Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
        """
        if self.cached_activations_path is None:
            return None

        assert self.cached_activations_path is not None  # keep pyright happy
        # Sanity check: does the cache directory exist?
        if not os.path.exists(self.cached_activations_path):
            raise FileNotFoundError(
                f"Cache directory {self.cached_activations_path} does not exist. "
                "Consider double-checking your dataset, model, and hook names."
            )

        # ---
        # Actual code
        activations_dataset = datasets.load_from_disk(self.cached_activations_path)
        columns = [self.hook_name]
        if "token_ids" in activations_dataset.column_names:
            columns.append("token_ids")
        activations_dataset.set_format(
            type="torch", columns=columns, device=self.device, dtype=self.dtype
        )
        self.current_row_idx = 0  # idx to load next batch from
        # ---

        assert isinstance(activations_dataset, Dataset)

        # multiple in hooks future
        if not set([self.hook_name]).issubset(activations_dataset.column_names):
            raise ValueError(
                f"loaded dataset does not include hook activations, got {activations_dataset.column_names}"
            )

        if activations_dataset.features[self.hook_name].shape != (
            self.context_size,
            self.d_in,
        ):
            raise ValueError(
                f"Given dataset of shape {activations_dataset.features[self.hook_name].shape} does not match context_size ({self.context_size}) and d_in ({self.d_in})"
            )

        return activations_dataset

    def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
        """
        This applies a shuffle to the huggingface dataset that is the input to the activations store. This
        also shuffles the shards of the dataset, which is especially useful for evaluating on different
        sections of very large streaming datasets. Buffer size is only relevant for streaming datasets.
        The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will
        additionally shuffle individual elements within the shard.
        """
        if isinstance(self.dataset, IterableDataset):
            self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size)
        else:
            self.dataset = self.dataset.shuffle(seed=seed)
        self.iterable_dataset = iter(self.dataset)

    def reset_input_dataset(self):
        """
        Resets the input dataset iterator to the beginning.
        """
        self.iterable_dataset = iter(self.dataset)

    def get_batch_tokens(
        self, batch_size: int | None = None, raise_at_epoch_end: bool = False
    ):
        """
        Streams a batch of tokens from a dataset.

        If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.
        """
        if not batch_size:
            batch_size = self.store_batch_size_prompts
        sequences = []
        # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
        for _ in range(batch_size):
            try:
                sequences.append(next(self.iterable_sequences))
            except StopIteration:
                self.iterable_sequences = self._iterate_tokenized_sequences()
                if raise_at_epoch_end:
                    raise StopIteration(
                        f"Ran out of tokens in dataset after {self.n_dataset_processed} samples, beginning the next epoch."
                    )
                sequences.append(next(self.iterable_sequences))

        return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

    @torch.no_grad()
    def get_activations(self, batch_tokens: torch.Tensor):
        """
        Returns activations of shape (batches, context, num_layers, d_in)

        d_in may result from a concatenated head dimension.
        """
        with torch.autocast(
            device_type="cuda",
            dtype=torch.bfloat16,
            enabled=self.autocast_lm,
        ):
            layerwise_activations_cache = self.model.run_with_cache(
                batch_tokens,
                names_filter=[self.hook_name],
                stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
                    self.hook_name
                ),
                prepend_bos=False,
                **self.model_kwargs,
            )[1]

        layerwise_activations = layerwise_activations_cache[self.hook_name][
            :, slice(*self.seqpos_slice)
        ]

        n_batches, n_context = layerwise_activations.shape[:2]

        stacked_activations = torch.zeros((n_batches, n_context, self.d_in))

        if self.hook_head_index is not None:
            stacked_activations[:, :] = layerwise_activations[
                :, :, self.hook_head_index
            ]
        elif layerwise_activations.ndim > 3:  # if we have a head dimension
            try:
                stacked_activations[:, :] = layerwise_activations.view(
                    n_batches, n_context, -1
                )
            except RuntimeError as e:
                logger.error(f"Error during view operation: {e}")
                logger.info("Attempting to use reshape instead...")
                stacked_activations[:, :] = layerwise_activations.reshape(
                    n_batches, n_context, -1
                )
        else:
            stacked_activations[:, :] = layerwise_activations

        return stacked_activations

    def _load_raw_llm_batch_from_cached(
        self,
        raise_on_epoch_end: bool,
    ) -> tuple[
        torch.Tensor,
        torch.Tensor | None,
    ]:
        """
        Loads a batch of activations from `cached_activation_dataset`

        The dataset has columns for each hook_name,
        each containing activations of shape (context_size, d_in).

        raises StopIteration
        """
        assert self.cached_activation_dataset is not None
        context_size = self.context_size
        batch_size = self.store_batch_size_prompts
        d_in = self.d_in

        # In future, could be a list of multiple hook names
        if self.hook_name not in self.cached_activation_dataset.column_names:
            raise ValueError(
                f"Missing columns in dataset. Expected {self.hook_name}, "
                f"got {self.cached_activation_dataset.column_names}."
            )

        if self.current_row_idx > len(self.cached_activation_dataset) - batch_size:
            self.current_row_idx = 0
            if raise_on_epoch_end:
                raise StopIteration

        ds_slice = self.cached_activation_dataset[
            self.current_row_idx : self.current_row_idx + batch_size
        ]
        # Load activations for each hook.
        # Usually faster to first slice dataset then pick column
        acts_buffer = ds_slice[self.hook_name]
        if acts_buffer.shape != (batch_size, context_size, d_in):
            raise ValueError(
                f"acts_buffer has shape {acts_buffer.shape}, "
                f"but expected ({batch_size}, {context_size}, {d_in})."
            )

        self.current_row_idx += batch_size
        acts_buffer = acts_buffer.reshape(batch_size * context_size, d_in)

        if "token_ids" not in self.cached_activation_dataset.column_names:
            return acts_buffer, None

        token_ids_buffer = ds_slice["token_ids"]
        if token_ids_buffer.shape != (batch_size, context_size):
            raise ValueError(
                f"token_ids_buffer has shape {token_ids_buffer.shape}, "
                f"but expected ({batch_size}, {context_size})."
            )
        token_ids_buffer = token_ids_buffer.reshape(batch_size * context_size)
        return acts_buffer, token_ids_buffer

    @torch.no_grad()
    def get_raw_llm_batch(
        self,
        raise_on_epoch_end: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """
        Loads the next batch of activations from the LLM and returns it.

        If raise_on_epoch_end is True, when the dataset is exhausted it will
        automatically refill the dataset and then raise a StopIteration so that
        the caller has a chance to react.

        Returns:
            Tuple of (activations, token_ids) where activations has shape
            (batch_size * context_size, d_in) and token_ids has shape
            (batch_size * context_size,).
        """
        d_in = self.d_in

        if self.cached_activation_dataset is not None:
            return self._load_raw_llm_batch_from_cached(raise_on_epoch_end)

        # move batch toks to gpu for model
        batch_tokens = self.get_batch_tokens(raise_at_epoch_end=raise_on_epoch_end).to(
            _get_model_device(self.model)
        )
        activations = self.get_activations(batch_tokens).to(self.device)

        # handle seqpos_slice, this is done for activations in get_activations
        batch_tokens = batch_tokens[:, slice(*self.seqpos_slice)]

        # reshape from (batch, context, d_in) to (batch * context, d_in)
        activations = activations.reshape(-1, d_in)
        token_ids = batch_tokens.reshape(-1)

        return activations, token_ids

    def get_filtered_llm_batch(
        self,
        raise_on_epoch_end: bool = False,
    ) -> torch.Tensor:
        """
        Get a batch of LLM activations with special tokens filtered out.
        """
        return _filter_buffer_acts(
            self.get_raw_llm_batch(raise_on_epoch_end=raise_on_epoch_end),
            self.exclude_special_tokens,
        )

    def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
        """
        Iterate over filtered LLM activation batches.
        """
        while True:
            try:
                yield self.get_filtered_llm_batch(raise_on_epoch_end=True)
            except StopIteration:
                warnings.warn(
                    "All samples in the training dataset have been exhausted, beginning new epoch."
                )
                try:
                    yield self.get_filtered_llm_batch()
                except StopIteration:
                    raise ValueError(
                        "Unable to fill buffer after starting new epoch. Dataset may be too small."
                    )

    def get_data_loader(
        self,
    ) -> Iterator[Any]:
        """
        Return an auto-refilling stream of filtered and mixed activations.
        """
        return mixing_buffer(
            buffer_size=self.n_batches_in_buffer * self.training_context_size,
            batch_size=self.train_batch_size_tokens,
            activations_loader=self._iterate_filtered_activations(),
            mix_fraction=self.activations_mixing_fraction,
        )

    def next_batch(self) -> torch.Tensor:
        """Get next batch, updating buffer if needed."""
        return self.__next__()

    # ActivationsStore should be an iterator
    def __next__(self) -> torch.Tensor:
        if self._dataloader is None:
            self._dataloader = self.get_data_loader()
        return next(self._dataloader)

    def __iter__(self) -> Iterator[torch.Tensor]:
        return self

    def state_dict(self) -> dict[str, torch.Tensor]:
        return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}

    def save(self, file_path: str):
        """save the state dict to a file in safetensors format"""
        save_file(self.state_dict(), file_path)

    def save_to_checkpoint(self, checkpoint_path: str | Path):
        """Save the state dict to a checkpoint path"""
        self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))

    def load_from_checkpoint(self, checkpoint_path: str | Path):
        """Load the state dict from a checkpoint path"""
        self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))

    def load(self, file_path: str):
        """Load the state dict from a file in safetensors format"""

        state_dict = load_file(file_path)

        if "n_dataset_processed" in state_dict:
            target_n_dataset_processed = state_dict["n_dataset_processed"].item()

            # Only fast-forward if needed

            if target_n_dataset_processed > self.n_dataset_processed:
                logger.info(
                    "Fast-forwarding through dataset samples to match checkpoint position"
                )
                samples_to_skip = target_n_dataset_processed - self.n_dataset_processed

                pbar = tqdm(
                    total=samples_to_skip,
                    desc="Fast-forwarding through dataset",
                    leave=False,
                )
                while target_n_dataset_processed > self.n_dataset_processed:
                    start = self.n_dataset_processed
                    try:
                        # Just consume and ignore the values to fast-forward
                        next(self.iterable_sequences)
                    except StopIteration:
                        logger.warning(
                            "Dataset exhausted during fast-forward. Resetting dataset."
                        )
                        self.iterable_sequences = self._iterate_tokenized_sequences()
                    pbar.update(self.n_dataset_processed - start)
                pbar.close()

from_cache_activations(model, cfg) classmethod

Public api to create an ActivationsStore from a cached activations dataset.

Source code in sae_lens/training/activations_store.py
@classmethod
def from_cache_activations(
    cls,
    model: HookedRootModule,
    cfg: CacheActivationsRunnerConfig,
) -> ActivationsStore:
    """
    Public api to create an ActivationsStore from a cached activations dataset.
    """
    return cls(
        cached_activations_path=cfg.new_cached_activations_path,
        dtype=cfg.dtype,
        hook_name=cfg.hook_name,
        context_size=cfg.context_size,
        d_in=cfg.d_in,
        n_batches_in_buffer=cfg.n_batches_in_buffer,
        total_training_tokens=cfg.training_tokens,
        store_batch_size_prompts=cfg.model_batch_size,  # get_buffer
        train_batch_size_tokens=cfg.model_batch_size,  # dataloader
        seqpos_slice=(None,),
        device=torch.device(cfg.device),  # since we're sending these to SAE
        # NOOP
        prepend_bos=False,
        hook_head_index=None,
        dataset=cfg.dataset_path,
        streaming=False,
        model=model,
        normalize_activations="none",
        model_kwargs=None,
        autocast_lm=False,
        dataset_trust_remote_code=None,
        exclude_special_tokens=None,
    )

get_activations(batch_tokens)

Returns activations of shape (batches, context, num_layers, d_in)

d_in may result from a concatenated head dimension.

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_activations(self, batch_tokens: torch.Tensor):
    """
    Returns activations of shape (batches, context, num_layers, d_in)

    d_in may result from a concatenated head dimension.
    """
    with torch.autocast(
        device_type="cuda",
        dtype=torch.bfloat16,
        enabled=self.autocast_lm,
    ):
        layerwise_activations_cache = self.model.run_with_cache(
            batch_tokens,
            names_filter=[self.hook_name],
            stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
                self.hook_name
            ),
            prepend_bos=False,
            **self.model_kwargs,
        )[1]

    layerwise_activations = layerwise_activations_cache[self.hook_name][
        :, slice(*self.seqpos_slice)
    ]

    n_batches, n_context = layerwise_activations.shape[:2]

    stacked_activations = torch.zeros((n_batches, n_context, self.d_in))

    if self.hook_head_index is not None:
        stacked_activations[:, :] = layerwise_activations[
            :, :, self.hook_head_index
        ]
    elif layerwise_activations.ndim > 3:  # if we have a head dimension
        try:
            stacked_activations[:, :] = layerwise_activations.view(
                n_batches, n_context, -1
            )
        except RuntimeError as e:
            logger.error(f"Error during view operation: {e}")
            logger.info("Attempting to use reshape instead...")
            stacked_activations[:, :] = layerwise_activations.reshape(
                n_batches, n_context, -1
            )
    else:
        stacked_activations[:, :] = layerwise_activations

    return stacked_activations

get_batch_tokens(batch_size=None, raise_at_epoch_end=False)

Streams a batch of tokens from a dataset.

If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.

Source code in sae_lens/training/activations_store.py
def get_batch_tokens(
    self, batch_size: int | None = None, raise_at_epoch_end: bool = False
):
    """
    Streams a batch of tokens from a dataset.

    If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.
    """
    if not batch_size:
        batch_size = self.store_batch_size_prompts
    sequences = []
    # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
    for _ in range(batch_size):
        try:
            sequences.append(next(self.iterable_sequences))
        except StopIteration:
            self.iterable_sequences = self._iterate_tokenized_sequences()
            if raise_at_epoch_end:
                raise StopIteration(
                    f"Ran out of tokens in dataset after {self.n_dataset_processed} samples, beginning the next epoch."
                )
            sequences.append(next(self.iterable_sequences))

    return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

get_data_loader()

Return an auto-refilling stream of filtered and mixed activations.

Source code in sae_lens/training/activations_store.py
def get_data_loader(
    self,
) -> Iterator[Any]:
    """
    Return an auto-refilling stream of filtered and mixed activations.
    """
    return mixing_buffer(
        buffer_size=self.n_batches_in_buffer * self.training_context_size,
        batch_size=self.train_batch_size_tokens,
        activations_loader=self._iterate_filtered_activations(),
        mix_fraction=self.activations_mixing_fraction,
    )

get_filtered_llm_batch(raise_on_epoch_end=False)

Get a batch of LLM activations with special tokens filtered out.

Source code in sae_lens/training/activations_store.py
def get_filtered_llm_batch(
    self,
    raise_on_epoch_end: bool = False,
) -> torch.Tensor:
    """
    Get a batch of LLM activations with special tokens filtered out.
    """
    return _filter_buffer_acts(
        self.get_raw_llm_batch(raise_on_epoch_end=raise_on_epoch_end),
        self.exclude_special_tokens,
    )

get_raw_llm_batch(raise_on_epoch_end=False)

Loads the next batch of activations from the LLM and returns it.

If raise_on_epoch_end is True, when the dataset is exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.

Returns:

Type Description
Tensor

Tuple of (activations, token_ids) where activations has shape

Tensor | None

(batch_size * context_size, d_in) and token_ids has shape

tuple[Tensor, Tensor | None]

(batch_size * context_size,).

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_raw_llm_batch(
    self,
    raise_on_epoch_end: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """
    Loads the next batch of activations from the LLM and returns it.

    If raise_on_epoch_end is True, when the dataset is exhausted it will
    automatically refill the dataset and then raise a StopIteration so that
    the caller has a chance to react.

    Returns:
        Tuple of (activations, token_ids) where activations has shape
        (batch_size * context_size, d_in) and token_ids has shape
        (batch_size * context_size,).
    """
    d_in = self.d_in

    if self.cached_activation_dataset is not None:
        return self._load_raw_llm_batch_from_cached(raise_on_epoch_end)

    # move batch toks to gpu for model
    batch_tokens = self.get_batch_tokens(raise_at_epoch_end=raise_on_epoch_end).to(
        _get_model_device(self.model)
    )
    activations = self.get_activations(batch_tokens).to(self.device)

    # handle seqpos_slice, this is done for activations in get_activations
    batch_tokens = batch_tokens[:, slice(*self.seqpos_slice)]

    # reshape from (batch, context, d_in) to (batch * context, d_in)
    activations = activations.reshape(-1, d_in)
    token_ids = batch_tokens.reshape(-1)

    return activations, token_ids

load(file_path)

Load the state dict from a file in safetensors format

Source code in sae_lens/training/activations_store.py
def load(self, file_path: str):
    """Load the state dict from a file in safetensors format"""

    state_dict = load_file(file_path)

    if "n_dataset_processed" in state_dict:
        target_n_dataset_processed = state_dict["n_dataset_processed"].item()

        # Only fast-forward if needed

        if target_n_dataset_processed > self.n_dataset_processed:
            logger.info(
                "Fast-forwarding through dataset samples to match checkpoint position"
            )
            samples_to_skip = target_n_dataset_processed - self.n_dataset_processed

            pbar = tqdm(
                total=samples_to_skip,
                desc="Fast-forwarding through dataset",
                leave=False,
            )
            while target_n_dataset_processed > self.n_dataset_processed:
                start = self.n_dataset_processed
                try:
                    # Just consume and ignore the values to fast-forward
                    next(self.iterable_sequences)
                except StopIteration:
                    logger.warning(
                        "Dataset exhausted during fast-forward. Resetting dataset."
                    )
                    self.iterable_sequences = self._iterate_tokenized_sequences()
                pbar.update(self.n_dataset_processed - start)
            pbar.close()

load_cached_activation_dataset()

Load the cached activation dataset from disk.

  • If cached_activations_path is set, returns Huggingface Dataset else None
  • Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
Source code in sae_lens/training/activations_store.py
def load_cached_activation_dataset(self) -> Dataset | None:
    """
    Load the cached activation dataset from disk.

    - If cached_activations_path is set, returns Huggingface Dataset else None
    - Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
    """
    if self.cached_activations_path is None:
        return None

    assert self.cached_activations_path is not None  # keep pyright happy
    # Sanity check: does the cache directory exist?
    if not os.path.exists(self.cached_activations_path):
        raise FileNotFoundError(
            f"Cache directory {self.cached_activations_path} does not exist. "
            "Consider double-checking your dataset, model, and hook names."
        )

    # ---
    # Actual code
    activations_dataset = datasets.load_from_disk(self.cached_activations_path)
    columns = [self.hook_name]
    if "token_ids" in activations_dataset.column_names:
        columns.append("token_ids")
    activations_dataset.set_format(
        type="torch", columns=columns, device=self.device, dtype=self.dtype
    )
    self.current_row_idx = 0  # idx to load next batch from
    # ---

    assert isinstance(activations_dataset, Dataset)

    # multiple in hooks future
    if not set([self.hook_name]).issubset(activations_dataset.column_names):
        raise ValueError(
            f"loaded dataset does not include hook activations, got {activations_dataset.column_names}"
        )

    if activations_dataset.features[self.hook_name].shape != (
        self.context_size,
        self.d_in,
    ):
        raise ValueError(
            f"Given dataset of shape {activations_dataset.features[self.hook_name].shape} does not match context_size ({self.context_size}) and d_in ({self.d_in})"
        )

    return activations_dataset

load_from_checkpoint(checkpoint_path)

Load the state dict from a checkpoint path

Source code in sae_lens/training/activations_store.py
def load_from_checkpoint(self, checkpoint_path: str | Path):
    """Load the state dict from a checkpoint path"""
    self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))

next_batch()

Get next batch, updating buffer if needed.

Source code in sae_lens/training/activations_store.py
def next_batch(self) -> torch.Tensor:
    """Get next batch, updating buffer if needed."""
    return self.__next__()

reset_input_dataset()

Resets the input dataset iterator to the beginning.

Source code in sae_lens/training/activations_store.py
def reset_input_dataset(self):
    """
    Resets the input dataset iterator to the beginning.
    """
    self.iterable_dataset = iter(self.dataset)

save(file_path)

save the state dict to a file in safetensors format

Source code in sae_lens/training/activations_store.py
def save(self, file_path: str):
    """save the state dict to a file in safetensors format"""
    save_file(self.state_dict(), file_path)

save_to_checkpoint(checkpoint_path)

Save the state dict to a checkpoint path

Source code in sae_lens/training/activations_store.py
def save_to_checkpoint(self, checkpoint_path: str | Path):
    """Save the state dict to a checkpoint path"""
    self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))

shuffle_input_dataset(seed, buffer_size=1)

This applies a shuffle to the huggingface dataset that is the input to the activations store. This also shuffles the shards of the dataset, which is especially useful for evaluating on different sections of very large streaming datasets. Buffer size is only relevant for streaming datasets. The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will additionally shuffle individual elements within the shard.

Source code in sae_lens/training/activations_store.py
def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
    """
    This applies a shuffle to the huggingface dataset that is the input to the activations store. This
    also shuffles the shards of the dataset, which is especially useful for evaluating on different
    sections of very large streaming datasets. Buffer size is only relevant for streaming datasets.
    The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will
    additionally shuffle individual elements within the shard.
    """
    if isinstance(self.dataset, IterableDataset):
        self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size)
    else:
        self.dataset = self.dataset.shuffle(seed=seed)
    self.iterable_dataset = iter(self.dataset)

BatchTopKTrainingSAE

Bases: TopKTrainingSAE

Global Batch TopK Training SAE

This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.

BatchTopK SAEs are saved as JumpReLU SAEs after training.

Source code in sae_lens/saes/batchtopk_sae.py
class BatchTopKTrainingSAE(TopKTrainingSAE):
    """
    Global Batch TopK Training SAE

    This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.

    BatchTopK SAEs are saved as JumpReLU SAEs after training.
    """

    topk_threshold: torch.Tensor
    cfg: BatchTopKTrainingSAEConfig  # type: ignore[assignment]

    def __init__(self, cfg: BatchTopKTrainingSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

        self.register_buffer(
            "topk_threshold",
            # use double precision as otherwise we can run into numerical issues
            torch.tensor(0.0, dtype=torch.double, device=self.W_dec.device),
        )

    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        return BatchTopK(self.cfg.k)

    @override
    def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
        output = super().training_forward_pass(step_input)
        self.update_topk_threshold(output.feature_acts)
        output.metrics["topk_threshold"] = self.topk_threshold
        return output

    @torch.no_grad()
    def update_topk_threshold(self, acts_topk: torch.Tensor) -> None:
        positive_mask = acts_topk > 0
        lr = self.cfg.topk_threshold_lr
        # autocast can cause numerical issues with the threshold update
        with torch.autocast(self.topk_threshold.device.type, enabled=False):
            if positive_mask.any():
                min_positive = (
                    acts_topk[positive_mask].min().to(self.topk_threshold.dtype)
                )
                self.topk_threshold = (1 - lr) * self.topk_threshold + lr * min_positive

    @override
    def process_state_dict_for_saving_inference(
        self, state_dict: dict[str, Any]
    ) -> None:
        super().process_state_dict_for_saving_inference(state_dict)
        # turn the topk threshold into jumprelu threshold
        topk_threshold = state_dict.pop("topk_threshold").item()
        state_dict["threshold"] = torch.ones_like(self.b_enc) * topk_threshold

BatchTopKTrainingSAEConfig dataclass

Bases: TopKTrainingSAEConfig

Configuration class for training a BatchTopKTrainingSAE.

BatchTopK SAEs maintain k active features on average across the entire batch, rather than enforcing k features per sample like standard TopK SAEs. During training, the SAE learns a global threshold that is updated based on the minimum positive activation value. After training, BatchTopK SAEs are saved as JumpReLU SAEs.

Parameters:

Name Type Description Default
k float

Average number of features to keep active across the batch. Unlike standard TopK SAEs where k is an integer per sample, this is a float representing the average number of active features across all samples in the batch. Defaults to 100.

100
topk_threshold_lr float

Learning rate for updating the global topk threshold. The threshold is updated using an exponential moving average of the minimum positive activation value. Defaults to 0.01.

0.01
aux_loss_coefficient float

Coefficient for the auxiliary loss that encourages dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig. Defaults to 1.0.

1.0
rescale_acts_by_decoder_norm bool

Treat the decoder as if it was already normalized. Inherited from TopKTrainingSAEConfig. Defaults to True.

True
decoder_init_norm float | None

Norm to initialize decoder weights to. Inherited from TrainingSAEConfig. Defaults to 0.1.

0.1
d_in int

Input dimension (dimensionality of the activations being encoded). Inherited from SAEConfig.

required
d_sae int

SAE latent dimension (number of features in the SAE). Inherited from SAEConfig.

required
dtype str

Data type for the SAE parameters. Inherited from SAEConfig. Defaults to "float32".

'float32'
device str

Device to place the SAE on. Inherited from SAEConfig. Defaults to "cpu".

'cpu'
Source code in sae_lens/saes/batchtopk_sae.py
@dataclass
class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
    """
    Configuration class for training a BatchTopKTrainingSAE.

    BatchTopK SAEs maintain k active features on average across the entire batch,
    rather than enforcing k features per sample like standard TopK SAEs. During training,
    the SAE learns a global threshold that is updated based on the minimum positive
    activation value. After training, BatchTopK SAEs are saved as JumpReLU SAEs.

    Args:
        k (float): Average number of features to keep active across the batch. Unlike
            standard TopK SAEs where k is an integer per sample, this is a float
            representing the average number of active features across all samples in
            the batch. Defaults to 100.
        topk_threshold_lr (float): Learning rate for updating the global topk threshold.
            The threshold is updated using an exponential moving average of the minimum
            positive activation value. Defaults to 0.01.
        aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
            dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
            Defaults to 1.0.
        rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
            Inherited from TopKTrainingSAEConfig. Defaults to True.
        decoder_init_norm (float | None): Norm to initialize decoder weights to.
            Inherited from TrainingSAEConfig. Defaults to 0.1.
        d_in (int): Input dimension (dimensionality of the activations being encoded).
            Inherited from SAEConfig.
        d_sae (int): SAE latent dimension (number of features in the SAE).
            Inherited from SAEConfig.
        dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
            Defaults to "float32".
        device (str): Device to place the SAE on. Inherited from SAEConfig.
            Defaults to "cpu".
    """

    k: float = 100  # type: ignore[assignment]
    topk_threshold_lr: float = 0.01

    @override
    @classmethod
    def architecture(cls) -> str:
        return "batchtopk"

    @override
    def get_inference_config_class(self) -> type[SAEConfig]:
        return JumpReLUSAEConfig

CacheActivationsRunner

Source code in sae_lens/cache_activations_runner.py
class CacheActivationsRunner:
    def __init__(
        self,
        cfg: CacheActivationsRunnerConfig,
        override_dataset: Dataset | None = None,
    ):
        self.cfg = cfg
        self.model: HookedRootModule = load_model(
            model_class_name=self.cfg.model_class_name,
            model_name=self.cfg.model_name,
            device=self.cfg.device,
            model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
        )
        if self.cfg.compile_llm:
            self.model = torch.compile(self.model, mode=self.cfg.llm_compilation_mode)  # type: ignore
        self.activations_store = _mk_activations_store(
            self.model,
            self.cfg,
            override_dataset=override_dataset,
        )
        self.context_size = self._get_sliced_context_size(
            self.cfg.context_size, self.cfg.seqpos_slice
        )
        features_dict: dict[str, Array2D | Sequence] = {
            hook_name: Array2D(
                shape=(self.context_size, self.cfg.d_in), dtype=self.cfg.dtype
            )
            for hook_name in [self.cfg.hook_name]
        }
        features_dict["token_ids"] = Sequence(  # type: ignore
            Value(dtype="int32"), length=self.context_size
        )
        self.features = Features(features_dict)

    def __str__(self):
        """
        Print the number of tokens to be cached.
        Print the number of buffers, and the number of tokens per buffer.
        Print the disk space required to store the activations.

        """

        bytes_per_token = (
            self.cfg.d_in * self.cfg.dtype.itemsize
            if isinstance(self.cfg.dtype, torch.dtype)
            else str_to_dtype(self.cfg.dtype).itemsize
        )
        total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
        total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

        return (
            f"Activation Cache Runner:\n"
            f"Total training tokens: {total_training_tokens}\n"
            f"Number of buffers: {self.cfg.n_buffers}\n"
            f"Tokens per buffer: {self.cfg.n_tokens_in_buffer}\n"
            f"Disk space required: {total_disk_space_gb:.2f} GB\n"
            f"Configuration:\n"
            f"{self.cfg}"
        )

    @staticmethod
    def _consolidate_shards(
        source_dir: Path, output_dir: Path, copy_files: bool = True
    ) -> Dataset:
        """Consolidate sharded datasets into a single directory without rewriting data.

        Each of the shards must be of the same format, aka the full dataset must be able to
        be recreated like so:

        ```
        ds = concatenate_datasets(
            [Dataset.load_from_disk(str(shard_dir)) for shard_dir in sorted(source_dir.iterdir())]
        )

        ```

        Sharded dataset format:
        ```
        source_dir/
            shard_00000/
                dataset_info.json
                state.json
                data-00000-of-00002.arrow
                data-00001-of-00002.arrow
            shard_00001/
                dataset_info.json
                state.json
                data-00000-of-00001.arrow
        ```

        And flattens them into the format:

        ```
        output_dir/
            dataset_info.json
            state.json
            data-00000-of-00003.arrow
            data-00001-of-00003.arrow
            data-00002-of-00003.arrow
        ```

        allowing the dataset to be loaded like so:

        ```
        ds = datasets.load_from_disk(output_dir)
        ```

        Args:
            source_dir: Directory containing the sharded datasets
            output_dir: Directory to consolidate the shards into
            copy_files: If True, copy files; if False, move them and delete source_dir
        """
        first_shard_dir_name = "shard_00000"  # shard_{i:05d}

        if not source_dir.exists() or not source_dir.is_dir():
            raise NotADirectoryError(
                f"source_dir is not an existing directory: {source_dir}"
            )

        if not output_dir.exists() or not output_dir.is_dir():
            raise NotADirectoryError(
                f"output_dir is not an existing directory: {output_dir}"
            )

        other_items = [p for p in output_dir.iterdir() if p.name != ".tmp_shards"]
        if other_items:
            raise FileExistsError(
                f"output_dir must be empty (besides .tmp_shards). Found: {other_items}"
            )

        if not (source_dir / first_shard_dir_name).exists():
            raise Exception(f"No shards in {source_dir} exist!")

        transfer_fn = shutil.copy2 if copy_files else shutil.move

        # Move dataset_info.json from any shard (all the same)
        transfer_fn(
            source_dir / first_shard_dir_name / "dataset_info.json",
            output_dir / "dataset_info.json",
        )

        arrow_files = []
        file_count = 0

        for shard_dir in sorted(source_dir.iterdir()):
            if not shard_dir.name.startswith("shard_"):
                continue

            # state.json contains arrow filenames
            state = json.loads((shard_dir / "state.json").read_text())

            for data_file in state["_data_files"]:
                src = shard_dir / data_file["filename"]
                new_name = f"data-{file_count:05d}-of-{len(list(source_dir.iterdir())):05d}.arrow"
                dst = output_dir / new_name
                transfer_fn(src, dst)
                arrow_files.append({"filename": new_name})
                file_count += 1

        new_state = {
            "_data_files": arrow_files,
            "_fingerprint": None,  # temporary
            "_format_columns": None,
            "_format_kwargs": {},
            "_format_type": None,
            "_output_all_columns": False,
            "_split": None,
        }

        # fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
        with open(output_dir / "state.json", "w") as f:
            json.dump(new_state, f, indent=2)

        ds = Dataset.load_from_disk(str(output_dir))
        fingerprint = generate_fingerprint(ds)
        del ds

        with open(output_dir / "state.json", "r+") as f:
            state = json.loads(f.read())
            state["_fingerprint"] = fingerprint
            f.seek(0)
            json.dump(state, f, indent=2)
            f.truncate()

        if not copy_files:  # cleanup source dir
            shutil.rmtree(source_dir)

        return Dataset.load_from_disk(output_dir)

    @torch.no_grad()
    def run(self) -> Dataset:
        activation_save_path = self.cfg.new_cached_activations_path
        assert activation_save_path is not None

        ### Paths setup
        final_cached_activation_path = Path(activation_save_path)
        final_cached_activation_path.mkdir(exist_ok=True, parents=True)
        if any(final_cached_activation_path.iterdir()):
            raise Exception(
                f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
            )

        tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
        tmp_cached_activation_path.mkdir(exist_ok=False, parents=False)

        ### Create temporary sharded datasets

        logger.info(f"Started caching activations for {self.cfg.dataset_path}")

        for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
            try:
                # Accumulate n_batches_in_buffer batches into one shard
                buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
                for _ in range(self.cfg.n_batches_in_buffer):
                    buffers.append(self.activations_store.get_raw_llm_batch())
                # Concatenate all batches
                acts = torch.cat([b[0] for b in buffers], dim=0)
                token_ids: torch.Tensor | None = None
                if buffers[0][1] is not None:
                    # All batches have token_ids if the first one does
                    token_ids = torch.cat([b[1] for b in buffers], dim=0)  # type: ignore[arg-type]
                shard = self._create_shard((acts, token_ids))
                shard.save_to_disk(
                    f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
                )
                del buffers, acts, token_ids, shard
            except StopIteration:
                logger.warning(
                    f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
                )
                break

        ### Concatenate shards and push to Huggingface Hub

        dataset = self._consolidate_shards(
            tmp_cached_activation_path, final_cached_activation_path, copy_files=False
        )

        if self.cfg.shuffle:
            logger.info("Shuffling...")
            dataset = dataset.shuffle(seed=self.cfg.seed)

        if self.cfg.hf_repo_id:
            logger.info("Pushing to Huggingface Hub...")
            dataset.push_to_hub(
                repo_id=self.cfg.hf_repo_id,
                num_shards=self.cfg.hf_num_shards,
                private=self.cfg.hf_is_private_repo,
                revision=self.cfg.hf_revision,
            )

            meta_io = io.BytesIO()
            meta_contents = json.dumps(
                asdict(self.cfg), indent=2, ensure_ascii=False
            ).encode("utf-8")
            meta_io.write(meta_contents)
            meta_io.seek(0)

            api = HfApi()
            api.upload_file(
                path_or_fileobj=meta_io,
                path_in_repo="cache_activations_runner_cfg.json",
                repo_id=self.cfg.hf_repo_id,
                repo_type="dataset",
                commit_message="Add cache_activations_runner metadata",
            )

        return dataset

    def _create_shard(
        self,
        buffer: tuple[
            torch.Tensor,  # shape: (bs context_size) d_in
            torch.Tensor | None,  # shape: (bs context_size) or None
        ],
    ) -> Dataset:
        hook_names = [self.cfg.hook_name]
        acts, token_ids = buffer
        acts = einops.rearrange(
            acts,
            "(bs context_size) d_in -> bs context_size d_in",
            bs=self.cfg.n_seq_in_buffer,
            context_size=self.context_size,
            d_in=self.cfg.d_in,
        )
        shard_dict: dict[str, object] = {
            hook_name: act_batch
            for hook_name, act_batch in zip(hook_names, [acts], strict=True)
        }

        if token_ids is not None:
            token_ids = einops.rearrange(
                token_ids,
                "(bs context_size) -> bs context_size",
                bs=self.cfg.n_seq_in_buffer,
                context_size=self.context_size,
            )
            shard_dict["token_ids"] = token_ids.to(torch.int32)
        return Dataset.from_dict(
            shard_dict,
            features=self.features,
        )

    @staticmethod
    def _get_sliced_context_size(
        context_size: int, seqpos_slice: tuple[int | None, ...] | None
    ) -> int:
        if seqpos_slice is not None:
            context_size = len(range(context_size)[slice(*seqpos_slice)])
        return context_size

__str__()

Print the number of tokens to be cached. Print the number of buffers, and the number of tokens per buffer. Print the disk space required to store the activations.

Source code in sae_lens/cache_activations_runner.py
def __str__(self):
    """
    Print the number of tokens to be cached.
    Print the number of buffers, and the number of tokens per buffer.
    Print the disk space required to store the activations.

    """

    bytes_per_token = (
        self.cfg.d_in * self.cfg.dtype.itemsize
        if isinstance(self.cfg.dtype, torch.dtype)
        else str_to_dtype(self.cfg.dtype).itemsize
    )
    total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
    total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

    return (
        f"Activation Cache Runner:\n"
        f"Total training tokens: {total_training_tokens}\n"
        f"Number of buffers: {self.cfg.n_buffers}\n"
        f"Tokens per buffer: {self.cfg.n_tokens_in_buffer}\n"
        f"Disk space required: {total_disk_space_gb:.2f} GB\n"
        f"Configuration:\n"
        f"{self.cfg}"
    )

CacheActivationsRunnerConfig dataclass

Configuration for creating and caching activations of an LLM.

Parameters:

Name Type Description Default
dataset_path str

The path to the Hugging Face dataset. This may be tokenized or not.

required
model_name str

The name of the model to use.

required
model_batch_size int

How many prompts are in the batch of the language model when generating activations.

required
hook_name str

The name of the hook to use.

required
d_in int

Dimension of the model.

required
total_training_tokens int

Total number of tokens to process.

required
context_size int

Context size to process. Can be left as -1 if the dataset is tokenized.

-1
model_class_name str

The name of the class of the model to use. This should be either HookedTransformer or HookedMamba.

'HookedTransformer'
new_cached_activations_path str

The path to save the activations.

None
shuffle bool

Whether to shuffle the dataset.

True
seed int

The seed to use for shuffling.

42
dtype str

Datatype of activations to be stored.

'float32'
device str

The device for the model.

'cuda' if is_available() else 'cpu'
buffer_size_gb float

The buffer size in GB. This should be < 2GB.

2.0
hf_repo_id str

The Hugging Face repository id to save the activations to.

None
hf_num_shards int

The number of shards to save the activations to.

None
hf_revision str

The revision to save the activations to.

'main'
hf_is_private_repo bool

Whether the Hugging Face repository is private.

False
model_kwargs dict

Keyword arguments for model.run_with_cache.

dict()
model_from_pretrained_kwargs dict

Keyword arguments for the from_pretrained method of the model.

dict()
compile_llm bool

Whether to compile the LLM.

False
llm_compilation_mode str

The torch.compile mode to use.

None
prepend_bos bool

Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.

True
seqpos_slice tuple

Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.

(None,)
streaming bool

Whether to stream the dataset. Streaming large datasets is usually practical.

True
autocast_lm bool

Whether to use autocast during activation fetching.

False
dataset_trust_remote_code bool

Whether to trust remote code when loading datasets from Huggingface.

None
Source code in sae_lens/config.py
@dataclass
class CacheActivationsRunnerConfig:
    """
    Configuration for creating and caching activations of an LLM.

    Args:
        dataset_path (str): The path to the Hugging Face dataset. This may be tokenized or not.
        model_name (str): The name of the model to use.
        model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
        hook_name (str): The name of the hook to use.
        d_in (int): Dimension of the model.
        total_training_tokens (int): Total number of tokens to process.
        context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
        model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
        new_cached_activations_path (str, optional): The path to save the activations.
        shuffle (bool): Whether to shuffle the dataset.
        seed (int): The seed to use for shuffling.
        dtype (str): Datatype of activations to be stored.
        device (str): The device for the model.
        buffer_size_gb (float): The buffer size in GB. This should be < 2GB.
        hf_repo_id (str, optional): The Hugging Face repository id to save the activations to.
        hf_num_shards (int, optional): The number of shards to save the activations to.
        hf_revision (str): The revision to save the activations to.
        hf_is_private_repo (bool): Whether the Hugging Face repository is private.
        model_kwargs (dict): Keyword arguments for `model.run_with_cache`.
        model_from_pretrained_kwargs (dict): Keyword arguments for the `from_pretrained` method of the model.
        compile_llm (bool): Whether to compile the LLM.
        llm_compilation_mode (str): The torch.compile mode to use.
        prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
        seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
        streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
        autocast_lm (bool): Whether to use autocast during activation fetching.
        dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
    """

    dataset_path: str
    model_name: str
    model_batch_size: int
    hook_name: str
    d_in: int
    training_tokens: int

    context_size: int = -1  # Required if dataset is not tokenized
    model_class_name: str = "HookedTransformer"
    # defaults to "activations/{dataset}/{model}/{hook_name}
    new_cached_activations_path: str | None = None
    shuffle: bool = True
    seed: int = 42
    dtype: str = "float32"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    buffer_size_gb: float = 2.0  # HF datasets writer have problems with shards > 2GB

    # Huggingface Integration
    hf_repo_id: str | None = None
    hf_num_shards: int | None = None
    hf_revision: str = "main"
    hf_is_private_repo: bool = False

    # Model
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
    compile_llm: bool = False
    llm_compilation_mode: str | None = None  # which torch.compile mode to use

    # Activation Store
    prepend_bos: bool = True
    seqpos_slice: tuple[int | None, ...] = (None,)
    streaming: bool = True
    autocast_lm: bool = False
    dataset_trust_remote_code: bool | None = None

    def __post_init__(self):
        # Automatically determine context_size if dataset is tokenized
        if self.context_size == -1:
            ds = load_dataset(self.dataset_path, split="train", streaming=True)
            assert isinstance(ds, IterableDataset)
            first_sample = next(iter(ds))
            toks = first_sample.get("tokens") or first_sample.get("input_ids") or None
            if toks is None:
                raise ValueError(
                    "Dataset is not tokenized. Please specify context_size."
                )
            token_length = len(toks)
            self.context_size = token_length

        if self.context_size == -1:
            raise ValueError("context_size is still -1 after dataset inspection.")

        if self.seqpos_slice is not None:
            _validate_seqpos(
                seqpos=self.seqpos_slice,
                context_size=self.context_size,
            )

        if self.context_size > self.training_tokens:
            raise ValueError(
                f"context_size ({self.context_size}) is greater than training_tokens "
                f"({self.training_tokens}). Please reduce context_size or increase training_tokens."
            )

        if self.new_cached_activations_path is None:
            self.new_cached_activations_path = _default_cached_activations_path(  # type: ignore
                self.dataset_path, self.model_name, self.hook_name, None
            )

    @property
    def sliced_context_size(self) -> int:
        if self.seqpos_slice is not None:
            return len(range(self.context_size)[slice(*self.seqpos_slice)])
        return self.context_size

    @property
    def bytes_per_token(self) -> int:
        return self.d_in * str_to_dtype(self.dtype).itemsize

    @property
    def n_tokens_in_buffer(self) -> int:
        # Calculate raw tokens per buffer based on memory constraints
        _tokens_per_buffer = int(self.buffer_size_gb * 1e9) // self.bytes_per_token
        # Round down to nearest multiple of batch_token_size
        return _tokens_per_buffer - (_tokens_per_buffer % self.n_tokens_in_batch)

    @property
    def n_tokens_in_batch(self) -> int:
        return self.model_batch_size * self.sliced_context_size

    @property
    def n_batches_in_buffer(self) -> int:
        return self.n_tokens_in_buffer // self.n_tokens_in_batch

    @property
    def n_seq_in_dataset(self) -> int:
        return self.training_tokens // self.sliced_context_size

    @property
    def n_seq_in_buffer(self) -> int:
        return self.n_tokens_in_buffer // self.sliced_context_size

    @property
    def n_buffers(self) -> int:
        return math.ceil(self.training_tokens / self.n_tokens_in_buffer)

GatedSAE

Bases: SAE[GatedSAEConfig]

GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE) using a gated linear encoder and a standard linear decoder.

Source code in sae_lens/saes/gated_sae.py
class GatedSAE(SAE[GatedSAEConfig]):
    """
    GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
    using a gated linear encoder and a standard linear decoder.
    """

    b_gate: nn.Parameter
    b_mag: nn.Parameter
    r_mag: nn.Parameter

    def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)
        # Ensure b_enc does not exist for the gated architecture
        self.b_enc = None

    @override
    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_gated(self)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode the input tensor into the feature space using a gated encoder.
        This must match the original encode_gated implementation from SAE class.
        """
        # Preprocess the SAE input (casting type, applying hooks, normalization)
        sae_in = self.process_sae_in(x)

        # Gating path exactly as in original SAE.encode_gated
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).to(self.dtype)

        # Magnitude path (weight sharing with gated encoder)
        magnitude_pre_activation = self.hook_sae_acts_pre(
            sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        )
        feature_magnitudes = self.activation_fn(magnitude_pre_activation)

        # Combine gating and magnitudes
        return self.hook_sae_acts_post(active_features * feature_magnitudes)

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        Decode the feature activations back into the input space:
          1) Apply optional finetuning scaling.
          2) Linear transform plus bias.
          3) Run any reconstruction hooks and out-normalization if configured.
          4) If the SAE was reshaping hook_z activations, reshape back.
        """
        # 1) optional finetuning scaling
        # 2) linear transform
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        # 3) hooking and normalization
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        # 4) reshape if needed (hook_z)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """Override to handle gated-specific parameters."""
        W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T

        # Gated-specific parameters need special handling
        # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
        self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
        self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()

decode(feature_acts)

Decode the feature activations back into the input space

1) Apply optional finetuning scaling. 2) Linear transform plus bias. 3) Run any reconstruction hooks and out-normalization if configured. 4) If the SAE was reshaping hook_z activations, reshape back.

Source code in sae_lens/saes/gated_sae.py
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """
    Decode the feature activations back into the input space:
      1) Apply optional finetuning scaling.
      2) Linear transform plus bias.
      3) Run any reconstruction hooks and out-normalization if configured.
      4) If the SAE was reshaping hook_z activations, reshape back.
    """
    # 1) optional finetuning scaling
    # 2) linear transform
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    # 3) hooking and normalization
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    # 4) reshape if needed (hook_z)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Encode the input tensor into the feature space using a gated encoder. This must match the original encode_gated implementation from SAE class.

Source code in sae_lens/saes/gated_sae.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode the input tensor into the feature space using a gated encoder.
    This must match the original encode_gated implementation from SAE class.
    """
    # Preprocess the SAE input (casting type, applying hooks, normalization)
    sae_in = self.process_sae_in(x)

    # Gating path exactly as in original SAE.encode_gated
    gating_pre_activation = sae_in @ self.W_enc + self.b_gate
    active_features = (gating_pre_activation > 0).to(self.dtype)

    # Magnitude path (weight sharing with gated encoder)
    magnitude_pre_activation = self.hook_sae_acts_pre(
        sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
    )
    feature_magnitudes = self.activation_fn(magnitude_pre_activation)

    # Combine gating and magnitudes
    return self.hook_sae_acts_post(active_features * feature_magnitudes)

fold_W_dec_norm()

Override to handle gated-specific parameters.

Source code in sae_lens/saes/gated_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """Override to handle gated-specific parameters."""
    W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
    self.W_dec.data = self.W_dec.data / W_dec_norms
    self.W_enc.data = self.W_enc.data * W_dec_norms.T

    # Gated-specific parameters need special handling
    # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
    self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
    self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()

GatedSAEConfig dataclass

Bases: SAEConfig

Configuration class for a GatedSAE.

Source code in sae_lens/saes/gated_sae.py
@dataclass
class GatedSAEConfig(SAEConfig):
    """
    Configuration class for a GatedSAE.
    """

    @override
    @classmethod
    def architecture(cls) -> str:
        return "gated"

GatedTrainingSAE

Bases: TrainingSAE[GatedTrainingSAEConfig]

GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture. It implements:

  • initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
  • encode: calls encode_with_hidden_pre (standard training approach).
  • decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
  • encode_with_hidden_pre: gating logic.
  • calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
  • training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
Source code in sae_lens/saes/gated_sae.py
class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
    """
    GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
    It implements:

      - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
      - encode: calls encode_with_hidden_pre (standard training approach).
      - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
      - encode_with_hidden_pre: gating logic.
      - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
      - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
    """

    b_gate: nn.Parameter  # type: ignore
    b_mag: nn.Parameter  # type: ignore
    r_mag: nn.Parameter  # type: ignore

    def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
        if use_error_term:
            raise ValueError(
                "GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
            )
        super().__init__(cfg, use_error_term)

    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_gated(self)

    def encode_with_hidden_pre(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Gated forward pass with pre-activation (for training).
        """
        sae_in = self.process_sae_in(x)

        # Gating path
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).to(self.dtype)

        # Magnitude path
        magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)

        feature_magnitudes = self.activation_fn(magnitude_pre_activation)

        # Combine gating path and magnitude path
        feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

        # Return both the final feature activations and the pre-activation (for logging or penalty)
        return feature_acts, magnitude_pre_activation

    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # Re-center the input if apply_b_dec_to_input is set
        sae_in_centered = step_input.sae_in - (
            self.b_dec * self.cfg.apply_b_dec_to_input
        )

        # The gating pre-activation (pi_gate) for the auxiliary path
        pi_gate = sae_in_centered @ self.W_enc + self.b_gate
        pi_gate_act = torch.relu(pi_gate)

        # L1-like penalty scaled by W_dec norms
        l1_loss = (
            step_input.coefficients["l1"]
            * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
        )

        # Aux reconstruction: reconstruct x purely from gating path
        via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
        aux_recon_loss = (
            (via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
        )

        # Return both losses separately
        return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}

    def log_histograms(self) -> dict[str, NDArray[Any]]:
        """Log histograms of the weights and biases."""
        b_gate_dist = self.b_gate.detach().float().cpu().numpy()
        b_mag_dist = self.b_mag.detach().float().cpu().numpy()
        return {
            **super().log_histograms(),
            "weights/b_gate": b_gate_dist,
            "weights/b_mag": b_mag_dist,
        }

    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {
            "l1": TrainCoefficientConfig(
                value=self.cfg.l1_coefficient,
                warm_up_steps=self.cfg.l1_warm_up_steps,
            ),
        }

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """Override to handle gated-specific parameters."""
        W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T

        # Gated-specific parameters need special handling
        # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
        self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
        self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()

encode_with_hidden_pre(x)

Gated forward pass with pre-activation (for training).

Source code in sae_lens/saes/gated_sae.py
def encode_with_hidden_pre(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Gated forward pass with pre-activation (for training).
    """
    sae_in = self.process_sae_in(x)

    # Gating path
    gating_pre_activation = sae_in @ self.W_enc + self.b_gate
    active_features = (gating_pre_activation > 0).to(self.dtype)

    # Magnitude path
    magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
    magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)

    feature_magnitudes = self.activation_fn(magnitude_pre_activation)

    # Combine gating path and magnitude path
    feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

    # Return both the final feature activations and the pre-activation (for logging or penalty)
    return feature_acts, magnitude_pre_activation

fold_W_dec_norm()

Override to handle gated-specific parameters.

Source code in sae_lens/saes/gated_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """Override to handle gated-specific parameters."""
    W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
    self.W_dec.data = self.W_dec.data / W_dec_norms
    self.W_enc.data = self.W_enc.data * W_dec_norms.T

    # Gated-specific parameters need special handling
    # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
    self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
    self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()

log_histograms()

Log histograms of the weights and biases.

Source code in sae_lens/saes/gated_sae.py
def log_histograms(self) -> dict[str, NDArray[Any]]:
    """Log histograms of the weights and biases."""
    b_gate_dist = self.b_gate.detach().float().cpu().numpy()
    b_mag_dist = self.b_mag.detach().float().cpu().numpy()
    return {
        **super().log_histograms(),
        "weights/b_gate": b_gate_dist,
        "weights/b_mag": b_mag_dist,
    }

GatedTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a GatedTrainingSAE.

Source code in sae_lens/saes/gated_sae.py
@dataclass
class GatedTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a GatedTrainingSAE.
    """

    l1_coefficient: float = 1.0
    l1_warm_up_steps: int = 0

    @override
    @classmethod
    def architecture(cls) -> str:
        return "gated"

HookedSAETransformer

Bases: HookedTransformer

Source code in sae_lens/analysis/hooked_sae_transformer.py
class HookedSAETransformer(HookedTransformer):
    def __init__(
        self,
        *model_args: Any,
        **model_kwargs: Any,
    ):
        """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

        Note that if you want to load the model from pretrained weights, you should use
        :meth:`from_pretrained` instead.

        Args:
            *model_args: Positional arguments for HookedTransformer initialization
            **model_kwargs: Keyword arguments for HookedTransformer initialization
        """
        super().__init__(*model_args, **model_kwargs)

        for block in self.blocks:
            add_hook_in_to_mlp(block.mlp)  # type: ignore
        self.setup()

        self._acts_to_saes: dict[str, _SAEWrapper] = {}
        # Track output hooks used by transcoders for cleanup
        self._transcoder_output_hooks: dict[str, str] = {}

    @property
    def acts_to_saes(self) -> dict[str, SAE[Any]]:
        """Returns a dict mapping hook names to attached SAEs."""
        return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}

    def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
        """Attaches an SAE or Transcoder to the model.

        WARNING: This SAE will be permanently attached until you remove it with
        reset_saes. This function will also overwrite any existing SAE attached
        to the same hook point.

        Args:
            sae: The SAE or Transcoder to attach to the model.
            use_error_term: If True, computes error term so output matches what the
                model would have produced without the SAE. This works for both SAEs
                (where input==output hook) and transcoders (where they differ).
                Defaults to None (uses SAE's existing setting).
        """
        input_hook = sae.cfg.metadata.hook_name
        output_hook = sae.cfg.metadata.hook_name_out or input_hook

        if (input_hook not in self._acts_to_saes) and (
            input_hook not in self.hook_dict
        ):
            logger.warning(
                f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
            )
            return

        # Check if output hook exists (either as hook_dict entry or already has SAE attached)
        output_hook_exists = (
            output_hook in self.hook_dict
            or output_hook in self._acts_to_saes
            or any(v == output_hook for v in self._transcoder_output_hooks.values())
        )
        if not output_hook_exists:
            logger.warning(f"No hook found for output {output_hook}. Skipping.")
            return

        # Always use wrapper - it handles both SAEs and transcoders uniformly
        # If use_error_term not specified, respect SAE's existing setting
        effective_use_error_term = (
            use_error_term if use_error_term is not None else sae.use_error_term
        )
        wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)

        # For transcoders (input != output), capture input at input hook
        if input_hook != output_hook:
            input_hook_point = get_deep_attr(self, input_hook)
            if isinstance(input_hook_point, HookPoint):
                input_hook_point.add_hook(
                    lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1],  # noqa: ARG005
                    dir="fwd",
                    is_permanent=True,
                )
            self._transcoder_output_hooks[input_hook] = output_hook

        # Store wrapper in _acts_to_saes and at output hook
        self._acts_to_saes[input_hook] = wrapper
        set_deep_attr(self, output_hook, wrapper)
        self.setup()

    def _reset_sae(
        self, act_name: str, prev_wrapper: _SAEWrapper | None = None
    ) -> None:
        """Resets an SAE that was attached to the model.

        By default will remove the SAE from that hook_point.
        If prev_wrapper is provided, will restore that wrapper's SAE with its settings.

        Args:
            act_name: The hook_name of the SAE to reset.
            prev_wrapper: The previous wrapper to restore. If None, will just
                remove the SAE from this hook point. Defaults to None.
        """
        if act_name not in self._acts_to_saes:
            logger.warning(
                f"No SAE is attached to {act_name}. There's nothing to reset."
            )
            return

        # Determine output hook location (different from input for transcoders)
        output_hook = self._transcoder_output_hooks.pop(act_name, act_name)

        # For transcoders, clear permanent hooks from input hook point
        if output_hook != act_name:
            input_hook_point = get_deep_attr(self, act_name)
            if isinstance(input_hook_point, HookPoint):
                input_hook_point.remove_hooks(dir="fwd", including_permanent=True)

        # Reset output hook location
        set_deep_attr(self, output_hook, HookPoint())
        del self._acts_to_saes[act_name]

        if prev_wrapper is not None:
            # Rebuild hook_dict before adding new SAE
            self.setup()
            self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)

    def reset_saes(
        self,
        act_names: str | list[str] | None = None,
    ) -> None:
        """Reset the SAEs attached to the model.

        If act_names are provided will just reset SAEs attached to those hooks.
        Otherwise will reset all SAEs attached to the model.

        Args:
            act_names: The act_names of the SAEs to reset. If None, will reset
                all SAEs attached to the model. Defaults to None.
        """
        if isinstance(act_names, str):
            act_names = [act_names]
        elif act_names is None:
            act_names = list(self._acts_to_saes.keys())

        for act_name in act_names:
            self._reset_sae(act_name)

        self.setup()

    def run_with_saes(
        self,
        *model_args: Any,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        use_error_term: bool | None = None,
        **model_kwargs: Any,
    ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
        """Wrapper around HookedTransformer forward pass.

        Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
            reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(
            saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
        ):
            return self(*model_args, **model_kwargs)

    def run_with_cache_with_saes(
        self,
        *model_args: Any,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        use_error_term: bool | None = None,
        return_cache_object: bool = True,
        remove_batch_dim: bool = False,
        **kwargs: Any,
    ) -> tuple[
        None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
        ActivationCache | dict[str, torch.Tensor],
    ]:
        """Wrapper around 'run_with_cache' in HookedTransformer.

        Attaches given SAEs before running the model with cache and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
            return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
                useful HookedTransformer specific methods, otherwise it will return a dictionary of
                activations as in HookedRootModule.
            remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
            **kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(
            saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
        ):
            return self.run_with_cache(  # type: ignore
                *model_args,
                return_cache_object=return_cache_object,  # type: ignore
                remove_batch_dim=remove_batch_dim,
                **kwargs,
            )

    def run_with_hooks_with_saes(
        self,
        *model_args: Any,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        fwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
        bwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
        reset_hooks_end: bool = True,
        clear_contexts: bool = False,
        **model_kwargs: Any,
    ):
        """Wrapper around 'run_with_hooks' in HookedTransformer.

        Attaches the given SAEs to the model before running the model with hooks and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
            fwd_hooks: (list[tuple[str | Callable, Callable]]) List of forward hooks to apply
            bwd_hooks: (list[tuple[str | Callable, Callable]]) List of backward hooks to apply
            reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
            clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(saes=saes, reset_saes_end=reset_saes_end):
            return self.run_with_hooks(
                *model_args,
                fwd_hooks=fwd_hooks,
                bwd_hooks=bwd_hooks,
                reset_hooks_end=reset_hooks_end,
                clear_contexts=clear_contexts,
                **model_kwargs,
            )

    @contextmanager
    def saes(
        self,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        use_error_term: bool | None = None,
    ):
        """A context manager for adding temporary SAEs to the model.

        See HookedTransformer.hooks for a similar context manager for hooks.
        By default will keep track of previously attached SAEs, and restore
        them when the context manager exits.

        Args:
            saes: SAEs to be attached.
            reset_saes_end: If True, removes all SAEs added by this context
                manager when the context manager exits, returning previously
                attached SAEs to their original state.
            use_error_term: If provided, will set the use_error_term attribute
                of all SAEs attached during this run to this value.
        """
        saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
        if isinstance(saes, SAE):
            saes = [saes]
        try:
            for sae in saes:
                act_name = sae.cfg.metadata.hook_name
                prev_wrapper = self._acts_to_saes.get(act_name, None)
                saes_to_restore.append((act_name, prev_wrapper))
                self.add_sae(sae, use_error_term=use_error_term)
            yield self
        finally:
            if reset_saes_end:
                for act_name, prev_wrapper in saes_to_restore:
                    self._reset_sae(act_name, prev_wrapper)
                self.setup()

acts_to_saes: dict[str, SAE[Any]] property

Returns a dict mapping hook names to attached SAEs.

__init__(*model_args, **model_kwargs)

Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

Note that if you want to load the model from pretrained weights, you should use :meth:from_pretrained instead.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for HookedTransformer initialization

()
**model_kwargs Any

Keyword arguments for HookedTransformer initialization

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def __init__(
    self,
    *model_args: Any,
    **model_kwargs: Any,
):
    """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

    Note that if you want to load the model from pretrained weights, you should use
    :meth:`from_pretrained` instead.

    Args:
        *model_args: Positional arguments for HookedTransformer initialization
        **model_kwargs: Keyword arguments for HookedTransformer initialization
    """
    super().__init__(*model_args, **model_kwargs)

    for block in self.blocks:
        add_hook_in_to_mlp(block.mlp)  # type: ignore
    self.setup()

    self._acts_to_saes: dict[str, _SAEWrapper] = {}
    # Track output hooks used by transcoders for cleanup
    self._transcoder_output_hooks: dict[str, str] = {}

add_sae(sae, use_error_term=None)

Attaches an SAE or Transcoder to the model.

WARNING: This SAE will be permanently attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

Parameters:

Name Type Description Default
sae SAE[Any]

The SAE or Transcoder to attach to the model.

required
use_error_term bool | None

If True, computes error term so output matches what the model would have produced without the SAE. This works for both SAEs (where input==output hook) and transcoders (where they differ). Defaults to None (uses SAE's existing setting).

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
    """Attaches an SAE or Transcoder to the model.

    WARNING: This SAE will be permanently attached until you remove it with
    reset_saes. This function will also overwrite any existing SAE attached
    to the same hook point.

    Args:
        sae: The SAE or Transcoder to attach to the model.
        use_error_term: If True, computes error term so output matches what the
            model would have produced without the SAE. This works for both SAEs
            (where input==output hook) and transcoders (where they differ).
            Defaults to None (uses SAE's existing setting).
    """
    input_hook = sae.cfg.metadata.hook_name
    output_hook = sae.cfg.metadata.hook_name_out or input_hook

    if (input_hook not in self._acts_to_saes) and (
        input_hook not in self.hook_dict
    ):
        logger.warning(
            f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
        )
        return

    # Check if output hook exists (either as hook_dict entry or already has SAE attached)
    output_hook_exists = (
        output_hook in self.hook_dict
        or output_hook in self._acts_to_saes
        or any(v == output_hook for v in self._transcoder_output_hooks.values())
    )
    if not output_hook_exists:
        logger.warning(f"No hook found for output {output_hook}. Skipping.")
        return

    # Always use wrapper - it handles both SAEs and transcoders uniformly
    # If use_error_term not specified, respect SAE's existing setting
    effective_use_error_term = (
        use_error_term if use_error_term is not None else sae.use_error_term
    )
    wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)

    # For transcoders (input != output), capture input at input hook
    if input_hook != output_hook:
        input_hook_point = get_deep_attr(self, input_hook)
        if isinstance(input_hook_point, HookPoint):
            input_hook_point.add_hook(
                lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1],  # noqa: ARG005
                dir="fwd",
                is_permanent=True,
            )
        self._transcoder_output_hooks[input_hook] = output_hook

    # Store wrapper in _acts_to_saes and at output hook
    self._acts_to_saes[input_hook] = wrapper
    set_deep_attr(self, output_hook, wrapper)
    self.setup()

reset_saes(act_names=None)

Reset the SAEs attached to the model.

If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.

Parameters:

Name Type Description Default
act_names str | list[str] | None

The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def reset_saes(
    self,
    act_names: str | list[str] | None = None,
) -> None:
    """Reset the SAEs attached to the model.

    If act_names are provided will just reset SAEs attached to those hooks.
    Otherwise will reset all SAEs attached to the model.

    Args:
        act_names: The act_names of the SAEs to reset. If None, will reset
            all SAEs attached to the model. Defaults to None.
    """
    if isinstance(act_names, str):
        act_names = [act_names]
    elif act_names is None:
        act_names = list(self._acts_to_saes.keys())

    for act_name in act_names:
        self._reset_sae(act_name)

    self.setup()

run_with_cache_with_saes(*model_args, saes=[], reset_saes_end=True, use_error_term=None, return_cache_object=True, remove_batch_dim=False, **kwargs)

Wrapper around 'run_with_cache' in HookedTransformer.

Attaches given SAEs before running the model with cache and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes SAE[Any] | list[SAE[Any]]

(SAE | list[SAE]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
use_error_term bool | None

(bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.

None
return_cache_object bool

(bool) if True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule.

True
remove_batch_dim bool

(bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

False
**kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_cache_with_saes(
    self,
    *model_args: Any,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    use_error_term: bool | None = None,
    return_cache_object: bool = True,
    remove_batch_dim: bool = False,
    **kwargs: Any,
) -> tuple[
    None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
    ActivationCache | dict[str, torch.Tensor],
]:
    """Wrapper around 'run_with_cache' in HookedTransformer.

    Attaches given SAEs before running the model with cache and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
        return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
            useful HookedTransformer specific methods, otherwise it will return a dictionary of
            activations as in HookedRootModule.
        remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
        **kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(
        saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
    ):
        return self.run_with_cache(  # type: ignore
            *model_args,
            return_cache_object=return_cache_object,  # type: ignore
            remove_batch_dim=remove_batch_dim,
            **kwargs,
        )

run_with_hooks_with_saes(*model_args, saes=[], reset_saes_end=True, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False, **model_kwargs)

Wrapper around 'run_with_hooks' in HookedTransformer.

Attaches the given SAEs to the model before running the model with hooks and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes SAE[Any] | list[SAE[Any]]

(SAE | list[SAE]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)

True
fwd_hooks list[tuple[str | Callable, Callable]]

(list[tuple[str | Callable, Callable]]) List of forward hooks to apply

[]
bwd_hooks list[tuple[str | Callable, Callable]]

(list[tuple[str | Callable, Callable]]) List of backward hooks to apply

[]
reset_hooks_end bool

(bool) Whether to reset the hooks at the end of the forward pass (default: True)

True
clear_contexts bool

(bool) Whether to clear the contexts at the end of the forward pass (default: False)

False
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_hooks_with_saes(
    self,
    *model_args: Any,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    fwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
    bwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
    reset_hooks_end: bool = True,
    clear_contexts: bool = False,
    **model_kwargs: Any,
):
    """Wrapper around 'run_with_hooks' in HookedTransformer.

    Attaches the given SAEs to the model before running the model with hooks and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
        fwd_hooks: (list[tuple[str | Callable, Callable]]) List of forward hooks to apply
        bwd_hooks: (list[tuple[str | Callable, Callable]]) List of backward hooks to apply
        reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
        clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(saes=saes, reset_saes_end=reset_saes_end):
        return self.run_with_hooks(
            *model_args,
            fwd_hooks=fwd_hooks,
            bwd_hooks=bwd_hooks,
            reset_hooks_end=reset_hooks_end,
            clear_contexts=clear_contexts,
            **model_kwargs,
        )

run_with_saes(*model_args, saes=[], reset_saes_end=True, use_error_term=None, **model_kwargs)

Wrapper around HookedTransformer forward pass.

Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes SAE[Any] | list[SAE[Any]]

(SAE | list[SAE]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
use_error_term bool | None

(bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.

None
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_saes(
    self,
    *model_args: Any,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    use_error_term: bool | None = None,
    **model_kwargs: Any,
) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
    """Wrapper around HookedTransformer forward pass.

    Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
        reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(
        saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
    ):
        return self(*model_args, **model_kwargs)

saes(saes=[], reset_saes_end=True, use_error_term=None)

A context manager for adding temporary SAEs to the model.

See HookedTransformer.hooks for a similar context manager for hooks. By default will keep track of previously attached SAEs, and restore them when the context manager exits.

Parameters:

Name Type Description Default
saes SAE[Any] | list[SAE[Any]]

SAEs to be attached.

[]
reset_saes_end bool

If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.

True
use_error_term bool | None

If provided, will set the use_error_term attribute of all SAEs attached during this run to this value.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
@contextmanager
def saes(
    self,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    use_error_term: bool | None = None,
):
    """A context manager for adding temporary SAEs to the model.

    See HookedTransformer.hooks for a similar context manager for hooks.
    By default will keep track of previously attached SAEs, and restore
    them when the context manager exits.

    Args:
        saes: SAEs to be attached.
        reset_saes_end: If True, removes all SAEs added by this context
            manager when the context manager exits, returning previously
            attached SAEs to their original state.
        use_error_term: If provided, will set the use_error_term attribute
            of all SAEs attached during this run to this value.
    """
    saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
    if isinstance(saes, SAE):
        saes = [saes]
    try:
        for sae in saes:
            act_name = sae.cfg.metadata.hook_name
            prev_wrapper = self._acts_to_saes.get(act_name, None)
            saes_to_restore.append((act_name, prev_wrapper))
            self.add_sae(sae, use_error_term=use_error_term)
        yield self
    finally:
        if reset_saes_end:
            for act_name, prev_wrapper in saes_to_restore:
                self._reset_sae(act_name, prev_wrapper)
            self.setup()

JumpReLUSAE

Bases: SAE[JumpReLUSAEConfig]

JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE) using a JumpReLU activation. For each unit, if its pre-activation is <= threshold, that unit is zeroed out; otherwise, it follows a user-specified activation function (e.g., ReLU etc.).

It implements:

  • initialize_weights: sets up parameters, including a threshold.
  • encode: computes the feature activations using JumpReLU.
  • decode: reconstructs the input from the feature activations.

The BaseSAE.forward() method automatically calls encode and decode, including any error-term processing if configured.

Source code in sae_lens/saes/jumprelu_sae.py
class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
    """
    JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
    using a JumpReLU activation. For each unit, if its pre-activation is
    <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
    activation function (e.g., ReLU etc.).

    It implements:

      - initialize_weights: sets up parameters, including a threshold.
      - encode: computes the feature activations using JumpReLU.
      - decode: reconstructs the input from the feature activations.

    The BaseSAE.forward() method automatically calls encode and decode,
    including any error-term processing if configured.
    """

    b_enc: nn.Parameter
    threshold: nn.Parameter

    def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        super().initialize_weights()
        self.threshold = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode the input tensor into the feature space using JumpReLU.
        The threshold parameter determines which units remain active.
        """
        sae_in = self.process_sae_in(x)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

        # 1) Apply the base "activation_fn" from config (e.g., ReLU).
        base_acts = self.activation_fn(hidden_pre)

        # 2) Zero out any unit whose (hidden_pre <= threshold).
        #    We cast the boolean mask to the same dtype for safe multiplication.
        jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)

        # 3) Multiply the normally activated units by that mask.
        return self.hook_sae_acts_post(base_acts * jump_relu_mask)

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        Decode the feature activations back to the input space.
        Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
        """
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """
        Override to properly handle threshold adjustment with W_dec norms.
        When we scale the encoder weights, we need to scale the threshold
        by the same factor to maintain the same sparsity pattern.
        """
        # Save the current threshold before calling parent method
        current_thresh = self.threshold.clone()

        # Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
        W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)

        # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
        super().fold_W_dec_norm()

        # Scale the threshold by the same factor as we scaled b_enc
        # This ensures the same features remain active/inactive after folding
        self.threshold.data = current_thresh * W_dec_norms

decode(feature_acts)

Decode the feature activations back to the input space. Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.

Source code in sae_lens/saes/jumprelu_sae.py
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """
    Decode the feature activations back to the input space.
    Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
    """
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Encode the input tensor into the feature space using JumpReLU. The threshold parameter determines which units remain active.

Source code in sae_lens/saes/jumprelu_sae.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode the input tensor into the feature space using JumpReLU.
    The threshold parameter determines which units remain active.
    """
    sae_in = self.process_sae_in(x)
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

    # 1) Apply the base "activation_fn" from config (e.g., ReLU).
    base_acts = self.activation_fn(hidden_pre)

    # 2) Zero out any unit whose (hidden_pre <= threshold).
    #    We cast the boolean mask to the same dtype for safe multiplication.
    jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)

    # 3) Multiply the normally activated units by that mask.
    return self.hook_sae_acts_post(base_acts * jump_relu_mask)

fold_W_dec_norm()

Override to properly handle threshold adjustment with W_dec norms. When we scale the encoder weights, we need to scale the threshold by the same factor to maintain the same sparsity pattern.

Source code in sae_lens/saes/jumprelu_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """
    Override to properly handle threshold adjustment with W_dec norms.
    When we scale the encoder weights, we need to scale the threshold
    by the same factor to maintain the same sparsity pattern.
    """
    # Save the current threshold before calling parent method
    current_thresh = self.threshold.clone()

    # Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
    W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)

    # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
    super().fold_W_dec_norm()

    # Scale the threshold by the same factor as we scaled b_enc
    # This ensures the same features remain active/inactive after folding
    self.threshold.data = current_thresh * W_dec_norms

JumpReLUSAEConfig dataclass

Bases: SAEConfig

Configuration class for a JumpReLUSAE.

Source code in sae_lens/saes/jumprelu_sae.py
@dataclass
class JumpReLUSAEConfig(SAEConfig):
    """
    Configuration class for a JumpReLUSAE.
    """

    @override
    @classmethod
    def architecture(cls) -> str:
        return "jumprelu"

JumpReLUSkipTranscoder

Bases: JumpReLUTranscoder, SkipTranscoder

A transcoder with a learnable skip connection and JumpReLU activation function.

Source code in sae_lens/saes/transcoder.py
class JumpReLUSkipTranscoder(JumpReLUTranscoder, SkipTranscoder):
    """
    A transcoder with a learnable skip connection and JumpReLU activation function.
    """

    cfg: JumpReLUSkipTranscoderConfig  # type: ignore[assignment]

    def __init__(self, cfg: JumpReLUSkipTranscoderConfig):
        super().__init__(cfg)

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoder":
        cfg = JumpReLUSkipTranscoderConfig.from_dict(config_dict)
        return cls(cfg)

JumpReLUSkipTranscoderConfig dataclass

Bases: JumpReLUTranscoderConfig

Configuration for JumpReLU transcoder.

Source code in sae_lens/saes/transcoder.py
@dataclass
class JumpReLUSkipTranscoderConfig(JumpReLUTranscoderConfig):
    """Configuration for JumpReLU transcoder."""

    @classmethod
    def architecture(cls) -> str:
        """Return the architecture name for this config."""
        return "jumprelu_skip_transcoder"

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
        """Create a JumpReLUSkipTranscoderConfig from a dictionary."""
        # Filter to only include valid dataclass fields
        filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

        # Create the config instance
        res = cls(**filtered_config_dict)

        # Handle metadata if present
        if "metadata" in config_dict:
            res.metadata = SAEMetadata(**config_dict["metadata"])

        return res

architecture() classmethod

Return the architecture name for this config.

Source code in sae_lens/saes/transcoder.py
@classmethod
def architecture(cls) -> str:
    """Return the architecture name for this config."""
    return "jumprelu_skip_transcoder"

from_dict(config_dict) classmethod

Create a JumpReLUSkipTranscoderConfig from a dictionary.

Source code in sae_lens/saes/transcoder.py
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
    """Create a JumpReLUSkipTranscoderConfig from a dictionary."""
    # Filter to only include valid dataclass fields
    filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

    # Create the config instance
    res = cls(**filtered_config_dict)

    # Handle metadata if present
    if "metadata" in config_dict:
        res.metadata = SAEMetadata(**config_dict["metadata"])

    return res

JumpReLUTrainingSAE

Bases: TrainingSAE[JumpReLUTrainingSAEConfig]

JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.

Similar to the inference-only JumpReLUSAE, but with:

  • A learnable log-threshold parameter (instead of a raw threshold).
  • A specialized auxiliary loss term for sparsity (L0 or similar).

Methods of interest include:

  • initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
  • encode_with_hidden_pre_jumprelu: runs a forward pass for training.
  • training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
Source code in sae_lens/saes/jumprelu_sae.py
class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
    """
    JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.

    Similar to the inference-only JumpReLUSAE, but with:

      - A learnable log-threshold parameter (instead of a raw threshold).
      - A specialized auxiliary loss term for sparsity (L0 or similar).

    Methods of interest include:

    - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
    - encode_with_hidden_pre_jumprelu: runs a forward pass for training.
    - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
    """

    b_enc: nn.Parameter
    log_threshold: nn.Parameter

    def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

        # We'll store a bandwidth for the training approach, if needed
        self.bandwidth = cfg.jumprelu_bandwidth

        # In typical JumpReLU training code, we may track a log_threshold:
        self.log_threshold = nn.Parameter(
            torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
            * np.log(cfg.jumprelu_init_threshold)
        )

    @override
    def initialize_weights(self) -> None:
        """
        Initialize parameters like the base SAE, but also add log_threshold.
        """
        super().initialize_weights()
        # Encoder Bias
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

    @property
    def threshold(self) -> torch.Tensor:
        """
        Returns the parameterized threshold > 0 for each unit.
        threshold = exp(log_threshold).
        """
        return torch.exp(self.log_threshold)

    def encode_with_hidden_pre(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        sae_in = self.process_sae_in(x)

        hidden_pre = sae_in @ self.W_enc + self.b_enc
        feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)

        return feature_acts, hidden_pre  # type: ignore

    @override
    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """Calculate architecture-specific auxiliary loss terms."""

        threshold = self.threshold
        W_dec_norm = self.W_dec.norm(dim=1)
        if self.cfg.jumprelu_sparsity_loss_mode == "step":
            l0 = torch.sum(
                Step.apply(hidden_pre, threshold, self.bandwidth),  # type: ignore
                dim=-1,
            )
            l0_loss = (step_input.coefficients["l0"] * l0).mean()
        elif self.cfg.jumprelu_sparsity_loss_mode == "tanh":
            per_item_l0_loss = torch.tanh(
                self.cfg.jumprelu_tanh_scale * feature_acts * W_dec_norm
            ).sum(dim=-1)
            l0_loss = (step_input.coefficients["l0"] * per_item_l0_loss).mean()
        else:
            raise ValueError(
                f"Invalid sparsity loss mode: {self.cfg.jumprelu_sparsity_loss_mode}"
            )
        losses = {"l0_loss": l0_loss}

        if self.cfg.pre_act_loss_coefficient is not None:
            losses["pre_act_loss"] = calculate_pre_act_loss(
                self.cfg.pre_act_loss_coefficient,
                threshold,
                hidden_pre,
                step_input.dead_neuron_mask,
                W_dec_norm,
            )
        return losses

    @override
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {
            "l0": TrainCoefficientConfig(
                value=self.cfg.l0_coefficient,
                warm_up_steps=self.cfg.l0_warm_up_steps,
            ),
        }

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """
        Override to properly handle threshold adjustment with W_dec norms.
        """
        # Save the current threshold before we call the parent method
        current_thresh = self.threshold.clone()

        # Get W_dec norms (clamped to avoid division by zero)
        W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)

        # Call parent implementation to handle W_enc and W_dec adjustment
        super().fold_W_dec_norm()

        # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
        self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())

    def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
        """Convert log_threshold to threshold for saving"""
        if "log_threshold" in state_dict:
            threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
            del state_dict["log_threshold"]
            state_dict["threshold"] = threshold

    def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
        """Convert threshold to log_threshold for loading"""
        if "threshold" in state_dict:
            threshold = state_dict["threshold"]
            del state_dict["threshold"]
            state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()

threshold: torch.Tensor property

Returns the parameterized threshold > 0 for each unit. threshold = exp(log_threshold).

calculate_aux_loss(step_input, feature_acts, hidden_pre, sae_out)

Calculate architecture-specific auxiliary loss terms.

Source code in sae_lens/saes/jumprelu_sae.py
@override
def calculate_aux_loss(
    self,
    step_input: TrainStepInput,
    feature_acts: torch.Tensor,
    hidden_pre: torch.Tensor,
    sae_out: torch.Tensor,
) -> dict[str, torch.Tensor]:
    """Calculate architecture-specific auxiliary loss terms."""

    threshold = self.threshold
    W_dec_norm = self.W_dec.norm(dim=1)
    if self.cfg.jumprelu_sparsity_loss_mode == "step":
        l0 = torch.sum(
            Step.apply(hidden_pre, threshold, self.bandwidth),  # type: ignore
            dim=-1,
        )
        l0_loss = (step_input.coefficients["l0"] * l0).mean()
    elif self.cfg.jumprelu_sparsity_loss_mode == "tanh":
        per_item_l0_loss = torch.tanh(
            self.cfg.jumprelu_tanh_scale * feature_acts * W_dec_norm
        ).sum(dim=-1)
        l0_loss = (step_input.coefficients["l0"] * per_item_l0_loss).mean()
    else:
        raise ValueError(
            f"Invalid sparsity loss mode: {self.cfg.jumprelu_sparsity_loss_mode}"
        )
    losses = {"l0_loss": l0_loss}

    if self.cfg.pre_act_loss_coefficient is not None:
        losses["pre_act_loss"] = calculate_pre_act_loss(
            self.cfg.pre_act_loss_coefficient,
            threshold,
            hidden_pre,
            step_input.dead_neuron_mask,
            W_dec_norm,
        )
    return losses

fold_W_dec_norm()

Override to properly handle threshold adjustment with W_dec norms.

Source code in sae_lens/saes/jumprelu_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """
    Override to properly handle threshold adjustment with W_dec norms.
    """
    # Save the current threshold before we call the parent method
    current_thresh = self.threshold.clone()

    # Get W_dec norms (clamped to avoid division by zero)
    W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)

    # Call parent implementation to handle W_enc and W_dec adjustment
    super().fold_W_dec_norm()

    # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
    self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())

initialize_weights()

Initialize parameters like the base SAE, but also add log_threshold.

Source code in sae_lens/saes/jumprelu_sae.py
@override
def initialize_weights(self) -> None:
    """
    Initialize parameters like the base SAE, but also add log_threshold.
    """
    super().initialize_weights()
    # Encoder Bias
    self.b_enc = nn.Parameter(
        torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
    )

process_state_dict_for_loading(state_dict)

Convert threshold to log_threshold for loading

Source code in sae_lens/saes/jumprelu_sae.py
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
    """Convert threshold to log_threshold for loading"""
    if "threshold" in state_dict:
        threshold = state_dict["threshold"]
        del state_dict["threshold"]
        state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()

process_state_dict_for_saving(state_dict)

Convert log_threshold to threshold for saving

Source code in sae_lens/saes/jumprelu_sae.py
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
    """Convert log_threshold to threshold for saving"""
    if "log_threshold" in state_dict:
        threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
        del state_dict["log_threshold"]
        state_dict["threshold"] = threshold

JumpReLUTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a JumpReLUTrainingSAE.

  • jumprelu_init_threshold: initial threshold for the JumpReLU activation
  • jumprelu_bandwidth: bandwidth for the JumpReLU activation
  • jumprelu_sparsity_loss_mode: mode for the sparsity loss, either "step" or "tanh". "step" is Google Deepmind's L0 loss, "tanh" is Anthropic's sparsity loss.
  • l0_coefficient: coefficient for the l0 sparsity loss
  • l0_warm_up_steps: number of warm-up steps for the l0 sparsity loss
  • pre_act_loss_coefficient: coefficient for the pre-activation loss. Set to None to disable. Set to 3e-6 to match Anthropic's setup. Default is None.
  • jumprelu_tanh_scale: scale for the tanh sparsity loss. Only relevant for "tanh" sparsity loss mode. Default is 4.0.
Source code in sae_lens/saes/jumprelu_sae.py
@dataclass
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a JumpReLUTrainingSAE.

    - jumprelu_init_threshold: initial threshold for the JumpReLU activation
    - jumprelu_bandwidth: bandwidth for the JumpReLU activation
    - jumprelu_sparsity_loss_mode: mode for the sparsity loss, either "step" or "tanh". "step" is Google Deepmind's L0 loss, "tanh" is Anthropic's sparsity loss.
    - l0_coefficient: coefficient for the l0 sparsity loss
    - l0_warm_up_steps: number of warm-up steps for the l0 sparsity loss
    - pre_act_loss_coefficient: coefficient for the pre-activation loss. Set to None to disable. Set to 3e-6 to match Anthropic's setup. Default is None.
    - jumprelu_tanh_scale: scale for the tanh sparsity loss. Only relevant for "tanh" sparsity loss mode. Default is 4.0.
    """

    jumprelu_init_threshold: float = 0.01
    jumprelu_bandwidth: float = 0.05
    # step is Google Deepmind, tanh is Anthropic
    jumprelu_sparsity_loss_mode: Literal["step", "tanh"] = "step"
    l0_coefficient: float = 1.0
    l0_warm_up_steps: int = 0

    # anthropic's auxiliary loss to avoid dead features
    pre_act_loss_coefficient: float | None = None

    # only relevant for tanh sparsity loss mode
    jumprelu_tanh_scale: float = 4.0

    @override
    @classmethod
    def architecture(cls) -> str:
        return "jumprelu"

JumpReLUTranscoder

Bases: Transcoder

A transcoder with JumpReLU activation function.

JumpReLU applies a threshold to activations: if pre-activation <= threshold, the unit is zeroed out; otherwise, it follows the base activation function.

Source code in sae_lens/saes/transcoder.py
class JumpReLUTranscoder(Transcoder):
    """
    A transcoder with JumpReLU activation function.

    JumpReLU applies a threshold to activations: if pre-activation <=
    threshold, the unit is zeroed out; otherwise, it follows the base
    activation function.
    """

    cfg: JumpReLUTranscoderConfig  # type: ignore[assignment]
    threshold: nn.Parameter

    def __init__(self, cfg: JumpReLUTranscoderConfig):
        super().__init__(cfg)
        self.cfg = cfg

    def initialize_weights(self):
        """Initialize transcoder weights including threshold parameter."""
        super().initialize_weights()

        # Initialize threshold parameter for JumpReLU
        self.threshold = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode using JumpReLU activation.

        Applies base activation function (ReLU) then masks based on threshold.
        """
        # Preprocess the SAE input
        sae_in = self.process_sae_in(x)

        # Compute pre-activation values
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

        # Apply base activation function (ReLU)
        feature_acts = self.activation_fn(hidden_pre)

        # Apply JumpReLU threshold
        # During training, use detached threshold to prevent gradient flow
        threshold = self.threshold.detach() if self.training else self.threshold
        jump_relu_mask = (hidden_pre > threshold).to(self.dtype)

        # Apply mask and hook
        return self.hook_sae_acts_post(feature_acts * jump_relu_mask)

    def fold_W_dec_norm(self) -> None:
        """
        Fold the decoder weight norm into the threshold parameter.

        This is important for JumpReLU as the threshold needs to be scaled
        along with the decoder weights.
        """
        # Get the decoder weight norms before normalizing
        with torch.no_grad():
            W_dec_norms = self.W_dec.norm(dim=1)

        # Fold the decoder norms as in the parent class
        super().fold_W_dec_norm()

        # Scale the threshold by the decoder weight norms
        with torch.no_grad():
            self.threshold.data = self.threshold.data * W_dec_norms

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
        cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
        return cls(cfg)

encode(x)

Encode using JumpReLU activation.

Applies base activation function (ReLU) then masks based on threshold.

Source code in sae_lens/saes/transcoder.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode using JumpReLU activation.

    Applies base activation function (ReLU) then masks based on threshold.
    """
    # Preprocess the SAE input
    sae_in = self.process_sae_in(x)

    # Compute pre-activation values
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

    # Apply base activation function (ReLU)
    feature_acts = self.activation_fn(hidden_pre)

    # Apply JumpReLU threshold
    # During training, use detached threshold to prevent gradient flow
    threshold = self.threshold.detach() if self.training else self.threshold
    jump_relu_mask = (hidden_pre > threshold).to(self.dtype)

    # Apply mask and hook
    return self.hook_sae_acts_post(feature_acts * jump_relu_mask)

fold_W_dec_norm()

Fold the decoder weight norm into the threshold parameter.

This is important for JumpReLU as the threshold needs to be scaled along with the decoder weights.

Source code in sae_lens/saes/transcoder.py
def fold_W_dec_norm(self) -> None:
    """
    Fold the decoder weight norm into the threshold parameter.

    This is important for JumpReLU as the threshold needs to be scaled
    along with the decoder weights.
    """
    # Get the decoder weight norms before normalizing
    with torch.no_grad():
        W_dec_norms = self.W_dec.norm(dim=1)

    # Fold the decoder norms as in the parent class
    super().fold_W_dec_norm()

    # Scale the threshold by the decoder weight norms
    with torch.no_grad():
        self.threshold.data = self.threshold.data * W_dec_norms

initialize_weights()

Initialize transcoder weights including threshold parameter.

Source code in sae_lens/saes/transcoder.py
def initialize_weights(self):
    """Initialize transcoder weights including threshold parameter."""
    super().initialize_weights()

    # Initialize threshold parameter for JumpReLU
    self.threshold = nn.Parameter(
        torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
    )

JumpReLUTranscoderConfig dataclass

Bases: TranscoderConfig

Configuration for JumpReLU transcoder.

Source code in sae_lens/saes/transcoder.py
@dataclass
class JumpReLUTranscoderConfig(TranscoderConfig):
    """Configuration for JumpReLU transcoder."""

    @classmethod
    def architecture(cls) -> str:
        """Return the architecture name for this config."""
        return "jumprelu_transcoder"

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoderConfig":
        """Create a JumpReLUTranscoderConfig from a dictionary."""
        # Filter to only include valid dataclass fields
        filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

        # Create the config instance
        res = cls(**filtered_config_dict)

        # Handle metadata if present
        if "metadata" in config_dict:
            res.metadata = SAEMetadata(**config_dict["metadata"])

        return res

architecture() classmethod

Return the architecture name for this config.

Source code in sae_lens/saes/transcoder.py
@classmethod
def architecture(cls) -> str:
    """Return the architecture name for this config."""
    return "jumprelu_transcoder"

from_dict(config_dict) classmethod

Create a JumpReLUTranscoderConfig from a dictionary.

Source code in sae_lens/saes/transcoder.py
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoderConfig":
    """Create a JumpReLUTranscoderConfig from a dictionary."""
    # Filter to only include valid dataclass fields
    filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

    # Create the config instance
    res = cls(**filtered_config_dict)

    # Handle metadata if present
    if "metadata" in config_dict:
        res.metadata = SAEMetadata(**config_dict["metadata"])

    return res

LanguageModelSAERunnerConfig dataclass

Bases: Generic[T_TRAINING_SAE_CONFIG]

Configuration for training a sparse autoencoder on a language model.

Parameters:

Name Type Description Default
sae T_TRAINING_SAE_CONFIG

The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).

required
model_name str

The name of the model to use. This should be the name of the model in the Hugging Face model hub.

'gelu-2l'
model_class_name str

The name of the class of the model to use. This should be either HookedTransformer or HookedMamba.

'HookedTransformer'
hook_name str

The name of the hook to use. This should be a valid TransformerLens hook.

'blocks.0.hook_mlp_out'
hook_eval str

DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.

'NOT_IN_USE'
hook_head_index int

When the hook is for an activation with a head index, we can specify a specific head to use here.

None
dataset_path str

A Hugging Face dataset path.

''
dataset_trust_remote_code bool

Whether to trust remote code when loading datasets from Huggingface.

True
streaming bool

Whether to stream the dataset. Streaming large datasets is usually practical.

True
is_dataset_tokenized bool

Whether the dataset is already tokenized.

True
context_size int

The context size to use when generating activations on which to train the SAE.

128
use_cached_activations bool

Whether to use cached activations. This is useful when doing sweeps over the same activations.

False
cached_activations_path str

The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".

None
from_pretrained_path str

The path to a pretrained SAE. We can finetune an existing SAE if needed.

None
n_batches_in_buffer int

The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.

20
training_tokens int

The number of training tokens.

2000000
store_batch_size_prompts int

The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.

32
seqpos_slice tuple[int | None, ...]

Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.

(None,)
disable_concat_sequences bool

Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.

False
sequence_separator_token int | Literal['bos', 'eos', 'sep'] | None

If not None, this token will be placed between sentences in a batch to act as a separator. By default, this is the <bos> token.

special_token_field(default='bos')
activations_mixing_fraction float

Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).

0.5
device str

The device to use. Usually "cuda".

'cpu'
act_store_device str

The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.

'with_model'
seed int

The seed to use.

42
dtype str

The data type to use for the SAE and activations.

'float32'
prepend_bos bool

Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.

True
autocast bool

Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.

False
autocast_lm bool

Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.

False
compile_llm bool

Whether to compile the LLM using torch.compile.

False
llm_compilation_mode str

The compilation mode to use for the LLM if compile_llm is True.

None
compile_sae bool

Whether to compile the SAE using torch.compile.

False
sae_compilation_mode str

The compilation mode to use for the SAE if compile_sae is True.

None
train_batch_size_tokens int

The batch size for training, in tokens. This controls the batch size of the SAE training loop.

4096
adam_beta1 float

The beta1 parameter for the Adam optimizer.

0.9
adam_beta2 float

The beta2 parameter for the Adam optimizer.

0.999
lr float

The learning rate.

0.0003
lr_scheduler_name str

The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").

'constant'
lr_warm_up_steps int

The number of warm-up steps for the learning rate.

0
lr_end float

The end learning rate if using a scheduler like cosine annealing. Defaults to lr / 10.

None
lr_decay_steps int

The number of decay steps for the learning rate if using a scheduler with decay.

0
n_restart_cycles int

The number of restart cycles for the cosine annealing with warm restarts scheduler.

1
dead_feature_window int

The window size (in training steps) for detecting dead features.

1000
feature_sampling_window int

The window size (in training steps) for resampling features (e.g. dead features).

2000
dead_feature_threshold float

The threshold below which a feature's activation frequency is considered dead.

1e-08
n_eval_batches int

The number of batches to use for evaluation.

10
eval_batch_size_prompts int

The batch size for evaluation, in prompts. Useful if evals cause OOM.

None
logger LoggingConfig

Configuration for logging (e.g. W&B).

LoggingConfig()
n_checkpoints int

The number of checkpoints to save during training. 0 means no checkpoints.

0
checkpoint_path str | None

The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")

'checkpoints'
save_final_checkpoint bool

Whether to include an additional final checkpoint when training is finished. (default is False).

False
resume_from_checkpoint str | None

The path to the checkpoint to resume training from. (default is None).

None
output_path str | None

The path to save outputs. Set to None to disable output saving. (default is "output")

'output'
verbose bool

Whether to print verbose output. (default is True)

True
model_kwargs dict[str, Any]

Keyword arguments for model.run_with_cache

dict_field(default={})
model_from_pretrained_kwargs dict[str, Any]

Additional keyword arguments to pass to the model's from_pretrained method.

dict_field(default=None)
sae_lens_version str

The version of the sae_lens library.

(lambda: __version__)()
sae_lens_training_version str

The version of the sae_lens training library.

(lambda: __version__)()
exclude_special_tokens bool | list[int]

Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.

False
Source code in sae_lens/config.py
@dataclass
class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
    """
    Configuration for training a sparse autoencoder on a language model.

    Args:
        sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
        model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
        model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
        hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
        hook_eval (str): DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
        hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
        dataset_path (str): A Hugging Face dataset path.
        dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
        streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
        is_dataset_tokenized (bool): Whether the dataset is already tokenized.
        context_size (int): The context size to use when generating activations on which to train the SAE.
        use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
        cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
        from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
        n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
        training_tokens (int): The number of training tokens.
        store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
        seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
        disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
        sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
        activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
        device (str): The device to use. Usually "cuda".
        act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
        seed (int): The seed to use.
        dtype (str): The data type to use for the SAE and activations.
        prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
        autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
        autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
        compile_llm (bool): Whether to compile the LLM using `torch.compile`.
        llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
        compile_sae (bool): Whether to compile the SAE using `torch.compile`.
        sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
        train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
        adam_beta1 (float): The beta1 parameter for the Adam optimizer.
        adam_beta2 (float): The beta2 parameter for the Adam optimizer.
        lr (float): The learning rate.
        lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
        lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
        lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
        lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
        n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
        dead_feature_window (int): The window size (in training steps) for detecting dead features.
        feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
        dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
        n_eval_batches (int): The number of batches to use for evaluation.
        eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
        logger (LoggingConfig): Configuration for logging (e.g. W&B).
        n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
        checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
        save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
        resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
        output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
        verbose (bool): Whether to print verbose output. (default is True)
        model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
        model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
        sae_lens_version (str): The version of the sae_lens library.
        sae_lens_training_version (str): The version of the sae_lens training library.
        exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
    """

    sae: T_TRAINING_SAE_CONFIG

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    model_class_name: str = "HookedTransformer"
    hook_name: str = "blocks.0.hook_mlp_out"
    hook_eval: str = "NOT_IN_USE"
    hook_head_index: int | None = None
    dataset_path: str = ""
    dataset_trust_remote_code: bool = True
    streaming: bool = True
    is_dataset_tokenized: bool = True
    context_size: int = 128
    use_cached_activations: bool = False
    cached_activations_path: str | None = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    )

    # SAE Parameters
    from_pretrained_path: str | None = None

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    training_tokens: int = 2_000_000
    store_batch_size_prompts: int = 32
    seqpos_slice: tuple[int | None, ...] = (None,)
    disable_concat_sequences: bool = False
    sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
        special_token_field(default="bos")
    )
    activations_mixing_fraction: float = 0.5

    # Misc
    device: str = "cpu"
    act_store_device: str = "with_model"  # will be set by post init if with_model
    seed: int = 42
    dtype: str = "float32"  # type: ignore #
    prepend_bos: bool = True

    # Performance - see compilation section of lm_runner.py for info
    autocast: bool = False  # autocast to autocast_dtype during training
    autocast_lm: bool = False  # autocast lm during activation fetching
    compile_llm: bool = False  # use torch.compile on the LLM
    llm_compilation_mode: str | None = None  # which torch.compile mode to use
    compile_sae: bool = False  # use torch.compile on the SAE
    sae_compilation_mode: str | None = None

    # Training Parameters

    ## Batch size
    train_batch_size_tokens: int = 4096

    ## Adam
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999

    ## Learning Rate Schedule
    lr: float = 3e-4
    lr_scheduler_name: str = (
        "constant"  # constant, cosineannealing, cosineannealingwarmrestarts
    )
    lr_warm_up_steps: int = 0
    lr_end: float | None = None  # only used for cosine annealing, default is lr / 10
    lr_decay_steps: int = 0
    n_restart_cycles: int = 1  # used only for cosineannealingwarmrestarts

    # Resampling protocol args
    dead_feature_window: int = 1000  # unless this window is larger feature sampling,
    feature_sampling_window: int = 2000
    dead_feature_threshold: float = 1e-8

    # Evals
    n_eval_batches: int = 10
    eval_batch_size_prompts: int | None = None  # useful if evals cause OOM

    logger: LoggingConfig = field(default_factory=LoggingConfig)

    # Outputs/Checkpoints
    n_checkpoints: int = 0
    checkpoint_path: str | None = "checkpoints"
    save_final_checkpoint: bool = False
    output_path: str | None = "output"
    resume_from_checkpoint: str | None = None

    # Misc
    verbose: bool = True
    model_kwargs: dict[str, Any] = dict_field(default={})
    model_from_pretrained_kwargs: dict[str, Any] | None = dict_field(default=None)
    sae_lens_version: str = field(default_factory=lambda: __version__)
    sae_lens_training_version: str = field(default_factory=lambda: __version__)
    exclude_special_tokens: bool | list[int] = False

    def __post_init__(self):
        if self.hook_eval != "NOT_IN_USE":
            warnings.warn(
                "The 'hook_eval' field is deprecated and will be removed in v7.0.0. "
                "It is not currently used and can be safely removed from your config.",
                DeprecationWarning,
                stacklevel=2,
            )

        if self.use_cached_activations and self.cached_activations_path is None:
            self.cached_activations_path = _default_cached_activations_path(
                self.dataset_path,
                self.model_name,
                self.hook_name,
                self.hook_head_index,
            )
        self.tokens_per_buffer = (
            self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
        )

        if self.logger.run_name is None:
            self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"

        if self.model_from_pretrained_kwargs is None:
            if self.model_class_name == "HookedTransformer":
                self.model_from_pretrained_kwargs = {"center_writing_weights": False}
            else:
                self.model_from_pretrained_kwargs = {}

        if self.act_store_device == "with_model":
            self.act_store_device = self.device

        if self.lr_end is None:
            self.lr_end = self.lr / 10

        unique_id = self.logger.wandb_id
        if unique_id is None:
            unique_id = cast(
                Any, wandb
            ).util.generate_id()  # not sure why this type is erroring
        self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

        if self.verbose:
            logger.info(
                f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
            )
            # Print out some useful info:
            n_tokens_per_buffer = (
                self.store_batch_size_prompts
                * self.context_size
                * self.n_batches_in_buffer
            )
            logger.info(
                f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10**6}"
            )
            n_contexts_per_buffer = (
                self.store_batch_size_prompts * self.n_batches_in_buffer
            )
            logger.info(
                f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10**6}"
            )

            total_training_steps = (
                self.training_tokens
            ) // self.train_batch_size_tokens
            logger.info(f"Total training steps: {total_training_steps}")

            total_wandb_updates = (
                total_training_steps // self.logger.wandb_log_frequency
            )
            logger.info(f"Total wandb updates: {total_wandb_updates}")

            # how many times will we sample dead neurons?
            # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
            n_feature_window_samples = (
                total_training_steps // self.feature_sampling_window
            )
            logger.info(
                f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size_tokens) / 10**6}"
            )
            logger.info(
                f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size_tokens) / 10**6}"
            )
            logger.info(
                f"We will reset the sparsity calculation {n_feature_window_samples} times."
            )
            # logger.info("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
            logger.info(
                f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
            )

        if self.context_size < 0:
            raise ValueError(
                f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
            )

        _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

        if isinstance(self.exclude_special_tokens, list) and not all(
            isinstance(x, int) for x in self.exclude_special_tokens
        ):
            raise ValueError("exclude_special_tokens list must contain only integers")

    @property
    def total_training_tokens(self) -> int:
        return self.training_tokens

    @property
    def total_training_steps(self) -> int:
        return self.total_training_tokens // self.train_batch_size_tokens

    def get_training_sae_cfg_dict(self) -> dict[str, Any]:
        return self.sae.to_dict()

    def to_dict(self) -> dict[str, Any]:
        """
        Convert the config to a dictionary.
        """

        d = asdict(self)

        d["logger"] = asdict(self.logger)
        d["sae"] = self.sae.to_dict()
        # Overwrite fields that might not be JSON-serializable
        d["dtype"] = str(self.dtype)
        d["device"] = str(self.device)
        d["act_store_device"] = str(self.act_store_device)
        return d

    @classmethod
    def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
        """
        Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.

        Args:
            cfg_dict (dict[str, Any]): The dictionary to load the config from.

        Returns:
            LanguageModelSAERunnerConfig: The loaded config.
        """
        if "sae" not in cfg_dict:
            raise ValueError("sae field is required in the config dictionary")
        if "architecture" not in cfg_dict["sae"]:
            raise ValueError("architecture field is required in the sae dictionary")
        if "logger" not in cfg_dict:
            raise ValueError("logger field is required in the config dictionary")
        sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
        sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
        logger_cfg = LoggingConfig(**cfg_dict["logger"])
        updated_cfg_dict: dict[str, Any] = {
            **cfg_dict,
            "sae": sae_cfg,
            "logger": logger_cfg,
        }
        output = cls(**updated_cfg_dict)
        # the post_init always appends to checkpoint path, so we need to set it explicitly here.
        if "checkpoint_path" in cfg_dict:
            output.checkpoint_path = cfg_dict["checkpoint_path"]
        return output

    def to_sae_trainer_config(self) -> "SAETrainerConfig":
        return SAETrainerConfig(
            n_checkpoints=self.n_checkpoints,
            checkpoint_path=self.checkpoint_path,
            save_final_checkpoint=self.save_final_checkpoint,
            total_training_samples=self.total_training_tokens,
            device=self.device,
            autocast=self.autocast,
            lr=self.lr,
            lr_end=self.lr_end,
            lr_scheduler_name=self.lr_scheduler_name,
            lr_warm_up_steps=self.lr_warm_up_steps,
            adam_beta1=self.adam_beta1,
            adam_beta2=self.adam_beta2,
            lr_decay_steps=self.lr_decay_steps,
            n_restart_cycles=self.n_restart_cycles,
            train_batch_size_samples=self.train_batch_size_tokens,
            dead_feature_window=self.dead_feature_window,
            feature_sampling_window=self.feature_sampling_window,
            logger=self.logger,
        )

from_dict(cfg_dict) classmethod

Load a LanguageModelSAERunnerConfig from a dictionary given by to_dict.

Parameters:

Name Type Description Default
cfg_dict dict[str, Any]

The dictionary to load the config from.

required

Returns:

Name Type Description
LanguageModelSAERunnerConfig LanguageModelSAERunnerConfig[Any]

The loaded config.

Source code in sae_lens/config.py
@classmethod
def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
    """
    Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.

    Args:
        cfg_dict (dict[str, Any]): The dictionary to load the config from.

    Returns:
        LanguageModelSAERunnerConfig: The loaded config.
    """
    if "sae" not in cfg_dict:
        raise ValueError("sae field is required in the config dictionary")
    if "architecture" not in cfg_dict["sae"]:
        raise ValueError("architecture field is required in the sae dictionary")
    if "logger" not in cfg_dict:
        raise ValueError("logger field is required in the config dictionary")
    sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
    sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
    logger_cfg = LoggingConfig(**cfg_dict["logger"])
    updated_cfg_dict: dict[str, Any] = {
        **cfg_dict,
        "sae": sae_cfg,
        "logger": logger_cfg,
    }
    output = cls(**updated_cfg_dict)
    # the post_init always appends to checkpoint path, so we need to set it explicitly here.
    if "checkpoint_path" in cfg_dict:
        output.checkpoint_path = cfg_dict["checkpoint_path"]
    return output

to_dict()

Convert the config to a dictionary.

Source code in sae_lens/config.py
def to_dict(self) -> dict[str, Any]:
    """
    Convert the config to a dictionary.
    """

    d = asdict(self)

    d["logger"] = asdict(self.logger)
    d["sae"] = self.sae.to_dict()
    # Overwrite fields that might not be JSON-serializable
    d["dtype"] = str(self.dtype)
    d["device"] = str(self.device)
    d["act_store_device"] = str(self.act_store_device)
    return d

LanguageModelSAETrainingRunner

Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.

Source code in sae_lens/llm_sae_training_runner.py
class LanguageModelSAETrainingRunner:
    """
    Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
    """

    cfg: LanguageModelSAERunnerConfig[Any]
    model: HookedRootModule
    sae: TrainingSAE[Any]
    activations_store: ActivationsStore

    def __init__(
        self,
        cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
        override_dataset: HfDataset | None = None,
        override_model: HookedRootModule | None = None,
        override_sae: TrainingSAE[Any] | None = None,
        resume_from_checkpoint: Path | str | None = None,
    ):
        if override_dataset is not None:
            logger.warning(
                f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
            )
        if override_model is not None:
            logger.warning(
                f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
            )

        self.cfg = cfg

        if override_model is None:
            self.model = load_model(
                self.cfg.model_class_name,
                self.cfg.model_name,
                device=self.cfg.device,
                model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
            )
        else:
            self.model = override_model

        self.activations_store = ActivationsStore.from_config(
            self.model,
            self.cfg,
            override_dataset=override_dataset,
        )

        if override_sae is None:
            if self.cfg.from_pretrained_path is not None:
                self.sae = TrainingSAE.load_from_disk(
                    self.cfg.from_pretrained_path, self.cfg.device
                )
            else:
                self.sae = TrainingSAE.from_dict(
                    TrainingSAEConfig.from_dict(
                        self.cfg.get_training_sae_cfg_dict(),
                    ).to_dict()
                )
        else:
            self.sae = override_sae

        self.sae.to(self.cfg.device)

    def run(self):
        """
        Run the training of the SAE.
        """
        self._set_sae_metadata()
        if self.cfg.logger.log_to_wandb:
            wandb.init(
                project=self.cfg.logger.wandb_project,
                entity=self.cfg.logger.wandb_entity,
                config=self.cfg.to_dict(),
                name=self.cfg.logger.run_name,
                id=self.cfg.logger.wandb_id,
            )

        evaluator = LLMSaeEvaluator(
            model=self.model,
            activations_store=self.activations_store,
            eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
            n_eval_batches=self.cfg.n_eval_batches,
            model_kwargs=self.cfg.model_kwargs,
        )

        trainer = SAETrainer(
            sae=self.sae,
            data_provider=self.activations_store,
            evaluator=evaluator,
            save_checkpoint_fn=self.save_checkpoint,
            cfg=self.cfg.to_sae_trainer_config(),
        )

        if self.cfg.resume_from_checkpoint is not None:
            logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
            trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
            self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
            self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)

        self._compile_if_needed()
        sae = self.run_trainer_with_interruption_handling(trainer)

        if self.cfg.output_path is not None:
            self.save_final_sae(
                sae=sae,
                output_path=self.cfg.output_path,
                log_feature_sparsity=trainer.log_feature_sparsity,
            )

        if self.cfg.logger.log_to_wandb:
            wandb.finish()

        return sae

    def save_final_sae(
        self,
        sae: TrainingSAE[Any],
        output_path: str,
        log_feature_sparsity: torch.Tensor | None = None,
    ):
        base_output_path = Path(output_path)
        base_output_path.mkdir(exist_ok=True, parents=True)

        weights_path, cfg_path = sae.save_inference_model(str(base_output_path))

        sparsity_path = None
        if log_feature_sparsity is not None:
            sparsity_path = base_output_path / SPARSITY_FILENAME
            save_file({"sparsity": log_feature_sparsity}, sparsity_path)

        runner_config = self.cfg.to_dict()
        with open(base_output_path / RUNNER_CFG_FILENAME, "w") as f:
            json.dump(runner_config, f)

        if self.cfg.logger.log_to_wandb:
            self.cfg.logger.log(
                self,
                weights_path,
                cfg_path,
                sparsity_path=sparsity_path,
                wandb_aliases=["final_model"],
            )

    def _set_sae_metadata(self):
        self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
        self.sae.cfg.metadata.hook_name = self.cfg.hook_name
        self.sae.cfg.metadata.model_name = self.cfg.model_name
        self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
        self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
        self.sae.cfg.metadata.context_size = self.cfg.context_size
        self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
        self.sae.cfg.metadata.model_from_pretrained_kwargs = (
            self.cfg.model_from_pretrained_kwargs
        )
        self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
        self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
        self.sae.cfg.metadata.sequence_separator_token = (
            self.cfg.sequence_separator_token
        )
        self.sae.cfg.metadata.disable_concat_sequences = (
            self.cfg.disable_concat_sequences
        )

    def _compile_if_needed(self):
        # Compile model and SAE
        #  torch.compile can provide significant speedups (10-20% in testing)
        # using max-autotune gives the best speedups but:
        # (a) increases VRAM usage,
        # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
        # (c) takes some time to compile
        # optimal settings seem to be:
        # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
        # (also pylance seems to really hate this)
        if self.cfg.compile_llm:
            self.model = torch.compile(
                self.model,
                mode=self.cfg.llm_compilation_mode,
            )  # type: ignore

        if self.cfg.compile_sae:
            backend = "aot_eager" if self.cfg.device == "mps" else "inductor"

            self.sae.training_forward_pass = torch.compile(  # type: ignore
                self.sae.training_forward_pass,
                mode=self.cfg.sae_compilation_mode,
                backend=backend,
            )  # type: ignore

    def run_trainer_with_interruption_handling(
        self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
    ):
        try:
            # signal handlers (if preempted)
            signal.signal(signal.SIGINT, interrupt_callback)
            signal.signal(signal.SIGTERM, interrupt_callback)

            # train SAE
            sae = trainer.fit()

        except (KeyboardInterrupt, InterruptedException):
            if self.cfg.checkpoint_path is not None:
                logger.warning("interrupted, saving progress")
                checkpoint_path = Path(self.cfg.checkpoint_path) / str(
                    trainer.n_training_samples
                )
                self.save_checkpoint(checkpoint_path)
                logger.info("done saving")
            raise

        return sae

    def save_checkpoint(
        self,
        checkpoint_path: Path | None,
    ) -> None:
        if checkpoint_path is None:
            return

        self.activations_store.save_to_checkpoint(checkpoint_path)

        runner_config = self.cfg.to_dict()
        with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
            json.dump(runner_config, f)

run()

Run the training of the SAE.

Source code in sae_lens/llm_sae_training_runner.py
def run(self):
    """
    Run the training of the SAE.
    """
    self._set_sae_metadata()
    if self.cfg.logger.log_to_wandb:
        wandb.init(
            project=self.cfg.logger.wandb_project,
            entity=self.cfg.logger.wandb_entity,
            config=self.cfg.to_dict(),
            name=self.cfg.logger.run_name,
            id=self.cfg.logger.wandb_id,
        )

    evaluator = LLMSaeEvaluator(
        model=self.model,
        activations_store=self.activations_store,
        eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
        n_eval_batches=self.cfg.n_eval_batches,
        model_kwargs=self.cfg.model_kwargs,
    )

    trainer = SAETrainer(
        sae=self.sae,
        data_provider=self.activations_store,
        evaluator=evaluator,
        save_checkpoint_fn=self.save_checkpoint,
        cfg=self.cfg.to_sae_trainer_config(),
    )

    if self.cfg.resume_from_checkpoint is not None:
        logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
        trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
        self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
        self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)

    self._compile_if_needed()
    sae = self.run_trainer_with_interruption_handling(trainer)

    if self.cfg.output_path is not None:
        self.save_final_sae(
            sae=sae,
            output_path=self.cfg.output_path,
            log_feature_sparsity=trainer.log_feature_sparsity,
        )

    if self.cfg.logger.log_to_wandb:
        wandb.finish()

    return sae

MatchingPursuitSAE

Bases: SAE[MatchingPursuitSAEConfig]

An inference-only sparse autoencoder using a "matching pursuit" activation function.

Source code in sae_lens/saes/matching_pursuit_sae.py
class MatchingPursuitSAE(SAE[MatchingPursuitSAEConfig]):
    """
    An inference-only sparse autoencoder using a "matching pursuit" activation function.
    """

    # Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
    @property
    def W_enc(self) -> torch.Tensor:  # pyright: ignore[reportIncompatibleVariableOverride]
        return self.W_dec.T

    # hacky way to get around the base class having W_enc.
    # TODO: harmonize with the base class in next major release
    @override
    def __setattr__(self, name: str, value: Any):
        if name == "W_enc":
            return
        super().__setattr__(name, value)

    @override
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Converts input x into feature activations.
        """
        sae_in = self.process_sae_in(x)
        return _encode_matching_pursuit(
            sae_in,
            self.W_dec,
            self.cfg.residual_threshold,
            max_iterations=self.cfg.max_iterations,
            stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
        )

    @override
    @torch.no_grad()
    def fold_W_dec_norm(self) -> None:
        raise NotImplementedError(
            "Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
        )

    @override
    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        Decode the feature activations back to the input space.
        Now, if hook_z reshaping is turned on, we reverse the flattening.
        """
        sae_out_pre = feature_acts @ self.W_dec
        # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
        if self.cfg.apply_b_dec_to_input:
            sae_out_pre = sae_out_pre + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

decode(feature_acts)

Decode the feature activations back to the input space. Now, if hook_z reshaping is turned on, we reverse the flattening.

Source code in sae_lens/saes/matching_pursuit_sae.py
@override
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """
    Decode the feature activations back to the input space.
    Now, if hook_z reshaping is turned on, we reverse the flattening.
    """
    sae_out_pre = feature_acts @ self.W_dec
    # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
    if self.cfg.apply_b_dec_to_input:
        sae_out_pre = sae_out_pre + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Converts input x into feature activations.

Source code in sae_lens/saes/matching_pursuit_sae.py
@override
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Converts input x into feature activations.
    """
    sae_in = self.process_sae_in(x)
    return _encode_matching_pursuit(
        sae_in,
        self.W_dec,
        self.cfg.residual_threshold,
        max_iterations=self.cfg.max_iterations,
        stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
    )

MatchingPursuitSAEConfig dataclass

Bases: SAEConfig

Configuration class for MatchingPursuitSAE inference.

Parameters:

Name Type Description Default
residual_threshold float

residual error at which to stop selecting latents. Default 1e-2.

0.01
max_iterations int | None

Maximum iterations (default: d_in if set to None). Defaults to None.

None
stop_on_duplicate_support bool

Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.

True
d_in int

Input dimension (dimensionality of the activations being encoded). Inherited from SAEConfig.

required
d_sae int

SAE latent dimension (number of features in the SAE). Inherited from SAEConfig.

required
dtype str

Data type for the SAE parameters. Inherited from SAEConfig. Defaults to "float32".

'float32'
device str

Device to place the SAE on. Inherited from SAEConfig. Defaults to "cpu".

'cpu'
apply_b_dec_to_input bool

Whether to apply decoder bias to the input before encoding. Inherited from SAEConfig. Defaults to True.

True
normalize_activations Literal[none, expected_average_only_in, constant_norm_rescale, layer_norm]

Normalization strategy for input activations. Inherited from SAEConfig. Defaults to "none".

'none'
reshape_activations Literal[none, hook_z]

How to reshape activations (useful for attention head outputs). Inherited from SAEConfig. Defaults to "none".

'none'
metadata SAEMetadata

Metadata about the SAE (model name, hook name, etc.). Inherited from SAEConfig.

SAEMetadata()
Source code in sae_lens/saes/matching_pursuit_sae.py
@dataclass
class MatchingPursuitSAEConfig(SAEConfig):
    """
    Configuration class for MatchingPursuitSAE inference.

    Args:
        residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
        max_iterations (int | None): Maximum iterations (default: d_in if set to None).
            Defaults to None.
        stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
        d_in (int): Input dimension (dimensionality of the activations being encoded).
            Inherited from SAEConfig.
        d_sae (int): SAE latent dimension (number of features in the SAE).
            Inherited from SAEConfig.
        dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
            Defaults to "float32".
        device (str): Device to place the SAE on. Inherited from SAEConfig.
            Defaults to "cpu".
        apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
            before encoding. Inherited from SAEConfig. Defaults to True.
        normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
            Normalization strategy for input activations. Inherited from SAEConfig.
            Defaults to "none".
        reshape_activations (Literal["none", "hook_z"]): How to reshape activations
            (useful for attention head outputs). Inherited from SAEConfig.
            Defaults to "none".
        metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
            Inherited from SAEConfig.
    """

    residual_threshold: float = 1e-2
    max_iterations: int | None = None
    stop_on_duplicate_support: bool = True

    @override
    @classmethod
    def architecture(cls) -> str:
        return "matching_pursuit"

MatchingPursuitTrainingSAE

Bases: TrainingSAE[MatchingPursuitTrainingSAEConfig]

Source code in sae_lens/saes/matching_pursuit_sae.py
class MatchingPursuitTrainingSAE(TrainingSAE[MatchingPursuitTrainingSAEConfig]):
    # Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
    @property
    def W_enc(self) -> torch.Tensor:  # pyright: ignore[reportIncompatibleVariableOverride]
        return self.W_dec.T

    # hacky way to get around the base class having W_enc.
    # TODO: harmonize with the base class in next major release
    @override
    def __setattr__(self, name: str, value: Any):
        if name == "W_enc":
            return
        super().__setattr__(name, value)

    @override
    def encode_with_hidden_pre(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation.
        We just return zeros for the hidden_pre.
        """

        sae_in = self.process_sae_in(x)
        acts = _encode_matching_pursuit(
            sae_in,
            self.W_dec,
            self.cfg.residual_threshold,
            max_iterations=self.cfg.max_iterations,
            stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
        )
        return acts, torch.zeros_like(acts)

    @override
    @torch.no_grad()
    def fold_W_dec_norm(self) -> None:
        raise NotImplementedError(
            "Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
        )

    @override
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {}

    @override
    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        return {}

    @override
    def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
        output = super().training_forward_pass(step_input)
        l0 = output.feature_acts.bool().float().sum(-1).to_dense()
        residual_norm = (step_input.sae_in - output.sae_out).norm(dim=-1)
        output.metrics["max_l0"] = l0.max()
        output.metrics["min_l0"] = l0.min()
        output.metrics["residual_norm"] = residual_norm.mean()
        output.metrics["residual_threshold_converged_portion"] = (
            (residual_norm < self.cfg.residual_threshold).float().mean()
        )
        return output

    @override
    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        Decode the feature activations back to the input space.
        Now, if hook_z reshaping is turned on, we reverse the flattening.
        """
        sae_out_pre = feature_acts @ self.W_dec
        # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
        if self.cfg.apply_b_dec_to_input:
            sae_out_pre = sae_out_pre + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

decode(feature_acts)

Decode the feature activations back to the input space. Now, if hook_z reshaping is turned on, we reverse the flattening.

Source code in sae_lens/saes/matching_pursuit_sae.py
@override
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """
    Decode the feature activations back to the input space.
    Now, if hook_z reshaping is turned on, we reverse the flattening.
    """
    sae_out_pre = feature_acts @ self.W_dec
    # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
    if self.cfg.apply_b_dec_to_input:
        sae_out_pre = sae_out_pre + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode_with_hidden_pre(x)

hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation. We just return zeros for the hidden_pre.

Source code in sae_lens/saes/matching_pursuit_sae.py
@override
def encode_with_hidden_pre(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation.
    We just return zeros for the hidden_pre.
    """

    sae_in = self.process_sae_in(x)
    acts = _encode_matching_pursuit(
        sae_in,
        self.W_dec,
        self.cfg.residual_threshold,
        max_iterations=self.cfg.max_iterations,
        stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
    )
    return acts, torch.zeros_like(acts)

MatchingPursuitTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a MatchingPursuitTrainingSAE.

Parameters:

Name Type Description Default
residual_threshold float

residual error at which to stop selecting latents. Default 1e-2.

0.01
max_iterations int | None

Maximum iterations (default: d_in if set to None). Defaults to None.

None
stop_on_duplicate_support bool

Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.

True
decoder_init_norm float | None

Norm to initialize decoder weights to. 0.1 corresponds to the "heuristic" initialization from Anthropic's April update. Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.

0.1
d_in int

Input dimension (dimensionality of the activations being encoded). Inherited from SAEConfig.

required
d_sae int

SAE latent dimension (number of features in the SAE). Inherited from SAEConfig.

required
dtype str

Data type for the SAE parameters. Inherited from SAEConfig. Defaults to "float32".

'float32'
device str

Device to place the SAE on. Inherited from SAEConfig. Defaults to "cpu".

'cpu'
apply_b_dec_to_input bool

Whether to apply decoder bias to the input before encoding. Inherited from SAEConfig. Defaults to True.

True
normalize_activations Literal[none, expected_average_only_in, constant_norm_rescale, layer_norm]

Normalization strategy for input activations. Inherited from SAEConfig. Defaults to "none".

'none'
reshape_activations Literal[none, hook_z]

How to reshape activations (useful for attention head outputs). Inherited from SAEConfig. Defaults to "none".

'none'
metadata SAEMetadata

Metadata about the SAE training (model name, hook name, etc.). Inherited from SAEConfig.

SAEMetadata()
Source code in sae_lens/saes/matching_pursuit_sae.py
@dataclass
class MatchingPursuitTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a MatchingPursuitTrainingSAE.

    Args:
        residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
        max_iterations (int | None): Maximum iterations (default: d_in if set to None).
            Defaults to None.
        stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
        decoder_init_norm (float | None): Norm to initialize decoder weights to.
            0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
            Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
        d_in (int): Input dimension (dimensionality of the activations being encoded).
            Inherited from SAEConfig.
        d_sae (int): SAE latent dimension (number of features in the SAE).
            Inherited from SAEConfig.
        dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
            Defaults to "float32".
        device (str): Device to place the SAE on. Inherited from SAEConfig.
            Defaults to "cpu".
        apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
            before encoding. Inherited from SAEConfig. Defaults to True.
        normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
            Normalization strategy for input activations. Inherited from SAEConfig.
            Defaults to "none".
        reshape_activations (Literal["none", "hook_z"]): How to reshape activations
            (useful for attention head outputs). Inherited from SAEConfig.
            Defaults to "none".
        metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
            Inherited from SAEConfig.
    """

    residual_threshold: float = 1e-2
    max_iterations: int | None = None
    stop_on_duplicate_support: bool = True

    @override
    @classmethod
    def architecture(cls) -> str:
        return "matching_pursuit"

    @override
    def __post_init__(self):
        super().__post_init__()
        if self.decoder_init_norm != 1.0:
            self.decoder_init_norm = 1.0
            warnings.warn(
                "decoder_init_norm must be set to 1.0 for MatchingPursuitTrainingSAE, setting to 1.0"
            )

MatryoshkaBatchTopKTrainingSAE

Bases: BatchTopKTrainingSAE

Global Batch TopK Training SAE

This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.

BatchTopK SAEs are saved as JumpReLU SAEs after training.

Source code in sae_lens/saes/matryoshka_batchtopk_sae.py
class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
    """
    Global Batch TopK Training SAE

    This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.

    BatchTopK SAEs are saved as JumpReLU SAEs after training.
    """

    cfg: MatryoshkaBatchTopKTrainingSAEConfig  # type: ignore[assignment]

    def __init__(
        self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
    ):
        super().__init__(cfg, use_error_term)
        _validate_matryoshka_config(cfg)

    @override
    def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
        base_output = super().training_forward_pass(step_input)
        inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
        # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
        for width in self.cfg.matryoshka_widths[:-1]:
            inner_reconstruction = self._decode_matryoshka_level(
                base_output.feature_acts, width, inv_W_dec_norm
            )
            inner_mse_loss = (
                self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
                .sum(dim=-1)
                .mean()
            )
            base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
            base_output.loss = base_output.loss + inner_mse_loss
        return base_output

    def _decode_matryoshka_level(
        self,
        feature_acts: torch.Tensor,
        width: int,
        inv_W_dec_norm: torch.Tensor,
    ) -> torch.Tensor:
        """
        Decodes feature activations back into input space for a matryoshka level
        """
        inner_feature_acts = feature_acts[:, :width]
        # Handle sparse tensors using efficient sparse matrix multiplication
        if self.cfg.rescale_acts_by_decoder_norm:
            # need to multiply by the inverse of the norm because division is illegal with sparse tensors
            inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
        if inner_feature_acts.is_sparse:
            sae_out_pre = (
                _sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
            )
        else:
            sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

MatryoshkaBatchTopKTrainingSAEConfig dataclass

Bases: BatchTopKTrainingSAEConfig

Configuration class for training a MatryoshkaBatchTopKTrainingSAE.

Matryoshka SAEs use a series of nested reconstruction losses of different widths during training to avoid feature absorption. This also has a nice side-effect of encouraging higher-frequency features to be learned in earlier levels. However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes longer to train due to requiring multiple forward passes per training step.

After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.

Parameters:

Name Type Description Default
matryoshka_widths list[int]

The widths of the matryoshka levels. Defaults to an empty list.

list()
k float

The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig. Defaults to 100.

100
topk_threshold_lr float

Learning rate for updating the global topk threshold. The threshold is updated using an exponential moving average of the minimum positive activation value. Defaults to 0.01.

0.01
aux_loss_coefficient float

Coefficient for the auxiliary loss that encourages dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig. Defaults to 1.0.

1.0
rescale_acts_by_decoder_norm bool

Treat the decoder as if it was already normalized. Inherited from TopKTrainingSAEConfig. Defaults to True.

True
decoder_init_norm float | None

Norm to initialize decoder weights to. Inherited from TrainingSAEConfig. Defaults to 0.1.

0.1
d_in int

Input dimension (dimensionality of the activations being encoded). Inherited from SAEConfig.

required
d_sae int

SAE latent dimension (number of features in the SAE). Inherited from SAEConfig.

required
dtype str

Data type for the SAE parameters. Inherited from SAEConfig. Defaults to "float32".

'float32'
device str

Device to place the SAE on. Inherited from SAEConfig. Defaults to "cpu".

'cpu'
Source code in sae_lens/saes/matryoshka_batchtopk_sae.py
@dataclass
class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
    """
    Configuration class for training a MatryoshkaBatchTopKTrainingSAE.

    [Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
    losses of different widths during training to avoid feature absorption. This also has a
    nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
    However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
    longer to train due to requiring multiple forward passes per training step.

    After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.

    Args:
        matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
        k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
            Defaults to 100.
        topk_threshold_lr (float): Learning rate for updating the global topk threshold.
            The threshold is updated using an exponential moving average of the minimum
            positive activation value. Defaults to 0.01.
        aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
            dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
            Defaults to 1.0.
        rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
            Inherited from TopKTrainingSAEConfig. Defaults to True.
        decoder_init_norm (float | None): Norm to initialize decoder weights to.
            Inherited from TrainingSAEConfig. Defaults to 0.1.
        d_in (int): Input dimension (dimensionality of the activations being encoded).
            Inherited from SAEConfig.
        d_sae (int): SAE latent dimension (number of features in the SAE).
            Inherited from SAEConfig.
        dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
            Defaults to "float32".
        device (str): Device to place the SAE on. Inherited from SAEConfig.
            Defaults to "cpu".
    """

    matryoshka_widths: list[int] = field(default_factory=list)

    @override
    @classmethod
    def architecture(cls) -> str:
        return "matryoshka_batchtopk"

PretokenizeRunner

Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
class PretokenizeRunner:
    """
    Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.
    """

    def __init__(self, cfg: PretokenizeRunnerConfig):
        self.cfg = cfg

    def run(self):
        """
        Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
        """
        dataset = load_dataset(  # type: ignore
            self.cfg.dataset_path,
            name=self.cfg.dataset_name,
            data_dir=self.cfg.data_dir,
            data_files=self.cfg.data_files,
            split=self.cfg.split,  # type: ignore
            streaming=self.cfg.streaming,  # type: ignore
        )
        if isinstance(dataset, DatasetDict):
            raise ValueError(
                "Dataset has multiple splits. Must provide a 'split' param."
            )
        tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
        tokenizer.model_max_length = sys.maxsize
        tokenized_dataset = pretokenize_dataset(
            cast(Dataset, dataset), tokenizer, self.cfg
        )

        if self.cfg.save_path is not None:
            tokenized_dataset.save_to_disk(self.cfg.save_path)
            metadata = metadata_from_config(self.cfg)
            metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
            with open(metadata_path, "w") as f:
                json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

        if self.cfg.hf_repo_id is not None:
            push_to_hugging_face_hub(tokenized_dataset, self.cfg)

        return tokenized_dataset

run()

Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
def run(self):
    """
    Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
    """
    dataset = load_dataset(  # type: ignore
        self.cfg.dataset_path,
        name=self.cfg.dataset_name,
        data_dir=self.cfg.data_dir,
        data_files=self.cfg.data_files,
        split=self.cfg.split,  # type: ignore
        streaming=self.cfg.streaming,  # type: ignore
    )
    if isinstance(dataset, DatasetDict):
        raise ValueError(
            "Dataset has multiple splits. Must provide a 'split' param."
        )
    tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
    tokenizer.model_max_length = sys.maxsize
    tokenized_dataset = pretokenize_dataset(
        cast(Dataset, dataset), tokenizer, self.cfg
    )

    if self.cfg.save_path is not None:
        tokenized_dataset.save_to_disk(self.cfg.save_path)
        metadata = metadata_from_config(self.cfg)
        metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
        with open(metadata_path, "w") as f:
            json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

    if self.cfg.hf_repo_id is not None:
        push_to_hugging_face_hub(tokenized_dataset, self.cfg)

    return tokenized_dataset

PretokenizeRunnerConfig dataclass

Configuration class for pretokenizing a dataset.

Source code in sae_lens/config.py
@dataclass
class PretokenizeRunnerConfig:
    """
    Configuration class for pretokenizing a dataset.
    """

    tokenizer_name: str = "gpt2"
    dataset_path: str = ""
    dataset_name: str | None = None
    dataset_trust_remote_code: bool | None = None
    split: str | None = "train"
    data_files: list[str] | None = None
    data_dir: str | None = None
    num_proc: int = 4
    context_size: int = 128
    column_name: str = "text"
    shuffle: bool = True
    seed: int | None = None
    streaming: bool = False
    pretokenize_batch_size: int | None = 1000

    # special tokens
    begin_batch_token: int | Literal["bos", "eos", "sep"] | None = "bos"
    begin_sequence_token: int | Literal["bos", "eos", "sep"] | None = None
    sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos"

    # sequence processing
    disable_concat_sequences: bool = False

    # if saving locally, set save_path
    save_path: str | None = None

    # if saving to huggingface, set hf_repo_id
    hf_repo_id: str | None = None
    hf_num_shards: int = 64
    hf_revision: str = "main"
    hf_is_private_repo: bool = False

SAE

Bases: HookedRootModule, Generic[T_SAE_CONFIG], ABC

Abstract base class for all SAE architectures.

Source code in sae_lens/saes/sae.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
    """Abstract base class for all SAE architectures."""

    cfg: T_SAE_CONFIG
    dtype: torch.dtype
    device: torch.device
    _use_error_term: bool

    # For type checking only - don't provide default values
    # These will be initialized by subclasses
    W_enc: nn.Parameter
    W_dec: nn.Parameter
    b_dec: nn.Parameter

    def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
        """Initialize the SAE."""
        super().__init__()

        self.cfg = cfg

        if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
            warnings.warn(
                "\nThis SAE has non-empty model_from_pretrained_kwargs. "
                "\nFor optimal performance, load the model like so:\n"
                "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
                category=UserWarning,
                stacklevel=1,
            )

        self.dtype = str_to_dtype(cfg.dtype)
        self.device = torch.device(cfg.device)
        self._use_error_term = False  # Set directly to avoid warning during init
        if use_error_term:
            self.use_error_term = True  # Use property setter to trigger warning

        # Set up activation function
        self.activation_fn = self.get_activation_fn()

        # Initialize weights
        self.initialize_weights()

        # Set up hooks
        self.hook_sae_input = HookPoint()
        self.hook_sae_acts_pre = HookPoint()
        self.hook_sae_acts_post = HookPoint()
        self.hook_sae_output = HookPoint()
        self.hook_sae_recons = HookPoint()
        self.hook_sae_error = HookPoint()

        # handle hook_z reshaping if needed.
        if self.cfg.reshape_activations == "hook_z":
            self.turn_on_forward_pass_hook_z_reshaping()
        else:
            self.turn_off_forward_pass_hook_z_reshaping()

        # Set up activation normalization
        self._setup_activation_normalization()

        self.setup()  # Required for HookedRootModule

    @property
    def use_error_term(self) -> bool:
        return self._use_error_term

    @use_error_term.setter
    def use_error_term(self, value: bool) -> None:
        if value and not self._use_error_term:
            warnings.warn(
                "Setting use_error_term directly on SAE is deprecated. "
                "Use HookedSAETransformer.add_sae(sae, use_error_term=True) instead. "
                "This will be removed in a future version.",
                DeprecationWarning,
                stacklevel=2,
            )
        self._use_error_term = value

    @torch.no_grad()
    def fold_activation_norm_scaling_factor(self, scaling_factor: float):
        self.W_enc.data *= scaling_factor  # type: ignore
        self.W_dec.data /= scaling_factor  # type: ignore
        self.b_dec.data /= scaling_factor  # type: ignore
        self.cfg.normalize_activations = "none"

    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        """Get the activation function specified in config."""
        return nn.ReLU()

    def _setup_activation_normalization(self):
        """Set up activation normalization functions based on config."""
        if self.cfg.normalize_activations == "constant_norm_rescale":

            def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
                self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
                return x * self.x_norm_coeff

            def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
                x = x / self.x_norm_coeff  # type: ignore
                del self.x_norm_coeff
                return x

            self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
            self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

        elif self.cfg.normalize_activations == "layer_norm":
            #  we need to scale the norm of the input and store the scaling factor
            def run_time_activation_ln_in(
                x: torch.Tensor, eps: float = 1e-5
            ) -> torch.Tensor:
                mu = x.mean(dim=-1, keepdim=True)
                x = x - mu
                std = x.std(dim=-1, keepdim=True)
                x = x / (std + eps)
                self.ln_mu = mu
                self.ln_std = std
                return x

            def run_time_activation_ln_out(
                x: torch.Tensor,
                eps: float = 1e-5,  # noqa: ARG001
            ) -> torch.Tensor:
                return x * self.ln_std + self.ln_mu  # type: ignore

            self.run_time_activation_norm_fn_in = run_time_activation_ln_in
            self.run_time_activation_norm_fn_out = run_time_activation_ln_out
        else:
            self.run_time_activation_norm_fn_in = lambda x: x
            self.run_time_activation_norm_fn_out = lambda x: x

    def initialize_weights(self):
        """Initialize model weights."""
        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
        )

        w_dec_data = torch.empty(
            self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
        )
        nn.init.kaiming_uniform_(w_dec_data)
        self.W_dec = nn.Parameter(w_dec_data)

        w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
        self.W_enc = nn.Parameter(w_enc_data)

    @abstractmethod
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode input tensor to feature space."""
        pass

    @abstractmethod
    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """Decode feature activations back to input space."""
        pass

    def turn_on_forward_pass_hook_z_reshaping(self):
        if (
            self.cfg.metadata.hook_name is not None
            and not self.cfg.metadata.hook_name.endswith("_z")
        ):
            raise ValueError("This method should only be called for hook_z SAEs.")

        # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")

        def reshape_fn_in(x: torch.Tensor):
            # print(f"reshape_fn_in input shape: {x.shape}")
            self.d_head = x.shape[-1]
            # print(f"Setting d_head to: {self.d_head}")
            self.reshape_fn_in = lambda x: einops.rearrange(
                x, "... n_heads d_head -> ... (n_heads d_head)"
            )
            return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")

        self.reshape_fn_in = reshape_fn_in
        self.reshape_fn_out = lambda x, d_head: einops.rearrange(
            x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
        )
        self.hook_z_reshaping_mode = True
        # print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")

    def turn_off_forward_pass_hook_z_reshaping(self):
        self.reshape_fn_in = lambda x: x
        self.reshape_fn_out = lambda x, d_head: x  # noqa: ARG005
        self.d_head = None
        self.hook_z_reshaping_mode = False

    @overload
    def to(
        self: T_SAE,
        device: torch.device | str | None = ...,
        dtype: torch.dtype | None = ...,
        non_blocking: bool = ...,
    ) -> T_SAE: ...

    @overload
    def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...

    @overload
    def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...

    def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE:  # type: ignore
        device_arg = None
        dtype_arg = None

        # Check args
        for arg in args:
            if isinstance(arg, (torch.device, str)):
                device_arg = arg
            elif isinstance(arg, torch.dtype):
                dtype_arg = arg
            elif isinstance(arg, torch.Tensor):
                device_arg = arg.device
                dtype_arg = arg.dtype

        # Check kwargs
        device_arg = kwargs.get("device", device_arg)
        dtype_arg = kwargs.get("dtype", dtype_arg)

        # Update device in config if provided
        if device_arg is not None:
            # Convert device to torch.device if it's a string
            device = (
                torch.device(device_arg) if isinstance(device_arg, str) else device_arg
            )

            # Update the cfg.device
            self.cfg.device = str(device)

            # Update the device property
            self.device = device

        # Update dtype in config if provided
        if dtype_arg is not None:
            # Update the cfg.dtype (use canonical short form like "float32")
            self.cfg.dtype = dtype_to_str(dtype_arg)

            # Update the dtype property
            self.dtype = dtype_arg

        return super().to(*args, **kwargs)

    def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
        sae_in = sae_in.to(self.dtype)
        sae_in = self.reshape_fn_in(sae_in)

        sae_in = self.hook_sae_input(sae_in)
        sae_in = self.run_time_activation_norm_fn_in(sae_in)

        # Here's where the error happens
        bias_term = self.b_dec * self.cfg.apply_b_dec_to_input

        return sae_in - bias_term

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the SAE."""
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)

        if self.use_error_term:
            with torch.no_grad():
                # Recompute without hooks for true error term
                with _disable_hooks(self):
                    feature_acts_clean = self.encode(x)
                    x_reconstruct_clean = self.decode(feature_acts_clean)
                sae_error = self.hook_sae_error(x - x_reconstruct_clean)
            sae_out = sae_out + sae_error

        return self.hook_sae_output(sae_out)

    # overwrite this in subclasses to modify the state_dict in-place before saving
    def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
        pass

    # overwrite this in subclasses to modify the state_dict in-place after loading
    def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
        pass

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """Fold decoder norms into encoder."""
        W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T

        # Only update b_enc if it exists (standard/jumprelu architectures)
        if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
            self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

    def get_name(self):
        """Generate a name for this SAE."""
        return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"

    def save_model(self, path: str | Path) -> tuple[Path, Path]:
        """Save model weights and config to disk."""
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)

        # Generate the weights
        state_dict = self.state_dict()  # Use internal SAE state dict
        self.process_state_dict_for_saving(state_dict)
        model_weights_path = path / SAE_WEIGHTS_FILENAME
        save_file(state_dict, model_weights_path)

        # Save the config
        config = self.cfg.to_dict()
        cfg_path = path / SAE_CFG_FILENAME
        with open(cfg_path, "w") as f:
            json.dump(config, f)

        return model_weights_path, cfg_path

    # Class methods for loading models
    @classmethod
    @deprecated("Use load_from_disk instead")
    def load_from_pretrained(
        cls: type[T_SAE],
        path: str | Path,
        device: str = "cpu",
        dtype: str | None = None,
    ) -> T_SAE:
        return cls.load_from_disk(path, device=device, dtype=dtype)

    @classmethod
    def load_from_disk(
        cls: type[T_SAE],
        path: str | Path,
        device: str = "cpu",
        dtype: str | None = None,
        converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
    ) -> T_SAE:
        """
        Load a SAE from disk.

        Args:
            path: The path to the SAE weights and config.
            device: The device to load the SAE on, defaults to "cpu".
            dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
            converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
        """
        overrides = {"dtype": dtype} if dtype is not None else None
        cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
        cfg_dict = handle_config_defaulting(cfg_dict)
        sae_config_cls = cls.get_sae_config_class_for_architecture(
            cfg_dict["architecture"]
        )
        sae_cfg = sae_config_cls.from_dict(cfg_dict)
        sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
        # hack to avoid using double memory when loading the SAE.
        # first put the SAE on the meta device, then load the weights.
        device = sae_cfg.device
        sae_cfg.device = "meta"
        sae = sae_cls(sae_cfg)
        sae.cfg.device = device
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict, assign=True)
        # the loaders should already handle the dtype / device conversion
        # but this is a fallback to guarantee the SAE is on the correct device and dtype
        return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)

    @classmethod
    def from_pretrained(
        cls: type[T_SAE],
        release: str,
        sae_id: str,
        device: str = "cpu",
        dtype: str = "float32",
        force_download: bool = False,
        converter: PretrainedSaeHuggingfaceLoader | None = None,
    ) -> T_SAE:
        """
        Load a pretrained SAE from the Hugging Face model hub.

        Args:
            release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
            id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
            device: The device to load the SAE on, defaults to "cpu".
            dtype: The dtype to load the SAE on, defaults to "float32".
            force_download: Whether to force download the SAE weights and config, defaults to False.
            converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
        """
        return cls.from_pretrained_with_cfg_and_sparsity(
            release,
            sae_id,
            device,
            force_download=force_download,
            dtype=dtype,
            converter=converter,
        )[0]

    @classmethod
    def from_pretrained_with_cfg_and_sparsity(
        cls: type[T_SAE],
        release: str,
        sae_id: str,
        device: str = "cpu",
        dtype: str = "float32",
        force_download: bool = False,
        converter: PretrainedSaeHuggingfaceLoader | None = None,
    ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
        """
        Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
        In SAELens <= 5.x.x, this was called SAE.from_pretrained().

        Args:
            release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
            id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
            device: The device to load the SAE on, defaults to "cpu".
            dtype: The dtype to load the SAE on, defaults to "float32".
            force_download: Whether to force download the SAE weights and config, defaults to False.
            converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
        """

        # get sae directory
        sae_directory = get_pretrained_saes_directory()

        # Validate release and sae_id
        if release not in sae_directory:
            if "/" not in release:
                raise ValueError(
                    f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
                )
            # Check if the user passed a repo_id that's in the pretrained SAEs list
            matching_releases = get_releases_for_repo_id(release)
            if matching_releases:
                warnings.warn(
                    f"You are loading an SAE using the HuggingFace repo_id '{release}' directly. "
                    f"This repo is registered in the official pretrained SAEs list with release name(s): {matching_releases}. "
                    f"For better compatibility and to access additional metadata, consider loading with: "
                    f"SAE.from_pretrained(release='{matching_releases[0]}', sae_id='<sae_id>'). "
                    f"See the full list of pretrained SAEs at: https://decoderesearch.github.io/SAELens/latest/pretrained_saes/",
                    UserWarning,
                    stacklevel=2,
                )
        elif sae_id not in sae_directory[release].saes_map:
            valid_ids = list(sae_directory[release].saes_map.keys())
            # Shorten the lengthy string of valid IDs
            if len(valid_ids) > 5:
                str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
            else:
                str_valid_ids = str(valid_ids)

            raise ValueError(
                f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
            )

        conversion_loader = (
            converter
            or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
        )
        repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
        config_overrides = get_config_overrides(release, sae_id)
        config_overrides["device"] = device
        config_overrides["dtype"] = dtype

        # Load config and weights
        cfg_dict, state_dict, log_sparsities = conversion_loader(
            repo_id=repo_id,
            folder_name=folder_name,
            device=device,
            force_download=force_download,
            cfg_overrides=config_overrides,
        )
        cfg_dict = handle_config_defaulting(cfg_dict)

        # Create SAE with appropriate architecture
        sae_config_cls = cls.get_sae_config_class_for_architecture(
            cfg_dict["architecture"]
        )
        sae_cfg = sae_config_cls.from_dict(cfg_dict)
        sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
        # hack to avoid using double memory when loading the SAE.
        # first put the SAE on the meta device, then load the weights.
        device = sae_cfg.device
        sae_cfg.device = "meta"
        sae = sae_cls(sae_cfg)
        sae.cfg.device = device
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict, assign=True)

        # the loaders should already handle the dtype / device conversion
        # but this is a fallback to guarantee the SAE is on the correct device and dtype
        return (
            sae.to(dtype=str_to_dtype(dtype), device=device),
            cfg_dict,
            log_sparsities,
        )

    @classmethod
    def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
        """Create an SAE from a config dictionary."""
        sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
        sae_config_cls = cls.get_sae_config_class_for_architecture(
            config_dict["architecture"]
        )
        return sae_cls(sae_config_cls.from_dict(config_dict))

    @classmethod
    def get_sae_class_for_architecture(
        cls: type[T_SAE], architecture: str
    ) -> type[T_SAE]:
        """Get the SAE class for a given architecture."""
        sae_cls, _ = get_sae_class(architecture)
        if not issubclass(sae_cls, cls):
            raise ValueError(
                f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
            )
        return sae_cls

    # in the future, this can be used to load different config classes for different architectures
    @classmethod
    def get_sae_config_class_for_architecture(
        cls,
        architecture: str,  # noqa: ARG003
    ) -> type[SAEConfig]:
        return SAEConfig

    ### Methods to support deprecated usage of SAE.from_pretrained() ###

    def __getitem__(self, index: int) -> Any:
        """
        Support indexing for backward compatibility with tuple unpacking.
        DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
        Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
        """
        warnings.warn(
            "Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
            "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
            "to get the config dict and sparsity as well.",
            DeprecationWarning,
            stacklevel=2,
        )

        if index == 0:
            return self
        if index == 1:
            return self.cfg.to_dict()
        if index == 2:
            return None
        raise IndexError(f"SAE tuple index {index} out of range")

    def __iter__(self):
        """
        Support unpacking for backward compatibility with tuple unpacking.
        DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
        Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
        """
        warnings.warn(
            "Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
            "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
            "to get the config dict and sparsity as well.",
            DeprecationWarning,
            stacklevel=2,
        )

        yield self
        yield self.cfg.to_dict()
        yield None

    def __len__(self) -> int:
        """
        Support len() for backward compatibility with tuple unpacking.
        DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
        Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
        """
        warnings.warn(
            "Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
            "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
            "to get the config dict and sparsity as well.",
            DeprecationWarning,
            stacklevel=2,
        )

        return 3

__getitem__(index)

Support indexing for backward compatibility with tuple unpacking. DEPRECATED: SAE.from_pretrained() no longer returns a tuple. Use SAE.from_pretrained_with_cfg_and_sparsity() instead.

Source code in sae_lens/saes/sae.py
def __getitem__(self, index: int) -> Any:
    """
    Support indexing for backward compatibility with tuple unpacking.
    DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
    Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
    """
    warnings.warn(
        "Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
        "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
        "to get the config dict and sparsity as well.",
        DeprecationWarning,
        stacklevel=2,
    )

    if index == 0:
        return self
    if index == 1:
        return self.cfg.to_dict()
    if index == 2:
        return None
    raise IndexError(f"SAE tuple index {index} out of range")

__init__(cfg, use_error_term=False)

Initialize the SAE.

Source code in sae_lens/saes/sae.py
def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
    """Initialize the SAE."""
    super().__init__()

    self.cfg = cfg

    if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
        warnings.warn(
            "\nThis SAE has non-empty model_from_pretrained_kwargs. "
            "\nFor optimal performance, load the model like so:\n"
            "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
            category=UserWarning,
            stacklevel=1,
        )

    self.dtype = str_to_dtype(cfg.dtype)
    self.device = torch.device(cfg.device)
    self._use_error_term = False  # Set directly to avoid warning during init
    if use_error_term:
        self.use_error_term = True  # Use property setter to trigger warning

    # Set up activation function
    self.activation_fn = self.get_activation_fn()

    # Initialize weights
    self.initialize_weights()

    # Set up hooks
    self.hook_sae_input = HookPoint()
    self.hook_sae_acts_pre = HookPoint()
    self.hook_sae_acts_post = HookPoint()
    self.hook_sae_output = HookPoint()
    self.hook_sae_recons = HookPoint()
    self.hook_sae_error = HookPoint()

    # handle hook_z reshaping if needed.
    if self.cfg.reshape_activations == "hook_z":
        self.turn_on_forward_pass_hook_z_reshaping()
    else:
        self.turn_off_forward_pass_hook_z_reshaping()

    # Set up activation normalization
    self._setup_activation_normalization()

    self.setup()  # Required for HookedRootModule

__iter__()

Support unpacking for backward compatibility with tuple unpacking. DEPRECATED: SAE.from_pretrained() no longer returns a tuple. Use SAE.from_pretrained_with_cfg_and_sparsity() instead.

Source code in sae_lens/saes/sae.py
def __iter__(self):
    """
    Support unpacking for backward compatibility with tuple unpacking.
    DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
    Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
    """
    warnings.warn(
        "Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
        "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
        "to get the config dict and sparsity as well.",
        DeprecationWarning,
        stacklevel=2,
    )

    yield self
    yield self.cfg.to_dict()
    yield None

__len__()

Support len() for backward compatibility with tuple unpacking. DEPRECATED: SAE.from_pretrained() no longer returns a tuple. Use SAE.from_pretrained_with_cfg_and_sparsity() instead.

Source code in sae_lens/saes/sae.py
def __len__(self) -> int:
    """
    Support len() for backward compatibility with tuple unpacking.
    DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
    Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
    """
    warnings.warn(
        "Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
        "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
        "to get the config dict and sparsity as well.",
        DeprecationWarning,
        stacklevel=2,
    )

    return 3

decode(feature_acts) abstractmethod

Decode feature activations back to input space.

Source code in sae_lens/saes/sae.py
@abstractmethod
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """Decode feature activations back to input space."""
    pass

encode(x) abstractmethod

Encode input tensor to feature space.

Source code in sae_lens/saes/sae.py
@abstractmethod
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """Encode input tensor to feature space."""
    pass

fold_W_dec_norm()

Fold decoder norms into encoder.

Source code in sae_lens/saes/sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """Fold decoder norms into encoder."""
    W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
    self.W_dec.data = self.W_dec.data / W_dec_norms
    self.W_enc.data = self.W_enc.data * W_dec_norms.T

    # Only update b_enc if it exists (standard/jumprelu architectures)
    if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
        self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

forward(x)

Forward pass through the SAE.

Source code in sae_lens/saes/sae.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the SAE."""
    feature_acts = self.encode(x)
    sae_out = self.decode(feature_acts)

    if self.use_error_term:
        with torch.no_grad():
            # Recompute without hooks for true error term
            with _disable_hooks(self):
                feature_acts_clean = self.encode(x)
                x_reconstruct_clean = self.decode(feature_acts_clean)
            sae_error = self.hook_sae_error(x - x_reconstruct_clean)
        sae_out = sae_out + sae_error

    return self.hook_sae_output(sae_out)

from_dict(config_dict) classmethod

Create an SAE from a config dictionary.

Source code in sae_lens/saes/sae.py
@classmethod
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
    """Create an SAE from a config dictionary."""
    sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
    sae_config_cls = cls.get_sae_config_class_for_architecture(
        config_dict["architecture"]
    )
    return sae_cls(sae_config_cls.from_dict(config_dict))

from_pretrained(release, sae_id, device='cpu', dtype='float32', force_download=False, converter=None) classmethod

Load a pretrained SAE from the Hugging Face model hub.

Parameters:

Name Type Description Default
release str

The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.

required
id

The id of the SAE to load. This will be mapped to a path in the huggingface repo.

required
device str

The device to load the SAE on, defaults to "cpu".

'cpu'
dtype str

The dtype to load the SAE on, defaults to "float32".

'float32'
force_download bool

Whether to force download the SAE weights and config, defaults to False.

False
converter PretrainedSaeHuggingfaceLoader | None

The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.

None
Source code in sae_lens/saes/sae.py
@classmethod
def from_pretrained(
    cls: type[T_SAE],
    release: str,
    sae_id: str,
    device: str = "cpu",
    dtype: str = "float32",
    force_download: bool = False,
    converter: PretrainedSaeHuggingfaceLoader | None = None,
) -> T_SAE:
    """
    Load a pretrained SAE from the Hugging Face model hub.

    Args:
        release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
        id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
        device: The device to load the SAE on, defaults to "cpu".
        dtype: The dtype to load the SAE on, defaults to "float32".
        force_download: Whether to force download the SAE weights and config, defaults to False.
        converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
    """
    return cls.from_pretrained_with_cfg_and_sparsity(
        release,
        sae_id,
        device,
        force_download=force_download,
        dtype=dtype,
        converter=converter,
    )[0]

from_pretrained_with_cfg_and_sparsity(release, sae_id, device='cpu', dtype='float32', force_download=False, converter=None) classmethod

Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present. In SAELens <= 5.x.x, this was called SAE.from_pretrained().

Parameters:

Name Type Description Default
release str

The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.

required
id

The id of the SAE to load. This will be mapped to a path in the huggingface repo.

required
device str

The device to load the SAE on, defaults to "cpu".

'cpu'
dtype str

The dtype to load the SAE on, defaults to "float32".

'float32'
force_download bool

Whether to force download the SAE weights and config, defaults to False.

False
converter PretrainedSaeHuggingfaceLoader | None

The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.

None
Source code in sae_lens/saes/sae.py
@classmethod
def from_pretrained_with_cfg_and_sparsity(
    cls: type[T_SAE],
    release: str,
    sae_id: str,
    device: str = "cpu",
    dtype: str = "float32",
    force_download: bool = False,
    converter: PretrainedSaeHuggingfaceLoader | None = None,
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
    """
    Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
    In SAELens <= 5.x.x, this was called SAE.from_pretrained().

    Args:
        release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
        id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
        device: The device to load the SAE on, defaults to "cpu".
        dtype: The dtype to load the SAE on, defaults to "float32".
        force_download: Whether to force download the SAE weights and config, defaults to False.
        converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
    """

    # get sae directory
    sae_directory = get_pretrained_saes_directory()

    # Validate release and sae_id
    if release not in sae_directory:
        if "/" not in release:
            raise ValueError(
                f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
            )
        # Check if the user passed a repo_id that's in the pretrained SAEs list
        matching_releases = get_releases_for_repo_id(release)
        if matching_releases:
            warnings.warn(
                f"You are loading an SAE using the HuggingFace repo_id '{release}' directly. "
                f"This repo is registered in the official pretrained SAEs list with release name(s): {matching_releases}. "
                f"For better compatibility and to access additional metadata, consider loading with: "
                f"SAE.from_pretrained(release='{matching_releases[0]}', sae_id='<sae_id>'). "
                f"See the full list of pretrained SAEs at: https://decoderesearch.github.io/SAELens/latest/pretrained_saes/",
                UserWarning,
                stacklevel=2,
            )
    elif sae_id not in sae_directory[release].saes_map:
        valid_ids = list(sae_directory[release].saes_map.keys())
        # Shorten the lengthy string of valid IDs
        if len(valid_ids) > 5:
            str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
        else:
            str_valid_ids = str(valid_ids)

        raise ValueError(
            f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
        )

    conversion_loader = (
        converter
        or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
    )
    repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
    config_overrides = get_config_overrides(release, sae_id)
    config_overrides["device"] = device
    config_overrides["dtype"] = dtype

    # Load config and weights
    cfg_dict, state_dict, log_sparsities = conversion_loader(
        repo_id=repo_id,
        folder_name=folder_name,
        device=device,
        force_download=force_download,
        cfg_overrides=config_overrides,
    )
    cfg_dict = handle_config_defaulting(cfg_dict)

    # Create SAE with appropriate architecture
    sae_config_cls = cls.get_sae_config_class_for_architecture(
        cfg_dict["architecture"]
    )
    sae_cfg = sae_config_cls.from_dict(cfg_dict)
    sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
    # hack to avoid using double memory when loading the SAE.
    # first put the SAE on the meta device, then load the weights.
    device = sae_cfg.device
    sae_cfg.device = "meta"
    sae = sae_cls(sae_cfg)
    sae.cfg.device = device
    sae.process_state_dict_for_loading(state_dict)
    sae.load_state_dict(state_dict, assign=True)

    # the loaders should already handle the dtype / device conversion
    # but this is a fallback to guarantee the SAE is on the correct device and dtype
    return (
        sae.to(dtype=str_to_dtype(dtype), device=device),
        cfg_dict,
        log_sparsities,
    )

get_activation_fn()

Get the activation function specified in config.

Source code in sae_lens/saes/sae.py
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
    """Get the activation function specified in config."""
    return nn.ReLU()

get_name()

Generate a name for this SAE.

Source code in sae_lens/saes/sae.py
def get_name(self):
    """Generate a name for this SAE."""
    return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"

get_sae_class_for_architecture(architecture) classmethod

Get the SAE class for a given architecture.

Source code in sae_lens/saes/sae.py
@classmethod
def get_sae_class_for_architecture(
    cls: type[T_SAE], architecture: str
) -> type[T_SAE]:
    """Get the SAE class for a given architecture."""
    sae_cls, _ = get_sae_class(architecture)
    if not issubclass(sae_cls, cls):
        raise ValueError(
            f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
        )
    return sae_cls

initialize_weights()

Initialize model weights.

Source code in sae_lens/saes/sae.py
def initialize_weights(self):
    """Initialize model weights."""
    self.b_dec = nn.Parameter(
        torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
    )

    w_dec_data = torch.empty(
        self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
    )
    nn.init.kaiming_uniform_(w_dec_data)
    self.W_dec = nn.Parameter(w_dec_data)

    w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
    self.W_enc = nn.Parameter(w_enc_data)

load_from_disk(path, device='cpu', dtype=None, converter=sae_lens_disk_loader) classmethod

Load a SAE from disk.

Parameters:

Name Type Description Default
path str | Path

The path to the SAE weights and config.

required
device str

The device to load the SAE on, defaults to "cpu".

'cpu'
dtype str | None

The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.

None
converter PretrainedSaeDiskLoader

The converter to use to load the SAE, defaults to sae_lens_disk_loader.

sae_lens_disk_loader
Source code in sae_lens/saes/sae.py
@classmethod
def load_from_disk(
    cls: type[T_SAE],
    path: str | Path,
    device: str = "cpu",
    dtype: str | None = None,
    converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
) -> T_SAE:
    """
    Load a SAE from disk.

    Args:
        path: The path to the SAE weights and config.
        device: The device to load the SAE on, defaults to "cpu".
        dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
        converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
    """
    overrides = {"dtype": dtype} if dtype is not None else None
    cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
    cfg_dict = handle_config_defaulting(cfg_dict)
    sae_config_cls = cls.get_sae_config_class_for_architecture(
        cfg_dict["architecture"]
    )
    sae_cfg = sae_config_cls.from_dict(cfg_dict)
    sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
    # hack to avoid using double memory when loading the SAE.
    # first put the SAE on the meta device, then load the weights.
    device = sae_cfg.device
    sae_cfg.device = "meta"
    sae = sae_cls(sae_cfg)
    sae.cfg.device = device
    sae.process_state_dict_for_loading(state_dict)
    sae.load_state_dict(state_dict, assign=True)
    # the loaders should already handle the dtype / device conversion
    # but this is a fallback to guarantee the SAE is on the correct device and dtype
    return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)

save_model(path)

Save model weights and config to disk.

Source code in sae_lens/saes/sae.py
def save_model(self, path: str | Path) -> tuple[Path, Path]:
    """Save model weights and config to disk."""
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Generate the weights
    state_dict = self.state_dict()  # Use internal SAE state dict
    self.process_state_dict_for_saving(state_dict)
    model_weights_path = path / SAE_WEIGHTS_FILENAME
    save_file(state_dict, model_weights_path)

    # Save the config
    config = self.cfg.to_dict()
    cfg_path = path / SAE_CFG_FILENAME
    with open(cfg_path, "w") as f:
        json.dump(config, f)

    return model_weights_path, cfg_path

SAEConfig dataclass

Bases: ABC

Base configuration for SAE models.

Source code in sae_lens/saes/sae.py
@dataclass
class SAEConfig(ABC):
    """Base configuration for SAE models."""

    d_in: int
    d_sae: int
    dtype: str = "float32"
    device: str = "cpu"
    apply_b_dec_to_input: bool = True
    normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
        "none"  # none, expected_average_only_in (Anthropic April Update)
    )
    reshape_activations: Literal["none", "hook_z"] = "none"
    metadata: SAEMetadata = field(default_factory=SAEMetadata)

    @classmethod
    @abstractmethod
    def architecture(cls) -> str: ...

    def to_dict(self) -> dict[str, Any]:
        res = {field.name: getattr(self, field.name) for field in fields(self)}
        res["metadata"] = self.metadata.to_dict()
        res["architecture"] = self.architecture()
        return res

    @classmethod
    def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
        cfg_class = get_sae_class(config_dict["architecture"])[1]
        filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
        res = cfg_class(**filtered_config_dict)
        if "metadata" in config_dict:
            res.metadata = SAEMetadata(**config_dict["metadata"])
        if not isinstance(res, cls):
            raise ValueError(
                f"SAE config class {cls} does not match dict config class {type(res)}"
            )
        return res

    def __post_init__(self):
        if self.normalize_activations not in [
            "none",
            "expected_average_only_in",
            "constant_norm_rescale",
            "layer_norm",
        ]:
            raise ValueError(
                f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
            )

SAETrainer

Bases: Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]

Trainer for Sparse Autoencoder (SAE) models.

Source code in sae_lens/training/sae_trainer.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
    """
    Trainer for Sparse Autoencoder (SAE) models.
    """

    data_provider: DataProvider
    activation_scaler: ActivationScaler
    evaluator: Evaluator[T_TRAINING_SAE] | None
    coefficient_schedulers: dict[str, CoefficientScheduler]

    def __init__(
        self,
        cfg: SAETrainerConfig,
        sae: T_TRAINING_SAE,
        data_provider: DataProvider,
        evaluator: Evaluator[T_TRAINING_SAE] | None = None,
        save_checkpoint_fn: SaveCheckpointFn | None = None,
    ) -> None:
        self.sae = sae
        self.data_provider = data_provider
        self.evaluator = evaluator
        self.activation_scaler = ActivationScaler()
        self.save_checkpoint_fn = save_checkpoint_fn
        self.cfg = cfg

        self.n_training_steps: int = 0
        self.n_training_samples: int = 0
        self.started_fine_tuning: bool = False

        _update_sae_lens_training_version(self.sae)

        self.checkpoint_thresholds = []
        if self.cfg.n_checkpoints > 0:
            self.checkpoint_thresholds = list(
                range(
                    0,
                    cfg.total_training_samples,
                    math.ceil(
                        cfg.total_training_samples / (self.cfg.n_checkpoints + 1)
                    ),
                )
            )[1:]

        self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
        self.n_forward_passes_since_fired = torch.zeros(
            sae.cfg.d_sae, device=cfg.device
        )
        self.n_frac_active_samples = 0

        self.optimizer = Adam(
            sae.parameters(),
            lr=cfg.lr,
            betas=(
                cfg.adam_beta1,
                cfg.adam_beta2,
            ),
        )
        assert cfg.lr_end is not None  # this is set in config post-init
        self.lr_scheduler = get_lr_scheduler(
            cfg.lr_scheduler_name,
            lr=cfg.lr,
            optimizer=self.optimizer,
            warm_up_steps=cfg.lr_warm_up_steps,
            decay_steps=cfg.lr_decay_steps,
            training_steps=self.cfg.total_training_steps,
            lr_end=cfg.lr_end,
            num_cycles=cfg.n_restart_cycles,
        )
        self.coefficient_schedulers = {}
        for name, coeff_cfg in self.sae.get_coefficients().items():
            if not isinstance(coeff_cfg, TrainCoefficientConfig):
                coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
            self.coefficient_schedulers[name] = CoefficientScheduler(
                warm_up_steps=coeff_cfg.warm_up_steps,
                final_value=coeff_cfg.value,
            )

        # Setup autocast if using
        self.grad_scaler = torch.amp.GradScaler(
            device=self.cfg.device, enabled=self.cfg.autocast
        )

        if self.cfg.autocast:
            self.autocast_if_enabled = torch.autocast(
                device_type=self.cfg.device,
                dtype=torch.bfloat16,
                enabled=self.cfg.autocast,
            )
        else:
            self.autocast_if_enabled = contextlib.nullcontext()

    @property
    def feature_sparsity(self) -> torch.Tensor:
        return self.act_freq_scores / self.n_frac_active_samples

    @property
    def log_feature_sparsity(self) -> torch.Tensor:
        return _log_feature_sparsity(self.feature_sparsity)

    @property
    def dead_neurons(self) -> torch.Tensor:
        return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()

    def fit(self) -> T_TRAINING_SAE:
        self.sae.to(self.cfg.device)
        pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")

        if self.sae.cfg.normalize_activations == "expected_average_only_in":
            self.activation_scaler.estimate_scaling_factor(
                d_in=self.sae.cfg.d_in,
                data_provider=self.data_provider,
                n_batches_for_norm_estimate=int(1e3),
            )

        # Train loop
        while self.n_training_samples < self.cfg.total_training_samples:
            # Do a training step.
            batch = next(self.data_provider).to(self.sae.device)
            self.n_training_samples += batch.shape[0]
            scaled_batch = self.activation_scaler(batch)

            step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)

            if self.cfg.logger.log_to_wandb:
                self._log_train_step(step_output)
                self._run_and_log_evals()

            self._checkpoint_if_needed()
            self.n_training_steps += 1
            self._update_pbar(step_output, pbar)

        # fold the estimated norm scaling factor into the sae weights
        if self.activation_scaler.scaling_factor is not None:
            self.sae.fold_activation_norm_scaling_factor(
                self.activation_scaler.scaling_factor
            )
            self.activation_scaler.scaling_factor = None

        if self.cfg.save_final_checkpoint:
            self.save_checkpoint(checkpoint_name=f"final_{self.n_training_samples}")

        pbar.close()
        return self.sae

    def save_checkpoint(
        self,
        checkpoint_name: str,
        wandb_aliases: list[str] | None = None,
    ) -> None:
        checkpoint_path = None
        if self.cfg.checkpoint_path is not None or self.cfg.logger.log_to_wandb:
            with path_or_tmp_dir(self.cfg.checkpoint_path) as base_checkpoint_path:
                checkpoint_path = base_checkpoint_path / checkpoint_name
                checkpoint_path.mkdir(exist_ok=True, parents=True)

                weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))

                sparsity_path = checkpoint_path / SPARSITY_FILENAME
                save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)

                self.save_trainer_state(checkpoint_path)

                if self.cfg.logger.log_to_wandb:
                    self.cfg.logger.log(
                        self,
                        weights_path,
                        cfg_path,
                        sparsity_path=sparsity_path,
                        wandb_aliases=wandb_aliases,
                    )

        if self.save_checkpoint_fn is not None:
            self.save_checkpoint_fn(checkpoint_path=checkpoint_path)

    def save_trainer_state(self, checkpoint_path: Path) -> None:
        checkpoint_path.mkdir(exist_ok=True, parents=True)
        scheduler_state_dicts = {
            name: scheduler.state_dict()
            for name, scheduler in self.coefficient_schedulers.items()
        }
        torch.save(
            {
                "optimizer": self.optimizer.state_dict(),
                "lr_scheduler": self.lr_scheduler.state_dict(),
                "n_training_samples": self.n_training_samples,
                "n_training_steps": self.n_training_steps,
                "act_freq_scores": self.act_freq_scores,
                "n_forward_passes_since_fired": self.n_forward_passes_since_fired,
                "n_frac_active_samples": self.n_frac_active_samples,
                "started_fine_tuning": self.started_fine_tuning,
                "coefficient_schedulers": scheduler_state_dicts,
            },
            str(checkpoint_path / TRAINER_STATE_FILENAME),
        )
        activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
        self.activation_scaler.save(str(activation_scaler_path))

    def load_trainer_state(self, checkpoint_path: Path | str) -> None:
        checkpoint_path = Path(checkpoint_path)
        self.activation_scaler.load(checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME)
        state_dict = torch.load(checkpoint_path / TRAINER_STATE_FILENAME)
        self.optimizer.load_state_dict(state_dict["optimizer"])
        self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
        self.n_training_samples = state_dict["n_training_samples"]
        self.n_training_steps = state_dict["n_training_steps"]
        self.act_freq_scores = state_dict["act_freq_scores"]
        self.n_forward_passes_since_fired = state_dict["n_forward_passes_since_fired"]
        self.n_frac_active_samples = state_dict["n_frac_active_samples"]
        self.started_fine_tuning = state_dict["started_fine_tuning"]
        for name, scheduler_state_dict in state_dict["coefficient_schedulers"].items():
            self.coefficient_schedulers[name].load_state_dict(scheduler_state_dict)

    def _train_step(
        self,
        sae: T_TRAINING_SAE,
        sae_in: torch.Tensor,
    ) -> TrainStepOutput:
        sae.train()

        # log and then reset the feature sparsity every feature_sampling_window steps
        if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
            if self.cfg.logger.log_to_wandb:
                sparsity_log_dict = self._build_sparsity_log_dict()
                wandb.log(sparsity_log_dict, step=self.n_training_steps)
            self._reset_running_sparsity_stats()

        # for documentation on autocasting see:
        # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
        with self.autocast_if_enabled:
            train_step_output = self.sae.training_forward_pass(
                step_input=TrainStepInput(
                    sae_in=sae_in,
                    dead_neuron_mask=self.dead_neurons,
                    coefficients=self.get_coefficients(),
                    n_training_steps=self.n_training_steps,
                ),
            )

            with torch.no_grad():
                # calling .bool() should be equivalent to .abs() > 0, and work with coo tensors
                firing_feats = train_step_output.feature_acts.bool().float()
                did_fire = firing_feats.sum(-2).bool()
                if did_fire.is_sparse:
                    did_fire = did_fire.to_dense()
                self.n_forward_passes_since_fired += 1
                self.n_forward_passes_since_fired[did_fire] = 0
                self.act_freq_scores += firing_feats.sum(0)
                self.n_frac_active_samples += self.cfg.train_batch_size_samples

        # Grad scaler will rescale gradients if autocast is enabled
        self.grad_scaler.scale(
            train_step_output.loss
        ).backward()  # loss.backward() if not autocasting
        self.grad_scaler.unscale_(self.optimizer)  # needed to clip correctly
        # TODO: Work out if grad norm clipping should be in config / how to test it.
        torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
        self.grad_scaler.step(
            self.optimizer
        )  # just ctx.optimizer.step() if not autocasting
        self.grad_scaler.update()

        self.optimizer.zero_grad()
        self.lr_scheduler.step()
        for scheduler in self.coefficient_schedulers.values():
            scheduler.step()

        return train_step_output

    @torch.no_grad()
    def _log_train_step(self, step_output: TrainStepOutput):
        if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
            wandb.log(
                self._build_train_step_log_dict(
                    output=step_output,
                    n_training_samples=self.n_training_samples,
                ),
                step=self.n_training_steps,
            )

    @torch.no_grad()
    def get_coefficients(self) -> dict[str, float]:
        return {
            name: scheduler.value
            for name, scheduler in self.coefficient_schedulers.items()
        }

    @torch.no_grad()
    def _build_train_step_log_dict(
        self,
        output: TrainStepOutput,
        n_training_samples: int,
    ) -> dict[str, Any]:
        sae_in = output.sae_in
        sae_out = output.sae_out
        feature_acts = output.feature_acts
        loss = output.loss.item()

        # metrics for currents acts
        l0 = feature_acts.bool().float().sum(-1).to_dense().mean()
        current_learning_rate = self.optimizer.param_groups[0]["lr"]

        per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
        total_variance = (sae_in - sae_in.mean(0)).pow(2).sum(-1)
        explained_variance_legacy = 1 - per_token_l2_loss / total_variance
        explained_variance = 1 - per_token_l2_loss.mean() / total_variance.mean()

        log_dict = {
            # losses
            "losses/overall_loss": loss,
            # variance explained
            "metrics/explained_variance_legacy": explained_variance_legacy.mean().item(),
            "metrics/explained_variance_legacy_std": explained_variance_legacy.std().item(),
            "metrics/explained_variance": explained_variance.item(),
            "metrics/l0": l0.item(),
            # sparsity
            "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
            "sparsity/dead_features": self.dead_neurons.sum().item(),
            "details/current_learning_rate": current_learning_rate,
            "details/n_training_samples": n_training_samples,
            **{
                f"details/{name}_coefficient": scheduler.value
                for name, scheduler in self.coefficient_schedulers.items()
            },
        }
        for loss_name, loss_value in output.losses.items():
            log_dict[f"losses/{loss_name}"] = _unwrap_item(loss_value)

        for metric_name, metric_value in output.metrics.items():
            log_dict[f"metrics/{metric_name}"] = _unwrap_item(metric_value)

        return log_dict

    @torch.no_grad()
    def _run_and_log_evals(self):
        # record loss frequently, but not all the time.
        if (self.n_training_steps + 1) % (
            self.cfg.logger.wandb_log_frequency
            * self.cfg.logger.eval_every_n_wandb_logs
        ) == 0:
            self.sae.eval()
            eval_metrics = (
                self.evaluator(self.sae, self.data_provider, self.activation_scaler)
                if self.evaluator is not None
                else {}
            )
            for key, value in self.sae.log_histograms().items():
                eval_metrics[key] = wandb.Histogram(value)  # type: ignore

            wandb.log(
                eval_metrics,
                step=self.n_training_steps,
            )
            self.sae.train()

    @torch.no_grad()
    def _build_sparsity_log_dict(self) -> dict[str, Any]:
        log_feature_sparsity = _log_feature_sparsity(self.feature_sparsity)
        wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy())  # type: ignore
        return {
            "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
            "plots/feature_density_line_chart": wandb_histogram,
            "sparsity/below_1e-5": (self.feature_sparsity < 1e-5).sum().item(),
            "sparsity/below_1e-6": (self.feature_sparsity < 1e-6).sum().item(),
        }

    @torch.no_grad()
    def _reset_running_sparsity_stats(self) -> None:
        self.act_freq_scores = torch.zeros(
            self.sae.cfg.d_sae,  # type: ignore
            device=self.cfg.device,
        )
        self.n_frac_active_samples = 0

    @torch.no_grad()
    def _checkpoint_if_needed(self):
        if (
            self.checkpoint_thresholds
            and self.n_training_samples > self.checkpoint_thresholds[0]
        ):
            self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
            self.checkpoint_thresholds.pop(0)

    @torch.no_grad()
    def _update_pbar(
        self,
        step_output: TrainStepOutput,
        pbar: tqdm,  # type: ignore
        update_interval: int = 100,
    ):
        if self.n_training_steps % update_interval == 0:
            loss_strs = " | ".join(
                f"{loss_name}: {_unwrap_item(loss_value):.5f}"
                for loss_name, loss_value in step_output.losses.items()
            )
            pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
            pbar.update(update_interval * self.cfg.train_batch_size_samples)

SkipTranscoder

Bases: Transcoder

A transcoder with a learnable skip connection.

Implements: f(x) = W_dec @ relu(W_enc @ x + b_enc) + W_skip @ x + b_dec where W_skip is initialized to zeros.

Source code in sae_lens/saes/transcoder.py
class SkipTranscoder(Transcoder):
    """
    A transcoder with a learnable skip connection.

    Implements: f(x) = W_dec @ relu(W_enc @ x + b_enc) + W_skip @ x + b_dec
    where W_skip is initialized to zeros.
    """

    cfg: SkipTranscoderConfig  # type: ignore[assignment]
    W_skip: nn.Parameter

    def __init__(self, cfg: SkipTranscoderConfig):
        super().__init__(cfg)
        self.cfg = cfg

        # Initialize skip connection matrix
        # Shape: [d_out, d_in] to map from input to output dimension
        self.W_skip = nn.Parameter(torch.zeros(self.cfg.d_out, self.cfg.d_in))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for skip transcoder.

        Args:
            x: Input activations from the input hook point [batch, d_in]

        Returns:
            sae_out: Reconstructed activations for the output hook point
            [batch, d_out]
        """
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)

        # Add skip connection: W_skip @ x
        # x has shape [batch, d_in], W_skip has shape [d_out, d_in]
        skip_out = x @ self.W_skip.T.to(x.device)
        return sae_out + skip_out

    def forward_with_activations(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass returning both output and feature activations.

        Args:
            x: Input activations from the input hook point [batch, d_in]

        Returns:
            sae_out: Reconstructed activations for the output hook point
            [batch, d_out]
            feature_acts: Hidden activations [batch, d_sae]
        """
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)

        # Add skip connection: W_skip @ x
        # x has shape [batch, d_in], W_skip has shape [d_out, d_in]
        skip_out = x @ self.W_skip.T.to(x.device)
        sae_out = sae_out + skip_out

        return sae_out, feature_acts

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoder":
        cfg = SkipTranscoderConfig.from_dict(config_dict)
        return cls(cfg)

forward(x)

Forward pass for skip transcoder.

Parameters:

Name Type Description Default
x Tensor

Input activations from the input hook point [batch, d_in]

required

Returns:

Name Type Description
sae_out Tensor

Reconstructed activations for the output hook point

Tensor

[batch, d_out]

Source code in sae_lens/saes/transcoder.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass for skip transcoder.

    Args:
        x: Input activations from the input hook point [batch, d_in]

    Returns:
        sae_out: Reconstructed activations for the output hook point
        [batch, d_out]
    """
    feature_acts = self.encode(x)
    sae_out = self.decode(feature_acts)

    # Add skip connection: W_skip @ x
    # x has shape [batch, d_in], W_skip has shape [d_out, d_in]
    skip_out = x @ self.W_skip.T.to(x.device)
    return sae_out + skip_out

forward_with_activations(x)

Forward pass returning both output and feature activations.

Parameters:

Name Type Description Default
x Tensor

Input activations from the input hook point [batch, d_in]

required

Returns:

Name Type Description
sae_out Tensor

Reconstructed activations for the output hook point

Tensor

[batch, d_out]

feature_acts tuple[Tensor, Tensor]

Hidden activations [batch, d_sae]

Source code in sae_lens/saes/transcoder.py
def forward_with_activations(
    self,
    x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Forward pass returning both output and feature activations.

    Args:
        x: Input activations from the input hook point [batch, d_in]

    Returns:
        sae_out: Reconstructed activations for the output hook point
        [batch, d_out]
        feature_acts: Hidden activations [batch, d_sae]
    """
    feature_acts = self.encode(x)
    sae_out = self.decode(feature_acts)

    # Add skip connection: W_skip @ x
    # x has shape [batch, d_in], W_skip has shape [d_out, d_in]
    skip_out = x @ self.W_skip.T.to(x.device)
    sae_out = sae_out + skip_out

    return sae_out, feature_acts

SkipTranscoderConfig dataclass

Bases: TranscoderConfig

Source code in sae_lens/saes/transcoder.py
@dataclass
class SkipTranscoderConfig(TranscoderConfig):
    @classmethod
    def architecture(cls) -> str:
        """Return the architecture name for this config."""
        return "skip_transcoder"

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoderConfig":
        """Create a SkipTranscoderConfig from a dictionary."""
        # Filter to only include valid dataclass fields
        filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

        # Create the config instance
        res = cls(**filtered_config_dict)

        # Handle metadata if present
        if "metadata" in config_dict:
            res.metadata = SAEMetadata(**config_dict["metadata"])

        return res

architecture() classmethod

Return the architecture name for this config.

Source code in sae_lens/saes/transcoder.py
@classmethod
def architecture(cls) -> str:
    """Return the architecture name for this config."""
    return "skip_transcoder"

from_dict(config_dict) classmethod

Create a SkipTranscoderConfig from a dictionary.

Source code in sae_lens/saes/transcoder.py
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoderConfig":
    """Create a SkipTranscoderConfig from a dictionary."""
    # Filter to only include valid dataclass fields
    filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

    # Create the config instance
    res = cls(**filtered_config_dict)

    # Handle metadata if present
    if "metadata" in config_dict:
        res.metadata = SAEMetadata(**config_dict["metadata"])

    return res

StandardSAE

Bases: SAE[StandardSAEConfig]

StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE) using a simple linear encoder and decoder.

It implements the required abstract methods from BaseSAE:

  • initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
  • encode: computes the feature activations from an input.
  • decode: reconstructs the input from the feature activations.

The BaseSAE.forward() method automatically calls encode and decode, including any error-term processing if configured.

Source code in sae_lens/saes/standard_sae.py
class StandardSAE(SAE[StandardSAEConfig]):
    """
    StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
    using a simple linear encoder and decoder.

    It implements the required abstract methods from BaseSAE:

      - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
      - encode: computes the feature activations from an input.
      - decode: reconstructs the input from the feature activations.

    The BaseSAE.forward() method automatically calls encode and decode,
    including any error-term processing if configured.
    """

    b_enc: nn.Parameter

    def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        # Initialize encoder weights and bias.
        super().initialize_weights()
        _init_weights_standard(self)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode the input tensor into the feature space.
        """
        # Preprocess the SAE input (casting type, applying hooks, normalization)
        sae_in = self.process_sae_in(x)
        # Compute the pre-activation values
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        # Apply the activation function (e.g., ReLU, depending on config)
        return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        Decode the feature activations back to the input space.
        Now, if hook_z reshaping is turned on, we reverse the flattening.
        """
        # 1) linear transform
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        # 2) hook reconstruction
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        # 4) optional out-normalization (e.g. constant_norm_rescale)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
        return self.reshape_fn_out(sae_out_pre, self.d_head)

decode(feature_acts)

Decode the feature activations back to the input space. Now, if hook_z reshaping is turned on, we reverse the flattening.

Source code in sae_lens/saes/standard_sae.py
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """
    Decode the feature activations back to the input space.
    Now, if hook_z reshaping is turned on, we reverse the flattening.
    """
    # 1) linear transform
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    # 2) hook reconstruction
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    # 4) optional out-normalization (e.g. constant_norm_rescale)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Encode the input tensor into the feature space.

Source code in sae_lens/saes/standard_sae.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode the input tensor into the feature space.
    """
    # Preprocess the SAE input (casting type, applying hooks, normalization)
    sae_in = self.process_sae_in(x)
    # Compute the pre-activation values
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    # Apply the activation function (e.g., ReLU, depending on config)
    return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

StandardSAEConfig dataclass

Bases: SAEConfig

Configuration class for a StandardSAE.

Source code in sae_lens/saes/standard_sae.py
@dataclass
class StandardSAEConfig(SAEConfig):
    """
    Configuration class for a StandardSAE.
    """

    @override
    @classmethod
    def architecture(cls) -> str:
        return "standard"

StandardTrainingSAE

Bases: TrainingSAE[StandardTrainingSAEConfig]

StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture. It implements:

  • initialize_weights: basic weight initialization for encoder/decoder.
  • encode: inference encoding (invokes encode_with_hidden_pre).
  • decode: a simple linear decoder.
  • encode_with_hidden_pre: computes activations and pre-activations.
  • calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
Source code in sae_lens/saes/standard_sae.py
class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
    """
    StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
    It implements:

      - initialize_weights: basic weight initialization for encoder/decoder.
      - encode: inference encoding (invokes encode_with_hidden_pre).
      - decode: a simple linear decoder.
      - encode_with_hidden_pre: computes activations and pre-activations.
      - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
    """

    b_enc: nn.Parameter

    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_standard(self)

    @override
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {
            "l1": TrainCoefficientConfig(
                value=self.cfg.l1_coefficient,
                warm_up_steps=self.cfg.l1_warm_up_steps,
            ),
        }

    def encode_with_hidden_pre(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Process the input (including dtype conversion, hook call, and any activation normalization)
        sae_in = self.process_sae_in(x)
        # Compute the pre-activation (and allow for a hook if desired)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)  # type: ignore
        # Apply the activation function (and any post-activation hook)
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
        return feature_acts, hidden_pre

    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # The "standard" auxiliary loss is a sparsity penalty on the feature activations
        weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)

        # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
        sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
        l1_loss = (step_input.coefficients["l1"] * sparsity).mean()

        return {"l1_loss": l1_loss}

    def log_histograms(self) -> dict[str, NDArray[np.generic]]:
        """Log histograms of the weights and biases."""
        b_e_dist = self.b_enc.detach().float().cpu().numpy()
        return {
            **super().log_histograms(),
            "weights/b_e": b_e_dist,
        }

log_histograms()

Log histograms of the weights and biases.

Source code in sae_lens/saes/standard_sae.py
def log_histograms(self) -> dict[str, NDArray[np.generic]]:
    """Log histograms of the weights and biases."""
    b_e_dist = self.b_enc.detach().float().cpu().numpy()
    return {
        **super().log_histograms(),
        "weights/b_e": b_e_dist,
    }

StandardTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a StandardTrainingSAE.

Source code in sae_lens/saes/standard_sae.py
@dataclass
class StandardTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a StandardTrainingSAE.
    """

    l1_coefficient: float = 1.0
    lp_norm: float = 1.0
    l1_warm_up_steps: int = 0

    @override
    @classmethod
    def architecture(cls) -> str:
        return "standard"

TemporalSAE

Bases: SAE[TemporalSAEConfig]

TemporalSAE: Sparse Autoencoder with temporal attention.

This SAE decomposes each activation x_t into:

  • x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
  • x_novel: Novel information at position t (encoded sparsely)

The forward pass: 1. Uses attention layers to predict x_t from context 2. Encodes the residual (novel part) with a sparse SAE 3. Combines both for reconstruction

Source code in sae_lens/saes/temporal_sae.py
class TemporalSAE(SAE[TemporalSAEConfig]):
    """TemporalSAE: Sparse Autoencoder with temporal attention.

    This SAE decomposes each activation x_t into:

    - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
    - x_novel: Novel information at position t (encoded sparsely)

    The forward pass:
    1. Uses attention layers to predict x_t from context
    2. Encodes the residual (novel part) with a sparse SAE
    3. Combines both for reconstruction
    """

    # Custom parameters (in addition to W_enc, W_dec, b_dec from base)
    attn_layers: nn.ModuleList  # Attention layers
    eps: float
    lam: float

    def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
        # Call parent init first
        super().__init__(cfg, use_error_term)

        # Initialize attention layers after parent init and move to correct device
        self.attn_layers = nn.ModuleList(
            [
                ManualAttention(
                    dimin=cfg.d_sae,
                    n_heads=cfg.n_heads,
                    bottleneck_factor=cfg.bottleneck_factor,
                    bias_k=True,
                    bias_q=True,
                    bias_v=True,
                    bias_o=True,
                ).to(device=self.device, dtype=self.dtype)
                for _ in range(cfg.n_attn_layers)
            ]
        )

        self.eps = 1e-6
        self.lam = 1 / (4 * self.cfg.d_in)

    @override
    def _setup_activation_normalization(self):
        """Set up activation normalization functions for TemporalSAE.

        Overrides the base implementation to handle constant_scalar_rescale
        using the temporal-specific activation_normalization_factor.
        """
        if self.cfg.normalize_activations == "constant_scalar_rescale":
            # Handle constant scalar rescaling for temporal SAEs
            def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
                return x * self.cfg.activation_normalization_factor

            def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
                return x / self.cfg.activation_normalization_factor

            self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
            self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
        else:
            # Delegate to parent for all other normalization types
            super()._setup_activation_normalization()

    @override
    def initialize_weights(self) -> None:
        """Initialize TemporalSAE weights."""
        # Initialize D (decoder) and b (bias)
        self.W_dec = nn.Parameter(
            torch.randn(
                (self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
            )
        )
        self.b_dec = nn.Parameter(
            torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
        )

        # Initialize E (encoder) if not tied
        if not self.cfg.tied_weights:
            self.W_enc = nn.Parameter(
                torch.randn(
                    (self.cfg.d_in, self.cfg.d_sae),
                    dtype=self.dtype,
                    device=self.device,
                )
            )

    def encode_with_predictions(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode input to novel codes only.

        Returns only the sparse novel codes (not predicted codes).
        This is the main feature representation for TemporalSAE.
        """
        # Process input through SAELens preprocessing
        x = self.process_sae_in(x)

        B, L, _ = x.shape

        if self.cfg.tied_weights:  # noqa: SIM108
            W_enc = self.W_dec.T
        else:
            W_enc = self.W_enc

        # Compute predicted codes using attention
        x_residual = x
        z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)

        for attn_layer in self.attn_layers:
            # Encode input to latent space
            z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))

            # Shift context (causal masking)
            z_ctx = torch.cat(
                (torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
            )

            # Apply attention to get predicted codes
            z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
            z_pred_ = F.relu(z_pred_)

            # Project predicted codes back to input space
            Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
            Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps

            # Compute projection scale
            proj_scale = (Dz_pred_ * x_residual).sum(
                dim=-1, keepdim=True
            ) / Dz_norm_.pow(2)

            # Accumulate predicted codes
            z_pred = z_pred + (z_pred_ * proj_scale)

            # Remove prediction from residual
            x_residual = x_residual - proj_scale * Dz_pred_

        # Encode residual (novel part) with sparse SAE
        z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
        if self.cfg.sae_diff_type == "topk":
            kval = self.cfg.kval_topk
            if kval is not None:
                _, topk_indices = torch.topk(z_novel, kval, dim=-1)
                mask = torch.zeros_like(z_novel)
                mask.scatter_(-1, topk_indices, 1)
                z_novel = z_novel * mask

        # Return only novel codes (these are the interpretable features)
        return z_novel, z_pred

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.encode_with_predictions(x)[0]

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """Decode novel codes to reconstruction.

        Note: This only decodes the novel codes. For full reconstruction,
        use forward() which includes predicted codes.
        """
        # Decode novel codes
        sae_out = torch.matmul(feature_acts, self.W_dec)
        sae_out = sae_out + self.b_dec

        # Apply hook
        sae_out = self.hook_sae_recons(sae_out)

        # Apply output activation normalization (reverses input normalization)
        sae_out = self.run_time_activation_norm_fn_out(sae_out)

        # Add bias (already removed in process_sae_in)
        logger.warning(
            "NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
        )
        return sae_out

    @override
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Full forward pass through TemporalSAE.

        Returns complete reconstruction (predicted + novel).
        """
        # Encode
        z_novel, z_pred = self.encode_with_predictions(x)

        # Decode the sum of predicted and novel codes.
        x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec

        # Apply output activation normalization (reverses input normalization)
        x_recons = self.run_time_activation_norm_fn_out(x_recons)

        return self.hook_sae_output(x_recons)

    @override
    def fold_W_dec_norm(self) -> None:
        raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")

    @override
    @torch.no_grad()
    def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
        raise NotImplementedError(
            "Folding activation norm scaling factor is not supported for TemporalSAE"
        )

decode(feature_acts)

Decode novel codes to reconstruction.

Note: This only decodes the novel codes. For full reconstruction, use forward() which includes predicted codes.

Source code in sae_lens/saes/temporal_sae.py
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """Decode novel codes to reconstruction.

    Note: This only decodes the novel codes. For full reconstruction,
    use forward() which includes predicted codes.
    """
    # Decode novel codes
    sae_out = torch.matmul(feature_acts, self.W_dec)
    sae_out = sae_out + self.b_dec

    # Apply hook
    sae_out = self.hook_sae_recons(sae_out)

    # Apply output activation normalization (reverses input normalization)
    sae_out = self.run_time_activation_norm_fn_out(sae_out)

    # Add bias (already removed in process_sae_in)
    logger.warning(
        "NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
    )
    return sae_out

encode_with_predictions(x)

Encode input to novel codes only.

Returns only the sparse novel codes (not predicted codes). This is the main feature representation for TemporalSAE.

Source code in sae_lens/saes/temporal_sae.py
def encode_with_predictions(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Encode input to novel codes only.

    Returns only the sparse novel codes (not predicted codes).
    This is the main feature representation for TemporalSAE.
    """
    # Process input through SAELens preprocessing
    x = self.process_sae_in(x)

    B, L, _ = x.shape

    if self.cfg.tied_weights:  # noqa: SIM108
        W_enc = self.W_dec.T
    else:
        W_enc = self.W_enc

    # Compute predicted codes using attention
    x_residual = x
    z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)

    for attn_layer in self.attn_layers:
        # Encode input to latent space
        z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))

        # Shift context (causal masking)
        z_ctx = torch.cat(
            (torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
        )

        # Apply attention to get predicted codes
        z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
        z_pred_ = F.relu(z_pred_)

        # Project predicted codes back to input space
        Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
        Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps

        # Compute projection scale
        proj_scale = (Dz_pred_ * x_residual).sum(
            dim=-1, keepdim=True
        ) / Dz_norm_.pow(2)

        # Accumulate predicted codes
        z_pred = z_pred + (z_pred_ * proj_scale)

        # Remove prediction from residual
        x_residual = x_residual - proj_scale * Dz_pred_

    # Encode residual (novel part) with sparse SAE
    z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
    if self.cfg.sae_diff_type == "topk":
        kval = self.cfg.kval_topk
        if kval is not None:
            _, topk_indices = torch.topk(z_novel, kval, dim=-1)
            mask = torch.zeros_like(z_novel)
            mask.scatter_(-1, topk_indices, 1)
            z_novel = z_novel * mask

    # Return only novel codes (these are the interpretable features)
    return z_novel, z_pred

forward(x)

Full forward pass through TemporalSAE.

Returns complete reconstruction (predicted + novel).

Source code in sae_lens/saes/temporal_sae.py
@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Full forward pass through TemporalSAE.

    Returns complete reconstruction (predicted + novel).
    """
    # Encode
    z_novel, z_pred = self.encode_with_predictions(x)

    # Decode the sum of predicted and novel codes.
    x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec

    # Apply output activation normalization (reverses input normalization)
    x_recons = self.run_time_activation_norm_fn_out(x_recons)

    return self.hook_sae_output(x_recons)

initialize_weights()

Initialize TemporalSAE weights.

Source code in sae_lens/saes/temporal_sae.py
@override
def initialize_weights(self) -> None:
    """Initialize TemporalSAE weights."""
    # Initialize D (decoder) and b (bias)
    self.W_dec = nn.Parameter(
        torch.randn(
            (self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
        )
    )
    self.b_dec = nn.Parameter(
        torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
    )

    # Initialize E (encoder) if not tied
    if not self.cfg.tied_weights:
        self.W_enc = nn.Parameter(
            torch.randn(
                (self.cfg.d_in, self.cfg.d_sae),
                dtype=self.dtype,
                device=self.device,
            )
        )

TemporalSAEConfig dataclass

Bases: SAEConfig

Configuration for TemporalSAE inference.

Parameters:

Name Type Description Default
d_in int

Input dimension (dimensionality of the activations being encoded)

required
d_sae int

SAE latent dimension (number of features)

required
n_heads int

Number of attention heads in temporal attention

8
n_attn_layers int

Number of attention layers

1
bottleneck_factor int

Bottleneck factor for attention dimension

64
sae_diff_type Literal['relu', 'topk']

Type of SAE for novel codes ('relu' or 'topk')

'topk'
kval_topk int | None

K value for top-k sparsity (if sae_diff_type='topk')

None
tied_weights bool

Whether to tie encoder and decoder weights

True
activation_normalization_factor float

Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')

1.0
Source code in sae_lens/saes/temporal_sae.py
@dataclass
class TemporalSAEConfig(SAEConfig):
    """Configuration for TemporalSAE inference.

    Args:
        d_in: Input dimension (dimensionality of the activations being encoded)
        d_sae: SAE latent dimension (number of features)
        n_heads: Number of attention heads in temporal attention
        n_attn_layers: Number of attention layers
        bottleneck_factor: Bottleneck factor for attention dimension
        sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
        kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
        tied_weights: Whether to tie encoder and decoder weights
        activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
    """

    n_heads: int = 8
    n_attn_layers: int = 1
    bottleneck_factor: int = 64
    sae_diff_type: Literal["relu", "topk"] = "topk"
    kval_topk: int | None = None
    tied_weights: bool = True
    activation_normalization_factor: float = 1.0

    def __post_init__(self):
        # Call parent's __post_init__ first, but allow constant_scalar_rescale
        if self.normalize_activations not in [
            "none",
            "expected_average_only_in",
            "constant_norm_rescale",
            "constant_scalar_rescale",  # Temporal SAEs support this
            "layer_norm",
        ]:
            raise ValueError(
                f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
            )

    @override
    @classmethod
    def architecture(cls) -> str:
        return "temporal"

TopKSAE

Bases: SAE[TopKSAEConfig]

An inference-only sparse autoencoder using a "topk" activation function. It uses linear encoder and decoder layers, applying the TopK activation to the hidden pre-activation in its encode step.

Source code in sae_lens/saes/topk_sae.py
class TopKSAE(SAE[TopKSAEConfig]):
    """
    An inference-only sparse autoencoder using a "topk" activation function.
    It uses linear encoder and decoder layers, applying the TopK activation
    to the hidden pre-activation in its encode step.
    """

    b_enc: nn.Parameter

    def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
        """
        Args:
            cfg: SAEConfig defining model size and behavior.
            use_error_term: Whether to apply the error-term approach in the forward pass.
        """
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        # Initialize encoder weights and bias.
        super().initialize_weights()
        _init_weights_topk(self)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Converts input x into feature activations.
        Uses topk activation under the hood.
        """
        sae_in = self.process_sae_in(x)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        if self.cfg.rescale_acts_by_decoder_norm:
            hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
        # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
        return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    def decode(
        self,
        feature_acts: torch.Tensor,
    ) -> torch.Tensor:
        """
        Reconstructs the input from topk feature activations.
        Applies optional finetuning scaling, hooking to recons, out normalization,
        and optional head reshaping.
        """
        # Handle sparse tensors using efficient sparse matrix multiplication
        if self.cfg.rescale_acts_by_decoder_norm:
            feature_acts = feature_acts / self.W_dec.norm(dim=-1)
        if feature_acts.is_sparse:
            sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
        else:
            sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @override
    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        return TopK(self.cfg.k, use_sparse_activations=False)

    @override
    @torch.no_grad()
    def fold_W_dec_norm(self) -> None:
        if not self.cfg.rescale_acts_by_decoder_norm:
            raise NotImplementedError(
                "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
            )
        _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)

__init__(cfg, use_error_term=False)

Parameters:

Name Type Description Default
cfg TopKSAEConfig

SAEConfig defining model size and behavior.

required
use_error_term bool

Whether to apply the error-term approach in the forward pass.

False
Source code in sae_lens/saes/topk_sae.py
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
    """
    Args:
        cfg: SAEConfig defining model size and behavior.
        use_error_term: Whether to apply the error-term approach in the forward pass.
    """
    super().__init__(cfg, use_error_term)

decode(feature_acts)

Reconstructs the input from topk feature activations. Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping.

Source code in sae_lens/saes/topk_sae.py
def decode(
    self,
    feature_acts: torch.Tensor,
) -> torch.Tensor:
    """
    Reconstructs the input from topk feature activations.
    Applies optional finetuning scaling, hooking to recons, out normalization,
    and optional head reshaping.
    """
    # Handle sparse tensors using efficient sparse matrix multiplication
    if self.cfg.rescale_acts_by_decoder_norm:
        feature_acts = feature_acts / self.W_dec.norm(dim=-1)
    if feature_acts.is_sparse:
        sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
    else:
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Converts input x into feature activations. Uses topk activation under the hood.

Source code in sae_lens/saes/topk_sae.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Converts input x into feature activations.
    Uses topk activation under the hood.
    """
    sae_in = self.process_sae_in(x)
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    if self.cfg.rescale_acts_by_decoder_norm:
        hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
    # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
    return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

TopKSAEConfig dataclass

Bases: SAEConfig

Configuration class for TopKSAE inference.

Parameters:

Name Type Description Default
k int

Number of top features to keep active during inference. Only the top k features with the highest pre-activations will be non-zero. Defaults to 100.

100
rescale_acts_by_decoder_norm bool

Whether to treat the decoder as if it was already normalized. This affects the topk selection by rescaling pre-activations by decoder norms. Requires that the SAE was trained this way. Defaults to False.

False
d_in int

Input dimension (dimensionality of the activations being encoded). Inherited from SAEConfig.

required
d_sae int

SAE latent dimension (number of features in the SAE). Inherited from SAEConfig.

required
dtype str

Data type for the SAE parameters. Inherited from SAEConfig. Defaults to "float32".

'float32'
device str

Device to place the SAE on. Inherited from SAEConfig. Defaults to "cpu".

'cpu'
apply_b_dec_to_input bool

Whether to apply decoder bias to the input before encoding. Inherited from SAEConfig. Defaults to True.

True
normalize_activations Literal[none, expected_average_only_in, constant_norm_rescale, layer_norm]

Normalization strategy for input activations. Inherited from SAEConfig. Defaults to "none".

'none'
reshape_activations Literal[none, hook_z]

How to reshape activations (useful for attention head outputs). Inherited from SAEConfig. Defaults to "none".

'none'
metadata SAEMetadata

Metadata about the SAE (model name, hook name, etc.). Inherited from SAEConfig.

SAEMetadata()
Source code in sae_lens/saes/topk_sae.py
@dataclass
class TopKSAEConfig(SAEConfig):
    """
    Configuration class for TopKSAE inference.

    Args:
        k (int): Number of top features to keep active during inference. Only the top k
            features with the highest pre-activations will be non-zero. Defaults to 100.
        rescale_acts_by_decoder_norm (bool): Whether to treat the decoder as if it was
            already normalized. This affects the topk selection by rescaling pre-activations
            by decoder norms. Requires that the SAE was trained this way. Defaults to False.
        d_in (int): Input dimension (dimensionality of the activations being encoded).
            Inherited from SAEConfig.
        d_sae (int): SAE latent dimension (number of features in the SAE).
            Inherited from SAEConfig.
        dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
            Defaults to "float32".
        device (str): Device to place the SAE on. Inherited from SAEConfig.
            Defaults to "cpu".
        apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
            before encoding. Inherited from SAEConfig. Defaults to True.
        normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
            Normalization strategy for input activations. Inherited from SAEConfig.
            Defaults to "none".
        reshape_activations (Literal["none", "hook_z"]): How to reshape activations
            (useful for attention head outputs). Inherited from SAEConfig.
            Defaults to "none".
        metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
            Inherited from SAEConfig.
    """

    k: int = 100
    rescale_acts_by_decoder_norm: bool = False

    @override
    @classmethod
    def architecture(cls) -> str:
        return "topk"

TopKTrainingSAE

Bases: TrainingSAE[TopKTrainingSAEConfig]

TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.

Source code in sae_lens/saes/topk_sae.py
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
    """
    TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
    """

    b_enc: nn.Parameter

    def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)
        self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
        self.setup()

    @override
    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_topk(self)

    def encode_with_hidden_pre(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Similar to the base training method: calculate pre-activations, then apply TopK.
        """
        sae_in = self.process_sae_in(x)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

        if self.cfg.rescale_acts_by_decoder_norm:
            hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)

        # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
        return feature_acts, hidden_pre

    @override
    def decode(
        self,
        feature_acts: torch.Tensor,
    ) -> torch.Tensor:
        """
        Decodes feature activations back into input space,
        applying optional finetuning scale, hooking, out normalization, etc.
        """
        # Handle sparse tensors using efficient sparse matrix multiplication
        if self.cfg.rescale_acts_by_decoder_norm:
            # need to multiply by the inverse of the norm because division is illegal with sparse tensors
            feature_acts = feature_acts * (1 / self.W_dec.norm(dim=-1))
        if feature_acts.is_sparse:
            sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
        else:
            sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @override
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the SAE."""
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)

        if self.use_error_term:
            with torch.no_grad():
                # Recompute without hooks for true error term
                with _disable_hooks(self):
                    feature_acts_clean = self.encode(x)
                    x_reconstruct_clean = self.decode(feature_acts_clean)
                sae_error = self.hook_sae_error(x - x_reconstruct_clean)
            sae_out = sae_out + sae_error

        return self.hook_sae_output(sae_out)

    @override
    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # Calculate the auxiliary loss for dead neurons
        topk_loss = self.calculate_topk_aux_loss(
            sae_in=step_input.sae_in,
            sae_out=sae_out,
            hidden_pre=hidden_pre,
            dead_neuron_mask=step_input.dead_neuron_mask,
        )
        return {"auxiliary_reconstruction_loss": topk_loss}

    @override
    @torch.no_grad()
    def fold_W_dec_norm(self) -> None:
        if not self.cfg.rescale_acts_by_decoder_norm:
            raise NotImplementedError(
                "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
            )
        _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)

    @override
    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)

    @override
    def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
        return {}

    def calculate_topk_aux_loss(
        self,
        sae_in: torch.Tensor,
        sae_out: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        """
        Calculate TopK auxiliary loss.

        This auxiliary loss encourages dead neurons to learn useful features by having
        them reconstruct the residual error from the live neurons. It's a key part of
        preventing neuron death in TopK SAEs.
        """
        # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
        # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
        if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
            return sae_out.new_tensor(0.0)
        residual = (sae_in - sae_out).detach()

        # Heuristic from Appendix B.1 in the paper
        k_aux = sae_in.shape[-1] // 2

        # Reduce the scale of the loss if there are a small number of dead latents
        scale = min(num_dead / k_aux, 1.0)
        k_aux = min(k_aux, num_dead)

        auxk_acts = _calculate_topk_aux_acts(
            k_aux=k_aux,
            hidden_pre=hidden_pre,
            dead_neuron_mask=dead_neuron_mask,
        )

        # Encourage the top ~50% of dead latents to predict the residual of the
        # top k living latents
        recons = self.decode(auxk_acts)
        auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
        return self.cfg.aux_loss_coefficient * scale * auxk_loss

    @override
    def process_state_dict_for_saving_inference(
        self, state_dict: dict[str, Any]
    ) -> None:
        super().process_state_dict_for_saving_inference(state_dict)
        if self.cfg.rescale_acts_by_decoder_norm:
            _fold_norm_topk(
                W_enc=state_dict["W_enc"],
                b_enc=state_dict["b_enc"],
                W_dec=state_dict["W_dec"],
            )

calculate_topk_aux_loss(sae_in, sae_out, hidden_pre, dead_neuron_mask)

Calculate TopK auxiliary loss.

This auxiliary loss encourages dead neurons to learn useful features by having them reconstruct the residual error from the live neurons. It's a key part of preventing neuron death in TopK SAEs.

Source code in sae_lens/saes/topk_sae.py
def calculate_topk_aux_loss(
    self,
    sae_in: torch.Tensor,
    sae_out: torch.Tensor,
    hidden_pre: torch.Tensor,
    dead_neuron_mask: torch.Tensor | None,
) -> torch.Tensor:
    """
    Calculate TopK auxiliary loss.

    This auxiliary loss encourages dead neurons to learn useful features by having
    them reconstruct the residual error from the live neurons. It's a key part of
    preventing neuron death in TopK SAEs.
    """
    # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
    # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
    if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
        return sae_out.new_tensor(0.0)
    residual = (sae_in - sae_out).detach()

    # Heuristic from Appendix B.1 in the paper
    k_aux = sae_in.shape[-1] // 2

    # Reduce the scale of the loss if there are a small number of dead latents
    scale = min(num_dead / k_aux, 1.0)
    k_aux = min(k_aux, num_dead)

    auxk_acts = _calculate_topk_aux_acts(
        k_aux=k_aux,
        hidden_pre=hidden_pre,
        dead_neuron_mask=dead_neuron_mask,
    )

    # Encourage the top ~50% of dead latents to predict the residual of the
    # top k living latents
    recons = self.decode(auxk_acts)
    auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
    return self.cfg.aux_loss_coefficient * scale * auxk_loss

decode(feature_acts)

Decodes feature activations back into input space, applying optional finetuning scale, hooking, out normalization, etc.

Source code in sae_lens/saes/topk_sae.py
@override
def decode(
    self,
    feature_acts: torch.Tensor,
) -> torch.Tensor:
    """
    Decodes feature activations back into input space,
    applying optional finetuning scale, hooking, out normalization, etc.
    """
    # Handle sparse tensors using efficient sparse matrix multiplication
    if self.cfg.rescale_acts_by_decoder_norm:
        # need to multiply by the inverse of the norm because division is illegal with sparse tensors
        feature_acts = feature_acts * (1 / self.W_dec.norm(dim=-1))
    if feature_acts.is_sparse:
        sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
    else:
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode_with_hidden_pre(x)

Similar to the base training method: calculate pre-activations, then apply TopK.

Source code in sae_lens/saes/topk_sae.py
def encode_with_hidden_pre(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Similar to the base training method: calculate pre-activations, then apply TopK.
    """
    sae_in = self.process_sae_in(x)
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

    if self.cfg.rescale_acts_by_decoder_norm:
        hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)

    # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
    feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
    return feature_acts, hidden_pre

forward(x)

Forward pass through the SAE.

Source code in sae_lens/saes/topk_sae.py
@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the SAE."""
    feature_acts = self.encode(x)
    sae_out = self.decode(feature_acts)

    if self.use_error_term:
        with torch.no_grad():
            # Recompute without hooks for true error term
            with _disable_hooks(self):
                feature_acts_clean = self.encode(x)
                x_reconstruct_clean = self.decode(feature_acts_clean)
            sae_error = self.hook_sae_error(x - x_reconstruct_clean)
        sae_out = sae_out + sae_error

    return self.hook_sae_output(sae_out)

TopKTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a TopKTrainingSAE.

Parameters:

Name Type Description Default
k int

Number of top features to keep active. Only the top k features with the highest pre-activations will be non-zero. Defaults to 100.

100
use_sparse_activations bool

Whether to use sparse tensor representations for activations during training. This can reduce memory usage and improve performance when k is small relative to d_sae, but is only worthwhile if using float32 and not using autocast. Defaults to False.

False
aux_loss_coefficient float

Coefficient for the auxiliary loss that encourages dead neurons to learn useful features. This loss helps prevent neuron death in TopK SAEs by having dead neurons reconstruct the residual error from live neurons. Defaults to 1.0.

1.0
rescale_acts_by_decoder_norm bool

Treat the decoder as if it was already normalized. This is a good idea since decoder norm can randomly drift during training, and this affects what the topk activations will be. Defaults to True.

True
decoder_init_norm float | None

Norm to initialize decoder weights to. 0.1 corresponds to the "heuristic" initialization from Anthropic's April update. Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.

0.1
d_in int

Input dimension (dimensionality of the activations being encoded). Inherited from SAEConfig.

required
d_sae int

SAE latent dimension (number of features in the SAE). Inherited from SAEConfig.

required
dtype str

Data type for the SAE parameters. Inherited from SAEConfig. Defaults to "float32".

'float32'
device str

Device to place the SAE on. Inherited from SAEConfig. Defaults to "cpu".

'cpu'
apply_b_dec_to_input bool

Whether to apply decoder bias to the input before encoding. Inherited from SAEConfig. Defaults to True.

True
normalize_activations Literal[none, expected_average_only_in, constant_norm_rescale, layer_norm]

Normalization strategy for input activations. Inherited from SAEConfig. Defaults to "none".

'none'
reshape_activations Literal[none, hook_z]

How to reshape activations (useful for attention head outputs). Inherited from SAEConfig. Defaults to "none".

'none'
metadata SAEMetadata

Metadata about the SAE training (model name, hook name, etc.). Inherited from SAEConfig.

SAEMetadata()
Source code in sae_lens/saes/topk_sae.py
@dataclass
class TopKTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a TopKTrainingSAE.

    Args:
        k (int): Number of top features to keep active. Only the top k features
            with the highest pre-activations will be non-zero. Defaults to 100.
        use_sparse_activations (bool): Whether to use sparse tensor representations
            for activations during training. This can reduce memory usage and improve
            performance when k is small relative to d_sae, but is only worthwhile if
            using float32 and not using autocast. Defaults to False.
        aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
            dead neurons to learn useful features. This loss helps prevent neuron death
            in TopK SAEs by having dead neurons reconstruct the residual error from
            live neurons. Defaults to 1.0.
        rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
            This is a good idea since decoder norm can randomly drift during training, and this
            affects what the topk activations will be. Defaults to True.
        decoder_init_norm (float | None): Norm to initialize decoder weights to.
            0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
            Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
        d_in (int): Input dimension (dimensionality of the activations being encoded).
            Inherited from SAEConfig.
        d_sae (int): SAE latent dimension (number of features in the SAE).
            Inherited from SAEConfig.
        dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
            Defaults to "float32".
        device (str): Device to place the SAE on. Inherited from SAEConfig.
            Defaults to "cpu".
        apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
            before encoding. Inherited from SAEConfig. Defaults to True.
        normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
            Normalization strategy for input activations. Inherited from SAEConfig.
            Defaults to "none".
        reshape_activations (Literal["none", "hook_z"]): How to reshape activations
            (useful for attention head outputs). Inherited from SAEConfig.
            Defaults to "none".
        metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
            Inherited from SAEConfig.
    """

    k: int = 100
    use_sparse_activations: bool = False
    aux_loss_coefficient: float = 1.0
    rescale_acts_by_decoder_norm: bool = True

    @override
    @classmethod
    def architecture(cls) -> str:
        return "topk"

TrainingSAE

Bases: SAE[T_TRAINING_SAE_CONFIG], ABC

Abstract base class for training versions of SAEs.

Source code in sae_lens/saes/sae.py
class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
    """Abstract base class for training versions of SAEs."""

    def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

        # Turn off hook_z reshaping for training mode - the activation store
        # is expected to handle reshaping before passing data to the SAE
        self.turn_off_forward_pass_hook_z_reshaping()
        self.mse_loss_fn = mse_loss

    @abstractmethod
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...

    @abstractmethod
    def encode_with_hidden_pre(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode with access to pre-activation values for training."""
        ...

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        For inference, just encode without returning hidden_pre.
        (training_forward_pass calls encode_with_hidden_pre).
        """
        feature_acts, _ = self.encode_with_hidden_pre(x)
        return feature_acts

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """
        Decodes feature activations back into input space,
        applying optional finetuning scale, hooking, out normalization, etc.
        """
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @override
    def initialize_weights(self):
        super().initialize_weights()
        if self.cfg.decoder_init_norm is not None:
            with torch.no_grad():
                self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
                self.W_dec.data *= self.cfg.decoder_init_norm
            self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()

    @abstractmethod
    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> torch.Tensor | dict[str, torch.Tensor]:
        """Calculate architecture-specific auxiliary loss terms."""
        ...

    def training_forward_pass(
        self,
        step_input: TrainStepInput,
    ) -> TrainStepOutput:
        """Forward pass during training."""
        feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
        sae_out = self.decode(feature_acts)

        # Calculate MSE loss
        per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
        mse_loss = per_item_mse_loss.sum(dim=-1).mean()

        # Calculate architecture-specific auxiliary losses
        aux_losses = self.calculate_aux_loss(
            step_input=step_input,
            feature_acts=feature_acts,
            hidden_pre=hidden_pre,
            sae_out=sae_out,
        )

        # Total loss is MSE plus all auxiliary losses
        total_loss = mse_loss

        # Create losses dictionary with mse_loss
        losses = {"mse_loss": mse_loss}

        # Add architecture-specific losses to the dictionary
        # Make sure aux_losses is a dictionary with string keys and tensor values
        if isinstance(aux_losses, dict):
            losses.update(aux_losses)

        # Sum all losses for total_loss
        if isinstance(aux_losses, dict):
            for loss_value in aux_losses.values():
                total_loss = total_loss + loss_value
        else:
            # Handle case where aux_losses is a tensor
            total_loss = total_loss + aux_losses

        return TrainStepOutput(
            sae_in=step_input.sae_in,
            sae_out=sae_out,
            feature_acts=feature_acts,
            hidden_pre=hidden_pre,
            loss=total_loss,
            losses=losses,
        )

    def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
        """Save inference version of model weights and config to disk."""
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)

        # Generate the weights
        state_dict = self.state_dict()  # Use internal SAE state dict
        self.process_state_dict_for_saving_inference(state_dict)
        model_weights_path = path / SAE_WEIGHTS_FILENAME
        save_file(state_dict, model_weights_path)

        # Save the config
        config = self.cfg.get_inference_sae_cfg_dict()
        cfg_path = path / SAE_CFG_FILENAME
        with open(cfg_path, "w") as f:
            json.dump(config, f)

        return model_weights_path, cfg_path

    def process_state_dict_for_saving_inference(
        self, state_dict: dict[str, Any]
    ) -> None:
        """
        Process the state dict for saving the inference model.
        This is a hook that can be overridden to change how the state dict is processed for the inference model.
        """
        return self.process_state_dict_for_saving(state_dict)

    @torch.no_grad()
    def log_histograms(self) -> dict[str, NDArray[Any]]:
        """Log histograms of the weights and biases."""
        W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
        return {
            "weights/W_dec_norms": W_dec_norm_dist,
        }

    @classmethod
    def get_sae_class_for_architecture(
        cls: type[T_TRAINING_SAE], architecture: str
    ) -> type[T_TRAINING_SAE]:
        """Get the SAE class for a given architecture."""
        sae_cls, _ = get_sae_training_class(architecture)
        if not issubclass(sae_cls, cls):
            raise ValueError(
                f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
            )
        return sae_cls

    # in the future, this can be used to load different config classes for different architectures
    @classmethod
    def get_sae_config_class_for_architecture(
        cls,
        architecture: str,  # noqa: ARG003
    ) -> type[TrainingSAEConfig]:
        return get_sae_training_class(architecture)[1]

    def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
        checkpoint_path = Path(checkpoint_path)
        state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
        self.process_state_dict_for_loading(state_dict)
        self.load_state_dict(state_dict)

calculate_aux_loss(step_input, feature_acts, hidden_pre, sae_out) abstractmethod

Calculate architecture-specific auxiliary loss terms.

Source code in sae_lens/saes/sae.py
@abstractmethod
def calculate_aux_loss(
    self,
    step_input: TrainStepInput,
    feature_acts: torch.Tensor,
    hidden_pre: torch.Tensor,
    sae_out: torch.Tensor,
) -> torch.Tensor | dict[str, torch.Tensor]:
    """Calculate architecture-specific auxiliary loss terms."""
    ...

decode(feature_acts)

Decodes feature activations back into input space, applying optional finetuning scale, hooking, out normalization, etc.

Source code in sae_lens/saes/sae.py
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """
    Decodes feature activations back into input space,
    applying optional finetuning scale, hooking, out normalization, etc.
    """
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

For inference, just encode without returning hidden_pre. (training_forward_pass calls encode_with_hidden_pre).

Source code in sae_lens/saes/sae.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    For inference, just encode without returning hidden_pre.
    (training_forward_pass calls encode_with_hidden_pre).
    """
    feature_acts, _ = self.encode_with_hidden_pre(x)
    return feature_acts

encode_with_hidden_pre(x) abstractmethod

Encode with access to pre-activation values for training.

Source code in sae_lens/saes/sae.py
@abstractmethod
def encode_with_hidden_pre(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Encode with access to pre-activation values for training."""
    ...

get_sae_class_for_architecture(architecture) classmethod

Get the SAE class for a given architecture.

Source code in sae_lens/saes/sae.py
@classmethod
def get_sae_class_for_architecture(
    cls: type[T_TRAINING_SAE], architecture: str
) -> type[T_TRAINING_SAE]:
    """Get the SAE class for a given architecture."""
    sae_cls, _ = get_sae_training_class(architecture)
    if not issubclass(sae_cls, cls):
        raise ValueError(
            f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
        )
    return sae_cls

log_histograms()

Log histograms of the weights and biases.

Source code in sae_lens/saes/sae.py
@torch.no_grad()
def log_histograms(self) -> dict[str, NDArray[Any]]:
    """Log histograms of the weights and biases."""
    W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
    return {
        "weights/W_dec_norms": W_dec_norm_dist,
    }

process_state_dict_for_saving_inference(state_dict)

Process the state dict for saving the inference model. This is a hook that can be overridden to change how the state dict is processed for the inference model.

Source code in sae_lens/saes/sae.py
def process_state_dict_for_saving_inference(
    self, state_dict: dict[str, Any]
) -> None:
    """
    Process the state dict for saving the inference model.
    This is a hook that can be overridden to change how the state dict is processed for the inference model.
    """
    return self.process_state_dict_for_saving(state_dict)

save_inference_model(path)

Save inference version of model weights and config to disk.

Source code in sae_lens/saes/sae.py
def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
    """Save inference version of model weights and config to disk."""
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Generate the weights
    state_dict = self.state_dict()  # Use internal SAE state dict
    self.process_state_dict_for_saving_inference(state_dict)
    model_weights_path = path / SAE_WEIGHTS_FILENAME
    save_file(state_dict, model_weights_path)

    # Save the config
    config = self.cfg.get_inference_sae_cfg_dict()
    cfg_path = path / SAE_CFG_FILENAME
    with open(cfg_path, "w") as f:
        json.dump(config, f)

    return model_weights_path, cfg_path

training_forward_pass(step_input)

Forward pass during training.

Source code in sae_lens/saes/sae.py
def training_forward_pass(
    self,
    step_input: TrainStepInput,
) -> TrainStepOutput:
    """Forward pass during training."""
    feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
    sae_out = self.decode(feature_acts)

    # Calculate MSE loss
    per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
    mse_loss = per_item_mse_loss.sum(dim=-1).mean()

    # Calculate architecture-specific auxiliary losses
    aux_losses = self.calculate_aux_loss(
        step_input=step_input,
        feature_acts=feature_acts,
        hidden_pre=hidden_pre,
        sae_out=sae_out,
    )

    # Total loss is MSE plus all auxiliary losses
    total_loss = mse_loss

    # Create losses dictionary with mse_loss
    losses = {"mse_loss": mse_loss}

    # Add architecture-specific losses to the dictionary
    # Make sure aux_losses is a dictionary with string keys and tensor values
    if isinstance(aux_losses, dict):
        losses.update(aux_losses)

    # Sum all losses for total_loss
    if isinstance(aux_losses, dict):
        for loss_value in aux_losses.values():
            total_loss = total_loss + loss_value
    else:
        # Handle case where aux_losses is a tensor
        total_loss = total_loss + aux_losses

    return TrainStepOutput(
        sae_in=step_input.sae_in,
        sae_out=sae_out,
        feature_acts=feature_acts,
        hidden_pre=hidden_pre,
        loss=total_loss,
        losses=losses,
    )

TrainingSAEConfig dataclass

Bases: SAEConfig, ABC

Source code in sae_lens/saes/sae.py
@dataclass(kw_only=True)
class TrainingSAEConfig(SAEConfig, ABC):
    # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
    # 0.1 corresponds to the "heuristic" initialization, use None to disable
    decoder_init_norm: float | None = 0.1

    @classmethod
    @abstractmethod
    def architecture(cls) -> str: ...

    @classmethod
    def from_sae_runner_config(
        cls: type[T_TRAINING_SAE_CONFIG],
        cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
    ) -> T_TRAINING_SAE_CONFIG:
        metadata = SAEMetadata(
            model_name=cfg.model_name,
            hook_name=cfg.hook_name,
            hook_head_index=cfg.hook_head_index,
            context_size=cfg.context_size,
            prepend_bos=cfg.prepend_bos,
            seqpos_slice=cfg.seqpos_slice,
            model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
        )
        if not isinstance(cfg.sae, cls):
            raise ValueError(
                f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
            )
        return replace(cfg.sae, metadata=metadata)

    @classmethod
    def from_dict(
        cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
    ) -> T_TRAINING_SAE_CONFIG:
        cfg_class = cls
        if "architecture" in config_dict:
            cfg_class = get_sae_training_class(config_dict["architecture"])[1]
        if not issubclass(cfg_class, cls):
            raise ValueError(
                f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
            )
        # remove any keys that are not in the dataclass
        # since we sometimes enhance the config with the whole LM runner config
        valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
        if "metadata" in config_dict:
            valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
        return cfg_class(**valid_config_dict)

    def to_dict(self) -> dict[str, Any]:
        return {
            **super().to_dict(),
            **asdict(self),
            "metadata": self.metadata.to_dict(),
            "architecture": self.architecture(),
        }

    def get_inference_config_class(self) -> type[SAEConfig]:
        """
        Get the architecture for inference.
        """
        return get_sae_class(self.architecture())[1]

    # this needs to exist so we can initialize the parent sae cfg without the training specific
    # parameters. Maybe there's a cleaner way to do this
    def get_inference_sae_cfg_dict(self) -> dict[str, Any]:
        """
        Creates a dictionary containing attributes corresponding to the fields
        defined in the base SAEConfig class.
        """
        base_sae_cfg_class = self.get_inference_config_class()
        base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
        result_dict = {
            field_name: getattr(self, field_name)
            for field_name in base_config_field_names
        }
        result_dict["architecture"] = base_sae_cfg_class.architecture()
        result_dict["metadata"] = self.metadata.to_dict()
        return result_dict

get_inference_config_class()

Get the architecture for inference.

Source code in sae_lens/saes/sae.py
def get_inference_config_class(self) -> type[SAEConfig]:
    """
    Get the architecture for inference.
    """
    return get_sae_class(self.architecture())[1]

get_inference_sae_cfg_dict()

Creates a dictionary containing attributes corresponding to the fields defined in the base SAEConfig class.

Source code in sae_lens/saes/sae.py
def get_inference_sae_cfg_dict(self) -> dict[str, Any]:
    """
    Creates a dictionary containing attributes corresponding to the fields
    defined in the base SAEConfig class.
    """
    base_sae_cfg_class = self.get_inference_config_class()
    base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
    result_dict = {
        field_name: getattr(self, field_name)
        for field_name in base_config_field_names
    }
    result_dict["architecture"] = base_sae_cfg_class.architecture()
    result_dict["metadata"] = self.metadata.to_dict()
    return result_dict

Transcoder

Bases: SAE[TranscoderConfig]

A transcoder maps activations from one hook point to another with potentially different dimensions. It extends the standard SAE but with a decoder that maps to a different output dimension.

Source code in sae_lens/saes/transcoder.py
class Transcoder(SAE[TranscoderConfig]):
    """
    A transcoder maps activations from one hook point to another with
    potentially different dimensions. It extends the standard SAE but with a
    decoder that maps to a different output dimension.
    """

    cfg: TranscoderConfig
    W_enc: nn.Parameter
    b_enc: nn.Parameter
    W_dec: nn.Parameter
    b_dec: nn.Parameter

    def __init__(self, cfg: TranscoderConfig):
        super().__init__(cfg)
        self.cfg = cfg

    def initialize_weights(self):
        """Initialize transcoder weights with proper dimensions."""
        # Initialize b_dec with output dimension
        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
        )

        # Initialize W_dec with shape [d_sae, d_out]
        w_dec_data = torch.empty(
            self.cfg.d_sae, self.cfg.d_out, dtype=self.dtype, device=self.device
        )
        nn.init.kaiming_uniform_(w_dec_data)
        self.W_dec = nn.Parameter(w_dec_data)

        # Initialize W_enc with shape [d_in, d_sae]
        w_enc_data = torch.empty(
            self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
        )
        nn.init.kaiming_uniform_(w_enc_data)
        self.W_enc = nn.Parameter(w_enc_data)

        # Initialize b_enc
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

    def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
        """
        Process input without applying decoder bias (which has wrong dimension
        for transcoder).

        Overrides the parent method to skip the bias subtraction since b_dec
        has dimension d_out which doesn't match the input dimension d_in.
        """
        # Don't apply b_dec since it has different dimension
        # Just handle dtype conversion and hooks
        sae_in = sae_in.to(self.dtype)
        sae_in = self.hook_sae_input(sae_in)
        return self.run_time_activation_norm_fn_in(sae_in)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode the input tensor into the feature space.
        """
        # Preprocess the SAE input (casting type, applying hooks, normalization)
        sae_in = self.process_sae_in(x)
        # Compute the pre-activation values
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        # Apply the activation function (e.g., ReLU)
        return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
        """Decode to output dimension."""
        # W_dec has shape [d_sae, d_out], feature_acts has shape
        # [batch, d_sae]
        sae_out = feature_acts @ self.W_dec + self.b_dec
        # Apply hooks
        # Note: We don't apply run_time_activation_norm_fn_out since the
        # output dimension is different from the input dimension
        return self.hook_sae_recons(sae_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for transcoder.

        Args:
            x: Input activations from the input hook point [batch, d_in]

        Returns:
            sae_out: Reconstructed activations for the output hook point
            [batch, d_out]
        """
        feature_acts = self.encode(x)
        return self.decode(feature_acts)

    def forward_with_activations(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass returning both output and feature activations.

        Args:
            x: Input activations from the input hook point [batch, d_in]

        Returns:
            sae_out: Reconstructed activations for the output hook point
            [batch, d_out]
            feature_acts: Hidden activations [batch, d_sae]
        """
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)
        return sae_out, feature_acts

    @property
    def d_out(self) -> int:
        """Output dimension of the transcoder."""
        return self.cfg.d_out

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "Transcoder":
        cfg = TranscoderConfig.from_dict(config_dict)
        return cls(cfg)

d_out: int property

Output dimension of the transcoder.

decode(feature_acts)

Decode to output dimension.

Source code in sae_lens/saes/transcoder.py
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
    """Decode to output dimension."""
    # W_dec has shape [d_sae, d_out], feature_acts has shape
    # [batch, d_sae]
    sae_out = feature_acts @ self.W_dec + self.b_dec
    # Apply hooks
    # Note: We don't apply run_time_activation_norm_fn_out since the
    # output dimension is different from the input dimension
    return self.hook_sae_recons(sae_out)

encode(x)

Encode the input tensor into the feature space.

Source code in sae_lens/saes/transcoder.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode the input tensor into the feature space.
    """
    # Preprocess the SAE input (casting type, applying hooks, normalization)
    sae_in = self.process_sae_in(x)
    # Compute the pre-activation values
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    # Apply the activation function (e.g., ReLU)
    return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

forward(x)

Forward pass for transcoder.

Parameters:

Name Type Description Default
x Tensor

Input activations from the input hook point [batch, d_in]

required

Returns:

Name Type Description
sae_out Tensor

Reconstructed activations for the output hook point

Tensor

[batch, d_out]

Source code in sae_lens/saes/transcoder.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass for transcoder.

    Args:
        x: Input activations from the input hook point [batch, d_in]

    Returns:
        sae_out: Reconstructed activations for the output hook point
        [batch, d_out]
    """
    feature_acts = self.encode(x)
    return self.decode(feature_acts)

forward_with_activations(x)

Forward pass returning both output and feature activations.

Parameters:

Name Type Description Default
x Tensor

Input activations from the input hook point [batch, d_in]

required

Returns:

Name Type Description
sae_out Tensor

Reconstructed activations for the output hook point

Tensor

[batch, d_out]

feature_acts tuple[Tensor, Tensor]

Hidden activations [batch, d_sae]

Source code in sae_lens/saes/transcoder.py
def forward_with_activations(
    self,
    x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Forward pass returning both output and feature activations.

    Args:
        x: Input activations from the input hook point [batch, d_in]

    Returns:
        sae_out: Reconstructed activations for the output hook point
        [batch, d_out]
        feature_acts: Hidden activations [batch, d_sae]
    """
    feature_acts = self.encode(x)
    sae_out = self.decode(feature_acts)
    return sae_out, feature_acts

initialize_weights()

Initialize transcoder weights with proper dimensions.

Source code in sae_lens/saes/transcoder.py
def initialize_weights(self):
    """Initialize transcoder weights with proper dimensions."""
    # Initialize b_dec with output dimension
    self.b_dec = nn.Parameter(
        torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
    )

    # Initialize W_dec with shape [d_sae, d_out]
    w_dec_data = torch.empty(
        self.cfg.d_sae, self.cfg.d_out, dtype=self.dtype, device=self.device
    )
    nn.init.kaiming_uniform_(w_dec_data)
    self.W_dec = nn.Parameter(w_dec_data)

    # Initialize W_enc with shape [d_in, d_sae]
    w_enc_data = torch.empty(
        self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
    )
    nn.init.kaiming_uniform_(w_enc_data)
    self.W_enc = nn.Parameter(w_enc_data)

    # Initialize b_enc
    self.b_enc = nn.Parameter(
        torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
    )

process_sae_in(sae_in)

Process input without applying decoder bias (which has wrong dimension for transcoder).

Overrides the parent method to skip the bias subtraction since b_dec has dimension d_out which doesn't match the input dimension d_in.

Source code in sae_lens/saes/transcoder.py
def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
    """
    Process input without applying decoder bias (which has wrong dimension
    for transcoder).

    Overrides the parent method to skip the bias subtraction since b_dec
    has dimension d_out which doesn't match the input dimension d_in.
    """
    # Don't apply b_dec since it has different dimension
    # Just handle dtype conversion and hooks
    sae_in = sae_in.to(self.dtype)
    sae_in = self.hook_sae_input(sae_in)
    return self.run_time_activation_norm_fn_in(sae_in)

TranscoderConfig dataclass

Bases: SAEConfig

Source code in sae_lens/saes/transcoder.py
@dataclass
class TranscoderConfig(SAEConfig):
    # Output dimension fields
    d_out: int = 768
    # hook_name_out: str = ""
    # hook_layer_out: int = 0
    # hook_head_index_out: int | None = None

    @classmethod
    def architecture(cls) -> str:
        """Return the architecture name for this config."""
        return "transcoder"

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "TranscoderConfig":
        """Create a TranscoderConfig from a dictionary."""
        # Filter to only include valid dataclass fields
        filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

        # Create the config instance
        res = cls(**filtered_config_dict)

        # Handle metadata if present
        if "metadata" in config_dict:
            res.metadata = SAEMetadata(**config_dict["metadata"])

        return res

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary, including parent fields."""
        # Get the base dictionary from parent
        res = super().to_dict()

        # Add transcoder-specific fields
        res.update({"d_out": self.d_out})

        return res

    def __post_init__(self):
        if self.apply_b_dec_to_input:
            raise ValueError("apply_b_dec_to_input is not supported for transcoders")
        return super().__post_init__()

architecture() classmethod

Return the architecture name for this config.

Source code in sae_lens/saes/transcoder.py
@classmethod
def architecture(cls) -> str:
    """Return the architecture name for this config."""
    return "transcoder"

from_dict(config_dict) classmethod

Create a TranscoderConfig from a dictionary.

Source code in sae_lens/saes/transcoder.py
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "TranscoderConfig":
    """Create a TranscoderConfig from a dictionary."""
    # Filter to only include valid dataclass fields
    filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)

    # Create the config instance
    res = cls(**filtered_config_dict)

    # Handle metadata if present
    if "metadata" in config_dict:
        res.metadata = SAEMetadata(**config_dict["metadata"])

    return res

to_dict()

Convert to dictionary, including parent fields.

Source code in sae_lens/saes/transcoder.py
def to_dict(self) -> dict[str, Any]:
    """Convert to dictionary, including parent fields."""
    # Get the base dictionary from parent
    res = super().to_dict()

    # Add transcoder-specific fields
    res.update({"d_out": self.d_out})

    return res

Synthetic Data

Synthetic data utilities for SAE experiments.

This module provides tools for creating feature dictionaries and generating synthetic activations for testing and experimenting with SAEs.

Main components:

  • FeatureDictionary: Maps sparse feature activations to dense hidden activations
  • ActivationGenerator: Generates batches of synthetic feature activations
  • HierarchyNode: Enforces hierarchical structure on feature activations
  • Training utilities: Helpers for training and evaluating SAEs on synthetic data
  • Plotting utilities: Visualization helpers for understanding SAE behavior

ActivationGenerator

Bases: Module

Generator for synthetic feature activations.

This module provides a generator for synthetic feature activations with controlled properties.

Source code in sae_lens/synthetic/activation_generator.py
class ActivationGenerator(nn.Module):
    """
    Generator for synthetic feature activations.

    This module provides a generator for synthetic feature activations with controlled properties.
    """

    num_features: int
    firing_probabilities: torch.Tensor
    std_firing_magnitudes: torch.Tensor
    mean_firing_magnitudes: torch.Tensor
    modify_activations: ActivationsModifier | None
    correlation_matrix: torch.Tensor | None
    low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
    correlation_thresholds: torch.Tensor | None
    use_sparse_tensors: bool

    def __init__(
        self,
        num_features: int,
        firing_probabilities: torch.Tensor | float,
        std_firing_magnitudes: torch.Tensor | float = 0.0,
        mean_firing_magnitudes: torch.Tensor | float = 1.0,
        modify_activations: ActivationsModifierInput = None,
        correlation_matrix: CorrelationMatrixInput | None = None,
        device: torch.device | str = "cpu",
        dtype: torch.dtype | str = "float32",
        use_sparse_tensors: bool = False,
    ):
        """
        Create a new ActivationGenerator.

        Args:
            num_features: Number of features to generate activations for.
            firing_probabilities: Probability of each feature firing. Can be a single
                float (applied to all features) or a tensor of shape (num_features,).
            std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
                single float or a tensor of shape (num_features,). Defaults to 0.0
                (deterministic magnitudes).
            mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
                a single float or a tensor of shape (num_features,). Defaults to 1.0.
            modify_activations: Optional function(s) to modify activations after
                generation. Can be a single callable, a sequence of callables (applied
                in order), or None. Useful for applying hierarchy constraints.
            correlation_matrix: Optional correlation structure between features. Can be:

                - A full correlation matrix tensor of shape (num_features, num_features)
                - A LowRankCorrelationMatrix for memory-efficient large-scale correlations
                - A tuple of (factor, diag) tensors representing low-rank structure

            device: Device to place tensors on. Defaults to "cpu".
            dtype: Data type for tensors. Defaults to "float32".
            use_sparse_tensors: If True, return sparse COO tensors from sample().
                Only recommended when using massive numbers of features. Defaults to False.
        """
        super().__init__()
        self.num_features = num_features
        self.firing_probabilities = _to_tensor(
            firing_probabilities, num_features, device, dtype
        )
        self.std_firing_magnitudes = _to_tensor(
            std_firing_magnitudes, num_features, device, dtype
        )
        self.mean_firing_magnitudes = _to_tensor(
            mean_firing_magnitudes, num_features, device, dtype
        )
        self.modify_activations = _normalize_modifiers(modify_activations)
        self.correlation_thresholds = None
        self.correlation_matrix = None
        self.low_rank_correlation = None
        self.use_sparse_tensors = use_sparse_tensors

        if correlation_matrix is not None:
            if isinstance(correlation_matrix, torch.Tensor):
                # Full correlation matrix
                _validate_correlation_matrix(correlation_matrix, num_features)
                self.correlation_matrix = correlation_matrix
            else:
                # Low-rank correlation (tuple or LowRankCorrelationMatrix)
                correlation_factor, correlation_diag = (
                    correlation_matrix[0],
                    correlation_matrix[1],
                )
                _validate_low_rank_correlation(
                    correlation_factor, correlation_diag, num_features
                )
                # Pre-compute sqrt for efficiency (used every sample call)
                self.low_rank_correlation = (
                    correlation_factor,
                    correlation_diag.sqrt(),
                )

            # Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
            self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
                1 - 2 * self.firing_probabilities
            )

    @torch.no_grad()
    def sample(self, batch_size: int) -> torch.Tensor:
        """
        Generate a batch of feature activations with controlled properties.

        This is the main function for generating synthetic training data for SAEs.
        Features fire independently according to their firing probabilities unless
        a correlation matrix is provided.

        Args:
            batch_size: Number of samples to generate

        Returns:
            Tensor of shape [batch_size, num_features] with non-negative activations
        """
        # All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
        # are on the same device from __init__ via _to_tensor()
        device = self.firing_probabilities.device

        if self.correlation_matrix is not None:
            assert self.correlation_thresholds is not None
            firing_indices = _generate_correlated_features(
                batch_size,
                self.correlation_matrix,
                self.correlation_thresholds,
                device,
            )
        elif self.low_rank_correlation is not None:
            assert self.correlation_thresholds is not None
            firing_indices = _generate_low_rank_correlated_features(
                batch_size,
                self.low_rank_correlation[0],
                self.low_rank_correlation[1],
                self.correlation_thresholds,
                device,
            )
        else:
            firing_indices = torch.bernoulli(
                self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
            ).nonzero(as_tuple=True)

        # Compute activations only at firing positions (sparse optimization)
        feature_indices = firing_indices[1]
        num_firing = feature_indices.shape[0]
        mean_at_firing = self.mean_firing_magnitudes[feature_indices]
        std_at_firing = self.std_firing_magnitudes[feature_indices]
        random_deltas = (
            torch.randn(
                num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
            )
            * std_at_firing
        )
        activations_at_firing = (mean_at_firing + random_deltas).relu()

        if self.use_sparse_tensors:
            # Return sparse COO tensor
            indices = torch.stack(firing_indices)  # [2, nnz]
            feature_activations = torch.sparse_coo_tensor(
                indices,
                activations_at_firing,
                size=(batch_size, self.num_features),
                device=device,
                dtype=self.mean_firing_magnitudes.dtype,
            )
        else:
            # Dense tensor path
            feature_activations = torch.zeros(
                batch_size,
                self.num_features,
                device=device,
                dtype=self.mean_firing_magnitudes.dtype,
            )
            feature_activations[firing_indices] = activations_at_firing

        if self.modify_activations is not None:
            feature_activations = self.modify_activations(feature_activations)
            if feature_activations.is_sparse:
                # Apply relu to sparse values
                feature_activations = feature_activations.coalesce()
                feature_activations = torch.sparse_coo_tensor(
                    feature_activations.indices(),
                    feature_activations.values().relu(),
                    feature_activations.shape,
                    device=feature_activations.device,
                    dtype=feature_activations.dtype,
                )
            else:
                feature_activations = feature_activations.relu()

        return feature_activations

    def forward(self, batch_size: int) -> torch.Tensor:
        return self.sample(batch_size)

__init__(num_features, firing_probabilities, std_firing_magnitudes=0.0, mean_firing_magnitudes=1.0, modify_activations=None, correlation_matrix=None, device='cpu', dtype='float32', use_sparse_tensors=False)

Create a new ActivationGenerator.

Parameters:

Name Type Description Default
num_features int

Number of features to generate activations for.

required
firing_probabilities Tensor | float

Probability of each feature firing. Can be a single float (applied to all features) or a tensor of shape (num_features,).

required
std_firing_magnitudes Tensor | float

Standard deviation of firing magnitudes. Can be a single float or a tensor of shape (num_features,). Defaults to 0.0 (deterministic magnitudes).

0.0
mean_firing_magnitudes Tensor | float

Mean firing magnitude when a feature fires. Can be a single float or a tensor of shape (num_features,). Defaults to 1.0.

1.0
modify_activations ActivationsModifierInput

Optional function(s) to modify activations after generation. Can be a single callable, a sequence of callables (applied in order), or None. Useful for applying hierarchy constraints.

None
correlation_matrix CorrelationMatrixInput | None

Optional correlation structure between features. Can be:

  • A full correlation matrix tensor of shape (num_features, num_features)
  • A LowRankCorrelationMatrix for memory-efficient large-scale correlations
  • A tuple of (factor, diag) tensors representing low-rank structure
None
device device | str

Device to place tensors on. Defaults to "cpu".

'cpu'
dtype dtype | str

Data type for tensors. Defaults to "float32".

'float32'
use_sparse_tensors bool

If True, return sparse COO tensors from sample(). Only recommended when using massive numbers of features. Defaults to False.

False
Source code in sae_lens/synthetic/activation_generator.py
def __init__(
    self,
    num_features: int,
    firing_probabilities: torch.Tensor | float,
    std_firing_magnitudes: torch.Tensor | float = 0.0,
    mean_firing_magnitudes: torch.Tensor | float = 1.0,
    modify_activations: ActivationsModifierInput = None,
    correlation_matrix: CorrelationMatrixInput | None = None,
    device: torch.device | str = "cpu",
    dtype: torch.dtype | str = "float32",
    use_sparse_tensors: bool = False,
):
    """
    Create a new ActivationGenerator.

    Args:
        num_features: Number of features to generate activations for.
        firing_probabilities: Probability of each feature firing. Can be a single
            float (applied to all features) or a tensor of shape (num_features,).
        std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
            single float or a tensor of shape (num_features,). Defaults to 0.0
            (deterministic magnitudes).
        mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
            a single float or a tensor of shape (num_features,). Defaults to 1.0.
        modify_activations: Optional function(s) to modify activations after
            generation. Can be a single callable, a sequence of callables (applied
            in order), or None. Useful for applying hierarchy constraints.
        correlation_matrix: Optional correlation structure between features. Can be:

            - A full correlation matrix tensor of shape (num_features, num_features)
            - A LowRankCorrelationMatrix for memory-efficient large-scale correlations
            - A tuple of (factor, diag) tensors representing low-rank structure

        device: Device to place tensors on. Defaults to "cpu".
        dtype: Data type for tensors. Defaults to "float32".
        use_sparse_tensors: If True, return sparse COO tensors from sample().
            Only recommended when using massive numbers of features. Defaults to False.
    """
    super().__init__()
    self.num_features = num_features
    self.firing_probabilities = _to_tensor(
        firing_probabilities, num_features, device, dtype
    )
    self.std_firing_magnitudes = _to_tensor(
        std_firing_magnitudes, num_features, device, dtype
    )
    self.mean_firing_magnitudes = _to_tensor(
        mean_firing_magnitudes, num_features, device, dtype
    )
    self.modify_activations = _normalize_modifiers(modify_activations)
    self.correlation_thresholds = None
    self.correlation_matrix = None
    self.low_rank_correlation = None
    self.use_sparse_tensors = use_sparse_tensors

    if correlation_matrix is not None:
        if isinstance(correlation_matrix, torch.Tensor):
            # Full correlation matrix
            _validate_correlation_matrix(correlation_matrix, num_features)
            self.correlation_matrix = correlation_matrix
        else:
            # Low-rank correlation (tuple or LowRankCorrelationMatrix)
            correlation_factor, correlation_diag = (
                correlation_matrix[0],
                correlation_matrix[1],
            )
            _validate_low_rank_correlation(
                correlation_factor, correlation_diag, num_features
            )
            # Pre-compute sqrt for efficiency (used every sample call)
            self.low_rank_correlation = (
                correlation_factor,
                correlation_diag.sqrt(),
            )

        # Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
        self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
            1 - 2 * self.firing_probabilities
        )

sample(batch_size)

Generate a batch of feature activations with controlled properties.

This is the main function for generating synthetic training data for SAEs. Features fire independently according to their firing probabilities unless a correlation matrix is provided.

Parameters:

Name Type Description Default
batch_size int

Number of samples to generate

required

Returns:

Type Description
Tensor

Tensor of shape [batch_size, num_features] with non-negative activations

Source code in sae_lens/synthetic/activation_generator.py
@torch.no_grad()
def sample(self, batch_size: int) -> torch.Tensor:
    """
    Generate a batch of feature activations with controlled properties.

    This is the main function for generating synthetic training data for SAEs.
    Features fire independently according to their firing probabilities unless
    a correlation matrix is provided.

    Args:
        batch_size: Number of samples to generate

    Returns:
        Tensor of shape [batch_size, num_features] with non-negative activations
    """
    # All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
    # are on the same device from __init__ via _to_tensor()
    device = self.firing_probabilities.device

    if self.correlation_matrix is not None:
        assert self.correlation_thresholds is not None
        firing_indices = _generate_correlated_features(
            batch_size,
            self.correlation_matrix,
            self.correlation_thresholds,
            device,
        )
    elif self.low_rank_correlation is not None:
        assert self.correlation_thresholds is not None
        firing_indices = _generate_low_rank_correlated_features(
            batch_size,
            self.low_rank_correlation[0],
            self.low_rank_correlation[1],
            self.correlation_thresholds,
            device,
        )
    else:
        firing_indices = torch.bernoulli(
            self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
        ).nonzero(as_tuple=True)

    # Compute activations only at firing positions (sparse optimization)
    feature_indices = firing_indices[1]
    num_firing = feature_indices.shape[0]
    mean_at_firing = self.mean_firing_magnitudes[feature_indices]
    std_at_firing = self.std_firing_magnitudes[feature_indices]
    random_deltas = (
        torch.randn(
            num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
        )
        * std_at_firing
    )
    activations_at_firing = (mean_at_firing + random_deltas).relu()

    if self.use_sparse_tensors:
        # Return sparse COO tensor
        indices = torch.stack(firing_indices)  # [2, nnz]
        feature_activations = torch.sparse_coo_tensor(
            indices,
            activations_at_firing,
            size=(batch_size, self.num_features),
            device=device,
            dtype=self.mean_firing_magnitudes.dtype,
        )
    else:
        # Dense tensor path
        feature_activations = torch.zeros(
            batch_size,
            self.num_features,
            device=device,
            dtype=self.mean_firing_magnitudes.dtype,
        )
        feature_activations[firing_indices] = activations_at_firing

    if self.modify_activations is not None:
        feature_activations = self.modify_activations(feature_activations)
        if feature_activations.is_sparse:
            # Apply relu to sparse values
            feature_activations = feature_activations.coalesce()
            feature_activations = torch.sparse_coo_tensor(
                feature_activations.indices(),
                feature_activations.values().relu(),
                feature_activations.shape,
                device=feature_activations.device,
                dtype=feature_activations.dtype,
            )
        else:
            feature_activations = feature_activations.relu()

    return feature_activations

CorrelationMatrixStats dataclass

Statistics computed from a correlation matrix.

Source code in sae_lens/synthetic/stats.py
@dataclass
class CorrelationMatrixStats:
    """Statistics computed from a correlation matrix."""

    rms_correlation: float  # Root mean square of off-diagonal correlations
    mean_correlation: float  # Mean of off-diagonal correlations (not absolute)
    num_features: int

FeatureDictionary

Bases: Module

A feature dictionary that maps sparse feature activations to dense hidden activations.

This class creates a set of feature vectors (the "dictionary") and provides methods to generate hidden activations from feature activations via a linear transformation.

The feature vectors can be configured to have a specific pairwise cosine similarity, which is useful for controlling the difficulty of sparse recovery.

Attributes:

Name Type Description
feature_vectors Parameter

Parameter of shape [num_features, hidden_dim] containing the feature embedding vectors

bias Parameter

Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)

Source code in sae_lens/synthetic/feature_dictionary.py
class FeatureDictionary(nn.Module):
    """
    A feature dictionary that maps sparse feature activations to dense hidden activations.

    This class creates a set of feature vectors (the "dictionary") and provides methods
    to generate hidden activations from feature activations via a linear transformation.

    The feature vectors can be configured to have a specific pairwise cosine similarity,
    which is useful for controlling the difficulty of sparse recovery.

    Attributes:
        feature_vectors: Parameter of shape [num_features, hidden_dim] containing the
            feature embedding vectors
        bias: Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)
    """

    feature_vectors: nn.Parameter
    bias: nn.Parameter

    def __init__(
        self,
        num_features: int,
        hidden_dim: int,
        bias: bool = False,
        initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
        device: str | torch.device = "cpu",
    ):
        """
        Create a new FeatureDictionary.

        Args:
            num_features: Number of features in the dictionary
            hidden_dim: Dimensionality of the hidden space
            bias: Whether to include a bias term in the embedding
            initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
            device: Device to use for the feature dictionary.
        """
        super().__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim

        # Initialize feature vectors as unit vectors
        embeddings = torch.randn(num_features, hidden_dim, device=device)
        embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
            min=1e-8
        )
        self.feature_vectors = nn.Parameter(embeddings)

        # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
        self.bias = nn.Parameter(
            torch.zeros(hidden_dim, device=device), requires_grad=bias
        )

        if initializer is not None:
            initializer(self)

    def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
        """
        Convert feature activations to hidden activations.

        Args:
            feature_activations: Tensor of shape [batch, num_features] containing
                sparse feature activation values. Can be dense or sparse COO.

        Returns:
            Tensor of shape [batch, hidden_dim] containing dense hidden activations
        """
        if feature_activations.is_sparse:
            # autocast is disabled here because sparse matmul is not supported with bfloat16
            with torch.autocast(
                device_type=feature_activations.device.type, enabled=False
            ):
                return (
                    torch.sparse.mm(feature_activations, self.feature_vectors)
                    + self.bias
                )
        return feature_activations @ self.feature_vectors + self.bias

__init__(num_features, hidden_dim, bias=False, initializer=orthogonal_initializer(), device='cpu')

Create a new FeatureDictionary.

Parameters:

Name Type Description Default
num_features int

Number of features in the dictionary

required
hidden_dim int

Dimensionality of the hidden space

required
bias bool

Whether to include a bias term in the embedding

False
initializer FeatureDictionaryInitializer | None

Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.

orthogonal_initializer()
device str | device

Device to use for the feature dictionary.

'cpu'
Source code in sae_lens/synthetic/feature_dictionary.py
def __init__(
    self,
    num_features: int,
    hidden_dim: int,
    bias: bool = False,
    initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
    device: str | torch.device = "cpu",
):
    """
    Create a new FeatureDictionary.

    Args:
        num_features: Number of features in the dictionary
        hidden_dim: Dimensionality of the hidden space
        bias: Whether to include a bias term in the embedding
        initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
        device: Device to use for the feature dictionary.
    """
    super().__init__()
    self.num_features = num_features
    self.hidden_dim = hidden_dim

    # Initialize feature vectors as unit vectors
    embeddings = torch.randn(num_features, hidden_dim, device=device)
    embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
        min=1e-8
    )
    self.feature_vectors = nn.Parameter(embeddings)

    # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
    self.bias = nn.Parameter(
        torch.zeros(hidden_dim, device=device), requires_grad=bias
    )

    if initializer is not None:
        initializer(self)

forward(feature_activations)

Convert feature activations to hidden activations.

Parameters:

Name Type Description Default
feature_activations Tensor

Tensor of shape [batch, num_features] containing sparse feature activation values. Can be dense or sparse COO.

required

Returns:

Type Description
Tensor

Tensor of shape [batch, hidden_dim] containing dense hidden activations

Source code in sae_lens/synthetic/feature_dictionary.py
def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
    """
    Convert feature activations to hidden activations.

    Args:
        feature_activations: Tensor of shape [batch, num_features] containing
            sparse feature activation values. Can be dense or sparse COO.

    Returns:
        Tensor of shape [batch, hidden_dim] containing dense hidden activations
    """
    if feature_activations.is_sparse:
        # autocast is disabled here because sparse matmul is not supported with bfloat16
        with torch.autocast(
            device_type=feature_activations.device.type, enabled=False
        ):
            return (
                torch.sparse.mm(feature_activations, self.feature_vectors)
                + self.bias
            )
    return feature_activations @ self.feature_vectors + self.bias

HierarchyNode

Represents a node in a feature hierarchy tree.

Used to define hierarchical dependencies between features. Children are deactivated when their parent is inactive, and children can optionally be mutually exclusive.

Use hierarchy_modifier() to create an ActivationsModifier from one or more HierarchyNode trees.

Attributes:

Name Type Description
feature_index int | None

Index of this feature in the activation tensor

children Sequence[HierarchyNode]

Child HierarchyNode nodes

mutually_exclusive_children

If True, at most one child is active per sample

feature_id

Optional identifier for debugging

Source code in sae_lens/synthetic/hierarchy.py
class HierarchyNode:
    """
    Represents a node in a feature hierarchy tree.

    Used to define hierarchical dependencies between features. Children are
    deactivated when their parent is inactive, and children can optionally
    be mutually exclusive.

    Use `hierarchy_modifier()` to create an ActivationsModifier from one or
    more HierarchyNode trees.


    Attributes:
        feature_index: Index of this feature in the activation tensor
        children: Child HierarchyNode nodes
        mutually_exclusive_children: If True, at most one child is active per sample
        feature_id: Optional identifier for debugging
    """

    children: Sequence[HierarchyNode]
    feature_index: int | None

    @classmethod
    def from_dict(cls, tree_dict: dict[str, Any]) -> HierarchyNode:
        """
        Create a HierarchyNode from a dictionary specification.

        Args:
            tree_dict: Dictionary with keys:

                - feature_index (optional): Index in the activation tensor
                - children (optional): List of child tree dictionaries
                - mutually_exclusive_children (optional): Whether children are exclusive
                - id (optional): Identifier for this node

        Returns:
            HierarchyNode instance
        """
        children = [
            HierarchyNode.from_dict(child_dict)
            for child_dict in tree_dict.get("children", [])
        ]
        return cls(
            feature_index=tree_dict.get("feature_index"),
            children=children,
            mutually_exclusive_children=tree_dict.get(
                "mutually_exclusive_children", False
            ),
            feature_id=tree_dict.get("id"),
        )

    def __init__(
        self,
        feature_index: int | None = None,
        children: Sequence[HierarchyNode] | None = None,
        mutually_exclusive_children: bool = False,
        feature_id: str | None = None,
    ):
        """
        Create a new HierarchyNode.

        Args:
            feature_index: Index of this feature in the activation tensor.
                Use None for organizational nodes that don't correspond to a feature.
            children: Child nodes that depend on this feature
            mutually_exclusive_children: If True, only one child can be active per sample
            feature_id: Optional identifier for debugging
        """
        self.feature_index = feature_index
        self.children = children or []
        self.mutually_exclusive_children = mutually_exclusive_children
        self.feature_id = feature_id

        if self.mutually_exclusive_children and len(self.children) < 2:
            raise ValueError("Need at least 2 children for mutual exclusion")

    def get_all_feature_indices(self) -> list[int]:
        """Get all feature indices in this subtree."""
        indices = []
        if self.feature_index is not None:
            indices.append(self.feature_index)
        for child in self.children:
            indices.extend(child.get_all_feature_indices())
        return indices

    def validate(self) -> None:
        """
        Validate the hierarchy structure.

        Checks that:
        1. There are no loops (no node is its own ancestor)
        2. Each node has at most one parent (no node appears in multiple children lists)

        Raises:
            ValueError: If the hierarchy is invalid
        """
        _validate_hierarchy([self])

    def __repr__(self, indent: int = 0) -> str:
        s = " " * (indent * 2)
        s += str(self.feature_index) if self.feature_index is not None else "-"
        s += "x" if self.mutually_exclusive_children else " "
        if self.feature_id:
            s += f" ({self.feature_id})"

        for child in self.children:
            s += "\n" + child.__repr__(indent + 2)
        return s

__init__(feature_index=None, children=None, mutually_exclusive_children=False, feature_id=None)

Create a new HierarchyNode.

Parameters:

Name Type Description Default
feature_index int | None

Index of this feature in the activation tensor. Use None for organizational nodes that don't correspond to a feature.

None
children Sequence[HierarchyNode] | None

Child nodes that depend on this feature

None
mutually_exclusive_children bool

If True, only one child can be active per sample

False
feature_id str | None

Optional identifier for debugging

None
Source code in sae_lens/synthetic/hierarchy.py
def __init__(
    self,
    feature_index: int | None = None,
    children: Sequence[HierarchyNode] | None = None,
    mutually_exclusive_children: bool = False,
    feature_id: str | None = None,
):
    """
    Create a new HierarchyNode.

    Args:
        feature_index: Index of this feature in the activation tensor.
            Use None for organizational nodes that don't correspond to a feature.
        children: Child nodes that depend on this feature
        mutually_exclusive_children: If True, only one child can be active per sample
        feature_id: Optional identifier for debugging
    """
    self.feature_index = feature_index
    self.children = children or []
    self.mutually_exclusive_children = mutually_exclusive_children
    self.feature_id = feature_id

    if self.mutually_exclusive_children and len(self.children) < 2:
        raise ValueError("Need at least 2 children for mutual exclusion")

from_dict(tree_dict) classmethod

Create a HierarchyNode from a dictionary specification.

Parameters:

Name Type Description Default
tree_dict dict[str, Any]

Dictionary with keys:

  • feature_index (optional): Index in the activation tensor
  • children (optional): List of child tree dictionaries
  • mutually_exclusive_children (optional): Whether children are exclusive
  • id (optional): Identifier for this node
required

Returns:

Type Description
HierarchyNode

HierarchyNode instance

Source code in sae_lens/synthetic/hierarchy.py
@classmethod
def from_dict(cls, tree_dict: dict[str, Any]) -> HierarchyNode:
    """
    Create a HierarchyNode from a dictionary specification.

    Args:
        tree_dict: Dictionary with keys:

            - feature_index (optional): Index in the activation tensor
            - children (optional): List of child tree dictionaries
            - mutually_exclusive_children (optional): Whether children are exclusive
            - id (optional): Identifier for this node

    Returns:
        HierarchyNode instance
    """
    children = [
        HierarchyNode.from_dict(child_dict)
        for child_dict in tree_dict.get("children", [])
    ]
    return cls(
        feature_index=tree_dict.get("feature_index"),
        children=children,
        mutually_exclusive_children=tree_dict.get(
            "mutually_exclusive_children", False
        ),
        feature_id=tree_dict.get("id"),
    )

get_all_feature_indices()

Get all feature indices in this subtree.

Source code in sae_lens/synthetic/hierarchy.py
def get_all_feature_indices(self) -> list[int]:
    """Get all feature indices in this subtree."""
    indices = []
    if self.feature_index is not None:
        indices.append(self.feature_index)
    for child in self.children:
        indices.extend(child.get_all_feature_indices())
    return indices

validate()

Validate the hierarchy structure.

Checks that: 1. There are no loops (no node is its own ancestor) 2. Each node has at most one parent (no node appears in multiple children lists)

Raises:

Type Description
ValueError

If the hierarchy is invalid

Source code in sae_lens/synthetic/hierarchy.py
def validate(self) -> None:
    """
    Validate the hierarchy structure.

    Checks that:
    1. There are no loops (no node is its own ancestor)
    2. Each node has at most one parent (no node appears in multiple children lists)

    Raises:
        ValueError: If the hierarchy is invalid
    """
    _validate_hierarchy([self])

LowRankCorrelationMatrix

Bases: NamedTuple

Low-rank representation of a correlation matrix for scalable correlated sampling.

The correlation structure is represented as

correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)

This requires O(num_features * rank) storage instead of O(num_features^2), making it suitable for very large numbers of features (e.g., 1M+).

Attributes:

Name Type Description
correlation_factor Tensor

Factor matrix of shape (num_features, rank) that captures correlations through shared latent factors.

correlation_diag Tensor

Diagonal variance term of shape (num_features,). Should be chosen such that the diagonal of the full correlation matrix equals 1. Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)

Source code in sae_lens/synthetic/correlation.py
class LowRankCorrelationMatrix(NamedTuple):
    """
    Low-rank representation of a correlation matrix for scalable correlated sampling.

    The correlation structure is represented as:
        correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)

    This requires O(num_features * rank) storage instead of O(num_features^2),
    making it suitable for very large numbers of features (e.g., 1M+).

    Attributes:
        correlation_factor: Factor matrix of shape (num_features, rank) that captures
            correlations through shared latent factors.
        correlation_diag: Diagonal variance term of shape (num_features,). Should be
            chosen such that the diagonal of the full correlation matrix equals 1.
            Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)
    """

    correlation_factor: torch.Tensor
    correlation_diag: torch.Tensor

SuperpositionStats dataclass

Statistics measuring superposition in a feature dictionary.

Source code in sae_lens/synthetic/stats.py
@dataclass
class SuperpositionStats:
    """Statistics measuring superposition in a feature dictionary."""

    # Per-latent statistics: for each latent, max and percentile of |cos_sim| with others
    max_abs_cos_sims: torch.Tensor  # Shape: (num_features,)
    percentile_abs_cos_sims: dict[int, torch.Tensor]  # {percentile: (num_features,)}

    # Summary statistics (means of the per-latent values)
    mean_max_abs_cos_sim: float
    mean_percentile_abs_cos_sim: dict[int, float]
    mean_abs_cos_sim: float  # Mean |cos_sim| across all pairs

    # Metadata
    num_features: int
    hidden_dim: int

SyntheticActivationIterator

Bases: Iterator[Tensor]

An iterator that generates synthetic activations for SAE training.

This iterator wraps a FeatureDictionary and a function that generates feature activations, producing hidden activations that can be used to train an SAE.

Source code in sae_lens/synthetic/training.py
class SyntheticActivationIterator(Iterator[torch.Tensor]):
    """
    An iterator that generates synthetic activations for SAE training.

    This iterator wraps a FeatureDictionary and a function that generates
    feature activations, producing hidden activations that can be used
    to train an SAE.
    """

    def __init__(
        self,
        feature_dict: FeatureDictionary,
        activations_generator: ActivationGenerator,
        batch_size: int,
        autocast: bool = False,
    ):
        """
        Create a new SyntheticActivationIterator.

        Args:
            feature_dict: The feature dictionary to use for generating hidden activations
            activations_generator: Generator that produces feature activations
            batch_size: Number of samples per batch
            autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
        """
        self.feature_dict = feature_dict
        self.activations_generator = activations_generator
        self.batch_size = batch_size
        self.autocast = autocast

    @torch.no_grad()
    def next_batch(self) -> torch.Tensor:
        """Generate the next batch of hidden activations."""
        with torch.autocast(
            device_type=self.feature_dict.feature_vectors.device.type,
            dtype=torch.bfloat16,
            enabled=self.autocast,
        ):
            features = self.activations_generator(self.batch_size)
            return self.feature_dict(features)

    def __iter__(self) -> "SyntheticActivationIterator":
        return self

    def __next__(self) -> torch.Tensor:
        return self.next_batch()

__init__(feature_dict, activations_generator, batch_size, autocast=False)

Create a new SyntheticActivationIterator.

Parameters:

Name Type Description Default
feature_dict FeatureDictionary

The feature dictionary to use for generating hidden activations

required
activations_generator ActivationGenerator

Generator that produces feature activations

required
batch_size int

Number of samples per batch

required
autocast bool

Whether to autocast the activations generator and feature dictionary to bfloat16.

False
Source code in sae_lens/synthetic/training.py
def __init__(
    self,
    feature_dict: FeatureDictionary,
    activations_generator: ActivationGenerator,
    batch_size: int,
    autocast: bool = False,
):
    """
    Create a new SyntheticActivationIterator.

    Args:
        feature_dict: The feature dictionary to use for generating hidden activations
        activations_generator: Generator that produces feature activations
        batch_size: Number of samples per batch
        autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
    """
    self.feature_dict = feature_dict
    self.activations_generator = activations_generator
    self.batch_size = batch_size
    self.autocast = autocast

next_batch()

Generate the next batch of hidden activations.

Source code in sae_lens/synthetic/training.py
@torch.no_grad()
def next_batch(self) -> torch.Tensor:
    """Generate the next batch of hidden activations."""
    with torch.autocast(
        device_type=self.feature_dict.feature_vectors.device.type,
        dtype=torch.bfloat16,
        enabled=self.autocast,
    ):
        features = self.activations_generator(self.batch_size)
        return self.feature_dict(features)

SyntheticDataEvalResult dataclass

Results from evaluating an SAE on synthetic data.

Source code in sae_lens/synthetic/evals.py
@dataclass
class SyntheticDataEvalResult:
    """Results from evaluating an SAE on synthetic data."""

    true_l0: float
    """Average L0 of the true feature activations"""

    sae_l0: float
    """Average L0 of the SAE's latent activations"""

    dead_latents: int
    """Number of SAE latents that never fired"""

    shrinkage: float
    """Average ratio of SAE output norm to input norm (1.0 = no shrinkage)"""

    mcc: float
    """Mean Correlation Coefficient between SAE decoder and ground truth features"""

dead_latents: int instance-attribute

Number of SAE latents that never fired

mcc: float instance-attribute

Mean Correlation Coefficient between SAE decoder and ground truth features

sae_l0: float instance-attribute

Average L0 of the SAE's latent activations

shrinkage: float instance-attribute

Average ratio of SAE output norm to input norm (1.0 = no shrinkage)

true_l0: float instance-attribute

Average L0 of the true feature activations

compute_correlation_matrix_stats(correlation_matrix)

Compute correlation statistics from a dense correlation matrix.

Parameters:

Name Type Description Default
correlation_matrix Tensor

Dense correlation matrix of shape (n, n)

required

Returns:

Type Description
CorrelationMatrixStats

CorrelationMatrixStats with correlation statistics

Source code in sae_lens/synthetic/stats.py
@torch.no_grad()
def compute_correlation_matrix_stats(
    correlation_matrix: torch.Tensor,
) -> CorrelationMatrixStats:
    """Compute correlation statistics from a dense correlation matrix.

    Args:
        correlation_matrix: Dense correlation matrix of shape (n, n)

    Returns:
        CorrelationMatrixStats with correlation statistics
    """
    num_features = correlation_matrix.shape[0]

    # Extract off-diagonal elements
    mask = ~torch.eye(num_features, dtype=torch.bool, device=correlation_matrix.device)
    off_diag = correlation_matrix[mask]

    rms_correlation = (off_diag**2).mean().sqrt().item()
    mean_correlation = off_diag.mean().item()

    return CorrelationMatrixStats(
        rms_correlation=rms_correlation,
        mean_correlation=mean_correlation,
        num_features=num_features,
    )

compute_low_rank_correlation_matrix_stats(correlation_matrix)

Compute correlation statistics from a LowRankCorrelationMatrix.

The correlation matrix is represented as

correlation = factor @ factor.T + diag(diag_term)

The off-diagonal elements are simply factor @ factor.T (the diagonal term only affects the diagonal).

All statistics are computed efficiently in O(n*r²) time and O(r²) memory without materializing the full n×n correlation matrix.

Parameters:

Name Type Description Default
correlation_matrix LowRankCorrelationMatrix

Low-rank correlation matrix

required

Returns:

Type Description
CorrelationMatrixStats

CorrelationMatrixStats with correlation statistics

Source code in sae_lens/synthetic/stats.py
@torch.no_grad()
def compute_low_rank_correlation_matrix_stats(
    correlation_matrix: LowRankCorrelationMatrix,
) -> CorrelationMatrixStats:
    """Compute correlation statistics from a LowRankCorrelationMatrix.

    The correlation matrix is represented as:
        correlation = factor @ factor.T + diag(diag_term)

    The off-diagonal elements are simply factor @ factor.T (the diagonal term
    only affects the diagonal).

    All statistics are computed efficiently in O(n*r²) time and O(r²) memory
    without materializing the full n×n correlation matrix.

    Args:
        correlation_matrix: Low-rank correlation matrix

    Returns:
        CorrelationMatrixStats with correlation statistics
    """

    factor = correlation_matrix.correlation_factor
    num_features = factor.shape[0]
    num_off_diag = num_features * (num_features - 1)

    # RMS correlation: uses ||F @ F.T||_F² = ||F.T @ F||_F²
    # This avoids computing the (num_features, num_features) matrix
    G = factor.T @ factor  # (rank, rank) - small!
    frobenius_sq = (G**2).sum()
    row_norms_sq = (factor**2).sum(dim=1)  # ||F[i]||² for each row
    diag_sq_sum = (row_norms_sq**2).sum()  # Σᵢ ||F[i]||⁴
    off_diag_sq_sum = frobenius_sq - diag_sq_sum
    rms_correlation = (off_diag_sq_sum / num_off_diag).sqrt().item()

    # Mean correlation (not absolute): sum(C) = ||col_sums(F)||², trace(C) = Σ||F[i]||²
    col_sums = factor.sum(dim=0)  # (rank,)
    sum_all = (col_sums**2).sum()  # 1ᵀ C 1
    trace_C = row_norms_sq.sum()
    mean_correlation = ((sum_all - trace_C) / num_off_diag).item()

    return CorrelationMatrixStats(
        rms_correlation=rms_correlation,
        mean_correlation=mean_correlation,
        num_features=num_features,
    )

compute_superposition_stats(feature_dictionary, batch_size=1024, device=None, percentiles=None)

Compute superposition statistics for a feature dictionary.

Computes pairwise cosine similarities in batches to handle large dictionaries.

For each latent i, computes:

  • max |cos_sim(i, j)| over all j != i
  • kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)

Parameters:

Name Type Description Default
feature_dictionary FeatureDictionary

FeatureDictionary containing the feature vectors

required
batch_size int

Number of features to process per batch

1024
device str | device | None

Device for computation (defaults to feature dictionary's device)

None
percentiles list[int] | None

List of percentiles to compute per latent (default: [95, 99])

None

Returns:

Type Description
SuperpositionStats

SuperpositionStats with superposition metrics

Source code in sae_lens/synthetic/stats.py
@torch.no_grad()
def compute_superposition_stats(
    feature_dictionary: FeatureDictionary,
    batch_size: int = 1024,
    device: str | torch.device | None = None,
    percentiles: list[int] | None = None,
) -> SuperpositionStats:
    """Compute superposition statistics for a feature dictionary.

    Computes pairwise cosine similarities in batches to handle large dictionaries.

    For each latent i, computes:

    - max |cos_sim(i, j)| over all j != i
    - kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)

    Args:
        feature_dictionary: FeatureDictionary containing the feature vectors
        batch_size: Number of features to process per batch
        device: Device for computation (defaults to feature dictionary's device)
        percentiles: List of percentiles to compute per latent (default: [95, 99])

    Returns:
        SuperpositionStats with superposition metrics
    """
    if percentiles is None:
        percentiles = [95, 99]

    feature_vectors = feature_dictionary.feature_vectors
    num_features, hidden_dim = feature_vectors.shape

    if num_features < 2:
        raise ValueError("Need at least 2 features to compute superposition stats")
    if device is None:
        device = feature_vectors.device

    # Normalize features to unit norm for cosine similarity
    features_normalized = feature_vectors.to(device).float()
    norms = torch.linalg.norm(features_normalized, dim=1, keepdim=True)
    features_normalized = features_normalized / norms.clamp(min=1e-8)

    # Track per-latent statistics
    max_abs_cos_sims = torch.zeros(num_features, device=device)
    percentile_abs_cos_sims = {
        p: torch.zeros(num_features, device=device) for p in percentiles
    }
    sum_abs_cos_sim = 0.0
    n_pairs = 0

    # Process in batches: for each batch of features, compute similarities with all others
    for i in range(0, num_features, batch_size):
        batch_end = min(i + batch_size, num_features)
        batch = features_normalized[i:batch_end]  # (batch_size, hidden_dim)

        # Compute cosine similarities with all features: (batch_size, num_features)
        cos_sims = batch @ features_normalized.T

        # Absolute cosine similarities
        abs_cos_sims = cos_sims.abs()

        # Process each latent in the batch
        for j, idx in enumerate(range(i, batch_end)):
            # Get similarities with all other features (exclude self)
            row = abs_cos_sims[j].clone()
            row[idx] = 0.0  # Exclude self for max
            max_abs_cos_sims[idx] = row.max()

            # For percentiles, exclude self and compute
            other_sims = torch.cat([abs_cos_sims[j, :idx], abs_cos_sims[j, idx + 1 :]])
            for p in percentiles:
                percentile_abs_cos_sims[p][idx] = torch.quantile(other_sims, p / 100.0)

            # Sum for mean computation (only count pairs once - with features after this one)
            sum_abs_cos_sim += abs_cos_sims[j, idx + 1 :].sum().item()
            n_pairs += num_features - idx - 1

    # Compute summary statistics
    mean_max_abs_cos_sim = max_abs_cos_sims.mean().item()
    mean_percentile_abs_cos_sim = {
        p: percentile_abs_cos_sims[p].mean().item() for p in percentiles
    }
    mean_abs_cos_sim = sum_abs_cos_sim / n_pairs if n_pairs > 0 else 0.0

    return SuperpositionStats(
        max_abs_cos_sims=max_abs_cos_sims.cpu(),
        percentile_abs_cos_sims={
            p: v.cpu() for p, v in percentile_abs_cos_sims.items()
        },
        mean_max_abs_cos_sim=mean_max_abs_cos_sim,
        mean_percentile_abs_cos_sim=mean_percentile_abs_cos_sim,
        mean_abs_cos_sim=mean_abs_cos_sim,
        num_features=num_features,
        hidden_dim=hidden_dim,
    )

cosine_similarities(mat1, mat2=None)

Compute cosine similarities between each row of mat1 and each row of mat2.

Parameters:

Name Type Description Default
mat1 Tensor

Tensor of shape [n1, d]

required
mat2 Tensor | None

Tensor of shape [n2, d]. If not provided, mat1 = mat2

None

Returns:

Type Description
Tensor

Tensor of shape [n1, n2] with cosine similarities

Source code in sae_lens/util.py
def cosine_similarities(
    mat1: torch.Tensor, mat2: torch.Tensor | None = None
) -> torch.Tensor:
    """
    Compute cosine similarities between each row of mat1 and each row of mat2.

    Args:
        mat1: Tensor of shape [n1, d]
        mat2: Tensor of shape [n2, d]. If not provided, mat1 = mat2

    Returns:
        Tensor of shape [n1, n2] with cosine similarities
    """
    if mat2 is None:
        mat2 = mat1
    # Clamp norm to 1e-8 to prevent division by zero. This threshold is chosen
    # to be small enough to not affect normal vectors but large enough to avoid
    # numerical instability. Zero vectors will effectively map to zero similarity.
    mat1_normed = mat1 / mat1.norm(dim=1, keepdim=True).clamp(min=1e-8)
    mat2_normed = mat2 / mat2.norm(dim=1, keepdim=True).clamp(min=1e-8)
    return mat1_normed @ mat2_normed.T

create_correlation_matrix_from_correlations(num_features, correlations=None, default_correlation=0.0)

Create a correlation matrix with specified pairwise correlations.

Note: If the resulting matrix is not positive definite, it will be adjusted to ensure validity. This adjustment may change the specified correlation values. To minimize this effect, use smaller correlation magnitudes.

Parameters:

Name Type Description Default
num_features int

Number of features

required
correlations dict[tuple[int, int], float] | None

Dict mapping (i, j) pairs to correlation values. Pairs should have i < j. Pairs not specified will use default_correlation.

None
default_correlation float

Default correlation for unspecified pairs

0.0

Returns:

Type Description
Tensor

Correlation matrix of shape (num_features, num_features)

Source code in sae_lens/synthetic/correlation.py
def create_correlation_matrix_from_correlations(
    num_features: int,
    correlations: dict[tuple[int, int], float] | None = None,
    default_correlation: float = 0.0,
) -> torch.Tensor:
    """
    Create a correlation matrix with specified pairwise correlations.

    Note: If the resulting matrix is not positive definite, it will be adjusted
    to ensure validity. This adjustment may change the specified correlation
    values. To minimize this effect, use smaller correlation magnitudes.

    Args:
        num_features: Number of features
        correlations: Dict mapping (i, j) pairs to correlation values.
            Pairs should have i < j. Pairs not specified will use default_correlation.
        default_correlation: Default correlation for unspecified pairs

    Returns:
        Correlation matrix of shape (num_features, num_features)
    """
    matrix = torch.eye(num_features) + default_correlation * (
        1 - torch.eye(num_features)
    )

    if correlations is not None:
        for (i, j), corr in correlations.items():
            matrix[i, j] = corr
            matrix[j, i] = corr

    # Ensure matrix is symmetric (numerical precision)
    matrix = (matrix + matrix.T) / 2

    # Check positive definiteness and fix if necessary
    # Use eigvalsh for symmetric matrices (returns real eigenvalues)
    eigenvals = torch.linalg.eigvalsh(matrix)
    if torch.any(eigenvals < -1e-6):
        matrix = _fix_correlation_matrix(matrix)

    return matrix

eval_sae_on_synthetic_data(sae, feature_dict, activations_generator, num_samples=100000)

Evaluate an SAE on synthetic data with known ground truth.

Parameters:

Name Type Description Default
sae Module

The SAE to evaluate. Must have encode() and decode() methods.

required
feature_dict FeatureDictionary

The feature dictionary used to generate activations

required
activations_generator ActivationGenerator

Generator that produces feature activations

required
num_samples int

Number of samples to use for evaluation

100000

Returns:

Type Description
SyntheticDataEvalResult

SyntheticDataEvalResult containing evaluation metrics

Source code in sae_lens/synthetic/evals.py
@torch.no_grad()
def eval_sae_on_synthetic_data(
    sae: torch.nn.Module,
    feature_dict: FeatureDictionary,
    activations_generator: ActivationGenerator,
    num_samples: int = 100_000,
) -> SyntheticDataEvalResult:
    """
    Evaluate an SAE on synthetic data with known ground truth.

    Args:
        sae: The SAE to evaluate. Must have encode() and decode() methods.
        feature_dict: The feature dictionary used to generate activations
        activations_generator: Generator that produces feature activations
        num_samples: Number of samples to use for evaluation

    Returns:
        SyntheticDataEvalResult containing evaluation metrics
    """
    sae.eval()

    # Generate samples
    feature_acts = activations_generator.sample(num_samples)
    true_l0 = (feature_acts > 0).float().sum(dim=-1).mean().item()
    hidden_acts = feature_dict(feature_acts)

    # Filter out entries where no features fire
    non_zero_mask = hidden_acts.norm(dim=-1) > 0
    hidden_acts_filtered = hidden_acts[non_zero_mask]

    # Get SAE reconstructions
    sae_latents = sae.encode(hidden_acts_filtered)  # type: ignore[attr-defined]
    sae_output = sae.decode(sae_latents)  # type: ignore[attr-defined]

    sae_l0 = (sae_latents > 0).float().sum(dim=-1).mean().item()
    dead_latents = int(
        ((sae_latents == 0).sum(dim=0) == sae_latents.shape[0]).sum().item()
    )
    if hidden_acts_filtered.shape[0] == 0:
        shrinkage = 0.0
    else:
        shrinkage = (
            (
                sae_output.norm(dim=-1)
                / hidden_acts_filtered.norm(dim=-1).clamp(min=1e-8)
            )
            .mean()
            .item()
        )

    # Compute MCC between SAE decoder and ground truth features
    sae_decoder: torch.Tensor = sae.W_dec  # type: ignore[attr-defined]
    gt_features = feature_dict.feature_vectors
    mcc = mean_correlation_coefficient(sae_decoder, gt_features)

    return SyntheticDataEvalResult(
        true_l0=true_l0,
        sae_l0=sae_l0,
        dead_latents=dead_latents,
        shrinkage=shrinkage,
        mcc=mcc,
    )

find_best_feature_ordering(sae_features, true_features)

Find the best ordering of SAE features to match true features.

Reorders SAE features so that each SAE latent aligns with its best-matching true feature in order. This makes cosine similarity plots more interpretable.

Parameters:

Name Type Description Default
sae_features Tensor

SAE decoder weights of shape [d_sae, hidden_dim]

required
true_features Tensor

True feature vectors of shape [num_features, hidden_dim]

required

Returns:

Type Description
Tensor

Tensor of indices that reorders sae_features for best alignment

Source code in sae_lens/synthetic/plotting.py
def find_best_feature_ordering(
    sae_features: torch.Tensor,
    true_features: torch.Tensor,
) -> torch.Tensor:
    """
    Find the best ordering of SAE features to match true features.

    Reorders SAE features so that each SAE latent aligns with its best-matching
    true feature in order. This makes cosine similarity plots more interpretable.

    Args:
        sae_features: SAE decoder weights of shape [d_sae, hidden_dim]
        true_features: True feature vectors of shape [num_features, hidden_dim]

    Returns:
        Tensor of indices that reorders sae_features for best alignment
    """
    cos_sims = cosine_similarities(sae_features, true_features)
    best_matches = torch.argmax(torch.abs(cos_sims), dim=1)
    return torch.argsort(best_matches)

find_best_feature_ordering_across_saes(saes, feature_dict)

Find the best feature ordering that works across multiple SAEs.

Useful for creating consistent orderings across training snapshots.

Parameters:

Name Type Description Default
saes Iterable[Module]

Iterable of SAEs to consider

required
feature_dict FeatureDictionary

The feature dictionary containing true features

required

Returns:

Type Description
Tensor

The best ordering tensor found across all SAEs

Source code in sae_lens/synthetic/plotting.py
def find_best_feature_ordering_across_saes(
    saes: Iterable[torch.nn.Module],
    feature_dict: FeatureDictionary,
) -> torch.Tensor:
    """
    Find the best feature ordering that works across multiple SAEs.

    Useful for creating consistent orderings across training snapshots.

    Args:
        saes: Iterable of SAEs to consider
        feature_dict: The feature dictionary containing true features

    Returns:
        The best ordering tensor found across all SAEs
    """
    best_score = float("-inf")
    best_ordering: torch.Tensor | None = None

    true_features = feature_dict.feature_vectors.detach()

    for sae in saes:
        sae_features = sae.W_dec.detach()  # type: ignore[attr-defined]
        cos_sims = cosine_similarities(sae_features, true_features)
        cos_sims = torch.round(cos_sims * 100) / 100  # Reduce numerical noise

        ordering = find_best_feature_ordering(sae_features, true_features)
        score = cos_sims[ordering, torch.arange(cos_sims.shape[1])].mean().item()

        if score > best_score:
            best_score = score
            best_ordering = ordering

    if best_ordering is None:
        raise ValueError("No SAEs provided")

    return best_ordering

find_best_feature_ordering_from_sae(sae, feature_dict)

Find the best feature ordering for an SAE given a feature dictionary.

Parameters:

Name Type Description Default
sae Module

SAE with W_dec attribute of shape [d_sae, hidden_dim]

required
feature_dict FeatureDictionary

The feature dictionary containing true features

required

Returns:

Type Description
Tensor

Tensor of indices that reorders SAE latents for best alignment

Source code in sae_lens/synthetic/plotting.py
def find_best_feature_ordering_from_sae(
    sae: torch.nn.Module,
    feature_dict: FeatureDictionary,
) -> torch.Tensor:
    """
    Find the best feature ordering for an SAE given a feature dictionary.

    Args:
        sae: SAE with W_dec attribute of shape [d_sae, hidden_dim]
        feature_dict: The feature dictionary containing true features

    Returns:
        Tensor of indices that reorders SAE latents for best alignment
    """
    sae_features = sae.W_dec.detach()  # type: ignore[attr-defined]
    true_features = feature_dict.feature_vectors.detach()
    return find_best_feature_ordering(sae_features, true_features)

generate_random_correlation_matrix(num_features, positive_ratio=0.5, uncorrelated_ratio=0.3, min_correlation_strength=0.1, max_correlation_strength=0.8, seed=None, device='cpu', dtype=torch.float32)

Generate a random correlation matrix with specified constraints.

Uses vectorized torch operations for efficiency with large numbers of features.

Note: If the randomly generated matrix is not positive definite, it will be adjusted to ensure validity. This adjustment may change correlation values, including turning some zero correlations into non-zero values. To minimize this effect, use smaller correlation strengths (e.g., 0.01-0.1).

Parameters:

Name Type Description Default
num_features int

Number of features

required
positive_ratio float

Fraction of correlated pairs that should be positive (0.0 to 1.0)

0.5
uncorrelated_ratio float

Fraction of feature pairs that should have zero correlation (0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce the actual number of zero correlations.

0.3
min_correlation_strength float

Minimum absolute correlation strength for correlated pairs

0.1
max_correlation_strength float

Maximum absolute correlation strength for correlated pairs

0.8
seed int | None

Random seed for reproducibility

None
device device | str

Device to create the matrix on

'cpu'
dtype dtype | str

Data type for the matrix

float32

Returns:

Type Description
Tensor

Random correlation matrix of shape (num_features, num_features)

Source code in sae_lens/synthetic/correlation.py
def generate_random_correlation_matrix(
    num_features: int,
    positive_ratio: float = 0.5,
    uncorrelated_ratio: float = 0.3,
    min_correlation_strength: float = 0.1,
    max_correlation_strength: float = 0.8,
    seed: int | None = None,
    device: torch.device | str = "cpu",
    dtype: torch.dtype | str = torch.float32,
) -> torch.Tensor:
    """
    Generate a random correlation matrix with specified constraints.

    Uses vectorized torch operations for efficiency with large numbers of features.

    Note: If the randomly generated matrix is not positive definite, it will be
    adjusted to ensure validity. This adjustment may change correlation values,
    including turning some zero correlations into non-zero values. To minimize
    this effect, use smaller correlation strengths (e.g., 0.01-0.1).

    Args:
        num_features: Number of features
        positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
        uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
            (0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce
            the actual number of zero correlations.
        min_correlation_strength: Minimum absolute correlation strength for correlated pairs
        max_correlation_strength: Maximum absolute correlation strength for correlated pairs
        seed: Random seed for reproducibility
        device: Device to create the matrix on
        dtype: Data type for the matrix

    Returns:
        Random correlation matrix of shape (num_features, num_features)
    """
    dtype = str_to_dtype(dtype)
    _validate_correlation_params(
        positive_ratio,
        uncorrelated_ratio,
        min_correlation_strength,
        max_correlation_strength,
    )

    if num_features <= 1:
        return torch.eye(num_features, device=device, dtype=dtype)

    # Set random seed if provided
    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    # Get upper triangular indices (i < j)
    row_idx, col_idx = torch.triu_indices(num_features, num_features, offset=1)
    num_pairs = row_idx.shape[0]

    # Generate random values for all pairs at once
    # is_correlated: 1 if this pair should have a correlation, 0 otherwise
    is_correlated = (
        torch.rand(num_pairs, generator=generator, device=device) >= uncorrelated_ratio
    )

    # signs: +1 for positive correlation, -1 for negative
    is_positive = (
        torch.rand(num_pairs, generator=generator, device=device) < positive_ratio
    )
    signs = torch.where(is_positive, 1.0, -1.0)

    # strengths: uniform in [min_strength, max_strength]
    strengths = (
        torch.rand(num_pairs, generator=generator, device=device, dtype=dtype)
        * (max_correlation_strength - min_correlation_strength)
        + min_correlation_strength
    )

    # Combine: correlation = is_correlated * sign * strength
    correlations = is_correlated.to(dtype) * signs.to(dtype) * strengths

    # Build the symmetric matrix
    matrix = torch.eye(num_features, device=device, dtype=dtype)
    matrix[row_idx, col_idx] = correlations
    matrix[col_idx, row_idx] = correlations

    # Check positive definiteness and fix if necessary
    eigenvals = torch.linalg.eigvalsh(matrix)
    if torch.any(eigenvals < -1e-6):
        matrix = _fix_correlation_matrix(matrix)

    return matrix

generate_random_correlations(num_features, positive_ratio=0.5, uncorrelated_ratio=0.3, min_correlation_strength=0.1, max_correlation_strength=0.8, seed=None)

Generate random correlations between features with specified constraints.

Parameters:

Name Type Description Default
num_features int

Number of features

required
positive_ratio float

Fraction of correlated pairs that should be positive (0.0 to 1.0)

0.5
uncorrelated_ratio float

Fraction of feature pairs that should have zero correlation (0.0 to 1.0). These pairs are omitted from the returned dictionary.

0.3
min_correlation_strength float

Minimum absolute correlation strength for correlated pairs

0.1
max_correlation_strength float

Maximum absolute correlation strength for correlated pairs

0.8
seed int | None

Random seed for reproducibility

None

Returns:

Type Description
dict[tuple[int, int], float]

Dictionary mapping (i, j) pairs to correlation values. Pairs with zero

dict[tuple[int, int], float]

correlation (determined by uncorrelated_ratio) are not included.

Source code in sae_lens/synthetic/correlation.py
def generate_random_correlations(
    num_features: int,
    positive_ratio: float = 0.5,
    uncorrelated_ratio: float = 0.3,
    min_correlation_strength: float = 0.1,
    max_correlation_strength: float = 0.8,
    seed: int | None = None,
) -> dict[tuple[int, int], float]:
    """
    Generate random correlations between features with specified constraints.

    Args:
        num_features: Number of features
        positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
        uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
            (0.0 to 1.0). These pairs are omitted from the returned dictionary.
        min_correlation_strength: Minimum absolute correlation strength for correlated pairs
        max_correlation_strength: Maximum absolute correlation strength for correlated pairs
        seed: Random seed for reproducibility

    Returns:
        Dictionary mapping (i, j) pairs to correlation values. Pairs with zero
        correlation (determined by uncorrelated_ratio) are not included.
    """
    # Use local random number generator to avoid side effects on global state
    rng = random.Random(seed)

    _validate_correlation_params(
        positive_ratio,
        uncorrelated_ratio,
        min_correlation_strength,
        max_correlation_strength,
    )

    # Generate all possible feature pairs (i, j) where i < j
    all_pairs = [
        (i, j) for i in range(num_features) for j in range(i + 1, num_features)
    ]
    total_pairs = len(all_pairs)

    if total_pairs == 0:
        return {}

    # Determine how many pairs to correlate vs leave uncorrelated
    num_uncorrelated = int(total_pairs * uncorrelated_ratio)
    num_correlated = total_pairs - num_uncorrelated

    # Randomly select which pairs to correlate
    correlated_pairs = rng.sample(all_pairs, num_correlated)

    # For correlated pairs, determine positive vs negative
    num_positive = int(num_correlated * positive_ratio)
    num_negative = num_correlated - num_positive

    # Assign signs
    signs = [1] * num_positive + [-1] * num_negative
    rng.shuffle(signs)

    # Generate correlation strengths
    correlations = {}
    for pair, sign in zip(correlated_pairs, signs):
        # Sample correlation strength uniformly from range
        strength = rng.uniform(min_correlation_strength, max_correlation_strength)
        correlations[pair] = sign * strength

    return correlations

generate_random_low_rank_correlation_matrix(num_features, rank, correlation_scale=0.075, seed=None, device='cpu', dtype=torch.float32)

Generate a random low-rank correlation structure for scalable correlated sampling.

The correlation structure is represented as

correlation = factor @ factor.T + diag(diag_term)

This requires O(num_features * rank) storage instead of O(num_features^2), making it suitable for very large numbers of features (e.g., 1M+).

The factor matrix is initialized with random values scaled by correlation_scale, and the diagonal term is computed to ensure the implied correlation matrix has unit diagonal.

Parameters:

Name Type Description Default
num_features int

Number of features

required
rank int

Rank of the low-rank approximation. Higher rank allows more complex correlation structures but uses more memory. Typical values: 10-100.

required
correlation_scale float

Scale factor for random correlations. Larger values produce stronger correlations between features. Use 0 for no correlations (identity matrix). Should be small enough that rank * correlation_scale^2 < 1 to ensure valid diagonal terms.

0.075
seed int | None

Random seed for reproducibility

None
device device | str

Device to create tensors on

'cpu'
dtype dtype | str

Data type for tensors

float32

Returns:

Type Description
LowRankCorrelationMatrix

LowRankCorrelationMatrix containing the factor matrix and diagonal term

Source code in sae_lens/synthetic/correlation.py
def generate_random_low_rank_correlation_matrix(
    num_features: int,
    rank: int,
    correlation_scale: float = 0.075,
    seed: int | None = None,
    device: torch.device | str = "cpu",
    dtype: torch.dtype | str = torch.float32,
) -> LowRankCorrelationMatrix:
    """
    Generate a random low-rank correlation structure for scalable correlated sampling.

    The correlation structure is represented as:
        correlation = factor @ factor.T + diag(diag_term)

    This requires O(num_features * rank) storage instead of O(num_features^2),
    making it suitable for very large numbers of features (e.g., 1M+).

    The factor matrix is initialized with random values scaled by correlation_scale,
    and the diagonal term is computed to ensure the implied correlation matrix has
    unit diagonal.

    Args:
        num_features: Number of features
        rank: Rank of the low-rank approximation. Higher rank allows more complex
            correlation structures but uses more memory. Typical values: 10-100.
        correlation_scale: Scale factor for random correlations. Larger values produce
            stronger correlations between features. Use 0 for no correlations (identity
            matrix). Should be small enough that rank * correlation_scale^2 < 1 to
            ensure valid diagonal terms.
        seed: Random seed for reproducibility
        device: Device to create tensors on
        dtype: Data type for tensors

    Returns:
        LowRankCorrelationMatrix containing the factor matrix and diagonal term
    """
    # Minimum diagonal value to ensure numerical stability in the covariance matrix.
    # This limits how much variance can come from the low-rank factor.
    _MIN_DIAG = 0.01

    dtype = str_to_dtype(dtype)
    device = torch.device(device)

    if rank <= 0:
        raise ValueError("rank must be positive")
    if correlation_scale < 0:
        raise ValueError("correlation_scale must be non-negative")

    # Set random seed if provided
    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    # Generate random factor matrix
    # Each row has norm roughly sqrt(rank) * correlation_scale
    factor = (
        torch.randn(num_features, rank, generator=generator, device=device, dtype=dtype)
        * correlation_scale
    )

    # Compute diagonal term to ensure unit diagonal in implied correlation matrix
    # diag(factor @ factor.T) + diag_term = 1
    # diag_term = 1 - sum(factor[i, :]^2)
    factor_sq_sum = (factor**2).sum(dim=1)
    diag_term = 1 - factor_sq_sum

    # alternatively, we can rescale each row independently to ensure the diagonal is 1
    mask = diag_term < _MIN_DIAG
    factor[mask, :] *= torch.sqrt((1 - _MIN_DIAG) / factor_sq_sum[mask].unsqueeze(1))
    factor_sq_sum = (factor**2).sum(dim=1)
    diag_term = 1 - factor_sq_sum

    total_rescaled = mask.sum().item()
    if total_rescaled > 0:
        logger.warning(
            f"{total_rescaled} / {num_features} rows were capped. Either reduce the rank or reduce the correlation_scale to avoid rescaling."
        )

    return LowRankCorrelationMatrix(
        correlation_factor=factor, correlation_diag=diag_term
    )

hierarchy_modifier(roots)

Create an activations modifier from one or more hierarchy trees.

This is the recommended way to use hierarchies with ActivationGenerator. It validates the hierarchy structure and returns a modifier function that applies all hierarchy constraints.

Parameters:

Name Type Description Default
roots Sequence[HierarchyNode] | HierarchyNode

One or more root HierarchyNode objects. Each root defines an independent hierarchy tree. All trees are validated and applied.

required

Returns:

Type Description
ActivationsModifier

An ActivationsModifier function that can be passed to ActivationGenerator.

Raises:

Type Description
ValueError

If validate=True and any hierarchy contains loops or nodes with multiple parents.

Source code in sae_lens/synthetic/hierarchy.py
@torch.no_grad()
def hierarchy_modifier(
    roots: Sequence[HierarchyNode] | HierarchyNode,
) -> ActivationsModifier:
    """
    Create an activations modifier from one or more hierarchy trees.

    This is the recommended way to use hierarchies with ActivationGenerator.
    It validates the hierarchy structure and returns a modifier function that
    applies all hierarchy constraints.

    Args:
        roots: One or more root HierarchyNode objects. Each root defines an
            independent hierarchy tree. All trees are validated and applied.

    Returns:
        An ActivationsModifier function that can be passed to ActivationGenerator.

    Raises:
        ValueError: If validate=True and any hierarchy contains loops or
            nodes with multiple parents.
    """
    if not roots:
        # No hierarchies - return identity function
        def identity(activations: torch.Tensor) -> torch.Tensor:
            return activations

        return identity

    if isinstance(roots, HierarchyNode):
        roots = [roots]
    _validate_hierarchy(roots)

    # Build sparse hierarchy data
    sparse_data = _build_sparse_hierarchy(roots)

    # Cache for device-specific tensors
    device_cache: dict[torch.device, _SparseHierarchyData] = {}

    def _get_sparse_for_device(device: torch.device) -> _SparseHierarchyData:
        """Get or create device-specific sparse hierarchy data."""
        if device not in device_cache:
            device_cache[device] = _SparseHierarchyData(
                level_data=[
                    _LevelData(
                        features=ld.features.to(device),
                        parents=ld.parents.to(device),
                        me_group_indices=ld.me_group_indices.to(device),
                    )
                    for ld in sparse_data.level_data
                ],
                me_group_siblings=sparse_data.me_group_siblings.to(device),
                me_group_sizes=sparse_data.me_group_sizes.to(device),
                me_group_parents=sparse_data.me_group_parents.to(device),
                num_groups=sparse_data.num_groups,
                feat_to_parent=(
                    sparse_data.feat_to_parent.to(device)
                    if sparse_data.feat_to_parent is not None
                    else None
                ),
                feat_to_me_group=(
                    sparse_data.feat_to_me_group.to(device)
                    if sparse_data.feat_to_me_group is not None
                    else None
                ),
            )
        return device_cache[device]

    def modifier(activations: torch.Tensor) -> torch.Tensor:
        device = activations.device
        cached = _get_sparse_for_device(device)
        if activations.is_sparse:
            return _apply_hierarchy_sparse_coo(activations, cached)
        return _apply_hierarchy_sparse(activations, cached)

    return modifier

init_sae_to_match_feature_dict(sae, feature_dict, noise_level=0.0, feature_ordering=None)

Initialize an SAE's weights to match a feature dictionary.

This can be useful for:

  • Starting training from a known good initialization
  • Testing SAE evaluation code with ground truth
  • Ablation studies on initialization

Parameters:

Name Type Description Default
sae Module

The SAE to initialize. Must have W_enc and W_dec attributes.

required
feature_dict FeatureDictionary

The feature dictionary to match

required
noise_level float

Standard deviation of Gaussian noise to add (0 = exact match)

0.0
feature_ordering Tensor | None

Optional permutation of feature indices

None
Source code in sae_lens/synthetic/initialization.py
@torch.no_grad()
def init_sae_to_match_feature_dict(
    sae: torch.nn.Module,
    feature_dict: FeatureDictionary,
    noise_level: float = 0.0,
    feature_ordering: torch.Tensor | None = None,
) -> None:
    """
    Initialize an SAE's weights to match a feature dictionary.

    This can be useful for:

    - Starting training from a known good initialization
    - Testing SAE evaluation code with ground truth
    - Ablation studies on initialization

    Args:
        sae: The SAE to initialize. Must have W_enc and W_dec attributes.
        feature_dict: The feature dictionary to match
        noise_level: Standard deviation of Gaussian noise to add (0 = exact match)
        feature_ordering: Optional permutation of feature indices
    """
    features = feature_dict.feature_vectors  # [num_features, hidden_dim]
    min_dim = min(sae.W_enc.shape[1], features.shape[0])  # type: ignore[attr-defined]

    if feature_ordering is not None:
        features = features[feature_ordering]

    features = features[:min_dim]

    # W_enc is [hidden_dim, d_sae], feature vectors are [num_features, hidden_dim]
    sae.W_enc.data[:, :min_dim] = (  # type: ignore[index]
        features.T + torch.randn_like(features.T) * noise_level
    )
    sae.W_dec.data = sae.W_enc.data.T.clone()  # type: ignore[union-attr]

linear_firing_probabilities(num_features, max_prob=0.3, min_prob=0.01)

Generate firing probabilities that decay linearly from max to min.

Parameters:

Name Type Description Default
num_features int

Number of features to generate probabilities for

required
max_prob float

Firing probability for the first feature

0.3
min_prob float

Firing probability for the last feature

0.01

Returns:

Type Description
Tensor

Tensor of shape [num_features] with linearly decaying probabilities

Source code in sae_lens/synthetic/firing_probabilities.py
def linear_firing_probabilities(
    num_features: int,
    max_prob: float = 0.3,
    min_prob: float = 0.01,
) -> torch.Tensor:
    """
    Generate firing probabilities that decay linearly from max to min.

    Args:
        num_features: Number of features to generate probabilities for
        max_prob: Firing probability for the first feature
        min_prob: Firing probability for the last feature

    Returns:
        Tensor of shape [num_features] with linearly decaying probabilities
    """
    if num_features < 1:
        raise ValueError("num_features must be at least 1")
    if not 0 < min_prob <= max_prob <= 1:
        raise ValueError("Must have 0 < min_prob <= max_prob <= 1")

    if num_features == 1:
        return torch.tensor([max_prob])

    return torch.linspace(max_prob, min_prob, num_features)

mean_correlation_coefficient(features_a, features_b)

Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.

MCC measures how well learned features align with ground truth features by finding an optimal one-to-one matching using the Hungarian algorithm and computing the mean absolute cosine similarity of matched pairs.

Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation Gap in Sparse Autoencoders" (arXiv:2411.13117)

Parameters:

Name Type Description Default
features_a Tensor

Feature vectors of shape [num_features_a, hidden_dim]

required
features_b Tensor

Feature vectors of shape [num_features_b, hidden_dim]

required

Returns:

Type Description
float

MCC score in range [0, 1], where 1 indicates perfect alignment

Source code in sae_lens/synthetic/evals.py
def mean_correlation_coefficient(
    features_a: torch.Tensor,
    features_b: torch.Tensor,
) -> float:
    """
    Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.

    MCC measures how well learned features align with ground truth features by finding
    an optimal one-to-one matching using the Hungarian algorithm and computing the
    mean absolute cosine similarity of matched pairs.

    Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation
    Gap in Sparse Autoencoders" (arXiv:2411.13117)

    Args:
        features_a: Feature vectors of shape [num_features_a, hidden_dim]
        features_b: Feature vectors of shape [num_features_b, hidden_dim]

    Returns:
        MCC score in range [0, 1], where 1 indicates perfect alignment
    """
    # Normalize to unit vectors
    a_norm = features_a / features_a.norm(dim=1, keepdim=True).clamp(min=1e-8)
    b_norm = features_b / features_b.norm(dim=1, keepdim=True).clamp(min=1e-8)

    # Compute absolute cosine similarity matrix
    cos_sim = torch.abs(a_norm @ b_norm.T)

    # Convert to cost matrix for Hungarian algorithm (which minimizes)
    cost_matrix = 1 - cos_sim.cpu().numpy()

    # Find optimal matching
    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    # Compute mean of matched similarities
    matched_similarities = cos_sim[row_ind, col_ind]
    return matched_similarities.mean().item()

orthogonalize_embeddings(embeddings, num_steps=200, lr=0.01, show_progress=False, chunk_size=1024)

Orthogonalize embeddings using gradient descent with chunked computation.

Uses chunked computation to avoid O(n²) memory usage when computing pairwise dot products. Memory usage is O(chunk_size × n) instead of O(n²).

Parameters:

Name Type Description Default
embeddings Tensor

Tensor of shape [num_vectors, hidden_dim]

required
num_steps int

Number of optimization steps

200
lr float

Learning rate for Adam optimizer

0.01
show_progress bool

Whether to show progress bar

False
chunk_size int

Number of vectors to process at once. Smaller values use less memory but may be slower.

1024

Returns:

Type Description
Tensor

Orthogonalized embeddings of the same shape, normalized to unit length.

Source code in sae_lens/synthetic/feature_dictionary.py
def orthogonalize_embeddings(
    embeddings: torch.Tensor,
    num_steps: int = 200,
    lr: float = 0.01,
    show_progress: bool = False,
    chunk_size: int = 1024,
) -> torch.Tensor:
    """
    Orthogonalize embeddings using gradient descent with chunked computation.

    Uses chunked computation to avoid O(n²) memory usage when computing pairwise
    dot products. Memory usage is O(chunk_size × n) instead of O(n²).

    Args:
        embeddings: Tensor of shape [num_vectors, hidden_dim]
        num_steps: Number of optimization steps
        lr: Learning rate for Adam optimizer
        show_progress: Whether to show progress bar
        chunk_size: Number of vectors to process at once. Smaller values use less
            memory but may be slower.

    Returns:
        Orthogonalized embeddings of the same shape, normalized to unit length.
    """
    num_vectors = embeddings.shape[0]
    # Create a detached copy and normalize, then enable gradients
    embeddings = embeddings.detach().clone()
    embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
    embeddings = embeddings.requires_grad_(True)

    optimizer = torch.optim.Adam([embeddings], lr=lr)  # type: ignore[list-item]

    pbar = tqdm(
        range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
    )
    for _ in pbar:
        optimizer.zero_grad()

        off_diag_loss = torch.tensor(0.0, device=embeddings.device)
        diag_loss = torch.tensor(0.0, device=embeddings.device)

        for i in range(0, num_vectors, chunk_size):
            end_i = min(i + chunk_size, num_vectors)
            chunk = embeddings[i:end_i]
            chunk_dots = chunk @ embeddings.T  # [chunk_size, num_vectors]

            # Create mask to zero out diagonal elements for this chunk
            # Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
            chunk_len = end_i - i
            row_indices = torch.arange(chunk_len, device=embeddings.device)
            col_indices = i + row_indices  # column indices in full matrix

            # Boolean mask: True for off-diagonal elements we want to include
            off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
            off_diag_mask[row_indices, col_indices] = False

            off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()

            # Diagonal loss: keep self-dot-products at 1
            diag_vals = chunk_dots[row_indices, col_indices]
            diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()

        loss = off_diag_loss + num_vectors * diag_loss
        loss.backward()
        optimizer.step()
        pbar.set_description(f"loss: {loss.item():.3f}")

    with torch.no_grad():
        embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
            min=1e-8
        )
    return embeddings.detach().clone()

plot_sae_feature_similarity(sae, feature_dict, title=None, reorder_features=False, decoder_only=False, show_values=False, height=400, width=800, save_path=None, show_plot=True, dtick=1, scale=1.0)

Plot cosine similarities between SAE features and true features.

Creates a heatmap showing how well each SAE latent aligns with each true feature. Useful for understanding what the SAE has learned.

Parameters:

Name Type Description Default
sae SAE[Any]

The SAE to visualize. Must have W_enc and W_dec attributes.

required
feature_dict FeatureDictionary

The feature dictionary containing true features

required
title str | None

Plot title. If None, a default title is used.

None
reorder_features bool | Tensor

If True, automatically reorders features for best alignment. If a tensor, uses that as the ordering.

False
decoder_only bool

If True, only plots the decoder (not encoder and decoder side-by-side)

False
show_values bool

If True, shows numeric values on the heatmap

False
height int

Height of the figure in pixels

400
width int

Width of the figure in pixels

800
save_path str | Path | None

If provided, saves the figure to this path

None
show_plot bool

If True, displays the plot

True
dtick int | None

Tick spacing for axes

1
scale float

Scale factor for image resolution when saving

1.0
Source code in sae_lens/synthetic/plotting.py
def plot_sae_feature_similarity(
    sae: SAE[Any],
    feature_dict: FeatureDictionary,
    title: str | None = None,
    reorder_features: bool | torch.Tensor = False,
    decoder_only: bool = False,
    show_values: bool = False,
    height: int = 400,
    width: int = 800,
    save_path: str | Path | None = None,
    show_plot: bool = True,
    dtick: int | None = 1,
    scale: float = 1.0,
):
    """
    Plot cosine similarities between SAE features and true features.

    Creates a heatmap showing how well each SAE latent aligns with each
    true feature. Useful for understanding what the SAE has learned.

    Args:
        sae: The SAE to visualize. Must have W_enc and W_dec attributes.
        feature_dict: The feature dictionary containing true features
        title: Plot title. If None, a default title is used.
        reorder_features: If True, automatically reorders features for best alignment.
            If a tensor, uses that as the ordering.
        decoder_only: If True, only plots the decoder (not encoder and decoder side-by-side)
        show_values: If True, shows numeric values on the heatmap
        height: Height of the figure in pixels
        width: Width of the figure in pixels
        save_path: If provided, saves the figure to this path
        show_plot: If True, displays the plot
        dtick: Tick spacing for axes
        scale: Scale factor for image resolution when saving
    """
    # Get cosine similarities
    true_features = feature_dict.feature_vectors.detach()
    dec_cos_sims = cosine_similarities(sae.W_dec.detach(), true_features)  # type: ignore[attr-defined]
    enc_cos_sims = cosine_similarities(sae.W_enc.T.detach(), true_features)  # type: ignore[attr-defined]

    # Round to reduce numerical noise
    dec_cos_sims = torch.round(dec_cos_sims * 100) / 100
    enc_cos_sims = torch.round(enc_cos_sims * 100) / 100

    # Apply feature reordering if requested
    if reorder_features is not False:
        if isinstance(reorder_features, bool):
            sorted_indices = find_best_feature_ordering(
                sae.W_dec.detach(),
                true_features,  # type: ignore[attr-defined]
            )
        else:
            sorted_indices = reorder_features
        dec_cos_sims = dec_cos_sims[sorted_indices]
        enc_cos_sims = enc_cos_sims[sorted_indices]

    hovertemplate = "True feature: %{x}<br>SAE Latent: %{y}<br>Cosine Similarity: %{z:.3f}<extra></extra>"

    if decoder_only:
        fig = make_subplots(rows=1, cols=1)

        decoder_args: dict[str, Any] = {
            "z": dec_cos_sims.cpu().numpy(),
            "zmin": -1,
            "zmax": 1,
            "colorscale": "RdBu",
            "colorbar": dict(title="cos sim", x=1.0, dtick=1, tickvals=[-1, 0, 1]),
            "hovertemplate": hovertemplate,
        }
        if show_values:
            decoder_args["texttemplate"] = "%{z:.2f}"
            decoder_args["textfont"] = {"size": 10}

        fig.add_trace(go.Heatmap(**decoder_args), row=1, col=1)
        fig.update_xaxes(title_text="True feature", row=1, col=1, dtick=dtick)
        fig.update_yaxes(title_text="SAE Latent", row=1, col=1, dtick=dtick)
    else:
        fig = make_subplots(
            rows=1, cols=2, subplot_titles=("SAE encoder", "SAE decoder")
        )

        # Encoder heatmap
        encoder_args: dict[str, Any] = {
            "z": enc_cos_sims.cpu().numpy(),
            "zmin": -1,
            "zmax": 1,
            "colorscale": "RdBu",
            "showscale": False,
            "hovertemplate": hovertemplate,
        }
        if show_values:
            encoder_args["texttemplate"] = "%{z:.2f}"
            encoder_args["textfont"] = {"size": 10}

        fig.add_trace(go.Heatmap(**encoder_args), row=1, col=1)

        # Decoder heatmap
        decoder_args = {
            "z": dec_cos_sims.cpu().numpy(),
            "zmin": -1,
            "zmax": 1,
            "colorscale": "RdBu",
            "colorbar": dict(title="cos sim", x=1.0, dtick=1, tickvals=[-1, 0, 1]),
            "hovertemplate": hovertemplate,
        }
        if show_values:
            decoder_args["texttemplate"] = "%{z:.2f}"
            decoder_args["textfont"] = {"size": 10}

        fig.add_trace(go.Heatmap(**decoder_args), row=1, col=2)

        fig.update_xaxes(title_text="True feature", row=1, col=1, dtick=dtick)
        fig.update_xaxes(title_text="True feature", row=1, col=2, dtick=dtick)
        fig.update_yaxes(title_text="SAE Latent", row=1, col=1, dtick=dtick)
        fig.update_yaxes(title_text="SAE Latent", row=1, col=2, dtick=dtick)

    # Set main title
    if title is None:
        title = "Cosine similarity with true features"
    fig.update_layout(height=height, width=width, title_text=title)

    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        fig.write_image(save_path, scale=scale)

    if show_plot:
        fig.show()

random_firing_probabilities(num_features, max_prob=0.5, min_prob=0.01, seed=None)

Generate random firing probabilities uniformly sampled from a range.

Parameters:

Name Type Description Default
num_features int

Number of features to generate probabilities for

required
max_prob float

Maximum firing probability

0.5
min_prob float

Minimum firing probability

0.01
seed int | None

Optional random seed for reproducibility

None

Returns:

Type Description
Tensor

Tensor of shape [num_features] with random firing probabilities

Source code in sae_lens/synthetic/firing_probabilities.py
def random_firing_probabilities(
    num_features: int,
    max_prob: float = 0.5,
    min_prob: float = 0.01,
    seed: int | None = None,
) -> torch.Tensor:
    """
    Generate random firing probabilities uniformly sampled from a range.

    Args:
        num_features: Number of features to generate probabilities for
        max_prob: Maximum firing probability
        min_prob: Minimum firing probability
        seed: Optional random seed for reproducibility

    Returns:
        Tensor of shape [num_features] with random firing probabilities
    """
    if num_features < 1:
        raise ValueError("num_features must be at least 1")
    if not 0 < min_prob < max_prob <= 1:
        raise ValueError("Must have 0 < min_prob < max_prob <= 1")

    generator = torch.Generator()
    if seed is not None:
        generator.manual_seed(seed)

    probs = torch.rand(num_features, generator=generator, dtype=torch.float32)
    return min_prob + (max_prob - min_prob) * probs

train_toy_sae(sae, feature_dict, activations_generator, training_samples=10000000, batch_size=1024, lr=0.0003, lr_warm_up_steps=0, lr_decay_steps=0, device='cpu', n_snapshots=0, snapshot_fn=None, autocast_sae=False, autocast_data=False)

Train an SAE on synthetic activations from a feature dictionary.

This is a convenience function that sets up the training loop with sensible defaults for small-scale synthetic data experiments.

Parameters:

Name Type Description Default
sae TrainingSAE[Any]

The TrainingSAE to train

required
feature_dict FeatureDictionary

The feature dictionary that maps feature activations to hidden activations

required
activations_generator ActivationGenerator

Generator that produces feature activations

required
training_samples int

Total number of training samples

10000000
batch_size int

Batch size for training

1024
lr float

Learning rate

0.0003
lr_warm_up_steps int

Number of warmup steps for learning rate

0
lr_decay_steps int

Number of steps over which to decay learning rate

0
device str | device

Device to train on

'cpu'
n_snapshots int

Number of snapshots to take during training. Snapshots are evenly spaced throughout training.

0
snapshot_fn Callable[[SAETrainer[Any, Any]], None] | None

Callback function called at each snapshot point. Receives the SAETrainer instance, allowing access to the SAE, training step, and other training state. Required if n_snapshots > 0.

None
autocast_sae bool

Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA

False
autocast_data bool

Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.

False
Source code in sae_lens/synthetic/training.py
def train_toy_sae(
    sae: TrainingSAE[Any],
    feature_dict: FeatureDictionary,
    activations_generator: ActivationGenerator,
    training_samples: int = 10_000_000,
    batch_size: int = 1024,
    lr: float = 3e-4,
    lr_warm_up_steps: int = 0,
    lr_decay_steps: int = 0,
    device: str | torch.device = "cpu",
    n_snapshots: int = 0,
    snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
    autocast_sae: bool = False,
    autocast_data: bool = False,
) -> None:
    """
    Train an SAE on synthetic activations from a feature dictionary.

    This is a convenience function that sets up the training loop with
    sensible defaults for small-scale synthetic data experiments.

    Args:
        sae: The TrainingSAE to train
        feature_dict: The feature dictionary that maps feature activations to
            hidden activations
        activations_generator: Generator that produces feature activations
        training_samples: Total number of training samples
        batch_size: Batch size for training
        lr: Learning rate
        lr_warm_up_steps: Number of warmup steps for learning rate
        lr_decay_steps: Number of steps over which to decay learning rate
        device: Device to train on
        n_snapshots: Number of snapshots to take during training. Snapshots are
            evenly spaced throughout training.
        snapshot_fn: Callback function called at each snapshot point. Receives
            the SAETrainer instance, allowing access to the SAE, training step,
            and other training state. Required if n_snapshots > 0.
        autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
        autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
    """

    device_str = str(device) if isinstance(device, torch.device) else device

    # Create data iterator
    data_iterator = SyntheticActivationIterator(
        feature_dict=feature_dict,
        activations_generator=activations_generator,
        batch_size=batch_size,
        autocast=autocast_data,
    )

    # Create trainer config
    trainer_cfg = SAETrainerConfig(
        n_checkpoints=n_snapshots,
        checkpoint_path=None,
        save_final_checkpoint=False,
        total_training_samples=training_samples,
        device=device_str,
        autocast=autocast_sae,
        lr=lr,
        lr_end=lr,
        lr_scheduler_name="constant",
        lr_warm_up_steps=lr_warm_up_steps,
        adam_beta1=0.9,
        adam_beta2=0.999,
        lr_decay_steps=lr_decay_steps,
        n_restart_cycles=1,
        train_batch_size_samples=batch_size,
        dead_feature_window=1000,
        feature_sampling_window=2000,
        logger=LoggingConfig(
            log_to_wandb=False,
            # hacky way to disable evals, but works for now
            eval_every_n_wandb_logs=2**31 - 1,
        ),
    )

    def snapshot_wrapper(
        snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None,
    ) -> SaveCheckpointFn:
        def save_checkpoint(checkpoint_path: Path | None) -> None:  # noqa: ARG001
            if snapshot_fn is None:
                raise ValueError("snapshot_fn must be provided to take snapshots")
            snapshot_fn(trainer)

        return save_checkpoint

    # Create trainer and train
    feature_dict.eval()
    trainer = SAETrainer(
        cfg=trainer_cfg,
        sae=sae,
        data_provider=data_iterator,
        save_checkpoint_fn=snapshot_wrapper(snapshot_fn),
    )

    trainer.fit()

zipfian_firing_probabilities(num_features, exponent=1.0, max_prob=0.3, min_prob=0.01)

Generate firing probabilities following a Zipfian (power-law) distribution.

Creates probabilities where a few features fire frequently and most fire rarely, which mirrors the distribution often observed in real neural network features.

Parameters:

Name Type Description Default
num_features int

Number of features to generate probabilities for

required
exponent float

Zipf exponent (higher = steeper dropoff). Default 1.0.

1.0
max_prob float

Maximum firing probability (for the most frequent feature)

0.3
min_prob float

Minimum firing probability (for the least frequent feature)

0.01

Returns:

Type Description
Tensor

Tensor of shape [num_features] with firing probabilities in descending order

Source code in sae_lens/synthetic/firing_probabilities.py
def zipfian_firing_probabilities(
    num_features: int,
    exponent: float = 1.0,
    max_prob: float = 0.3,
    min_prob: float = 0.01,
) -> torch.Tensor:
    """
    Generate firing probabilities following a Zipfian (power-law) distribution.

    Creates probabilities where a few features fire frequently and most fire rarely,
    which mirrors the distribution often observed in real neural network features.

    Args:
        num_features: Number of features to generate probabilities for
        exponent: Zipf exponent (higher = steeper dropoff). Default 1.0.
        max_prob: Maximum firing probability (for the most frequent feature)
        min_prob: Minimum firing probability (for the least frequent feature)

    Returns:
        Tensor of shape [num_features] with firing probabilities in descending order
    """
    if num_features < 1:
        raise ValueError("num_features must be at least 1")
    if exponent <= 0:
        raise ValueError("exponent must be positive")
    if not 0 < min_prob < max_prob <= 1:
        raise ValueError("Must have 0 < min_prob < max_prob <= 1")

    ranks = torch.arange(1, num_features + 1, dtype=torch.float32)
    probs = 1.0 / ranks**exponent

    # Scale to [min_prob, max_prob]
    if num_features == 1:
        return torch.tensor([max_prob])

    probs_min, probs_max = probs.min(), probs.max()
    return min_prob + (max_prob - min_prob) * (probs - probs_min) / (
        probs_max - probs_min
    )