cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1943 @@
1
+ """
2
+ Binary Treatment Covariate Balancing Propensity Score
3
+
4
+ This module implements the CBPS algorithm for binary treatments, supporting
5
+ both exactly-identified and over-identified generalized method of moments
6
+ (GMM) estimation.
7
+
8
+ The covariate balancing propensity score (CBPS) methodology estimates
9
+ propensity scores that optimize covariate balance while maintaining good
10
+ prediction of treatment assignment.
11
+
12
+ References
13
+ ----------
14
+ Imai, K. and Ratkovic, M. (2014). Covariate Balancing Propensity Score.
15
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
16
+ """
17
+
18
+ import warnings
19
+ from typing import Dict, Optional, Tuple, Callable
20
+
21
+ import numpy as np
22
+ import scipy.linalg
23
+ import scipy.special
24
+ import scipy.optimize
25
+ import statsmodels.api as sm
26
+ from statsmodels.genmod.families import Binomial
27
+
28
+ from ..utils.weights import standardize_weights
29
+ from ..utils.helpers import normalize_sample_weights
30
+ from ..utils.numerics import r_ginv_with_diagnostics
31
+ from ..utils.validation import ensure_dense, validate_cbps_input
32
+ from ..constants import DEFAULT_CONFIG
33
+ from ..logging_config import logger, set_verbosity
34
+
35
+ # Constants (sourced from unified NumericalConfig)
36
+ PROBS_MIN = DEFAULT_CONFIG.probs_min # Minimum probability clipping threshold
37
+
38
+ # att parameter normalization: support string and integer
39
+ _ATT_MAP = {'ate': 0, 'att': 1, 'atc': 2, 0: 0, 1: 1, 2: 2}
40
+
41
+
42
+ def _normalize_att(att):
43
+ """Normalize att parameter to integer.
44
+
45
+ Supports both string ('ate', 'att', 'atc') and integer (0, 1, 2) inputs.
46
+
47
+ Parameters
48
+ ----------
49
+ att : str or int
50
+ Target estimand specification.
51
+
52
+ Returns
53
+ -------
54
+ int
55
+ Normalized integer value (0, 1, or 2).
56
+
57
+ Raises
58
+ ------
59
+ ValueError
60
+ If att is not a valid string or integer value.
61
+ """
62
+ if isinstance(att, str):
63
+ att_lower = att.lower().strip()
64
+ if att_lower not in _ATT_MAP:
65
+ raise ValueError(
66
+ f"Invalid att='{att}'. Use 'ate', 'att', 'atc' or 0, 1, 2."
67
+ )
68
+ return _ATT_MAP[att_lower]
69
+ if att not in (0, 1, 2):
70
+ raise ValueError(
71
+ f"Invalid att={att}. Use 'ate'(0), 'att'(1), 'atc'(2)."
72
+ )
73
+ return int(att)
74
+
75
+
76
+ def _r_ginv(X: np.ndarray, tol: float = None) -> np.ndarray:
77
+ """
78
+ Compute the Moore-Penrose pseudoinverse with numerical stability.
79
+
80
+ This function computes the pseudoinverse using a tolerance based on the
81
+ square root of machine epsilon to determine which singular values to retain.
82
+
83
+ Parameters
84
+ ----------
85
+ X : np.ndarray
86
+ Input matrix for which to compute the pseudoinverse.
87
+ tol : float, optional
88
+ Tolerance parameter for singular value selection. Default is
89
+ sqrt(machine_epsilon) ≈ 1.49e-08. Singular values d are kept if
90
+ d > tol * max(d).
91
+
92
+ Returns
93
+ -------
94
+ np.ndarray
95
+ The Moore-Penrose pseudoinverse of X.
96
+
97
+ Notes
98
+ -----
99
+ The implementation follows a three-branch logic:
100
+ 1. If all singular values are positive: compute full pseudoinverse
101
+ 2. If no singular values are positive: return zero matrix
102
+ 3. If some singular values are positive: compute partial pseudoinverse
103
+ """
104
+ # Default tolerance: sqrt(machine epsilon) ≈ 1.49e-08
105
+ if tol is None:
106
+ machine_eps = np.finfo(float).eps # 2.220446049250313e-16
107
+ tol = np.sqrt(machine_eps) # ≈ 1.490116119384766e-08
108
+
109
+ # Compute reduced SVD decomposition (matches R's svd() default behavior)
110
+ # For X with shape (m, n):
111
+ # - U has shape (m, min(m,n))
112
+ # - d has shape (min(m,n),)
113
+ # - Vt has shape (min(m,n), n)
114
+ Xsvd_u, Xsvd_d, Xsvd_vt = np.linalg.svd(X, full_matrices=False)
115
+ Xsvd_v = Xsvd_vt.T # NumPy returns V^T, transpose to get V
116
+
117
+ # If no singular values or maximum is extremely small (< machine eps),
118
+ # return zero matrix to avoid numerical amplification
119
+ if len(Xsvd_d) == 0 or Xsvd_d[0] < np.finfo(float).eps:
120
+ return np.zeros((X.shape[1], X.shape[0]))
121
+
122
+ # Determine which singular values to retain: d > max(tol * d[0], 0)
123
+ # This matches R's MASS::ginv tolerance formula
124
+ tol_threshold = max(tol * Xsvd_d[0], 0.0)
125
+ Positive = Xsvd_d > tol_threshold
126
+
127
+ # Compute pseudoinverse based on retained singular values
128
+ # Formula: X+ = V @ diag(1/d) @ U.T (for retained singular values)
129
+ if np.all(Positive):
130
+ # All singular values retained: V @ diag(1/d) @ U.T
131
+ X_pinv = Xsvd_v @ np.diag(1.0 / Xsvd_d) @ Xsvd_u.T
132
+ elif not np.any(Positive):
133
+ # All singular values truncated: return zero matrix
134
+ X_pinv = np.zeros((X.shape[1], X.shape[0]))
135
+ else:
136
+ # Partial retention: V[:, pos] @ diag(1/d[pos]) @ U[:, pos].T
137
+ Xsvd_v_pos = Xsvd_v[:, Positive]
138
+ Xsvd_d_pos = Xsvd_d[Positive]
139
+ Xsvd_u_pos = Xsvd_u[:, Positive]
140
+ X_pinv = Xsvd_v_pos @ np.diag(1.0 / Xsvd_d_pos) @ Xsvd_u_pos.T
141
+
142
+ return X_pinv
143
+
144
+
145
+ def _att_wt_func(
146
+ beta_curr: np.ndarray,
147
+ X: np.ndarray,
148
+ treat: np.ndarray,
149
+ sample_weights: np.ndarray
150
+ ) -> np.ndarray:
151
+ """
152
+ Compute Average Treatment Effect on the Treated (ATT) weights.
153
+
154
+ This function implements the ATT weight function that assigns weights
155
+ to observations based on their estimated propensity scores. The weights
156
+ are constructed to balance covariates between treated and control groups.
157
+
158
+ Parameters
159
+ ----------
160
+ beta_curr : np.ndarray
161
+ Current coefficient estimates, shape (k,).
162
+ X : np.ndarray
163
+ Covariate matrix including intercept, shape (n, k).
164
+ treat : np.ndarray
165
+ Binary treatment indicator (0/1), shape (n,).
166
+ sample_weights : np.ndarray
167
+ Normalized sampling weights summing to n, shape (n,).
168
+
169
+ Returns
170
+ -------
171
+ np.ndarray
172
+ ATT weights, possibly containing negative values for control units.
173
+ Shape (n,). The calling function should take absolute values.
174
+
175
+ Notes
176
+ -----
177
+ The ATT weight formula is:
178
+ w_i = (n/n_t) * (T_i - π_i) / (1 - π_i)
179
+
180
+ where n_t is the weighted sum of treated units and π_i is the estimated
181
+ propensity score. Treated units receive positive weights while control
182
+ units receive negative weights, reflecting the ATT estimand.
183
+ """
184
+ # Compute weighted sample sizes
185
+ n_c = np.sum(sample_weights[treat == 0])
186
+ n_t = np.sum(sample_weights[treat == 1])
187
+ n = n_c + n_t
188
+
189
+ # Compute propensity scores
190
+ theta_curr = X @ beta_curr
191
+ probs_curr = scipy.special.expit(theta_curr)
192
+
193
+ # Clip probabilities to avoid numerical instability
194
+ probs_curr = np.minimum(1 - PROBS_MIN, probs_curr)
195
+ probs_curr = np.maximum(PROBS_MIN, probs_curr)
196
+
197
+ # ATT weight formula: w = (n/n_t) * (T - pi) / (1 - pi)
198
+ w1 = (n / n_t) * (treat - probs_curr) / (1 - probs_curr)
199
+
200
+ return w1
201
+
202
+
203
+ def _compute_V_matrix(
204
+ X: np.ndarray,
205
+ probs_curr: np.ndarray,
206
+ sample_weights: np.ndarray,
207
+ treat: np.ndarray,
208
+ att: int,
209
+ n: int
210
+ ) -> np.ndarray:
211
+ """
212
+ Compute the covariance matrix V for GMM estimation.
213
+
214
+ This function computes the covariance matrix of moment conditions
215
+ used in the generalized method of moments (GMM) estimation of CBPS.
216
+ The matrix structure differs between ATE and ATT estimation.
217
+
218
+ Parameters
219
+ ----------
220
+ X : np.ndarray
221
+ Covariate matrix, shape (n, k).
222
+ probs_curr : np.ndarray
223
+ Current propensity score estimates, shape (n,).
224
+ sample_weights : np.ndarray
225
+ Normalized sampling weights, shape (n,).
226
+ treat : np.ndarray
227
+ Binary treatment indicator, shape (n,).
228
+ att : int
229
+ Estimand type: 0 for ATE, 1 for ATT.
230
+ n : int
231
+ Number of observations.
232
+
233
+ Returns
234
+ -------
235
+ np.ndarray
236
+ The Moore-Penrose pseudoinverse of the covariance matrix V,
237
+ shape (2k, 2k) where k is the number of covariates.
238
+
239
+ Notes
240
+ -----
241
+ The V matrix has a 2x2 block structure combining score and balance
242
+ moment conditions. For ATT estimation, the matrix includes scaling
243
+ factors involving the ratio of treated to total observations.
244
+ """
245
+ sw_sqrt = np.sqrt(sample_weights)
246
+
247
+ if att:
248
+ # ATT: weighted covariate matrices with propensity score scaling
249
+ X_1 = sw_sqrt[:, None] * X * np.sqrt(probs_curr * (1 - probs_curr))[:, None]
250
+ X_2 = sw_sqrt[:, None] * X * np.sqrt(probs_curr / (1 - probs_curr))[:, None]
251
+ X_1_1 = sw_sqrt[:, None] * X * np.sqrt(probs_curr)[:, None]
252
+
253
+ # Block covariance matrix with ATT scaling factors
254
+ n_treat = np.sum(treat == 1)
255
+ V11 = (1 / n) * (X_1.T @ X_1) * n / n_treat
256
+ V12 = (1 / n) * (X_1_1.T @ X_1_1) * n / n_treat
257
+ V21 = V12 # Symmetric
258
+ V22 = (1 / n) * (X_2.T @ X_2) * n**2 / n_treat**2
259
+ else:
260
+ # ATE: weighted covariate matrices
261
+ X_1 = sw_sqrt[:, None] * X * np.sqrt(probs_curr * (1 - probs_curr))[:, None]
262
+ X_2 = sw_sqrt[:, None] * X / np.sqrt(probs_curr * (1 - probs_curr))[:, None]
263
+ X_1_1 = sw_sqrt[:, None] * X
264
+
265
+ # Block covariance matrix without scaling
266
+ V11 = (1 / n) * (X_1.T @ X_1)
267
+ V12 = (1 / n) * (X_1_1.T @ X_1_1)
268
+ V21 = V12 # Symmetric
269
+ V22 = (1 / n) * (X_2.T @ X_2)
270
+
271
+ # Assemble 2x2 block matrix
272
+ V = np.block([[V11, V12],
273
+ [V21, V22]])
274
+
275
+ # Verify symmetry of covariance matrix
276
+ assert np.allclose(V, V.T, atol=1e-15), "V matrix must be symmetric"
277
+
278
+ # Compute Moore-Penrose pseudoinverse with diagnostics
279
+ inv_V, _v_diagnostics = r_ginv_with_diagnostics(V, warn_threshold=1e12)
280
+ # Diagnostics are emitted as UserWarning when condition number exceeds threshold
281
+
282
+ return inv_V
283
+
284
+
285
+ def _gmm_func(
286
+ beta_curr: np.ndarray,
287
+ X: np.ndarray,
288
+ treat: np.ndarray,
289
+ sample_weights: np.ndarray,
290
+ att: int,
291
+ inv_V: Optional[np.ndarray] = None
292
+ ) -> Dict:
293
+ """
294
+ Compute the GMM objective function and covariance matrix.
295
+
296
+ This function evaluates the generalized method of moments objective
297
+ combining score conditions and covariate balancing conditions.
298
+
299
+ Parameters
300
+ ----------
301
+ beta_curr : np.ndarray
302
+ Current coefficient vector, shape (k,).
303
+ inv_V : np.ndarray or None
304
+ Precomputed inverse covariance matrix. If None, it will be computed.
305
+
306
+ Returns
307
+ -------
308
+ dict
309
+ Dictionary containing:
310
+ - 'loss': float, GMM loss (quadratic form gbar' @ inv_V @ gbar)
311
+ - 'inv_V': ndarray, pseudoinverse of the covariance matrix V
312
+
313
+ Notes
314
+ -----
315
+ When two_step=True, the inverse covariance matrix is precomputed and
316
+ passed in; when two_step=False, it is recomputed at each iteration.
317
+ """
318
+ n = len(treat)
319
+
320
+ # Compute propensity scores
321
+ theta_curr = X @ beta_curr
322
+ probs_curr = scipy.special.expit(theta_curr)
323
+
324
+ # Clip probabilities for numerical stability
325
+ probs_curr = np.minimum(1 - PROBS_MIN, probs_curr)
326
+ probs_curr = np.maximum(PROBS_MIN, probs_curr)
327
+ probs_curr = probs_curr.ravel()
328
+
329
+ # Compute weights based on estimand type
330
+ if att:
331
+ w_curr = _att_wt_func(beta_curr, X, treat, sample_weights)
332
+ else:
333
+ # ATE weight: 1 / (pi - 1 + T) = T/pi - (1-T)/(1-pi)
334
+ w_curr = 1 / (probs_curr - 1 + treat)
335
+
336
+ # Construct moment conditions
337
+ # Balance condition: weighted covariate means
338
+ w_curr_del = (1 / n) * (sample_weights[:, None] * X).T @ w_curr
339
+ w_curr_del = w_curr_del.ravel()
340
+
341
+ # Combine score and balance conditions
342
+ score_cond = (1 / n) * (sample_weights[:, None] * X).T @ (treat - probs_curr)
343
+ gbar = np.concatenate([score_cond.ravel(), w_curr_del])
344
+
345
+ # Compute covariance matrix if not provided
346
+ if inv_V is None:
347
+ inv_V = _compute_V_matrix(X, probs_curr, sample_weights, treat, att, n)
348
+
349
+ # GMM loss: quadratic form
350
+ loss = float(gbar.T @ inv_V @ gbar)
351
+
352
+ return {'loss': loss, 'inv_V': inv_V}
353
+
354
+
355
+ def _gmm_loss(
356
+ beta: np.ndarray,
357
+ X: np.ndarray,
358
+ treat: np.ndarray,
359
+ sample_weights: np.ndarray,
360
+ att: int,
361
+ inv_V: Optional[np.ndarray]
362
+ ) -> float:
363
+ """
364
+ Compute the GMM objective function value.
365
+
366
+ This function evaluates the generalized method of moments objective
367
+ function that combines the propensity score likelihood and covariate
368
+ balancing conditions.
369
+
370
+ Parameters
371
+ ----------
372
+ beta : np.ndarray
373
+ Coefficient vector, shape (k,).
374
+ X : np.ndarray
375
+ Covariate matrix, shape (n, k).
376
+ treat : np.ndarray
377
+ Binary treatment indicator, shape (n,).
378
+ sample_weights : np.ndarray
379
+ Normalized sampling weights, shape (n,).
380
+ att : int
381
+ Estimand type: 0 for ATE, 1 for ATT.
382
+ inv_V : np.ndarray, optional
383
+ Precomputed inverse covariance matrix. If None, computes it.
384
+
385
+ Returns
386
+ -------
387
+ float
388
+ The GMM objective function value.
389
+ """
390
+ return _gmm_func(beta, X, treat, sample_weights, att, inv_V)['loss']
391
+
392
+
393
+ def _gmm_gradient(
394
+ beta_curr: np.ndarray,
395
+ inv_V: np.ndarray,
396
+ X: np.ndarray,
397
+ treat: np.ndarray,
398
+ sample_weights: np.ndarray,
399
+ att: int
400
+ ) -> np.ndarray:
401
+ """
402
+ Compute the analytical gradient of the GMM objective function.
403
+
404
+ This function calculates the analytical gradient of the GMM objective
405
+ with respect to the coefficient vector, following the R CBPS package
406
+ implementation exactly.
407
+
408
+ Parameters
409
+ ----------
410
+ beta_curr : np.ndarray
411
+ Current coefficient estimates, shape (k,).
412
+ inv_V : np.ndarray
413
+ Inverse covariance matrix, shape (2k, 2k).
414
+ X : np.ndarray
415
+ Covariate matrix, shape (n, k).
416
+ treat : np.ndarray
417
+ Binary treatment indicator, shape (n,).
418
+ sample_weights : np.ndarray
419
+ Normalized sampling weights, shape (n,).
420
+ att : int
421
+ Estimand type: 0 for ATE, 1 for ATT.
422
+
423
+ Returns
424
+ -------
425
+ np.ndarray
426
+ Gradient vector, shape (k,).
427
+
428
+ Notes
429
+ -----
430
+ The gradient is computed as: grad = 2 * dgbar @ inv_V @ gbar
431
+
432
+ where dgbar is the Jacobian of the moment conditions gbar with respect
433
+ to beta. The formula differs between ATE and ATT estimation.
434
+
435
+ For ATE:
436
+ dgbar = [-1/n * X' * diag(sw * pi * (1-pi)) * X,
437
+ -1/n * X' * diag(sw * (T-pi)^2 / (pi*(1-pi))) * X]
438
+
439
+ For ATT:
440
+ dw = -n/n_t * pi / (1-pi), with dw[treat==1] = 0
441
+ dgbar = [1/n * X' * diag(-sw * pi * (1-pi)) * X,
442
+ 1/n * X' * diag(dw * sw) * X]
443
+
444
+ References
445
+ ----------
446
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
447
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
448
+ """
449
+ n = len(treat)
450
+ n_c = np.sum(sample_weights[treat == 0])
451
+ n_t = np.sum(sample_weights[treat == 1])
452
+
453
+ # Compute propensity scores
454
+ theta_curr = X @ beta_curr
455
+ probs_curr = scipy.special.expit(theta_curr)
456
+ probs_curr = np.clip(probs_curr, PROBS_MIN, 1 - PROBS_MIN)
457
+
458
+ # Pre-compute sample_weights * X (used multiple times)
459
+ sw_X = sample_weights[:, None] * X
460
+
461
+ # Compute weights based on estimand type
462
+ if att:
463
+ w_curr = _att_wt_func(beta_curr, X, treat, sample_weights)
464
+ else:
465
+ w_curr = 1 / (probs_curr - 1 + treat)
466
+
467
+ # Compute gbar (moment conditions)
468
+ w_curr_del = (1 / n) * sw_X.T @ w_curr
469
+ w_curr_del = w_curr_del.ravel()
470
+ score_cond = (1 / n) * sw_X.T @ (treat - probs_curr)
471
+ gbar = np.concatenate([score_cond.ravel(), w_curr_del])
472
+
473
+ # Compute dgbar (Jacobian of moment conditions)
474
+ if att:
475
+ # ATT balance gradient computation (Imai & Ratkovic 2014, Eq. 11)
476
+ #
477
+ # ATT weight: w_i = (N/N_1) * (T_i - pi_i) / (1 - pi_i)
478
+ #
479
+ # For T=1 (treated): w_i = N/N_1 (constant), dw/dbeta = 0
480
+ # For T=0 (control): w_i = -(N/N_1) * pi / (1-pi)
481
+ #
482
+ # Gradient via chain rule:
483
+ # dw/dbeta = dw/dpi * dpi/dbeta
484
+ # where:
485
+ # dw/dpi = -(N/N_1) * 1/(1-pi)^2
486
+ # dpi/dbeta = pi*(1-pi) * X (logistic link)
487
+ # therefore:
488
+ # dw/dbeta = -(N/N_1) * [1/(1-pi)^2] * [pi*(1-pi)] * X
489
+ # = -(N/N_1) * pi/(1-pi) * X
490
+ #
491
+ # The Jacobian of the balance condition g_b = (1/n) * X' * diag(dw) * X
492
+ dw = -n / n_t * probs_curr / (1 - probs_curr)
493
+ dw[treat == 1] = 0
494
+
495
+ # Score condition derivative: 1/n * X' * diag(-sw * pi * (1-pi)) * X
496
+ dgbar_score = (1 / n) * (X * (-sample_weights * probs_curr * (1 - probs_curr))[:, None]).T @ X
497
+
498
+ # Balance condition derivative: 1/n * X' * diag(dw * sw) * X
499
+ # Note: R code uses 1/n.t here, but mathematically correct is 1/n.
500
+ # The derivative of gbar_balance = (1/n) * X' * sw * w_ATT w.r.t. beta
501
+ # gives (1/n) * X' * diag(sw * dw) * X. R's 1/n.t has an extra n/n_t
502
+ # factor. BFGS is robust to such gradient scaling, so R still converges.
503
+ # We use the mathematically correct 1/n for better numerical gradient match.
504
+ dgbar_balance = (1 / n) * (X * (dw * sample_weights)[:, None]).T @ X
505
+ else:
506
+ # ATE gradient formula from R code
507
+ # Score condition derivative: -1/n * X' * diag(sw * pi * (1-pi)) * X
508
+ dgbar_score = (-1 / n) * (X * (sample_weights * probs_curr * (1 - probs_curr))[:, None]).T @ X
509
+
510
+ # Balance condition derivative: -1/n * X' * diag(sw * (T-pi)^2 / (pi*(1-pi))) * X
511
+ balance_weight = sample_weights * (treat - probs_curr)**2 / (probs_curr * (1 - probs_curr))
512
+ dgbar_balance = (-1 / n) * (X * balance_weight[:, None]).T @ X
513
+
514
+ # Combine into full Jacobian: dgbar has shape (k, 2k)
515
+ dgbar = np.hstack([dgbar_score, dgbar_balance])
516
+
517
+ # Compute gradient: 2 * dgbar @ inv_V @ gbar
518
+ grad = 2 * dgbar @ inv_V @ gbar
519
+
520
+ return grad
521
+
522
+
523
+ def _bal_loss(
524
+ beta_curr: np.ndarray,
525
+ X: np.ndarray,
526
+ treat: np.ndarray,
527
+ sample_weights: np.ndarray,
528
+ XprimeX_inv: np.ndarray,
529
+ att: int
530
+ ) -> float:
531
+ """
532
+ Balance loss function (covariate balancing only).
533
+
534
+ This function implements the balance component of the CBPS objective
535
+ function, focusing solely on achieving covariate balance between
536
+ treatment groups without considering prediction of treatment assignment.
537
+
538
+ Parameters
539
+ ----------
540
+ beta_curr : np.ndarray
541
+ Current coefficient estimates, shape (k,).
542
+ X : np.ndarray
543
+ Covariate matrix, shape (n, k).
544
+ treat : np.ndarray
545
+ Binary treatment indicator, shape (n,).
546
+ sample_weights : np.ndarray
547
+ Normalized sampling weights, shape (n,).
548
+ XprimeX_inv : np.ndarray
549
+ Inverse of X'X matrix, pre-computed for efficiency, shape (k, k).
550
+ att : int
551
+ Estimand type: 0 for ATE, 1 for ATT.
552
+
553
+ Returns
554
+ -------
555
+ float
556
+ Balance loss value (absolute quadratic form).
557
+
558
+ Notes
559
+ -----
560
+ Key differences between balance loss and GMM loss:
561
+
562
+ - Balance loss uses absolute value: |ḡ' (X'WX)^{-1} ḡ| where ḡ = X'Ww
563
+ - GMM loss uses quadratic form without absolute value: ḡ' Σ^{-1} ḡ
564
+ - Weight computation includes 1/n scaling factor
565
+
566
+ Here W = diag(sample_weights) is the sample weight matrix.
567
+ """
568
+ n = len(treat)
569
+
570
+ # Compute propensity scores with numerical clipping
571
+ theta_curr = X @ beta_curr
572
+ probs_curr = scipy.special.expit(theta_curr)
573
+ probs_curr = np.clip(probs_curr, PROBS_MIN, 1 - PROBS_MIN)
574
+
575
+ # Compute weights with 1/n scaling factor
576
+ if att:
577
+ w_curr = (1 / n) * _att_wt_func(beta_curr, X, treat, sample_weights)
578
+ else:
579
+ w_curr = (1 / n) * (1 / (probs_curr - 1 + treat))
580
+
581
+ # Compute weighted covariate sum
582
+ Xprimew = (sample_weights[:, None] * X).T @ w_curr # (k,) vector
583
+
584
+ # Balance loss: quadratic form with absolute value
585
+ loss = np.abs(Xprimew.T @ XprimeX_inv @ Xprimew)
586
+
587
+ return float(loss)
588
+
589
+
590
+ def _bal_gradient(
591
+ beta_curr: np.ndarray,
592
+ X: np.ndarray,
593
+ treat: np.ndarray,
594
+ sample_weights: np.ndarray,
595
+ XprimeX_inv: np.ndarray,
596
+ att: int
597
+ ) -> np.ndarray:
598
+ """
599
+ Analytical gradient of the balance loss function.
600
+
601
+ This function computes the analytical gradient of the balance component
602
+ of the CBPS objective function, following the R CBPS package implementation
603
+ exactly. The use of analytical gradient is critical because the balance
604
+ loss contains an absolute value function.
605
+
606
+ Parameters
607
+ ----------
608
+ beta_curr : np.ndarray
609
+ Current coefficient estimates, shape (k,).
610
+ X : np.ndarray
611
+ Covariate matrix, shape (n, k).
612
+ treat : np.ndarray
613
+ Binary treatment indicator, shape (n,).
614
+ sample_weights : np.ndarray
615
+ Normalized sampling weights, shape (n,).
616
+ XprimeX_inv : np.ndarray
617
+ Inverse of X'X matrix, shape (k, k).
618
+ att : int
619
+ Estimand type: 0 for ATE, 1 for ATT.
620
+
621
+ Returns
622
+ -------
623
+ np.ndarray
624
+ Gradient vector, shape (k,).
625
+
626
+ Notes
627
+ -----
628
+ The R implementation uses a sign adjustment to handle the absolute value:
629
+
630
+ out = sapply(2*dw%*%X%*%XprimeX.inv%*%Xprimew,
631
+ function(x) ifelse((x>0 & loss1>0) | (x<0 & loss1<0),
632
+ abs(x), -abs(x)))
633
+
634
+ This ensures the gradient points in the correct direction for minimizing
635
+ the absolute value of the quadratic form.
636
+
637
+ For ATE:
638
+ dw = 1/n * t(-X * (T-pi)^2 / (pi*(1-pi)))
639
+
640
+ For ATT:
641
+ dw2 = -n/n_t * pi / (1-pi), with dw2[treat==1] = 0
642
+ dw = 1/n * t(X * dw2)
643
+
644
+ References
645
+ ----------
646
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
647
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
648
+ """
649
+ n = len(treat)
650
+ n_c = np.sum(sample_weights[treat == 0])
651
+ n_t = np.sum(sample_weights[treat == 1])
652
+
653
+ # Compute propensity scores
654
+ theta_curr = X @ beta_curr
655
+ probs_curr = scipy.special.expit(theta_curr)
656
+ probs_curr = np.clip(probs_curr, PROBS_MIN, 1 - PROBS_MIN)
657
+
658
+ # Compute weights with 1/n scaling factor
659
+ if att:
660
+ w_curr = (1 / n) * _att_wt_func(beta_curr, X, treat, sample_weights)
661
+ else:
662
+ w_curr = (1 / n) * (1 / (probs_curr - 1 + treat))
663
+
664
+ # Compute dw (derivative of weights with respect to beta)
665
+ if att:
666
+ # ATT: dw2 = -n/n_t * pi / (1-pi), with dw2[treat==1] = 0
667
+ dw2 = -n / n_t * probs_curr / (1 - probs_curr)
668
+ dw2[treat == 1] = 0
669
+ # dw has shape (k, n): dw = 1/n * X.T * dw2
670
+ dw = (1 / n) * (X * dw2[:, None]).T
671
+ else:
672
+ # ATE: dw = 1/n * t(-X * (T-pi)^2 / (pi*(1-pi)))
673
+ dw_weight = -(treat - probs_curr)**2 / (probs_curr * (1 - probs_curr))
674
+ dw = (1 / n) * (X * dw_weight[:, None]).T
675
+
676
+ # Compute Xprimew = X' @ (w_curr * sample_weights)
677
+ Xprimew = X.T @ (w_curr * sample_weights) # shape (k,)
678
+
679
+ # Compute loss1 = Xprimew' @ XprimeX_inv @ Xprimew (scalar)
680
+ loss1 = Xprimew.T @ XprimeX_inv @ Xprimew
681
+
682
+ # Compute raw gradient: 2 * dw @ X @ XprimeX_inv @ Xprimew
683
+ # dw has shape (k, n), X has shape (n, k)
684
+ # dw @ X has shape (k, k)
685
+ # (dw @ X @ XprimeX_inv @ Xprimew) has shape (k,)
686
+ raw_grad = 2 * dw @ X @ XprimeX_inv @ Xprimew
687
+
688
+ # Apply sign adjustment for absolute value (R's sapply logic)
689
+ # ifelse((x>0 & loss1>0) | (x<0 & loss1<0), abs(x), -abs(x))
690
+ # This means: if x and loss1 have the same sign, use abs(x); otherwise use -abs(x)
691
+ grad = np.where(
692
+ ((raw_grad > 0) & (loss1 > 0)) | ((raw_grad < 0) & (loss1 < 0)),
693
+ np.abs(raw_grad),
694
+ -np.abs(raw_grad)
695
+ )
696
+
697
+ return grad
698
+
699
+ def _vmmin_bfgs(
700
+ b0: np.ndarray,
701
+ fn: Callable,
702
+ gr: Optional[Callable],
703
+ maxit: int = 10000,
704
+ abstol: float = -np.inf,
705
+ reltol: float = np.sqrt(np.finfo(float).eps),
706
+ trace: bool = False,
707
+ nREPORT: int = 10,
708
+ show_progress: bool = False,
709
+ ) -> scipy.optimize.OptimizeResult:
710
+ """
711
+ R's vmmin BFGS optimizer, faithfully translated from C source.
712
+
713
+ This is a line-by-line translation of R's ``vmmin`` function from
714
+ ``src/appl/optim.c`` (the backend of ``optim(..., method="BFGS")``).
715
+ It uses a simple Armijo backtracking line search and a relative-
716
+ tolerance convergence criterion, which differ fundamentally from
717
+ scipy's Strong-Wolfe / gradient-norm approach.
718
+
719
+ Parameters
720
+ ----------
721
+ b0 : np.ndarray
722
+ Initial parameter vector, shape (n,).
723
+ fn : callable
724
+ Objective function ``fn(b) -> float``.
725
+ gr : callable or None
726
+ Gradient function ``gr(b) -> np.ndarray`` of shape (n,).
727
+ If None, uses forward finite differences with step size 1e-3,
728
+ matching R's ``optim`` default behavior (``fmingr`` in optim.c).
729
+ maxit : int
730
+ Maximum number of BFGS iterations (default 10000, R default).
731
+ abstol : float
732
+ Absolute tolerance on function value (default -inf, R default).
733
+ reltol : float
734
+ Relative tolerance on function value change
735
+ (default ``sqrt(eps) ≈ 1.49e-8``, R default).
736
+ trace : bool
737
+ If True, print iteration information (default False).
738
+ nREPORT : int
739
+ Report every *nREPORT* iterations when *trace* is True.
740
+
741
+ Returns
742
+ -------
743
+ scipy.optimize.OptimizeResult
744
+ Result object with fields ``x``, ``fun``, ``nit``, ``nfev``,
745
+ ``njev``, ``success``, ``message``.
746
+
747
+ Notes
748
+ -----
749
+ Constants hard-coded to match R exactly:
750
+
751
+ * ``stepredn = 0.2`` – step reduction factor in backtracking
752
+ * ``acctol = 0.0001`` – Armijo sufficient-decrease parameter
753
+ * ``reltest = 10.0`` – used to detect "no change" in parameters
754
+
755
+ Convergence criterion (R ``reltol``):
756
+ ``|f_new - f_old| > reltol * (|f_old| + reltol)``
757
+
758
+ References
759
+ ----------
760
+ J.C. Nash, *Compact Numerical Methods for Computers*, 2nd ed.
761
+ R Core Team, ``src/appl/optim.c`` (vmmin).
762
+ """
763
+ # If no analytical gradient provided, use R's default numerical gradient.
764
+ # R's fmingr in optim.c uses forward finite differences with ndeps=1e-3.
765
+ if gr is None:
766
+ _ndeps = DEFAULT_CONFIG.ndeps
767
+ def gr(b):
768
+ """Forward finite difference gradient, matching R's fmingr."""
769
+ f0 = fn(b)
770
+ g = np.empty_like(b)
771
+ for i in range(len(b)):
772
+ b_pert = b.copy()
773
+ b_pert[i] += _ndeps
774
+ g[i] = (fn(b_pert) - f0) / _ndeps
775
+ return g
776
+ # ---- constants (must match R exactly) ----
777
+ STEPREDN = 0.2
778
+ ACCTOL = 0.0001
779
+ RELTEST = 10.0
780
+
781
+ # Optional tqdm progress bar (soft dependency)
782
+ pbar = None
783
+ if show_progress:
784
+ try:
785
+ from tqdm import tqdm
786
+ pbar = tqdm(total=maxit, desc="BFGS optimization", leave=False)
787
+ except ImportError:
788
+ pass
789
+
790
+ n = len(b0)
791
+ b = b0.astype(float).copy()
792
+
793
+ if maxit <= 0:
794
+ f = fn(b)
795
+ if pbar:
796
+ pbar.close()
797
+ return scipy.optimize.OptimizeResult(
798
+ x=b, fun=f, nit=0, nfev=1, njev=0,
799
+ success=True, message="maxit <= 0, returning initial value",
800
+ )
801
+
802
+ # All parameters are free (mask = all True).
803
+ # In R, l[] maps free-parameter indices; here every index is free,
804
+ # so l[i] = i and the indirection is a no-op.
805
+
806
+ # ---- allocate working arrays ----
807
+ g = np.empty(n) # gradient
808
+ t = np.empty(n) # search direction
809
+ X = np.empty(n) # saved parameters
810
+ c = np.empty(n) # saved gradient
811
+ # B: lower-triangular BFGS Hessian-inverse approximation
812
+ # B[i][j] stored for j <= i (symmetric, only lower triangle kept)
813
+ B = np.zeros((n, n))
814
+
815
+ # ---- initial evaluation ----
816
+ f = fn(b)
817
+ if not np.isfinite(f):
818
+ raise ValueError(
819
+ "initial value in vmmin is not finite. "
820
+ "Suggestions: (1) Check for extreme covariate values, "
821
+ "(2) Scale covariates to have similar ranges, "
822
+ "(3) Remove covariates with very low variance, "
823
+ "(4) Try init_params with values closer to zero."
824
+ )
825
+ if trace:
826
+ print(f"initial value {f}")
827
+ Fmin = f
828
+ funcount = 1
829
+ gradcount = 1
830
+ g[:] = gr(b)
831
+ iter_ = 1
832
+ ilast = gradcount
833
+
834
+ while True:
835
+ # ---- Hessian reset when needed ----
836
+ if ilast == gradcount:
837
+ B[:, :] = 0.0
838
+ np.fill_diagonal(B, 1.0)
839
+
840
+ # ---- save current state ----
841
+ X[:] = b
842
+ c[:] = g
843
+
844
+ # ---- compute search direction t = -B g ----
845
+ # B is symmetric; use full matrix-vector product
846
+ t[:] = -(B @ g)
847
+ gradproj = float(t @ g)
848
+
849
+ if gradproj < 0.0:
850
+ # ---- downhill: backtracking line search ----
851
+ steplength = 1.0
852
+ accpoint = False
853
+ while True:
854
+ b[:] = X + steplength * t
855
+ count = int(np.sum(RELTEST + X == RELTEST + b))
856
+ if count < n:
857
+ f = fn(b)
858
+ funcount += 1
859
+ accpoint = (
860
+ np.isfinite(f)
861
+ and f <= Fmin + gradproj * steplength * ACCTOL
862
+ )
863
+ if not accpoint:
864
+ steplength *= STEPREDN
865
+ if count == n or accpoint:
866
+ break
867
+
868
+ enough = (
869
+ f > abstol
870
+ and abs(f - Fmin) > reltol * (abs(Fmin) + reltol)
871
+ )
872
+ if not enough:
873
+ count = n
874
+ Fmin = f
875
+
876
+ if count < n:
877
+ # ---- making progress ----
878
+ Fmin = f
879
+ g[:] = gr(b)
880
+ gradcount += 1
881
+ iter_ += 1
882
+
883
+ # prepare for BFGS update
884
+ t *= steplength # actual step
885
+ c[:] = g - c # gradient change
886
+ D1 = float(t @ c)
887
+
888
+ if D1 > 0:
889
+ # ---- BFGS Hessian-inverse update ----
890
+ # Compute X_tmp = B @ c (vectorized)
891
+ X[:] = B @ c
892
+ D2 = 1.0 + float(X @ c) / D1
893
+
894
+ # Rank-2 symmetric update (only lower triangle matters
895
+ # but we maintain full symmetric matrix for B @ g)
896
+ B += (D2 * np.outer(t, t)
897
+ - np.outer(X, t)
898
+ - np.outer(t, X)) / D1
899
+ else:
900
+ # D1 <= 0: curvature condition violated → reset
901
+ ilast = gradcount
902
+ else:
903
+ # ---- no progress ----
904
+ if ilast < gradcount:
905
+ count = 0
906
+ ilast = gradcount
907
+ else:
908
+ # ---- uphill search direction ----
909
+ count = 0
910
+ if ilast == gradcount:
911
+ count = n # already reset → give up
912
+ else:
913
+ ilast = gradcount # reset Hessian
914
+
915
+ if pbar:
916
+ pbar.update(1)
917
+
918
+ if trace and (iter_ % nREPORT == 0):
919
+ print(f"iter{iter_:4d} value {f}")
920
+
921
+ if iter_ >= maxit:
922
+ break
923
+
924
+ # ---- periodic restart ----
925
+ if gradcount - ilast > 2 * n:
926
+ ilast = gradcount
927
+
928
+ if count == n and ilast == gradcount:
929
+ break
930
+
931
+ if pbar:
932
+ pbar.close()
933
+
934
+ if trace:
935
+ print(f"final value {Fmin}")
936
+ if iter_ < maxit:
937
+ print("converged")
938
+ else:
939
+ print(f"stopped after {iter_} iterations")
940
+
941
+ success = iter_ < maxit
942
+ return scipy.optimize.OptimizeResult(
943
+ x=b,
944
+ fun=Fmin,
945
+ nit=iter_,
946
+ nfev=funcount,
947
+ njev=gradcount,
948
+ success=success,
949
+ message="converged" if success else f"stopped after {iter_} iterations",
950
+ )
951
+
952
+
953
+
954
+ def _glm_init(
955
+ treat: np.ndarray,
956
+ X: np.ndarray,
957
+ sample_weights: np.ndarray,
958
+ att: int,
959
+ gmm_loss_func: Callable
960
+ ) -> Tuple[np.ndarray, np.ndarray]:
961
+ """
962
+ Initialize GLM coefficients through six-step optimization.
963
+
964
+ This function computes initial values for the CBPS optimization by
965
+ fitting a standard GLM model and then optimizing the scaling factor
966
+ alpha to minimize the GMM loss function.
967
+
968
+ Parameters
969
+ ----------
970
+ treat : np.ndarray
971
+ Binary treatment indicator, shape (n,).
972
+ X : np.ndarray
973
+ Covariate matrix, shape (n, k).
974
+ sample_weights : np.ndarray
975
+ Normalized sampling weights, shape (n,).
976
+ att : int
977
+ Estimand type: 0 for ATE, 1 for ATT.
978
+ gmm_loss_func : Callable
979
+ GMM loss function for alpha scaling optimization.
980
+
981
+ Returns
982
+ -------
983
+ beta_init : np.ndarray
984
+ Initial coefficients after GLM fitting and alpha scaling.
985
+ beta_glm : np.ndarray
986
+ Original GLM coefficients, used for computing MLE J-statistic.
987
+
988
+ Notes
989
+ -----
990
+ Six-step initialization process:
991
+ 1. Fit GLM model with warnings suppressed
992
+ 2. Set NA coefficients to 0 (first pass)
993
+ 3. Sequential probability clipping
994
+ 4. Extract coefficients and handle NA (second pass)
995
+ 5. Optimize alpha scaling factor in [0.8, 1.1]
996
+ """
997
+ # Step 1: GLM fitting with warnings suppressed
998
+ # Note: GLM doesn't use weights parameter here; sample_weights are used only in GMM steps.
999
+ with warnings.catch_warnings():
1000
+ warnings.simplefilter("ignore")
1001
+ model = sm.GLM(treat, X, family=Binomial())
1002
+ glm_fit = model.fit(tol=DEFAULT_CONFIG.glm_tol, maxiter=25) # Standard IRLS algorithm
1003
+
1004
+ # Step 2: Handle NA coefficients (first pass)
1005
+ beta_glm = glm_fit.params.copy()
1006
+ beta_glm[np.isnan(beta_glm)] = 0
1007
+
1008
+ # Step 3: Probability clipping
1009
+ probs_glm = np.clip(glm_fit.fittedvalues, PROBS_MIN, 1 - PROBS_MIN)
1010
+
1011
+ # Step 4: Extract coefficients and handle NA (second pass)
1012
+ beta_curr = beta_glm.copy()
1013
+ beta_curr[np.isnan(beta_curr)] = 0
1014
+
1015
+ # Step 5: Alpha scaling optimization (1D search for optimal scaling factor)
1016
+ alpha_func = lambda alpha: gmm_loss_func(beta_curr * alpha)
1017
+ result = scipy.optimize.minimize_scalar(
1018
+ alpha_func, bounds=(0.8, 1.1), method='bounded'
1019
+ )
1020
+ beta_curr = beta_curr * result.x
1021
+
1022
+ # Return: scaled coefficients and original GLM coefficients (for MLE J-statistic)
1023
+ return beta_curr, beta_glm
1024
+
1025
+
1026
+ def _compute_moment_conditions(
1027
+ beta: np.ndarray,
1028
+ X: np.ndarray,
1029
+ treat: np.ndarray,
1030
+ sample_weights: np.ndarray,
1031
+ att: int,
1032
+ n: int
1033
+ ) -> np.ndarray:
1034
+ """
1035
+ Compute CBPS moment conditions (covariate balance conditions).
1036
+
1037
+ Implements the moment conditions from Imai & Ratkovic (2014) JRSS-B:
1038
+ - Equation (10): ATE balance condition
1039
+ - Equation (11): ATT balance condition
1040
+
1041
+ Parameters
1042
+ ----------
1043
+ beta : np.ndarray
1044
+ Coefficient vector, shape (k,).
1045
+ X : np.ndarray
1046
+ Covariate matrix, shape (n, k).
1047
+ treat : np.ndarray
1048
+ Binary treatment vector (0/1), shape (n,).
1049
+ sample_weights : np.ndarray
1050
+ Sample weights, shape (n,).
1051
+ att : int
1052
+ Estimand: 0=ATE, 1=ATT (T=1 is treated), 2=ATT (T=0 is treated).
1053
+ n : int
1054
+ Sample size.
1055
+
1056
+ Returns
1057
+ -------
1058
+ np.ndarray
1059
+ k-dimensional moment condition vector.
1060
+ For just-identified GMM: moments should be approximately zero.
1061
+
1062
+ Notes
1063
+ -----
1064
+ This is the core of just-identified GMM: k equations for k unknowns.
1065
+ The theoretical requirement is to solve moments = 0 directly.
1066
+ """
1067
+ theta = X @ beta
1068
+ pi = scipy.special.expit(theta)
1069
+ pi = np.clip(pi, PROBS_MIN, 1 - PROBS_MIN)
1070
+
1071
+ # Compute weights based on estimand (Equations 10/11 in the paper)
1072
+ if att == 1:
1073
+ # ATT Equation (11): w = (n/n_1) * (T - pi) / (1 - pi)
1074
+ n_treated = np.sum(treat * sample_weights)
1075
+ w = (n / n_treated) * (treat - pi) / (1 - pi)
1076
+ elif att == 2:
1077
+ # ATT with reversed treatment (T=0 is treated)
1078
+ n_control = np.sum((1 - treat) * sample_weights)
1079
+ w = (n / n_control) * (treat - pi) / pi
1080
+ else:
1081
+ # ATE Equation (10): w = (T - pi) / (pi * (1 - pi))
1082
+ w = (treat - pi) / (pi * (1 - pi))
1083
+
1084
+ # Moment conditions (covariate balance)
1085
+ moments = (sample_weights[:, None] * X).T @ w / n
1086
+
1087
+ return moments
1088
+
1089
+
1090
+ def _solve_moment_equations(
1091
+ beta_init: np.ndarray,
1092
+ X: np.ndarray,
1093
+ treat: np.ndarray,
1094
+ sample_weights: np.ndarray,
1095
+ att: int,
1096
+ n: int,
1097
+ iterations: int = 1000
1098
+ ) -> Tuple[np.ndarray, bool, np.ndarray, str]:
1099
+ """
1100
+ Solve moment equations directly (theoretically correct just-identified GMM).
1101
+
1102
+ This implementation follows the GMM framework:
1103
+ - Hansen (1982) GMM: Just-identified = solve E[g(X, theta)] = 0
1104
+ - Imai & Ratkovic (2014) Equations (10)/(11): Balance conditions
1105
+
1106
+ Parameters
1107
+ ----------
1108
+ beta_init : np.ndarray
1109
+ Initial values (from GLM or balance optimization), shape (k,).
1110
+
1111
+ Returns
1112
+ -------
1113
+ beta_opt : np.ndarray
1114
+ Optimal coefficients satisfying moment = 0.
1115
+ success : bool
1116
+ Whether convergence was achieved.
1117
+ moments_final : np.ndarray
1118
+ Final moment values (should be approximately zero).
1119
+ method : str
1120
+ Solver method used.
1121
+
1122
+ Notes
1123
+ -----
1124
+ Advantages over balance loss optimization:
1125
+ 1. Theoretically correct: directly corresponds to just-identified GMM
1126
+ 2. Numerical precision: can achieve machine precision (~1e-15)
1127
+ 3. Computational efficiency: typically faster
1128
+
1129
+ Solver strategy:
1130
+ 1. First try 'hybr' (hybrid Powell, robust and fast)
1131
+ 2. Fall back to 'lm' (Levenberg-Marquardt, more robust but slower)
1132
+ 3. If both fail, return failure status
1133
+ """
1134
+ from scipy.optimize import root
1135
+
1136
+ def moment_eq(beta):
1137
+ """Moment equations: k equations for k unknowns."""
1138
+ return _compute_moment_conditions(beta, X, treat, sample_weights, att, n)
1139
+
1140
+ # Primary solver: hybrid Powell method (fast and robust)
1141
+ result = root(
1142
+ moment_eq,
1143
+ x0=beta_init,
1144
+ method='hybr',
1145
+ options={'xtol': DEFAULT_CONFIG.optim_xtol, 'maxfev': iterations * 10}
1146
+ )
1147
+
1148
+ if result.success:
1149
+ moments_final = moment_eq(result.x)
1150
+ return result.x, True, moments_final, 'hybr'
1151
+
1152
+ # Fallback: Levenberg-Marquardt (more robust)
1153
+ try:
1154
+ result = root(
1155
+ moment_eq,
1156
+ x0=beta_init,
1157
+ method='lm',
1158
+ options={'xtol': DEFAULT_CONFIG.optim_xtol, 'maxiter': iterations * 5}
1159
+ )
1160
+
1161
+ if result.success:
1162
+ moments_final = moment_eq(result.x)
1163
+ return result.x, True, moments_final, 'lm'
1164
+ except (ValueError, RuntimeError, np.linalg.LinAlgError):
1165
+ pass
1166
+
1167
+ # Both solvers failed: return initial values with failure status
1168
+ moments_final = moment_eq(beta_init)
1169
+ return beta_init, False, moments_final, 'failed'
1170
+
1171
+
1172
+ def _optimize_balance(
1173
+ gmm_init: np.ndarray,
1174
+ X: np.ndarray,
1175
+ treat: np.ndarray,
1176
+ sample_weights: np.ndarray,
1177
+ XprimeX_inv: np.ndarray,
1178
+ att: int,
1179
+ two_step: bool,
1180
+ iterations: int,
1181
+ bal_only: bool = False,
1182
+ show_progress: bool = False,
1183
+ **kwargs
1184
+ ) -> scipy.optimize.OptimizeResult:
1185
+ """
1186
+ Optimize balance loss to find initial values for GMM.
1187
+
1188
+ Uses R's vmmin BFGS algorithm (simple Armijo backtracking line search
1189
+ with reltol convergence) to exactly replicate R CBPS package behavior.
1190
+
1191
+ Parameters
1192
+ ----------
1193
+ gmm_init : np.ndarray
1194
+ GLM-initialized coefficients, shape (k,).
1195
+ bal_only : bool
1196
+ Whether this is just-identified mode (method='exact').
1197
+ **kwargs
1198
+ Additional arguments passed through from CBPS wrapper.
1199
+
1200
+ Returns
1201
+ -------
1202
+ scipy.optimize.OptimizeResult
1203
+ Balance optimization result object.
1204
+
1205
+ Notes
1206
+ -----
1207
+ The analytical gradient is required for reliable optimization because
1208
+ the balance loss function contains an absolute value, which has
1209
+ discontinuous derivatives at zero. Numerical gradients perform poorly
1210
+ in this case.
1211
+ """
1212
+ bal_loss_func = lambda b: _bal_loss(b, X, treat, sample_weights, XprimeX_inv, att)
1213
+ bal_grad_func = lambda b: _bal_gradient(b, X, treat, sample_weights, XprimeX_inv, att)
1214
+
1215
+ verbose = kwargs.get('verbose', False)
1216
+
1217
+ # R CBPS package only provides analytical gradient for balance optimization
1218
+ # when twostep=TRUE. For continuous updating (twostep=FALSE), R uses
1219
+ # numerical gradients (finite differences via optim's default behavior).
1220
+ gr_func = bal_grad_func if two_step else None
1221
+
1222
+ # Use R's vmmin BFGS (faithful translation of R's optim(..., method="BFGS"))
1223
+ # This ensures identical convergence behavior: simple Armijo backtracking
1224
+ # line search + reltol convergence criterion.
1225
+ opt_bal = _vmmin_bfgs(
1226
+ gmm_init,
1227
+ fn=bal_loss_func,
1228
+ gr=gr_func,
1229
+ maxit=iterations,
1230
+ trace=verbose,
1231
+ show_progress=show_progress,
1232
+ )
1233
+
1234
+ return opt_bal
1235
+
1236
+
1237
+ def _optimize_gmm_dual_init(
1238
+ gmm_init: np.ndarray,
1239
+ beta_bal: np.ndarray,
1240
+ X: np.ndarray,
1241
+ treat: np.ndarray,
1242
+ sample_weights: np.ndarray,
1243
+ att: int,
1244
+ this_inv_V: np.ndarray,
1245
+ two_step: bool,
1246
+ iterations: int,
1247
+ show_progress: bool = False,
1248
+ **kwargs
1249
+ ) -> scipy.optimize.OptimizeResult:
1250
+ """
1251
+ Perform GMM optimization with dual initialization strategy.
1252
+
1253
+ Runs GMM optimization from two starting points (GLM-initialized and
1254
+ balance-optimized) and returns the result with lower objective value.
1255
+ Uses R's vmmin BFGS algorithm for exact replication.
1256
+
1257
+ Parameters
1258
+ ----------
1259
+ gmm_init : np.ndarray
1260
+ GLM-initialized coefficients (after alpha scaling), shape (k,).
1261
+ beta_bal : np.ndarray
1262
+ Balance-optimized coefficients, shape (k,).
1263
+ this_inv_V : np.ndarray
1264
+ Precomputed inverse covariance matrix (for two-step GMM).
1265
+
1266
+ Returns
1267
+ -------
1268
+ scipy.optimize.OptimizeResult
1269
+ Optimization result with lower objective value.
1270
+
1271
+ Notes
1272
+ -----
1273
+ The dual initialization strategy improves robustness by exploring
1274
+ different regions of the parameter space. When two_step=True, analytical
1275
+ gradients are used following the R CBPS package implementation.
1276
+ """
1277
+ verbose = kwargs.get('verbose', False)
1278
+
1279
+ if two_step:
1280
+ # Two-step GMM optimization using analytical gradients (R-compatible)
1281
+ def gmm_loss_with_inv_V(b):
1282
+ return _gmm_loss(b, X, treat, sample_weights, att, this_inv_V)
1283
+
1284
+ def gmm_grad_with_inv_V(b):
1285
+ return _gmm_gradient(b, this_inv_V, X, treat, sample_weights, att)
1286
+
1287
+ gmm_glm_init = _vmmin_bfgs(
1288
+ gmm_init,
1289
+ fn=gmm_loss_with_inv_V,
1290
+ gr=gmm_grad_with_inv_V,
1291
+ maxit=iterations,
1292
+ trace=verbose,
1293
+ show_progress=show_progress,
1294
+ )
1295
+ gmm_bal_init = _vmmin_bfgs(
1296
+ beta_bal,
1297
+ fn=gmm_loss_with_inv_V,
1298
+ gr=gmm_grad_with_inv_V,
1299
+ maxit=iterations,
1300
+ trace=verbose,
1301
+ show_progress=show_progress,
1302
+ )
1303
+ else:
1304
+ # Continuous updating GMM optimization
1305
+ # R CBPS package does NOT provide analytical gradients for continuous
1306
+ # updating (twostep=FALSE). It relies on numerical differentiation
1307
+ # via optim's default finite-difference method. This is because
1308
+ # _gmm_gradient treats inv_V as fixed, which is only valid for
1309
+ # two-step GMM where V is pre-computed. In continuous updating,
1310
+ # V is recomputed at each iteration, making the fixed-V gradient
1311
+ # only an approximation.
1312
+ def gmm_loss_continuous(b):
1313
+ return _gmm_loss(b, X, treat, sample_weights, att, None)
1314
+
1315
+ gmm_glm_init = _vmmin_bfgs(
1316
+ gmm_init,
1317
+ fn=gmm_loss_continuous,
1318
+ gr=None,
1319
+ maxit=iterations,
1320
+ trace=verbose,
1321
+ show_progress=show_progress,
1322
+ )
1323
+ gmm_bal_init = _vmmin_bfgs(
1324
+ beta_bal,
1325
+ fn=gmm_loss_continuous,
1326
+ gr=None,
1327
+ maxit=iterations,
1328
+ trace=verbose,
1329
+ show_progress=show_progress,
1330
+ )
1331
+
1332
+ # Return the result with lower objective value
1333
+ if gmm_glm_init.fun < gmm_bal_init.fun:
1334
+ return gmm_glm_init
1335
+ else:
1336
+ return gmm_bal_init
1337
+
1338
+
1339
+ def _classify_separation(
1340
+ probs_opt_raw: np.ndarray,
1341
+ beta_opt: np.ndarray,
1342
+ extreme_coef_threshold: float = 10.0
1343
+ ) -> Optional[Tuple[str, str]]:
1344
+ """
1345
+ Classify separation severity based on propensity scores at boundaries.
1346
+
1347
+ This is a pure function that examines how many observations have
1348
+ propensity scores at or beyond the clipping boundaries (PROBS_MIN
1349
+ and 1 - PROBS_MIN) and returns the appropriate severity level.
1350
+
1351
+ Parameters
1352
+ ----------
1353
+ probs_opt_raw : np.ndarray
1354
+ Raw (unclipped) propensity scores from expit(X @ beta), shape (n,).
1355
+ beta_opt : np.ndarray
1356
+ Optimized coefficient vector, shape (k,). Used to check for
1357
+ extreme coefficients.
1358
+ extreme_coef_threshold : float, default 10.0
1359
+ Threshold for flagging extreme coefficients.
1360
+
1361
+ Returns
1362
+ -------
1363
+ tuple or None
1364
+ If no boundary observations: returns None.
1365
+ Otherwise: (severity_level, warning_msg) where severity_level is one of
1366
+ 'MINOR', 'MODERATE SEPARATION', 'QUASI-SEPARATION', 'COMPLETE SEPARATION'.
1367
+ """
1368
+ n = len(probs_opt_raw)
1369
+ n_clipped_low = np.sum(probs_opt_raw <= PROBS_MIN)
1370
+ n_clipped_high = np.sum(probs_opt_raw >= 1 - PROBS_MIN)
1371
+ n_boundary = n_clipped_low + n_clipped_high
1372
+
1373
+ if n_boundary == 0:
1374
+ return None
1375
+
1376
+ boundary_pct = 100.0 * n_boundary / n
1377
+
1378
+ # Check for extreme coefficients (may indicate separation)
1379
+ extreme_coef_mask = np.abs(beta_opt) > extreme_coef_threshold
1380
+ has_extreme_coefs = np.any(extreme_coef_mask)
1381
+
1382
+ # Build common diagnostic lines
1383
+ _header_lines = (
1384
+ f"Detected: {n_boundary} observations ({boundary_pct:.1f}%) "
1385
+ f"at probability boundary\n"
1386
+ f" - Low boundary (\u03c0 \u2264 {PROBS_MIN}): {n_clipped_low}\n"
1387
+ f" - High boundary (\u03c0 \u2265 {1-PROBS_MIN}): {n_clipped_high}"
1388
+ )
1389
+ _extreme_line = (
1390
+ f"\n - Extreme coefficients (|\u03b2| > {extreme_coef_threshold}): detected"
1391
+ if has_extreme_coefs else ""
1392
+ )
1393
+
1394
+ # Issue graduated warnings based on severity
1395
+ if boundary_pct >= 100.0:
1396
+ severity_level = "COMPLETE SEPARATION"
1397
+ suggestions = [
1398
+ "Check for perfect predictors: examine if any covariate "
1399
+ "perfectly separates treatment groups",
1400
+ "Consider penalized estimation: use hdCBPS with LASSO regularization",
1401
+ "Remove or combine highly predictive variables",
1402
+ "Consider Firth's penalized likelihood as initialization",
1403
+ "Verify data coding: check for data entry errors in treatment variable",
1404
+ ]
1405
+ elif boundary_pct >= 50.0:
1406
+ severity_level = "QUASI-SEPARATION"
1407
+ suggestions = [
1408
+ "Check for multicollinearity: compute VIF for covariates",
1409
+ "Consider trimming: remove units with extreme propensity scores "
1410
+ "(Crump et al. 2009)",
1411
+ "Use regularized estimation (hdCBPS with LASSO)",
1412
+ "Report sensitivity analysis with different trimming thresholds",
1413
+ "Consider weight truncation at the 1st/99th percentile",
1414
+ ]
1415
+ elif boundary_pct >= 10.0:
1416
+ severity_level = "MODERATE SEPARATION"
1417
+ suggestions = [
1418
+ "Examine covariate balance after weighting",
1419
+ "Report effective sample size (ESS = (sum(w))^2 / sum(w^2))",
1420
+ "Verify stability: compare results with 'exact' vs 'over' method",
1421
+ "Consider standardize=True to reduce weight variability",
1422
+ ]
1423
+ else:
1424
+ severity_level = "MINOR"
1425
+ suggestions = [
1426
+ "This is usually acceptable but check covariate balance.",
1427
+ ]
1428
+
1429
+ # Assemble final warning message
1430
+ suggestion_text = "\n".join(
1431
+ f" {i}. {s}" for i, s in enumerate(suggestions, 1)
1432
+ )
1433
+ warning_msg = (
1434
+ f"[CBPS Separation Warning - {severity_level}]\n"
1435
+ f"{_header_lines}{_extreme_line}\n"
1436
+ f"Suggested actions:\n"
1437
+ f"{suggestion_text}"
1438
+ )
1439
+
1440
+ return severity_level, warning_msg
1441
+
1442
+
1443
+ def _compute_final_weights(
1444
+ beta_opt: np.ndarray,
1445
+ X: np.ndarray,
1446
+ treat: np.ndarray,
1447
+ sample_weights: np.ndarray,
1448
+ att: int,
1449
+ standardize: bool
1450
+ ) -> Tuple[np.ndarray, np.ndarray]:
1451
+ """
1452
+ Compute final propensity scores and inverse probability weights.
1453
+
1454
+ Parameters
1455
+ ----------
1456
+ beta_opt : np.ndarray
1457
+ Optimized coefficient vector, shape (k,).
1458
+ X : np.ndarray
1459
+ Covariate matrix, shape (n, k).
1460
+ treat : np.ndarray
1461
+ Binary treatment indicator, shape (n,).
1462
+ sample_weights : np.ndarray
1463
+ Sampling weights, shape (n,).
1464
+ att : int
1465
+ Estimand type: 0 for ATE, 1 for ATT.
1466
+ standardize : bool
1467
+ Whether to normalize weights to sum to sample size.
1468
+
1469
+ Returns
1470
+ -------
1471
+ probs_opt : np.ndarray
1472
+ Final propensity scores, shape (n,).
1473
+ w_opt : np.ndarray
1474
+ Final inverse probability weights (standardized and incorporating
1475
+ sample_weights), shape (n,).
1476
+
1477
+ Notes
1478
+ -----
1479
+ The weight computation follows these steps:
1480
+ 1. Compute propensity scores from optimized coefficients
1481
+ 2. Compute initial IPW weights (ATT or ATE formula)
1482
+ 3. Standardize weights if requested
1483
+ 4. Incorporate sampling weights
1484
+ """
1485
+ # Compute propensity scores from optimized coefficients
1486
+ theta_opt = X @ beta_opt
1487
+ probs_opt_raw = scipy.special.expit(theta_opt)
1488
+ probs_opt = np.clip(probs_opt_raw, PROBS_MIN, 1 - PROBS_MIN)
1489
+
1490
+ # Detect separation issues (propensity scores at boundaries)
1491
+ result = _classify_separation(probs_opt_raw, beta_opt)
1492
+ if result is not None:
1493
+ _, warning_msg = result
1494
+ warnings.warn(warning_msg, UserWarning, stacklevel=3)
1495
+
1496
+ # Compute initial IPW weights
1497
+ if att:
1498
+ # ATT weights
1499
+ w_opt = np.abs(_att_wt_func(beta_opt, X, treat, sample_weights))
1500
+ else:
1501
+ # ATE weights
1502
+ w_opt = np.abs(1 / (probs_opt - 1 + treat))
1503
+
1504
+ # Standardize weights and incorporate sampling weights
1505
+ w_opt = standardize_weights(w_opt, treat, probs_opt, sample_weights, att, standardize)
1506
+
1507
+ return probs_opt, w_opt
1508
+
1509
+
1510
+ def _compute_diagnostics(
1511
+ beta_opt: np.ndarray,
1512
+ beta_glm: np.ndarray,
1513
+ probs_opt: np.ndarray,
1514
+ treat: np.ndarray,
1515
+ sample_weights: np.ndarray,
1516
+ att: int,
1517
+ two_step: bool,
1518
+ this_inv_V: np.ndarray,
1519
+ X: np.ndarray
1520
+ ) -> Tuple[float, float, float, float]:
1521
+ """
1522
+ Compute J-statistic, deviance, and null deviance.
1523
+
1524
+ Returns
1525
+ -------
1526
+ J_opt : float
1527
+ J-statistic (GMM loss, over-identification test).
1528
+ mle_J : float
1529
+ MLE baseline J (computed with GLM coefficients).
1530
+ deviance : float
1531
+ Negative 2 times weighted log-likelihood.
1532
+ nulldeviance : float
1533
+ Null model deviance (intercept-only model).
1534
+
1535
+ Notes
1536
+ -----
1537
+ The J-statistic can be used to test the over-identifying restrictions
1538
+ in the GMM framework. Under the null hypothesis of correct specification,
1539
+ J ~ chi-squared with degrees of freedom equal to the number of
1540
+ over-identifying restrictions.
1541
+ """
1542
+ # Compute J-statistic based on two-step or continuous updating
1543
+ if two_step:
1544
+ J_opt = _gmm_func(beta_opt, X, treat, sample_weights, att, inv_V=this_inv_V)['loss']
1545
+ else:
1546
+ J_opt = _gmm_func(beta_opt, X, treat, sample_weights, att, inv_V=None)['loss']
1547
+
1548
+ # Compute MLE baseline J-statistic using GLM coefficients
1549
+ if two_step:
1550
+ mle_J = _gmm_func(beta_glm, X, treat, sample_weights, att, inv_V=this_inv_V)['loss']
1551
+ else:
1552
+ mle_J = _gmm_func(beta_glm, X, treat, sample_weights, att, inv_V=None)['loss']
1553
+
1554
+ # Deviance: negative 2 times weighted log-likelihood
1555
+ deviance = -2 * np.sum(
1556
+ treat * sample_weights * np.log(probs_opt) +
1557
+ (1 - treat) * sample_weights * np.log(1 - probs_opt)
1558
+ )
1559
+
1560
+ # Null deviance: intercept-only model with predicted probability = sample mean
1561
+ treat_mean = np.average(treat, weights=sample_weights)
1562
+ treat_mean = np.clip(treat_mean, 1e-10, 1 - 1e-10) # Prevent log(0)
1563
+ nulldeviance = -2 * np.sum(
1564
+ treat * sample_weights * np.log(treat_mean) +
1565
+ (1 - treat) * sample_weights * np.log(1 - treat_mean)
1566
+ )
1567
+
1568
+ return J_opt, mle_J, deviance, nulldeviance
1569
+
1570
+
1571
+ def _compute_vcov(
1572
+ beta_opt: np.ndarray,
1573
+ probs_opt: np.ndarray,
1574
+ treat: np.ndarray,
1575
+ X: np.ndarray,
1576
+ sample_weights: np.ndarray,
1577
+ att: int,
1578
+ bal_only: bool,
1579
+ XprimeX_inv: np.ndarray,
1580
+ this_inv_V: np.ndarray,
1581
+ two_step: bool,
1582
+ n: int
1583
+ ) -> np.ndarray:
1584
+ """
1585
+ Compute sandwich variance-covariance matrix.
1586
+
1587
+ Returns
1588
+ -------
1589
+ np.ndarray
1590
+ Coefficient variance-covariance matrix, shape (k, k).
1591
+
1592
+ Notes
1593
+ -----
1594
+ Implements the sandwich estimator (Newey & McFadden 1994, Eq. 6.17):
1595
+ Var(beta_hat) = (G'WG)^-1 G'W Omega W'G (G'WG)^-1
1596
+
1597
+ Processing steps:
1598
+ 1. Construct G matrix (gradients) and W1 matrix (moment conditions)
1599
+ 2. Assemble G and W matrices based on identification mode
1600
+ 3. Compute variance using sandwich formula
1601
+ """
1602
+ n_c = np.sum(sample_weights[treat == 0])
1603
+ n_t = np.sum(sample_weights[treat == 1])
1604
+
1605
+ # Score condition components (shared by ATT/ATE)
1606
+ XG_1 = -X * (probs_opt * (1 - probs_opt))[:, None] * sample_weights[:, None]
1607
+ XW_1 = X * (treat - probs_opt)[:, None] * np.sqrt(sample_weights)[:, None]
1608
+
1609
+ # Balance condition components (ATT/ATE branches)
1610
+ if att:
1611
+ # ATT branch
1612
+ XW_2 = X * _att_wt_func(beta_opt, X, treat, sample_weights)[:, None] * sample_weights[:, None]
1613
+ dw2 = -n / n_t * probs_opt / (1 - probs_opt)
1614
+ dw2[treat == 1] = 0 # Zero derivative for treated units
1615
+ XG_2 = X * dw2[:, None] * sample_weights[:, None]
1616
+ else:
1617
+ # ATE branch
1618
+ XW_2 = X * (1 / (probs_opt - 1 + treat))[:, None] * np.sqrt(sample_weights)[:, None]
1619
+ XG_2 = -X * ((treat - probs_opt)**2 / (probs_opt * (1 - probs_opt)))[:, None] * sample_weights[:, None]
1620
+
1621
+ # Assemble G and W matrices based on identification mode
1622
+ if bal_only: # method='exact'
1623
+ # Balance conditions only
1624
+ G = (XG_2.T @ X) / n
1625
+ W1 = XW_2.T
1626
+ W = XprimeX_inv
1627
+ else: # method='over'
1628
+ # Score + balance conditions
1629
+ G = np.hstack([(XG_1.T @ X), (XG_2.T @ X)]) / n
1630
+ W1 = np.vstack([XW_1.T, XW_2.T])
1631
+
1632
+ # Select W matrix based on estimation method
1633
+ if two_step:
1634
+ W = this_inv_V # Reuse precomputed
1635
+ else:
1636
+ W = _gmm_func(beta_opt, X, treat, sample_weights, att, inv_V=None)['inv_V']
1637
+
1638
+ # Sandwich formula
1639
+ Omega = (W1 @ W1.T) / n # Moment condition covariance
1640
+ GWG = G @ W @ G.T
1641
+ GWGinv = _r_ginv(GWG) # Moore-Penrose pseudoinverse
1642
+ GWGinvGW = GWGinv @ G @ W
1643
+ vcov = GWGinvGW @ Omega @ GWGinvGW.T
1644
+
1645
+ return vcov
1646
+
1647
+
1648
+ def cbps_binary_fit(
1649
+ treat: np.ndarray,
1650
+ X: np.ndarray,
1651
+ att: int = 1,
1652
+ method: str = 'over',
1653
+ two_step: bool = True,
1654
+ standardize: bool = True,
1655
+ sample_weights: Optional[np.ndarray] = None,
1656
+ iterations: int = 1000,
1657
+ XprimeX_inv: Optional[np.ndarray] = None,
1658
+ verbose: int = 0,
1659
+ init_params: Optional[np.ndarray] = None,
1660
+ show_progress: bool = False,
1661
+ **kwargs
1662
+ ) -> Dict:
1663
+ """
1664
+ Estimate covariate balancing propensity scores for binary treatments.
1665
+
1666
+ Implements the covariate balancing propensity score (CBPS) methodology
1667
+ for binary treatment assignments using generalized method of moments
1668
+ (GMM) estimation. The function simultaneously optimizes covariate balance
1669
+ and prediction of treatment assignment.
1670
+
1671
+ Parameters
1672
+ ----------
1673
+ treat : np.ndarray
1674
+ Binary treatment indicator vector coded as 0/1, shape (n,).
1675
+ X : np.ndarray
1676
+ Covariate matrix including intercept column, shape (n, k).
1677
+ The intercept should be the first column.
1678
+ att : int, default 1
1679
+ Target estimand for estimation:
1680
+ - 0: Average treatment effect (ATE)
1681
+ - 1: Average treatment effect on the treated (ATT) with treatment=1
1682
+ - 2: Average treatment effect on the treated (ATT) with treatment=0
1683
+ method : {'over', 'exact'}, default 'over'
1684
+ GMM estimation method:
1685
+ - 'over': Over-identified GMM combining likelihood and balance conditions
1686
+ - 'exact': Exactly-identified GMM using balance conditions only
1687
+ two_step : bool, default True
1688
+ GMM estimator type:
1689
+ - True: Two-step GMM with pre-computed weight matrix (faster)
1690
+ - False: Continuous-updating GMM with iterative weight updates
1691
+ standardize : bool, default True
1692
+ Weight standardization:
1693
+ - True: Weights sum to 1 within each treatment group
1694
+ - False: Return Horvitz-Thompson weights
1695
+ sample_weights : np.ndarray, optional
1696
+ Sampling weights for observations. If None, defaults to equal weights.
1697
+ iterations : int, default 1000
1698
+ Maximum number of iterations for the optimization algorithm.
1699
+ XprimeX_inv : np.ndarray, optional
1700
+ Pre-computed inverse of X'X matrix for balance loss computation.
1701
+ init_params : np.ndarray, optional
1702
+ Initial parameter values for warm start. If provided, skips GLM
1703
+ initialization and uses these values directly. Length must equal
1704
+ the number of columns in X.
1705
+ theoretical_exact : bool, default False (passed via **kwargs)
1706
+ Only applicable when method='exact':
1707
+ - True: Direct equation solver for moment conditions (precision ~1e-15)
1708
+ - False: Balance loss optimization (R-compatible, precision ~1e-6)
1709
+ **kwargs
1710
+ Additional arguments passed to scipy.optimize.minimize.
1711
+
1712
+ Returns
1713
+ -------
1714
+ dict
1715
+ Fitted CBPS object containing:
1716
+ - coefficients: Estimated propensity score coefficients, shape (k, 1)
1717
+ - fitted_values: Estimated propensity scores, shape (n,)
1718
+ - weights: CBPS weights for causal effect estimation, shape (n,)
1719
+ - J: J-statistic for overidentification test
1720
+ - var: Asymptotic variance-covariance matrix, shape (k, k)
1721
+ - converged: Boolean convergence indicator
1722
+ - mle_J: Maximum likelihood J-statistic
1723
+ - deviance: Model deviance
1724
+ - linear_predictor: Linear predictor values (X @ coefficients)
1725
+ - y: Treatment indicator vector
1726
+ - x: Covariate matrix
1727
+
1728
+ Notes
1729
+ -----
1730
+ The algorithm implements the following key steps:
1731
+ 1. Initial MLE estimation for starting values
1732
+ 2. Balance loss optimization for initial GMM values
1733
+ 3. GMM optimization to satisfy both score and balance conditions
1734
+ 4. Final weight computation and diagnostics
1735
+
1736
+ For ATT estimation, weights are constructed to balance covariates between
1737
+ the treated group and the weighted control group. For ATE estimation,
1738
+ weights balance all groups simultaneously.
1739
+
1740
+ References
1741
+ ----------
1742
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
1743
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
1744
+ https://doi.org/10.1111/rssb.12027
1745
+ """
1746
+ # Ensure dense matrix (sparse input auto-converted)
1747
+ X = ensure_dense(X)
1748
+ treat = np.asarray(treat, dtype=float).ravel()
1749
+
1750
+ # Input validation: NaN/Inf check (before any computation)
1751
+ validate_cbps_input(
1752
+ treat, X,
1753
+ min_observations=2,
1754
+ module_name="Binary CBPS",
1755
+ check_treatment_variance=False
1756
+ )
1757
+
1758
+ # Normalize att parameter (support string: 'ate', 'att', 'atc')
1759
+ att = _normalize_att(att)
1760
+
1761
+ n = len(treat)
1762
+ bal_only = (method == 'exact')
1763
+
1764
+ # Note: SVD preprocessing is applied in the CBPS() main function before
1765
+ # calling this function, matching R package's CBPSMain.R behavior.
1766
+ # The X passed here may already be SVD-transformed (U matrix).
1767
+ # X is not modified in-place; use view to avoid unnecessary copy
1768
+ X_orig = X
1769
+
1770
+ # Full rank check
1771
+ k = np.linalg.matrix_rank(X)
1772
+ if k < X.shape[1]:
1773
+ raise ValueError(
1774
+ f"X is not full rank: rank={k} < ncol={X.shape[1]}. "
1775
+ f"Suggestions: (1) Remove collinear variables, "
1776
+ f"(2) Check for duplicate columns, "
1777
+ f"(3) Use hdCBPS for automatic variable selection."
1778
+ )
1779
+
1780
+ # Step 1: Normalize sample weights
1781
+ sample_weights = normalize_sample_weights(sample_weights, n)
1782
+ n_c = np.sum(sample_weights[treat == 0])
1783
+ n_t = np.sum(sample_weights[treat == 1])
1784
+
1785
+ # Compute XprimeX_inv
1786
+ if XprimeX_inv is None:
1787
+ sw_sqrt_X = np.sqrt(sample_weights)[:, None] * X
1788
+ XprimeX = sw_sqrt_X.T @ sw_sqrt_X
1789
+ XprimeX_inv = _r_ginv(XprimeX)
1790
+
1791
+ # Step 2: GLM initialization (or warm start)
1792
+ if init_params is not None:
1793
+ init_params = np.asarray(init_params, dtype=float)
1794
+ if len(init_params) != X.shape[1]:
1795
+ raise ValueError(
1796
+ f"init_params length {len(init_params)} != {X.shape[1]} covariates. "
1797
+ f"Ensure init_params matches the number of columns in the design matrix."
1798
+ )
1799
+ # init_params is never modified in-place downstream;
1800
+ # _vmmin_bfgs copies internally, _compute_diagnostics is read-only
1801
+ beta_init = init_params
1802
+ beta_glm = init_params
1803
+ else:
1804
+ gmm_loss_func_for_init = lambda b: _gmm_loss(b, X, treat, sample_weights, att, None)
1805
+ beta_init, beta_glm = _glm_init(
1806
+ treat, X, sample_weights, att, gmm_loss_func_for_init
1807
+ )
1808
+
1809
+ # Step 3: Pre-compute inverse covariance matrix (for two-step GMM)
1810
+ gmm_init = beta_init
1811
+ gmm_result_init = _gmm_func(gmm_init, X, treat, sample_weights, att, inv_V=None)
1812
+ this_inv_V = gmm_result_init['inv_V']
1813
+
1814
+ # Configure logging from verbose parameter (backward compatibility)
1815
+ if verbose >= 2:
1816
+ set_verbosity(2)
1817
+ elif verbose >= 1:
1818
+ set_verbosity(1)
1819
+
1820
+ # Step 4: Balance loss optimization for initial values
1821
+ logger.info(f"Starting balance optimization (max_iter={iterations})...")
1822
+
1823
+ opt_bal = _optimize_balance(
1824
+ gmm_init, X, treat, sample_weights, XprimeX_inv, att,
1825
+ two_step, iterations, bal_only=bal_only, show_progress=show_progress, **kwargs
1826
+ )
1827
+
1828
+ logger.info(f"Balance optimization complete: loss={opt_bal.fun:.6f}, converged={opt_bal.success}")
1829
+ beta_bal = opt_bal.x # Extract balance-optimized coefficients
1830
+
1831
+ # Step 5: GMM optimization (for method='over') or exact moment solving
1832
+ if bal_only:
1833
+ # For just-identified GMM, user can choose:
1834
+ # - theoretical_exact=True: Use equation solver (precision ~1e-15)
1835
+ # - theoretical_exact=False: Use balance loss (R-compatible, precision ~1e-6)
1836
+
1837
+ use_theoretical_exact = kwargs.get('theoretical_exact', False)
1838
+
1839
+ if use_theoretical_exact:
1840
+ # Direct moment equation solving (theoretically correct)
1841
+ beta_opt, root_success, moments_final, solver_method = _solve_moment_equations(
1842
+ beta_bal, # Use balance-optimized result as initial value
1843
+ X, treat, sample_weights, att, n, iterations
1844
+ )
1845
+
1846
+ max_moment = np.max(np.abs(moments_final))
1847
+
1848
+ if root_success:
1849
+ if max_moment < 1e-8:
1850
+ # Perfect convergence to theoretical precision
1851
+ pass
1852
+ else:
1853
+ # Solver converged but moment not satisfied (rare)
1854
+ warnings.warn(
1855
+ f"theoretical_exact=True: Equation solver converged but moment={max_moment:.2e}, "
1856
+ f"below theoretical requirement <1e-10. Consider better variable preprocessing.",
1857
+ UserWarning
1858
+ )
1859
+ else:
1860
+ # Equation solver failed, fall back to balance optimization
1861
+ warnings.warn(
1862
+ f"theoretical_exact=True: Equation solver failed ({solver_method}), "
1863
+ f"falling back to balance loss optimization result.",
1864
+ UserWarning
1865
+ )
1866
+ beta_opt = beta_bal
1867
+
1868
+ # Update opt1 object for interface compatibility
1869
+ opt1 = opt_bal
1870
+ opt1.x = beta_opt
1871
+ else:
1872
+ # R-compatible implementation: balance loss optimization
1873
+ opt1 = opt_bal
1874
+
1875
+ # Check moment convergence
1876
+ moments_final = _compute_moment_conditions(
1877
+ opt1.x, X, treat, sample_weights, att, n
1878
+ )
1879
+ max_moment = np.max(np.abs(moments_final))
1880
+
1881
+ # Note: For method='exact', the J-statistic is computed using over-identified
1882
+ # GMM conditions (score + balance). This means J > 0 even for just-identified
1883
+ # models, reflecting the degree to which score conditions are violated.
1884
+
1885
+ if max_moment > 1e-6:
1886
+ warnings.warn(
1887
+ f"method='exact': Moment conditions converged to {max_moment:.2e}, "
1888
+ f"below theoretical GMM precision <1e-10. This is a known limitation "
1889
+ f"of balance loss optimization.\n"
1890
+ f"For exact moment=0 satisfaction (~1e-15 precision), "
1891
+ f"use theoretical_exact=True in CBPS() call.",
1892
+ UserWarning
1893
+ )
1894
+ else:
1895
+ logger.info("Starting GMM optimization with dual initialization...")
1896
+
1897
+ opt1 = _optimize_gmm_dual_init(
1898
+ gmm_init, beta_bal, X, treat, sample_weights, att,
1899
+ this_inv_V, two_step, iterations, show_progress=show_progress, **kwargs
1900
+ )
1901
+
1902
+ logger.info(f"GMM optimization complete: J={opt1.fun:.6f}, converged={opt1.success}")
1903
+
1904
+ # Step 6: Final probabilities and weights
1905
+ beta_opt = opt1.x
1906
+ probs_opt, w_opt = _compute_final_weights(
1907
+ beta_opt, X, treat, sample_weights, att, standardize
1908
+ )
1909
+
1910
+ # Step 7: Compute J-statistic, deviance, and null deviance
1911
+ J_opt, mle_J, deviance, nulldeviance = _compute_diagnostics(
1912
+ beta_opt, beta_glm, probs_opt, treat, sample_weights,
1913
+ att, two_step, this_inv_V, X
1914
+ )
1915
+
1916
+ # Note: For method='exact', the J-statistic is computed using over-identified
1917
+ # GMM conditions. Theoretically, J should be 0 for just-identified models,
1918
+ # but the full GMM conditions provide a useful diagnostic.
1919
+
1920
+ # Step 8: Variance-covariance matrix
1921
+ vcov = _compute_vcov(
1922
+ beta_opt, probs_opt, treat, X, sample_weights, att,
1923
+ bal_only, XprimeX_inv, this_inv_V, two_step, n
1924
+ )
1925
+
1926
+ # Step 9: Construct return dictionary
1927
+ output = {
1928
+ 'coefficients': beta_opt.reshape(-1, 1), # (k, 1) column vector
1929
+ 'fitted_values': probs_opt,
1930
+ 'linear_predictor': X @ beta_opt,
1931
+ 'deviance': deviance,
1932
+ 'nulldeviance': nulldeviance,
1933
+ 'weights': w_opt,
1934
+ 'y': treat,
1935
+ 'x': X_orig,
1936
+ 'converged': opt1.success,
1937
+ 'J': J_opt,
1938
+ 'var': vcov,
1939
+ 'mle_J': mle_J
1940
+ }
1941
+
1942
+ return output
1943
+