mi-crow 1.0.0__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +71 -1
- mi_crow/datasets/classification_dataset.py +136 -30
- mi_crow/datasets/text_dataset.py +165 -24
- mi_crow/hooks/controller.py +12 -7
- mi_crow/hooks/implementations/layer_activation_detector.py +30 -34
- mi_crow/hooks/implementations/model_input_detector.py +87 -87
- mi_crow/hooks/implementations/model_output_detector.py +43 -42
- mi_crow/hooks/utils.py +74 -0
- mi_crow/language_model/activations.py +174 -77
- mi_crow/language_model/device_manager.py +119 -0
- mi_crow/language_model/inference.py +18 -5
- mi_crow/language_model/initialization.py +10 -6
- mi_crow/language_model/language_model.py +67 -97
- mi_crow/language_model/layers.py +16 -13
- mi_crow/language_model/persistence.py +4 -2
- mi_crow/language_model/utils.py +5 -5
- mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py +157 -95
- mi_crow/mechanistic/sae/concepts/concept_dictionary.py +12 -2
- mi_crow/mechanistic/sae/concepts/text_heap.py +161 -0
- mi_crow/mechanistic/sae/modules/topk_sae.py +29 -22
- mi_crow/mechanistic/sae/sae.py +3 -1
- mi_crow/mechanistic/sae/sae_trainer.py +362 -29
- mi_crow/store/local_store.py +11 -5
- mi_crow/store/store.py +34 -1
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/METADATA +2 -1
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/RECORD +28 -26
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/WHEEL +0 -0
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -3,12 +3,20 @@
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Any, Optional, TYPE_CHECKING
|
|
5
5
|
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
import json
|
|
7
8
|
import logging
|
|
8
9
|
import gc
|
|
10
|
+
import os
|
|
9
11
|
|
|
10
12
|
import torch
|
|
11
13
|
|
|
14
|
+
try:
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
DOTENV_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
DOTENV_AVAILABLE = False
|
|
19
|
+
|
|
12
20
|
from mi_crow.store.store_dataloader import StoreDataloader
|
|
13
21
|
from mi_crow.utils import get_logger
|
|
14
22
|
|
|
@@ -35,7 +43,6 @@ class SaeTrainingConfig:
|
|
|
35
43
|
monitoring: int = 1 # 0=silent, 1=basic, 2=detailed (overcomplete parameter)
|
|
36
44
|
scheduler: Optional[Any] = None # Learning rate scheduler (overcomplete parameter)
|
|
37
45
|
max_nan_fallbacks: int = 5 # For train_sae_amp (overcomplete parameter)
|
|
38
|
-
# Wandb configuration
|
|
39
46
|
use_wandb: bool = False # Enable wandb logging
|
|
40
47
|
wandb_project: Optional[str] = None # Wandb project name (defaults to "sae-training" if not set)
|
|
41
48
|
wandb_entity: Optional[str] = None # Wandb entity/team name
|
|
@@ -44,7 +51,10 @@ class SaeTrainingConfig:
|
|
|
44
51
|
wandb_config: Optional[dict[str, Any]] = None # Additional config to log to wandb
|
|
45
52
|
wandb_mode: str = "online" # Wandb mode: "online", "offline", or "disabled"
|
|
46
53
|
wandb_slow_metrics_frequency: int = 50 # Log slow metrics (L0, dead features) every N epochs (default: 50)
|
|
54
|
+
wandb_api_key: Optional[str] = None # Wandb API key (can also be set via WANDB_API_KEY env var)
|
|
47
55
|
memory_efficient: bool = False # Enable memory-efficient processing (moves tensors to CPU, clears cache)
|
|
56
|
+
snapshot_every_n_epochs: Optional[int] = None # Save model snapshot every N epochs (None = disabled)
|
|
57
|
+
snapshot_base_path: Optional[str] = None # Base path for snapshots (defaults to training_run_id/snapshots)
|
|
48
58
|
|
|
49
59
|
|
|
50
60
|
class SaeTrainer:
|
|
@@ -77,27 +87,73 @@ class SaeTrainer:
|
|
|
77
87
|
cfg = config or SaeTrainingConfig()
|
|
78
88
|
|
|
79
89
|
wandb_run = self._initialize_wandb(cfg, run_id)
|
|
90
|
+
if wandb_run is not None and cfg.verbose:
|
|
91
|
+
try:
|
|
92
|
+
self.logger.info(f"[SaeTrainer] Wandb run initialized: {wandb_run.name} (id: {wandb_run.id})")
|
|
93
|
+
self.logger.info(f"[SaeTrainer] Wandb project: {wandb_run.project}, entity: {wandb_run.entity}")
|
|
94
|
+
self.logger.info(f"[SaeTrainer] Wandb mode: {wandb_run.mode}")
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
80
97
|
device = self._setup_device(cfg)
|
|
81
98
|
optimizer = self._create_optimizer(cfg)
|
|
82
99
|
criterion = self._create_criterion(cfg)
|
|
83
100
|
dataloader = self._create_dataloader(store, run_id, layer_signature, cfg, device)
|
|
84
101
|
monitoring = self._configure_logging(cfg, run_id)
|
|
85
102
|
|
|
86
|
-
|
|
103
|
+
if training_run_id is None:
|
|
104
|
+
training_run_id = f"sae_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
105
|
+
|
|
106
|
+
logs = self._run_training(cfg, dataloader, criterion, optimizer, device, monitoring, store, training_run_id, run_id, layer_signature)
|
|
87
107
|
history = self._process_training_logs(logs, cfg)
|
|
108
|
+
|
|
109
|
+
if cfg.verbose and history.get("loss"):
|
|
110
|
+
num_epochs = len(history["loss"])
|
|
111
|
+
for epoch_idx in range(num_epochs):
|
|
112
|
+
epoch = epoch_idx + 1
|
|
113
|
+
loss = history["loss"][epoch_idx]
|
|
114
|
+
recon_mse = history["recon_mse"][epoch_idx] if epoch_idx < len(history["recon_mse"]) else None
|
|
115
|
+
l1_val = history["l1"][epoch_idx] if epoch_idx < len(history["l1"]) else None
|
|
116
|
+
r2_val = history["r2"][epoch_idx] if epoch_idx < len(history["r2"]) else None
|
|
117
|
+
self.logger.info(
|
|
118
|
+
"[SaeTrainer] Epoch %d: loss=%.6f recon_mse=%s l1=%s r2=%s",
|
|
119
|
+
epoch,
|
|
120
|
+
float(loss),
|
|
121
|
+
f"{float(recon_mse):.6f}" if recon_mse is not None else "n/a",
|
|
122
|
+
f"{float(l1_val):.6f}" if l1_val is not None else "n/a",
|
|
123
|
+
f"{float(r2_val):.6f}" if r2_val is not None else "n/a",
|
|
124
|
+
)
|
|
125
|
+
|
|
88
126
|
if cfg.memory_efficient:
|
|
89
127
|
self._clear_memory()
|
|
90
128
|
|
|
129
|
+
wandb_url = None
|
|
91
130
|
if wandb_run is not None:
|
|
131
|
+
if cfg.verbose:
|
|
132
|
+
self.logger.info(f"[SaeTrainer] Logging metrics to wandb (history keys: {list(history.keys())}, num epochs: {len(history.get('loss', []))})")
|
|
92
133
|
self._log_to_wandb(wandb_run, history, cfg)
|
|
134
|
+
try:
|
|
135
|
+
wandb_url = wandb_run.url
|
|
136
|
+
if cfg.verbose:
|
|
137
|
+
self.logger.info(f"[SaeTrainer] Wandb run URL: {wandb_url}")
|
|
138
|
+
except (AttributeError, RuntimeError) as e:
|
|
139
|
+
if cfg.verbose:
|
|
140
|
+
self.logger.warning(f"[SaeTrainer] Could not get wandb URL: {e}")
|
|
141
|
+
|
|
142
|
+
# Ensure wandb run is finished and synced
|
|
143
|
+
try:
|
|
144
|
+
import wandb
|
|
145
|
+
if wandb.run is not None:
|
|
146
|
+
wandb.finish()
|
|
147
|
+
if cfg.verbose:
|
|
148
|
+
self.logger.info("[SaeTrainer] Wandb run finished and synced")
|
|
149
|
+
except Exception as e:
|
|
150
|
+
if cfg.verbose:
|
|
151
|
+
self.logger.warning(f"[SaeTrainer] Error finishing wandb run: {e}")
|
|
93
152
|
|
|
94
153
|
if cfg.verbose:
|
|
95
154
|
self.logger.info(f"[SaeTrainer] Training completed, processing {len(history['loss'])} epochs of results")
|
|
96
155
|
self.logger.info("[SaeTrainer] Completed training")
|
|
97
156
|
|
|
98
|
-
if training_run_id is None:
|
|
99
|
-
training_run_id = f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
100
|
-
|
|
101
157
|
if cfg.verbose:
|
|
102
158
|
self.logger.info(f"[SaeTrainer] Saving training outputs to store/runs/{training_run_id}/")
|
|
103
159
|
|
|
@@ -112,7 +168,8 @@ class SaeTrainer:
|
|
|
112
168
|
|
|
113
169
|
return {
|
|
114
170
|
"history": history,
|
|
115
|
-
"training_run_id": training_run_id
|
|
171
|
+
"training_run_id": training_run_id,
|
|
172
|
+
"wandb_url": wandb_url
|
|
116
173
|
}
|
|
117
174
|
|
|
118
175
|
def _ensure_overcomplete_available(self) -> None:
|
|
@@ -122,35 +179,111 @@ class SaeTrainer:
|
|
|
122
179
|
raise ImportError("overcomplete.sae.train module not available. Cannot use overcomplete training.")
|
|
123
180
|
|
|
124
181
|
def _initialize_wandb(self, cfg: SaeTrainingConfig, run_id: str) -> Any:
|
|
182
|
+
"""Initialize wandb run if enabled in config."""
|
|
125
183
|
if not cfg.use_wandb:
|
|
126
184
|
return None
|
|
127
185
|
|
|
128
186
|
try:
|
|
129
187
|
import wandb
|
|
130
|
-
|
|
188
|
+
import os
|
|
189
|
+
|
|
190
|
+
# Load .env file if available
|
|
191
|
+
if DOTENV_AVAILABLE:
|
|
192
|
+
# Try to find .env file in project root
|
|
193
|
+
current_dir = Path(__file__).parent
|
|
194
|
+
project_root = current_dir
|
|
195
|
+
# Walk up to find project root (pyproject.toml or .git)
|
|
196
|
+
for _ in range(10): # Limit depth
|
|
197
|
+
if (project_root / "pyproject.toml").exists() or (project_root / ".git").exists():
|
|
198
|
+
break
|
|
199
|
+
if project_root == project_root.parent:
|
|
200
|
+
break
|
|
201
|
+
project_root = project_root.parent
|
|
202
|
+
|
|
203
|
+
env_file = project_root / ".env"
|
|
204
|
+
if env_file.exists():
|
|
205
|
+
load_dotenv(env_file, override=True)
|
|
206
|
+
|
|
207
|
+
# Get API key from config, environment, or .env (already loaded)
|
|
208
|
+
wandb_api_key = getattr(cfg, 'wandb_api_key', None) or os.getenv('WANDB_API_KEY')
|
|
209
|
+
wandb_project = cfg.wandb_project or os.getenv('WANDB_PROJECT') or os.getenv('SERVER_WANDB_PROJECT') or "sae-training"
|
|
131
210
|
wandb_name = cfg.wandb_name or run_id
|
|
132
211
|
wandb_mode = cfg.wandb_mode.lower() if cfg.wandb_mode else "online"
|
|
212
|
+
|
|
213
|
+
# Login with API key before init if available
|
|
214
|
+
if wandb_api_key:
|
|
215
|
+
# Clean API key (strip whitespace and quotes)
|
|
216
|
+
wandb_api_key = wandb_api_key.strip().strip('"').strip("'")
|
|
217
|
+
os.environ['WANDB_API_KEY'] = wandb_api_key
|
|
218
|
+
# Login with API key before init
|
|
219
|
+
try:
|
|
220
|
+
# Ensure API key is in environment
|
|
221
|
+
os.environ['WANDB_API_KEY'] = wandb_api_key
|
|
222
|
+
wandb.login(key=wandb_api_key, relogin=True)
|
|
223
|
+
except Exception as login_error:
|
|
224
|
+
self.logger.warning(f"[SaeTrainer] Wandb login failed: {login_error}")
|
|
225
|
+
# Fall back to offline mode if login fails
|
|
226
|
+
wandb_mode = "offline"
|
|
133
227
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
"
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
228
|
+
try:
|
|
229
|
+
# Ensure WANDB_API_KEY is set in environment before init
|
|
230
|
+
if wandb_api_key:
|
|
231
|
+
os.environ['WANDB_API_KEY'] = wandb_api_key
|
|
232
|
+
wandb_run = wandb.init(
|
|
233
|
+
project=wandb_project,
|
|
234
|
+
entity=cfg.wandb_entity,
|
|
235
|
+
name=wandb_name,
|
|
236
|
+
mode=wandb_mode,
|
|
237
|
+
settings=wandb.Settings(_disable_stats=True) if wandb_mode == "offline" else None,
|
|
238
|
+
config={
|
|
239
|
+
"run_id": run_id,
|
|
240
|
+
"epochs": cfg.epochs,
|
|
241
|
+
"batch_size": cfg.batch_size,
|
|
242
|
+
"lr": cfg.lr,
|
|
243
|
+
"l1_lambda": cfg.l1_lambda,
|
|
244
|
+
"device": str(cfg.device),
|
|
245
|
+
"dtype": str(cfg.dtype) if cfg.dtype else None,
|
|
246
|
+
"use_amp": cfg.use_amp,
|
|
247
|
+
"clip_grad": cfg.clip_grad,
|
|
248
|
+
"max_batches_per_epoch": cfg.max_batches_per_epoch,
|
|
249
|
+
**(cfg.wandb_config or {}),
|
|
250
|
+
},
|
|
251
|
+
tags=cfg.wandb_tags or [],
|
|
252
|
+
)
|
|
253
|
+
except Exception as init_error:
|
|
254
|
+
# If init fails with auth error (401), try offline mode
|
|
255
|
+
if "401" in str(init_error) or "PERMISSION_ERROR" in str(init_error) or "not logged in" in str(init_error).lower():
|
|
256
|
+
self.logger.warning(f"[SaeTrainer] Wandb init failed with auth error, retrying in offline mode: {init_error}")
|
|
257
|
+
try:
|
|
258
|
+
wandb_run = wandb.init(
|
|
259
|
+
project=wandb_project,
|
|
260
|
+
entity=cfg.wandb_entity,
|
|
261
|
+
name=wandb_name,
|
|
262
|
+
mode="offline",
|
|
263
|
+
config={
|
|
264
|
+
"run_id": run_id,
|
|
265
|
+
"epochs": cfg.epochs,
|
|
266
|
+
"batch_size": cfg.batch_size,
|
|
267
|
+
"lr": cfg.lr,
|
|
268
|
+
"l1_lambda": cfg.l1_lambda,
|
|
269
|
+
"device": str(cfg.device),
|
|
270
|
+
"dtype": str(cfg.dtype) if cfg.dtype else None,
|
|
271
|
+
"use_amp": cfg.use_amp,
|
|
272
|
+
"clip_grad": cfg.clip_grad,
|
|
273
|
+
"max_batches_per_epoch": cfg.max_batches_per_epoch,
|
|
274
|
+
**(cfg.wandb_config or {}),
|
|
275
|
+
},
|
|
276
|
+
tags=cfg.wandb_tags or [],
|
|
277
|
+
)
|
|
278
|
+
self.logger.info("[SaeTrainer] Wandb initialized in offline mode - sync later with: wandb sync")
|
|
279
|
+
except Exception as offline_error:
|
|
280
|
+
self.logger.warning(f"[SaeTrainer] Wandb init failed even in offline mode: {offline_error}")
|
|
281
|
+
return None
|
|
282
|
+
else:
|
|
283
|
+
# Re-raise if it's not an auth error
|
|
284
|
+
raise
|
|
285
|
+
|
|
286
|
+
return wandb_run
|
|
154
287
|
except ImportError:
|
|
155
288
|
self.logger.warning("[SaeTrainer] wandb not installed, skipping wandb logging")
|
|
156
289
|
self.logger.warning("[SaeTrainer] Install with: pip install wandb")
|
|
@@ -222,12 +355,31 @@ class SaeTrainer:
|
|
|
222
355
|
|
|
223
356
|
return monitoring
|
|
224
357
|
|
|
225
|
-
def _run_training(
|
|
226
|
-
|
|
358
|
+
def _run_training(
|
|
359
|
+
self,
|
|
360
|
+
cfg: SaeTrainingConfig,
|
|
361
|
+
dataloader: StoreDataloader,
|
|
362
|
+
criterion,
|
|
363
|
+
optimizer: torch.optim.Optimizer,
|
|
364
|
+
device: torch.device,
|
|
365
|
+
monitoring: int,
|
|
366
|
+
store: "Store",
|
|
367
|
+
training_run_id: str,
|
|
368
|
+
run_id: str,
|
|
369
|
+
layer_signature: str | int
|
|
370
|
+
) -> dict[str, Any]:
|
|
227
371
|
from overcomplete.sae.train import train_sae, train_sae_amp
|
|
228
372
|
|
|
229
373
|
device_str = str(device)
|
|
230
374
|
|
|
375
|
+
should_save_snapshots = cfg.snapshot_every_n_epochs is not None and cfg.snapshot_every_n_epochs > 0
|
|
376
|
+
|
|
377
|
+
if should_save_snapshots:
|
|
378
|
+
return self._run_training_with_snapshots(
|
|
379
|
+
cfg, dataloader, criterion, optimizer, device, monitoring,
|
|
380
|
+
store, training_run_id, run_id, layer_signature
|
|
381
|
+
)
|
|
382
|
+
|
|
231
383
|
try:
|
|
232
384
|
if cfg.use_amp and device.type in ("cuda", "cpu"):
|
|
233
385
|
if cfg.verbose:
|
|
@@ -270,6 +422,173 @@ class SaeTrainer:
|
|
|
270
422
|
self.logger.error(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
|
|
271
423
|
raise
|
|
272
424
|
|
|
425
|
+
def _run_training_with_snapshots(
|
|
426
|
+
self,
|
|
427
|
+
cfg: SaeTrainingConfig,
|
|
428
|
+
dataloader: StoreDataloader,
|
|
429
|
+
criterion,
|
|
430
|
+
optimizer: torch.optim.Optimizer,
|
|
431
|
+
device: torch.device,
|
|
432
|
+
monitoring: int,
|
|
433
|
+
store: "Store",
|
|
434
|
+
training_run_id: str,
|
|
435
|
+
run_id: str,
|
|
436
|
+
layer_signature: str | int
|
|
437
|
+
) -> dict[str, Any]:
|
|
438
|
+
"""
|
|
439
|
+
Run training epoch by epoch, saving snapshots every N epochs.
|
|
440
|
+
|
|
441
|
+
This method runs training in a loop, one epoch at a time, to enable
|
|
442
|
+
snapshot saving between epochs.
|
|
443
|
+
"""
|
|
444
|
+
from overcomplete.sae.train import train_sae, train_sae_amp
|
|
445
|
+
|
|
446
|
+
device_str = str(device)
|
|
447
|
+
snapshot_freq = cfg.snapshot_every_n_epochs
|
|
448
|
+
|
|
449
|
+
if cfg.verbose:
|
|
450
|
+
self.logger.info(f"[SaeTrainer] Running training with snapshots every {snapshot_freq} epochs")
|
|
451
|
+
|
|
452
|
+
all_logs: dict[str, Any] = {}
|
|
453
|
+
|
|
454
|
+
try:
|
|
455
|
+
for epoch in range(1, cfg.epochs + 1):
|
|
456
|
+
if cfg.verbose:
|
|
457
|
+
self.logger.info(f"[SaeTrainer] Training epoch {epoch}/{cfg.epochs}")
|
|
458
|
+
|
|
459
|
+
if cfg.use_amp and device.type in ("cuda", "cpu"):
|
|
460
|
+
epoch_logs = train_sae_amp(
|
|
461
|
+
model=self.sae.sae_engine,
|
|
462
|
+
dataloader=dataloader,
|
|
463
|
+
criterion=criterion,
|
|
464
|
+
optimizer=optimizer,
|
|
465
|
+
scheduler=cfg.scheduler,
|
|
466
|
+
nb_epochs=1,
|
|
467
|
+
clip_grad=cfg.clip_grad,
|
|
468
|
+
monitoring=monitoring,
|
|
469
|
+
device=device_str,
|
|
470
|
+
max_nan_fallbacks=cfg.max_nan_fallbacks
|
|
471
|
+
)
|
|
472
|
+
else:
|
|
473
|
+
epoch_logs = train_sae(
|
|
474
|
+
model=self.sae.sae_engine,
|
|
475
|
+
dataloader=dataloader,
|
|
476
|
+
criterion=criterion,
|
|
477
|
+
optimizer=optimizer,
|
|
478
|
+
scheduler=cfg.scheduler,
|
|
479
|
+
nb_epochs=1,
|
|
480
|
+
clip_grad=cfg.clip_grad,
|
|
481
|
+
monitoring=monitoring,
|
|
482
|
+
device=device_str
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
self._merge_epoch_logs(all_logs, epoch_logs)
|
|
486
|
+
|
|
487
|
+
if epoch % snapshot_freq == 0 or epoch == cfg.epochs:
|
|
488
|
+
self._save_snapshot(store, training_run_id, run_id, layer_signature, epoch, cfg)
|
|
489
|
+
|
|
490
|
+
if cfg.memory_efficient and epoch % 5 == 0:
|
|
491
|
+
self._clear_memory()
|
|
492
|
+
|
|
493
|
+
if cfg.verbose:
|
|
494
|
+
self.logger.info(
|
|
495
|
+
f"[SaeTrainer] Training with snapshots completed, processing {len(all_logs.get('avg_loss', []))} epoch results...")
|
|
496
|
+
|
|
497
|
+
return all_logs
|
|
498
|
+
except Exception as e:
|
|
499
|
+
self.logger.error(f"[SaeTrainer] Error during training with snapshots: {e}")
|
|
500
|
+
import traceback
|
|
501
|
+
self.logger.error(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
|
|
502
|
+
raise
|
|
503
|
+
|
|
504
|
+
def _merge_epoch_logs(self, all_logs: dict[str, Any], epoch_logs: dict[str, Any]) -> None:
|
|
505
|
+
"""
|
|
506
|
+
Merge single-epoch logs into accumulated logs.
|
|
507
|
+
|
|
508
|
+
Handles all fields that overcomplete's training functions may return,
|
|
509
|
+
including avg_loss, r2, z, and dead_features.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
all_logs: Accumulated logs dictionary to update
|
|
513
|
+
epoch_logs: Single epoch logs from overcomplete training
|
|
514
|
+
"""
|
|
515
|
+
for key, value in epoch_logs.items():
|
|
516
|
+
if key not in all_logs:
|
|
517
|
+
all_logs[key] = []
|
|
518
|
+
|
|
519
|
+
if isinstance(value, list):
|
|
520
|
+
all_logs[key].extend(value)
|
|
521
|
+
else:
|
|
522
|
+
all_logs[key].append(value)
|
|
523
|
+
|
|
524
|
+
def _save_snapshot(
|
|
525
|
+
self,
|
|
526
|
+
store: "Store",
|
|
527
|
+
training_run_id: str,
|
|
528
|
+
run_id: str,
|
|
529
|
+
layer_signature: str | int,
|
|
530
|
+
epoch: int,
|
|
531
|
+
config: SaeTrainingConfig
|
|
532
|
+
) -> None:
|
|
533
|
+
"""
|
|
534
|
+
Save a snapshot of the current model state.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
store: Store instance
|
|
538
|
+
training_run_id: Training run ID
|
|
539
|
+
run_id: Original activation run ID
|
|
540
|
+
layer_signature: Layer signature
|
|
541
|
+
epoch: Current epoch number
|
|
542
|
+
config: Training configuration
|
|
543
|
+
"""
|
|
544
|
+
try:
|
|
545
|
+
snapshot_base = config.snapshot_base_path
|
|
546
|
+
if snapshot_base is None:
|
|
547
|
+
run_path = store._run_key(training_run_id)
|
|
548
|
+
snapshot_dir = run_path / "snapshots"
|
|
549
|
+
else:
|
|
550
|
+
snapshot_dir = Path(snapshot_base)
|
|
551
|
+
|
|
552
|
+
snapshot_dir.mkdir(parents=True, exist_ok=True)
|
|
553
|
+
|
|
554
|
+
snapshot_path = snapshot_dir / f"model_epoch_{epoch}.pt"
|
|
555
|
+
|
|
556
|
+
sae_state_dict = self.sae.sae_engine.state_dict()
|
|
557
|
+
|
|
558
|
+
mi_crow_metadata = {
|
|
559
|
+
"concepts_state": {
|
|
560
|
+
'multiplication': self.sae.concepts.multiplication.data.cpu().clone(),
|
|
561
|
+
'bias': self.sae.concepts.bias.data.cpu().clone(),
|
|
562
|
+
},
|
|
563
|
+
"n_latents": self.sae.context.n_latents,
|
|
564
|
+
"n_inputs": self.sae.context.n_inputs,
|
|
565
|
+
"device": self.sae.context.device,
|
|
566
|
+
"layer_signature": self.sae.context.lm_layer_signature,
|
|
567
|
+
"model_id": self.sae.context.model_id,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
if hasattr(self.sae, 'k'):
|
|
571
|
+
mi_crow_metadata["k"] = self.sae.k
|
|
572
|
+
|
|
573
|
+
payload = {
|
|
574
|
+
"sae_state_dict": sae_state_dict,
|
|
575
|
+
"mi_crow_metadata": mi_crow_metadata,
|
|
576
|
+
"epoch": epoch,
|
|
577
|
+
"training_run_id": training_run_id,
|
|
578
|
+
"activation_run_id": run_id,
|
|
579
|
+
"layer_signature": str(layer_signature),
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
torch.save(payload, snapshot_path)
|
|
583
|
+
|
|
584
|
+
if config.verbose:
|
|
585
|
+
self.logger.info(f"[SaeTrainer] Saved snapshot to: {snapshot_path}")
|
|
586
|
+
except Exception as e:
|
|
587
|
+
self.logger.warning(f"[SaeTrainer] Failed to save snapshot at epoch {epoch}: {e}")
|
|
588
|
+
if config.verbose:
|
|
589
|
+
import traceback
|
|
590
|
+
self.logger.warning(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
|
|
591
|
+
|
|
273
592
|
def _process_training_logs(self, logs: dict[str, Any], cfg: SaeTrainingConfig) -> dict[str, list[float]]:
|
|
274
593
|
history: dict[str, list[float]] = {
|
|
275
594
|
"loss": logs.get("avg_loss", []),
|
|
@@ -422,18 +741,30 @@ class SaeTrainer:
|
|
|
422
741
|
|
|
423
742
|
def _log_to_wandb(self, wandb_run: Any, history: dict[str, list[float]], cfg: SaeTrainingConfig) -> None:
|
|
424
743
|
try:
|
|
744
|
+
if not history.get("loss"):
|
|
745
|
+
self.logger.warning("[SaeTrainer] No loss data in history, skipping wandb logging")
|
|
746
|
+
return
|
|
747
|
+
|
|
425
748
|
num_epochs = len(history["loss"])
|
|
426
749
|
slow_metrics_freq = cfg.wandb_slow_metrics_frequency
|
|
750
|
+
|
|
751
|
+
if cfg.verbose:
|
|
752
|
+
self.logger.info(f"[SaeTrainer] Logging {num_epochs} epochs to wandb")
|
|
427
753
|
|
|
428
754
|
for epoch in range(1, num_epochs + 1):
|
|
429
755
|
epoch_idx = epoch - 1
|
|
430
756
|
should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)
|
|
431
757
|
|
|
432
758
|
metrics = self._build_epoch_metrics(history, epoch, epoch_idx, cfg, should_log_slow)
|
|
759
|
+
if cfg.verbose:
|
|
760
|
+
self.logger.info(f"[SaeTrainer] Logging epoch {epoch} metrics to wandb: {list(metrics.keys())}")
|
|
433
761
|
wandb_run.log(metrics)
|
|
434
762
|
|
|
435
763
|
final_metrics = self._build_final_metrics(history, num_epochs)
|
|
436
764
|
wandb_run.summary.update(final_metrics)
|
|
765
|
+
|
|
766
|
+
if cfg.verbose:
|
|
767
|
+
self.logger.info(f"[SaeTrainer] Updated wandb summary with final metrics: {list(final_metrics.keys())}")
|
|
437
768
|
|
|
438
769
|
if cfg.verbose:
|
|
439
770
|
try:
|
|
@@ -443,6 +774,8 @@ class SaeTrainer:
|
|
|
443
774
|
self.logger.info("[SaeTrainer] Metrics logged to wandb (offline mode)")
|
|
444
775
|
except Exception as e:
|
|
445
776
|
self.logger.warning(f"[SaeTrainer] Failed to log metrics to wandb: {e}")
|
|
777
|
+
import traceback
|
|
778
|
+
self.logger.warning(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
|
|
446
779
|
|
|
447
780
|
def _build_epoch_metrics(self, history: dict[str, list[float]], epoch: int, epoch_idx: int,
|
|
448
781
|
cfg: SaeTrainingConfig, should_log_slow: bool) -> dict[str, Any]:
|
mi_crow/store/local_store.py
CHANGED
|
@@ -36,7 +36,9 @@ class LocalStore(Store):
|
|
|
36
36
|
|
|
37
37
|
def put_tensor(self, key: str, tensor: torch.Tensor) -> None:
|
|
38
38
|
path = self._full(key)
|
|
39
|
-
|
|
39
|
+
tensor_copy = tensor.clone().detach()
|
|
40
|
+
storch.save_file({"tensor": tensor_copy}, str(path))
|
|
41
|
+
del tensor_copy
|
|
40
42
|
|
|
41
43
|
def get_tensor(self, key: str) -> torch.Tensor:
|
|
42
44
|
loaded = storch.load_file(str(self._full(key)))
|
|
@@ -248,7 +250,9 @@ class LocalStore(Store):
|
|
|
248
250
|
tensor_filename = f"{tensor_key}.safetensors"
|
|
249
251
|
tensor_path = layer_dir / tensor_filename
|
|
250
252
|
try:
|
|
251
|
-
|
|
253
|
+
tensor_copy = tensor.clone().detach()
|
|
254
|
+
storch.save_file({"tensor": tensor_copy}, str(tensor_path))
|
|
255
|
+
del tensor_copy
|
|
252
256
|
except Exception as e:
|
|
253
257
|
raise OSError(
|
|
254
258
|
f"Failed to save tensor at {tensor_path} for run_id={run_id!r}, "
|
|
@@ -423,15 +427,17 @@ class LocalStore(Store):
|
|
|
423
427
|
existing = {}
|
|
424
428
|
|
|
425
429
|
batch_key = f"batch_{batch_index}"
|
|
426
|
-
existing[batch_key] = tensor
|
|
430
|
+
existing[batch_key] = tensor.clone().detach()
|
|
427
431
|
|
|
428
432
|
try:
|
|
429
433
|
storch.save_file(existing, str(tensor_path))
|
|
434
|
+
del existing[batch_key]
|
|
430
435
|
except Exception as e:
|
|
436
|
+
tensor_shape = tuple(tensor.shape) if hasattr(tensor, 'shape') else 'unknown'
|
|
431
437
|
raise OSError(
|
|
432
438
|
f"Failed to save unified tensor at {tensor_path} for run_id={run_id!r}, "
|
|
433
|
-
f"layer={layer_signature!r}, key={tensor_key!r}, batch_index={batch_index}
|
|
434
|
-
f"Error: {e}"
|
|
439
|
+
f"layer={layer_signature!r}, key={tensor_key!r}, batch_index={batch_index}, "
|
|
440
|
+
f"shape={tensor_shape}. Error: {e}"
|
|
435
441
|
) from e
|
|
436
442
|
|
|
437
443
|
return f"{self.runs_prefix}/{run_id}/detectors"
|
mi_crow/store/store.py
CHANGED
|
@@ -4,9 +4,42 @@ import abc
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Dict, Any, List, Iterator
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
# #region agent log
|
|
8
|
+
import json
|
|
9
|
+
import sys
|
|
10
|
+
import os
|
|
11
|
+
from pathlib import Path
|
|
8
12
|
|
|
13
|
+
_debug_log_path = Path('/mnt/evafs/groups/mi2lab/akaniasty/Mi-Crow/.cursor/debug.log')
|
|
14
|
+
if _debug_log_path.parent.exists():
|
|
15
|
+
torch_search_paths = [p for p in sys.path if 'torch' in p.lower()]
|
|
16
|
+
try:
|
|
17
|
+
with open(_debug_log_path, 'a') as f:
|
|
18
|
+
f.write(json.dumps({"sessionId":"debug-session","runId":"pre-fix","hypothesisId":"A,B,C,D,E","location":"store.py:7","message":"Before torch import","data":{"sys_path":sys.path[:5],"torch_in_path":torch_search_paths},"timestamp":__import__('time').time()*1000}) + '\n')
|
|
19
|
+
except (OSError, IOError):
|
|
20
|
+
pass
|
|
21
|
+
# #endregion
|
|
22
|
+
import torch
|
|
23
|
+
# #region agent log
|
|
24
|
+
if _debug_log_path.parent.exists():
|
|
25
|
+
torch_file = getattr(torch, '__file__', None)
|
|
26
|
+
torch_path = getattr(torch, '__path__', None)
|
|
27
|
+
torch_loader = str(getattr(torch, '__loader__', None))
|
|
28
|
+
try:
|
|
29
|
+
with open(_debug_log_path, 'a') as f:
|
|
30
|
+
f.write(json.dumps({"sessionId":"debug-session","runId":"pre-fix","hypothesisId":"A,B,C,D,E","location":"store.py:8","message":"After torch import","data":{"torch_type":str(type(torch)),"torch_file":torch_file,"torch_path":str(torch_path) if torch_path else None,"torch_loader":torch_loader[:100],"has_tensor":hasattr(torch,'Tensor'),"torch_dir_count":len(dir(torch))},"timestamp":__import__('time').time()*1000}) + '\n')
|
|
31
|
+
except (OSError, IOError):
|
|
32
|
+
pass
|
|
33
|
+
# #endregion
|
|
9
34
|
|
|
35
|
+
# #region agent log
|
|
36
|
+
if _debug_log_path.parent.exists():
|
|
37
|
+
try:
|
|
38
|
+
with open(_debug_log_path, 'a') as f:
|
|
39
|
+
f.write(json.dumps({"sessionId":"debug-session","runId":"pre-fix","hypothesisId":"A,B,C,D,E","location":"store.py:10","message":"Before TensorMetadata definition","data":{"torch_has_tensor":hasattr(torch,'Tensor'),"torch_attrs":str([x for x in dir(torch) if 'Tensor' in x or 'tensor' in x.lower()][:10])},"timestamp":__import__('time').time()*1000}) + '\n')
|
|
40
|
+
except (OSError, IOError):
|
|
41
|
+
pass
|
|
42
|
+
# #endregion
|
|
10
43
|
TensorMetadata = Dict[str, Dict[str, torch.Tensor]]
|
|
11
44
|
|
|
12
45
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mi-crow
|
|
3
|
-
Version: 1.0.0
|
|
3
|
+
Version: 1.0.0.post1
|
|
4
4
|
Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
|
|
5
5
|
Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -19,6 +19,7 @@ Requires-Dist: wandb>=0.22.1
|
|
|
19
19
|
Requires-Dist: pytest>=8.4.2
|
|
20
20
|
Requires-Dist: pytest-xdist>=3.8.0
|
|
21
21
|
Requires-Dist: seaborn>=0.13.2
|
|
22
|
+
Requires-Dist: numpy<2.0,>=1.20.0
|
|
22
23
|
Provides-Extra: dev
|
|
23
24
|
Requires-Dist: pre-commit>=4.3.0; extra == "dev"
|
|
24
25
|
Requires-Dist: ruff>=0.13.2; extra == "dev"
|