cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,945 @@
1
+ """
2
+ Covariate Balancing Propensity Score for Continuous Treatments
3
+ ===============================================================
4
+
5
+ This module implements the Covariate Balancing Propensity Score (CBPS) methodology
6
+ for continuous treatments using generalized propensity scores (GPS). The implementation
7
+ extends the binary CBPS framework to handle continuous treatment variables through
8
+ covariate whitening and normal density estimation.
9
+
10
+ Methodology
11
+ -----------
12
+ The continuous CBPS estimates the generalized propensity score by maximizing the
13
+ covariate balance. The method involves:
14
+ 1. Cholesky whitening of covariates with sample weights.
15
+ 2. Log-space normal density computation for numerical stability.
16
+ 3. GMM optimization with multiple starting values.
17
+ 4. Coefficient inverse transformation from whitened to original space.
18
+
19
+ References
20
+ ----------
21
+ Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity score
22
+ for a continuous treatment: Application to the efficacy of political advertisements.
23
+ The Annals of Applied Statistics, 12(1), 156-177.
24
+ """
25
+
26
+ from typing import Dict, Any, Optional
27
+ import warnings
28
+ import numpy as np
29
+ import scipy.stats
30
+ import scipy.optimize
31
+ import scipy.linalg
32
+ import statsmodels.api as sm
33
+
34
+ from cbps.utils.validation import validate_cbps_input
35
+ from cbps.utils.validation import ensure_dense
36
+ from cbps.logging_config import logger, set_verbosity
37
+ from cbps.constants import DEFAULT_CONFIG
38
+
39
+
40
+ # ========== Constants (sourced from unified NumericalConfig) ==========
41
+ PROBS_MIN = DEFAULT_CONFIG.probs_min
42
+ CONST_COL_THRESHOLD = DEFAULT_CONFIG.const_col_threshold
43
+ ALPHA_BOUNDS = (0.8, 1.1)
44
+ CLIP_RANGE = DEFAULT_CONFIG.log_clip_range
45
+
46
+
47
+ def cbps_continuous_fit(
48
+ treat: np.ndarray,
49
+ X: np.ndarray,
50
+ method: str = 'over',
51
+ two_step: bool = True,
52
+ iterations: int = 1000,
53
+ standardize: bool = True,
54
+ sample_weights: Optional[np.ndarray] = None,
55
+ verbose: int = 0
56
+ ) -> Dict[str, Any]:
57
+ """
58
+ Fit the Covariate Balancing Propensity Score model for continuous treatments.
59
+
60
+ Parameters
61
+ ----------
62
+ treat : np.ndarray
63
+ Continuous treatment vector, shape (n,).
64
+ X : np.ndarray
65
+ Covariate matrix (including intercept column), shape (n, k).
66
+ method : {'over', 'exact'}, default='over'
67
+ Estimation method:
68
+ - 'over': Over-identified GMM (score + balance + sigma conditions).
69
+ - 'exact': Exactly identified GMM (balance + sigma conditions).
70
+ two_step : bool, default=True
71
+ If True, use two-step GMM with fixed weight matrix.
72
+ If False, use continuously updating GMM.
73
+ iterations : int, default=1000
74
+ Maximum number of optimization iterations.
75
+ standardize : bool, default=True
76
+ If True, standardize weights to sum to the sample size.
77
+ sample_weights : np.ndarray, optional
78
+ Sampling weights. Defaults to uniform weights if None.
79
+ Weights will be normalized to sum to n.
80
+ verbose : int, default=0
81
+ Verbosity level.
82
+
83
+ Returns
84
+ -------
85
+ dict
86
+ Dictionary containing estimation results:
87
+ - coefficients: Estimated parameters.
88
+ - fitted_values: Estimated propensity scores.
89
+ - weights: Inverse probability weights.
90
+ - deviance: Model deviance.
91
+ - converged: Convergence status.
92
+ - J: GMM loss function value.
93
+ - var: Variance-covariance matrix.
94
+ - sigmasq: Estimated residual variance.
95
+ - Ttilde: Standardized treatment.
96
+ - Xtilde: Whitened covariates.
97
+
98
+ Notes
99
+ -----
100
+ The algorithm performs Cholesky whitening on covariates, standardizes the treatment,
101
+ and then optimizes the GMM objective function. It handles potential numerical
102
+ instability in the weight matrix calculation through regularization when necessary.
103
+ """
104
+ # Input validation
105
+ X = ensure_dense(X)
106
+ validate_cbps_input(
107
+ treat, X,
108
+ min_observations=2,
109
+ module_name="Continuous CBPS",
110
+ check_treatment_variance=True
111
+ )
112
+
113
+ # Auto-fallback: if method='over' encounters infinite V matrix,
114
+ # fall back to method='exact' (matching R CBPS behavior)
115
+ if method == 'over':
116
+ try:
117
+ return _cbps_continuous_fit_impl(
118
+ treat, X, method=method, two_step=two_step,
119
+ iterations=iterations, standardize=standardize,
120
+ sample_weights=sample_weights, verbose=verbose
121
+ )
122
+ except ValueError as e:
123
+ if "infinite value in the weighting matrix" in str(e).lower():
124
+ warnings.warn(
125
+ f"Over-identified GMM failed due to infinite V matrix values. "
126
+ f'Automatically falling back to method="exact" '
127
+ f"(just-identified). Original error: {e}",
128
+ UserWarning
129
+ )
130
+ return _cbps_continuous_fit_impl(
131
+ treat, X, method='exact', two_step=two_step,
132
+ iterations=iterations, standardize=standardize,
133
+ sample_weights=sample_weights, verbose=verbose
134
+ )
135
+ raise
136
+ else:
137
+ return _cbps_continuous_fit_impl(
138
+ treat, X, method=method, two_step=two_step,
139
+ iterations=iterations, standardize=standardize,
140
+ sample_weights=sample_weights, verbose=verbose
141
+ )
142
+
143
+
144
+ def _cbps_continuous_fit_impl(
145
+ treat: np.ndarray,
146
+ X: np.ndarray,
147
+ method: str = 'over',
148
+ two_step: bool = True,
149
+ iterations: int = 1000,
150
+ standardize: bool = True,
151
+ sample_weights: Optional[np.ndarray] = None,
152
+ verbose: int = 0
153
+ ) -> Dict[str, Any]:
154
+ """Internal implementation of cbps_continuous_fit."""
155
+
156
+ # Configure logging from verbose parameter (backward compatibility)
157
+ if verbose >= 2:
158
+ set_verbosity(2)
159
+ elif verbose >= 1:
160
+ set_verbosity(1)
161
+
162
+ # Initialization
163
+ n = len(treat)
164
+ k = X.shape[1]
165
+ bal_only = (method == 'exact')
166
+
167
+ # Normalize sample weights
168
+ if sample_weights is None:
169
+ sample_weights = np.ones(n)
170
+ sample_weights = sample_weights / sample_weights.mean()
171
+ if not np.isclose(sample_weights.sum(), n, atol=1e-10):
172
+ warnings.warn(f"Sample weights normalization check failed: sum={sample_weights.sum():.6f} != n={n}")
173
+
174
+ # Save original X
175
+ X_orig = X.copy()
176
+
177
+ # ========== Covariate Whitening Preprocessing ==========
178
+
179
+ # Detect constant columns
180
+ col_std = np.std(X, axis=0, ddof=1)
181
+ int_ind = np.where(col_std <= CONST_COL_THRESHOLD)[0]
182
+ non_const_ind = np.where(col_std > CONST_COL_THRESHOLD)[0]
183
+
184
+ if len(non_const_ind) == 0:
185
+ warnings.warn(
186
+ "All columns are constant (sd <= 1e-10). "
187
+ "Continuous CBPS will degenerate to no-covariate model. "
188
+ "This is a valid edge case where the model only standardizes the treatment distribution.",
189
+ UserWarning
190
+ )
191
+ # Degenerate case: Xtilde is just X
192
+ Xtilde = X.copy()
193
+ else:
194
+ # Perform Cholesky whitening on non-constant columns
195
+ X_non_const = X[:, non_const_ind]
196
+ sw_X_non_const = sample_weights[:, None] * X_non_const
197
+ cov_weighted = np.cov(sw_X_non_const.T, ddof=1)
198
+
199
+ assert np.allclose(cov_weighted, cov_weighted.T, atol=1e-12), \
200
+ "Weighted covariance matrix must be symmetric"
201
+
202
+ # Cholesky decomposition to get upper triangular U
203
+ U = scipy.linalg.cholesky(cov_weighted, lower=False)
204
+
205
+ assert np.allclose(np.tril(U, k=-1), 0, atol=1e-12), "U must be upper triangular"
206
+ assert np.all(np.diag(U) > 0), "Diagonal elements of U must be positive"
207
+
208
+ U_inv = np.linalg.inv(U)
209
+
210
+ # Whitening transformation
211
+ X_white = sw_X_non_const @ U_inv
212
+
213
+ # Centering (no scaling)
214
+ X_white_centered = X_white - X_white.mean(axis=0)
215
+
216
+ assert abs(X_white_centered.mean()) < 1e-10, "Whitened data should be centered"
217
+
218
+ # Combine constant and whitened columns
219
+ if len(int_ind) > 0:
220
+ X_const = X[:, int_ind]
221
+ Xtilde = np.column_stack([X_const, X_white_centered])
222
+ else:
223
+ Xtilde = X_white_centered
224
+
225
+ # Verify shape consistency
226
+ if Xtilde.shape != X.shape:
227
+ raise ValueError(f"Xtilde shape {Xtilde.shape} != X shape {X.shape}")
228
+
229
+ # ========== Auxiliary Matrix Calculation ==========
230
+
231
+ # Pre-compute weighted Xtilde
232
+ wtXilde = sample_weights[:, None] * Xtilde
233
+
234
+ # Standardize treatment (zero mean, unit variance)
235
+ sw_treat = sample_weights * treat
236
+ Ttilde = (sw_treat - sw_treat.mean()) / sw_treat.std(ddof=1)
237
+
238
+ # Internal consistency checks
239
+ assert abs(Ttilde.mean()) < 1e-10
240
+ assert abs(Ttilde.std(ddof=1) - 1) < 1e-10
241
+
242
+ n_identity_vec = np.ones((n, 1))
243
+
244
+ # ========== Stabilizers Calculation ==========
245
+ # Calculate log marginal density log f(T*)
246
+ # Ideally constant, but computed per observation for robustness
247
+
248
+ pdf_vals = scipy.stats.norm.pdf(Ttilde, 0, 1)
249
+ pdf_clipped = np.clip(pdf_vals, PROBS_MIN, 1 - PROBS_MIN)
250
+ stabilizers = np.log(pdf_clipped)
251
+
252
+ # ========== GMM Objective Function ==========
253
+
254
+ def gmm_func(params_curr: np.ndarray, invV: Optional[np.ndarray] = None) -> Dict[str, Any]:
255
+ """
256
+ GMM objective function for over-identified case.
257
+
258
+ Parameters
259
+ ----------
260
+ params_curr : np.ndarray
261
+ Parameter vector [beta, log(sigma^2)].
262
+ invV : np.ndarray, optional
263
+ Inverse weight matrix V.
264
+
265
+ Returns
266
+ -------
267
+ dict
268
+ Dictionary containing loss value and inverse V matrix.
269
+ """
270
+ beta_curr = params_curr[:-1]
271
+ sigmasq = np.exp(params_curr[-1])
272
+
273
+ # Log conditional density
274
+ log_dens = scipy.stats.norm.logpdf(
275
+ Ttilde,
276
+ loc=Xtilde @ beta_curr,
277
+ scale=np.sqrt(sigmasq)
278
+ )
279
+
280
+ # Log-space clipping
281
+ log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
282
+ log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
283
+
284
+ # Weight calculation in log space
285
+ log_diff = stabilizers - log_dens
286
+ log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
287
+ w_curr = Ttilde * np.exp(log_diff_clipped)
288
+
289
+ if not np.all(np.isfinite(w_curr)):
290
+ raise ValueError("Weights contain non-finite values")
291
+
292
+ # Construct sample moment conditions gbar
293
+ # Moment 1: Score condition for sigma^2
294
+ gbar_1 = (1/n) * wtXilde.T @ ((Ttilde - Xtilde @ beta_curr) / sigmasq)
295
+
296
+ # Moment 2: Balance condition
297
+ w_curr_del = (1/n) * wtXilde.T @ w_curr
298
+ gbar_2 = w_curr_del.ravel()
299
+
300
+ # Moment 3: Score condition for beta
301
+ gbar_3 = (1/n) * sample_weights.T @ (
302
+ (Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
303
+ )
304
+
305
+ gbar = np.concatenate([gbar_1.ravel(), gbar_2, [gbar_3]])
306
+
307
+ # Compute V matrix or use pre-computed invV
308
+ if invV is None:
309
+ # Construct V matrix blocks
310
+ V11 = (1/sigmasq) * wtXilde.T @ Xtilde
311
+ V12 = wtXilde.T @ Xtilde / sigmasq
312
+ V13 = wtXilde.T @ n_identity_vec * 0
313
+
314
+ # V22 calculation with scaling vector
315
+ linear_pred = Xtilde @ beta_curr
316
+ linear_pred_sq = linear_pred**2
317
+ term_A = linear_pred_sq / sigmasq
318
+ term_B = np.log(sigmasq + linear_pred_sq)
319
+
320
+ exponent = term_A + term_B
321
+ if np.any(exponent > 700):
322
+ raise ValueError(
323
+ f"Potential overflow in V matrix calculation (max exponent={exponent.max():.2f}). "
324
+ f"Residual variance sigma^2={sigmasq:.6f} might be too small. "
325
+ f"Consider using method='exact'."
326
+ )
327
+
328
+ vec_scaling = np.exp(exponent)
329
+
330
+ if not np.all(np.isfinite(vec_scaling)):
331
+ raise ValueError("V22 scaling vector contains non-finite values.")
332
+
333
+ Xtilde_swept = vec_scaling[:, None] * Xtilde
334
+ V22 = wtXilde.T @ Xtilde_swept
335
+
336
+ V23 = (wtXilde.T @ (-Xtilde @ beta_curr) * (-2/sigmasq)).reshape(-1, 1)
337
+
338
+ V33_scalar = sample_weights.T @ n_identity_vec.ravel() * 2
339
+ V33 = np.array([[V33_scalar]])
340
+
341
+ # Assemble V
342
+ V = (1/n) * np.block([
343
+ [V11, V12, V13],
344
+ [V12, V22, V23],
345
+ [V13.T, V23.T, V33]
346
+ ])
347
+
348
+ if not np.allclose(V, V.T, atol=1e-12):
349
+ warnings.warn("V matrix is not symmetric within tolerance")
350
+
351
+ if np.any(np.isinf(V)):
352
+ raise ValueError(
353
+ "Encountered an infinite value in the weighting matrix. "
354
+ 'Use the just-identified version of CBPS instead by setting method="exact".'
355
+ )
356
+
357
+ invV = scipy.linalg.pinv(V)
358
+
359
+ loss = gbar.T @ invV @ gbar
360
+
361
+ if loss < -1e-6:
362
+ warnings.warn(
363
+ f"GMM loss is negative ({loss:.2e}). Check numerical stability.",
364
+ UserWarning
365
+ )
366
+
367
+ return {'loss': float(loss), 'invV': invV}
368
+
369
+ def gmm_loss(params_curr: np.ndarray, invV: Optional[np.ndarray] = None) -> float:
370
+ return gmm_func(params_curr, invV)['loss']
371
+
372
+
373
+ # ========== bal_func (exactly-identified, 2 moment conditions) ==========
374
+
375
+ def bal_func(params_curr: np.ndarray) -> Dict[str, float]:
376
+ """
377
+ Balance objective function for the exactly-identified case.
378
+
379
+ Parameters
380
+ ----------
381
+ params_curr : np.ndarray
382
+ Parameter vector [beta, log(sigma^2)].
383
+
384
+ Returns
385
+ -------
386
+ dict
387
+ Dictionary containing the balance loss value.
388
+ """
389
+ beta_curr = params_curr[:-1]
390
+ sigmasq = np.exp(params_curr[-1])
391
+
392
+ # Log conditional density
393
+ log_dens = scipy.stats.norm.logpdf(
394
+ Ttilde,
395
+ loc=Xtilde @ beta_curr,
396
+ scale=np.sqrt(sigmasq)
397
+ )
398
+
399
+ log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
400
+ log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
401
+
402
+ # Weight calculation
403
+ log_diff = stabilizers - log_dens
404
+ log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
405
+ w_curr = Ttilde * np.exp(log_diff_clipped)
406
+
407
+ # Construct sample moment conditions
408
+ w_curr_del = (1/n) * wtXilde.T @ w_curr
409
+
410
+ gbar = np.concatenate([
411
+ w_curr_del.ravel(), # Balance condition
412
+ [(1/n) * sample_weights.T @ (
413
+ (Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
414
+ )] # Sigma^2 condition
415
+ ])
416
+
417
+ if gbar.shape != (k + 1,):
418
+ raise ValueError(f"Gradient vector shape mismatch: {gbar.shape}")
419
+
420
+ # Loss calculation with identity weight matrix
421
+ loss = gbar.T @ np.eye(k + 1) @ gbar
422
+
423
+ return {'loss': float(loss)}
424
+
425
+ def bal_loss(params_curr: np.ndarray) -> float:
426
+ """Wrapper for balance loss function."""
427
+ return bal_func(params_curr)['loss']
428
+
429
+ # ========== GMM Gradient Calculation ==========
430
+
431
+ def gmm_gradient(params_curr: np.ndarray, invV: np.ndarray) -> np.ndarray:
432
+ """
433
+ Gradient of the GMM objective function.
434
+
435
+ Parameters
436
+ ----------
437
+ params_curr : np.ndarray
438
+ Parameter vector.
439
+ invV : np.ndarray
440
+ Inverse weight matrix V.
441
+
442
+ Returns
443
+ -------
444
+ np.ndarray
445
+ Gradient vector.
446
+ """
447
+ beta_curr = params_curr[:-1]
448
+ sigmasq = np.exp(params_curr[-1])
449
+
450
+ # Log conditional density
451
+ log_dens = scipy.stats.norm.logpdf(
452
+ Ttilde,
453
+ loc=Xtilde @ beta_curr,
454
+ scale=np.sqrt(sigmasq)
455
+ )
456
+
457
+ log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
458
+ log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
459
+
460
+ # Weights
461
+ log_diff = stabilizers - log_dens
462
+ log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
463
+ w_curr = Ttilde * np.exp(log_diff_clipped)
464
+
465
+ # Recompute gbar
466
+ gbar_1 = (1/n) * wtXilde.T @ ((Ttilde - Xtilde @ beta_curr) / sigmasq)
467
+ w_curr_del = (1/n) * wtXilde.T @ w_curr
468
+ gbar_2 = w_curr_del.ravel()
469
+ gbar_3 = (1/n) * sample_weights.T @ (
470
+ (Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
471
+ )
472
+ gbar = np.concatenate([gbar_1.ravel(), gbar_2, [gbar_3]])
473
+
474
+ # Calculate dgbar blocks
475
+ # dgbar.1.1 (k x k)
476
+ dgbar_1_1 = (-wtXilde.T @ Xtilde) / sigmasq
477
+
478
+ # dgbar.1.2 (1 x k)
479
+ dgbar_1_2 = (
480
+ -sample_weights * (Ttilde - Xtilde @ beta_curr) / (sigmasq**2)
481
+ ).reshape(1, -1) @ Xtilde
482
+
483
+ # dgbar.2.1 (k x k)
484
+ vec_L110 = -(Ttilde - Xtilde @ beta_curr) / sigmasq * w_curr
485
+ dgbar_2_1 = (wtXilde.T * vec_L110) @ Xtilde
486
+
487
+ # dgbar.2.2 (1 x k)
488
+ dgbar_2_2 = (
489
+ w_curr * (1/(2*sigmasq) - (Ttilde - Xtilde @ beta_curr)**2 / (2*sigmasq**2))
490
+ ).reshape(1, -1) @ Xtilde
491
+
492
+ # dgbar.3.1 (k x 1)
493
+ dgbar_3_1 = wtXilde.T @ (
494
+ -2 * (Ttilde - Xtilde @ beta_curr) / sigmasq
495
+ ).reshape(-1, 1)
496
+
497
+ # dgbar.3.2 (scalar)
498
+ dgbar_3_2 = sample_weights.T @ (
499
+ -(Ttilde - Xtilde @ beta_curr)**2 / (sigmasq**2)
500
+ )
501
+
502
+ # Assemble dgbar
503
+ col1 = np.vstack([dgbar_1_1, dgbar_1_2 * sigmasq])
504
+ col2 = np.vstack([dgbar_2_1, dgbar_2_2 * sigmasq])
505
+ col3 = np.vstack([dgbar_3_1, dgbar_3_2.reshape(1, 1) * sigmasq])
506
+
507
+ dgbar = (1/n) * np.hstack([col1, col2, col3])
508
+
509
+ # Gradient calculation: 2 * dgbar @ invV @ gbar
510
+ gradient = 2 * dgbar @ invV @ gbar
511
+
512
+ return gradient.ravel()
513
+
514
+ # ========== Balance Gradient Calculation ==========
515
+
516
+ def bal_gradient(params_curr: np.ndarray) -> np.ndarray:
517
+ """
518
+ Gradient of the balance objective function.
519
+
520
+ Parameters
521
+ ----------
522
+ params_curr : np.ndarray
523
+ Parameter vector.
524
+
525
+ Returns
526
+ -------
527
+ np.ndarray
528
+ Gradient vector.
529
+ """
530
+ beta_curr = params_curr[:-1]
531
+ sigmasq = np.exp(params_curr[-1])
532
+
533
+ log_dens = scipy.stats.norm.logpdf(
534
+ Ttilde,
535
+ loc=Xtilde @ beta_curr,
536
+ scale=np.sqrt(sigmasq)
537
+ )
538
+
539
+ log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
540
+ log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
541
+
542
+ log_diff = stabilizers - log_dens
543
+ log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
544
+ w_curr = Ttilde * np.exp(log_diff_clipped)
545
+
546
+ w_curr_del = (1/n) * wtXilde.T @ w_curr
547
+ gbar = np.concatenate([
548
+ w_curr_del.ravel(),
549
+ [(1/n) * sample_weights.T @ (
550
+ (Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
551
+ )]
552
+ ])
553
+
554
+ # Calculate dgbar blocks
555
+ vec_L145 = -(Ttilde - Xtilde @ beta_curr) / sigmasq * w_curr
556
+ dgbar_2_1 = (wtXilde.T * vec_L145) @ Xtilde
557
+
558
+ dgbar_2_2 = (
559
+ w_curr * (1/(2*sigmasq) - (Ttilde - Xtilde @ beta_curr)**2 / (2*sigmasq**2))
560
+ ).reshape(1, -1) @ Xtilde
561
+
562
+ dgbar_3_1 = wtXilde.T @ (
563
+ -2 * (Ttilde - Xtilde @ beta_curr) / sigmasq
564
+ ).reshape(-1, 1)
565
+
566
+ dgbar_3_2 = sample_weights.T @ (
567
+ -(Ttilde - Xtilde @ beta_curr)**2 / (sigmasq**2)
568
+ )
569
+
570
+ col1 = np.vstack([dgbar_2_1, dgbar_2_2 * sigmasq])
571
+ col2 = np.vstack([dgbar_3_1, dgbar_3_2.reshape(1, 1) * sigmasq])
572
+
573
+ dgbar = (1/n) * np.hstack([col1, col2])
574
+
575
+ gradient = 2 * dgbar @ np.eye(k + 1) @ gbar
576
+
577
+ return gradient.ravel()
578
+
579
+ # ========== Optimization Initialization and Scaling ==========
580
+
581
+ # Initial Linear Regression estimate
582
+ lm_model = sm.WLS(Ttilde, Xtilde, weights=sample_weights).fit()
583
+
584
+ mcoef = lm_model.params.copy()
585
+ mcoef[np.isnan(mcoef)] = 0
586
+
587
+ residuals = Ttilde - Xtilde @ mcoef
588
+ sigmasq_init = np.mean(residuals**2)
589
+
590
+ assert sigmasq_init > 0, f"Initial residual variance must be positive (got {sigmasq_init})"
591
+
592
+ # Calculate MLE probabilities
593
+ probs_mle = scipy.stats.norm.logpdf(
594
+ Ttilde,
595
+ loc=Xtilde @ mcoef,
596
+ scale=np.sqrt(sigmasq_init)
597
+ )
598
+ probs_mle = np.minimum(np.log(1 - PROBS_MIN), probs_mle)
599
+ probs_mle = np.maximum(np.log(PROBS_MIN), probs_mle)
600
+
601
+ # Construct initial parameter vector
602
+ params_curr = np.concatenate([mcoef, [np.log(sigmasq_init)]])
603
+
604
+ # Pre-compute MLE baseline loss for fallback
605
+ mle_J = np.nan
606
+ try:
607
+ mle_J = gmm_loss(params_curr)
608
+ except Exception as e:
609
+ warnings.warn(f"Failed to compute MLE J statistic: {e}")
610
+
611
+ mle_bal = bal_loss(params_curr)
612
+
613
+ # Alpha scaling optimization
614
+ # Implementation Note:
615
+ # We use a fixed V matrix (calculated at alpha=1.0) during the alpha scaling phase.
616
+ # While continuous updating of V is theoretically possible, fixed V provides better
617
+ # numerical stability in pathological cases (e.g., extremely poor initial fit)
618
+ # and matches the performance of standard two-step GMM approaches.
619
+
620
+ glm_invV = None
621
+ try:
622
+ # Pre-compute V inverse at alpha=1.0
623
+ glm_invV = gmm_func(params_curr, invV=None)['invV']
624
+
625
+ def alpha_func(alpha):
626
+ return gmm_loss(params_curr * alpha, invV=glm_invV)
627
+
628
+ alpha_result = scipy.optimize.minimize_scalar(
629
+ alpha_func,
630
+ bounds=ALPHA_BOUNDS,
631
+ method='bounded'
632
+ )
633
+
634
+ # Update parameters with optimal alpha scaling
635
+ params_curr = params_curr * alpha_result.x
636
+
637
+ except Exception as e:
638
+ warnings.warn(f"Alpha scaling failed, using unscaled LM initialization: {e}")
639
+ glm_invV = None
640
+
641
+ gmm_init = params_curr.copy()
642
+
643
+ # ========== Balance and GMM Optimization ==========
644
+
645
+ logger.info(f"Starting balance optimization (max_iter={iterations}, two_step={two_step})...")
646
+
647
+ if two_step:
648
+ # Two-step estimation using BFGS
649
+ opt_bal = scipy.optimize.minimize(
650
+ bal_loss, gmm_init,
651
+ method='BFGS',
652
+ jac=bal_gradient,
653
+ options={
654
+ 'maxiter': iterations,
655
+ 'gtol': 1e-05
656
+ }
657
+ )
658
+ else:
659
+ # Continuous updating with fallback
660
+ try:
661
+ opt_bal = scipy.optimize.minimize(
662
+ bal_loss, gmm_init,
663
+ method='BFGS',
664
+ options={'maxiter': iterations}
665
+ )
666
+ except (np.linalg.LinAlgError, ValueError, RuntimeWarning) as e:
667
+ warnings.warn(f"Balance BFGS failed, falling back to Nelder-Mead: {e}")
668
+ opt_bal = scipy.optimize.minimize(
669
+ bal_loss, gmm_init,
670
+ method='Nelder-Mead',
671
+ options={'maxiter': iterations}
672
+ )
673
+
674
+ params_bal = opt_bal.x
675
+
676
+ if bal_only:
677
+ opt1 = opt_bal
678
+
679
+ if not bal_only:
680
+ logger.info("Starting GMM optimization with dual initialization...")
681
+
682
+ if two_step:
683
+ # Initialize from GLM and Balance solutions
684
+ gmm_glm_init = scipy.optimize.minimize(
685
+ lambda p: gmm_loss(p, invV=glm_invV),
686
+ gmm_init,
687
+ method='BFGS',
688
+ jac=lambda p: gmm_gradient(p, glm_invV),
689
+ options={
690
+ 'maxiter': iterations,
691
+ 'gtol': 1e-05
692
+ }
693
+ )
694
+ gmm_bal_init = scipy.optimize.minimize(
695
+ lambda p: gmm_loss(p, invV=glm_invV),
696
+ params_bal,
697
+ method='BFGS',
698
+ jac=lambda p: gmm_gradient(p, glm_invV),
699
+ options={
700
+ 'maxiter': iterations,
701
+ 'gtol': 1e-05
702
+ }
703
+ )
704
+ else:
705
+ # Continuous updating
706
+ try:
707
+ gmm_glm_init = scipy.optimize.minimize(
708
+ gmm_loss, gmm_init,
709
+ method='BFGS',
710
+ options={
711
+ 'maxiter': iterations,
712
+ 'gtol': 1e-05
713
+ }
714
+ )
715
+ except (np.linalg.LinAlgError, ValueError, RuntimeWarning) as e:
716
+ warnings.warn(f"GMM-GLM BFGS failed, falling back to Nelder-Mead: {e}")
717
+ gmm_glm_init = scipy.optimize.minimize(
718
+ gmm_loss, gmm_init,
719
+ method='Nelder-Mead',
720
+ options={'maxiter': iterations}
721
+ )
722
+
723
+ try:
724
+ gmm_bal_init = scipy.optimize.minimize(
725
+ gmm_loss, params_bal,
726
+ method='BFGS',
727
+ options={
728
+ 'maxiter': iterations,
729
+ 'gtol': 1e-05
730
+ }
731
+ )
732
+ except (np.linalg.LinAlgError, ValueError, RuntimeWarning) as e:
733
+ warnings.warn(f"GMM-Balance BFGS failed, falling back to Nelder-Mead: {e}")
734
+ gmm_bal_init = scipy.optimize.minimize(
735
+ gmm_loss, params_bal,
736
+ method='Nelder-Mead',
737
+ options={'maxiter': iterations}
738
+ )
739
+
740
+ # Select best solution
741
+ if gmm_glm_init.fun < gmm_bal_init.fun:
742
+ opt1 = gmm_glm_init
743
+ pick_glm = 1
744
+ else:
745
+ opt1 = gmm_bal_init
746
+ pick_glm = 0
747
+
748
+ if verbose >= 1:
749
+ source = "GLM" if pick_glm == 1 else "Balance"
750
+ logger.info(f"GMM optimization complete: J={opt1.fun:.6f}, converged={opt1.success}, source={source}")
751
+
752
+ # ========== Parameter Extraction and MLE Fallback ==========
753
+
754
+ params_opt = opt1.x
755
+ beta_opt = params_opt[:-1]
756
+ sigmasq = np.exp(params_opt[-1])
757
+
758
+ # Recalculate probabilities
759
+ probs_opt = scipy.stats.norm.logpdf(
760
+ Ttilde,
761
+ loc=Xtilde @ beta_opt,
762
+ scale=np.sqrt(sigmasq)
763
+ )
764
+ probs_opt = np.minimum(np.log(1 - PROBS_MIN), probs_opt)
765
+ probs_opt = np.maximum(np.log(PROBS_MIN), probs_opt)
766
+
767
+ if not bal_only:
768
+ if two_step:
769
+ J_opt = gmm_func(params_opt, invV=glm_invV)['loss']
770
+ else:
771
+ J_opt = gmm_func(params_opt)['loss']
772
+
773
+ # MLE Fallback Logic
774
+ # Check 1: Significantly negative J statistic (theoretical violation)
775
+ if J_opt < -1e-6:
776
+ raise ValueError(
777
+ f"Encountered an infinite value in the weighting matrix. "
778
+ f"J statistic is significantly negative (J={J_opt:.6e}), "
779
+ f"indicating numerical instability in the V matrix. "
780
+ f'Use the just-identified version of CBPS instead by setting method="exact".'
781
+ )
782
+
783
+ # Check 2: Optimization result worse than MLE
784
+ # R code: if ((J.opt > mle.J) & (bal.loss(params.opt) > mle.bal))
785
+ elif (J_opt > mle_J) and (bal_loss(params_opt) > mle_bal):
786
+ warnings.warn(
787
+ f"Optimization produced worse results than MLE (|J_opt|={abs(J_opt):.6e} > "
788
+ f"|J_mle|={abs(mle_J):.6e}). Falling back to MLE.",
789
+ UserWarning
790
+ )
791
+ beta_opt = mcoef
792
+ probs_opt = probs_mle
793
+ J_opt = mle_J
794
+
795
+ # Check 3: Minor negative J
796
+ elif J_opt < 0:
797
+ warnings.warn(
798
+ f"J statistic is slightly negative (J={J_opt:.6e}). "
799
+ f"This may indicate minor numerical precision issues.",
800
+ UserWarning
801
+ )
802
+ else:
803
+ J_opt = bal_loss(params_opt)
804
+
805
+ # ========== Final Weight Calculation and Variance Estimation ==========
806
+
807
+ w_opt = np.exp(stabilizers - probs_opt)
808
+
809
+ if standardize:
810
+ w_opt = w_opt / np.sum(w_opt * sample_weights)
811
+ if not np.isclose(np.sum(w_opt * sample_weights), 1.0, atol=1e-10):
812
+ warnings.warn("Weight standardization failed to sum to 1")
813
+
814
+ if not np.all(np.isfinite(w_opt)):
815
+ raise ValueError("Final weights contain non-finite values")
816
+
817
+ deviance = -2 * np.sum(probs_opt)
818
+
819
+ # Compute XG matrix blocks (Gradient of moment conditions)
820
+ XG_1_1 = (-wtXilde.T @ Xtilde) / sigmasq
821
+
822
+ XG_2_1 = (wtXilde.T @ (
823
+ -2 * (Ttilde - Xtilde @ beta_opt) / sigmasq
824
+ )).reshape(-1, 1)
825
+
826
+ vec_L258 = -(Ttilde - Xtilde @ beta_opt) / sigmasq * Ttilde * w_opt
827
+ XG_3_1 = (wtXilde.T * vec_L258) @ Xtilde
828
+
829
+ XG_1_2 = ((-wtXilde.T @ (Ttilde - Xtilde @ beta_opt)) / (sigmasq**2)).reshape(-1, 1)
830
+
831
+ XG_2_2_scalar = sample_weights.T @ (
832
+ -(Ttilde - Xtilde @ beta_opt)**2 / (sigmasq**2)
833
+ )
834
+ XG_2_2 = np.array([[XG_2_2_scalar]])
835
+
836
+ XG_3_2 = (
837
+ -Ttilde * sample_weights * w_opt * (
838
+ (Ttilde - Xtilde @ beta_opt)**2 / (2*sigmasq**2) - 1/(2*sigmasq)
839
+ )
840
+ ).reshape(1, -1) @ Xtilde
841
+
842
+ # Compute XW matrix blocks
843
+ XW_1 = Xtilde * (
844
+ (Ttilde - Xtilde @ beta_opt) / sigmasq * sample_weights**0.5
845
+ )[:, None]
846
+
847
+ XW_2 = (
848
+ (Ttilde - Xtilde @ beta_opt)**2 / sigmasq - 1
849
+ ) * sample_weights**0.5
850
+
851
+ XW_3 = Xtilde * (Ttilde * w_opt * sample_weights)[:, None]
852
+
853
+ if bal_only:
854
+ W = np.eye(k + 1)
855
+ G = (1/n) * np.vstack([
856
+ np.hstack([XG_3_1, XG_3_2.T]),
857
+ np.hstack([XG_2_1.T, XG_2_2])
858
+ ])
859
+ W1 = np.vstack([XW_3.T, XW_2.reshape(1, -1)])
860
+ else:
861
+ W = gmm_func(params_opt)['invV']
862
+ G = (1/n) * np.vstack([
863
+ np.hstack([XG_1_1, XG_1_2]),
864
+ np.hstack([XG_3_1, XG_3_2.T]),
865
+ np.hstack([XG_2_1.T, XG_2_2])
866
+ ])
867
+ W1 = np.vstack([XW_1.T, XW_3.T, XW_2.reshape(1, -1)])
868
+
869
+ Omega = (1/n) * (W1 @ W1.T)
870
+
871
+ GWG_inv = scipy.linalg.pinv(G.T @ W @ G)
872
+ GWGinvGW = W @ G @ GWG_inv
873
+
874
+ vcov_tilde = (GWGinvGW.T @ Omega @ GWGinvGW)[0:k, 0:k]
875
+ vcov_tilde = (vcov_tilde + vcov_tilde.T) / 2
876
+
877
+ # Inverse transformation to original space
878
+ beta_tilde = beta_opt.copy()
879
+
880
+ XtX_inv = scipy.linalg.pinv(X.T @ X)
881
+ beta_opt = XtX_inv @ X.T @ (
882
+ Xtilde @ beta_tilde * np.std(sw_treat, ddof=1) + np.mean(sw_treat)
883
+ )
884
+
885
+ sigmasq_tilde = sigmasq
886
+ sigmasq = sigmasq_tilde * np.var(sw_treat, ddof=1)
887
+
888
+ # Variance-covariance transformation
889
+ sw_treat_var = np.var(sw_treat, ddof=1)
890
+ middle = Xtilde @ vcov_tilde @ Xtilde.T * sw_treat_var
891
+ vcov = XtX_inv @ X.T @ middle @ X @ XtX_inv
892
+ vcov = (vcov + vcov.T) / 2
893
+
894
+ result = {
895
+ 'coefficients': beta_opt.reshape(-1, 1),
896
+ 'fitted_values': np.clip(
897
+ scipy.stats.norm.pdf(
898
+ Ttilde,
899
+ loc=Xtilde @ beta_tilde,
900
+ scale=np.sqrt(sigmasq_tilde)
901
+ ),
902
+ PROBS_MIN,
903
+ 1 - PROBS_MIN
904
+ ),
905
+ 'linear_predictor': Xtilde @ beta_tilde,
906
+ 'deviance': deviance,
907
+ 'weights': w_opt * sample_weights,
908
+ 'y': treat,
909
+ 'x': X,
910
+ 'converged': opt1.success,
911
+ 'J': J_opt,
912
+ 'var': vcov,
913
+ 'mle_J': mle_J,
914
+ 'sigmasq': sigmasq,
915
+ 'Ttilde': Ttilde,
916
+ 'Xtilde': Xtilde,
917
+ 'beta_tilde': beta_tilde,
918
+ 'sigmasq_tilde': sigmasq_tilde,
919
+ 'stabilizers': stabilizers
920
+ }
921
+
922
+ # ========== Normality Diagnostics (P1-17) ==========
923
+ # Test the conditional normality assumption T|X ~ N(X'beta, sigma^2)
924
+ # using the original (un-whitened) treatment and covariates.
925
+ try:
926
+ from cbps.diagnostics.normality import test_treatment_normality
927
+ normality_diag = test_treatment_normality(treat, X)
928
+ result['normality_diagnostics'] = normality_diag
929
+
930
+ if normality_diag['reject_normality']:
931
+ warnings.warn(
932
+ f"Treatment normality assumption rejected "
933
+ f"(p={normality_diag['p_value']:.4f}). "
934
+ f"Consider using npCBPS for nonparametric estimation.",
935
+ UserWarning
936
+ )
937
+ except (ImportError, Exception) as e:
938
+ # Diagnostics should never block estimation
939
+ result['normality_diagnostics'] = {
940
+ 'error': str(e),
941
+ 'reject_normality': None
942
+ }
943
+
944
+ return result
945
+