subset-mixture-model 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- subset_mixture_model-0.1.0/PKG-INFO +152 -0
- subset_mixture_model-0.1.0/README.md +118 -0
- subset_mixture_model-0.1.0/pyproject.toml +59 -0
- subset_mixture_model-0.1.0/setup.cfg +4 -0
- subset_mixture_model-0.1.0/smm/__init__.py +20 -0
- subset_mixture_model-0.1.0/smm/laplace.py +266 -0
- subset_mixture_model-0.1.0/smm/model.py +125 -0
- subset_mixture_model-0.1.0/smm/predictor.py +60 -0
- subset_mixture_model-0.1.0/smm/subset_maker.py +103 -0
- subset_mixture_model-0.1.0/subset_mixture_model.egg-info/PKG-INFO +152 -0
- subset_mixture_model-0.1.0/subset_mixture_model.egg-info/SOURCES.txt +13 -0
- subset_mixture_model-0.1.0/subset_mixture_model.egg-info/dependency_links.txt +1 -0
- subset_mixture_model-0.1.0/subset_mixture_model.egg-info/requires.txt +17 -0
- subset_mixture_model-0.1.0/subset_mixture_model.egg-info/top_level.txt +1 -0
- subset_mixture_model-0.1.0/tests/test_smm.py +259 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: subset-mixture-model
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
|
|
5
|
+
Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
|
|
8
|
+
Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
Requires-Dist: torch>=2.0
|
|
20
|
+
Requires-Dist: numpy>=1.24
|
|
21
|
+
Requires-Dist: pandas>=2.0
|
|
22
|
+
Requires-Dist: scikit-learn>=1.3
|
|
23
|
+
Requires-Dist: scipy>=1.11
|
|
24
|
+
Provides-Extra: experiments
|
|
25
|
+
Requires-Dist: matplotlib>=3.7; extra == "experiments"
|
|
26
|
+
Requires-Dist: seaborn>=0.12; extra == "experiments"
|
|
27
|
+
Requires-Dist: joblib>=1.3; extra == "experiments"
|
|
28
|
+
Requires-Dist: lightgbm>=4.0; extra == "experiments"
|
|
29
|
+
Requires-Dist: ngboost>=0.4; extra == "experiments"
|
|
30
|
+
Requires-Dist: mapie>=0.6; extra == "experiments"
|
|
31
|
+
Provides-Extra: dev
|
|
32
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
33
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
34
|
+
|
|
35
|
+
# Subset Mixture Model (SMM)
|
|
36
|
+
|
|
37
|
+
**SMM** is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.
|
|
38
|
+
|
|
39
|
+
## Key idea
|
|
40
|
+
|
|
41
|
+
Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:
|
|
42
|
+
|
|
43
|
+
$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$
|
|
44
|
+
|
|
45
|
+
The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.
|
|
46
|
+
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install subset-mixture-model
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
Or from source:
|
|
54
|
+
|
|
55
|
+
```bash
|
|
56
|
+
git clone https://github.com/aaronjdanielson/subset-mixture-model
|
|
57
|
+
cd subset-mixture-model
|
|
58
|
+
pip install -e .
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Quick start
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import pandas as pd
|
|
65
|
+
import torch
|
|
66
|
+
import torch.nn.functional as F
|
|
67
|
+
from torch.utils.data import DataLoader
|
|
68
|
+
|
|
69
|
+
from smm import (
|
|
70
|
+
SubsetMaker, SubsetWeightsModel, SubsetDataset,
|
|
71
|
+
subset_mixture_neg_log_posterior, SubsetMixturePredictor,
|
|
72
|
+
compute_posterior_covariance, predict_with_uncertainty,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# --- your data (integer-coded categorical features) ---
|
|
76
|
+
train_df = pd.read_csv("train.csv")
|
|
77
|
+
val_df = pd.read_csv("val.csv")
|
|
78
|
+
test_df = pd.read_csv("test.csv")
|
|
79
|
+
|
|
80
|
+
cat_cols = ["feature_a", "feature_b", "feature_c"]
|
|
81
|
+
target = "y"
|
|
82
|
+
|
|
83
|
+
# --- build lookup table ---
|
|
84
|
+
subset_maker = SubsetMaker(train_df, cat_cols, [target])
|
|
85
|
+
n_subsets = len(subset_maker.lookup)
|
|
86
|
+
|
|
87
|
+
# --- train ---
|
|
88
|
+
model = SubsetWeightsModel(n_subsets)
|
|
89
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
|
|
90
|
+
loader = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
|
|
91
|
+
batch_size=64, shuffle=True)
|
|
92
|
+
|
|
93
|
+
for epoch in range(100):
|
|
94
|
+
for x, y in loader:
|
|
95
|
+
optimizer.zero_grad()
|
|
96
|
+
mus, variances, mask = subset_maker.batch_lookup(x)
|
|
97
|
+
loss = subset_mixture_neg_log_posterior(
|
|
98
|
+
model(), y, mus, variances, mask, alpha=1.1)
|
|
99
|
+
loss.backward()
|
|
100
|
+
optimizer.step()
|
|
101
|
+
|
|
102
|
+
# --- predict with uncertainty ---
|
|
103
|
+
pi_hat = F.softmax(model.eta.detach(), dim=0)
|
|
104
|
+
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
|
|
105
|
+
sigma_pi = compute_posterior_covariance(
|
|
106
|
+
model, subset_maker, train_df, cat_cols, target, alpha=1.1)
|
|
107
|
+
|
|
108
|
+
y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
|
|
109
|
+
# y_mean: point predictions
|
|
110
|
+
# y_std: total predictive standard deviation (aleatoric + epistemic)
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Interpretability
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
import numpy as np
|
|
117
|
+
|
|
118
|
+
subsets = list(subset_maker.lookup.keys())
|
|
119
|
+
top_idx = np.argsort(pi_hat.numpy())[::-1][:10]
|
|
120
|
+
|
|
121
|
+
for rank, i in enumerate(top_idx):
|
|
122
|
+
print(f"{rank+1:2d}. {subsets[i]} π={pi_hat[i]:.4f}")
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
## Features
|
|
126
|
+
|
|
127
|
+
- **Interpretable by construction**: learned weights reveal which feature interactions matter
|
|
128
|
+
- **Principled uncertainty**: aleatoric/epistemic decomposition via Laplace approximation
|
|
129
|
+
- **Efficient training**: only $2^D - 1$ logits optimized; lookup table precomputed once
|
|
130
|
+
- **No post-hoc calibration**: well-calibrated predictive intervals out of the box
|
|
131
|
+
- **Scalable to D ≤ 15** features; $k$-way truncation available for larger $D$
|
|
132
|
+
|
|
133
|
+
## Datasets supported
|
|
134
|
+
|
|
135
|
+
Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.
|
|
136
|
+
|
|
137
|
+
## Citation
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@article{danielson2025smm,
|
|
141
|
+
title = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
|
|
142
|
+
of Partition Estimators for Categorical Regression},
|
|
143
|
+
author = {Danielson, Aaron John},
|
|
144
|
+
journal = {Machine Learning},
|
|
145
|
+
year = {2025},
|
|
146
|
+
note = {Under review}
|
|
147
|
+
}
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## License
|
|
151
|
+
|
|
152
|
+
MIT
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Subset Mixture Model (SMM)
|
|
2
|
+
|
|
3
|
+
**SMM** is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.
|
|
4
|
+
|
|
5
|
+
## Key idea
|
|
6
|
+
|
|
7
|
+
Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:
|
|
8
|
+
|
|
9
|
+
$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$
|
|
10
|
+
|
|
11
|
+
The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.
|
|
12
|
+
|
|
13
|
+
## Installation
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
pip install subset-mixture-model
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Or from source:
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
git clone https://github.com/aaronjdanielson/subset-mixture-model
|
|
23
|
+
cd subset-mixture-model
|
|
24
|
+
pip install -e .
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Quick start
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
import pandas as pd
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn.functional as F
|
|
33
|
+
from torch.utils.data import DataLoader
|
|
34
|
+
|
|
35
|
+
from smm import (
|
|
36
|
+
SubsetMaker, SubsetWeightsModel, SubsetDataset,
|
|
37
|
+
subset_mixture_neg_log_posterior, SubsetMixturePredictor,
|
|
38
|
+
compute_posterior_covariance, predict_with_uncertainty,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# --- your data (integer-coded categorical features) ---
|
|
42
|
+
train_df = pd.read_csv("train.csv")
|
|
43
|
+
val_df = pd.read_csv("val.csv")
|
|
44
|
+
test_df = pd.read_csv("test.csv")
|
|
45
|
+
|
|
46
|
+
cat_cols = ["feature_a", "feature_b", "feature_c"]
|
|
47
|
+
target = "y"
|
|
48
|
+
|
|
49
|
+
# --- build lookup table ---
|
|
50
|
+
subset_maker = SubsetMaker(train_df, cat_cols, [target])
|
|
51
|
+
n_subsets = len(subset_maker.lookup)
|
|
52
|
+
|
|
53
|
+
# --- train ---
|
|
54
|
+
model = SubsetWeightsModel(n_subsets)
|
|
55
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
|
|
56
|
+
loader = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
|
|
57
|
+
batch_size=64, shuffle=True)
|
|
58
|
+
|
|
59
|
+
for epoch in range(100):
|
|
60
|
+
for x, y in loader:
|
|
61
|
+
optimizer.zero_grad()
|
|
62
|
+
mus, variances, mask = subset_maker.batch_lookup(x)
|
|
63
|
+
loss = subset_mixture_neg_log_posterior(
|
|
64
|
+
model(), y, mus, variances, mask, alpha=1.1)
|
|
65
|
+
loss.backward()
|
|
66
|
+
optimizer.step()
|
|
67
|
+
|
|
68
|
+
# --- predict with uncertainty ---
|
|
69
|
+
pi_hat = F.softmax(model.eta.detach(), dim=0)
|
|
70
|
+
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
|
|
71
|
+
sigma_pi = compute_posterior_covariance(
|
|
72
|
+
model, subset_maker, train_df, cat_cols, target, alpha=1.1)
|
|
73
|
+
|
|
74
|
+
y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
|
|
75
|
+
# y_mean: point predictions
|
|
76
|
+
# y_std: total predictive standard deviation (aleatoric + epistemic)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## Interpretability
|
|
80
|
+
|
|
81
|
+
```python
|
|
82
|
+
import numpy as np
|
|
83
|
+
|
|
84
|
+
subsets = list(subset_maker.lookup.keys())
|
|
85
|
+
top_idx = np.argsort(pi_hat.numpy())[::-1][:10]
|
|
86
|
+
|
|
87
|
+
for rank, i in enumerate(top_idx):
|
|
88
|
+
print(f"{rank+1:2d}. {subsets[i]} π={pi_hat[i]:.4f}")
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## Features
|
|
92
|
+
|
|
93
|
+
- **Interpretable by construction**: learned weights reveal which feature interactions matter
|
|
94
|
+
- **Principled uncertainty**: aleatoric/epistemic decomposition via Laplace approximation
|
|
95
|
+
- **Efficient training**: only $2^D - 1$ logits optimized; lookup table precomputed once
|
|
96
|
+
- **No post-hoc calibration**: well-calibrated predictive intervals out of the box
|
|
97
|
+
- **Scalable to D ≤ 15** features; $k$-way truncation available for larger $D$
|
|
98
|
+
|
|
99
|
+
## Datasets supported
|
|
100
|
+
|
|
101
|
+
Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.
|
|
102
|
+
|
|
103
|
+
## Citation
|
|
104
|
+
|
|
105
|
+
```bibtex
|
|
106
|
+
@article{danielson2025smm,
|
|
107
|
+
title = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
|
|
108
|
+
of Partition Estimators for Categorical Regression},
|
|
109
|
+
author = {Danielson, Aaron John},
|
|
110
|
+
journal = {Machine Learning},
|
|
111
|
+
year = {2025},
|
|
112
|
+
note = {Under review}
|
|
113
|
+
}
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
## License
|
|
117
|
+
|
|
118
|
+
MIT
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "subset-mixture-model"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Interpretable empirical-Bayes aggregation of partition estimators for categorical regression"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { text = "MIT" }
|
|
11
|
+
authors = [
|
|
12
|
+
{ name = "Aaron John Danielson", email = "aaron.danielson@austin.utexas.edu" },
|
|
13
|
+
]
|
|
14
|
+
keywords = [
|
|
15
|
+
"interpretable machine learning",
|
|
16
|
+
"categorical features",
|
|
17
|
+
"empirical Bayes",
|
|
18
|
+
"uncertainty quantification",
|
|
19
|
+
"mixture model",
|
|
20
|
+
]
|
|
21
|
+
classifiers = [
|
|
22
|
+
"Development Status :: 4 - Beta",
|
|
23
|
+
"Intended Audience :: Science/Research",
|
|
24
|
+
"License :: OSI Approved :: MIT License",
|
|
25
|
+
"Programming Language :: Python :: 3",
|
|
26
|
+
"Programming Language :: Python :: 3.9",
|
|
27
|
+
"Programming Language :: Python :: 3.10",
|
|
28
|
+
"Programming Language :: Python :: 3.11",
|
|
29
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
30
|
+
]
|
|
31
|
+
requires-python = ">=3.9"
|
|
32
|
+
dependencies = [
|
|
33
|
+
"torch>=2.0",
|
|
34
|
+
"numpy>=1.24",
|
|
35
|
+
"pandas>=2.0",
|
|
36
|
+
"scikit-learn>=1.3",
|
|
37
|
+
"scipy>=1.11",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
[project.optional-dependencies]
|
|
41
|
+
experiments = [
|
|
42
|
+
"matplotlib>=3.7",
|
|
43
|
+
"seaborn>=0.12",
|
|
44
|
+
"joblib>=1.3",
|
|
45
|
+
"lightgbm>=4.0",
|
|
46
|
+
"ngboost>=0.4",
|
|
47
|
+
"mapie>=0.6",
|
|
48
|
+
]
|
|
49
|
+
dev = [
|
|
50
|
+
"pytest>=7",
|
|
51
|
+
"pytest-cov",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
[project.urls]
|
|
55
|
+
Repository = "https://github.com/aaronjdanielson/subset-mixture-model"
|
|
56
|
+
|
|
57
|
+
[tool.setuptools.packages.find]
|
|
58
|
+
where = ["."]
|
|
59
|
+
include = ["smm*"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .subset_maker import SubsetMaker
|
|
2
|
+
from .model import SubsetWeightsModel, SubsetDataset, subset_mixture_neg_log_posterior, subset_mixture_mse
|
|
3
|
+
from .predictor import SubsetMixturePredictor
|
|
4
|
+
from .laplace import (
|
|
5
|
+
compute_posterior_covariance,
|
|
6
|
+
predict_with_uncertainty,
|
|
7
|
+
coverage,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"SubsetMaker",
|
|
12
|
+
"SubsetWeightsModel",
|
|
13
|
+
"SubsetDataset",
|
|
14
|
+
"subset_mixture_neg_log_posterior",
|
|
15
|
+
"subset_mixture_mse",
|
|
16
|
+
"SubsetMixturePredictor",
|
|
17
|
+
"compute_posterior_covariance",
|
|
18
|
+
"predict_with_uncertainty",
|
|
19
|
+
"coverage",
|
|
20
|
+
]
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Laplace approximation for posterior uncertainty in the Subset Mixture Model.
|
|
3
|
+
|
|
4
|
+
After MAP estimation of eta (the logit weight vector), the Laplace approximation
|
|
5
|
+
treats the negative log-posterior as a quadratic around the MAP:
|
|
6
|
+
|
|
7
|
+
q(eta) = N(eta | eta_hat, H^{-1})
|
|
8
|
+
|
|
9
|
+
where H is the Hessian of the negative log-posterior at eta_hat.
|
|
10
|
+
|
|
11
|
+
Transforming to the simplex via pi = softmax(eta) and applying the delta method:
|
|
12
|
+
|
|
13
|
+
Cov[pi | D] ≈ Sigma_pi = J H^{-1} J^T
|
|
14
|
+
|
|
15
|
+
where J = d(pi)/d(eta) is the |S| x |S| softmax Jacobian.
|
|
16
|
+
|
|
17
|
+
Predictive variance for a new point x_tilde then decomposes as:
|
|
18
|
+
|
|
19
|
+
Var[y | x_tilde] ≈ sum_s pi_s * sigma^2_s(x_tilde) [aleatoric]
|
|
20
|
+
+ mu_{x_tilde}^T Sigma_pi mu_{x_tilde} [epistemic]
|
|
21
|
+
|
|
22
|
+
References:
|
|
23
|
+
MacKay, D.J.C. (1992). A Practical Bayesian Framework for Backpropagation Networks.
|
|
24
|
+
Daxberger et al. (2021). Laplace Redux -- Effortless Bayesian Deep Learning.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
29
|
+
import torch.nn.functional as F
|
|
30
|
+
from torch.utils.data import DataLoader
|
|
31
|
+
|
|
32
|
+
from .model import SubsetDataset, subset_mixture_neg_log_posterior
|
|
33
|
+
from .subset_maker import SubsetMaker
|
|
34
|
+
from .predictor import SubsetMixturePredictor
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
# Step 1: Compute posterior covariance of pi
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
|
|
41
|
+
def softmax_jacobian(pi: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
"""
|
|
43
|
+
Jacobian of softmax(eta) w.r.t. eta, evaluated at pi = softmax(eta).
|
|
44
|
+
|
|
45
|
+
J_ij = pi_i * (delta_ij - pi_j)
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
pi: [S] mixture weights (must sum to 1).
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
J: [S, S] Jacobian matrix.
|
|
52
|
+
"""
|
|
53
|
+
return torch.diag(pi) - torch.outer(pi, pi)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_hessian(
|
|
57
|
+
model: torch.nn.Module,
|
|
58
|
+
subset_maker: SubsetMaker,
|
|
59
|
+
train_df,
|
|
60
|
+
cat_cols: list,
|
|
61
|
+
target: str,
|
|
62
|
+
alpha: float = 1.1,
|
|
63
|
+
batch_size: int = 512,
|
|
64
|
+
) -> torch.Tensor:
|
|
65
|
+
"""
|
|
66
|
+
Compute the exact Hessian of the negative log-posterior w.r.t. eta.
|
|
67
|
+
|
|
68
|
+
Because |S| is at most a few hundred, the Hessian is small and can be
|
|
69
|
+
computed exactly via torch.autograd.functional.hessian on the full
|
|
70
|
+
training set (loaded in batches if needed, accumulated via the empirical
|
|
71
|
+
Fisher for very large datasets).
|
|
72
|
+
|
|
73
|
+
For datasets where n is large we accumulate an approximation via the
|
|
74
|
+
diagonal outer-product (empirical Fisher). For the exact Hessian we need
|
|
75
|
+
the full batch, which is feasible when n < ~20K.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
model: Trained SubsetWeightsModel.
|
|
79
|
+
subset_maker: Fitted SubsetMaker.
|
|
80
|
+
train_df: Training DataFrame (integer-encoded).
|
|
81
|
+
cat_cols: Categorical feature column names.
|
|
82
|
+
target: Target column name.
|
|
83
|
+
alpha: Dirichlet concentration (must match training value).
|
|
84
|
+
batch_size: Batch size for loading data (only affects memory).
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
H: [S, S] Hessian tensor.
|
|
88
|
+
"""
|
|
89
|
+
model.eval()
|
|
90
|
+
eta_hat = model.eta.detach().clone()
|
|
91
|
+
|
|
92
|
+
# Load all training data (full batch for exact Hessian)
|
|
93
|
+
ds = SubsetDataset(train_df, cat_cols, [target])
|
|
94
|
+
loader = DataLoader(ds, batch_size=len(ds), shuffle=False)
|
|
95
|
+
x_all, y_all = next(iter(loader))
|
|
96
|
+
|
|
97
|
+
mus, variances, mask = subset_maker.batch_lookup(x_all)
|
|
98
|
+
# Detach and keep on CPU
|
|
99
|
+
mus = mus.detach()
|
|
100
|
+
variances = variances.detach()
|
|
101
|
+
mask = mask.detach()
|
|
102
|
+
|
|
103
|
+
def loss_fn(eta):
|
|
104
|
+
return subset_mixture_neg_log_posterior(
|
|
105
|
+
eta, y_all, mus, variances, mask, alpha=alpha
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
eta_param = eta_hat.requires_grad_(True)
|
|
109
|
+
H = torch.autograd.functional.hessian(loss_fn, eta_param) # [S, S]
|
|
110
|
+
return H.detach()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def compute_posterior_covariance(
|
|
114
|
+
model: torch.nn.Module,
|
|
115
|
+
subset_maker: SubsetMaker,
|
|
116
|
+
train_df,
|
|
117
|
+
cat_cols: list,
|
|
118
|
+
target: str,
|
|
119
|
+
alpha: float = 1.1,
|
|
120
|
+
hessian_reg: float = 1e-4,
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
"""
|
|
123
|
+
Compute the Laplace approximation to the posterior covariance of pi.
|
|
124
|
+
|
|
125
|
+
Sigma_pi = J H^{-1} J^T
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
model: Trained SubsetWeightsModel (provides eta_hat).
|
|
129
|
+
subset_maker: Fitted SubsetMaker.
|
|
130
|
+
train_df: Training DataFrame (integer-encoded).
|
|
131
|
+
cat_cols: Categorical feature column names.
|
|
132
|
+
target: Target column name.
|
|
133
|
+
alpha: Dirichlet concentration (must match training value).
|
|
134
|
+
hessian_reg: Ridge added to H before inversion for numerical stability.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
sigma_pi: [S, S] posterior covariance of mixture weights.
|
|
138
|
+
"""
|
|
139
|
+
eta_hat = model.eta.detach()
|
|
140
|
+
pi_hat = F.softmax(eta_hat, dim=0)
|
|
141
|
+
|
|
142
|
+
S = len(eta_hat)
|
|
143
|
+
H = compute_hessian(model, subset_maker, train_df, cat_cols, target, alpha)
|
|
144
|
+
|
|
145
|
+
# Regularise to ensure positive definiteness
|
|
146
|
+
H_reg = H + hessian_reg * torch.eye(S)
|
|
147
|
+
H_inv = torch.linalg.inv(H_reg)
|
|
148
|
+
|
|
149
|
+
# Softmax Jacobian at MAP estimate
|
|
150
|
+
J = softmax_jacobian(pi_hat) # [S, S]
|
|
151
|
+
sigma_pi = J @ H_inv @ J.T # [S, S]
|
|
152
|
+
|
|
153
|
+
return sigma_pi
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ---------------------------------------------------------------------------
|
|
157
|
+
# Step 2: Predictive distribution for new points
|
|
158
|
+
# ---------------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def predict_with_uncertainty(
|
|
161
|
+
predictor: SubsetMixturePredictor,
|
|
162
|
+
sigma_pi: torch.Tensor,
|
|
163
|
+
df,
|
|
164
|
+
return_components: bool = False,
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Compute predictive mean and uncertainty for a batch of test points.
|
|
168
|
+
|
|
169
|
+
Predictive mean:
|
|
170
|
+
y_hat(x) = pi_hat^T mu_x
|
|
171
|
+
|
|
172
|
+
Predictive variance:
|
|
173
|
+
Var[y|x] = sum_s pi_s * sigma_s^2(x) [aleatoric]
|
|
174
|
+
+ mu_x^T Sigma_pi mu_x [epistemic]
|
|
175
|
+
|
|
176
|
+
For test points with no valid subset cell (all masked), the predictor
|
|
177
|
+
falls back to the global training mean with aleatoric variance equal to
|
|
178
|
+
the global training variance.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
predictor: SubsetMixturePredictor with learned weight vector.
|
|
182
|
+
sigma_pi: [S, S] posterior covariance from compute_posterior_covariance.
|
|
183
|
+
df: Test DataFrame (integer-encoded).
|
|
184
|
+
return_components: If True, also return aleatoric and epistemic stds.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
y_mean (np.ndarray): [B] predicted means.
|
|
188
|
+
y_std (np.ndarray): [B] total predictive standard deviations.
|
|
189
|
+
(optional) aleatoric_std (np.ndarray): [B]
|
|
190
|
+
(optional) epistemic_std (np.ndarray): [B]
|
|
191
|
+
"""
|
|
192
|
+
batch_tensor = torch.tensor(
|
|
193
|
+
df[predictor.subset_maker.subset_features].astype(float).values,
|
|
194
|
+
dtype=torch.float32,
|
|
195
|
+
)
|
|
196
|
+
mus, variances, mask = predictor.subset_maker.batch_lookup(batch_tensor)
|
|
197
|
+
|
|
198
|
+
B, S = mus.shape
|
|
199
|
+
pi = predictor.weight_vector # [S]
|
|
200
|
+
weights = pi.unsqueeze(0).expand(B, S).clone() # [B, S]
|
|
201
|
+
|
|
202
|
+
# Mask invalid cells
|
|
203
|
+
weights = weights.masked_fill(~mask, 0.0)
|
|
204
|
+
mus_m = mus.masked_fill(~mask, 0.0)
|
|
205
|
+
vars_m = variances.masked_fill(~mask, 0.0)
|
|
206
|
+
|
|
207
|
+
weight_sums = weights.sum(dim=1, keepdim=True) # [B, 1]
|
|
208
|
+
norm_weights = weights / (weight_sums + 1e-9) # [B, S]
|
|
209
|
+
|
|
210
|
+
# --- Predictive mean ---
|
|
211
|
+
fallback_mask = weight_sums.squeeze() < 1e-6 # [B]
|
|
212
|
+
y_mean_weighted = (mus_m * norm_weights).sum(dim=1) # [B]
|
|
213
|
+
y_mean = torch.where(
|
|
214
|
+
fallback_mask,
|
|
215
|
+
torch.full_like(y_mean_weighted, predictor.subset_maker.fallback_mean),
|
|
216
|
+
y_mean_weighted,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# --- Aleatoric variance: E_pi[sigma^2(x)] = sum_s pi_s * sigma_s^2(x) ---
|
|
220
|
+
aleatoric_var = (norm_weights * vars_m).sum(dim=1) # [B]
|
|
221
|
+
# Fallback: use global training variance
|
|
222
|
+
aleatoric_var = torch.where(
|
|
223
|
+
fallback_mask,
|
|
224
|
+
torch.full_like(aleatoric_var, predictor.subset_maker.fallback_var),
|
|
225
|
+
aleatoric_var,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# --- Epistemic variance: mu_x^T Sigma_pi mu_x ---
|
|
229
|
+
# mu_x is mus_m[i] — the vector of subset-conditional means for example i
|
|
230
|
+
# (zeroed for invalid cells, consistent with the masked prediction)
|
|
231
|
+
epistemic_var = torch.einsum("bi,ij,bj->b", mus_m, sigma_pi, mus_m) # [B]
|
|
232
|
+
epistemic_var = epistemic_var.clamp(min=0.0)
|
|
233
|
+
|
|
234
|
+
total_var = aleatoric_var + epistemic_var
|
|
235
|
+
total_std = torch.sqrt(total_var.clamp(min=1e-12))
|
|
236
|
+
|
|
237
|
+
y_mean_np = y_mean.detach().numpy()
|
|
238
|
+
total_std_np = total_std.detach().numpy()
|
|
239
|
+
|
|
240
|
+
if return_components:
|
|
241
|
+
aleatoric_std_np = torch.sqrt(aleatoric_var.clamp(min=1e-12)).detach().numpy()
|
|
242
|
+
epistemic_std_np = torch.sqrt(epistemic_var.clamp(min=1e-12)).detach().numpy()
|
|
243
|
+
return y_mean_np, total_std_np, aleatoric_std_np, epistemic_std_np
|
|
244
|
+
|
|
245
|
+
return y_mean_np, total_std_np
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def coverage(y_true: np.ndarray, y_mean: np.ndarray, y_std: np.ndarray,
|
|
249
|
+
level: float = 0.95) -> float:
|
|
250
|
+
"""
|
|
251
|
+
Empirical coverage at a given credible level.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
y_true: [B] true targets.
|
|
255
|
+
y_mean: [B] predictive means.
|
|
256
|
+
y_std: [B] predictive standard deviations.
|
|
257
|
+
level: Nominal coverage (default 0.95).
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Scalar empirical coverage in [0, 1].
|
|
261
|
+
"""
|
|
262
|
+
from scipy.stats import norm as scipy_norm
|
|
263
|
+
z = scipy_norm.ppf((1 + level) / 2)
|
|
264
|
+
lo = y_mean - z * y_std
|
|
265
|
+
hi = y_mean + z * y_std
|
|
266
|
+
return float(((y_true >= lo) & (y_true <= hi)).mean())
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SubsetWeightsModel(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Learns a single global weight vector over all feature subsets.
|
|
9
|
+
|
|
10
|
+
A real-valued parameter vector eta of length |S| is initialized to zero and
|
|
11
|
+
passed through softmax to produce the mixture weights pi.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
num_subsets (int): Number of subsets in the powerset (|S|).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, num_subsets: int):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.eta = torch.nn.Parameter(torch.zeros(num_subsets))
|
|
20
|
+
|
|
21
|
+
def forward(self, x=None):
|
|
22
|
+
# Return raw logits; softmax is applied inside the loss function.
|
|
23
|
+
return self.eta
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SubsetDataset(torch.utils.data.Dataset):
|
|
27
|
+
"""
|
|
28
|
+
PyTorch Dataset wrapping a DataFrame of integer-coded categorical features.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
df (pd.DataFrame): Data split (train / val / test).
|
|
32
|
+
subset_features (list[str]): Categorical feature column names.
|
|
33
|
+
target (list[str]): Single-element list with the target column name.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, df, subset_features: list, target: list):
|
|
37
|
+
self.df = df
|
|
38
|
+
self.subset_features = subset_features
|
|
39
|
+
self.target = target
|
|
40
|
+
|
|
41
|
+
def __len__(self):
|
|
42
|
+
return len(self.df)
|
|
43
|
+
|
|
44
|
+
def __getitem__(self, idx):
|
|
45
|
+
row = self.df.iloc[idx]
|
|
46
|
+
x = torch.tensor(row[self.subset_features].astype(np.float32).values)
|
|
47
|
+
y = torch.tensor(np.float32(row[self.target[0]]))
|
|
48
|
+
return x, y
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def subset_mixture_neg_log_posterior(
|
|
52
|
+
pi_logits: torch.Tensor,
|
|
53
|
+
y: torch.Tensor,
|
|
54
|
+
mus: torch.Tensor,
|
|
55
|
+
variances: torch.Tensor,
|
|
56
|
+
mask: torch.Tensor = None,
|
|
57
|
+
alpha: float = 1.1,
|
|
58
|
+
) -> torch.Tensor:
|
|
59
|
+
"""
|
|
60
|
+
Negative log-posterior of the subset mixture model.
|
|
61
|
+
|
|
62
|
+
Loss = -sum_i log( sum_s pi_s * N(y_i | mu_s(x_i), sigma^2_s(x_i)) )
|
|
63
|
+
- (alpha - 1) * sum_s log(pi_s)
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
pi_logits: [|S|] unnormalized logits (global parameter).
|
|
67
|
+
y: [B] target values.
|
|
68
|
+
mus: [B, |S|] empirical conditional means.
|
|
69
|
+
variances: [B, |S|] empirical conditional variances.
|
|
70
|
+
mask: [B, |S|] bool tensor; True where the cell exists in training data.
|
|
71
|
+
alpha: Dirichlet concentration parameter (>1 encourages non-degenerate weights).
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Scalar loss tensor.
|
|
75
|
+
"""
|
|
76
|
+
pi = F.softmax(pi_logits, dim=0) # [|S|]
|
|
77
|
+
log_pi = torch.log(pi + 1e-9)
|
|
78
|
+
log_pi_b = log_pi.unsqueeze(0) # [1, |S|]
|
|
79
|
+
|
|
80
|
+
log_probs = (
|
|
81
|
+
-0.5 * torch.log(2 * torch.pi * variances)
|
|
82
|
+
- 0.5 * (y.unsqueeze(1) - mus) ** 2 / variances
|
|
83
|
+
) # [B, |S|]
|
|
84
|
+
|
|
85
|
+
log_weighted = log_probs + log_pi_b # [B, |S|]
|
|
86
|
+
|
|
87
|
+
if mask is not None:
|
|
88
|
+
log_weighted = log_weighted.masked_fill(~mask, float("-inf"))
|
|
89
|
+
|
|
90
|
+
log_likelihoods = torch.logsumexp(log_weighted, dim=1) # [B]
|
|
91
|
+
nll = -log_likelihoods.sum()
|
|
92
|
+
|
|
93
|
+
batch_size = y.size(0)
|
|
94
|
+
log_prior = (alpha - 1.0) * log_pi.sum()
|
|
95
|
+
|
|
96
|
+
return nll - log_prior / batch_size
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def subset_mixture_mse(
|
|
100
|
+
pi_logits: torch.Tensor,
|
|
101
|
+
y: torch.Tensor,
|
|
102
|
+
mus: torch.Tensor,
|
|
103
|
+
mask: torch.Tensor = None,
|
|
104
|
+
) -> torch.Tensor:
|
|
105
|
+
"""
|
|
106
|
+
MSE loss for the subset mixture model (useful for warmup / debugging).
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
pi_logits: [|S|] unnormalized logits.
|
|
110
|
+
y: [B] target values.
|
|
111
|
+
mus: [B, |S|] empirical conditional means.
|
|
112
|
+
mask: [B, |S|] bool tensor.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Scalar MSE tensor.
|
|
116
|
+
"""
|
|
117
|
+
pi = F.softmax(pi_logits, dim=0) # [|S|]
|
|
118
|
+
weights = pi.unsqueeze(0).expand_as(mus) # [B, |S|]
|
|
119
|
+
|
|
120
|
+
if mask is not None:
|
|
121
|
+
weights = weights.masked_fill(~mask, 0.0)
|
|
122
|
+
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)
|
|
123
|
+
|
|
124
|
+
preds = (mus * weights).sum(dim=1)
|
|
125
|
+
return F.mse_loss(preds, y, reduction="sum")
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from .subset_maker import SubsetMaker
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SubsetMixturePredictor:
|
|
8
|
+
"""
|
|
9
|
+
Inference wrapper for a trained Subset Mixture Model.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
subset_maker (SubsetMaker): Trained SubsetMaker with lookup table.
|
|
13
|
+
weight_vector (torch.Tensor): Softmaxed mixture weights [|S|].
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, subset_maker: SubsetMaker, weight_vector: torch.Tensor):
|
|
17
|
+
self.subset_maker = subset_maker
|
|
18
|
+
self.weight_vector = weight_vector
|
|
19
|
+
|
|
20
|
+
def predict(self, df: pd.DataFrame, return_debug: bool = False):
|
|
21
|
+
"""
|
|
22
|
+
Predict target values for a new DataFrame.
|
|
23
|
+
|
|
24
|
+
For each example, weights are masked to valid subsets (those seen during
|
|
25
|
+
training) and re-normalized. Examples with no valid subsets fall back to
|
|
26
|
+
the global training mean.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
df: Input DataFrame containing all subset_features columns.
|
|
30
|
+
return_debug: If True, also return per-example normalized weights
|
|
31
|
+
and a fallback mask.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
preds (np.ndarray): Predicted values [B].
|
|
35
|
+
(optional) norm_weights (np.ndarray): [B, |S|]
|
|
36
|
+
(optional) fallback_mask (np.ndarray): [B] bool
|
|
37
|
+
"""
|
|
38
|
+
batch_tensor = torch.tensor(
|
|
39
|
+
df[self.subset_maker.subset_features].astype(np.float32).values
|
|
40
|
+
)
|
|
41
|
+
mus, _, mask = self.subset_maker.batch_lookup(batch_tensor) # [B, S]
|
|
42
|
+
|
|
43
|
+
B, S = mus.shape
|
|
44
|
+
weights = self.weight_vector.unsqueeze(0).expand(B, S) # [B, S]
|
|
45
|
+
weights = weights.masked_fill(~mask, 0.0)
|
|
46
|
+
mus = mus.masked_fill(~mask, 0.0)
|
|
47
|
+
|
|
48
|
+
weight_sums = weights.sum(dim=1, keepdim=True) # [B, 1]
|
|
49
|
+
norm_weights = weights / (weight_sums + 1e-9)
|
|
50
|
+
|
|
51
|
+
fallback_mask = weight_sums.squeeze() < 1e-6 # [B]
|
|
52
|
+
pred_weighted = (mus * norm_weights).sum(dim=1)
|
|
53
|
+
pred_fallback = torch.full_like(
|
|
54
|
+
pred_weighted, self.subset_maker.fallback_mean
|
|
55
|
+
)
|
|
56
|
+
preds = torch.where(fallback_mask, pred_fallback, pred_weighted)
|
|
57
|
+
|
|
58
|
+
if return_debug:
|
|
59
|
+
return preds.numpy(), norm_weights.numpy(), fallback_mask.numpy()
|
|
60
|
+
return preds.numpy()
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
import itertools
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SubsetMaker:
|
|
8
|
+
"""
|
|
9
|
+
Builds a powerset lookup table from a training DataFrame.
|
|
10
|
+
|
|
11
|
+
For each non-empty subset of `subset_features`, groups training examples by
|
|
12
|
+
their unique value combination and stores the empirical mean and variance of
|
|
13
|
+
the target within each group.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
df (pd.DataFrame): Training data. All `subset_features` columns must be
|
|
17
|
+
integer-coded categorical variables.
|
|
18
|
+
subset_features (list[str]): Names of the categorical feature columns.
|
|
19
|
+
target (list[str]): Single-element list with the target column name.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, df: pd.DataFrame, subset_features: list, target: list):
|
|
23
|
+
self.df = df
|
|
24
|
+
self.subset_features = subset_features
|
|
25
|
+
self.target = target
|
|
26
|
+
self.lookup = self._build_lookup()
|
|
27
|
+
valid_target = self.df[self.target[0]].dropna()
|
|
28
|
+
self.fallback_mean = float(valid_target.mean())
|
|
29
|
+
self.fallback_var = float(valid_target.var())
|
|
30
|
+
|
|
31
|
+
def _get_powerset(self) -> list:
|
|
32
|
+
powerset = []
|
|
33
|
+
for r in range(1, len(self.subset_features) + 1):
|
|
34
|
+
powerset += list(itertools.combinations(self.subset_features, r))
|
|
35
|
+
return [list(s) for s in powerset]
|
|
36
|
+
|
|
37
|
+
def _build_lookup(self, drop_missing_rows: bool = True) -> dict:
|
|
38
|
+
lookup = {}
|
|
39
|
+
for subset in self._get_powerset():
|
|
40
|
+
grouped = (
|
|
41
|
+
self.df[subset + self.target]
|
|
42
|
+
.groupby(subset)
|
|
43
|
+
.agg({self.target[0]: ["mean", "var"]})
|
|
44
|
+
.reset_index()
|
|
45
|
+
)
|
|
46
|
+
grouped.columns = [
|
|
47
|
+
"_".join(c).strip("_") if isinstance(c, tuple) else c
|
|
48
|
+
for c in grouped.columns
|
|
49
|
+
]
|
|
50
|
+
if drop_missing_rows:
|
|
51
|
+
grouped = grouped.dropna()
|
|
52
|
+
grouped = grouped[grouped[f"{self.target[0]}_var"] > 1e-6]
|
|
53
|
+
|
|
54
|
+
mean_col = f"{self.target[0]}_mean"
|
|
55
|
+
var_col = f"{self.target[0]}_var"
|
|
56
|
+
group_dict = {
|
|
57
|
+
tuple(row[subset]): (row[mean_col], row[var_col])
|
|
58
|
+
for _, row in grouped.iterrows()
|
|
59
|
+
}
|
|
60
|
+
lookup[tuple(subset)] = (subset, grouped, group_dict)
|
|
61
|
+
return lookup
|
|
62
|
+
|
|
63
|
+
def batch_lookup(self, batch_tensor: torch.Tensor):
|
|
64
|
+
"""
|
|
65
|
+
Look up empirical means/variances for a batch of integer-coded examples.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
batch_tensor: [B, D] integer tensor matching the order of subset_features.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
means: [B, |S|] float tensor
|
|
72
|
+
variances: [B, |S|] float tensor
|
|
73
|
+
mask: [B, |S|] bool tensor (True where the cell exists in training data)
|
|
74
|
+
"""
|
|
75
|
+
batch_df = pd.DataFrame(
|
|
76
|
+
batch_tensor.numpy(), columns=self.subset_features
|
|
77
|
+
).astype(int)
|
|
78
|
+
|
|
79
|
+
all_means, all_vars, all_masks = [], [], []
|
|
80
|
+
|
|
81
|
+
for _, (subset_cols, _, group_dict) in self.lookup.items():
|
|
82
|
+
subset_vals = batch_df[subset_cols].apply(tuple, axis=1)
|
|
83
|
+
means, vars_, mask = [], [], []
|
|
84
|
+
for key in subset_vals:
|
|
85
|
+
if key in group_dict:
|
|
86
|
+
m, v = group_dict[key]
|
|
87
|
+
means.append(m)
|
|
88
|
+
vars_.append(v)
|
|
89
|
+
mask.append(True)
|
|
90
|
+
else:
|
|
91
|
+
means.append(self.fallback_mean)
|
|
92
|
+
vars_.append(self.fallback_var)
|
|
93
|
+
mask.append(False)
|
|
94
|
+
all_means.append(means)
|
|
95
|
+
all_vars.append(vars_)
|
|
96
|
+
all_masks.append(mask)
|
|
97
|
+
|
|
98
|
+
# Transpose from [|S|, B] → [B, |S|]
|
|
99
|
+
return (
|
|
100
|
+
torch.tensor(all_means, dtype=torch.float32).T,
|
|
101
|
+
torch.tensor(all_vars, dtype=torch.float32).T,
|
|
102
|
+
torch.tensor(all_masks, dtype=torch.bool).T,
|
|
103
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: subset-mixture-model
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
|
|
5
|
+
Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
|
|
8
|
+
Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
Requires-Dist: torch>=2.0
|
|
20
|
+
Requires-Dist: numpy>=1.24
|
|
21
|
+
Requires-Dist: pandas>=2.0
|
|
22
|
+
Requires-Dist: scikit-learn>=1.3
|
|
23
|
+
Requires-Dist: scipy>=1.11
|
|
24
|
+
Provides-Extra: experiments
|
|
25
|
+
Requires-Dist: matplotlib>=3.7; extra == "experiments"
|
|
26
|
+
Requires-Dist: seaborn>=0.12; extra == "experiments"
|
|
27
|
+
Requires-Dist: joblib>=1.3; extra == "experiments"
|
|
28
|
+
Requires-Dist: lightgbm>=4.0; extra == "experiments"
|
|
29
|
+
Requires-Dist: ngboost>=0.4; extra == "experiments"
|
|
30
|
+
Requires-Dist: mapie>=0.6; extra == "experiments"
|
|
31
|
+
Provides-Extra: dev
|
|
32
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
33
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
34
|
+
|
|
35
|
+
# Subset Mixture Model (SMM)
|
|
36
|
+
|
|
37
|
+
**SMM** is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.
|
|
38
|
+
|
|
39
|
+
## Key idea
|
|
40
|
+
|
|
41
|
+
Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:
|
|
42
|
+
|
|
43
|
+
$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$
|
|
44
|
+
|
|
45
|
+
The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.
|
|
46
|
+
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install subset-mixture-model
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
Or from source:
|
|
54
|
+
|
|
55
|
+
```bash
|
|
56
|
+
git clone https://github.com/aaronjdanielson/subset-mixture-model
|
|
57
|
+
cd subset-mixture-model
|
|
58
|
+
pip install -e .
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Quick start
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import pandas as pd
|
|
65
|
+
import torch
|
|
66
|
+
import torch.nn.functional as F
|
|
67
|
+
from torch.utils.data import DataLoader
|
|
68
|
+
|
|
69
|
+
from smm import (
|
|
70
|
+
SubsetMaker, SubsetWeightsModel, SubsetDataset,
|
|
71
|
+
subset_mixture_neg_log_posterior, SubsetMixturePredictor,
|
|
72
|
+
compute_posterior_covariance, predict_with_uncertainty,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# --- your data (integer-coded categorical features) ---
|
|
76
|
+
train_df = pd.read_csv("train.csv")
|
|
77
|
+
val_df = pd.read_csv("val.csv")
|
|
78
|
+
test_df = pd.read_csv("test.csv")
|
|
79
|
+
|
|
80
|
+
cat_cols = ["feature_a", "feature_b", "feature_c"]
|
|
81
|
+
target = "y"
|
|
82
|
+
|
|
83
|
+
# --- build lookup table ---
|
|
84
|
+
subset_maker = SubsetMaker(train_df, cat_cols, [target])
|
|
85
|
+
n_subsets = len(subset_maker.lookup)
|
|
86
|
+
|
|
87
|
+
# --- train ---
|
|
88
|
+
model = SubsetWeightsModel(n_subsets)
|
|
89
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
|
|
90
|
+
loader = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
|
|
91
|
+
batch_size=64, shuffle=True)
|
|
92
|
+
|
|
93
|
+
for epoch in range(100):
|
|
94
|
+
for x, y in loader:
|
|
95
|
+
optimizer.zero_grad()
|
|
96
|
+
mus, variances, mask = subset_maker.batch_lookup(x)
|
|
97
|
+
loss = subset_mixture_neg_log_posterior(
|
|
98
|
+
model(), y, mus, variances, mask, alpha=1.1)
|
|
99
|
+
loss.backward()
|
|
100
|
+
optimizer.step()
|
|
101
|
+
|
|
102
|
+
# --- predict with uncertainty ---
|
|
103
|
+
pi_hat = F.softmax(model.eta.detach(), dim=0)
|
|
104
|
+
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
|
|
105
|
+
sigma_pi = compute_posterior_covariance(
|
|
106
|
+
model, subset_maker, train_df, cat_cols, target, alpha=1.1)
|
|
107
|
+
|
|
108
|
+
y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
|
|
109
|
+
# y_mean: point predictions
|
|
110
|
+
# y_std: total predictive standard deviation (aleatoric + epistemic)
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Interpretability
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
import numpy as np
|
|
117
|
+
|
|
118
|
+
subsets = list(subset_maker.lookup.keys())
|
|
119
|
+
top_idx = np.argsort(pi_hat.numpy())[::-1][:10]
|
|
120
|
+
|
|
121
|
+
for rank, i in enumerate(top_idx):
|
|
122
|
+
print(f"{rank+1:2d}. {subsets[i]} π={pi_hat[i]:.4f}")
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
## Features
|
|
126
|
+
|
|
127
|
+
- **Interpretable by construction**: learned weights reveal which feature interactions matter
|
|
128
|
+
- **Principled uncertainty**: aleatoric/epistemic decomposition via Laplace approximation
|
|
129
|
+
- **Efficient training**: only $2^D - 1$ logits optimized; lookup table precomputed once
|
|
130
|
+
- **No post-hoc calibration**: well-calibrated predictive intervals out of the box
|
|
131
|
+
- **Scalable to D ≤ 15** features; $k$-way truncation available for larger $D$
|
|
132
|
+
|
|
133
|
+
## Datasets supported
|
|
134
|
+
|
|
135
|
+
Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.
|
|
136
|
+
|
|
137
|
+
## Citation
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@article{danielson2025smm,
|
|
141
|
+
title = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
|
|
142
|
+
of Partition Estimators for Categorical Regression},
|
|
143
|
+
author = {Danielson, Aaron John},
|
|
144
|
+
journal = {Machine Learning},
|
|
145
|
+
year = {2025},
|
|
146
|
+
note = {Under review}
|
|
147
|
+
}
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## License
|
|
151
|
+
|
|
152
|
+
MIT
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
smm/__init__.py
|
|
4
|
+
smm/laplace.py
|
|
5
|
+
smm/model.py
|
|
6
|
+
smm/predictor.py
|
|
7
|
+
smm/subset_maker.py
|
|
8
|
+
subset_mixture_model.egg-info/PKG-INFO
|
|
9
|
+
subset_mixture_model.egg-info/SOURCES.txt
|
|
10
|
+
subset_mixture_model.egg-info/dependency_links.txt
|
|
11
|
+
subset_mixture_model.egg-info/requires.txt
|
|
12
|
+
subset_mixture_model.egg-info/top_level.txt
|
|
13
|
+
tests/test_smm.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
smm
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core tests for the Subset Mixture Model package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pytest
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
|
|
12
|
+
from smm import (
|
|
13
|
+
SubsetMaker,
|
|
14
|
+
SubsetWeightsModel,
|
|
15
|
+
SubsetDataset,
|
|
16
|
+
subset_mixture_neg_log_posterior,
|
|
17
|
+
SubsetMixturePredictor,
|
|
18
|
+
compute_posterior_covariance,
|
|
19
|
+
predict_with_uncertainty,
|
|
20
|
+
coverage,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
# Fixtures
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
|
|
28
|
+
@pytest.fixture
|
|
29
|
+
def synthetic_df():
|
|
30
|
+
"""3-feature synthetic dataset with a known 3-way interaction."""
|
|
31
|
+
rng = np.random.default_rng(0)
|
|
32
|
+
n = 300
|
|
33
|
+
a = rng.integers(0, 4, n)
|
|
34
|
+
b = rng.integers(0, 3, n)
|
|
35
|
+
c = rng.integers(0, 2, n)
|
|
36
|
+
# target = interaction mean + noise
|
|
37
|
+
cell_means = {(i, j, k): rng.normal(0, 2)
|
|
38
|
+
for i in range(4) for j in range(3) for k in range(2)}
|
|
39
|
+
y = np.array([cell_means[(a[i], b[i], c[i])] + rng.normal(0, 0.1)
|
|
40
|
+
for i in range(n)])
|
|
41
|
+
return pd.DataFrame({"a": a, "b": b, "c": c, "y": y})
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def splits(synthetic_df):
|
|
46
|
+
from sklearn.model_selection import train_test_split
|
|
47
|
+
tr, te = train_test_split(synthetic_df, test_size=0.2, random_state=0)
|
|
48
|
+
tr, va = train_test_split(tr, test_size=0.15, random_state=0)
|
|
49
|
+
return tr.reset_index(drop=True), va.reset_index(drop=True), te.reset_index(drop=True)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
CAT_COLS = ["a", "b", "c"]
|
|
53
|
+
TARGET = "y"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
# SubsetMaker
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
|
|
60
|
+
def test_subset_maker_powerset_size(splits):
|
|
61
|
+
tr, _, _ = splits
|
|
62
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
63
|
+
# 3 features → 2^3 - 1 = 7 subsets
|
|
64
|
+
assert len(sm.lookup) == 7
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_subset_maker_batch_lookup_shapes(splits):
|
|
68
|
+
tr, _, _ = splits
|
|
69
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
70
|
+
x = torch.tensor(tr[CAT_COLS].astype(np.float32).values[:16])
|
|
71
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
72
|
+
assert mus.shape == (16, 7)
|
|
73
|
+
assert variances.shape == (16, 7)
|
|
74
|
+
assert mask.shape == (16, 7)
|
|
75
|
+
assert mask.dtype == torch.bool
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_subset_maker_fallback(splits):
|
|
79
|
+
tr, _, _ = splits
|
|
80
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
81
|
+
assert np.isfinite(sm.fallback_mean)
|
|
82
|
+
assert np.isfinite(sm.fallback_var)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# ---------------------------------------------------------------------------
|
|
86
|
+
# SubsetWeightsModel
|
|
87
|
+
# ---------------------------------------------------------------------------
|
|
88
|
+
|
|
89
|
+
def test_model_forward_returns_logits():
|
|
90
|
+
model = SubsetWeightsModel(7)
|
|
91
|
+
eta = model()
|
|
92
|
+
assert eta.shape == (7,)
|
|
93
|
+
assert not torch.isnan(eta).any()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_softmax_sums_to_one():
|
|
97
|
+
model = SubsetWeightsModel(7)
|
|
98
|
+
pi = F.softmax(model(), dim=0)
|
|
99
|
+
assert torch.allclose(pi.sum(), torch.tensor(1.0), atol=1e-6)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# ---------------------------------------------------------------------------
|
|
103
|
+
# Loss function
|
|
104
|
+
# ---------------------------------------------------------------------------
|
|
105
|
+
|
|
106
|
+
def test_loss_finite(splits):
|
|
107
|
+
tr, _, _ = splits
|
|
108
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
109
|
+
model = SubsetWeightsModel(len(sm.lookup))
|
|
110
|
+
loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
|
|
111
|
+
batch_size=32, shuffle=False)
|
|
112
|
+
x, y = next(iter(loader))
|
|
113
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
114
|
+
loss = subset_mixture_neg_log_posterior(model(), y, mus, variances, mask)
|
|
115
|
+
assert torch.isfinite(loss)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_loss_decreases_after_training(splits):
|
|
119
|
+
tr, va, _ = splits
|
|
120
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
121
|
+
model = SubsetWeightsModel(len(sm.lookup))
|
|
122
|
+
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
123
|
+
loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
|
|
124
|
+
batch_size=64, shuffle=True)
|
|
125
|
+
|
|
126
|
+
def epoch_loss():
|
|
127
|
+
total = 0.0
|
|
128
|
+
model.eval()
|
|
129
|
+
with torch.no_grad():
|
|
130
|
+
for x, y in loader:
|
|
131
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
132
|
+
total += subset_mixture_neg_log_posterior(
|
|
133
|
+
model(), y, mus, variances, mask).item()
|
|
134
|
+
return total
|
|
135
|
+
|
|
136
|
+
loss_before = epoch_loss()
|
|
137
|
+
model.train()
|
|
138
|
+
for _ in range(5):
|
|
139
|
+
for x, y in loader:
|
|
140
|
+
opt.zero_grad()
|
|
141
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
142
|
+
subset_mixture_neg_log_posterior(
|
|
143
|
+
model(), y, mus, variances, mask).backward()
|
|
144
|
+
opt.step()
|
|
145
|
+
loss_after = epoch_loss()
|
|
146
|
+
assert loss_after < loss_before
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# ---------------------------------------------------------------------------
|
|
150
|
+
# Predictor
|
|
151
|
+
# ---------------------------------------------------------------------------
|
|
152
|
+
|
|
153
|
+
def test_predictor_output_shape(splits):
|
|
154
|
+
tr, _, te = splits
|
|
155
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
156
|
+
model = SubsetWeightsModel(len(sm.lookup))
|
|
157
|
+
pi = F.softmax(model.eta.detach(), dim=0)
|
|
158
|
+
predictor = SubsetMixturePredictor(sm, pi)
|
|
159
|
+
preds = predictor.predict(te)
|
|
160
|
+
assert preds.shape == (len(te),)
|
|
161
|
+
assert np.isfinite(preds).all()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_predictor_fallback_for_unseen_cells():
|
|
165
|
+
"""Test point with values not seen in training should fall back gracefully."""
|
|
166
|
+
rng = np.random.default_rng(1)
|
|
167
|
+
tr = pd.DataFrame({"a": [0, 1, 0, 1], "b": [0, 0, 1, 1], "y": [1.0, 2.0, 3.0, 4.0]})
|
|
168
|
+
sm = SubsetMaker(tr, ["a", "b"], ["y"])
|
|
169
|
+
model = SubsetWeightsModel(len(sm.lookup))
|
|
170
|
+
pi = F.softmax(model.eta.detach(), dim=0)
|
|
171
|
+
predictor = SubsetMixturePredictor(sm, pi)
|
|
172
|
+
# value 99 was never seen — should return fallback_mean
|
|
173
|
+
unseen = pd.DataFrame({"a": [99], "b": [99], "y": [0.0]})
|
|
174
|
+
preds = predictor.predict(unseen)
|
|
175
|
+
assert np.isfinite(preds).all()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# ---------------------------------------------------------------------------
|
|
179
|
+
# Laplace / uncertainty
|
|
180
|
+
# ---------------------------------------------------------------------------
|
|
181
|
+
|
|
182
|
+
def test_predict_with_uncertainty_shapes(splits):
|
|
183
|
+
tr, va, te = splits
|
|
184
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
185
|
+
model = SubsetWeightsModel(len(sm.lookup))
|
|
186
|
+
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
187
|
+
loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
|
|
188
|
+
batch_size=64, shuffle=True)
|
|
189
|
+
for _ in range(3):
|
|
190
|
+
for x, y in loader:
|
|
191
|
+
opt.zero_grad()
|
|
192
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
193
|
+
subset_mixture_neg_log_posterior(
|
|
194
|
+
model(), y, mus, variances, mask).backward()
|
|
195
|
+
opt.step()
|
|
196
|
+
pi = F.softmax(model.eta.detach(), dim=0)
|
|
197
|
+
predictor = SubsetMixturePredictor(sm, pi)
|
|
198
|
+
sigma_pi = compute_posterior_covariance(
|
|
199
|
+
model, sm, tr, CAT_COLS, TARGET, alpha=1.1)
|
|
200
|
+
y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, te)
|
|
201
|
+
assert y_mean.shape == (len(te),)
|
|
202
|
+
assert y_std.shape == (len(te),)
|
|
203
|
+
assert (y_std >= 0).all()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def test_coverage_function():
|
|
207
|
+
y_true = np.array([0.0, 1.0, 2.0, 3.0])
|
|
208
|
+
y_mean = np.array([0.0, 1.0, 2.0, 3.0])
|
|
209
|
+
y_std = np.array([1.0, 1.0, 1.0, 1.0])
|
|
210
|
+
cov = coverage(y_true, y_mean, y_std, level=0.95)
|
|
211
|
+
assert 0.0 <= cov <= 1.0
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# ---------------------------------------------------------------------------
|
|
215
|
+
# Synthetic recovery (integration test)
|
|
216
|
+
# ---------------------------------------------------------------------------
|
|
217
|
+
|
|
218
|
+
def test_synthetic_weight_recovery(splits):
|
|
219
|
+
"""
|
|
220
|
+
With the DGP driven purely by the 3-way interaction,
|
|
221
|
+
the full-powerset subset should receive the highest weight.
|
|
222
|
+
"""
|
|
223
|
+
tr, va, te = splits
|
|
224
|
+
sm = SubsetMaker(tr, CAT_COLS, [TARGET])
|
|
225
|
+
model = SubsetWeightsModel(len(sm.lookup))
|
|
226
|
+
opt = torch.optim.Adam(model.parameters(), lr=5e-3)
|
|
227
|
+
loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
|
|
228
|
+
batch_size=64, shuffle=True)
|
|
229
|
+
best_state, best_val = None, float("inf")
|
|
230
|
+
val_loader = DataLoader(SubsetDataset(va, CAT_COLS, [TARGET]),
|
|
231
|
+
batch_size=64, shuffle=False)
|
|
232
|
+
torch.manual_seed(0)
|
|
233
|
+
for epoch in range(60):
|
|
234
|
+
model.train()
|
|
235
|
+
for x, y in loader:
|
|
236
|
+
opt.zero_grad()
|
|
237
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
238
|
+
subset_mixture_neg_log_posterior(
|
|
239
|
+
model(), y, mus, variances, mask, alpha=1.1).backward()
|
|
240
|
+
opt.step()
|
|
241
|
+
model.eval()
|
|
242
|
+
vl = 0.0
|
|
243
|
+
with torch.no_grad():
|
|
244
|
+
for x, y in val_loader:
|
|
245
|
+
mus, variances, mask = sm.batch_lookup(x)
|
|
246
|
+
vl += subset_mixture_neg_log_posterior(
|
|
247
|
+
model(), y, mus, variances, mask, alpha=1.1).item()
|
|
248
|
+
vl /= len(val_loader)
|
|
249
|
+
if vl < best_val:
|
|
250
|
+
best_val = vl
|
|
251
|
+
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
|
252
|
+
model.load_state_dict(best_state)
|
|
253
|
+
pi = F.softmax(model.eta.detach(), dim=0).numpy()
|
|
254
|
+
subsets = list(sm.lookup.keys())
|
|
255
|
+
# Full 3-way subset should be highest weight
|
|
256
|
+
full_subset = tuple(sorted(CAT_COLS))
|
|
257
|
+
idx = next(i for i, s in enumerate(subsets) if tuple(sorted(s)) == full_subset)
|
|
258
|
+
assert pi[idx] == pi.max(), (
|
|
259
|
+
f"3-way subset not top-weighted: pi={pi[idx]:.3f} vs max={pi.max():.3f}")
|