mbe-eval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mbe_eval/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .core import MBEEvaluator, MBEReport
2
+ from .sample_eval import simulate_mbe_evaluation
3
+
4
+ __all__ = ["MBEEvaluator", "MBEReport", "simulate_mbe_evaluation"]
mbe_eval/core.py ADDED
@@ -0,0 +1,133 @@
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pingouin as pg
5
+ from scipy.stats import pearsonr, spearmanr
6
+ from dataclasses import dataclass
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from rich.console import Console
10
+ from rich.table import Table
11
+ from rich.panel import Panel
12
+
13
+ console = Console()
14
+
15
+ @dataclass
16
+ class MBEReport:
17
+ metric_name: str
18
+ baseline_name: str
19
+ absolute_r: float
20
+ absolute_p: float
21
+ partial_r: float
22
+ partial_p: float
23
+ is_loss_proxy: bool
24
+
25
+ class MBEEvaluator:
26
+ """
27
+ The Marginal Baseline Eval (MBE) framework.
28
+ Evaluates whether a proposed representation metric offers independent
29
+ structural insight beyond a trivial baseline (e.g. early validation loss).
30
+ """
31
+
32
+ def __init__(self, metric_name: str = "Proposed Metric", baseline_name: str = "Validation Loss"):
33
+ self.metric_name = metric_name
34
+ self.baseline_name = baseline_name
35
+
36
+ def evaluate(self, metric_vals: np.ndarray, baseline_vals: np.ndarray, target_vals: np.ndarray,
37
+ alpha: float = 0.05, output_dir: str = "mbe_reports") -> MBEReport:
38
+ """
39
+ Runs Stage 1 (Absolute Correlation) and Stage 4 (Partial Correlation).
40
+ """
41
+ console.print(Panel(f"[bold cyan]Marginal Baseline Eval (MBE) - {self.metric_name}[/bold cyan]"))
42
+
43
+ # Stage 1: Absolute Correlation (Does it predict the target at all?)
44
+ # We use Spearman rank correlation as the primary measure for structural metrics
45
+ abs_r, abs_p = spearmanr(metric_vals, target_vals)
46
+
47
+ # Stage 4: Partial Correlation (Does it beat the baseline?)
48
+ # Partial correlation requires linear controls, so we use Pearson on the ranks if desired,
49
+ # but pingouin partial_corr handles linear partials well. We'll use pingouin.
50
+ df = pd.DataFrame({
51
+ 'Target': target_vals,
52
+ 'Metric': metric_vals,
53
+ 'Baseline': baseline_vals
54
+ })
55
+
56
+ pcorr = pg.partial_corr(data=df, x='Metric', y='Target', covar='Baseline', method='spearman')
57
+ part_r = pcorr['r'].values[0]
58
+ part_p = pcorr['p_val'].values[0]
59
+
60
+ is_proxy = part_p > alpha
61
+
62
+ report = MBEReport(
63
+ metric_name=self.metric_name,
64
+ baseline_name=self.baseline_name,
65
+ absolute_r=abs_r, absolute_p=abs_p,
66
+ partial_r=part_r, partial_p=part_p,
67
+ is_loss_proxy=is_proxy
68
+ )
69
+
70
+ self._print_rich_report(report)
71
+ self._generate_plots(df, report, output_dir)
72
+
73
+ return report
74
+
75
+ def _print_rich_report(self, r: MBEReport):
76
+ table = Table(title="MBE Evaluation Results", show_header=True, header_style="bold magenta")
77
+ table.add_column("Stage", style="dim", width=20)
78
+ table.add_column("Test", width=40)
79
+ table.add_column("Correlation", justify="right")
80
+ table.add_column("p-value", justify="right")
81
+ table.add_column("Verdict", justify="center")
82
+
83
+ # Absolute
84
+ abs_verdict = "[bold green]PASS[/bold green]" if r.absolute_p < 0.05 else "[bold red]FAIL[/bold red]"
85
+ abs_p_str = f"[green]{r.absolute_p:.2e}[/green]" if r.absolute_p < 0.05 else f"[red]{r.absolute_p:.2e}[/red]"
86
+ table.add_row("1: Absolute", f"Correlation with Final Target", f"{r.absolute_r:.3f}", abs_p_str, abs_verdict)
87
+
88
+ # Partial
89
+ part_verdict = "[bold red]FAIL[/bold red]" if r.is_loss_proxy else "[bold green]PASS[/bold green]"
90
+ part_p_str = f"[red]{r.partial_p:.2e}[/red]" if r.is_loss_proxy else f"[green]{r.partial_p:.2e}[/green]"
91
+ table.add_row("4: MBE Control", f"Controlling for {r.baseline_name}", f"{r.partial_r:.3f}", part_p_str, part_verdict)
92
+
93
+ console.print(table)
94
+
95
+ if r.is_loss_proxy:
96
+ console.print(f"[bold red]DIAGNOSIS: {r.metric_name} is a disguised Loss Proxy.[/bold red]")
97
+ console.print(f"It provides NO independent predictive signal beyond {r.baseline_name}.\n")
98
+ else:
99
+ console.print(f"[bold green]DIAGNOSIS: {r.metric_name} provides independent structural insight![/bold green]\n")
100
+
101
+ def _generate_plots(self, df: pd.DataFrame, r: MBEReport, output_dir: str):
102
+ os.makedirs(output_dir, exist_ok=True)
103
+ sns.set_theme(style="whitegrid")
104
+
105
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
106
+ fig.suptitle(f"MBE Report: {r.metric_name}", fontsize=16, fontweight='bold')
107
+
108
+ # Plot 1: Metric vs Target (Absolute)
109
+ sns.regplot(ax=axes[0], data=df, x='Metric', y='Target', scatter_kws={'alpha':0.6, 's':80}, color='indigo')
110
+ axes[0].set_title(f"Stage 1: Absolute Correlation\nρ = {r.absolute_r:.3f} (p = {r.absolute_p:.2e})", fontsize=14)
111
+ axes[0].set_xlabel(r.metric_name, fontsize=12)
112
+ axes[0].set_ylabel("Final Target (e.g., Accuracy)", fontsize=12)
113
+
114
+ # Plot 2: Bar chart of correlation collapse
115
+ labels = ['Absolute\n(Uncontrolled)', f'Marginal\n(Controlling {r.baseline_name})']
116
+ vals = [abs(r.absolute_r), abs(r.partial_r)]
117
+ colors = ['#2ecc71' if r.absolute_p < 0.05 else '#e74c3c',
118
+ '#2ecc71' if not r.is_loss_proxy else '#e74c3c']
119
+
120
+ axes[1].bar(labels, vals, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
121
+ axes[1].set_ylim(0, 1.0)
122
+ axes[1].set_ylabel("Absolute Magnitude of Spearman ρ", fontsize=12)
123
+ axes[1].set_title("Stage 4: Predictive Power Collapse", fontsize=14)
124
+
125
+ for i, v in enumerate(vals):
126
+ axes[1].text(i, v + 0.02, f"|ρ|={v:.3f}", ha='center', fontsize=12, fontweight='bold')
127
+
128
+ plt.tight_layout()
129
+ save_path = os.path.join(output_dir, f"mbe_report_{r.metric_name.replace(' ', '_')}.png")
130
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
131
+ plt.close()
132
+
133
+ console.print(f"[dim]Generated graphical report at: {save_path}[/dim]")
@@ -0,0 +1,70 @@
1
+ import numpy as np
2
+ from scipy.stats import pearsonr
3
+ import pingouin as pg
4
+ import pandas as pd
5
+
6
+ def simulate_mbe_evaluation():
7
+ """
8
+ Simulates the Marginal Baseline Eval (MBE) on a dummy representation metric.
9
+ MBE tests whether a proposed metric offers independent predictive signal
10
+ beyond a trivial baseline (early validation loss).
11
+ """
12
+ print("==================================================")
13
+ print("Marginal Baseline Eval (MBE) - Sample Test Run")
14
+ print("==================================================\n")
15
+
16
+ # 1. Generate heterogeneous dummy data (simulating 30 training runs)
17
+ np.random.seed(42)
18
+ n_runs = 30
19
+
20
+ # Underlying true model capability (unobserved)
21
+ true_capability = np.random.randn(n_runs)
22
+
23
+ # Early Validation Loss (our trivial baseline) perfectly tracks capability + some noise
24
+ early_val_loss = -true_capability + np.random.randn(n_runs) * 0.2
25
+
26
+ # Final Epoch 200 Accuracy (our target to predict) perfectly tracks capability + some noise
27
+ final_test_acc = true_capability + np.random.randn(n_runs) * 0.1
28
+
29
+ # Our proposed dummy metric: Gradient L2 Norm
30
+ # IT IS A LOSS PROXY! It mathematically correlates heavily with early val loss.
31
+ proposed_metric = early_val_loss + np.random.randn(n_runs) * 0.3
32
+
33
+ # 2. Stage 1: Absolute Correlation (The False Assurance)
34
+ # Does the metric predict final accuracy?
35
+ r_metric, p_metric = pearsonr(proposed_metric, final_test_acc)
36
+ print(f"[Stage 1] Absolute Correlation Check:")
37
+ print(f" Correlation of Proposed Metric with Final Accuracy: r = {r_metric:.3f} (p = {p_metric:.3e})")
38
+
39
+ if p_metric < 0.05:
40
+ print(" -> RESULT: PASS. Metric correlates significantly with generalization.\n")
41
+
42
+ # 3. Stage 2: The MBE Partial-Correlation Baseline Control
43
+ # Does the metric offer MARGINAL signal beyond the trivial validation loss baseline?
44
+ print("[Stage 2] The MBE Baseline Control (Partial Correlation):")
45
+
46
+ df = pd.DataFrame({
47
+ 'Final_Acc': final_test_acc,
48
+ 'Proposed_Metric': proposed_metric,
49
+ 'Baseline_Loss': early_val_loss
50
+ })
51
+
52
+ # Calculate partial correlation: Metric vs Final Acc, controlling for Baseline Loss
53
+ pcorr_metric = pg.partial_corr(data=df, x='Proposed_Metric', y='Final_Acc', covar='Baseline_Loss')
54
+ r_partial = pcorr_metric['r'].values[0]
55
+ p_partial = pcorr_metric['p-val'].values[0]
56
+
57
+ print(f" Marginal Correlation (Controlling for Early Val Loss): r = {r_partial:.3f} (p = {p_partial:.3f})")
58
+
59
+ if p_partial > 0.05:
60
+ print(" -> RESULT: FAIL. The metric offers NO independent predictive signal.")
61
+ print(" -> DIAGNOSIS: The proposed metric is a disguised Loss Proxy.\n")
62
+ else:
63
+ print(" -> RESULT: PASS. The metric offers independent structural insight.\n")
64
+
65
+ print("==================================================")
66
+ print("MBE Evaluation Complete.")
67
+ print("==================================================")
68
+
69
+ if __name__ == "__main__":
70
+ simulate_mbe_evaluation()
mbe_eval/utils.py ADDED
@@ -0,0 +1,50 @@
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ def compute_fim_norm(model, loss_fn, inputs, targets):
6
+ """
7
+ Computes Gradient Effective Rank (FIM_norm) via the dual Gram matrix.
8
+ This exactly implements the mathematical formulation from the MBE paper.
9
+ """
10
+ N = inputs.shape[0]
11
+
12
+ # We need per-sample gradients. Since standard PyTorch accumulates,
13
+ # we compute it one by one for exactness. (In practice, functorch/vmap is faster,
14
+ # but a simple loop is highly readable for a demonstration.)
15
+
16
+ grads = []
17
+ model.eval() # ensure no batchnorm tracking during this
18
+
19
+ for i in range(N):
20
+ x = inputs[i:i+1]
21
+ y = targets[i:i+1]
22
+
23
+ loss = loss_fn(model(x), y)
24
+ model.zero_grad()
25
+ loss.backward()
26
+
27
+ # Flatten all parameters' gradients into a single vector
28
+ g_vec = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None])
29
+ grads.append(g_vec.detach())
30
+
31
+ G = torch.stack(grads) # shape: (N, P)
32
+
33
+ # The Dual Gram Matrix
34
+ S_dual = (1.0 / N) * torch.matmul(G, G.T) # shape: (N, N)
35
+
36
+ # Eigendecomposition
37
+ eigenvalues = torch.linalg.eigvalsh(S_dual)
38
+ eigenvalues = eigenvalues[eigenvalues > 1e-12] # numerical noise floor
39
+
40
+ if len(eigenvalues) == 0:
41
+ return 1.0 # Max normalized rank
42
+
43
+ # Shannon entropy of normalized spectrum
44
+ p = eigenvalues / eigenvalues.sum()
45
+ H = -(p * torch.log(p)).sum().item()
46
+
47
+ erank = math.exp(H)
48
+ fim_norm = erank / N
49
+
50
+ return fim_norm
@@ -0,0 +1,128 @@
1
+ Metadata-Version: 2.4
2
+ Name: mbe-eval
3
+ Version: 0.1.0
4
+ Summary: Marginal Baseline Eval (MBE): A framework for rigorously auditing representation metrics in deep neural networks.
5
+ Home-page: https://github.com/AparajeetS/metric-audit-paper-code
6
+ Author: Aparajeet Shadangi
7
+ Author-email: author@example.com
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy
16
+ Requires-Dist: scipy
17
+ Requires-Dist: pandas
18
+ Requires-Dist: pingouin
19
+ Requires-Dist: torch
20
+ Requires-Dist: torchvision
21
+ Requires-Dist: matplotlib
22
+ Requires-Dist: seaborn
23
+ Requires-Dist: scikit-learn
24
+ Requires-Dist: rich
25
+ Dynamic: author
26
+ Dynamic: author-email
27
+ Dynamic: classifier
28
+ Dynamic: description
29
+ Dynamic: description-content-type
30
+ Dynamic: home-page
31
+ Dynamic: license-file
32
+ Dynamic: requires-dist
33
+ Dynamic: requires-python
34
+ Dynamic: summary
35
+
36
+ # The Marginal Baseline Eval (MBE)
37
+
38
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
39
+
40
+ Welcome to the **Marginal Baseline Eval (MBE)** repository!
41
+
42
+ This repository provides the formal implementation of the MBE protocol — a strict, 4-stage validation methodology designed to rigorously audit representation metrics in deep neural networks.
43
+
44
+ It was originally built during a massive case study that mathematically falsified the Gradient Effective Rank (FIM_norm) metric.
45
+
46
+ ## Why Do We Need MBE?
47
+
48
+ The AI safety and interpretability communities frequently propose internal structural metrics (e.g., representation geometry, effective rank, gradient coherence) to predict generalization or track model health.
49
+
50
+ However, many of these metrics are secretly **Loss Proxies**. Because early validation loss trivially predicts final test accuracy, any metric that mathematically correlates with the *magnitude* of the loss will automatically correlate with generalization. Such a metric provides **zero independent structural insight**.
51
+
52
+ The MBE protocol catches these false positive metrics using a rigorous **partial-correlation baseline control**.
53
+
54
+ ## Installation
55
+
56
+ You can install the framework directly from PyPI:
57
+
58
+ ```bash
59
+ pip install mbe-eval
60
+ ```
61
+
62
+ Or, if you want to run the PyTorch demos, clone the repository:
63
+
64
+ ```bash
65
+ git clone https://github.com/AparajeetS/metric-audit-paper-code.git
66
+ cd metric-audit-paper-code
67
+ pip install -r requirements.txt
68
+ ```
69
+
70
+ ## The MBE API
71
+
72
+ MBE is a fully importable Python framework powered by `pandas` and `pingouin`. You can integrate it directly into your own model evaluation pipelines.
73
+
74
+ ```python
75
+ from mbe_eval import MBEEvaluator
76
+
77
+ # Pass your experimental arrays (numpy arrays)
78
+ evaluator = MBEEvaluator(metric_name="My Cool Metric", baseline_name="Epoch 20 Val Loss")
79
+ report = evaluator.evaluate(metric_vals, baseline_vals, target_vals)
80
+ ```
81
+ This automatically prints a beautiful `rich` diagnostics table to the console and generates a high-resolution `seaborn` graphical report in the `mbe_reports/` directory.
82
+
83
+ ## Real PyTorch Demos
84
+
85
+ We provide two end-to-end PyTorch scripts in the `examples/` directory that actually train neural networks and run the evaluation live.
86
+
87
+ **1. The Acid Test (Stage 1)**
88
+ Shows how a metric can successfully track capacity and noise, giving false assurance.
89
+ ```bash
90
+ python examples/01_run_acid_test.py
91
+ ```
92
+
93
+ **2. The Heterogeneous Grid (Stage 4)**
94
+ The killer demo. Trains 20 models with randomized hyperparameters, computes the Gradient Effective Rank, and runs the final MBE Partial Correlation control to prove the metric is a disguised loss proxy.
95
+ ```bash
96
+ python examples/02_run_heterogeneous_grid.py
97
+ ```
98
+
99
+ ## Repository Structure
100
+
101
+ ```
102
+ metric-audit-paper-code/
103
+ ├── mbe_eval/ # The core MBE evaluation API
104
+ │ ├── __init__.py
105
+ │ ├── core.py # MBEEvaluator class
106
+ │ ├── utils.py # PyTorch FIM_norm extraction
107
+ │ └── sample_eval.py # Basic synthetic simulation
108
+ ├── examples/ # Real end-to-end PyTorch demos
109
+ │ ├── 01_run_acid_test.py
110
+ │ └── 02_run_heterogeneous_grid.py
111
+ ├── experiments/ # All 12 original paper experiment scripts
112
+ ├── metric_audit/ # Core FIM_norm computation library
113
+ ├── docs/
114
+ │ └── RESULTS.md # Raw numerical results for the paper
115
+ ├── PAPER.md # Full technical writeup
116
+ ├── requirements.txt
117
+ ├── LICENSE
118
+ └── README.md
119
+ ```
120
+
121
+ ## Citation
122
+
123
+ If you use the Marginal Baseline Eval in your own representation evaluation, please cite the accompanying manuscript:
124
+
125
+ ```
126
+ Shadangi, A. (2026). Does It Beat the Baseline? A Comprehensive Negative Result
127
+ on Gradient Effective Rank as a Generalization Predictor. arXiv preprint.
128
+ ```
@@ -0,0 +1,11 @@
1
+ mbe_eval/__init__.py,sha256=78eMbBLcXaHAmAkrHzQgNNIZsSiS7LBUyz-wjIbKqMA,159
2
+ mbe_eval/core.py,sha256=2455WF5dTCipo8UYr7nsEDNcJWl4oujs3rQflI-Z69c,6001
3
+ mbe_eval/sample_eval.py,sha256=8-H-_zd5rEEU05p837Zyj1pZeufxXrgZK_O_T7M4QSM,3092
4
+ mbe_eval/utils.py,sha256=L66LsmVwjFTsGu0oaH_7xXUVoEDac_wEhI_gNHiwXfc,1572
5
+ mbe_eval-0.1.0.dist-info/licenses/LICENSE,sha256=w9-S3P-lYZOpilpXgq_hhYezIacu2FZ9RsFIp_18fMY,1075
6
+ metric_audit/__init__.py,sha256=G1_XnKKJZWrZMuGXI1jUUfgZXJkf4mNvzgaRKP1uMUk,30
7
+ metric_audit/sci_tracker.py,sha256=9c-Lgaf660YrNQJVEWF3Hbjl6EaG0T1R2u_3itQ2Tn4,4483
8
+ mbe_eval-0.1.0.dist-info/METADATA,sha256=Kdp9qKRmcya-Os_qy47XdRi8NinMgpVqCw7C9t-BuHk,5058
9
+ mbe_eval-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
10
+ mbe_eval-0.1.0.dist-info/top_level.txt,sha256=7RGD-pg-boBEVp3aL7qM9ALR2Fa9An0TMy2lGwrwzfg,22
11
+ mbe_eval-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Aparajeet Shadangi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,2 @@
1
+ mbe_eval
2
+ metric_audit
@@ -0,0 +1 @@
1
+ # Metric Audit - Core Library
@@ -0,0 +1,134 @@
1
+ """
2
+ SCI Tracker — Structural Constraint Index (CEI v2)
3
+
4
+ Measures erank(W_l) / min(n_out, n_in) for each weight matrix.
5
+ No forward pass required. No BatchNorm contamination.
6
+ """
7
+
8
+ import math
9
+ import numpy as np
10
+
11
+
12
+ # ------------------------------------------------------------------ #
13
+ # Core math #
14
+ # ------------------------------------------------------------------ #
15
+
16
+ def effective_rank(M: np.ndarray) -> float:
17
+ """
18
+ erank(M) = exp( H(p) ) where H is Shannon entropy of
19
+ normalized singular values p_k = σ_k / Σ σ_j.
20
+ Range: [1, min(n_rows, n_cols)].
21
+ """
22
+ sv = np.linalg.svd(M.astype(np.float64), compute_uv=False)
23
+ sv = sv[sv > 1e-10]
24
+ if len(sv) == 0:
25
+ return 1.0
26
+ p = sv / sv.sum()
27
+ H = -(p * np.log(p + 1e-12)).sum()
28
+ return float(math.exp(H))
29
+
30
+
31
+ def sci_from_weight(W: np.ndarray) -> float:
32
+ """
33
+ Structural Constraint Index for a single weight matrix W.
34
+ sci ∈ (1/r, 1] where r = min(n_out, n_in).
35
+ Lower = more constrained = more generalizable (hypothesis).
36
+ """
37
+ r = min(W.shape[0], W.shape[1])
38
+ return effective_rank(W) / r
39
+
40
+
41
+ def sci_spectrum(W: np.ndarray) -> dict:
42
+ """
43
+ Full spectral breakdown for diagnostics.
44
+ Returns: sci, erank, rank_max, singular values, spectral entropy.
45
+ """
46
+ sv = np.linalg.svd(W.astype(np.float64), compute_uv=False)
47
+ r = min(W.shape[0], W.shape[1])
48
+ sv_p = sv[sv > 1e-10]
49
+ if len(sv_p) == 0:
50
+ return {"sci": 1.0, "erank": 1.0, "rank_max": r,
51
+ "spectral_entropy": 0.0, "sv_max": 0.0, "sv_min": 0.0,
52
+ "stable_rank": 0.0, "nuclear_norm": 0.0, "spectral_norm": 0.0}
53
+ p = sv_p / sv_p.sum()
54
+ H = -(p * np.log(p + 1e-12)).sum()
55
+ er = math.exp(H)
56
+ return {
57
+ "sci": er / r,
58
+ "erank": er,
59
+ "rank_max": r,
60
+ "spectral_entropy": float(H),
61
+ "sv_max": float(sv_p[0]),
62
+ "sv_min": float(sv_p[-1]),
63
+ # stable rank = ||W||_F^2 / ||W||_2^2 (Rudelson & Vershynin)
64
+ "stable_rank": float((sv_p**2).sum() / (sv_p[0]**2 + 1e-12)),
65
+ "nuclear_norm": float(sv_p.sum()),
66
+ "spectral_norm": float(sv_p[0]),
67
+ }
68
+
69
+
70
+ # ------------------------------------------------------------------ #
71
+ # NumPy MLP tracker (no PyTorch needed) #
72
+ # ------------------------------------------------------------------ #
73
+
74
+ class NumpySCITracker:
75
+ """
76
+ Tracks SCI across all weight matrices of a numpy MLP.
77
+
78
+ Usage:
79
+ tracker = NumpySCITracker(model)
80
+ sci_vals, net_sci = tracker.compute()
81
+ """
82
+
83
+ def __init__(self, model):
84
+ self.model = model # expects model.W = list of np.ndarray
85
+
86
+ def compute(self) -> tuple[list[float], float]:
87
+ """Returns (per_layer_sci, network_mean_sci)."""
88
+ vals = [sci_from_weight(W) for W in self.model.W]
89
+ return vals, float(np.mean(vals))
90
+
91
+ def compute_full(self) -> list[dict]:
92
+ """Returns full spectral breakdown per layer."""
93
+ return [sci_spectrum(W) for W in self.model.W]
94
+
95
+
96
+ # ------------------------------------------------------------------ #
97
+ # Optional PyTorch wrapper #
98
+ # ------------------------------------------------------------------ #
99
+
100
+ def pytorch_sci(model, layer_types=None) -> dict:
101
+ """
102
+ Compute SCI for all weight matrices in a PyTorch model.
103
+
104
+ Args:
105
+ model: nn.Module
106
+ layer_types: tuple of types to include (default: Linear + Conv2d)
107
+
108
+ Returns:
109
+ dict mapping layer_name -> sci_spectrum dict, plus "network_sci" mean.
110
+ """
111
+ try:
112
+ import torch
113
+ import torch.nn as nn
114
+ except ImportError:
115
+ raise ImportError("PyTorch not available; use NumpySCITracker for numpy models.")
116
+
117
+ if layer_types is None:
118
+ layer_types = (nn.Linear, nn.Conv2d)
119
+
120
+ results = {}
121
+ sci_list = []
122
+
123
+ for name, module in model.named_modules():
124
+ if isinstance(module, layer_types):
125
+ W = module.weight.detach().cpu().numpy()
126
+ if W.ndim > 2:
127
+ # Conv2d: (C_out, C_in, kH, kW) → reshape to (C_out, C_in*kH*kW)
128
+ W = W.reshape(W.shape[0], -1)
129
+ spec = sci_spectrum(W)
130
+ results[name] = spec
131
+ sci_list.append(spec["sci"])
132
+
133
+ results["network_sci"] = float(np.mean(sci_list)) if sci_list else float("nan")
134
+ return results