mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,604 @@
|
|
|
1
|
+
"""Training utilities for SAE models using overcomplete's training functions."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Optional, TYPE_CHECKING
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import gc
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from amber.store.store_dataloader import StoreDataloader
|
|
13
|
+
from amber.utils import get_logger
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from amber.mechanistic.sae.sae import Sae
|
|
17
|
+
from amber.store.store import Store
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class SaeTrainingConfig:
|
|
22
|
+
"""Configuration for SAE training (compatible with overcomplete.train_sae)."""
|
|
23
|
+
epochs: int = 1
|
|
24
|
+
batch_size: int = 1024
|
|
25
|
+
lr: float = 1e-3
|
|
26
|
+
l1_lambda: float = 0.0
|
|
27
|
+
device: str | torch.device = "cpu"
|
|
28
|
+
dtype: Optional[torch.dtype] = None
|
|
29
|
+
max_batches_per_epoch: Optional[int] = None
|
|
30
|
+
verbose: bool = False
|
|
31
|
+
use_amp: bool = True
|
|
32
|
+
amp_dtype: Optional[torch.dtype] = None
|
|
33
|
+
grad_accum_steps: int = 1
|
|
34
|
+
clip_grad: float = 1.0 # Gradient clipping (overcomplete parameter)
|
|
35
|
+
monitoring: int = 1 # 0=silent, 1=basic, 2=detailed (overcomplete parameter)
|
|
36
|
+
scheduler: Optional[Any] = None # Learning rate scheduler (overcomplete parameter)
|
|
37
|
+
max_nan_fallbacks: int = 5 # For train_sae_amp (overcomplete parameter)
|
|
38
|
+
# Wandb configuration
|
|
39
|
+
use_wandb: bool = False # Enable wandb logging
|
|
40
|
+
wandb_project: Optional[str] = None # Wandb project name (defaults to "sae-training" if not set)
|
|
41
|
+
wandb_entity: Optional[str] = None # Wandb entity/team name
|
|
42
|
+
wandb_name: Optional[str] = None # Wandb run name (defaults to run_id if not set)
|
|
43
|
+
wandb_tags: Optional[list[str]] = None # Additional tags for wandb run
|
|
44
|
+
wandb_config: Optional[dict[str, Any]] = None # Additional config to log to wandb
|
|
45
|
+
wandb_mode: str = "online" # Wandb mode: "online", "offline", or "disabled"
|
|
46
|
+
wandb_slow_metrics_frequency: int = 50 # Log slow metrics (L0, dead features) every N epochs (default: 50)
|
|
47
|
+
memory_efficient: bool = False # Enable memory-efficient processing (moves tensors to CPU, clears cache)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SaeTrainer:
|
|
51
|
+
"""
|
|
52
|
+
Composite trainer class for SAE models using overcomplete's training functions.
|
|
53
|
+
|
|
54
|
+
This trainer handles training of any SAE that has a sae_engine attribute
|
|
55
|
+
compatible with overcomplete's train_sae functions.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, sae: "Sae") -> None:
|
|
59
|
+
"""
|
|
60
|
+
Initialize SaeTrainer.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
sae: The SAE instance to train
|
|
64
|
+
"""
|
|
65
|
+
self.sae = sae
|
|
66
|
+
self.logger = get_logger(__name__)
|
|
67
|
+
|
|
68
|
+
def train(
|
|
69
|
+
self,
|
|
70
|
+
store: "Store",
|
|
71
|
+
run_id: str,
|
|
72
|
+
layer_signature: str | int,
|
|
73
|
+
config: SaeTrainingConfig | None = None,
|
|
74
|
+
training_run_id: str | None = None
|
|
75
|
+
) -> dict[str, Any]:
|
|
76
|
+
self._ensure_overcomplete_available()
|
|
77
|
+
cfg = config or SaeTrainingConfig()
|
|
78
|
+
|
|
79
|
+
wandb_run = self._initialize_wandb(cfg, run_id)
|
|
80
|
+
device = self._setup_device(cfg)
|
|
81
|
+
optimizer = self._create_optimizer(cfg)
|
|
82
|
+
criterion = self._create_criterion(cfg)
|
|
83
|
+
dataloader = self._create_dataloader(store, run_id, layer_signature, cfg, device)
|
|
84
|
+
monitoring = self._configure_logging(cfg, run_id)
|
|
85
|
+
|
|
86
|
+
logs = self._run_training(cfg, dataloader, criterion, optimizer, device, monitoring)
|
|
87
|
+
history = self._process_training_logs(logs, cfg)
|
|
88
|
+
if cfg.memory_efficient:
|
|
89
|
+
self._clear_memory()
|
|
90
|
+
|
|
91
|
+
if wandb_run is not None:
|
|
92
|
+
self._log_to_wandb(wandb_run, history, cfg)
|
|
93
|
+
|
|
94
|
+
if cfg.verbose:
|
|
95
|
+
self.logger.info(f"[SaeTrainer] Training completed, processing {len(history['loss'])} epochs of results")
|
|
96
|
+
self.logger.info("[SaeTrainer] Completed training")
|
|
97
|
+
|
|
98
|
+
if training_run_id is None:
|
|
99
|
+
training_run_id = f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
100
|
+
|
|
101
|
+
if cfg.verbose:
|
|
102
|
+
self.logger.info(f"[SaeTrainer] Saving training outputs to store/runs/{training_run_id}/")
|
|
103
|
+
|
|
104
|
+
self._save_training_to_store(
|
|
105
|
+
store=store,
|
|
106
|
+
training_run_id=training_run_id,
|
|
107
|
+
run_id=run_id,
|
|
108
|
+
layer_signature=layer_signature,
|
|
109
|
+
history=history,
|
|
110
|
+
config=cfg
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return {
|
|
114
|
+
"history": history,
|
|
115
|
+
"training_run_id": training_run_id
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
def _ensure_overcomplete_available(self) -> None:
|
|
119
|
+
try:
|
|
120
|
+
from overcomplete.sae.train import train_sae, train_sae_amp
|
|
121
|
+
except ImportError:
|
|
122
|
+
raise ImportError("overcomplete.sae.train module not available. Cannot use overcomplete training.")
|
|
123
|
+
|
|
124
|
+
def _initialize_wandb(self, cfg: SaeTrainingConfig, run_id: str) -> Any:
|
|
125
|
+
if not cfg.use_wandb:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
import wandb
|
|
130
|
+
wandb_project = cfg.wandb_project or "sae-training"
|
|
131
|
+
wandb_name = cfg.wandb_name or run_id
|
|
132
|
+
wandb_mode = cfg.wandb_mode.lower() if cfg.wandb_mode else "online"
|
|
133
|
+
|
|
134
|
+
return wandb.init(
|
|
135
|
+
project=wandb_project,
|
|
136
|
+
entity=cfg.wandb_entity,
|
|
137
|
+
name=wandb_name,
|
|
138
|
+
mode=wandb_mode,
|
|
139
|
+
config={
|
|
140
|
+
"run_id": run_id,
|
|
141
|
+
"epochs": cfg.epochs,
|
|
142
|
+
"batch_size": cfg.batch_size,
|
|
143
|
+
"lr": cfg.lr,
|
|
144
|
+
"l1_lambda": cfg.l1_lambda,
|
|
145
|
+
"device": str(cfg.device),
|
|
146
|
+
"dtype": str(cfg.dtype) if cfg.dtype else None,
|
|
147
|
+
"use_amp": cfg.use_amp,
|
|
148
|
+
"clip_grad": cfg.clip_grad,
|
|
149
|
+
"max_batches_per_epoch": cfg.max_batches_per_epoch,
|
|
150
|
+
**(cfg.wandb_config or {}),
|
|
151
|
+
},
|
|
152
|
+
tags=cfg.wandb_tags or [],
|
|
153
|
+
)
|
|
154
|
+
except ImportError:
|
|
155
|
+
self.logger.warning("[SaeTrainer] wandb not installed, skipping wandb logging")
|
|
156
|
+
self.logger.warning("[SaeTrainer] Install with: pip install wandb")
|
|
157
|
+
return None
|
|
158
|
+
except Exception as e:
|
|
159
|
+
self.logger.warning(f"[SaeTrainer] Unexpected error initializing wandb: {e}")
|
|
160
|
+
self.logger.warning("[SaeTrainer] Continuing training without wandb logging")
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
def _setup_device(self, cfg: SaeTrainingConfig) -> torch.device:
|
|
164
|
+
device_str = str(cfg.device)
|
|
165
|
+
device = torch.device(device_str)
|
|
166
|
+
self.sae.sae_engine.to(device)
|
|
167
|
+
|
|
168
|
+
if cfg.dtype is not None:
|
|
169
|
+
try:
|
|
170
|
+
self.sae.sae_engine.to(device, dtype=cfg.dtype)
|
|
171
|
+
except (TypeError, AttributeError):
|
|
172
|
+
self.sae.sae_engine.to(device)
|
|
173
|
+
|
|
174
|
+
return device
|
|
175
|
+
|
|
176
|
+
def _create_optimizer(self, cfg: SaeTrainingConfig) -> torch.optim.Optimizer:
|
|
177
|
+
return torch.optim.AdamW(self.sae.sae_engine.parameters(), lr=cfg.lr)
|
|
178
|
+
|
|
179
|
+
def _create_criterion(self, cfg: SaeTrainingConfig):
|
|
180
|
+
def criterion(x: torch.Tensor, x_hat: torch.Tensor, z_pre: torch.Tensor, z: torch.Tensor,
|
|
181
|
+
dictionary: torch.Tensor) -> torch.Tensor:
|
|
182
|
+
recon_loss = ((x_hat - x) ** 2).mean()
|
|
183
|
+
l1_penalty = z.abs().mean() * cfg.l1_lambda if cfg.l1_lambda > 0 else torch.tensor(0.0, device=x.device)
|
|
184
|
+
return recon_loss + l1_penalty
|
|
185
|
+
return criterion
|
|
186
|
+
|
|
187
|
+
def _create_dataloader(self, store: "Store", run_id: str, layer_signature: str | int, cfg: SaeTrainingConfig, device: torch.device) -> StoreDataloader:
|
|
188
|
+
return StoreDataloader(
|
|
189
|
+
store=store,
|
|
190
|
+
run_id=run_id,
|
|
191
|
+
layer=layer_signature,
|
|
192
|
+
key="activations",
|
|
193
|
+
batch_size=cfg.batch_size,
|
|
194
|
+
dtype=cfg.dtype,
|
|
195
|
+
device=device,
|
|
196
|
+
max_batches=cfg.max_batches_per_epoch,
|
|
197
|
+
logger_instance=self.logger
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def _configure_logging(self, cfg: SaeTrainingConfig, run_id: str) -> int:
|
|
201
|
+
monitoring = cfg.monitoring
|
|
202
|
+
if cfg.verbose and monitoring < 2:
|
|
203
|
+
monitoring = 2
|
|
204
|
+
|
|
205
|
+
if cfg.verbose:
|
|
206
|
+
device_str = str(cfg.device)
|
|
207
|
+
self.logger.info(
|
|
208
|
+
f"[SaeTrainer] Starting training run_id={run_id} epochs={cfg.epochs} batch_size={cfg.batch_size} "
|
|
209
|
+
f"device={device_str} use_amp={cfg.use_amp}"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
overcomplete_logger = logging.getLogger("overcomplete")
|
|
213
|
+
if cfg.verbose:
|
|
214
|
+
overcomplete_logger.setLevel(logging.INFO)
|
|
215
|
+
if not overcomplete_logger.handlers:
|
|
216
|
+
handler = logging.StreamHandler()
|
|
217
|
+
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
|
218
|
+
overcomplete_logger.addHandler(handler)
|
|
219
|
+
overcomplete_logger.propagate = True
|
|
220
|
+
else:
|
|
221
|
+
overcomplete_logger.setLevel(logging.WARNING)
|
|
222
|
+
|
|
223
|
+
return monitoring
|
|
224
|
+
|
|
225
|
+
def _run_training(self, cfg: SaeTrainingConfig, dataloader: StoreDataloader, criterion, optimizer: torch.optim.Optimizer,
|
|
226
|
+
device: torch.device, monitoring: int) -> dict[str, Any]:
|
|
227
|
+
from overcomplete.sae.train import train_sae, train_sae_amp
|
|
228
|
+
|
|
229
|
+
device_str = str(device)
|
|
230
|
+
|
|
231
|
+
try:
|
|
232
|
+
if cfg.use_amp and device.type in ("cuda", "cpu"):
|
|
233
|
+
if cfg.verbose:
|
|
234
|
+
self.logger.info(f"[SaeTrainer] Using train_sae_amp with monitoring={monitoring}")
|
|
235
|
+
logs = train_sae_amp(
|
|
236
|
+
model=self.sae.sae_engine,
|
|
237
|
+
dataloader=dataloader,
|
|
238
|
+
criterion=criterion,
|
|
239
|
+
optimizer=optimizer,
|
|
240
|
+
scheduler=cfg.scheduler,
|
|
241
|
+
nb_epochs=cfg.epochs,
|
|
242
|
+
clip_grad=cfg.clip_grad,
|
|
243
|
+
monitoring=monitoring,
|
|
244
|
+
device=device_str,
|
|
245
|
+
max_nan_fallbacks=cfg.max_nan_fallbacks
|
|
246
|
+
)
|
|
247
|
+
else:
|
|
248
|
+
if cfg.verbose:
|
|
249
|
+
self.logger.info(f"[SaeTrainer] Using train_sae with monitoring={monitoring}")
|
|
250
|
+
logs = train_sae(
|
|
251
|
+
model=self.sae.sae_engine,
|
|
252
|
+
dataloader=dataloader,
|
|
253
|
+
criterion=criterion,
|
|
254
|
+
optimizer=optimizer,
|
|
255
|
+
scheduler=cfg.scheduler,
|
|
256
|
+
nb_epochs=cfg.epochs,
|
|
257
|
+
clip_grad=cfg.clip_grad,
|
|
258
|
+
monitoring=monitoring,
|
|
259
|
+
device=device_str
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if cfg.verbose:
|
|
263
|
+
self.logger.info(
|
|
264
|
+
f"[SaeTrainer] Overcomplete training function returned, processing {len(logs.get('avg_loss', []))} epoch results...")
|
|
265
|
+
|
|
266
|
+
return logs
|
|
267
|
+
except Exception as e:
|
|
268
|
+
self.logger.error(f"[SaeTrainer] Error during training: {e}")
|
|
269
|
+
import traceback
|
|
270
|
+
self.logger.error(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
|
|
271
|
+
raise
|
|
272
|
+
|
|
273
|
+
def _process_training_logs(self, logs: dict[str, Any], cfg: SaeTrainingConfig) -> dict[str, list[float]]:
|
|
274
|
+
history: dict[str, list[float]] = {
|
|
275
|
+
"loss": logs.get("avg_loss", []),
|
|
276
|
+
"recon_mse": [],
|
|
277
|
+
"l1": [],
|
|
278
|
+
"r2": [],
|
|
279
|
+
"l0": [],
|
|
280
|
+
"dead_features_pct": [],
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
self._extract_r2_and_mse(history, logs)
|
|
284
|
+
self._extract_sparsity_metrics(history, logs, cfg)
|
|
285
|
+
|
|
286
|
+
return history
|
|
287
|
+
|
|
288
|
+
def _extract_r2_and_mse(self, history: dict[str, list[float]], logs: dict[str, Any]) -> None:
|
|
289
|
+
if "r2" in logs:
|
|
290
|
+
history["r2"] = logs["r2"]
|
|
291
|
+
history["recon_mse"] = [(1.0 - r2) for r2 in logs["r2"]]
|
|
292
|
+
else:
|
|
293
|
+
history["r2"] = [0.0] * len(history["loss"])
|
|
294
|
+
|
|
295
|
+
def _extract_sparsity_metrics(self, history: dict[str, list[float]], logs: dict[str, Any], cfg: SaeTrainingConfig) -> None:
|
|
296
|
+
memory_efficient = cfg.memory_efficient
|
|
297
|
+
num_epochs = len(history["loss"])
|
|
298
|
+
|
|
299
|
+
if "dead_features" in logs and isinstance(logs["dead_features"], list) and len(logs["dead_features"]) == num_epochs:
|
|
300
|
+
history["dead_features_pct"] = logs["dead_features"]
|
|
301
|
+
else:
|
|
302
|
+
history["dead_features_pct"] = [0.0] * num_epochs
|
|
303
|
+
|
|
304
|
+
history["l0"] = [0.0] * num_epochs
|
|
305
|
+
|
|
306
|
+
if "z" in logs and logs["z"]:
|
|
307
|
+
first_z = logs["z"][0] if len(logs["z"]) > 0 else None
|
|
308
|
+
is_flat_list = isinstance(first_z, torch.Tensor) if first_z is not None else False
|
|
309
|
+
|
|
310
|
+
if is_flat_list:
|
|
311
|
+
batches_per_epoch = len(logs["z"]) // num_epochs if num_epochs > 0 else len(logs["z"])
|
|
312
|
+
|
|
313
|
+
for epoch_idx in range(num_epochs):
|
|
314
|
+
start_idx = epoch_idx * batches_per_epoch
|
|
315
|
+
end_idx = start_idx + batches_per_epoch
|
|
316
|
+
epoch_z_tensors = logs["z"][start_idx:end_idx]
|
|
317
|
+
|
|
318
|
+
if epoch_z_tensors:
|
|
319
|
+
l1_val = self._compute_l1(epoch_z_tensors, memory_efficient)
|
|
320
|
+
history["l1"].append(l1_val)
|
|
321
|
+
else:
|
|
322
|
+
history["l1"].append(0.0)
|
|
323
|
+
else:
|
|
324
|
+
n_latents = self._get_n_latents()
|
|
325
|
+
slow_metrics_freq = cfg.wandb_slow_metrics_frequency if cfg.use_wandb else 1
|
|
326
|
+
|
|
327
|
+
for epoch_idx, z_batch_list in enumerate(logs["z"]):
|
|
328
|
+
if isinstance(z_batch_list, list) and len(z_batch_list) > 0:
|
|
329
|
+
l1_val = self._compute_l1(z_batch_list, memory_efficient)
|
|
330
|
+
history["l1"].append(l1_val)
|
|
331
|
+
|
|
332
|
+
should_compute_slow = (epoch_idx % slow_metrics_freq == 0) or (epoch_idx == len(logs["z"]) - 1)
|
|
333
|
+
|
|
334
|
+
if should_compute_slow:
|
|
335
|
+
l0, dead_pct = self._compute_slow_metrics(z_batch_list, n_latents, memory_efficient)
|
|
336
|
+
if len(history["l0"]) <= epoch_idx:
|
|
337
|
+
history["l0"].extend([None] * (epoch_idx + 1 - len(history["l0"])))
|
|
338
|
+
history["l0"][epoch_idx] = l0
|
|
339
|
+
if len(history["dead_features_pct"]) <= epoch_idx:
|
|
340
|
+
history["dead_features_pct"].extend([None] * (epoch_idx + 1 - len(history["dead_features_pct"])))
|
|
341
|
+
history["dead_features_pct"][epoch_idx] = dead_pct
|
|
342
|
+
else:
|
|
343
|
+
if len(history["l0"]) <= epoch_idx:
|
|
344
|
+
history["l0"].extend([None] * (epoch_idx + 1 - len(history["l0"])))
|
|
345
|
+
history["l0"].append(None)
|
|
346
|
+
if len(history["dead_features_pct"]) <= epoch_idx:
|
|
347
|
+
history["dead_features_pct"].extend([None] * (epoch_idx + 1 - len(history["dead_features_pct"])))
|
|
348
|
+
history["dead_features_pct"].append(None)
|
|
349
|
+
|
|
350
|
+
if memory_efficient:
|
|
351
|
+
del z_batch_list
|
|
352
|
+
if epoch_idx % 5 == 0:
|
|
353
|
+
self._clear_memory()
|
|
354
|
+
else:
|
|
355
|
+
history["l1"].append(0.0)
|
|
356
|
+
elif "z_sparsity" in logs:
|
|
357
|
+
history["l1"] = logs["z_sparsity"]
|
|
358
|
+
else:
|
|
359
|
+
history["l1"] = [0.0] * num_epochs
|
|
360
|
+
|
|
361
|
+
if memory_efficient and "z" in logs:
|
|
362
|
+
del logs["z"]
|
|
363
|
+
self._clear_memory()
|
|
364
|
+
|
|
365
|
+
def _get_n_latents(self) -> Optional[int]:
|
|
366
|
+
if hasattr(self.sae, 'context') and hasattr(self.sae.context, 'n_latents'):
|
|
367
|
+
return self.sae.context.n_latents
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
def _clear_memory(self) -> None:
|
|
371
|
+
"""Clear GPU/MPS memory cache and run garbage collection."""
|
|
372
|
+
gc.collect()
|
|
373
|
+
if torch.backends.mps.is_available():
|
|
374
|
+
torch.mps.empty_cache()
|
|
375
|
+
elif torch.cuda.is_available():
|
|
376
|
+
torch.cuda.empty_cache()
|
|
377
|
+
|
|
378
|
+
def _compute_l1(self, z_batch_list: list[torch.Tensor], memory_efficient: bool = False) -> float:
|
|
379
|
+
l1_vals = []
|
|
380
|
+
for z in z_batch_list:
|
|
381
|
+
if isinstance(z, torch.Tensor):
|
|
382
|
+
if memory_efficient and z.device.type != "cpu":
|
|
383
|
+
z_cpu = z.cpu()
|
|
384
|
+
l1_vals.append(z_cpu.abs().mean().item())
|
|
385
|
+
del z
|
|
386
|
+
else:
|
|
387
|
+
l1_vals.append(z.abs().mean().item())
|
|
388
|
+
return sum(l1_vals) / len(l1_vals) if l1_vals else 0.0
|
|
389
|
+
|
|
390
|
+
def _compute_slow_metrics(self, z_batch_list: list[torch.Tensor], n_latents: Optional[int], memory_efficient: bool = False) -> tuple[float, float]:
|
|
391
|
+
l0_vals = []
|
|
392
|
+
all_z_epoch = []
|
|
393
|
+
|
|
394
|
+
for z in z_batch_list:
|
|
395
|
+
if isinstance(z, torch.Tensor):
|
|
396
|
+
if memory_efficient and z.device.type != "cpu":
|
|
397
|
+
z_cpu = z.cpu()
|
|
398
|
+
active = (z_cpu.abs() > 1e-6).float()
|
|
399
|
+
l0_vals.append(active.sum(dim=-1).mean().item())
|
|
400
|
+
all_z_epoch.append(z_cpu)
|
|
401
|
+
del z
|
|
402
|
+
else:
|
|
403
|
+
active = (z.abs() > 1e-6).float()
|
|
404
|
+
l0_vals.append(active.sum(dim=-1).mean().item())
|
|
405
|
+
all_z_epoch.append(z)
|
|
406
|
+
|
|
407
|
+
l0 = sum(l0_vals) / len(l0_vals) if l0_vals else 0.0
|
|
408
|
+
|
|
409
|
+
if all_z_epoch and n_latents is not None:
|
|
410
|
+
z_concatenated = torch.cat(all_z_epoch, dim=0)
|
|
411
|
+
feature_activity = (z_concatenated.abs() > 1e-6).any(dim=0).float()
|
|
412
|
+
dead_count = (feature_activity == 0).sum().item()
|
|
413
|
+
dead_features_pct = dead_count / n_latents * 100.0 if n_latents > 0 else 0.0
|
|
414
|
+
if memory_efficient:
|
|
415
|
+
del z_concatenated, feature_activity
|
|
416
|
+
else:
|
|
417
|
+
dead_features_pct = 0.0
|
|
418
|
+
|
|
419
|
+
if memory_efficient:
|
|
420
|
+
del all_z_epoch
|
|
421
|
+
return l0, dead_features_pct
|
|
422
|
+
|
|
423
|
+
def _log_to_wandb(self, wandb_run: Any, history: dict[str, list[float]], cfg: SaeTrainingConfig) -> None:
|
|
424
|
+
try:
|
|
425
|
+
num_epochs = len(history["loss"])
|
|
426
|
+
slow_metrics_freq = cfg.wandb_slow_metrics_frequency
|
|
427
|
+
|
|
428
|
+
for epoch in range(1, num_epochs + 1):
|
|
429
|
+
epoch_idx = epoch - 1
|
|
430
|
+
should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)
|
|
431
|
+
|
|
432
|
+
metrics = self._build_epoch_metrics(history, epoch, epoch_idx, cfg, should_log_slow)
|
|
433
|
+
wandb_run.log(metrics)
|
|
434
|
+
|
|
435
|
+
final_metrics = self._build_final_metrics(history, num_epochs)
|
|
436
|
+
wandb_run.summary.update(final_metrics)
|
|
437
|
+
|
|
438
|
+
if cfg.verbose:
|
|
439
|
+
try:
|
|
440
|
+
url = wandb_run.url
|
|
441
|
+
self.logger.info(f"[SaeTrainer] Metrics logged to wandb: {url}")
|
|
442
|
+
except (AttributeError, RuntimeError):
|
|
443
|
+
self.logger.info("[SaeTrainer] Metrics logged to wandb (offline mode)")
|
|
444
|
+
except Exception as e:
|
|
445
|
+
self.logger.warning(f"[SaeTrainer] Failed to log metrics to wandb: {e}")
|
|
446
|
+
|
|
447
|
+
def _build_epoch_metrics(self, history: dict[str, list[float]], epoch: int, epoch_idx: int,
|
|
448
|
+
cfg: SaeTrainingConfig, should_log_slow: bool) -> dict[str, Any]:
|
|
449
|
+
metrics = {
|
|
450
|
+
"epoch": epoch,
|
|
451
|
+
"train/loss": history["loss"][epoch_idx] if epoch_idx < len(history["loss"]) else 0.0,
|
|
452
|
+
"train/reconstruction_mse": history["recon_mse"][epoch_idx] if epoch_idx < len(history["recon_mse"]) else 0.0,
|
|
453
|
+
"train/r2_score": history["r2"][epoch_idx] if epoch_idx < len(history["r2"]) else 0.0,
|
|
454
|
+
"train/l1_penalty": history["l1"][epoch_idx] if epoch_idx < len(history["l1"]) else 0.0,
|
|
455
|
+
"train/learning_rate": cfg.lr,
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
if should_log_slow:
|
|
459
|
+
l0_val = self._get_metric_value(history["l0"], epoch_idx)
|
|
460
|
+
dead_pct = self._get_metric_value(history["dead_features_pct"], epoch_idx)
|
|
461
|
+
metrics["train/l0_sparsity"] = l0_val
|
|
462
|
+
metrics["train/dead_features_pct"] = dead_pct
|
|
463
|
+
|
|
464
|
+
return metrics
|
|
465
|
+
|
|
466
|
+
def _get_metric_value(self, values: list[float | None], idx: int) -> float:
|
|
467
|
+
if idx < len(values) and values[idx] is not None:
|
|
468
|
+
return values[idx]
|
|
469
|
+
return self._get_last_known_value(values, idx)
|
|
470
|
+
|
|
471
|
+
def _get_last_known_value(self, values: list[float | None], idx: int) -> float:
|
|
472
|
+
for i in range(idx, -1, -1):
|
|
473
|
+
if i < len(values) and values[i] is not None:
|
|
474
|
+
return values[i]
|
|
475
|
+
return 0.0
|
|
476
|
+
|
|
477
|
+
def _build_final_metrics(self, history: dict[str, list[float]], num_epochs: int) -> dict[str, Any]:
|
|
478
|
+
final_l0 = self._get_last_known_value(history["l0"], len(history["l0"]) - 1) if history["l0"] else 0.0
|
|
479
|
+
final_dead_pct = self._get_last_known_value(history["dead_features_pct"], len(history["dead_features_pct"]) - 1) if history["dead_features_pct"] else 0.0
|
|
480
|
+
|
|
481
|
+
final_metrics = {
|
|
482
|
+
"final/loss": history["loss"][-1] if history["loss"] else 0.0,
|
|
483
|
+
"final/reconstruction_mse": history["recon_mse"][-1] if history["recon_mse"] else 0.0,
|
|
484
|
+
"final/r2_score": history["r2"][-1] if history["r2"] else 0.0,
|
|
485
|
+
"final/l1_penalty": history["l1"][-1] if history["l1"] else 0.0,
|
|
486
|
+
"final/l0_sparsity": final_l0,
|
|
487
|
+
"final/dead_features_pct": final_dead_pct,
|
|
488
|
+
"training/num_epochs": num_epochs,
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
if history["loss"]:
|
|
492
|
+
best_loss_idx = min(range(len(history["loss"])), key=lambda i: history["loss"][i])
|
|
493
|
+
final_metrics["best/loss"] = history["loss"][best_loss_idx]
|
|
494
|
+
final_metrics["best/loss_epoch"] = best_loss_idx + 1
|
|
495
|
+
|
|
496
|
+
if history["r2"]:
|
|
497
|
+
best_r2_idx = max(range(len(history["r2"])), key=lambda i: history["r2"][i])
|
|
498
|
+
final_metrics["best/r2_score"] = history["r2"][best_r2_idx]
|
|
499
|
+
final_metrics["best/r2_epoch"] = best_r2_idx + 1
|
|
500
|
+
|
|
501
|
+
return final_metrics
|
|
502
|
+
|
|
503
|
+
def _save_training_to_store(
|
|
504
|
+
self,
|
|
505
|
+
store: "Store",
|
|
506
|
+
training_run_id: str,
|
|
507
|
+
run_id: str,
|
|
508
|
+
layer_signature: str | int,
|
|
509
|
+
history: dict[str, list[float]],
|
|
510
|
+
config: SaeTrainingConfig
|
|
511
|
+
) -> None:
|
|
512
|
+
"""Save training outputs (model, history, metadata) to store under training_run_id.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
store: Store instance
|
|
516
|
+
training_run_id: Training run ID to save under
|
|
517
|
+
run_id: Original activation run ID used for training
|
|
518
|
+
layer_signature: Layer signature used for training
|
|
519
|
+
history: Training history dictionary
|
|
520
|
+
config: Training configuration
|
|
521
|
+
"""
|
|
522
|
+
try:
|
|
523
|
+
run_path = store._run_key(training_run_id)
|
|
524
|
+
run_path.mkdir(parents=True, exist_ok=True)
|
|
525
|
+
|
|
526
|
+
model_path = run_path / "model.pt"
|
|
527
|
+
history_path = run_path / "history.json"
|
|
528
|
+
|
|
529
|
+
sae_state_dict = self.sae.sae_engine.state_dict()
|
|
530
|
+
|
|
531
|
+
amber_metadata = {
|
|
532
|
+
"concepts_state": {
|
|
533
|
+
'multiplication': self.sae.concepts.multiplication.data.cpu().clone(),
|
|
534
|
+
'bias': self.sae.concepts.bias.data.cpu().clone(),
|
|
535
|
+
},
|
|
536
|
+
"n_latents": self.sae.context.n_latents,
|
|
537
|
+
"n_inputs": self.sae.context.n_inputs,
|
|
538
|
+
"device": self.sae.context.device,
|
|
539
|
+
"layer_signature": self.sae.context.lm_layer_signature,
|
|
540
|
+
"model_id": self.sae.context.model_id,
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
if hasattr(self.sae, 'k'):
|
|
544
|
+
amber_metadata["k"] = self.sae.k
|
|
545
|
+
|
|
546
|
+
payload = {
|
|
547
|
+
"sae_state_dict": sae_state_dict,
|
|
548
|
+
"amber_metadata": amber_metadata,
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
torch.save(payload, model_path)
|
|
552
|
+
|
|
553
|
+
with open(history_path, "w") as f:
|
|
554
|
+
json.dump(history, f, indent=2)
|
|
555
|
+
|
|
556
|
+
training_metadata = {
|
|
557
|
+
"training_run_id": training_run_id,
|
|
558
|
+
"activation_run_id": run_id,
|
|
559
|
+
"layer_signature": str(layer_signature),
|
|
560
|
+
"model_id": self.sae.context.model_id if hasattr(self.sae.context, 'model_id') else None,
|
|
561
|
+
"sae_type": type(self.sae).__name__,
|
|
562
|
+
"training_config": {
|
|
563
|
+
"epochs": config.epochs,
|
|
564
|
+
"batch_size": config.batch_size,
|
|
565
|
+
"lr": config.lr,
|
|
566
|
+
"l1_lambda": config.l1_lambda,
|
|
567
|
+
"device": str(config.device),
|
|
568
|
+
"dtype": str(config.dtype) if config.dtype else None,
|
|
569
|
+
"use_amp": config.use_amp,
|
|
570
|
+
"clip_grad": config.clip_grad,
|
|
571
|
+
"monitoring": config.monitoring,
|
|
572
|
+
},
|
|
573
|
+
"final_metrics": {
|
|
574
|
+
"loss": history["loss"][-1] if history["loss"] else None,
|
|
575
|
+
"r2": history["r2"][-1] if history["r2"] else None,
|
|
576
|
+
"recon_mse": history["recon_mse"][-1] if history["recon_mse"] else None,
|
|
577
|
+
"l1": history["l1"][-1] if history["l1"] else None,
|
|
578
|
+
},
|
|
579
|
+
"n_epochs": len(history["loss"]),
|
|
580
|
+
"timestamp": datetime.now().isoformat(),
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
if history["l0"]:
|
|
584
|
+
final_l0 = [x for x in history["l0"] if x is not None]
|
|
585
|
+
if final_l0:
|
|
586
|
+
training_metadata["final_metrics"]["l0"] = final_l0[-1]
|
|
587
|
+
|
|
588
|
+
if history["dead_features_pct"]:
|
|
589
|
+
final_dead = [x for x in history["dead_features_pct"] if x is not None]
|
|
590
|
+
if final_dead:
|
|
591
|
+
training_metadata["final_metrics"]["dead_features_pct"] = final_dead[-1]
|
|
592
|
+
|
|
593
|
+
store.put_run_metadata(training_run_id, training_metadata)
|
|
594
|
+
|
|
595
|
+
if config.verbose:
|
|
596
|
+
self.logger.info(f"[SaeTrainer] Saved model to: {model_path}")
|
|
597
|
+
self.logger.info(f"[SaeTrainer] Saved history to: {history_path}")
|
|
598
|
+
self.logger.info(f"[SaeTrainer] Saved metadata to: runs/{training_run_id}/meta.json")
|
|
599
|
+
|
|
600
|
+
except Exception as e:
|
|
601
|
+
self.logger.warning(f"[SaeTrainer] Failed to save training outputs to store: {e}")
|
|
602
|
+
if config.verbose:
|
|
603
|
+
import traceback
|
|
604
|
+
self.logger.warning(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
|