mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,604 @@
1
+ """Training utilities for SAE models using overcomplete's training functions."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, TYPE_CHECKING
5
+ from datetime import datetime
6
+ import json
7
+ import logging
8
+ import gc
9
+
10
+ import torch
11
+
12
+ from amber.store.store_dataloader import StoreDataloader
13
+ from amber.utils import get_logger
14
+
15
+ if TYPE_CHECKING:
16
+ from amber.mechanistic.sae.sae import Sae
17
+ from amber.store.store import Store
18
+
19
+
20
+ @dataclass
21
+ class SaeTrainingConfig:
22
+ """Configuration for SAE training (compatible with overcomplete.train_sae)."""
23
+ epochs: int = 1
24
+ batch_size: int = 1024
25
+ lr: float = 1e-3
26
+ l1_lambda: float = 0.0
27
+ device: str | torch.device = "cpu"
28
+ dtype: Optional[torch.dtype] = None
29
+ max_batches_per_epoch: Optional[int] = None
30
+ verbose: bool = False
31
+ use_amp: bool = True
32
+ amp_dtype: Optional[torch.dtype] = None
33
+ grad_accum_steps: int = 1
34
+ clip_grad: float = 1.0 # Gradient clipping (overcomplete parameter)
35
+ monitoring: int = 1 # 0=silent, 1=basic, 2=detailed (overcomplete parameter)
36
+ scheduler: Optional[Any] = None # Learning rate scheduler (overcomplete parameter)
37
+ max_nan_fallbacks: int = 5 # For train_sae_amp (overcomplete parameter)
38
+ # Wandb configuration
39
+ use_wandb: bool = False # Enable wandb logging
40
+ wandb_project: Optional[str] = None # Wandb project name (defaults to "sae-training" if not set)
41
+ wandb_entity: Optional[str] = None # Wandb entity/team name
42
+ wandb_name: Optional[str] = None # Wandb run name (defaults to run_id if not set)
43
+ wandb_tags: Optional[list[str]] = None # Additional tags for wandb run
44
+ wandb_config: Optional[dict[str, Any]] = None # Additional config to log to wandb
45
+ wandb_mode: str = "online" # Wandb mode: "online", "offline", or "disabled"
46
+ wandb_slow_metrics_frequency: int = 50 # Log slow metrics (L0, dead features) every N epochs (default: 50)
47
+ memory_efficient: bool = False # Enable memory-efficient processing (moves tensors to CPU, clears cache)
48
+
49
+
50
+ class SaeTrainer:
51
+ """
52
+ Composite trainer class for SAE models using overcomplete's training functions.
53
+
54
+ This trainer handles training of any SAE that has a sae_engine attribute
55
+ compatible with overcomplete's train_sae functions.
56
+ """
57
+
58
+ def __init__(self, sae: "Sae") -> None:
59
+ """
60
+ Initialize SaeTrainer.
61
+
62
+ Args:
63
+ sae: The SAE instance to train
64
+ """
65
+ self.sae = sae
66
+ self.logger = get_logger(__name__)
67
+
68
+ def train(
69
+ self,
70
+ store: "Store",
71
+ run_id: str,
72
+ layer_signature: str | int,
73
+ config: SaeTrainingConfig | None = None,
74
+ training_run_id: str | None = None
75
+ ) -> dict[str, Any]:
76
+ self._ensure_overcomplete_available()
77
+ cfg = config or SaeTrainingConfig()
78
+
79
+ wandb_run = self._initialize_wandb(cfg, run_id)
80
+ device = self._setup_device(cfg)
81
+ optimizer = self._create_optimizer(cfg)
82
+ criterion = self._create_criterion(cfg)
83
+ dataloader = self._create_dataloader(store, run_id, layer_signature, cfg, device)
84
+ monitoring = self._configure_logging(cfg, run_id)
85
+
86
+ logs = self._run_training(cfg, dataloader, criterion, optimizer, device, monitoring)
87
+ history = self._process_training_logs(logs, cfg)
88
+ if cfg.memory_efficient:
89
+ self._clear_memory()
90
+
91
+ if wandb_run is not None:
92
+ self._log_to_wandb(wandb_run, history, cfg)
93
+
94
+ if cfg.verbose:
95
+ self.logger.info(f"[SaeTrainer] Training completed, processing {len(history['loss'])} epochs of results")
96
+ self.logger.info("[SaeTrainer] Completed training")
97
+
98
+ if training_run_id is None:
99
+ training_run_id = f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
100
+
101
+ if cfg.verbose:
102
+ self.logger.info(f"[SaeTrainer] Saving training outputs to store/runs/{training_run_id}/")
103
+
104
+ self._save_training_to_store(
105
+ store=store,
106
+ training_run_id=training_run_id,
107
+ run_id=run_id,
108
+ layer_signature=layer_signature,
109
+ history=history,
110
+ config=cfg
111
+ )
112
+
113
+ return {
114
+ "history": history,
115
+ "training_run_id": training_run_id
116
+ }
117
+
118
+ def _ensure_overcomplete_available(self) -> None:
119
+ try:
120
+ from overcomplete.sae.train import train_sae, train_sae_amp
121
+ except ImportError:
122
+ raise ImportError("overcomplete.sae.train module not available. Cannot use overcomplete training.")
123
+
124
+ def _initialize_wandb(self, cfg: SaeTrainingConfig, run_id: str) -> Any:
125
+ if not cfg.use_wandb:
126
+ return None
127
+
128
+ try:
129
+ import wandb
130
+ wandb_project = cfg.wandb_project or "sae-training"
131
+ wandb_name = cfg.wandb_name or run_id
132
+ wandb_mode = cfg.wandb_mode.lower() if cfg.wandb_mode else "online"
133
+
134
+ return wandb.init(
135
+ project=wandb_project,
136
+ entity=cfg.wandb_entity,
137
+ name=wandb_name,
138
+ mode=wandb_mode,
139
+ config={
140
+ "run_id": run_id,
141
+ "epochs": cfg.epochs,
142
+ "batch_size": cfg.batch_size,
143
+ "lr": cfg.lr,
144
+ "l1_lambda": cfg.l1_lambda,
145
+ "device": str(cfg.device),
146
+ "dtype": str(cfg.dtype) if cfg.dtype else None,
147
+ "use_amp": cfg.use_amp,
148
+ "clip_grad": cfg.clip_grad,
149
+ "max_batches_per_epoch": cfg.max_batches_per_epoch,
150
+ **(cfg.wandb_config or {}),
151
+ },
152
+ tags=cfg.wandb_tags or [],
153
+ )
154
+ except ImportError:
155
+ self.logger.warning("[SaeTrainer] wandb not installed, skipping wandb logging")
156
+ self.logger.warning("[SaeTrainer] Install with: pip install wandb")
157
+ return None
158
+ except Exception as e:
159
+ self.logger.warning(f"[SaeTrainer] Unexpected error initializing wandb: {e}")
160
+ self.logger.warning("[SaeTrainer] Continuing training without wandb logging")
161
+ return None
162
+
163
+ def _setup_device(self, cfg: SaeTrainingConfig) -> torch.device:
164
+ device_str = str(cfg.device)
165
+ device = torch.device(device_str)
166
+ self.sae.sae_engine.to(device)
167
+
168
+ if cfg.dtype is not None:
169
+ try:
170
+ self.sae.sae_engine.to(device, dtype=cfg.dtype)
171
+ except (TypeError, AttributeError):
172
+ self.sae.sae_engine.to(device)
173
+
174
+ return device
175
+
176
+ def _create_optimizer(self, cfg: SaeTrainingConfig) -> torch.optim.Optimizer:
177
+ return torch.optim.AdamW(self.sae.sae_engine.parameters(), lr=cfg.lr)
178
+
179
+ def _create_criterion(self, cfg: SaeTrainingConfig):
180
+ def criterion(x: torch.Tensor, x_hat: torch.Tensor, z_pre: torch.Tensor, z: torch.Tensor,
181
+ dictionary: torch.Tensor) -> torch.Tensor:
182
+ recon_loss = ((x_hat - x) ** 2).mean()
183
+ l1_penalty = z.abs().mean() * cfg.l1_lambda if cfg.l1_lambda > 0 else torch.tensor(0.0, device=x.device)
184
+ return recon_loss + l1_penalty
185
+ return criterion
186
+
187
+ def _create_dataloader(self, store: "Store", run_id: str, layer_signature: str | int, cfg: SaeTrainingConfig, device: torch.device) -> StoreDataloader:
188
+ return StoreDataloader(
189
+ store=store,
190
+ run_id=run_id,
191
+ layer=layer_signature,
192
+ key="activations",
193
+ batch_size=cfg.batch_size,
194
+ dtype=cfg.dtype,
195
+ device=device,
196
+ max_batches=cfg.max_batches_per_epoch,
197
+ logger_instance=self.logger
198
+ )
199
+
200
+ def _configure_logging(self, cfg: SaeTrainingConfig, run_id: str) -> int:
201
+ monitoring = cfg.monitoring
202
+ if cfg.verbose and monitoring < 2:
203
+ monitoring = 2
204
+
205
+ if cfg.verbose:
206
+ device_str = str(cfg.device)
207
+ self.logger.info(
208
+ f"[SaeTrainer] Starting training run_id={run_id} epochs={cfg.epochs} batch_size={cfg.batch_size} "
209
+ f"device={device_str} use_amp={cfg.use_amp}"
210
+ )
211
+
212
+ overcomplete_logger = logging.getLogger("overcomplete")
213
+ if cfg.verbose:
214
+ overcomplete_logger.setLevel(logging.INFO)
215
+ if not overcomplete_logger.handlers:
216
+ handler = logging.StreamHandler()
217
+ handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
218
+ overcomplete_logger.addHandler(handler)
219
+ overcomplete_logger.propagate = True
220
+ else:
221
+ overcomplete_logger.setLevel(logging.WARNING)
222
+
223
+ return monitoring
224
+
225
+ def _run_training(self, cfg: SaeTrainingConfig, dataloader: StoreDataloader, criterion, optimizer: torch.optim.Optimizer,
226
+ device: torch.device, monitoring: int) -> dict[str, Any]:
227
+ from overcomplete.sae.train import train_sae, train_sae_amp
228
+
229
+ device_str = str(device)
230
+
231
+ try:
232
+ if cfg.use_amp and device.type in ("cuda", "cpu"):
233
+ if cfg.verbose:
234
+ self.logger.info(f"[SaeTrainer] Using train_sae_amp with monitoring={monitoring}")
235
+ logs = train_sae_amp(
236
+ model=self.sae.sae_engine,
237
+ dataloader=dataloader,
238
+ criterion=criterion,
239
+ optimizer=optimizer,
240
+ scheduler=cfg.scheduler,
241
+ nb_epochs=cfg.epochs,
242
+ clip_grad=cfg.clip_grad,
243
+ monitoring=monitoring,
244
+ device=device_str,
245
+ max_nan_fallbacks=cfg.max_nan_fallbacks
246
+ )
247
+ else:
248
+ if cfg.verbose:
249
+ self.logger.info(f"[SaeTrainer] Using train_sae with monitoring={monitoring}")
250
+ logs = train_sae(
251
+ model=self.sae.sae_engine,
252
+ dataloader=dataloader,
253
+ criterion=criterion,
254
+ optimizer=optimizer,
255
+ scheduler=cfg.scheduler,
256
+ nb_epochs=cfg.epochs,
257
+ clip_grad=cfg.clip_grad,
258
+ monitoring=monitoring,
259
+ device=device_str
260
+ )
261
+
262
+ if cfg.verbose:
263
+ self.logger.info(
264
+ f"[SaeTrainer] Overcomplete training function returned, processing {len(logs.get('avg_loss', []))} epoch results...")
265
+
266
+ return logs
267
+ except Exception as e:
268
+ self.logger.error(f"[SaeTrainer] Error during training: {e}")
269
+ import traceback
270
+ self.logger.error(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
271
+ raise
272
+
273
+ def _process_training_logs(self, logs: dict[str, Any], cfg: SaeTrainingConfig) -> dict[str, list[float]]:
274
+ history: dict[str, list[float]] = {
275
+ "loss": logs.get("avg_loss", []),
276
+ "recon_mse": [],
277
+ "l1": [],
278
+ "r2": [],
279
+ "l0": [],
280
+ "dead_features_pct": [],
281
+ }
282
+
283
+ self._extract_r2_and_mse(history, logs)
284
+ self._extract_sparsity_metrics(history, logs, cfg)
285
+
286
+ return history
287
+
288
+ def _extract_r2_and_mse(self, history: dict[str, list[float]], logs: dict[str, Any]) -> None:
289
+ if "r2" in logs:
290
+ history["r2"] = logs["r2"]
291
+ history["recon_mse"] = [(1.0 - r2) for r2 in logs["r2"]]
292
+ else:
293
+ history["r2"] = [0.0] * len(history["loss"])
294
+
295
+ def _extract_sparsity_metrics(self, history: dict[str, list[float]], logs: dict[str, Any], cfg: SaeTrainingConfig) -> None:
296
+ memory_efficient = cfg.memory_efficient
297
+ num_epochs = len(history["loss"])
298
+
299
+ if "dead_features" in logs and isinstance(logs["dead_features"], list) and len(logs["dead_features"]) == num_epochs:
300
+ history["dead_features_pct"] = logs["dead_features"]
301
+ else:
302
+ history["dead_features_pct"] = [0.0] * num_epochs
303
+
304
+ history["l0"] = [0.0] * num_epochs
305
+
306
+ if "z" in logs and logs["z"]:
307
+ first_z = logs["z"][0] if len(logs["z"]) > 0 else None
308
+ is_flat_list = isinstance(first_z, torch.Tensor) if first_z is not None else False
309
+
310
+ if is_flat_list:
311
+ batches_per_epoch = len(logs["z"]) // num_epochs if num_epochs > 0 else len(logs["z"])
312
+
313
+ for epoch_idx in range(num_epochs):
314
+ start_idx = epoch_idx * batches_per_epoch
315
+ end_idx = start_idx + batches_per_epoch
316
+ epoch_z_tensors = logs["z"][start_idx:end_idx]
317
+
318
+ if epoch_z_tensors:
319
+ l1_val = self._compute_l1(epoch_z_tensors, memory_efficient)
320
+ history["l1"].append(l1_val)
321
+ else:
322
+ history["l1"].append(0.0)
323
+ else:
324
+ n_latents = self._get_n_latents()
325
+ slow_metrics_freq = cfg.wandb_slow_metrics_frequency if cfg.use_wandb else 1
326
+
327
+ for epoch_idx, z_batch_list in enumerate(logs["z"]):
328
+ if isinstance(z_batch_list, list) and len(z_batch_list) > 0:
329
+ l1_val = self._compute_l1(z_batch_list, memory_efficient)
330
+ history["l1"].append(l1_val)
331
+
332
+ should_compute_slow = (epoch_idx % slow_metrics_freq == 0) or (epoch_idx == len(logs["z"]) - 1)
333
+
334
+ if should_compute_slow:
335
+ l0, dead_pct = self._compute_slow_metrics(z_batch_list, n_latents, memory_efficient)
336
+ if len(history["l0"]) <= epoch_idx:
337
+ history["l0"].extend([None] * (epoch_idx + 1 - len(history["l0"])))
338
+ history["l0"][epoch_idx] = l0
339
+ if len(history["dead_features_pct"]) <= epoch_idx:
340
+ history["dead_features_pct"].extend([None] * (epoch_idx + 1 - len(history["dead_features_pct"])))
341
+ history["dead_features_pct"][epoch_idx] = dead_pct
342
+ else:
343
+ if len(history["l0"]) <= epoch_idx:
344
+ history["l0"].extend([None] * (epoch_idx + 1 - len(history["l0"])))
345
+ history["l0"].append(None)
346
+ if len(history["dead_features_pct"]) <= epoch_idx:
347
+ history["dead_features_pct"].extend([None] * (epoch_idx + 1 - len(history["dead_features_pct"])))
348
+ history["dead_features_pct"].append(None)
349
+
350
+ if memory_efficient:
351
+ del z_batch_list
352
+ if epoch_idx % 5 == 0:
353
+ self._clear_memory()
354
+ else:
355
+ history["l1"].append(0.0)
356
+ elif "z_sparsity" in logs:
357
+ history["l1"] = logs["z_sparsity"]
358
+ else:
359
+ history["l1"] = [0.0] * num_epochs
360
+
361
+ if memory_efficient and "z" in logs:
362
+ del logs["z"]
363
+ self._clear_memory()
364
+
365
+ def _get_n_latents(self) -> Optional[int]:
366
+ if hasattr(self.sae, 'context') and hasattr(self.sae.context, 'n_latents'):
367
+ return self.sae.context.n_latents
368
+ return None
369
+
370
+ def _clear_memory(self) -> None:
371
+ """Clear GPU/MPS memory cache and run garbage collection."""
372
+ gc.collect()
373
+ if torch.backends.mps.is_available():
374
+ torch.mps.empty_cache()
375
+ elif torch.cuda.is_available():
376
+ torch.cuda.empty_cache()
377
+
378
+ def _compute_l1(self, z_batch_list: list[torch.Tensor], memory_efficient: bool = False) -> float:
379
+ l1_vals = []
380
+ for z in z_batch_list:
381
+ if isinstance(z, torch.Tensor):
382
+ if memory_efficient and z.device.type != "cpu":
383
+ z_cpu = z.cpu()
384
+ l1_vals.append(z_cpu.abs().mean().item())
385
+ del z
386
+ else:
387
+ l1_vals.append(z.abs().mean().item())
388
+ return sum(l1_vals) / len(l1_vals) if l1_vals else 0.0
389
+
390
+ def _compute_slow_metrics(self, z_batch_list: list[torch.Tensor], n_latents: Optional[int], memory_efficient: bool = False) -> tuple[float, float]:
391
+ l0_vals = []
392
+ all_z_epoch = []
393
+
394
+ for z in z_batch_list:
395
+ if isinstance(z, torch.Tensor):
396
+ if memory_efficient and z.device.type != "cpu":
397
+ z_cpu = z.cpu()
398
+ active = (z_cpu.abs() > 1e-6).float()
399
+ l0_vals.append(active.sum(dim=-1).mean().item())
400
+ all_z_epoch.append(z_cpu)
401
+ del z
402
+ else:
403
+ active = (z.abs() > 1e-6).float()
404
+ l0_vals.append(active.sum(dim=-1).mean().item())
405
+ all_z_epoch.append(z)
406
+
407
+ l0 = sum(l0_vals) / len(l0_vals) if l0_vals else 0.0
408
+
409
+ if all_z_epoch and n_latents is not None:
410
+ z_concatenated = torch.cat(all_z_epoch, dim=0)
411
+ feature_activity = (z_concatenated.abs() > 1e-6).any(dim=0).float()
412
+ dead_count = (feature_activity == 0).sum().item()
413
+ dead_features_pct = dead_count / n_latents * 100.0 if n_latents > 0 else 0.0
414
+ if memory_efficient:
415
+ del z_concatenated, feature_activity
416
+ else:
417
+ dead_features_pct = 0.0
418
+
419
+ if memory_efficient:
420
+ del all_z_epoch
421
+ return l0, dead_features_pct
422
+
423
+ def _log_to_wandb(self, wandb_run: Any, history: dict[str, list[float]], cfg: SaeTrainingConfig) -> None:
424
+ try:
425
+ num_epochs = len(history["loss"])
426
+ slow_metrics_freq = cfg.wandb_slow_metrics_frequency
427
+
428
+ for epoch in range(1, num_epochs + 1):
429
+ epoch_idx = epoch - 1
430
+ should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)
431
+
432
+ metrics = self._build_epoch_metrics(history, epoch, epoch_idx, cfg, should_log_slow)
433
+ wandb_run.log(metrics)
434
+
435
+ final_metrics = self._build_final_metrics(history, num_epochs)
436
+ wandb_run.summary.update(final_metrics)
437
+
438
+ if cfg.verbose:
439
+ try:
440
+ url = wandb_run.url
441
+ self.logger.info(f"[SaeTrainer] Metrics logged to wandb: {url}")
442
+ except (AttributeError, RuntimeError):
443
+ self.logger.info("[SaeTrainer] Metrics logged to wandb (offline mode)")
444
+ except Exception as e:
445
+ self.logger.warning(f"[SaeTrainer] Failed to log metrics to wandb: {e}")
446
+
447
+ def _build_epoch_metrics(self, history: dict[str, list[float]], epoch: int, epoch_idx: int,
448
+ cfg: SaeTrainingConfig, should_log_slow: bool) -> dict[str, Any]:
449
+ metrics = {
450
+ "epoch": epoch,
451
+ "train/loss": history["loss"][epoch_idx] if epoch_idx < len(history["loss"]) else 0.0,
452
+ "train/reconstruction_mse": history["recon_mse"][epoch_idx] if epoch_idx < len(history["recon_mse"]) else 0.0,
453
+ "train/r2_score": history["r2"][epoch_idx] if epoch_idx < len(history["r2"]) else 0.0,
454
+ "train/l1_penalty": history["l1"][epoch_idx] if epoch_idx < len(history["l1"]) else 0.0,
455
+ "train/learning_rate": cfg.lr,
456
+ }
457
+
458
+ if should_log_slow:
459
+ l0_val = self._get_metric_value(history["l0"], epoch_idx)
460
+ dead_pct = self._get_metric_value(history["dead_features_pct"], epoch_idx)
461
+ metrics["train/l0_sparsity"] = l0_val
462
+ metrics["train/dead_features_pct"] = dead_pct
463
+
464
+ return metrics
465
+
466
+ def _get_metric_value(self, values: list[float | None], idx: int) -> float:
467
+ if idx < len(values) and values[idx] is not None:
468
+ return values[idx]
469
+ return self._get_last_known_value(values, idx)
470
+
471
+ def _get_last_known_value(self, values: list[float | None], idx: int) -> float:
472
+ for i in range(idx, -1, -1):
473
+ if i < len(values) and values[i] is not None:
474
+ return values[i]
475
+ return 0.0
476
+
477
+ def _build_final_metrics(self, history: dict[str, list[float]], num_epochs: int) -> dict[str, Any]:
478
+ final_l0 = self._get_last_known_value(history["l0"], len(history["l0"]) - 1) if history["l0"] else 0.0
479
+ final_dead_pct = self._get_last_known_value(history["dead_features_pct"], len(history["dead_features_pct"]) - 1) if history["dead_features_pct"] else 0.0
480
+
481
+ final_metrics = {
482
+ "final/loss": history["loss"][-1] if history["loss"] else 0.0,
483
+ "final/reconstruction_mse": history["recon_mse"][-1] if history["recon_mse"] else 0.0,
484
+ "final/r2_score": history["r2"][-1] if history["r2"] else 0.0,
485
+ "final/l1_penalty": history["l1"][-1] if history["l1"] else 0.0,
486
+ "final/l0_sparsity": final_l0,
487
+ "final/dead_features_pct": final_dead_pct,
488
+ "training/num_epochs": num_epochs,
489
+ }
490
+
491
+ if history["loss"]:
492
+ best_loss_idx = min(range(len(history["loss"])), key=lambda i: history["loss"][i])
493
+ final_metrics["best/loss"] = history["loss"][best_loss_idx]
494
+ final_metrics["best/loss_epoch"] = best_loss_idx + 1
495
+
496
+ if history["r2"]:
497
+ best_r2_idx = max(range(len(history["r2"])), key=lambda i: history["r2"][i])
498
+ final_metrics["best/r2_score"] = history["r2"][best_r2_idx]
499
+ final_metrics["best/r2_epoch"] = best_r2_idx + 1
500
+
501
+ return final_metrics
502
+
503
+ def _save_training_to_store(
504
+ self,
505
+ store: "Store",
506
+ training_run_id: str,
507
+ run_id: str,
508
+ layer_signature: str | int,
509
+ history: dict[str, list[float]],
510
+ config: SaeTrainingConfig
511
+ ) -> None:
512
+ """Save training outputs (model, history, metadata) to store under training_run_id.
513
+
514
+ Args:
515
+ store: Store instance
516
+ training_run_id: Training run ID to save under
517
+ run_id: Original activation run ID used for training
518
+ layer_signature: Layer signature used for training
519
+ history: Training history dictionary
520
+ config: Training configuration
521
+ """
522
+ try:
523
+ run_path = store._run_key(training_run_id)
524
+ run_path.mkdir(parents=True, exist_ok=True)
525
+
526
+ model_path = run_path / "model.pt"
527
+ history_path = run_path / "history.json"
528
+
529
+ sae_state_dict = self.sae.sae_engine.state_dict()
530
+
531
+ amber_metadata = {
532
+ "concepts_state": {
533
+ 'multiplication': self.sae.concepts.multiplication.data.cpu().clone(),
534
+ 'bias': self.sae.concepts.bias.data.cpu().clone(),
535
+ },
536
+ "n_latents": self.sae.context.n_latents,
537
+ "n_inputs": self.sae.context.n_inputs,
538
+ "device": self.sae.context.device,
539
+ "layer_signature": self.sae.context.lm_layer_signature,
540
+ "model_id": self.sae.context.model_id,
541
+ }
542
+
543
+ if hasattr(self.sae, 'k'):
544
+ amber_metadata["k"] = self.sae.k
545
+
546
+ payload = {
547
+ "sae_state_dict": sae_state_dict,
548
+ "amber_metadata": amber_metadata,
549
+ }
550
+
551
+ torch.save(payload, model_path)
552
+
553
+ with open(history_path, "w") as f:
554
+ json.dump(history, f, indent=2)
555
+
556
+ training_metadata = {
557
+ "training_run_id": training_run_id,
558
+ "activation_run_id": run_id,
559
+ "layer_signature": str(layer_signature),
560
+ "model_id": self.sae.context.model_id if hasattr(self.sae.context, 'model_id') else None,
561
+ "sae_type": type(self.sae).__name__,
562
+ "training_config": {
563
+ "epochs": config.epochs,
564
+ "batch_size": config.batch_size,
565
+ "lr": config.lr,
566
+ "l1_lambda": config.l1_lambda,
567
+ "device": str(config.device),
568
+ "dtype": str(config.dtype) if config.dtype else None,
569
+ "use_amp": config.use_amp,
570
+ "clip_grad": config.clip_grad,
571
+ "monitoring": config.monitoring,
572
+ },
573
+ "final_metrics": {
574
+ "loss": history["loss"][-1] if history["loss"] else None,
575
+ "r2": history["r2"][-1] if history["r2"] else None,
576
+ "recon_mse": history["recon_mse"][-1] if history["recon_mse"] else None,
577
+ "l1": history["l1"][-1] if history["l1"] else None,
578
+ },
579
+ "n_epochs": len(history["loss"]),
580
+ "timestamp": datetime.now().isoformat(),
581
+ }
582
+
583
+ if history["l0"]:
584
+ final_l0 = [x for x in history["l0"] if x is not None]
585
+ if final_l0:
586
+ training_metadata["final_metrics"]["l0"] = final_l0[-1]
587
+
588
+ if history["dead_features_pct"]:
589
+ final_dead = [x for x in history["dead_features_pct"] if x is not None]
590
+ if final_dead:
591
+ training_metadata["final_metrics"]["dead_features_pct"] = final_dead[-1]
592
+
593
+ store.put_run_metadata(training_run_id, training_metadata)
594
+
595
+ if config.verbose:
596
+ self.logger.info(f"[SaeTrainer] Saved model to: {model_path}")
597
+ self.logger.info(f"[SaeTrainer] Saved history to: {history_path}")
598
+ self.logger.info(f"[SaeTrainer] Saved metadata to: runs/{training_run_id}/meta.json")
599
+
600
+ except Exception as e:
601
+ self.logger.warning(f"[SaeTrainer] Failed to save training outputs to store: {e}")
602
+ if config.verbose:
603
+ import traceback
604
+ self.logger.warning(f"[SaeTrainer] Traceback: {traceback.format_exc()}")