sae-lens 5.9.1__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,10 +1,15 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "5.9.1"
2
+ __version__ = "6.0.0-rc.1"
3
3
 
4
4
  import logging
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
8
+ from sae_lens.saes.gated_sae import GatedSAE, GatedTrainingSAE
9
+ from sae_lens.saes.jumprelu_sae import JumpReLUSAE, JumpReLUTrainingSAE
10
+ from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE
11
+ from sae_lens.saes.topk_sae import TopKSAE, TopKTrainingSAE
12
+
8
13
  from .analysis.hooked_sae_transformer import HookedSAETransformer
9
14
  from .cache_activations_runner import CacheActivationsRunner
10
15
  from .config import (
@@ -13,17 +18,26 @@ from .config import (
13
18
  PretokenizeRunnerConfig,
14
19
  )
15
20
  from .evals import run_evals
16
- from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
17
- from .sae import SAE, SAEConfig
18
- from .sae_training_runner import SAETrainingRunner
19
- from .toolkit.pretrained_sae_loaders import (
21
+ from .loading.pretrained_sae_loaders import (
20
22
  PretrainedSaeDiskLoader,
21
23
  PretrainedSaeHuggingfaceLoader,
22
24
  )
25
+ from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
26
+ from .regsitry import register_sae_class, register_sae_training_class
27
+ from .sae_training_runner import SAETrainingRunner
28
+ from .saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
23
29
  from .training.activations_store import ActivationsStore
24
- from .training.training_sae import TrainingSAE, TrainingSAEConfig
25
30
  from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
26
31
 
32
+ register_sae_class("standard", StandardSAE)
33
+ register_sae_training_class("standard", StandardTrainingSAE)
34
+ register_sae_class("gated", GatedSAE)
35
+ register_sae_training_class("gated", GatedTrainingSAE)
36
+ register_sae_class("topk", TopKSAE)
37
+ register_sae_training_class("topk", TopKTrainingSAE)
38
+ register_sae_class("jumprelu", JumpReLUSAE)
39
+ register_sae_training_class("jumprelu", JumpReLUTrainingSAE)
40
+
27
41
  __all__ = [
28
42
  "SAE",
29
43
  "SAEConfig",
@@ -42,4 +56,6 @@ __all__ = [
42
56
  "upload_saes_to_huggingface",
43
57
  "PretrainedSaeHuggingfaceLoader",
44
58
  "PretrainedSaeDiskLoader",
59
+ "register_sae_class",
60
+ "register_sae_training_class",
45
61
  ]
@@ -8,7 +8,7 @@ from transformer_lens.ActivationCache import ActivationCache
8
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
9
9
  from transformer_lens.HookedTransformer import HookedTransformer
10
10
 
11
- from sae_lens.sae import SAE
11
+ from sae_lens.saes.sae import SAE
12
12
 
13
13
  SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
14
14
  LossPerToken = Float[torch.Tensor, "batch pos-1"]
@@ -275,7 +275,7 @@ class HookedSAETransformer(HookedTransformer):
275
275
  .. code-block:: python
276
276
 
277
277
  from transformer_lens import HookedSAETransformer
278
- from sae_lens.sae import SAE
278
+ from sae_lens.saes.sae import SAE
279
279
 
280
280
  model = HookedSAETransformer.from_pretrained('gpt2-small')
281
281
  sae_cfg = SAEConfig(...)
sae_lens/config.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import math
3
3
  import os
4
- from dataclasses import dataclass, field
4
+ from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
5
6
  from typing import Any, Literal, cast
6
7
 
7
8
  import simple_parsing
@@ -53,6 +54,52 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
53
54
  return simple_parsing.helpers.dict_field(default, type=json_dict, **kwargs)
54
55
 
55
56
 
57
+ @dataclass
58
+ class LoggingConfig:
59
+ # WANDB
60
+ log_to_wandb: bool = True
61
+ log_activations_store_to_wandb: bool = False
62
+ log_optimizer_state_to_wandb: bool = False
63
+ wandb_project: str = "sae_lens_training"
64
+ wandb_id: str | None = None
65
+ run_name: str | None = None
66
+ wandb_entity: str | None = None
67
+ wandb_log_frequency: int = 10
68
+ eval_every_n_wandb_logs: int = 100 # logs every 100 steps.
69
+
70
+ def log(
71
+ self,
72
+ trainer: Any, # avoid import cycle from importing SAETrainer
73
+ weights_path: Path | str,
74
+ cfg_path: Path | str,
75
+ sparsity_path: Path | str | None,
76
+ wandb_aliases: list[str] | None = None,
77
+ ) -> None:
78
+ # Avoid wandb saving errors such as:
79
+ # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
80
+ sae_name = trainer.sae.get_name().replace("/", "__")
81
+
82
+ # save model weights and cfg
83
+ model_artifact = wandb.Artifact(
84
+ sae_name,
85
+ type="model",
86
+ metadata=dict(trainer.cfg.__dict__),
87
+ )
88
+ model_artifact.add_file(str(weights_path))
89
+ model_artifact.add_file(str(cfg_path))
90
+ wandb.log_artifact(model_artifact, aliases=wandb_aliases)
91
+
92
+ # save log feature sparsity
93
+ sparsity_artifact = wandb.Artifact(
94
+ f"{sae_name}_log_feature_sparsity",
95
+ type="log_feature_sparsity",
96
+ metadata=dict(trainer.cfg.__dict__),
97
+ )
98
+ if sparsity_path is not None:
99
+ sparsity_artifact.add_file(str(sparsity_path))
100
+ wandb.log_artifact(sparsity_artifact)
101
+
102
+
56
103
  @dataclass
57
104
  class LanguageModelSAERunnerConfig:
58
105
  """
@@ -245,16 +292,7 @@ class LanguageModelSAERunnerConfig:
245
292
  n_eval_batches: int = 10
246
293
  eval_batch_size_prompts: int | None = None # useful if evals cause OOM
247
294
 
248
- # WANDB
249
- log_to_wandb: bool = True
250
- log_activations_store_to_wandb: bool = False
251
- log_optimizer_state_to_wandb: bool = False
252
- wandb_project: str = "mats_sae_training_language_model"
253
- wandb_id: str | None = None
254
- run_name: str | None = None
255
- wandb_entity: str | None = None
256
- wandb_log_frequency: int = 10
257
- eval_every_n_wandb_logs: int = 100 # logs every 1000 steps.
295
+ logger: LoggingConfig = field(default_factory=LoggingConfig)
258
296
 
259
297
  # Misc
260
298
  resume: bool = False
@@ -310,8 +348,8 @@ class LanguageModelSAERunnerConfig:
310
348
  self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
311
349
  )
312
350
 
313
- if self.run_name is None:
314
- self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
351
+ if self.logger.run_name is None:
352
+ self.logger.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
315
353
 
316
354
  if self.model_from_pretrained_kwargs is None:
317
355
  if self.model_class_name == "HookedTransformer":
@@ -356,7 +394,7 @@ class LanguageModelSAERunnerConfig:
356
394
  if self.lr_end is None:
357
395
  self.lr_end = self.lr / 10
358
396
 
359
- unique_id = self.wandb_id
397
+ unique_id = self.logger.wandb_id
360
398
  if unique_id is None:
361
399
  unique_id = cast(
362
400
  Any, wandb
@@ -388,7 +426,9 @@ class LanguageModelSAERunnerConfig:
388
426
  ) // self.train_batch_size_tokens
389
427
  logger.info(f"Total training steps: {total_training_steps}")
390
428
 
391
- total_wandb_updates = total_training_steps // self.wandb_log_frequency
429
+ total_wandb_updates = (
430
+ total_training_steps // self.logger.wandb_log_frequency
431
+ )
392
432
  logger.info(f"Total wandb updates: {total_wandb_updates}")
393
433
 
394
434
  # how many times will we sample dead neurons?
@@ -445,7 +485,7 @@ class LanguageModelSAERunnerConfig:
445
485
  "hook_name": self.hook_name,
446
486
  "hook_layer": self.hook_layer,
447
487
  "hook_head_index": self.hook_head_index,
448
- "activation_fn_str": self.activation_fn,
488
+ "activation_fn": self.activation_fn,
449
489
  "apply_b_dec_to_input": self.apply_b_dec_to_input,
450
490
  "context_size": self.context_size,
451
491
  "prepend_bos": self.prepend_bos,
@@ -478,13 +518,16 @@ class LanguageModelSAERunnerConfig:
478
518
  }
479
519
 
480
520
  def to_dict(self) -> dict[str, Any]:
481
- return {
482
- **self.__dict__,
483
- # some args may not be serializable by default
484
- "dtype": str(self.dtype),
485
- "device": str(self.device),
486
- "act_store_device": str(self.act_store_device),
487
- }
521
+ # Make a shallow copy of config’s dictionary
522
+ d = dict(self.__dict__)
523
+
524
+ d["logger"] = asdict(self.logger)
525
+
526
+ # Overwrite fields that might not be JSON-serializable
527
+ d["dtype"] = str(self.dtype)
528
+ d["device"] = str(self.device)
529
+ d["act_store_device"] = str(self.act_store_device)
530
+ return d
488
531
 
489
532
  def to_json(self, path: str) -> None:
490
533
  if not os.path.exists(os.path.dirname(path)):
sae_lens/evals.py CHANGED
@@ -19,8 +19,8 @@ from tqdm import tqdm
19
19
  from transformer_lens import HookedTransformer
20
20
  from transformer_lens.hook_points import HookedRootModule
21
21
 
22
- from sae_lens.sae import SAE
23
- from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
22
+ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
23
+ from sae_lens.saes.sae import SAE
24
24
  from sae_lens.training.activations_store import ActivationsStore
25
25
 
26
26
 
@@ -279,7 +279,6 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
279
279
  unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
280
280
 
281
281
  encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()
282
- encoder_bias = sae.b_enc.cpu().tolist()
283
282
  encoder_decoder_cosine_sim = (
284
283
  torch.nn.functional.cosine_similarity(
285
284
  unit_norm_decoder.T,
@@ -289,11 +288,13 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
289
288
  .tolist()
290
289
  )
291
290
 
292
- return {
293
- "encoder_bias": encoder_bias,
291
+ metrics = {
294
292
  "encoder_norm": encoder_norms,
295
293
  "encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
296
294
  }
295
+ if hasattr(sae, "b_enc") and sae.b_enc is not None:
296
+ metrics["encoder_bias"] = sae.b_enc.cpu().tolist() # type: ignore
297
+ return metrics
297
298
 
298
299
 
299
300
  def get_downstream_reconstruction_metrics(
@@ -17,7 +17,7 @@ from sae_lens.config import (
17
17
  SAE_WEIGHTS_FILENAME,
18
18
  SPARSITY_FILENAME,
19
19
  )
20
- from sae_lens.toolkit.pretrained_saes_directory import (
20
+ from sae_lens.loading.pretrained_saes_directory import (
21
21
  get_config_overrides,
22
22
  get_pretrained_saes_directory,
23
23
  get_repo_id_and_folder_name,
@@ -174,30 +174,38 @@ def get_sae_lens_config_from_disk(
174
174
 
175
175
 
176
176
  def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
177
+ rename_keys_map = {
178
+ "hook_point": "hook_name",
179
+ "hook_point_layer": "hook_layer",
180
+ "hook_point_head_index": "hook_head_index",
181
+ "activation_fn_str": "activation_fn",
182
+ }
183
+ new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
184
+
177
185
  # Set default values for backwards compatibility
178
- cfg_dict.setdefault("prepend_bos", True)
179
- cfg_dict.setdefault("dataset_trust_remote_code", True)
180
- cfg_dict.setdefault("apply_b_dec_to_input", True)
181
- cfg_dict.setdefault("finetuning_scaling_factor", False)
182
- cfg_dict.setdefault("sae_lens_training_version", None)
183
- cfg_dict.setdefault("activation_fn_str", cfg_dict.get("activation_fn", "relu"))
184
- cfg_dict.setdefault("architecture", "standard")
185
- cfg_dict.setdefault("neuronpedia_id", None)
186
-
187
- if "normalize_activations" in cfg_dict and isinstance(
188
- cfg_dict["normalize_activations"], bool
186
+ new_cfg.setdefault("prepend_bos", True)
187
+ new_cfg.setdefault("dataset_trust_remote_code", True)
188
+ new_cfg.setdefault("apply_b_dec_to_input", True)
189
+ new_cfg.setdefault("finetuning_scaling_factor", False)
190
+ new_cfg.setdefault("sae_lens_training_version", None)
191
+ new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
192
+ new_cfg.setdefault("architecture", "standard")
193
+ new_cfg.setdefault("neuronpedia_id", None)
194
+
195
+ if "normalize_activations" in new_cfg and isinstance(
196
+ new_cfg["normalize_activations"], bool
189
197
  ):
190
198
  # backwards compatibility
191
- cfg_dict["normalize_activations"] = (
199
+ new_cfg["normalize_activations"] = (
192
200
  "none"
193
- if not cfg_dict["normalize_activations"]
201
+ if not new_cfg["normalize_activations"]
194
202
  else "expected_average_only_in"
195
203
  )
196
204
 
197
- cfg_dict.setdefault("normalize_activations", "none")
198
- cfg_dict.setdefault("device", "cpu")
205
+ new_cfg.setdefault("normalize_activations", "none")
206
+ new_cfg.setdefault("device", "cpu")
199
207
 
200
- return cfg_dict
208
+ return new_cfg
201
209
 
202
210
 
203
211
  def get_connor_rob_hook_z_config_from_hf(
@@ -223,7 +231,7 @@ def get_connor_rob_hook_z_config_from_hf(
223
231
  "hook_name": old_cfg_dict["act_name"],
224
232
  "hook_layer": old_cfg_dict["layer"],
225
233
  "hook_head_index": None,
226
- "activation_fn_str": "relu",
234
+ "activation_fn": "relu",
227
235
  "apply_b_dec_to_input": True,
228
236
  "finetuning_scaling_factor": False,
229
237
  "sae_lens_training_version": None,
@@ -372,7 +380,7 @@ def get_gemma_2_config_from_hf(
372
380
  "hook_name": hook_name,
373
381
  "hook_layer": layer,
374
382
  "hook_head_index": None,
375
- "activation_fn_str": "relu",
383
+ "activation_fn": "relu",
376
384
  "finetuning_scaling_factor": False,
377
385
  "sae_lens_training_version": None,
378
386
  "prepend_bos": True,
@@ -485,7 +493,7 @@ def get_llama_scope_config_from_hf(
485
493
  "hook_name": old_cfg_dict["hook_point_in"],
486
494
  "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
487
495
  "hook_head_index": None,
488
- "activation_fn_str": "relu",
496
+ "activation_fn": "relu",
489
497
  "finetuning_scaling_factor": False,
490
498
  "sae_lens_training_version": None,
491
499
  "prepend_bos": True,
@@ -597,8 +605,8 @@ def get_dictionary_learning_config_1_from_hf(
597
605
 
598
606
  hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
599
607
 
600
- activation_fn_str = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
601
- activation_fn_kwargs = {"k": trainer["k"]} if activation_fn_str == "topk" else {}
608
+ activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
609
+ activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
602
610
 
603
611
  return {
604
612
  "architecture": (
@@ -612,7 +620,7 @@ def get_dictionary_learning_config_1_from_hf(
612
620
  "hook_name": hook_point_name,
613
621
  "hook_layer": trainer["layer"],
614
622
  "hook_head_index": None,
615
- "activation_fn_str": activation_fn_str,
623
+ "activation_fn": activation_fn,
616
624
  "activation_fn_kwargs": activation_fn_kwargs,
617
625
  "apply_b_dec_to_input": True,
618
626
  "finetuning_scaling_factor": False,
@@ -655,7 +663,7 @@ def get_deepseek_r1_config_from_hf(
655
663
  "dataset_path": "lmsys/lmsys-chat-1m",
656
664
  "dataset_trust_remote_code": True,
657
665
  "sae_lens_training_version": None,
658
- "activation_fn_str": "relu",
666
+ "activation_fn": "relu",
659
667
  "normalize_activations": "none",
660
668
  "device": device,
661
669
  "apply_b_dec_to_input": False,
@@ -810,7 +818,7 @@ def get_llama_scope_r1_distill_config_from_hf(
810
818
  "hook_name": huggingface_cfg_dict["hook_point_in"],
811
819
  "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
812
820
  "hook_head_index": None,
813
- "activation_fn_str": "relu",
821
+ "activation_fn": "relu",
814
822
  "finetuning_scaling_factor": False,
815
823
  "sae_lens_training_version": None,
816
824
  "prepend_bos": True,
sae_lens/regsitry.py ADDED
@@ -0,0 +1,34 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ # avoid circular imports
4
+ if TYPE_CHECKING:
5
+ from sae_lens.saes.sae import SAE, TrainingSAE
6
+
7
+ SAE_CLASS_REGISTRY: dict[str, "type[SAE]"] = {}
8
+ SAE_TRAINING_CLASS_REGISTRY: dict[str, "type[TrainingSAE]"] = {}
9
+
10
+
11
+ def register_sae_class(architecture: str, sae_class: "type[SAE]") -> None:
12
+ if architecture in SAE_CLASS_REGISTRY:
13
+ raise ValueError(
14
+ f"SAE class for architecture {architecture} already registered."
15
+ )
16
+ SAE_CLASS_REGISTRY[architecture] = sae_class
17
+
18
+
19
+ def register_sae_training_class(
20
+ architecture: str, sae_training_class: "type[TrainingSAE]"
21
+ ) -> None:
22
+ if architecture in SAE_TRAINING_CLASS_REGISTRY:
23
+ raise ValueError(
24
+ f"SAE training class for architecture {architecture} already registered."
25
+ )
26
+ SAE_TRAINING_CLASS_REGISTRY[architecture] = sae_training_class
27
+
28
+
29
+ def get_sae_class(architecture: str) -> "type[SAE]":
30
+ return SAE_CLASS_REGISTRY[architecture]
31
+
32
+
33
+ def get_sae_training_class(architecture: str) -> "type[TrainingSAE]":
34
+ return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -13,10 +13,10 @@ from transformer_lens.hook_points import HookedRootModule
13
13
  from sae_lens import logger
14
14
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
15
15
  from sae_lens.load_model import load_model
16
+ from sae_lens.saes.sae import TrainingSAE, TrainingSAEConfig
16
17
  from sae_lens.training.activations_store import ActivationsStore
17
18
  from sae_lens.training.geometric_median import compute_geometric_median
18
19
  from sae_lens.training.sae_trainer import SAETrainer
19
- from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
20
20
 
21
21
 
22
22
  class InterruptedException(Exception):
@@ -73,14 +73,14 @@ class SAETrainingRunner:
73
73
 
74
74
  if override_sae is None:
75
75
  if self.cfg.from_pretrained_path is not None:
76
- self.sae = TrainingSAE.load_from_pretrained(
76
+ self.sae = TrainingSAE.load_from_disk(
77
77
  self.cfg.from_pretrained_path, self.cfg.device
78
78
  )
79
79
  else:
80
- self.sae = TrainingSAE(
80
+ self.sae = TrainingSAE.from_dict(
81
81
  TrainingSAEConfig.from_dict(
82
82
  self.cfg.get_training_sae_cfg_dict(),
83
- )
83
+ ).to_dict()
84
84
  )
85
85
  self._init_sae_group_b_decs()
86
86
  else:
@@ -91,13 +91,13 @@ class SAETrainingRunner:
91
91
  Run the training of the SAE.
92
92
  """
93
93
 
94
- if self.cfg.log_to_wandb:
94
+ if self.cfg.logger.log_to_wandb:
95
95
  wandb.init(
96
- project=self.cfg.wandb_project,
97
- entity=self.cfg.wandb_entity,
96
+ project=self.cfg.logger.wandb_project,
97
+ entity=self.cfg.logger.wandb_entity,
98
98
  config=cast(Any, self.cfg),
99
- name=self.cfg.run_name,
100
- id=self.cfg.wandb_id,
99
+ name=self.cfg.logger.run_name,
100
+ id=self.cfg.logger.wandb_id,
101
101
  )
102
102
 
103
103
  trainer = SAETrainer(
@@ -111,7 +111,7 @@ class SAETrainingRunner:
111
111
  self._compile_if_needed()
112
112
  sae = self.run_trainer_with_interruption_handling(trainer)
113
113
 
114
- if self.cfg.log_to_wandb:
114
+ if self.cfg.logger.log_to_wandb:
115
115
  wandb.finish()
116
116
 
117
117
  return sae
@@ -175,7 +175,7 @@ class SAETrainingRunner:
175
175
  layer_acts,
176
176
  maxiter=100,
177
177
  ).median
178
- self.sae.initialize_b_dec_with_precalculated(median) # type: ignore
178
+ self.sae.initialize_b_dec_with_precalculated(median)
179
179
  elif self.cfg.b_dec_init_method == "mean":
180
180
  self.activations_store.set_norm_scaling_factor_if_needed()
181
181
  layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
@@ -208,29 +208,14 @@ class SAETrainingRunner:
208
208
  with open(cfg_path, "w") as f:
209
209
  json.dump(config, f)
210
210
 
211
- if trainer.cfg.log_to_wandb:
212
- # Avoid wandb saving errors such as:
213
- # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
214
- sae_name = trainer.sae.get_name().replace("/", "__")
215
-
216
- # save model weights and cfg
217
- model_artifact = wandb.Artifact(
218
- sae_name,
219
- type="model",
220
- metadata=dict(trainer.cfg.__dict__),
221
- )
222
- model_artifact.add_file(str(weights_path))
223
- model_artifact.add_file(str(cfg_path))
224
- wandb.log_artifact(model_artifact, aliases=wandb_aliases)
225
-
226
- # save log feature sparsity
227
- sparsity_artifact = wandb.Artifact(
228
- f"{sae_name}_log_feature_sparsity",
229
- type="log_feature_sparsity",
230
- metadata=dict(trainer.cfg.__dict__),
211
+ if trainer.cfg.logger.log_to_wandb:
212
+ trainer.cfg.logger.log(
213
+ trainer,
214
+ weights_path,
215
+ cfg_path,
216
+ sparsity_path=sparsity_path,
217
+ wandb_aliases=wandb_aliases,
231
218
  )
232
- sparsity_artifact.add_file(str(sparsity_path))
233
- wandb.log_artifact(sparsity_artifact)
234
219
 
235
220
 
236
221
  def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig: