subset-mixture-model 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,152 @@
1
+ Metadata-Version: 2.4
2
+ Name: subset-mixture-model
3
+ Version: 0.1.0
4
+ Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
5
+ Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
6
+ License: MIT
7
+ Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
8
+ Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.9
18
+ Description-Content-Type: text/markdown
19
+ Requires-Dist: torch>=2.0
20
+ Requires-Dist: numpy>=1.24
21
+ Requires-Dist: pandas>=2.0
22
+ Requires-Dist: scikit-learn>=1.3
23
+ Requires-Dist: scipy>=1.11
24
+ Provides-Extra: experiments
25
+ Requires-Dist: matplotlib>=3.7; extra == "experiments"
26
+ Requires-Dist: seaborn>=0.12; extra == "experiments"
27
+ Requires-Dist: joblib>=1.3; extra == "experiments"
28
+ Requires-Dist: lightgbm>=4.0; extra == "experiments"
29
+ Requires-Dist: ngboost>=0.4; extra == "experiments"
30
+ Requires-Dist: mapie>=0.6; extra == "experiments"
31
+ Provides-Extra: dev
32
+ Requires-Dist: pytest>=7; extra == "dev"
33
+ Requires-Dist: pytest-cov; extra == "dev"
34
+
35
+ # Subset Mixture Model (SMM)
36
+
37
+ **SMM** is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.
38
+
39
+ ## Key idea
40
+
41
+ Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:
42
+
43
+ $$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$
44
+
45
+ The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.
46
+
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install subset-mixture-model
51
+ ```
52
+
53
+ Or from source:
54
+
55
+ ```bash
56
+ git clone https://github.com/aaronjdanielson/subset-mixture-model
57
+ cd subset-mixture-model
58
+ pip install -e .
59
+ ```
60
+
61
+ ## Quick start
62
+
63
+ ```python
64
+ import pandas as pd
65
+ import torch
66
+ import torch.nn.functional as F
67
+ from torch.utils.data import DataLoader
68
+
69
+ from smm import (
70
+ SubsetMaker, SubsetWeightsModel, SubsetDataset,
71
+ subset_mixture_neg_log_posterior, SubsetMixturePredictor,
72
+ compute_posterior_covariance, predict_with_uncertainty,
73
+ )
74
+
75
+ # --- your data (integer-coded categorical features) ---
76
+ train_df = pd.read_csv("train.csv")
77
+ val_df = pd.read_csv("val.csv")
78
+ test_df = pd.read_csv("test.csv")
79
+
80
+ cat_cols = ["feature_a", "feature_b", "feature_c"]
81
+ target = "y"
82
+
83
+ # --- build lookup table ---
84
+ subset_maker = SubsetMaker(train_df, cat_cols, [target])
85
+ n_subsets = len(subset_maker.lookup)
86
+
87
+ # --- train ---
88
+ model = SubsetWeightsModel(n_subsets)
89
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
90
+ loader = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
91
+ batch_size=64, shuffle=True)
92
+
93
+ for epoch in range(100):
94
+ for x, y in loader:
95
+ optimizer.zero_grad()
96
+ mus, variances, mask = subset_maker.batch_lookup(x)
97
+ loss = subset_mixture_neg_log_posterior(
98
+ model(), y, mus, variances, mask, alpha=1.1)
99
+ loss.backward()
100
+ optimizer.step()
101
+
102
+ # --- predict with uncertainty ---
103
+ pi_hat = F.softmax(model.eta.detach(), dim=0)
104
+ predictor = SubsetMixturePredictor(subset_maker, pi_hat)
105
+ sigma_pi = compute_posterior_covariance(
106
+ model, subset_maker, train_df, cat_cols, target, alpha=1.1)
107
+
108
+ y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
109
+ # y_mean: point predictions
110
+ # y_std: total predictive standard deviation (aleatoric + epistemic)
111
+ ```
112
+
113
+ ## Interpretability
114
+
115
+ ```python
116
+ import numpy as np
117
+
118
+ subsets = list(subset_maker.lookup.keys())
119
+ top_idx = np.argsort(pi_hat.numpy())[::-1][:10]
120
+
121
+ for rank, i in enumerate(top_idx):
122
+ print(f"{rank+1:2d}. {subsets[i]} π={pi_hat[i]:.4f}")
123
+ ```
124
+
125
+ ## Features
126
+
127
+ - **Interpretable by construction**: learned weights reveal which feature interactions matter
128
+ - **Principled uncertainty**: aleatoric/epistemic decomposition via Laplace approximation
129
+ - **Efficient training**: only $2^D - 1$ logits optimized; lookup table precomputed once
130
+ - **No post-hoc calibration**: well-calibrated predictive intervals out of the box
131
+ - **Scalable to D ≤ 15** features; $k$-way truncation available for larger $D$
132
+
133
+ ## Datasets supported
134
+
135
+ Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.
136
+
137
+ ## Citation
138
+
139
+ ```bibtex
140
+ @article{danielson2025smm,
141
+ title = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
142
+ of Partition Estimators for Categorical Regression},
143
+ author = {Danielson, Aaron John},
144
+ journal = {Machine Learning},
145
+ year = {2025},
146
+ note = {Under review}
147
+ }
148
+ ```
149
+
150
+ ## License
151
+
152
+ MIT
@@ -0,0 +1,118 @@
1
+ # Subset Mixture Model (SMM)
2
+
3
+ **SMM** is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.
4
+
5
+ ## Key idea
6
+
7
+ Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:
8
+
9
+ $$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$
10
+
11
+ The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.
12
+
13
+ ## Installation
14
+
15
+ ```bash
16
+ pip install subset-mixture-model
17
+ ```
18
+
19
+ Or from source:
20
+
21
+ ```bash
22
+ git clone https://github.com/aaronjdanielson/subset-mixture-model
23
+ cd subset-mixture-model
24
+ pip install -e .
25
+ ```
26
+
27
+ ## Quick start
28
+
29
+ ```python
30
+ import pandas as pd
31
+ import torch
32
+ import torch.nn.functional as F
33
+ from torch.utils.data import DataLoader
34
+
35
+ from smm import (
36
+ SubsetMaker, SubsetWeightsModel, SubsetDataset,
37
+ subset_mixture_neg_log_posterior, SubsetMixturePredictor,
38
+ compute_posterior_covariance, predict_with_uncertainty,
39
+ )
40
+
41
+ # --- your data (integer-coded categorical features) ---
42
+ train_df = pd.read_csv("train.csv")
43
+ val_df = pd.read_csv("val.csv")
44
+ test_df = pd.read_csv("test.csv")
45
+
46
+ cat_cols = ["feature_a", "feature_b", "feature_c"]
47
+ target = "y"
48
+
49
+ # --- build lookup table ---
50
+ subset_maker = SubsetMaker(train_df, cat_cols, [target])
51
+ n_subsets = len(subset_maker.lookup)
52
+
53
+ # --- train ---
54
+ model = SubsetWeightsModel(n_subsets)
55
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
56
+ loader = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
57
+ batch_size=64, shuffle=True)
58
+
59
+ for epoch in range(100):
60
+ for x, y in loader:
61
+ optimizer.zero_grad()
62
+ mus, variances, mask = subset_maker.batch_lookup(x)
63
+ loss = subset_mixture_neg_log_posterior(
64
+ model(), y, mus, variances, mask, alpha=1.1)
65
+ loss.backward()
66
+ optimizer.step()
67
+
68
+ # --- predict with uncertainty ---
69
+ pi_hat = F.softmax(model.eta.detach(), dim=0)
70
+ predictor = SubsetMixturePredictor(subset_maker, pi_hat)
71
+ sigma_pi = compute_posterior_covariance(
72
+ model, subset_maker, train_df, cat_cols, target, alpha=1.1)
73
+
74
+ y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
75
+ # y_mean: point predictions
76
+ # y_std: total predictive standard deviation (aleatoric + epistemic)
77
+ ```
78
+
79
+ ## Interpretability
80
+
81
+ ```python
82
+ import numpy as np
83
+
84
+ subsets = list(subset_maker.lookup.keys())
85
+ top_idx = np.argsort(pi_hat.numpy())[::-1][:10]
86
+
87
+ for rank, i in enumerate(top_idx):
88
+ print(f"{rank+1:2d}. {subsets[i]} π={pi_hat[i]:.4f}")
89
+ ```
90
+
91
+ ## Features
92
+
93
+ - **Interpretable by construction**: learned weights reveal which feature interactions matter
94
+ - **Principled uncertainty**: aleatoric/epistemic decomposition via Laplace approximation
95
+ - **Efficient training**: only $2^D - 1$ logits optimized; lookup table precomputed once
96
+ - **No post-hoc calibration**: well-calibrated predictive intervals out of the box
97
+ - **Scalable to D ≤ 15** features; $k$-way truncation available for larger $D$
98
+
99
+ ## Datasets supported
100
+
101
+ Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.
102
+
103
+ ## Citation
104
+
105
+ ```bibtex
106
+ @article{danielson2025smm,
107
+ title = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
108
+ of Partition Estimators for Categorical Regression},
109
+ author = {Danielson, Aaron John},
110
+ journal = {Machine Learning},
111
+ year = {2025},
112
+ note = {Under review}
113
+ }
114
+ ```
115
+
116
+ ## License
117
+
118
+ MIT
@@ -0,0 +1,59 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "subset-mixture-model"
7
+ version = "0.1.0"
8
+ description = "Interpretable empirical-Bayes aggregation of partition estimators for categorical regression"
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ authors = [
12
+ { name = "Aaron John Danielson", email = "aaron.danielson@austin.utexas.edu" },
13
+ ]
14
+ keywords = [
15
+ "interpretable machine learning",
16
+ "categorical features",
17
+ "empirical Bayes",
18
+ "uncertainty quantification",
19
+ "mixture model",
20
+ ]
21
+ classifiers = [
22
+ "Development Status :: 4 - Beta",
23
+ "Intended Audience :: Science/Research",
24
+ "License :: OSI Approved :: MIT License",
25
+ "Programming Language :: Python :: 3",
26
+ "Programming Language :: Python :: 3.9",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
30
+ ]
31
+ requires-python = ">=3.9"
32
+ dependencies = [
33
+ "torch>=2.0",
34
+ "numpy>=1.24",
35
+ "pandas>=2.0",
36
+ "scikit-learn>=1.3",
37
+ "scipy>=1.11",
38
+ ]
39
+
40
+ [project.optional-dependencies]
41
+ experiments = [
42
+ "matplotlib>=3.7",
43
+ "seaborn>=0.12",
44
+ "joblib>=1.3",
45
+ "lightgbm>=4.0",
46
+ "ngboost>=0.4",
47
+ "mapie>=0.6",
48
+ ]
49
+ dev = [
50
+ "pytest>=7",
51
+ "pytest-cov",
52
+ ]
53
+
54
+ [project.urls]
55
+ Repository = "https://github.com/aaronjdanielson/subset-mixture-model"
56
+
57
+ [tool.setuptools.packages.find]
58
+ where = ["."]
59
+ include = ["smm*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,20 @@
1
+ from .subset_maker import SubsetMaker
2
+ from .model import SubsetWeightsModel, SubsetDataset, subset_mixture_neg_log_posterior, subset_mixture_mse
3
+ from .predictor import SubsetMixturePredictor
4
+ from .laplace import (
5
+ compute_posterior_covariance,
6
+ predict_with_uncertainty,
7
+ coverage,
8
+ )
9
+
10
+ __all__ = [
11
+ "SubsetMaker",
12
+ "SubsetWeightsModel",
13
+ "SubsetDataset",
14
+ "subset_mixture_neg_log_posterior",
15
+ "subset_mixture_mse",
16
+ "SubsetMixturePredictor",
17
+ "compute_posterior_covariance",
18
+ "predict_with_uncertainty",
19
+ "coverage",
20
+ ]
@@ -0,0 +1,266 @@
1
+ """
2
+ Laplace approximation for posterior uncertainty in the Subset Mixture Model.
3
+
4
+ After MAP estimation of eta (the logit weight vector), the Laplace approximation
5
+ treats the negative log-posterior as a quadratic around the MAP:
6
+
7
+ q(eta) = N(eta | eta_hat, H^{-1})
8
+
9
+ where H is the Hessian of the negative log-posterior at eta_hat.
10
+
11
+ Transforming to the simplex via pi = softmax(eta) and applying the delta method:
12
+
13
+ Cov[pi | D] ≈ Sigma_pi = J H^{-1} J^T
14
+
15
+ where J = d(pi)/d(eta) is the |S| x |S| softmax Jacobian.
16
+
17
+ Predictive variance for a new point x_tilde then decomposes as:
18
+
19
+ Var[y | x_tilde] ≈ sum_s pi_s * sigma^2_s(x_tilde) [aleatoric]
20
+ + mu_{x_tilde}^T Sigma_pi mu_{x_tilde} [epistemic]
21
+
22
+ References:
23
+ MacKay, D.J.C. (1992). A Practical Bayesian Framework for Backpropagation Networks.
24
+ Daxberger et al. (2021). Laplace Redux -- Effortless Bayesian Deep Learning.
25
+ """
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from torch.utils.data import DataLoader
31
+
32
+ from .model import SubsetDataset, subset_mixture_neg_log_posterior
33
+ from .subset_maker import SubsetMaker
34
+ from .predictor import SubsetMixturePredictor
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Step 1: Compute posterior covariance of pi
39
+ # ---------------------------------------------------------------------------
40
+
41
+ def softmax_jacobian(pi: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Jacobian of softmax(eta) w.r.t. eta, evaluated at pi = softmax(eta).
44
+
45
+ J_ij = pi_i * (delta_ij - pi_j)
46
+
47
+ Args:
48
+ pi: [S] mixture weights (must sum to 1).
49
+
50
+ Returns:
51
+ J: [S, S] Jacobian matrix.
52
+ """
53
+ return torch.diag(pi) - torch.outer(pi, pi)
54
+
55
+
56
+ def compute_hessian(
57
+ model: torch.nn.Module,
58
+ subset_maker: SubsetMaker,
59
+ train_df,
60
+ cat_cols: list,
61
+ target: str,
62
+ alpha: float = 1.1,
63
+ batch_size: int = 512,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Compute the exact Hessian of the negative log-posterior w.r.t. eta.
67
+
68
+ Because |S| is at most a few hundred, the Hessian is small and can be
69
+ computed exactly via torch.autograd.functional.hessian on the full
70
+ training set (loaded in batches if needed, accumulated via the empirical
71
+ Fisher for very large datasets).
72
+
73
+ For datasets where n is large we accumulate an approximation via the
74
+ diagonal outer-product (empirical Fisher). For the exact Hessian we need
75
+ the full batch, which is feasible when n < ~20K.
76
+
77
+ Args:
78
+ model: Trained SubsetWeightsModel.
79
+ subset_maker: Fitted SubsetMaker.
80
+ train_df: Training DataFrame (integer-encoded).
81
+ cat_cols: Categorical feature column names.
82
+ target: Target column name.
83
+ alpha: Dirichlet concentration (must match training value).
84
+ batch_size: Batch size for loading data (only affects memory).
85
+
86
+ Returns:
87
+ H: [S, S] Hessian tensor.
88
+ """
89
+ model.eval()
90
+ eta_hat = model.eta.detach().clone()
91
+
92
+ # Load all training data (full batch for exact Hessian)
93
+ ds = SubsetDataset(train_df, cat_cols, [target])
94
+ loader = DataLoader(ds, batch_size=len(ds), shuffle=False)
95
+ x_all, y_all = next(iter(loader))
96
+
97
+ mus, variances, mask = subset_maker.batch_lookup(x_all)
98
+ # Detach and keep on CPU
99
+ mus = mus.detach()
100
+ variances = variances.detach()
101
+ mask = mask.detach()
102
+
103
+ def loss_fn(eta):
104
+ return subset_mixture_neg_log_posterior(
105
+ eta, y_all, mus, variances, mask, alpha=alpha
106
+ )
107
+
108
+ eta_param = eta_hat.requires_grad_(True)
109
+ H = torch.autograd.functional.hessian(loss_fn, eta_param) # [S, S]
110
+ return H.detach()
111
+
112
+
113
+ def compute_posterior_covariance(
114
+ model: torch.nn.Module,
115
+ subset_maker: SubsetMaker,
116
+ train_df,
117
+ cat_cols: list,
118
+ target: str,
119
+ alpha: float = 1.1,
120
+ hessian_reg: float = 1e-4,
121
+ ) -> torch.Tensor:
122
+ """
123
+ Compute the Laplace approximation to the posterior covariance of pi.
124
+
125
+ Sigma_pi = J H^{-1} J^T
126
+
127
+ Args:
128
+ model: Trained SubsetWeightsModel (provides eta_hat).
129
+ subset_maker: Fitted SubsetMaker.
130
+ train_df: Training DataFrame (integer-encoded).
131
+ cat_cols: Categorical feature column names.
132
+ target: Target column name.
133
+ alpha: Dirichlet concentration (must match training value).
134
+ hessian_reg: Ridge added to H before inversion for numerical stability.
135
+
136
+ Returns:
137
+ sigma_pi: [S, S] posterior covariance of mixture weights.
138
+ """
139
+ eta_hat = model.eta.detach()
140
+ pi_hat = F.softmax(eta_hat, dim=0)
141
+
142
+ S = len(eta_hat)
143
+ H = compute_hessian(model, subset_maker, train_df, cat_cols, target, alpha)
144
+
145
+ # Regularise to ensure positive definiteness
146
+ H_reg = H + hessian_reg * torch.eye(S)
147
+ H_inv = torch.linalg.inv(H_reg)
148
+
149
+ # Softmax Jacobian at MAP estimate
150
+ J = softmax_jacobian(pi_hat) # [S, S]
151
+ sigma_pi = J @ H_inv @ J.T # [S, S]
152
+
153
+ return sigma_pi
154
+
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Step 2: Predictive distribution for new points
158
+ # ---------------------------------------------------------------------------
159
+
160
+ def predict_with_uncertainty(
161
+ predictor: SubsetMixturePredictor,
162
+ sigma_pi: torch.Tensor,
163
+ df,
164
+ return_components: bool = False,
165
+ ):
166
+ """
167
+ Compute predictive mean and uncertainty for a batch of test points.
168
+
169
+ Predictive mean:
170
+ y_hat(x) = pi_hat^T mu_x
171
+
172
+ Predictive variance:
173
+ Var[y|x] = sum_s pi_s * sigma_s^2(x) [aleatoric]
174
+ + mu_x^T Sigma_pi mu_x [epistemic]
175
+
176
+ For test points with no valid subset cell (all masked), the predictor
177
+ falls back to the global training mean with aleatoric variance equal to
178
+ the global training variance.
179
+
180
+ Args:
181
+ predictor: SubsetMixturePredictor with learned weight vector.
182
+ sigma_pi: [S, S] posterior covariance from compute_posterior_covariance.
183
+ df: Test DataFrame (integer-encoded).
184
+ return_components: If True, also return aleatoric and epistemic stds.
185
+
186
+ Returns:
187
+ y_mean (np.ndarray): [B] predicted means.
188
+ y_std (np.ndarray): [B] total predictive standard deviations.
189
+ (optional) aleatoric_std (np.ndarray): [B]
190
+ (optional) epistemic_std (np.ndarray): [B]
191
+ """
192
+ batch_tensor = torch.tensor(
193
+ df[predictor.subset_maker.subset_features].astype(float).values,
194
+ dtype=torch.float32,
195
+ )
196
+ mus, variances, mask = predictor.subset_maker.batch_lookup(batch_tensor)
197
+
198
+ B, S = mus.shape
199
+ pi = predictor.weight_vector # [S]
200
+ weights = pi.unsqueeze(0).expand(B, S).clone() # [B, S]
201
+
202
+ # Mask invalid cells
203
+ weights = weights.masked_fill(~mask, 0.0)
204
+ mus_m = mus.masked_fill(~mask, 0.0)
205
+ vars_m = variances.masked_fill(~mask, 0.0)
206
+
207
+ weight_sums = weights.sum(dim=1, keepdim=True) # [B, 1]
208
+ norm_weights = weights / (weight_sums + 1e-9) # [B, S]
209
+
210
+ # --- Predictive mean ---
211
+ fallback_mask = weight_sums.squeeze() < 1e-6 # [B]
212
+ y_mean_weighted = (mus_m * norm_weights).sum(dim=1) # [B]
213
+ y_mean = torch.where(
214
+ fallback_mask,
215
+ torch.full_like(y_mean_weighted, predictor.subset_maker.fallback_mean),
216
+ y_mean_weighted,
217
+ )
218
+
219
+ # --- Aleatoric variance: E_pi[sigma^2(x)] = sum_s pi_s * sigma_s^2(x) ---
220
+ aleatoric_var = (norm_weights * vars_m).sum(dim=1) # [B]
221
+ # Fallback: use global training variance
222
+ aleatoric_var = torch.where(
223
+ fallback_mask,
224
+ torch.full_like(aleatoric_var, predictor.subset_maker.fallback_var),
225
+ aleatoric_var,
226
+ )
227
+
228
+ # --- Epistemic variance: mu_x^T Sigma_pi mu_x ---
229
+ # mu_x is mus_m[i] — the vector of subset-conditional means for example i
230
+ # (zeroed for invalid cells, consistent with the masked prediction)
231
+ epistemic_var = torch.einsum("bi,ij,bj->b", mus_m, sigma_pi, mus_m) # [B]
232
+ epistemic_var = epistemic_var.clamp(min=0.0)
233
+
234
+ total_var = aleatoric_var + epistemic_var
235
+ total_std = torch.sqrt(total_var.clamp(min=1e-12))
236
+
237
+ y_mean_np = y_mean.detach().numpy()
238
+ total_std_np = total_std.detach().numpy()
239
+
240
+ if return_components:
241
+ aleatoric_std_np = torch.sqrt(aleatoric_var.clamp(min=1e-12)).detach().numpy()
242
+ epistemic_std_np = torch.sqrt(epistemic_var.clamp(min=1e-12)).detach().numpy()
243
+ return y_mean_np, total_std_np, aleatoric_std_np, epistemic_std_np
244
+
245
+ return y_mean_np, total_std_np
246
+
247
+
248
+ def coverage(y_true: np.ndarray, y_mean: np.ndarray, y_std: np.ndarray,
249
+ level: float = 0.95) -> float:
250
+ """
251
+ Empirical coverage at a given credible level.
252
+
253
+ Args:
254
+ y_true: [B] true targets.
255
+ y_mean: [B] predictive means.
256
+ y_std: [B] predictive standard deviations.
257
+ level: Nominal coverage (default 0.95).
258
+
259
+ Returns:
260
+ Scalar empirical coverage in [0, 1].
261
+ """
262
+ from scipy.stats import norm as scipy_norm
263
+ z = scipy_norm.ppf((1 + level) / 2)
264
+ lo = y_mean - z * y_std
265
+ hi = y_mean + z * y_std
266
+ return float(((y_true >= lo) & (y_true <= hi)).mean())
@@ -0,0 +1,125 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SubsetWeightsModel(torch.nn.Module):
7
+ """
8
+ Learns a single global weight vector over all feature subsets.
9
+
10
+ A real-valued parameter vector eta of length |S| is initialized to zero and
11
+ passed through softmax to produce the mixture weights pi.
12
+
13
+ Args:
14
+ num_subsets (int): Number of subsets in the powerset (|S|).
15
+ """
16
+
17
+ def __init__(self, num_subsets: int):
18
+ super().__init__()
19
+ self.eta = torch.nn.Parameter(torch.zeros(num_subsets))
20
+
21
+ def forward(self, x=None):
22
+ # Return raw logits; softmax is applied inside the loss function.
23
+ return self.eta
24
+
25
+
26
+ class SubsetDataset(torch.utils.data.Dataset):
27
+ """
28
+ PyTorch Dataset wrapping a DataFrame of integer-coded categorical features.
29
+
30
+ Args:
31
+ df (pd.DataFrame): Data split (train / val / test).
32
+ subset_features (list[str]): Categorical feature column names.
33
+ target (list[str]): Single-element list with the target column name.
34
+ """
35
+
36
+ def __init__(self, df, subset_features: list, target: list):
37
+ self.df = df
38
+ self.subset_features = subset_features
39
+ self.target = target
40
+
41
+ def __len__(self):
42
+ return len(self.df)
43
+
44
+ def __getitem__(self, idx):
45
+ row = self.df.iloc[idx]
46
+ x = torch.tensor(row[self.subset_features].astype(np.float32).values)
47
+ y = torch.tensor(np.float32(row[self.target[0]]))
48
+ return x, y
49
+
50
+
51
+ def subset_mixture_neg_log_posterior(
52
+ pi_logits: torch.Tensor,
53
+ y: torch.Tensor,
54
+ mus: torch.Tensor,
55
+ variances: torch.Tensor,
56
+ mask: torch.Tensor = None,
57
+ alpha: float = 1.1,
58
+ ) -> torch.Tensor:
59
+ """
60
+ Negative log-posterior of the subset mixture model.
61
+
62
+ Loss = -sum_i log( sum_s pi_s * N(y_i | mu_s(x_i), sigma^2_s(x_i)) )
63
+ - (alpha - 1) * sum_s log(pi_s)
64
+
65
+ Args:
66
+ pi_logits: [|S|] unnormalized logits (global parameter).
67
+ y: [B] target values.
68
+ mus: [B, |S|] empirical conditional means.
69
+ variances: [B, |S|] empirical conditional variances.
70
+ mask: [B, |S|] bool tensor; True where the cell exists in training data.
71
+ alpha: Dirichlet concentration parameter (>1 encourages non-degenerate weights).
72
+
73
+ Returns:
74
+ Scalar loss tensor.
75
+ """
76
+ pi = F.softmax(pi_logits, dim=0) # [|S|]
77
+ log_pi = torch.log(pi + 1e-9)
78
+ log_pi_b = log_pi.unsqueeze(0) # [1, |S|]
79
+
80
+ log_probs = (
81
+ -0.5 * torch.log(2 * torch.pi * variances)
82
+ - 0.5 * (y.unsqueeze(1) - mus) ** 2 / variances
83
+ ) # [B, |S|]
84
+
85
+ log_weighted = log_probs + log_pi_b # [B, |S|]
86
+
87
+ if mask is not None:
88
+ log_weighted = log_weighted.masked_fill(~mask, float("-inf"))
89
+
90
+ log_likelihoods = torch.logsumexp(log_weighted, dim=1) # [B]
91
+ nll = -log_likelihoods.sum()
92
+
93
+ batch_size = y.size(0)
94
+ log_prior = (alpha - 1.0) * log_pi.sum()
95
+
96
+ return nll - log_prior / batch_size
97
+
98
+
99
+ def subset_mixture_mse(
100
+ pi_logits: torch.Tensor,
101
+ y: torch.Tensor,
102
+ mus: torch.Tensor,
103
+ mask: torch.Tensor = None,
104
+ ) -> torch.Tensor:
105
+ """
106
+ MSE loss for the subset mixture model (useful for warmup / debugging).
107
+
108
+ Args:
109
+ pi_logits: [|S|] unnormalized logits.
110
+ y: [B] target values.
111
+ mus: [B, |S|] empirical conditional means.
112
+ mask: [B, |S|] bool tensor.
113
+
114
+ Returns:
115
+ Scalar MSE tensor.
116
+ """
117
+ pi = F.softmax(pi_logits, dim=0) # [|S|]
118
+ weights = pi.unsqueeze(0).expand_as(mus) # [B, |S|]
119
+
120
+ if mask is not None:
121
+ weights = weights.masked_fill(~mask, 0.0)
122
+ weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)
123
+
124
+ preds = (mus * weights).sum(dim=1)
125
+ return F.mse_loss(preds, y, reduction="sum")
@@ -0,0 +1,60 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ from .subset_maker import SubsetMaker
5
+
6
+
7
+ class SubsetMixturePredictor:
8
+ """
9
+ Inference wrapper for a trained Subset Mixture Model.
10
+
11
+ Args:
12
+ subset_maker (SubsetMaker): Trained SubsetMaker with lookup table.
13
+ weight_vector (torch.Tensor): Softmaxed mixture weights [|S|].
14
+ """
15
+
16
+ def __init__(self, subset_maker: SubsetMaker, weight_vector: torch.Tensor):
17
+ self.subset_maker = subset_maker
18
+ self.weight_vector = weight_vector
19
+
20
+ def predict(self, df: pd.DataFrame, return_debug: bool = False):
21
+ """
22
+ Predict target values for a new DataFrame.
23
+
24
+ For each example, weights are masked to valid subsets (those seen during
25
+ training) and re-normalized. Examples with no valid subsets fall back to
26
+ the global training mean.
27
+
28
+ Args:
29
+ df: Input DataFrame containing all subset_features columns.
30
+ return_debug: If True, also return per-example normalized weights
31
+ and a fallback mask.
32
+
33
+ Returns:
34
+ preds (np.ndarray): Predicted values [B].
35
+ (optional) norm_weights (np.ndarray): [B, |S|]
36
+ (optional) fallback_mask (np.ndarray): [B] bool
37
+ """
38
+ batch_tensor = torch.tensor(
39
+ df[self.subset_maker.subset_features].astype(np.float32).values
40
+ )
41
+ mus, _, mask = self.subset_maker.batch_lookup(batch_tensor) # [B, S]
42
+
43
+ B, S = mus.shape
44
+ weights = self.weight_vector.unsqueeze(0).expand(B, S) # [B, S]
45
+ weights = weights.masked_fill(~mask, 0.0)
46
+ mus = mus.masked_fill(~mask, 0.0)
47
+
48
+ weight_sums = weights.sum(dim=1, keepdim=True) # [B, 1]
49
+ norm_weights = weights / (weight_sums + 1e-9)
50
+
51
+ fallback_mask = weight_sums.squeeze() < 1e-6 # [B]
52
+ pred_weighted = (mus * norm_weights).sum(dim=1)
53
+ pred_fallback = torch.full_like(
54
+ pred_weighted, self.subset_maker.fallback_mean
55
+ )
56
+ preds = torch.where(fallback_mask, pred_fallback, pred_weighted)
57
+
58
+ if return_debug:
59
+ return preds.numpy(), norm_weights.numpy(), fallback_mask.numpy()
60
+ return preds.numpy()
@@ -0,0 +1,103 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ import itertools
5
+
6
+
7
+ class SubsetMaker:
8
+ """
9
+ Builds a powerset lookup table from a training DataFrame.
10
+
11
+ For each non-empty subset of `subset_features`, groups training examples by
12
+ their unique value combination and stores the empirical mean and variance of
13
+ the target within each group.
14
+
15
+ Args:
16
+ df (pd.DataFrame): Training data. All `subset_features` columns must be
17
+ integer-coded categorical variables.
18
+ subset_features (list[str]): Names of the categorical feature columns.
19
+ target (list[str]): Single-element list with the target column name.
20
+ """
21
+
22
+ def __init__(self, df: pd.DataFrame, subset_features: list, target: list):
23
+ self.df = df
24
+ self.subset_features = subset_features
25
+ self.target = target
26
+ self.lookup = self._build_lookup()
27
+ valid_target = self.df[self.target[0]].dropna()
28
+ self.fallback_mean = float(valid_target.mean())
29
+ self.fallback_var = float(valid_target.var())
30
+
31
+ def _get_powerset(self) -> list:
32
+ powerset = []
33
+ for r in range(1, len(self.subset_features) + 1):
34
+ powerset += list(itertools.combinations(self.subset_features, r))
35
+ return [list(s) for s in powerset]
36
+
37
+ def _build_lookup(self, drop_missing_rows: bool = True) -> dict:
38
+ lookup = {}
39
+ for subset in self._get_powerset():
40
+ grouped = (
41
+ self.df[subset + self.target]
42
+ .groupby(subset)
43
+ .agg({self.target[0]: ["mean", "var"]})
44
+ .reset_index()
45
+ )
46
+ grouped.columns = [
47
+ "_".join(c).strip("_") if isinstance(c, tuple) else c
48
+ for c in grouped.columns
49
+ ]
50
+ if drop_missing_rows:
51
+ grouped = grouped.dropna()
52
+ grouped = grouped[grouped[f"{self.target[0]}_var"] > 1e-6]
53
+
54
+ mean_col = f"{self.target[0]}_mean"
55
+ var_col = f"{self.target[0]}_var"
56
+ group_dict = {
57
+ tuple(row[subset]): (row[mean_col], row[var_col])
58
+ for _, row in grouped.iterrows()
59
+ }
60
+ lookup[tuple(subset)] = (subset, grouped, group_dict)
61
+ return lookup
62
+
63
+ def batch_lookup(self, batch_tensor: torch.Tensor):
64
+ """
65
+ Look up empirical means/variances for a batch of integer-coded examples.
66
+
67
+ Args:
68
+ batch_tensor: [B, D] integer tensor matching the order of subset_features.
69
+
70
+ Returns:
71
+ means: [B, |S|] float tensor
72
+ variances: [B, |S|] float tensor
73
+ mask: [B, |S|] bool tensor (True where the cell exists in training data)
74
+ """
75
+ batch_df = pd.DataFrame(
76
+ batch_tensor.numpy(), columns=self.subset_features
77
+ ).astype(int)
78
+
79
+ all_means, all_vars, all_masks = [], [], []
80
+
81
+ for _, (subset_cols, _, group_dict) in self.lookup.items():
82
+ subset_vals = batch_df[subset_cols].apply(tuple, axis=1)
83
+ means, vars_, mask = [], [], []
84
+ for key in subset_vals:
85
+ if key in group_dict:
86
+ m, v = group_dict[key]
87
+ means.append(m)
88
+ vars_.append(v)
89
+ mask.append(True)
90
+ else:
91
+ means.append(self.fallback_mean)
92
+ vars_.append(self.fallback_var)
93
+ mask.append(False)
94
+ all_means.append(means)
95
+ all_vars.append(vars_)
96
+ all_masks.append(mask)
97
+
98
+ # Transpose from [|S|, B] → [B, |S|]
99
+ return (
100
+ torch.tensor(all_means, dtype=torch.float32).T,
101
+ torch.tensor(all_vars, dtype=torch.float32).T,
102
+ torch.tensor(all_masks, dtype=torch.bool).T,
103
+ )
@@ -0,0 +1,152 @@
1
+ Metadata-Version: 2.4
2
+ Name: subset-mixture-model
3
+ Version: 0.1.0
4
+ Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
5
+ Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
6
+ License: MIT
7
+ Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
8
+ Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.9
18
+ Description-Content-Type: text/markdown
19
+ Requires-Dist: torch>=2.0
20
+ Requires-Dist: numpy>=1.24
21
+ Requires-Dist: pandas>=2.0
22
+ Requires-Dist: scikit-learn>=1.3
23
+ Requires-Dist: scipy>=1.11
24
+ Provides-Extra: experiments
25
+ Requires-Dist: matplotlib>=3.7; extra == "experiments"
26
+ Requires-Dist: seaborn>=0.12; extra == "experiments"
27
+ Requires-Dist: joblib>=1.3; extra == "experiments"
28
+ Requires-Dist: lightgbm>=4.0; extra == "experiments"
29
+ Requires-Dist: ngboost>=0.4; extra == "experiments"
30
+ Requires-Dist: mapie>=0.6; extra == "experiments"
31
+ Provides-Extra: dev
32
+ Requires-Dist: pytest>=7; extra == "dev"
33
+ Requires-Dist: pytest-cov; extra == "dev"
34
+
35
+ # Subset Mixture Model (SMM)
36
+
37
+ **SMM** is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.
38
+
39
+ ## Key idea
40
+
41
+ Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:
42
+
43
+ $$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$
44
+
45
+ The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.
46
+
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install subset-mixture-model
51
+ ```
52
+
53
+ Or from source:
54
+
55
+ ```bash
56
+ git clone https://github.com/aaronjdanielson/subset-mixture-model
57
+ cd subset-mixture-model
58
+ pip install -e .
59
+ ```
60
+
61
+ ## Quick start
62
+
63
+ ```python
64
+ import pandas as pd
65
+ import torch
66
+ import torch.nn.functional as F
67
+ from torch.utils.data import DataLoader
68
+
69
+ from smm import (
70
+ SubsetMaker, SubsetWeightsModel, SubsetDataset,
71
+ subset_mixture_neg_log_posterior, SubsetMixturePredictor,
72
+ compute_posterior_covariance, predict_with_uncertainty,
73
+ )
74
+
75
+ # --- your data (integer-coded categorical features) ---
76
+ train_df = pd.read_csv("train.csv")
77
+ val_df = pd.read_csv("val.csv")
78
+ test_df = pd.read_csv("test.csv")
79
+
80
+ cat_cols = ["feature_a", "feature_b", "feature_c"]
81
+ target = "y"
82
+
83
+ # --- build lookup table ---
84
+ subset_maker = SubsetMaker(train_df, cat_cols, [target])
85
+ n_subsets = len(subset_maker.lookup)
86
+
87
+ # --- train ---
88
+ model = SubsetWeightsModel(n_subsets)
89
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
90
+ loader = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
91
+ batch_size=64, shuffle=True)
92
+
93
+ for epoch in range(100):
94
+ for x, y in loader:
95
+ optimizer.zero_grad()
96
+ mus, variances, mask = subset_maker.batch_lookup(x)
97
+ loss = subset_mixture_neg_log_posterior(
98
+ model(), y, mus, variances, mask, alpha=1.1)
99
+ loss.backward()
100
+ optimizer.step()
101
+
102
+ # --- predict with uncertainty ---
103
+ pi_hat = F.softmax(model.eta.detach(), dim=0)
104
+ predictor = SubsetMixturePredictor(subset_maker, pi_hat)
105
+ sigma_pi = compute_posterior_covariance(
106
+ model, subset_maker, train_df, cat_cols, target, alpha=1.1)
107
+
108
+ y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
109
+ # y_mean: point predictions
110
+ # y_std: total predictive standard deviation (aleatoric + epistemic)
111
+ ```
112
+
113
+ ## Interpretability
114
+
115
+ ```python
116
+ import numpy as np
117
+
118
+ subsets = list(subset_maker.lookup.keys())
119
+ top_idx = np.argsort(pi_hat.numpy())[::-1][:10]
120
+
121
+ for rank, i in enumerate(top_idx):
122
+ print(f"{rank+1:2d}. {subsets[i]} π={pi_hat[i]:.4f}")
123
+ ```
124
+
125
+ ## Features
126
+
127
+ - **Interpretable by construction**: learned weights reveal which feature interactions matter
128
+ - **Principled uncertainty**: aleatoric/epistemic decomposition via Laplace approximation
129
+ - **Efficient training**: only $2^D - 1$ logits optimized; lookup table precomputed once
130
+ - **No post-hoc calibration**: well-calibrated predictive intervals out of the box
131
+ - **Scalable to D ≤ 15** features; $k$-way truncation available for larger $D$
132
+
133
+ ## Datasets supported
134
+
135
+ Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.
136
+
137
+ ## Citation
138
+
139
+ ```bibtex
140
+ @article{danielson2025smm,
141
+ title = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
142
+ of Partition Estimators for Categorical Regression},
143
+ author = {Danielson, Aaron John},
144
+ journal = {Machine Learning},
145
+ year = {2025},
146
+ note = {Under review}
147
+ }
148
+ ```
149
+
150
+ ## License
151
+
152
+ MIT
@@ -0,0 +1,13 @@
1
+ README.md
2
+ pyproject.toml
3
+ smm/__init__.py
4
+ smm/laplace.py
5
+ smm/model.py
6
+ smm/predictor.py
7
+ smm/subset_maker.py
8
+ subset_mixture_model.egg-info/PKG-INFO
9
+ subset_mixture_model.egg-info/SOURCES.txt
10
+ subset_mixture_model.egg-info/dependency_links.txt
11
+ subset_mixture_model.egg-info/requires.txt
12
+ subset_mixture_model.egg-info/top_level.txt
13
+ tests/test_smm.py
@@ -0,0 +1,17 @@
1
+ torch>=2.0
2
+ numpy>=1.24
3
+ pandas>=2.0
4
+ scikit-learn>=1.3
5
+ scipy>=1.11
6
+
7
+ [dev]
8
+ pytest>=7
9
+ pytest-cov
10
+
11
+ [experiments]
12
+ matplotlib>=3.7
13
+ seaborn>=0.12
14
+ joblib>=1.3
15
+ lightgbm>=4.0
16
+ ngboost>=0.4
17
+ mapie>=0.6
@@ -0,0 +1,259 @@
1
+ """
2
+ Core tests for the Subset Mixture Model package.
3
+ """
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pytest
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import DataLoader
11
+
12
+ from smm import (
13
+ SubsetMaker,
14
+ SubsetWeightsModel,
15
+ SubsetDataset,
16
+ subset_mixture_neg_log_posterior,
17
+ SubsetMixturePredictor,
18
+ compute_posterior_covariance,
19
+ predict_with_uncertainty,
20
+ coverage,
21
+ )
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Fixtures
26
+ # ---------------------------------------------------------------------------
27
+
28
+ @pytest.fixture
29
+ def synthetic_df():
30
+ """3-feature synthetic dataset with a known 3-way interaction."""
31
+ rng = np.random.default_rng(0)
32
+ n = 300
33
+ a = rng.integers(0, 4, n)
34
+ b = rng.integers(0, 3, n)
35
+ c = rng.integers(0, 2, n)
36
+ # target = interaction mean + noise
37
+ cell_means = {(i, j, k): rng.normal(0, 2)
38
+ for i in range(4) for j in range(3) for k in range(2)}
39
+ y = np.array([cell_means[(a[i], b[i], c[i])] + rng.normal(0, 0.1)
40
+ for i in range(n)])
41
+ return pd.DataFrame({"a": a, "b": b, "c": c, "y": y})
42
+
43
+
44
+ @pytest.fixture
45
+ def splits(synthetic_df):
46
+ from sklearn.model_selection import train_test_split
47
+ tr, te = train_test_split(synthetic_df, test_size=0.2, random_state=0)
48
+ tr, va = train_test_split(tr, test_size=0.15, random_state=0)
49
+ return tr.reset_index(drop=True), va.reset_index(drop=True), te.reset_index(drop=True)
50
+
51
+
52
+ CAT_COLS = ["a", "b", "c"]
53
+ TARGET = "y"
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # SubsetMaker
58
+ # ---------------------------------------------------------------------------
59
+
60
+ def test_subset_maker_powerset_size(splits):
61
+ tr, _, _ = splits
62
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
63
+ # 3 features → 2^3 - 1 = 7 subsets
64
+ assert len(sm.lookup) == 7
65
+
66
+
67
+ def test_subset_maker_batch_lookup_shapes(splits):
68
+ tr, _, _ = splits
69
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
70
+ x = torch.tensor(tr[CAT_COLS].astype(np.float32).values[:16])
71
+ mus, variances, mask = sm.batch_lookup(x)
72
+ assert mus.shape == (16, 7)
73
+ assert variances.shape == (16, 7)
74
+ assert mask.shape == (16, 7)
75
+ assert mask.dtype == torch.bool
76
+
77
+
78
+ def test_subset_maker_fallback(splits):
79
+ tr, _, _ = splits
80
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
81
+ assert np.isfinite(sm.fallback_mean)
82
+ assert np.isfinite(sm.fallback_var)
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # SubsetWeightsModel
87
+ # ---------------------------------------------------------------------------
88
+
89
+ def test_model_forward_returns_logits():
90
+ model = SubsetWeightsModel(7)
91
+ eta = model()
92
+ assert eta.shape == (7,)
93
+ assert not torch.isnan(eta).any()
94
+
95
+
96
+ def test_softmax_sums_to_one():
97
+ model = SubsetWeightsModel(7)
98
+ pi = F.softmax(model(), dim=0)
99
+ assert torch.allclose(pi.sum(), torch.tensor(1.0), atol=1e-6)
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Loss function
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def test_loss_finite(splits):
107
+ tr, _, _ = splits
108
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
109
+ model = SubsetWeightsModel(len(sm.lookup))
110
+ loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
111
+ batch_size=32, shuffle=False)
112
+ x, y = next(iter(loader))
113
+ mus, variances, mask = sm.batch_lookup(x)
114
+ loss = subset_mixture_neg_log_posterior(model(), y, mus, variances, mask)
115
+ assert torch.isfinite(loss)
116
+
117
+
118
+ def test_loss_decreases_after_training(splits):
119
+ tr, va, _ = splits
120
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
121
+ model = SubsetWeightsModel(len(sm.lookup))
122
+ opt = torch.optim.Adam(model.parameters(), lr=1e-2)
123
+ loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
124
+ batch_size=64, shuffle=True)
125
+
126
+ def epoch_loss():
127
+ total = 0.0
128
+ model.eval()
129
+ with torch.no_grad():
130
+ for x, y in loader:
131
+ mus, variances, mask = sm.batch_lookup(x)
132
+ total += subset_mixture_neg_log_posterior(
133
+ model(), y, mus, variances, mask).item()
134
+ return total
135
+
136
+ loss_before = epoch_loss()
137
+ model.train()
138
+ for _ in range(5):
139
+ for x, y in loader:
140
+ opt.zero_grad()
141
+ mus, variances, mask = sm.batch_lookup(x)
142
+ subset_mixture_neg_log_posterior(
143
+ model(), y, mus, variances, mask).backward()
144
+ opt.step()
145
+ loss_after = epoch_loss()
146
+ assert loss_after < loss_before
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Predictor
151
+ # ---------------------------------------------------------------------------
152
+
153
+ def test_predictor_output_shape(splits):
154
+ tr, _, te = splits
155
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
156
+ model = SubsetWeightsModel(len(sm.lookup))
157
+ pi = F.softmax(model.eta.detach(), dim=0)
158
+ predictor = SubsetMixturePredictor(sm, pi)
159
+ preds = predictor.predict(te)
160
+ assert preds.shape == (len(te),)
161
+ assert np.isfinite(preds).all()
162
+
163
+
164
+ def test_predictor_fallback_for_unseen_cells():
165
+ """Test point with values not seen in training should fall back gracefully."""
166
+ rng = np.random.default_rng(1)
167
+ tr = pd.DataFrame({"a": [0, 1, 0, 1], "b": [0, 0, 1, 1], "y": [1.0, 2.0, 3.0, 4.0]})
168
+ sm = SubsetMaker(tr, ["a", "b"], ["y"])
169
+ model = SubsetWeightsModel(len(sm.lookup))
170
+ pi = F.softmax(model.eta.detach(), dim=0)
171
+ predictor = SubsetMixturePredictor(sm, pi)
172
+ # value 99 was never seen — should return fallback_mean
173
+ unseen = pd.DataFrame({"a": [99], "b": [99], "y": [0.0]})
174
+ preds = predictor.predict(unseen)
175
+ assert np.isfinite(preds).all()
176
+
177
+
178
+ # ---------------------------------------------------------------------------
179
+ # Laplace / uncertainty
180
+ # ---------------------------------------------------------------------------
181
+
182
+ def test_predict_with_uncertainty_shapes(splits):
183
+ tr, va, te = splits
184
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
185
+ model = SubsetWeightsModel(len(sm.lookup))
186
+ opt = torch.optim.Adam(model.parameters(), lr=1e-2)
187
+ loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
188
+ batch_size=64, shuffle=True)
189
+ for _ in range(3):
190
+ for x, y in loader:
191
+ opt.zero_grad()
192
+ mus, variances, mask = sm.batch_lookup(x)
193
+ subset_mixture_neg_log_posterior(
194
+ model(), y, mus, variances, mask).backward()
195
+ opt.step()
196
+ pi = F.softmax(model.eta.detach(), dim=0)
197
+ predictor = SubsetMixturePredictor(sm, pi)
198
+ sigma_pi = compute_posterior_covariance(
199
+ model, sm, tr, CAT_COLS, TARGET, alpha=1.1)
200
+ y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, te)
201
+ assert y_mean.shape == (len(te),)
202
+ assert y_std.shape == (len(te),)
203
+ assert (y_std >= 0).all()
204
+
205
+
206
+ def test_coverage_function():
207
+ y_true = np.array([0.0, 1.0, 2.0, 3.0])
208
+ y_mean = np.array([0.0, 1.0, 2.0, 3.0])
209
+ y_std = np.array([1.0, 1.0, 1.0, 1.0])
210
+ cov = coverage(y_true, y_mean, y_std, level=0.95)
211
+ assert 0.0 <= cov <= 1.0
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # Synthetic recovery (integration test)
216
+ # ---------------------------------------------------------------------------
217
+
218
+ def test_synthetic_weight_recovery(splits):
219
+ """
220
+ With the DGP driven purely by the 3-way interaction,
221
+ the full-powerset subset should receive the highest weight.
222
+ """
223
+ tr, va, te = splits
224
+ sm = SubsetMaker(tr, CAT_COLS, [TARGET])
225
+ model = SubsetWeightsModel(len(sm.lookup))
226
+ opt = torch.optim.Adam(model.parameters(), lr=5e-3)
227
+ loader = DataLoader(SubsetDataset(tr, CAT_COLS, [TARGET]),
228
+ batch_size=64, shuffle=True)
229
+ best_state, best_val = None, float("inf")
230
+ val_loader = DataLoader(SubsetDataset(va, CAT_COLS, [TARGET]),
231
+ batch_size=64, shuffle=False)
232
+ torch.manual_seed(0)
233
+ for epoch in range(60):
234
+ model.train()
235
+ for x, y in loader:
236
+ opt.zero_grad()
237
+ mus, variances, mask = sm.batch_lookup(x)
238
+ subset_mixture_neg_log_posterior(
239
+ model(), y, mus, variances, mask, alpha=1.1).backward()
240
+ opt.step()
241
+ model.eval()
242
+ vl = 0.0
243
+ with torch.no_grad():
244
+ for x, y in val_loader:
245
+ mus, variances, mask = sm.batch_lookup(x)
246
+ vl += subset_mixture_neg_log_posterior(
247
+ model(), y, mus, variances, mask, alpha=1.1).item()
248
+ vl /= len(val_loader)
249
+ if vl < best_val:
250
+ best_val = vl
251
+ best_state = {k: v.clone() for k, v in model.state_dict().items()}
252
+ model.load_state_dict(best_state)
253
+ pi = F.softmax(model.eta.detach(), dim=0).numpy()
254
+ subsets = list(sm.lookup.keys())
255
+ # Full 3-way subset should be highest weight
256
+ full_subset = tuple(sorted(CAT_COLS))
257
+ idx = next(i for i, s in enumerate(subsets) if tuple(sorted(s)) == full_subset)
258
+ assert pi[idx] == pi.max(), (
259
+ f"3-way subset not top-weighted: pi={pi[idx]:.3f} vs max={pi.max():.3f}")