invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2041 @@
1
+ """
2
+ InvarLock Core Runner
3
+ =================
4
+
5
+ Main pipeline execution orchestrator: prepare → edit → guards → eval → finalize/rollback.
6
+ Torch-independent coordination with proper event logging and checkpoint management.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ import math
13
+ import os
14
+ import time
15
+ from array import array
16
+ from collections.abc import Sequence
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ from .api import Guard, ModelAdapter, ModelEdit, RunConfig, RunReport
22
+ from .auto_tuning import resolve_tier_policies
23
+ from .bootstrap import (
24
+ compute_logloss_ci,
25
+ compute_paired_delta_log_ci,
26
+ logspace_to_ratio_ci,
27
+ )
28
+ from .checkpoint import CheckpointManager
29
+ from .events import EventLogger
30
+ from .types import LogLevel, RunStatus
31
+
32
+ BOOTSTRAP_COVERAGE_REQUIREMENTS = {
33
+ # Minimum window counts and bootstrap replicates expected per policy tier.
34
+ # Individual configs can request more aggressive settings, but these values
35
+ # represent the guard-rail floor that CI profiles should maintain.
36
+ "conservative": {"preview": 220, "final": 220, "replicates": 1500},
37
+ "balanced": {"preview": 180, "final": 180, "replicates": 1200},
38
+ "aggressive": {"preview": 140, "final": 140, "replicates": 800},
39
+ }
40
+
41
+ __all__ = ["CoreRunner"]
42
+
43
+
44
+ def _collect_cuda_flags() -> dict[str, Any]:
45
+ """Capture deterministic CUDA configuration for provenance."""
46
+ flags: dict[str, Any] = {}
47
+ try:
48
+ import torch
49
+
50
+ flags["deterministic_algorithms"] = bool(
51
+ torch.are_deterministic_algorithms_enabled()
52
+ )
53
+ if hasattr(torch.backends, "cudnn"):
54
+ flags["cudnn_deterministic"] = bool(torch.backends.cudnn.deterministic)
55
+ flags["cudnn_benchmark"] = bool(torch.backends.cudnn.benchmark)
56
+ if hasattr(torch.backends.cudnn, "allow_tf32"):
57
+ flags["cudnn_allow_tf32"] = bool(torch.backends.cudnn.allow_tf32)
58
+ if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
59
+ matmul = torch.backends.cuda.matmul
60
+ if hasattr(matmul, "allow_tf32"):
61
+ flags["cuda_matmul_allow_tf32"] = bool(matmul.allow_tf32)
62
+ except Exception: # pragma: no cover - fallback when torch missing
63
+ pass
64
+
65
+ workspace = os.environ.get("CUBLAS_WORKSPACE_CONFIG")
66
+ if workspace:
67
+ flags["CUBLAS_WORKSPACE_CONFIG"] = workspace
68
+ return flags
69
+
70
+
71
+ class CoreRunner:
72
+ """
73
+ Core pipeline execution orchestrator.
74
+
75
+ Coordinates the full InvarLock pipeline while maintaining torch-independence
76
+ in the core coordination logic. Provides event logging, checkpointing,
77
+ and rollback capabilities.
78
+ """
79
+
80
+ def __init__(self):
81
+ self.event_logger: EventLogger | None = None
82
+ self.checkpoint_manager: CheckpointManager | None = None
83
+
84
+ def execute(
85
+ self,
86
+ model: Any,
87
+ adapter: ModelAdapter,
88
+ edit: ModelEdit,
89
+ guards: list[Guard],
90
+ config: RunConfig,
91
+ calibration_data: Any = None,
92
+ auto_config: dict[str, Any] | None = None,
93
+ edit_config: dict[str, Any] | None = None,
94
+ preview_n: int | None = None,
95
+ final_n: int | None = None,
96
+ ) -> RunReport:
97
+ """
98
+ Execute the full InvarLock pipeline.
99
+
100
+ Args:
101
+ model: The model to process
102
+ adapter: Model adapter for model-specific operations
103
+ edit: Edit to apply
104
+ guards: Safety guards to run
105
+ config: Runtime configuration
106
+ calibration_data: Optional calibration/validation data
107
+ auto_config: Optional auto-tuning configuration
108
+
109
+ Returns:
110
+ RunReport with execution results
111
+ """
112
+ # Initialize services
113
+ self._initialize_services(config)
114
+
115
+ # Create report
116
+ report = RunReport()
117
+ report.meta["cuda_flags"] = _collect_cuda_flags()
118
+ report.meta["start_time"] = time.time()
119
+ report.meta["config"] = self._serialize_config(config)
120
+ if config.context:
121
+ try:
122
+ report.context.update(config.context)
123
+ except Exception:
124
+ # Defensive: ensure context remains a dict even if update fails
125
+ report.context = dict(config.context)
126
+
127
+ if isinstance(config.context, dict):
128
+ run_id = config.context.get("run_id")
129
+ if run_id:
130
+ report.meta["run_id"] = run_id
131
+ plugins_meta = config.context.get("plugins")
132
+ if plugins_meta:
133
+ report.meta["plugins"] = plugins_meta
134
+
135
+ # Store auto configuration for tier resolution
136
+ if auto_config:
137
+ report.meta["auto"] = auto_config
138
+
139
+ report.status = RunStatus.RUNNING.value
140
+
141
+ try:
142
+ # Log start
143
+ self._log_event(
144
+ "runner",
145
+ "start",
146
+ LogLevel.INFO,
147
+ {
148
+ "edit": edit.name,
149
+ "guards": [g.name for g in guards],
150
+ "context": report.context,
151
+ },
152
+ )
153
+
154
+ # Phase 1: Prepare (describe model, create checkpoint)
155
+ model_desc = self._prepare_phase(model, adapter, report)
156
+
157
+ # Phase 2: Prepare guards (must happen before edit)
158
+ self._prepare_guards_phase(
159
+ model, adapter, guards, calibration_data, report, auto_config
160
+ )
161
+
162
+ # Phase 3: Apply edit
163
+ self._edit_phase(model, adapter, edit, model_desc, report, edit_config)
164
+
165
+ # Phase 4: Run guards
166
+ guard_results = self._guard_phase(model, adapter, guards, report)
167
+
168
+ # Phase 5: Evaluate final metrics
169
+ metrics = self._eval_phase(
170
+ model,
171
+ adapter,
172
+ calibration_data,
173
+ report,
174
+ preview_n,
175
+ final_n,
176
+ config,
177
+ )
178
+
179
+ # Phase 6: Finalize or rollback
180
+ final_status = self._finalize_phase(
181
+ model, adapter, guard_results, metrics, config, report
182
+ )
183
+
184
+ report.status = final_status
185
+ report.meta["end_time"] = time.time()
186
+ report.meta["duration"] = (
187
+ report.meta["end_time"] - report.meta["start_time"]
188
+ )
189
+
190
+ self._log_event(
191
+ "runner",
192
+ "complete",
193
+ LogLevel.INFO,
194
+ {"status": final_status, "duration": report.meta["duration"]},
195
+ )
196
+
197
+ return report
198
+
199
+ except Exception as e:
200
+ self._handle_error(e, report)
201
+ return report
202
+
203
+ finally:
204
+ self._cleanup_services()
205
+
206
+ def _initialize_services(self, config: RunConfig) -> None:
207
+ """Initialize event logging and checkpoint services."""
208
+ if config.event_path:
209
+ run_id = None
210
+ if isinstance(config.context, dict):
211
+ run_id = config.context.get("run_id")
212
+ self.event_logger = EventLogger(config.event_path, run_id=run_id)
213
+
214
+ if config.checkpoint_interval > 0:
215
+ self.checkpoint_manager = CheckpointManager()
216
+
217
+ def _cleanup_services(self) -> None:
218
+ """Clean up services."""
219
+ if self.event_logger:
220
+ self.event_logger.close()
221
+
222
+ if self.checkpoint_manager:
223
+ self.checkpoint_manager.cleanup()
224
+
225
+ def _prepare_phase(
226
+ self, model: Any, adapter: ModelAdapter, report: RunReport
227
+ ) -> dict[str, Any]:
228
+ """Phase 1: Model preparation and analysis."""
229
+ self._log_event("prepare", "start", LogLevel.INFO)
230
+
231
+ # Describe model structure
232
+ model_desc = adapter.describe(model)
233
+ report.meta["model"] = model_desc
234
+
235
+ # Create initial checkpoint
236
+ if self.checkpoint_manager:
237
+ checkpoint_id = self.checkpoint_manager.create_checkpoint(model, adapter)
238
+ report.meta["initial_checkpoint"] = checkpoint_id
239
+ self._log_event(
240
+ "prepare", "checkpoint_created", LogLevel.INFO, {"id": checkpoint_id}
241
+ )
242
+
243
+ self._log_event(
244
+ "prepare",
245
+ "complete",
246
+ LogLevel.INFO,
247
+ {"layers": model_desc.get("n_layer", 0)},
248
+ )
249
+
250
+ return model_desc
251
+
252
+ def _edit_phase(
253
+ self,
254
+ model: Any,
255
+ adapter: ModelAdapter,
256
+ edit: ModelEdit,
257
+ model_desc: dict[str, Any],
258
+ report: RunReport,
259
+ edit_config: dict[str, Any] | None = None,
260
+ ) -> dict[str, Any]:
261
+ """Phase 2: Apply edit operation."""
262
+ edit_label = "baseline" if edit.name == "baseline" else edit.name
263
+ self._log_event("edit", "start", LogLevel.INFO, {"edit": edit_label})
264
+
265
+ # Store edit name for tier resolution
266
+ report.meta["edit_name"] = edit.name
267
+
268
+ # Check if edit can be applied
269
+ if not edit.can_edit(model_desc):
270
+ raise ValueError(f"Edit '{edit.name}' cannot be applied to this model")
271
+
272
+ # Apply edit with configuration parameters
273
+ if edit_config:
274
+ edit_result = edit.apply(model, adapter, **edit_config)
275
+ else:
276
+ edit_result = edit.apply(model, adapter)
277
+ report.edit = edit_result
278
+ if not isinstance(report.context, dict):
279
+ report.context = {}
280
+ edit_context = report.context.setdefault("edit", {})
281
+ if isinstance(edit_result, dict):
282
+ edit_context.setdefault("name", edit_result.get("name", edit.name))
283
+ deltas = edit_result.get("deltas") or {}
284
+ if isinstance(deltas, dict):
285
+ edit_context["params_changed"] = deltas.get("params_changed", 0)
286
+ edit_context["layers_modified"] = deltas.get("layers_modified", 0)
287
+ else:
288
+ edit_context.setdefault("params_changed", 0)
289
+ else:
290
+ edit_context.setdefault("name", edit.name)
291
+ edit_context.setdefault("params_changed", 0)
292
+
293
+ self._log_event(
294
+ "edit",
295
+ "complete",
296
+ LogLevel.INFO,
297
+ {"edit": edit.name, "result": edit_result},
298
+ )
299
+
300
+ return edit_result
301
+
302
+ def _prepare_guards_phase(
303
+ self,
304
+ model: Any,
305
+ adapter: ModelAdapter,
306
+ guards: list[Guard],
307
+ calibration_data: Any,
308
+ report: RunReport,
309
+ auto_config: dict[str, Any] | None = None,
310
+ ) -> None:
311
+ """Phase 2: Prepare safety guards with tier-resolved policies."""
312
+ self._log_event(
313
+ "guards_prepare", "start", LogLevel.INFO, {"count": len(guards)}
314
+ )
315
+
316
+ # Resolve tier policies before guard preparation
317
+ tier_policies = self._resolve_guard_policies(report, auto_config)
318
+
319
+ for guard in guards:
320
+ self._log_event(
321
+ "guard_prepare", "start", LogLevel.INFO, {"guard": guard.name}
322
+ )
323
+
324
+ try:
325
+ guard_policy: dict[str, Any] = tier_policies.get(guard.name, {})
326
+
327
+ # Apply tier-resolved policy to guard
328
+ if guard_policy:
329
+ self._apply_guard_policy(guard, guard_policy)
330
+ self._log_event(
331
+ "guard_prepare",
332
+ "policy_applied",
333
+ LogLevel.INFO,
334
+ {"guard": guard.name, "policy": guard_policy},
335
+ )
336
+
337
+ if hasattr(guard, "set_run_context"):
338
+ try:
339
+ guard.set_run_context(report)
340
+ except Exception as exc:
341
+ self._log_event(
342
+ "guard_prepare",
343
+ "context_error",
344
+ LogLevel.WARNING,
345
+ {"guard": guard.name, "error": str(exc)},
346
+ )
347
+
348
+ # Call prepare method if it exists (most guards need this)
349
+ if hasattr(guard, "prepare"):
350
+ prepare_result = guard.prepare(
351
+ model, adapter, calibration_data, guard_policy
352
+ )
353
+ self._log_event(
354
+ "guard_prepare",
355
+ "complete",
356
+ LogLevel.INFO,
357
+ {
358
+ "guard": guard.name,
359
+ "ready": prepare_result.get("ready", False),
360
+ },
361
+ )
362
+ else:
363
+ self._log_event(
364
+ "guard_prepare",
365
+ "skipped",
366
+ LogLevel.INFO,
367
+ {"guard": guard.name, "reason": "no_prepare_method"},
368
+ )
369
+
370
+ except Exception as e:
371
+ self._log_event(
372
+ "guard_prepare",
373
+ "error",
374
+ LogLevel.ERROR,
375
+ {"guard": guard.name, "error": str(e)},
376
+ )
377
+
378
+ # Store resolved policies in report for certificate
379
+ report.meta["tier_policies"] = tier_policies
380
+
381
+ self._log_event(
382
+ "guards_prepare", "complete", LogLevel.INFO, {"count": len(guards)}
383
+ )
384
+
385
+ def _guard_phase(
386
+ self, model: Any, adapter: ModelAdapter, guards: list[Guard], report: RunReport
387
+ ) -> dict[str, dict[str, Any]]:
388
+ """Phase 4: Run safety guards."""
389
+ self._log_event("guards", "start", LogLevel.INFO, {"count": len(guards)})
390
+
391
+ guard_results = {}
392
+
393
+ for guard in guards:
394
+ self._log_event("guard", "start", LogLevel.INFO, {"guard": guard.name})
395
+
396
+ if hasattr(guard, "set_run_context"):
397
+ try:
398
+ guard.set_run_context(report)
399
+ except Exception as exc: # pragma: no cover - defensive
400
+ self._log_event(
401
+ "guard",
402
+ "context_error",
403
+ LogLevel.WARNING,
404
+ {"guard": guard.name, "error": str(exc)},
405
+ )
406
+
407
+ try:
408
+ result = guard.validate(model, adapter, report.context)
409
+ guard_results[guard.name] = result
410
+
411
+ # Log guard result
412
+ status = "passed" if result.get("passed", False) else "failed"
413
+ self._log_event(
414
+ "guard",
415
+ "complete",
416
+ LogLevel.INFO,
417
+ {"guard": guard.name, "status": status},
418
+ )
419
+
420
+ except Exception as e:
421
+ guard_results[guard.name] = {"passed": False, "error": str(e)}
422
+ self._log_event(
423
+ "guard",
424
+ "error",
425
+ LogLevel.ERROR,
426
+ {"guard": guard.name, "error": str(e)},
427
+ )
428
+
429
+ report.guards = guard_results
430
+
431
+ # Summary
432
+ passed_guards = sum(1 for r in guard_results.values() if r.get("passed", False))
433
+ self._log_event(
434
+ "guards",
435
+ "complete",
436
+ LogLevel.INFO,
437
+ {"total": len(guards), "passed": passed_guards},
438
+ )
439
+
440
+ return guard_results
441
+
442
+ def _eval_phase(
443
+ self,
444
+ model: Any,
445
+ adapter: ModelAdapter,
446
+ calibration_data: Any,
447
+ report: RunReport,
448
+ preview_n: int | None = None,
449
+ final_n: int | None = None,
450
+ config: RunConfig | None = None,
451
+ ) -> dict[str, Any]:
452
+ """Phase 4: Final evaluation metrics."""
453
+ self._log_event("eval", "start", LogLevel.INFO)
454
+
455
+ if calibration_data is not None:
456
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
457
+ length_hint = None
458
+ try:
459
+ length_hint = len(calibration_data) # type: ignore[arg-type]
460
+ except Exception: # pragma: no cover - defensive
461
+ length_hint = None
462
+ first_batch = None
463
+ indexable = hasattr(calibration_data, "__getitem__")
464
+ if isinstance(calibration_data, list | tuple):
465
+ if calibration_data:
466
+ first_batch = calibration_data[0]
467
+ elif indexable:
468
+ try:
469
+ first_batch = calibration_data[0] # type: ignore[index]
470
+ except Exception: # pragma: no cover - defensive
471
+ first_batch = None
472
+ masked_preview = None
473
+ first_keys = None
474
+ if isinstance(first_batch, dict):
475
+ first_keys = list(first_batch.keys())
476
+ labels_preview = first_batch.get("labels")
477
+ if isinstance(labels_preview, list | tuple):
478
+ try:
479
+ masked_preview = sum(
480
+ 1 for tok in labels_preview if tok != -100
481
+ )
482
+ except Exception: # pragma: no cover - defensive
483
+ masked_preview = None
484
+ self._log_event(
485
+ "eval",
486
+ "calibration_snapshot",
487
+ LogLevel.DEBUG,
488
+ {
489
+ "calibration_type": type(calibration_data).__name__,
490
+ "length_hint": length_hint,
491
+ "indexable": bool(indexable),
492
+ "first_batch_keys": first_keys,
493
+ "first_batch_masked": masked_preview,
494
+ },
495
+ )
496
+ # Compute real perplexity using calibration data
497
+ metrics, eval_windows = self._compute_real_metrics(
498
+ model,
499
+ calibration_data,
500
+ adapter,
501
+ preview_n,
502
+ final_n,
503
+ config,
504
+ )
505
+ else:
506
+ # Fallback to mock metrics if no calibration data
507
+ self._log_event(
508
+ "eval",
509
+ "warning",
510
+ LogLevel.WARNING,
511
+ {"message": "No calibration data provided, using mock metrics"},
512
+ )
513
+ # Provide a minimal primary_metric snapshot and basic perf counters
514
+ metrics = {
515
+ "primary_metric": {
516
+ "kind": "ppl_causal",
517
+ "preview": 25.0,
518
+ "final": 26.0,
519
+ },
520
+ "latency_ms_per_tok": 15.0,
521
+ "memory_mb_peak": 1024.0,
522
+ }
523
+ eval_windows = {"preview": {}, "final": {}}
524
+
525
+ # Store metrics in report
526
+ if hasattr(report, "metrics"):
527
+ report.metrics.update(metrics)
528
+ else:
529
+ report.metrics = metrics
530
+
531
+ report.evaluation_windows = eval_windows
532
+
533
+ self._log_event("eval", "complete", LogLevel.INFO, {"metrics": metrics})
534
+
535
+ return metrics
536
+
537
+ def _compute_real_metrics(
538
+ self,
539
+ model: Any,
540
+ calibration_data: Any,
541
+ adapter: ModelAdapter,
542
+ preview_n: int | None = None,
543
+ final_n: int | None = None,
544
+ config: RunConfig | None = None,
545
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
546
+ """Compute real evaluation metrics using calibration data."""
547
+ import os
548
+
549
+ import psutil
550
+ import torch
551
+
552
+ _ = adapter # Adapter kept for API parity; direct HF forward used.
553
+
554
+ model.eval()
555
+
556
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
557
+ print(
558
+ f"[debug] compute_real_metrics preview_n={preview_n} final_n={final_n} calibration_len={len(calibration_data) if hasattr(calibration_data, '__len__') else 'n/a'}"
559
+ )
560
+ device = next(model.parameters()).device
561
+
562
+ eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
563
+ if eval_device_override:
564
+ override_device = torch.device(eval_device_override)
565
+ if override_device != device:
566
+ model.to(override_device)
567
+ device = override_device
568
+
569
+ process = psutil.Process(os.getpid())
570
+ initial_memory = process.memory_info().rss / 1024 / 1024 # MB
571
+
572
+ total_available = (
573
+ len(calibration_data) if hasattr(calibration_data, "__len__") else 0
574
+ )
575
+ if total_available == 0:
576
+ raise ValueError("Calibration data is empty; cannot compute metrics.")
577
+
578
+ if preview_n is None:
579
+ preview_n = max(total_available // 2, 1)
580
+ if final_n is None:
581
+ final_n = preview_n
582
+
583
+ preview_n = max(int(preview_n), 0)
584
+ final_n = max(int(final_n), 0)
585
+ max_needed = max(preview_n, final_n)
586
+ if max_needed <= 0:
587
+ raise ValueError("preview_n and final_n cannot both be zero.")
588
+
589
+ if max_needed > total_available:
590
+ self._log_event(
591
+ "eval",
592
+ "data_scaled",
593
+ LogLevel.WARNING,
594
+ {
595
+ "requested_preview": preview_n,
596
+ "requested_final": final_n,
597
+ "available": total_available,
598
+ },
599
+ )
600
+ preview_n = min(preview_n, total_available)
601
+ final_n = min(final_n, total_available)
602
+ max_needed = max(preview_n, final_n)
603
+
604
+ requested_preview = preview_n
605
+ requested_final = final_n
606
+
607
+ max_needed = preview_n + final_n
608
+ if max_needed > total_available:
609
+ self._log_event(
610
+ "eval",
611
+ "window_shortage",
612
+ LogLevel.WARNING,
613
+ {
614
+ "requested_preview": preview_n,
615
+ "requested_final": final_n,
616
+ "available": total_available,
617
+ },
618
+ )
619
+ max_needed = min(total_available, max_needed)
620
+
621
+ preview_n = min(preview_n, total_available)
622
+ final_start = preview_n
623
+ remaining = max(total_available - preview_n, 0)
624
+ if final_n > remaining:
625
+ self._log_event(
626
+ "eval",
627
+ "final_window_shortage",
628
+ LogLevel.WARNING,
629
+ {
630
+ "requested_final": final_n,
631
+ "available_after_preview": remaining,
632
+ "requested_preview": requested_preview,
633
+ "requested_final_original": requested_final,
634
+ },
635
+ )
636
+ final_n = remaining
637
+
638
+ preview_data = calibration_data[:preview_n]
639
+ final_data = calibration_data[final_start : final_start + final_n]
640
+
641
+ eval_context: dict[str, Any] = {}
642
+ if config and isinstance(config.context, dict):
643
+ eval_context = config.context.get("eval", {}) or {}
644
+
645
+ loss_cfg = (
646
+ eval_context.get("loss", {}) if isinstance(eval_context, dict) else {}
647
+ )
648
+ resolved_loss_mode = str(
649
+ loss_cfg.get("resolved_type") or loss_cfg.get("type") or ""
650
+ ).lower()
651
+ bootstrap_cfg = eval_context.get("bootstrap", {}) or {}
652
+ bootstrap_enabled = bool(bootstrap_cfg.get("enabled", True))
653
+ bootstrap_method = str(
654
+ bootstrap_cfg.get("method", "bca_paired_delta_log")
655
+ ).lower()
656
+ bootstrap_replicates = int(
657
+ bootstrap_cfg.get("replicates", bootstrap_cfg.get("n", 1000) or 1000)
658
+ )
659
+ bootstrap_alpha = float(bootstrap_cfg.get("alpha", 0.05) or 0.05)
660
+ bootstrap_seed_cfg = bootstrap_cfg.get("seed")
661
+ ci_band = float(bootstrap_cfg.get("ci_band", 0.10) or 0.10)
662
+
663
+ single_method = "bca"
664
+ delta_method = "bca"
665
+ if bootstrap_method == "percentile":
666
+ single_method = "percentile"
667
+ delta_method = "percentile"
668
+ elif bootstrap_method == "bca_paired_delta_log":
669
+ single_method = "bca"
670
+ delta_method = "bca"
671
+ else:
672
+ single_method = bootstrap_method
673
+ delta_method = bootstrap_method
674
+
675
+ dataset_seed = None
676
+ profile_label = ""
677
+ pairing_context: dict[str, Any] = {}
678
+ if config and isinstance(config.context, dict):
679
+ dataset_cfg = config.context.get("dataset", {})
680
+ if isinstance(dataset_cfg, dict):
681
+ dataset_seed = dataset_cfg.get("seed")
682
+ profile_label = str(config.context.get("profile", "")).lower()
683
+ pairing_context = config.context.get("pairing_baseline", {}) or {}
684
+
685
+ bootstrap_seed = (
686
+ bootstrap_seed_cfg if bootstrap_seed_cfg is not None else dataset_seed
687
+ )
688
+ try:
689
+ bootstrap_seed = int(bootstrap_seed) if bootstrap_seed is not None else 0
690
+ except (TypeError, ValueError):
691
+ bootstrap_seed = 0
692
+
693
+ if bootstrap_replicates <= 0:
694
+ bootstrap_enabled = False
695
+ if not (0.0 < bootstrap_alpha < 1.0):
696
+ bootstrap_alpha = 0.05
697
+
698
+ pm_preview = 50.0
699
+ pm_final = 50.0
700
+ pm_ratio = 1.0
701
+ ratio_ci: tuple[float, float] = (pm_ratio, pm_ratio)
702
+ preview_log_ci: tuple[float, float] = (
703
+ math.log(pm_preview),
704
+ math.log(pm_preview),
705
+ )
706
+ final_log_ci: tuple[float, float] = (math.log(pm_final), math.log(pm_final))
707
+ delta_log_ci: tuple[float, float] = (0.0, 0.0)
708
+ preview_mean_log = math.log(pm_preview)
709
+ final_mean_log = math.log(pm_final)
710
+ delta_mean_log = 0.0
711
+ preview_log_losses: list[float] = []
712
+ final_log_losses: list[float] = []
713
+ preview_tokens_ct = 0
714
+ final_tokens_ct = 0
715
+ preview_batches_ct = 0
716
+ final_batches_ct = 0
717
+ window_overlap_fraction = 0.0
718
+ # Defaults for pairing metrics to avoid unbound locals on error paths
719
+ window_match_fraction = 1.0
720
+ pairing_reason = None
721
+ preview_pair_stats = {"matched": 0, "expected": 0}
722
+ final_pair_stats = {"matched": 0, "expected": 0}
723
+ preview_window_ids: list[int] = []
724
+ final_window_ids: list[int] = []
725
+ preview_tokens: list[list[int]] = []
726
+ final_tokens: list[list[int]] = []
727
+ preview_limit = min(preview_n, len(preview_data)) if preview_data else 0
728
+ final_limit = min(final_n, len(final_data)) if final_data else 0
729
+
730
+ # Safe defaults in case of early exceptions inside compute block
731
+ preview_actual_tokens_ct = int(preview_tokens_ct)
732
+ final_actual_tokens_ct = int(final_tokens_ct)
733
+ preview_masked_total = int(preview_tokens_ct)
734
+ final_masked_total = int(final_tokens_ct)
735
+ preview_token_counts = []
736
+ final_token_counts = []
737
+ preview_attention_masks: list[list[int]] = []
738
+ final_attention_masks: list[list[int]] = []
739
+ preview_mask_counts: list[int] = []
740
+ final_mask_counts: list[int] = []
741
+ preview_labels: list[list[int]] = []
742
+ final_labels: list[list[int]] = []
743
+ preview_actual_token_counts: list[int] = []
744
+ final_actual_token_counts: list[int] = []
745
+
746
+ # Defaults for degeneracy flags
747
+ degenerate_delta = False
748
+ degenerate_reason: str | None = None
749
+
750
+ bootstrap_info = {
751
+ "enabled": bool(bootstrap_enabled),
752
+ "method": bootstrap_method,
753
+ "alpha": float(bootstrap_alpha),
754
+ "replicates": int(bootstrap_replicates),
755
+ "seed": int(bootstrap_seed),
756
+ "ci_band": float(ci_band),
757
+ }
758
+
759
+ alignment_logged = False
760
+
761
+ # Initialize to safe defaults to ensure later metrics assembly succeeds
762
+ # even if an exception occurs during the main compute block.
763
+ delta_samples: list[float] = []
764
+ delta_weights: list[float] = []
765
+
766
+ try:
767
+
768
+ def _resolve_limit(batches: Sequence[Any], requested: int) -> int:
769
+ if not batches:
770
+ return 0
771
+ if requested <= 0:
772
+ return len(batches)
773
+ return min(len(batches), requested)
774
+
775
+ def _compute_slice_summary(
776
+ batches: Sequence[Any],
777
+ max_batches: int,
778
+ start_idx: int,
779
+ ) -> dict[str, Any]:
780
+ nonlocal alignment_logged
781
+
782
+ total_tokens_local = 0
783
+ actual_tokens_local = 0
784
+ weighted_log_loss = 0.0
785
+ log_losses: list[float] = []
786
+ window_ids: list[int] = []
787
+ collected_tokens: list[list[int]] = []
788
+ collected_attn: list[list[int]] = []
789
+ collected_labels: list[list[int]] = []
790
+ token_counts: list[int] = []
791
+ masked_token_counts: list[int] = []
792
+ actual_token_counts: list[int] = []
793
+ count = 0
794
+ zero_mask_batches = 0
795
+ any_labels_seen = False
796
+ store_windows = os.environ.get(
797
+ "INVARLOCK_STORE_EVAL_WINDOWS", "1"
798
+ ).lower() not in {"0", "false", "no"}
799
+
800
+ if not batches:
801
+ return {
802
+ "ppl": float("nan"),
803
+ "total_tokens": 0,
804
+ "num_batches": 0,
805
+ "log_losses": [],
806
+ "window_ids": [],
807
+ "tokens": [],
808
+ "attention_masks": [],
809
+ "weighted_log_loss": 0.0,
810
+ "window_token_counts": [],
811
+ }
812
+
813
+ limit = _resolve_limit(batches, max_batches)
814
+
815
+ for batch in batches[:limit]:
816
+ if max_batches > 0 and count >= max_batches:
817
+ break
818
+
819
+ labels = None
820
+ if isinstance(batch, dict):
821
+ input_ids = batch.get("input_ids", batch.get("inputs"))
822
+ attention_mask = batch.get("attention_mask")
823
+ labels = batch.get("labels")
824
+ else:
825
+ input_ids = batch
826
+ attention_mask = None
827
+
828
+ if input_ids is None:
829
+ continue
830
+
831
+ if isinstance(input_ids, torch.Tensor):
832
+ input_ids_t = input_ids.to(device=device, dtype=torch.long)
833
+ else:
834
+ input_ids_t = torch.as_tensor(
835
+ input_ids, device=device, dtype=torch.long
836
+ )
837
+
838
+ if input_ids_t.dim() == 1:
839
+ input_ids_t = input_ids_t.unsqueeze(0)
840
+
841
+ attn_t = None
842
+ if attention_mask is not None:
843
+ if isinstance(attention_mask, torch.Tensor):
844
+ attn_t = attention_mask.to(device=device, dtype=torch.long)
845
+ else:
846
+ attn_t = torch.as_tensor(
847
+ attention_mask, device=device, dtype=torch.long
848
+ )
849
+ if attn_t.dim() == 1:
850
+ attn_t = attn_t.unsqueeze(0)
851
+
852
+ if labels is not None:
853
+ any_labels_seen = True
854
+ if isinstance(labels, torch.Tensor):
855
+ labels_t = labels.to(device=device, dtype=torch.long)
856
+ else:
857
+ labels_t = torch.as_tensor(
858
+ labels, device=device, dtype=torch.long
859
+ )
860
+ if labels_t.dim() == 1:
861
+ labels_t = labels_t.unsqueeze(0)
862
+ else:
863
+ labels_t = input_ids_t.clone()
864
+ if attn_t is not None:
865
+ labels_t = labels_t.masked_fill(attn_t == 0, -100)
866
+
867
+ snapshot = input_ids_t.detach().cpu()
868
+ attn_snapshot = (
869
+ attn_t.detach().cpu() if attn_t is not None else None
870
+ )
871
+
872
+ with torch.no_grad():
873
+ if attn_t is not None:
874
+ outputs = model(
875
+ input_ids_t, attention_mask=attn_t, labels=labels_t
876
+ )
877
+ else:
878
+ outputs = model(input_ids_t, labels=labels_t)
879
+
880
+ loss_val = (
881
+ outputs.loss.item()
882
+ if hasattr(outputs, "loss") and hasattr(outputs.loss, "item")
883
+ else None
884
+ )
885
+ if loss_val is None:
886
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
887
+ self._log_event(
888
+ "eval",
889
+ "missing_loss",
890
+ LogLevel.DEBUG,
891
+ {
892
+ "has_loss_attr": bool(hasattr(outputs, "loss")),
893
+ "labels_provided": bool(labels is not None),
894
+ "window_index": start_idx + count,
895
+ },
896
+ )
897
+ continue
898
+
899
+ if attn_snapshot is not None:
900
+ tokens_in_batch = int(attn_snapshot.sum().item())
901
+ else:
902
+ tokens_in_batch = int(input_ids_t.numel())
903
+
904
+ if tokens_in_batch <= 0:
905
+ continue
906
+
907
+ masked_tokens_batch = int((labels_t != -100).sum().item())
908
+ effective_masked = masked_tokens_batch
909
+ if labels is not None and masked_tokens_batch <= 0:
910
+ zero_mask_batches += 1
911
+ effective_masked = tokens_in_batch
912
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
913
+ sample_labels = None
914
+ try:
915
+ sample_labels = labels_t[0].detach().cpu().tolist()[:8]
916
+ except Exception: # pragma: no cover - defensive
917
+ sample_labels = None
918
+ self._log_event(
919
+ "eval",
920
+ "zero_mask_batch",
921
+ LogLevel.WARNING,
922
+ {
923
+ "window_index": start_idx + count,
924
+ "tokens_in_batch": tokens_in_batch,
925
+ "masked_tokens": masked_tokens_batch,
926
+ "labels_sample": sample_labels,
927
+ "fallback_weight": effective_masked,
928
+ },
929
+ )
930
+ effective_weight = (
931
+ effective_masked if labels is not None else tokens_in_batch
932
+ )
933
+ if effective_weight <= 0:
934
+ continue
935
+
936
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
937
+ print(
938
+ f"[debug] eval batch loss={float(loss_val):.6f} masked_tokens={masked_tokens_batch} tokens_in_batch={tokens_in_batch}"
939
+ )
940
+
941
+ if store_windows:
942
+ for row in snapshot:
943
+ collected_tokens.append(row.tolist())
944
+
945
+ if attn_snapshot is not None:
946
+ for row in attn_snapshot:
947
+ collected_attn.append(row.tolist())
948
+ else:
949
+ for row in snapshot:
950
+ collected_attn.append([1] * len(row))
951
+ collected_labels.extend(labels_t.detach().cpu().tolist())
952
+
953
+ if not alignment_logged:
954
+ self._log_event(
955
+ "eval",
956
+ "label_alignment",
957
+ LogLevel.INFO,
958
+ {
959
+ "ignore_index": -100,
960
+ "used_attention_mask": bool(attn_snapshot is not None),
961
+ "tokens_in_batch": tokens_in_batch,
962
+ "masked_tokens": masked_tokens_batch,
963
+ },
964
+ )
965
+ alignment_logged = True
966
+
967
+ log_losses.append(float(loss_val))
968
+ actual_tokens_local += tokens_in_batch
969
+ total_tokens_local += effective_weight
970
+ weighted_log_loss += float(loss_val) * effective_weight
971
+ token_counts.append(effective_weight)
972
+ masked_token_counts.append(masked_tokens_batch)
973
+ if labels is not None and masked_tokens_batch <= 0:
974
+ masked_token_counts[-1] = effective_masked
975
+ actual_token_counts.append(tokens_in_batch)
976
+ window_ids.append(start_idx + count)
977
+ count += 1
978
+
979
+ if count == 0:
980
+ if zero_mask_batches and os.environ.get("INVARLOCK_DEBUG_TRACE"):
981
+ self._log_event(
982
+ "eval",
983
+ "zero_mask_total",
984
+ LogLevel.ERROR,
985
+ {
986
+ "zero_mask_batches": zero_mask_batches,
987
+ "requested": limit,
988
+ },
989
+ )
990
+ if resolved_loss_mode == "mlm":
991
+ error_msg = (
992
+ "MLM evaluation produced zero usable batches; "
993
+ "ensure baseline pairing includes masked tokens."
994
+ )
995
+ if any_labels_seen:
996
+ error_msg = (
997
+ "MLM evaluation saw labels but zero masked tokens were accumulated; "
998
+ "check calibration data integrity."
999
+ )
1000
+ self._log_event(
1001
+ "eval",
1002
+ "mlm_missing_masks",
1003
+ LogLevel.ERROR,
1004
+ {
1005
+ "any_labels": bool(any_labels_seen),
1006
+ "requested": limit,
1007
+ "zero_mask_batches": zero_mask_batches,
1008
+ },
1009
+ )
1010
+ raise ValueError(error_msg)
1011
+ return {
1012
+ "ppl": float("nan"),
1013
+ "total_tokens": total_tokens_local,
1014
+ "actual_total_tokens": actual_tokens_local,
1015
+ "num_batches": 0,
1016
+ "log_losses": [],
1017
+ "window_ids": [],
1018
+ "tokens": [],
1019
+ "attention_masks": [],
1020
+ "weighted_log_loss": 0.0,
1021
+ "window_token_counts": [],
1022
+ "masked_token_counts": [],
1023
+ "actual_token_counts": [],
1024
+ "labels": [],
1025
+ }
1026
+
1027
+ mean_loss = (
1028
+ weighted_log_loss / total_tokens_local
1029
+ if total_tokens_local > 0
1030
+ else sum(log_losses) / max(count, 1)
1031
+ )
1032
+ return {
1033
+ "ppl": float(math.exp(mean_loss)),
1034
+ "total_tokens": total_tokens_local,
1035
+ "num_batches": count,
1036
+ "log_losses": log_losses,
1037
+ "window_ids": window_ids,
1038
+ "tokens": collected_tokens,
1039
+ "attention_masks": collected_attn,
1040
+ "weighted_log_loss": weighted_log_loss,
1041
+ "window_token_counts": token_counts,
1042
+ "masked_token_counts": masked_token_counts,
1043
+ "actual_token_counts": actual_token_counts,
1044
+ "labels": collected_labels,
1045
+ "actual_total_tokens": actual_tokens_local,
1046
+ }
1047
+
1048
+ preview_limit = _resolve_limit(preview_data, preview_n)
1049
+ final_limit = _resolve_limit(final_data, final_n)
1050
+
1051
+ preview_summary = _compute_slice_summary(preview_data, preview_limit, 0)
1052
+ final_summary = _compute_slice_summary(
1053
+ final_data, final_limit, preview_summary["num_batches"]
1054
+ )
1055
+
1056
+ preview_log_losses = preview_summary["log_losses"]
1057
+ final_log_losses = final_summary["log_losses"]
1058
+ preview_tokens_ct = preview_summary["total_tokens"]
1059
+ final_tokens_ct = final_summary["total_tokens"]
1060
+ preview_batches_ct = preview_summary["num_batches"]
1061
+ final_batches_ct = final_summary["num_batches"]
1062
+ preview_window_ids = list(preview_summary["window_ids"])
1063
+ final_window_ids = list(final_summary["window_ids"])
1064
+ preview_tokens = list(preview_summary["tokens"])
1065
+ final_tokens = list(final_summary["tokens"])
1066
+ preview_token_counts = list(preview_summary.get("window_token_counts", []))
1067
+ final_token_counts = list(final_summary.get("window_token_counts", []))
1068
+ preview_attention_masks = list(preview_summary.get("attention_masks", []))
1069
+ final_attention_masks = list(final_summary.get("attention_masks", []))
1070
+ preview_mask_counts = list(preview_summary.get("masked_token_counts", []))
1071
+ final_mask_counts = list(final_summary.get("masked_token_counts", []))
1072
+ preview_labels = list(preview_summary.get("labels", []))
1073
+ final_labels = list(final_summary.get("labels", []))
1074
+ preview_actual_token_counts = list(
1075
+ preview_summary.get("actual_token_counts", [])
1076
+ )
1077
+ final_actual_token_counts = list(
1078
+ final_summary.get("actual_token_counts", [])
1079
+ )
1080
+ preview_actual_tokens_ct = int(
1081
+ preview_summary.get("actual_total_tokens", preview_tokens_ct)
1082
+ )
1083
+ final_actual_tokens_ct = int(
1084
+ final_summary.get("actual_total_tokens", final_tokens_ct)
1085
+ )
1086
+ preview_masked_total = (
1087
+ sum(preview_mask_counts)
1088
+ if preview_mask_counts
1089
+ else int(preview_tokens_ct)
1090
+ )
1091
+ final_masked_total = (
1092
+ sum(final_mask_counts) if final_mask_counts else int(final_tokens_ct)
1093
+ )
1094
+ preview_weighted_loss = float(preview_summary.get("weighted_log_loss", 0.0))
1095
+ final_weighted_loss = float(final_summary.get("weighted_log_loss", 0.0))
1096
+
1097
+ if preview_tokens_ct > 0:
1098
+ preview_mean_log = float(preview_weighted_loss / preview_tokens_ct)
1099
+ pm_preview = math.exp(preview_mean_log)
1100
+ elif preview_log_losses:
1101
+ preview_mean_log = float(np.mean(preview_log_losses))
1102
+ pm_preview = math.exp(preview_mean_log)
1103
+ else:
1104
+ pm_preview = preview_summary["ppl"]
1105
+ if not math.isfinite(pm_preview) or pm_preview <= 0:
1106
+ pm_preview = 50.0
1107
+ preview_mean_log = math.log(pm_preview)
1108
+
1109
+ if final_tokens_ct > 0:
1110
+ final_mean_log = float(final_weighted_loss / final_tokens_ct)
1111
+ pm_final = math.exp(final_mean_log)
1112
+ elif final_log_losses:
1113
+ final_mean_log = float(np.mean(final_log_losses))
1114
+ pm_final = math.exp(final_mean_log)
1115
+ else:
1116
+ pm_final = final_summary["ppl"]
1117
+ if not math.isfinite(pm_final) or pm_final <= 0:
1118
+ pm_final = 50.0
1119
+ final_mean_log = math.log(pm_final)
1120
+
1121
+ delta_mean_log = final_mean_log - preview_mean_log
1122
+ pm_ratio = math.exp(delta_mean_log)
1123
+
1124
+ if not (math.isfinite(delta_mean_log) and math.isfinite(pm_ratio)):
1125
+ raise RuntimeError("Invalid perplexity ratio or delta")
1126
+
1127
+ expected_ratio = math.exp(delta_mean_log)
1128
+ if abs(pm_ratio - expected_ratio) > 1e-6:
1129
+ raise RuntimeError(
1130
+ "Primary-metric ratio mismatch with exp(mean ΔlogNLL)"
1131
+ )
1132
+
1133
+ if bootstrap_enabled and preview_log_losses:
1134
+ preview_log_ci = compute_logloss_ci(
1135
+ preview_log_losses,
1136
+ method=single_method,
1137
+ replicates=bootstrap_replicates,
1138
+ alpha=bootstrap_alpha,
1139
+ seed=bootstrap_seed + 7,
1140
+ )
1141
+ else:
1142
+ preview_log_ci = (preview_mean_log, preview_mean_log)
1143
+
1144
+ # primary_metric consumers use log-space intervals; skip ppl-space tuple here
1145
+
1146
+ if bootstrap_enabled and final_log_losses:
1147
+ final_log_ci = compute_logloss_ci(
1148
+ final_log_losses,
1149
+ method=single_method,
1150
+ replicates=bootstrap_replicates,
1151
+ alpha=bootstrap_alpha,
1152
+ seed=bootstrap_seed + 13,
1153
+ )
1154
+ else:
1155
+ final_log_ci = (final_mean_log, final_mean_log)
1156
+
1157
+ # primary_metric consumers use log-space intervals; skip ppl-space tuple here
1158
+
1159
+ if (
1160
+ bootstrap_enabled
1161
+ and final_log_losses
1162
+ and preview_log_losses
1163
+ and len(final_log_losses)
1164
+ and len(preview_log_losses)
1165
+ ):
1166
+ delta_log_ci = compute_paired_delta_log_ci(
1167
+ final_log_losses,
1168
+ preview_log_losses,
1169
+ method=delta_method,
1170
+ replicates=bootstrap_replicates,
1171
+ alpha=bootstrap_alpha,
1172
+ seed=bootstrap_seed + 97,
1173
+ )
1174
+ ratio_ci = logspace_to_ratio_ci(delta_log_ci)
1175
+ expected_ratio_ci = tuple(math.exp(bound) for bound in delta_log_ci)
1176
+ if any(
1177
+ abs(r - e) > 1e-6
1178
+ for r, e in zip(ratio_ci, expected_ratio_ci, strict=False)
1179
+ ):
1180
+ raise RuntimeError("Ratio CI inconsistent with Δlog CI")
1181
+ else:
1182
+ delta_log_ci = (delta_mean_log, delta_mean_log)
1183
+ ratio_ci = (pm_ratio, pm_ratio)
1184
+
1185
+ delta_samples: list[float] = []
1186
+ delta_weights: list[float] = []
1187
+ if final_log_losses and preview_log_losses:
1188
+ limit = min(len(final_log_losses), len(preview_log_losses))
1189
+ if limit:
1190
+ delta_samples = [
1191
+ final_log_losses[i] - preview_log_losses[i]
1192
+ for i in range(limit)
1193
+ ]
1194
+ if preview_token_counts and len(preview_token_counts) >= limit:
1195
+ delta_weights = [
1196
+ float(max(preview_token_counts[i], 1)) for i in range(limit)
1197
+ ]
1198
+
1199
+ degenerate_delta = False
1200
+ degenerate_reason: str | None = None
1201
+ if len(delta_samples) < 2:
1202
+ if len(delta_samples) == 0:
1203
+ degenerate_delta = True
1204
+ degenerate_reason = "no_pairs"
1205
+ else:
1206
+ degenerate_delta = True
1207
+ degenerate_reason = "single_pair"
1208
+ elif np.allclose(delta_samples, delta_samples[0]):
1209
+ degenerate_delta = True
1210
+ degenerate_reason = "no_variation"
1211
+
1212
+ if degenerate_delta:
1213
+ self._log_event(
1214
+ "eval",
1215
+ "degenerate_delta_samples",
1216
+ LogLevel.ERROR,
1217
+ {
1218
+ "reason": degenerate_reason,
1219
+ "sample_count": len(delta_samples),
1220
+ },
1221
+ )
1222
+ if profile_label in {"ci", "release"}:
1223
+ raise RuntimeError(
1224
+ f"Degenerate paired ΔlogNLL distribution ({degenerate_reason})"
1225
+ )
1226
+
1227
+ def _hash_tokens(tokens: list[int]) -> bytes:
1228
+ if not tokens:
1229
+ return b""
1230
+ token_array = array("I", (int(token) & 0xFFFFFFFF for token in tokens))
1231
+ return hashlib.blake2b(token_array.tobytes(), digest_size=16).digest()
1232
+
1233
+ def _duplicate_fraction(seqs: list[list[int]]) -> float:
1234
+ if not seqs:
1235
+ return 0.0
1236
+ hashes = [_hash_tokens(seq) for seq in seqs]
1237
+ unique = len(set(hashes))
1238
+ if not hashes:
1239
+ return 0.0
1240
+ return max(0.0, (len(hashes) - unique) / len(hashes))
1241
+
1242
+ def _compare_with_baseline(
1243
+ run_ids: list[int],
1244
+ run_tokens: list[list[int]],
1245
+ baseline_section: dict[str, Any] | None,
1246
+ split_label: str,
1247
+ ) -> dict[str, Any]:
1248
+ stats = {
1249
+ "matched": 0,
1250
+ "expected": 0,
1251
+ "missing_ids": [],
1252
+ "mismatched_ids": [],
1253
+ "unexpected_ids": [],
1254
+ "reason": None,
1255
+ }
1256
+
1257
+ if not baseline_section:
1258
+ stats["matched"] = len(run_tokens)
1259
+ stats["expected"] = len(run_tokens)
1260
+ stats["reason"] = "no_baseline_reference"
1261
+ return stats
1262
+
1263
+ base_ids = baseline_section.get("window_ids") or []
1264
+ base_tokens = baseline_section.get("input_ids") or []
1265
+ if not isinstance(base_ids, list) or not isinstance(base_tokens, list):
1266
+ stats["matched"] = len(run_tokens)
1267
+ stats["expected"] = len(run_tokens)
1268
+ stats["reason"] = "invalid_baseline_reference"
1269
+ return stats
1270
+
1271
+ base_map: dict[int, bytes] = {}
1272
+ for bid, seq in zip(base_ids, base_tokens, strict=False):
1273
+ try:
1274
+ bid_int = int(bid)
1275
+ except Exception:
1276
+ continue
1277
+ seq_list = list(seq) if not isinstance(seq, list) else seq
1278
+ base_map[bid_int] = _hash_tokens(seq_list)
1279
+
1280
+ stats["expected"] = len(base_map)
1281
+ matched = 0
1282
+ seen_ids: set[int] = set()
1283
+ mismatched: list[int] = []
1284
+ unexpected: list[int] = []
1285
+
1286
+ for rid, seq in zip(run_ids, run_tokens, strict=False):
1287
+ try:
1288
+ rid_int = int(rid)
1289
+ except Exception:
1290
+ unexpected.append(rid)
1291
+ continue
1292
+
1293
+ hashed = _hash_tokens(seq)
1294
+ if rid_int not in base_map:
1295
+ unexpected.append(rid_int)
1296
+ continue
1297
+
1298
+ seen_ids.add(rid_int)
1299
+ if hashed == base_map[rid_int]:
1300
+ matched += 1
1301
+ else:
1302
+ mismatched.append(rid_int)
1303
+
1304
+ missing = [bid for bid in base_map if bid not in seen_ids]
1305
+ stats.update(
1306
+ {
1307
+ "matched": matched,
1308
+ "missing_ids": missing,
1309
+ "mismatched_ids": mismatched,
1310
+ "unexpected_ids": unexpected,
1311
+ }
1312
+ )
1313
+
1314
+ if missing:
1315
+ stats["reason"] = f"{split_label}_missing_ids:{missing[:3]}"
1316
+ elif mismatched:
1317
+ stats["reason"] = f"{split_label}_token_mismatch:{mismatched[:3]}"
1318
+ elif unexpected:
1319
+ stats["reason"] = f"{split_label}_unexpected_ids:{unexpected[:3]}"
1320
+ else:
1321
+ stats["reason"] = None
1322
+
1323
+ return stats
1324
+
1325
+ baseline_preview = (
1326
+ pairing_context.get("preview")
1327
+ if isinstance(pairing_context, dict)
1328
+ else {}
1329
+ )
1330
+ baseline_final = (
1331
+ pairing_context.get("final")
1332
+ if isinstance(pairing_context, dict)
1333
+ else {}
1334
+ )
1335
+
1336
+ preview_pair_stats = _compare_with_baseline(
1337
+ preview_window_ids, preview_tokens, baseline_preview, "preview"
1338
+ )
1339
+ final_pair_stats = _compare_with_baseline(
1340
+ final_window_ids, final_tokens, baseline_final, "final"
1341
+ )
1342
+
1343
+ total_expected = (
1344
+ preview_pair_stats["expected"] + final_pair_stats["expected"]
1345
+ )
1346
+ total_matched = preview_pair_stats["matched"] + final_pair_stats["matched"]
1347
+ window_match_fraction = (
1348
+ float(total_matched / total_expected) if total_expected > 0 else 1.0
1349
+ )
1350
+ window_overlap_fraction = _duplicate_fraction(preview_tokens + final_tokens)
1351
+
1352
+ pairing_reason = None
1353
+ if total_expected > 0:
1354
+ for stats_dict, label in (
1355
+ (preview_pair_stats, "preview"),
1356
+ (final_pair_stats, "final"),
1357
+ ):
1358
+ if (
1359
+ stats_dict["expected"]
1360
+ and stats_dict["matched"] < stats_dict["expected"]
1361
+ ):
1362
+ pairing_reason = stats_dict.get("reason") or f"{label}_mismatch"
1363
+ break
1364
+ if pairing_reason is None:
1365
+ if window_overlap_fraction > 0.0:
1366
+ pairing_reason = "duplicate_windows"
1367
+ elif not pairing_context:
1368
+ pairing_reason = preview_pair_stats.get(
1369
+ "reason"
1370
+ ) or final_pair_stats.get("reason")
1371
+
1372
+ if pairing_context and window_match_fraction < 0.999999:
1373
+ self._log_event(
1374
+ "eval",
1375
+ "window_pairing_mismatch",
1376
+ LogLevel.ERROR,
1377
+ {
1378
+ "match_fraction": window_match_fraction,
1379
+ "overlap_fraction": window_overlap_fraction,
1380
+ "reason": pairing_reason,
1381
+ "preview": preview_pair_stats,
1382
+ "final": final_pair_stats,
1383
+ },
1384
+ )
1385
+
1386
+ if window_overlap_fraction > 0.0 and pairing_context:
1387
+ self._log_event(
1388
+ "eval",
1389
+ "window_overlap_warning",
1390
+ LogLevel.WARNING,
1391
+ {
1392
+ "duplicate_fraction": window_overlap_fraction,
1393
+ "match_fraction": window_match_fraction,
1394
+ "preview": preview_pair_stats,
1395
+ "final": final_pair_stats,
1396
+ },
1397
+ )
1398
+
1399
+ if pairing_context and profile_label in {"ci", "release"}:
1400
+ if window_match_fraction < 0.999999:
1401
+ raise RuntimeError(
1402
+ f"Window pairing mismatch detected (fraction={window_match_fraction:.3f}, reason={pairing_reason})"
1403
+ )
1404
+ if window_overlap_fraction > 0.0:
1405
+ raise RuntimeError(
1406
+ f"Window duplication detected (overlap_fraction={window_overlap_fraction:.3f})"
1407
+ )
1408
+
1409
+ tier = "balanced"
1410
+ if config and isinstance(config.context, dict):
1411
+ auto_section = config.context.get("auto", {})
1412
+ if isinstance(auto_section, dict):
1413
+ tier = str(auto_section.get("tier", tier)).lower()
1414
+
1415
+ coverage_requirements = BOOTSTRAP_COVERAGE_REQUIREMENTS.get(
1416
+ tier, BOOTSTRAP_COVERAGE_REQUIREMENTS["balanced"]
1417
+ )
1418
+
1419
+ def _meets_requirement(actual: int, required: int) -> bool:
1420
+ if required <= 0:
1421
+ return True
1422
+ slack = max(1, int(required * 0.95))
1423
+ return actual >= slack
1424
+
1425
+ preview_required = int(coverage_requirements.get("preview", 0))
1426
+ final_required = int(coverage_requirements.get("final", 0))
1427
+ replicates_required = int(coverage_requirements.get("replicates", 0))
1428
+
1429
+ preview_ok = _meets_requirement(preview_batches_ct, preview_required)
1430
+ final_ok = _meets_requirement(final_batches_ct, final_required)
1431
+ replicates_ok = (
1432
+ _meets_requirement(bootstrap_replicates, replicates_required)
1433
+ if bootstrap_enabled
1434
+ else True
1435
+ )
1436
+
1437
+ if not (preview_ok and final_ok and replicates_ok):
1438
+ self._log_event(
1439
+ "eval",
1440
+ "bootstrap_coverage_warning",
1441
+ LogLevel.WARNING,
1442
+ {
1443
+ "tier": tier,
1444
+ "preview_used": preview_batches_ct,
1445
+ "preview_required": preview_required,
1446
+ "final_used": final_batches_ct,
1447
+ "final_required": final_required,
1448
+ "replicates_used": bootstrap_replicates,
1449
+ "replicates_required": replicates_required,
1450
+ },
1451
+ )
1452
+ # In CI/Release profiles, treat insufficient coverage as a hard error
1453
+ if pairing_context and profile_label in {"ci", "release"}:
1454
+ from invarlock.cli.errors import InvarlockError
1455
+
1456
+ raise InvarlockError(
1457
+ code="E005",
1458
+ message=(
1459
+ "INSUFFICIENT-SAMPLE: bootstrap coverage below policy floors in CI/Release"
1460
+ ),
1461
+ )
1462
+
1463
+ bootstrap_info.update(
1464
+ {
1465
+ "enabled": bool(bootstrap_enabled),
1466
+ "method": bootstrap_method,
1467
+ "alpha": float(bootstrap_alpha),
1468
+ "replicates": int(bootstrap_replicates),
1469
+ "seed": int(bootstrap_seed),
1470
+ "ci_band": float(ci_band),
1471
+ "window_duplicate_fraction": float(window_overlap_fraction),
1472
+ "window_match_fraction": float(window_match_fraction),
1473
+ "coverage": {
1474
+ "tier": tier,
1475
+ "preview": {
1476
+ "used": int(preview_batches_ct),
1477
+ "required": preview_required,
1478
+ "ok": bool(preview_ok),
1479
+ },
1480
+ "final": {
1481
+ "used": int(final_batches_ct),
1482
+ "required": final_required,
1483
+ "ok": bool(final_ok),
1484
+ },
1485
+ "replicates": {
1486
+ "used": int(bootstrap_replicates),
1487
+ "required": replicates_required,
1488
+ "ok": bool(replicates_ok),
1489
+ },
1490
+ },
1491
+ }
1492
+ )
1493
+
1494
+ except Exception as exc: # pragma: no cover - defensive fallback
1495
+ self._log_event(
1496
+ "eval",
1497
+ "error",
1498
+ LogLevel.ERROR,
1499
+ {"message": f"Primary-metric computation failed: {exc}"},
1500
+ )
1501
+
1502
+ pm_ratio = pm_final / pm_preview if pm_preview > 0 else 1.0
1503
+
1504
+ latency_ms_per_tok = self._measure_latency(
1505
+ model, preview_data[:1] if preview_data else final_data[:1], device
1506
+ )
1507
+
1508
+ current_memory = process.memory_info().rss / 1024 / 1024
1509
+ peak_memory = max(initial_memory, current_memory)
1510
+
1511
+ eval_samples = 0
1512
+ total_tokens = 0
1513
+ masked_total_tokens = 0
1514
+ try:
1515
+ eval_samples = int(preview_batches_ct) + int(final_batches_ct)
1516
+ total_tokens = int(preview_actual_tokens_ct) + int(final_actual_tokens_ct)
1517
+ masked_total_tokens = int(preview_masked_total) + int(final_masked_total)
1518
+ except Exception:
1519
+ pass
1520
+
1521
+ paired_windows_count = len(delta_samples)
1522
+ unweighted_delta_mean = (
1523
+ float(np.mean(delta_samples)) if delta_samples else float(delta_mean_log)
1524
+ )
1525
+ preview_weighted_delta_mean: float | None = None
1526
+ if delta_weights:
1527
+ total_weight = float(sum(delta_weights))
1528
+ if total_weight > 0.0:
1529
+ preview_weighted_delta_mean = float(
1530
+ np.dot(delta_samples, delta_weights) / total_weight
1531
+ )
1532
+ paired_delta_mean = float(delta_mean_log)
1533
+ paired_delta_std = (
1534
+ float(np.std(delta_samples, ddof=1)) if len(delta_samples) > 1 else 0.0
1535
+ )
1536
+ paired_delta_min = float(min(delta_samples)) if delta_samples else None
1537
+ paired_delta_max = float(max(delta_samples)) if delta_samples else None
1538
+
1539
+ # Resolve primary metric kind from resolved loss mode
1540
+ pm_kind = "ppl_causal"
1541
+ if resolved_loss_mode == "mlm":
1542
+ pm_kind = "ppl_mlm"
1543
+ elif resolved_loss_mode in {"seq2seq", "s2s", "t5"}:
1544
+ pm_kind = "ppl_seq2seq"
1545
+
1546
+ metrics = {
1547
+ "primary_metric": {
1548
+ "kind": pm_kind,
1549
+ "preview": float(pm_preview),
1550
+ "final": float(pm_final),
1551
+ },
1552
+ "logloss_preview": float(preview_mean_log),
1553
+ "logloss_final": float(final_mean_log),
1554
+ "logloss_delta": float(delta_mean_log),
1555
+ "logloss_preview_ci": tuple(map(float, preview_log_ci)),
1556
+ "logloss_final_ci": tuple(map(float, final_log_ci)),
1557
+ "logloss_delta_ci": tuple(map(float, delta_log_ci)),
1558
+ "latency_ms_per_tok": latency_ms_per_tok,
1559
+ "memory_mb_peak": peak_memory,
1560
+ "eval_samples": eval_samples,
1561
+ "total_tokens": total_tokens,
1562
+ "preview_total_tokens": int(preview_actual_tokens_ct),
1563
+ "final_total_tokens": int(final_actual_tokens_ct),
1564
+ "masked_tokens_total": masked_total_tokens,
1565
+ "masked_tokens_preview": int(preview_masked_total),
1566
+ "masked_tokens_final": int(final_masked_total),
1567
+ "reduction": {
1568
+ "mode": "token_mean",
1569
+ "implementation": "huggingface_cross_entropy",
1570
+ },
1571
+ "window_overlap_fraction": float(window_overlap_fraction),
1572
+ "window_match_fraction": float(window_match_fraction),
1573
+ "window_pairing_reason": pairing_reason,
1574
+ "window_pairing_preview": {
1575
+ "matched": preview_pair_stats["matched"],
1576
+ "expected": preview_pair_stats["expected"],
1577
+ "reason": preview_pair_stats.get("reason"),
1578
+ },
1579
+ "window_pairing_final": {
1580
+ "matched": final_pair_stats["matched"],
1581
+ "expected": final_pair_stats["expected"],
1582
+ "reason": final_pair_stats.get("reason"),
1583
+ },
1584
+ "bootstrap": bootstrap_info,
1585
+ "paired_windows": paired_windows_count,
1586
+ "paired_delta_summary": {
1587
+ "mean": paired_delta_mean,
1588
+ "mean_unweighted": unweighted_delta_mean,
1589
+ "mean_preview_weighted": (
1590
+ preview_weighted_delta_mean
1591
+ if preview_weighted_delta_mean is not None
1592
+ else unweighted_delta_mean
1593
+ ),
1594
+ "std": paired_delta_std,
1595
+ "min": paired_delta_min,
1596
+ "max": paired_delta_max,
1597
+ "degenerate": degenerate_delta,
1598
+ "degenerate_reason": degenerate_reason,
1599
+ },
1600
+ }
1601
+
1602
+ eval_windows = {
1603
+ "preview": {
1604
+ "window_ids": preview_window_ids[:preview_limit],
1605
+ "logloss": list(preview_log_losses),
1606
+ "input_ids": preview_tokens,
1607
+ "attention_masks": preview_attention_masks,
1608
+ "token_counts": preview_token_counts,
1609
+ "masked_token_counts": preview_mask_counts,
1610
+ "actual_token_counts": preview_actual_token_counts,
1611
+ "labels": preview_labels,
1612
+ },
1613
+ "final": {
1614
+ "window_ids": final_window_ids[:final_limit],
1615
+ "logloss": list(final_log_losses),
1616
+ "input_ids": final_tokens,
1617
+ "attention_masks": final_attention_masks,
1618
+ "token_counts": final_token_counts,
1619
+ "masked_token_counts": final_mask_counts,
1620
+ "actual_token_counts": final_actual_token_counts,
1621
+ "labels": final_labels,
1622
+ },
1623
+ }
1624
+
1625
+ return metrics, eval_windows
1626
+
1627
+ def _measure_latency(self, model: Any, sample_data: Any, device: Any) -> float:
1628
+ """Simple latency measurement for a sample."""
1629
+ import time
1630
+
1631
+ import torch
1632
+
1633
+ if not sample_data:
1634
+ return 0.0
1635
+
1636
+ # Model eval is managed by caller to avoid duplicate invocations in tests
1637
+
1638
+ # Get a sample for timing
1639
+ sample = sample_data[0] if sample_data else None
1640
+ if sample is None:
1641
+ return 0.0
1642
+
1643
+ if isinstance(sample, dict):
1644
+ input_ids = sample.get("input_ids", sample.get("inputs"))
1645
+ else:
1646
+ input_ids = sample
1647
+
1648
+ if input_ids is None:
1649
+ return 0.0
1650
+
1651
+ # Convert to tensor if needed
1652
+ if not isinstance(input_ids, torch.Tensor):
1653
+ input_ids = torch.tensor(input_ids)
1654
+
1655
+ # Some tests patch torch.tensor with dim side effects; guard against exceptions
1656
+ try:
1657
+ dim_val = input_ids.dim()
1658
+ except Exception:
1659
+ dim_val = 2 # assume already batched
1660
+ if dim_val == 1:
1661
+ try:
1662
+ input_ids = input_ids.unsqueeze(0)
1663
+ except Exception:
1664
+ pass
1665
+
1666
+ # to(device) may be a Mock; guard call
1667
+ try:
1668
+ input_ids = input_ids.to(device)
1669
+ except Exception:
1670
+ pass
1671
+
1672
+ # Simple timing measurement
1673
+ with torch.no_grad():
1674
+ try:
1675
+ # Prepare labels and attention mask if available
1676
+ labels_t = input_ids
1677
+ attn_t = None
1678
+ token_type_t = None
1679
+ if isinstance(sample, dict) and "attention_mask" in sample:
1680
+ try:
1681
+ attn_t = torch.tensor(sample["attention_mask"])
1682
+ try:
1683
+ if attn_t.dim() == 1:
1684
+ attn_t = attn_t.unsqueeze(0)
1685
+ except Exception:
1686
+ pass
1687
+ try:
1688
+ attn_t = attn_t.to(device)
1689
+ except Exception:
1690
+ pass
1691
+ except Exception:
1692
+ attn_t = None
1693
+ if isinstance(sample, dict) and "token_type_ids" in sample:
1694
+ try:
1695
+ token_type_t = torch.tensor(sample["token_type_ids"])
1696
+ if token_type_t.dim() == 1:
1697
+ token_type_t = token_type_t.unsqueeze(0)
1698
+ token_type_t = token_type_t.to(device)
1699
+ except Exception:
1700
+ token_type_t = None
1701
+
1702
+ def _call_model():
1703
+ if attn_t is not None:
1704
+ return model(
1705
+ input_ids,
1706
+ attention_mask=attn_t,
1707
+ labels=labels_t,
1708
+ token_type_ids=token_type_t,
1709
+ )
1710
+ else:
1711
+ return model(
1712
+ input_ids,
1713
+ labels=labels_t,
1714
+ token_type_ids=token_type_t,
1715
+ )
1716
+
1717
+ # Warmup
1718
+ for _ in range(3):
1719
+ _ = _call_model()
1720
+
1721
+ # Measure
1722
+ start_time = time.time()
1723
+ for _ in range(10):
1724
+ _ = _call_model()
1725
+ end_time = time.time()
1726
+
1727
+ total_time = (end_time - start_time) * 1000 # Convert to ms
1728
+ try:
1729
+ total_tokens = input_ids.numel() * 10 # 10 iterations
1730
+ except Exception:
1731
+ total_tokens = 0
1732
+
1733
+ return total_time / total_tokens if total_tokens > 0 else 0.0
1734
+
1735
+ except Exception:
1736
+ return 0.0
1737
+
1738
+ def _samples_to_dataloader(self, samples: list) -> Any:
1739
+ """
1740
+ Convert list of samples to DataLoader-compatible format.
1741
+
1742
+ Args:
1743
+ samples: List of sample dictionaries with 'input_ids' and 'attention_mask'
1744
+
1745
+ Returns:
1746
+ Simple iterable that yields batches compatible with compute_perplexity()
1747
+ """
1748
+
1749
+ class SampleDataLoader:
1750
+ def __init__(self, samples):
1751
+ self.samples = samples
1752
+
1753
+ def __iter__(self):
1754
+ for sample in self.samples:
1755
+ # Each sample is already a dict with 'input_ids' and 'attention_mask'
1756
+ # Convert to tensor format that compute_perplexity expects
1757
+ import torch
1758
+
1759
+ input_ids = sample.get("input_ids", sample.get("inputs"))
1760
+ attention_mask = sample.get("attention_mask")
1761
+
1762
+ if input_ids is None:
1763
+ continue
1764
+
1765
+ # Convert to tensors if needed and add batch dimension
1766
+ if not isinstance(input_ids, torch.Tensor):
1767
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
1768
+ if input_ids.dim() == 1:
1769
+ input_ids = input_ids.unsqueeze(0)
1770
+
1771
+ if attention_mask is not None:
1772
+ if not isinstance(attention_mask, torch.Tensor):
1773
+ attention_mask = torch.tensor(
1774
+ attention_mask, dtype=torch.long
1775
+ )
1776
+ if attention_mask.dim() == 1:
1777
+ attention_mask = attention_mask.unsqueeze(0)
1778
+
1779
+ batch = {"input_ids": input_ids}
1780
+ if attention_mask is not None:
1781
+ batch["attention_mask"] = attention_mask
1782
+
1783
+ token_type = sample.get("token_type_ids")
1784
+ if token_type is not None:
1785
+ if not isinstance(token_type, torch.Tensor):
1786
+ token_type = torch.tensor(token_type, dtype=torch.long)
1787
+ if token_type.dim() == 1:
1788
+ token_type = token_type.unsqueeze(0)
1789
+ batch["token_type_ids"] = token_type
1790
+
1791
+ labels = sample.get("labels")
1792
+ if labels is None:
1793
+ labels = input_ids.clone()
1794
+ if attention_mask is not None:
1795
+ labels = labels.masked_fill(attention_mask == 0, -100)
1796
+ else:
1797
+ if not isinstance(labels, torch.Tensor):
1798
+ labels = torch.tensor(labels, dtype=torch.long)
1799
+ if labels.dim() == 1:
1800
+ labels = labels.unsqueeze(0)
1801
+ batch["labels"] = labels
1802
+
1803
+ yield batch
1804
+
1805
+ def __len__(self):
1806
+ return len(self.samples)
1807
+
1808
+ return SampleDataLoader(samples)
1809
+
1810
+ def _finalize_phase(
1811
+ self,
1812
+ model: Any,
1813
+ adapter: ModelAdapter,
1814
+ guard_results: dict[str, dict[str, Any]],
1815
+ metrics: dict[str, Any],
1816
+ config: RunConfig,
1817
+ report: RunReport,
1818
+ ) -> str:
1819
+ """Phase 5: Finalize or rollback based on results."""
1820
+ self._log_event("finalize", "start", LogLevel.INFO)
1821
+
1822
+ # Check if guards passed
1823
+ all_guards_passed = all(r.get("passed", False) for r in guard_results.values())
1824
+
1825
+ # Check for catastrophic drift spike using primary metric preview/final
1826
+ pm = metrics.get("primary_metric", {}) if isinstance(metrics, dict) else {}
1827
+ pm_prev = pm.get("preview") if isinstance(pm, dict) else None
1828
+ pm_fin = pm.get("final") if isinstance(pm, dict) else None
1829
+ try:
1830
+ drift_ratio = (
1831
+ float(pm_fin) / float(pm_prev)
1832
+ if isinstance(pm_fin, (int | float))
1833
+ and isinstance(pm_prev, (int | float))
1834
+ and float(pm_prev) > 0.0
1835
+ else float("inf")
1836
+ )
1837
+ except Exception:
1838
+ drift_ratio = float("inf")
1839
+ spike_threshold = getattr(config, "spike_threshold", 2.0)
1840
+ is_catastrophic_spike = drift_ratio > spike_threshold
1841
+
1842
+ # Check if standard metrics are acceptable against configured max ratio
1843
+ metrics_acceptable = drift_ratio <= getattr(config, "max_pm_ratio", 2.0)
1844
+
1845
+ # Determine rollback reason and status
1846
+ rollback_reason = None
1847
+ if is_catastrophic_spike:
1848
+ rollback_reason = (
1849
+ f"catastrophic_ppl_spike (ratio: {drift_ratio:.3f} > {spike_threshold})"
1850
+ )
1851
+ status = RunStatus.ROLLBACK.value
1852
+
1853
+ self._log_event(
1854
+ "finalize",
1855
+ "catastrophic_spike_detected",
1856
+ LogLevel.ERROR,
1857
+ {
1858
+ "primary_metric_drift_ratio": drift_ratio,
1859
+ "spike_threshold": spike_threshold,
1860
+ "immediate_rollback": True,
1861
+ },
1862
+ )
1863
+ elif (not all_guards_passed) or (not metrics_acceptable):
1864
+ # Match historical/test expectation string exactly
1865
+ rollback_reason = "guards_failed or metrics_unacceptable"
1866
+ status = RunStatus.ROLLBACK.value
1867
+ else:
1868
+ status = RunStatus.SUCCESS.value
1869
+
1870
+ # Execute the determined action
1871
+ if status == RunStatus.SUCCESS.value:
1872
+ self._log_event(
1873
+ "finalize",
1874
+ "success",
1875
+ LogLevel.INFO,
1876
+ {
1877
+ "guards_passed": all_guards_passed,
1878
+ "metrics_ok": metrics_acceptable,
1879
+ },
1880
+ )
1881
+ else:
1882
+ # Perform rollback if checkpoint available
1883
+ if self.checkpoint_manager and "initial_checkpoint" in report.meta:
1884
+ checkpoint_id = report.meta["initial_checkpoint"]
1885
+ self.checkpoint_manager.restore_checkpoint(
1886
+ model, adapter, checkpoint_id
1887
+ )
1888
+ # Match test expectation: only include checkpoint and reason
1889
+ self._log_event(
1890
+ "finalize",
1891
+ "rollback",
1892
+ LogLevel.WARNING,
1893
+ {
1894
+ "checkpoint": checkpoint_id,
1895
+ "reason": rollback_reason,
1896
+ },
1897
+ )
1898
+
1899
+ # Store rollback metadata in report
1900
+ report.meta["rollback_reason"] = rollback_reason
1901
+ report.meta["rollback_checkpoint"] = checkpoint_id
1902
+ report.meta["guard_recovered"] = True
1903
+
1904
+ else:
1905
+ # Match test expectation: log without additional data payload
1906
+ self._log_event("finalize", "rollback_unavailable", LogLevel.ERROR)
1907
+
1908
+ return status
1909
+
1910
+ def _handle_error(self, error: Exception, report: RunReport) -> None:
1911
+ """Handle pipeline errors."""
1912
+ report.status = RunStatus.FAILED.value
1913
+ report.error = str(error)
1914
+ report.meta["end_time"] = time.time()
1915
+
1916
+ if "start_time" in report.meta:
1917
+ report.meta["duration"] = (
1918
+ report.meta["end_time"] - report.meta["start_time"]
1919
+ )
1920
+
1921
+ self._log_event("runner", "error", LogLevel.ERROR, {"error": str(error)})
1922
+
1923
+ # Attempt rollback on error
1924
+ if self.checkpoint_manager and "initial_checkpoint" in report.meta:
1925
+ try:
1926
+ checkpoint_id = report.meta["initial_checkpoint"]
1927
+ # Would need model and adapter here for actual rollback
1928
+ self._log_event(
1929
+ "runner",
1930
+ "emergency_rollback",
1931
+ LogLevel.WARNING,
1932
+ {"checkpoint": checkpoint_id},
1933
+ )
1934
+ except Exception as rollback_error:
1935
+ self._log_event(
1936
+ "runner",
1937
+ "rollback_failed",
1938
+ LogLevel.CRITICAL,
1939
+ {"error": str(rollback_error)},
1940
+ )
1941
+
1942
+ def _resolve_guard_policies(
1943
+ self, report: RunReport, auto_config: dict[str, Any] | None = None
1944
+ ) -> dict[str, dict[str, Any]]:
1945
+ """Resolve tier-based guard policies from configuration."""
1946
+ # Use passed auto_config if available, otherwise extract from report meta
1947
+ if auto_config is None:
1948
+ config_meta = report.meta.get("config", {})
1949
+
1950
+ # Try to get auto config from various possible locations
1951
+ if hasattr(report, "auto_config"):
1952
+ auto_config = report.auto_config
1953
+ elif "auto" in config_meta:
1954
+ auto_config = config_meta["auto"]
1955
+ else:
1956
+ # Fallback to default balanced tier
1957
+ auto_config = {"tier": "balanced", "enabled": True}
1958
+
1959
+ # Extract tier and edit name
1960
+ tier = auto_config.get("tier", "balanced")
1961
+ edit_name = None
1962
+ if hasattr(report, "edit") and report.edit:
1963
+ edit_name = report.edit.get("name")
1964
+
1965
+ # Also try to get edit name from stored edit result in meta
1966
+ if not edit_name and "edit_name" in report.meta:
1967
+ edit_name = report.meta["edit_name"]
1968
+
1969
+ # Get explicit guard overrides from config
1970
+ config_meta = report.meta.get("config", {})
1971
+ explicit_overrides = config_meta.get("guards", {})
1972
+
1973
+ try:
1974
+ # Resolve tier policies
1975
+ policies = resolve_tier_policies(tier, edit_name, explicit_overrides)
1976
+
1977
+ self._log_event(
1978
+ "auto_tuning",
1979
+ "tier_resolved",
1980
+ LogLevel.INFO,
1981
+ {"tier": tier, "edit": edit_name, "policies_count": len(policies)},
1982
+ )
1983
+
1984
+ return policies
1985
+
1986
+ except Exception as e:
1987
+ self._log_event(
1988
+ "auto_tuning",
1989
+ "tier_resolution_failed",
1990
+ LogLevel.ERROR,
1991
+ {"tier": tier, "error": str(e)},
1992
+ )
1993
+ # Return empty policies dict on failure
1994
+ return {}
1995
+
1996
+ def _apply_guard_policy(self, guard: Guard, policy: dict[str, Any]) -> None:
1997
+ """Apply resolved policy parameters to a guard instance."""
1998
+ try:
1999
+ # Apply policy parameters to guard
2000
+ for param_name, param_value in policy.items():
2001
+ if hasattr(guard, param_name):
2002
+ setattr(guard, param_name, param_value)
2003
+ elif hasattr(guard, "config") and isinstance(guard.config, dict):
2004
+ # Try to set in guard's config dict
2005
+ guard.config[param_name] = param_value
2006
+ elif hasattr(guard, "policy") and isinstance(guard.policy, dict):
2007
+ # Try to set in guard's policy dict
2008
+ guard.policy[param_name] = param_value
2009
+ else:
2010
+ # Last resort: add to guard as attribute
2011
+ setattr(guard, param_name, param_value)
2012
+
2013
+ except Exception as e:
2014
+ self._log_event(
2015
+ "auto_tuning",
2016
+ "policy_application_failed",
2017
+ LogLevel.WARNING,
2018
+ {"guard": guard.name, "policy": policy, "error": str(e)},
2019
+ )
2020
+
2021
+ def _log_event(
2022
+ self,
2023
+ component: str,
2024
+ operation: str,
2025
+ level: LogLevel,
2026
+ data: dict[str, Any] | None = None,
2027
+ ) -> None:
2028
+ """Log an event if event logger is available."""
2029
+ if self.event_logger:
2030
+ self.event_logger.log(component, operation, level, data)
2031
+
2032
+ def _serialize_config(self, config: RunConfig) -> dict[str, Any]:
2033
+ """Serialize RunConfig for storage in report."""
2034
+ return {
2035
+ "device": config.device,
2036
+ "max_pm_ratio": config.max_pm_ratio,
2037
+ "checkpoint_interval": config.checkpoint_interval,
2038
+ "dry_run": config.dry_run,
2039
+ "verbose": config.verbose,
2040
+ "guards": config.context.get("guards", {}) if config.context else {},
2041
+ }