representation-geometry 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ from representation_geometry.api import analyze_model
2
+ from representation_geometry.metrics import metric_from_eigenvalues, subspace_novelty
3
+ from representation_geometry.online import RunningCovariance
4
+ from representation_geometry.results import AnalysisResults
5
+
6
+ __all__ = [
7
+ "AnalysisResults",
8
+ "RunningCovariance",
9
+ "analyze_model",
10
+ "metric_from_eigenvalues",
11
+ "subspace_novelty",
12
+ ]
@@ -0,0 +1,366 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.metadata
4
+ import platform
5
+ import time
6
+ from collections.abc import Iterable, Mapping, Sequence
7
+ from typing import Any
8
+
9
+ import pandas as pd
10
+ import torch
11
+
12
+ from representation_geometry.diagnostics import DIAGNOSTIC_REGISTRY, available_diagnostics
13
+ from representation_geometry.hooks import (
14
+ ResidualStatsStore,
15
+ get_input_device,
16
+ move_to_device,
17
+ register_activation_hooks,
18
+ resolve_hook_targets,
19
+ )
20
+ from representation_geometry.metrics import compute_metrics_from_stats
21
+ from representation_geometry.moe import MoERouterStats
22
+ from representation_geometry.results import AnalysisResults
23
+
24
+ ARTIFACT_SCHEMA_VERSION = "0.1"
25
+
26
+
27
+ def analyze_model(
28
+ model_name: str | torch.nn.Module | None = None,
29
+ dataloader: Iterable[Any] | None = None,
30
+ *,
31
+ model: torch.nn.Module | None = None,
32
+ max_tokens: int | None = 100_000,
33
+ metrics: str | Sequence[str] | None = "default",
34
+ diagnostics: Sequence[str] | None = None,
35
+ layers: str | Sequence[int] = "all",
36
+ hook_point: str = "residual_input",
37
+ save_mode: str | None = None,
38
+ save_dir: str | None = None,
39
+ sample_limit: int = 512,
40
+ novelty_k: int = 32,
41
+ device: str | torch.device | None = None,
42
+ model_kwargs: Mapping[str, Any] | None = None,
43
+ trust_remote_code: bool = False,
44
+ output_router_logits: bool | str = "auto",
45
+ run_name: str | None = None,
46
+ tokenizer_name: str | None = None,
47
+ dataset_name: str | None = None,
48
+ run_metadata: Mapping[str, Any] | None = None,
49
+ ) -> AnalysisResults:
50
+ """Analyze residual-stream covariance geometry for a transformer model.
51
+
52
+ Parameters
53
+ ----------
54
+ model_name:
55
+ Either a HuggingFace model id or an already loaded ``torch.nn.Module``.
56
+ Passing the model object here is supported to match the rough API used
57
+ in the paper workflow.
58
+ dataloader:
59
+ Iterable of batches. Dict batches with ``input_ids`` are preferred.
60
+ model:
61
+ Explicit loaded model. If supplied, it takes precedence over
62
+ ``model_name``.
63
+ max_tokens:
64
+ Maximum token budget. Full batches are processed until this budget is
65
+ reached; the recorded metadata reports the actual observed tokens.
66
+ metrics:
67
+ ``"default"``, ``None``, or ``[]`` computes the default spectral metrics.
68
+ The list form is reserved for future metric subsets.
69
+ diagnostics:
70
+ Optional diagnostics. Currently supports ``"normalization_ablation"``.
71
+ layers:
72
+ ``"all"`` or a sequence of integer block indices for residual hooks.
73
+ hook_point:
74
+ ``"residual_input"``, ``"residual_output"``, ``"module_input:<path>"``,
75
+ or ``"module_output:<path>"``.
76
+ save_mode:
77
+ ``None``, ``"json"``, ``"csv"``, or ``"bundle"``.
78
+ save_dir:
79
+ Output directory for saved artifacts.
80
+ """
81
+
82
+ if dataloader is None:
83
+ raise ValueError("dataloader is required; pass tokenized batches for the model.")
84
+
85
+ metric_names = _normalize_metrics(metrics)
86
+ diagnostic_names = set(diagnostics or [])
87
+ unknown_diagnostics = diagnostic_names - set(available_diagnostics())
88
+ if unknown_diagnostics:
89
+ raise ValueError(f"Unknown diagnostics: {sorted(unknown_diagnostics)}")
90
+
91
+ model_obj, model_id = _resolve_model(model_name, model, model_kwargs, trust_remote_code)
92
+ model_obj.eval()
93
+
94
+ input_device = torch.device(device) if device is not None else get_input_device(model_obj)
95
+ if device is not None and not _has_device_map(model_obj):
96
+ model_obj.to(input_device)
97
+
98
+ hook_targets, hook_metadata = resolve_hook_targets(
99
+ model_obj,
100
+ hook_point=hook_point,
101
+ layers=layers,
102
+ )
103
+ store = ResidualStatsStore(sample_limit=sample_limit)
104
+ handles = register_activation_hooks(hook_targets, store, hook_point=hook_point)
105
+
106
+ router_stats = _make_router_stats(model_obj, len(hook_targets), output_router_logits)
107
+ ask_router_logits = router_stats is not None and output_router_logits is not False
108
+
109
+ started = time.time()
110
+ batches_seen = 0
111
+ tokens_seen = 0
112
+
113
+ try:
114
+ with torch.no_grad():
115
+ for batch in dataloader:
116
+ if max_tokens is not None and tokens_seen >= max_tokens:
117
+ break
118
+ batch_tokens = _count_tokens(batch)
119
+ moved_batch = move_to_device(batch, input_device)
120
+ outputs = _call_model(
121
+ model_obj,
122
+ moved_batch,
123
+ output_router_logits=ask_router_logits,
124
+ )
125
+ if router_stats is not None:
126
+ router_stats.update_from_outputs(outputs)
127
+ batches_seen += 1
128
+ tokens_seen += batch_tokens
129
+ finally:
130
+ for handle in handles:
131
+ handle.remove()
132
+
133
+ metrics_df, novelty_df, eigenvalues = compute_metrics_from_stats(
134
+ store.stats,
135
+ model_key=_safe_model_key(model_id),
136
+ model_label=model_id,
137
+ novelty_k=novelty_k,
138
+ include_spectrum_metrics="spectrum" in metric_names or "default" in metric_names,
139
+ include_sample_metrics="sample" in metric_names or "default" in metric_names,
140
+ include_novelty_metrics="novelty" in metric_names or "default" in metric_names,
141
+ )
142
+
143
+ router_df = router_stats.to_frame() if router_stats is not None else pd.DataFrame()
144
+ diagnostic_tables: dict[str, pd.DataFrame] = {}
145
+ for diagnostic_name in sorted(diagnostic_names):
146
+ diagnostic_tables[diagnostic_name] = DIAGNOSTIC_REGISTRY[diagnostic_name](
147
+ store.stats,
148
+ model_key=_safe_model_key(model_id),
149
+ model_label=model_id,
150
+ )
151
+
152
+ runtime_sec = time.time() - started
153
+ metadata = {
154
+ "artifact_schema_version": ARTIFACT_SCHEMA_VERSION,
155
+ "model_id": model_id,
156
+ "model_class": type(model_obj).__name__,
157
+ "max_tokens": max_tokens,
158
+ "tokens_observed": int(tokens_seen),
159
+ "batches_observed": int(batches_seen),
160
+ "num_blocks": hook_metadata.get("num_blocks"),
161
+ "layers": hook_metadata.get("layers"),
162
+ "hook_point": hook_point,
163
+ "hook_targets": hook_metadata.get("hook_targets", []),
164
+ "sample_limit": sample_limit,
165
+ "novelty_k": novelty_k,
166
+ "metrics": sorted(metric_names),
167
+ "diagnostics": sorted(diagnostic_names),
168
+ "input_device": str(input_device),
169
+ "runtime_sec": runtime_sec,
170
+ "save_mode": save_mode,
171
+ "run_name": run_name,
172
+ "tokenizer": tokenizer_name,
173
+ "dataset": dataset_name,
174
+ "model_config": _model_config_summary(model_obj),
175
+ "software": _software_versions(),
176
+ }
177
+ if run_metadata:
178
+ metadata["run_metadata"] = dict(run_metadata)
179
+
180
+ results = AnalysisResults(
181
+ metrics=metrics_df,
182
+ novelty=novelty_df,
183
+ router=router_df,
184
+ diagnostics=diagnostic_tables,
185
+ eigenvalues=eigenvalues,
186
+ metadata=metadata,
187
+ )
188
+
189
+ if save_mode is not None:
190
+ results.save(save_dir=save_dir, save_mode=save_mode)
191
+
192
+ return results
193
+
194
+
195
+ def _normalize_metrics(metrics: str | Sequence[str] | None) -> set[str]:
196
+ if metrics is None or metrics == "default" or metrics == []:
197
+ return {"default", "spectrum", "novelty", "sample"}
198
+ if isinstance(metrics, str):
199
+ return {metrics}
200
+ metric_names = set(metrics)
201
+ unknown = metric_names - {"default", "spectrum", "novelty", "sample", "router"}
202
+ if unknown:
203
+ raise ValueError(f"Unknown metrics: {sorted(unknown)}")
204
+ return metric_names or {"default", "spectrum", "novelty", "sample"}
205
+
206
+
207
+ def _resolve_model(
208
+ model_name: str | torch.nn.Module | None,
209
+ model: torch.nn.Module | None,
210
+ model_kwargs: Mapping[str, Any] | None,
211
+ trust_remote_code: bool,
212
+ ) -> tuple[torch.nn.Module, str]:
213
+ if model is not None:
214
+ return model, str(model_name or _infer_model_id(model))
215
+ if isinstance(model_name, torch.nn.Module):
216
+ return model_name, _infer_model_id(model_name)
217
+ if isinstance(model_name, str):
218
+ try:
219
+ from transformers import AutoModelForCausalLM
220
+ except ImportError as exc:
221
+ raise ImportError(
222
+ "Install HuggingFace support with: python -m pip install -e '.[hf]'"
223
+ ) from exc
224
+ kwargs = {"trust_remote_code": trust_remote_code}
225
+ kwargs.update(dict(model_kwargs or {}))
226
+ loaded = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
227
+ return loaded, model_name
228
+ raise ValueError(
229
+ "Pass either model_name=<hf id>, model_name=<model object>, or model=<model object>."
230
+ )
231
+
232
+
233
+ def _has_device_map(model: torch.nn.Module) -> bool:
234
+ return hasattr(model, "hf_device_map") and bool(model.hf_device_map)
235
+
236
+
237
+ def _make_router_stats(
238
+ model: torch.nn.Module,
239
+ num_layers: int,
240
+ output_router_logits: bool | str,
241
+ ) -> MoERouterStats | None:
242
+ if output_router_logits is False:
243
+ return None
244
+ config = getattr(model, "config", None)
245
+ if config is None:
246
+ return None
247
+ num_experts = getattr(config, "num_local_experts", getattr(config, "num_experts", None))
248
+ if num_experts is None:
249
+ return None
250
+ top_k = int(getattr(config, "num_experts_per_tok", 2))
251
+ return MoERouterStats(num_layers=num_layers, num_experts=int(num_experts), top_k=top_k)
252
+
253
+
254
+ def _call_model(model: torch.nn.Module, batch: Any, *, output_router_logits: bool) -> Any:
255
+ extra = {"use_cache": False}
256
+ if output_router_logits:
257
+ extra["output_router_logits"] = True
258
+
259
+ if isinstance(batch, Mapping):
260
+ try:
261
+ return model(**batch, **extra)
262
+ except TypeError:
263
+ try:
264
+ return model(**batch, use_cache=False)
265
+ except TypeError:
266
+ return model(**batch)
267
+
268
+ if isinstance(batch, tuple):
269
+ try:
270
+ return model(*batch, **extra)
271
+ except TypeError:
272
+ return model(*batch)
273
+
274
+ if isinstance(batch, list):
275
+ try:
276
+ return model(*batch, **extra)
277
+ except TypeError:
278
+ return model(*batch)
279
+
280
+ try:
281
+ return model(batch, **extra)
282
+ except TypeError:
283
+ return model(batch)
284
+
285
+
286
+ def _count_tokens(batch: Any) -> int:
287
+ if isinstance(batch, Mapping):
288
+ input_ids = batch.get("input_ids")
289
+ if isinstance(input_ids, torch.Tensor):
290
+ return int(input_ids.numel())
291
+ for value in batch.values():
292
+ if isinstance(value, torch.Tensor):
293
+ return int(value.shape[0])
294
+ return 0
295
+ if isinstance(batch, torch.Tensor):
296
+ return int(batch.numel()) if batch.ndim <= 2 else int(batch.shape[0] * batch.shape[1])
297
+ if isinstance(batch, (tuple, list)):
298
+ for value in batch:
299
+ if isinstance(value, torch.Tensor):
300
+ if value.ndim <= 2:
301
+ return int(value.numel())
302
+ return int(value.shape[0] * value.shape[1])
303
+ return 0
304
+
305
+
306
+ def _safe_model_key(model_id: str) -> str:
307
+ return (
308
+ model_id.replace("/", "_")
309
+ .replace("\\", "_")
310
+ .replace(" ", "_")
311
+ .replace(":", "_")
312
+ .lower()
313
+ )
314
+
315
+
316
+ def _infer_model_id(model: torch.nn.Module) -> str:
317
+ name_or_path = getattr(model, "name_or_path", None)
318
+ if name_or_path:
319
+ return str(name_or_path)
320
+ config = getattr(model, "config", None)
321
+ config_name = getattr(config, "_name_or_path", None)
322
+ if config_name:
323
+ return str(config_name)
324
+ return type(model).__name__
325
+
326
+
327
+ def _model_config_summary(model: torch.nn.Module) -> dict[str, Any]:
328
+ config = getattr(model, "config", None)
329
+ if config is None:
330
+ return {}
331
+
332
+ fields = [
333
+ "model_type",
334
+ "architectures",
335
+ "hidden_size",
336
+ "n_embd",
337
+ "num_hidden_layers",
338
+ "n_layer",
339
+ "num_attention_heads",
340
+ "n_head",
341
+ "intermediate_size",
342
+ "vocab_size",
343
+ ]
344
+ summary: dict[str, Any] = {}
345
+ for field in fields:
346
+ value = getattr(config, field, None)
347
+ if value is not None:
348
+ summary[field] = value
349
+ return summary
350
+
351
+
352
+ def _software_versions() -> dict[str, str | None]:
353
+ return {
354
+ "python": platform.python_version(),
355
+ "torch": torch.__version__,
356
+ "pandas": pd.__version__,
357
+ "representation_geometry": _distribution_version("representation-geometry"),
358
+ "transformers": _distribution_version("transformers"),
359
+ }
360
+
361
+
362
+ def _distribution_version(package: str) -> str | None:
363
+ try:
364
+ return importlib.metadata.version(package)
365
+ except importlib.metadata.PackageNotFoundError:
366
+ return None
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from representation_geometry.metrics import metric_from_eigenvalues
10
+ from representation_geometry.online import RunningCovariance
11
+
12
+ DiagnosticFn = Any
13
+
14
+
15
+ def transform_activation(x: torch.Tensor, variant: str) -> torch.Tensor:
16
+ x = x.float()
17
+ if variant == "raw":
18
+ return x
19
+ if variant == "token_l2":
20
+ return F.normalize(x, dim=-1)
21
+ if variant == "token_rms":
22
+ rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt().clamp(min=1e-12)
23
+ return x / rms
24
+ if variant == "feature_standardized":
25
+ return (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True).clamp(min=1e-12)
26
+ raise ValueError(f"Unknown activation transform: {variant}")
27
+
28
+
29
+ def spectrum_metrics_from_activations(x: torch.Tensor) -> dict[str, Any]:
30
+ if x.shape[0] < 2:
31
+ return {}
32
+ x = x.float()
33
+ x = x - x.mean(dim=0, keepdim=True)
34
+ singular_values = torch.linalg.svdvals(x)
35
+ eigen_like = (singular_values**2).clamp_min(0).cpu().numpy()
36
+ return metric_from_eigenvalues(eigen_like)
37
+
38
+
39
+ def normalization_ablation_metrics(
40
+ stats: dict[int, RunningCovariance],
41
+ *,
42
+ model_key: str,
43
+ model_label: str,
44
+ variants: tuple[str, ...] = ("raw", "token_l2", "token_rms", "feature_standardized"),
45
+ ) -> pd.DataFrame:
46
+ rows: list[dict[str, Any]] = []
47
+ for layer_idx, stat in sorted(stats.items()):
48
+ sample = stat.sample()
49
+ if sample.numel() == 0:
50
+ continue
51
+ for variant in variants:
52
+ transformed = transform_activation(sample, variant)
53
+ row = spectrum_metrics_from_activations(transformed)
54
+ if not row:
55
+ continue
56
+ row.update(
57
+ {
58
+ "model_key": model_key,
59
+ "model_label": model_label,
60
+ "layer": layer_idx,
61
+ "variant": variant,
62
+ "tokens": int(sample.shape[0]),
63
+ "hidden_dim": int(sample.shape[1]),
64
+ }
65
+ )
66
+ rows.append(row)
67
+ return pd.DataFrame(rows)
68
+
69
+
70
+ DIAGNOSTIC_REGISTRY: dict[str, DiagnosticFn] = {
71
+ "normalization_ablation": normalization_ablation_metrics,
72
+ }
73
+
74
+
75
+ def available_diagnostics() -> tuple[str, ...]:
76
+ return tuple(sorted(DIAGNOSTIC_REGISTRY))
@@ -0,0 +1,215 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping, Sequence
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from representation_geometry.online import RunningCovariance
10
+
11
+
12
+ class ResidualStatsStore:
13
+ """Lazy per-layer streaming covariance store."""
14
+
15
+ def __init__(self, *, sample_limit: int = 512, dtype: torch.dtype = torch.float32):
16
+ self.sample_limit = sample_limit
17
+ self.dtype = dtype
18
+ self.stats: dict[int, RunningCovariance] = {}
19
+
20
+ @torch.no_grad()
21
+ def update(self, layer_idx: int, value: Any) -> None:
22
+ tensor = first_tensor(value)
23
+ if tensor is None or tensor.ndim < 2:
24
+ return
25
+ dim = int(tensor.shape[-1])
26
+ stat = self.stats.get(layer_idx)
27
+ if stat is None:
28
+ stat = RunningCovariance(dim=dim, dtype=self.dtype, sample_limit=self.sample_limit)
29
+ self.stats[layer_idx] = stat
30
+ stat.update(tensor)
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class HookTarget:
35
+ layer: int
36
+ module: torch.nn.Module
37
+ name: str
38
+
39
+
40
+ def resolve_hook_targets(
41
+ model: torch.nn.Module,
42
+ *,
43
+ hook_point: str,
44
+ layers: str | Sequence[int],
45
+ ) -> tuple[list[HookTarget], dict[str, Any]]:
46
+ """Resolve a public hook-point string into concrete PyTorch modules."""
47
+
48
+ if hook_point in {"residual_input", "residual_output"}:
49
+ blocks = get_transformer_blocks(model)
50
+ layer_ids = parse_layers(layers, len(blocks))
51
+ targets = [
52
+ HookTarget(layer=layer_idx, module=blocks[layer_idx], name=f"block_{layer_idx}")
53
+ for layer_idx in layer_ids
54
+ ]
55
+ return targets, {
56
+ "num_blocks": len(blocks),
57
+ "layers": layer_ids,
58
+ "hook_targets": [target.name for target in targets],
59
+ }
60
+
61
+ module_prefixes = {
62
+ "module_input:": "input",
63
+ "module_output:": "output",
64
+ }
65
+ for prefix, capture in module_prefixes.items():
66
+ if hook_point.startswith(prefix):
67
+ if layers != "all":
68
+ raise ValueError(
69
+ "layers is only supported with residual_input and residual_output hooks."
70
+ )
71
+ module_name = hook_point.removeprefix(prefix)
72
+ if not module_name:
73
+ raise ValueError(f"Expected a module path after {prefix!r}.")
74
+ modules = dict(model.named_modules())
75
+ if module_name not in modules:
76
+ available = ", ".join(name for name in modules if name) or "<root only>"
77
+ raise ValueError(
78
+ f"Unknown module path {module_name!r}. "
79
+ f"Available modules include: {available}"
80
+ )
81
+ target = HookTarget(layer=0, module=modules[module_name], name=module_name)
82
+ return [target], {
83
+ "num_blocks": None,
84
+ "layers": [0],
85
+ "hook_targets": [module_name],
86
+ "module_capture": capture,
87
+ }
88
+
89
+ raise ValueError(
90
+ "hook_point must be one of 'residual_input', 'residual_output', "
91
+ "'module_input:<module_path>', or 'module_output:<module_path>'."
92
+ )
93
+
94
+
95
+ def register_activation_hooks(
96
+ targets: Sequence[HookTarget],
97
+ store: ResidualStatsStore,
98
+ *,
99
+ hook_point: str,
100
+ ) -> list[torch.utils.hooks.RemovableHandle]:
101
+ capture_input = hook_point == "residual_input" or hook_point.startswith("module_input:")
102
+ handles: list[torch.utils.hooks.RemovableHandle] = []
103
+
104
+ for target in targets:
105
+ if capture_input:
106
+ handles.append(
107
+ target.module.register_forward_pre_hook(_make_input_hook(target.layer, store))
108
+ )
109
+ else:
110
+ handles.append(
111
+ target.module.register_forward_hook(_make_output_hook(target.layer, store))
112
+ )
113
+ return handles
114
+
115
+
116
+ def register_residual_input_hooks(
117
+ blocks: Sequence[torch.nn.Module],
118
+ layers: Sequence[int],
119
+ store: ResidualStatsStore,
120
+ ) -> list[torch.utils.hooks.RemovableHandle]:
121
+ """Backward-compatible wrapper for the original residual-input collector."""
122
+
123
+ targets = [
124
+ HookTarget(layer=idx, module=block, name=f"block_{idx}")
125
+ for idx, block in enumerate(blocks)
126
+ if idx in set(layers)
127
+ ]
128
+ return register_activation_hooks(targets, store, hook_point="residual_input")
129
+
130
+
131
+ def _make_input_hook(layer_idx: int, store: ResidualStatsStore):
132
+ def hook(module: torch.nn.Module, inputs: tuple[Any, ...]) -> None:
133
+ if inputs:
134
+ store.update(layer_idx, inputs[0])
135
+
136
+ return hook
137
+
138
+
139
+ def _make_output_hook(layer_idx: int, store: ResidualStatsStore):
140
+ def hook(module: torch.nn.Module, inputs: tuple[Any, ...], output: Any) -> None:
141
+ store.update(layer_idx, output)
142
+
143
+ return hook
144
+
145
+
146
+ def get_transformer_blocks(model: torch.nn.Module) -> list[torch.nn.Module]:
147
+ """Locate common transformer block containers."""
148
+
149
+ candidates = [
150
+ ("transformer", "h"),
151
+ ("model", "layers"),
152
+ ("gpt_neox", "layers"),
153
+ ("decoder", "layers"),
154
+ ]
155
+ for parent_name, child_name in candidates:
156
+ parent = getattr(model, parent_name, None)
157
+ if parent is not None and hasattr(parent, child_name):
158
+ return list(getattr(parent, child_name))
159
+
160
+ for attr in ["layers", "blocks", "h"]:
161
+ value = getattr(model, attr, None)
162
+ if value is not None:
163
+ return list(value)
164
+
165
+ raise ValueError(
166
+ f"Could not locate transformer blocks for {type(model).__name__}. "
167
+ "Pass a model with .transformer.h, .model.layers, .gpt_neox.layers, .layers, or .blocks."
168
+ )
169
+
170
+
171
+ def get_input_device(model: torch.nn.Module) -> torch.device:
172
+ for parameter in model.parameters():
173
+ if parameter.device.type != "meta":
174
+ return parameter.device
175
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
176
+
177
+
178
+ def parse_layers(layers: str | Sequence[int], num_blocks: int) -> list[int]:
179
+ if layers == "all":
180
+ return list(range(num_blocks))
181
+ if isinstance(layers, str):
182
+ return [int(part.strip()) for part in layers.split(",") if part.strip()]
183
+ out = [int(layer) for layer in layers]
184
+ invalid = [layer for layer in out if layer < 0 or layer >= num_blocks]
185
+ if invalid:
186
+ raise ValueError(f"Layer indices out of range for {num_blocks} blocks: {invalid}")
187
+ return out
188
+
189
+
190
+ def first_tensor(value: Any) -> torch.Tensor | None:
191
+ if isinstance(value, torch.Tensor):
192
+ return value
193
+ if isinstance(value, Mapping):
194
+ for item in value.values():
195
+ tensor = first_tensor(item)
196
+ if tensor is not None:
197
+ return tensor
198
+ if isinstance(value, (tuple, list)):
199
+ for item in value:
200
+ tensor = first_tensor(item)
201
+ if tensor is not None:
202
+ return tensor
203
+ return None
204
+
205
+
206
+ def move_to_device(value: Any, device: torch.device) -> Any:
207
+ if isinstance(value, torch.Tensor):
208
+ return value.to(device)
209
+ if isinstance(value, Mapping):
210
+ return {key: move_to_device(item, device) for key, item in value.items()}
211
+ if isinstance(value, tuple):
212
+ return tuple(move_to_device(item, device) for item in value)
213
+ if isinstance(value, list):
214
+ return [move_to_device(item, device) for item in value]
215
+ return value