subset-mixture-model 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. subset_mixture_model-0.2.0/PKG-INFO +197 -0
  2. subset_mixture_model-0.2.0/README.md +163 -0
  3. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/pyproject.toml +2 -2
  4. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/smm/__init__.py +23 -1
  5. subset_mixture_model-0.2.0/smm/cells.py +183 -0
  6. subset_mixture_model-0.2.0/smm/crossfit.py +89 -0
  7. subset_mixture_model-0.2.0/smm/diagnostics.py +91 -0
  8. subset_mixture_model-0.2.0/smm/laplace.py +81 -0
  9. subset_mixture_model-0.2.0/smm/model.py +152 -0
  10. subset_mixture_model-0.2.0/smm/predictor.py +157 -0
  11. subset_mixture_model-0.2.0/smm/smm.py +168 -0
  12. subset_mixture_model-0.2.0/smm/subset_maker.py +177 -0
  13. subset_mixture_model-0.2.0/subset_mixture_model.egg-info/PKG-INFO +197 -0
  14. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/SOURCES.txt +3 -0
  15. subset_mixture_model-0.2.0/tests/test_smm.py +117 -0
  16. subset_mixture_model-0.1.2/PKG-INFO +0 -349
  17. subset_mixture_model-0.1.2/README.md +0 -315
  18. subset_mixture_model-0.1.2/smm/diagnostics.py +0 -201
  19. subset_mixture_model-0.1.2/smm/laplace.py +0 -272
  20. subset_mixture_model-0.1.2/smm/model.py +0 -125
  21. subset_mixture_model-0.1.2/smm/predictor.py +0 -60
  22. subset_mixture_model-0.1.2/smm/subset_maker.py +0 -103
  23. subset_mixture_model-0.1.2/subset_mixture_model.egg-info/PKG-INFO +0 -349
  24. subset_mixture_model-0.1.2/tests/test_smm.py +0 -259
  25. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/setup.cfg +0 -0
  26. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/dependency_links.txt +0 -0
  27. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/requires.txt +1 -1
  28. {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/top_level.txt +0 -0
@@ -0,0 +1,197 @@
1
+ Metadata-Version: 2.4
2
+ Name: subset-mixture-model
3
+ Version: 0.2.0
4
+ Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
5
+ Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
6
+ License: MIT
7
+ Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
8
+ Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.9
18
+ Description-Content-Type: text/markdown
19
+ Requires-Dist: torch>=2.0
20
+ Requires-Dist: numpy>=1.24
21
+ Requires-Dist: pandas>=2.0
22
+ Requires-Dist: scipy>=1.11
23
+ Provides-Extra: experiments
24
+ Requires-Dist: matplotlib>=3.7; extra == "experiments"
25
+ Requires-Dist: seaborn>=0.12; extra == "experiments"
26
+ Requires-Dist: joblib>=1.3; extra == "experiments"
27
+ Requires-Dist: scikit-learn>=1.3; extra == "experiments"
28
+ Requires-Dist: lightgbm>=4.0; extra == "experiments"
29
+ Requires-Dist: ngboost>=0.4; extra == "experiments"
30
+ Requires-Dist: mapie>=0.6; extra == "experiments"
31
+ Provides-Extra: dev
32
+ Requires-Dist: pytest>=7; extra == "dev"
33
+ Requires-Dist: pytest-cov; extra == "dev"
34
+
35
+ # Subset Mixture Model (SMM)
36
+
37
+ [![PyPI version](https://badge.fury.io/py/subset-mixture-model.svg)](https://pypi.org/project/subset-mixture-model/)
38
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
39
+
40
+ **SMM** is an interpretable, uncertainty-aware regression method for data with
41
+ categorical features. It learns a global convex mixture over *subset-induced
42
+ partition estimators*—one per non-empty feature subset—and returns a full
43
+ predictive distribution together with a transparent account of *why* each
44
+ prediction has the value and the uncertainty it does.
45
+
46
+ Version 0.2 upgrades the method to be genuinely probabilistic:
47
+
48
+ - **Cross-fitted weights** — subset cell means are multiway target encodings, so
49
+ the mixture weights are learned on out-of-fold statistics to avoid target
50
+ leakage (small, high-order cells no longer get to memorize the response).
51
+ - **Conjugate Student-t cells** — each cell uses a Normal-Inverse-Gamma
52
+ posterior, so its predictive component is a Student-t that widens and grows
53
+ heavy tails in sparse cells (the plug-in Gaussian is the large-sample limit).
54
+ - **Exact mixture inference** — NLL, CDF, central intervals (by bisection) and
55
+ CRPS are computed exactly from the mixture of Student-t components.
56
+ - **Four-part predictive variance** — within-cell noise, cell-estimation
57
+ uncertainty, subset-resolution disagreement, and weight uncertainty (Laplace).
58
+ - **Interpretation as diagnostics** — the learned weights are summarized by
59
+ entropy, effective number of subsets, order-level mass, and concentration,
60
+ rather than claimed to be a sparse ANOVA decomposition.
61
+
62
+ ---
63
+
64
+ ## Installation
65
+
66
+ ```bash
67
+ pip install subset-mixture-model
68
+ ```
69
+
70
+ Import as `smm`:
71
+
72
+ ```python
73
+ import smm
74
+ ```
75
+
76
+ Requires `torch>=2.0`, `numpy`, `pandas`, `scipy`. `plot_calibration` also needs
77
+ `matplotlib`.
78
+
79
+ ---
80
+
81
+ ## Quickstart
82
+
83
+ ```python
84
+ import numpy as np
85
+ import pandas as pd
86
+ from smm import SMM
87
+
88
+ # --- synthetic data: a main effect (region) + an interaction (season x tier) ---
89
+ rng = np.random.default_rng(0)
90
+ N = 2000
91
+ region = rng.integers(0, 2, N)
92
+ season = rng.integers(0, 3, N)
93
+ tier = rng.integers(0, 2, N)
94
+ y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
95
+ df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})
96
+
97
+ train, val, test = df[:1400], df[1400:1700], df[1700:]
98
+ FEATURES, TARGET = ["region", "season", "tier"], "y"
99
+
100
+ # --- fit: cross-fitted weights + conjugate Student-t cells + Laplace UQ ---
101
+ model = SMM(FEATURES, TARGET, kappa0=1.0, lam=0.5).fit(train, val)
102
+
103
+ # --- point prediction and a full predictive distribution ---
104
+ mean = model.predict(test)
105
+ mean, std = model.predict_with_uncertainty(test)
106
+ lo, hi = model.interval(test, level=0.95) # exact mixture interval
107
+
108
+ y_test = test[TARGET].values
109
+ print("RMSE :", np.sqrt(np.mean((mean - y_test) ** 2)).round(3))
110
+ print("NLL :", round(model.nll(test, y_test), 3))
111
+ print("CRPS :", round(model.crps(test, y_test), 3))
112
+ print("cov95:", round(float(((y_test >= lo) & (y_test <= hi)).mean()), 3))
113
+ ```
114
+
115
+ ### Why this prediction? Uncertainty decomposition
116
+
117
+ ```python
118
+ mean, std, aleatoric_std, epistemic_std = model.predict_with_uncertainty(
119
+ test, return_components=True
120
+ )
121
+ # std**2 == aleatoric**2 + epistemic**2 (within-cell noise vs. everything else)
122
+ ```
123
+
124
+ ### Which interactions matter? Global diagnostics
125
+
126
+ ```python
127
+ print(model.weight_table(top_k=5)) # subsets ranked by learned weight
128
+ print(model.diagnostics()) # H, N_eff, HHI, order-level mass M_k
129
+ ```
130
+
131
+ `diagnostics()` returns the concentration of the weight distribution. Diffuse
132
+ weights (large `N_eff`) are an honest signal that *no* sparse subset explanation
133
+ dominates—prediction draws on many comparable resolutions—while concentrated
134
+ weights identify a few dominant interactions.
135
+
136
+ ### Why *this* prediction? Local contributions
137
+
138
+ ```python
139
+ row = test.iloc[[0]]
140
+ print(model.explain(row)) # per-subset contributions; they sum to the prediction
141
+ ```
142
+
143
+ ---
144
+
145
+ ## What the model gives you
146
+
147
+ | Method | Returns |
148
+ |---|---|
149
+ | `SMM(features, target, ...)` | estimator; key knobs `kappa0` (shrinkage), `alpha0`,`lam` (order-aware prior), `K` (folds), `mode` (`"nig"`/`"plugin"`), `max_order` |
150
+ | `.fit(df, val_df=None)` | cross-fit → optimize weights (early-stop on `val_df`) → Laplace |
151
+ | `.predict(df)` | predictive mean |
152
+ | `.predict_with_uncertainty(df, return_components=)` | mean, std [, aleatoric, epistemic] |
153
+ | `.nll / .crps` | exact mixture proper scores |
154
+ | `.interval(df, level)` | exact central interval via bisection on the mixture CDF |
155
+ | `.weight_table / .diagnostics / .explain` | interpretability outputs |
156
+
157
+ Lower-level building blocks are also exported: `SubsetMaker`, `crossfit_components`,
158
+ `SubsetMixturePredictor`, `compute_posterior_covariance`, `predict_with_uncertainty`,
159
+ `NIGPrior`, `weight_table`, `weight_diagnostics`, `calibration_stats`.
160
+
161
+ To reproduce the first-version plug-in behavior, use
162
+ `SMM(..., mode="plugin", cross_fit=False, kappa0=0)`.
163
+
164
+ ---
165
+
166
+ ## Method summary
167
+
168
+ For each non-empty feature subset *s*, SMM groups the training data by the values
169
+ of *s* and models each resulting cell with a conjugate Normal-Inverse-Gamma
170
+ posterior, giving a Student-t predictive component. A single global simplex
171
+ weight vector π over all subsets is learned by MAP under an (optionally
172
+ order-aware) Dirichlet prior, using out-of-fold cell statistics. The predictive
173
+ distribution is the mixture ∑ₛ πₛ · tₛ, and uncertainty in π is propagated by a
174
+ Laplace approximation in the low-dimensional logit space.
175
+
176
+ SMM is intended for problems whose predictive structure is concentrated in a
177
+ modest number of naturally categorical features (full powerset for *D ≤ 8*;
178
+ order-restricted for larger *D*). It is a transparent, uncertainty-aware
179
+ alternative to gradient-boosted trees in that regime, not a general replacement.
180
+
181
+ ---
182
+
183
+ ## Citation
184
+
185
+ ```bibtex
186
+ @article{danielson2026smm,
187
+ title = {Subset Mixture Models: Interpretable Probabilistic Aggregation
188
+ of Partition Estimators for Categorical Regression},
189
+ author = {Danielson, Aaron John},
190
+ journal = {Under review},
191
+ year = {2026},
192
+ }
193
+ ```
194
+
195
+ ## License
196
+
197
+ MIT © Aaron John Danielson
@@ -0,0 +1,163 @@
1
+ # Subset Mixture Model (SMM)
2
+
3
+ [![PyPI version](https://badge.fury.io/py/subset-mixture-model.svg)](https://pypi.org/project/subset-mixture-model/)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
5
+
6
+ **SMM** is an interpretable, uncertainty-aware regression method for data with
7
+ categorical features. It learns a global convex mixture over *subset-induced
8
+ partition estimators*—one per non-empty feature subset—and returns a full
9
+ predictive distribution together with a transparent account of *why* each
10
+ prediction has the value and the uncertainty it does.
11
+
12
+ Version 0.2 upgrades the method to be genuinely probabilistic:
13
+
14
+ - **Cross-fitted weights** — subset cell means are multiway target encodings, so
15
+ the mixture weights are learned on out-of-fold statistics to avoid target
16
+ leakage (small, high-order cells no longer get to memorize the response).
17
+ - **Conjugate Student-t cells** — each cell uses a Normal-Inverse-Gamma
18
+ posterior, so its predictive component is a Student-t that widens and grows
19
+ heavy tails in sparse cells (the plug-in Gaussian is the large-sample limit).
20
+ - **Exact mixture inference** — NLL, CDF, central intervals (by bisection) and
21
+ CRPS are computed exactly from the mixture of Student-t components.
22
+ - **Four-part predictive variance** — within-cell noise, cell-estimation
23
+ uncertainty, subset-resolution disagreement, and weight uncertainty (Laplace).
24
+ - **Interpretation as diagnostics** — the learned weights are summarized by
25
+ entropy, effective number of subsets, order-level mass, and concentration,
26
+ rather than claimed to be a sparse ANOVA decomposition.
27
+
28
+ ---
29
+
30
+ ## Installation
31
+
32
+ ```bash
33
+ pip install subset-mixture-model
34
+ ```
35
+
36
+ Import as `smm`:
37
+
38
+ ```python
39
+ import smm
40
+ ```
41
+
42
+ Requires `torch>=2.0`, `numpy`, `pandas`, `scipy`. `plot_calibration` also needs
43
+ `matplotlib`.
44
+
45
+ ---
46
+
47
+ ## Quickstart
48
+
49
+ ```python
50
+ import numpy as np
51
+ import pandas as pd
52
+ from smm import SMM
53
+
54
+ # --- synthetic data: a main effect (region) + an interaction (season x tier) ---
55
+ rng = np.random.default_rng(0)
56
+ N = 2000
57
+ region = rng.integers(0, 2, N)
58
+ season = rng.integers(0, 3, N)
59
+ tier = rng.integers(0, 2, N)
60
+ y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
61
+ df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})
62
+
63
+ train, val, test = df[:1400], df[1400:1700], df[1700:]
64
+ FEATURES, TARGET = ["region", "season", "tier"], "y"
65
+
66
+ # --- fit: cross-fitted weights + conjugate Student-t cells + Laplace UQ ---
67
+ model = SMM(FEATURES, TARGET, kappa0=1.0, lam=0.5).fit(train, val)
68
+
69
+ # --- point prediction and a full predictive distribution ---
70
+ mean = model.predict(test)
71
+ mean, std = model.predict_with_uncertainty(test)
72
+ lo, hi = model.interval(test, level=0.95) # exact mixture interval
73
+
74
+ y_test = test[TARGET].values
75
+ print("RMSE :", np.sqrt(np.mean((mean - y_test) ** 2)).round(3))
76
+ print("NLL :", round(model.nll(test, y_test), 3))
77
+ print("CRPS :", round(model.crps(test, y_test), 3))
78
+ print("cov95:", round(float(((y_test >= lo) & (y_test <= hi)).mean()), 3))
79
+ ```
80
+
81
+ ### Why this prediction? Uncertainty decomposition
82
+
83
+ ```python
84
+ mean, std, aleatoric_std, epistemic_std = model.predict_with_uncertainty(
85
+ test, return_components=True
86
+ )
87
+ # std**2 == aleatoric**2 + epistemic**2 (within-cell noise vs. everything else)
88
+ ```
89
+
90
+ ### Which interactions matter? Global diagnostics
91
+
92
+ ```python
93
+ print(model.weight_table(top_k=5)) # subsets ranked by learned weight
94
+ print(model.diagnostics()) # H, N_eff, HHI, order-level mass M_k
95
+ ```
96
+
97
+ `diagnostics()` returns the concentration of the weight distribution. Diffuse
98
+ weights (large `N_eff`) are an honest signal that *no* sparse subset explanation
99
+ dominates—prediction draws on many comparable resolutions—while concentrated
100
+ weights identify a few dominant interactions.
101
+
102
+ ### Why *this* prediction? Local contributions
103
+
104
+ ```python
105
+ row = test.iloc[[0]]
106
+ print(model.explain(row)) # per-subset contributions; they sum to the prediction
107
+ ```
108
+
109
+ ---
110
+
111
+ ## What the model gives you
112
+
113
+ | Method | Returns |
114
+ |---|---|
115
+ | `SMM(features, target, ...)` | estimator; key knobs `kappa0` (shrinkage), `alpha0`,`lam` (order-aware prior), `K` (folds), `mode` (`"nig"`/`"plugin"`), `max_order` |
116
+ | `.fit(df, val_df=None)` | cross-fit → optimize weights (early-stop on `val_df`) → Laplace |
117
+ | `.predict(df)` | predictive mean |
118
+ | `.predict_with_uncertainty(df, return_components=)` | mean, std [, aleatoric, epistemic] |
119
+ | `.nll / .crps` | exact mixture proper scores |
120
+ | `.interval(df, level)` | exact central interval via bisection on the mixture CDF |
121
+ | `.weight_table / .diagnostics / .explain` | interpretability outputs |
122
+
123
+ Lower-level building blocks are also exported: `SubsetMaker`, `crossfit_components`,
124
+ `SubsetMixturePredictor`, `compute_posterior_covariance`, `predict_with_uncertainty`,
125
+ `NIGPrior`, `weight_table`, `weight_diagnostics`, `calibration_stats`.
126
+
127
+ To reproduce the first-version plug-in behavior, use
128
+ `SMM(..., mode="plugin", cross_fit=False, kappa0=0)`.
129
+
130
+ ---
131
+
132
+ ## Method summary
133
+
134
+ For each non-empty feature subset *s*, SMM groups the training data by the values
135
+ of *s* and models each resulting cell with a conjugate Normal-Inverse-Gamma
136
+ posterior, giving a Student-t predictive component. A single global simplex
137
+ weight vector π over all subsets is learned by MAP under an (optionally
138
+ order-aware) Dirichlet prior, using out-of-fold cell statistics. The predictive
139
+ distribution is the mixture ∑ₛ πₛ · tₛ, and uncertainty in π is propagated by a
140
+ Laplace approximation in the low-dimensional logit space.
141
+
142
+ SMM is intended for problems whose predictive structure is concentrated in a
143
+ modest number of naturally categorical features (full powerset for *D ≤ 8*;
144
+ order-restricted for larger *D*). It is a transparent, uncertainty-aware
145
+ alternative to gradient-boosted trees in that regime, not a general replacement.
146
+
147
+ ---
148
+
149
+ ## Citation
150
+
151
+ ```bibtex
152
+ @article{danielson2026smm,
153
+ title = {Subset Mixture Models: Interpretable Probabilistic Aggregation
154
+ of Partition Estimators for Categorical Regression},
155
+ author = {Danielson, Aaron John},
156
+ journal = {Under review},
157
+ year = {2026},
158
+ }
159
+ ```
160
+
161
+ ## License
162
+
163
+ MIT © Aaron John Danielson
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "subset-mixture-model"
7
- version = "0.1.2"
7
+ version = "0.2.0"
8
8
  description = "Interpretable empirical-Bayes aggregation of partition estimators for categorical regression"
9
9
  readme = "README.md"
10
10
  license = { text = "MIT" }
@@ -33,7 +33,6 @@ dependencies = [
33
33
  "torch>=2.0",
34
34
  "numpy>=1.24",
35
35
  "pandas>=2.0",
36
- "scikit-learn>=1.3",
37
36
  "scipy>=1.11",
38
37
  ]
39
38
 
@@ -42,6 +41,7 @@ experiments = [
42
41
  "matplotlib>=3.7",
43
42
  "seaborn>=0.12",
44
43
  "joblib>=1.3",
44
+ "scikit-learn>=1.3",
45
45
  "lightgbm>=4.0",
46
46
  "ngboost>=0.4",
47
47
  "mapie>=0.6",
@@ -1,5 +1,17 @@
1
+ """Subset Mixture Model (SMM): interpretable probabilistic aggregation of
2
+ subset-induced partition estimators for categorical regression."""
3
+
4
+ from .cells import NIGPrior
1
5
  from .subset_maker import SubsetMaker
2
- from .model import SubsetWeightsModel, SubsetDataset, subset_mixture_neg_log_posterior, subset_mixture_mse
6
+ from .model import (
7
+ SubsetWeightsModel,
8
+ SubsetDataset,
9
+ ComponentDataset,
10
+ order_aware_alpha,
11
+ subset_mixture_neg_log_posterior,
12
+ subset_mixture_mse,
13
+ )
14
+ from .crossfit import crossfit_components
3
15
  from .predictor import SubsetMixturePredictor
4
16
  from .laplace import (
5
17
  compute_posterior_covariance,
@@ -8,18 +20,27 @@ from .laplace import (
8
20
  )
9
21
  from .diagnostics import (
10
22
  weight_table,
23
+ weight_diagnostics,
11
24
  explain_prediction,
12
25
  calibration_stats,
13
26
  plot_calibration,
14
27
  )
28
+ from .smm import SMM
15
29
 
16
30
  __all__ = [
31
+ # High-level estimator
32
+ "SMM",
33
+ # Cells / prior
34
+ "NIGPrior",
17
35
  # Core
18
36
  "SubsetMaker",
19
37
  "SubsetWeightsModel",
20
38
  "SubsetDataset",
39
+ "ComponentDataset",
40
+ "order_aware_alpha",
21
41
  "subset_mixture_neg_log_posterior",
22
42
  "subset_mixture_mse",
43
+ "crossfit_components",
23
44
  "SubsetMixturePredictor",
24
45
  # Uncertainty
25
46
  "compute_posterior_covariance",
@@ -27,6 +48,7 @@ __all__ = [
27
48
  "coverage",
28
49
  # Diagnostics
29
50
  "weight_table",
51
+ "weight_diagnostics",
30
52
  "explain_prediction",
31
53
  "calibration_stats",
32
54
  "plot_calibration",
@@ -0,0 +1,183 @@
1
+ """
2
+ Conjugate Normal-Inverse-Gamma (NIG) cell model and Student-t component math.
3
+
4
+ This module implements the principled cell model of the Subset Mixture Model
5
+ (SMM). Each subset cell is modeled as Gaussian data with a Normal-Inverse-Gamma
6
+ prior on (mu, sigma^2); the posterior predictive of a new response in the cell
7
+ is Student-t (Theorem "Student-t posterior predictive" in the paper):
8
+
9
+ (mu, sigma^2) ~ NIG(m0, kappa0, a0, b0)
10
+ posterior NIG(m_n, kappa_n, a_n, b_n) given (n, ybar, S)
11
+ predictive t_{2 a_n}( m_n, b_n (kappa_n + 1) / (a_n kappa_n) )
12
+
13
+ with sufficient statistics
14
+ n = cell size,
15
+ ybar = cell mean,
16
+ S = sum of squared deviations = sum_i (y_i - ybar)^2.
17
+
18
+ The plug-in Gaussian component of the first version of the method is recovered
19
+ as the limit kappa0 -> 0, a0, b0 -> 0 (see `plugin_limit`), which this module
20
+ supports so the plug-in path remains reachable as an ablation.
21
+
22
+ Empirical-Bayes hyperparameters (default): match the prior to the global
23
+ marginal moments of the target,
24
+ m0 = global mean,
25
+ a0 = 2 (weakly informative; finite prior variance),
26
+ b0 = (a0 - 1) * global_var = global_var (so E[sigma^2] = global_var),
27
+ kappa0 = shrinkage strength (the one substantive hyperparameter).
28
+ """
29
+
30
+ import math
31
+ from dataclasses import dataclass
32
+
33
+ import numpy as np
34
+ import torch
35
+
36
+
37
+ # --------------------------------------------------------------------------- #
38
+ # Hyperparameters
39
+ # --------------------------------------------------------------------------- #
40
+ @dataclass
41
+ class NIGPrior:
42
+ """Normal-Inverse-Gamma prior hyperparameters (m0, kappa0, a0, b0)."""
43
+
44
+ m0: float
45
+ kappa0: float
46
+ a0: float
47
+ b0: float
48
+
49
+ @classmethod
50
+ def empirical_bayes(cls, y, kappa0: float = 1.0, a0: float = 2.0):
51
+ """
52
+ Build an empirical-Bayes prior by moment-matching the global marginal.
53
+
54
+ E[sigma^2] = b0 / (a0 - 1) = Var(y) -> b0 = (a0 - 1) * Var(y).
55
+ The prior mean m0 is the global mean of y. `kappa0` (the prior
56
+ pseudo-count / shrinkage strength) is the one substantive knob.
57
+ """
58
+ y = np.asarray(y, dtype=float)
59
+ y = y[~np.isnan(y)]
60
+ global_mean = float(y.mean())
61
+ global_var = float(y.var(ddof=0))
62
+ if not np.isfinite(global_var) or global_var <= 0:
63
+ global_var = 1.0
64
+ if a0 <= 1.0:
65
+ raise ValueError("a0 must be > 1 for a finite prior variance.")
66
+ b0 = (a0 - 1.0) * global_var
67
+ return cls(m0=global_mean, kappa0=float(kappa0), a0=float(a0), b0=b0)
68
+
69
+
70
+ # --------------------------------------------------------------------------- #
71
+ # Conjugate update (n, ybar, S) -> posterior -> Student-t predictive
72
+ # --------------------------------------------------------------------------- #
73
+ def nig_posterior(n, ybar, S, prior: NIGPrior):
74
+ """
75
+ Conjugate NIG update. Accepts scalars or numpy arrays (vectorized over cells).
76
+
77
+ Returns (m_n, kappa_n, a_n, b_n) with the paper's formulas:
78
+ kappa_n = kappa0 + n
79
+ m_n = (kappa0 m0 + n ybar) / (kappa0 + n)
80
+ a_n = a0 + n / 2
81
+ b_n = b0 + S/2 + (kappa0 n) / (2 (kappa0 + n)) (ybar - m0)^2
82
+ """
83
+ n = np.asarray(n, dtype=float)
84
+ ybar = np.asarray(ybar, dtype=float)
85
+ S = np.asarray(S, dtype=float)
86
+
87
+ kappa_n = prior.kappa0 + n
88
+ m_n = (prior.kappa0 * prior.m0 + n * ybar) / kappa_n
89
+ a_n = prior.a0 + n / 2.0
90
+ denom = np.where(kappa_n > 0, 2.0 * kappa_n, 1.0)
91
+ b_n = (
92
+ prior.b0
93
+ + 0.5 * S
94
+ + (prior.kappa0 * n) / denom * (ybar - prior.m0) ** 2
95
+ )
96
+ return m_n, kappa_n, a_n, b_n
97
+
98
+
99
+ def student_t_params(n, ybar, S, prior: NIGPrior):
100
+ """
101
+ Posterior-predictive Student-t parameters for each cell.
102
+
103
+ Returns a dict of numpy arrays:
104
+ loc = m_n (location)
105
+ scale2 = b_n (kappa_n + 1) / (a_n kappa_n) (squared scale, tau^2)
106
+ df = 2 a_n (degrees of freedom, nu)
107
+ plus the two moment pieces used by the four-part variance decomposition:
108
+ e_sigma2 = b_n / (a_n - 1) E[sigma^2 | D] (within-cell)
109
+ var_mu = b_n / (kappa_n (a_n - 1)) Var(mu | D) (cell-estimation)
110
+ """
111
+ m_n, kappa_n, a_n, b_n = nig_posterior(n, ybar, S, prior)
112
+ scale2 = b_n * (kappa_n + 1.0) / (a_n * kappa_n)
113
+ df = 2.0 * a_n
114
+ e_sigma2 = b_n / (a_n - 1.0)
115
+ var_mu = b_n / (kappa_n * (a_n - 1.0))
116
+ return {
117
+ "loc": m_n,
118
+ "scale2": scale2,
119
+ "df": df,
120
+ "e_sigma2": e_sigma2,
121
+ "var_mu": var_mu,
122
+ }
123
+
124
+
125
+ def plugin_params(ybar, var, eps: float = 1e-9):
126
+ """
127
+ Plug-in Gaussian component parameters (the first-version ablation).
128
+
129
+ A Gaussian is a Student-t with df = inf; we return a large finite df so the
130
+ same Student-t code path evaluates it. loc = ybar, scale2 = var.
131
+ """
132
+ ybar = np.asarray(ybar, dtype=float)
133
+ var = np.asarray(var, dtype=float)
134
+ scale2 = np.maximum(var, eps)
135
+ return {
136
+ "loc": ybar,
137
+ "scale2": scale2,
138
+ "df": np.full_like(ybar, 1e8),
139
+ "e_sigma2": scale2,
140
+ "var_mu": np.zeros_like(ybar), # plug-in omits cell-estimation error
141
+ }
142
+
143
+
144
+ # --------------------------------------------------------------------------- #
145
+ # Student-t density / cdf (torch for training, numpy for inference)
146
+ # --------------------------------------------------------------------------- #
147
+ def student_t_logpdf(y, loc, scale2, df):
148
+ """
149
+ Log density of a location-scale Student-t, elementwise (torch).
150
+
151
+ log t_nu(y | loc, s^2) =
152
+ lgamma((nu+1)/2) - lgamma(nu/2) - 0.5 log(nu pi s^2)
153
+ - (nu+1)/2 * log(1 + (y-loc)^2 / (nu s^2)).
154
+ """
155
+ z2 = (y - loc) ** 2 / scale2
156
+ half_df = 0.5 * df
157
+ return (
158
+ torch.lgamma(half_df + 0.5)
159
+ - torch.lgamma(half_df)
160
+ - 0.5 * torch.log(df * math.pi * scale2)
161
+ - (half_df + 0.5) * torch.log1p(z2 / df)
162
+ )
163
+
164
+
165
+ def student_t_logpdf_np(y, loc, scale2, df):
166
+ """Numpy version of `student_t_logpdf` (for inference / scoring)."""
167
+ from scipy.special import gammaln
168
+
169
+ z2 = (y - loc) ** 2 / scale2
170
+ half_df = 0.5 * df
171
+ return (
172
+ gammaln(half_df + 0.5)
173
+ - gammaln(half_df)
174
+ - 0.5 * np.log(df * np.pi * scale2)
175
+ - (half_df + 0.5) * np.log1p(z2 / df)
176
+ )
177
+
178
+
179
+ def student_t_cdf_np(y, loc, scale2, df):
180
+ """CDF of a location-scale Student-t (numpy, via scipy)."""
181
+ from scipy.stats import t as _t
182
+
183
+ return _t.cdf((y - loc) / np.sqrt(scale2), df=df)