invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2167 @@
|
|
|
1
|
+
"""
|
|
2
|
+
invarlock.metrics
|
|
3
|
+
=============
|
|
4
|
+
|
|
5
|
+
Enhanced diagnostic helpers used by the Phase-2 notebooks with improved
|
|
6
|
+
robustness, performance, and configurability.
|
|
7
|
+
|
|
8
|
+
Public entry point
|
|
9
|
+
------------------
|
|
10
|
+
>>> from invarlock.metrics import calculate_lens_metrics_for_model, MetricsConfig
|
|
11
|
+
>>> config = MetricsConfig(oracle_windows=32, max_tokens=512)
|
|
12
|
+
>>> metrics = calculate_lens_metrics_for_model(model, dataloader, config=config)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import gc
|
|
18
|
+
import logging
|
|
19
|
+
import math
|
|
20
|
+
import time
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import psutil
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
|
|
30
|
+
from invarlock.core.error_utils import wrap_errors
|
|
31
|
+
from invarlock.core.exceptions import MetricsError, ValidationError
|
|
32
|
+
|
|
33
|
+
# ── Enhanced logging setup ─────────────────────────────────────────────────
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
try: # Optional dependency: tqdm (progress bars)
|
|
38
|
+
from tqdm.auto import tqdm as _tqdm
|
|
39
|
+
except Exception: # pragma: no cover - exercised only when tqdm is absent
|
|
40
|
+
|
|
41
|
+
class _TqdmShim:
|
|
42
|
+
def __init__(self, iterable=None, total=None, **kwargs):
|
|
43
|
+
self._iterable = iterable
|
|
44
|
+
self.total = total
|
|
45
|
+
|
|
46
|
+
def __iter__(self):
|
|
47
|
+
if self._iterable is None:
|
|
48
|
+
return iter(())
|
|
49
|
+
return iter(self._iterable)
|
|
50
|
+
|
|
51
|
+
def __enter__(self):
|
|
52
|
+
return self
|
|
53
|
+
|
|
54
|
+
def __exit__(self, exc_type, exc, tb):
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
def update(self, n: int = 1) -> None:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
def _tqdm(iterable=None, *args, **kwargs):
|
|
61
|
+
return _TqdmShim(iterable=iterable, **kwargs)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
tqdm = _tqdm
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class DependencyError(MetricsError):
|
|
68
|
+
"""Raised when required dependencies are missing."""
|
|
69
|
+
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ResourceError(MetricsError):
|
|
74
|
+
"""Raised when insufficient resources are available."""
|
|
75
|
+
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
## Note: Use ValidationError from invarlock.core.exceptions
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def bootstrap_confidence_interval(
|
|
83
|
+
samples: list[float] | np.ndarray,
|
|
84
|
+
n_bootstrap: int = 500,
|
|
85
|
+
alpha: float = 0.05,
|
|
86
|
+
statistic: callable = np.mean,
|
|
87
|
+
random_state: np.random.Generator | None = None,
|
|
88
|
+
) -> tuple[float, float]:
|
|
89
|
+
"""
|
|
90
|
+
Compute a bootstrap confidence interval for a 1D sample.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
samples: 1D iterable of numeric samples.
|
|
94
|
+
n_bootstrap: Number of bootstrap resamples.
|
|
95
|
+
alpha: Significance level (0 < alpha < 1).
|
|
96
|
+
statistic: Statistic function to apply to each resample.
|
|
97
|
+
random_state: Optional numpy random generator for reproducibility.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
(lower, upper) confidence bounds.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
ValidationError(E402): For invalid inputs (shape/empty/range).
|
|
104
|
+
MetricsError(E401): For compute/statistic failures during bootstrap.
|
|
105
|
+
"""
|
|
106
|
+
data = np.asarray(samples, dtype=float)
|
|
107
|
+
if data.ndim != 1:
|
|
108
|
+
raise ValidationError(
|
|
109
|
+
code="E402",
|
|
110
|
+
message="METRICS-VALIDATION-FAILED",
|
|
111
|
+
details={"reason": "samples must be 1-dimensional"},
|
|
112
|
+
)
|
|
113
|
+
if data.size == 0:
|
|
114
|
+
raise ValidationError(
|
|
115
|
+
code="E402",
|
|
116
|
+
message="METRICS-VALIDATION-FAILED",
|
|
117
|
+
details={"reason": "samples cannot be empty"},
|
|
118
|
+
)
|
|
119
|
+
if not 0.0 < alpha < 1.0:
|
|
120
|
+
raise ValidationError(
|
|
121
|
+
code="E402",
|
|
122
|
+
message="METRICS-VALIDATION-FAILED",
|
|
123
|
+
details={"reason": "alpha must be between 0 and 1", "alpha": alpha},
|
|
124
|
+
)
|
|
125
|
+
if n_bootstrap <= 0:
|
|
126
|
+
raise ValidationError(
|
|
127
|
+
code="E402",
|
|
128
|
+
message="METRICS-VALIDATION-FAILED",
|
|
129
|
+
details={
|
|
130
|
+
"reason": "n_bootstrap must be positive",
|
|
131
|
+
"n_bootstrap": n_bootstrap,
|
|
132
|
+
},
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
with wrap_errors(MetricsError, "E401", "METRICS-COMPUTE-FAILED"):
|
|
136
|
+
rng = random_state or np.random.default_rng()
|
|
137
|
+
stats = np.empty(n_bootstrap, dtype=float)
|
|
138
|
+
for i in range(n_bootstrap):
|
|
139
|
+
indices = rng.integers(0, data.size, size=data.size)
|
|
140
|
+
stats[i] = statistic(data[indices])
|
|
141
|
+
|
|
142
|
+
lower = float(np.percentile(stats, 100 * (alpha / 2)))
|
|
143
|
+
upper = float(np.percentile(stats, 100 * (1 - alpha / 2)))
|
|
144
|
+
return lower, upper
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class MetricsConfig:
|
|
149
|
+
"""Configuration for metrics calculation with sensible defaults."""
|
|
150
|
+
|
|
151
|
+
# Core parameters
|
|
152
|
+
oracle_windows: int = 16
|
|
153
|
+
max_tokens: int = 256
|
|
154
|
+
max_samples_per_layer: int = 25_000
|
|
155
|
+
|
|
156
|
+
# Memory management
|
|
157
|
+
auto_batch_size: bool = True
|
|
158
|
+
memory_limit_gb: float | None = None
|
|
159
|
+
cpu_fallback_threshold_gb: float = 0.5
|
|
160
|
+
|
|
161
|
+
# Performance options
|
|
162
|
+
use_cache: bool = True
|
|
163
|
+
cache_dir: Path | None = None
|
|
164
|
+
progress_bars: bool = True
|
|
165
|
+
|
|
166
|
+
# Numerical stability
|
|
167
|
+
clip_value: float = 1e3
|
|
168
|
+
nan_replacement: float = 0.0
|
|
169
|
+
inf_replacement: float = 1e4
|
|
170
|
+
|
|
171
|
+
# Device management
|
|
172
|
+
device: torch.device | None = None
|
|
173
|
+
force_cpu: bool = False
|
|
174
|
+
cleanup_after: bool = True
|
|
175
|
+
|
|
176
|
+
# Validation options
|
|
177
|
+
strict_validation: bool = True
|
|
178
|
+
allow_empty_data: bool = False
|
|
179
|
+
|
|
180
|
+
# Lens-specific parameters
|
|
181
|
+
sigma_max_margin: float = 0.98
|
|
182
|
+
mi_gini_subsample_ratio: float = 0.05
|
|
183
|
+
head_energy_layers_filter: bool = True
|
|
184
|
+
|
|
185
|
+
def __post_init__(self):
|
|
186
|
+
"""Validate configuration after initialization."""
|
|
187
|
+
if self.oracle_windows < 0:
|
|
188
|
+
raise ValidationError(
|
|
189
|
+
code="E402",
|
|
190
|
+
message="METRICS-VALIDATION-FAILED",
|
|
191
|
+
details={"reason": "oracle_windows must be non-negative"},
|
|
192
|
+
)
|
|
193
|
+
if self.max_tokens <= 0:
|
|
194
|
+
raise ValidationError(
|
|
195
|
+
code="E402",
|
|
196
|
+
message="METRICS-VALIDATION-FAILED",
|
|
197
|
+
details={"reason": "max_tokens must be positive"},
|
|
198
|
+
)
|
|
199
|
+
if self.memory_limit_gb is not None and self.memory_limit_gb <= 0:
|
|
200
|
+
raise ValidationError(
|
|
201
|
+
code="E402",
|
|
202
|
+
message="METRICS-VALIDATION-FAILED",
|
|
203
|
+
details={"reason": "memory_limit_gb must be positive"},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Set default cache directory
|
|
207
|
+
if self.use_cache and self.cache_dir is None:
|
|
208
|
+
self.cache_dir = Path.home() / ".cache" / "invarlock_metrics"
|
|
209
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class ResourceManager:
|
|
213
|
+
"""Manages computational resources and memory usage."""
|
|
214
|
+
|
|
215
|
+
def __init__(self, config: MetricsConfig):
|
|
216
|
+
self.config = config
|
|
217
|
+
self.device = self._determine_device()
|
|
218
|
+
self.memory_info = self._get_memory_info()
|
|
219
|
+
|
|
220
|
+
def _determine_device(self) -> torch.device:
|
|
221
|
+
"""Determine the best device to use."""
|
|
222
|
+
if self.config.force_cpu:
|
|
223
|
+
return torch.device("cpu")
|
|
224
|
+
|
|
225
|
+
if self.config.device is not None:
|
|
226
|
+
return self.config.device
|
|
227
|
+
|
|
228
|
+
if torch.cuda.is_available():
|
|
229
|
+
return torch.device("cuda")
|
|
230
|
+
elif torch.backends.mps.is_available():
|
|
231
|
+
return torch.device("mps")
|
|
232
|
+
else:
|
|
233
|
+
return torch.device("cpu")
|
|
234
|
+
|
|
235
|
+
def _get_memory_info(self) -> dict[str, float]:
|
|
236
|
+
"""Get current memory information."""
|
|
237
|
+
info = {}
|
|
238
|
+
|
|
239
|
+
# System memory
|
|
240
|
+
vm = psutil.virtual_memory()
|
|
241
|
+
info["system_total_gb"] = vm.total / (1024**3)
|
|
242
|
+
info["system_available_gb"] = vm.available / (1024**3)
|
|
243
|
+
|
|
244
|
+
# GPU memory
|
|
245
|
+
if self.device.type == "cuda":
|
|
246
|
+
info["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / (
|
|
247
|
+
1024**3
|
|
248
|
+
)
|
|
249
|
+
info["gpu_free_gb"] = (
|
|
250
|
+
torch.cuda.get_device_properties(0).total_memory
|
|
251
|
+
- torch.cuda.memory_allocated()
|
|
252
|
+
) / (1024**3)
|
|
253
|
+
|
|
254
|
+
return info
|
|
255
|
+
|
|
256
|
+
def estimate_memory_usage(
|
|
257
|
+
self, model: nn.Module, batch_size: int, seq_length: int
|
|
258
|
+
) -> float:
|
|
259
|
+
"""Estimate memory usage in GB for given parameters."""
|
|
260
|
+
# Model parameters
|
|
261
|
+
param_memory = sum(p.numel() * p.element_size() for p in model.parameters()) / (
|
|
262
|
+
1024**3
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Activation memory (rough estimate)
|
|
266
|
+
if hasattr(model, "config"):
|
|
267
|
+
hidden_size = getattr(
|
|
268
|
+
model.config, "n_embd", getattr(model.config, "hidden_size", 768)
|
|
269
|
+
)
|
|
270
|
+
num_layers = getattr(
|
|
271
|
+
model.config, "n_layer", getattr(model.config, "num_hidden_layers", 12)
|
|
272
|
+
)
|
|
273
|
+
activation_memory = (
|
|
274
|
+
batch_size * seq_length * hidden_size * num_layers * 4
|
|
275
|
+
) / (1024**3)
|
|
276
|
+
else:
|
|
277
|
+
activation_memory = param_memory * 2 # Conservative estimate
|
|
278
|
+
|
|
279
|
+
return param_memory + activation_memory
|
|
280
|
+
|
|
281
|
+
def should_use_cpu_fallback(self, estimated_memory_gb: float) -> bool:
|
|
282
|
+
"""Determine if CPU fallback should be used."""
|
|
283
|
+
if self.device.type == "cpu":
|
|
284
|
+
return False
|
|
285
|
+
|
|
286
|
+
available_memory = self.memory_info.get(
|
|
287
|
+
"gpu_free_gb", self.memory_info.get("system_available_gb", 8.0)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return estimated_memory_gb > (
|
|
291
|
+
available_memory - self.config.cpu_fallback_threshold_gb
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def cleanup(self):
|
|
295
|
+
"""Clean up GPU memory."""
|
|
296
|
+
if self.config.cleanup_after:
|
|
297
|
+
if torch.cuda.is_available():
|
|
298
|
+
torch.cuda.empty_cache()
|
|
299
|
+
gc.collect()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# ── Enhanced dependency management ─────────────────────────────────────────
|
|
303
|
+
class DependencyManager:
|
|
304
|
+
"""Manages optional dependencies with graceful degradation."""
|
|
305
|
+
|
|
306
|
+
def __init__(self):
|
|
307
|
+
self.available_modules: dict[str, Any] = {}
|
|
308
|
+
self.missing_modules: list[tuple[str, str]] = []
|
|
309
|
+
self._check_dependencies()
|
|
310
|
+
|
|
311
|
+
def _check_dependencies(self):
|
|
312
|
+
"""Check availability of optional dependencies."""
|
|
313
|
+
# Check lens2_mi
|
|
314
|
+
try:
|
|
315
|
+
from .lens2_mi import mi_scores
|
|
316
|
+
|
|
317
|
+
self.available_modules["mi_scores"] = mi_scores
|
|
318
|
+
logger.info("✓ lens2_mi module available")
|
|
319
|
+
except ImportError as e:
|
|
320
|
+
self.missing_modules.append(("lens2_mi", str(e)))
|
|
321
|
+
logger.warning("✗ lens2_mi module not available - MI-Gini will be NaN")
|
|
322
|
+
|
|
323
|
+
# Check lens3
|
|
324
|
+
try:
|
|
325
|
+
from .lens3 import scan_model_gains
|
|
326
|
+
|
|
327
|
+
self.available_modules["scan_model_gains"] = scan_model_gains
|
|
328
|
+
logger.info("✓ lens3 module available")
|
|
329
|
+
except ImportError as e:
|
|
330
|
+
self.missing_modules.append(("lens3", str(e)))
|
|
331
|
+
logger.warning("✗ lens3 module not available - σ_max will be NaN")
|
|
332
|
+
|
|
333
|
+
def get_module(self, name: str):
|
|
334
|
+
"""Get a module if available, otherwise raise DependencyError."""
|
|
335
|
+
if name in self.available_modules:
|
|
336
|
+
return self.available_modules[name]
|
|
337
|
+
raise DependencyError(
|
|
338
|
+
code="E203",
|
|
339
|
+
message=f"DEPENDENCY-MISSING: module {name} is not available",
|
|
340
|
+
details={"module": name},
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def is_available(self, name: str) -> bool:
|
|
344
|
+
"""Check if a module is available."""
|
|
345
|
+
return name in self.available_modules
|
|
346
|
+
|
|
347
|
+
def get_missing_dependencies(self) -> list[tuple[str, str]]:
|
|
348
|
+
"""Get list of missing dependencies with error messages."""
|
|
349
|
+
return self.missing_modules.copy()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
# ── Input validation ───────────────────────────────────────────────────────
|
|
353
|
+
class InputValidator:
|
|
354
|
+
"""Validates inputs for metrics calculation."""
|
|
355
|
+
|
|
356
|
+
@staticmethod
|
|
357
|
+
def validate_model(model: nn.Module, config: MetricsConfig) -> None:
|
|
358
|
+
"""Validate model input."""
|
|
359
|
+
if not isinstance(model, nn.Module):
|
|
360
|
+
raise ValidationError(
|
|
361
|
+
code="E402",
|
|
362
|
+
message="METRICS-VALIDATION-FAILED",
|
|
363
|
+
details={"reason": f"Expected nn.Module, got {type(model)}"},
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Check if model has parameters
|
|
367
|
+
try:
|
|
368
|
+
param_count = sum(1 for _ in model.parameters())
|
|
369
|
+
if param_count == 0:
|
|
370
|
+
if config.strict_validation:
|
|
371
|
+
raise ValidationError(
|
|
372
|
+
code="E402",
|
|
373
|
+
message="METRICS-VALIDATION-FAILED",
|
|
374
|
+
details={"reason": "Model has no parameters"},
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
logger.warning("Model has no parameters")
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.debug(f"Could not count model parameters: {e}")
|
|
380
|
+
|
|
381
|
+
@staticmethod
|
|
382
|
+
def validate_dataloader(dataloader, config: MetricsConfig) -> None:
|
|
383
|
+
"""Validate dataloader input."""
|
|
384
|
+
if dataloader is None:
|
|
385
|
+
raise ValidationError(
|
|
386
|
+
code="E402",
|
|
387
|
+
message="METRICS-VALIDATION-FAILED",
|
|
388
|
+
details={"reason": "Dataloader cannot be None"},
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Check if dataloader has data
|
|
392
|
+
try:
|
|
393
|
+
first_batch = next(iter(dataloader))
|
|
394
|
+
if not first_batch:
|
|
395
|
+
if not config.allow_empty_data:
|
|
396
|
+
raise ValidationError(
|
|
397
|
+
code="E402",
|
|
398
|
+
message="METRICS-VALIDATION-FAILED",
|
|
399
|
+
details={"reason": "Dataloader is empty"},
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
logger.warning("Dataloader is empty")
|
|
403
|
+
except StopIteration as e:
|
|
404
|
+
if not config.allow_empty_data:
|
|
405
|
+
raise ValidationError(
|
|
406
|
+
code="E402",
|
|
407
|
+
message="METRICS-VALIDATION-FAILED",
|
|
408
|
+
details={"reason": "Dataloader is empty"},
|
|
409
|
+
) from e
|
|
410
|
+
else:
|
|
411
|
+
logger.warning("Dataloader is empty")
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def validate_tensor(
|
|
415
|
+
tensor: torch.Tensor, name: str, config: MetricsConfig
|
|
416
|
+
) -> torch.Tensor:
|
|
417
|
+
"""Validate and sanitize tensor."""
|
|
418
|
+
if not isinstance(tensor, torch.Tensor):
|
|
419
|
+
raise ValidationError(
|
|
420
|
+
code="E402",
|
|
421
|
+
message="METRICS-VALIDATION-FAILED",
|
|
422
|
+
details={"reason": f"{name} must be a tensor, got {type(tensor)}"},
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Check for NaN/Inf
|
|
426
|
+
if torch.isnan(tensor).any():
|
|
427
|
+
if config.strict_validation:
|
|
428
|
+
raise ValidationError(
|
|
429
|
+
code="E402",
|
|
430
|
+
message="METRICS-VALIDATION-FAILED",
|
|
431
|
+
details={"reason": f"{name} contains NaN values"},
|
|
432
|
+
)
|
|
433
|
+
else:
|
|
434
|
+
logger.warning(
|
|
435
|
+
f"{name} contains NaN values, replacing with {config.nan_replacement}"
|
|
436
|
+
)
|
|
437
|
+
tensor = torch.nan_to_num(tensor, nan=config.nan_replacement)
|
|
438
|
+
|
|
439
|
+
if torch.isinf(tensor).any():
|
|
440
|
+
if config.strict_validation:
|
|
441
|
+
raise ValidationError(
|
|
442
|
+
code="E402",
|
|
443
|
+
message="METRICS-VALIDATION-FAILED",
|
|
444
|
+
details={"reason": f"{name} contains Inf values"},
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
logger.warning(
|
|
448
|
+
f"{name} contains Inf values, replacing with ±{config.inf_replacement}"
|
|
449
|
+
)
|
|
450
|
+
tensor = torch.nan_to_num(
|
|
451
|
+
tensor,
|
|
452
|
+
posinf=config.inf_replacement,
|
|
453
|
+
neginf=-config.inf_replacement,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return tensor
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# ── Enhanced helper functions ──────────────────────────────────────────────
|
|
460
|
+
def _gini_vectorized(vec: torch.Tensor) -> float:
|
|
461
|
+
"""Optimized Gini coefficient calculation."""
|
|
462
|
+
flat = vec.flatten().abs().float()
|
|
463
|
+
if flat.numel() == 0 or torch.sum(flat) == 0:
|
|
464
|
+
return float("nan")
|
|
465
|
+
|
|
466
|
+
# Use more efficient sorting and cumsum
|
|
467
|
+
sorted_vals = torch.sort(flat)[0]
|
|
468
|
+
n = sorted_vals.numel()
|
|
469
|
+
|
|
470
|
+
# Vectorized Gini calculation
|
|
471
|
+
indices = torch.arange(1, n + 1, dtype=torch.float32, device=flat.device)
|
|
472
|
+
gini = (2 * torch.sum(indices * sorted_vals) / torch.sum(sorted_vals) - (n + 1)) / n
|
|
473
|
+
|
|
474
|
+
return gini.item()
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _mi_gini_optimized_cpu_path(
|
|
478
|
+
feats_cpu: torch.Tensor,
|
|
479
|
+
targ_cpu: torch.Tensor,
|
|
480
|
+
max_per_layer: int,
|
|
481
|
+
config: MetricsConfig,
|
|
482
|
+
) -> float:
|
|
483
|
+
"""Optimized MI Gini calculation on CPU with better memory management."""
|
|
484
|
+
L, N, _ = feats_cpu.shape
|
|
485
|
+
|
|
486
|
+
# Subsample if dataset is too large
|
|
487
|
+
if N > max_per_layer:
|
|
488
|
+
sel = torch.randperm(N)[:max_per_layer]
|
|
489
|
+
feats_cpu = feats_cpu[:, sel, :]
|
|
490
|
+
targ_cpu = targ_cpu[sel]
|
|
491
|
+
|
|
492
|
+
# Get MI function
|
|
493
|
+
dep_manager = DependencyManager()
|
|
494
|
+
if not dep_manager.is_available("mi_scores"):
|
|
495
|
+
return float("nan")
|
|
496
|
+
|
|
497
|
+
mi_scores_fn = dep_manager.get_module("mi_scores")
|
|
498
|
+
|
|
499
|
+
# Process in chunks to manage memory
|
|
500
|
+
chunk_size = min(8, L) # Process 8 layers at a time
|
|
501
|
+
mi_scores_all = []
|
|
502
|
+
|
|
503
|
+
progress_desc = "MI-Gini (CPU optimized)"
|
|
504
|
+
with tqdm(
|
|
505
|
+
total=L, desc=progress_desc, disable=not config.progress_bars, leave=False
|
|
506
|
+
) as pbar:
|
|
507
|
+
for i in range(0, L, chunk_size):
|
|
508
|
+
end_idx = min(i + chunk_size, L)
|
|
509
|
+
chunk_feats = feats_cpu[i:end_idx]
|
|
510
|
+
|
|
511
|
+
# Vectorized processing for the chunk
|
|
512
|
+
chunk_scores = []
|
|
513
|
+
for j in range(chunk_feats.shape[0]):
|
|
514
|
+
try:
|
|
515
|
+
score = mi_scores_fn(chunk_feats[j], targ_cpu)
|
|
516
|
+
chunk_scores.append(score)
|
|
517
|
+
except Exception as e:
|
|
518
|
+
logger.warning(f"MI calculation failed for layer {i + j}: {e}")
|
|
519
|
+
chunk_scores.append(torch.zeros_like(chunk_feats[j, 0, :]))
|
|
520
|
+
|
|
521
|
+
mi_scores_all.extend(chunk_scores)
|
|
522
|
+
pbar.update(end_idx - i)
|
|
523
|
+
|
|
524
|
+
if not mi_scores_all:
|
|
525
|
+
return float("nan")
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
mi_mat = torch.stack(mi_scores_all)
|
|
529
|
+
return _gini_vectorized(mi_mat)
|
|
530
|
+
except Exception as e:
|
|
531
|
+
logger.warning(f"Failed to stack MI scores: {e}")
|
|
532
|
+
return float("nan")
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def _locate_transformer_blocks_enhanced(model: nn.Module) -> list[nn.Module] | None:
|
|
536
|
+
"""Enhanced transformer block detection with better model support."""
|
|
537
|
+
|
|
538
|
+
# Standard GPT2 patterns - safer approach
|
|
539
|
+
def safe_getattr_chain(obj, *attrs):
|
|
540
|
+
"""Safely get nested attributes."""
|
|
541
|
+
for attr in attrs:
|
|
542
|
+
if obj is None:
|
|
543
|
+
return None
|
|
544
|
+
obj = getattr(obj, attr, None)
|
|
545
|
+
return obj
|
|
546
|
+
|
|
547
|
+
patterns = [
|
|
548
|
+
lambda m: safe_getattr_chain(m, "transformer", "h"),
|
|
549
|
+
lambda m: safe_getattr_chain(m, "h"), # Bare GPT2Model
|
|
550
|
+
lambda m: safe_getattr_chain(m, "base_model", "h"), # Common wrappers
|
|
551
|
+
lambda m: safe_getattr_chain(m, "model", "h"), # Some wrappers
|
|
552
|
+
lambda m: safe_getattr_chain(m, "transformer", "layers"), # Alternative naming
|
|
553
|
+
]
|
|
554
|
+
|
|
555
|
+
for pattern in patterns:
|
|
556
|
+
try:
|
|
557
|
+
blocks = pattern(model)
|
|
558
|
+
if blocks is not None and hasattr(blocks, "__len__") and len(blocks) > 0:
|
|
559
|
+
logger.debug(f"Found {len(blocks)} transformer blocks using pattern")
|
|
560
|
+
return list(blocks)
|
|
561
|
+
except (AttributeError, TypeError):
|
|
562
|
+
continue
|
|
563
|
+
|
|
564
|
+
# Fallback: search for transformer-like modules
|
|
565
|
+
transformer_modules = []
|
|
566
|
+
for name, module in model.named_modules():
|
|
567
|
+
if any(attr in name.lower() for attr in ["block", "layer", "transformer"]):
|
|
568
|
+
if hasattr(module, "attn") and hasattr(module, "mlp"):
|
|
569
|
+
transformer_modules.append(module)
|
|
570
|
+
|
|
571
|
+
if transformer_modules:
|
|
572
|
+
logger.debug(
|
|
573
|
+
f"Found {len(transformer_modules)} transformer blocks via fallback search"
|
|
574
|
+
)
|
|
575
|
+
return transformer_modules
|
|
576
|
+
|
|
577
|
+
logger.warning("Could not locate transformer blocks in model")
|
|
578
|
+
return None
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# ── Result caching ─────────────────────────────────────────────────────────
|
|
582
|
+
class ResultCache:
|
|
583
|
+
"""Simple result caching for expensive operations."""
|
|
584
|
+
|
|
585
|
+
def __init__(self, config: MetricsConfig):
|
|
586
|
+
self.config = config
|
|
587
|
+
self.cache: dict[str, dict[str, float]] = {}
|
|
588
|
+
self.enabled = config.use_cache
|
|
589
|
+
|
|
590
|
+
def _get_cache_key(
|
|
591
|
+
self, model: nn.Module, dataloader, config: MetricsConfig
|
|
592
|
+
) -> str:
|
|
593
|
+
"""Generate cache key for model and data."""
|
|
594
|
+
# Simple hash based on model parameters and config
|
|
595
|
+
model_hash = hash(tuple(p.data_ptr() for p in model.parameters()))
|
|
596
|
+
config_hash = hash(
|
|
597
|
+
(config.oracle_windows, config.max_tokens, config.max_samples_per_layer)
|
|
598
|
+
)
|
|
599
|
+
return f"{model_hash}_{config_hash}"
|
|
600
|
+
|
|
601
|
+
def get(self, key: str) -> dict[str, float] | None:
|
|
602
|
+
"""Get cached result."""
|
|
603
|
+
if not self.enabled:
|
|
604
|
+
return None
|
|
605
|
+
return self.cache.get(key)
|
|
606
|
+
|
|
607
|
+
def set(self, key: str, result: dict[str, float]) -> None:
|
|
608
|
+
"""Cache result."""
|
|
609
|
+
if self.enabled:
|
|
610
|
+
self.cache[key] = result.copy()
|
|
611
|
+
|
|
612
|
+
def clear(self) -> None:
|
|
613
|
+
"""Clear cache."""
|
|
614
|
+
self.cache.clear()
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
# ── Main metrics calculation function ──────────────────────────────────────
|
|
618
|
+
@torch.no_grad()
|
|
619
|
+
def calculate_lens_metrics_for_model(
|
|
620
|
+
model: nn.Module,
|
|
621
|
+
dataloader,
|
|
622
|
+
*,
|
|
623
|
+
config: MetricsConfig | None = None,
|
|
624
|
+
oracle_windows: int | None = None, # Backward compatibility
|
|
625
|
+
device: torch.device | None = None, # Backward compatibility
|
|
626
|
+
) -> dict[str, float]:
|
|
627
|
+
"""
|
|
628
|
+
Calculate comprehensive lens metrics for a model with enhanced robustness.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
model: The neural network model to analyze
|
|
632
|
+
dataloader: DataLoader providing input data
|
|
633
|
+
config: MetricsConfig object with all parameters
|
|
634
|
+
oracle_windows: (deprecated) Number of windows to process
|
|
635
|
+
device: (deprecated) Device to use for computation
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Dictionary containing calculated metrics
|
|
639
|
+
|
|
640
|
+
Raises:
|
|
641
|
+
MetricsError: If calculation fails due to various reasons
|
|
642
|
+
"""
|
|
643
|
+
# Handle backward compatibility
|
|
644
|
+
if config is None:
|
|
645
|
+
config = MetricsConfig()
|
|
646
|
+
if oracle_windows is not None:
|
|
647
|
+
config.oracle_windows = oracle_windows
|
|
648
|
+
if device is not None:
|
|
649
|
+
config.device = device
|
|
650
|
+
|
|
651
|
+
# Initialize managers
|
|
652
|
+
dep_manager = DependencyManager()
|
|
653
|
+
resource_manager = ResourceManager(config)
|
|
654
|
+
validator = InputValidator()
|
|
655
|
+
cache = ResultCache(config)
|
|
656
|
+
|
|
657
|
+
# Validate inputs
|
|
658
|
+
validator.validate_model(model, config)
|
|
659
|
+
validator.validate_dataloader(dataloader, config)
|
|
660
|
+
|
|
661
|
+
# Check cache
|
|
662
|
+
cache_key = cache._get_cache_key(model, dataloader, config)
|
|
663
|
+
cached_result = cache.get(cache_key)
|
|
664
|
+
if cached_result is not None:
|
|
665
|
+
logger.info("Using cached metrics result")
|
|
666
|
+
return cached_result
|
|
667
|
+
|
|
668
|
+
start_time = time.time()
|
|
669
|
+
logger.info(
|
|
670
|
+
f"Starting metrics calculation with config: oracle_windows={config.oracle_windows}, "
|
|
671
|
+
f"max_tokens={config.max_tokens}, device={resource_manager.device}"
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# Pre-evaluation checks
|
|
675
|
+
try:
|
|
676
|
+
_perform_pre_eval_checks(model, dataloader, resource_manager.device, config)
|
|
677
|
+
except Exception as e:
|
|
678
|
+
logger.warning(f"Pre-evaluation checks failed: {e}")
|
|
679
|
+
|
|
680
|
+
# Unwrap common wrappers if present
|
|
681
|
+
if hasattr(model, "base_model"):
|
|
682
|
+
try:
|
|
683
|
+
model = model.base_model
|
|
684
|
+
except Exception:
|
|
685
|
+
pass
|
|
686
|
+
|
|
687
|
+
model.eval()
|
|
688
|
+
device = resource_manager.device
|
|
689
|
+
|
|
690
|
+
# Initialize results
|
|
691
|
+
results = {
|
|
692
|
+
"sigma_max": float("nan"),
|
|
693
|
+
"head_energy": float("nan"),
|
|
694
|
+
"mi_gini": float("nan"),
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
skipped_metrics: list[str] = []
|
|
698
|
+
|
|
699
|
+
try:
|
|
700
|
+
# Collect activations with progress tracking
|
|
701
|
+
logger.info("Collecting model activations...")
|
|
702
|
+
activation_data = _collect_activations(model, dataloader, config, device)
|
|
703
|
+
|
|
704
|
+
if not activation_data["hidden_states"]:
|
|
705
|
+
logger.warning("No activations collected - returning default values")
|
|
706
|
+
return _finalize_results(
|
|
707
|
+
results, skipped_metrics, cache, cache_key, start_time
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# Calculate each metric
|
|
711
|
+
results["sigma_max"] = _calculate_sigma_max(
|
|
712
|
+
model, activation_data["first_batch"], dep_manager, config, device
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
results["head_energy"] = _calculate_head_energy(
|
|
716
|
+
activation_data["hidden_states"], config
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
results["mi_gini"] = _calculate_mi_gini(
|
|
720
|
+
model, activation_data, dep_manager, config, device
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
except Exception as e:
|
|
724
|
+
logger.error(f"Metrics calculation failed: {e}")
|
|
725
|
+
if config.strict_validation:
|
|
726
|
+
raise MetricsError(f"Metrics calculation failed: {e}") from e
|
|
727
|
+
|
|
728
|
+
finally:
|
|
729
|
+
resource_manager.cleanup()
|
|
730
|
+
|
|
731
|
+
return _finalize_results(results, skipped_metrics, cache, cache_key, start_time)
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
def _perform_pre_eval_checks(
|
|
735
|
+
model: nn.Module, dataloader, device: torch.device, config: MetricsConfig
|
|
736
|
+
) -> None:
|
|
737
|
+
"""Perform pre-evaluation sanity checks."""
|
|
738
|
+
# Check model context length vs data
|
|
739
|
+
try:
|
|
740
|
+
tok_len_attr = getattr(model.config, "n_positions", None) or getattr(
|
|
741
|
+
model.config, "max_position_embeddings", None
|
|
742
|
+
)
|
|
743
|
+
if tok_len_attr:
|
|
744
|
+
sample_batch = next(iter(dataloader))
|
|
745
|
+
sample_ids = sample_batch["input_ids"]
|
|
746
|
+
if sample_ids.shape[1] > tok_len_attr:
|
|
747
|
+
logger.warning(
|
|
748
|
+
f"Input sequence length {sample_ids.shape[1]} exceeds "
|
|
749
|
+
f"model limit {tok_len_attr}"
|
|
750
|
+
)
|
|
751
|
+
except Exception as e:
|
|
752
|
+
logger.debug(f"Context length check failed: {e}")
|
|
753
|
+
|
|
754
|
+
# Dry run forward pass
|
|
755
|
+
try:
|
|
756
|
+
dry_batch = next(iter(dataloader))
|
|
757
|
+
model_input = {
|
|
758
|
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
|
759
|
+
for k, v in dry_batch.items()
|
|
760
|
+
}
|
|
761
|
+
_ = model(**model_input)
|
|
762
|
+
logger.debug("Pre-evaluation dry run successful")
|
|
763
|
+
except Exception as e:
|
|
764
|
+
logger.warning(f"Pre-evaluation dry run failed: {e}")
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def _collect_activations(
|
|
768
|
+
model: nn.Module, dataloader, config: MetricsConfig, device: torch.device
|
|
769
|
+
) -> dict[str, Any]:
|
|
770
|
+
"""Collect model activations with enhanced error handling."""
|
|
771
|
+
hidden_states_list = []
|
|
772
|
+
fc1_activations_list = []
|
|
773
|
+
targets_list = []
|
|
774
|
+
first_batch = None
|
|
775
|
+
|
|
776
|
+
# Progress tracking
|
|
777
|
+
total_batches = (
|
|
778
|
+
min(config.oracle_windows, len(dataloader))
|
|
779
|
+
if hasattr(dataloader, "__len__")
|
|
780
|
+
else config.oracle_windows
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
with tqdm(
|
|
784
|
+
total=total_batches,
|
|
785
|
+
desc="Collecting activations",
|
|
786
|
+
disable=not config.progress_bars,
|
|
787
|
+
) as pbar:
|
|
788
|
+
for i, batch in enumerate(dataloader):
|
|
789
|
+
if i >= config.oracle_windows:
|
|
790
|
+
break
|
|
791
|
+
|
|
792
|
+
try:
|
|
793
|
+
# Store first batch for later use
|
|
794
|
+
if first_batch is None:
|
|
795
|
+
first_batch = {
|
|
796
|
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
|
797
|
+
for k, v in batch.items()
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
# Move batch to device
|
|
801
|
+
input_ids = batch["input_ids"].to(device)
|
|
802
|
+
|
|
803
|
+
# Limit sequence length
|
|
804
|
+
if input_ids.shape[1] > config.max_tokens:
|
|
805
|
+
input_ids = input_ids[:, : config.max_tokens]
|
|
806
|
+
|
|
807
|
+
# Forward pass with hidden states
|
|
808
|
+
output = model(input_ids, output_hidden_states=True)
|
|
809
|
+
|
|
810
|
+
# Collect hidden states (exclude first and last)
|
|
811
|
+
if hasattr(output, "hidden_states") and len(output.hidden_states) > 2:
|
|
812
|
+
hidden_states = torch.stack(output.hidden_states[1:-1])
|
|
813
|
+
hidden_states = validator.validate_tensor(
|
|
814
|
+
hidden_states, f"hidden_states_batch_{i}", config
|
|
815
|
+
)
|
|
816
|
+
hidden_states_list.append(hidden_states)
|
|
817
|
+
|
|
818
|
+
# Collect FC1 activations for MI-Gini
|
|
819
|
+
fc1_acts = _extract_fc1_activations(model, output, config)
|
|
820
|
+
if fc1_acts is not None:
|
|
821
|
+
fc1_activations_list.append(fc1_acts)
|
|
822
|
+
targets_list.append(
|
|
823
|
+
input_ids[:, 1:]
|
|
824
|
+
) # Shifted for next-token prediction
|
|
825
|
+
|
|
826
|
+
pbar.update(1)
|
|
827
|
+
|
|
828
|
+
except Exception as e:
|
|
829
|
+
logger.warning(f"Failed to process batch {i}: {e}")
|
|
830
|
+
continue
|
|
831
|
+
|
|
832
|
+
return {
|
|
833
|
+
"hidden_states": hidden_states_list,
|
|
834
|
+
"fc1_activations": fc1_activations_list,
|
|
835
|
+
"targets": targets_list,
|
|
836
|
+
"first_batch": first_batch,
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
def _extract_fc1_activations(
|
|
841
|
+
model: nn.Module, output, config: MetricsConfig
|
|
842
|
+
) -> torch.Tensor | None:
|
|
843
|
+
"""Extract FC1 activations for MI-Gini calculation."""
|
|
844
|
+
blocks = _locate_transformer_blocks_enhanced(model)
|
|
845
|
+
if blocks is None:
|
|
846
|
+
return None
|
|
847
|
+
|
|
848
|
+
try:
|
|
849
|
+
valid_activations = []
|
|
850
|
+
for idx, block in enumerate(blocks):
|
|
851
|
+
if hasattr(block, "mlp") and hasattr(block.mlp, "c_fc"):
|
|
852
|
+
try:
|
|
853
|
+
# Get hidden state for this layer
|
|
854
|
+
if (
|
|
855
|
+
hasattr(output, "hidden_states")
|
|
856
|
+
and len(output.hidden_states) > idx + 1
|
|
857
|
+
):
|
|
858
|
+
hidden_state = output.hidden_states[idx + 1]
|
|
859
|
+
activation = block.mlp.c_fc(hidden_state)
|
|
860
|
+
activation = validator.validate_tensor(
|
|
861
|
+
activation, f"fc1_activation_{idx}", config
|
|
862
|
+
)
|
|
863
|
+
valid_activations.append(activation)
|
|
864
|
+
except Exception as e:
|
|
865
|
+
logger.debug(
|
|
866
|
+
f"Failed to extract FC1 activation for block {idx}: {e}"
|
|
867
|
+
)
|
|
868
|
+
continue
|
|
869
|
+
|
|
870
|
+
if valid_activations:
|
|
871
|
+
# Check for consistent shapes
|
|
872
|
+
shapes = [act.shape for act in valid_activations]
|
|
873
|
+
if len(set(shapes)) > 1:
|
|
874
|
+
logger.warning(f"Inconsistent FC1 activation shapes: {set(shapes)}")
|
|
875
|
+
# Use most common shape
|
|
876
|
+
from collections import Counter
|
|
877
|
+
|
|
878
|
+
most_common_shape = Counter(shapes).most_common(1)[0][0]
|
|
879
|
+
valid_activations = [
|
|
880
|
+
act for act in valid_activations if act.shape == most_common_shape
|
|
881
|
+
]
|
|
882
|
+
|
|
883
|
+
return torch.stack(valid_activations)
|
|
884
|
+
|
|
885
|
+
except Exception as e:
|
|
886
|
+
logger.warning(f"FC1 activation extraction failed: {e}")
|
|
887
|
+
|
|
888
|
+
return None
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def _calculate_sigma_max(
|
|
892
|
+
model: nn.Module,
|
|
893
|
+
first_batch: dict | None,
|
|
894
|
+
dep_manager: DependencyManager,
|
|
895
|
+
config: MetricsConfig,
|
|
896
|
+
device: torch.device,
|
|
897
|
+
) -> float:
|
|
898
|
+
"""Calculate sigma_max metric via Lens-3."""
|
|
899
|
+
if not dep_manager.is_available("scan_model_gains"):
|
|
900
|
+
logger.info("Skipping σ_max: scan_model_gains not available")
|
|
901
|
+
return float("nan")
|
|
902
|
+
|
|
903
|
+
if first_batch is None:
|
|
904
|
+
logger.info("Skipping σ_max: no data batch available")
|
|
905
|
+
return float("nan")
|
|
906
|
+
|
|
907
|
+
try:
|
|
908
|
+
scan_model_gains = dep_manager.get_module("scan_model_gains")
|
|
909
|
+
gains_df = scan_model_gains(model, first_batch)
|
|
910
|
+
|
|
911
|
+
if gains_df is None:
|
|
912
|
+
logger.warning("scan_model_gains returned None")
|
|
913
|
+
return float("nan")
|
|
914
|
+
|
|
915
|
+
# Filter out embedding and head layers if possible
|
|
916
|
+
if hasattr(gains_df, "columns") and "name" in gains_df.columns:
|
|
917
|
+
mask = ~gains_df["name"].str.contains(
|
|
918
|
+
"embed|lm_head", case=False, regex=True
|
|
919
|
+
)
|
|
920
|
+
filtered_gains = gains_df[mask]
|
|
921
|
+
else:
|
|
922
|
+
logger.info("Could not filter layers by name for σ_max")
|
|
923
|
+
filtered_gains = gains_df
|
|
924
|
+
|
|
925
|
+
if len(filtered_gains) == 0:
|
|
926
|
+
logger.warning("No valid layers found for σ_max computation")
|
|
927
|
+
return float("nan")
|
|
928
|
+
|
|
929
|
+
# Extract gains
|
|
930
|
+
gains_values = getattr(
|
|
931
|
+
filtered_gains, "gain", getattr(filtered_gains, "values", [])
|
|
932
|
+
)
|
|
933
|
+
gains_tensor = torch.as_tensor(gains_values, dtype=torch.float32, device=device)
|
|
934
|
+
|
|
935
|
+
if gains_tensor.numel() == 0:
|
|
936
|
+
logger.warning("No gain values found")
|
|
937
|
+
return float("nan")
|
|
938
|
+
|
|
939
|
+
# Validate and get max
|
|
940
|
+
gains_tensor = validator.validate_tensor(
|
|
941
|
+
gains_tensor, "sigma_max_gains", config
|
|
942
|
+
)
|
|
943
|
+
finite_mask = torch.isfinite(gains_tensor)
|
|
944
|
+
|
|
945
|
+
if not finite_mask.any():
|
|
946
|
+
logger.warning("All σ_max gains are NaN/Inf")
|
|
947
|
+
return float("nan")
|
|
948
|
+
|
|
949
|
+
sigma_max = torch.max(gains_tensor[finite_mask]).item()
|
|
950
|
+
logger.debug(f"Calculated σ_max: {sigma_max:.4f}")
|
|
951
|
+
return sigma_max
|
|
952
|
+
|
|
953
|
+
except Exception as e:
|
|
954
|
+
logger.warning(f"σ_max calculation failed: {e}")
|
|
955
|
+
return float("nan")
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
def _calculate_head_energy(
|
|
959
|
+
hidden_states_list: list[torch.Tensor], config: MetricsConfig
|
|
960
|
+
) -> float:
|
|
961
|
+
"""Calculate head energy metric (mean squared activation per layer)."""
|
|
962
|
+
if not hidden_states_list:
|
|
963
|
+
logger.info("Skipping head energy: no hidden states available")
|
|
964
|
+
return float("nan")
|
|
965
|
+
|
|
966
|
+
try:
|
|
967
|
+
# Concatenate all hidden states: [L, N, T, D]
|
|
968
|
+
hidden_stack = torch.cat(hidden_states_list, dim=1)
|
|
969
|
+
|
|
970
|
+
# Crop to max_tokens
|
|
971
|
+
hidden_crop = hidden_stack[:, :, : config.max_tokens, :]
|
|
972
|
+
|
|
973
|
+
# Sanitize
|
|
974
|
+
hidden_crop = validator.validate_tensor(
|
|
975
|
+
hidden_crop, "head_energy_hidden_states", config
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
# Calculate mean squared activation per layer
|
|
979
|
+
squared_activations = hidden_crop.float().pow(2).mean(dim=-1) # [L, N, T]
|
|
980
|
+
per_layer_energy = squared_activations.mean(dim=(1, 2)) # [L]
|
|
981
|
+
|
|
982
|
+
# Filter finite values
|
|
983
|
+
finite_mask = torch.isfinite(per_layer_energy)
|
|
984
|
+
if not finite_mask.any():
|
|
985
|
+
logger.warning("All head energies are NaN/Inf")
|
|
986
|
+
return float("nan")
|
|
987
|
+
|
|
988
|
+
head_energy = per_layer_energy[finite_mask].mean().item()
|
|
989
|
+
logger.debug(f"Calculated head energy: {head_energy:.6f}")
|
|
990
|
+
return head_energy
|
|
991
|
+
|
|
992
|
+
except Exception as e:
|
|
993
|
+
logger.warning(f"Head energy calculation failed: {e}")
|
|
994
|
+
return float("nan")
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
def _calculate_mi_gini(
|
|
998
|
+
model: nn.Module,
|
|
999
|
+
activation_data: dict[str, Any],
|
|
1000
|
+
dep_manager: DependencyManager,
|
|
1001
|
+
config: MetricsConfig,
|
|
1002
|
+
device: torch.device,
|
|
1003
|
+
) -> float:
|
|
1004
|
+
"""Calculate MI-based Gini coefficient."""
|
|
1005
|
+
if not dep_manager.is_available("mi_scores"):
|
|
1006
|
+
logger.info("Skipping MI-Gini: mi_scores not available")
|
|
1007
|
+
return float("nan")
|
|
1008
|
+
|
|
1009
|
+
if not activation_data["fc1_activations"] or not activation_data["targets"]:
|
|
1010
|
+
logger.info("Skipping MI-Gini: no FC1 activations available")
|
|
1011
|
+
return float("nan")
|
|
1012
|
+
|
|
1013
|
+
try:
|
|
1014
|
+
# Concatenate activations and targets
|
|
1015
|
+
fc1_all = torch.cat(activation_data["fc1_activations"], dim=1) # [L, N, T, D]
|
|
1016
|
+
targ_all = torch.cat(activation_data["targets"], dim=0) # [N, T]
|
|
1017
|
+
|
|
1018
|
+
# Trim to align dimensions (remove last token from activations)
|
|
1019
|
+
fc1_trim = fc1_all[:, :, :-1, :] # [L, N, T-1, D]
|
|
1020
|
+
|
|
1021
|
+
# Crop to max_tokens
|
|
1022
|
+
fc1_trim = fc1_trim[:, :, : config.max_tokens, :]
|
|
1023
|
+
targ_trim = targ_all[:, : config.max_tokens]
|
|
1024
|
+
|
|
1025
|
+
# Reshape for MI calculation
|
|
1026
|
+
L, N, T, D = fc1_trim.shape
|
|
1027
|
+
fc1_flat = fc1_trim.permute(0, 2, 1, 3).reshape(L, -1, D) # [L, N*T, D]
|
|
1028
|
+
targ_flat = targ_trim.flatten() # [N*T]
|
|
1029
|
+
|
|
1030
|
+
# Validate tensors
|
|
1031
|
+
fc1_flat = InputValidator.validate_tensor(fc1_flat, "mi_gini_features", config)
|
|
1032
|
+
targ_flat = InputValidator.validate_tensor(targ_flat, "mi_gini_targets", config)
|
|
1033
|
+
|
|
1034
|
+
# Get MI scores function
|
|
1035
|
+
mi_scores_fn = dep_manager.get_module("mi_scores")
|
|
1036
|
+
|
|
1037
|
+
# Try GPU calculation first
|
|
1038
|
+
try:
|
|
1039
|
+
logger.debug("Attempting MI-Gini calculation on GPU")
|
|
1040
|
+
mi_scores_result = mi_scores_fn(fc1_flat, targ_flat)
|
|
1041
|
+
mi_gini = _gini_vectorized(mi_scores_result)
|
|
1042
|
+
logger.debug(f"Calculated MI-Gini (GPU): {mi_gini:.6f}")
|
|
1043
|
+
return mi_gini
|
|
1044
|
+
|
|
1045
|
+
except RuntimeError as e:
|
|
1046
|
+
if "out of memory" in str(e).lower():
|
|
1047
|
+
logger.warning("GPU OOM for MI-Gini, falling back to CPU")
|
|
1048
|
+
if torch.cuda.is_available():
|
|
1049
|
+
torch.cuda.empty_cache()
|
|
1050
|
+
|
|
1051
|
+
# CPU fallback with subsampling
|
|
1052
|
+
mi_gini = _mi_gini_optimized_cpu_path(
|
|
1053
|
+
fc1_flat.cpu().float(),
|
|
1054
|
+
targ_flat.cpu(),
|
|
1055
|
+
config.max_samples_per_layer,
|
|
1056
|
+
config,
|
|
1057
|
+
)
|
|
1058
|
+
logger.debug(f"Calculated MI-Gini (CPU): {mi_gini:.6f}")
|
|
1059
|
+
return mi_gini
|
|
1060
|
+
else:
|
|
1061
|
+
raise
|
|
1062
|
+
|
|
1063
|
+
except Exception as e:
|
|
1064
|
+
logger.warning(f"MI-Gini calculation failed: {e}")
|
|
1065
|
+
return float("nan")
|
|
1066
|
+
|
|
1067
|
+
|
|
1068
|
+
def _finalize_results(
|
|
1069
|
+
results: dict[str, Any],
|
|
1070
|
+
skipped_metrics: list[str],
|
|
1071
|
+
cache: ResultCache,
|
|
1072
|
+
cache_key: str,
|
|
1073
|
+
start_time: float,
|
|
1074
|
+
) -> dict[str, float]:
|
|
1075
|
+
"""Finalize and validate results."""
|
|
1076
|
+
# Ensure all values are finite or NaN
|
|
1077
|
+
for key, value in results.items():
|
|
1078
|
+
if not isinstance(value, int | float):
|
|
1079
|
+
logger.warning(
|
|
1080
|
+
f"Metric {key} has invalid type {type(value)}, setting to NaN"
|
|
1081
|
+
)
|
|
1082
|
+
results[key] = float("nan")
|
|
1083
|
+
elif not (math.isnan(value) or math.isfinite(value)):
|
|
1084
|
+
logger.warning(f"Metric {key} is infinite, setting to NaN")
|
|
1085
|
+
results[key] = float("nan")
|
|
1086
|
+
|
|
1087
|
+
# Log skipped metrics
|
|
1088
|
+
if skipped_metrics:
|
|
1089
|
+
logger.info(f"Skipped metrics: {', '.join(skipped_metrics)}")
|
|
1090
|
+
|
|
1091
|
+
# Cache results
|
|
1092
|
+
cache.set(cache_key, results)
|
|
1093
|
+
|
|
1094
|
+
# Log completion
|
|
1095
|
+
elapsed = time.time() - start_time
|
|
1096
|
+
logger.info(f"Metrics calculation completed in {elapsed:.2f}s: {results}")
|
|
1097
|
+
|
|
1098
|
+
return results
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
# ── Backward compatibility functions ──────────────────────────────────────
|
|
1102
|
+
def _gini(vec: torch.Tensor) -> float:
|
|
1103
|
+
"""Legacy Gini function for backward compatibility."""
|
|
1104
|
+
return _gini_vectorized(vec)
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
def _mi_gini_cpu_safe_path(
|
|
1108
|
+
feats_cpu: torch.Tensor, targ_cpu: torch.Tensor, max_per_layer: int
|
|
1109
|
+
) -> float:
|
|
1110
|
+
"""Legacy CPU MI-Gini function for backward compatibility."""
|
|
1111
|
+
config = MetricsConfig(max_samples_per_layer=max_per_layer, progress_bars=True)
|
|
1112
|
+
return _mi_gini_optimized_cpu_path(feats_cpu, targ_cpu, max_per_layer, config)
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
def _locate_transformer_blocks(model: nn.Module) -> list[nn.Module] | None:
|
|
1116
|
+
"""Legacy transformer block locator for backward compatibility."""
|
|
1117
|
+
return _locate_transformer_blocks_enhanced(model)
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
# ── Additional utility functions ───────────────────────────────────────────
|
|
1121
|
+
def get_metrics_info() -> dict[str, Any]:
|
|
1122
|
+
"""Get information about available metrics and dependencies."""
|
|
1123
|
+
dep_manager = DependencyManager()
|
|
1124
|
+
|
|
1125
|
+
return {
|
|
1126
|
+
"available_metrics": ["sigma_max", "head_energy", "mi_gini"],
|
|
1127
|
+
"available_dependencies": list(dep_manager.available_modules.keys()),
|
|
1128
|
+
"missing_dependencies": dep_manager.get_missing_dependencies(),
|
|
1129
|
+
"default_config": MetricsConfig().__dict__,
|
|
1130
|
+
}
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
def validate_metrics_environment() -> bool:
|
|
1134
|
+
"""Validate that the metrics environment is properly set up."""
|
|
1135
|
+
try:
|
|
1136
|
+
dep_manager = DependencyManager()
|
|
1137
|
+
MetricsConfig()
|
|
1138
|
+
|
|
1139
|
+
# Check basic dependencies
|
|
1140
|
+
|
|
1141
|
+
logger.info("✓ Basic dependencies available")
|
|
1142
|
+
|
|
1143
|
+
# Check optional dependencies
|
|
1144
|
+
available_count = len(dep_manager.available_modules)
|
|
1145
|
+
total_count = available_count + len(dep_manager.missing_modules)
|
|
1146
|
+
|
|
1147
|
+
logger.info(
|
|
1148
|
+
f"✓ {available_count}/{total_count} optional dependencies available"
|
|
1149
|
+
)
|
|
1150
|
+
|
|
1151
|
+
if dep_manager.missing_modules:
|
|
1152
|
+
logger.warning("Some optional dependencies are missing:")
|
|
1153
|
+
for name, error in dep_manager.missing_modules:
|
|
1154
|
+
logger.warning(f" - {name}: {error}")
|
|
1155
|
+
|
|
1156
|
+
return True
|
|
1157
|
+
|
|
1158
|
+
except Exception as e:
|
|
1159
|
+
logger.error(f"Environment validation failed: {e}")
|
|
1160
|
+
return False
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
# ── Import necessary modules for validation ────────────────────────────────
|
|
1164
|
+
# Note: math is already imported at top of file
|
|
1165
|
+
|
|
1166
|
+
# Global validator instance for use in helper functions
|
|
1167
|
+
validator = InputValidator()
|
|
1168
|
+
|
|
1169
|
+
|
|
1170
|
+
# ── Perplexity validation ──────────────────────────────────────────────────
|
|
1171
|
+
class PerplexityStatus:
|
|
1172
|
+
"""Quality status levels for ppl-like primary metrics (perplexity)."""
|
|
1173
|
+
|
|
1174
|
+
EXCELLENT = "excellent" # < 50
|
|
1175
|
+
GOOD = "good" # 50-100
|
|
1176
|
+
ACCEPTABLE = "acceptable" # 100-200
|
|
1177
|
+
POOR = "poor" # 200-500
|
|
1178
|
+
UNUSABLE = "unusable" # > 500
|
|
1179
|
+
|
|
1180
|
+
@classmethod
|
|
1181
|
+
def from_value(cls, ppl: float, vocab_size: int | None = None) -> str:
|
|
1182
|
+
"""Get status from perplexity value."""
|
|
1183
|
+
if ppl < 50:
|
|
1184
|
+
return cls.EXCELLENT
|
|
1185
|
+
elif ppl < 100:
|
|
1186
|
+
return cls.GOOD
|
|
1187
|
+
elif ppl < 200:
|
|
1188
|
+
return cls.ACCEPTABLE
|
|
1189
|
+
elif ppl < 500:
|
|
1190
|
+
return cls.POOR
|
|
1191
|
+
else:
|
|
1192
|
+
return cls.UNUSABLE
|
|
1193
|
+
|
|
1194
|
+
|
|
1195
|
+
def validate_perplexity(
|
|
1196
|
+
ppl: float,
|
|
1197
|
+
vocab_size: int | None = None,
|
|
1198
|
+
context: str = "evaluation",
|
|
1199
|
+
warn_threshold: float = 200.0,
|
|
1200
|
+
error_threshold: float = 2000.0,
|
|
1201
|
+
allow_high: bool = False,
|
|
1202
|
+
) -> tuple[bool, str, str]:
|
|
1203
|
+
"""
|
|
1204
|
+
Validate perplexity value and provide feedback.
|
|
1205
|
+
|
|
1206
|
+
Args:
|
|
1207
|
+
ppl: Perplexity value to validate
|
|
1208
|
+
vocab_size: Vocabulary size for context-aware validation
|
|
1209
|
+
context: Context string for error messages
|
|
1210
|
+
warn_threshold: Threshold for warning (default 200)
|
|
1211
|
+
error_threshold: Threshold for error (default 2000)
|
|
1212
|
+
allow_high: Allow high perplexity values (for testing)
|
|
1213
|
+
|
|
1214
|
+
Returns:
|
|
1215
|
+
Tuple of (is_valid, status, message)
|
|
1216
|
+
"""
|
|
1217
|
+
# Check for invalid values
|
|
1218
|
+
if math.isnan(ppl) or math.isinf(ppl):
|
|
1219
|
+
return False, "invalid", f"Perplexity is {ppl}"
|
|
1220
|
+
|
|
1221
|
+
if ppl < 1.0:
|
|
1222
|
+
return False, "invalid", f"Perplexity {ppl:.2f} is less than 1.0"
|
|
1223
|
+
|
|
1224
|
+
# Get status
|
|
1225
|
+
status = PerplexityStatus.from_value(ppl, vocab_size)
|
|
1226
|
+
|
|
1227
|
+
# Adjust thresholds based on vocab size if provided
|
|
1228
|
+
if vocab_size is not None:
|
|
1229
|
+
# For untrained models, ppl-like PM ≈ vocab_size is expected
|
|
1230
|
+
# Adjust thresholds accordingly
|
|
1231
|
+
warn_threshold = max(warn_threshold, vocab_size * 0.5)
|
|
1232
|
+
error_threshold = max(error_threshold, vocab_size * 2.0)
|
|
1233
|
+
|
|
1234
|
+
# Generate message based on status
|
|
1235
|
+
if ppl > error_threshold and not allow_high:
|
|
1236
|
+
message = (
|
|
1237
|
+
f"Perplexity {ppl:.1f} exceeds error threshold {error_threshold:.0f} "
|
|
1238
|
+
f"in {context}. Model appears to be untrained or corrupted."
|
|
1239
|
+
)
|
|
1240
|
+
return False, status, message
|
|
1241
|
+
|
|
1242
|
+
elif ppl > warn_threshold:
|
|
1243
|
+
message = (
|
|
1244
|
+
f"Perplexity {ppl:.1f} exceeds warning threshold {warn_threshold:.0f} "
|
|
1245
|
+
f"in {context}. Model may be severely degraded."
|
|
1246
|
+
)
|
|
1247
|
+
if not allow_high:
|
|
1248
|
+
logger.warning(message)
|
|
1249
|
+
return True, status, message
|
|
1250
|
+
|
|
1251
|
+
elif status == PerplexityStatus.POOR:
|
|
1252
|
+
message = f"Perplexity {ppl:.1f} indicates poor model quality in {context}."
|
|
1253
|
+
logger.info(message)
|
|
1254
|
+
return True, status, message
|
|
1255
|
+
|
|
1256
|
+
elif status == PerplexityStatus.ACCEPTABLE:
|
|
1257
|
+
message = f"Perplexity {ppl:.1f} is acceptable for {context}."
|
|
1258
|
+
return True, status, message
|
|
1259
|
+
|
|
1260
|
+
else:
|
|
1261
|
+
message = f"Perplexity {ppl:.1f} is {status} for {context}."
|
|
1262
|
+
return True, status, message
|
|
1263
|
+
|
|
1264
|
+
|
|
1265
|
+
# ── Helper function for robust forward pass ────────────────────────────────
|
|
1266
|
+
def _forward_loss_causal(
|
|
1267
|
+
model: nn.Module,
|
|
1268
|
+
input_ids: torch.Tensor,
|
|
1269
|
+
attention_mask: torch.Tensor | None = None,
|
|
1270
|
+
labels: torch.Tensor | None = None,
|
|
1271
|
+
) -> tuple[float, torch.Tensor | None]:
|
|
1272
|
+
"""
|
|
1273
|
+
Robust forward that handles HF ModelOutput or tuple, computes loss if needed.
|
|
1274
|
+
Returns (loss_value: float, logits: torch.Tensor or None).
|
|
1275
|
+
"""
|
|
1276
|
+
import torch.nn.functional as F
|
|
1277
|
+
|
|
1278
|
+
# 1) Prefer dict-style outputs
|
|
1279
|
+
try:
|
|
1280
|
+
outputs = model(
|
|
1281
|
+
input_ids=input_ids,
|
|
1282
|
+
attention_mask=attention_mask,
|
|
1283
|
+
labels=labels,
|
|
1284
|
+
return_dict=True,
|
|
1285
|
+
)
|
|
1286
|
+
# If we got a ModelOutput, use it
|
|
1287
|
+
if hasattr(outputs, "loss") and outputs.loss is not None:
|
|
1288
|
+
return float(outputs.loss.detach().cpu()), getattr(outputs, "logits", None)
|
|
1289
|
+
logits = getattr(outputs, "logits", None)
|
|
1290
|
+
except (TypeError, AttributeError):
|
|
1291
|
+
# Some stub models/tests may not accept return_dict
|
|
1292
|
+
outputs = model(
|
|
1293
|
+
input_ids=input_ids, attention_mask=attention_mask, labels=labels
|
|
1294
|
+
)
|
|
1295
|
+
if isinstance(outputs, tuple | list):
|
|
1296
|
+
# If labels were provided, many HF models put loss first, logits second
|
|
1297
|
+
if (
|
|
1298
|
+
labels is not None
|
|
1299
|
+
and len(outputs) >= 2
|
|
1300
|
+
and torch.is_tensor(outputs[0])
|
|
1301
|
+
and outputs[0].ndim == 0
|
|
1302
|
+
):
|
|
1303
|
+
return float(outputs[0].detach().cpu()), outputs[1] if len(
|
|
1304
|
+
outputs
|
|
1305
|
+
) > 1 else None
|
|
1306
|
+
# Otherwise first is logits
|
|
1307
|
+
logits = outputs[0] if len(outputs) > 0 else None
|
|
1308
|
+
else:
|
|
1309
|
+
# Custom object: try attributes
|
|
1310
|
+
maybe_loss = getattr(outputs, "loss", None)
|
|
1311
|
+
maybe_logits = getattr(outputs, "logits", None)
|
|
1312
|
+
if maybe_loss is not None:
|
|
1313
|
+
return float(maybe_loss.detach().cpu()), maybe_logits
|
|
1314
|
+
logits = maybe_logits
|
|
1315
|
+
|
|
1316
|
+
# 2) If we're here, we have logits but no loss → compute it manually
|
|
1317
|
+
if logits is None:
|
|
1318
|
+
raise MetricsError(
|
|
1319
|
+
code="E401",
|
|
1320
|
+
message="METRICS-COMPUTE-FAILED: model returned neither loss nor logits",
|
|
1321
|
+
)
|
|
1322
|
+
|
|
1323
|
+
if labels is None:
|
|
1324
|
+
raise ValidationError(
|
|
1325
|
+
code="E402",
|
|
1326
|
+
message="METRICS-VALIDATION-FAILED",
|
|
1327
|
+
details={"reason": "labels are required to compute perplexity loss"},
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
# Causal LM shift
|
|
1331
|
+
shift_logits = logits[:, :-1, :].contiguous()
|
|
1332
|
+
shift_labels = labels[:, 1:].contiguous()
|
|
1333
|
+
|
|
1334
|
+
loss = F.cross_entropy(
|
|
1335
|
+
shift_logits.view(-1, shift_logits.size(-1)),
|
|
1336
|
+
shift_labels.view(-1),
|
|
1337
|
+
ignore_index=-100,
|
|
1338
|
+
reduction="mean",
|
|
1339
|
+
)
|
|
1340
|
+
return float(loss.detach().cpu()), logits
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
def _resolve_eval_device(
|
|
1344
|
+
model: nn.Module, device: str | torch.device | None
|
|
1345
|
+
) -> torch.device:
|
|
1346
|
+
"""
|
|
1347
|
+
Resolve evaluation device with graceful MPS fallback.
|
|
1348
|
+
|
|
1349
|
+
If MPS is requested but unavailable (common in CI or non‑MacOS builds),
|
|
1350
|
+
fall back to CPU instead of raising at tensor .to(device) calls.
|
|
1351
|
+
"""
|
|
1352
|
+
if device is None:
|
|
1353
|
+
try:
|
|
1354
|
+
resolved = next(model.parameters()).device
|
|
1355
|
+
except StopIteration:
|
|
1356
|
+
resolved = torch.device("cpu")
|
|
1357
|
+
else:
|
|
1358
|
+
resolved = torch.device(device) if isinstance(device, str) else device
|
|
1359
|
+
|
|
1360
|
+
# Handle MPS when backend is not actually usable
|
|
1361
|
+
try:
|
|
1362
|
+
if isinstance(resolved, torch.device) and resolved.type == "mps":
|
|
1363
|
+
mps_backend = getattr(torch.backends, "mps", None)
|
|
1364
|
+
is_available = bool(
|
|
1365
|
+
mps_backend is not None
|
|
1366
|
+
and hasattr(mps_backend, "is_available")
|
|
1367
|
+
and mps_backend.is_available()
|
|
1368
|
+
)
|
|
1369
|
+
if not is_available:
|
|
1370
|
+
logger.warning(
|
|
1371
|
+
"Requested device 'mps' for metrics evaluation but MPS backend "
|
|
1372
|
+
"is not available; falling back to CPU."
|
|
1373
|
+
)
|
|
1374
|
+
resolved = torch.device("cpu")
|
|
1375
|
+
except Exception:
|
|
1376
|
+
# On any introspection failure, be conservative and fall back to CPU
|
|
1377
|
+
resolved = torch.device("cpu")
|
|
1378
|
+
|
|
1379
|
+
return resolved
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
# ── Perplexity calculation ─────────────────────────────────────────────────
|
|
1383
|
+
@torch.no_grad()
|
|
1384
|
+
def calculate_perplexity(
|
|
1385
|
+
model: nn.Module,
|
|
1386
|
+
dataloader,
|
|
1387
|
+
max_batches: int = 100,
|
|
1388
|
+
device: str | torch.device | None = None,
|
|
1389
|
+
) -> float:
|
|
1390
|
+
"""
|
|
1391
|
+
DEPRECATED: Use compute_perplexity for new code.
|
|
1392
|
+
This is an alias for backward compatibility with tests.
|
|
1393
|
+
"""
|
|
1394
|
+
return compute_perplexity(model, dataloader, max_samples=max_batches, device=device)
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
@torch.no_grad()
|
|
1398
|
+
def compute_perplexity_strict(
|
|
1399
|
+
model: nn.Module, dataloader, device: str | torch.device | None = None
|
|
1400
|
+
) -> float:
|
|
1401
|
+
"""
|
|
1402
|
+
Compute perplexity with strict token-level accounting.
|
|
1403
|
+
|
|
1404
|
+
Args:
|
|
1405
|
+
model: Language model to evaluate
|
|
1406
|
+
dataloader: DataLoader providing input sequences
|
|
1407
|
+
device: Device to use for computation
|
|
1408
|
+
|
|
1409
|
+
Returns:
|
|
1410
|
+
Perplexity value
|
|
1411
|
+
|
|
1412
|
+
Raises:
|
|
1413
|
+
ValueError: If no valid tokens found for perplexity computation
|
|
1414
|
+
"""
|
|
1415
|
+
device = _resolve_eval_device(model, device)
|
|
1416
|
+
|
|
1417
|
+
model.eval()
|
|
1418
|
+
nll_sum = 0.0
|
|
1419
|
+
tok_count = 0
|
|
1420
|
+
|
|
1421
|
+
for batch in dataloader:
|
|
1422
|
+
# Handle different batch formats
|
|
1423
|
+
if isinstance(batch, dict):
|
|
1424
|
+
input_ids = batch.get("input_ids", batch.get("inputs", None))
|
|
1425
|
+
labels = batch.get("labels", None)
|
|
1426
|
+
attention_mask = batch.get("attention_mask", None)
|
|
1427
|
+
token_type_ids = batch.get("token_type_ids", None)
|
|
1428
|
+
elif isinstance(batch, tuple | list):
|
|
1429
|
+
input_ids = batch[0] if len(batch) > 0 else None
|
|
1430
|
+
labels = batch[1] if len(batch) > 1 else None
|
|
1431
|
+
attention_mask = batch[2] if len(batch) > 2 else None
|
|
1432
|
+
token_type_ids = batch[3] if len(batch) > 3 else None
|
|
1433
|
+
else:
|
|
1434
|
+
input_ids = batch
|
|
1435
|
+
labels = None
|
|
1436
|
+
attention_mask = None
|
|
1437
|
+
token_type_ids = None
|
|
1438
|
+
|
|
1439
|
+
if input_ids is None or not isinstance(input_ids, torch.Tensor):
|
|
1440
|
+
continue
|
|
1441
|
+
|
|
1442
|
+
input_ids = input_ids.to(device)
|
|
1443
|
+
attn = attention_mask.to(device) if attention_mask is not None else None
|
|
1444
|
+
token_type_ids_t = (
|
|
1445
|
+
token_type_ids.to(device) if token_type_ids is not None else None
|
|
1446
|
+
)
|
|
1447
|
+
|
|
1448
|
+
# Default causal labels
|
|
1449
|
+
if labels is None:
|
|
1450
|
+
labels = input_ids.clone()
|
|
1451
|
+
if attn is not None:
|
|
1452
|
+
labels[attn == 0] = -100
|
|
1453
|
+
else:
|
|
1454
|
+
labels = labels.to(device)
|
|
1455
|
+
|
|
1456
|
+
# Skip if sequence too short
|
|
1457
|
+
if input_ids.size(1) < 2:
|
|
1458
|
+
continue
|
|
1459
|
+
|
|
1460
|
+
is_masked_lm = hasattr(model, "config") and getattr(
|
|
1461
|
+
model.config, "model_type", ""
|
|
1462
|
+
) in {"bert", "roberta", "distilbert", "albert"}
|
|
1463
|
+
|
|
1464
|
+
if is_masked_lm:
|
|
1465
|
+
masked_labels = labels.clone()
|
|
1466
|
+
if attn is not None:
|
|
1467
|
+
masked_labels = masked_labels.masked_fill(attn == 0, -100)
|
|
1468
|
+
outputs = model(
|
|
1469
|
+
input_ids=input_ids,
|
|
1470
|
+
attention_mask=attn,
|
|
1471
|
+
token_type_ids=token_type_ids_t,
|
|
1472
|
+
labels=masked_labels,
|
|
1473
|
+
return_dict=True,
|
|
1474
|
+
)
|
|
1475
|
+
loss = outputs.loss
|
|
1476
|
+
if loss is None:
|
|
1477
|
+
continue
|
|
1478
|
+
valid_tokens = int((masked_labels != -100).sum().item())
|
|
1479
|
+
if valid_tokens == 0:
|
|
1480
|
+
continue
|
|
1481
|
+
nll_sum += float(loss.item()) * valid_tokens
|
|
1482
|
+
tok_count += valid_tokens
|
|
1483
|
+
continue
|
|
1484
|
+
|
|
1485
|
+
# Forward (don't trust .loss, compute ourselves)
|
|
1486
|
+
try:
|
|
1487
|
+
outputs = model(input_ids=input_ids, attention_mask=attn, return_dict=True)
|
|
1488
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
|
|
1489
|
+
except Exception:
|
|
1490
|
+
# Fallback for non-standard models
|
|
1491
|
+
outputs = model(input_ids=input_ids, attention_mask=attn)
|
|
1492
|
+
if isinstance(outputs, tuple | list):
|
|
1493
|
+
logits = outputs[0]
|
|
1494
|
+
else:
|
|
1495
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
|
1496
|
+
|
|
1497
|
+
# Causal shift
|
|
1498
|
+
shift_logits = logits[:, :-1, :]
|
|
1499
|
+
shift_labels = labels[:, 1:]
|
|
1500
|
+
shift_mask = attn[:, 1:] if attn is not None else None
|
|
1501
|
+
|
|
1502
|
+
valid = shift_labels != -100
|
|
1503
|
+
if shift_mask is not None:
|
|
1504
|
+
valid = valid & shift_mask.bool()
|
|
1505
|
+
|
|
1506
|
+
if not valid.any():
|
|
1507
|
+
continue
|
|
1508
|
+
|
|
1509
|
+
log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
|
|
1510
|
+
tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
|
|
1511
|
+
nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
|
|
1512
|
+
|
|
1513
|
+
nll_sum += nll[valid].sum().item()
|
|
1514
|
+
tok_count += int(valid.sum().item())
|
|
1515
|
+
|
|
1516
|
+
if tok_count == 0:
|
|
1517
|
+
raise ValidationError(
|
|
1518
|
+
code="E402",
|
|
1519
|
+
message="METRICS-VALIDATION-FAILED",
|
|
1520
|
+
details={
|
|
1521
|
+
"reason": "No valid tokens for perplexity (all masked or seq_len<=1)."
|
|
1522
|
+
},
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
return float(torch.exp(torch.tensor(nll_sum / tok_count)))
|
|
1526
|
+
|
|
1527
|
+
|
|
1528
|
+
@torch.no_grad()
|
|
1529
|
+
def compute_perplexity(
|
|
1530
|
+
model: nn.Module,
|
|
1531
|
+
dataloader,
|
|
1532
|
+
max_samples: int = 100,
|
|
1533
|
+
device: str | torch.device | None = None,
|
|
1534
|
+
) -> float:
|
|
1535
|
+
"""
|
|
1536
|
+
Compute perplexity of a language model on a dataset.
|
|
1537
|
+
|
|
1538
|
+
ALWAYS uses strict token-level accounting to avoid padding issues.
|
|
1539
|
+
|
|
1540
|
+
Args:
|
|
1541
|
+
model: Language model to evaluate
|
|
1542
|
+
dataloader: DataLoader providing input sequences
|
|
1543
|
+
max_samples: Maximum number of batches to evaluate
|
|
1544
|
+
device: Device to use for computation
|
|
1545
|
+
|
|
1546
|
+
Returns:
|
|
1547
|
+
Perplexity value
|
|
1548
|
+
|
|
1549
|
+
Raises:
|
|
1550
|
+
ValueError: If no valid tokens found
|
|
1551
|
+
"""
|
|
1552
|
+
device = _resolve_eval_device(model, device)
|
|
1553
|
+
|
|
1554
|
+
model.eval()
|
|
1555
|
+
nll_sum = 0.0
|
|
1556
|
+
tok_count = 0
|
|
1557
|
+
batch_count = 0
|
|
1558
|
+
|
|
1559
|
+
for i, batch in enumerate(dataloader):
|
|
1560
|
+
# Check max_samples limit
|
|
1561
|
+
if max_samples is not None and i >= max_samples:
|
|
1562
|
+
break
|
|
1563
|
+
|
|
1564
|
+
# Handle different batch formats
|
|
1565
|
+
if isinstance(batch, dict):
|
|
1566
|
+
input_ids = batch.get("input_ids", batch.get("inputs", None))
|
|
1567
|
+
labels = batch.get("labels", None)
|
|
1568
|
+
attention_mask = batch.get("attention_mask", None)
|
|
1569
|
+
elif isinstance(batch, tuple | list):
|
|
1570
|
+
input_ids = batch[0] if len(batch) > 0 else None
|
|
1571
|
+
labels = batch[1] if len(batch) > 1 else None
|
|
1572
|
+
attention_mask = batch[2] if len(batch) > 2 else None
|
|
1573
|
+
else:
|
|
1574
|
+
input_ids = batch
|
|
1575
|
+
labels = None
|
|
1576
|
+
attention_mask = None
|
|
1577
|
+
|
|
1578
|
+
if input_ids is None or not isinstance(input_ids, torch.Tensor):
|
|
1579
|
+
continue
|
|
1580
|
+
|
|
1581
|
+
input_ids = input_ids.to(device)
|
|
1582
|
+
attn = attention_mask.to(device) if attention_mask is not None else None
|
|
1583
|
+
|
|
1584
|
+
# Default causal labels
|
|
1585
|
+
if labels is None:
|
|
1586
|
+
labels = input_ids.clone()
|
|
1587
|
+
if attn is not None:
|
|
1588
|
+
labels[attn == 0] = -100
|
|
1589
|
+
else:
|
|
1590
|
+
labels = labels.to(device)
|
|
1591
|
+
|
|
1592
|
+
# Skip if sequence too short
|
|
1593
|
+
if input_ids.size(1) < 2:
|
|
1594
|
+
continue
|
|
1595
|
+
|
|
1596
|
+
# Forward pass - get logits
|
|
1597
|
+
try:
|
|
1598
|
+
outputs = model(input_ids=input_ids, attention_mask=attn, return_dict=True)
|
|
1599
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
|
|
1600
|
+
except Exception:
|
|
1601
|
+
# Fallback for non-standard models
|
|
1602
|
+
outputs = model(input_ids=input_ids, attention_mask=attn)
|
|
1603
|
+
if isinstance(outputs, tuple | list):
|
|
1604
|
+
logits = outputs[0]
|
|
1605
|
+
else:
|
|
1606
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
|
1607
|
+
|
|
1608
|
+
# Causal shift for next-token prediction
|
|
1609
|
+
shift_logits = logits[:, :-1, :]
|
|
1610
|
+
shift_labels = labels[:, 1:]
|
|
1611
|
+
shift_mask = attn[:, 1:] if attn is not None else None
|
|
1612
|
+
|
|
1613
|
+
# Identify valid (non-padding) tokens
|
|
1614
|
+
valid = shift_labels != -100
|
|
1615
|
+
if shift_mask is not None:
|
|
1616
|
+
valid = valid & shift_mask.bool()
|
|
1617
|
+
|
|
1618
|
+
if not valid.any():
|
|
1619
|
+
continue
|
|
1620
|
+
|
|
1621
|
+
# Compute negative log-likelihood
|
|
1622
|
+
log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
|
|
1623
|
+
tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
|
|
1624
|
+
|
|
1625
|
+
# MPS workaround: gather operation can fail on MPS, use CPU fallback
|
|
1626
|
+
if str(device).startswith("mps"):
|
|
1627
|
+
log_probs_cpu = log_probs.cpu()
|
|
1628
|
+
tgt_cpu = tgt.cpu()
|
|
1629
|
+
nll_cpu = -log_probs_cpu.gather(-1, tgt_cpu).squeeze(-1)
|
|
1630
|
+
nll = nll_cpu.to(device)
|
|
1631
|
+
else:
|
|
1632
|
+
nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
|
|
1633
|
+
|
|
1634
|
+
# Accumulate only for valid tokens
|
|
1635
|
+
nll_sum += nll[valid].sum().item()
|
|
1636
|
+
tok_count += int(valid.sum().item())
|
|
1637
|
+
batch_count += 1
|
|
1638
|
+
|
|
1639
|
+
if tok_count == 0:
|
|
1640
|
+
raise ValidationError(
|
|
1641
|
+
code="E402",
|
|
1642
|
+
message="METRICS-VALIDATION-FAILED",
|
|
1643
|
+
details={
|
|
1644
|
+
"reason": (
|
|
1645
|
+
f"No valid tokens for perplexity computation after {batch_count} batches. "
|
|
1646
|
+
"All tokens were either padding or sequences were too short (<=1 token). "
|
|
1647
|
+
"Ensure your data contains sequences of at least 2 tokens."
|
|
1648
|
+
)
|
|
1649
|
+
},
|
|
1650
|
+
)
|
|
1651
|
+
|
|
1652
|
+
# Compute perplexity from average NLL
|
|
1653
|
+
avg_nll = nll_sum / tok_count
|
|
1654
|
+
ppl = float(math.exp(avg_nll))
|
|
1655
|
+
|
|
1656
|
+
# Sanity check
|
|
1657
|
+
if ppl < 1.0:
|
|
1658
|
+
logger.warning(
|
|
1659
|
+
f"Computed perplexity {ppl:.2f} is less than 1.0, setting to 1.0"
|
|
1660
|
+
)
|
|
1661
|
+
ppl = 1.0
|
|
1662
|
+
elif not math.isfinite(ppl):
|
|
1663
|
+
logger.warning(f"Computed perplexity is not finite: {ppl}")
|
|
1664
|
+
ppl = float("inf")
|
|
1665
|
+
|
|
1666
|
+
return ppl
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
# ── New Unified Evaluation Functions ──────────────────────────────────────
|
|
1670
|
+
|
|
1671
|
+
|
|
1672
|
+
@torch.no_grad()
|
|
1673
|
+
def compute_ppl(
|
|
1674
|
+
model: nn.Module,
|
|
1675
|
+
adapter: Any | None,
|
|
1676
|
+
window: Any, # EvaluationWindow
|
|
1677
|
+
device: str | torch.device | None = None,
|
|
1678
|
+
) -> float:
|
|
1679
|
+
"""
|
|
1680
|
+
Compute perplexity for a specific evaluation window.
|
|
1681
|
+
|
|
1682
|
+
This is the new unified evaluation function that works with EvaluationWindow objects
|
|
1683
|
+
from the data loading system.
|
|
1684
|
+
|
|
1685
|
+
Args:
|
|
1686
|
+
model: Language model to evaluate
|
|
1687
|
+
adapter: Model adapter (unused currently, for future extensibility)
|
|
1688
|
+
window: EvaluationWindow with tokenized samples
|
|
1689
|
+
device: Device to use for computation
|
|
1690
|
+
|
|
1691
|
+
Returns:
|
|
1692
|
+
Perplexity value for the window
|
|
1693
|
+
"""
|
|
1694
|
+
device = _resolve_eval_device(model, device)
|
|
1695
|
+
|
|
1696
|
+
model.eval()
|
|
1697
|
+
nll_sum = 0.0
|
|
1698
|
+
tok_count = 0
|
|
1699
|
+
|
|
1700
|
+
# Process each sample in the window
|
|
1701
|
+
for input_ids, attention_mask in zip(
|
|
1702
|
+
window.input_ids, window.attention_masks, strict=False
|
|
1703
|
+
):
|
|
1704
|
+
if not input_ids:
|
|
1705
|
+
continue
|
|
1706
|
+
|
|
1707
|
+
# Convert to tensors
|
|
1708
|
+
input_ids_tensor = (
|
|
1709
|
+
torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
|
|
1710
|
+
)
|
|
1711
|
+
attention_mask_tensor = (
|
|
1712
|
+
torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
|
|
1713
|
+
)
|
|
1714
|
+
|
|
1715
|
+
# Skip sequences that are too short
|
|
1716
|
+
if input_ids_tensor.size(1) < 2:
|
|
1717
|
+
continue
|
|
1718
|
+
|
|
1719
|
+
# Forward pass
|
|
1720
|
+
try:
|
|
1721
|
+
outputs = model(
|
|
1722
|
+
input_ids=input_ids_tensor,
|
|
1723
|
+
attention_mask=attention_mask_tensor,
|
|
1724
|
+
return_dict=True,
|
|
1725
|
+
)
|
|
1726
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
|
|
1727
|
+
except Exception:
|
|
1728
|
+
# Fallback for non-standard models
|
|
1729
|
+
outputs = model(
|
|
1730
|
+
input_ids=input_ids_tensor, attention_mask=attention_mask_tensor
|
|
1731
|
+
)
|
|
1732
|
+
if isinstance(outputs, tuple | list):
|
|
1733
|
+
logits = outputs[0]
|
|
1734
|
+
else:
|
|
1735
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
|
1736
|
+
|
|
1737
|
+
# Causal shift for next-token prediction
|
|
1738
|
+
shift_logits = logits[:, :-1, :]
|
|
1739
|
+
shift_labels = input_ids_tensor[:, 1:]
|
|
1740
|
+
shift_mask = attention_mask_tensor[:, 1:]
|
|
1741
|
+
|
|
1742
|
+
# Identify valid (non-padding) tokens
|
|
1743
|
+
valid = (shift_labels != -100) & shift_mask.bool()
|
|
1744
|
+
|
|
1745
|
+
if not valid.any():
|
|
1746
|
+
continue
|
|
1747
|
+
|
|
1748
|
+
# Compute negative log-likelihood
|
|
1749
|
+
log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
|
|
1750
|
+
tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
|
|
1751
|
+
|
|
1752
|
+
# Handle MPS device issues with gather
|
|
1753
|
+
if str(device).startswith("mps"):
|
|
1754
|
+
log_probs_cpu = log_probs.cpu()
|
|
1755
|
+
tgt_cpu = tgt.cpu()
|
|
1756
|
+
nll_cpu = -log_probs_cpu.gather(-1, tgt_cpu).squeeze(-1)
|
|
1757
|
+
nll = nll_cpu.to(device)
|
|
1758
|
+
else:
|
|
1759
|
+
nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
|
|
1760
|
+
|
|
1761
|
+
# Accumulate only for valid tokens
|
|
1762
|
+
nll_sum += nll[valid].sum().item()
|
|
1763
|
+
tok_count += int(valid.sum().item())
|
|
1764
|
+
|
|
1765
|
+
if tok_count == 0:
|
|
1766
|
+
raise ValidationError(
|
|
1767
|
+
code="E402",
|
|
1768
|
+
message="METRICS-VALIDATION-FAILED",
|
|
1769
|
+
details={
|
|
1770
|
+
"reason": "No valid tokens for perplexity computation in evaluation window",
|
|
1771
|
+
},
|
|
1772
|
+
)
|
|
1773
|
+
|
|
1774
|
+
# Compute perplexity from average NLL
|
|
1775
|
+
avg_nll = nll_sum / tok_count
|
|
1776
|
+
ppl = float(math.exp(avg_nll))
|
|
1777
|
+
|
|
1778
|
+
# Sanity check
|
|
1779
|
+
if ppl < 1.0:
|
|
1780
|
+
logger.warning(
|
|
1781
|
+
f"Computed perplexity {ppl:.2f} is less than 1.0, setting to 1.0"
|
|
1782
|
+
)
|
|
1783
|
+
ppl = 1.0
|
|
1784
|
+
elif not math.isfinite(ppl):
|
|
1785
|
+
logger.warning(f"Computed perplexity is not finite: {ppl}")
|
|
1786
|
+
ppl = float("inf")
|
|
1787
|
+
|
|
1788
|
+
return ppl
|
|
1789
|
+
|
|
1790
|
+
|
|
1791
|
+
def measure_latency(
|
|
1792
|
+
model: nn.Module,
|
|
1793
|
+
window: Any, # EvaluationWindow
|
|
1794
|
+
device: str | torch.device | None = None,
|
|
1795
|
+
warmup_steps: int = 3,
|
|
1796
|
+
measurement_steps: int = 10,
|
|
1797
|
+
) -> float:
|
|
1798
|
+
"""
|
|
1799
|
+
Measure inference latency per token.
|
|
1800
|
+
|
|
1801
|
+
Args:
|
|
1802
|
+
model: Model to measure
|
|
1803
|
+
window: EvaluationWindow with samples to use for measurement
|
|
1804
|
+
device: Device to use for measurement
|
|
1805
|
+
warmup_steps: Number of warmup iterations
|
|
1806
|
+
measurement_steps: Number of measurement iterations
|
|
1807
|
+
|
|
1808
|
+
Returns:
|
|
1809
|
+
Average latency in milliseconds per token
|
|
1810
|
+
"""
|
|
1811
|
+
if device is None:
|
|
1812
|
+
device = next(model.parameters()).device
|
|
1813
|
+
else:
|
|
1814
|
+
device = torch.device(device) if isinstance(device, str) else device
|
|
1815
|
+
|
|
1816
|
+
model.eval()
|
|
1817
|
+
|
|
1818
|
+
# Select a representative sample for timing
|
|
1819
|
+
if not window.input_ids:
|
|
1820
|
+
return 0.0
|
|
1821
|
+
|
|
1822
|
+
# Use the first valid sample
|
|
1823
|
+
sample_input_ids = None
|
|
1824
|
+
sample_attention_mask = None
|
|
1825
|
+
|
|
1826
|
+
for input_ids, attention_mask in zip(
|
|
1827
|
+
window.input_ids, window.attention_masks, strict=False
|
|
1828
|
+
):
|
|
1829
|
+
if len(input_ids) > 10: # Ensure reasonable length
|
|
1830
|
+
sample_input_ids = (
|
|
1831
|
+
torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
|
|
1832
|
+
)
|
|
1833
|
+
sample_attention_mask = (
|
|
1834
|
+
torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
|
|
1835
|
+
)
|
|
1836
|
+
break
|
|
1837
|
+
|
|
1838
|
+
if sample_input_ids is None:
|
|
1839
|
+
return 0.0
|
|
1840
|
+
|
|
1841
|
+
# Warmup
|
|
1842
|
+
with torch.no_grad():
|
|
1843
|
+
for _ in range(warmup_steps):
|
|
1844
|
+
try:
|
|
1845
|
+
_ = model(
|
|
1846
|
+
input_ids=sample_input_ids, attention_mask=sample_attention_mask
|
|
1847
|
+
)
|
|
1848
|
+
except Exception:
|
|
1849
|
+
# If there are issues with the model, return 0
|
|
1850
|
+
return 0.0
|
|
1851
|
+
|
|
1852
|
+
# Synchronize for accurate timing
|
|
1853
|
+
if device.type == "cuda":
|
|
1854
|
+
torch.cuda.synchronize()
|
|
1855
|
+
|
|
1856
|
+
# Measure latency
|
|
1857
|
+
start_time = time.time()
|
|
1858
|
+
|
|
1859
|
+
with torch.no_grad():
|
|
1860
|
+
for _ in range(measurement_steps):
|
|
1861
|
+
_ = model(input_ids=sample_input_ids, attention_mask=sample_attention_mask)
|
|
1862
|
+
|
|
1863
|
+
if device.type == "cuda":
|
|
1864
|
+
torch.cuda.synchronize()
|
|
1865
|
+
|
|
1866
|
+
end_time = time.time()
|
|
1867
|
+
|
|
1868
|
+
# Calculate per-token latency
|
|
1869
|
+
total_time_ms = (end_time - start_time) * 1000 # Convert to milliseconds
|
|
1870
|
+
total_tokens = int(sample_attention_mask.sum().item()) * measurement_steps
|
|
1871
|
+
|
|
1872
|
+
if total_tokens == 0:
|
|
1873
|
+
return 0.0
|
|
1874
|
+
|
|
1875
|
+
latency_ms_per_token = total_time_ms / total_tokens
|
|
1876
|
+
|
|
1877
|
+
logger.debug(
|
|
1878
|
+
f"Measured latency: {latency_ms_per_token:.3f} ms/token over {measurement_steps} steps"
|
|
1879
|
+
)
|
|
1880
|
+
return latency_ms_per_token
|
|
1881
|
+
|
|
1882
|
+
|
|
1883
|
+
def measure_memory(
|
|
1884
|
+
model: nn.Module,
|
|
1885
|
+
window: Any, # EvaluationWindow
|
|
1886
|
+
device: str | torch.device | None = None,
|
|
1887
|
+
) -> float:
|
|
1888
|
+
"""
|
|
1889
|
+
Measure peak memory usage during inference.
|
|
1890
|
+
|
|
1891
|
+
Args:
|
|
1892
|
+
model: Model to measure
|
|
1893
|
+
window: EvaluationWindow with samples to use for measurement
|
|
1894
|
+
device: Device to measure memory on
|
|
1895
|
+
|
|
1896
|
+
Returns:
|
|
1897
|
+
Peak memory usage in MB
|
|
1898
|
+
"""
|
|
1899
|
+
if device is None:
|
|
1900
|
+
device = next(model.parameters()).device
|
|
1901
|
+
else:
|
|
1902
|
+
device = torch.device(device) if isinstance(device, str) else device
|
|
1903
|
+
|
|
1904
|
+
model.eval()
|
|
1905
|
+
|
|
1906
|
+
# Get baseline memory
|
|
1907
|
+
if device.type == "cuda":
|
|
1908
|
+
torch.cuda.empty_cache()
|
|
1909
|
+
baseline_memory = torch.cuda.memory_allocated() / (1024 * 1024)
|
|
1910
|
+
torch.cuda.reset_peak_memory_stats()
|
|
1911
|
+
else:
|
|
1912
|
+
# For CPU/MPS, use psutil for system memory
|
|
1913
|
+
import psutil
|
|
1914
|
+
|
|
1915
|
+
process = psutil.Process()
|
|
1916
|
+
baseline_memory = process.memory_info().rss / (1024 * 1024)
|
|
1917
|
+
|
|
1918
|
+
# Run inference on a few samples to measure memory
|
|
1919
|
+
max_memory = baseline_memory
|
|
1920
|
+
|
|
1921
|
+
with torch.no_grad():
|
|
1922
|
+
for i, (input_ids, attention_mask) in enumerate(
|
|
1923
|
+
zip(window.input_ids, window.attention_masks, strict=False)
|
|
1924
|
+
):
|
|
1925
|
+
if i >= 5: # Only measure on first 5 samples
|
|
1926
|
+
break
|
|
1927
|
+
|
|
1928
|
+
if not input_ids:
|
|
1929
|
+
continue
|
|
1930
|
+
|
|
1931
|
+
try:
|
|
1932
|
+
input_ids_tensor = (
|
|
1933
|
+
torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
|
|
1934
|
+
)
|
|
1935
|
+
attention_mask_tensor = (
|
|
1936
|
+
torch.tensor(attention_mask, dtype=torch.long)
|
|
1937
|
+
.unsqueeze(0)
|
|
1938
|
+
.to(device)
|
|
1939
|
+
)
|
|
1940
|
+
|
|
1941
|
+
_ = model(
|
|
1942
|
+
input_ids=input_ids_tensor, attention_mask=attention_mask_tensor
|
|
1943
|
+
)
|
|
1944
|
+
|
|
1945
|
+
# Measure memory after forward pass
|
|
1946
|
+
if device.type == "cuda":
|
|
1947
|
+
current_memory = torch.cuda.memory_allocated() / (1024 * 1024)
|
|
1948
|
+
else:
|
|
1949
|
+
current_memory = process.memory_info().rss / (1024 * 1024)
|
|
1950
|
+
|
|
1951
|
+
max_memory = max(max_memory, current_memory)
|
|
1952
|
+
|
|
1953
|
+
except Exception as e:
|
|
1954
|
+
logger.debug(f"Memory measurement failed for sample {i}: {e}")
|
|
1955
|
+
continue
|
|
1956
|
+
|
|
1957
|
+
peak_memory_mb = max_memory
|
|
1958
|
+
logger.debug(f"Peak memory usage: {peak_memory_mb:.1f} MB")
|
|
1959
|
+
|
|
1960
|
+
return peak_memory_mb
|
|
1961
|
+
|
|
1962
|
+
|
|
1963
|
+
def compute_parameter_deltas(
|
|
1964
|
+
model_before: nn.Module, model_after: nn.Module, adapter: Any | None = None
|
|
1965
|
+
) -> dict[str, Any]:
|
|
1966
|
+
"""
|
|
1967
|
+
Compute precise parameter deltas between before and after models.
|
|
1968
|
+
|
|
1969
|
+
Args:
|
|
1970
|
+
model_before: Model state before edit
|
|
1971
|
+
model_after: Model state after edit
|
|
1972
|
+
adapter: Model adapter for architecture-specific analysis
|
|
1973
|
+
|
|
1974
|
+
Returns:
|
|
1975
|
+
Dictionary with parameter delta information:
|
|
1976
|
+
- params_changed: Number of parameters that were modified
|
|
1977
|
+
- layers_modified: Number of layers that were changed
|
|
1978
|
+
- sparsity: Overall sparsity ratio (if applicable)
|
|
1979
|
+
"""
|
|
1980
|
+
deltas = {
|
|
1981
|
+
"params_changed": 0,
|
|
1982
|
+
"layers_modified": 0,
|
|
1983
|
+
"sparsity": None,
|
|
1984
|
+
}
|
|
1985
|
+
|
|
1986
|
+
try:
|
|
1987
|
+
# Compare parameters
|
|
1988
|
+
before_params = dict(model_before.named_parameters())
|
|
1989
|
+
after_params = dict(model_after.named_parameters())
|
|
1990
|
+
|
|
1991
|
+
modified_layers = set()
|
|
1992
|
+
total_changed = 0
|
|
1993
|
+
|
|
1994
|
+
for name, before_param in before_params.items():
|
|
1995
|
+
if name not in after_params:
|
|
1996
|
+
continue
|
|
1997
|
+
|
|
1998
|
+
after_param = after_params[name]
|
|
1999
|
+
|
|
2000
|
+
# Check if parameter changed
|
|
2001
|
+
if not torch.equal(before_param.data, after_param.data):
|
|
2002
|
+
total_changed += before_param.numel()
|
|
2003
|
+
|
|
2004
|
+
# Extract layer information from parameter name
|
|
2005
|
+
layer_match = None
|
|
2006
|
+
if ".h." in name or ".layers." in name:
|
|
2007
|
+
# Extract layer number for transformer models
|
|
2008
|
+
import re
|
|
2009
|
+
|
|
2010
|
+
match = re.search(r"\.(?:h|layers)\.(\d+)\.", name)
|
|
2011
|
+
if match:
|
|
2012
|
+
layer_match = int(match.group(1))
|
|
2013
|
+
modified_layers.add(layer_match)
|
|
2014
|
+
|
|
2015
|
+
deltas["params_changed"] = total_changed
|
|
2016
|
+
deltas["layers_modified"] = len(modified_layers)
|
|
2017
|
+
|
|
2018
|
+
# Structural deltas (like head/neuron counts) are not tracked in this profile
|
|
2019
|
+
|
|
2020
|
+
# Compute overall sparsity if applicable
|
|
2021
|
+
total_params_before = sum(p.numel() for p in model_before.parameters())
|
|
2022
|
+
total_params_after = sum(p.numel() for p in model_after.parameters())
|
|
2023
|
+
|
|
2024
|
+
if total_params_after < total_params_before:
|
|
2025
|
+
deltas["sparsity"] = 1.0 - (total_params_after / total_params_before)
|
|
2026
|
+
|
|
2027
|
+
except Exception as e:
|
|
2028
|
+
logger.warning(f"Parameter delta computation failed: {e}")
|
|
2029
|
+
|
|
2030
|
+
return deltas
|
|
2031
|
+
|
|
2032
|
+
|
|
2033
|
+
def analyze_spectral_changes(
|
|
2034
|
+
model_before: nn.Module, model_after: nn.Module, scope: str = "ffn"
|
|
2035
|
+
) -> dict[str, Any]:
|
|
2036
|
+
"""
|
|
2037
|
+
Analyze spectral norm changes between model states.
|
|
2038
|
+
|
|
2039
|
+
Args:
|
|
2040
|
+
model_before: Model before edit
|
|
2041
|
+
model_after: Model after edit
|
|
2042
|
+
scope: Scope for spectral analysis ("ffn", "all")
|
|
2043
|
+
|
|
2044
|
+
Returns:
|
|
2045
|
+
Dictionary with spectral analysis results
|
|
2046
|
+
"""
|
|
2047
|
+
try:
|
|
2048
|
+
# Import spectral analysis if available
|
|
2049
|
+
from invarlock.guards.spectral import compute_spectral_norms
|
|
2050
|
+
|
|
2051
|
+
before_norms = compute_spectral_norms(model_before, scope=scope)
|
|
2052
|
+
after_norms = compute_spectral_norms(model_after, scope=scope)
|
|
2053
|
+
|
|
2054
|
+
# Compute changes
|
|
2055
|
+
changes = {}
|
|
2056
|
+
for layer_name in before_norms:
|
|
2057
|
+
if layer_name in after_norms:
|
|
2058
|
+
before_norm = before_norms[layer_name]
|
|
2059
|
+
after_norm = after_norms[layer_name]
|
|
2060
|
+
change_ratio = after_norm / before_norm if before_norm > 0 else 1.0
|
|
2061
|
+
changes[layer_name] = {
|
|
2062
|
+
"before": before_norm,
|
|
2063
|
+
"after": after_norm,
|
|
2064
|
+
"ratio": change_ratio,
|
|
2065
|
+
}
|
|
2066
|
+
|
|
2067
|
+
# Summary statistics
|
|
2068
|
+
ratios = [change["ratio"] for change in changes.values()]
|
|
2069
|
+
summary = {
|
|
2070
|
+
"layer_changes": changes,
|
|
2071
|
+
"mean_ratio": float(np.mean(ratios)) if ratios else 1.0,
|
|
2072
|
+
"max_ratio": float(np.max(ratios)) if ratios else 1.0,
|
|
2073
|
+
"min_ratio": float(np.min(ratios)) if ratios else 1.0,
|
|
2074
|
+
"layers_analyzed": len(changes),
|
|
2075
|
+
}
|
|
2076
|
+
|
|
2077
|
+
return summary
|
|
2078
|
+
|
|
2079
|
+
except ImportError:
|
|
2080
|
+
logger.debug("Spectral analysis not available")
|
|
2081
|
+
return {"error": "spectral_analysis_unavailable"}
|
|
2082
|
+
except Exception as e:
|
|
2083
|
+
logger.warning(f"Spectral analysis failed: {e}")
|
|
2084
|
+
return {"error": str(e)}
|
|
2085
|
+
|
|
2086
|
+
|
|
2087
|
+
def analyze_rmt_changes(
|
|
2088
|
+
model_before: nn.Module, model_after: nn.Module
|
|
2089
|
+
) -> dict[str, Any]:
|
|
2090
|
+
"""
|
|
2091
|
+
Analyze RMT (Random Matrix Theory) changes between model states.
|
|
2092
|
+
|
|
2093
|
+
Args:
|
|
2094
|
+
model_before: Model before edit
|
|
2095
|
+
model_after: Model after edit
|
|
2096
|
+
|
|
2097
|
+
Returns:
|
|
2098
|
+
Dictionary with RMT analysis results
|
|
2099
|
+
"""
|
|
2100
|
+
try:
|
|
2101
|
+
# Import RMT analysis if available
|
|
2102
|
+
from invarlock.guards.rmt import compute_mp_stats
|
|
2103
|
+
|
|
2104
|
+
before_stats = compute_mp_stats(model_before)
|
|
2105
|
+
after_stats = compute_mp_stats(model_after)
|
|
2106
|
+
|
|
2107
|
+
# Analyze changes in MP statistics
|
|
2108
|
+
changes = {}
|
|
2109
|
+
for layer_name in before_stats:
|
|
2110
|
+
if layer_name in after_stats:
|
|
2111
|
+
before_mp = before_stats[layer_name]
|
|
2112
|
+
after_mp = after_stats[layer_name]
|
|
2113
|
+
changes[layer_name] = {
|
|
2114
|
+
"before": before_mp,
|
|
2115
|
+
"after": after_mp,
|
|
2116
|
+
"stable": abs(before_mp - after_mp) < 0.1, # Stability threshold
|
|
2117
|
+
}
|
|
2118
|
+
|
|
2119
|
+
# Count stable vs unstable layers
|
|
2120
|
+
stable_count = sum(
|
|
2121
|
+
1 for change in changes.values() if change.get("stable", False)
|
|
2122
|
+
)
|
|
2123
|
+
total_count = len(changes)
|
|
2124
|
+
|
|
2125
|
+
summary = {
|
|
2126
|
+
"layer_changes": changes,
|
|
2127
|
+
"stable_layers": stable_count,
|
|
2128
|
+
"total_layers": total_count,
|
|
2129
|
+
"stability_ratio": stable_count / total_count if total_count > 0 else 0.0,
|
|
2130
|
+
}
|
|
2131
|
+
|
|
2132
|
+
return summary
|
|
2133
|
+
|
|
2134
|
+
except ImportError:
|
|
2135
|
+
logger.debug("RMT analysis not available")
|
|
2136
|
+
return {"error": "rmt_analysis_unavailable"}
|
|
2137
|
+
except Exception as e:
|
|
2138
|
+
logger.warning(f"RMT analysis failed: {e}")
|
|
2139
|
+
return {"error": str(e)}
|
|
2140
|
+
|
|
2141
|
+
|
|
2142
|
+
# ── Integration with existing system ───────────────────────────────────────
|
|
2143
|
+
|
|
2144
|
+
# Update exports to include new functions (add to existing __all__ if it exists)
|
|
2145
|
+
try:
|
|
2146
|
+
__all__.extend(
|
|
2147
|
+
[
|
|
2148
|
+
"bootstrap_confidence_interval",
|
|
2149
|
+
"compute_ppl",
|
|
2150
|
+
"measure_latency",
|
|
2151
|
+
"measure_memory",
|
|
2152
|
+
"compute_parameter_deltas",
|
|
2153
|
+
"analyze_spectral_changes",
|
|
2154
|
+
"analyze_rmt_changes",
|
|
2155
|
+
]
|
|
2156
|
+
)
|
|
2157
|
+
except NameError:
|
|
2158
|
+
# If __all__ doesn't exist, create it with the new functions
|
|
2159
|
+
__all__ = [
|
|
2160
|
+
"bootstrap_confidence_interval",
|
|
2161
|
+
"compute_ppl",
|
|
2162
|
+
"measure_latency",
|
|
2163
|
+
"measure_memory",
|
|
2164
|
+
"compute_parameter_deltas",
|
|
2165
|
+
"analyze_spectral_changes",
|
|
2166
|
+
"analyze_rmt_changes",
|
|
2167
|
+
]
|