mi-crow 1.0.0__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,12 +3,20 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, Optional, TYPE_CHECKING
5
5
  from datetime import datetime
6
+ from pathlib import Path
6
7
  import json
7
8
  import logging
8
9
  import gc
10
+ import os
9
11
 
10
12
  import torch
11
13
 
14
+ try:
15
+ from dotenv import load_dotenv
16
+ DOTENV_AVAILABLE = True
17
+ except ImportError:
18
+ DOTENV_AVAILABLE = False
19
+
12
20
  from mi_crow.store.store_dataloader import StoreDataloader
13
21
  from mi_crow.utils import get_logger
14
22
 
@@ -35,7 +43,6 @@ class SaeTrainingConfig:
35
43
  monitoring: int = 1 # 0=silent, 1=basic, 2=detailed (overcomplete parameter)
36
44
  scheduler: Optional[Any] = None # Learning rate scheduler (overcomplete parameter)
37
45
  max_nan_fallbacks: int = 5 # For train_sae_amp (overcomplete parameter)
38
- # Wandb configuration
39
46
  use_wandb: bool = False # Enable wandb logging
40
47
  wandb_project: Optional[str] = None # Wandb project name (defaults to "sae-training" if not set)
41
48
  wandb_entity: Optional[str] = None # Wandb entity/team name
@@ -44,7 +51,10 @@ class SaeTrainingConfig:
44
51
  wandb_config: Optional[dict[str, Any]] = None # Additional config to log to wandb
45
52
  wandb_mode: str = "online" # Wandb mode: "online", "offline", or "disabled"
46
53
  wandb_slow_metrics_frequency: int = 50 # Log slow metrics (L0, dead features) every N epochs (default: 50)
54
+ wandb_api_key: Optional[str] = None # Wandb API key (can also be set via WANDB_API_KEY env var)
47
55
  memory_efficient: bool = False # Enable memory-efficient processing (moves tensors to CPU, clears cache)
56
+ snapshot_every_n_epochs: Optional[int] = None # Save model snapshot every N epochs (None = disabled)
57
+ snapshot_base_path: Optional[str] = None # Base path for snapshots (defaults to training_run_id/snapshots)
48
58
 
49
59
 
50
60
  class SaeTrainer:
@@ -77,27 +87,73 @@ class SaeTrainer:
77
87
  cfg = config or SaeTrainingConfig()
78
88
 
79
89
  wandb_run = self._initialize_wandb(cfg, run_id)
90
+ if wandb_run is not None and cfg.verbose:
91
+ try:
92
+ self.logger.info(f"[SaeTrainer] Wandb run initialized: {wandb_run.name} (id: {wandb_run.id})")
93
+ self.logger.info(f"[SaeTrainer] Wandb project: {wandb_run.project}, entity: {wandb_run.entity}")
94
+ self.logger.info(f"[SaeTrainer] Wandb mode: {wandb_run.mode}")
95
+ except Exception:
96
+ pass
80
97
  device = self._setup_device(cfg)
81
98
  optimizer = self._create_optimizer(cfg)
82
99
  criterion = self._create_criterion(cfg)
83
100
  dataloader = self._create_dataloader(store, run_id, layer_signature, cfg, device)
84
101
  monitoring = self._configure_logging(cfg, run_id)
85
102
 
86
- logs = self._run_training(cfg, dataloader, criterion, optimizer, device, monitoring)
103
+ if training_run_id is None:
104
+ training_run_id = f"sae_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
105
+
106
+ logs = self._run_training(cfg, dataloader, criterion, optimizer, device, monitoring, store, training_run_id, run_id, layer_signature)
87
107
  history = self._process_training_logs(logs, cfg)
108
+
109
+ if cfg.verbose and history.get("loss"):
110
+ num_epochs = len(history["loss"])
111
+ for epoch_idx in range(num_epochs):
112
+ epoch = epoch_idx + 1
113
+ loss = history["loss"][epoch_idx]
114
+ recon_mse = history["recon_mse"][epoch_idx] if epoch_idx < len(history["recon_mse"]) else None
115
+ l1_val = history["l1"][epoch_idx] if epoch_idx < len(history["l1"]) else None
116
+ r2_val = history["r2"][epoch_idx] if epoch_idx < len(history["r2"]) else None
117
+ self.logger.info(
118
+ "[SaeTrainer] Epoch %d: loss=%.6f recon_mse=%s l1=%s r2=%s",
119
+ epoch,
120
+ float(loss),
121
+ f"{float(recon_mse):.6f}" if recon_mse is not None else "n/a",
122
+ f"{float(l1_val):.6f}" if l1_val is not None else "n/a",
123
+ f"{float(r2_val):.6f}" if r2_val is not None else "n/a",
124
+ )
125
+
88
126
  if cfg.memory_efficient:
89
127
  self._clear_memory()
90
128
 
129
+ wandb_url = None
91
130
  if wandb_run is not None:
131
+ if cfg.verbose:
132
+ self.logger.info(f"[SaeTrainer] Logging metrics to wandb (history keys: {list(history.keys())}, num epochs: {len(history.get('loss', []))})")
92
133
  self._log_to_wandb(wandb_run, history, cfg)
134
+ try:
135
+ wandb_url = wandb_run.url
136
+ if cfg.verbose:
137
+ self.logger.info(f"[SaeTrainer] Wandb run URL: {wandb_url}")
138
+ except (AttributeError, RuntimeError) as e:
139
+ if cfg.verbose:
140
+ self.logger.warning(f"[SaeTrainer] Could not get wandb URL: {e}")
141
+
142
+ # Ensure wandb run is finished and synced
143
+ try:
144
+ import wandb
145
+ if wandb.run is not None:
146
+ wandb.finish()
147
+ if cfg.verbose:
148
+ self.logger.info("[SaeTrainer] Wandb run finished and synced")
149
+ except Exception as e:
150
+ if cfg.verbose:
151
+ self.logger.warning(f"[SaeTrainer] Error finishing wandb run: {e}")
93
152
 
94
153
  if cfg.verbose:
95
154
  self.logger.info(f"[SaeTrainer] Training completed, processing {len(history['loss'])} epochs of results")
96
155
  self.logger.info("[SaeTrainer] Completed training")
97
156
 
98
- if training_run_id is None:
99
- training_run_id = f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
100
-
101
157
  if cfg.verbose:
102
158
  self.logger.info(f"[SaeTrainer] Saving training outputs to store/runs/{training_run_id}/")
103
159
 
@@ -112,7 +168,8 @@ class SaeTrainer:
112
168
 
113
169
  return {
114
170
  "history": history,
115
- "training_run_id": training_run_id
171
+ "training_run_id": training_run_id,
172
+ "wandb_url": wandb_url
116
173
  }
117
174
 
118
175
  def _ensure_overcomplete_available(self) -> None:
@@ -122,35 +179,111 @@ class SaeTrainer:
122
179
  raise ImportError("overcomplete.sae.train module not available. Cannot use overcomplete training.")
123
180
 
124
181
  def _initialize_wandb(self, cfg: SaeTrainingConfig, run_id: str) -> Any:
182
+ """Initialize wandb run if enabled in config."""
125
183
  if not cfg.use_wandb:
126
184
  return None
127
185
 
128
186
  try:
129
187
  import wandb
130
- wandb_project = cfg.wandb_project or "sae-training"
188
+ import os
189
+
190
+ # Load .env file if available
191
+ if DOTENV_AVAILABLE:
192
+ # Try to find .env file in project root
193
+ current_dir = Path(__file__).parent
194
+ project_root = current_dir
195
+ # Walk up to find project root (pyproject.toml or .git)
196
+ for _ in range(10): # Limit depth
197
+ if (project_root / "pyproject.toml").exists() or (project_root / ".git").exists():
198
+ break
199
+ if project_root == project_root.parent:
200
+ break
201
+ project_root = project_root.parent
202
+
203
+ env_file = project_root / ".env"
204
+ if env_file.exists():
205
+ load_dotenv(env_file, override=True)
206
+
207
+ # Get API key from config, environment, or .env (already loaded)
208
+ wandb_api_key = getattr(cfg, 'wandb_api_key', None) or os.getenv('WANDB_API_KEY')
209
+ wandb_project = cfg.wandb_project or os.getenv('WANDB_PROJECT') or os.getenv('SERVER_WANDB_PROJECT') or "sae-training"
131
210
  wandb_name = cfg.wandb_name or run_id
132
211
  wandb_mode = cfg.wandb_mode.lower() if cfg.wandb_mode else "online"
212
+
213
+ # Login with API key before init if available
214
+ if wandb_api_key:
215
+ # Clean API key (strip whitespace and quotes)
216
+ wandb_api_key = wandb_api_key.strip().strip('"').strip("'")
217
+ os.environ['WANDB_API_KEY'] = wandb_api_key
218
+ # Login with API key before init
219
+ try:
220
+ # Ensure API key is in environment
221
+ os.environ['WANDB_API_KEY'] = wandb_api_key
222
+ wandb.login(key=wandb_api_key, relogin=True)
223
+ except Exception as login_error:
224
+ self.logger.warning(f"[SaeTrainer] Wandb login failed: {login_error}")
225
+ # Fall back to offline mode if login fails
226
+ wandb_mode = "offline"
133
227
 
134
- return wandb.init(
135
- project=wandb_project,
136
- entity=cfg.wandb_entity,
137
- name=wandb_name,
138
- mode=wandb_mode,
139
- config={
140
- "run_id": run_id,
141
- "epochs": cfg.epochs,
142
- "batch_size": cfg.batch_size,
143
- "lr": cfg.lr,
144
- "l1_lambda": cfg.l1_lambda,
145
- "device": str(cfg.device),
146
- "dtype": str(cfg.dtype) if cfg.dtype else None,
147
- "use_amp": cfg.use_amp,
148
- "clip_grad": cfg.clip_grad,
149
- "max_batches_per_epoch": cfg.max_batches_per_epoch,
150
- **(cfg.wandb_config or {}),
151
- },
152
- tags=cfg.wandb_tags or [],
153
- )
228
+ try:
229
+ # Ensure WANDB_API_KEY is set in environment before init
230
+ if wandb_api_key:
231
+ os.environ['WANDB_API_KEY'] = wandb_api_key
232
+ wandb_run = wandb.init(
233
+ project=wandb_project,
234
+ entity=cfg.wandb_entity,
235
+ name=wandb_name,
236
+ mode=wandb_mode,
237
+ settings=wandb.Settings(_disable_stats=True) if wandb_mode == "offline" else None,
238
+ config={
239
+ "run_id": run_id,
240
+ "epochs": cfg.epochs,
241
+ "batch_size": cfg.batch_size,
242
+ "lr": cfg.lr,
243
+ "l1_lambda": cfg.l1_lambda,
244
+ "device": str(cfg.device),
245
+ "dtype": str(cfg.dtype) if cfg.dtype else None,
246
+ "use_amp": cfg.use_amp,
247
+ "clip_grad": cfg.clip_grad,
248
+ "max_batches_per_epoch": cfg.max_batches_per_epoch,
249
+ **(cfg.wandb_config or {}),
250
+ },
251
+ tags=cfg.wandb_tags or [],
252
+ )
253
+ except Exception as init_error:
254
+ # If init fails with auth error (401), try offline mode
255
+ if "401" in str(init_error) or "PERMISSION_ERROR" in str(init_error) or "not logged in" in str(init_error).lower():
256
+ self.logger.warning(f"[SaeTrainer] Wandb init failed with auth error, retrying in offline mode: {init_error}")
257
+ try:
258
+ wandb_run = wandb.init(
259
+ project=wandb_project,
260
+ entity=cfg.wandb_entity,
261
+ name=wandb_name,
262
+ mode="offline",
263
+ config={
264
+ "run_id": run_id,
265
+ "epochs": cfg.epochs,
266
+ "batch_size": cfg.batch_size,
267
+ "lr": cfg.lr,
268
+ "l1_lambda": cfg.l1_lambda,
269
+ "device": str(cfg.device),
270
+ "dtype": str(cfg.dtype) if cfg.dtype else None,
271
+ "use_amp": cfg.use_amp,
272
+ "clip_grad": cfg.clip_grad,
273
+ "max_batches_per_epoch": cfg.max_batches_per_epoch,
274
+ **(cfg.wandb_config or {}),
275
+ },
276
+ tags=cfg.wandb_tags or [],
277
+ )
278
+ self.logger.info("[SaeTrainer] Wandb initialized in offline mode - sync later with: wandb sync")
279
+ except Exception as offline_error:
280
+ self.logger.warning(f"[SaeTrainer] Wandb init failed even in offline mode: {offline_error}")
281
+ return None
282
+ else:
283
+ # Re-raise if it's not an auth error
284
+ raise
285
+
286
+ return wandb_run
154
287
  except ImportError:
155
288
  self.logger.warning("[SaeTrainer] wandb not installed, skipping wandb logging")
156
289
  self.logger.warning("[SaeTrainer] Install with: pip install wandb")
@@ -222,12 +355,31 @@ class SaeTrainer:
222
355
 
223
356
  return monitoring
224
357
 
225
- def _run_training(self, cfg: SaeTrainingConfig, dataloader: StoreDataloader, criterion, optimizer: torch.optim.Optimizer,
226
- device: torch.device, monitoring: int) -> dict[str, Any]:
358
+ def _run_training(
359
+ self,
360
+ cfg: SaeTrainingConfig,
361
+ dataloader: StoreDataloader,
362
+ criterion,
363
+ optimizer: torch.optim.Optimizer,
364
+ device: torch.device,
365
+ monitoring: int,
366
+ store: "Store",
367
+ training_run_id: str,
368
+ run_id: str,
369
+ layer_signature: str | int
370
+ ) -> dict[str, Any]:
227
371
  from overcomplete.sae.train import train_sae, train_sae_amp
228
372
 
229
373
  device_str = str(device)
230
374
 
375
+ should_save_snapshots = cfg.snapshot_every_n_epochs is not None and cfg.snapshot_every_n_epochs > 0
376
+
377
+ if should_save_snapshots:
378
+ return self._run_training_with_snapshots(
379
+ cfg, dataloader, criterion, optimizer, device, monitoring,
380
+ store, training_run_id, run_id, layer_signature
381
+ )
382
+
231
383
  try:
232
384
  if cfg.use_amp and device.type in ("cuda", "cpu"):
233
385
  if cfg.verbose:
@@ -270,6 +422,173 @@ class SaeTrainer:
270
422
  self.logger.error(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
271
423
  raise
272
424
 
425
+ def _run_training_with_snapshots(
426
+ self,
427
+ cfg: SaeTrainingConfig,
428
+ dataloader: StoreDataloader,
429
+ criterion,
430
+ optimizer: torch.optim.Optimizer,
431
+ device: torch.device,
432
+ monitoring: int,
433
+ store: "Store",
434
+ training_run_id: str,
435
+ run_id: str,
436
+ layer_signature: str | int
437
+ ) -> dict[str, Any]:
438
+ """
439
+ Run training epoch by epoch, saving snapshots every N epochs.
440
+
441
+ This method runs training in a loop, one epoch at a time, to enable
442
+ snapshot saving between epochs.
443
+ """
444
+ from overcomplete.sae.train import train_sae, train_sae_amp
445
+
446
+ device_str = str(device)
447
+ snapshot_freq = cfg.snapshot_every_n_epochs
448
+
449
+ if cfg.verbose:
450
+ self.logger.info(f"[SaeTrainer] Running training with snapshots every {snapshot_freq} epochs")
451
+
452
+ all_logs: dict[str, Any] = {}
453
+
454
+ try:
455
+ for epoch in range(1, cfg.epochs + 1):
456
+ if cfg.verbose:
457
+ self.logger.info(f"[SaeTrainer] Training epoch {epoch}/{cfg.epochs}")
458
+
459
+ if cfg.use_amp and device.type in ("cuda", "cpu"):
460
+ epoch_logs = train_sae_amp(
461
+ model=self.sae.sae_engine,
462
+ dataloader=dataloader,
463
+ criterion=criterion,
464
+ optimizer=optimizer,
465
+ scheduler=cfg.scheduler,
466
+ nb_epochs=1,
467
+ clip_grad=cfg.clip_grad,
468
+ monitoring=monitoring,
469
+ device=device_str,
470
+ max_nan_fallbacks=cfg.max_nan_fallbacks
471
+ )
472
+ else:
473
+ epoch_logs = train_sae(
474
+ model=self.sae.sae_engine,
475
+ dataloader=dataloader,
476
+ criterion=criterion,
477
+ optimizer=optimizer,
478
+ scheduler=cfg.scheduler,
479
+ nb_epochs=1,
480
+ clip_grad=cfg.clip_grad,
481
+ monitoring=monitoring,
482
+ device=device_str
483
+ )
484
+
485
+ self._merge_epoch_logs(all_logs, epoch_logs)
486
+
487
+ if epoch % snapshot_freq == 0 or epoch == cfg.epochs:
488
+ self._save_snapshot(store, training_run_id, run_id, layer_signature, epoch, cfg)
489
+
490
+ if cfg.memory_efficient and epoch % 5 == 0:
491
+ self._clear_memory()
492
+
493
+ if cfg.verbose:
494
+ self.logger.info(
495
+ f"[SaeTrainer] Training with snapshots completed, processing {len(all_logs.get('avg_loss', []))} epoch results...")
496
+
497
+ return all_logs
498
+ except Exception as e:
499
+ self.logger.error(f"[SaeTrainer] Error during training with snapshots: {e}")
500
+ import traceback
501
+ self.logger.error(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
502
+ raise
503
+
504
+ def _merge_epoch_logs(self, all_logs: dict[str, Any], epoch_logs: dict[str, Any]) -> None:
505
+ """
506
+ Merge single-epoch logs into accumulated logs.
507
+
508
+ Handles all fields that overcomplete's training functions may return,
509
+ including avg_loss, r2, z, and dead_features.
510
+
511
+ Args:
512
+ all_logs: Accumulated logs dictionary to update
513
+ epoch_logs: Single epoch logs from overcomplete training
514
+ """
515
+ for key, value in epoch_logs.items():
516
+ if key not in all_logs:
517
+ all_logs[key] = []
518
+
519
+ if isinstance(value, list):
520
+ all_logs[key].extend(value)
521
+ else:
522
+ all_logs[key].append(value)
523
+
524
+ def _save_snapshot(
525
+ self,
526
+ store: "Store",
527
+ training_run_id: str,
528
+ run_id: str,
529
+ layer_signature: str | int,
530
+ epoch: int,
531
+ config: SaeTrainingConfig
532
+ ) -> None:
533
+ """
534
+ Save a snapshot of the current model state.
535
+
536
+ Args:
537
+ store: Store instance
538
+ training_run_id: Training run ID
539
+ run_id: Original activation run ID
540
+ layer_signature: Layer signature
541
+ epoch: Current epoch number
542
+ config: Training configuration
543
+ """
544
+ try:
545
+ snapshot_base = config.snapshot_base_path
546
+ if snapshot_base is None:
547
+ run_path = store._run_key(training_run_id)
548
+ snapshot_dir = run_path / "snapshots"
549
+ else:
550
+ snapshot_dir = Path(snapshot_base)
551
+
552
+ snapshot_dir.mkdir(parents=True, exist_ok=True)
553
+
554
+ snapshot_path = snapshot_dir / f"model_epoch_{epoch}.pt"
555
+
556
+ sae_state_dict = self.sae.sae_engine.state_dict()
557
+
558
+ mi_crow_metadata = {
559
+ "concepts_state": {
560
+ 'multiplication': self.sae.concepts.multiplication.data.cpu().clone(),
561
+ 'bias': self.sae.concepts.bias.data.cpu().clone(),
562
+ },
563
+ "n_latents": self.sae.context.n_latents,
564
+ "n_inputs": self.sae.context.n_inputs,
565
+ "device": self.sae.context.device,
566
+ "layer_signature": self.sae.context.lm_layer_signature,
567
+ "model_id": self.sae.context.model_id,
568
+ }
569
+
570
+ if hasattr(self.sae, 'k'):
571
+ mi_crow_metadata["k"] = self.sae.k
572
+
573
+ payload = {
574
+ "sae_state_dict": sae_state_dict,
575
+ "mi_crow_metadata": mi_crow_metadata,
576
+ "epoch": epoch,
577
+ "training_run_id": training_run_id,
578
+ "activation_run_id": run_id,
579
+ "layer_signature": str(layer_signature),
580
+ }
581
+
582
+ torch.save(payload, snapshot_path)
583
+
584
+ if config.verbose:
585
+ self.logger.info(f"[SaeTrainer] Saved snapshot to: {snapshot_path}")
586
+ except Exception as e:
587
+ self.logger.warning(f"[SaeTrainer] Failed to save snapshot at epoch {epoch}: {e}")
588
+ if config.verbose:
589
+ import traceback
590
+ self.logger.warning(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
591
+
273
592
  def _process_training_logs(self, logs: dict[str, Any], cfg: SaeTrainingConfig) -> dict[str, list[float]]:
274
593
  history: dict[str, list[float]] = {
275
594
  "loss": logs.get("avg_loss", []),
@@ -422,18 +741,30 @@ class SaeTrainer:
422
741
 
423
742
  def _log_to_wandb(self, wandb_run: Any, history: dict[str, list[float]], cfg: SaeTrainingConfig) -> None:
424
743
  try:
744
+ if not history.get("loss"):
745
+ self.logger.warning("[SaeTrainer] No loss data in history, skipping wandb logging")
746
+ return
747
+
425
748
  num_epochs = len(history["loss"])
426
749
  slow_metrics_freq = cfg.wandb_slow_metrics_frequency
750
+
751
+ if cfg.verbose:
752
+ self.logger.info(f"[SaeTrainer] Logging {num_epochs} epochs to wandb")
427
753
 
428
754
  for epoch in range(1, num_epochs + 1):
429
755
  epoch_idx = epoch - 1
430
756
  should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)
431
757
 
432
758
  metrics = self._build_epoch_metrics(history, epoch, epoch_idx, cfg, should_log_slow)
759
+ if cfg.verbose:
760
+ self.logger.info(f"[SaeTrainer] Logging epoch {epoch} metrics to wandb: {list(metrics.keys())}")
433
761
  wandb_run.log(metrics)
434
762
 
435
763
  final_metrics = self._build_final_metrics(history, num_epochs)
436
764
  wandb_run.summary.update(final_metrics)
765
+
766
+ if cfg.verbose:
767
+ self.logger.info(f"[SaeTrainer] Updated wandb summary with final metrics: {list(final_metrics.keys())}")
437
768
 
438
769
  if cfg.verbose:
439
770
  try:
@@ -443,6 +774,8 @@ class SaeTrainer:
443
774
  self.logger.info("[SaeTrainer] Metrics logged to wandb (offline mode)")
444
775
  except Exception as e:
445
776
  self.logger.warning(f"[SaeTrainer] Failed to log metrics to wandb: {e}")
777
+ import traceback
778
+ self.logger.warning(f"[SaeTrainer] Traceback: {traceback.format_exc()}")
446
779
 
447
780
  def _build_epoch_metrics(self, history: dict[str, list[float]], epoch: int, epoch_idx: int,
448
781
  cfg: SaeTrainingConfig, should_log_slow: bool) -> dict[str, Any]:
@@ -36,7 +36,9 @@ class LocalStore(Store):
36
36
 
37
37
  def put_tensor(self, key: str, tensor: torch.Tensor) -> None:
38
38
  path = self._full(key)
39
- storch.save_file({"tensor": tensor}, str(path))
39
+ tensor_copy = tensor.clone().detach()
40
+ storch.save_file({"tensor": tensor_copy}, str(path))
41
+ del tensor_copy
40
42
 
41
43
  def get_tensor(self, key: str) -> torch.Tensor:
42
44
  loaded = storch.load_file(str(self._full(key)))
@@ -248,7 +250,9 @@ class LocalStore(Store):
248
250
  tensor_filename = f"{tensor_key}.safetensors"
249
251
  tensor_path = layer_dir / tensor_filename
250
252
  try:
251
- storch.save_file({"tensor": tensor}, str(tensor_path))
253
+ tensor_copy = tensor.clone().detach()
254
+ storch.save_file({"tensor": tensor_copy}, str(tensor_path))
255
+ del tensor_copy
252
256
  except Exception as e:
253
257
  raise OSError(
254
258
  f"Failed to save tensor at {tensor_path} for run_id={run_id!r}, "
@@ -423,15 +427,17 @@ class LocalStore(Store):
423
427
  existing = {}
424
428
 
425
429
  batch_key = f"batch_{batch_index}"
426
- existing[batch_key] = tensor
430
+ existing[batch_key] = tensor.clone().detach()
427
431
 
428
432
  try:
429
433
  storch.save_file(existing, str(tensor_path))
434
+ del existing[batch_key]
430
435
  except Exception as e:
436
+ tensor_shape = tuple(tensor.shape) if hasattr(tensor, 'shape') else 'unknown'
431
437
  raise OSError(
432
438
  f"Failed to save unified tensor at {tensor_path} for run_id={run_id!r}, "
433
- f"layer={layer_signature!r}, key={tensor_key!r}, batch_index={batch_index}. "
434
- f"Error: {e}"
439
+ f"layer={layer_signature!r}, key={tensor_key!r}, batch_index={batch_index}, "
440
+ f"shape={tensor_shape}. Error: {e}"
435
441
  ) from e
436
442
 
437
443
  return f"{self.runs_prefix}/{run_id}/detectors"
mi_crow/store/store.py CHANGED
@@ -4,9 +4,42 @@ import abc
4
4
  from pathlib import Path
5
5
  from typing import Dict, Any, List, Iterator
6
6
 
7
- import torch
7
+ # #region agent log
8
+ import json
9
+ import sys
10
+ import os
11
+ from pathlib import Path
8
12
 
13
+ _debug_log_path = Path('/mnt/evafs/groups/mi2lab/akaniasty/Mi-Crow/.cursor/debug.log')
14
+ if _debug_log_path.parent.exists():
15
+ torch_search_paths = [p for p in sys.path if 'torch' in p.lower()]
16
+ try:
17
+ with open(_debug_log_path, 'a') as f:
18
+ f.write(json.dumps({"sessionId":"debug-session","runId":"pre-fix","hypothesisId":"A,B,C,D,E","location":"store.py:7","message":"Before torch import","data":{"sys_path":sys.path[:5],"torch_in_path":torch_search_paths},"timestamp":__import__('time').time()*1000}) + '\n')
19
+ except (OSError, IOError):
20
+ pass
21
+ # #endregion
22
+ import torch
23
+ # #region agent log
24
+ if _debug_log_path.parent.exists():
25
+ torch_file = getattr(torch, '__file__', None)
26
+ torch_path = getattr(torch, '__path__', None)
27
+ torch_loader = str(getattr(torch, '__loader__', None))
28
+ try:
29
+ with open(_debug_log_path, 'a') as f:
30
+ f.write(json.dumps({"sessionId":"debug-session","runId":"pre-fix","hypothesisId":"A,B,C,D,E","location":"store.py:8","message":"After torch import","data":{"torch_type":str(type(torch)),"torch_file":torch_file,"torch_path":str(torch_path) if torch_path else None,"torch_loader":torch_loader[:100],"has_tensor":hasattr(torch,'Tensor'),"torch_dir_count":len(dir(torch))},"timestamp":__import__('time').time()*1000}) + '\n')
31
+ except (OSError, IOError):
32
+ pass
33
+ # #endregion
9
34
 
35
+ # #region agent log
36
+ if _debug_log_path.parent.exists():
37
+ try:
38
+ with open(_debug_log_path, 'a') as f:
39
+ f.write(json.dumps({"sessionId":"debug-session","runId":"pre-fix","hypothesisId":"A,B,C,D,E","location":"store.py:10","message":"Before TensorMetadata definition","data":{"torch_has_tensor":hasattr(torch,'Tensor'),"torch_attrs":str([x for x in dir(torch) if 'Tensor' in x or 'tensor' in x.lower()][:10])},"timestamp":__import__('time').time()*1000}) + '\n')
40
+ except (OSError, IOError):
41
+ pass
42
+ # #endregion
10
43
  TensorMetadata = Dict[str, Dict[str, torch.Tensor]]
11
44
 
12
45
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mi-crow
3
- Version: 1.0.0
3
+ Version: 1.0.0.post1
4
4
  Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
5
5
  Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -19,6 +19,7 @@ Requires-Dist: wandb>=0.22.1
19
19
  Requires-Dist: pytest>=8.4.2
20
20
  Requires-Dist: pytest-xdist>=3.8.0
21
21
  Requires-Dist: seaborn>=0.13.2
22
+ Requires-Dist: numpy<2.0,>=1.20.0
22
23
  Provides-Extra: dev
23
24
  Requires-Dist: pre-commit>=4.3.0; extra == "dev"
24
25
  Requires-Dist: ruff>=0.13.2; extra == "dev"