sae-lens 6.0.0rc3__tar.gz → 6.0.0rc4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/PKG-INFO +1 -1
  2. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/pyproject.toml +2 -1
  3. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/analysis/neuronpedia_integration.py +3 -3
  5. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/config.py +5 -3
  6. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/evals.py +20 -9
  7. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/llm_sae_training_runner.py +113 -5
  8. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/loading/pretrained_sae_loaders.py +24 -5
  9. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/gated_sae.py +0 -4
  10. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/jumprelu_sae.py +4 -10
  11. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/sae.py +121 -48
  12. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/standard_sae.py +4 -11
  13. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/topk_sae.py +18 -12
  14. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/activation_scaler.py +1 -1
  15. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/activations_store.py +0 -2
  16. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/sae_trainer.py +11 -3
  17. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/upload_saes_to_huggingface.py +1 -1
  18. sae_lens-6.0.0rc3/sae_lens/training/geometric_median.py +0 -101
  19. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/LICENSE +0 -0
  20. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/README.md +0 -0
  21. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/analysis/__init__.py +0 -0
  22. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  23. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/cache_activations_runner.py +0 -0
  24. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/constants.py +0 -0
  25. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/load_model.py +0 -0
  26. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/loading/__init__.py +0 -0
  27. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  28. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/pretokenize_runner.py +0 -0
  29. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/pretrained_saes.yaml +0 -0
  30. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/registry.py +0 -0
  31. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/__init__.py +0 -0
  32. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/tokenization_and_batching.py +0 -0
  33. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/__init__.py +0 -0
  34. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/mixing_buffer.py +0 -0
  35. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/optim.py +0 -0
  36. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/tutorial/tsea.py +0 -0
  38. {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc3
3
+ Version: 6.0.0rc4
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.0.0-rc.3"
3
+ version = "6.0.0-rc.4"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -60,6 +60,7 @@ tabulate = "^0.9.0"
60
60
  ruff = "^0.7.4"
61
61
  eai-sparsify = "^1.1.1"
62
62
  mike = "^2.0.0"
63
+ trio = "^0.30.0"
63
64
 
64
65
  [tool.poetry.extras]
65
66
  mamba = ["mamba-lens"]
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.0.0-rc.3"
2
+ __version__ = "6.0.0-rc.4"
3
3
 
4
4
  import logging
5
5
 
@@ -59,7 +59,7 @@ def NanAndInfReplacer(value: str):
59
59
 
60
60
 
61
61
  def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
62
- sae_id = sae.cfg.neuronpedia_id
62
+ sae_id = sae.cfg.metadata.neuronpedia_id
63
63
  if sae_id is None:
64
64
  logger.warning(
65
65
  "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
@@ -74,7 +74,7 @@ def get_neuronpedia_quick_list(
74
74
  features: list[int],
75
75
  name: str = "temporary_list",
76
76
  ):
77
- sae_id = sae.cfg.neuronpedia_id
77
+ sae_id = sae.cfg.metadata.neuronpedia_id
78
78
  if sae_id is None:
79
79
  logger.warning(
80
80
  "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
@@ -86,7 +86,7 @@ def get_neuronpedia_quick_list(
86
86
  url = url + "?name=" + name
87
87
  list_feature = [
88
88
  {
89
- "modelId": sae.cfg.model_name,
89
+ "modelId": sae.cfg.metadata.model_name,
90
90
  "layer": sae_id.split("/")[1],
91
91
  "index": str(feature),
92
92
  }
@@ -201,7 +201,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
201
201
  train_batch_size_tokens: int = 4096
202
202
 
203
203
  ## Adam
204
- adam_beta1: float = 0.0
204
+ adam_beta1: float = 0.9
205
205
  adam_beta2: float = 0.999
206
206
 
207
207
  ## Learning Rate Schedule
@@ -390,7 +390,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
390
390
  adam_beta2=self.adam_beta2,
391
391
  lr_decay_steps=self.lr_decay_steps,
392
392
  n_restart_cycles=self.n_restart_cycles,
393
- total_training_steps=self.total_training_steps,
394
393
  train_batch_size_samples=self.train_batch_size_tokens,
395
394
  dead_feature_window=self.dead_feature_window,
396
395
  feature_sampling_window=self.feature_sampling_window,
@@ -613,8 +612,11 @@ class SAETrainerConfig:
613
612
  adam_beta2: float
614
613
  lr_decay_steps: int
615
614
  n_restart_cycles: int
616
- total_training_steps: int
617
615
  train_batch_size_samples: int
618
616
  dead_feature_window: int
619
617
  feature_sampling_window: int
620
618
  logger: LoggingConfig
619
+
620
+ @property
621
+ def total_training_steps(self) -> int:
622
+ return self.total_training_samples // self.train_batch_size_samples
@@ -4,6 +4,7 @@ import json
4
4
  import math
5
5
  import re
6
6
  import subprocess
7
+ import sys
7
8
  from collections import defaultdict
8
9
  from collections.abc import Mapping
9
10
  from dataclasses import dataclass, field
@@ -15,7 +16,7 @@ from typing import Any
15
16
  import einops
16
17
  import pandas as pd
17
18
  import torch
18
- from tqdm import tqdm
19
+ from tqdm.auto import tqdm
19
20
  from transformer_lens import HookedTransformer
20
21
  from transformer_lens.hook_points import HookedRootModule
21
22
 
@@ -814,16 +815,18 @@ def multiple_evals(
814
815
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
815
816
  sae_id=sae_id, # won't always be a hook point
816
817
  device=device,
817
- )[0]
818
+ )
818
819
 
819
820
  # move SAE to device if not there already
820
821
  sae.to(device)
821
822
 
822
- if current_model_str != sae.cfg.model_name:
823
+ if current_model_str != sae.cfg.metadata.model_name:
823
824
  del current_model # potentially saves GPU memory
824
- current_model_str = sae.cfg.model_name
825
+ current_model_str = sae.cfg.metadata.model_name
825
826
  current_model = HookedTransformer.from_pretrained_no_processing(
826
- current_model_str, device=device, **sae.cfg.model_from_pretrained_kwargs
827
+ current_model_str,
828
+ device=device,
829
+ **sae.cfg.metadata.model_from_pretrained_kwargs,
827
830
  )
828
831
  assert current_model is not None
829
832
 
@@ -941,7 +944,7 @@ def process_results(
941
944
  }
942
945
 
943
946
 
944
- if __name__ == "__main__":
947
+ def process_args(args: list[str]) -> argparse.Namespace:
945
948
  arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
946
949
  arg_parser.add_argument(
947
950
  "sae_regex_pattern",
@@ -1031,11 +1034,19 @@ if __name__ == "__main__":
1031
1034
  help="Enable verbose output with tqdm loaders.",
1032
1035
  )
1033
1036
 
1034
- args = arg_parser.parse_args()
1035
- eval_results = run_evaluations(args)
1036
- output_files = process_results(eval_results, args.output_dir)
1037
+ return arg_parser.parse_args(args)
1038
+
1039
+
1040
+ def run_evals_cli(args: list[str]) -> None:
1041
+ opts = process_args(args)
1042
+ eval_results = run_evaluations(opts)
1043
+ output_files = process_results(eval_results, opts.output_dir)
1037
1044
 
1038
1045
  print("Evaluation complete. Output files:")
1039
1046
  print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
1040
1047
  print(f"Combined JSON: {output_files['combined_json']}")
1041
1048
  print(f"CSV: {output_files['csv']}")
1049
+
1050
+
1051
+ if __name__ == "__main__":
1052
+ run_evals_cli(sys.argv[1:])
@@ -4,7 +4,7 @@ import sys
4
4
  from collections.abc import Sequence
5
5
  from dataclasses import dataclass
6
6
  from pathlib import Path
7
- from typing import Any, Generic, cast
7
+ from typing import Any, Generic
8
8
 
9
9
  import torch
10
10
  import wandb
@@ -17,12 +17,16 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
17
  from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
18
  from sae_lens.evals import EvalConfig, run_evals
19
19
  from sae_lens.load_model import load_model
20
+ from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
+ from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
20
22
  from sae_lens.saes.sae import (
21
23
  T_TRAINING_SAE,
22
24
  T_TRAINING_SAE_CONFIG,
23
25
  TrainingSAE,
24
26
  TrainingSAEConfig,
25
27
  )
28
+ from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
29
+ from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
26
30
  from sae_lens.training.activation_scaler import ActivationScaler
27
31
  from sae_lens.training.activations_store import ActivationsStore
28
32
  from sae_lens.training.sae_trainer import SAETrainer
@@ -145,17 +149,18 @@ class LanguageModelSAETrainingRunner:
145
149
  )
146
150
  else:
147
151
  self.sae = override_sae
152
+ self.sae.to(self.cfg.device)
148
153
 
149
154
  def run(self):
150
155
  """
151
156
  Run the training of the SAE.
152
157
  """
153
-
158
+ self._set_sae_metadata()
154
159
  if self.cfg.logger.log_to_wandb:
155
160
  wandb.init(
156
161
  project=self.cfg.logger.wandb_project,
157
162
  entity=self.cfg.logger.wandb_entity,
158
- config=cast(Any, self.cfg),
163
+ config=self.cfg.to_dict(),
159
164
  name=self.cfg.logger.run_name,
160
165
  id=self.cfg.logger.wandb_id,
161
166
  )
@@ -184,6 +189,20 @@ class LanguageModelSAETrainingRunner:
184
189
 
185
190
  return sae
186
191
 
192
+ def _set_sae_metadata(self):
193
+ self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
194
+ self.sae.cfg.metadata.hook_name = self.cfg.hook_name
195
+ self.sae.cfg.metadata.model_name = self.cfg.model_name
196
+ self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
197
+ self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
198
+ self.sae.cfg.metadata.context_size = self.cfg.context_size
199
+ self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
200
+ self.sae.cfg.metadata.model_from_pretrained_kwargs = (
201
+ self.cfg.model_from_pretrained_kwargs
202
+ )
203
+ self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
204
+ self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
205
+
187
206
  def _compile_if_needed(self):
188
207
  # Compile model and SAE
189
208
  # torch.compile can provide significant speedups (10-20% in testing)
@@ -247,11 +266,100 @@ class LanguageModelSAETrainingRunner:
247
266
  def _parse_cfg_args(
248
267
  args: Sequence[str],
249
268
  ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
269
+ """
270
+ Parse command line arguments into a LanguageModelSAERunnerConfig.
271
+
272
+ This function first parses the architecture argument to determine which
273
+ concrete SAE config class to use, then parses the full configuration
274
+ with that concrete type.
275
+ """
250
276
  if len(args) == 0:
251
277
  args = ["--help"]
278
+
279
+ # First, parse only the architecture to determine which concrete class to use
280
+ architecture_parser = ArgumentParser(
281
+ description="Parse architecture to determine SAE config class",
282
+ exit_on_error=False,
283
+ add_help=False, # Don't add help to avoid conflicts
284
+ )
285
+ architecture_parser.add_argument(
286
+ "--architecture",
287
+ type=str,
288
+ choices=["standard", "gated", "jumprelu", "topk"],
289
+ default="standard",
290
+ help="SAE architecture to use",
291
+ )
292
+
293
+ # Parse known args to extract architecture, ignore unknown args for now
294
+ arch_args, remaining_args = architecture_parser.parse_known_args(args)
295
+ architecture = arch_args.architecture
296
+
297
+ # Remove architecture from remaining args if it exists
298
+ filtered_args = []
299
+ skip_next = False
300
+ for arg in remaining_args:
301
+ if skip_next:
302
+ skip_next = False
303
+ continue
304
+ if arg == "--architecture":
305
+ skip_next = True # Skip the next argument (the architecture value)
306
+ continue
307
+ filtered_args.append(arg)
308
+
309
+ # Create a custom wrapper class that simple_parsing can handle
310
+ def create_config_class(
311
+ sae_config_type: type[TrainingSAEConfig],
312
+ ) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
313
+ """Create a concrete config class for the given SAE config type."""
314
+
315
+ # Create the base config without the sae field
316
+ from dataclasses import field as dataclass_field
317
+ from dataclasses import fields, make_dataclass
318
+
319
+ # Get all fields from LanguageModelSAERunnerConfig except the generic sae field
320
+ base_fields = []
321
+ for field_obj in fields(LanguageModelSAERunnerConfig):
322
+ if field_obj.name != "sae":
323
+ base_fields.append((field_obj.name, field_obj.type, field_obj))
324
+
325
+ # Add the concrete sae field
326
+ base_fields.append(
327
+ (
328
+ "sae",
329
+ sae_config_type,
330
+ dataclass_field(
331
+ default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
332
+ ),
333
+ )
334
+ )
335
+
336
+ # Create the concrete class
337
+ return make_dataclass(
338
+ f"{sae_config_type.__name__}RunnerConfig",
339
+ base_fields,
340
+ bases=(LanguageModelSAERunnerConfig,),
341
+ )
342
+
343
+ # Map architecture to concrete config class
344
+ sae_config_map = {
345
+ "standard": StandardTrainingSAEConfig,
346
+ "gated": GatedTrainingSAEConfig,
347
+ "jumprelu": JumpReLUTrainingSAEConfig,
348
+ "topk": TopKTrainingSAEConfig,
349
+ }
350
+
351
+ sae_config_type = sae_config_map[architecture]
352
+ concrete_config_class = create_config_class(sae_config_type)
353
+
354
+ # Now parse the full configuration with the concrete type
252
355
  parser = ArgumentParser(exit_on_error=False)
253
- parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
254
- return parser.parse_args(args).cfg
356
+ parser.add_arguments(concrete_config_class, dest="cfg")
357
+
358
+ # Parse the filtered arguments (without --architecture)
359
+ parsed_args = parser.parse_args(filtered_args)
360
+
361
+ # Return the parsed configuration
362
+ return parsed_args.cfg
255
363
 
256
364
 
257
365
  # moved into its own function to make it easier to test
@@ -26,6 +26,22 @@ from sae_lens.loading.pretrained_saes_directory import (
26
26
  from sae_lens.registry import get_sae_class
27
27
  from sae_lens.util import filter_valid_dataclass_fields
28
28
 
29
+ LLM_METADATA_KEYS = {
30
+ "model_name",
31
+ "hook_name",
32
+ "model_class_name",
33
+ "hook_head_index",
34
+ "model_from_pretrained_kwargs",
35
+ "prepend_bos",
36
+ "exclude_special_tokens",
37
+ "neuronpedia_id",
38
+ "context_size",
39
+ "seqpos_slice",
40
+ "dataset_path",
41
+ "sae_lens_version",
42
+ "sae_lens_training_version",
43
+ }
44
+
29
45
 
30
46
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
31
47
  class PretrainedSaeHuggingfaceLoader(Protocol):
@@ -207,6 +223,10 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
207
223
  new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
208
224
  new_cfg.setdefault("architecture", "standard")
209
225
  new_cfg.setdefault("neuronpedia_id", None)
226
+ new_cfg.setdefault(
227
+ "reshape_activations",
228
+ "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
229
+ )
210
230
 
211
231
  if "normalize_activations" in new_cfg and isinstance(
212
232
  new_cfg["normalize_activations"], bool
@@ -231,11 +251,9 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
231
251
  if architecture == "topk":
232
252
  sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
233
253
 
234
- # import here to avoid circular import
235
- from sae_lens.saes.sae import SAEMetadata
236
-
237
- meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
238
- sae_cfg_dict["metadata"] = meta_dict
254
+ sae_cfg_dict["metadata"] = {
255
+ k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
256
+ }
239
257
  sae_cfg_dict["architecture"] = architecture
240
258
  return sae_cfg_dict
241
259
 
@@ -271,6 +289,7 @@ def get_connor_rob_hook_z_config_from_hf(
271
289
  "context_size": 128,
272
290
  "normalize_activations": "none",
273
291
  "dataset_trust_remote_code": True,
292
+ "reshape_activations": "hook_z",
274
293
  **(cfg_overrides or {}),
275
294
  }
276
295
 
@@ -168,10 +168,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
168
168
 
169
169
  # Magnitude path
170
170
  magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
171
- if self.training and self.cfg.noise_scale > 0:
172
- magnitude_pre_activation += (
173
- torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
174
- )
175
171
  magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
176
172
 
177
173
  feature_magnitudes = self.activation_fn(magnitude_pre_activation)
@@ -105,7 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
105
105
  JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
106
106
  using a JumpReLU activation. For each unit, if its pre-activation is
107
107
  <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
108
- activation function (e.g., ReLU, tanh-relu, etc.).
108
+ activation function (e.g., ReLU etc.).
109
109
 
110
110
  It implements:
111
111
  - initialize_weights: sets up parameters, including a threshold.
@@ -142,7 +142,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
142
142
  sae_in = self.process_sae_in(x)
143
143
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
144
144
 
145
- # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
145
+ # 1) Apply the base "activation_fn" from config (e.g., ReLU).
146
146
  base_acts = self.activation_fn(hidden_pre)
147
147
 
148
148
  # 2) Zero out any unit whose (hidden_pre <= threshold).
@@ -191,8 +191,8 @@ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
191
191
  Configuration class for training a JumpReLUTrainingSAE.
192
192
  """
193
193
 
194
- jumprelu_init_threshold: float = 0.001
195
- jumprelu_bandwidth: float = 0.001
194
+ jumprelu_init_threshold: float = 0.01
195
+ jumprelu_bandwidth: float = 0.05
196
196
  l0_coefficient: float = 1.0
197
197
  l0_warm_up_steps: int = 0
198
198
 
@@ -257,12 +257,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
257
257
  sae_in = self.process_sae_in(x)
258
258
 
259
259
  hidden_pre = sae_in @ self.W_enc + self.b_enc
260
-
261
- if self.training and self.cfg.noise_scale > 0:
262
- hidden_pre = (
263
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
264
- )
265
-
266
260
  feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
267
261
 
268
262
  return feature_acts, hidden_pre # type: ignore
@@ -1,5 +1,6 @@
1
1
  """Base classes for Sparse Autoencoders (SAEs)."""
2
2
 
3
+ import copy
3
4
  import json
4
5
  import warnings
5
6
  from abc import ABC, abstractmethod
@@ -59,23 +60,91 @@ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
59
60
  T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
60
61
 
61
62
 
62
- @dataclass
63
63
  class SAEMetadata:
64
64
  """Core metadata about how this SAE should be used, if known."""
65
65
 
66
- model_name: str | None = None
67
- hook_name: str | None = None
68
- model_class_name: str | None = None
69
- hook_head_index: int | None = None
70
- model_from_pretrained_kwargs: dict[str, Any] | None = None
71
- prepend_bos: bool | None = None
72
- exclude_special_tokens: bool | list[int] | None = None
73
- neuronpedia_id: str | None = None
74
- context_size: int | None = None
75
- seqpos_slice: tuple[int | None, ...] | None = None
76
- dataset_path: str | None = None
77
- sae_lens_version: str = field(default_factory=lambda: __version__)
78
- sae_lens_training_version: str = field(default_factory=lambda: __version__)
66
+ def __init__(self, **kwargs: Any):
67
+ # Set default version fields with their current behavior
68
+ self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
69
+ self.sae_lens_training_version = kwargs.pop(
70
+ "sae_lens_training_version", __version__
71
+ )
72
+
73
+ # Set all other attributes dynamically
74
+ for key, value in kwargs.items():
75
+ setattr(self, key, value)
76
+
77
+ def __getattr__(self, name: str) -> None:
78
+ """Return None for any missing attribute (like defaultdict)"""
79
+ return
80
+
81
+ def __setattr__(self, name: str, value: Any) -> None:
82
+ """Allow setting any attribute"""
83
+ super().__setattr__(name, value)
84
+
85
+ def __getitem__(self, key: str) -> Any:
86
+ """Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
87
+ return getattr(self, key)
88
+
89
+ def __setitem__(self, key: str, value: Any) -> None:
90
+ """Allow dictionary-style assignment: metadata['key'] = value"""
91
+ setattr(self, key, value)
92
+
93
+ def __contains__(self, key: str) -> bool:
94
+ """Allow 'in' operator: 'key' in metadata"""
95
+ # Only return True if the attribute was explicitly set (not just defaulting to None)
96
+ return key in self.__dict__
97
+
98
+ def get(self, key: str, default: Any = None) -> Any:
99
+ """Dictionary-style get with default"""
100
+ value = getattr(self, key)
101
+ # If the attribute wasn't explicitly set and we got None from __getattr__,
102
+ # use the provided default instead
103
+ if key not in self.__dict__ and value is None:
104
+ return default
105
+ return value
106
+
107
+ def keys(self):
108
+ """Return all explicitly set attribute names"""
109
+ return self.__dict__.keys()
110
+
111
+ def values(self):
112
+ """Return all explicitly set attribute values"""
113
+ return self.__dict__.values()
114
+
115
+ def items(self):
116
+ """Return all explicitly set attribute name-value pairs"""
117
+ return self.__dict__.items()
118
+
119
+ def to_dict(self) -> dict[str, Any]:
120
+ """Convert to dictionary for serialization"""
121
+ return self.__dict__.copy()
122
+
123
+ @classmethod
124
+ def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
125
+ """Create from dictionary"""
126
+ return cls(**data)
127
+
128
+ def __repr__(self) -> str:
129
+ return f"SAEMetadata({self.__dict__})"
130
+
131
+ def __eq__(self, other: object) -> bool:
132
+ if not isinstance(other, SAEMetadata):
133
+ return False
134
+ return self.__dict__ == other.__dict__
135
+
136
+ def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
137
+ """Support for deep copying"""
138
+
139
+ return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
140
+
141
+ def __getstate__(self) -> dict[str, Any]:
142
+ """Support for pickling"""
143
+ return self.__dict__
144
+
145
+ def __setstate__(self, state: dict[str, Any]) -> None:
146
+ """Support for unpickling"""
147
+ self.__dict__.update(state)
79
148
 
80
149
 
81
150
  @dataclass
@@ -99,7 +168,7 @@ class SAEConfig(ABC):
99
168
 
100
169
  def to_dict(self) -> dict[str, Any]:
101
170
  res = {field.name: getattr(self, field.name) for field in fields(self)}
102
- res["metadata"] = asdict(self.metadata)
171
+ res["metadata"] = self.metadata.to_dict()
103
172
  res["architecture"] = self.architecture()
104
173
  return res
105
174
 
@@ -124,7 +193,7 @@ class SAEConfig(ABC):
124
193
  "layer_norm",
125
194
  ]:
126
195
  raise ValueError(
127
- f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
196
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
128
197
  )
129
198
 
130
199
 
@@ -238,9 +307,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
238
307
 
239
308
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
240
309
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
241
-
242
310
  elif self.cfg.normalize_activations == "layer_norm":
243
-
311
+ # we need to scale the norm of the input and store the scaling factor
244
312
  def run_time_activation_ln_in(
245
313
  x: torch.Tensor, eps: float = 1e-5
246
314
  ) -> torch.Tensor:
@@ -522,7 +590,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
522
590
  device: str = "cpu",
523
591
  force_download: bool = False,
524
592
  converter: PretrainedSaeHuggingfaceLoader | None = None,
525
- ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
593
+ ) -> T_SAE:
526
594
  """
527
595
  Load a pretrained SAE from the Hugging Face model hub.
528
596
 
@@ -530,7 +598,28 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
530
598
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
531
599
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
532
600
  device: The device to load the SAE on.
533
- return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
601
+ """
602
+ return cls.from_pretrained_with_cfg_and_sparsity(
603
+ release, sae_id, device, force_download, converter=converter
604
+ )[0]
605
+
606
+ @classmethod
607
+ def from_pretrained_with_cfg_and_sparsity(
608
+ cls: Type[T_SAE],
609
+ release: str,
610
+ sae_id: str,
611
+ device: str = "cpu",
612
+ force_download: bool = False,
613
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
614
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
615
+ """
616
+ Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
617
+ In SAELens <= 5.x.x, this was called SAE.from_pretrained().
618
+
619
+ Args:
620
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
621
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
622
+ device: The device to load the SAE on.
534
623
  """
535
624
 
536
625
  # get sae directory
@@ -646,8 +735,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
646
735
 
647
736
  @dataclass(kw_only=True)
648
737
  class TrainingSAEConfig(SAEConfig, ABC):
649
- noise_scale: float = 0.0
650
- mse_loss_normalization: str | None = None
651
738
  # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
652
739
  # 0.1 corresponds to the "heuristic" initialization, use None to disable
653
740
  decoder_init_norm: float | None = 0.1
@@ -680,9 +767,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
680
767
  def from_dict(
681
768
  cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
682
769
  ) -> T_TRAINING_SAE_CONFIG:
683
- # remove any keys that are not in the dataclass
684
- # since we sometimes enhance the config with the whole LM runner config
685
- valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
686
770
  cfg_class = cls
687
771
  if "architecture" in config_dict:
688
772
  cfg_class = get_sae_training_class(config_dict["architecture"])[1]
@@ -690,6 +774,9 @@ class TrainingSAEConfig(SAEConfig, ABC):
690
774
  raise ValueError(
691
775
  f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
692
776
  )
777
+ # remove any keys that are not in the dataclass
778
+ # since we sometimes enhance the config with the whole LM runner config
779
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
693
780
  if "metadata" in config_dict:
694
781
  valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
695
782
  return cfg_class(**valid_config_dict)
@@ -698,6 +785,7 @@ class TrainingSAEConfig(SAEConfig, ABC):
698
785
  return {
699
786
  **super().to_dict(),
700
787
  **asdict(self),
788
+ "metadata": self.metadata.to_dict(),
701
789
  "architecture": self.architecture(),
702
790
  }
703
791
 
@@ -708,12 +796,14 @@ class TrainingSAEConfig(SAEConfig, ABC):
708
796
  Creates a dictionary containing attributes corresponding to the fields
709
797
  defined in the base SAEConfig class.
710
798
  """
711
- base_config_field_names = {f.name for f in fields(SAEConfig)}
799
+ base_sae_cfg_class = get_sae_class(self.architecture())[1]
800
+ base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
712
801
  result_dict = {
713
802
  field_name: getattr(self, field_name)
714
803
  for field_name in base_config_field_names
715
804
  }
716
805
  result_dict["architecture"] = self.architecture()
806
+ result_dict["metadata"] = self.metadata.to_dict()
717
807
  return result_dict
718
808
 
719
809
 
@@ -726,7 +816,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
726
816
  # Turn off hook_z reshaping for training mode - the activation store
727
817
  # is expected to handle reshaping before passing data to the SAE
728
818
  self.turn_off_forward_pass_hook_z_reshaping()
729
- self.mse_loss_fn = self._get_mse_loss_fn()
819
+ self.mse_loss_fn = mse_loss
730
820
 
731
821
  @abstractmethod
732
822
  def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
@@ -861,27 +951,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
861
951
  """
862
952
  return self.process_state_dict_for_saving(state_dict)
863
953
 
864
- def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
865
- """Get the MSE loss function based on config."""
866
-
867
- def standard_mse_loss_fn(
868
- preds: torch.Tensor, target: torch.Tensor
869
- ) -> torch.Tensor:
870
- return torch.nn.functional.mse_loss(preds, target, reduction="none")
871
-
872
- def batch_norm_mse_loss_fn(
873
- preds: torch.Tensor, target: torch.Tensor
874
- ) -> torch.Tensor:
875
- target_centered = target - target.mean(dim=0, keepdim=True)
876
- normalization = target_centered.norm(dim=-1, keepdim=True)
877
- return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
878
- normalization + 1e-6
879
- )
880
-
881
- if self.cfg.mse_loss_normalization == "dense_batch":
882
- return batch_norm_mse_loss_fn
883
- return standard_mse_loss_fn
884
-
885
954
  @torch.no_grad()
886
955
  def remove_gradient_parallel_to_decoder_directions(self) -> None:
887
956
  """Remove gradient components parallel to decoder directions."""
@@ -943,3 +1012,7 @@ def _disable_hooks(sae: SAE[Any]):
943
1012
  finally:
944
1013
  for hook_name, hook in sae.hook_dict.items():
945
1014
  setattr(sae, hook_name, hook)
1015
+
1016
+
1017
+ def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1018
+ return torch.nn.functional.mse_loss(preds, target, reduction="none")
@@ -67,7 +67,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
67
67
  sae_in = self.process_sae_in(x)
68
68
  # Compute the pre-activation values
69
69
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
70
- # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
70
+ # Apply the activation function (e.g., ReLU, depending on config)
71
71
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
72
72
 
73
73
  def decode(
@@ -81,7 +81,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
81
81
  sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
82
  # 2) hook reconstruction
83
83
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
84
- # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
84
+ # 4) optional out-normalization (e.g. constant_norm_rescale)
85
85
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
86
86
  # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
87
87
  return self.reshape_fn_out(sae_out_pre, self.d_head)
@@ -136,16 +136,9 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
136
136
  sae_in = self.process_sae_in(x)
137
137
  # Compute the pre-activation (and allow for a hook if desired)
138
138
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
139
- # Add noise during training for robustness (scaled by noise_scale from the configuration)
140
- if self.training and self.cfg.noise_scale > 0:
141
- hidden_pre_noised = (
142
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
143
- )
144
- else:
145
- hidden_pre_noised = hidden_pre
146
139
  # Apply the activation function (and any post-activation hook)
147
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
148
- return feature_acts, hidden_pre_noised
140
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
141
+ return feature_acts, hidden_pre
149
142
 
150
143
  def calculate_aux_loss(
151
144
  self,
@@ -91,8 +91,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
91
91
  ) -> Float[torch.Tensor, "... d_sae"]:
92
92
  """
93
93
  Converts input x into feature activations.
94
- Uses topk activation from the config (cfg.activation_fn == "topk")
95
- under the hood.
94
+ Uses topk activation under the hood.
96
95
  """
97
96
  sae_in = self.process_sae_in(x)
98
97
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
@@ -116,6 +115,13 @@ class TopKSAE(SAE[TopKSAEConfig]):
116
115
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
117
116
  return TopK(self.cfg.k)
118
117
 
118
+ @override
119
+ @torch.no_grad()
120
+ def fold_W_dec_norm(self) -> None:
121
+ raise NotImplementedError(
122
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
123
+ )
124
+
119
125
 
120
126
  @dataclass
121
127
  class TopKTrainingSAEConfig(TrainingSAEConfig):
@@ -156,18 +162,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
156
162
  sae_in = self.process_sae_in(x)
157
163
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
158
164
 
159
- # Inject noise if training
160
- if self.training and self.cfg.noise_scale > 0:
161
- hidden_pre_noised = (
162
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
163
- )
164
- else:
165
- hidden_pre_noised = hidden_pre
166
-
167
165
  # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
168
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
169
- return feature_acts, hidden_pre_noised
166
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
167
+ return feature_acts, hidden_pre
170
168
 
169
+ @override
171
170
  def calculate_aux_loss(
172
171
  self,
173
172
  step_input: TrainStepInput,
@@ -184,6 +183,13 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
184
183
  )
185
184
  return {"auxiliary_reconstruction_loss": topk_loss}
186
185
 
186
+ @override
187
+ @torch.no_grad()
188
+ def fold_W_dec_norm(self) -> None:
189
+ raise NotImplementedError(
190
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
191
+ )
192
+
187
193
  @override
188
194
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
189
195
  return TopK(self.cfg.k)
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  from statistics import mean
4
4
 
5
5
  import torch
6
- from tqdm import tqdm
6
+ from tqdm.auto import tqdm
7
7
 
8
8
  from sae_lens.training.types import DataProvider
9
9
 
@@ -161,8 +161,6 @@ class ActivationsStore:
161
161
  ) -> ActivationsStore:
162
162
  if sae.cfg.metadata.hook_name is None:
163
163
  raise ValueError("hook_name is required")
164
- if sae.cfg.metadata.hook_head_index is None:
165
- raise ValueError("hook_head_index is required")
166
164
  if sae.cfg.metadata.context_size is None:
167
165
  raise ValueError("context_size is required")
168
166
  if sae.cfg.metadata.prepend_bos is None:
@@ -7,7 +7,7 @@ import torch
7
7
  import wandb
8
8
  from safetensors.torch import save_file
9
9
  from torch.optim import Adam
10
- from tqdm import tqdm
10
+ from tqdm.auto import tqdm
11
11
 
12
12
  from sae_lens import __version__
13
13
  from sae_lens.config import SAETrainerConfig
@@ -161,6 +161,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
161
161
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
162
162
 
163
163
  def fit(self) -> T_TRAINING_SAE:
164
+ self.sae.to(self.cfg.device)
164
165
  pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
165
166
 
166
167
  if self.sae.cfg.normalize_activations == "expected_average_only_in":
@@ -194,10 +195,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
194
195
  )
195
196
  self.activation_scaler.scaling_factor = None
196
197
 
197
- # save final sae group to checkpoints folder
198
+ # save final inference sae group to checkpoints folder
198
199
  self.save_checkpoint(
199
200
  checkpoint_name=f"final_{self.n_training_samples}",
200
201
  wandb_aliases=["final_model"],
202
+ save_inference_model=True,
201
203
  )
202
204
 
203
205
  pbar.close()
@@ -207,11 +209,17 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
207
209
  self,
208
210
  checkpoint_name: str,
209
211
  wandb_aliases: list[str] | None = None,
212
+ save_inference_model: bool = False,
210
213
  ) -> None:
211
214
  checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
212
215
  checkpoint_path.mkdir(exist_ok=True, parents=True)
213
216
 
214
- weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
217
+ save_fn = (
218
+ self.sae.save_inference_model
219
+ if save_inference_model
220
+ else self.sae.save_model
221
+ )
222
+ weights_path, cfg_path = save_fn(str(checkpoint_path))
215
223
 
216
224
  sparsity_path = checkpoint_path / SPARSITY_FILENAME
217
225
  save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
@@ -88,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
88
88
  ```python
89
89
  from sae_lens import SAE
90
90
 
91
- sae, cfg_dict, sparsity = SAE.from_pretrained("{repo_id}", "<sae_id>")
91
+ sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
92
92
  ```
93
93
  """
94
94
  )
@@ -1,101 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
- import torch
4
- import tqdm
5
-
6
-
7
- def weighted_average(points: torch.Tensor, weights: torch.Tensor):
8
- weights = weights / weights.sum()
9
- return (points * weights.view(-1, 1)).sum(dim=0)
10
-
11
-
12
- @torch.no_grad()
13
- def geometric_median_objective(
14
- median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
15
- ) -> torch.Tensor:
16
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
17
-
18
- return (norms * weights).sum()
19
-
20
-
21
- def compute_geometric_median(
22
- points: torch.Tensor,
23
- weights: torch.Tensor | None = None,
24
- eps: float = 1e-6,
25
- maxiter: int = 100,
26
- ftol: float = 1e-20,
27
- do_log: bool = False,
28
- ):
29
- """
30
- :param points: ``torch.Tensor`` of shape ``(n, d)``
31
- :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
32
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
33
- Equivalently, this is a smoothing parameter. Default 1e-6.
34
- :param maxiter: Maximum number of Weiszfeld iterations. Default 100
35
- :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
36
- :param do_log: If true will return a log of function values encountered through the course of the algorithm
37
- :return: SimpleNamespace object with fields
38
- - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
39
- - `termination`: string explaining how the algorithm terminated.
40
- - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
41
- """
42
- with torch.no_grad():
43
- if weights is None:
44
- weights = torch.ones((points.shape[0],), device=points.device)
45
- # initialize median estimate at mean
46
- new_weights = weights
47
- median = weighted_average(points, weights)
48
- objective_value = geometric_median_objective(median, points, weights)
49
- logs = [objective_value] if do_log else None
50
-
51
- # Weiszfeld iterations
52
- early_termination = False
53
- pbar = tqdm.tqdm(range(maxiter))
54
- for _ in pbar:
55
- prev_obj_value = objective_value
56
-
57
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
58
- new_weights = weights / torch.clamp(norms, min=eps)
59
- median = weighted_average(points, new_weights)
60
- objective_value = geometric_median_objective(median, points, weights)
61
-
62
- if logs is not None:
63
- logs.append(objective_value)
64
- if abs(prev_obj_value - objective_value) <= ftol * objective_value:
65
- early_termination = True
66
- break
67
-
68
- pbar.set_description(f"Objective value: {objective_value:.4f}")
69
-
70
- median = weighted_average(points, new_weights) # allow autodiff to track it
71
- return SimpleNamespace(
72
- median=median,
73
- new_weights=new_weights,
74
- termination=(
75
- "function value converged within tolerance"
76
- if early_termination
77
- else "maximum iterations reached"
78
- ),
79
- logs=logs,
80
- )
81
-
82
-
83
- if __name__ == "__main__":
84
- import time
85
-
86
- TOLERANCE = 1e-2
87
-
88
- dim1 = 10000
89
- dim2 = 768
90
- device = "cuda" if torch.cuda.is_available() else "cpu"
91
-
92
- sample = (
93
- torch.randn((dim1, dim2), device=device) * 100
94
- ) # seems to be the order of magnitude of the actual use case
95
- weights = torch.randn((dim1,), device=device)
96
-
97
- torch.tensor(weights, device=device)
98
-
99
- tic = time.perf_counter()
100
- new = compute_geometric_median(sample, weights=weights, maxiter=100)
101
- print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201
File without changes
File without changes