invarlock 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +61 -0
  3. invarlock/adapters/hf_loading.py +97 -0
  4. invarlock/calibration/__init__.py +6 -0
  5. invarlock/calibration/spectral_null.py +301 -0
  6. invarlock/calibration/variance_ve.py +154 -0
  7. invarlock/cli/app.py +15 -0
  8. invarlock/cli/commands/calibrate.py +576 -0
  9. invarlock/cli/commands/doctor.py +9 -3
  10. invarlock/cli/commands/explain_gates.py +53 -9
  11. invarlock/cli/commands/plugins.py +12 -2
  12. invarlock/cli/commands/run.py +175 -79
  13. invarlock/cli/commands/verify.py +40 -0
  14. invarlock/cli/determinism.py +237 -0
  15. invarlock/core/auto_tuning.py +215 -17
  16. invarlock/core/registry.py +9 -4
  17. invarlock/eval/bench.py +467 -141
  18. invarlock/eval/bench_regression.py +12 -0
  19. invarlock/eval/data.py +29 -7
  20. invarlock/guards/spectral.py +216 -9
  21. invarlock/guards/variance.py +6 -3
  22. invarlock/reporting/certificate.py +249 -37
  23. invarlock/reporting/certificate_schema.py +4 -1
  24. invarlock/reporting/guards_analysis.py +108 -10
  25. invarlock/reporting/normalizer.py +21 -1
  26. invarlock/reporting/policy_utils.py +100 -16
  27. {invarlock-0.3.1.dist-info → invarlock-0.3.2.dist-info}/METADATA +12 -10
  28. {invarlock-0.3.1.dist-info → invarlock-0.3.2.dist-info}/RECORD +32 -25
  29. {invarlock-0.3.1.dist-info → invarlock-0.3.2.dist-info}/WHEEL +0 -0
  30. {invarlock-0.3.1.dist-info → invarlock-0.3.2.dist-info}/entry_points.txt +0 -0
  31. {invarlock-0.3.1.dist-info → invarlock-0.3.2.dist-info}/licenses/LICENSE +0 -0
  32. {invarlock-0.3.1.dist-info → invarlock-0.3.2.dist-info}/top_level.txt +0 -0
@@ -897,11 +897,21 @@ def _check_plugin_extras(plugin_name: str, plugin_type: str) -> str:
897
897
  if not plugin_info or not plugin_info["packages"]:
898
898
  return "" # No extra dependencies needed
899
899
 
900
- # Check each required package using import to play nice with tests that mock __import__
900
+ # Check each required package. For most packages we use a light import so
901
+ # tests can monkeypatch __import__; for GPU-only stacks like bitsandbytes
902
+ # we only probe presence via importlib.util.find_spec to avoid crashing on
903
+ # CPU-only builds during simple listing.
901
904
  missing_packages: list[str] = []
902
905
  for pkg in plugin_info["packages"]:
903
906
  try:
904
- __import__(pkg)
907
+ if pkg == "bitsandbytes":
908
+ import importlib.util as _util
909
+
910
+ spec = _util.find_spec(pkg)
911
+ if spec is None:
912
+ raise ImportError("bitsandbytes not importable")
913
+ else:
914
+ __import__(pkg)
905
915
  except Exception:
906
916
  missing_packages.append(pkg)
907
917
 
@@ -9,6 +9,7 @@ prefer Compare & Certify via `invarlock certify --baseline ... --subject ...`.
9
9
 
10
10
  import copy
11
11
  import hashlib
12
+ import inspect
12
13
  import json
13
14
  import math
14
15
  import os
@@ -818,6 +819,51 @@ def _resolve_provider_and_split(
818
819
  return data_provider, resolved_split, used_fallback_split
819
820
 
820
821
 
822
+ def _extract_model_load_kwargs(cfg: InvarLockConfig) -> dict[str, Any]:
823
+ """Return adapter.load_model kwargs from config (excluding core fields)."""
824
+ try:
825
+ data = cfg.model_dump()
826
+ except Exception:
827
+ data = {}
828
+ model = data.get("model") if isinstance(data, dict) else None
829
+ if not isinstance(model, dict):
830
+ return {}
831
+ return {
832
+ key: value
833
+ for key, value in model.items()
834
+ if key not in {"id", "adapter", "device"} and value is not None
835
+ }
836
+
837
+
838
+ def _load_model_with_cfg(adapter: Any, cfg: InvarLockConfig, device: str) -> Any:
839
+ """Load a model with config-provided kwargs, filtering for strict adapters."""
840
+ try:
841
+ model_id = cfg.model.id
842
+ except Exception:
843
+ try:
844
+ model_id = (cfg.model_dump().get("model") or {}).get("id")
845
+ except Exception:
846
+ model_id = None
847
+ if not isinstance(model_id, str) or not model_id:
848
+ raise ValueError("Missing model.id in config")
849
+
850
+ extra = _extract_model_load_kwargs(cfg)
851
+ try:
852
+ sig = inspect.signature(adapter.load_model)
853
+ accepts_var_kw = any(
854
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
855
+ )
856
+ if accepts_var_kw:
857
+ return adapter.load_model(model_id, device=device, **extra)
858
+ allowed = {k: v for k, v in extra.items() if k in sig.parameters}
859
+ if allowed:
860
+ return adapter.load_model(model_id, device=device, **allowed)
861
+ except Exception:
862
+ # Fall back to the strictest call shape.
863
+ pass
864
+ return adapter.load_model(model_id, device=device)
865
+
866
+
821
867
  def _run_bare_control(
822
868
  *,
823
869
  adapter: Any,
@@ -899,6 +945,7 @@ def _run_bare_control(
899
945
  "errors": [],
900
946
  "checks": {},
901
947
  "source": f"{profile_normalized or 'ci'}_profile",
948
+ "mode": "bare",
902
949
  }
903
950
 
904
951
  if getattr(bare_report, "status", "").lower() not in {"success", "completed", "ok"}:
@@ -977,7 +1024,7 @@ def _postprocess_and_summarize(
977
1024
  match_fraction: float | None,
978
1025
  overlap_fraction: float | None,
979
1026
  console: Console,
980
- ) -> None:
1027
+ ) -> dict[str, str]:
981
1028
  """Finalize report windows stats and print/save summary artifacts."""
982
1029
  try:
983
1030
  ds = report.setdefault("dataset", {}).setdefault("windows", {})
@@ -1001,6 +1048,7 @@ def _postprocess_and_summarize(
1001
1048
  console.print(f"📄 Report: {saved_files['json']}")
1002
1049
  if run_config.event_path:
1003
1050
  console.print(f"📝 Events: {run_config.event_path}")
1051
+ return saved_files
1004
1052
 
1005
1053
 
1006
1054
  def _compute_provider_digest(report: dict[str, Any]) -> dict[str, str] | None:
@@ -1537,6 +1585,7 @@ def run_command(
1537
1585
  no_cleanup = bool(_coerce_option(no_cleanup, False))
1538
1586
 
1539
1587
  # Use shared CLI coercers from invarlock.cli.utils
1588
+ report_path_out: str | None = None
1540
1589
 
1541
1590
  def _fail_run(message: str) -> None:
1542
1591
  console.print(f"[red]❌ {message}[/red]")
@@ -1673,6 +1722,26 @@ def run_command(
1673
1722
  cfg, device=device, out=out, console=console
1674
1723
  )
1675
1724
 
1725
+ determinism_meta: dict[str, Any] | None = None
1726
+ try:
1727
+ from invarlock.cli.determinism import apply_determinism_preset
1728
+
1729
+ preset = apply_determinism_preset(
1730
+ profile=profile_label,
1731
+ device=resolved_device,
1732
+ seed=int(seed_bundle.get("python") or seed_value),
1733
+ threads=int(os.environ.get("INVARLOCK_OMP_THREADS", 1) or 1),
1734
+ )
1735
+ if isinstance(preset, dict) and preset:
1736
+ determinism_meta = preset
1737
+ preset_seeds = preset.get("seeds")
1738
+ if isinstance(preset_seeds, dict) and preset_seeds:
1739
+ for key in ("python", "numpy", "torch"):
1740
+ if key in preset_seeds:
1741
+ seed_bundle[key] = preset_seeds.get(key)
1742
+ except Exception:
1743
+ determinism_meta = None
1744
+
1676
1745
  # Create run directory with timestamp
1677
1746
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1678
1747
  run_dir = output_dir / timestamp
@@ -2934,7 +3003,23 @@ def run_command(
2934
3003
  )
2935
3004
 
2936
3005
  guard_overhead_payload: dict[str, Any] | None = None
2937
- if measure_guard_overhead:
3006
+ if skip_overhead and profile_normalized in {"ci", "release"}:
3007
+ guard_overhead_payload = {
3008
+ "overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
3009
+ "evaluated": False,
3010
+ "passed": True,
3011
+ "skipped": True,
3012
+ "skip_reason": "INVARLOCK_SKIP_OVERHEAD_CHECK",
3013
+ "mode": "skipped",
3014
+ "source": "env:INVARLOCK_SKIP_OVERHEAD_CHECK",
3015
+ "messages": [
3016
+ "Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK"
3017
+ ],
3018
+ "warnings": [],
3019
+ "errors": [],
3020
+ "checks": {},
3021
+ }
3022
+ elif measure_guard_overhead:
2938
3023
  guard_overhead_payload = _run_bare_control(
2939
3024
  adapter=adapter,
2940
3025
  edit_op=edit_op,
@@ -3076,6 +3161,8 @@ def run_command(
3076
3161
  meta_payload["invarlock_version"] = invarlock_version
3077
3162
  if env_flags:
3078
3163
  meta_payload["env_flags"] = env_flags
3164
+ if determinism_meta:
3165
+ meta_payload["determinism"] = determinism_meta
3079
3166
  report["meta"].update(meta_payload)
3080
3167
  if pm_acceptance_range:
3081
3168
  report["meta"]["pm_acceptance_range"] = pm_acceptance_range
@@ -3235,87 +3322,90 @@ def run_command(
3235
3322
  report["metrics"].update(metrics_payload)
3236
3323
 
3237
3324
  if guard_overhead_payload is not None:
3238
- # Compute guarded primary-metric snapshot; pass structured reports into validator
3239
- try:
3240
- # Map loss type to ppl family kind
3241
- lk = str(resolved_loss_type or "causal").lower()
3242
- if lk == "mlm":
3243
- pm_kind_for_overhead = "ppl_mlm"
3244
- elif lk in {"seq2seq", "s2s", "t5"}:
3245
- pm_kind_for_overhead = "ppl_seq2seq"
3246
- else:
3247
- pm_kind_for_overhead = "ppl_causal"
3325
+ if bool(guard_overhead_payload.get("skipped", False)):
3326
+ report["guard_overhead"] = guard_overhead_payload
3327
+ else:
3328
+ # Compute guarded primary-metric snapshot; pass structured reports into validator
3329
+ try:
3330
+ # Map loss type to ppl family kind
3331
+ lk = str(resolved_loss_type or "causal").lower()
3332
+ if lk == "mlm":
3333
+ pm_kind_for_overhead = "ppl_mlm"
3334
+ elif lk in {"seq2seq", "s2s", "t5"}:
3335
+ pm_kind_for_overhead = "ppl_seq2seq"
3336
+ else:
3337
+ pm_kind_for_overhead = "ppl_causal"
3248
3338
 
3249
- # Prefer computing from the in-memory core_report windows to avoid ordering issues
3250
- pm_guarded = _extract_pm_snapshot_for_overhead(
3251
- core_report, kind=pm_kind_for_overhead
3252
- )
3253
- if not isinstance(pm_guarded, dict) or not pm_guarded:
3339
+ # Prefer computing from the in-memory core_report windows to avoid ordering issues
3254
3340
  pm_guarded = _extract_pm_snapshot_for_overhead(
3255
- report, kind=pm_kind_for_overhead
3341
+ core_report, kind=pm_kind_for_overhead
3256
3342
  )
3343
+ if not isinstance(pm_guarded, dict) or not pm_guarded:
3344
+ pm_guarded = _extract_pm_snapshot_for_overhead(
3345
+ report, kind=pm_kind_for_overhead
3346
+ )
3257
3347
 
3258
- guard_overhead_payload["guarded_report"] = (
3259
- {"metrics": {"primary_metric": pm_guarded}}
3260
- if isinstance(pm_guarded, dict) and pm_guarded
3261
- else None
3348
+ guard_overhead_payload["guarded_report"] = (
3349
+ {"metrics": {"primary_metric": pm_guarded}}
3350
+ if isinstance(pm_guarded, dict) and pm_guarded
3351
+ else None
3352
+ )
3353
+ except Exception:
3354
+ guard_overhead_payload["guarded_report"] = None
3355
+ bare_struct = guard_overhead_payload.get("bare_report") or {}
3356
+ guarded_struct = guard_overhead_payload.get("guarded_report") or {}
3357
+ # Be robust to mocks or minimal objects returned by validators
3358
+ result = validate_guard_overhead(
3359
+ bare_struct,
3360
+ guarded_struct,
3361
+ overhead_threshold=guard_overhead_payload.get(
3362
+ "overhead_threshold", GUARD_OVERHEAD_THRESHOLD
3363
+ ),
3262
3364
  )
3263
- except Exception:
3264
- guard_overhead_payload["guarded_report"] = None
3265
- bare_struct = guard_overhead_payload.get("bare_report") or {}
3266
- guarded_struct = guard_overhead_payload.get("guarded_report") or {}
3267
- # Be robust to mocks or minimal objects returned by validators
3268
- result = validate_guard_overhead(
3269
- bare_struct,
3270
- guarded_struct,
3271
- overhead_threshold=guard_overhead_payload.get(
3272
- "overhead_threshold", GUARD_OVERHEAD_THRESHOLD
3273
- ),
3274
- )
3275
- try:
3276
- messages = list(getattr(result, "messages", []))
3277
- except Exception: # pragma: no cover - defensive
3278
- messages = []
3279
- try:
3280
- warnings = list(getattr(result, "warnings", []))
3281
- except Exception: # pragma: no cover - defensive
3282
- warnings = []
3283
- try:
3284
- errors = list(getattr(result, "errors", []))
3285
- except Exception: # pragma: no cover - defensive
3286
- errors = []
3287
- try:
3288
- checks = dict(getattr(result, "checks", {}))
3289
- except Exception: # pragma: no cover - defensive
3290
- checks = {}
3291
- metrics_obj = getattr(result, "metrics", {})
3292
- if not isinstance(metrics_obj, dict):
3293
- metrics_obj = {}
3294
- overhead_ratio = metrics_obj.get("overhead_ratio")
3295
- if overhead_ratio is None:
3296
- overhead_ratio = getattr(result, "overhead_ratio", None)
3297
- overhead_percent = metrics_obj.get("overhead_percent")
3298
- if overhead_percent is None:
3299
- overhead_percent = getattr(result, "overhead_percent", None)
3300
- passed_flag = bool(getattr(result, "passed", False))
3301
-
3302
- guard_overhead_payload.update(
3303
- {
3304
- "messages": messages,
3305
- "warnings": warnings,
3306
- "errors": errors,
3307
- "checks": checks,
3308
- "overhead_ratio": overhead_ratio,
3309
- "overhead_percent": overhead_percent,
3310
- "passed": passed_flag,
3311
- "evaluated": True,
3312
- }
3313
- )
3314
- # Normalize for non-finite/degenerate cases
3315
- guard_overhead_payload = _normalize_overhead_result(
3316
- guard_overhead_payload, profile=profile_normalized
3317
- )
3318
- report["guard_overhead"] = guard_overhead_payload
3365
+ try:
3366
+ messages = list(getattr(result, "messages", []))
3367
+ except Exception: # pragma: no cover - defensive
3368
+ messages = []
3369
+ try:
3370
+ warnings = list(getattr(result, "warnings", []))
3371
+ except Exception: # pragma: no cover - defensive
3372
+ warnings = []
3373
+ try:
3374
+ errors = list(getattr(result, "errors", []))
3375
+ except Exception: # pragma: no cover - defensive
3376
+ errors = []
3377
+ try:
3378
+ checks = dict(getattr(result, "checks", {}))
3379
+ except Exception: # pragma: no cover - defensive
3380
+ checks = {}
3381
+ metrics_obj = getattr(result, "metrics", {})
3382
+ if not isinstance(metrics_obj, dict):
3383
+ metrics_obj = {}
3384
+ overhead_ratio = metrics_obj.get("overhead_ratio")
3385
+ if overhead_ratio is None:
3386
+ overhead_ratio = getattr(result, "overhead_ratio", None)
3387
+ overhead_percent = metrics_obj.get("overhead_percent")
3388
+ if overhead_percent is None:
3389
+ overhead_percent = getattr(result, "overhead_percent", None)
3390
+ passed_flag = bool(getattr(result, "passed", False))
3391
+
3392
+ guard_overhead_payload.update(
3393
+ {
3394
+ "messages": messages,
3395
+ "warnings": warnings,
3396
+ "errors": errors,
3397
+ "checks": checks,
3398
+ "overhead_ratio": overhead_ratio,
3399
+ "overhead_percent": overhead_percent,
3400
+ "passed": passed_flag,
3401
+ "evaluated": True,
3402
+ }
3403
+ )
3404
+ # Normalize for non-finite/degenerate cases
3405
+ guard_overhead_payload = _normalize_overhead_result(
3406
+ guard_overhead_payload, profile=profile_normalized
3407
+ )
3408
+ report["guard_overhead"] = guard_overhead_payload
3319
3409
 
3320
3410
  had_baseline = bool(baseline and Path(baseline).exists())
3321
3411
  if (
@@ -3860,7 +3950,7 @@ def run_command(
3860
3950
  except Exception:
3861
3951
  pass
3862
3952
 
3863
- _postprocess_and_summarize(
3953
+ saved_files = _postprocess_and_summarize(
3864
3954
  report=report,
3865
3955
  run_dir=run_dir,
3866
3956
  run_config=run_config,
@@ -3870,6 +3960,11 @@ def run_command(
3870
3960
  overlap_fraction=overlap_fraction,
3871
3961
  console=console,
3872
3962
  )
3963
+ try:
3964
+ if isinstance(saved_files, dict) and saved_files.get("json"):
3965
+ report_path_out = str(saved_files["json"])
3966
+ except Exception:
3967
+ pass
3873
3968
 
3874
3969
  # Metrics display
3875
3970
  pm_obj = None
@@ -4060,6 +4155,7 @@ def run_command(
4060
4155
  pass
4061
4156
 
4062
4157
  # Normal path falls through; cleanup handled below in finally
4158
+ return report_path_out
4063
4159
 
4064
4160
  except FileNotFoundError as e:
4065
4161
  console.print(f"[red]❌ Configuration file not found: {e}[/red]")
@@ -35,6 +35,22 @@ from .run import _enforce_provider_parity, _resolve_exit_code
35
35
  console = Console()
36
36
 
37
37
 
38
+ def _coerce_float(value: Any) -> float | None:
39
+ try:
40
+ out = float(value)
41
+ except (TypeError, ValueError):
42
+ return None
43
+ return out if math.isfinite(out) else None
44
+
45
+
46
+ def _coerce_int(value: Any) -> int | None:
47
+ try:
48
+ out = int(value)
49
+ except (TypeError, ValueError):
50
+ return None
51
+ return out if out >= 0 else None
52
+
53
+
38
54
  def _load_certificate(path: Path) -> dict[str, Any]:
39
55
  """Load certificate JSON from disk."""
40
56
  with path.open("r", encoding="utf-8") as handle:
@@ -315,6 +331,30 @@ def _validate_certificate_payload(
315
331
  errors.extend(_validate_drift_band(certificate))
316
332
  errors.extend(_apply_profile_lints(certificate))
317
333
  errors.extend(_validate_tokenizer_hash(certificate))
334
+ # Release-only enforcement: guard overhead must be measured or explicitly skipped.
335
+ if prof == "release":
336
+ go = certificate.get("guard_overhead")
337
+ if not isinstance(go, dict) or not go:
338
+ errors.append(
339
+ "Release verification requires guard_overhead (missing). "
340
+ "Set INVARLOCK_SKIP_OVERHEAD_CHECK=1 to explicitly skip during certification."
341
+ )
342
+ else:
343
+ skipped = bool(go.get("skipped", False)) or (
344
+ str(go.get("mode", "")).strip().lower() == "skipped"
345
+ )
346
+ if not skipped:
347
+ evaluated = go.get("evaluated")
348
+ if evaluated is not True:
349
+ errors.append(
350
+ "Release verification requires evaluated guard_overhead (not evaluated). "
351
+ "Set INVARLOCK_SKIP_OVERHEAD_CHECK=1 to explicitly skip during certification."
352
+ )
353
+ ratio = go.get("overhead_ratio")
354
+ if ratio is None:
355
+ errors.append(
356
+ "Release verification requires guard_overhead.overhead_ratio (missing)."
357
+ )
318
358
  # Legacy cross-checks removed; primary_metric is canonical
319
359
 
320
360
  return errors
@@ -0,0 +1,237 @@
1
+ """Determinism presets for CI/release runs.
2
+
3
+ Centralizes:
4
+ - Seeds (python/numpy/torch)
5
+ - Thread caps (OMP/MKL/etc + torch threads)
6
+ - TF32 policy
7
+ - torch deterministic algorithms
8
+ - A structured "determinism level" for certificate provenance
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import random
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+
19
+ from invarlock.model_utils import set_seed
20
+
21
+ try: # optional torch
22
+ import torch
23
+ except Exception: # pragma: no cover
24
+ torch = None # type: ignore[assignment]
25
+
26
+
27
+ _THREAD_ENV_VARS: tuple[str, ...] = (
28
+ "OMP_NUM_THREADS",
29
+ "MKL_NUM_THREADS",
30
+ "OPENBLAS_NUM_THREADS",
31
+ "NUMEXPR_NUM_THREADS",
32
+ "VECLIB_MAXIMUM_THREADS",
33
+ )
34
+
35
+
36
+ def _coerce_int(value: Any, default: int) -> int:
37
+ try:
38
+ return int(value)
39
+ except Exception:
40
+ return int(default)
41
+
42
+
43
+ def _coerce_profile(profile: str | None) -> str:
44
+ try:
45
+ return (profile or "").strip().lower()
46
+ except Exception:
47
+ return ""
48
+
49
+
50
+ def _coerce_device(device: str | None) -> str:
51
+ try:
52
+ return (device or "").strip().lower()
53
+ except Exception:
54
+ return "cpu"
55
+
56
+
57
+ def apply_determinism_preset(
58
+ *,
59
+ profile: str | None,
60
+ device: str | None,
61
+ seed: int,
62
+ threads: int = 1,
63
+ ) -> dict[str, Any]:
64
+ """Apply a determinism preset and return a provenance payload."""
65
+
66
+ prof = _coerce_profile(profile)
67
+ dev = _coerce_device(device)
68
+ threads_i = max(1, _coerce_int(threads, 1))
69
+
70
+ requested = "off"
71
+ if prof in {"ci", "release"}:
72
+ requested = "strict"
73
+
74
+ env_set: dict[str, Any] = {}
75
+ torch_flags: dict[str, Any] = {}
76
+ notes: list[str] = []
77
+
78
+ # Thread caps (best-effort): make CPU determinism explicit and reduce drift.
79
+ if requested == "strict":
80
+ for var in _THREAD_ENV_VARS:
81
+ os.environ[var] = str(threads_i)
82
+ env_set[var] = os.environ.get(var)
83
+
84
+ # CUDA determinism: cuBLAS workspace config.
85
+ if requested == "strict" and dev.startswith("cuda"):
86
+ os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")
87
+ env_set["CUBLAS_WORKSPACE_CONFIG"] = os.environ.get("CUBLAS_WORKSPACE_CONFIG")
88
+
89
+ # Seed all RNGs (python/numpy/torch) using the existing helper for parity.
90
+ set_seed(int(seed))
91
+
92
+ # Derive a stable seed bundle for provenance.
93
+ seed_bundle = {
94
+ "python": int(seed),
95
+ "numpy": int(seed),
96
+ "torch": None,
97
+ }
98
+ try:
99
+ numpy_seed = int(np.random.get_state()[1][0])
100
+ seed_bundle["numpy"] = int(numpy_seed)
101
+ except Exception:
102
+ pass
103
+ if torch is not None:
104
+ try:
105
+ seed_bundle["torch"] = int(torch.initial_seed())
106
+ except Exception:
107
+ seed_bundle["torch"] = int(seed)
108
+
109
+ # Torch-specific controls.
110
+ level = "off" if requested == "off" else "strict"
111
+ if requested == "strict":
112
+ if torch is None:
113
+ level = "tolerance"
114
+ notes.append("torch_unavailable")
115
+ else:
116
+ # Thread caps.
117
+ try:
118
+ if hasattr(torch, "set_num_threads"):
119
+ torch.set_num_threads(threads_i)
120
+ if hasattr(torch, "set_num_interop_threads"):
121
+ torch.set_num_interop_threads(threads_i)
122
+ torch_flags["torch_threads"] = threads_i
123
+ except Exception:
124
+ level = "tolerance"
125
+ notes.append("torch_thread_caps_failed")
126
+
127
+ # Disable TF32 for determinism.
128
+ try:
129
+ matmul = getattr(
130
+ getattr(torch.backends, "cuda", object()), "matmul", None
131
+ )
132
+ if matmul is not None and hasattr(matmul, "allow_tf32"):
133
+ matmul.allow_tf32 = False
134
+ cudnn_mod = getattr(torch.backends, "cudnn", None)
135
+ if cudnn_mod is not None and hasattr(cudnn_mod, "allow_tf32"):
136
+ cudnn_mod.allow_tf32 = False
137
+ except Exception:
138
+ level = "tolerance"
139
+ notes.append("tf32_policy_failed")
140
+
141
+ # Deterministic algorithms.
142
+ try:
143
+ if hasattr(torch, "use_deterministic_algorithms"):
144
+ torch.use_deterministic_algorithms(True, warn_only=False)
145
+ except Exception:
146
+ # Downgrade to tolerance-based determinism rather than crashing.
147
+ level = "tolerance"
148
+ notes.append("deterministic_algorithms_unavailable")
149
+ try:
150
+ if hasattr(torch, "use_deterministic_algorithms"):
151
+ torch.use_deterministic_algorithms(True, warn_only=True)
152
+ except Exception:
153
+ pass
154
+
155
+ # cuDNN knobs.
156
+ try:
157
+ cudnn_mod = getattr(torch.backends, "cudnn", None)
158
+ if cudnn_mod is not None:
159
+ cudnn_mod.benchmark = False
160
+ if hasattr(cudnn_mod, "deterministic"):
161
+ cudnn_mod.deterministic = True
162
+ except Exception:
163
+ level = "tolerance"
164
+ notes.append("cudnn_determinism_failed")
165
+
166
+ # Snapshot applied flags for provenance.
167
+ try:
168
+ det_enabled = getattr(
169
+ torch, "are_deterministic_algorithms_enabled", None
170
+ )
171
+ if callable(det_enabled):
172
+ torch_flags["deterministic_algorithms"] = bool(det_enabled())
173
+ except Exception:
174
+ pass
175
+ try:
176
+ cudnn_mod = getattr(torch.backends, "cudnn", None)
177
+ if cudnn_mod is not None:
178
+ torch_flags["cudnn_deterministic"] = bool(
179
+ getattr(cudnn_mod, "deterministic", False)
180
+ )
181
+ torch_flags["cudnn_benchmark"] = bool(
182
+ getattr(cudnn_mod, "benchmark", False)
183
+ )
184
+ if hasattr(cudnn_mod, "allow_tf32"):
185
+ torch_flags["cudnn_allow_tf32"] = bool(
186
+ getattr(cudnn_mod, "allow_tf32", False)
187
+ )
188
+ except Exception:
189
+ pass
190
+ try:
191
+ matmul = getattr(
192
+ getattr(torch.backends, "cuda", object()), "matmul", None
193
+ )
194
+ if matmul is not None and hasattr(matmul, "allow_tf32"):
195
+ torch_flags["cuda_matmul_allow_tf32"] = bool(matmul.allow_tf32)
196
+ except Exception:
197
+ pass
198
+
199
+ # Normalized level is always one of these.
200
+ if level not in {"off", "strict", "tolerance"}:
201
+ level = "tolerance" if requested == "strict" else "off"
202
+
203
+ # Extra breadcrumb: random module state is not easily serializable; include a coarse marker.
204
+ try:
205
+ torch_flags["python_random"] = isinstance(random.random(), float)
206
+ except Exception:
207
+ pass
208
+
209
+ payload = {
210
+ "requested": requested,
211
+ "level": level,
212
+ "profile": prof or None,
213
+ "device": dev,
214
+ "threads": threads_i if requested == "strict" else None,
215
+ "seed": int(seed),
216
+ "seeds": seed_bundle,
217
+ "env": env_set,
218
+ "torch": torch_flags,
219
+ "notes": notes,
220
+ }
221
+
222
+ # Remove empty sections for stable artifacts.
223
+ if not payload["env"]:
224
+ payload.pop("env", None)
225
+ if not payload["torch"]:
226
+ payload.pop("torch", None)
227
+ if not payload["notes"]:
228
+ payload.pop("notes", None)
229
+ if payload.get("threads") is None:
230
+ payload.pop("threads", None)
231
+ if payload.get("profile") is None:
232
+ payload.pop("profile", None)
233
+
234
+ return payload
235
+
236
+
237
+ __all__ = ["apply_determinism_preset"]