mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,222 @@
1
+ """Wandb logging utilities for SAE training."""
2
+
3
+ from typing import Any, Optional
4
+
5
+ from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
6
+ from amber.utils import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ class WandbLogger:
12
+ """
13
+ Handles wandb logging for SAE training.
14
+
15
+ Encapsulates all wandb-related operations including initialization,
16
+ metric logging, and summary updates.
17
+ """
18
+
19
+ def __init__(self, config: SaeTrainingConfig, run_id: str):
20
+ """
21
+ Initialize WandbLogger.
22
+
23
+ Args:
24
+ config: Training configuration
25
+ run_id: Training run identifier
26
+ """
27
+ self.config = config
28
+ self.run_id = run_id
29
+ self.wandb_run: Optional[Any] = None
30
+ self._initialized = False
31
+
32
+ def initialize(self) -> bool:
33
+ """
34
+ Initialize wandb run if enabled in config.
35
+
36
+ Returns:
37
+ True if wandb was successfully initialized, False otherwise
38
+ """
39
+ if not self.config.use_wandb:
40
+ return False
41
+
42
+ try:
43
+ import wandb
44
+ except ImportError:
45
+ logger.warning("[WandbLogger] wandb not installed, skipping wandb logging")
46
+ logger.warning("[WandbLogger] Install with: pip install wandb")
47
+ return False
48
+
49
+ try:
50
+ wandb_project = self.config.wandb_project or "sae-training"
51
+ wandb_name = self.config.wandb_name or self.run_id
52
+ wandb_mode = self.config.wandb_mode.lower() if self.config.wandb_mode else "online"
53
+
54
+ self.wandb_run = wandb.init(
55
+ project=wandb_project,
56
+ entity=self.config.wandb_entity,
57
+ name=wandb_name,
58
+ mode=wandb_mode,
59
+ config=self._build_wandb_config(),
60
+ tags=self.config.wandb_tags or [],
61
+ )
62
+ self._initialized = True
63
+ return True
64
+ except Exception as e:
65
+ logger.warning(f"[WandbLogger] Unexpected error initializing wandb: {e}")
66
+ logger.warning("[WandbLogger] Continuing training without wandb logging")
67
+ return False
68
+
69
+ def _build_wandb_config(self) -> dict[str, Any]:
70
+ """Build wandb configuration dictionary."""
71
+ return {
72
+ "run_id": self.run_id,
73
+ "epochs": self.config.epochs,
74
+ "batch_size": self.config.batch_size,
75
+ "lr": self.config.lr,
76
+ "l1_lambda": self.config.l1_lambda,
77
+ "device": str(self.config.device),
78
+ "dtype": str(self.config.dtype) if self.config.dtype else None,
79
+ "use_amp": self.config.use_amp,
80
+ "clip_grad": self.config.clip_grad,
81
+ "max_batches_per_epoch": self.config.max_batches_per_epoch,
82
+ **(self.config.wandb_config or {}),
83
+ }
84
+
85
+ def log_metrics(
86
+ self,
87
+ history: dict[str, list[float | None]],
88
+ verbose: bool = False
89
+ ) -> None:
90
+ """
91
+ Log training metrics to wandb.
92
+
93
+ Args:
94
+ history: Dictionary with training history (loss, r2, l1, l0, etc.)
95
+ verbose: Whether to log verbose information
96
+ """
97
+ if not self._initialized or self.wandb_run is None:
98
+ return
99
+
100
+ try:
101
+ num_epochs = len(history.get("loss", []))
102
+ slow_metrics_freq = self.config.wandb_slow_metrics_frequency
103
+
104
+ # Helper to get last known value for slow metrics
105
+ def get_last_known_value(values: list[float | None], idx: int) -> float:
106
+ """Get the last non-None value up to idx, or 0.0 if none found."""
107
+ for i in range(idx, -1, -1):
108
+ if i < len(values) and values[i] is not None:
109
+ return values[i]
110
+ return 0.0
111
+
112
+ # Log metrics for each epoch
113
+ for epoch in range(1, num_epochs + 1):
114
+ epoch_idx = epoch - 1
115
+ should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)
116
+
117
+ metrics = self._build_epoch_metrics(history, epoch_idx, should_log_slow, get_last_known_value)
118
+ self.wandb_run.log(metrics)
119
+
120
+ # Log final summary metrics
121
+ self._log_summary_metrics(history, get_last_known_value)
122
+
123
+ if verbose:
124
+ self._log_wandb_url()
125
+
126
+ except Exception as e:
127
+ logger.warning(f"[WandbLogger] Failed to log metrics to wandb: {e}")
128
+
129
+ def _build_epoch_metrics(
130
+ self,
131
+ history: dict[str, list[float | None]],
132
+ epoch_idx: int,
133
+ should_log_slow: bool,
134
+ get_last_known_value: Any
135
+ ) -> dict[str, Any]:
136
+ """Build metrics dictionary for a single epoch."""
137
+ metrics = {
138
+ "epoch": epoch_idx + 1,
139
+ "train/loss": self._safe_get(history["loss"], epoch_idx, 0.0),
140
+ "train/reconstruction_mse": self._safe_get(history["recon_mse"], epoch_idx, 0.0),
141
+ "train/r2_score": self._safe_get(history["r2"], epoch_idx, 0.0),
142
+ "train/l1_penalty": self._safe_get(history["l1"], epoch_idx, 0.0),
143
+ "train/learning_rate": self.config.lr,
144
+ }
145
+
146
+ # Add slow metrics if computed this epoch
147
+ if should_log_slow:
148
+ l0_val = self._get_slow_metric(history["l0"], epoch_idx, get_last_known_value)
149
+ dead_pct = self._get_slow_metric(history["dead_features_pct"], epoch_idx, get_last_known_value)
150
+ metrics["train/l0_sparsity"] = l0_val
151
+ metrics["train/dead_features_pct"] = dead_pct
152
+
153
+ return metrics
154
+
155
+ def _get_slow_metric(
156
+ self,
157
+ values: list[float | None],
158
+ epoch_idx: int,
159
+ get_last_known_value: Any
160
+ ) -> float:
161
+ """Get slow metric value, using last known value if current is None."""
162
+ if epoch_idx < len(values) and values[epoch_idx] is not None:
163
+ return values[epoch_idx]
164
+ return get_last_known_value(values, epoch_idx)
165
+
166
+ def _safe_get(self, values: list[float | None], idx: int, default: float) -> float:
167
+ """Safely get value from list, returning default if out of bounds."""
168
+ if idx < len(values) and values[idx] is not None:
169
+ return values[idx]
170
+ return default
171
+
172
+ def _log_summary_metrics(
173
+ self,
174
+ history: dict[str, list[float | None]],
175
+ get_last_known_value: Any
176
+ ) -> None:
177
+ """Log final summary metrics to wandb."""
178
+ if self.wandb_run is None:
179
+ return
180
+
181
+ # Get last computed values for slow metrics
182
+ final_l0 = get_last_known_value(history["l0"], len(history["l0"]) - 1) if history.get("l0") else 0.0
183
+ final_dead_pct = get_last_known_value(
184
+ history["dead_features_pct"],
185
+ len(history["dead_features_pct"]) - 1
186
+ ) if history.get("dead_features_pct") else 0.0
187
+
188
+ final_metrics = {
189
+ "final/loss": history["loss"][-1] if history.get("loss") else 0.0,
190
+ "final/reconstruction_mse": history["recon_mse"][-1] if history.get("recon_mse") else 0.0,
191
+ "final/r2_score": history["r2"][-1] if history.get("r2") else 0.0,
192
+ "final/l1_penalty": history["l1"][-1] if history.get("l1") else 0.0,
193
+ "final/l0_sparsity": final_l0,
194
+ "final/dead_features_pct": final_dead_pct,
195
+ "training/num_epochs": len(history.get("loss", [])),
196
+ }
197
+
198
+ # Add best metrics
199
+ if history.get("loss"):
200
+ best_loss_idx = min(range(len(history["loss"])), key=lambda i: history["loss"][i] or float('inf'))
201
+ final_metrics["best/loss"] = history["loss"][best_loss_idx] or 0.0
202
+ final_metrics["best/loss_epoch"] = best_loss_idx + 1
203
+
204
+ if history.get("r2"):
205
+ best_r2_idx = max(range(len(history["r2"])), key=lambda i: history["r2"][i] or -float('inf'))
206
+ final_metrics["best/r2_score"] = history["r2"][best_r2_idx] or 0.0
207
+ final_metrics["best/r2_epoch"] = best_r2_idx + 1
208
+
209
+ self.wandb_run.summary.update(final_metrics)
210
+
211
+ def _log_wandb_url(self) -> None:
212
+ """Log wandb run URL if available."""
213
+ if self.wandb_run is None:
214
+ return
215
+
216
+ try:
217
+ url = self.wandb_run.url
218
+ logger.info(f"[WandbLogger] Metrics logged to wandb: {url}")
219
+ except (AttributeError, RuntimeError):
220
+ # Offline mode or URL not available
221
+ logger.info("[WandbLogger] Metrics logged to wandb (offline mode)")
222
+
@@ -0,0 +1,5 @@
1
+ from amber.store.store import Store
2
+ from amber.store.local_store import LocalStore
3
+
4
+ __all__ = ["Store", "LocalStore"]
5
+
@@ -0,0 +1,437 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Any, List
4
+ import shutil
5
+
6
+ import torch
7
+
8
+ from amber.store.store import Store, TensorMetadata
9
+ import safetensors.torch as storch
10
+
11
+
12
+ class LocalStore(Store):
13
+ """Local filesystem implementation of Store interface."""
14
+
15
+ def __init__(
16
+ self,
17
+ base_path: Path | str = '',
18
+ runs_prefix: str = "runs",
19
+ dataset_prefix: str = "datasets",
20
+ model_prefix: str = "models",
21
+ ):
22
+ """Initialize LocalStore.
23
+
24
+ Args:
25
+ base_path: Base directory path for the store
26
+ runs_prefix: Prefix for runs directory
27
+ dataset_prefix: Prefix for datasets directory
28
+ model_prefix: Prefix for models directory
29
+ """
30
+ super().__init__(base_path, runs_prefix, dataset_prefix, model_prefix)
31
+
32
+ def _full(self, key: str) -> Path:
33
+ p = self.base_path / key
34
+ p.parent.mkdir(parents=True, exist_ok=True)
35
+ return p
36
+
37
+ def put_tensor(self, key: str, tensor: torch.Tensor) -> None:
38
+ path = self._full(key)
39
+ storch.save_file({"tensor": tensor}, str(path))
40
+
41
+ def get_tensor(self, key: str) -> torch.Tensor:
42
+ loaded = storch.load_file(str(self._full(key)))
43
+ return loaded["tensor"]
44
+
45
+ def _validate_run_id(self, run_id: str) -> None:
46
+ if not run_id or not isinstance(run_id, str) or not run_id.strip():
47
+ raise ValueError(f"run_id must be a non-empty string, got: {run_id!r}")
48
+
49
+ def _validate_batch_index(self, batch_index: int) -> None:
50
+ if batch_index < 0:
51
+ raise ValueError(f"batch_index must be non-negative, got: {batch_index}")
52
+
53
+ def _validate_layer_key(self, layer: str, key: str) -> None:
54
+ if not layer or not isinstance(layer, str) or not layer.strip():
55
+ raise ValueError(f"layer must be a non-empty string, got: {layer!r}")
56
+ if not key or not isinstance(key, str) or not key.strip():
57
+ raise ValueError(f"key must be a non-empty string, got: {key!r}")
58
+
59
+ def _ensure_directory(self, path: Path) -> None:
60
+ path.mkdir(parents=True, exist_ok=True)
61
+
62
+ def put_run_batch(self, run_id: str, batch_index: int,
63
+ tensors: List[torch.Tensor] | Dict[str, torch.Tensor]) -> str:
64
+ if isinstance(tensors, dict):
65
+ to_save = tensors
66
+ elif isinstance(tensors, list):
67
+ if len(tensors) == 0:
68
+ to_save = {"_empty_list": torch.tensor([])}
69
+ else:
70
+ to_save = {f"item_{i}": t for i, t in enumerate(tensors)}
71
+ else:
72
+ to_save = {}
73
+ batch_path = self.base_path / f"{self.runs_prefix}/{run_id}/batch_{batch_index:06d}.safetensors"
74
+ self._ensure_directory(batch_path.parent)
75
+ storch.save_file(to_save, str(batch_path))
76
+ return f"{self.runs_prefix}/{run_id}/batch_{batch_index:06d}.safetensors"
77
+
78
+ def get_run_batch(self, run_id: str, batch_index: int) -> List[torch.Tensor] | Dict[
79
+ str, torch.Tensor]:
80
+
81
+ batch_path = self.base_path / f"{self.runs_prefix}/{run_id}/batch_{batch_index:06d}.safetensors"
82
+ if batch_path.exists():
83
+ loaded = storch.load_file(str(batch_path))
84
+ keys = list(loaded.keys())
85
+ if keys == ["_empty_list"]:
86
+ return []
87
+ if keys and all(k.startswith("item_") for k in keys):
88
+ try:
89
+ items = sorted(((int(k.split("_", 1)[1]), v) for k, v in loaded.items()), key=lambda x: x[0])
90
+ if [i for i, _ in items] == list(range(len(items))):
91
+ return [v for _, v in items]
92
+ except Exception:
93
+ pass
94
+ return loaded
95
+
96
+ detector_base = self.base_path / self.runs_prefix / run_id / f"batch_{batch_index}"
97
+ if detector_base.exists():
98
+ result: Dict[str, torch.Tensor] = {}
99
+
100
+ layer_dirs = [d for d in detector_base.iterdir() if d.is_dir()]
101
+ for layer_dir in layer_dirs:
102
+ activations_path = layer_dir / "activations.safetensors"
103
+ if activations_path.exists():
104
+ try:
105
+ loaded_tensor = storch.load_file(str(activations_path))["tensor"]
106
+ # Use layer_signature as key, or "activations" if only one layer
107
+ layer_sig = layer_dir.name
108
+ if len(layer_dirs) == 1:
109
+ # Only one layer, use simple "activations" key for compatibility
110
+ result["activations"] = loaded_tensor
111
+ else:
112
+ # Multiple layers, use layer-specific key
113
+ result[f"activations_{layer_sig}"] = loaded_tensor
114
+ except Exception:
115
+ pass
116
+
117
+ if result:
118
+ return result
119
+
120
+ # If neither exists, raise FileNotFoundError
121
+ raise FileNotFoundError(f"Batch {batch_index} not found for run {run_id}")
122
+
123
+ def list_run_batches(self, run_id: str) -> List[int]:
124
+ base = self.base_path / self.runs_prefix / run_id
125
+ if not base.exists():
126
+ return []
127
+ out: set[int] = set()
128
+
129
+ for p in sorted(base.glob("batch_*.safetensors")):
130
+ name = p.name
131
+ try:
132
+ idx = int(name[len("batch_"): len("batch_") + 6])
133
+ out.add(idx)
134
+ except Exception:
135
+ continue
136
+
137
+ for p in sorted(base.glob("batch_*")):
138
+ if p.is_dir():
139
+ name = p.name
140
+ try:
141
+ idx = int(name[len("batch_"):])
142
+ out.add(idx)
143
+ except Exception:
144
+ continue
145
+
146
+ return sorted(list(out))
147
+
148
+ def delete_run(self, run_id: str) -> None:
149
+ base = self.base_path / self.runs_prefix / run_id
150
+ if not base.exists():
151
+ return
152
+ for p in base.glob("batch_*.safetensors"):
153
+ if p.is_file():
154
+ p.unlink()
155
+ for p in base.glob("batch_*"):
156
+ if p.is_dir():
157
+ shutil.rmtree(p, ignore_errors=True)
158
+ metadata_path = self._run_metadata_key(run_id)
159
+ if metadata_path.exists():
160
+ metadata_path.unlink()
161
+
162
+ def put_run_metadata(self, run_id: str, meta: Dict[str, Any]) -> str:
163
+ self._validate_run_id(run_id)
164
+
165
+ metadata_path = self._run_metadata_key(run_id)
166
+ self._ensure_directory(metadata_path.parent)
167
+
168
+ try:
169
+ with metadata_path.open("w", encoding="utf-8") as f:
170
+ json.dump(meta, f, ensure_ascii=False, indent=2)
171
+ except (TypeError, ValueError) as e:
172
+ raise ValueError(
173
+ f"Metadata is not JSON-serializable for run_id={run_id!r}. "
174
+ f"Error: {e}"
175
+ ) from e
176
+ except OSError as e:
177
+ raise OSError(
178
+ f"Failed to write metadata file at {metadata_path} for run_id={run_id!r}. "
179
+ f"Error: {e}"
180
+ ) from e
181
+
182
+ return f"{self.runs_prefix}/{run_id}/meta.json"
183
+
184
+ def get_run_metadata(self, run_id: str) -> Dict[str, Any]:
185
+ self._validate_run_id(run_id)
186
+
187
+ metadata_path = self._run_metadata_key(run_id)
188
+ if not metadata_path.exists():
189
+ return {}
190
+
191
+ try:
192
+ with metadata_path.open("r", encoding="utf-8") as f:
193
+ return json.load(f)
194
+ except json.JSONDecodeError as e:
195
+ raise json.JSONDecodeError(
196
+ f"Invalid JSON in metadata file at {metadata_path} for run_id={run_id!r}",
197
+ e.doc,
198
+ e.pos
199
+ ) from e
200
+
201
+ def put_detector_metadata(
202
+ self,
203
+ run_id: str,
204
+ batch_index: int,
205
+ metadata: Dict[str, Any],
206
+ tensor_metadata: TensorMetadata
207
+ ) -> str:
208
+ self._validate_run_id(run_id)
209
+ self._validate_batch_index(batch_index)
210
+
211
+ batch_dir = self._run_batch_key(run_id, batch_index)
212
+ self._ensure_directory(batch_dir)
213
+
214
+ tensor_metadata_names = {
215
+ str(layer_signature): list(detector_tensors.keys())
216
+ for layer_signature, detector_tensors in tensor_metadata.items()
217
+ if detector_tensors
218
+ }
219
+ metadata_with_tensor_names = {
220
+ **metadata,
221
+ "_tensor_metadata_names": tensor_metadata_names
222
+ }
223
+
224
+ detector_metadata_path = batch_dir / "metadata.json"
225
+ try:
226
+ with detector_metadata_path.open("w", encoding="utf-8") as f:
227
+ json.dump(metadata_with_tensor_names, f, ensure_ascii=False, indent=2)
228
+ except (TypeError, ValueError) as e:
229
+ raise ValueError(
230
+ f"Metadata is not JSON-serializable for run_id={run_id!r}, "
231
+ f"batch_index={batch_index}. Error: {e}"
232
+ ) from e
233
+ except OSError as e:
234
+ raise OSError(
235
+ f"Failed to write metadata file at {detector_metadata_path} for "
236
+ f"run_id={run_id!r}, batch_index={batch_index}. Error: {e}"
237
+ ) from e
238
+
239
+ for layer_signature, detector_tensors in tensor_metadata.items():
240
+ if not detector_tensors:
241
+ continue
242
+
243
+ layer_dir = batch_dir / layer_signature
244
+ self._ensure_directory(layer_dir)
245
+
246
+ # Save each tensor key (e.g., "activations") as a separate safetensors file
247
+ for tensor_key, tensor in detector_tensors.items():
248
+ tensor_filename = f"{tensor_key}.safetensors"
249
+ tensor_path = layer_dir / tensor_filename
250
+ try:
251
+ storch.save_file({"tensor": tensor}, str(tensor_path))
252
+ except Exception as e:
253
+ raise OSError(
254
+ f"Failed to save tensor at {tensor_path} for run_id={run_id!r}, "
255
+ f"batch_index={batch_index}, layer={layer_signature!r}, "
256
+ f"key={tensor_key!r}. Error: {e}"
257
+ ) from e
258
+
259
+ return f"{self.runs_prefix}/{run_id}/batch_{batch_index}"
260
+
261
+ def get_detector_metadata(
262
+ self,
263
+ run_id: str,
264
+ batch_index: int
265
+ ) -> tuple[Dict[str, Any], TensorMetadata]:
266
+ self._validate_run_id(run_id)
267
+ self._validate_batch_index(batch_index)
268
+
269
+ batch_dir = self._run_batch_key(run_id, batch_index)
270
+ metadata_path = batch_dir / "metadata.json"
271
+
272
+ if not metadata_path.exists():
273
+ return {}, {}
274
+
275
+ try:
276
+ with metadata_path.open("r", encoding="utf-8") as f:
277
+ metadata = json.load(f)
278
+ except json.JSONDecodeError as e:
279
+ raise json.JSONDecodeError(
280
+ f"Invalid JSON in metadata file at {metadata_path} for "
281
+ f"run_id={run_id!r}, batch_index={batch_index}",
282
+ e.doc,
283
+ e.pos
284
+ ) from e
285
+
286
+ tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = {}
287
+ tensor_metadata_names = metadata.pop("_tensor_metadata_names", None)
288
+
289
+ if tensor_metadata_names is not None:
290
+ for layer_signature, tensor_keys in tensor_metadata_names.items():
291
+ layer_dir = batch_dir / layer_signature
292
+ detector_tensors: Dict[str, torch.Tensor] = {}
293
+ for tensor_key in tensor_keys:
294
+ tensor_filename = f"{tensor_key}.safetensors"
295
+ tensor_path = layer_dir / tensor_filename
296
+ if tensor_path.exists():
297
+ try:
298
+ detector_tensors[tensor_key] = storch.load_file(str(tensor_path))["tensor"]
299
+ except Exception as e:
300
+ raise OSError(
301
+ f"Failed to load tensor at {tensor_path} for "
302
+ f"run_id={run_id!r}, batch_index={batch_index}, "
303
+ f"layer={layer_signature!r}, key={tensor_key!r}. Error: {e}"
304
+ ) from e
305
+ if detector_tensors:
306
+ tensor_metadata[layer_signature] = detector_tensors
307
+ else:
308
+ raise ValueError(
309
+ f"Field '_tensor_metadata_names' not found in detector metadata at "
310
+ f"{metadata_path} for run_id={run_id!r}, batch_index={batch_index}. "
311
+ f"Cannot retrieve tensors."
312
+ )
313
+
314
+ return metadata, tensor_metadata
315
+
316
+ def get_detector_metadata_by_layer_by_key(
317
+ self,
318
+ run_id: str,
319
+ batch_index: int,
320
+ layer: str,
321
+ key: str
322
+ ) -> torch.Tensor:
323
+ self._validate_run_id(run_id)
324
+ self._validate_batch_index(batch_index)
325
+ self._validate_layer_key(layer, key)
326
+
327
+ batch_dir = self._run_batch_key(run_id, batch_index)
328
+ layer_dir = batch_dir / layer
329
+ tensor_path = layer_dir / f"{key}.safetensors"
330
+
331
+ if not tensor_path.exists():
332
+ raise FileNotFoundError(
333
+ f"Tensor not found at {tensor_path} for run_id={run_id!r}, "
334
+ f"batch_index={batch_index}, layer={layer!r}, key={key!r}"
335
+ )
336
+
337
+ try:
338
+ return storch.load_file(str(tensor_path))["tensor"]
339
+ except Exception as e:
340
+ raise OSError(
341
+ f"Failed to load tensor at {tensor_path} for run_id={run_id!r}, "
342
+ f"batch_index={batch_index}, layer={layer!r}, key={key!r}. Error: {e}"
343
+ ) from e
344
+
345
+ def put_run_detector_metadata(
346
+ self,
347
+ run_id: str,
348
+ metadata: Dict[str, Any],
349
+ tensor_metadata: TensorMetadata,
350
+ ) -> str:
351
+ self._validate_run_id(run_id)
352
+
353
+ detectors_dir = self._run_key(run_id) / "detectors"
354
+ self._ensure_directory(detectors_dir)
355
+
356
+ metadata_path = detectors_dir / "metadata.json"
357
+
358
+ if metadata_path.exists():
359
+ try:
360
+ with metadata_path.open("r", encoding="utf-8") as f:
361
+ aggregated = json.load(f)
362
+ except json.JSONDecodeError as e:
363
+ raise json.JSONDecodeError(
364
+ f"Invalid JSON in unified detector metadata at {metadata_path} for run_id={run_id!r}",
365
+ e.doc,
366
+ e.pos,
367
+ ) from e
368
+ else:
369
+ aggregated = {"batches": []}
370
+
371
+ batches = aggregated.setdefault("batches", [])
372
+ batch_index = len(batches)
373
+
374
+ tensor_metadata_names = {
375
+ layer_signature: list(detector_tensors.keys())
376
+ for layer_signature, detector_tensors in tensor_metadata.items()
377
+ if detector_tensors
378
+ }
379
+
380
+ batch_entry = {
381
+ **metadata,
382
+ "batch_index": batch_index,
383
+ "_tensor_metadata_names": tensor_metadata_names,
384
+ }
385
+
386
+ batches.append(batch_entry)
387
+
388
+ try:
389
+ with metadata_path.open("w", encoding="utf-8") as f:
390
+ json.dump(aggregated, f, ensure_ascii=False, indent=2)
391
+ except (TypeError, ValueError) as e:
392
+ raise ValueError(
393
+ f"Unified detector metadata is not JSON-serializable for run_id={run_id!r}. "
394
+ f"Error: {e}"
395
+ ) from e
396
+ except OSError as e:
397
+ raise OSError(
398
+ f"Failed to write unified detector metadata at {metadata_path} for run_id={run_id!r}. "
399
+ f"Error: {e}"
400
+ ) from e
401
+
402
+ for layer_signature, detector_tensors in tensor_metadata.items():
403
+ if not detector_tensors:
404
+ continue
405
+
406
+ layer_dir = detectors_dir / str(layer_signature)
407
+ self._ensure_directory(layer_dir)
408
+
409
+ for tensor_key, tensor in detector_tensors.items():
410
+ tensor_filename = f"{tensor_key}.safetensors"
411
+ tensor_path = layer_dir / tensor_filename
412
+
413
+ if tensor_path.exists():
414
+ try:
415
+ existing = storch.load_file(str(tensor_path))
416
+ except Exception as e:
417
+ raise OSError(
418
+ f"Failed to load existing unified tensor at {tensor_path} for "
419
+ f"run_id={run_id!r}, layer={layer_signature!r}, key={tensor_key!r}. "
420
+ f"Error: {e}"
421
+ ) from e
422
+ else:
423
+ existing = {}
424
+
425
+ batch_key = f"batch_{batch_index}"
426
+ existing[batch_key] = tensor
427
+
428
+ try:
429
+ storch.save_file(existing, str(tensor_path))
430
+ except Exception as e:
431
+ raise OSError(
432
+ f"Failed to save unified tensor at {tensor_path} for run_id={run_id!r}, "
433
+ f"layer={layer_signature!r}, key={tensor_key!r}, batch_index={batch_index}. "
434
+ f"Error: {e}"
435
+ ) from e
436
+
437
+ return f"{self.runs_prefix}/{run_id}/detectors"