mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Wandb logging utilities for SAE training."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
|
|
6
|
+
from amber.utils import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WandbLogger:
|
|
12
|
+
"""
|
|
13
|
+
Handles wandb logging for SAE training.
|
|
14
|
+
|
|
15
|
+
Encapsulates all wandb-related operations including initialization,
|
|
16
|
+
metric logging, and summary updates.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, config: SaeTrainingConfig, run_id: str):
|
|
20
|
+
"""
|
|
21
|
+
Initialize WandbLogger.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
config: Training configuration
|
|
25
|
+
run_id: Training run identifier
|
|
26
|
+
"""
|
|
27
|
+
self.config = config
|
|
28
|
+
self.run_id = run_id
|
|
29
|
+
self.wandb_run: Optional[Any] = None
|
|
30
|
+
self._initialized = False
|
|
31
|
+
|
|
32
|
+
def initialize(self) -> bool:
|
|
33
|
+
"""
|
|
34
|
+
Initialize wandb run if enabled in config.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
True if wandb was successfully initialized, False otherwise
|
|
38
|
+
"""
|
|
39
|
+
if not self.config.use_wandb:
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
import wandb
|
|
44
|
+
except ImportError:
|
|
45
|
+
logger.warning("[WandbLogger] wandb not installed, skipping wandb logging")
|
|
46
|
+
logger.warning("[WandbLogger] Install with: pip install wandb")
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
wandb_project = self.config.wandb_project or "sae-training"
|
|
51
|
+
wandb_name = self.config.wandb_name or self.run_id
|
|
52
|
+
wandb_mode = self.config.wandb_mode.lower() if self.config.wandb_mode else "online"
|
|
53
|
+
|
|
54
|
+
self.wandb_run = wandb.init(
|
|
55
|
+
project=wandb_project,
|
|
56
|
+
entity=self.config.wandb_entity,
|
|
57
|
+
name=wandb_name,
|
|
58
|
+
mode=wandb_mode,
|
|
59
|
+
config=self._build_wandb_config(),
|
|
60
|
+
tags=self.config.wandb_tags or [],
|
|
61
|
+
)
|
|
62
|
+
self._initialized = True
|
|
63
|
+
return True
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.warning(f"[WandbLogger] Unexpected error initializing wandb: {e}")
|
|
66
|
+
logger.warning("[WandbLogger] Continuing training without wandb logging")
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
def _build_wandb_config(self) -> dict[str, Any]:
|
|
70
|
+
"""Build wandb configuration dictionary."""
|
|
71
|
+
return {
|
|
72
|
+
"run_id": self.run_id,
|
|
73
|
+
"epochs": self.config.epochs,
|
|
74
|
+
"batch_size": self.config.batch_size,
|
|
75
|
+
"lr": self.config.lr,
|
|
76
|
+
"l1_lambda": self.config.l1_lambda,
|
|
77
|
+
"device": str(self.config.device),
|
|
78
|
+
"dtype": str(self.config.dtype) if self.config.dtype else None,
|
|
79
|
+
"use_amp": self.config.use_amp,
|
|
80
|
+
"clip_grad": self.config.clip_grad,
|
|
81
|
+
"max_batches_per_epoch": self.config.max_batches_per_epoch,
|
|
82
|
+
**(self.config.wandb_config or {}),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
def log_metrics(
|
|
86
|
+
self,
|
|
87
|
+
history: dict[str, list[float | None]],
|
|
88
|
+
verbose: bool = False
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Log training metrics to wandb.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
history: Dictionary with training history (loss, r2, l1, l0, etc.)
|
|
95
|
+
verbose: Whether to log verbose information
|
|
96
|
+
"""
|
|
97
|
+
if not self._initialized or self.wandb_run is None:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
num_epochs = len(history.get("loss", []))
|
|
102
|
+
slow_metrics_freq = self.config.wandb_slow_metrics_frequency
|
|
103
|
+
|
|
104
|
+
# Helper to get last known value for slow metrics
|
|
105
|
+
def get_last_known_value(values: list[float | None], idx: int) -> float:
|
|
106
|
+
"""Get the last non-None value up to idx, or 0.0 if none found."""
|
|
107
|
+
for i in range(idx, -1, -1):
|
|
108
|
+
if i < len(values) and values[i] is not None:
|
|
109
|
+
return values[i]
|
|
110
|
+
return 0.0
|
|
111
|
+
|
|
112
|
+
# Log metrics for each epoch
|
|
113
|
+
for epoch in range(1, num_epochs + 1):
|
|
114
|
+
epoch_idx = epoch - 1
|
|
115
|
+
should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)
|
|
116
|
+
|
|
117
|
+
metrics = self._build_epoch_metrics(history, epoch_idx, should_log_slow, get_last_known_value)
|
|
118
|
+
self.wandb_run.log(metrics)
|
|
119
|
+
|
|
120
|
+
# Log final summary metrics
|
|
121
|
+
self._log_summary_metrics(history, get_last_known_value)
|
|
122
|
+
|
|
123
|
+
if verbose:
|
|
124
|
+
self._log_wandb_url()
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
logger.warning(f"[WandbLogger] Failed to log metrics to wandb: {e}")
|
|
128
|
+
|
|
129
|
+
def _build_epoch_metrics(
|
|
130
|
+
self,
|
|
131
|
+
history: dict[str, list[float | None]],
|
|
132
|
+
epoch_idx: int,
|
|
133
|
+
should_log_slow: bool,
|
|
134
|
+
get_last_known_value: Any
|
|
135
|
+
) -> dict[str, Any]:
|
|
136
|
+
"""Build metrics dictionary for a single epoch."""
|
|
137
|
+
metrics = {
|
|
138
|
+
"epoch": epoch_idx + 1,
|
|
139
|
+
"train/loss": self._safe_get(history["loss"], epoch_idx, 0.0),
|
|
140
|
+
"train/reconstruction_mse": self._safe_get(history["recon_mse"], epoch_idx, 0.0),
|
|
141
|
+
"train/r2_score": self._safe_get(history["r2"], epoch_idx, 0.0),
|
|
142
|
+
"train/l1_penalty": self._safe_get(history["l1"], epoch_idx, 0.0),
|
|
143
|
+
"train/learning_rate": self.config.lr,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Add slow metrics if computed this epoch
|
|
147
|
+
if should_log_slow:
|
|
148
|
+
l0_val = self._get_slow_metric(history["l0"], epoch_idx, get_last_known_value)
|
|
149
|
+
dead_pct = self._get_slow_metric(history["dead_features_pct"], epoch_idx, get_last_known_value)
|
|
150
|
+
metrics["train/l0_sparsity"] = l0_val
|
|
151
|
+
metrics["train/dead_features_pct"] = dead_pct
|
|
152
|
+
|
|
153
|
+
return metrics
|
|
154
|
+
|
|
155
|
+
def _get_slow_metric(
|
|
156
|
+
self,
|
|
157
|
+
values: list[float | None],
|
|
158
|
+
epoch_idx: int,
|
|
159
|
+
get_last_known_value: Any
|
|
160
|
+
) -> float:
|
|
161
|
+
"""Get slow metric value, using last known value if current is None."""
|
|
162
|
+
if epoch_idx < len(values) and values[epoch_idx] is not None:
|
|
163
|
+
return values[epoch_idx]
|
|
164
|
+
return get_last_known_value(values, epoch_idx)
|
|
165
|
+
|
|
166
|
+
def _safe_get(self, values: list[float | None], idx: int, default: float) -> float:
|
|
167
|
+
"""Safely get value from list, returning default if out of bounds."""
|
|
168
|
+
if idx < len(values) and values[idx] is not None:
|
|
169
|
+
return values[idx]
|
|
170
|
+
return default
|
|
171
|
+
|
|
172
|
+
def _log_summary_metrics(
|
|
173
|
+
self,
|
|
174
|
+
history: dict[str, list[float | None]],
|
|
175
|
+
get_last_known_value: Any
|
|
176
|
+
) -> None:
|
|
177
|
+
"""Log final summary metrics to wandb."""
|
|
178
|
+
if self.wandb_run is None:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
# Get last computed values for slow metrics
|
|
182
|
+
final_l0 = get_last_known_value(history["l0"], len(history["l0"]) - 1) if history.get("l0") else 0.0
|
|
183
|
+
final_dead_pct = get_last_known_value(
|
|
184
|
+
history["dead_features_pct"],
|
|
185
|
+
len(history["dead_features_pct"]) - 1
|
|
186
|
+
) if history.get("dead_features_pct") else 0.0
|
|
187
|
+
|
|
188
|
+
final_metrics = {
|
|
189
|
+
"final/loss": history["loss"][-1] if history.get("loss") else 0.0,
|
|
190
|
+
"final/reconstruction_mse": history["recon_mse"][-1] if history.get("recon_mse") else 0.0,
|
|
191
|
+
"final/r2_score": history["r2"][-1] if history.get("r2") else 0.0,
|
|
192
|
+
"final/l1_penalty": history["l1"][-1] if history.get("l1") else 0.0,
|
|
193
|
+
"final/l0_sparsity": final_l0,
|
|
194
|
+
"final/dead_features_pct": final_dead_pct,
|
|
195
|
+
"training/num_epochs": len(history.get("loss", [])),
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# Add best metrics
|
|
199
|
+
if history.get("loss"):
|
|
200
|
+
best_loss_idx = min(range(len(history["loss"])), key=lambda i: history["loss"][i] or float('inf'))
|
|
201
|
+
final_metrics["best/loss"] = history["loss"][best_loss_idx] or 0.0
|
|
202
|
+
final_metrics["best/loss_epoch"] = best_loss_idx + 1
|
|
203
|
+
|
|
204
|
+
if history.get("r2"):
|
|
205
|
+
best_r2_idx = max(range(len(history["r2"])), key=lambda i: history["r2"][i] or -float('inf'))
|
|
206
|
+
final_metrics["best/r2_score"] = history["r2"][best_r2_idx] or 0.0
|
|
207
|
+
final_metrics["best/r2_epoch"] = best_r2_idx + 1
|
|
208
|
+
|
|
209
|
+
self.wandb_run.summary.update(final_metrics)
|
|
210
|
+
|
|
211
|
+
def _log_wandb_url(self) -> None:
|
|
212
|
+
"""Log wandb run URL if available."""
|
|
213
|
+
if self.wandb_run is None:
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
url = self.wandb_run.url
|
|
218
|
+
logger.info(f"[WandbLogger] Metrics logged to wandb: {url}")
|
|
219
|
+
except (AttributeError, RuntimeError):
|
|
220
|
+
# Offline mode or URL not available
|
|
221
|
+
logger.info("[WandbLogger] Metrics logged to wandb (offline mode)")
|
|
222
|
+
|
amber/store/__init__.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, Any, List
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from amber.store.store import Store, TensorMetadata
|
|
9
|
+
import safetensors.torch as storch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LocalStore(Store):
|
|
13
|
+
"""Local filesystem implementation of Store interface."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
base_path: Path | str = '',
|
|
18
|
+
runs_prefix: str = "runs",
|
|
19
|
+
dataset_prefix: str = "datasets",
|
|
20
|
+
model_prefix: str = "models",
|
|
21
|
+
):
|
|
22
|
+
"""Initialize LocalStore.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
base_path: Base directory path for the store
|
|
26
|
+
runs_prefix: Prefix for runs directory
|
|
27
|
+
dataset_prefix: Prefix for datasets directory
|
|
28
|
+
model_prefix: Prefix for models directory
|
|
29
|
+
"""
|
|
30
|
+
super().__init__(base_path, runs_prefix, dataset_prefix, model_prefix)
|
|
31
|
+
|
|
32
|
+
def _full(self, key: str) -> Path:
|
|
33
|
+
p = self.base_path / key
|
|
34
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
35
|
+
return p
|
|
36
|
+
|
|
37
|
+
def put_tensor(self, key: str, tensor: torch.Tensor) -> None:
|
|
38
|
+
path = self._full(key)
|
|
39
|
+
storch.save_file({"tensor": tensor}, str(path))
|
|
40
|
+
|
|
41
|
+
def get_tensor(self, key: str) -> torch.Tensor:
|
|
42
|
+
loaded = storch.load_file(str(self._full(key)))
|
|
43
|
+
return loaded["tensor"]
|
|
44
|
+
|
|
45
|
+
def _validate_run_id(self, run_id: str) -> None:
|
|
46
|
+
if not run_id or not isinstance(run_id, str) or not run_id.strip():
|
|
47
|
+
raise ValueError(f"run_id must be a non-empty string, got: {run_id!r}")
|
|
48
|
+
|
|
49
|
+
def _validate_batch_index(self, batch_index: int) -> None:
|
|
50
|
+
if batch_index < 0:
|
|
51
|
+
raise ValueError(f"batch_index must be non-negative, got: {batch_index}")
|
|
52
|
+
|
|
53
|
+
def _validate_layer_key(self, layer: str, key: str) -> None:
|
|
54
|
+
if not layer or not isinstance(layer, str) or not layer.strip():
|
|
55
|
+
raise ValueError(f"layer must be a non-empty string, got: {layer!r}")
|
|
56
|
+
if not key or not isinstance(key, str) or not key.strip():
|
|
57
|
+
raise ValueError(f"key must be a non-empty string, got: {key!r}")
|
|
58
|
+
|
|
59
|
+
def _ensure_directory(self, path: Path) -> None:
|
|
60
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
|
|
62
|
+
def put_run_batch(self, run_id: str, batch_index: int,
|
|
63
|
+
tensors: List[torch.Tensor] | Dict[str, torch.Tensor]) -> str:
|
|
64
|
+
if isinstance(tensors, dict):
|
|
65
|
+
to_save = tensors
|
|
66
|
+
elif isinstance(tensors, list):
|
|
67
|
+
if len(tensors) == 0:
|
|
68
|
+
to_save = {"_empty_list": torch.tensor([])}
|
|
69
|
+
else:
|
|
70
|
+
to_save = {f"item_{i}": t for i, t in enumerate(tensors)}
|
|
71
|
+
else:
|
|
72
|
+
to_save = {}
|
|
73
|
+
batch_path = self.base_path / f"{self.runs_prefix}/{run_id}/batch_{batch_index:06d}.safetensors"
|
|
74
|
+
self._ensure_directory(batch_path.parent)
|
|
75
|
+
storch.save_file(to_save, str(batch_path))
|
|
76
|
+
return f"{self.runs_prefix}/{run_id}/batch_{batch_index:06d}.safetensors"
|
|
77
|
+
|
|
78
|
+
def get_run_batch(self, run_id: str, batch_index: int) -> List[torch.Tensor] | Dict[
|
|
79
|
+
str, torch.Tensor]:
|
|
80
|
+
|
|
81
|
+
batch_path = self.base_path / f"{self.runs_prefix}/{run_id}/batch_{batch_index:06d}.safetensors"
|
|
82
|
+
if batch_path.exists():
|
|
83
|
+
loaded = storch.load_file(str(batch_path))
|
|
84
|
+
keys = list(loaded.keys())
|
|
85
|
+
if keys == ["_empty_list"]:
|
|
86
|
+
return []
|
|
87
|
+
if keys and all(k.startswith("item_") for k in keys):
|
|
88
|
+
try:
|
|
89
|
+
items = sorted(((int(k.split("_", 1)[1]), v) for k, v in loaded.items()), key=lambda x: x[0])
|
|
90
|
+
if [i for i, _ in items] == list(range(len(items))):
|
|
91
|
+
return [v for _, v in items]
|
|
92
|
+
except Exception:
|
|
93
|
+
pass
|
|
94
|
+
return loaded
|
|
95
|
+
|
|
96
|
+
detector_base = self.base_path / self.runs_prefix / run_id / f"batch_{batch_index}"
|
|
97
|
+
if detector_base.exists():
|
|
98
|
+
result: Dict[str, torch.Tensor] = {}
|
|
99
|
+
|
|
100
|
+
layer_dirs = [d for d in detector_base.iterdir() if d.is_dir()]
|
|
101
|
+
for layer_dir in layer_dirs:
|
|
102
|
+
activations_path = layer_dir / "activations.safetensors"
|
|
103
|
+
if activations_path.exists():
|
|
104
|
+
try:
|
|
105
|
+
loaded_tensor = storch.load_file(str(activations_path))["tensor"]
|
|
106
|
+
# Use layer_signature as key, or "activations" if only one layer
|
|
107
|
+
layer_sig = layer_dir.name
|
|
108
|
+
if len(layer_dirs) == 1:
|
|
109
|
+
# Only one layer, use simple "activations" key for compatibility
|
|
110
|
+
result["activations"] = loaded_tensor
|
|
111
|
+
else:
|
|
112
|
+
# Multiple layers, use layer-specific key
|
|
113
|
+
result[f"activations_{layer_sig}"] = loaded_tensor
|
|
114
|
+
except Exception:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
if result:
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
# If neither exists, raise FileNotFoundError
|
|
121
|
+
raise FileNotFoundError(f"Batch {batch_index} not found for run {run_id}")
|
|
122
|
+
|
|
123
|
+
def list_run_batches(self, run_id: str) -> List[int]:
|
|
124
|
+
base = self.base_path / self.runs_prefix / run_id
|
|
125
|
+
if not base.exists():
|
|
126
|
+
return []
|
|
127
|
+
out: set[int] = set()
|
|
128
|
+
|
|
129
|
+
for p in sorted(base.glob("batch_*.safetensors")):
|
|
130
|
+
name = p.name
|
|
131
|
+
try:
|
|
132
|
+
idx = int(name[len("batch_"): len("batch_") + 6])
|
|
133
|
+
out.add(idx)
|
|
134
|
+
except Exception:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
for p in sorted(base.glob("batch_*")):
|
|
138
|
+
if p.is_dir():
|
|
139
|
+
name = p.name
|
|
140
|
+
try:
|
|
141
|
+
idx = int(name[len("batch_"):])
|
|
142
|
+
out.add(idx)
|
|
143
|
+
except Exception:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
return sorted(list(out))
|
|
147
|
+
|
|
148
|
+
def delete_run(self, run_id: str) -> None:
|
|
149
|
+
base = self.base_path / self.runs_prefix / run_id
|
|
150
|
+
if not base.exists():
|
|
151
|
+
return
|
|
152
|
+
for p in base.glob("batch_*.safetensors"):
|
|
153
|
+
if p.is_file():
|
|
154
|
+
p.unlink()
|
|
155
|
+
for p in base.glob("batch_*"):
|
|
156
|
+
if p.is_dir():
|
|
157
|
+
shutil.rmtree(p, ignore_errors=True)
|
|
158
|
+
metadata_path = self._run_metadata_key(run_id)
|
|
159
|
+
if metadata_path.exists():
|
|
160
|
+
metadata_path.unlink()
|
|
161
|
+
|
|
162
|
+
def put_run_metadata(self, run_id: str, meta: Dict[str, Any]) -> str:
|
|
163
|
+
self._validate_run_id(run_id)
|
|
164
|
+
|
|
165
|
+
metadata_path = self._run_metadata_key(run_id)
|
|
166
|
+
self._ensure_directory(metadata_path.parent)
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
with metadata_path.open("w", encoding="utf-8") as f:
|
|
170
|
+
json.dump(meta, f, ensure_ascii=False, indent=2)
|
|
171
|
+
except (TypeError, ValueError) as e:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Metadata is not JSON-serializable for run_id={run_id!r}. "
|
|
174
|
+
f"Error: {e}"
|
|
175
|
+
) from e
|
|
176
|
+
except OSError as e:
|
|
177
|
+
raise OSError(
|
|
178
|
+
f"Failed to write metadata file at {metadata_path} for run_id={run_id!r}. "
|
|
179
|
+
f"Error: {e}"
|
|
180
|
+
) from e
|
|
181
|
+
|
|
182
|
+
return f"{self.runs_prefix}/{run_id}/meta.json"
|
|
183
|
+
|
|
184
|
+
def get_run_metadata(self, run_id: str) -> Dict[str, Any]:
|
|
185
|
+
self._validate_run_id(run_id)
|
|
186
|
+
|
|
187
|
+
metadata_path = self._run_metadata_key(run_id)
|
|
188
|
+
if not metadata_path.exists():
|
|
189
|
+
return {}
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
with metadata_path.open("r", encoding="utf-8") as f:
|
|
193
|
+
return json.load(f)
|
|
194
|
+
except json.JSONDecodeError as e:
|
|
195
|
+
raise json.JSONDecodeError(
|
|
196
|
+
f"Invalid JSON in metadata file at {metadata_path} for run_id={run_id!r}",
|
|
197
|
+
e.doc,
|
|
198
|
+
e.pos
|
|
199
|
+
) from e
|
|
200
|
+
|
|
201
|
+
def put_detector_metadata(
|
|
202
|
+
self,
|
|
203
|
+
run_id: str,
|
|
204
|
+
batch_index: int,
|
|
205
|
+
metadata: Dict[str, Any],
|
|
206
|
+
tensor_metadata: TensorMetadata
|
|
207
|
+
) -> str:
|
|
208
|
+
self._validate_run_id(run_id)
|
|
209
|
+
self._validate_batch_index(batch_index)
|
|
210
|
+
|
|
211
|
+
batch_dir = self._run_batch_key(run_id, batch_index)
|
|
212
|
+
self._ensure_directory(batch_dir)
|
|
213
|
+
|
|
214
|
+
tensor_metadata_names = {
|
|
215
|
+
str(layer_signature): list(detector_tensors.keys())
|
|
216
|
+
for layer_signature, detector_tensors in tensor_metadata.items()
|
|
217
|
+
if detector_tensors
|
|
218
|
+
}
|
|
219
|
+
metadata_with_tensor_names = {
|
|
220
|
+
**metadata,
|
|
221
|
+
"_tensor_metadata_names": tensor_metadata_names
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
detector_metadata_path = batch_dir / "metadata.json"
|
|
225
|
+
try:
|
|
226
|
+
with detector_metadata_path.open("w", encoding="utf-8") as f:
|
|
227
|
+
json.dump(metadata_with_tensor_names, f, ensure_ascii=False, indent=2)
|
|
228
|
+
except (TypeError, ValueError) as e:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"Metadata is not JSON-serializable for run_id={run_id!r}, "
|
|
231
|
+
f"batch_index={batch_index}. Error: {e}"
|
|
232
|
+
) from e
|
|
233
|
+
except OSError as e:
|
|
234
|
+
raise OSError(
|
|
235
|
+
f"Failed to write metadata file at {detector_metadata_path} for "
|
|
236
|
+
f"run_id={run_id!r}, batch_index={batch_index}. Error: {e}"
|
|
237
|
+
) from e
|
|
238
|
+
|
|
239
|
+
for layer_signature, detector_tensors in tensor_metadata.items():
|
|
240
|
+
if not detector_tensors:
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
layer_dir = batch_dir / layer_signature
|
|
244
|
+
self._ensure_directory(layer_dir)
|
|
245
|
+
|
|
246
|
+
# Save each tensor key (e.g., "activations") as a separate safetensors file
|
|
247
|
+
for tensor_key, tensor in detector_tensors.items():
|
|
248
|
+
tensor_filename = f"{tensor_key}.safetensors"
|
|
249
|
+
tensor_path = layer_dir / tensor_filename
|
|
250
|
+
try:
|
|
251
|
+
storch.save_file({"tensor": tensor}, str(tensor_path))
|
|
252
|
+
except Exception as e:
|
|
253
|
+
raise OSError(
|
|
254
|
+
f"Failed to save tensor at {tensor_path} for run_id={run_id!r}, "
|
|
255
|
+
f"batch_index={batch_index}, layer={layer_signature!r}, "
|
|
256
|
+
f"key={tensor_key!r}. Error: {e}"
|
|
257
|
+
) from e
|
|
258
|
+
|
|
259
|
+
return f"{self.runs_prefix}/{run_id}/batch_{batch_index}"
|
|
260
|
+
|
|
261
|
+
def get_detector_metadata(
|
|
262
|
+
self,
|
|
263
|
+
run_id: str,
|
|
264
|
+
batch_index: int
|
|
265
|
+
) -> tuple[Dict[str, Any], TensorMetadata]:
|
|
266
|
+
self._validate_run_id(run_id)
|
|
267
|
+
self._validate_batch_index(batch_index)
|
|
268
|
+
|
|
269
|
+
batch_dir = self._run_batch_key(run_id, batch_index)
|
|
270
|
+
metadata_path = batch_dir / "metadata.json"
|
|
271
|
+
|
|
272
|
+
if not metadata_path.exists():
|
|
273
|
+
return {}, {}
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
with metadata_path.open("r", encoding="utf-8") as f:
|
|
277
|
+
metadata = json.load(f)
|
|
278
|
+
except json.JSONDecodeError as e:
|
|
279
|
+
raise json.JSONDecodeError(
|
|
280
|
+
f"Invalid JSON in metadata file at {metadata_path} for "
|
|
281
|
+
f"run_id={run_id!r}, batch_index={batch_index}",
|
|
282
|
+
e.doc,
|
|
283
|
+
e.pos
|
|
284
|
+
) from e
|
|
285
|
+
|
|
286
|
+
tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = {}
|
|
287
|
+
tensor_metadata_names = metadata.pop("_tensor_metadata_names", None)
|
|
288
|
+
|
|
289
|
+
if tensor_metadata_names is not None:
|
|
290
|
+
for layer_signature, tensor_keys in tensor_metadata_names.items():
|
|
291
|
+
layer_dir = batch_dir / layer_signature
|
|
292
|
+
detector_tensors: Dict[str, torch.Tensor] = {}
|
|
293
|
+
for tensor_key in tensor_keys:
|
|
294
|
+
tensor_filename = f"{tensor_key}.safetensors"
|
|
295
|
+
tensor_path = layer_dir / tensor_filename
|
|
296
|
+
if tensor_path.exists():
|
|
297
|
+
try:
|
|
298
|
+
detector_tensors[tensor_key] = storch.load_file(str(tensor_path))["tensor"]
|
|
299
|
+
except Exception as e:
|
|
300
|
+
raise OSError(
|
|
301
|
+
f"Failed to load tensor at {tensor_path} for "
|
|
302
|
+
f"run_id={run_id!r}, batch_index={batch_index}, "
|
|
303
|
+
f"layer={layer_signature!r}, key={tensor_key!r}. Error: {e}"
|
|
304
|
+
) from e
|
|
305
|
+
if detector_tensors:
|
|
306
|
+
tensor_metadata[layer_signature] = detector_tensors
|
|
307
|
+
else:
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"Field '_tensor_metadata_names' not found in detector metadata at "
|
|
310
|
+
f"{metadata_path} for run_id={run_id!r}, batch_index={batch_index}. "
|
|
311
|
+
f"Cannot retrieve tensors."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return metadata, tensor_metadata
|
|
315
|
+
|
|
316
|
+
def get_detector_metadata_by_layer_by_key(
|
|
317
|
+
self,
|
|
318
|
+
run_id: str,
|
|
319
|
+
batch_index: int,
|
|
320
|
+
layer: str,
|
|
321
|
+
key: str
|
|
322
|
+
) -> torch.Tensor:
|
|
323
|
+
self._validate_run_id(run_id)
|
|
324
|
+
self._validate_batch_index(batch_index)
|
|
325
|
+
self._validate_layer_key(layer, key)
|
|
326
|
+
|
|
327
|
+
batch_dir = self._run_batch_key(run_id, batch_index)
|
|
328
|
+
layer_dir = batch_dir / layer
|
|
329
|
+
tensor_path = layer_dir / f"{key}.safetensors"
|
|
330
|
+
|
|
331
|
+
if not tensor_path.exists():
|
|
332
|
+
raise FileNotFoundError(
|
|
333
|
+
f"Tensor not found at {tensor_path} for run_id={run_id!r}, "
|
|
334
|
+
f"batch_index={batch_index}, layer={layer!r}, key={key!r}"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
return storch.load_file(str(tensor_path))["tensor"]
|
|
339
|
+
except Exception as e:
|
|
340
|
+
raise OSError(
|
|
341
|
+
f"Failed to load tensor at {tensor_path} for run_id={run_id!r}, "
|
|
342
|
+
f"batch_index={batch_index}, layer={layer!r}, key={key!r}. Error: {e}"
|
|
343
|
+
) from e
|
|
344
|
+
|
|
345
|
+
def put_run_detector_metadata(
|
|
346
|
+
self,
|
|
347
|
+
run_id: str,
|
|
348
|
+
metadata: Dict[str, Any],
|
|
349
|
+
tensor_metadata: TensorMetadata,
|
|
350
|
+
) -> str:
|
|
351
|
+
self._validate_run_id(run_id)
|
|
352
|
+
|
|
353
|
+
detectors_dir = self._run_key(run_id) / "detectors"
|
|
354
|
+
self._ensure_directory(detectors_dir)
|
|
355
|
+
|
|
356
|
+
metadata_path = detectors_dir / "metadata.json"
|
|
357
|
+
|
|
358
|
+
if metadata_path.exists():
|
|
359
|
+
try:
|
|
360
|
+
with metadata_path.open("r", encoding="utf-8") as f:
|
|
361
|
+
aggregated = json.load(f)
|
|
362
|
+
except json.JSONDecodeError as e:
|
|
363
|
+
raise json.JSONDecodeError(
|
|
364
|
+
f"Invalid JSON in unified detector metadata at {metadata_path} for run_id={run_id!r}",
|
|
365
|
+
e.doc,
|
|
366
|
+
e.pos,
|
|
367
|
+
) from e
|
|
368
|
+
else:
|
|
369
|
+
aggregated = {"batches": []}
|
|
370
|
+
|
|
371
|
+
batches = aggregated.setdefault("batches", [])
|
|
372
|
+
batch_index = len(batches)
|
|
373
|
+
|
|
374
|
+
tensor_metadata_names = {
|
|
375
|
+
layer_signature: list(detector_tensors.keys())
|
|
376
|
+
for layer_signature, detector_tensors in tensor_metadata.items()
|
|
377
|
+
if detector_tensors
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
batch_entry = {
|
|
381
|
+
**metadata,
|
|
382
|
+
"batch_index": batch_index,
|
|
383
|
+
"_tensor_metadata_names": tensor_metadata_names,
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
batches.append(batch_entry)
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
with metadata_path.open("w", encoding="utf-8") as f:
|
|
390
|
+
json.dump(aggregated, f, ensure_ascii=False, indent=2)
|
|
391
|
+
except (TypeError, ValueError) as e:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Unified detector metadata is not JSON-serializable for run_id={run_id!r}. "
|
|
394
|
+
f"Error: {e}"
|
|
395
|
+
) from e
|
|
396
|
+
except OSError as e:
|
|
397
|
+
raise OSError(
|
|
398
|
+
f"Failed to write unified detector metadata at {metadata_path} for run_id={run_id!r}. "
|
|
399
|
+
f"Error: {e}"
|
|
400
|
+
) from e
|
|
401
|
+
|
|
402
|
+
for layer_signature, detector_tensors in tensor_metadata.items():
|
|
403
|
+
if not detector_tensors:
|
|
404
|
+
continue
|
|
405
|
+
|
|
406
|
+
layer_dir = detectors_dir / str(layer_signature)
|
|
407
|
+
self._ensure_directory(layer_dir)
|
|
408
|
+
|
|
409
|
+
for tensor_key, tensor in detector_tensors.items():
|
|
410
|
+
tensor_filename = f"{tensor_key}.safetensors"
|
|
411
|
+
tensor_path = layer_dir / tensor_filename
|
|
412
|
+
|
|
413
|
+
if tensor_path.exists():
|
|
414
|
+
try:
|
|
415
|
+
existing = storch.load_file(str(tensor_path))
|
|
416
|
+
except Exception as e:
|
|
417
|
+
raise OSError(
|
|
418
|
+
f"Failed to load existing unified tensor at {tensor_path} for "
|
|
419
|
+
f"run_id={run_id!r}, layer={layer_signature!r}, key={tensor_key!r}. "
|
|
420
|
+
f"Error: {e}"
|
|
421
|
+
) from e
|
|
422
|
+
else:
|
|
423
|
+
existing = {}
|
|
424
|
+
|
|
425
|
+
batch_key = f"batch_{batch_index}"
|
|
426
|
+
existing[batch_key] = tensor
|
|
427
|
+
|
|
428
|
+
try:
|
|
429
|
+
storch.save_file(existing, str(tensor_path))
|
|
430
|
+
except Exception as e:
|
|
431
|
+
raise OSError(
|
|
432
|
+
f"Failed to save unified tensor at {tensor_path} for run_id={run_id!r}, "
|
|
433
|
+
f"layer={layer_signature!r}, key={tensor_key!r}, batch_index={batch_index}. "
|
|
434
|
+
f"Error: {e}"
|
|
435
|
+
) from e
|
|
436
|
+
|
|
437
|
+
return f"{self.runs_prefix}/{run_id}/detectors"
|