cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1123 @@
1
+ """
2
+ Covariate Balancing Propensity Score for Multi-valued Treatments.
3
+
4
+ This module implements CBPS for categorical treatments with 3 or 4 levels,
5
+ using multinomial logistic regression and contrast weights within the GMM
6
+ framework.
7
+
8
+ Algorithm Overview
9
+ ------------------
10
+ 1. Multinomial logistic regression for MLE initialization
11
+ 2. GMM optimization with covariate balance constraints
12
+ 3. Contrast weight computation for treatment effects
13
+
14
+ Notes on Implementation
15
+ -----------------------
16
+ This implementation uses statsmodels.MNLogit for multinomial logistic
17
+ initialization. Baseline-category logit models may have minor numerical
18
+ variations across different statistical libraries due to optimization
19
+ algorithms (±1e-2 to ±1e-1 in MLE estimates).
20
+
21
+ The CBPS optimization process typically reduces these differences,
22
+ with final results usually achieving ±1e-3 accuracy depending on the data.
23
+
24
+ References
25
+ ----------
26
+ Imai, Kosuke and Marc Ratkovic. 2014. "Covariate Balancing Propensity Score."
27
+ Journal of the Royal Statistical Society, Series B (Statistical Methodology).
28
+ §4.1 Multi-valued Treatments, Eq.22-24 (p.260)
29
+ DOI:10.1111/rssb.12027
30
+ http://imai.princeton.edu/research/CBPS.html
31
+ """
32
+
33
+ import warnings
34
+ from typing import Dict, Optional, Tuple, List, Any
35
+ import numpy as np
36
+ import scipy.linalg
37
+ import scipy.special
38
+ import scipy.optimize
39
+ import statsmodels.api as sm
40
+
41
+ from .results import CBPSResults
42
+ from ..utils.helpers import normalize_sample_weights
43
+ from ..utils.numerics import r_ginv_like, pinv_match_r
44
+ from ..utils.validation import ensure_dense
45
+ from ..logging_config import logger, set_verbosity
46
+
47
+ # Constants
48
+ PROBS_MIN = 1e-6 # Minimum probability clipping threshold
49
+
50
+
51
+ from typing import Optional
52
+
53
+
54
+ def _r_ginv(X: np.ndarray, tol: Optional[float] = None) -> np.ndarray:
55
+ """
56
+ R-compatible pseudoinverse.
57
+
58
+ Default matches MASS::ginv cutoff: tol = max(dim) * smax * eps.
59
+ If tol is provided (absolute), apply it via explicit SVD for
60
+ consistent behavior regardless of SciPy version.
61
+ """
62
+ if tol is None:
63
+ # Match MASS::ginv by default (preferred for R parity)
64
+ return pinv_match_r(X)
65
+ # Absolute tol requested: compute via explicit SVD to avoid
66
+ # version-specific SciPy kwargs differences
67
+ return r_ginv_like(X, tol=tol)
68
+
69
+
70
+ def _compute_softmax_probs_3treat(
71
+ theta: np.ndarray,
72
+ probs_min: float = PROBS_MIN
73
+ ) -> np.ndarray:
74
+ """
75
+ Compute softmax probabilities for 3-level treatments.
76
+
77
+ This function implements numerically stable softmax computation
78
+ to avoid exponential overflow. It uses the baseline category
79
+ logit parameterization where the first category serves as reference.
80
+
81
+ Parameters
82
+ ----------
83
+ theta : np.ndarray
84
+ Logit parameters for categories 2 and 3, shape (n, 2).
85
+ probs_min : float, default PROBS_MIN
86
+ Minimum probability threshold for clipping.
87
+
88
+ Returns
89
+ -------
90
+ np.ndarray
91
+ Probability matrix, shape (n, 3), with each row summing to 1.
92
+ """
93
+ n = theta.shape[0]
94
+ # Numerically stable softmax: subtract row maximum before exponentiation
95
+ theta_with_baseline = np.column_stack([np.zeros(n), theta]) # (n, 3): [0, theta[:,0], theta[:,1]]
96
+ theta_max = theta_with_baseline.max(axis=1, keepdims=True)
97
+ theta_stable = theta_with_baseline - theta_max
98
+
99
+ # Compute exp(theta_stable) without overflow
100
+ exp_theta = np.exp(theta_stable)
101
+ probs = exp_theta / exp_theta.sum(axis=1, keepdims=True)
102
+
103
+ # Iterative clipping and renormalization for numerical stability
104
+ # Single-pass clipping can yield probabilities below threshold when sum > 1
105
+ # after clipping. Iteration ensures all probabilities meet the minimum bound.
106
+ max_iter = 10
107
+ for iteration in range(max_iter):
108
+ # Lower bound clipping
109
+ probs_clipped = np.maximum(probs_min, probs)
110
+
111
+ # Renormalization
112
+ probs_new = probs_clipped / probs_clipped.sum(axis=1, keepdims=True)
113
+
114
+ # Check convergence (all probabilities >= probs_min * 0.999 for numerical tolerance)
115
+ if np.all(probs_new >= probs_min * 0.999):
116
+ probs = probs_new
117
+ break
118
+
119
+ probs = probs_new
120
+
121
+ # If the last iteration still doesn't converge, issue warning
122
+ if iteration == max_iter - 1:
123
+ min_prob = probs.min()
124
+ if min_prob < probs_min * 0.999:
125
+ import warnings
126
+ warnings.warn(
127
+ f"Iterative clipping did not fully converge: min_prob={min_prob:.2e} < {probs_min:.2e}. "
128
+ f"This may occur in extremely imbalanced data (probabilities > 99.9999%).",
129
+ UserWarning
130
+ )
131
+
132
+ assert probs.shape == (n, 3) and np.allclose(probs.sum(axis=1), 1.0, atol=1e-10), \
133
+ f"Softmax probability anomaly: shape={probs.shape}, sum range=[{probs.sum(axis=1).min()}, {probs.sum(axis=1).max()}]"
134
+
135
+ # Verify minimum probability threshold is maintained (with 0.1% numerical tolerance)
136
+ min_prob_actual = probs.min()
137
+ assert min_prob_actual >= probs_min * 0.999, \
138
+ f"Probability threshold violation: min={min_prob_actual:.2e} < {probs_min:.2e}"
139
+
140
+ return probs
141
+
142
+
143
+ def _compute_softmax_probs_4treat(theta: np.ndarray, probs_min: float = PROBS_MIN) -> np.ndarray:
144
+ """
145
+ Compute 4-treatment softmax probabilities.
146
+
147
+ Uses numerically stable softmax computation to avoid exp overflow.
148
+ """
149
+ n = theta.shape[0]
150
+ # Numerically stable softmax
151
+ theta_with_baseline = np.column_stack([np.zeros(n), theta]) # (n, 4): [0, theta[:,0], theta[:,1], theta[:,2]]
152
+ theta_max = theta_with_baseline.max(axis=1, keepdims=True)
153
+ theta_stable = theta_with_baseline - theta_max
154
+
155
+ exp_theta = np.exp(theta_stable)
156
+ probs = exp_theta / exp_theta.sum(axis=1, keepdims=True)
157
+
158
+ # Iterative clipping and renormalization for numerical stability (same as 3-treatment)
159
+ max_iter = 10
160
+ for iteration in range(max_iter):
161
+ # Lower bound clipping
162
+ probs_clipped = np.maximum(probs_min, probs)
163
+
164
+ # Re-normalization
165
+ probs_new = probs_clipped / probs_clipped.sum(axis=1, keepdims=True)
166
+
167
+ # Check convergence
168
+ if np.all(probs_new >= probs_min * 0.999):
169
+ probs = probs_new
170
+ break
171
+
172
+ probs = probs_new
173
+
174
+ # If the last iteration still doesn't converge, issue warning
175
+ if iteration == max_iter - 1:
176
+ min_prob = probs.min()
177
+ if min_prob < probs_min * 0.999:
178
+ import warnings
179
+ warnings.warn(
180
+ f"Iterative clipping did not fully converge: min_prob={min_prob:.2e} < {probs_min:.2e}. "
181
+ f"This may occur with extremely imbalanced data (probability >99.9999%).",
182
+ UserWarning
183
+ )
184
+
185
+ assert probs.shape == (n, 4) and np.allclose(probs.sum(axis=1), 1.0, atol=1e-10), \
186
+ f"Softmax probability error: shape={probs.shape}, sum range=[{probs.sum(axis=1).min()}, {probs.sum(axis=1).max()}]"
187
+
188
+ # Verify probability threshold
189
+ min_prob_actual = probs.min()
190
+ assert min_prob_actual >= probs_min * 0.999, \
191
+ f"Probability threshold violated: min={min_prob_actual:.2e} < {probs_min:.2e}"
192
+
193
+ return probs
194
+
195
+
196
+ def _compute_contrast_weights_3treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray, probs: np.ndarray) -> np.ndarray:
197
+ """Compute contrast weights for 3-level treatment."""
198
+ w_contrast = np.column_stack([
199
+ 2*T1/probs[:,0] - T2/probs[:,1] - T3/probs[:,2],
200
+ T2/probs[:,1] - T3/probs[:,2]
201
+ ])
202
+ assert w_contrast.shape == (len(T1), 2)
203
+ return w_contrast
204
+
205
+
206
+ def _compute_contrast_weights_4treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray, T4: np.ndarray, probs: np.ndarray) -> np.ndarray:
207
+ """Compute contrast weights for 4-level treatment."""
208
+ w_contrast = np.column_stack([
209
+ T1/probs[:,0] + T2/probs[:,1] - T3/probs[:,2] - T4/probs[:,3],
210
+ T1/probs[:,0] - T2/probs[:,1] - T3/probs[:,2] + T4/probs[:,3],
211
+ -T1/probs[:,0] + T2/probs[:,1] - T3/probs[:,2] + T4/probs[:,3]
212
+ ])
213
+ assert w_contrast.shape == (len(T1), 3)
214
+ return w_contrast
215
+
216
+
217
+ def _compute_V_matrix_3treat(X: np.ndarray, probs: np.ndarray, T1: np.ndarray, T2: np.ndarray,
218
+ T3: np.ndarray, wtX: np.ndarray, n: int) -> np.ndarray:
219
+ """Compute V matrix (4k x 4k) for 3-level treatment."""
220
+ k = X.shape[1]
221
+ # 10 block matrices with proper broadcasting
222
+ X_1_1 = wtX * (probs[:,1] * (1 - probs[:,1]))[:, None]
223
+ X_2_2 = wtX * (probs[:,2] * (1 - probs[:,2]))[:, None]
224
+ X_3_3 = wtX * (4*probs[:,0]**(-1) + probs[:,1]**(-1) + probs[:,2]**(-1))[:, None]
225
+ X_4_4 = wtX * (probs[:,1]**(-1) + probs[:,2]**(-1))[:, None]
226
+ X_1_2 = wtX * (-probs[:,1] * probs[:,2])[:, None]
227
+ X_1_3 = wtX * (-1)
228
+ X_1_4 = wtX * 1
229
+ X_2_3 = wtX * (-1)
230
+ X_2_4 = wtX * (-1)
231
+ X_3_4 = wtX * (-probs[:,1]**(-1) + probs[:,2]**(-1))[:, None]
232
+ # Assemble 4x4 block matrix
233
+ V = (1.0/n) * np.block([[X_1_1.T @ X, X_1_2.T @ X, X_1_3.T @ X, X_1_4.T @ X],
234
+ [X_1_2.T @ X, X_2_2.T @ X, X_2_3.T @ X, X_2_4.T @ X],
235
+ [X_1_3.T @ X, X_2_3.T @ X, X_3_3.T @ X, X_3_4.T @ X],
236
+ [X_1_4.T @ X, X_2_4.T @ X, X_3_4.T @ X, X_4_4.T @ X]])
237
+ assert V.shape == (4*k, 4*k) and np.allclose(V, V.T, atol=1e-12)
238
+ return V
239
+
240
+
241
+ def _compute_V_matrix_4treat(X: np.ndarray, probs: np.ndarray, T1: np.ndarray, T2: np.ndarray,
242
+ T3: np.ndarray, T4: np.ndarray, wtX: np.ndarray, n: int) -> np.ndarray:
243
+ """Compute V matrix (6k x 6k) for 4-level treatment."""
244
+ k = X.shape[1]
245
+ # 21 block matrices with proper broadcasting
246
+ X_1_1 = wtX * (probs[:,1] * (1 - probs[:,1]))[:, None]
247
+ X_2_2 = wtX * (probs[:,2] * (1 - probs[:,2]))[:, None]
248
+ X_3_3 = wtX * (probs[:,3] * (1 - probs[:,3]))[:, None]
249
+ X_4_4 = wtX * (probs[:,0]**(-1) + probs[:,1]**(-1) + probs[:,2]**(-1) + probs[:,3]**(-1))[:, None]
250
+ X_5_5 = X_4_4
251
+ X_6_6 = X_4_4
252
+ X_1_2 = wtX * (-probs[:,1] * probs[:,2])[:, None]
253
+ X_1_3 = wtX * (-probs[:,1] * probs[:,3])[:, None]
254
+ X_2_3 = wtX * (-probs[:,2] * probs[:,3])[:, None]
255
+ X_1_4, X_1_6, X_3_5, X_3_6 = wtX, wtX, wtX, wtX
256
+ X_1_5, X_2_4, X_2_5, X_2_6, X_3_4 = wtX * (-1), wtX * (-1), wtX * (-1), wtX * (-1), wtX * (-1)
257
+ X_4_5 = wtX * (probs[:,0]**(-1) - probs[:,1]**(-1) + probs[:,2]**(-1) - probs[:,3]**(-1))[:, None]
258
+ X_4_6 = wtX * (-probs[:,0]**(-1) + probs[:,1]**(-1) + probs[:,2]**(-1) - probs[:,3]**(-1))[:, None]
259
+ X_5_6 = wtX * (-probs[:,0]**(-1) - probs[:,1]**(-1) + probs[:,2]**(-1) + probs[:,3]**(-1))[:, None]
260
+ # Assemble 6x6 block matrix
261
+ V = (1.0/n) * np.block([[X_1_1.T @ X, X_1_2.T @ X, X_1_3.T @ X, X_1_4.T @ X, X_1_5.T @ X, X_1_6.T @ X],
262
+ [X_1_2.T @ X, X_2_2.T @ X, X_2_3.T @ X, X_2_4.T @ X, X_2_5.T @ X, X_2_6.T @ X],
263
+ [X_1_3.T @ X, X_2_3.T @ X, X_3_3.T @ X, X_3_4.T @ X, X_3_5.T @ X, X_3_6.T @ X],
264
+ [X_1_4.T @ X, X_2_4.T @ X, X_3_4.T @ X, X_4_4.T @ X, X_4_5.T @ X, X_4_6.T @ X],
265
+ [X_1_5.T @ X, X_2_5.T @ X, X_3_5.T @ X, X_4_5.T @ X, X_5_5.T @ X, X_5_6.T @ X],
266
+ [X_1_6.T @ X, X_2_6.T @ X, X_3_6.T @ X, X_4_6.T @ X, X_5_6.T @ X, X_6_6.T @ X]])
267
+ assert V.shape == (6*k, 6*k) and np.allclose(V, V.T, atol=1e-12)
268
+ return V
269
+
270
+
271
+ def _gmm_func_3treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
272
+ T3: np.ndarray, sample_weights: np.ndarray, n: int,
273
+ inv_V: Optional[np.ndarray] = None) -> Dict[str, Any]:
274
+ """GMM objective function for 3-level treatment."""
275
+ k = X.shape[1]
276
+ beta_curr = beta_curr.reshape(k, 2) if beta_curr.ndim == 1 else beta_curr
277
+ theta = X @ beta_curr
278
+ probs = _compute_softmax_probs_3treat(theta, PROBS_MIN)
279
+ w_contrast = _compute_contrast_weights_3treat(T1, T2, T3, probs)
280
+ wtX = sample_weights[:, None] * X
281
+ w_curr_del = (1.0/n) * wtX.T @ w_contrast
282
+ gbar = np.concatenate([(1.0/n) * wtX.T @ (T2 - probs[:,1]),
283
+ (1.0/n) * wtX.T @ (T3 - probs[:,2]),
284
+ w_curr_del.ravel(order='F')])
285
+ if inv_V is None:
286
+ V = _compute_V_matrix_3treat(X, probs, T1, T2, T3, wtX, n)
287
+ inv_V = _r_ginv(V)
288
+ loss = float(gbar.T @ inv_V @ gbar)
289
+ return {'loss': loss, 'inv_V': inv_V}
290
+
291
+
292
+ def _gmm_func_4treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
293
+ T3: np.ndarray, T4: np.ndarray, sample_weights: np.ndarray, n: int,
294
+ inv_V: Optional[np.ndarray] = None) -> Dict[str, Any]:
295
+ """GMM objective function for 4-level treatment."""
296
+ k = X.shape[1]
297
+ beta_curr = beta_curr.reshape(k, 3) if beta_curr.ndim == 1 else beta_curr
298
+ theta = X @ beta_curr
299
+ probs = _compute_softmax_probs_4treat(theta, PROBS_MIN)
300
+ w_contrast = _compute_contrast_weights_4treat(T1, T2, T3, T4, probs)
301
+ wtX = sample_weights[:, None] * X
302
+ w_curr_del = (1.0/n) * wtX.T @ w_contrast
303
+ gbar = np.concatenate([(1.0/n) * wtX.T @ (T2 - probs[:,1]),
304
+ (1.0/n) * wtX.T @ (T3 - probs[:,2]),
305
+ (1.0/n) * wtX.T @ (T4 - probs[:,3]),
306
+ w_curr_del.ravel(order='F')])
307
+ if inv_V is None:
308
+ V = _compute_V_matrix_4treat(X, probs, T1, T2, T3, T4, wtX, n)
309
+ inv_V = _r_ginv(V)
310
+ loss = float(gbar.T @ inv_V @ gbar)
311
+ return {'loss': loss, 'inv_V': inv_V}
312
+
313
+
314
+ def _bal_loss_3treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
315
+ T3: np.ndarray, sample_weights: np.ndarray, XprimeX_inv: np.ndarray,
316
+ k: int, n: int) -> float:
317
+ """Balance loss function for 3-level treatment."""
318
+ beta_mat = beta_curr.reshape(k, 2) if beta_curr.ndim == 1 else beta_curr
319
+ theta = X @ beta_mat
320
+ probs = _compute_softmax_probs_3treat(theta, PROBS_MIN)
321
+ w_contrast = _compute_contrast_weights_3treat(T1, T2, T3, probs) / n # Divide by n
322
+ wtX = sample_weights[:, None] * X
323
+ wtXprimew = wtX.T @ w_contrast
324
+ loss = np.sum(np.diag(wtXprimew.T @ XprimeX_inv @ wtXprimew))
325
+ return float(loss)
326
+
327
+
328
+ def _bal_loss_4treat(beta_curr: np.ndarray, X: np.ndarray, T1: np.ndarray, T2: np.ndarray,
329
+ T3: np.ndarray, T4: np.ndarray, sample_weights: np.ndarray,
330
+ XprimeX_inv: np.ndarray, k: int, n: int) -> float:
331
+ """Balance loss function for 4-level treatment."""
332
+ beta_mat = beta_curr.reshape(k, 3) if beta_curr.ndim == 1 else beta_curr
333
+ theta = X @ beta_mat
334
+ probs = _compute_softmax_probs_4treat(theta, PROBS_MIN)
335
+ w_contrast = _compute_contrast_weights_4treat(T1, T2, T3, T4, probs) / n
336
+ wtX = sample_weights[:, None] * X
337
+ wtXprimew = wtX.T @ w_contrast
338
+ loss = np.sum(np.diag(wtXprimew.T @ XprimeX_inv @ wtXprimew))
339
+ return float(loss)
340
+
341
+
342
+ def _mnlogit_init_3treat(treat: np.ndarray, X: np.ndarray, sample_weights: np.ndarray,
343
+ treat_levels: np.ndarray, k: int, n: int) -> Tuple[np.ndarray, np.ndarray]:
344
+ """Multinomial logit initialization for 3-level treatment."""
345
+ # Encode treat as 0,1,2 according to treat_levels order
346
+ # Handle multiple types: treat may be integer, string, pd.Categorical, etc.
347
+
348
+ # Convert to numpy array (handles pd.Categorical etc.)
349
+ treat_array = np.asarray(treat)
350
+
351
+ # Check if already integer encoded (check values, not dtype)
352
+ try:
353
+ treat_as_int = treat_array.astype(int)
354
+ if np.array_equal(treat_as_int, treat_array) and np.all((treat_as_int >= 0) & (treat_as_int < len(treat_levels))):
355
+ # treat is already valid integer encoding
356
+ treat_encoded = treat_as_int
357
+ else:
358
+ raise ValueError("Re-encoding needed")
359
+ except (ValueError, TypeError):
360
+ # treat is not integer or needs re-encoding
361
+ treat_map = {level: i for i, level in enumerate(treat_levels)}
362
+ treat_encoded = np.array([treat_map[t] for t in treat_array])
363
+
364
+ # Fit MNLogit with sample weights via row replication
365
+ # (statsmodels.MNLogit doesn't support freq_weights)
366
+ weights_unique = np.unique(sample_weights)
367
+ if len(weights_unique) == 1:
368
+ # Uniform weights, fit directly
369
+ mnl_model = sm.MNLogit(treat_encoded, X)
370
+ mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
371
+ else:
372
+ # Non-uniform weights, use row replication method
373
+ # Normalize weights so minimum is 1
374
+ min_weight = sample_weights.min()
375
+ weights_normalized = sample_weights / min_weight
376
+
377
+ # Check if can convert to integers (tolerance 1e-6)
378
+ weights_int_candidate = np.round(weights_normalized)
379
+ if np.allclose(weights_normalized, weights_int_candidate, atol=1e-6):
380
+ # Use integer weight replication
381
+ weights_int = weights_int_candidate.astype(int)
382
+ X_expanded = np.repeat(X, weights_int, axis=0)
383
+ treat_expanded = np.repeat(treat_encoded, weights_int)
384
+
385
+ mnl_model = sm.MNLogit(treat_expanded, X_expanded)
386
+ mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
387
+ else:
388
+ # Non-integer weights, use approximation
389
+ # Scale weights to be closer to integers
390
+ scale_factor = 100 # Adjustable
391
+ weights_scaled = weights_normalized * scale_factor
392
+ weights_int = np.round(weights_scaled).astype(int)
393
+ weights_int = np.maximum(weights_int, 1) # Ensure at least 1
394
+
395
+ X_expanded = np.repeat(X, weights_int, axis=0)
396
+ treat_expanded = np.repeat(treat_encoded, weights_int)
397
+
398
+ mnl_model = sm.MNLogit(treat_expanded, X_expanded)
399
+ mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
400
+ # statsmodels returns params in (k, K-1) format (no transpose needed)
401
+ mcoef = mnl_result.params # shape (k, 2)
402
+ # Handle NA coefficients
403
+ mcoef[np.isnan(mcoef[:, 0]), 0] = 0
404
+ mcoef[np.isnan(mcoef[:, 1]), 1] = 0
405
+ # Compute MLE probabilities
406
+ theta_mnl = X @ mcoef # (n, 2)
407
+ probs_mnl = _compute_softmax_probs_3treat(theta_mnl, PROBS_MIN)
408
+ return mcoef, probs_mnl
409
+
410
+
411
+ def _mnlogit_init_4treat(treat: np.ndarray, X: np.ndarray, sample_weights: np.ndarray,
412
+ treat_levels: np.ndarray, k: int, n: int) -> Tuple[np.ndarray, np.ndarray]:
413
+ """Multinomial logit initialization for 4-level treatment."""
414
+ # Encode treat as 0,1,2,3 according to treat_levels order
415
+ # Handle multiple types: pd.Categorical, etc.
416
+
417
+ # Convert to numpy array
418
+ treat_array = np.asarray(treat)
419
+
420
+ # Check if already integer encoded
421
+ try:
422
+ treat_as_int = treat_array.astype(int)
423
+ if np.array_equal(treat_as_int, treat_array) and np.all((treat_as_int >= 0) & (treat_as_int < len(treat_levels))):
424
+ treat_encoded = treat_as_int
425
+ else:
426
+ raise ValueError("Re-encoding needed")
427
+ except (ValueError, TypeError):
428
+ treat_map = {level: i for i, level in enumerate(treat_levels)}
429
+ treat_encoded = np.array([treat_map[t] for t in treat_array])
430
+
431
+ # Fit MNLogit with sample weights via row replication
432
+ weights_unique = np.unique(sample_weights)
433
+ if len(weights_unique) == 1:
434
+ # Uniform weights, fit directly
435
+ mnl_model = sm.MNLogit(treat_encoded, X)
436
+ mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
437
+ else:
438
+ # Non-uniform weights, use row replication
439
+ min_weight = sample_weights.min()
440
+ weights_normalized = sample_weights / min_weight
441
+
442
+ weights_int_candidate = np.round(weights_normalized)
443
+ if np.allclose(weights_normalized, weights_int_candidate, atol=1e-6):
444
+ weights_int = weights_int_candidate.astype(int)
445
+ X_expanded = np.repeat(X, weights_int, axis=0)
446
+ treat_expanded = np.repeat(treat_encoded, weights_int)
447
+
448
+ mnl_model = sm.MNLogit(treat_expanded, X_expanded)
449
+ mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
450
+ else:
451
+ # Non-integer weights, use approximation
452
+ scale_factor = 100
453
+ weights_scaled = weights_normalized * scale_factor
454
+ weights_int = np.round(weights_scaled).astype(int)
455
+ weights_int = np.maximum(weights_int, 1)
456
+
457
+ X_expanded = np.repeat(X, weights_int, axis=0)
458
+ treat_expanded = np.repeat(treat_encoded, weights_int)
459
+
460
+ mnl_model = sm.MNLogit(treat_expanded, X_expanded)
461
+ mnl_result = mnl_model.fit(maxiter=100, disp=False, method='bfgs')
462
+ # statsmodels returns params in (k, K-1) format
463
+ mcoef = mnl_result.params # shape (k, 3)
464
+ mcoef[np.isnan(mcoef[:, 0]), 0] = 0
465
+ mcoef[np.isnan(mcoef[:, 1]), 1] = 0
466
+ mcoef[np.isnan(mcoef[:, 2]), 2] = 0
467
+ # Compute MLE probabilities
468
+ theta_mnl = X @ mcoef # (n, 3)
469
+ probs_mnl = _compute_softmax_probs_4treat(theta_mnl, PROBS_MIN)
470
+ return mcoef, probs_mnl
471
+
472
+
473
+ def _standardize_weights_3treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray,
474
+ probs_opt: np.ndarray, sample_weights: np.ndarray,
475
+ standardize: bool) -> np.ndarray:
476
+ """Standardize weights for 3-level treatment."""
477
+ if standardize:
478
+ norm1 = np.sum(T1 * sample_weights / probs_opt[:,0])
479
+ norm2 = np.sum(T2 * sample_weights / probs_opt[:,1])
480
+ norm3 = np.sum(T3 * sample_weights / probs_opt[:,2])
481
+ else:
482
+ norm1 = norm2 = norm3 = 1.0
483
+ w_opt = (T1 / probs_opt[:,0] / norm1 +
484
+ T2 / probs_opt[:,1] / norm2 +
485
+ T3 / probs_opt[:,2] / norm3)
486
+ return w_opt
487
+
488
+
489
+ def _standardize_weights_4treat(T1: np.ndarray, T2: np.ndarray, T3: np.ndarray, T4: np.ndarray,
490
+ probs_opt: np.ndarray, sample_weights: np.ndarray,
491
+ standardize: bool) -> np.ndarray:
492
+ """Standardize weights for 4-level treatment."""
493
+ if standardize:
494
+ norm1 = np.sum(T1 * sample_weights / probs_opt[:,0])
495
+ norm2 = np.sum(T2 * sample_weights / probs_opt[:,1])
496
+ norm3 = np.sum(T3 * sample_weights / probs_opt[:,2])
497
+ norm4 = np.sum(T4 * sample_weights / probs_opt[:,3])
498
+ else:
499
+ norm1 = norm2 = norm3 = norm4 = 1.0
500
+ w_opt = (T1 / probs_opt[:,0] / norm1 + T2 / probs_opt[:,1] / norm2 +
501
+ T3 / probs_opt[:,2] / norm3 + T4 / probs_opt[:,3] / norm4)
502
+ return w_opt
503
+
504
+
505
+ def _check_and_fallback_to_mle(J_opt: float, beta_opt: np.ndarray, probs_opt: np.ndarray,
506
+ mcoef: np.ndarray, probs_mnl: np.ndarray,
507
+ gmm_loss_func: Any, bal_loss_func: Any) -> Tuple[np.ndarray, np.ndarray, float, bool]:
508
+ """Check MLE fallback with dual AND condition."""
509
+ mle_J = gmm_loss_func(mcoef.ravel())
510
+ mle_bal = bal_loss_func(mcoef.ravel())
511
+ opt_bal = bal_loss_func(beta_opt.ravel())
512
+ if (J_opt > mle_J) and (opt_bal > mle_bal):
513
+ warnings.warn("Optimization failed. Results returned are for MLE.")
514
+ return mcoef, probs_mnl, mle_J, True
515
+ return beta_opt, probs_opt, J_opt, False
516
+
517
+
518
+ def _compute_vcov_3treat(beta_opt: np.ndarray, probs_opt: np.ndarray, T1: np.ndarray,
519
+ T2: np.ndarray, T3: np.ndarray, X: np.ndarray,
520
+ sample_weights: np.ndarray, gmm_func: Any, n: int, k: int) -> np.ndarray:
521
+ """Compute variance-covariance matrix for 3-level treatment."""
522
+ wtX = sample_weights[:, None] * X
523
+ # Recompute invV
524
+ result = gmm_func(beta_opt.ravel(), inv_V=None)
525
+ W = result['inv_V']
526
+ # 8 XG block matrices with proper broadcasting
527
+ XG_1_1 = (-wtX * (probs_opt[:,1] * (1 - probs_opt[:,1]))[:, None]).T @ X
528
+ XG_1_2 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
529
+ XG_1_3 = (wtX * (2*T1*probs_opt[:,1]/probs_opt[:,0] + T2*(1-probs_opt[:,1])/probs_opt[:,1] -
530
+ T3*probs_opt[:,1]/probs_opt[:,2])[:, None]).T @ X
531
+ XG_1_4 = (wtX * (-T2*(1-probs_opt[:,1])/probs_opt[:,1] - T3*probs_opt[:,1]/probs_opt[:,2])[:, None]).T @ X
532
+ XG_2_1 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
533
+ XG_2_2 = (-wtX * (probs_opt[:,2] * (1 - probs_opt[:,2]))[:, None]).T @ X
534
+ XG_2_3 = (wtX * (2*T1*probs_opt[:,2]/probs_opt[:,0] - T2*probs_opt[:,2]/probs_opt[:,1] +
535
+ T3*(1-probs_opt[:,2])/probs_opt[:,2])[:, None]).T @ X
536
+ XG_2_4 = (wtX * (T2*probs_opt[:,2]/probs_opt[:,1] + T3*(1-probs_opt[:,2])/probs_opt[:,2])[:, None]).T @ X
537
+ # Assemble G matrix (2k x 4k)
538
+ G = (1.0/n) * np.vstack([
539
+ np.hstack([XG_1_1, XG_1_2, XG_1_3, XG_1_4]),
540
+ np.hstack([XG_2_1, XG_2_2, XG_2_3, XG_2_4])
541
+ ])
542
+ # W1 matrix (4k x n)
543
+ XW_1 = X * (T2 - probs_opt[:,1])[:, None] * (sample_weights**0.5)[:, None]
544
+ XW_2 = X * (T3 - probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
545
+ XW_3 = X * (2*T1/probs_opt[:,0] - T2/probs_opt[:,1] - T3/probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
546
+ XW_4 = X * (T2/probs_opt[:,1] - T3/probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
547
+ W1 = np.vstack([XW_1.T, XW_2.T, XW_3.T, XW_4.T])
548
+ # Omega matrix
549
+ Omega = (1.0/n) * (W1 @ W1.T)
550
+ # Sandwich formula
551
+ GWG = G @ W @ G.T
552
+ GWGinv = _r_ginv(GWG)
553
+ GWGinvGW = GWGinv @ G @ W
554
+ vcov = GWGinvGW @ Omega @ GWGinvGW.T
555
+ assert vcov.shape == (2*k, 2*k)
556
+ return vcov
557
+
558
+
559
+ def _compute_vcov_4treat(beta_opt: np.ndarray, probs_opt: np.ndarray, T1: np.ndarray,
560
+ T2: np.ndarray, T3: np.ndarray, T4: np.ndarray, X: np.ndarray,
561
+ sample_weights: np.ndarray, gmm_func: Any, n: int, k: int) -> np.ndarray:
562
+ """Compute variance-covariance matrix for 4-level treatment."""
563
+ wtX = sample_weights[:, None] * X
564
+ result = gmm_func(beta_opt.ravel(), inv_V=None)
565
+ W = result['inv_V']
566
+ # 18 XG block matrices with proper broadcasting
567
+ XG_1_1 = (-wtX * (probs_opt[:,1] * (1 - probs_opt[:,1]))[:, None]).T @ X
568
+ XG_1_2 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
569
+ XG_1_3 = (wtX * (probs_opt[:,1] * probs_opt[:,3])[:, None]).T @ X
570
+ XG_1_4 = (wtX * (probs_opt[:,1] * (T1/probs_opt[:,0] - T2*(1-probs_opt[:,1])/probs_opt[:,1]**2 -
571
+ T3/probs_opt[:,2] - T4/probs_opt[:,3]))[:, None]).T @ X
572
+ XG_1_5 = (wtX * (probs_opt[:,1] * (T1/probs_opt[:,0] + T2*(1-probs_opt[:,1])/probs_opt[:,1]**2 -
573
+ T3/probs_opt[:,2] + T4/probs_opt[:,3]))[:, None]).T @ X
574
+ XG_1_6 = (wtX * (probs_opt[:,1] * (-T1/probs_opt[:,0] - T2*(1-probs_opt[:,1])/probs_opt[:,1]**2 -
575
+ T3/probs_opt[:,2] + T4/probs_opt[:,3]))[:, None]).T @ X
576
+ XG_2_1 = (wtX * (probs_opt[:,1] * probs_opt[:,2])[:, None]).T @ X
577
+ XG_2_2 = (-wtX * (probs_opt[:,2] * (1 - probs_opt[:,2]))[:, None]).T @ X
578
+ XG_2_3 = (wtX * (probs_opt[:,2] * probs_opt[:,3])[:, None]).T @ X
579
+ XG_2_4 = (wtX * (probs_opt[:,2] * (T1/probs_opt[:,0] + T2/probs_opt[:,1] +
580
+ T3*(1-probs_opt[:,2])/probs_opt[:,2]**2 - T4/probs_opt[:,3]))[:, None]).T @ X
581
+ XG_2_5 = (wtX * (probs_opt[:,2] * (T1/probs_opt[:,0] - T2/probs_opt[:,1] +
582
+ T3*(1-probs_opt[:,2])/probs_opt[:,2]**2 + T4/probs_opt[:,3]))[:, None]).T @ X
583
+ XG_2_6 = (wtX * (probs_opt[:,2] * (-T1/probs_opt[:,0] + T2/probs_opt[:,1] +
584
+ T3*(1-probs_opt[:,2])/probs_opt[:,2]**2 + T4/probs_opt[:,3]))[:, None]).T @ X
585
+ XG_3_1 = (wtX * (probs_opt[:,1] * probs_opt[:,3])[:, None]).T @ X
586
+ XG_3_2 = (wtX * (probs_opt[:,2] * probs_opt[:,3])[:, None]).T @ X
587
+ XG_3_3 = (-wtX * (probs_opt[:,3] * (1 - probs_opt[:,3]))[:, None]).T @ X
588
+ XG_3_4 = (wtX * (probs_opt[:,3] * (T1/probs_opt[:,0] + T2/probs_opt[:,1] -
589
+ T3/probs_opt[:,2] + T4*(1-probs_opt[:,3])/probs_opt[:,3]**2))[:, None]).T @ X
590
+ XG_3_5 = (wtX * (probs_opt[:,3] * (T1/probs_opt[:,0] - T2/probs_opt[:,1] -
591
+ T3/probs_opt[:,2] - T4*(1-probs_opt[:,3])/probs_opt[:,3]**2))[:, None]).T @ X
592
+ XG_3_6 = (wtX * (probs_opt[:,3] * (-T1/probs_opt[:,0] + T2/probs_opt[:,1] -
593
+ T3/probs_opt[:,2] - T4*(1-probs_opt[:,3])/probs_opt[:,3]**2))[:, None]).T @ X
594
+ # G matrix (3k x 6k)
595
+ G = (1.0/n) * np.vstack([
596
+ np.hstack([XG_1_1, XG_1_2, XG_1_3, XG_1_4, XG_1_5, XG_1_6]),
597
+ np.hstack([XG_2_1, XG_2_2, XG_2_3, XG_2_4, XG_2_5, XG_2_6]),
598
+ np.hstack([XG_3_1, XG_3_2, XG_3_3, XG_3_4, XG_3_5, XG_3_6])
599
+ ])
600
+ # W1 matrix (6k x n)
601
+ XW_1 = X * (T2 - probs_opt[:,1])[:, None] * (sample_weights**0.5)[:, None]
602
+ XW_2 = X * (T3 - probs_opt[:,2])[:, None] * (sample_weights**0.5)[:, None]
603
+ XW_3 = X * (T4 - probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
604
+ XW_4 = X * (T1/probs_opt[:,0] + T2/probs_opt[:,1] - T3/probs_opt[:,2] - T4/probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
605
+ XW_5 = X * (T1/probs_opt[:,0] - T2/probs_opt[:,1] - T3/probs_opt[:,2] + T4/probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
606
+ XW_6 = X * (-T1/probs_opt[:,0] + T2/probs_opt[:,1] - T3/probs_opt[:,2] + T4/probs_opt[:,3])[:, None] * (sample_weights**0.5)[:, None]
607
+ W1 = np.vstack([XW_1.T, XW_2.T, XW_3.T, XW_4.T, XW_5.T, XW_6.T])
608
+ # Omega matrix
609
+ Omega = (1.0/n) * (W1 @ W1.T)
610
+ # Sandwich formula
611
+ GWG = G @ W @ G.T
612
+ GWGinv = _r_ginv(GWG)
613
+ GWGinvGW = GWGinv @ G @ W
614
+ vcov = GWGinvGW @ Omega @ GWGinvGW.T
615
+ assert vcov.shape == (3*k, 3*k)
616
+ return vcov
617
+
618
+
619
+ def cbps_3treat_fit(
620
+ treat: np.ndarray,
621
+ X: np.ndarray,
622
+ method: str = 'over',
623
+ k: int = None,
624
+ XprimeX_inv: np.ndarray = None,
625
+ bal_only: bool = False,
626
+ iterations: int = 1000,
627
+ standardize: bool = True,
628
+ two_step: bool = True,
629
+ sample_weights: np.ndarray = None,
630
+ treat_levels: np.ndarray = None,
631
+ verbose: int = 0
632
+ ) -> Dict[str, Any]:
633
+ """
634
+ Fit CBPS for 3-level categorical treatments.
635
+
636
+ This function implements the full CBPS algorithm for treatments with
637
+ exactly three levels, using multinomial logistic regression for
638
+ initialization and GMM optimization for covariate balance.
639
+
640
+ Parameters
641
+ ----------
642
+ treat : np.ndarray
643
+ Treatment indicator with 3 levels, shape (n,).
644
+ X : np.ndarray, shape (n, k)
645
+ Covariate matrix (SVD-orthogonalized if applicable).
646
+ method : str, default 'over'
647
+ Estimation method: 'over' for overidentified GMM,
648
+ 'exact' for exactly identified GMM.
649
+ k : int
650
+ Rank of covariate matrix after SVD.
651
+ XprimeX_inv : np.ndarray, shape (k, k)
652
+ Inverse of X'X matrix for balance loss computation.
653
+ bal_only : bool, default False
654
+ If True, use balance constraints only.
655
+ If False, include score constraints (overidentified).
656
+ iterations : int, default 1000
657
+ Maximum number of optimization iterations.
658
+ standardize : bool, default True
659
+ If True, apply weight standardization.
660
+ If False, use Horvitz-Thompson weights.
661
+ two_step : bool, default True
662
+ If True, use two-step GMM with pre-computed invV.
663
+ If False, use continuous-updating GMM.
664
+ sample_weights : np.ndarray, optional
665
+ Sampling weights. If None, defaults to uniform weights.
666
+ treat_levels : np.ndarray, optional
667
+ Treatment level values for labeling.
668
+
669
+ Returns
670
+ -------
671
+ Dict[str, Any]
672
+ Dictionary containing fitted model results including:
673
+ - coefficients: Estimated coefficients
674
+ - fitted_values: Propensity scores
675
+ - weights: CBPS weights
676
+ - Additional diagnostic information
677
+
678
+ Keys include:
679
+ - coefficients: Coefficients in orthogonal space, shape (k, 2)
680
+ - fitted_values: Probability matrix, shape (n, 3)
681
+ - linear_predictor: Linear predictor values, shape (n, 2)
682
+ - weights: ATE weights, shape (n,)
683
+ - y: Treatment indicator vector
684
+ - x: Orthogonalized covariate matrix
685
+ - J: J-statistic for overidentification test
686
+ - mle_J: MLE J-statistic
687
+ - deviance: Negative twice log-likelihood
688
+ - converged: Convergence status
689
+ - var: Covariance matrix in orthogonal space, shape (2k, 2k)
690
+
691
+ Algorithm Flow
692
+ -------------
693
+ 1. Initialize constants and treatment indicators
694
+ 2. MNLogit initialization
695
+ 3. Alpha scaling
696
+ 4. Balance optimization
697
+ 5. Return if bal_only=True
698
+ 6. GMM dual initialization optimization
699
+ 7. Compute optimal probabilities
700
+ 8. Calculate J-statistic
701
+ 9. Check for MLE fallback
702
+ 10. Compute deviance and weight standardization
703
+ 11. Compute covariance matrix
704
+ 12. Construct return object
705
+
706
+ References
707
+ ----------
708
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
709
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
710
+ """
711
+ # ========== Initialization ==========
712
+ # Ensure dense matrix (sparse input auto-converted)
713
+ X = ensure_dense(X)
714
+
715
+ # Configure logging from verbose parameter (backward compatibility)
716
+ if verbose >= 2:
717
+ set_verbosity(2)
718
+ elif verbose >= 1:
719
+ set_verbosity(1)
720
+
721
+ # Step 0: Define n first due to Python scoping requirements
722
+ n = len(treat)
723
+
724
+ # Step 1: Treatment levels and indicators
725
+ if treat_levels is None:
726
+ treat_levels = np.unique(treat)
727
+ assert len(treat_levels) == 3, "Must be 3-valued treatment"
728
+
729
+ T1 = (treat == treat_levels[0]).astype(float)
730
+ T2 = (treat == treat_levels[1]).astype(float)
731
+ T3 = (treat == treat_levels[2]).astype(float)
732
+
733
+ # Step 2: Normalize sample_weights
734
+ sample_weights = normalize_sample_weights(sample_weights, n)
735
+
736
+ # Step 3: Compute k and XprimeX_inv
737
+ if k is None:
738
+ k = X.shape[1]
739
+ if XprimeX_inv is None:
740
+ wtX_sqrt = (sample_weights**0.5)[:, None] * X
741
+ XprimeX_inv = _r_ginv(wtX_sqrt.T @ wtX_sqrt)
742
+
743
+ # ========== Define closure functions (using external variables) ==========
744
+ def gmm_loss(beta):
745
+ return _gmm_func_3treat(beta, X, T1, T2, T3, sample_weights, n, None)['loss']
746
+
747
+ def bal_loss(beta):
748
+ return _bal_loss_3treat(beta, X, T1, T2, T3, sample_weights, XprimeX_inv, k, n)
749
+
750
+ # ========== MNLogit initialization ==========
751
+ mcoef, probs_mnl = _mnlogit_init_3treat(treat, X, sample_weights, treat_levels, k, n)
752
+
753
+ # ========== Alpha scaling ==========
754
+ def alpha_func(alpha):
755
+ return gmm_loss(mcoef.ravel() * alpha)
756
+ alpha_result = scipy.optimize.minimize_scalar(alpha_func, bounds=(0.8, 1.1), method='bounded')
757
+ gmm_init = mcoef.ravel() * alpha_result.x
758
+
759
+ # ========== Pre-compute invV (two-step method) ==========
760
+ this_invV = _gmm_func_3treat(gmm_init, X, T1, T2, T3, sample_weights, n, None)['inv_V']
761
+
762
+ # ========== Balance optimization ==========
763
+ logger.info(f"Starting balance optimization (max_iter={iterations})...")
764
+
765
+ if two_step:
766
+ opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
767
+ options={'maxiter': iterations})
768
+ logger.info(f"Balance optimization complete: loss={opt_bal.fun:.6f}, converged={opt_bal.success}")
769
+ else:
770
+ try:
771
+ opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
772
+ options={'maxiter': iterations})
773
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
774
+ opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='Nelder-Mead',
775
+ options={'maxiter': iterations})
776
+
777
+ beta_bal = opt_bal.x
778
+
779
+ # ========== Compute nulldeviance (before all return paths) ==========
780
+ # Null model: each category's probability = its sample proportion
781
+ T1_mean = np.average(T1, weights=sample_weights)
782
+ T2_mean = np.average(T2, weights=sample_weights)
783
+ T3_mean = np.average(T3, weights=sample_weights)
784
+ # Prevent log(0)
785
+ T1_mean = np.clip(T1_mean, 1e-10, 1.0)
786
+ T2_mean = np.clip(T2_mean, 1e-10, 1.0)
787
+ T3_mean = np.clip(T3_mean, 1e-10, 1.0)
788
+ nulldeviance = -2 * np.sum(T1 * np.log(T1_mean) + T2 * np.log(T2_mean) + T3 * np.log(T3_mean))
789
+
790
+ # ========== bal_only early return ==========
791
+ if bal_only:
792
+ beta_opt = beta_bal.reshape(k, 2)
793
+ theta_opt = X @ beta_opt
794
+ probs_opt = _compute_softmax_probs_3treat(theta_opt, PROBS_MIN)
795
+ w_opt = _standardize_weights_3treat(T1, T2, T3, probs_opt, sample_weights, standardize)
796
+ J_opt = bal_loss(beta_opt.ravel())
797
+ deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) + T3 * np.log(probs_opt[:,2]))
798
+ vcov = _compute_vcov_3treat(beta_opt, probs_opt, T1, T2, T3, X, sample_weights,
799
+ lambda b, inv_V=None: _gmm_func_3treat(b, X, T1, T2, T3, sample_weights, n, inv_V),
800
+ n, k)
801
+ mle_J_val = _gmm_func_3treat(mcoef.ravel(), X, T1, T2, T3, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
802
+ return {'coefficients': beta_opt, 'fitted_values': probs_opt, 'linear_predictor': theta_opt,
803
+ 'deviance': deviance, 'nulldeviance': nulldeviance, 'weights': w_opt * sample_weights, 'y': treat, 'x': X,
804
+ 'converged': opt_bal.success, 'J': J_opt, 'var': vcov, 'mle_J': mle_J_val}
805
+
806
+ # ========== GMM dual initialization selection ==========
807
+ def gmm_loss_with_invV(beta):
808
+ return _gmm_func_3treat(beta, X, T1, T2, T3, sample_weights, n, this_invV)['loss']
809
+
810
+ if two_step:
811
+ gmm_glm_init = scipy.optimize.minimize(gmm_loss_with_invV, gmm_init, method='BFGS',
812
+ options={'maxiter': iterations})
813
+ gmm_bal_init = scipy.optimize.minimize(gmm_loss_with_invV, beta_bal, method='BFGS',
814
+ options={'maxiter': iterations})
815
+ else:
816
+ try:
817
+ gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='BFGS',
818
+ options={'maxiter': iterations})
819
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
820
+ gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='Nelder-Mead',
821
+ options={'maxiter': iterations})
822
+ try:
823
+ gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='BFGS',
824
+ options={'maxiter': iterations})
825
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
826
+ gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='Nelder-Mead',
827
+ options={'maxiter': iterations})
828
+
829
+ # Select the optimization result with lower loss
830
+ opt1 = gmm_glm_init if gmm_glm_init.fun < gmm_bal_init.fun else gmm_bal_init
831
+
832
+ # ========== Optimal probabilities and J-statistic ==========
833
+ beta_opt = opt1.x.reshape(k, 2)
834
+ theta_opt = X @ beta_opt
835
+ probs_opt = _compute_softmax_probs_3treat(theta_opt, PROBS_MIN)
836
+ J_opt = _gmm_func_3treat(beta_opt.ravel(), X, T1, T2, T3, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(beta_opt.ravel())
837
+
838
+ # ========== MLE fallback check ==========
839
+ beta_opt, probs_opt, J_opt, used_mle = _check_and_fallback_to_mle(
840
+ J_opt, beta_opt, probs_opt, mcoef, probs_mnl, gmm_loss, bal_loss
841
+ )
842
+
843
+ # ========== Deviance and weights ==========
844
+ deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) + T3 * np.log(probs_opt[:,2]))
845
+
846
+ # Null deviance already computed above
847
+
848
+ w_opt = _standardize_weights_3treat(T1, T2, T3, probs_opt, sample_weights, standardize)
849
+
850
+ # ========== Vcov computation ==========
851
+ vcov = _compute_vcov_3treat(beta_opt, probs_opt, T1, T2, T3, X, sample_weights,
852
+ lambda b, inv_V=None: _gmm_func_3treat(b, X, T1, T2, T3, sample_weights, n, inv_V),
853
+ n, k)
854
+
855
+ # ========== Return dict ==========
856
+ mle_J_val = _gmm_func_3treat(mcoef.ravel(), X, T1, T2, T3, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
857
+
858
+ # Enhanced non-convergence warning
859
+ if not opt1.success:
860
+ warnings.warn(
861
+ f"Multi-valued CBPS (3-treat) optimization did not converge (converged=False). "
862
+ f"Results may be unreliable. Consider:\n"
863
+ f" 1. Increasing iterations (current: {iterations})\n"
864
+ f" 2. Checking for perfect separation or collinearity\n"
865
+ f" 3. Examining the balance diagnostics\n"
866
+ f" 4. J-statistic: {J_opt:.6f}\n"
867
+ f" 5. Trying different starting values or method='exact'",
868
+ UserWarning,
869
+ stacklevel=2
870
+ )
871
+
872
+ return {
873
+ 'coefficients': beta_opt,
874
+ 'fitted_values': probs_opt,
875
+ 'linear_predictor': theta_opt,
876
+ 'deviance': deviance,
877
+ 'nulldeviance': nulldeviance,
878
+ 'weights': w_opt * sample_weights,
879
+ 'y': treat,
880
+ 'x': X,
881
+ 'converged': opt1.success,
882
+ 'J': J_opt,
883
+ 'var': vcov,
884
+ 'mle_J': mle_J_val
885
+ }
886
+
887
+
888
+ def cbps_4treat_fit(
889
+ treat: np.ndarray,
890
+ X: np.ndarray,
891
+ method: str = 'over',
892
+ k: int = None,
893
+ XprimeX_inv: np.ndarray = None,
894
+ bal_only: bool = False,
895
+ iterations: int = 1000,
896
+ standardize: bool = True,
897
+ two_step: bool = True,
898
+ sample_weights: np.ndarray = None,
899
+ treat_levels: np.ndarray = None,
900
+ verbose: int = 0
901
+ ) -> Dict[str, Any]:
902
+ """
903
+ 4-valued treatment CBPS fitting function (complete workflow).
904
+
905
+ Four-valued treatment CBPS estimator using GMM optimization.
906
+
907
+ Parameters
908
+ ----------
909
+ Same as cbps_3treat_fit, but treatment has 4 levels.
910
+
911
+ Returns
912
+ -------
913
+ Dict[str, Any]
914
+ Dictionary with 11 core attributes.
915
+
916
+ Keys include (same structure as 3-treat):
917
+ - coefficients: (k, 3) orthogonal space coefficients (4-treat needs 3 columns)
918
+ - fitted_values: (n, 4) probability matrix (4 columns)
919
+ - linear_predictor: (n, 3) linear predictor (3 columns)
920
+ - weights: (n,) ATE weights
921
+ - var: (3k, 3k) orthogonal space vcov (larger for 4-treat)
922
+ - Other fields same as 3-treat
923
+
924
+ Algorithm Flow
925
+ --------------
926
+ Mostly same as cbps_3treat_fit, main differences:
927
+ - K=4 levels → 3 coefficient columns
928
+ - softmax computes 4 probability columns
929
+ - contrast weights 3 columns (3 of 6 pairwise contrasts)
930
+ - V matrix expands to (6k, 6k) (15 blocks)
931
+ - G matrix expands to (3k, 6k)
932
+
933
+ Notes
934
+ -----
935
+ The 4-treat specific invV selection logic chooses between GMM and balance
936
+ initialization based on which yields lower GMM loss.
937
+ """
938
+ # ========== Initialization ==========
939
+ # Ensure dense matrix (sparse input auto-converted)
940
+ X = ensure_dense(X)
941
+
942
+ # Configure logging from verbose parameter (backward compatibility)
943
+ if verbose >= 2:
944
+ set_verbosity(2)
945
+ elif verbose >= 1:
946
+ set_verbosity(1)
947
+
948
+ n = len(treat)
949
+
950
+ if treat_levels is None:
951
+ treat_levels = np.unique(treat)
952
+ assert len(treat_levels) == 4, "Must be 4-valued treatment"
953
+
954
+ T1 = (treat == treat_levels[0]).astype(float)
955
+ T2 = (treat == treat_levels[1]).astype(float)
956
+ T3 = (treat == treat_levels[2]).astype(float)
957
+ T4 = (treat == treat_levels[3]).astype(float)
958
+
959
+ sample_weights = normalize_sample_weights(sample_weights, n)
960
+
961
+ if k is None:
962
+ k = X.shape[1]
963
+ if XprimeX_inv is None:
964
+ wtX_sqrt = (sample_weights**0.5)[:, None] * X
965
+ XprimeX_inv = _r_ginv(wtX_sqrt.T @ wtX_sqrt)
966
+
967
+ # ========== Define closure functions ==========
968
+ def gmm_loss(beta):
969
+ return _gmm_func_4treat(beta, X, T1, T2, T3, T4, sample_weights, n, None)['loss']
970
+
971
+ def bal_loss(beta):
972
+ return _bal_loss_4treat(beta, X, T1, T2, T3, T4, sample_weights, XprimeX_inv, k, n)
973
+
974
+ # ========== MNLogit initialization ==========
975
+ mcoef, probs_mnl = _mnlogit_init_4treat(treat, X, sample_weights, treat_levels, k, n)
976
+
977
+ # ========== Alpha scaling ==========
978
+ def alpha_func(alpha):
979
+ return gmm_loss(mcoef.ravel() * alpha)
980
+ alpha_result = scipy.optimize.minimize_scalar(alpha_func, bounds=(0.8, 1.1), method='bounded')
981
+ gmm_init = mcoef.ravel() * alpha_result.x
982
+
983
+ # ========== Pre-compute invV ==========
984
+ temp_invV = _gmm_func_4treat(gmm_init, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
985
+
986
+ # ========== Balance optimization ==========
987
+ logger.info(f"Starting balance optimization (max_iter={iterations})...")
988
+
989
+ if two_step:
990
+ opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
991
+ options={'maxiter': iterations})
992
+ logger.info(f"Balance optimization complete: loss={opt_bal.fun:.6f}, converged={opt_bal.success}")
993
+ else:
994
+ try:
995
+ opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='BFGS',
996
+ options={'maxiter': iterations})
997
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
998
+ opt_bal = scipy.optimize.minimize(bal_loss, gmm_init, method='Nelder-Mead',
999
+ options={'maxiter': iterations})
1000
+
1001
+ beta_bal = opt_bal.x
1002
+
1003
+ # ========== Compute nulldeviance (before all return paths) ==========
1004
+ # Null model: each category's probability = its sample proportion
1005
+ T1_mean = np.average(T1, weights=sample_weights)
1006
+ T2_mean = np.average(T2, weights=sample_weights)
1007
+ T3_mean = np.average(T3, weights=sample_weights)
1008
+ T4_mean = np.average(T4, weights=sample_weights)
1009
+ T1_mean = np.clip(T1_mean, 1e-10, 1.0)
1010
+ T2_mean = np.clip(T2_mean, 1e-10, 1.0)
1011
+ T3_mean = np.clip(T3_mean, 1e-10, 1.0)
1012
+ T4_mean = np.clip(T4_mean, 1e-10, 1.0)
1013
+ nulldeviance = -2 * np.sum(T1 * np.log(T1_mean) + T2 * np.log(T2_mean) +
1014
+ T3 * np.log(T3_mean) + T4 * np.log(T4_mean))
1015
+
1016
+ # ========== 4-treat specific: invV selection logic ==========
1017
+ if two_step:
1018
+ if gmm_loss(gmm_init) < gmm_loss(beta_bal):
1019
+ this_invV = _gmm_func_4treat(gmm_init, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
1020
+ else:
1021
+ this_invV = _gmm_func_4treat(beta_bal, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
1022
+ if bal_only:
1023
+ this_invV = _gmm_func_4treat(beta_bal, X, T1, T2, T3, T4, sample_weights, n, None)['inv_V']
1024
+ else:
1025
+ this_invV = temp_invV
1026
+
1027
+ # ========== bal_only early return ==========
1028
+ if bal_only:
1029
+ beta_opt = beta_bal.reshape(k, 3)
1030
+ theta_opt = X @ beta_opt
1031
+ probs_opt = _compute_softmax_probs_4treat(theta_opt, PROBS_MIN)
1032
+ w_opt = _standardize_weights_4treat(T1, T2, T3, T4, probs_opt, sample_weights, standardize)
1033
+ J_opt = bal_loss(beta_opt.ravel())
1034
+ deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) +
1035
+ T3 * np.log(probs_opt[:,2]) + T4 * np.log(probs_opt[:,3]))
1036
+ vcov = _compute_vcov_4treat(beta_opt, probs_opt, T1, T2, T3, T4, X, sample_weights,
1037
+ lambda b, inv_V=None: _gmm_func_4treat(b, X, T1, T2, T3, T4, sample_weights, n, inv_V),
1038
+ n, k)
1039
+ mle_J_val = _gmm_func_4treat(mcoef.ravel(), X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
1040
+ return {'coefficients': beta_opt, 'fitted_values': probs_opt, 'linear_predictor': theta_opt,
1041
+ 'deviance': deviance, 'nulldeviance': nulldeviance, 'weights': w_opt * sample_weights, 'y': treat, 'x': X,
1042
+ 'converged': opt_bal.success, 'J': J_opt, 'var': vcov, 'mle_J': mle_J_val}
1043
+
1044
+ # ========== GMM dual initialization selection ==========
1045
+ def gmm_loss_with_invV(beta):
1046
+ return _gmm_func_4treat(beta, X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss']
1047
+
1048
+ if two_step:
1049
+ gmm_glm_init = scipy.optimize.minimize(gmm_loss_with_invV, gmm_init, method='BFGS',
1050
+ options={'maxiter': iterations})
1051
+ gmm_bal_init = scipy.optimize.minimize(gmm_loss_with_invV, beta_bal, method='BFGS',
1052
+ options={'maxiter': iterations})
1053
+ else:
1054
+ try:
1055
+ gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='BFGS',
1056
+ options={'maxiter': iterations})
1057
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
1058
+ gmm_glm_init = scipy.optimize.minimize(gmm_loss, gmm_init, method='Nelder-Mead',
1059
+ options={'maxiter': iterations})
1060
+ try:
1061
+ gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='BFGS',
1062
+ options={'maxiter': iterations})
1063
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
1064
+ gmm_bal_init = scipy.optimize.minimize(gmm_loss, beta_bal, method='Nelder-Mead',
1065
+ options={'maxiter': iterations})
1066
+
1067
+ opt1 = gmm_glm_init if gmm_glm_init.fun < gmm_bal_init.fun else gmm_bal_init
1068
+
1069
+ # ========== Optimal probabilities and J-statistic ==========
1070
+ beta_opt = opt1.x.reshape(k, 3)
1071
+ theta_opt = X @ beta_opt
1072
+ probs_opt = _compute_softmax_probs_4treat(theta_opt, PROBS_MIN)
1073
+ J_opt = _gmm_func_4treat(beta_opt.ravel(), X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(beta_opt.ravel())
1074
+
1075
+ # ========== MLE fallback check ==========
1076
+ beta_opt, probs_opt, J_opt, used_mle = _check_and_fallback_to_mle(
1077
+ J_opt, beta_opt, probs_opt, mcoef, probs_mnl, gmm_loss, bal_loss
1078
+ )
1079
+
1080
+ # ========== Deviance and weights ==========
1081
+ deviance = -2 * np.sum(T1 * np.log(probs_opt[:,0]) + T2 * np.log(probs_opt[:,1]) +
1082
+ T3 * np.log(probs_opt[:,2]) + T4 * np.log(probs_opt[:,3]))
1083
+
1084
+ # Null deviance already computed above
1085
+
1086
+ w_opt = _standardize_weights_4treat(T1, T2, T3, T4, probs_opt, sample_weights, standardize)
1087
+
1088
+ # ========== Vcov computation ==========
1089
+ vcov = _compute_vcov_4treat(beta_opt, probs_opt, T1, T2, T3, T4, X, sample_weights,
1090
+ lambda b, inv_V=None: _gmm_func_4treat(b, X, T1, T2, T3, T4, sample_weights, n, inv_V),
1091
+ n, k)
1092
+
1093
+ # ========== Return dict ==========
1094
+ mle_J_val = _gmm_func_4treat(mcoef.ravel(), X, T1, T2, T3, T4, sample_weights, n, this_invV)['loss'] if two_step else gmm_loss(mcoef.ravel())
1095
+
1096
+ # Enhanced non-convergence warning
1097
+ if not opt1.success:
1098
+ warnings.warn(
1099
+ f"Multi-valued CBPS (4-treat) optimization did not converge (converged=False). "
1100
+ f"Results may be unreliable. Consider:\n"
1101
+ f" 1. Increasing iterations (current: {iterations})\n"
1102
+ f" 2. Checking for perfect separation or collinearity\n"
1103
+ f" 3. Examining the balance diagnostics\n"
1104
+ f" 4. J-statistic: {J_opt:.6f}\n"
1105
+ f" 5. Trying different starting values or method='exact'",
1106
+ UserWarning,
1107
+ stacklevel=2
1108
+ )
1109
+
1110
+ return {
1111
+ 'coefficients': beta_opt,
1112
+ 'fitted_values': probs_opt,
1113
+ 'linear_predictor': theta_opt,
1114
+ 'deviance': deviance,
1115
+ 'nulldeviance': nulldeviance,
1116
+ 'weights': w_opt * sample_weights,
1117
+ 'y': treat,
1118
+ 'x': X,
1119
+ 'converged': opt1.success,
1120
+ 'J': J_opt,
1121
+ 'var': vcov,
1122
+ 'mle_J': mle_J_val
1123
+ }