invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2167 @@
1
+ """
2
+ invarlock.metrics
3
+ =============
4
+
5
+ Enhanced diagnostic helpers used by the Phase-2 notebooks with improved
6
+ robustness, performance, and configurability.
7
+
8
+ Public entry point
9
+ ------------------
10
+ >>> from invarlock.metrics import calculate_lens_metrics_for_model, MetricsConfig
11
+ >>> config = MetricsConfig(oracle_windows=32, max_tokens=512)
12
+ >>> metrics = calculate_lens_metrics_for_model(model, dataloader, config=config)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import gc
18
+ import logging
19
+ import math
20
+ import time
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Any
24
+
25
+ import numpy as np
26
+ import psutil
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from invarlock.core.error_utils import wrap_errors
31
+ from invarlock.core.exceptions import MetricsError, ValidationError
32
+
33
+ # ── Enhanced logging setup ─────────────────────────────────────────────────
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ try: # Optional dependency: tqdm (progress bars)
38
+ from tqdm.auto import tqdm as _tqdm
39
+ except Exception: # pragma: no cover - exercised only when tqdm is absent
40
+
41
+ class _TqdmShim:
42
+ def __init__(self, iterable=None, total=None, **kwargs):
43
+ self._iterable = iterable
44
+ self.total = total
45
+
46
+ def __iter__(self):
47
+ if self._iterable is None:
48
+ return iter(())
49
+ return iter(self._iterable)
50
+
51
+ def __enter__(self):
52
+ return self
53
+
54
+ def __exit__(self, exc_type, exc, tb):
55
+ return False
56
+
57
+ def update(self, n: int = 1) -> None:
58
+ return None
59
+
60
+ def _tqdm(iterable=None, *args, **kwargs):
61
+ return _TqdmShim(iterable=iterable, **kwargs)
62
+
63
+
64
+ tqdm = _tqdm
65
+
66
+
67
+ class DependencyError(MetricsError):
68
+ """Raised when required dependencies are missing."""
69
+
70
+ pass
71
+
72
+
73
+ class ResourceError(MetricsError):
74
+ """Raised when insufficient resources are available."""
75
+
76
+ pass
77
+
78
+
79
+ ## Note: Use ValidationError from invarlock.core.exceptions
80
+
81
+
82
+ def bootstrap_confidence_interval(
83
+ samples: list[float] | np.ndarray,
84
+ n_bootstrap: int = 500,
85
+ alpha: float = 0.05,
86
+ statistic: callable = np.mean,
87
+ random_state: np.random.Generator | None = None,
88
+ ) -> tuple[float, float]:
89
+ """
90
+ Compute a bootstrap confidence interval for a 1D sample.
91
+
92
+ Args:
93
+ samples: 1D iterable of numeric samples.
94
+ n_bootstrap: Number of bootstrap resamples.
95
+ alpha: Significance level (0 < alpha < 1).
96
+ statistic: Statistic function to apply to each resample.
97
+ random_state: Optional numpy random generator for reproducibility.
98
+
99
+ Returns:
100
+ (lower, upper) confidence bounds.
101
+
102
+ Raises:
103
+ ValidationError(E402): For invalid inputs (shape/empty/range).
104
+ MetricsError(E401): For compute/statistic failures during bootstrap.
105
+ """
106
+ data = np.asarray(samples, dtype=float)
107
+ if data.ndim != 1:
108
+ raise ValidationError(
109
+ code="E402",
110
+ message="METRICS-VALIDATION-FAILED",
111
+ details={"reason": "samples must be 1-dimensional"},
112
+ )
113
+ if data.size == 0:
114
+ raise ValidationError(
115
+ code="E402",
116
+ message="METRICS-VALIDATION-FAILED",
117
+ details={"reason": "samples cannot be empty"},
118
+ )
119
+ if not 0.0 < alpha < 1.0:
120
+ raise ValidationError(
121
+ code="E402",
122
+ message="METRICS-VALIDATION-FAILED",
123
+ details={"reason": "alpha must be between 0 and 1", "alpha": alpha},
124
+ )
125
+ if n_bootstrap <= 0:
126
+ raise ValidationError(
127
+ code="E402",
128
+ message="METRICS-VALIDATION-FAILED",
129
+ details={
130
+ "reason": "n_bootstrap must be positive",
131
+ "n_bootstrap": n_bootstrap,
132
+ },
133
+ )
134
+
135
+ with wrap_errors(MetricsError, "E401", "METRICS-COMPUTE-FAILED"):
136
+ rng = random_state or np.random.default_rng()
137
+ stats = np.empty(n_bootstrap, dtype=float)
138
+ for i in range(n_bootstrap):
139
+ indices = rng.integers(0, data.size, size=data.size)
140
+ stats[i] = statistic(data[indices])
141
+
142
+ lower = float(np.percentile(stats, 100 * (alpha / 2)))
143
+ upper = float(np.percentile(stats, 100 * (1 - alpha / 2)))
144
+ return lower, upper
145
+
146
+
147
+ @dataclass
148
+ class MetricsConfig:
149
+ """Configuration for metrics calculation with sensible defaults."""
150
+
151
+ # Core parameters
152
+ oracle_windows: int = 16
153
+ max_tokens: int = 256
154
+ max_samples_per_layer: int = 25_000
155
+
156
+ # Memory management
157
+ auto_batch_size: bool = True
158
+ memory_limit_gb: float | None = None
159
+ cpu_fallback_threshold_gb: float = 0.5
160
+
161
+ # Performance options
162
+ use_cache: bool = True
163
+ cache_dir: Path | None = None
164
+ progress_bars: bool = True
165
+
166
+ # Numerical stability
167
+ clip_value: float = 1e3
168
+ nan_replacement: float = 0.0
169
+ inf_replacement: float = 1e4
170
+
171
+ # Device management
172
+ device: torch.device | None = None
173
+ force_cpu: bool = False
174
+ cleanup_after: bool = True
175
+
176
+ # Validation options
177
+ strict_validation: bool = True
178
+ allow_empty_data: bool = False
179
+
180
+ # Lens-specific parameters
181
+ sigma_max_margin: float = 0.98
182
+ mi_gini_subsample_ratio: float = 0.05
183
+ head_energy_layers_filter: bool = True
184
+
185
+ def __post_init__(self):
186
+ """Validate configuration after initialization."""
187
+ if self.oracle_windows < 0:
188
+ raise ValidationError(
189
+ code="E402",
190
+ message="METRICS-VALIDATION-FAILED",
191
+ details={"reason": "oracle_windows must be non-negative"},
192
+ )
193
+ if self.max_tokens <= 0:
194
+ raise ValidationError(
195
+ code="E402",
196
+ message="METRICS-VALIDATION-FAILED",
197
+ details={"reason": "max_tokens must be positive"},
198
+ )
199
+ if self.memory_limit_gb is not None and self.memory_limit_gb <= 0:
200
+ raise ValidationError(
201
+ code="E402",
202
+ message="METRICS-VALIDATION-FAILED",
203
+ details={"reason": "memory_limit_gb must be positive"},
204
+ )
205
+
206
+ # Set default cache directory
207
+ if self.use_cache and self.cache_dir is None:
208
+ self.cache_dir = Path.home() / ".cache" / "invarlock_metrics"
209
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
210
+
211
+
212
+ class ResourceManager:
213
+ """Manages computational resources and memory usage."""
214
+
215
+ def __init__(self, config: MetricsConfig):
216
+ self.config = config
217
+ self.device = self._determine_device()
218
+ self.memory_info = self._get_memory_info()
219
+
220
+ def _determine_device(self) -> torch.device:
221
+ """Determine the best device to use."""
222
+ if self.config.force_cpu:
223
+ return torch.device("cpu")
224
+
225
+ if self.config.device is not None:
226
+ return self.config.device
227
+
228
+ if torch.cuda.is_available():
229
+ return torch.device("cuda")
230
+ elif torch.backends.mps.is_available():
231
+ return torch.device("mps")
232
+ else:
233
+ return torch.device("cpu")
234
+
235
+ def _get_memory_info(self) -> dict[str, float]:
236
+ """Get current memory information."""
237
+ info = {}
238
+
239
+ # System memory
240
+ vm = psutil.virtual_memory()
241
+ info["system_total_gb"] = vm.total / (1024**3)
242
+ info["system_available_gb"] = vm.available / (1024**3)
243
+
244
+ # GPU memory
245
+ if self.device.type == "cuda":
246
+ info["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / (
247
+ 1024**3
248
+ )
249
+ info["gpu_free_gb"] = (
250
+ torch.cuda.get_device_properties(0).total_memory
251
+ - torch.cuda.memory_allocated()
252
+ ) / (1024**3)
253
+
254
+ return info
255
+
256
+ def estimate_memory_usage(
257
+ self, model: nn.Module, batch_size: int, seq_length: int
258
+ ) -> float:
259
+ """Estimate memory usage in GB for given parameters."""
260
+ # Model parameters
261
+ param_memory = sum(p.numel() * p.element_size() for p in model.parameters()) / (
262
+ 1024**3
263
+ )
264
+
265
+ # Activation memory (rough estimate)
266
+ if hasattr(model, "config"):
267
+ hidden_size = getattr(
268
+ model.config, "n_embd", getattr(model.config, "hidden_size", 768)
269
+ )
270
+ num_layers = getattr(
271
+ model.config, "n_layer", getattr(model.config, "num_hidden_layers", 12)
272
+ )
273
+ activation_memory = (
274
+ batch_size * seq_length * hidden_size * num_layers * 4
275
+ ) / (1024**3)
276
+ else:
277
+ activation_memory = param_memory * 2 # Conservative estimate
278
+
279
+ return param_memory + activation_memory
280
+
281
+ def should_use_cpu_fallback(self, estimated_memory_gb: float) -> bool:
282
+ """Determine if CPU fallback should be used."""
283
+ if self.device.type == "cpu":
284
+ return False
285
+
286
+ available_memory = self.memory_info.get(
287
+ "gpu_free_gb", self.memory_info.get("system_available_gb", 8.0)
288
+ )
289
+
290
+ return estimated_memory_gb > (
291
+ available_memory - self.config.cpu_fallback_threshold_gb
292
+ )
293
+
294
+ def cleanup(self):
295
+ """Clean up GPU memory."""
296
+ if self.config.cleanup_after:
297
+ if torch.cuda.is_available():
298
+ torch.cuda.empty_cache()
299
+ gc.collect()
300
+
301
+
302
+ # ── Enhanced dependency management ─────────────────────────────────────────
303
+ class DependencyManager:
304
+ """Manages optional dependencies with graceful degradation."""
305
+
306
+ def __init__(self):
307
+ self.available_modules: dict[str, Any] = {}
308
+ self.missing_modules: list[tuple[str, str]] = []
309
+ self._check_dependencies()
310
+
311
+ def _check_dependencies(self):
312
+ """Check availability of optional dependencies."""
313
+ # Check lens2_mi
314
+ try:
315
+ from .lens2_mi import mi_scores
316
+
317
+ self.available_modules["mi_scores"] = mi_scores
318
+ logger.info("✓ lens2_mi module available")
319
+ except ImportError as e:
320
+ self.missing_modules.append(("lens2_mi", str(e)))
321
+ logger.warning("✗ lens2_mi module not available - MI-Gini will be NaN")
322
+
323
+ # Check lens3
324
+ try:
325
+ from .lens3 import scan_model_gains
326
+
327
+ self.available_modules["scan_model_gains"] = scan_model_gains
328
+ logger.info("✓ lens3 module available")
329
+ except ImportError as e:
330
+ self.missing_modules.append(("lens3", str(e)))
331
+ logger.warning("✗ lens3 module not available - σ_max will be NaN")
332
+
333
+ def get_module(self, name: str):
334
+ """Get a module if available, otherwise raise DependencyError."""
335
+ if name in self.available_modules:
336
+ return self.available_modules[name]
337
+ raise DependencyError(
338
+ code="E203",
339
+ message=f"DEPENDENCY-MISSING: module {name} is not available",
340
+ details={"module": name},
341
+ )
342
+
343
+ def is_available(self, name: str) -> bool:
344
+ """Check if a module is available."""
345
+ return name in self.available_modules
346
+
347
+ def get_missing_dependencies(self) -> list[tuple[str, str]]:
348
+ """Get list of missing dependencies with error messages."""
349
+ return self.missing_modules.copy()
350
+
351
+
352
+ # ── Input validation ───────────────────────────────────────────────────────
353
+ class InputValidator:
354
+ """Validates inputs for metrics calculation."""
355
+
356
+ @staticmethod
357
+ def validate_model(model: nn.Module, config: MetricsConfig) -> None:
358
+ """Validate model input."""
359
+ if not isinstance(model, nn.Module):
360
+ raise ValidationError(
361
+ code="E402",
362
+ message="METRICS-VALIDATION-FAILED",
363
+ details={"reason": f"Expected nn.Module, got {type(model)}"},
364
+ )
365
+
366
+ # Check if model has parameters
367
+ try:
368
+ param_count = sum(1 for _ in model.parameters())
369
+ if param_count == 0:
370
+ if config.strict_validation:
371
+ raise ValidationError(
372
+ code="E402",
373
+ message="METRICS-VALIDATION-FAILED",
374
+ details={"reason": "Model has no parameters"},
375
+ )
376
+ else:
377
+ logger.warning("Model has no parameters")
378
+ except Exception as e:
379
+ logger.debug(f"Could not count model parameters: {e}")
380
+
381
+ @staticmethod
382
+ def validate_dataloader(dataloader, config: MetricsConfig) -> None:
383
+ """Validate dataloader input."""
384
+ if dataloader is None:
385
+ raise ValidationError(
386
+ code="E402",
387
+ message="METRICS-VALIDATION-FAILED",
388
+ details={"reason": "Dataloader cannot be None"},
389
+ )
390
+
391
+ # Check if dataloader has data
392
+ try:
393
+ first_batch = next(iter(dataloader))
394
+ if not first_batch:
395
+ if not config.allow_empty_data:
396
+ raise ValidationError(
397
+ code="E402",
398
+ message="METRICS-VALIDATION-FAILED",
399
+ details={"reason": "Dataloader is empty"},
400
+ )
401
+ else:
402
+ logger.warning("Dataloader is empty")
403
+ except StopIteration as e:
404
+ if not config.allow_empty_data:
405
+ raise ValidationError(
406
+ code="E402",
407
+ message="METRICS-VALIDATION-FAILED",
408
+ details={"reason": "Dataloader is empty"},
409
+ ) from e
410
+ else:
411
+ logger.warning("Dataloader is empty")
412
+
413
+ @staticmethod
414
+ def validate_tensor(
415
+ tensor: torch.Tensor, name: str, config: MetricsConfig
416
+ ) -> torch.Tensor:
417
+ """Validate and sanitize tensor."""
418
+ if not isinstance(tensor, torch.Tensor):
419
+ raise ValidationError(
420
+ code="E402",
421
+ message="METRICS-VALIDATION-FAILED",
422
+ details={"reason": f"{name} must be a tensor, got {type(tensor)}"},
423
+ )
424
+
425
+ # Check for NaN/Inf
426
+ if torch.isnan(tensor).any():
427
+ if config.strict_validation:
428
+ raise ValidationError(
429
+ code="E402",
430
+ message="METRICS-VALIDATION-FAILED",
431
+ details={"reason": f"{name} contains NaN values"},
432
+ )
433
+ else:
434
+ logger.warning(
435
+ f"{name} contains NaN values, replacing with {config.nan_replacement}"
436
+ )
437
+ tensor = torch.nan_to_num(tensor, nan=config.nan_replacement)
438
+
439
+ if torch.isinf(tensor).any():
440
+ if config.strict_validation:
441
+ raise ValidationError(
442
+ code="E402",
443
+ message="METRICS-VALIDATION-FAILED",
444
+ details={"reason": f"{name} contains Inf values"},
445
+ )
446
+ else:
447
+ logger.warning(
448
+ f"{name} contains Inf values, replacing with ±{config.inf_replacement}"
449
+ )
450
+ tensor = torch.nan_to_num(
451
+ tensor,
452
+ posinf=config.inf_replacement,
453
+ neginf=-config.inf_replacement,
454
+ )
455
+
456
+ return tensor
457
+
458
+
459
+ # ── Enhanced helper functions ──────────────────────────────────────────────
460
+ def _gini_vectorized(vec: torch.Tensor) -> float:
461
+ """Optimized Gini coefficient calculation."""
462
+ flat = vec.flatten().abs().float()
463
+ if flat.numel() == 0 or torch.sum(flat) == 0:
464
+ return float("nan")
465
+
466
+ # Use more efficient sorting and cumsum
467
+ sorted_vals = torch.sort(flat)[0]
468
+ n = sorted_vals.numel()
469
+
470
+ # Vectorized Gini calculation
471
+ indices = torch.arange(1, n + 1, dtype=torch.float32, device=flat.device)
472
+ gini = (2 * torch.sum(indices * sorted_vals) / torch.sum(sorted_vals) - (n + 1)) / n
473
+
474
+ return gini.item()
475
+
476
+
477
+ def _mi_gini_optimized_cpu_path(
478
+ feats_cpu: torch.Tensor,
479
+ targ_cpu: torch.Tensor,
480
+ max_per_layer: int,
481
+ config: MetricsConfig,
482
+ ) -> float:
483
+ """Optimized MI Gini calculation on CPU with better memory management."""
484
+ L, N, _ = feats_cpu.shape
485
+
486
+ # Subsample if dataset is too large
487
+ if N > max_per_layer:
488
+ sel = torch.randperm(N)[:max_per_layer]
489
+ feats_cpu = feats_cpu[:, sel, :]
490
+ targ_cpu = targ_cpu[sel]
491
+
492
+ # Get MI function
493
+ dep_manager = DependencyManager()
494
+ if not dep_manager.is_available("mi_scores"):
495
+ return float("nan")
496
+
497
+ mi_scores_fn = dep_manager.get_module("mi_scores")
498
+
499
+ # Process in chunks to manage memory
500
+ chunk_size = min(8, L) # Process 8 layers at a time
501
+ mi_scores_all = []
502
+
503
+ progress_desc = "MI-Gini (CPU optimized)"
504
+ with tqdm(
505
+ total=L, desc=progress_desc, disable=not config.progress_bars, leave=False
506
+ ) as pbar:
507
+ for i in range(0, L, chunk_size):
508
+ end_idx = min(i + chunk_size, L)
509
+ chunk_feats = feats_cpu[i:end_idx]
510
+
511
+ # Vectorized processing for the chunk
512
+ chunk_scores = []
513
+ for j in range(chunk_feats.shape[0]):
514
+ try:
515
+ score = mi_scores_fn(chunk_feats[j], targ_cpu)
516
+ chunk_scores.append(score)
517
+ except Exception as e:
518
+ logger.warning(f"MI calculation failed for layer {i + j}: {e}")
519
+ chunk_scores.append(torch.zeros_like(chunk_feats[j, 0, :]))
520
+
521
+ mi_scores_all.extend(chunk_scores)
522
+ pbar.update(end_idx - i)
523
+
524
+ if not mi_scores_all:
525
+ return float("nan")
526
+
527
+ try:
528
+ mi_mat = torch.stack(mi_scores_all)
529
+ return _gini_vectorized(mi_mat)
530
+ except Exception as e:
531
+ logger.warning(f"Failed to stack MI scores: {e}")
532
+ return float("nan")
533
+
534
+
535
+ def _locate_transformer_blocks_enhanced(model: nn.Module) -> list[nn.Module] | None:
536
+ """Enhanced transformer block detection with better model support."""
537
+
538
+ # Standard GPT2 patterns - safer approach
539
+ def safe_getattr_chain(obj, *attrs):
540
+ """Safely get nested attributes."""
541
+ for attr in attrs:
542
+ if obj is None:
543
+ return None
544
+ obj = getattr(obj, attr, None)
545
+ return obj
546
+
547
+ patterns = [
548
+ lambda m: safe_getattr_chain(m, "transformer", "h"),
549
+ lambda m: safe_getattr_chain(m, "h"), # Bare GPT2Model
550
+ lambda m: safe_getattr_chain(m, "base_model", "h"), # Common wrappers
551
+ lambda m: safe_getattr_chain(m, "model", "h"), # Some wrappers
552
+ lambda m: safe_getattr_chain(m, "transformer", "layers"), # Alternative naming
553
+ ]
554
+
555
+ for pattern in patterns:
556
+ try:
557
+ blocks = pattern(model)
558
+ if blocks is not None and hasattr(blocks, "__len__") and len(blocks) > 0:
559
+ logger.debug(f"Found {len(blocks)} transformer blocks using pattern")
560
+ return list(blocks)
561
+ except (AttributeError, TypeError):
562
+ continue
563
+
564
+ # Fallback: search for transformer-like modules
565
+ transformer_modules = []
566
+ for name, module in model.named_modules():
567
+ if any(attr in name.lower() for attr in ["block", "layer", "transformer"]):
568
+ if hasattr(module, "attn") and hasattr(module, "mlp"):
569
+ transformer_modules.append(module)
570
+
571
+ if transformer_modules:
572
+ logger.debug(
573
+ f"Found {len(transformer_modules)} transformer blocks via fallback search"
574
+ )
575
+ return transformer_modules
576
+
577
+ logger.warning("Could not locate transformer blocks in model")
578
+ return None
579
+
580
+
581
+ # ── Result caching ─────────────────────────────────────────────────────────
582
+ class ResultCache:
583
+ """Simple result caching for expensive operations."""
584
+
585
+ def __init__(self, config: MetricsConfig):
586
+ self.config = config
587
+ self.cache: dict[str, dict[str, float]] = {}
588
+ self.enabled = config.use_cache
589
+
590
+ def _get_cache_key(
591
+ self, model: nn.Module, dataloader, config: MetricsConfig
592
+ ) -> str:
593
+ """Generate cache key for model and data."""
594
+ # Simple hash based on model parameters and config
595
+ model_hash = hash(tuple(p.data_ptr() for p in model.parameters()))
596
+ config_hash = hash(
597
+ (config.oracle_windows, config.max_tokens, config.max_samples_per_layer)
598
+ )
599
+ return f"{model_hash}_{config_hash}"
600
+
601
+ def get(self, key: str) -> dict[str, float] | None:
602
+ """Get cached result."""
603
+ if not self.enabled:
604
+ return None
605
+ return self.cache.get(key)
606
+
607
+ def set(self, key: str, result: dict[str, float]) -> None:
608
+ """Cache result."""
609
+ if self.enabled:
610
+ self.cache[key] = result.copy()
611
+
612
+ def clear(self) -> None:
613
+ """Clear cache."""
614
+ self.cache.clear()
615
+
616
+
617
+ # ── Main metrics calculation function ──────────────────────────────────────
618
+ @torch.no_grad()
619
+ def calculate_lens_metrics_for_model(
620
+ model: nn.Module,
621
+ dataloader,
622
+ *,
623
+ config: MetricsConfig | None = None,
624
+ oracle_windows: int | None = None, # Backward compatibility
625
+ device: torch.device | None = None, # Backward compatibility
626
+ ) -> dict[str, float]:
627
+ """
628
+ Calculate comprehensive lens metrics for a model with enhanced robustness.
629
+
630
+ Args:
631
+ model: The neural network model to analyze
632
+ dataloader: DataLoader providing input data
633
+ config: MetricsConfig object with all parameters
634
+ oracle_windows: (deprecated) Number of windows to process
635
+ device: (deprecated) Device to use for computation
636
+
637
+ Returns:
638
+ Dictionary containing calculated metrics
639
+
640
+ Raises:
641
+ MetricsError: If calculation fails due to various reasons
642
+ """
643
+ # Handle backward compatibility
644
+ if config is None:
645
+ config = MetricsConfig()
646
+ if oracle_windows is not None:
647
+ config.oracle_windows = oracle_windows
648
+ if device is not None:
649
+ config.device = device
650
+
651
+ # Initialize managers
652
+ dep_manager = DependencyManager()
653
+ resource_manager = ResourceManager(config)
654
+ validator = InputValidator()
655
+ cache = ResultCache(config)
656
+
657
+ # Validate inputs
658
+ validator.validate_model(model, config)
659
+ validator.validate_dataloader(dataloader, config)
660
+
661
+ # Check cache
662
+ cache_key = cache._get_cache_key(model, dataloader, config)
663
+ cached_result = cache.get(cache_key)
664
+ if cached_result is not None:
665
+ logger.info("Using cached metrics result")
666
+ return cached_result
667
+
668
+ start_time = time.time()
669
+ logger.info(
670
+ f"Starting metrics calculation with config: oracle_windows={config.oracle_windows}, "
671
+ f"max_tokens={config.max_tokens}, device={resource_manager.device}"
672
+ )
673
+
674
+ # Pre-evaluation checks
675
+ try:
676
+ _perform_pre_eval_checks(model, dataloader, resource_manager.device, config)
677
+ except Exception as e:
678
+ logger.warning(f"Pre-evaluation checks failed: {e}")
679
+
680
+ # Unwrap common wrappers if present
681
+ if hasattr(model, "base_model"):
682
+ try:
683
+ model = model.base_model
684
+ except Exception:
685
+ pass
686
+
687
+ model.eval()
688
+ device = resource_manager.device
689
+
690
+ # Initialize results
691
+ results = {
692
+ "sigma_max": float("nan"),
693
+ "head_energy": float("nan"),
694
+ "mi_gini": float("nan"),
695
+ }
696
+
697
+ skipped_metrics: list[str] = []
698
+
699
+ try:
700
+ # Collect activations with progress tracking
701
+ logger.info("Collecting model activations...")
702
+ activation_data = _collect_activations(model, dataloader, config, device)
703
+
704
+ if not activation_data["hidden_states"]:
705
+ logger.warning("No activations collected - returning default values")
706
+ return _finalize_results(
707
+ results, skipped_metrics, cache, cache_key, start_time
708
+ )
709
+
710
+ # Calculate each metric
711
+ results["sigma_max"] = _calculate_sigma_max(
712
+ model, activation_data["first_batch"], dep_manager, config, device
713
+ )
714
+
715
+ results["head_energy"] = _calculate_head_energy(
716
+ activation_data["hidden_states"], config
717
+ )
718
+
719
+ results["mi_gini"] = _calculate_mi_gini(
720
+ model, activation_data, dep_manager, config, device
721
+ )
722
+
723
+ except Exception as e:
724
+ logger.error(f"Metrics calculation failed: {e}")
725
+ if config.strict_validation:
726
+ raise MetricsError(f"Metrics calculation failed: {e}") from e
727
+
728
+ finally:
729
+ resource_manager.cleanup()
730
+
731
+ return _finalize_results(results, skipped_metrics, cache, cache_key, start_time)
732
+
733
+
734
+ def _perform_pre_eval_checks(
735
+ model: nn.Module, dataloader, device: torch.device, config: MetricsConfig
736
+ ) -> None:
737
+ """Perform pre-evaluation sanity checks."""
738
+ # Check model context length vs data
739
+ try:
740
+ tok_len_attr = getattr(model.config, "n_positions", None) or getattr(
741
+ model.config, "max_position_embeddings", None
742
+ )
743
+ if tok_len_attr:
744
+ sample_batch = next(iter(dataloader))
745
+ sample_ids = sample_batch["input_ids"]
746
+ if sample_ids.shape[1] > tok_len_attr:
747
+ logger.warning(
748
+ f"Input sequence length {sample_ids.shape[1]} exceeds "
749
+ f"model limit {tok_len_attr}"
750
+ )
751
+ except Exception as e:
752
+ logger.debug(f"Context length check failed: {e}")
753
+
754
+ # Dry run forward pass
755
+ try:
756
+ dry_batch = next(iter(dataloader))
757
+ model_input = {
758
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
759
+ for k, v in dry_batch.items()
760
+ }
761
+ _ = model(**model_input)
762
+ logger.debug("Pre-evaluation dry run successful")
763
+ except Exception as e:
764
+ logger.warning(f"Pre-evaluation dry run failed: {e}")
765
+
766
+
767
+ def _collect_activations(
768
+ model: nn.Module, dataloader, config: MetricsConfig, device: torch.device
769
+ ) -> dict[str, Any]:
770
+ """Collect model activations with enhanced error handling."""
771
+ hidden_states_list = []
772
+ fc1_activations_list = []
773
+ targets_list = []
774
+ first_batch = None
775
+
776
+ # Progress tracking
777
+ total_batches = (
778
+ min(config.oracle_windows, len(dataloader))
779
+ if hasattr(dataloader, "__len__")
780
+ else config.oracle_windows
781
+ )
782
+
783
+ with tqdm(
784
+ total=total_batches,
785
+ desc="Collecting activations",
786
+ disable=not config.progress_bars,
787
+ ) as pbar:
788
+ for i, batch in enumerate(dataloader):
789
+ if i >= config.oracle_windows:
790
+ break
791
+
792
+ try:
793
+ # Store first batch for later use
794
+ if first_batch is None:
795
+ first_batch = {
796
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
797
+ for k, v in batch.items()
798
+ }
799
+
800
+ # Move batch to device
801
+ input_ids = batch["input_ids"].to(device)
802
+
803
+ # Limit sequence length
804
+ if input_ids.shape[1] > config.max_tokens:
805
+ input_ids = input_ids[:, : config.max_tokens]
806
+
807
+ # Forward pass with hidden states
808
+ output = model(input_ids, output_hidden_states=True)
809
+
810
+ # Collect hidden states (exclude first and last)
811
+ if hasattr(output, "hidden_states") and len(output.hidden_states) > 2:
812
+ hidden_states = torch.stack(output.hidden_states[1:-1])
813
+ hidden_states = validator.validate_tensor(
814
+ hidden_states, f"hidden_states_batch_{i}", config
815
+ )
816
+ hidden_states_list.append(hidden_states)
817
+
818
+ # Collect FC1 activations for MI-Gini
819
+ fc1_acts = _extract_fc1_activations(model, output, config)
820
+ if fc1_acts is not None:
821
+ fc1_activations_list.append(fc1_acts)
822
+ targets_list.append(
823
+ input_ids[:, 1:]
824
+ ) # Shifted for next-token prediction
825
+
826
+ pbar.update(1)
827
+
828
+ except Exception as e:
829
+ logger.warning(f"Failed to process batch {i}: {e}")
830
+ continue
831
+
832
+ return {
833
+ "hidden_states": hidden_states_list,
834
+ "fc1_activations": fc1_activations_list,
835
+ "targets": targets_list,
836
+ "first_batch": first_batch,
837
+ }
838
+
839
+
840
+ def _extract_fc1_activations(
841
+ model: nn.Module, output, config: MetricsConfig
842
+ ) -> torch.Tensor | None:
843
+ """Extract FC1 activations for MI-Gini calculation."""
844
+ blocks = _locate_transformer_blocks_enhanced(model)
845
+ if blocks is None:
846
+ return None
847
+
848
+ try:
849
+ valid_activations = []
850
+ for idx, block in enumerate(blocks):
851
+ if hasattr(block, "mlp") and hasattr(block.mlp, "c_fc"):
852
+ try:
853
+ # Get hidden state for this layer
854
+ if (
855
+ hasattr(output, "hidden_states")
856
+ and len(output.hidden_states) > idx + 1
857
+ ):
858
+ hidden_state = output.hidden_states[idx + 1]
859
+ activation = block.mlp.c_fc(hidden_state)
860
+ activation = validator.validate_tensor(
861
+ activation, f"fc1_activation_{idx}", config
862
+ )
863
+ valid_activations.append(activation)
864
+ except Exception as e:
865
+ logger.debug(
866
+ f"Failed to extract FC1 activation for block {idx}: {e}"
867
+ )
868
+ continue
869
+
870
+ if valid_activations:
871
+ # Check for consistent shapes
872
+ shapes = [act.shape for act in valid_activations]
873
+ if len(set(shapes)) > 1:
874
+ logger.warning(f"Inconsistent FC1 activation shapes: {set(shapes)}")
875
+ # Use most common shape
876
+ from collections import Counter
877
+
878
+ most_common_shape = Counter(shapes).most_common(1)[0][0]
879
+ valid_activations = [
880
+ act for act in valid_activations if act.shape == most_common_shape
881
+ ]
882
+
883
+ return torch.stack(valid_activations)
884
+
885
+ except Exception as e:
886
+ logger.warning(f"FC1 activation extraction failed: {e}")
887
+
888
+ return None
889
+
890
+
891
+ def _calculate_sigma_max(
892
+ model: nn.Module,
893
+ first_batch: dict | None,
894
+ dep_manager: DependencyManager,
895
+ config: MetricsConfig,
896
+ device: torch.device,
897
+ ) -> float:
898
+ """Calculate sigma_max metric via Lens-3."""
899
+ if not dep_manager.is_available("scan_model_gains"):
900
+ logger.info("Skipping σ_max: scan_model_gains not available")
901
+ return float("nan")
902
+
903
+ if first_batch is None:
904
+ logger.info("Skipping σ_max: no data batch available")
905
+ return float("nan")
906
+
907
+ try:
908
+ scan_model_gains = dep_manager.get_module("scan_model_gains")
909
+ gains_df = scan_model_gains(model, first_batch)
910
+
911
+ if gains_df is None:
912
+ logger.warning("scan_model_gains returned None")
913
+ return float("nan")
914
+
915
+ # Filter out embedding and head layers if possible
916
+ if hasattr(gains_df, "columns") and "name" in gains_df.columns:
917
+ mask = ~gains_df["name"].str.contains(
918
+ "embed|lm_head", case=False, regex=True
919
+ )
920
+ filtered_gains = gains_df[mask]
921
+ else:
922
+ logger.info("Could not filter layers by name for σ_max")
923
+ filtered_gains = gains_df
924
+
925
+ if len(filtered_gains) == 0:
926
+ logger.warning("No valid layers found for σ_max computation")
927
+ return float("nan")
928
+
929
+ # Extract gains
930
+ gains_values = getattr(
931
+ filtered_gains, "gain", getattr(filtered_gains, "values", [])
932
+ )
933
+ gains_tensor = torch.as_tensor(gains_values, dtype=torch.float32, device=device)
934
+
935
+ if gains_tensor.numel() == 0:
936
+ logger.warning("No gain values found")
937
+ return float("nan")
938
+
939
+ # Validate and get max
940
+ gains_tensor = validator.validate_tensor(
941
+ gains_tensor, "sigma_max_gains", config
942
+ )
943
+ finite_mask = torch.isfinite(gains_tensor)
944
+
945
+ if not finite_mask.any():
946
+ logger.warning("All σ_max gains are NaN/Inf")
947
+ return float("nan")
948
+
949
+ sigma_max = torch.max(gains_tensor[finite_mask]).item()
950
+ logger.debug(f"Calculated σ_max: {sigma_max:.4f}")
951
+ return sigma_max
952
+
953
+ except Exception as e:
954
+ logger.warning(f"σ_max calculation failed: {e}")
955
+ return float("nan")
956
+
957
+
958
+ def _calculate_head_energy(
959
+ hidden_states_list: list[torch.Tensor], config: MetricsConfig
960
+ ) -> float:
961
+ """Calculate head energy metric (mean squared activation per layer)."""
962
+ if not hidden_states_list:
963
+ logger.info("Skipping head energy: no hidden states available")
964
+ return float("nan")
965
+
966
+ try:
967
+ # Concatenate all hidden states: [L, N, T, D]
968
+ hidden_stack = torch.cat(hidden_states_list, dim=1)
969
+
970
+ # Crop to max_tokens
971
+ hidden_crop = hidden_stack[:, :, : config.max_tokens, :]
972
+
973
+ # Sanitize
974
+ hidden_crop = validator.validate_tensor(
975
+ hidden_crop, "head_energy_hidden_states", config
976
+ )
977
+
978
+ # Calculate mean squared activation per layer
979
+ squared_activations = hidden_crop.float().pow(2).mean(dim=-1) # [L, N, T]
980
+ per_layer_energy = squared_activations.mean(dim=(1, 2)) # [L]
981
+
982
+ # Filter finite values
983
+ finite_mask = torch.isfinite(per_layer_energy)
984
+ if not finite_mask.any():
985
+ logger.warning("All head energies are NaN/Inf")
986
+ return float("nan")
987
+
988
+ head_energy = per_layer_energy[finite_mask].mean().item()
989
+ logger.debug(f"Calculated head energy: {head_energy:.6f}")
990
+ return head_energy
991
+
992
+ except Exception as e:
993
+ logger.warning(f"Head energy calculation failed: {e}")
994
+ return float("nan")
995
+
996
+
997
+ def _calculate_mi_gini(
998
+ model: nn.Module,
999
+ activation_data: dict[str, Any],
1000
+ dep_manager: DependencyManager,
1001
+ config: MetricsConfig,
1002
+ device: torch.device,
1003
+ ) -> float:
1004
+ """Calculate MI-based Gini coefficient."""
1005
+ if not dep_manager.is_available("mi_scores"):
1006
+ logger.info("Skipping MI-Gini: mi_scores not available")
1007
+ return float("nan")
1008
+
1009
+ if not activation_data["fc1_activations"] or not activation_data["targets"]:
1010
+ logger.info("Skipping MI-Gini: no FC1 activations available")
1011
+ return float("nan")
1012
+
1013
+ try:
1014
+ # Concatenate activations and targets
1015
+ fc1_all = torch.cat(activation_data["fc1_activations"], dim=1) # [L, N, T, D]
1016
+ targ_all = torch.cat(activation_data["targets"], dim=0) # [N, T]
1017
+
1018
+ # Trim to align dimensions (remove last token from activations)
1019
+ fc1_trim = fc1_all[:, :, :-1, :] # [L, N, T-1, D]
1020
+
1021
+ # Crop to max_tokens
1022
+ fc1_trim = fc1_trim[:, :, : config.max_tokens, :]
1023
+ targ_trim = targ_all[:, : config.max_tokens]
1024
+
1025
+ # Reshape for MI calculation
1026
+ L, N, T, D = fc1_trim.shape
1027
+ fc1_flat = fc1_trim.permute(0, 2, 1, 3).reshape(L, -1, D) # [L, N*T, D]
1028
+ targ_flat = targ_trim.flatten() # [N*T]
1029
+
1030
+ # Validate tensors
1031
+ fc1_flat = InputValidator.validate_tensor(fc1_flat, "mi_gini_features", config)
1032
+ targ_flat = InputValidator.validate_tensor(targ_flat, "mi_gini_targets", config)
1033
+
1034
+ # Get MI scores function
1035
+ mi_scores_fn = dep_manager.get_module("mi_scores")
1036
+
1037
+ # Try GPU calculation first
1038
+ try:
1039
+ logger.debug("Attempting MI-Gini calculation on GPU")
1040
+ mi_scores_result = mi_scores_fn(fc1_flat, targ_flat)
1041
+ mi_gini = _gini_vectorized(mi_scores_result)
1042
+ logger.debug(f"Calculated MI-Gini (GPU): {mi_gini:.6f}")
1043
+ return mi_gini
1044
+
1045
+ except RuntimeError as e:
1046
+ if "out of memory" in str(e).lower():
1047
+ logger.warning("GPU OOM for MI-Gini, falling back to CPU")
1048
+ if torch.cuda.is_available():
1049
+ torch.cuda.empty_cache()
1050
+
1051
+ # CPU fallback with subsampling
1052
+ mi_gini = _mi_gini_optimized_cpu_path(
1053
+ fc1_flat.cpu().float(),
1054
+ targ_flat.cpu(),
1055
+ config.max_samples_per_layer,
1056
+ config,
1057
+ )
1058
+ logger.debug(f"Calculated MI-Gini (CPU): {mi_gini:.6f}")
1059
+ return mi_gini
1060
+ else:
1061
+ raise
1062
+
1063
+ except Exception as e:
1064
+ logger.warning(f"MI-Gini calculation failed: {e}")
1065
+ return float("nan")
1066
+
1067
+
1068
+ def _finalize_results(
1069
+ results: dict[str, Any],
1070
+ skipped_metrics: list[str],
1071
+ cache: ResultCache,
1072
+ cache_key: str,
1073
+ start_time: float,
1074
+ ) -> dict[str, float]:
1075
+ """Finalize and validate results."""
1076
+ # Ensure all values are finite or NaN
1077
+ for key, value in results.items():
1078
+ if not isinstance(value, int | float):
1079
+ logger.warning(
1080
+ f"Metric {key} has invalid type {type(value)}, setting to NaN"
1081
+ )
1082
+ results[key] = float("nan")
1083
+ elif not (math.isnan(value) or math.isfinite(value)):
1084
+ logger.warning(f"Metric {key} is infinite, setting to NaN")
1085
+ results[key] = float("nan")
1086
+
1087
+ # Log skipped metrics
1088
+ if skipped_metrics:
1089
+ logger.info(f"Skipped metrics: {', '.join(skipped_metrics)}")
1090
+
1091
+ # Cache results
1092
+ cache.set(cache_key, results)
1093
+
1094
+ # Log completion
1095
+ elapsed = time.time() - start_time
1096
+ logger.info(f"Metrics calculation completed in {elapsed:.2f}s: {results}")
1097
+
1098
+ return results
1099
+
1100
+
1101
+ # ── Backward compatibility functions ──────────────────────────────────────
1102
+ def _gini(vec: torch.Tensor) -> float:
1103
+ """Legacy Gini function for backward compatibility."""
1104
+ return _gini_vectorized(vec)
1105
+
1106
+
1107
+ def _mi_gini_cpu_safe_path(
1108
+ feats_cpu: torch.Tensor, targ_cpu: torch.Tensor, max_per_layer: int
1109
+ ) -> float:
1110
+ """Legacy CPU MI-Gini function for backward compatibility."""
1111
+ config = MetricsConfig(max_samples_per_layer=max_per_layer, progress_bars=True)
1112
+ return _mi_gini_optimized_cpu_path(feats_cpu, targ_cpu, max_per_layer, config)
1113
+
1114
+
1115
+ def _locate_transformer_blocks(model: nn.Module) -> list[nn.Module] | None:
1116
+ """Legacy transformer block locator for backward compatibility."""
1117
+ return _locate_transformer_blocks_enhanced(model)
1118
+
1119
+
1120
+ # ── Additional utility functions ───────────────────────────────────────────
1121
+ def get_metrics_info() -> dict[str, Any]:
1122
+ """Get information about available metrics and dependencies."""
1123
+ dep_manager = DependencyManager()
1124
+
1125
+ return {
1126
+ "available_metrics": ["sigma_max", "head_energy", "mi_gini"],
1127
+ "available_dependencies": list(dep_manager.available_modules.keys()),
1128
+ "missing_dependencies": dep_manager.get_missing_dependencies(),
1129
+ "default_config": MetricsConfig().__dict__,
1130
+ }
1131
+
1132
+
1133
+ def validate_metrics_environment() -> bool:
1134
+ """Validate that the metrics environment is properly set up."""
1135
+ try:
1136
+ dep_manager = DependencyManager()
1137
+ MetricsConfig()
1138
+
1139
+ # Check basic dependencies
1140
+
1141
+ logger.info("✓ Basic dependencies available")
1142
+
1143
+ # Check optional dependencies
1144
+ available_count = len(dep_manager.available_modules)
1145
+ total_count = available_count + len(dep_manager.missing_modules)
1146
+
1147
+ logger.info(
1148
+ f"✓ {available_count}/{total_count} optional dependencies available"
1149
+ )
1150
+
1151
+ if dep_manager.missing_modules:
1152
+ logger.warning("Some optional dependencies are missing:")
1153
+ for name, error in dep_manager.missing_modules:
1154
+ logger.warning(f" - {name}: {error}")
1155
+
1156
+ return True
1157
+
1158
+ except Exception as e:
1159
+ logger.error(f"Environment validation failed: {e}")
1160
+ return False
1161
+
1162
+
1163
+ # ── Import necessary modules for validation ────────────────────────────────
1164
+ # Note: math is already imported at top of file
1165
+
1166
+ # Global validator instance for use in helper functions
1167
+ validator = InputValidator()
1168
+
1169
+
1170
+ # ── Perplexity validation ──────────────────────────────────────────────────
1171
+ class PerplexityStatus:
1172
+ """Quality status levels for ppl-like primary metrics (perplexity)."""
1173
+
1174
+ EXCELLENT = "excellent" # < 50
1175
+ GOOD = "good" # 50-100
1176
+ ACCEPTABLE = "acceptable" # 100-200
1177
+ POOR = "poor" # 200-500
1178
+ UNUSABLE = "unusable" # > 500
1179
+
1180
+ @classmethod
1181
+ def from_value(cls, ppl: float, vocab_size: int | None = None) -> str:
1182
+ """Get status from perplexity value."""
1183
+ if ppl < 50:
1184
+ return cls.EXCELLENT
1185
+ elif ppl < 100:
1186
+ return cls.GOOD
1187
+ elif ppl < 200:
1188
+ return cls.ACCEPTABLE
1189
+ elif ppl < 500:
1190
+ return cls.POOR
1191
+ else:
1192
+ return cls.UNUSABLE
1193
+
1194
+
1195
+ def validate_perplexity(
1196
+ ppl: float,
1197
+ vocab_size: int | None = None,
1198
+ context: str = "evaluation",
1199
+ warn_threshold: float = 200.0,
1200
+ error_threshold: float = 2000.0,
1201
+ allow_high: bool = False,
1202
+ ) -> tuple[bool, str, str]:
1203
+ """
1204
+ Validate perplexity value and provide feedback.
1205
+
1206
+ Args:
1207
+ ppl: Perplexity value to validate
1208
+ vocab_size: Vocabulary size for context-aware validation
1209
+ context: Context string for error messages
1210
+ warn_threshold: Threshold for warning (default 200)
1211
+ error_threshold: Threshold for error (default 2000)
1212
+ allow_high: Allow high perplexity values (for testing)
1213
+
1214
+ Returns:
1215
+ Tuple of (is_valid, status, message)
1216
+ """
1217
+ # Check for invalid values
1218
+ if math.isnan(ppl) or math.isinf(ppl):
1219
+ return False, "invalid", f"Perplexity is {ppl}"
1220
+
1221
+ if ppl < 1.0:
1222
+ return False, "invalid", f"Perplexity {ppl:.2f} is less than 1.0"
1223
+
1224
+ # Get status
1225
+ status = PerplexityStatus.from_value(ppl, vocab_size)
1226
+
1227
+ # Adjust thresholds based on vocab size if provided
1228
+ if vocab_size is not None:
1229
+ # For untrained models, ppl-like PM ≈ vocab_size is expected
1230
+ # Adjust thresholds accordingly
1231
+ warn_threshold = max(warn_threshold, vocab_size * 0.5)
1232
+ error_threshold = max(error_threshold, vocab_size * 2.0)
1233
+
1234
+ # Generate message based on status
1235
+ if ppl > error_threshold and not allow_high:
1236
+ message = (
1237
+ f"Perplexity {ppl:.1f} exceeds error threshold {error_threshold:.0f} "
1238
+ f"in {context}. Model appears to be untrained or corrupted."
1239
+ )
1240
+ return False, status, message
1241
+
1242
+ elif ppl > warn_threshold:
1243
+ message = (
1244
+ f"Perplexity {ppl:.1f} exceeds warning threshold {warn_threshold:.0f} "
1245
+ f"in {context}. Model may be severely degraded."
1246
+ )
1247
+ if not allow_high:
1248
+ logger.warning(message)
1249
+ return True, status, message
1250
+
1251
+ elif status == PerplexityStatus.POOR:
1252
+ message = f"Perplexity {ppl:.1f} indicates poor model quality in {context}."
1253
+ logger.info(message)
1254
+ return True, status, message
1255
+
1256
+ elif status == PerplexityStatus.ACCEPTABLE:
1257
+ message = f"Perplexity {ppl:.1f} is acceptable for {context}."
1258
+ return True, status, message
1259
+
1260
+ else:
1261
+ message = f"Perplexity {ppl:.1f} is {status} for {context}."
1262
+ return True, status, message
1263
+
1264
+
1265
+ # ── Helper function for robust forward pass ────────────────────────────────
1266
+ def _forward_loss_causal(
1267
+ model: nn.Module,
1268
+ input_ids: torch.Tensor,
1269
+ attention_mask: torch.Tensor | None = None,
1270
+ labels: torch.Tensor | None = None,
1271
+ ) -> tuple[float, torch.Tensor | None]:
1272
+ """
1273
+ Robust forward that handles HF ModelOutput or tuple, computes loss if needed.
1274
+ Returns (loss_value: float, logits: torch.Tensor or None).
1275
+ """
1276
+ import torch.nn.functional as F
1277
+
1278
+ # 1) Prefer dict-style outputs
1279
+ try:
1280
+ outputs = model(
1281
+ input_ids=input_ids,
1282
+ attention_mask=attention_mask,
1283
+ labels=labels,
1284
+ return_dict=True,
1285
+ )
1286
+ # If we got a ModelOutput, use it
1287
+ if hasattr(outputs, "loss") and outputs.loss is not None:
1288
+ return float(outputs.loss.detach().cpu()), getattr(outputs, "logits", None)
1289
+ logits = getattr(outputs, "logits", None)
1290
+ except (TypeError, AttributeError):
1291
+ # Some stub models/tests may not accept return_dict
1292
+ outputs = model(
1293
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels
1294
+ )
1295
+ if isinstance(outputs, tuple | list):
1296
+ # If labels were provided, many HF models put loss first, logits second
1297
+ if (
1298
+ labels is not None
1299
+ and len(outputs) >= 2
1300
+ and torch.is_tensor(outputs[0])
1301
+ and outputs[0].ndim == 0
1302
+ ):
1303
+ return float(outputs[0].detach().cpu()), outputs[1] if len(
1304
+ outputs
1305
+ ) > 1 else None
1306
+ # Otherwise first is logits
1307
+ logits = outputs[0] if len(outputs) > 0 else None
1308
+ else:
1309
+ # Custom object: try attributes
1310
+ maybe_loss = getattr(outputs, "loss", None)
1311
+ maybe_logits = getattr(outputs, "logits", None)
1312
+ if maybe_loss is not None:
1313
+ return float(maybe_loss.detach().cpu()), maybe_logits
1314
+ logits = maybe_logits
1315
+
1316
+ # 2) If we're here, we have logits but no loss → compute it manually
1317
+ if logits is None:
1318
+ raise MetricsError(
1319
+ code="E401",
1320
+ message="METRICS-COMPUTE-FAILED: model returned neither loss nor logits",
1321
+ )
1322
+
1323
+ if labels is None:
1324
+ raise ValidationError(
1325
+ code="E402",
1326
+ message="METRICS-VALIDATION-FAILED",
1327
+ details={"reason": "labels are required to compute perplexity loss"},
1328
+ )
1329
+
1330
+ # Causal LM shift
1331
+ shift_logits = logits[:, :-1, :].contiguous()
1332
+ shift_labels = labels[:, 1:].contiguous()
1333
+
1334
+ loss = F.cross_entropy(
1335
+ shift_logits.view(-1, shift_logits.size(-1)),
1336
+ shift_labels.view(-1),
1337
+ ignore_index=-100,
1338
+ reduction="mean",
1339
+ )
1340
+ return float(loss.detach().cpu()), logits
1341
+
1342
+
1343
+ def _resolve_eval_device(
1344
+ model: nn.Module, device: str | torch.device | None
1345
+ ) -> torch.device:
1346
+ """
1347
+ Resolve evaluation device with graceful MPS fallback.
1348
+
1349
+ If MPS is requested but unavailable (common in CI or non‑MacOS builds),
1350
+ fall back to CPU instead of raising at tensor .to(device) calls.
1351
+ """
1352
+ if device is None:
1353
+ try:
1354
+ resolved = next(model.parameters()).device
1355
+ except StopIteration:
1356
+ resolved = torch.device("cpu")
1357
+ else:
1358
+ resolved = torch.device(device) if isinstance(device, str) else device
1359
+
1360
+ # Handle MPS when backend is not actually usable
1361
+ try:
1362
+ if isinstance(resolved, torch.device) and resolved.type == "mps":
1363
+ mps_backend = getattr(torch.backends, "mps", None)
1364
+ is_available = bool(
1365
+ mps_backend is not None
1366
+ and hasattr(mps_backend, "is_available")
1367
+ and mps_backend.is_available()
1368
+ )
1369
+ if not is_available:
1370
+ logger.warning(
1371
+ "Requested device 'mps' for metrics evaluation but MPS backend "
1372
+ "is not available; falling back to CPU."
1373
+ )
1374
+ resolved = torch.device("cpu")
1375
+ except Exception:
1376
+ # On any introspection failure, be conservative and fall back to CPU
1377
+ resolved = torch.device("cpu")
1378
+
1379
+ return resolved
1380
+
1381
+
1382
+ # ── Perplexity calculation ─────────────────────────────────────────────────
1383
+ @torch.no_grad()
1384
+ def calculate_perplexity(
1385
+ model: nn.Module,
1386
+ dataloader,
1387
+ max_batches: int = 100,
1388
+ device: str | torch.device | None = None,
1389
+ ) -> float:
1390
+ """
1391
+ DEPRECATED: Use compute_perplexity for new code.
1392
+ This is an alias for backward compatibility with tests.
1393
+ """
1394
+ return compute_perplexity(model, dataloader, max_samples=max_batches, device=device)
1395
+
1396
+
1397
+ @torch.no_grad()
1398
+ def compute_perplexity_strict(
1399
+ model: nn.Module, dataloader, device: str | torch.device | None = None
1400
+ ) -> float:
1401
+ """
1402
+ Compute perplexity with strict token-level accounting.
1403
+
1404
+ Args:
1405
+ model: Language model to evaluate
1406
+ dataloader: DataLoader providing input sequences
1407
+ device: Device to use for computation
1408
+
1409
+ Returns:
1410
+ Perplexity value
1411
+
1412
+ Raises:
1413
+ ValueError: If no valid tokens found for perplexity computation
1414
+ """
1415
+ device = _resolve_eval_device(model, device)
1416
+
1417
+ model.eval()
1418
+ nll_sum = 0.0
1419
+ tok_count = 0
1420
+
1421
+ for batch in dataloader:
1422
+ # Handle different batch formats
1423
+ if isinstance(batch, dict):
1424
+ input_ids = batch.get("input_ids", batch.get("inputs", None))
1425
+ labels = batch.get("labels", None)
1426
+ attention_mask = batch.get("attention_mask", None)
1427
+ token_type_ids = batch.get("token_type_ids", None)
1428
+ elif isinstance(batch, tuple | list):
1429
+ input_ids = batch[0] if len(batch) > 0 else None
1430
+ labels = batch[1] if len(batch) > 1 else None
1431
+ attention_mask = batch[2] if len(batch) > 2 else None
1432
+ token_type_ids = batch[3] if len(batch) > 3 else None
1433
+ else:
1434
+ input_ids = batch
1435
+ labels = None
1436
+ attention_mask = None
1437
+ token_type_ids = None
1438
+
1439
+ if input_ids is None or not isinstance(input_ids, torch.Tensor):
1440
+ continue
1441
+
1442
+ input_ids = input_ids.to(device)
1443
+ attn = attention_mask.to(device) if attention_mask is not None else None
1444
+ token_type_ids_t = (
1445
+ token_type_ids.to(device) if token_type_ids is not None else None
1446
+ )
1447
+
1448
+ # Default causal labels
1449
+ if labels is None:
1450
+ labels = input_ids.clone()
1451
+ if attn is not None:
1452
+ labels[attn == 0] = -100
1453
+ else:
1454
+ labels = labels.to(device)
1455
+
1456
+ # Skip if sequence too short
1457
+ if input_ids.size(1) < 2:
1458
+ continue
1459
+
1460
+ is_masked_lm = hasattr(model, "config") and getattr(
1461
+ model.config, "model_type", ""
1462
+ ) in {"bert", "roberta", "distilbert", "albert"}
1463
+
1464
+ if is_masked_lm:
1465
+ masked_labels = labels.clone()
1466
+ if attn is not None:
1467
+ masked_labels = masked_labels.masked_fill(attn == 0, -100)
1468
+ outputs = model(
1469
+ input_ids=input_ids,
1470
+ attention_mask=attn,
1471
+ token_type_ids=token_type_ids_t,
1472
+ labels=masked_labels,
1473
+ return_dict=True,
1474
+ )
1475
+ loss = outputs.loss
1476
+ if loss is None:
1477
+ continue
1478
+ valid_tokens = int((masked_labels != -100).sum().item())
1479
+ if valid_tokens == 0:
1480
+ continue
1481
+ nll_sum += float(loss.item()) * valid_tokens
1482
+ tok_count += valid_tokens
1483
+ continue
1484
+
1485
+ # Forward (don't trust .loss, compute ourselves)
1486
+ try:
1487
+ outputs = model(input_ids=input_ids, attention_mask=attn, return_dict=True)
1488
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
1489
+ except Exception:
1490
+ # Fallback for non-standard models
1491
+ outputs = model(input_ids=input_ids, attention_mask=attn)
1492
+ if isinstance(outputs, tuple | list):
1493
+ logits = outputs[0]
1494
+ else:
1495
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
1496
+
1497
+ # Causal shift
1498
+ shift_logits = logits[:, :-1, :]
1499
+ shift_labels = labels[:, 1:]
1500
+ shift_mask = attn[:, 1:] if attn is not None else None
1501
+
1502
+ valid = shift_labels != -100
1503
+ if shift_mask is not None:
1504
+ valid = valid & shift_mask.bool()
1505
+
1506
+ if not valid.any():
1507
+ continue
1508
+
1509
+ log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1510
+ tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1511
+ nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
1512
+
1513
+ nll_sum += nll[valid].sum().item()
1514
+ tok_count += int(valid.sum().item())
1515
+
1516
+ if tok_count == 0:
1517
+ raise ValidationError(
1518
+ code="E402",
1519
+ message="METRICS-VALIDATION-FAILED",
1520
+ details={
1521
+ "reason": "No valid tokens for perplexity (all masked or seq_len<=1)."
1522
+ },
1523
+ )
1524
+
1525
+ return float(torch.exp(torch.tensor(nll_sum / tok_count)))
1526
+
1527
+
1528
+ @torch.no_grad()
1529
+ def compute_perplexity(
1530
+ model: nn.Module,
1531
+ dataloader,
1532
+ max_samples: int = 100,
1533
+ device: str | torch.device | None = None,
1534
+ ) -> float:
1535
+ """
1536
+ Compute perplexity of a language model on a dataset.
1537
+
1538
+ ALWAYS uses strict token-level accounting to avoid padding issues.
1539
+
1540
+ Args:
1541
+ model: Language model to evaluate
1542
+ dataloader: DataLoader providing input sequences
1543
+ max_samples: Maximum number of batches to evaluate
1544
+ device: Device to use for computation
1545
+
1546
+ Returns:
1547
+ Perplexity value
1548
+
1549
+ Raises:
1550
+ ValueError: If no valid tokens found
1551
+ """
1552
+ device = _resolve_eval_device(model, device)
1553
+
1554
+ model.eval()
1555
+ nll_sum = 0.0
1556
+ tok_count = 0
1557
+ batch_count = 0
1558
+
1559
+ for i, batch in enumerate(dataloader):
1560
+ # Check max_samples limit
1561
+ if max_samples is not None and i >= max_samples:
1562
+ break
1563
+
1564
+ # Handle different batch formats
1565
+ if isinstance(batch, dict):
1566
+ input_ids = batch.get("input_ids", batch.get("inputs", None))
1567
+ labels = batch.get("labels", None)
1568
+ attention_mask = batch.get("attention_mask", None)
1569
+ elif isinstance(batch, tuple | list):
1570
+ input_ids = batch[0] if len(batch) > 0 else None
1571
+ labels = batch[1] if len(batch) > 1 else None
1572
+ attention_mask = batch[2] if len(batch) > 2 else None
1573
+ else:
1574
+ input_ids = batch
1575
+ labels = None
1576
+ attention_mask = None
1577
+
1578
+ if input_ids is None or not isinstance(input_ids, torch.Tensor):
1579
+ continue
1580
+
1581
+ input_ids = input_ids.to(device)
1582
+ attn = attention_mask.to(device) if attention_mask is not None else None
1583
+
1584
+ # Default causal labels
1585
+ if labels is None:
1586
+ labels = input_ids.clone()
1587
+ if attn is not None:
1588
+ labels[attn == 0] = -100
1589
+ else:
1590
+ labels = labels.to(device)
1591
+
1592
+ # Skip if sequence too short
1593
+ if input_ids.size(1) < 2:
1594
+ continue
1595
+
1596
+ # Forward pass - get logits
1597
+ try:
1598
+ outputs = model(input_ids=input_ids, attention_mask=attn, return_dict=True)
1599
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
1600
+ except Exception:
1601
+ # Fallback for non-standard models
1602
+ outputs = model(input_ids=input_ids, attention_mask=attn)
1603
+ if isinstance(outputs, tuple | list):
1604
+ logits = outputs[0]
1605
+ else:
1606
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
1607
+
1608
+ # Causal shift for next-token prediction
1609
+ shift_logits = logits[:, :-1, :]
1610
+ shift_labels = labels[:, 1:]
1611
+ shift_mask = attn[:, 1:] if attn is not None else None
1612
+
1613
+ # Identify valid (non-padding) tokens
1614
+ valid = shift_labels != -100
1615
+ if shift_mask is not None:
1616
+ valid = valid & shift_mask.bool()
1617
+
1618
+ if not valid.any():
1619
+ continue
1620
+
1621
+ # Compute negative log-likelihood
1622
+ log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1623
+ tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1624
+
1625
+ # MPS workaround: gather operation can fail on MPS, use CPU fallback
1626
+ if str(device).startswith("mps"):
1627
+ log_probs_cpu = log_probs.cpu()
1628
+ tgt_cpu = tgt.cpu()
1629
+ nll_cpu = -log_probs_cpu.gather(-1, tgt_cpu).squeeze(-1)
1630
+ nll = nll_cpu.to(device)
1631
+ else:
1632
+ nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
1633
+
1634
+ # Accumulate only for valid tokens
1635
+ nll_sum += nll[valid].sum().item()
1636
+ tok_count += int(valid.sum().item())
1637
+ batch_count += 1
1638
+
1639
+ if tok_count == 0:
1640
+ raise ValidationError(
1641
+ code="E402",
1642
+ message="METRICS-VALIDATION-FAILED",
1643
+ details={
1644
+ "reason": (
1645
+ f"No valid tokens for perplexity computation after {batch_count} batches. "
1646
+ "All tokens were either padding or sequences were too short (<=1 token). "
1647
+ "Ensure your data contains sequences of at least 2 tokens."
1648
+ )
1649
+ },
1650
+ )
1651
+
1652
+ # Compute perplexity from average NLL
1653
+ avg_nll = nll_sum / tok_count
1654
+ ppl = float(math.exp(avg_nll))
1655
+
1656
+ # Sanity check
1657
+ if ppl < 1.0:
1658
+ logger.warning(
1659
+ f"Computed perplexity {ppl:.2f} is less than 1.0, setting to 1.0"
1660
+ )
1661
+ ppl = 1.0
1662
+ elif not math.isfinite(ppl):
1663
+ logger.warning(f"Computed perplexity is not finite: {ppl}")
1664
+ ppl = float("inf")
1665
+
1666
+ return ppl
1667
+
1668
+
1669
+ # ── New Unified Evaluation Functions ──────────────────────────────────────
1670
+
1671
+
1672
+ @torch.no_grad()
1673
+ def compute_ppl(
1674
+ model: nn.Module,
1675
+ adapter: Any | None,
1676
+ window: Any, # EvaluationWindow
1677
+ device: str | torch.device | None = None,
1678
+ ) -> float:
1679
+ """
1680
+ Compute perplexity for a specific evaluation window.
1681
+
1682
+ This is the new unified evaluation function that works with EvaluationWindow objects
1683
+ from the data loading system.
1684
+
1685
+ Args:
1686
+ model: Language model to evaluate
1687
+ adapter: Model adapter (unused currently, for future extensibility)
1688
+ window: EvaluationWindow with tokenized samples
1689
+ device: Device to use for computation
1690
+
1691
+ Returns:
1692
+ Perplexity value for the window
1693
+ """
1694
+ device = _resolve_eval_device(model, device)
1695
+
1696
+ model.eval()
1697
+ nll_sum = 0.0
1698
+ tok_count = 0
1699
+
1700
+ # Process each sample in the window
1701
+ for input_ids, attention_mask in zip(
1702
+ window.input_ids, window.attention_masks, strict=False
1703
+ ):
1704
+ if not input_ids:
1705
+ continue
1706
+
1707
+ # Convert to tensors
1708
+ input_ids_tensor = (
1709
+ torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
1710
+ )
1711
+ attention_mask_tensor = (
1712
+ torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
1713
+ )
1714
+
1715
+ # Skip sequences that are too short
1716
+ if input_ids_tensor.size(1) < 2:
1717
+ continue
1718
+
1719
+ # Forward pass
1720
+ try:
1721
+ outputs = model(
1722
+ input_ids=input_ids_tensor,
1723
+ attention_mask=attention_mask_tensor,
1724
+ return_dict=True,
1725
+ )
1726
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
1727
+ except Exception:
1728
+ # Fallback for non-standard models
1729
+ outputs = model(
1730
+ input_ids=input_ids_tensor, attention_mask=attention_mask_tensor
1731
+ )
1732
+ if isinstance(outputs, tuple | list):
1733
+ logits = outputs[0]
1734
+ else:
1735
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
1736
+
1737
+ # Causal shift for next-token prediction
1738
+ shift_logits = logits[:, :-1, :]
1739
+ shift_labels = input_ids_tensor[:, 1:]
1740
+ shift_mask = attention_mask_tensor[:, 1:]
1741
+
1742
+ # Identify valid (non-padding) tokens
1743
+ valid = (shift_labels != -100) & shift_mask.bool()
1744
+
1745
+ if not valid.any():
1746
+ continue
1747
+
1748
+ # Compute negative log-likelihood
1749
+ log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1750
+ tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1751
+
1752
+ # Handle MPS device issues with gather
1753
+ if str(device).startswith("mps"):
1754
+ log_probs_cpu = log_probs.cpu()
1755
+ tgt_cpu = tgt.cpu()
1756
+ nll_cpu = -log_probs_cpu.gather(-1, tgt_cpu).squeeze(-1)
1757
+ nll = nll_cpu.to(device)
1758
+ else:
1759
+ nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
1760
+
1761
+ # Accumulate only for valid tokens
1762
+ nll_sum += nll[valid].sum().item()
1763
+ tok_count += int(valid.sum().item())
1764
+
1765
+ if tok_count == 0:
1766
+ raise ValidationError(
1767
+ code="E402",
1768
+ message="METRICS-VALIDATION-FAILED",
1769
+ details={
1770
+ "reason": "No valid tokens for perplexity computation in evaluation window",
1771
+ },
1772
+ )
1773
+
1774
+ # Compute perplexity from average NLL
1775
+ avg_nll = nll_sum / tok_count
1776
+ ppl = float(math.exp(avg_nll))
1777
+
1778
+ # Sanity check
1779
+ if ppl < 1.0:
1780
+ logger.warning(
1781
+ f"Computed perplexity {ppl:.2f} is less than 1.0, setting to 1.0"
1782
+ )
1783
+ ppl = 1.0
1784
+ elif not math.isfinite(ppl):
1785
+ logger.warning(f"Computed perplexity is not finite: {ppl}")
1786
+ ppl = float("inf")
1787
+
1788
+ return ppl
1789
+
1790
+
1791
+ def measure_latency(
1792
+ model: nn.Module,
1793
+ window: Any, # EvaluationWindow
1794
+ device: str | torch.device | None = None,
1795
+ warmup_steps: int = 3,
1796
+ measurement_steps: int = 10,
1797
+ ) -> float:
1798
+ """
1799
+ Measure inference latency per token.
1800
+
1801
+ Args:
1802
+ model: Model to measure
1803
+ window: EvaluationWindow with samples to use for measurement
1804
+ device: Device to use for measurement
1805
+ warmup_steps: Number of warmup iterations
1806
+ measurement_steps: Number of measurement iterations
1807
+
1808
+ Returns:
1809
+ Average latency in milliseconds per token
1810
+ """
1811
+ if device is None:
1812
+ device = next(model.parameters()).device
1813
+ else:
1814
+ device = torch.device(device) if isinstance(device, str) else device
1815
+
1816
+ model.eval()
1817
+
1818
+ # Select a representative sample for timing
1819
+ if not window.input_ids:
1820
+ return 0.0
1821
+
1822
+ # Use the first valid sample
1823
+ sample_input_ids = None
1824
+ sample_attention_mask = None
1825
+
1826
+ for input_ids, attention_mask in zip(
1827
+ window.input_ids, window.attention_masks, strict=False
1828
+ ):
1829
+ if len(input_ids) > 10: # Ensure reasonable length
1830
+ sample_input_ids = (
1831
+ torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
1832
+ )
1833
+ sample_attention_mask = (
1834
+ torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
1835
+ )
1836
+ break
1837
+
1838
+ if sample_input_ids is None:
1839
+ return 0.0
1840
+
1841
+ # Warmup
1842
+ with torch.no_grad():
1843
+ for _ in range(warmup_steps):
1844
+ try:
1845
+ _ = model(
1846
+ input_ids=sample_input_ids, attention_mask=sample_attention_mask
1847
+ )
1848
+ except Exception:
1849
+ # If there are issues with the model, return 0
1850
+ return 0.0
1851
+
1852
+ # Synchronize for accurate timing
1853
+ if device.type == "cuda":
1854
+ torch.cuda.synchronize()
1855
+
1856
+ # Measure latency
1857
+ start_time = time.time()
1858
+
1859
+ with torch.no_grad():
1860
+ for _ in range(measurement_steps):
1861
+ _ = model(input_ids=sample_input_ids, attention_mask=sample_attention_mask)
1862
+
1863
+ if device.type == "cuda":
1864
+ torch.cuda.synchronize()
1865
+
1866
+ end_time = time.time()
1867
+
1868
+ # Calculate per-token latency
1869
+ total_time_ms = (end_time - start_time) * 1000 # Convert to milliseconds
1870
+ total_tokens = int(sample_attention_mask.sum().item()) * measurement_steps
1871
+
1872
+ if total_tokens == 0:
1873
+ return 0.0
1874
+
1875
+ latency_ms_per_token = total_time_ms / total_tokens
1876
+
1877
+ logger.debug(
1878
+ f"Measured latency: {latency_ms_per_token:.3f} ms/token over {measurement_steps} steps"
1879
+ )
1880
+ return latency_ms_per_token
1881
+
1882
+
1883
+ def measure_memory(
1884
+ model: nn.Module,
1885
+ window: Any, # EvaluationWindow
1886
+ device: str | torch.device | None = None,
1887
+ ) -> float:
1888
+ """
1889
+ Measure peak memory usage during inference.
1890
+
1891
+ Args:
1892
+ model: Model to measure
1893
+ window: EvaluationWindow with samples to use for measurement
1894
+ device: Device to measure memory on
1895
+
1896
+ Returns:
1897
+ Peak memory usage in MB
1898
+ """
1899
+ if device is None:
1900
+ device = next(model.parameters()).device
1901
+ else:
1902
+ device = torch.device(device) if isinstance(device, str) else device
1903
+
1904
+ model.eval()
1905
+
1906
+ # Get baseline memory
1907
+ if device.type == "cuda":
1908
+ torch.cuda.empty_cache()
1909
+ baseline_memory = torch.cuda.memory_allocated() / (1024 * 1024)
1910
+ torch.cuda.reset_peak_memory_stats()
1911
+ else:
1912
+ # For CPU/MPS, use psutil for system memory
1913
+ import psutil
1914
+
1915
+ process = psutil.Process()
1916
+ baseline_memory = process.memory_info().rss / (1024 * 1024)
1917
+
1918
+ # Run inference on a few samples to measure memory
1919
+ max_memory = baseline_memory
1920
+
1921
+ with torch.no_grad():
1922
+ for i, (input_ids, attention_mask) in enumerate(
1923
+ zip(window.input_ids, window.attention_masks, strict=False)
1924
+ ):
1925
+ if i >= 5: # Only measure on first 5 samples
1926
+ break
1927
+
1928
+ if not input_ids:
1929
+ continue
1930
+
1931
+ try:
1932
+ input_ids_tensor = (
1933
+ torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
1934
+ )
1935
+ attention_mask_tensor = (
1936
+ torch.tensor(attention_mask, dtype=torch.long)
1937
+ .unsqueeze(0)
1938
+ .to(device)
1939
+ )
1940
+
1941
+ _ = model(
1942
+ input_ids=input_ids_tensor, attention_mask=attention_mask_tensor
1943
+ )
1944
+
1945
+ # Measure memory after forward pass
1946
+ if device.type == "cuda":
1947
+ current_memory = torch.cuda.memory_allocated() / (1024 * 1024)
1948
+ else:
1949
+ current_memory = process.memory_info().rss / (1024 * 1024)
1950
+
1951
+ max_memory = max(max_memory, current_memory)
1952
+
1953
+ except Exception as e:
1954
+ logger.debug(f"Memory measurement failed for sample {i}: {e}")
1955
+ continue
1956
+
1957
+ peak_memory_mb = max_memory
1958
+ logger.debug(f"Peak memory usage: {peak_memory_mb:.1f} MB")
1959
+
1960
+ return peak_memory_mb
1961
+
1962
+
1963
+ def compute_parameter_deltas(
1964
+ model_before: nn.Module, model_after: nn.Module, adapter: Any | None = None
1965
+ ) -> dict[str, Any]:
1966
+ """
1967
+ Compute precise parameter deltas between before and after models.
1968
+
1969
+ Args:
1970
+ model_before: Model state before edit
1971
+ model_after: Model state after edit
1972
+ adapter: Model adapter for architecture-specific analysis
1973
+
1974
+ Returns:
1975
+ Dictionary with parameter delta information:
1976
+ - params_changed: Number of parameters that were modified
1977
+ - layers_modified: Number of layers that were changed
1978
+ - sparsity: Overall sparsity ratio (if applicable)
1979
+ """
1980
+ deltas = {
1981
+ "params_changed": 0,
1982
+ "layers_modified": 0,
1983
+ "sparsity": None,
1984
+ }
1985
+
1986
+ try:
1987
+ # Compare parameters
1988
+ before_params = dict(model_before.named_parameters())
1989
+ after_params = dict(model_after.named_parameters())
1990
+
1991
+ modified_layers = set()
1992
+ total_changed = 0
1993
+
1994
+ for name, before_param in before_params.items():
1995
+ if name not in after_params:
1996
+ continue
1997
+
1998
+ after_param = after_params[name]
1999
+
2000
+ # Check if parameter changed
2001
+ if not torch.equal(before_param.data, after_param.data):
2002
+ total_changed += before_param.numel()
2003
+
2004
+ # Extract layer information from parameter name
2005
+ layer_match = None
2006
+ if ".h." in name or ".layers." in name:
2007
+ # Extract layer number for transformer models
2008
+ import re
2009
+
2010
+ match = re.search(r"\.(?:h|layers)\.(\d+)\.", name)
2011
+ if match:
2012
+ layer_match = int(match.group(1))
2013
+ modified_layers.add(layer_match)
2014
+
2015
+ deltas["params_changed"] = total_changed
2016
+ deltas["layers_modified"] = len(modified_layers)
2017
+
2018
+ # Structural deltas (like head/neuron counts) are not tracked in this profile
2019
+
2020
+ # Compute overall sparsity if applicable
2021
+ total_params_before = sum(p.numel() for p in model_before.parameters())
2022
+ total_params_after = sum(p.numel() for p in model_after.parameters())
2023
+
2024
+ if total_params_after < total_params_before:
2025
+ deltas["sparsity"] = 1.0 - (total_params_after / total_params_before)
2026
+
2027
+ except Exception as e:
2028
+ logger.warning(f"Parameter delta computation failed: {e}")
2029
+
2030
+ return deltas
2031
+
2032
+
2033
+ def analyze_spectral_changes(
2034
+ model_before: nn.Module, model_after: nn.Module, scope: str = "ffn"
2035
+ ) -> dict[str, Any]:
2036
+ """
2037
+ Analyze spectral norm changes between model states.
2038
+
2039
+ Args:
2040
+ model_before: Model before edit
2041
+ model_after: Model after edit
2042
+ scope: Scope for spectral analysis ("ffn", "all")
2043
+
2044
+ Returns:
2045
+ Dictionary with spectral analysis results
2046
+ """
2047
+ try:
2048
+ # Import spectral analysis if available
2049
+ from invarlock.guards.spectral import compute_spectral_norms
2050
+
2051
+ before_norms = compute_spectral_norms(model_before, scope=scope)
2052
+ after_norms = compute_spectral_norms(model_after, scope=scope)
2053
+
2054
+ # Compute changes
2055
+ changes = {}
2056
+ for layer_name in before_norms:
2057
+ if layer_name in after_norms:
2058
+ before_norm = before_norms[layer_name]
2059
+ after_norm = after_norms[layer_name]
2060
+ change_ratio = after_norm / before_norm if before_norm > 0 else 1.0
2061
+ changes[layer_name] = {
2062
+ "before": before_norm,
2063
+ "after": after_norm,
2064
+ "ratio": change_ratio,
2065
+ }
2066
+
2067
+ # Summary statistics
2068
+ ratios = [change["ratio"] for change in changes.values()]
2069
+ summary = {
2070
+ "layer_changes": changes,
2071
+ "mean_ratio": float(np.mean(ratios)) if ratios else 1.0,
2072
+ "max_ratio": float(np.max(ratios)) if ratios else 1.0,
2073
+ "min_ratio": float(np.min(ratios)) if ratios else 1.0,
2074
+ "layers_analyzed": len(changes),
2075
+ }
2076
+
2077
+ return summary
2078
+
2079
+ except ImportError:
2080
+ logger.debug("Spectral analysis not available")
2081
+ return {"error": "spectral_analysis_unavailable"}
2082
+ except Exception as e:
2083
+ logger.warning(f"Spectral analysis failed: {e}")
2084
+ return {"error": str(e)}
2085
+
2086
+
2087
+ def analyze_rmt_changes(
2088
+ model_before: nn.Module, model_after: nn.Module
2089
+ ) -> dict[str, Any]:
2090
+ """
2091
+ Analyze RMT (Random Matrix Theory) changes between model states.
2092
+
2093
+ Args:
2094
+ model_before: Model before edit
2095
+ model_after: Model after edit
2096
+
2097
+ Returns:
2098
+ Dictionary with RMT analysis results
2099
+ """
2100
+ try:
2101
+ # Import RMT analysis if available
2102
+ from invarlock.guards.rmt import compute_mp_stats
2103
+
2104
+ before_stats = compute_mp_stats(model_before)
2105
+ after_stats = compute_mp_stats(model_after)
2106
+
2107
+ # Analyze changes in MP statistics
2108
+ changes = {}
2109
+ for layer_name in before_stats:
2110
+ if layer_name in after_stats:
2111
+ before_mp = before_stats[layer_name]
2112
+ after_mp = after_stats[layer_name]
2113
+ changes[layer_name] = {
2114
+ "before": before_mp,
2115
+ "after": after_mp,
2116
+ "stable": abs(before_mp - after_mp) < 0.1, # Stability threshold
2117
+ }
2118
+
2119
+ # Count stable vs unstable layers
2120
+ stable_count = sum(
2121
+ 1 for change in changes.values() if change.get("stable", False)
2122
+ )
2123
+ total_count = len(changes)
2124
+
2125
+ summary = {
2126
+ "layer_changes": changes,
2127
+ "stable_layers": stable_count,
2128
+ "total_layers": total_count,
2129
+ "stability_ratio": stable_count / total_count if total_count > 0 else 0.0,
2130
+ }
2131
+
2132
+ return summary
2133
+
2134
+ except ImportError:
2135
+ logger.debug("RMT analysis not available")
2136
+ return {"error": "rmt_analysis_unavailable"}
2137
+ except Exception as e:
2138
+ logger.warning(f"RMT analysis failed: {e}")
2139
+ return {"error": str(e)}
2140
+
2141
+
2142
+ # ── Integration with existing system ───────────────────────────────────────
2143
+
2144
+ # Update exports to include new functions (add to existing __all__ if it exists)
2145
+ try:
2146
+ __all__.extend(
2147
+ [
2148
+ "bootstrap_confidence_interval",
2149
+ "compute_ppl",
2150
+ "measure_latency",
2151
+ "measure_memory",
2152
+ "compute_parameter_deltas",
2153
+ "analyze_spectral_changes",
2154
+ "analyze_rmt_changes",
2155
+ ]
2156
+ )
2157
+ except NameError:
2158
+ # If __all__ doesn't exist, create it with the new functions
2159
+ __all__ = [
2160
+ "bootstrap_confidence_interval",
2161
+ "compute_ppl",
2162
+ "measure_latency",
2163
+ "measure_memory",
2164
+ "compute_parameter_deltas",
2165
+ "analyze_spectral_changes",
2166
+ "analyze_rmt_changes",
2167
+ ]