cbps 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbps/__init__.py +3462 -0
- cbps/constants.py +46 -0
- cbps/core/__init__.py +93 -0
- cbps/core/cbps_binary.py +1943 -0
- cbps/core/cbps_continuous.py +945 -0
- cbps/core/cbps_multitreat.py +1123 -0
- cbps/core/cbps_optimal.py +507 -0
- cbps/core/results.py +1447 -0
- cbps/data/Blackwell.csv +571 -0
- cbps/data/LaLonde.csv +3213 -0
- cbps/data/npcbps_continuous_sim.csv +501 -0
- cbps/data/nsw.csv +723 -0
- cbps/data/nsw_dw.csv +446 -0
- cbps/data/political_ads_urban_niebler.csv +16266 -0
- cbps/data/psid_controls.csv +2491 -0
- cbps/data/psid_controls2.csv +254 -0
- cbps/data/psid_controls3.csv +129 -0
- cbps/data/simulation_dgp1_seed12345.csv +201 -0
- cbps/data/simulation_dgp2_seed12345.csv +201 -0
- cbps/data/simulation_dgp3_seed12345.csv +201 -0
- cbps/data/simulation_dgp4_seed12345.csv +201 -0
- cbps/datasets/__init__.py +78 -0
- cbps/datasets/blackwell.py +112 -0
- cbps/datasets/continuous.py +223 -0
- cbps/datasets/lalonde.py +272 -0
- cbps/datasets/npcbps_sim.py +101 -0
- cbps/diagnostics/__init__.py +101 -0
- cbps/diagnostics/balance.py +760 -0
- cbps/diagnostics/balance_cbmsm_addon.py +162 -0
- cbps/diagnostics/continuous_diagnostics.py +259 -0
- cbps/diagnostics/normality.py +173 -0
- cbps/diagnostics/ocbps_conditions.py +197 -0
- cbps/diagnostics/overlap.py +198 -0
- cbps/diagnostics/plots.py +1193 -0
- cbps/diagnostics/weights_diag.py +205 -0
- cbps/highdim/__init__.py +84 -0
- cbps/highdim/gmm_loss.py +340 -0
- cbps/highdim/hdcbps.py +1078 -0
- cbps/highdim/lasso_utils.py +498 -0
- cbps/highdim/weight_funcs.py +298 -0
- cbps/inference/__init__.py +42 -0
- cbps/inference/asyvar.py +621 -0
- cbps/inference/vcov_outcome.py +217 -0
- cbps/iv/__init__.py +48 -0
- cbps/iv/cbiv.py +2603 -0
- cbps/logging_config.py +45 -0
- cbps/msm/__init__.py +45 -0
- cbps/msm/cbmsm.py +1871 -0
- cbps/msm/rank_diagnostics.py +112 -0
- cbps/nonparametric/__init__.py +58 -0
- cbps/nonparametric/cholesky_whitening.py +232 -0
- cbps/nonparametric/empirical_likelihood.py +339 -0
- cbps/nonparametric/npcbps.py +1036 -0
- cbps/nonparametric/taylor_approx.py +207 -0
- cbps/py.typed +0 -0
- cbps/sklearn/__init__.py +42 -0
- cbps/sklearn/estimator.py +378 -0
- cbps/utils/__init__.py +82 -0
- cbps/utils/formula.py +415 -0
- cbps/utils/helpers.py +378 -0
- cbps/utils/numerics.py +438 -0
- cbps/utils/r_compat.py +109 -0
- cbps/utils/validation.py +224 -0
- cbps/utils/variance_transform.py +483 -0
- cbps/utils/weights.py +586 -0
- cbps-0.2.0.dist-info/METADATA +1090 -0
- cbps-0.2.0.dist-info/RECORD +70 -0
- cbps-0.2.0.dist-info/WHEEL +5 -0
- cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
- cbps-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Covariate Balancing Propensity Score for Multi-valued Treatments.
|
|
3
|
+
|
|
4
|
+
This module implements CBPS for categorical treatments with 3 or 4 levels,
|
|
5
|
+
using multinomial logistic regression and contrast weights within the GMM
|
|
6
|
+
framework.
|
|
7
|
+
|
|
8
|
+
Algorithm Overview
|
|
9
|
+
------------------
|
|
10
|
+
1. Multinomial logistic regression for MLE initialization
|
|
11
|
+
2. GMM optimization with covariate balance constraints
|
|
12
|
+
3. Contrast weight computation for treatment effects
|
|
13
|
+
|
|
14
|
+
Notes on Implementation
|
|
15
|
+
-----------------------
|
|
16
|
+
This implementation uses statsmodels.MNLogit for multinomial logistic
|
|
17
|
+
initialization. Baseline-category logit models may have minor numerical
|
|
18
|
+
variations across different statistical libraries due to optimization
|
|
19
|
+
algorithms (±1e-2 to ±1e-1 in MLE estimates).
|
|
20
|
+
|
|
21
|
+
The CBPS optimization process typically reduces these differences,
|
|
22
|
+
with final results usually achieving ±1e-3 accuracy depending on the data.
|
|
23
|
+
|
|
24
|
+
References
|
|
25
|
+
----------
|
|
26
|
+
Imai, Kosuke and Marc Ratkovic. 2014. "Covariate Balancing Propensity Score."
|
|
27
|
+
Journal of the Royal Statistical Society, Series B (Statistical Methodology).
|
|
28
|
+
§4.1 Multi-valued Treatments, Eq.22-24 (p.260)
|
|
29
|
+
DOI:10.1111/rssb.12027
|
|
30
|
+
http://imai.princeton.edu/research/CBPS.html
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import warnings
|
|
34
|
+
from typing import Dict, Optional, Tuple, List, Any
|
|
35
|
+
import numpy as np
|
|
36
|
+
import scipy.linalg
|
|
37
|
+
import scipy.special
|
|
38
|
+
import scipy.optimize
|
|
39
|
+
import statsmodels.api as sm
|
|
40
|
+
|
|
41
|
+
from .results import CBPSResults
|
|
42
|
+
from ..utils.helpers import normalize_sample_weights
|
|
43
|
+
from ..utils.numerics import r_ginv_like, pinv_match_r
|
|
44
|
+
from ..utils.validation import ensure_dense
|
|
45
|
+
from ..logging_config import logger, set_verbosity
|
|
46
|
+
|
|
47
|
+
# Constants
|
|
48
|
+
PROBS_MIN = 1e-6 # Minimum probability clipping threshold
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
from typing import Optional
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _r_ginv(X: np.ndarray, tol: Optional[float] = None) -> np.ndarray:
|
|
55
|
+
"""
|
|
56
|
+
R-compatible pseudoinverse.
|
|
57
|
+
|
|
58
|
+
Default matches MASS::ginv cutoff: tol = max(dim) * smax * eps.
|
|
59
|
+
If tol is provided (absolute), apply it via explicit SVD for
|
|
60
|
+
consistent behavior regardless of SciPy version.
|
|
61
|
+
"""
|
|
62
|
+
if tol is None:
|
|
63
|
+
# Match MASS::ginv by default (preferred for R parity)
|
|
64
|
+
return pinv_match_r(X)
|
|
65
|
+
# Absolute tol requested: compute via explicit SVD to avoid
|
|
66
|
+
# version-specific SciPy kwargs differences
|
|
67
|
+
return r_ginv_like(X, tol=tol)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _compute_softmax_probs_3treat(
|
|
71
|
+
theta: np.ndarray,
|
|
72
|
+
probs_min: float = PROBS_MIN
|
|
73
|
+
) -> np.ndarray:
|
|
74
|
+
"""
|
|
75
|
+
Compute softmax probabilities for 3-level treatments.
|
|
76
|
+
|
|
77
|
+
This function implements numerically stable softmax computation
|
|
78
|
+
to avoid exponential overflow. It uses the baseline category
|
|
79
|
+
logit parameterization where the first category serves as reference.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
theta : np.ndarray
|
|
84
|
+
Logit parameters for categories 2 and 3, shape (n, 2).
|
|
85
|
+
probs_min : float, default PROBS_MIN
|
|
86
|
+
Minimum probability threshold for clipping.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
np.ndarray
|
|
91
|
+
Probability matrix, shape (n, 3), with each row summing to 1.
|
|
92
|
+
"""
|
|
93
|
+
n = theta.shape[0]
|
|
94
|
+
# Numerically stable softmax: subtract row maximum before exponentiation
|
|
95
|
+
theta_with_baseline = np.column_stack([np.zeros(n), theta]) # (n, 3): [0, theta[:,0], theta[:,1]]
|
|
96
|
+
theta_max = theta_with_baseline.max(axis=1, keepdims=True)
|
|
97
|
+
theta_stable = theta_with_baseline - theta_max
|
|
98
|
+
|
|
99
|
+
# Compute exp(theta_stable) without overflow
|
|
100
|
+
exp_theta = np.exp(theta_stable)
|
|
101
|
+
probs = exp_theta / exp_theta.sum(axis=1, keepdims=True)
|
|
102
|
+
|
|
103
|
+
# Iterative clipping and renormalization for numerical stability
|
|
104
|
+
# Single-pass clipping can yield probabilities below threshold when sum > 1
|
|
105
|
+
# after clipping. Iteration ensures all probabilities meet the minimum bound.
|
|
106
|
+
max_iter = 10
|
|
107
|
+
for iteration in range(max_iter):
|
|
108
|
+
# Lower bound clipping
|
|
109
|
+
probs_clipped = np.maximum(probs_min, probs)
|
|
110
|
+
|
|
111
|
+
# Renormalization
|
|
112
|
+
probs_new = probs_clipped / probs_clipped.sum(axis=1, keepdims=True)
|
|
113
|
+
|
|
114
|
+
# Check convergence (all probabilities >= probs_min * 0.999 for numerical tolerance)
|
|
115
|
+
if np.all(probs_new >= probs_min * 0.999):
|
|
116
|
+
probs = probs_new
|
|
117
|
+
break
|
|
118
|
+
|
|
119
|
+
probs = probs_new
|
|
120
|
+
|
|
121
|
+
# If the last iteration still doesn't converge, issue warning
|
|
122
|
+
if iteration == max_iter - 1:
|
|
123
|
+
min_prob = probs.min()
|
|
124
|
+
if min_prob < probs_min * 0.999:
|
|
125
|
+
import warnings
|
|
126
|
+
warnings.warn(
|
|
127
|
+
f"Iterative clipping did not fully converge: min_prob={min_prob:.2e} < {probs_min:.2e}. "
|
|
128
|
+
f"This may occur in extremely imbalanced data (probabilities > 99.9999%).",
|
|
129
|
+
UserWarning
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
assert probs.shape == (n, 3) and np.allclose(probs.sum(axis=1), 1.0, atol=1e-10), \
|
|
133
|
+
f"Softmax probability anomaly: shape={probs.shape}, sum range=[{probs.sum(axis=1).min()}, {probs.sum(axis=1).max()}]"
|
|
134
|
+
|
|
135
|
+
# Verify minimum probability threshold is maintained (with 0.1% numerical tolerance)
|
|
136
|
+
min_prob_actual = probs.min()
|
|
137
|
+
assert min_prob_actual >= probs_min * 0.999, \
|
|
138
|
+
f"Probability threshold violation: min={min_prob_actual:.2e} < {probs_min:.2e}"
|
|
139
|
+
|
|
140
|
+
return probs
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _compute_softmax_probs_4treat(theta: np.ndarray, probs_min: float = PROBS_MIN) -> np.ndarray:
|
|
144
|
+
"""
|
|
145
|
+
Compute 4-treatment softmax probabilities.
|
|
146
|
+
|
|
147
|
+
Uses numerically stable softmax computation to avoid exp overflow.
|
|
148
|
+
"""
|
|
149
|
+
n = theta.shape[0]
|
|
150
|
+
# Numerically stable softmax
|
|
151
|
+
theta_with_baseline = np.column_stack([np.zeros(n), theta]) # (n, 4): [0, theta[:,0], theta[:,1], theta[:,2]]
|
|
152
|
+
theta_max = theta_with_baseline.max(axis=1, keepdims=True)
|
|
153
|
+
theta_stable = theta_with_baseline - theta_max
|
|
154
|
+
|
|
155
|
+
exp_theta = np.exp(theta_stable)
|
|
156
|
+
probs = exp_theta / exp_theta.sum(axis=1, keepdims=True)
|
|
157
|
+
|
|
158
|
+
# Iterative clipping and renormalization for numerical stability (same as 3-treatment)
|
|
159
|
+
max_iter = 10
|
|
160
|
+
for iteration in range(max_iter):
|
|
161
|
+
# Lower bound clipping
|
|
162
|
+
probs_clipped = np.maximum(probs_min, probs)
|
|
163
|
+
|
|
164
|
+
# Re-normalization
|
|
165
|
+
probs_new = probs_clipped / probs_clipped.sum(axis=1, keepdims=True)
|
|
166
|
+
|
|
167
|
+
# Check convergence
|
|
168
|
+
if np.all(probs_new >= probs_min * 0.999):
|
|
169
|
+
probs = probs_new
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
probs = probs_new
|
|
173
|
+
|
|
174
|
+
# If the last iteration still doesn't converge, issue warning
|
|
175
|
+
if iteration == max_iter - 1:
|
|
176
|
+
min_prob = probs.min()
|
|
177
|
+
if min_prob < probs_min * 0.999:
|
|
178
|
+
import warnings
|
|
179
|
+
warnings.warn(
|
|
180
|
+
f"Iterative clipping did not fully converge: min_prob={min_prob:.2e} < {probs_min:.2e}. "
|
|
181
|
+
f"This may occur with extremely imbalanced data (probability >99.9999%).",
|
|
182
|
+
UserWarning
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
assert probs.shape == (n, 4) and np.allclose(probs.sum(axis=1), 1.0, atol=1e-10), \
|
|
186
|
+
f"Softmax probability error: shape={probs.shape}, sum range=[{probs.sum(axis=1).min()}, {probs.sum(axis=1).max()}]"
|
|
187
|
+
|
|
188
|
+
# Verify probability threshold
|
|
189
|
+
min_prob_actual = probs.min()
|
|
190
|
+
assert min_prob_actual >= probs_min * 0.999, \
|
|
191
|
+
f"Probability threshold violated: min={min_prob_actual:.2e} < {probs_min:.2e}"
|
|
192
|
+
|
|
193
|
+
return probs
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _compute_contrast_weights_3treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray, probs: np.ndarray) -> np.ndarray:
|
|
197
|
+
"""Compute contrast weights for 3-level treatment."""
|
|
198
|
+
w_contrast = np.column_stack([
|
|
199
|
+
2*T1/probs[:,0] - T2/probs[:,1] - T3/probs[:,2],
|
|
200
|
+
T2/probs[:,1] - T3/probs[:,2]
|
|
201
|
+
])
|
|
202
|
+
assert w_contrast.shape == (len(T1), 2)
|
|
203
|
+
return w_contrast
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _compute_contrast_weights_4treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray, T4: np.ndarray, probs: np.ndarray) -> np.ndarray:
|
|
207
|
+
"""Compute contrast weights for 4-level treatment."""
|
|
208
|
+
w_contrast = np.column_stack([
|
|
209
|
+
T1/probs[:,0] + T2/probs[:,1] - T3/probs[:,2] - T4/probs[:,3],
|
|
210
|
+
T1/probs[:,0] - T2/probs[:,1] - T3/probs[:,2] + T4/probs[:,3],
|
|
211
|
+
-T1/probs[:,0] + T2/probs[:,1] - T3/probs[:,2] + T4/probs[:,3]
|
|
212
|
+
])
|
|
213
|
+
assert w_contrast.shape == (len(T1), 3)
|
|
214
|
+
return w_contrast
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _compute_V_matrix_3treat(X: np.ndarray, probs: np.ndarray, T1: np.ndarray, T2: np.ndarray,
|
|
218
|
+
T3: np.ndarray, wtX: np.ndarray, n: int) -> np.ndarray:
|
|
219
|
+
"""Compute V matrix (4k x 4k) for 3-level treatment."""
|
|
220
|
+
k = X.shape[1]
|
|
221
|
+
# 10 block matrices with proper broadcasting
|
|
222
|
+
X_1_1 = wtX * (probs[:,1] * (1 - probs[:,1]))[:, None]
|
|
223
|
+
X_2_2 = wtX * (probs[:,2] * (1 - probs[:,2]))[:, None]
|
|
224
|
+
X_3_3 = wtX * (4*probs[:,0]**(-1) + probs[:,1]**(-1) + probs[:,2]**(-1))[:, None]
|
|
225
|
+
X_4_4 = wtX * (probs[:,1]**(-1) + probs[:,2]**(-1))[:, None]
|
|
226
|
+
X_1_2 = wtX * (-probs[:,1] * probs[:,2])[:, None]
|
|
227
|
+
X_1_3 = wtX * (-1)
|
|
228
|
+
X_1_4 = wtX * 1
|
|
229
|
+
X_2_3 = wtX * (-1)
|
|
230
|
+
X_2_4 = wtX * (-1)
|
|
231
|
+
X_3_4 = wtX * (-probs[:,1]**(-1) + probs[:,2]**(-1))[:, None]
|
|
232
|
+
# Assemble 4x4 block matrix
|
|
233
|
+
V = (1.0/n) * np.block([[X_1_1.T @ X, X_1_2.T @ X, X_1_3.T @ X, X_1_4.T @ X],
|
|
234
|
+
[X_1_2.T @ X, X_2_2.T @ X, X_2_3.T @ X, X_2_4.T @ X],
|
|
235
|
+
[X_1_3.T @ X, X_2_3.T @ X, X_3_3.T @ X, X_3_4.T @ X],
|
|
236
|
+
[X_1_4.T @ X, X_2_4.T @ X, X_3_4.T @ X, X_4_4.T @ X]])
|
|
237
|
+
assert V.shape == (4*k, 4*k) and np.allclose(V, V.T, atol=1e-12)
|
|
238
|
+
return V
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _compute_V_matrix_4treat(X: np.ndarray, probs: np.ndarray, T1: np.ndarray, T2: np.ndarray,
|
|
242
|
+
T3: np.ndarray, T4: np.ndarray, wtX: np.ndarray, n: int) -> np.ndarray:
|
|
243
|
+
"""Compute V matrix (6k x 6k) for 4-level treatment."""
|
|
244
|
+
k = X.shape[1]
|
|
245
|
+
# 21 block matrices with proper broadcasting
|
|
246
|
+
X_1_1 = wtX * (probs[:,1] * (1 - probs[:,1]))[:, None]
|
|
247
|
+
X_2_2 = wtX * (probs[:,2] * (1 - probs[:,2]))[:, None]
|
|
248
|
+
X_3_3 = wtX * (probs[:,3] * (1 - probs[:,3]))[:, None]
|
|
249
|
+
X_4_4 = wtX * (probs[:,0]**(-1) + probs[:,1]**(-1) + probs[:,2]**(-1) + probs[:,3]**(-1))[:, None]
|
|
250
|
+
X_5_5 = X_4_4
|
|
251
|
+
X_6_6 = X_4_4
|
|
252
|
+
X_1_2 = wtX * (-probs[:,1] * probs[:,2])[:, None]
|
|
253
|
+
X_1_3 = wtX * (-probs[:,1] * probs[:,3])[:, None]
|
|
254
|
+
X_2_3 = wtX * (-probs[:,2] * probs[:,3])[:, None]
|
|
255
|
+
X_1_4, X_1_6, X_3_5, X_3_6 = wtX, wtX, wtX, wtX
|
|
256
|
+
X_1_5, X_2_4, X_2_5, X_2_6, X_3_4 = wtX * (-1), wtX * (-1), wtX * (-1), wtX * (-1), wtX * (-1)
|
|
257
|
+
X_4_5 = wtX * (probs[:,0]**(-1) - probs[:,1]**(-1) + probs[:,2]**(-1) - probs[:,3]**(-1))[:, None]
|
|
258
|
+
X_4_6 = wtX * (-probs[:,0]**(-1) + probs[:,1]**(-1) + probs[:,2]**(-1) - probs[:,3]**(-1))[:, None]
|
|
259
|
+
X_5_6 = wtX * (-probs[:,0]**(-1) - probs[:,1]**(-1) + probs[:,2]**(-1) + probs[:,3]**(-1))[:, None]
|
|
260
|
+
# Assemble 6x6 block matrix
|
|
261
|
+
V = (1.0/n) * np.block([[X_1_1.T @ X, X_1_2.T @ X, X_1_3.T @ X, X_1_4.T @ X, X_1_5.T @ X, X_1_6.T @ X],
|
|
262
|
+
[X_1_2.T @ X, X_2_2.T @ X, X_2_3.T @ X, X_2_4.T @ X, X_2_5.T @ X, X_2_6.T @ X],
|
|
263
|
+
[X_1_3.T @ X, X_2_3.T @ X, X_3_3.T @ X, X_3_4.T @ X, X_3_5.T @ X, X_3_6.T @ X],
|
|
264
|
+
[X_1_4.T @ X, X_2_4.T @ X, X_3_4.T @ X, X_4_4.T @ X, X_4_5.T @ X, X_4_6.T @ X],
|
|
265
|
+
[X_1_5.T @ X, X_2_5.T @ X, X_3_5.T @ X, X_4_5.T @ X, X_5_5.T @ X, X_5_6.T @ X],
|
|
266
|
+
[X_1_6.T @ X, X_2_6.T @ X, X_3_6.T @ X, X_4_6.T @ X, X_5_6.T @ X, X_6_6.T @ X]])
|
|
267
|
+
assert V.shape == (6*k, 6*k) and np.allclose(V, V.T, atol=1e-12)
|
|
268
|
+
return V
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _gmm_func_3treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
|
|
272
|
+
T3: np.ndarray, sample_weights: np.ndarray, n: int,
|
|
273
|
+
inv_V: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
|
274
|
+
"""GMM objective function for 3-level treatment."""
|
|
275
|
+
k = X.shape[1]
|
|
276
|
+
beta_curr = beta_curr.reshape(k, 2) if beta_curr.ndim == 1 else beta_curr
|
|
277
|
+
theta = X @ beta_curr
|
|
278
|
+
probs = _compute_softmax_probs_3treat(theta, PROBS_MIN)
|
|
279
|
+
w_contrast = _compute_contrast_weights_3treat(T1, T2, T3, probs)
|
|
280
|
+
wtX = sample_weights[:, None] * X
|
|
281
|
+
w_curr_del = (1.0/n) * wtX.T @ w_contrast
|
|
282
|
+
gbar = np.concatenate([(1.0/n) * wtX.T @ (T2 - probs[:,1]),
|
|
283
|
+
(1.0/n) * wtX.T @ (T3 - probs[:,2]),
|
|
284
|
+
w_curr_del.ravel(order='F')])
|
|
285
|
+
if inv_V is None:
|
|
286
|
+
V = _compute_V_matrix_3treat(X, probs, T1, T2, T3, wtX, n)
|
|
287
|
+
inv_V = _r_ginv(V)
|
|
288
|
+
loss = float(gbar.T @ inv_V @ gbar)
|
|
289
|
+
return {'loss': loss, 'inv_V': inv_V}
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _gmm_func_4treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
|
|
293
|
+
T3: np.ndarray, T4: np.ndarray, sample_weights: np.ndarray, n: int,
|
|
294
|
+
inv_V: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
|
295
|
+
"""GMM objective function for 4-level treatment."""
|
|
296
|
+
k = X.shape[1]
|
|
297
|
+
beta_curr = beta_curr.reshape(k, 3) if beta_curr.ndim == 1 else beta_curr
|
|
298
|
+
theta = X @ beta_curr
|
|
299
|
+
probs = _compute_softmax_probs_4treat(theta, PROBS_MIN)
|
|
300
|
+
w_contrast = _compute_contrast_weights_4treat(T1, T2, T3, T4, probs)
|
|
301
|
+
wtX = sample_weights[:, None] * X
|
|
302
|
+
w_curr_del = (1.0/n) * wtX.T @ w_contrast
|
|
303
|
+
gbar = np.concatenate([(1.0/n) * wtX.T @ (T2 - probs[:,1]),
|
|
304
|
+
(1.0/n) * wtX.T @ (T3 - probs[:,2]),
|
|
305
|
+
(1.0/n) * wtX.T @ (T4 - probs[:,3]),
|
|
306
|
+
w_curr_del.ravel(order='F')])
|
|
307
|
+
if inv_V is None:
|
|
308
|
+
V = _compute_V_matrix_4treat(X, probs, T1, T2, T3, T4, wtX, n)
|
|
309
|
+
inv_V = _r_ginv(V)
|
|
310
|
+
loss = float(gbar.T @ inv_V @ gbar)
|
|
311
|
+
return {'loss': loss, 'inv_V': inv_V}
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _bal_loss_3treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
|
|
315
|
+
T3: np.ndarray, sample_weights: np.ndarray, XprimeX_inv: np.ndarray,
|
|
316
|
+
k: int, n: int) -> float:
|
|
317
|
+
"""Balance loss function for 3-level treatment."""
|
|
318
|
+
beta_mat = beta_curr.reshape(k, 2) if beta_curr.ndim == 1 else beta_curr
|
|
319
|
+
theta = X @ beta_mat
|
|
320
|
+
probs = _compute_softmax_probs_3treat(theta, PROBS_MIN)
|
|
321
|
+
w_contrast = _compute_contrast_weights_3treat(T1, T2, T3, probs) / n # Divide by n
|
|
322
|
+
wtX = sample_weights[:, None] * X
|
|
323
|
+
wtXprimew = wtX.T @ w_contrast
|
|
324
|
+
loss = np.sum(np.diag(wtXprimew.T @ XprimeX_inv @ wtXprimew))
|
|
325
|
+
return float(loss)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _bal_loss_4treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
|
|
329
|
+
T3: np.ndarray, T4: np.ndarray, sample_weights: np.ndarray,
|
|
330
|
+
XprimeX_inv: np.ndarray, k: int, n: int) -> float:
|
|
331
|
+
"""Balance loss function for 4-level treatment."""
|
|
332
|
+
beta_mat = beta_curr.reshape(k, 3) if beta_curr.ndim == 1 else beta_curr
|
|
333
|
+
theta = X @ beta_mat
|
|
334
|
+
probs = _compute_softmax_probs_4treat(theta, PROBS_MIN)
|
|
335
|
+
w_contrast = _compute_contrast_weights_4treat(T1, T2, T3, T4, probs) / n
|
|
336
|
+
wtX = sample_weights[:, None] * X
|
|
337
|
+
wtXprimew = wtX.T @ w_contrast
|
|
338
|
+
loss = np.sum(np.diag(wtXprimew.T @ XprimeX_inv @ wtXprimew))
|
|
339
|
+
return float(loss)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _mnlogit_init_3treat(treat: np.ndarray, X: np.ndarray, sample_weights: np.ndarray,
|
|
343
|
+
treat_levels: np.ndarray, k: int, n: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
344
|
+
"""Multinomial logit initialization for 3-level treatment."""
|
|
345
|
+
# Encode treat as 0,1,2 according to treat_levels order
|
|
346
|
+
# Handle multiple types: treat may be integer, string, pd.Categorical, etc.
|
|
347
|
+
|
|
348
|
+
# Convert to numpy array (handles pd.Categorical etc.)
|
|
349
|
+
treat_array = np.asarray(treat)
|
|
350
|
+
|
|
351
|
+
# Check if already integer encoded (check values, not dtype)
|
|
352
|
+
try:
|
|
353
|
+
treat_as_int = treat_array.astype(int)
|
|
354
|
+
if np.array_equal(treat_as_int, treat_array) and np.all((treat_as_int >= 0) & (treat_as_int < len(treat_levels))):
|
|
355
|
+
# treat is already valid integer encoding
|
|
356
|
+
treat_encoded = treat_as_int
|
|
357
|
+
else:
|
|
358
|
+
raise ValueError("Re-encoding needed")
|
|
359
|
+
except (ValueError, TypeError):
|
|
360
|
+
# treat is not integer or needs re-encoding
|
|
361
|
+
treat_map = {level: i for i, level in enumerate(treat_levels)}
|
|
362
|
+
treat_encoded = np.array([treat_map[t] for t in treat_array])
|
|
363
|
+
|
|
364
|
+
# Fit MNLogit with sample weights via row replication
|
|
365
|
+
# (statsmodels.MNLogit doesn't support freq_weights)
|
|
366
|
+
weights_unique = np.unique(sample_weights)
|
|
367
|
+
if len(weights_unique) == 1:
|
|
368
|
+
# Uniform weights, fit directly
|
|
369
|
+
mnl_model = sm.MNLogit(treat_encoded, X)
|
|
370
|
+
mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
|
|
371
|
+
else:
|
|
372
|
+
# Non-uniform weights, use row replication method
|
|
373
|
+
# Normalize weights so minimum is 1
|
|
374
|
+
min_weight = sample_weights.min()
|
|
375
|
+
weights_normalized = sample_weights / min_weight
|
|
376
|
+
|
|
377
|
+
# Check if can convert to integers (tolerance 1e-6)
|
|
378
|
+
weights_int_candidate = np.round(weights_normalized)
|
|
379
|
+
if np.allclose(weights_normalized, weights_int_candidate, atol=1e-6):
|
|
380
|
+
# Use integer weight replication
|
|
381
|
+
weights_int = weights_int_candidate.astype(int)
|
|
382
|
+
X_expanded = np.repeat(X, weights_int, axis=0)
|
|
383
|
+
treat_expanded = np.repeat(treat_encoded, weights_int)
|
|
384
|
+
|
|
385
|
+
mnl_model = sm.MNLogit(treat_expanded, X_expanded)
|
|
386
|
+
mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
|
|
387
|
+
else:
|
|
388
|
+
# Non-integer weights, use approximation
|
|
389
|
+
# Scale weights to be closer to integers
|
|
390
|
+
scale_factor = 100 # Adjustable
|
|
391
|
+
weights_scaled = weights_normalized * scale_factor
|
|
392
|
+
weights_int = np.round(weights_scaled).astype(int)
|
|
393
|
+
weights_int = np.maximum(weights_int, 1) # Ensure at least 1
|
|
394
|
+
|
|
395
|
+
X_expanded = np.repeat(X, weights_int, axis=0)
|
|
396
|
+
treat_expanded = np.repeat(treat_encoded, weights_int)
|
|
397
|
+
|
|
398
|
+
mnl_model = sm.MNLogit(treat_expanded, X_expanded)
|
|
399
|
+
mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
|
|
400
|
+
# statsmodels returns params in (k, K-1) format (no transpose needed)
|
|
401
|
+
mcoef = mnl_result.params # shape (k, 2)
|
|
402
|
+
# Handle NA coefficients
|
|
403
|
+
mcoef[np.isnan(mcoef[:, 0]), 0] = 0
|
|
404
|
+
mcoef[np.isnan(mcoef[:, 1]), 1] = 0
|
|
405
|
+
# Compute MLE probabilities
|
|
406
|
+
theta_mnl = X @ mcoef # (n, 2)
|
|
407
|
+
probs_mnl = _compute_softmax_probs_3treat(theta_mnl, PROBS_MIN)
|
|
408
|
+
return mcoef, probs_mnl
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _mnlogit_init_4treat(treat: np.ndarray, X: np.ndarray, sample_weights: np.ndarray,
|
|
412
|
+
treat_levels: np.ndarray, k: int, n: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
413
|
+
"""Multinomial logit initialization for 4-level treatment."""
|
|
414
|
+
# Encode treat as 0,1,2,3 according to treat_levels order
|
|
415
|
+
# Handle multiple types: pd.Categorical, etc.
|
|
416
|
+
|
|
417
|
+
# Convert to numpy array
|
|
418
|
+
treat_array = np.asarray(treat)
|
|
419
|
+
|
|
420
|
+
# Check if already integer encoded
|
|
421
|
+
try:
|
|
422
|
+
treat_as_int = treat_array.astype(int)
|
|
423
|
+
if np.array_equal(treat_as_int, treat_array) and np.all((treat_as_int >= 0) & (treat_as_int < len(treat_levels))):
|
|
424
|
+
treat_encoded = treat_as_int
|
|
425
|
+
else:
|
|
426
|
+
raise ValueError("Re-encoding needed")
|
|
427
|
+
except (ValueError, TypeError):
|
|
428
|
+
treat_map = {level: i for i, level in enumerate(treat_levels)}
|
|
429
|
+
treat_encoded = np.array([treat_map[t] for t in treat_array])
|
|
430
|
+
|
|
431
|
+
# Fit MNLogit with sample weights via row replication
|
|
432
|
+
weights_unique = np.unique(sample_weights)
|
|
433
|
+
if len(weights_unique) == 1:
|
|
434
|
+
# Uniform weights, fit directly
|
|
435
|
+
mnl_model = sm.MNLogit(treat_encoded, X)
|
|
436
|
+
mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
|
|
437
|
+
else:
|
|
438
|
+
# Non-uniform weights, use row replication
|
|
439
|
+
min_weight = sample_weights.min()
|
|
440
|
+
weights_normalized = sample_weights / min_weight
|
|
441
|
+
|
|
442
|
+
weights_int_candidate = np.round(weights_normalized)
|
|
443
|
+
if np.allclose(weights_normalized, weights_int_candidate, atol=1e-6):
|
|
444
|
+
weights_int = weights_int_candidate.astype(int)
|
|
445
|
+
X_expanded = np.repeat(X, weights_int, axis=0)
|
|
446
|
+
treat_expanded = np.repeat(treat_encoded, weights_int)
|
|
447
|
+
|
|
448
|
+
mnl_model = sm.MNLogit(treat_expanded, X_expanded)
|
|
449
|
+
mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
|
|
450
|
+
else:
|
|
451
|
+
# Non-integer weights, use approximation
|
|
452
|
+
scale_factor = 100
|
|
453
|
+
weights_scaled = weights_normalized * scale_factor
|
|
454
|
+
weights_int = np.round(weights_scaled).astype(int)
|
|
455
|
+
weights_int = np.maximum(weights_int, 1)
|
|
456
|
+
|
|
457
|
+
X_expanded = np.repeat(X, weights_int, axis=0)
|
|
458
|
+
treat_expanded = np.repeat(treat_encoded, weights_int)
|
|
459
|
+
|
|
460
|
+
mnl_model = sm.MNLogit(treat_expanded, X_expanded)
|
|
461
|
+
mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
|
|
462
|
+
# statsmodels returns params in (k, K-1) format
|
|
463
|
+
mcoef = mnl_result.params # shape (k, 3)
|
|
464
|
+
mcoef[np.isnan(mcoef[:, 0]), 0] = 0
|
|
465
|
+
mcoef[np.isnan(mcoef[:, 1]), 1] = 0
|
|
466
|
+
mcoef[np.isnan(mcoef[:, 2]), 2] = 0
|
|
467
|
+
# Compute MLE probabilities
|
|
468
|
+
theta_mnl = X @ mcoef # (n, 3)
|
|
469
|
+
probs_mnl = _compute_softmax_probs_4treat(theta_mnl, PROBS_MIN)
|
|
470
|
+
return mcoef, probs_mnl
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _standardize_weights_3treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray,
|
|
474
|
+
probs_opt: np.ndarray, sample_weights: np.ndarray,
|
|
475
|
+
standardize: bool) -> np.ndarray:
|
|
476
|
+
"""Standardize weights for 3-level treatment."""
|
|
477
|
+
if standardize:
|
|
478
|
+
norm1 = np.sum(T1 * sample_weights / probs_opt[:,0])
|
|
479
|
+
norm2 = np.sum(T2 * sample_weights / probs_opt[:,1])
|
|
480
|
+
norm3 = np.sum(T3 * sample_weights / probs_opt[:,2])
|
|
481
|
+
else:
|
|
482
|
+
norm1 = norm2 = norm3 = 1.0
|
|
483
|
+
w_opt = (T1 / probs_opt[:,0] / norm1 +
|
|
484
|
+
T2 / probs_opt[:,1] / norm2 +
|
|
485
|
+
T3 / probs_opt[:,2] / norm3)
|
|
486
|
+
return w_opt
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _standardize_weights_4treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray, T4: np.ndarray,
|
|
490
|
+
probs_opt: np.ndarray, sample_weights: np.ndarray,
|
|
491
|
+
standardize: bool) -> np.ndarray:
|
|
492
|
+
"""Standardize weights for 4-level treatment."""
|
|
493
|
+
if standardize:
|
|
494
|
+
norm1 = np.sum(T1 * sample_weights / probs_opt[:,0])
|
|
495
|
+
norm2 = np.sum(T2 * sample_weights / probs_opt[:,1])
|
|
496
|
+
norm3 = np.sum(T3 * sample_weights / probs_opt[:,2])
|
|
497
|
+
norm4 = np.sum(T4 * sample_weights / probs_opt[:,3])
|
|
498
|
+
else:
|
|
499
|
+
norm1 = norm2 = norm3 = norm4 = 1.0
|
|
500
|
+
w_opt = (T1 / probs_opt[:,0] / norm1 + T2 / probs_opt[:,1] / norm2 +
|
|
501
|
+
T3 / probs_opt[:,2] / norm3 + T4 / probs_opt[:,3] / norm4)
|
|
502
|
+
return w_opt
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _check_and_fallback_to_mle(J_opt: float, beta_opt: np.ndarray, probs_opt: np.ndarray,
|
|
506
|
+
mcoef: np.ndarray, probs_mnl: np.ndarray,
|
|
507
|
+
gmm_loss_func: Any, bal_loss_func: Any) -> Tuple[np.ndarray, np.ndarray, float, bool]:
|
|
508
|
+
"""Check MLE fallback with dual AND condition."""
|
|
509
|
+
mle_J = gmm_loss_func(mcoef.ravel())
|
|
510
|
+
mle_bal = bal_loss_func(mcoef.ravel())
|
|
511
|
+
opt_bal = bal_loss_func(beta_opt.ravel())
|
|
512
|
+
if (J_opt > mle_J) and (opt_bal > mle_bal):
|
|
513
|
+
warnings.warn("Optimization failed. Results returned are for MLE.")
|
|
514
|
+
return mcoef, probs_mnl, mle_J, True
|
|
515
|
+
return beta_opt, probs_opt, J_opt, False
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _compute_vcov_3treat(beta_opt: np.ndarray, probs_opt: np.ndarray, T1: np.ndarray,
|
|
519
|
+
T2: np.ndarray, T3: np.ndarray, X: np.ndarray,
|
|
520
|
+
sample_weights: np.ndarray, gmm_func: Any, n: int, k: int) -> np.ndarray:
|
|
521
|
+
"""Compute variance-covariance matrix for 3-level treatment."""
|
|
522
|
+
wtX = sample_weights[:, None] * X
|
|
523
|
+
# Recompute invV
|
|
524
|
+
result = gmm_func(beta_opt.ravel(), inv_V=None)
|
|
525
|
+
W = result['inv_V']
|
|
526
|
+
# 8 XG block matrices with proper broadcasting
|
|
527
|
+
XG_1_1 = (-wtX * (probs_opt[:,1] * (1 - probs_opt[:,1]))[:, None]).T @ X
|
|
528
|
+
XG_1_2 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
|
|
529
|
+
XG_1_3 = (wtX * (2*T1*probs_opt[:,1]/probs_opt[:,0] + T2*(1-probs_opt[:,1])/probs_opt[:,1] -
|
|
530
|
+
T3*probs_opt[:,1]/probs_opt[:,2])[:, None]).T @ X
|
|
531
|
+
XG_1_4 = (wtX * (-T2*(1-probs_opt[:,1])/probs_opt[:,1] - T3*probs_opt[:,1]/probs_opt[:,2])[:, None]).T @ X
|
|
532
|
+
XG_2_1 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
|
|
533
|
+
XG_2_2 = (-wtX * (probs_opt[:,2] * (1 - probs_opt[:,2]))[:, None]).T @ X
|
|
534
|
+
XG_2_3 = (wtX * (2*T1*probs_opt[:,2]/probs_opt[:,0] - T2*probs_opt[:,2]/probs_opt[:,1] +
|
|
535
|
+
T3*(1-probs_opt[:,2])/probs_opt[:,2])[:, None]).T @ X
|
|
536
|
+
XG_2_4 = (wtX * (T2*probs_opt[:,2]/probs_opt[:,1] + T3*(1-probs_opt[:,2])/probs_opt[:,2])[:, None]).T @ X
|
|
537
|
+
# Assemble G matrix (2k x 4k)
|
|
538
|
+
G = (1.0/n) * np.vstack([
|
|
539
|
+
np.hstack([XG_1_1, XG_1_2, XG_1_3, XG_1_4]),
|
|
540
|
+
np.hstack([XG_2_1, XG_2_2, XG_2_3, XG_2_4])
|
|
541
|
+
])
|
|
542
|
+
# W1 matrix (4k x n)
|
|
543
|
+
XW_1 = X * (T2 - probs_opt[:,1])[:, None] * (sample_weights**0.5)[:, None]
|
|
544
|
+
XW_2 = X * (T3 - probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
|
|
545
|
+
XW_3 = X * (2*T1/probs_opt[:,0] - T2/probs_opt[:,1] - T3/probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
|
|
546
|
+
XW_4 = X * (T2/probs_opt[:,1] - T3/probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
|
|
547
|
+
W1 = np.vstack([XW_1.T, XW_2.T, XW_3.T, XW_4.T])
|
|
548
|
+
# Omega matrix
|
|
549
|
+
Omega = (1.0/n) * (W1 @ W1.T)
|
|
550
|
+
# Sandwich formula
|
|
551
|
+
GWG = G @ W @ G.T
|
|
552
|
+
GWGinv = _r_ginv(GWG)
|
|
553
|
+
GWGinvGW = GWGinv @ G @ W
|
|
554
|
+
vcov = GWGinvGW @ Omega @ GWGinvGW.T
|
|
555
|
+
assert vcov.shape == (2*k, 2*k)
|
|
556
|
+
return vcov
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _compute_vcov_4treat(beta_opt: np.ndarray, probs_opt: np.ndarray, T1: np.ndarray,
|
|
560
|
+
T2: np.ndarray, T3: np.ndarray, T4: np.ndarray, X: np.ndarray,
|
|
561
|
+
sample_weights: np.ndarray, gmm_func: Any, n: int, k: int) -> np.ndarray:
|
|
562
|
+
"""Compute variance-covariance matrix for 4-level treatment."""
|
|
563
|
+
wtX = sample_weights[:, None] * X
|
|
564
|
+
result = gmm_func(beta_opt.ravel(), inv_V=None)
|
|
565
|
+
W = result['inv_V']
|
|
566
|
+
# 18 XG block matrices with proper broadcasting
|
|
567
|
+
XG_1_1 = (-wtX * (probs_opt[:,1] * (1 - probs_opt[:,1]))[:, None]).T @ X
|
|
568
|
+
XG_1_2 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
|
|
569
|
+
XG_1_3 = (wtX * (probs_opt[:,1] * probs_opt[:,3])[:, None]).T @ X
|
|
570
|
+
XG_1_4 = (wtX * (probs_opt[:,1] * (T1/probs_opt[:,0] - T2*(1-probs_opt[:,1])/probs_opt[:,1]**2 -
|
|
571
|
+
T3/probs_opt[:,2] - T4/probs_opt[:,3]))[:, None]).T @ X
|
|
572
|
+
XG_1_5 = (wtX * (probs_opt[:,1] * (T1/probs_opt[:,0] + T2*(1-probs_opt[:,1])/probs_opt[:,1]**2 -
|
|
573
|
+
T3/probs_opt[:,2] + T4/probs_opt[:,3]))[:, None]).T @ X
|
|
574
|
+
XG_1_6 = (wtX * (probs_opt[:,1] * (-T1/probs_opt[:,0] - T2*(1-probs_opt[:,1])/probs_opt[:,1]**2 -
|
|
575
|
+
T3/probs_opt[:,2] + T4/probs_opt[:,3]))[:, None]).T @ X
|
|
576
|
+
XG_2_1 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
|
|
577
|
+
XG_2_2 = (-wtX * (probs_opt[:,2] * (1 - probs_opt[:,2]))[:, None]).T @ X
|
|
578
|
+
XG_2_3 = (wtX * (probs_opt[:,2] * probs_opt[:,3])[:, None]).T @ X
|
|
579
|
+
XG_2_4 = (wtX * (probs_opt[:,2] * (T1/probs_opt[:,0] + T2/probs_opt[:,1] +
|
|
580
|
+
T3*(1-probs_opt[:,2])/probs_opt[:,2]**2 - T4/probs_opt[:,3]))[:, None]).T @ X
|
|
581
|
+
XG_2_5 = (wtX * (probs_opt[:,2] * (T1/probs_opt[:,0] - T2/probs_opt[:,1] +
|
|
582
|
+
T3*(1-probs_opt[:,2])/probs_opt[:,2]**2 + T4/probs_opt[:,3]))[:, None]).T @ X
|
|
583
|
+
XG_2_6 = (wtX * (probs_opt[:,2] * (-T1/probs_opt[:,0] + T2/probs_opt[:,1] +
|
|
584
|
+
T3*(1-probs_opt[:,2])/probs_opt[:,2]**2 + T4/probs_opt[:,3]))[:, None]).T @ X
|
|
585
|
+
XG_3_1 = (wtX * (probs_opt[:,1] * probs_opt[:,3])[:, None]).T @ X
|
|
586
|
+
XG_3_2 = (wtX * (probs_opt[:,2] * probs_opt[:,3])[:, None]).T @ X
|
|
587
|
+
XG_3_3 = (-wtX * (probs_opt[:,3] * (1 - probs_opt[:,3]))[:, None]).T @ X
|
|
588
|
+
XG_3_4 = (wtX * (probs_opt[:,3] * (T1/probs_opt[:,0] + T2/probs_opt[:,1] -
|
|
589
|
+
T3/probs_opt[:,2] + T4*(1-probs_opt[:,3])/probs_opt[:,3]**2))[:, None]).T @ X
|
|
590
|
+
XG_3_5 = (wtX * (probs_opt[:,3] * (T1/probs_opt[:,0] - T2/probs_opt[:,1] -
|
|
591
|
+
T3/probs_opt[:,2] - T4*(1-probs_opt[:,3])/probs_opt[:,3]**2))[:, None]).T @ X
|
|
592
|
+
XG_3_6 = (wtX * (probs_opt[:,3] * (-T1/probs_opt[:,0] + T2/probs_opt[:,1] -
|
|
593
|
+
T3/probs_opt[:,2] - T4*(1-probs_opt[:,3])/probs_opt[:,3]**2))[:, None]).T @ X
|
|
594
|
+
# G matrix (3k x 6k)
|
|
595
|
+
G = (1.0/n) * np.vstack([
|
|
596
|
+
np.hstack([XG_1_1, XG_1_2, XG_1_3, XG_1_4, XG_1_5, XG_1_6]),
|
|
597
|
+
np.hstack([XG_2_1, XG_2_2, XG_2_3, XG_2_4, XG_2_5, XG_2_6]),
|
|
598
|
+
np.hstack([XG_3_1, XG_3_2, XG_3_3, XG_3_4, XG_3_5, XG_3_6])
|
|
599
|
+
])
|
|
600
|
+
# W1 matrix (6k x n)
|
|
601
|
+
XW_1 = X * (T2 - probs_opt[:,1])[:, None] * (sample_weights**0.5)[:, None]
|
|
602
|
+
XW_2 = X * (T3 - probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
|
|
603
|
+
XW_3 = X * (T4 - probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
|
|
604
|
+
XW_4 = X * (T1/probs_opt[:,0] + T2/probs_opt[:,1] - T3/probs_opt[:,2] - T4/probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
|
|
605
|
+
XW_5 = X * (T1/probs_opt[:,0] - T2/probs_opt[:,1] - T3/probs_opt[:,2] + T4/probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
|
|
606
|
+
XW_6 = X * (-T1/probs_opt[:,0] + T2/probs_opt[:,1] - T3/probs_opt[:,2] + T4/probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
|
|
607
|
+
W1 = np.vstack([XW_1.T, XW_2.T, XW_3.T, XW_4.T, XW_5.T, XW_6.T])
|
|
608
|
+
# Omega matrix
|
|
609
|
+
Omega = (1.0/n) * (W1 @ W1.T)
|
|
610
|
+
# Sandwich formula
|
|
611
|
+
GWG = G @ W @ G.T
|
|
612
|
+
GWGinv = _r_ginv(GWG)
|
|
613
|
+
GWGinvGW = GWGinv @ G @ W
|
|
614
|
+
vcov = GWGinvGW @ Omega @ GWGinvGW.T
|
|
615
|
+
assert vcov.shape == (3*k, 3*k)
|
|
616
|
+
return vcov
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def cbps_3treat_fit(
|
|
620
|
+
treat: np.ndarray,
|
|
621
|
+
X: np.ndarray,
|
|
622
|
+
method: str = 'over',
|
|
623
|
+
k: int = None,
|
|
624
|
+
XprimeX_inv: np.ndarray = None,
|
|
625
|
+
bal_only: bool = False,
|
|
626
|
+
iterations: int = 1000,
|
|
627
|
+
standardize: bool = True,
|
|
628
|
+
two_step: bool = True,
|
|
629
|
+
sample_weights: np.ndarray = None,
|
|
630
|
+
treat_levels: np.ndarray = None,
|
|
631
|
+
verbose: int = 0
|
|
632
|
+
) -> Dict[str, Any]:
|
|
633
|
+
"""
|
|
634
|
+
Fit CBPS for 3-level categorical treatments.
|
|
635
|
+
|
|
636
|
+
This function implements the full CBPS algorithm for treatments with
|
|
637
|
+
exactly three levels, using multinomial logistic regression for
|
|
638
|
+
initialization and GMM optimization for covariate balance.
|
|
639
|
+
|
|
640
|
+
Parameters
|
|
641
|
+
----------
|
|
642
|
+
treat : np.ndarray
|
|
643
|
+
Treatment indicator with 3 levels, shape (n,).
|
|
644
|
+
X : np.ndarray, shape (n, k)
|
|
645
|
+
Covariate matrix (SVD-orthogonalized if applicable).
|
|
646
|
+
method : str, default 'over'
|
|
647
|
+
Estimation method: 'over' for overidentified GMM,
|
|
648
|
+
'exact' for exactly identified GMM.
|
|
649
|
+
k : int
|
|
650
|
+
Rank of covariate matrix after SVD.
|
|
651
|
+
XprimeX_inv : np.ndarray, shape (k, k)
|
|
652
|
+
Inverse of X'X matrix for balance loss computation.
|
|
653
|
+
bal_only : bool, default False
|
|
654
|
+
If True, use balance constraints only.
|
|
655
|
+
If False, include score constraints (overidentified).
|
|
656
|
+
iterations : int, default 1000
|
|
657
|
+
Maximum number of optimization iterations.
|
|
658
|
+
standardize : bool, default True
|
|
659
|
+
If True, apply weight standardization.
|
|
660
|
+
If False, use Horvitz-Thompson weights.
|
|
661
|
+
two_step : bool, default True
|
|
662
|
+
If True, use two-step GMM with pre-computed invV.
|
|
663
|
+
If False, use continuous-updating GMM.
|
|
664
|
+
sample_weights : np.ndarray, optional
|
|
665
|
+
Sampling weights. If None, defaults to uniform weights.
|
|
666
|
+
treat_levels : np.ndarray, optional
|
|
667
|
+
Treatment level values for labeling.
|
|
668
|
+
|
|
669
|
+
Returns
|
|
670
|
+
-------
|
|
671
|
+
Dict[str, Any]
|
|
672
|
+
Dictionary containing fitted model results including:
|
|
673
|
+
- coefficients: Estimated coefficients
|
|
674
|
+
- fitted_values: Propensity scores
|
|
675
|
+
- weights: CBPS weights
|
|
676
|
+
- Additional diagnostic information
|
|
677
|
+
|
|
678
|
+
Keys include:
|
|
679
|
+
- coefficients: Coefficients in orthogonal space, shape (k, 2)
|
|
680
|
+
- fitted_values: Probability matrix, shape (n, 3)
|
|
681
|
+
- linear_predictor: Linear predictor values, shape (n, 2)
|
|
682
|
+
- weights: ATE weights, shape (n,)
|
|
683
|
+
- y: Treatment indicator vector
|
|
684
|
+
- x: Orthogonalized covariate matrix
|
|
685
|
+
- J: J-statistic for overidentification test
|
|
686
|
+
- mle_J: MLE J-statistic
|
|
687
|
+
- deviance: Negative twice log-likelihood
|
|
688
|
+
- converged: Convergence status
|
|
689
|
+
- var: Covariance matrix in orthogonal space, shape (2k, 2k)
|
|
690
|
+
|
|
691
|
+
Algorithm Flow
|
|
692
|
+
-------------
|
|
693
|
+
1. Initialize constants and treatment indicators
|
|
694
|
+
2. MNLogit initialization
|
|
695
|
+
3. Alpha scaling
|
|
696
|
+
4. Balance optimization
|
|
697
|
+
5. Return if bal_only=True
|
|
698
|
+
6. GMM dual initialization optimization
|
|
699
|
+
7. Compute optimal probabilities
|
|
700
|
+
8. Calculate J-statistic
|
|
701
|
+
9. Check for MLE fallback
|
|
702
|
+
10. Compute deviance and weight standardization
|
|
703
|
+
11. Compute covariance matrix
|
|
704
|
+
12. Construct return object
|
|
705
|
+
|
|
706
|
+
References
|
|
707
|
+
----------
|
|
708
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
709
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
710
|
+
"""
|
|
711
|
+
# ========== Initialization ==========
|
|
712
|
+
# Ensure dense matrix (sparse input auto-converted)
|
|
713
|
+
X = ensure_dense(X)
|
|
714
|
+
|
|
715
|
+
# Configure logging from verbose parameter (backward compatibility)
|
|
716
|
+
if verbose >= 2:
|
|
717
|
+
set_verbosity(2)
|
|
718
|
+
elif verbose >= 1:
|
|
719
|
+
set_verbosity(1)
|
|
720
|
+
|
|
721
|
+
# Step 0: Define n first due to Python scoping requirements
|
|
722
|
+
n = len(treat)
|
|
723
|
+
|
|
724
|
+
# Step 1: Treatment levels and indicators
|
|
725
|
+
if treat_levels is None:
|
|
726
|
+
treat_levels = np.unique(treat)
|
|
727
|
+
assert len(treat_levels) == 3, "Must be 3-valued treatment"
|
|
728
|
+
|
|
729
|
+
T1 = (treat == treat_levels[0]).astype(float)
|
|
730
|
+
T2 = (treat == treat_levels[1]).astype(float)
|
|
731
|
+
T3 = (treat == treat_levels[2]).astype(float)
|
|
732
|
+
|
|
733
|
+
# Step 2: Normalize sample_weights
|
|
734
|
+
sample_weights = normalize_sample_weights(sample_weights, n)
|
|
735
|
+
|
|
736
|
+
# Step 3: Compute k and XprimeX_inv
|
|
737
|
+
if k is None:
|
|
738
|
+
k = X.shape[1]
|
|
739
|
+
if XprimeX_inv is None:
|
|
740
|
+
wtX_sqrt = (sample_weights**0.5)[:, None] * X
|
|
741
|
+
XprimeX_inv = _r_ginv(wtX_sqrt.T @ wtX_sqrt)
|
|
742
|
+
|
|
743
|
+
# ========== Define closure functions (using external variables) ==========
|
|
744
|
+
def gmm_loss(beta):
|
|
745
|
+
return _gmm_func_3treat(beta, X, T1, T2, T3, sample_weights, n, None)['loss']
|
|
746
|
+
|
|
747
|
+
def bal_loss(beta):
|
|
748
|
+
return _bal_loss_3treat(beta, X, T1, T2, T3, sample_weights, XprimeX_inv, k, n)
|
|
749
|
+
|
|
750
|
+
# ========== MNLogit initialization ==========
|
|
751
|
+
mcoef, probs_mnl = _mnlogit_init_3treat(treat, X, sample_weights, treat_levels, k, n)
|
|
752
|
+
|
|
753
|
+
# ========== Alpha scaling ==========
|
|
754
|
+
def alpha_func(alpha):
|
|
755
|
+
return gmm_loss(mcoef.ravel() * alpha)
|
|
756
|
+
alpha_result = scipy.optimize.minimize_scalar(alpha_func, bounds=(0.8, 1.1), method='bounded')
|
|
757
|
+
gmm_init = mcoef.ravel() * alpha_result.x
|
|
758
|
+
|
|
759
|
+
# ========== Pre-compute invV (two-step method) ==========
|
|
760
|
+
this_invV = _gmm_func_3treat(gmm_init, X, T1, T2, T3, sample_weights, n, None)['inv_V']
|
|
761
|
+
|
|
762
|
+
# ========== Balance optimization ==========
|
|
763
|
+
logger.info(f"Starting balance optimization (max_iter={iterations})...")
|
|
764
|
+
|
|
765
|
+
if two_step:
|
|
766
|
+
opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
|
|
767
|
+
options={'maxiter': iterations})
|
|
768
|
+
logger.info(f"Balance optimization complete: loss={opt_bal.fun:.6f}, converged={opt_bal.success}")
|
|
769
|
+
else:
|
|
770
|
+
try:
|
|
771
|
+
opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
|
|
772
|
+
options={'maxiter': iterations})
|
|
773
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
774
|
+
opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='Nelder-Mead',
|
|
775
|
+
options={'maxiter': iterations})
|
|
776
|
+
|
|
777
|
+
beta_bal = opt_bal.x
|
|
778
|
+
|
|
779
|
+
# ========== Compute nulldeviance (before all return paths) ==========
|
|
780
|
+
# Null model: each category's probability = its sample proportion
|
|
781
|
+
T1_mean = np.average(T1, weights=sample_weights)
|
|
782
|
+
T2_mean = np.average(T2, weights=sample_weights)
|
|
783
|
+
T3_mean = np.average(T3, weights=sample_weights)
|
|
784
|
+
# Prevent log(0)
|
|
785
|
+
T1_mean = np.clip(T1_mean, 1e-10, 1.0)
|
|
786
|
+
T2_mean = np.clip(T2_mean, 1e-10, 1.0)
|
|
787
|
+
T3_mean = np.clip(T3_mean, 1e-10, 1.0)
|
|
788
|
+
nulldeviance = -2 * np.sum(T1 * np.log(T1_mean) + T2 * np.log(T2_mean) + T3 * np.log(T3_mean))
|
|
789
|
+
|
|
790
|
+
# ========== bal_only early return ==========
|
|
791
|
+
if bal_only:
|
|
792
|
+
beta_opt = beta_bal.reshape(k, 2)
|
|
793
|
+
theta_opt = X @ beta_opt
|
|
794
|
+
probs_opt = _compute_softmax_probs_3treat(theta_opt, PROBS_MIN)
|
|
795
|
+
w_opt = _standardize_weights_3treat(T1, T2, T3, probs_opt, sample_weights, standardize)
|
|
796
|
+
J_opt = bal_loss(beta_opt.ravel())
|
|
797
|
+
deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) + T3 * np.log(probs_opt[:,2]))
|
|
798
|
+
vcov = _compute_vcov_3treat(beta_opt, probs_opt, T1, T2, T3, X, sample_weights,
|
|
799
|
+
lambda b, inv_V=None: _gmm_func_3treat(b, X, T1, T2, T3, sample_weights, n, inv_V),
|
|
800
|
+
n, k)
|
|
801
|
+
mle_J_val = _gmm_func_3treat(mcoef.ravel(), X, T1, T2, T3, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
|
|
802
|
+
return {'coefficients': beta_opt, 'fitted_values': probs_opt, 'linear_predictor': theta_opt,
|
|
803
|
+
'deviance': deviance, 'nulldeviance': nulldeviance, 'weights': w_opt * sample_weights, 'y': treat, 'x': X,
|
|
804
|
+
'converged': opt_bal.success, 'J': J_opt, 'var': vcov, 'mle_J': mle_J_val}
|
|
805
|
+
|
|
806
|
+
# ========== GMM dual initialization selection ==========
|
|
807
|
+
def gmm_loss_with_invV(beta):
|
|
808
|
+
return _gmm_func_3treat(beta, X, T1, T2, T3, sample_weights, n, this_invV)['loss']
|
|
809
|
+
|
|
810
|
+
if two_step:
|
|
811
|
+
gmm_glm_init = scipy.optimize.minimize(gmm_loss_with_invV, gmm_init, method='BFGS',
|
|
812
|
+
options={'maxiter': iterations})
|
|
813
|
+
gmm_bal_init = scipy.optimize.minimize(gmm_loss_with_invV, beta_bal, method='BFGS',
|
|
814
|
+
options={'maxiter': iterations})
|
|
815
|
+
else:
|
|
816
|
+
try:
|
|
817
|
+
gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='BFGS',
|
|
818
|
+
options={'maxiter': iterations})
|
|
819
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
820
|
+
gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='Nelder-Mead',
|
|
821
|
+
options={'maxiter': iterations})
|
|
822
|
+
try:
|
|
823
|
+
gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='BFGS',
|
|
824
|
+
options={'maxiter': iterations})
|
|
825
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
826
|
+
gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='Nelder-Mead',
|
|
827
|
+
options={'maxiter': iterations})
|
|
828
|
+
|
|
829
|
+
# Select the optimization result with lower loss
|
|
830
|
+
opt1 = gmm_glm_init if gmm_glm_init.fun < gmm_bal_init.fun else gmm_bal_init
|
|
831
|
+
|
|
832
|
+
# ========== Optimal probabilities and J-statistic ==========
|
|
833
|
+
beta_opt = opt1.x.reshape(k, 2)
|
|
834
|
+
theta_opt = X @ beta_opt
|
|
835
|
+
probs_opt = _compute_softmax_probs_3treat(theta_opt, PROBS_MIN)
|
|
836
|
+
J_opt = _gmm_func_3treat(beta_opt.ravel(), X, T1, T2, T3, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(beta_opt.ravel())
|
|
837
|
+
|
|
838
|
+
# ========== MLE fallback check ==========
|
|
839
|
+
beta_opt, probs_opt, J_opt, used_mle = _check_and_fallback_to_mle(
|
|
840
|
+
J_opt, beta_opt, probs_opt, mcoef, probs_mnl, gmm_loss, bal_loss
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
# ========== Deviance and weights ==========
|
|
844
|
+
deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) + T3 * np.log(probs_opt[:,2]))
|
|
845
|
+
|
|
846
|
+
# Null deviance already computed above
|
|
847
|
+
|
|
848
|
+
w_opt = _standardize_weights_3treat(T1, T2, T3, probs_opt, sample_weights, standardize)
|
|
849
|
+
|
|
850
|
+
# ========== Vcov computation ==========
|
|
851
|
+
vcov = _compute_vcov_3treat(beta_opt, probs_opt, T1, T2, T3, X, sample_weights,
|
|
852
|
+
lambda b, inv_V=None: _gmm_func_3treat(b, X, T1, T2, T3, sample_weights, n, inv_V),
|
|
853
|
+
n, k)
|
|
854
|
+
|
|
855
|
+
# ========== Return dict ==========
|
|
856
|
+
mle_J_val = _gmm_func_3treat(mcoef.ravel(), X, T1, T2, T3, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
|
|
857
|
+
|
|
858
|
+
# Enhanced non-convergence warning
|
|
859
|
+
if not opt1.success:
|
|
860
|
+
warnings.warn(
|
|
861
|
+
f"Multi-valued CBPS (3-treat) optimization did not converge (converged=False). "
|
|
862
|
+
f"Results may be unreliable. Consider:\n"
|
|
863
|
+
f" 1. Increasing iterations (current: {iterations})\n"
|
|
864
|
+
f" 2. Checking for perfect separation or collinearity\n"
|
|
865
|
+
f" 3. Examining the balance diagnostics\n"
|
|
866
|
+
f" 4. J-statistic: {J_opt:.6f}\n"
|
|
867
|
+
f" 5. Trying different starting values or method='exact'",
|
|
868
|
+
UserWarning,
|
|
869
|
+
stacklevel=2
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
return {
|
|
873
|
+
'coefficients': beta_opt,
|
|
874
|
+
'fitted_values': probs_opt,
|
|
875
|
+
'linear_predictor': theta_opt,
|
|
876
|
+
'deviance': deviance,
|
|
877
|
+
'nulldeviance': nulldeviance,
|
|
878
|
+
'weights': w_opt * sample_weights,
|
|
879
|
+
'y': treat,
|
|
880
|
+
'x': X,
|
|
881
|
+
'converged': opt1.success,
|
|
882
|
+
'J': J_opt,
|
|
883
|
+
'var': vcov,
|
|
884
|
+
'mle_J': mle_J_val
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def cbps_4treat_fit(
|
|
889
|
+
treat: np.ndarray,
|
|
890
|
+
X: np.ndarray,
|
|
891
|
+
method: str = 'over',
|
|
892
|
+
k: int = None,
|
|
893
|
+
XprimeX_inv: np.ndarray = None,
|
|
894
|
+
bal_only: bool = False,
|
|
895
|
+
iterations: int = 1000,
|
|
896
|
+
standardize: bool = True,
|
|
897
|
+
two_step: bool = True,
|
|
898
|
+
sample_weights: np.ndarray = None,
|
|
899
|
+
treat_levels: np.ndarray = None,
|
|
900
|
+
verbose: int = 0
|
|
901
|
+
) -> Dict[str, Any]:
|
|
902
|
+
"""
|
|
903
|
+
4-valued treatment CBPS fitting function (complete workflow).
|
|
904
|
+
|
|
905
|
+
Four-valued treatment CBPS estimator using GMM optimization.
|
|
906
|
+
|
|
907
|
+
Parameters
|
|
908
|
+
----------
|
|
909
|
+
Same as cbps_3treat_fit, but treatment has 4 levels.
|
|
910
|
+
|
|
911
|
+
Returns
|
|
912
|
+
-------
|
|
913
|
+
Dict[str, Any]
|
|
914
|
+
Dictionary with 11 core attributes.
|
|
915
|
+
|
|
916
|
+
Keys include (same structure as 3-treat):
|
|
917
|
+
- coefficients: (k, 3) orthogonal space coefficients (4-treat needs 3 columns)
|
|
918
|
+
- fitted_values: (n, 4) probability matrix (4 columns)
|
|
919
|
+
- linear_predictor: (n, 3) linear predictor (3 columns)
|
|
920
|
+
- weights: (n,) ATE weights
|
|
921
|
+
- var: (3k, 3k) orthogonal space vcov (larger for 4-treat)
|
|
922
|
+
- Other fields same as 3-treat
|
|
923
|
+
|
|
924
|
+
Algorithm Flow
|
|
925
|
+
--------------
|
|
926
|
+
Mostly same as cbps_3treat_fit, main differences:
|
|
927
|
+
- K=4 levels → 3 coefficient columns
|
|
928
|
+
- softmax computes 4 probability columns
|
|
929
|
+
- contrast weights 3 columns (3 of 6 pairwise contrasts)
|
|
930
|
+
- V matrix expands to (6k, 6k) (15 blocks)
|
|
931
|
+
- G matrix expands to (3k, 6k)
|
|
932
|
+
|
|
933
|
+
Notes
|
|
934
|
+
-----
|
|
935
|
+
The 4-treat specific invV selection logic chooses between GMM and balance
|
|
936
|
+
initialization based on which yields lower GMM loss.
|
|
937
|
+
"""
|
|
938
|
+
# ========== Initialization ==========
|
|
939
|
+
# Ensure dense matrix (sparse input auto-converted)
|
|
940
|
+
X = ensure_dense(X)
|
|
941
|
+
|
|
942
|
+
# Configure logging from verbose parameter (backward compatibility)
|
|
943
|
+
if verbose >= 2:
|
|
944
|
+
set_verbosity(2)
|
|
945
|
+
elif verbose >= 1:
|
|
946
|
+
set_verbosity(1)
|
|
947
|
+
|
|
948
|
+
n = len(treat)
|
|
949
|
+
|
|
950
|
+
if treat_levels is None:
|
|
951
|
+
treat_levels = np.unique(treat)
|
|
952
|
+
assert len(treat_levels) == 4, "Must be 4-valued treatment"
|
|
953
|
+
|
|
954
|
+
T1 = (treat == treat_levels[0]).astype(float)
|
|
955
|
+
T2 = (treat == treat_levels[1]).astype(float)
|
|
956
|
+
T3 = (treat == treat_levels[2]).astype(float)
|
|
957
|
+
T4 = (treat == treat_levels[3]).astype(float)
|
|
958
|
+
|
|
959
|
+
sample_weights = normalize_sample_weights(sample_weights, n)
|
|
960
|
+
|
|
961
|
+
if k is None:
|
|
962
|
+
k = X.shape[1]
|
|
963
|
+
if XprimeX_inv is None:
|
|
964
|
+
wtX_sqrt = (sample_weights**0.5)[:, None] * X
|
|
965
|
+
XprimeX_inv = _r_ginv(wtX_sqrt.T @ wtX_sqrt)
|
|
966
|
+
|
|
967
|
+
# ========== Define closure functions ==========
|
|
968
|
+
def gmm_loss(beta):
|
|
969
|
+
return _gmm_func_4treat(beta, X, T1, T2, T3, T4, sample_weights, n, None)['loss']
|
|
970
|
+
|
|
971
|
+
def bal_loss(beta):
|
|
972
|
+
return _bal_loss_4treat(beta, X, T1, T2, T3, T4, sample_weights, XprimeX_inv, k, n)
|
|
973
|
+
|
|
974
|
+
# ========== MNLogit initialization ==========
|
|
975
|
+
mcoef, probs_mnl = _mnlogit_init_4treat(treat, X, sample_weights, treat_levels, k, n)
|
|
976
|
+
|
|
977
|
+
# ========== Alpha scaling ==========
|
|
978
|
+
def alpha_func(alpha):
|
|
979
|
+
return gmm_loss(mcoef.ravel() * alpha)
|
|
980
|
+
alpha_result = scipy.optimize.minimize_scalar(alpha_func, bounds=(0.8, 1.1), method='bounded')
|
|
981
|
+
gmm_init = mcoef.ravel() * alpha_result.x
|
|
982
|
+
|
|
983
|
+
# ========== Pre-compute invV ==========
|
|
984
|
+
temp_invV = _gmm_func_4treat(gmm_init, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
|
|
985
|
+
|
|
986
|
+
# ========== Balance optimization ==========
|
|
987
|
+
logger.info(f"Starting balance optimization (max_iter={iterations})...")
|
|
988
|
+
|
|
989
|
+
if two_step:
|
|
990
|
+
opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
|
|
991
|
+
options={'maxiter': iterations})
|
|
992
|
+
logger.info(f"Balance optimization complete: loss={opt_bal.fun:.6f}, converged={opt_bal.success}")
|
|
993
|
+
else:
|
|
994
|
+
try:
|
|
995
|
+
opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
|
|
996
|
+
options={'maxiter': iterations})
|
|
997
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
998
|
+
opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='Nelder-Mead',
|
|
999
|
+
options={'maxiter': iterations})
|
|
1000
|
+
|
|
1001
|
+
beta_bal = opt_bal.x
|
|
1002
|
+
|
|
1003
|
+
# ========== Compute nulldeviance (before all return paths) ==========
|
|
1004
|
+
# Null model: each category's probability = its sample proportion
|
|
1005
|
+
T1_mean = np.average(T1, weights=sample_weights)
|
|
1006
|
+
T2_mean = np.average(T2, weights=sample_weights)
|
|
1007
|
+
T3_mean = np.average(T3, weights=sample_weights)
|
|
1008
|
+
T4_mean = np.average(T4, weights=sample_weights)
|
|
1009
|
+
T1_mean = np.clip(T1_mean, 1e-10, 1.0)
|
|
1010
|
+
T2_mean = np.clip(T2_mean, 1e-10, 1.0)
|
|
1011
|
+
T3_mean = np.clip(T3_mean, 1e-10, 1.0)
|
|
1012
|
+
T4_mean = np.clip(T4_mean, 1e-10, 1.0)
|
|
1013
|
+
nulldeviance = -2 * np.sum(T1 * np.log(T1_mean) + T2 * np.log(T2_mean) +
|
|
1014
|
+
T3 * np.log(T3_mean) + T4 * np.log(T4_mean))
|
|
1015
|
+
|
|
1016
|
+
# ========== 4-treat specific: invV selection logic ==========
|
|
1017
|
+
if two_step:
|
|
1018
|
+
if gmm_loss(gmm_init) < gmm_loss(beta_bal):
|
|
1019
|
+
this_invV = _gmm_func_4treat(gmm_init, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
|
|
1020
|
+
else:
|
|
1021
|
+
this_invV = _gmm_func_4treat(beta_bal, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
|
|
1022
|
+
if bal_only:
|
|
1023
|
+
this_invV = _gmm_func_4treat(beta_bal, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
|
|
1024
|
+
else:
|
|
1025
|
+
this_invV = temp_invV
|
|
1026
|
+
|
|
1027
|
+
# ========== bal_only early return ==========
|
|
1028
|
+
if bal_only:
|
|
1029
|
+
beta_opt = beta_bal.reshape(k, 3)
|
|
1030
|
+
theta_opt = X @ beta_opt
|
|
1031
|
+
probs_opt = _compute_softmax_probs_4treat(theta_opt, PROBS_MIN)
|
|
1032
|
+
w_opt = _standardize_weights_4treat(T1, T2, T3, T4, probs_opt, sample_weights, standardize)
|
|
1033
|
+
J_opt = bal_loss(beta_opt.ravel())
|
|
1034
|
+
deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) +
|
|
1035
|
+
T3 * np.log(probs_opt[:,2]) + T4 * np.log(probs_opt[:,3]))
|
|
1036
|
+
vcov = _compute_vcov_4treat(beta_opt, probs_opt, T1, T2, T3, T4, X, sample_weights,
|
|
1037
|
+
lambda b, inv_V=None: _gmm_func_4treat(b, X, T1, T2, T3, T4, sample_weights, n, inv_V),
|
|
1038
|
+
n, k)
|
|
1039
|
+
mle_J_val = _gmm_func_4treat(mcoef.ravel(), X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
|
|
1040
|
+
return {'coefficients': beta_opt, 'fitted_values': probs_opt, 'linear_predictor': theta_opt,
|
|
1041
|
+
'deviance': deviance, 'nulldeviance': nulldeviance, 'weights': w_opt * sample_weights, 'y': treat, 'x': X,
|
|
1042
|
+
'converged': opt_bal.success, 'J': J_opt, 'var': vcov, 'mle_J': mle_J_val}
|
|
1043
|
+
|
|
1044
|
+
# ========== GMM dual initialization selection ==========
|
|
1045
|
+
def gmm_loss_with_invV(beta):
|
|
1046
|
+
return _gmm_func_4treat(beta, X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss']
|
|
1047
|
+
|
|
1048
|
+
if two_step:
|
|
1049
|
+
gmm_glm_init = scipy.optimize.minimize(gmm_loss_with_invV, gmm_init, method='BFGS',
|
|
1050
|
+
options={'maxiter': iterations})
|
|
1051
|
+
gmm_bal_init = scipy.optimize.minimize(gmm_loss_with_invV, beta_bal, method='BFGS',
|
|
1052
|
+
options={'maxiter': iterations})
|
|
1053
|
+
else:
|
|
1054
|
+
try:
|
|
1055
|
+
gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='BFGS',
|
|
1056
|
+
options={'maxiter': iterations})
|
|
1057
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
1058
|
+
gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='Nelder-Mead',
|
|
1059
|
+
options={'maxiter': iterations})
|
|
1060
|
+
try:
|
|
1061
|
+
gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='BFGS',
|
|
1062
|
+
options={'maxiter': iterations})
|
|
1063
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
1064
|
+
gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='Nelder-Mead',
|
|
1065
|
+
options={'maxiter': iterations})
|
|
1066
|
+
|
|
1067
|
+
opt1 = gmm_glm_init if gmm_glm_init.fun < gmm_bal_init.fun else gmm_bal_init
|
|
1068
|
+
|
|
1069
|
+
# ========== Optimal probabilities and J-statistic ==========
|
|
1070
|
+
beta_opt = opt1.x.reshape(k, 3)
|
|
1071
|
+
theta_opt = X @ beta_opt
|
|
1072
|
+
probs_opt = _compute_softmax_probs_4treat(theta_opt, PROBS_MIN)
|
|
1073
|
+
J_opt = _gmm_func_4treat(beta_opt.ravel(), X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(beta_opt.ravel())
|
|
1074
|
+
|
|
1075
|
+
# ========== MLE fallback check ==========
|
|
1076
|
+
beta_opt, probs_opt, J_opt, used_mle = _check_and_fallback_to_mle(
|
|
1077
|
+
J_opt, beta_opt, probs_opt, mcoef, probs_mnl, gmm_loss, bal_loss
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
# ========== Deviance and weights ==========
|
|
1081
|
+
deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) +
|
|
1082
|
+
T3 * np.log(probs_opt[:,2]) + T4 * np.log(probs_opt[:,3]))
|
|
1083
|
+
|
|
1084
|
+
# Null deviance already computed above
|
|
1085
|
+
|
|
1086
|
+
w_opt = _standardize_weights_4treat(T1, T2, T3, T4, probs_opt, sample_weights, standardize)
|
|
1087
|
+
|
|
1088
|
+
# ========== Vcov computation ==========
|
|
1089
|
+
vcov = _compute_vcov_4treat(beta_opt, probs_opt, T1, T2, T3, T4, X, sample_weights,
|
|
1090
|
+
lambda b, inv_V=None: _gmm_func_4treat(b, X, T1, T2, T3, T4, sample_weights, n, inv_V),
|
|
1091
|
+
n, k)
|
|
1092
|
+
|
|
1093
|
+
# ========== Return dict ==========
|
|
1094
|
+
mle_J_val = _gmm_func_4treat(mcoef.ravel(), X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
|
|
1095
|
+
|
|
1096
|
+
# Enhanced non-convergence warning
|
|
1097
|
+
if not opt1.success:
|
|
1098
|
+
warnings.warn(
|
|
1099
|
+
f"Multi-valued CBPS (4-treat) optimization did not converge (converged=False). "
|
|
1100
|
+
f"Results may be unreliable. Consider:\n"
|
|
1101
|
+
f" 1. Increasing iterations (current: {iterations})\n"
|
|
1102
|
+
f" 2. Checking for perfect separation or collinearity\n"
|
|
1103
|
+
f" 3. Examining the balance diagnostics\n"
|
|
1104
|
+
f" 4. J-statistic: {J_opt:.6f}\n"
|
|
1105
|
+
f" 5. Trying different starting values or method='exact'",
|
|
1106
|
+
UserWarning,
|
|
1107
|
+
stacklevel=2
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
return {
|
|
1111
|
+
'coefficients': beta_opt,
|
|
1112
|
+
'fitted_values': probs_opt,
|
|
1113
|
+
'linear_predictor': theta_opt,
|
|
1114
|
+
'deviance': deviance,
|
|
1115
|
+
'nulldeviance': nulldeviance,
|
|
1116
|
+
'weights': w_opt * sample_weights,
|
|
1117
|
+
'y': treat,
|
|
1118
|
+
'x': X,
|
|
1119
|
+
'converged': opt1.success,
|
|
1120
|
+
'J': J_opt,
|
|
1121
|
+
'var': vcov,
|
|
1122
|
+
'mle_J': mle_J_val
|
|
1123
|
+
}
|