subset-mixture-model 0.1.2__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- subset_mixture_model-0.2.0/PKG-INFO +197 -0
- subset_mixture_model-0.2.0/README.md +163 -0
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/pyproject.toml +2 -2
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/smm/__init__.py +23 -1
- subset_mixture_model-0.2.0/smm/cells.py +183 -0
- subset_mixture_model-0.2.0/smm/crossfit.py +89 -0
- subset_mixture_model-0.2.0/smm/diagnostics.py +91 -0
- subset_mixture_model-0.2.0/smm/laplace.py +81 -0
- subset_mixture_model-0.2.0/smm/model.py +152 -0
- subset_mixture_model-0.2.0/smm/predictor.py +157 -0
- subset_mixture_model-0.2.0/smm/smm.py +168 -0
- subset_mixture_model-0.2.0/smm/subset_maker.py +177 -0
- subset_mixture_model-0.2.0/subset_mixture_model.egg-info/PKG-INFO +197 -0
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/SOURCES.txt +3 -0
- subset_mixture_model-0.2.0/tests/test_smm.py +117 -0
- subset_mixture_model-0.1.2/PKG-INFO +0 -349
- subset_mixture_model-0.1.2/README.md +0 -315
- subset_mixture_model-0.1.2/smm/diagnostics.py +0 -201
- subset_mixture_model-0.1.2/smm/laplace.py +0 -272
- subset_mixture_model-0.1.2/smm/model.py +0 -125
- subset_mixture_model-0.1.2/smm/predictor.py +0 -60
- subset_mixture_model-0.1.2/smm/subset_maker.py +0 -103
- subset_mixture_model-0.1.2/subset_mixture_model.egg-info/PKG-INFO +0 -349
- subset_mixture_model-0.1.2/tests/test_smm.py +0 -259
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/setup.cfg +0 -0
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/dependency_links.txt +0 -0
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/requires.txt +1 -1
- {subset_mixture_model-0.1.2 → subset_mixture_model-0.2.0}/subset_mixture_model.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: subset-mixture-model
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
|
|
5
|
+
Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
|
|
8
|
+
Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
Requires-Dist: torch>=2.0
|
|
20
|
+
Requires-Dist: numpy>=1.24
|
|
21
|
+
Requires-Dist: pandas>=2.0
|
|
22
|
+
Requires-Dist: scipy>=1.11
|
|
23
|
+
Provides-Extra: experiments
|
|
24
|
+
Requires-Dist: matplotlib>=3.7; extra == "experiments"
|
|
25
|
+
Requires-Dist: seaborn>=0.12; extra == "experiments"
|
|
26
|
+
Requires-Dist: joblib>=1.3; extra == "experiments"
|
|
27
|
+
Requires-Dist: scikit-learn>=1.3; extra == "experiments"
|
|
28
|
+
Requires-Dist: lightgbm>=4.0; extra == "experiments"
|
|
29
|
+
Requires-Dist: ngboost>=0.4; extra == "experiments"
|
|
30
|
+
Requires-Dist: mapie>=0.6; extra == "experiments"
|
|
31
|
+
Provides-Extra: dev
|
|
32
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
33
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
34
|
+
|
|
35
|
+
# Subset Mixture Model (SMM)
|
|
36
|
+
|
|
37
|
+
[](https://pypi.org/project/subset-mixture-model/)
|
|
38
|
+
[](https://opensource.org/licenses/MIT)
|
|
39
|
+
|
|
40
|
+
**SMM** is an interpretable, uncertainty-aware regression method for data with
|
|
41
|
+
categorical features. It learns a global convex mixture over *subset-induced
|
|
42
|
+
partition estimators*—one per non-empty feature subset—and returns a full
|
|
43
|
+
predictive distribution together with a transparent account of *why* each
|
|
44
|
+
prediction has the value and the uncertainty it does.
|
|
45
|
+
|
|
46
|
+
Version 0.2 upgrades the method to be genuinely probabilistic:
|
|
47
|
+
|
|
48
|
+
- **Cross-fitted weights** — subset cell means are multiway target encodings, so
|
|
49
|
+
the mixture weights are learned on out-of-fold statistics to avoid target
|
|
50
|
+
leakage (small, high-order cells no longer get to memorize the response).
|
|
51
|
+
- **Conjugate Student-t cells** — each cell uses a Normal-Inverse-Gamma
|
|
52
|
+
posterior, so its predictive component is a Student-t that widens and grows
|
|
53
|
+
heavy tails in sparse cells (the plug-in Gaussian is the large-sample limit).
|
|
54
|
+
- **Exact mixture inference** — NLL, CDF, central intervals (by bisection) and
|
|
55
|
+
CRPS are computed exactly from the mixture of Student-t components.
|
|
56
|
+
- **Four-part predictive variance** — within-cell noise, cell-estimation
|
|
57
|
+
uncertainty, subset-resolution disagreement, and weight uncertainty (Laplace).
|
|
58
|
+
- **Interpretation as diagnostics** — the learned weights are summarized by
|
|
59
|
+
entropy, effective number of subsets, order-level mass, and concentration,
|
|
60
|
+
rather than claimed to be a sparse ANOVA decomposition.
|
|
61
|
+
|
|
62
|
+
---
|
|
63
|
+
|
|
64
|
+
## Installation
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
pip install subset-mixture-model
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Import as `smm`:
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
import smm
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Requires `torch>=2.0`, `numpy`, `pandas`, `scipy`. `plot_calibration` also needs
|
|
77
|
+
`matplotlib`.
|
|
78
|
+
|
|
79
|
+
---
|
|
80
|
+
|
|
81
|
+
## Quickstart
|
|
82
|
+
|
|
83
|
+
```python
|
|
84
|
+
import numpy as np
|
|
85
|
+
import pandas as pd
|
|
86
|
+
from smm import SMM
|
|
87
|
+
|
|
88
|
+
# --- synthetic data: a main effect (region) + an interaction (season x tier) ---
|
|
89
|
+
rng = np.random.default_rng(0)
|
|
90
|
+
N = 2000
|
|
91
|
+
region = rng.integers(0, 2, N)
|
|
92
|
+
season = rng.integers(0, 3, N)
|
|
93
|
+
tier = rng.integers(0, 2, N)
|
|
94
|
+
y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
|
|
95
|
+
df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})
|
|
96
|
+
|
|
97
|
+
train, val, test = df[:1400], df[1400:1700], df[1700:]
|
|
98
|
+
FEATURES, TARGET = ["region", "season", "tier"], "y"
|
|
99
|
+
|
|
100
|
+
# --- fit: cross-fitted weights + conjugate Student-t cells + Laplace UQ ---
|
|
101
|
+
model = SMM(FEATURES, TARGET, kappa0=1.0, lam=0.5).fit(train, val)
|
|
102
|
+
|
|
103
|
+
# --- point prediction and a full predictive distribution ---
|
|
104
|
+
mean = model.predict(test)
|
|
105
|
+
mean, std = model.predict_with_uncertainty(test)
|
|
106
|
+
lo, hi = model.interval(test, level=0.95) # exact mixture interval
|
|
107
|
+
|
|
108
|
+
y_test = test[TARGET].values
|
|
109
|
+
print("RMSE :", np.sqrt(np.mean((mean - y_test) ** 2)).round(3))
|
|
110
|
+
print("NLL :", round(model.nll(test, y_test), 3))
|
|
111
|
+
print("CRPS :", round(model.crps(test, y_test), 3))
|
|
112
|
+
print("cov95:", round(float(((y_test >= lo) & (y_test <= hi)).mean()), 3))
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
### Why this prediction? Uncertainty decomposition
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
mean, std, aleatoric_std, epistemic_std = model.predict_with_uncertainty(
|
|
119
|
+
test, return_components=True
|
|
120
|
+
)
|
|
121
|
+
# std**2 == aleatoric**2 + epistemic**2 (within-cell noise vs. everything else)
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
### Which interactions matter? Global diagnostics
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
print(model.weight_table(top_k=5)) # subsets ranked by learned weight
|
|
128
|
+
print(model.diagnostics()) # H, N_eff, HHI, order-level mass M_k
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
`diagnostics()` returns the concentration of the weight distribution. Diffuse
|
|
132
|
+
weights (large `N_eff`) are an honest signal that *no* sparse subset explanation
|
|
133
|
+
dominates—prediction draws on many comparable resolutions—while concentrated
|
|
134
|
+
weights identify a few dominant interactions.
|
|
135
|
+
|
|
136
|
+
### Why *this* prediction? Local contributions
|
|
137
|
+
|
|
138
|
+
```python
|
|
139
|
+
row = test.iloc[[0]]
|
|
140
|
+
print(model.explain(row)) # per-subset contributions; they sum to the prediction
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
---
|
|
144
|
+
|
|
145
|
+
## What the model gives you
|
|
146
|
+
|
|
147
|
+
| Method | Returns |
|
|
148
|
+
|---|---|
|
|
149
|
+
| `SMM(features, target, ...)` | estimator; key knobs `kappa0` (shrinkage), `alpha0`,`lam` (order-aware prior), `K` (folds), `mode` (`"nig"`/`"plugin"`), `max_order` |
|
|
150
|
+
| `.fit(df, val_df=None)` | cross-fit → optimize weights (early-stop on `val_df`) → Laplace |
|
|
151
|
+
| `.predict(df)` | predictive mean |
|
|
152
|
+
| `.predict_with_uncertainty(df, return_components=)` | mean, std [, aleatoric, epistemic] |
|
|
153
|
+
| `.nll / .crps` | exact mixture proper scores |
|
|
154
|
+
| `.interval(df, level)` | exact central interval via bisection on the mixture CDF |
|
|
155
|
+
| `.weight_table / .diagnostics / .explain` | interpretability outputs |
|
|
156
|
+
|
|
157
|
+
Lower-level building blocks are also exported: `SubsetMaker`, `crossfit_components`,
|
|
158
|
+
`SubsetMixturePredictor`, `compute_posterior_covariance`, `predict_with_uncertainty`,
|
|
159
|
+
`NIGPrior`, `weight_table`, `weight_diagnostics`, `calibration_stats`.
|
|
160
|
+
|
|
161
|
+
To reproduce the first-version plug-in behavior, use
|
|
162
|
+
`SMM(..., mode="plugin", cross_fit=False, kappa0=0)`.
|
|
163
|
+
|
|
164
|
+
---
|
|
165
|
+
|
|
166
|
+
## Method summary
|
|
167
|
+
|
|
168
|
+
For each non-empty feature subset *s*, SMM groups the training data by the values
|
|
169
|
+
of *s* and models each resulting cell with a conjugate Normal-Inverse-Gamma
|
|
170
|
+
posterior, giving a Student-t predictive component. A single global simplex
|
|
171
|
+
weight vector π over all subsets is learned by MAP under an (optionally
|
|
172
|
+
order-aware) Dirichlet prior, using out-of-fold cell statistics. The predictive
|
|
173
|
+
distribution is the mixture ∑ₛ πₛ · tₛ, and uncertainty in π is propagated by a
|
|
174
|
+
Laplace approximation in the low-dimensional logit space.
|
|
175
|
+
|
|
176
|
+
SMM is intended for problems whose predictive structure is concentrated in a
|
|
177
|
+
modest number of naturally categorical features (full powerset for *D ≤ 8*;
|
|
178
|
+
order-restricted for larger *D*). It is a transparent, uncertainty-aware
|
|
179
|
+
alternative to gradient-boosted trees in that regime, not a general replacement.
|
|
180
|
+
|
|
181
|
+
---
|
|
182
|
+
|
|
183
|
+
## Citation
|
|
184
|
+
|
|
185
|
+
```bibtex
|
|
186
|
+
@article{danielson2026smm,
|
|
187
|
+
title = {Subset Mixture Models: Interpretable Probabilistic Aggregation
|
|
188
|
+
of Partition Estimators for Categorical Regression},
|
|
189
|
+
author = {Danielson, Aaron John},
|
|
190
|
+
journal = {Under review},
|
|
191
|
+
year = {2026},
|
|
192
|
+
}
|
|
193
|
+
```
|
|
194
|
+
|
|
195
|
+
## License
|
|
196
|
+
|
|
197
|
+
MIT © Aaron John Danielson
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Subset Mixture Model (SMM)
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/subset-mixture-model/)
|
|
4
|
+
[](https://opensource.org/licenses/MIT)
|
|
5
|
+
|
|
6
|
+
**SMM** is an interpretable, uncertainty-aware regression method for data with
|
|
7
|
+
categorical features. It learns a global convex mixture over *subset-induced
|
|
8
|
+
partition estimators*—one per non-empty feature subset—and returns a full
|
|
9
|
+
predictive distribution together with a transparent account of *why* each
|
|
10
|
+
prediction has the value and the uncertainty it does.
|
|
11
|
+
|
|
12
|
+
Version 0.2 upgrades the method to be genuinely probabilistic:
|
|
13
|
+
|
|
14
|
+
- **Cross-fitted weights** — subset cell means are multiway target encodings, so
|
|
15
|
+
the mixture weights are learned on out-of-fold statistics to avoid target
|
|
16
|
+
leakage (small, high-order cells no longer get to memorize the response).
|
|
17
|
+
- **Conjugate Student-t cells** — each cell uses a Normal-Inverse-Gamma
|
|
18
|
+
posterior, so its predictive component is a Student-t that widens and grows
|
|
19
|
+
heavy tails in sparse cells (the plug-in Gaussian is the large-sample limit).
|
|
20
|
+
- **Exact mixture inference** — NLL, CDF, central intervals (by bisection) and
|
|
21
|
+
CRPS are computed exactly from the mixture of Student-t components.
|
|
22
|
+
- **Four-part predictive variance** — within-cell noise, cell-estimation
|
|
23
|
+
uncertainty, subset-resolution disagreement, and weight uncertainty (Laplace).
|
|
24
|
+
- **Interpretation as diagnostics** — the learned weights are summarized by
|
|
25
|
+
entropy, effective number of subsets, order-level mass, and concentration,
|
|
26
|
+
rather than claimed to be a sparse ANOVA decomposition.
|
|
27
|
+
|
|
28
|
+
---
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install subset-mixture-model
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Import as `smm`:
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
import smm
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Requires `torch>=2.0`, `numpy`, `pandas`, `scipy`. `plot_calibration` also needs
|
|
43
|
+
`matplotlib`.
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## Quickstart
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
import numpy as np
|
|
51
|
+
import pandas as pd
|
|
52
|
+
from smm import SMM
|
|
53
|
+
|
|
54
|
+
# --- synthetic data: a main effect (region) + an interaction (season x tier) ---
|
|
55
|
+
rng = np.random.default_rng(0)
|
|
56
|
+
N = 2000
|
|
57
|
+
region = rng.integers(0, 2, N)
|
|
58
|
+
season = rng.integers(0, 3, N)
|
|
59
|
+
tier = rng.integers(0, 2, N)
|
|
60
|
+
y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
|
|
61
|
+
df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})
|
|
62
|
+
|
|
63
|
+
train, val, test = df[:1400], df[1400:1700], df[1700:]
|
|
64
|
+
FEATURES, TARGET = ["region", "season", "tier"], "y"
|
|
65
|
+
|
|
66
|
+
# --- fit: cross-fitted weights + conjugate Student-t cells + Laplace UQ ---
|
|
67
|
+
model = SMM(FEATURES, TARGET, kappa0=1.0, lam=0.5).fit(train, val)
|
|
68
|
+
|
|
69
|
+
# --- point prediction and a full predictive distribution ---
|
|
70
|
+
mean = model.predict(test)
|
|
71
|
+
mean, std = model.predict_with_uncertainty(test)
|
|
72
|
+
lo, hi = model.interval(test, level=0.95) # exact mixture interval
|
|
73
|
+
|
|
74
|
+
y_test = test[TARGET].values
|
|
75
|
+
print("RMSE :", np.sqrt(np.mean((mean - y_test) ** 2)).round(3))
|
|
76
|
+
print("NLL :", round(model.nll(test, y_test), 3))
|
|
77
|
+
print("CRPS :", round(model.crps(test, y_test), 3))
|
|
78
|
+
print("cov95:", round(float(((y_test >= lo) & (y_test <= hi)).mean()), 3))
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Why this prediction? Uncertainty decomposition
|
|
82
|
+
|
|
83
|
+
```python
|
|
84
|
+
mean, std, aleatoric_std, epistemic_std = model.predict_with_uncertainty(
|
|
85
|
+
test, return_components=True
|
|
86
|
+
)
|
|
87
|
+
# std**2 == aleatoric**2 + epistemic**2 (within-cell noise vs. everything else)
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### Which interactions matter? Global diagnostics
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
print(model.weight_table(top_k=5)) # subsets ranked by learned weight
|
|
94
|
+
print(model.diagnostics()) # H, N_eff, HHI, order-level mass M_k
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
`diagnostics()` returns the concentration of the weight distribution. Diffuse
|
|
98
|
+
weights (large `N_eff`) are an honest signal that *no* sparse subset explanation
|
|
99
|
+
dominates—prediction draws on many comparable resolutions—while concentrated
|
|
100
|
+
weights identify a few dominant interactions.
|
|
101
|
+
|
|
102
|
+
### Why *this* prediction? Local contributions
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
row = test.iloc[[0]]
|
|
106
|
+
print(model.explain(row)) # per-subset contributions; they sum to the prediction
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
---
|
|
110
|
+
|
|
111
|
+
## What the model gives you
|
|
112
|
+
|
|
113
|
+
| Method | Returns |
|
|
114
|
+
|---|---|
|
|
115
|
+
| `SMM(features, target, ...)` | estimator; key knobs `kappa0` (shrinkage), `alpha0`,`lam` (order-aware prior), `K` (folds), `mode` (`"nig"`/`"plugin"`), `max_order` |
|
|
116
|
+
| `.fit(df, val_df=None)` | cross-fit → optimize weights (early-stop on `val_df`) → Laplace |
|
|
117
|
+
| `.predict(df)` | predictive mean |
|
|
118
|
+
| `.predict_with_uncertainty(df, return_components=)` | mean, std [, aleatoric, epistemic] |
|
|
119
|
+
| `.nll / .crps` | exact mixture proper scores |
|
|
120
|
+
| `.interval(df, level)` | exact central interval via bisection on the mixture CDF |
|
|
121
|
+
| `.weight_table / .diagnostics / .explain` | interpretability outputs |
|
|
122
|
+
|
|
123
|
+
Lower-level building blocks are also exported: `SubsetMaker`, `crossfit_components`,
|
|
124
|
+
`SubsetMixturePredictor`, `compute_posterior_covariance`, `predict_with_uncertainty`,
|
|
125
|
+
`NIGPrior`, `weight_table`, `weight_diagnostics`, `calibration_stats`.
|
|
126
|
+
|
|
127
|
+
To reproduce the first-version plug-in behavior, use
|
|
128
|
+
`SMM(..., mode="plugin", cross_fit=False, kappa0=0)`.
|
|
129
|
+
|
|
130
|
+
---
|
|
131
|
+
|
|
132
|
+
## Method summary
|
|
133
|
+
|
|
134
|
+
For each non-empty feature subset *s*, SMM groups the training data by the values
|
|
135
|
+
of *s* and models each resulting cell with a conjugate Normal-Inverse-Gamma
|
|
136
|
+
posterior, giving a Student-t predictive component. A single global simplex
|
|
137
|
+
weight vector π over all subsets is learned by MAP under an (optionally
|
|
138
|
+
order-aware) Dirichlet prior, using out-of-fold cell statistics. The predictive
|
|
139
|
+
distribution is the mixture ∑ₛ πₛ · tₛ, and uncertainty in π is propagated by a
|
|
140
|
+
Laplace approximation in the low-dimensional logit space.
|
|
141
|
+
|
|
142
|
+
SMM is intended for problems whose predictive structure is concentrated in a
|
|
143
|
+
modest number of naturally categorical features (full powerset for *D ≤ 8*;
|
|
144
|
+
order-restricted for larger *D*). It is a transparent, uncertainty-aware
|
|
145
|
+
alternative to gradient-boosted trees in that regime, not a general replacement.
|
|
146
|
+
|
|
147
|
+
---
|
|
148
|
+
|
|
149
|
+
## Citation
|
|
150
|
+
|
|
151
|
+
```bibtex
|
|
152
|
+
@article{danielson2026smm,
|
|
153
|
+
title = {Subset Mixture Models: Interpretable Probabilistic Aggregation
|
|
154
|
+
of Partition Estimators for Categorical Regression},
|
|
155
|
+
author = {Danielson, Aaron John},
|
|
156
|
+
journal = {Under review},
|
|
157
|
+
year = {2026},
|
|
158
|
+
}
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
## License
|
|
162
|
+
|
|
163
|
+
MIT © Aaron John Danielson
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "subset-mixture-model"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.0"
|
|
8
8
|
description = "Interpretable empirical-Bayes aggregation of partition estimators for categorical regression"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { text = "MIT" }
|
|
@@ -33,7 +33,6 @@ dependencies = [
|
|
|
33
33
|
"torch>=2.0",
|
|
34
34
|
"numpy>=1.24",
|
|
35
35
|
"pandas>=2.0",
|
|
36
|
-
"scikit-learn>=1.3",
|
|
37
36
|
"scipy>=1.11",
|
|
38
37
|
]
|
|
39
38
|
|
|
@@ -42,6 +41,7 @@ experiments = [
|
|
|
42
41
|
"matplotlib>=3.7",
|
|
43
42
|
"seaborn>=0.12",
|
|
44
43
|
"joblib>=1.3",
|
|
44
|
+
"scikit-learn>=1.3",
|
|
45
45
|
"lightgbm>=4.0",
|
|
46
46
|
"ngboost>=0.4",
|
|
47
47
|
"mapie>=0.6",
|
|
@@ -1,5 +1,17 @@
|
|
|
1
|
+
"""Subset Mixture Model (SMM): interpretable probabilistic aggregation of
|
|
2
|
+
subset-induced partition estimators for categorical regression."""
|
|
3
|
+
|
|
4
|
+
from .cells import NIGPrior
|
|
1
5
|
from .subset_maker import SubsetMaker
|
|
2
|
-
from .model import
|
|
6
|
+
from .model import (
|
|
7
|
+
SubsetWeightsModel,
|
|
8
|
+
SubsetDataset,
|
|
9
|
+
ComponentDataset,
|
|
10
|
+
order_aware_alpha,
|
|
11
|
+
subset_mixture_neg_log_posterior,
|
|
12
|
+
subset_mixture_mse,
|
|
13
|
+
)
|
|
14
|
+
from .crossfit import crossfit_components
|
|
3
15
|
from .predictor import SubsetMixturePredictor
|
|
4
16
|
from .laplace import (
|
|
5
17
|
compute_posterior_covariance,
|
|
@@ -8,18 +20,27 @@ from .laplace import (
|
|
|
8
20
|
)
|
|
9
21
|
from .diagnostics import (
|
|
10
22
|
weight_table,
|
|
23
|
+
weight_diagnostics,
|
|
11
24
|
explain_prediction,
|
|
12
25
|
calibration_stats,
|
|
13
26
|
plot_calibration,
|
|
14
27
|
)
|
|
28
|
+
from .smm import SMM
|
|
15
29
|
|
|
16
30
|
__all__ = [
|
|
31
|
+
# High-level estimator
|
|
32
|
+
"SMM",
|
|
33
|
+
# Cells / prior
|
|
34
|
+
"NIGPrior",
|
|
17
35
|
# Core
|
|
18
36
|
"SubsetMaker",
|
|
19
37
|
"SubsetWeightsModel",
|
|
20
38
|
"SubsetDataset",
|
|
39
|
+
"ComponentDataset",
|
|
40
|
+
"order_aware_alpha",
|
|
21
41
|
"subset_mixture_neg_log_posterior",
|
|
22
42
|
"subset_mixture_mse",
|
|
43
|
+
"crossfit_components",
|
|
23
44
|
"SubsetMixturePredictor",
|
|
24
45
|
# Uncertainty
|
|
25
46
|
"compute_posterior_covariance",
|
|
@@ -27,6 +48,7 @@ __all__ = [
|
|
|
27
48
|
"coverage",
|
|
28
49
|
# Diagnostics
|
|
29
50
|
"weight_table",
|
|
51
|
+
"weight_diagnostics",
|
|
30
52
|
"explain_prediction",
|
|
31
53
|
"calibration_stats",
|
|
32
54
|
"plot_calibration",
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Conjugate Normal-Inverse-Gamma (NIG) cell model and Student-t component math.
|
|
3
|
+
|
|
4
|
+
This module implements the principled cell model of the Subset Mixture Model
|
|
5
|
+
(SMM). Each subset cell is modeled as Gaussian data with a Normal-Inverse-Gamma
|
|
6
|
+
prior on (mu, sigma^2); the posterior predictive of a new response in the cell
|
|
7
|
+
is Student-t (Theorem "Student-t posterior predictive" in the paper):
|
|
8
|
+
|
|
9
|
+
(mu, sigma^2) ~ NIG(m0, kappa0, a0, b0)
|
|
10
|
+
posterior NIG(m_n, kappa_n, a_n, b_n) given (n, ybar, S)
|
|
11
|
+
predictive t_{2 a_n}( m_n, b_n (kappa_n + 1) / (a_n kappa_n) )
|
|
12
|
+
|
|
13
|
+
with sufficient statistics
|
|
14
|
+
n = cell size,
|
|
15
|
+
ybar = cell mean,
|
|
16
|
+
S = sum of squared deviations = sum_i (y_i - ybar)^2.
|
|
17
|
+
|
|
18
|
+
The plug-in Gaussian component of the first version of the method is recovered
|
|
19
|
+
as the limit kappa0 -> 0, a0, b0 -> 0 (see `plugin_limit`), which this module
|
|
20
|
+
supports so the plug-in path remains reachable as an ablation.
|
|
21
|
+
|
|
22
|
+
Empirical-Bayes hyperparameters (default): match the prior to the global
|
|
23
|
+
marginal moments of the target,
|
|
24
|
+
m0 = global mean,
|
|
25
|
+
a0 = 2 (weakly informative; finite prior variance),
|
|
26
|
+
b0 = (a0 - 1) * global_var = global_var (so E[sigma^2] = global_var),
|
|
27
|
+
kappa0 = shrinkage strength (the one substantive hyperparameter).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import math
|
|
31
|
+
from dataclasses import dataclass
|
|
32
|
+
|
|
33
|
+
import numpy as np
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# --------------------------------------------------------------------------- #
|
|
38
|
+
# Hyperparameters
|
|
39
|
+
# --------------------------------------------------------------------------- #
|
|
40
|
+
@dataclass
|
|
41
|
+
class NIGPrior:
|
|
42
|
+
"""Normal-Inverse-Gamma prior hyperparameters (m0, kappa0, a0, b0)."""
|
|
43
|
+
|
|
44
|
+
m0: float
|
|
45
|
+
kappa0: float
|
|
46
|
+
a0: float
|
|
47
|
+
b0: float
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def empirical_bayes(cls, y, kappa0: float = 1.0, a0: float = 2.0):
|
|
51
|
+
"""
|
|
52
|
+
Build an empirical-Bayes prior by moment-matching the global marginal.
|
|
53
|
+
|
|
54
|
+
E[sigma^2] = b0 / (a0 - 1) = Var(y) -> b0 = (a0 - 1) * Var(y).
|
|
55
|
+
The prior mean m0 is the global mean of y. `kappa0` (the prior
|
|
56
|
+
pseudo-count / shrinkage strength) is the one substantive knob.
|
|
57
|
+
"""
|
|
58
|
+
y = np.asarray(y, dtype=float)
|
|
59
|
+
y = y[~np.isnan(y)]
|
|
60
|
+
global_mean = float(y.mean())
|
|
61
|
+
global_var = float(y.var(ddof=0))
|
|
62
|
+
if not np.isfinite(global_var) or global_var <= 0:
|
|
63
|
+
global_var = 1.0
|
|
64
|
+
if a0 <= 1.0:
|
|
65
|
+
raise ValueError("a0 must be > 1 for a finite prior variance.")
|
|
66
|
+
b0 = (a0 - 1.0) * global_var
|
|
67
|
+
return cls(m0=global_mean, kappa0=float(kappa0), a0=float(a0), b0=b0)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# --------------------------------------------------------------------------- #
|
|
71
|
+
# Conjugate update (n, ybar, S) -> posterior -> Student-t predictive
|
|
72
|
+
# --------------------------------------------------------------------------- #
|
|
73
|
+
def nig_posterior(n, ybar, S, prior: NIGPrior):
|
|
74
|
+
"""
|
|
75
|
+
Conjugate NIG update. Accepts scalars or numpy arrays (vectorized over cells).
|
|
76
|
+
|
|
77
|
+
Returns (m_n, kappa_n, a_n, b_n) with the paper's formulas:
|
|
78
|
+
kappa_n = kappa0 + n
|
|
79
|
+
m_n = (kappa0 m0 + n ybar) / (kappa0 + n)
|
|
80
|
+
a_n = a0 + n / 2
|
|
81
|
+
b_n = b0 + S/2 + (kappa0 n) / (2 (kappa0 + n)) (ybar - m0)^2
|
|
82
|
+
"""
|
|
83
|
+
n = np.asarray(n, dtype=float)
|
|
84
|
+
ybar = np.asarray(ybar, dtype=float)
|
|
85
|
+
S = np.asarray(S, dtype=float)
|
|
86
|
+
|
|
87
|
+
kappa_n = prior.kappa0 + n
|
|
88
|
+
m_n = (prior.kappa0 * prior.m0 + n * ybar) / kappa_n
|
|
89
|
+
a_n = prior.a0 + n / 2.0
|
|
90
|
+
denom = np.where(kappa_n > 0, 2.0 * kappa_n, 1.0)
|
|
91
|
+
b_n = (
|
|
92
|
+
prior.b0
|
|
93
|
+
+ 0.5 * S
|
|
94
|
+
+ (prior.kappa0 * n) / denom * (ybar - prior.m0) ** 2
|
|
95
|
+
)
|
|
96
|
+
return m_n, kappa_n, a_n, b_n
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def student_t_params(n, ybar, S, prior: NIGPrior):
|
|
100
|
+
"""
|
|
101
|
+
Posterior-predictive Student-t parameters for each cell.
|
|
102
|
+
|
|
103
|
+
Returns a dict of numpy arrays:
|
|
104
|
+
loc = m_n (location)
|
|
105
|
+
scale2 = b_n (kappa_n + 1) / (a_n kappa_n) (squared scale, tau^2)
|
|
106
|
+
df = 2 a_n (degrees of freedom, nu)
|
|
107
|
+
plus the two moment pieces used by the four-part variance decomposition:
|
|
108
|
+
e_sigma2 = b_n / (a_n - 1) E[sigma^2 | D] (within-cell)
|
|
109
|
+
var_mu = b_n / (kappa_n (a_n - 1)) Var(mu | D) (cell-estimation)
|
|
110
|
+
"""
|
|
111
|
+
m_n, kappa_n, a_n, b_n = nig_posterior(n, ybar, S, prior)
|
|
112
|
+
scale2 = b_n * (kappa_n + 1.0) / (a_n * kappa_n)
|
|
113
|
+
df = 2.0 * a_n
|
|
114
|
+
e_sigma2 = b_n / (a_n - 1.0)
|
|
115
|
+
var_mu = b_n / (kappa_n * (a_n - 1.0))
|
|
116
|
+
return {
|
|
117
|
+
"loc": m_n,
|
|
118
|
+
"scale2": scale2,
|
|
119
|
+
"df": df,
|
|
120
|
+
"e_sigma2": e_sigma2,
|
|
121
|
+
"var_mu": var_mu,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def plugin_params(ybar, var, eps: float = 1e-9):
|
|
126
|
+
"""
|
|
127
|
+
Plug-in Gaussian component parameters (the first-version ablation).
|
|
128
|
+
|
|
129
|
+
A Gaussian is a Student-t with df = inf; we return a large finite df so the
|
|
130
|
+
same Student-t code path evaluates it. loc = ybar, scale2 = var.
|
|
131
|
+
"""
|
|
132
|
+
ybar = np.asarray(ybar, dtype=float)
|
|
133
|
+
var = np.asarray(var, dtype=float)
|
|
134
|
+
scale2 = np.maximum(var, eps)
|
|
135
|
+
return {
|
|
136
|
+
"loc": ybar,
|
|
137
|
+
"scale2": scale2,
|
|
138
|
+
"df": np.full_like(ybar, 1e8),
|
|
139
|
+
"e_sigma2": scale2,
|
|
140
|
+
"var_mu": np.zeros_like(ybar), # plug-in omits cell-estimation error
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# --------------------------------------------------------------------------- #
|
|
145
|
+
# Student-t density / cdf (torch for training, numpy for inference)
|
|
146
|
+
# --------------------------------------------------------------------------- #
|
|
147
|
+
def student_t_logpdf(y, loc, scale2, df):
|
|
148
|
+
"""
|
|
149
|
+
Log density of a location-scale Student-t, elementwise (torch).
|
|
150
|
+
|
|
151
|
+
log t_nu(y | loc, s^2) =
|
|
152
|
+
lgamma((nu+1)/2) - lgamma(nu/2) - 0.5 log(nu pi s^2)
|
|
153
|
+
- (nu+1)/2 * log(1 + (y-loc)^2 / (nu s^2)).
|
|
154
|
+
"""
|
|
155
|
+
z2 = (y - loc) ** 2 / scale2
|
|
156
|
+
half_df = 0.5 * df
|
|
157
|
+
return (
|
|
158
|
+
torch.lgamma(half_df + 0.5)
|
|
159
|
+
- torch.lgamma(half_df)
|
|
160
|
+
- 0.5 * torch.log(df * math.pi * scale2)
|
|
161
|
+
- (half_df + 0.5) * torch.log1p(z2 / df)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def student_t_logpdf_np(y, loc, scale2, df):
|
|
166
|
+
"""Numpy version of `student_t_logpdf` (for inference / scoring)."""
|
|
167
|
+
from scipy.special import gammaln
|
|
168
|
+
|
|
169
|
+
z2 = (y - loc) ** 2 / scale2
|
|
170
|
+
half_df = 0.5 * df
|
|
171
|
+
return (
|
|
172
|
+
gammaln(half_df + 0.5)
|
|
173
|
+
- gammaln(half_df)
|
|
174
|
+
- 0.5 * np.log(df * np.pi * scale2)
|
|
175
|
+
- (half_df + 0.5) * np.log1p(z2 / df)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def student_t_cdf_np(y, loc, scale2, df):
|
|
180
|
+
"""CDF of a location-scale Student-t (numpy, via scipy)."""
|
|
181
|
+
from scipy.stats import t as _t
|
|
182
|
+
|
|
183
|
+
return _t.cdf((y - loc) / np.sqrt(scale2), df=df)
|