invarlock 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. invarlock/__init__.py +2 -2
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +11 -15
  4. invarlock/adapters/auto.py +35 -40
  5. invarlock/adapters/capabilities.py +2 -2
  6. invarlock/adapters/hf_causal.py +418 -0
  7. invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
  8. invarlock/adapters/hf_mixin.py +25 -4
  9. invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
  10. invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
  11. invarlock/calibration/spectral_null.py +15 -10
  12. invarlock/calibration/variance_ve.py +0 -2
  13. invarlock/cli/adapter_auto.py +31 -21
  14. invarlock/cli/app.py +73 -2
  15. invarlock/cli/commands/calibrate.py +6 -2
  16. invarlock/cli/commands/certify.py +651 -91
  17. invarlock/cli/commands/doctor.py +11 -11
  18. invarlock/cli/commands/explain_gates.py +57 -8
  19. invarlock/cli/commands/plugins.py +13 -9
  20. invarlock/cli/commands/report.py +233 -69
  21. invarlock/cli/commands/run.py +1066 -244
  22. invarlock/cli/commands/verify.py +154 -15
  23. invarlock/cli/config.py +22 -6
  24. invarlock/cli/doctor_helpers.py +4 -5
  25. invarlock/cli/output.py +193 -0
  26. invarlock/cli/provenance.py +1 -1
  27. invarlock/core/api.py +45 -5
  28. invarlock/core/auto_tuning.py +65 -20
  29. invarlock/core/bootstrap.py +1 -1
  30. invarlock/core/contracts.py +7 -1
  31. invarlock/core/registry.py +11 -13
  32. invarlock/core/runner.py +425 -75
  33. invarlock/edits/quant_rtn.py +65 -37
  34. invarlock/eval/bench.py +3 -16
  35. invarlock/eval/data.py +82 -51
  36. invarlock/eval/metrics.py +63 -2
  37. invarlock/eval/primary_metric.py +23 -0
  38. invarlock/eval/tail_stats.py +230 -0
  39. invarlock/eval/tasks/__init__.py +12 -0
  40. invarlock/eval/tasks/classification.py +48 -0
  41. invarlock/eval/tasks/qa.py +36 -0
  42. invarlock/eval/tasks/text_generation.py +102 -0
  43. invarlock/guards/_estimators.py +154 -0
  44. invarlock/guards/invariants.py +19 -10
  45. invarlock/guards/policies.py +16 -6
  46. invarlock/guards/rmt.py +627 -546
  47. invarlock/guards/spectral.py +348 -110
  48. invarlock/guards/tier_config.py +32 -30
  49. invarlock/guards/variance.py +7 -31
  50. invarlock/guards_ref/rmt_ref.py +23 -23
  51. invarlock/model_profile.py +90 -42
  52. invarlock/observability/health.py +6 -6
  53. invarlock/observability/metrics.py +108 -0
  54. invarlock/reporting/certificate.py +384 -55
  55. invarlock/reporting/certificate_schema.py +3 -2
  56. invarlock/reporting/dataset_hashing.py +15 -2
  57. invarlock/reporting/guards_analysis.py +350 -277
  58. invarlock/reporting/html.py +55 -5
  59. invarlock/reporting/normalizer.py +13 -0
  60. invarlock/reporting/policy_utils.py +38 -36
  61. invarlock/reporting/primary_metric_utils.py +71 -17
  62. invarlock/reporting/render.py +852 -431
  63. invarlock/reporting/report.py +40 -4
  64. invarlock/reporting/report_types.py +11 -3
  65. invarlock/reporting/telemetry.py +86 -0
  66. invarlock/reporting/validate.py +1 -18
  67. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/METADATA +27 -13
  68. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/RECORD +72 -65
  69. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
  70. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
  71. invarlock/adapters/hf_gpt2.py +0 -404
  72. invarlock/adapters/hf_llama.py +0 -487
  73. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
  74. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
@@ -623,6 +623,9 @@ def compute_primary_metric_from_report(
623
623
  "preview": float("nan"),
624
624
  "final": float("nan"),
625
625
  "ratio_vs_baseline": float("nan"),
626
+ "invalid": True,
627
+ "degraded": True,
628
+ "degraded_reason": "non_finite_pm",
626
629
  }
627
630
  # For accuracy kinds, derive counts from input_ids if aggregates are missing
628
631
  if kind in {"accuracy", "vqa_accuracy"}:
@@ -661,6 +664,11 @@ def compute_primary_metric_from_report(
661
664
  final_point = metric.point_from_windows(windows=final_win)
662
665
 
663
666
  ratio_vs_baseline = float("nan")
667
+ baseline_has_reference = False
668
+
669
+ def _is_finite(value: Any) -> bool:
670
+ return isinstance(value, (int, float)) and math.isfinite(float(value))
671
+
664
672
  if isinstance(baseline, dict):
665
673
  try:
666
674
  base_metrics = (
@@ -686,14 +694,25 @@ def compute_primary_metric_from_report(
686
694
  is_ppl_like = str(kind).lower().startswith("ppl")
687
695
  if is_ppl_like and base_ref > 0:
688
696
  ratio_vs_baseline = float(final_point) / float(base_ref)
697
+ baseline_has_reference = True
689
698
  elif (
690
699
  str(kind).lower() in {"accuracy", "vqa_accuracy"}
691
700
  and 0 <= base_ref <= 1
692
701
  ):
693
702
  ratio_vs_baseline = float(final_point) - float(base_ref)
703
+ baseline_has_reference = True
694
704
  except Exception:
695
705
  ratio_vs_baseline = float("nan")
696
706
 
707
+ invalid = not (_is_finite(preview_point) and _is_finite(final_point))
708
+ degraded_reason = None
709
+ if invalid:
710
+ degraded_reason = "non_finite_pm"
711
+ elif baseline_has_reference and not _is_finite(ratio_vs_baseline):
712
+ degraded_reason = "non_finite_delta"
713
+
714
+ degraded = bool(degraded_reason or invalid)
715
+
697
716
  payload = {
698
717
  "kind": metric.kind,
699
718
  "unit": metric.unit,
@@ -705,7 +724,11 @@ def compute_primary_metric_from_report(
705
724
  "preview": preview_point,
706
725
  "final": final_point,
707
726
  "ratio_vs_baseline": ratio_vs_baseline,
727
+ "invalid": invalid,
728
+ "degraded": degraded,
708
729
  }
730
+ if degraded and degraded_reason:
731
+ payload["degraded_reason"] = degraded_reason
709
732
  # Carry counts for accuracy to aid gating
710
733
  if kind in {"accuracy", "vqa_accuracy"}:
711
734
  if "n_prev" in locals() and n_prev is not None:
@@ -0,0 +1,230 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import Any
6
+
7
+ __all__ = [
8
+ "compute_tail_summary",
9
+ "evaluate_metric_tail",
10
+ ]
11
+
12
+
13
+ def _as_finite_float(value: Any) -> float | None:
14
+ try:
15
+ out = float(value)
16
+ except Exception:
17
+ return None
18
+ return out if math.isfinite(out) else None
19
+
20
+
21
+ def _linear_quantile(sorted_values: Sequence[float], q: float) -> float:
22
+ """Deterministic linear-interpolated quantile on sorted values (q in [0, 1])."""
23
+ n = len(sorted_values)
24
+ if n == 0:
25
+ return float("nan")
26
+ if n == 1:
27
+ return float(sorted_values[0])
28
+ if q <= 0.0:
29
+ return float(sorted_values[0])
30
+ if q >= 1.0:
31
+ return float(sorted_values[-1])
32
+ pos = float(q) * float(n - 1)
33
+ lo = int(math.floor(pos))
34
+ hi = int(math.ceil(pos))
35
+ if lo == hi:
36
+ return float(sorted_values[lo])
37
+ frac = pos - float(lo)
38
+ a = float(sorted_values[lo])
39
+ b = float(sorted_values[hi])
40
+ return a + frac * (b - a)
41
+
42
+
43
+ def compute_tail_summary(
44
+ deltas: Sequence[float] | Sequence[Any],
45
+ *,
46
+ quantiles: Sequence[float] = (0.5, 0.9, 0.95, 0.99),
47
+ epsilon: float = 1e-4,
48
+ weights: Sequence[float] | Sequence[Any] | None = None,
49
+ ) -> dict[str, Any]:
50
+ """Compute deterministic tail summaries for Δlog-loss samples.
51
+
52
+ - Quantiles are computed unweighted using linear interpolation on sorted values.
53
+ - tail_mass is Pr[delta > epsilon] (unweighted).
54
+ - tail_mass_weighted is included when weights are provided and finite.
55
+ """
56
+ eps = _as_finite_float(epsilon)
57
+ if eps is None or eps < 0.0:
58
+ eps = 0.0
59
+
60
+ values: list[float] = []
61
+ paired_weights: list[float] | None = [] if weights is not None else None
62
+
63
+ if weights is None:
64
+ for d in deltas:
65
+ dv = _as_finite_float(d)
66
+ if dv is None:
67
+ continue
68
+ values.append(float(dv))
69
+ else:
70
+ for d, w in zip(deltas, weights, strict=False):
71
+ dv = _as_finite_float(d)
72
+ if dv is None:
73
+ continue
74
+ wv = _as_finite_float(w)
75
+ if wv is None or wv < 0.0:
76
+ wv = 0.0
77
+ values.append(float(dv))
78
+ if paired_weights is not None:
79
+ paired_weights.append(float(wv))
80
+
81
+ n = int(len(values))
82
+ values_sorted = sorted(values)
83
+
84
+ summary: dict[str, Any] = {
85
+ "n": n,
86
+ "epsilon": float(eps),
87
+ }
88
+ if n == 0:
89
+ summary.update({"max": float("nan"), "tail_mass": 0.0})
90
+ for q in quantiles:
91
+ try:
92
+ qf = float(q)
93
+ except Exception:
94
+ continue
95
+ label = f"q{int(round(100.0 * max(0.0, min(1.0, qf))))}"
96
+ summary[label] = float("nan")
97
+ return summary
98
+
99
+ summary["max"] = float(values_sorted[-1])
100
+ tail_ct = sum(1 for v in values if v > eps)
101
+ summary["tail_mass"] = float(tail_ct / n)
102
+
103
+ if paired_weights is not None:
104
+ total_w = 0.0
105
+ tail_w = 0.0
106
+ for v, w in zip(values, paired_weights, strict=False):
107
+ total_w += float(w)
108
+ if v > eps:
109
+ tail_w += float(w)
110
+ if total_w > 0.0:
111
+ summary["tail_mass_weighted"] = float(tail_w / total_w)
112
+ summary["tail_mass_weighted_by"] = "weights"
113
+
114
+ for q in quantiles:
115
+ try:
116
+ qf = float(q)
117
+ except Exception:
118
+ continue
119
+ qf = max(0.0, min(1.0, qf))
120
+ label = f"q{int(round(100.0 * qf))}"
121
+ summary[label] = float(_linear_quantile(values_sorted, qf))
122
+
123
+ return summary
124
+
125
+
126
+ def evaluate_metric_tail(
127
+ *,
128
+ deltas: Sequence[float] | Sequence[Any],
129
+ policy: Mapping[str, Any] | None = None,
130
+ weights: Sequence[float] | Sequence[Any] | None = None,
131
+ ) -> dict[str, Any]:
132
+ """Evaluate a tail policy against Δlog-loss samples.
133
+
134
+ Policy keys:
135
+ - mode: "off" | "warn" | "fail" (default: "warn")
136
+ - min_windows: int (default: 1)
137
+ - quantile: float in [0, 1] (default: 0.95)
138
+ - quantile_max: float threshold in Δlog-loss (optional)
139
+ - epsilon: float deadband for tail_mass (default: 1e-4)
140
+ - mass_max: float in [0, 1] (optional)
141
+ """
142
+ pol = dict(policy or {})
143
+ mode = str(pol.get("mode", "warn") or "warn").strip().lower()
144
+ if mode not in {"off", "warn", "fail"}:
145
+ mode = "warn"
146
+
147
+ min_windows = pol.get("min_windows", 1)
148
+ try:
149
+ min_windows_i = int(min_windows)
150
+ except Exception:
151
+ min_windows_i = 1
152
+ min_windows_i = max(1, min_windows_i)
153
+
154
+ q = _as_finite_float(pol.get("quantile", 0.95))
155
+ if q is None:
156
+ q = 0.95
157
+ q = max(0.0, min(1.0, float(q)))
158
+
159
+ eps = _as_finite_float(pol.get("epsilon", 1e-4))
160
+ if eps is None or eps < 0.0:
161
+ eps = 0.0
162
+
163
+ qmax = _as_finite_float(pol.get("quantile_max"))
164
+ mmax = _as_finite_float(pol.get("mass_max"))
165
+ if mmax is not None:
166
+ mmax = max(0.0, min(1.0, float(mmax)))
167
+
168
+ quantiles = sorted({0.5, 0.9, 0.95, 0.99, float(q)})
169
+ stats = compute_tail_summary(
170
+ deltas, quantiles=tuple(quantiles), epsilon=float(eps), weights=weights
171
+ )
172
+ n = int(stats.get("n", 0) or 0)
173
+
174
+ thresholds_present = (qmax is not None) or (mmax is not None)
175
+ evaluated = bool(mode != "off" and thresholds_present and n >= min_windows_i)
176
+
177
+ violations: list[dict[str, Any]] = []
178
+ passed = True
179
+ if evaluated:
180
+ passed = True
181
+ q_label = f"q{int(round(100.0 * q))}"
182
+ q_obs = stats.get(q_label)
183
+ if not (isinstance(q_obs, int | float) and math.isfinite(float(q_obs))):
184
+ q_obs = float("nan")
185
+ if qmax is not None and math.isfinite(q_obs) and q_obs > float(qmax):
186
+ passed = False
187
+ violations.append(
188
+ {
189
+ "type": "quantile_max_exceeded",
190
+ "quantile": float(q),
191
+ "observed": float(q_obs),
192
+ "threshold": float(qmax),
193
+ }
194
+ )
195
+
196
+ tail_mass = stats.get("tail_mass")
197
+ if (
198
+ mmax is not None
199
+ and isinstance(tail_mass, int | float)
200
+ and math.isfinite(float(tail_mass))
201
+ and float(tail_mass) > float(mmax)
202
+ ):
203
+ passed = False
204
+ violations.append(
205
+ {
206
+ "type": "tail_mass_exceeded",
207
+ "epsilon": float(eps),
208
+ "observed": float(tail_mass),
209
+ "threshold": float(mmax),
210
+ }
211
+ )
212
+
213
+ warned = bool(evaluated and (not passed) and mode == "warn")
214
+
215
+ return {
216
+ "mode": mode,
217
+ "evaluated": evaluated,
218
+ "passed": bool(passed),
219
+ "warned": warned,
220
+ "violations": violations,
221
+ "policy": {
222
+ "mode": mode,
223
+ "min_windows": int(min_windows_i),
224
+ "quantile": float(q),
225
+ "quantile_max": float(qmax) if qmax is not None else None,
226
+ "epsilon": float(eps),
227
+ "mass_max": float(mmax) if mmax is not None else None,
228
+ },
229
+ "stats": stats,
230
+ }
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from .classification import accuracy_from_records
4
+ from .qa import exact_match_from_records
5
+ from .text_generation import bleu1_from_records, rouge_l_from_records
6
+
7
+ __all__ = [
8
+ "accuracy_from_records",
9
+ "exact_match_from_records",
10
+ "bleu1_from_records",
11
+ "rouge_l_from_records",
12
+ ]
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+
7
+ def _iter_pairs(record: dict[str, Any]) -> list[tuple[Any, Any]]:
8
+ if "correct" in record:
9
+ return [(bool(record.get("correct")), True)]
10
+
11
+ label = record.get("label")
12
+ pred = record.get("prediction")
13
+ if label is None:
14
+ label = record.get("labels")
15
+ if pred is None:
16
+ pred = record.get("pred")
17
+ if pred is None:
18
+ pred = record.get("predictions")
19
+
20
+ if isinstance(label, list) and isinstance(pred, list):
21
+ return list(zip(label, pred, strict=False))
22
+ if label is None or pred is None:
23
+ return []
24
+ return [(label, pred)]
25
+
26
+
27
+ def accuracy_from_records(records: Iterable[dict[str, Any]]) -> float:
28
+ """Compute accuracy from records with labels/predictions.
29
+
30
+ Accepted record shapes:
31
+ - {"label": <label>, "prediction": <label>}
32
+ - {"labels": [...], "predictions": [...]}
33
+ - {"correct": <bool>}
34
+ """
35
+ total = 0
36
+ correct = 0
37
+ for record in records:
38
+ if not isinstance(record, dict):
39
+ continue
40
+ for label, pred in _iter_pairs(record):
41
+ total += 1
42
+ if isinstance(label, bool):
43
+ correct += int(label is pred)
44
+ else:
45
+ correct += int(label == pred)
46
+ if total == 0:
47
+ return float("nan")
48
+ return float(correct / total)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+
7
+ def _normalize(text: str) -> str:
8
+ return " ".join(str(text).strip().lower().split())
9
+
10
+
11
+ def exact_match_from_records(records: Iterable[dict[str, Any]]) -> float:
12
+ """Compute exact-match accuracy for QA-style records.
13
+
14
+ Accepted record shapes:
15
+ - {"prediction": "...", "answer": "..."}
16
+ - {"prediction": "...", "answers": ["...", ...]}
17
+ """
18
+ total = 0
19
+ correct = 0
20
+ for record in records:
21
+ if not isinstance(record, dict):
22
+ continue
23
+ pred = record.get("prediction")
24
+ answers = record.get("answers")
25
+ if answers is None and "answer" in record:
26
+ answers = [record.get("answer")]
27
+ if pred is None or answers is None:
28
+ continue
29
+ pred_norm = _normalize(pred)
30
+ answer_list = answers if isinstance(answers, list) else [answers]
31
+ total += 1
32
+ if any(_normalize(a) == pred_norm for a in answer_list if a is not None):
33
+ correct += 1
34
+ if total == 0:
35
+ return float("nan")
36
+ return float(correct / total)
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from collections.abc import Iterable
5
+ from typing import Any
6
+
7
+
8
+ def _tokenize(text: str) -> list[str]:
9
+ return [tok for tok in str(text).strip().lower().split() if tok]
10
+
11
+
12
+ def _bleu1(pred: str, ref: str) -> float:
13
+ pred_tokens = _tokenize(pred)
14
+ ref_tokens = _tokenize(ref)
15
+ if not pred_tokens or not ref_tokens:
16
+ return 0.0
17
+ pred_counts = Counter(pred_tokens)
18
+ ref_counts = Counter(ref_tokens)
19
+ overlap = sum(min(pred_counts[tok], ref_counts.get(tok, 0)) for tok in pred_counts)
20
+ precision = overlap / float(len(pred_tokens))
21
+ bp = 1.0
22
+ if len(pred_tokens) < len(ref_tokens):
23
+ bp = pow(2.718281828, 1.0 - (len(ref_tokens) / float(len(pred_tokens))))
24
+ return float(precision * bp)
25
+
26
+
27
+ def bleu1_from_records(records: Iterable[dict[str, Any]]) -> float:
28
+ """Compute BLEU-1 from records with predictions and references."""
29
+ scores: list[float] = []
30
+ for record in records:
31
+ if not isinstance(record, dict):
32
+ continue
33
+ pred = record.get("prediction")
34
+ refs = record.get("references")
35
+ if pred is None:
36
+ continue
37
+ if refs is None and "reference" in record:
38
+ refs = [record.get("reference")]
39
+ if refs is None:
40
+ continue
41
+ ref_list = refs if isinstance(refs, list) else [refs]
42
+ best = 0.0
43
+ for ref in ref_list:
44
+ if ref is None:
45
+ continue
46
+ best = max(best, _bleu1(str(pred), str(ref)))
47
+ scores.append(best)
48
+ if not scores:
49
+ return float("nan")
50
+ return float(sum(scores) / float(len(scores)))
51
+
52
+
53
+ def _lcs_len(a: list[str], b: list[str]) -> int:
54
+ if not a or not b:
55
+ return 0
56
+ dp = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
57
+ for i, tok_a in enumerate(a, start=1):
58
+ for j, tok_b in enumerate(b, start=1):
59
+ if tok_a == tok_b:
60
+ dp[i][j] = dp[i - 1][j - 1] + 1
61
+ else:
62
+ dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
63
+ return dp[-1][-1]
64
+
65
+
66
+ def _rouge_l(pred: str, ref: str) -> float:
67
+ pred_tokens = _tokenize(pred)
68
+ ref_tokens = _tokenize(ref)
69
+ if not pred_tokens or not ref_tokens:
70
+ return 0.0
71
+ lcs = _lcs_len(pred_tokens, ref_tokens)
72
+ prec = lcs / float(len(pred_tokens))
73
+ rec = lcs / float(len(ref_tokens))
74
+ if prec + rec == 0:
75
+ return 0.0
76
+ return float(2 * prec * rec / (prec + rec))
77
+
78
+
79
+ def rouge_l_from_records(records: Iterable[dict[str, Any]]) -> float:
80
+ """Compute ROUGE-L (F1) from records with predictions and references."""
81
+ scores: list[float] = []
82
+ for record in records:
83
+ if not isinstance(record, dict):
84
+ continue
85
+ pred = record.get("prediction")
86
+ refs = record.get("references")
87
+ if pred is None:
88
+ continue
89
+ if refs is None and "reference" in record:
90
+ refs = [record.get("reference")]
91
+ if refs is None:
92
+ continue
93
+ ref_list = refs if isinstance(refs, list) else [refs]
94
+ best = 0.0
95
+ for ref in ref_list:
96
+ if ref is None:
97
+ continue
98
+ best = max(best, _rouge_l(str(pred), str(ref)))
99
+ scores.append(best)
100
+ if not scores:
101
+ return float("nan")
102
+ return float(sum(scores) / float(len(scores)))
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ __all__ = [
9
+ "power_iter_sigma_max",
10
+ "frobenius_norm_sq",
11
+ "row_col_norm_extrema",
12
+ "stable_rank_estimate",
13
+ ]
14
+
15
+
16
+ def _as_matrix(tensor: torch.Tensor) -> torch.Tensor:
17
+ if tensor.ndim == 2:
18
+ return tensor
19
+ return tensor.view(tensor.shape[0], -1)
20
+
21
+
22
+ def power_iter_sigma_max(
23
+ matrix: Any,
24
+ *,
25
+ iters: int,
26
+ init: str = "ones",
27
+ eps: float = 1e-12,
28
+ ) -> float:
29
+ """Estimate the largest singular value (spectral norm) via fixed-iter power iteration.
30
+
31
+ Contract properties (vNext):
32
+ - fixed iteration budget (no convergence stopping)
33
+ - deterministic initialization (`init`)
34
+ - device-resident matvecs (no `.cpu()` transfers)
35
+ """
36
+ try:
37
+ iters_i = int(iters)
38
+ except Exception:
39
+ iters_i = 4
40
+ if iters_i < 1:
41
+ iters_i = 1
42
+
43
+ if not isinstance(matrix, torch.Tensor):
44
+ return 0.0
45
+ if matrix.numel() == 0:
46
+ return 0.0
47
+ if matrix.dtype in {torch.int8, torch.uint8}:
48
+ return 0.0
49
+
50
+ W = _as_matrix(matrix.detach())
51
+ if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
52
+ return 0.0
53
+
54
+ device = W.device
55
+ dtype = W.dtype
56
+ n = int(W.shape[1])
57
+
58
+ with torch.no_grad():
59
+ if init == "ones":
60
+ v = torch.ones((n,), device=device, dtype=dtype)
61
+ else:
62
+ # Deterministic fallback: unit vector e0.
63
+ v = torch.zeros((n,), device=device, dtype=dtype)
64
+ v[0] = 1
65
+
66
+ v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
67
+ v = v / v_norm.to(dtype)
68
+
69
+ sigma = 0.0
70
+ for _ in range(iters_i):
71
+ u = W @ v
72
+ u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
73
+ sigma_val = float(u_norm.item())
74
+ if not math.isfinite(sigma_val):
75
+ return 0.0
76
+ u = u / u_norm.to(dtype)
77
+ v = W.T @ u
78
+ v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
79
+ v = v / v_norm.to(dtype)
80
+ sigma = sigma_val
81
+ return float(sigma)
82
+
83
+
84
+ def frobenius_norm_sq(matrix: torch.Tensor) -> float:
85
+ """Return ||matrix||_F^2 with float32 accumulation (device-resident)."""
86
+ W = _as_matrix(matrix.detach())
87
+ if W.numel() == 0:
88
+ return 0.0
89
+ with torch.no_grad():
90
+ # Use a fused reduction to avoid materializing a W*W intermediate.
91
+ norm = torch.linalg.vector_norm(W.reshape(-1), ord=2, dtype=torch.float32)
92
+ out = float((norm * norm).item())
93
+ return out if math.isfinite(out) else 0.0
94
+
95
+
96
+ def row_col_norm_extrema(
97
+ matrix: torch.Tensor, *, eps: float = 1e-12
98
+ ) -> dict[str, float]:
99
+ """Compute min/median/max of row/col L2 norms with float32 accumulation."""
100
+ W = _as_matrix(matrix.detach())
101
+ if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
102
+ return {
103
+ "row_min": 0.0,
104
+ "row_median": 0.0,
105
+ "row_max": 0.0,
106
+ "col_min": 0.0,
107
+ "col_median": 0.0,
108
+ "col_max": 0.0,
109
+ }
110
+ with torch.no_grad():
111
+ # Avoid materializing W*W: use fused reductions.
112
+ row = torch.linalg.vector_norm(W, ord=2, dim=1, dtype=torch.float32).clamp_min(
113
+ eps
114
+ )
115
+ col = torch.linalg.vector_norm(W, ord=2, dim=0, dtype=torch.float32).clamp_min(
116
+ eps
117
+ )
118
+
119
+ row_sorted, _ = torch.sort(row)
120
+ col_sorted, _ = torch.sort(col)
121
+
122
+ def _median(sorted_vec: torch.Tensor) -> float:
123
+ n = int(sorted_vec.numel())
124
+ if n <= 0:
125
+ return 0.0
126
+ mid = n // 2
127
+ if n % 2 == 1:
128
+ return float(sorted_vec[mid].item())
129
+ return float((sorted_vec[mid - 1] + sorted_vec[mid]).mul(0.5).item())
130
+
131
+ return {
132
+ "row_min": float(row_sorted[0].item()),
133
+ "row_median": _median(row_sorted),
134
+ "row_max": float(row_sorted[-1].item()),
135
+ "col_min": float(col_sorted[0].item()),
136
+ "col_median": _median(col_sorted),
137
+ "col_max": float(col_sorted[-1].item()),
138
+ }
139
+
140
+
141
+ def stable_rank_estimate(
142
+ matrix: torch.Tensor, *, sigma_max: float, eps: float = 1e-12
143
+ ) -> float:
144
+ """Estimate stable rank: ||W||_F^2 / ||W||_2^2, using a provided σ̂max."""
145
+ try:
146
+ denom = float(sigma_max) ** 2
147
+ except Exception:
148
+ return 0.0
149
+ if not math.isfinite(denom) or denom <= 0.0:
150
+ return 0.0
151
+ denom = max(denom, eps)
152
+ num = frobenius_norm_sq(matrix)
153
+ out = float(num) / denom if denom > 0 else 0.0
154
+ return out if math.isfinite(out) else 0.0