cbps 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbps/__init__.py +3462 -0
- cbps/constants.py +46 -0
- cbps/core/__init__.py +93 -0
- cbps/core/cbps_binary.py +1943 -0
- cbps/core/cbps_continuous.py +945 -0
- cbps/core/cbps_multitreat.py +1123 -0
- cbps/core/cbps_optimal.py +507 -0
- cbps/core/results.py +1447 -0
- cbps/data/Blackwell.csv +571 -0
- cbps/data/LaLonde.csv +3213 -0
- cbps/data/npcbps_continuous_sim.csv +501 -0
- cbps/data/nsw.csv +723 -0
- cbps/data/nsw_dw.csv +446 -0
- cbps/data/political_ads_urban_niebler.csv +16266 -0
- cbps/data/psid_controls.csv +2491 -0
- cbps/data/psid_controls2.csv +254 -0
- cbps/data/psid_controls3.csv +129 -0
- cbps/data/simulation_dgp1_seed12345.csv +201 -0
- cbps/data/simulation_dgp2_seed12345.csv +201 -0
- cbps/data/simulation_dgp3_seed12345.csv +201 -0
- cbps/data/simulation_dgp4_seed12345.csv +201 -0
- cbps/datasets/__init__.py +78 -0
- cbps/datasets/blackwell.py +112 -0
- cbps/datasets/continuous.py +223 -0
- cbps/datasets/lalonde.py +272 -0
- cbps/datasets/npcbps_sim.py +101 -0
- cbps/diagnostics/__init__.py +101 -0
- cbps/diagnostics/balance.py +760 -0
- cbps/diagnostics/balance_cbmsm_addon.py +162 -0
- cbps/diagnostics/continuous_diagnostics.py +259 -0
- cbps/diagnostics/normality.py +173 -0
- cbps/diagnostics/ocbps_conditions.py +197 -0
- cbps/diagnostics/overlap.py +198 -0
- cbps/diagnostics/plots.py +1193 -0
- cbps/diagnostics/weights_diag.py +205 -0
- cbps/highdim/__init__.py +84 -0
- cbps/highdim/gmm_loss.py +340 -0
- cbps/highdim/hdcbps.py +1078 -0
- cbps/highdim/lasso_utils.py +498 -0
- cbps/highdim/weight_funcs.py +298 -0
- cbps/inference/__init__.py +42 -0
- cbps/inference/asyvar.py +621 -0
- cbps/inference/vcov_outcome.py +217 -0
- cbps/iv/__init__.py +48 -0
- cbps/iv/cbiv.py +2603 -0
- cbps/logging_config.py +45 -0
- cbps/msm/__init__.py +45 -0
- cbps/msm/cbmsm.py +1871 -0
- cbps/msm/rank_diagnostics.py +112 -0
- cbps/nonparametric/__init__.py +58 -0
- cbps/nonparametric/cholesky_whitening.py +232 -0
- cbps/nonparametric/empirical_likelihood.py +339 -0
- cbps/nonparametric/npcbps.py +1036 -0
- cbps/nonparametric/taylor_approx.py +207 -0
- cbps/py.typed +0 -0
- cbps/sklearn/__init__.py +42 -0
- cbps/sklearn/estimator.py +378 -0
- cbps/utils/__init__.py +82 -0
- cbps/utils/formula.py +415 -0
- cbps/utils/helpers.py +378 -0
- cbps/utils/numerics.py +438 -0
- cbps/utils/r_compat.py +109 -0
- cbps/utils/validation.py +224 -0
- cbps/utils/variance_transform.py +483 -0
- cbps/utils/weights.py +586 -0
- cbps-0.2.0.dist-info/METADATA +1090 -0
- cbps-0.2.0.dist-info/RECORD +70 -0
- cbps-0.2.0.dist-info/WHEEL +5 -0
- cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
- cbps-0.2.0.dist-info/top_level.txt +1 -0
cbps/core/cbps_binary.py
ADDED
|
@@ -0,0 +1,1943 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Binary Treatment Covariate Balancing Propensity Score
|
|
3
|
+
|
|
4
|
+
This module implements the CBPS algorithm for binary treatments, supporting
|
|
5
|
+
both exactly-identified and over-identified generalized method of moments
|
|
6
|
+
(GMM) estimation.
|
|
7
|
+
|
|
8
|
+
The covariate balancing propensity score (CBPS) methodology estimates
|
|
9
|
+
propensity scores that optimize covariate balance while maintaining good
|
|
10
|
+
prediction of treatment assignment.
|
|
11
|
+
|
|
12
|
+
References
|
|
13
|
+
----------
|
|
14
|
+
Imai, K. and Ratkovic, M. (2014). Covariate Balancing Propensity Score.
|
|
15
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import warnings
|
|
19
|
+
from typing import Dict, Optional, Tuple, Callable
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import scipy.linalg
|
|
23
|
+
import scipy.special
|
|
24
|
+
import scipy.optimize
|
|
25
|
+
import statsmodels.api as sm
|
|
26
|
+
from statsmodels.genmod.families import Binomial
|
|
27
|
+
|
|
28
|
+
from ..utils.weights import standardize_weights
|
|
29
|
+
from ..utils.helpers import normalize_sample_weights
|
|
30
|
+
from ..utils.numerics import r_ginv_with_diagnostics
|
|
31
|
+
from ..utils.validation import ensure_dense, validate_cbps_input
|
|
32
|
+
from ..constants import DEFAULT_CONFIG
|
|
33
|
+
from ..logging_config import logger, set_verbosity
|
|
34
|
+
|
|
35
|
+
# Constants (sourced from unified NumericalConfig)
|
|
36
|
+
PROBS_MIN = DEFAULT_CONFIG.probs_min # Minimum probability clipping threshold
|
|
37
|
+
|
|
38
|
+
# att parameter normalization: support string and integer
|
|
39
|
+
_ATT_MAP = {'ate': 0, 'att': 1, 'atc': 2, 0: 0, 1: 1, 2: 2}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _normalize_att(att):
|
|
43
|
+
"""Normalize att parameter to integer.
|
|
44
|
+
|
|
45
|
+
Supports both string ('ate', 'att', 'atc') and integer (0, 1, 2) inputs.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
att : str or int
|
|
50
|
+
Target estimand specification.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
int
|
|
55
|
+
Normalized integer value (0, 1, or 2).
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
ValueError
|
|
60
|
+
If att is not a valid string or integer value.
|
|
61
|
+
"""
|
|
62
|
+
if isinstance(att, str):
|
|
63
|
+
att_lower = att.lower().strip()
|
|
64
|
+
if att_lower not in _ATT_MAP:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Invalid att='{att}'. Use 'ate', 'att', 'atc' or 0, 1, 2."
|
|
67
|
+
)
|
|
68
|
+
return _ATT_MAP[att_lower]
|
|
69
|
+
if att not in (0, 1, 2):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Invalid att={att}. Use 'ate'(0), 'att'(1), 'atc'(2)."
|
|
72
|
+
)
|
|
73
|
+
return int(att)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _r_ginv(X: np.ndarray, tol: float = None) -> np.ndarray:
|
|
77
|
+
"""
|
|
78
|
+
Compute the Moore-Penrose pseudoinverse with numerical stability.
|
|
79
|
+
|
|
80
|
+
This function computes the pseudoinverse using a tolerance based on the
|
|
81
|
+
square root of machine epsilon to determine which singular values to retain.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
X : np.ndarray
|
|
86
|
+
Input matrix for which to compute the pseudoinverse.
|
|
87
|
+
tol : float, optional
|
|
88
|
+
Tolerance parameter for singular value selection. Default is
|
|
89
|
+
sqrt(machine_epsilon) ≈ 1.49e-08. Singular values d are kept if
|
|
90
|
+
d > tol * max(d).
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
np.ndarray
|
|
95
|
+
The Moore-Penrose pseudoinverse of X.
|
|
96
|
+
|
|
97
|
+
Notes
|
|
98
|
+
-----
|
|
99
|
+
The implementation follows a three-branch logic:
|
|
100
|
+
1. If all singular values are positive: compute full pseudoinverse
|
|
101
|
+
2. If no singular values are positive: return zero matrix
|
|
102
|
+
3. If some singular values are positive: compute partial pseudoinverse
|
|
103
|
+
"""
|
|
104
|
+
# Default tolerance: sqrt(machine epsilon) ≈ 1.49e-08
|
|
105
|
+
if tol is None:
|
|
106
|
+
machine_eps = np.finfo(float).eps # 2.220446049250313e-16
|
|
107
|
+
tol = np.sqrt(machine_eps) # ≈ 1.490116119384766e-08
|
|
108
|
+
|
|
109
|
+
# Compute reduced SVD decomposition (matches R's svd() default behavior)
|
|
110
|
+
# For X with shape (m, n):
|
|
111
|
+
# - U has shape (m, min(m,n))
|
|
112
|
+
# - d has shape (min(m,n),)
|
|
113
|
+
# - Vt has shape (min(m,n), n)
|
|
114
|
+
Xsvd_u, Xsvd_d, Xsvd_vt = np.linalg.svd(X, full_matrices=False)
|
|
115
|
+
Xsvd_v = Xsvd_vt.T # NumPy returns V^T, transpose to get V
|
|
116
|
+
|
|
117
|
+
# If no singular values or maximum is extremely small (< machine eps),
|
|
118
|
+
# return zero matrix to avoid numerical amplification
|
|
119
|
+
if len(Xsvd_d) == 0 or Xsvd_d[0] < np.finfo(float).eps:
|
|
120
|
+
return np.zeros((X.shape[1], X.shape[0]))
|
|
121
|
+
|
|
122
|
+
# Determine which singular values to retain: d > max(tol * d[0], 0)
|
|
123
|
+
# This matches R's MASS::ginv tolerance formula
|
|
124
|
+
tol_threshold = max(tol * Xsvd_d[0], 0.0)
|
|
125
|
+
Positive = Xsvd_d > tol_threshold
|
|
126
|
+
|
|
127
|
+
# Compute pseudoinverse based on retained singular values
|
|
128
|
+
# Formula: X+ = V @ diag(1/d) @ U.T (for retained singular values)
|
|
129
|
+
if np.all(Positive):
|
|
130
|
+
# All singular values retained: V @ diag(1/d) @ U.T
|
|
131
|
+
X_pinv = Xsvd_v @ np.diag(1.0 / Xsvd_d) @ Xsvd_u.T
|
|
132
|
+
elif not np.any(Positive):
|
|
133
|
+
# All singular values truncated: return zero matrix
|
|
134
|
+
X_pinv = np.zeros((X.shape[1], X.shape[0]))
|
|
135
|
+
else:
|
|
136
|
+
# Partial retention: V[:, pos] @ diag(1/d[pos]) @ U[:, pos].T
|
|
137
|
+
Xsvd_v_pos = Xsvd_v[:, Positive]
|
|
138
|
+
Xsvd_d_pos = Xsvd_d[Positive]
|
|
139
|
+
Xsvd_u_pos = Xsvd_u[:, Positive]
|
|
140
|
+
X_pinv = Xsvd_v_pos @ np.diag(1.0 / Xsvd_d_pos) @ Xsvd_u_pos.T
|
|
141
|
+
|
|
142
|
+
return X_pinv
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _att_wt_func(
|
|
146
|
+
beta_curr: np.ndarray,
|
|
147
|
+
X: np.ndarray,
|
|
148
|
+
treat: np.ndarray,
|
|
149
|
+
sample_weights: np.ndarray
|
|
150
|
+
) -> np.ndarray:
|
|
151
|
+
"""
|
|
152
|
+
Compute Average Treatment Effect on the Treated (ATT) weights.
|
|
153
|
+
|
|
154
|
+
This function implements the ATT weight function that assigns weights
|
|
155
|
+
to observations based on their estimated propensity scores. The weights
|
|
156
|
+
are constructed to balance covariates between treated and control groups.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
beta_curr : np.ndarray
|
|
161
|
+
Current coefficient estimates, shape (k,).
|
|
162
|
+
X : np.ndarray
|
|
163
|
+
Covariate matrix including intercept, shape (n, k).
|
|
164
|
+
treat : np.ndarray
|
|
165
|
+
Binary treatment indicator (0/1), shape (n,).
|
|
166
|
+
sample_weights : np.ndarray
|
|
167
|
+
Normalized sampling weights summing to n, shape (n,).
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
np.ndarray
|
|
172
|
+
ATT weights, possibly containing negative values for control units.
|
|
173
|
+
Shape (n,). The calling function should take absolute values.
|
|
174
|
+
|
|
175
|
+
Notes
|
|
176
|
+
-----
|
|
177
|
+
The ATT weight formula is:
|
|
178
|
+
w_i = (n/n_t) * (T_i - π_i) / (1 - π_i)
|
|
179
|
+
|
|
180
|
+
where n_t is the weighted sum of treated units and π_i is the estimated
|
|
181
|
+
propensity score. Treated units receive positive weights while control
|
|
182
|
+
units receive negative weights, reflecting the ATT estimand.
|
|
183
|
+
"""
|
|
184
|
+
# Compute weighted sample sizes
|
|
185
|
+
n_c = np.sum(sample_weights[treat == 0])
|
|
186
|
+
n_t = np.sum(sample_weights[treat == 1])
|
|
187
|
+
n = n_c + n_t
|
|
188
|
+
|
|
189
|
+
# Compute propensity scores
|
|
190
|
+
theta_curr = X @ beta_curr
|
|
191
|
+
probs_curr = scipy.special.expit(theta_curr)
|
|
192
|
+
|
|
193
|
+
# Clip probabilities to avoid numerical instability
|
|
194
|
+
probs_curr = np.minimum(1 - PROBS_MIN, probs_curr)
|
|
195
|
+
probs_curr = np.maximum(PROBS_MIN, probs_curr)
|
|
196
|
+
|
|
197
|
+
# ATT weight formula: w = (n/n_t) * (T - pi) / (1 - pi)
|
|
198
|
+
w1 = (n / n_t) * (treat - probs_curr) / (1 - probs_curr)
|
|
199
|
+
|
|
200
|
+
return w1
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _compute_V_matrix(
|
|
204
|
+
X: np.ndarray,
|
|
205
|
+
probs_curr: np.ndarray,
|
|
206
|
+
sample_weights: np.ndarray,
|
|
207
|
+
treat: np.ndarray,
|
|
208
|
+
att: int,
|
|
209
|
+
n: int
|
|
210
|
+
) -> np.ndarray:
|
|
211
|
+
"""
|
|
212
|
+
Compute the covariance matrix V for GMM estimation.
|
|
213
|
+
|
|
214
|
+
This function computes the covariance matrix of moment conditions
|
|
215
|
+
used in the generalized method of moments (GMM) estimation of CBPS.
|
|
216
|
+
The matrix structure differs between ATE and ATT estimation.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
X : np.ndarray
|
|
221
|
+
Covariate matrix, shape (n, k).
|
|
222
|
+
probs_curr : np.ndarray
|
|
223
|
+
Current propensity score estimates, shape (n,).
|
|
224
|
+
sample_weights : np.ndarray
|
|
225
|
+
Normalized sampling weights, shape (n,).
|
|
226
|
+
treat : np.ndarray
|
|
227
|
+
Binary treatment indicator, shape (n,).
|
|
228
|
+
att : int
|
|
229
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
230
|
+
n : int
|
|
231
|
+
Number of observations.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
np.ndarray
|
|
236
|
+
The Moore-Penrose pseudoinverse of the covariance matrix V,
|
|
237
|
+
shape (2k, 2k) where k is the number of covariates.
|
|
238
|
+
|
|
239
|
+
Notes
|
|
240
|
+
-----
|
|
241
|
+
The V matrix has a 2x2 block structure combining score and balance
|
|
242
|
+
moment conditions. For ATT estimation, the matrix includes scaling
|
|
243
|
+
factors involving the ratio of treated to total observations.
|
|
244
|
+
"""
|
|
245
|
+
sw_sqrt = np.sqrt(sample_weights)
|
|
246
|
+
|
|
247
|
+
if att:
|
|
248
|
+
# ATT: weighted covariate matrices with propensity score scaling
|
|
249
|
+
X_1 = sw_sqrt[:, None] * X * np.sqrt(probs_curr * (1 - probs_curr))[:, None]
|
|
250
|
+
X_2 = sw_sqrt[:, None] * X * np.sqrt(probs_curr / (1 - probs_curr))[:, None]
|
|
251
|
+
X_1_1 = sw_sqrt[:, None] * X * np.sqrt(probs_curr)[:, None]
|
|
252
|
+
|
|
253
|
+
# Block covariance matrix with ATT scaling factors
|
|
254
|
+
n_treat = np.sum(treat == 1)
|
|
255
|
+
V11 = (1 / n) * (X_1.T @ X_1) * n / n_treat
|
|
256
|
+
V12 = (1 / n) * (X_1_1.T @ X_1_1) * n / n_treat
|
|
257
|
+
V21 = V12 # Symmetric
|
|
258
|
+
V22 = (1 / n) * (X_2.T @ X_2) * n**2 / n_treat**2
|
|
259
|
+
else:
|
|
260
|
+
# ATE: weighted covariate matrices
|
|
261
|
+
X_1 = sw_sqrt[:, None] * X * np.sqrt(probs_curr * (1 - probs_curr))[:, None]
|
|
262
|
+
X_2 = sw_sqrt[:, None] * X / np.sqrt(probs_curr * (1 - probs_curr))[:, None]
|
|
263
|
+
X_1_1 = sw_sqrt[:, None] * X
|
|
264
|
+
|
|
265
|
+
# Block covariance matrix without scaling
|
|
266
|
+
V11 = (1 / n) * (X_1.T @ X_1)
|
|
267
|
+
V12 = (1 / n) * (X_1_1.T @ X_1_1)
|
|
268
|
+
V21 = V12 # Symmetric
|
|
269
|
+
V22 = (1 / n) * (X_2.T @ X_2)
|
|
270
|
+
|
|
271
|
+
# Assemble 2x2 block matrix
|
|
272
|
+
V = np.block([[V11, V12],
|
|
273
|
+
[V21, V22]])
|
|
274
|
+
|
|
275
|
+
# Verify symmetry of covariance matrix
|
|
276
|
+
assert np.allclose(V, V.T, atol=1e-15), "V matrix must be symmetric"
|
|
277
|
+
|
|
278
|
+
# Compute Moore-Penrose pseudoinverse with diagnostics
|
|
279
|
+
inv_V, _v_diagnostics = r_ginv_with_diagnostics(V, warn_threshold=1e12)
|
|
280
|
+
# Diagnostics are emitted as UserWarning when condition number exceeds threshold
|
|
281
|
+
|
|
282
|
+
return inv_V
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _gmm_func(
|
|
286
|
+
beta_curr: np.ndarray,
|
|
287
|
+
X: np.ndarray,
|
|
288
|
+
treat: np.ndarray,
|
|
289
|
+
sample_weights: np.ndarray,
|
|
290
|
+
att: int,
|
|
291
|
+
inv_V: Optional[np.ndarray] = None
|
|
292
|
+
) -> Dict:
|
|
293
|
+
"""
|
|
294
|
+
Compute the GMM objective function and covariance matrix.
|
|
295
|
+
|
|
296
|
+
This function evaluates the generalized method of moments objective
|
|
297
|
+
combining score conditions and covariate balancing conditions.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
beta_curr : np.ndarray
|
|
302
|
+
Current coefficient vector, shape (k,).
|
|
303
|
+
inv_V : np.ndarray or None
|
|
304
|
+
Precomputed inverse covariance matrix. If None, it will be computed.
|
|
305
|
+
|
|
306
|
+
Returns
|
|
307
|
+
-------
|
|
308
|
+
dict
|
|
309
|
+
Dictionary containing:
|
|
310
|
+
- 'loss': float, GMM loss (quadratic form gbar' @ inv_V @ gbar)
|
|
311
|
+
- 'inv_V': ndarray, pseudoinverse of the covariance matrix V
|
|
312
|
+
|
|
313
|
+
Notes
|
|
314
|
+
-----
|
|
315
|
+
When two_step=True, the inverse covariance matrix is precomputed and
|
|
316
|
+
passed in; when two_step=False, it is recomputed at each iteration.
|
|
317
|
+
"""
|
|
318
|
+
n = len(treat)
|
|
319
|
+
|
|
320
|
+
# Compute propensity scores
|
|
321
|
+
theta_curr = X @ beta_curr
|
|
322
|
+
probs_curr = scipy.special.expit(theta_curr)
|
|
323
|
+
|
|
324
|
+
# Clip probabilities for numerical stability
|
|
325
|
+
probs_curr = np.minimum(1 - PROBS_MIN, probs_curr)
|
|
326
|
+
probs_curr = np.maximum(PROBS_MIN, probs_curr)
|
|
327
|
+
probs_curr = probs_curr.ravel()
|
|
328
|
+
|
|
329
|
+
# Compute weights based on estimand type
|
|
330
|
+
if att:
|
|
331
|
+
w_curr = _att_wt_func(beta_curr, X, treat, sample_weights)
|
|
332
|
+
else:
|
|
333
|
+
# ATE weight: 1 / (pi - 1 + T) = T/pi - (1-T)/(1-pi)
|
|
334
|
+
w_curr = 1 / (probs_curr - 1 + treat)
|
|
335
|
+
|
|
336
|
+
# Construct moment conditions
|
|
337
|
+
# Balance condition: weighted covariate means
|
|
338
|
+
w_curr_del = (1 / n) * (sample_weights[:, None] * X).T @ w_curr
|
|
339
|
+
w_curr_del = w_curr_del.ravel()
|
|
340
|
+
|
|
341
|
+
# Combine score and balance conditions
|
|
342
|
+
score_cond = (1 / n) * (sample_weights[:, None] * X).T @ (treat - probs_curr)
|
|
343
|
+
gbar = np.concatenate([score_cond.ravel(), w_curr_del])
|
|
344
|
+
|
|
345
|
+
# Compute covariance matrix if not provided
|
|
346
|
+
if inv_V is None:
|
|
347
|
+
inv_V = _compute_V_matrix(X, probs_curr, sample_weights, treat, att, n)
|
|
348
|
+
|
|
349
|
+
# GMM loss: quadratic form
|
|
350
|
+
loss = float(gbar.T @ inv_V @ gbar)
|
|
351
|
+
|
|
352
|
+
return {'loss': loss, 'inv_V': inv_V}
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _gmm_loss(
|
|
356
|
+
beta: np.ndarray,
|
|
357
|
+
X: np.ndarray,
|
|
358
|
+
treat: np.ndarray,
|
|
359
|
+
sample_weights: np.ndarray,
|
|
360
|
+
att: int,
|
|
361
|
+
inv_V: Optional[np.ndarray]
|
|
362
|
+
) -> float:
|
|
363
|
+
"""
|
|
364
|
+
Compute the GMM objective function value.
|
|
365
|
+
|
|
366
|
+
This function evaluates the generalized method of moments objective
|
|
367
|
+
function that combines the propensity score likelihood and covariate
|
|
368
|
+
balancing conditions.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
beta : np.ndarray
|
|
373
|
+
Coefficient vector, shape (k,).
|
|
374
|
+
X : np.ndarray
|
|
375
|
+
Covariate matrix, shape (n, k).
|
|
376
|
+
treat : np.ndarray
|
|
377
|
+
Binary treatment indicator, shape (n,).
|
|
378
|
+
sample_weights : np.ndarray
|
|
379
|
+
Normalized sampling weights, shape (n,).
|
|
380
|
+
att : int
|
|
381
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
382
|
+
inv_V : np.ndarray, optional
|
|
383
|
+
Precomputed inverse covariance matrix. If None, computes it.
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
float
|
|
388
|
+
The GMM objective function value.
|
|
389
|
+
"""
|
|
390
|
+
return _gmm_func(beta, X, treat, sample_weights, att, inv_V)['loss']
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _gmm_gradient(
|
|
394
|
+
beta_curr: np.ndarray,
|
|
395
|
+
inv_V: np.ndarray,
|
|
396
|
+
X: np.ndarray,
|
|
397
|
+
treat: np.ndarray,
|
|
398
|
+
sample_weights: np.ndarray,
|
|
399
|
+
att: int
|
|
400
|
+
) -> np.ndarray:
|
|
401
|
+
"""
|
|
402
|
+
Compute the analytical gradient of the GMM objective function.
|
|
403
|
+
|
|
404
|
+
This function calculates the analytical gradient of the GMM objective
|
|
405
|
+
with respect to the coefficient vector, following the R CBPS package
|
|
406
|
+
implementation exactly.
|
|
407
|
+
|
|
408
|
+
Parameters
|
|
409
|
+
----------
|
|
410
|
+
beta_curr : np.ndarray
|
|
411
|
+
Current coefficient estimates, shape (k,).
|
|
412
|
+
inv_V : np.ndarray
|
|
413
|
+
Inverse covariance matrix, shape (2k, 2k).
|
|
414
|
+
X : np.ndarray
|
|
415
|
+
Covariate matrix, shape (n, k).
|
|
416
|
+
treat : np.ndarray
|
|
417
|
+
Binary treatment indicator, shape (n,).
|
|
418
|
+
sample_weights : np.ndarray
|
|
419
|
+
Normalized sampling weights, shape (n,).
|
|
420
|
+
att : int
|
|
421
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
np.ndarray
|
|
426
|
+
Gradient vector, shape (k,).
|
|
427
|
+
|
|
428
|
+
Notes
|
|
429
|
+
-----
|
|
430
|
+
The gradient is computed as: grad = 2 * dgbar @ inv_V @ gbar
|
|
431
|
+
|
|
432
|
+
where dgbar is the Jacobian of the moment conditions gbar with respect
|
|
433
|
+
to beta. The formula differs between ATE and ATT estimation.
|
|
434
|
+
|
|
435
|
+
For ATE:
|
|
436
|
+
dgbar = [-1/n * X' * diag(sw * pi * (1-pi)) * X,
|
|
437
|
+
-1/n * X' * diag(sw * (T-pi)^2 / (pi*(1-pi))) * X]
|
|
438
|
+
|
|
439
|
+
For ATT:
|
|
440
|
+
dw = -n/n_t * pi / (1-pi), with dw[treat==1] = 0
|
|
441
|
+
dgbar = [1/n * X' * diag(-sw * pi * (1-pi)) * X,
|
|
442
|
+
1/n * X' * diag(dw * sw) * X]
|
|
443
|
+
|
|
444
|
+
References
|
|
445
|
+
----------
|
|
446
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
447
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
448
|
+
"""
|
|
449
|
+
n = len(treat)
|
|
450
|
+
n_c = np.sum(sample_weights[treat == 0])
|
|
451
|
+
n_t = np.sum(sample_weights[treat == 1])
|
|
452
|
+
|
|
453
|
+
# Compute propensity scores
|
|
454
|
+
theta_curr = X @ beta_curr
|
|
455
|
+
probs_curr = scipy.special.expit(theta_curr)
|
|
456
|
+
probs_curr = np.clip(probs_curr, PROBS_MIN, 1 - PROBS_MIN)
|
|
457
|
+
|
|
458
|
+
# Pre-compute sample_weights * X (used multiple times)
|
|
459
|
+
sw_X = sample_weights[:, None] * X
|
|
460
|
+
|
|
461
|
+
# Compute weights based on estimand type
|
|
462
|
+
if att:
|
|
463
|
+
w_curr = _att_wt_func(beta_curr, X, treat, sample_weights)
|
|
464
|
+
else:
|
|
465
|
+
w_curr = 1 / (probs_curr - 1 + treat)
|
|
466
|
+
|
|
467
|
+
# Compute gbar (moment conditions)
|
|
468
|
+
w_curr_del = (1 / n) * sw_X.T @ w_curr
|
|
469
|
+
w_curr_del = w_curr_del.ravel()
|
|
470
|
+
score_cond = (1 / n) * sw_X.T @ (treat - probs_curr)
|
|
471
|
+
gbar = np.concatenate([score_cond.ravel(), w_curr_del])
|
|
472
|
+
|
|
473
|
+
# Compute dgbar (Jacobian of moment conditions)
|
|
474
|
+
if att:
|
|
475
|
+
# ATT balance gradient computation (Imai & Ratkovic 2014, Eq. 11)
|
|
476
|
+
#
|
|
477
|
+
# ATT weight: w_i = (N/N_1) * (T_i - pi_i) / (1 - pi_i)
|
|
478
|
+
#
|
|
479
|
+
# For T=1 (treated): w_i = N/N_1 (constant), dw/dbeta = 0
|
|
480
|
+
# For T=0 (control): w_i = -(N/N_1) * pi / (1-pi)
|
|
481
|
+
#
|
|
482
|
+
# Gradient via chain rule:
|
|
483
|
+
# dw/dbeta = dw/dpi * dpi/dbeta
|
|
484
|
+
# where:
|
|
485
|
+
# dw/dpi = -(N/N_1) * 1/(1-pi)^2
|
|
486
|
+
# dpi/dbeta = pi*(1-pi) * X (logistic link)
|
|
487
|
+
# therefore:
|
|
488
|
+
# dw/dbeta = -(N/N_1) * [1/(1-pi)^2] * [pi*(1-pi)] * X
|
|
489
|
+
# = -(N/N_1) * pi/(1-pi) * X
|
|
490
|
+
#
|
|
491
|
+
# The Jacobian of the balance condition g_b = (1/n) * X' * diag(dw) * X
|
|
492
|
+
dw = -n / n_t * probs_curr / (1 - probs_curr)
|
|
493
|
+
dw[treat == 1] = 0
|
|
494
|
+
|
|
495
|
+
# Score condition derivative: 1/n * X' * diag(-sw * pi * (1-pi)) * X
|
|
496
|
+
dgbar_score = (1 / n) * (X * (-sample_weights * probs_curr * (1 - probs_curr))[:, None]).T @ X
|
|
497
|
+
|
|
498
|
+
# Balance condition derivative: 1/n * X' * diag(dw * sw) * X
|
|
499
|
+
# Note: R code uses 1/n.t here, but mathematically correct is 1/n.
|
|
500
|
+
# The derivative of gbar_balance = (1/n) * X' * sw * w_ATT w.r.t. beta
|
|
501
|
+
# gives (1/n) * X' * diag(sw * dw) * X. R's 1/n.t has an extra n/n_t
|
|
502
|
+
# factor. BFGS is robust to such gradient scaling, so R still converges.
|
|
503
|
+
# We use the mathematically correct 1/n for better numerical gradient match.
|
|
504
|
+
dgbar_balance = (1 / n) * (X * (dw * sample_weights)[:, None]).T @ X
|
|
505
|
+
else:
|
|
506
|
+
# ATE gradient formula from R code
|
|
507
|
+
# Score condition derivative: -1/n * X' * diag(sw * pi * (1-pi)) * X
|
|
508
|
+
dgbar_score = (-1 / n) * (X * (sample_weights * probs_curr * (1 - probs_curr))[:, None]).T @ X
|
|
509
|
+
|
|
510
|
+
# Balance condition derivative: -1/n * X' * diag(sw * (T-pi)^2 / (pi*(1-pi))) * X
|
|
511
|
+
balance_weight = sample_weights * (treat - probs_curr)**2 / (probs_curr * (1 - probs_curr))
|
|
512
|
+
dgbar_balance = (-1 / n) * (X * balance_weight[:, None]).T @ X
|
|
513
|
+
|
|
514
|
+
# Combine into full Jacobian: dgbar has shape (k, 2k)
|
|
515
|
+
dgbar = np.hstack([dgbar_score, dgbar_balance])
|
|
516
|
+
|
|
517
|
+
# Compute gradient: 2 * dgbar @ inv_V @ gbar
|
|
518
|
+
grad = 2 * dgbar @ inv_V @ gbar
|
|
519
|
+
|
|
520
|
+
return grad
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def _bal_loss(
|
|
524
|
+
beta_curr: np.ndarray,
|
|
525
|
+
X: np.ndarray,
|
|
526
|
+
treat: np.ndarray,
|
|
527
|
+
sample_weights: np.ndarray,
|
|
528
|
+
XprimeX_inv: np.ndarray,
|
|
529
|
+
att: int
|
|
530
|
+
) -> float:
|
|
531
|
+
"""
|
|
532
|
+
Balance loss function (covariate balancing only).
|
|
533
|
+
|
|
534
|
+
This function implements the balance component of the CBPS objective
|
|
535
|
+
function, focusing solely on achieving covariate balance between
|
|
536
|
+
treatment groups without considering prediction of treatment assignment.
|
|
537
|
+
|
|
538
|
+
Parameters
|
|
539
|
+
----------
|
|
540
|
+
beta_curr : np.ndarray
|
|
541
|
+
Current coefficient estimates, shape (k,).
|
|
542
|
+
X : np.ndarray
|
|
543
|
+
Covariate matrix, shape (n, k).
|
|
544
|
+
treat : np.ndarray
|
|
545
|
+
Binary treatment indicator, shape (n,).
|
|
546
|
+
sample_weights : np.ndarray
|
|
547
|
+
Normalized sampling weights, shape (n,).
|
|
548
|
+
XprimeX_inv : np.ndarray
|
|
549
|
+
Inverse of X'X matrix, pre-computed for efficiency, shape (k, k).
|
|
550
|
+
att : int
|
|
551
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
552
|
+
|
|
553
|
+
Returns
|
|
554
|
+
-------
|
|
555
|
+
float
|
|
556
|
+
Balance loss value (absolute quadratic form).
|
|
557
|
+
|
|
558
|
+
Notes
|
|
559
|
+
-----
|
|
560
|
+
Key differences between balance loss and GMM loss:
|
|
561
|
+
|
|
562
|
+
- Balance loss uses absolute value: |ḡ' (X'WX)^{-1} ḡ| where ḡ = X'Ww
|
|
563
|
+
- GMM loss uses quadratic form without absolute value: ḡ' Σ^{-1} ḡ
|
|
564
|
+
- Weight computation includes 1/n scaling factor
|
|
565
|
+
|
|
566
|
+
Here W = diag(sample_weights) is the sample weight matrix.
|
|
567
|
+
"""
|
|
568
|
+
n = len(treat)
|
|
569
|
+
|
|
570
|
+
# Compute propensity scores with numerical clipping
|
|
571
|
+
theta_curr = X @ beta_curr
|
|
572
|
+
probs_curr = scipy.special.expit(theta_curr)
|
|
573
|
+
probs_curr = np.clip(probs_curr, PROBS_MIN, 1 - PROBS_MIN)
|
|
574
|
+
|
|
575
|
+
# Compute weights with 1/n scaling factor
|
|
576
|
+
if att:
|
|
577
|
+
w_curr = (1 / n) * _att_wt_func(beta_curr, X, treat, sample_weights)
|
|
578
|
+
else:
|
|
579
|
+
w_curr = (1 / n) * (1 / (probs_curr - 1 + treat))
|
|
580
|
+
|
|
581
|
+
# Compute weighted covariate sum
|
|
582
|
+
Xprimew = (sample_weights[:, None] * X).T @ w_curr # (k,) vector
|
|
583
|
+
|
|
584
|
+
# Balance loss: quadratic form with absolute value
|
|
585
|
+
loss = np.abs(Xprimew.T @ XprimeX_inv @ Xprimew)
|
|
586
|
+
|
|
587
|
+
return float(loss)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def _bal_gradient(
|
|
591
|
+
beta_curr: np.ndarray,
|
|
592
|
+
X: np.ndarray,
|
|
593
|
+
treat: np.ndarray,
|
|
594
|
+
sample_weights: np.ndarray,
|
|
595
|
+
XprimeX_inv: np.ndarray,
|
|
596
|
+
att: int
|
|
597
|
+
) -> np.ndarray:
|
|
598
|
+
"""
|
|
599
|
+
Analytical gradient of the balance loss function.
|
|
600
|
+
|
|
601
|
+
This function computes the analytical gradient of the balance component
|
|
602
|
+
of the CBPS objective function, following the R CBPS package implementation
|
|
603
|
+
exactly. The use of analytical gradient is critical because the balance
|
|
604
|
+
loss contains an absolute value function.
|
|
605
|
+
|
|
606
|
+
Parameters
|
|
607
|
+
----------
|
|
608
|
+
beta_curr : np.ndarray
|
|
609
|
+
Current coefficient estimates, shape (k,).
|
|
610
|
+
X : np.ndarray
|
|
611
|
+
Covariate matrix, shape (n, k).
|
|
612
|
+
treat : np.ndarray
|
|
613
|
+
Binary treatment indicator, shape (n,).
|
|
614
|
+
sample_weights : np.ndarray
|
|
615
|
+
Normalized sampling weights, shape (n,).
|
|
616
|
+
XprimeX_inv : np.ndarray
|
|
617
|
+
Inverse of X'X matrix, shape (k, k).
|
|
618
|
+
att : int
|
|
619
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
620
|
+
|
|
621
|
+
Returns
|
|
622
|
+
-------
|
|
623
|
+
np.ndarray
|
|
624
|
+
Gradient vector, shape (k,).
|
|
625
|
+
|
|
626
|
+
Notes
|
|
627
|
+
-----
|
|
628
|
+
The R implementation uses a sign adjustment to handle the absolute value:
|
|
629
|
+
|
|
630
|
+
out = sapply(2*dw%*%X%*%XprimeX.inv%*%Xprimew,
|
|
631
|
+
function(x) ifelse((x>0 & loss1>0) | (x<0 & loss1<0),
|
|
632
|
+
abs(x), -abs(x)))
|
|
633
|
+
|
|
634
|
+
This ensures the gradient points in the correct direction for minimizing
|
|
635
|
+
the absolute value of the quadratic form.
|
|
636
|
+
|
|
637
|
+
For ATE:
|
|
638
|
+
dw = 1/n * t(-X * (T-pi)^2 / (pi*(1-pi)))
|
|
639
|
+
|
|
640
|
+
For ATT:
|
|
641
|
+
dw2 = -n/n_t * pi / (1-pi), with dw2[treat==1] = 0
|
|
642
|
+
dw = 1/n * t(X * dw2)
|
|
643
|
+
|
|
644
|
+
References
|
|
645
|
+
----------
|
|
646
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
647
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
648
|
+
"""
|
|
649
|
+
n = len(treat)
|
|
650
|
+
n_c = np.sum(sample_weights[treat == 0])
|
|
651
|
+
n_t = np.sum(sample_weights[treat == 1])
|
|
652
|
+
|
|
653
|
+
# Compute propensity scores
|
|
654
|
+
theta_curr = X @ beta_curr
|
|
655
|
+
probs_curr = scipy.special.expit(theta_curr)
|
|
656
|
+
probs_curr = np.clip(probs_curr, PROBS_MIN, 1 - PROBS_MIN)
|
|
657
|
+
|
|
658
|
+
# Compute weights with 1/n scaling factor
|
|
659
|
+
if att:
|
|
660
|
+
w_curr = (1 / n) * _att_wt_func(beta_curr, X, treat, sample_weights)
|
|
661
|
+
else:
|
|
662
|
+
w_curr = (1 / n) * (1 / (probs_curr - 1 + treat))
|
|
663
|
+
|
|
664
|
+
# Compute dw (derivative of weights with respect to beta)
|
|
665
|
+
if att:
|
|
666
|
+
# ATT: dw2 = -n/n_t * pi / (1-pi), with dw2[treat==1] = 0
|
|
667
|
+
dw2 = -n / n_t * probs_curr / (1 - probs_curr)
|
|
668
|
+
dw2[treat == 1] = 0
|
|
669
|
+
# dw has shape (k, n): dw = 1/n * X.T * dw2
|
|
670
|
+
dw = (1 / n) * (X * dw2[:, None]).T
|
|
671
|
+
else:
|
|
672
|
+
# ATE: dw = 1/n * t(-X * (T-pi)^2 / (pi*(1-pi)))
|
|
673
|
+
dw_weight = -(treat - probs_curr)**2 / (probs_curr * (1 - probs_curr))
|
|
674
|
+
dw = (1 / n) * (X * dw_weight[:, None]).T
|
|
675
|
+
|
|
676
|
+
# Compute Xprimew = X' @ (w_curr * sample_weights)
|
|
677
|
+
Xprimew = X.T @ (w_curr * sample_weights) # shape (k,)
|
|
678
|
+
|
|
679
|
+
# Compute loss1 = Xprimew' @ XprimeX_inv @ Xprimew (scalar)
|
|
680
|
+
loss1 = Xprimew.T @ XprimeX_inv @ Xprimew
|
|
681
|
+
|
|
682
|
+
# Compute raw gradient: 2 * dw @ X @ XprimeX_inv @ Xprimew
|
|
683
|
+
# dw has shape (k, n), X has shape (n, k)
|
|
684
|
+
# dw @ X has shape (k, k)
|
|
685
|
+
# (dw @ X @ XprimeX_inv @ Xprimew) has shape (k,)
|
|
686
|
+
raw_grad = 2 * dw @ X @ XprimeX_inv @ Xprimew
|
|
687
|
+
|
|
688
|
+
# Apply sign adjustment for absolute value (R's sapply logic)
|
|
689
|
+
# ifelse((x>0 & loss1>0) | (x<0 & loss1<0), abs(x), -abs(x))
|
|
690
|
+
# This means: if x and loss1 have the same sign, use abs(x); otherwise use -abs(x)
|
|
691
|
+
grad = np.where(
|
|
692
|
+
((raw_grad > 0) & (loss1 > 0)) | ((raw_grad < 0) & (loss1 < 0)),
|
|
693
|
+
np.abs(raw_grad),
|
|
694
|
+
-np.abs(raw_grad)
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
return grad
|
|
698
|
+
|
|
699
|
+
def _vmmin_bfgs(
|
|
700
|
+
b0: np.ndarray,
|
|
701
|
+
fn: Callable,
|
|
702
|
+
gr: Optional[Callable],
|
|
703
|
+
maxit: int = 10000,
|
|
704
|
+
abstol: float = -np.inf,
|
|
705
|
+
reltol: float = np.sqrt(np.finfo(float).eps),
|
|
706
|
+
trace: bool = False,
|
|
707
|
+
nREPORT: int = 10,
|
|
708
|
+
show_progress: bool = False,
|
|
709
|
+
) -> scipy.optimize.OptimizeResult:
|
|
710
|
+
"""
|
|
711
|
+
R's vmmin BFGS optimizer, faithfully translated from C source.
|
|
712
|
+
|
|
713
|
+
This is a line-by-line translation of R's ``vmmin`` function from
|
|
714
|
+
``src/appl/optim.c`` (the backend of ``optim(..., method="BFGS")``).
|
|
715
|
+
It uses a simple Armijo backtracking line search and a relative-
|
|
716
|
+
tolerance convergence criterion, which differ fundamentally from
|
|
717
|
+
scipy's Strong-Wolfe / gradient-norm approach.
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
b0 : np.ndarray
|
|
722
|
+
Initial parameter vector, shape (n,).
|
|
723
|
+
fn : callable
|
|
724
|
+
Objective function ``fn(b) -> float``.
|
|
725
|
+
gr : callable or None
|
|
726
|
+
Gradient function ``gr(b) -> np.ndarray`` of shape (n,).
|
|
727
|
+
If None, uses forward finite differences with step size 1e-3,
|
|
728
|
+
matching R's ``optim`` default behavior (``fmingr`` in optim.c).
|
|
729
|
+
maxit : int
|
|
730
|
+
Maximum number of BFGS iterations (default 10000, R default).
|
|
731
|
+
abstol : float
|
|
732
|
+
Absolute tolerance on function value (default -inf, R default).
|
|
733
|
+
reltol : float
|
|
734
|
+
Relative tolerance on function value change
|
|
735
|
+
(default ``sqrt(eps) ≈ 1.49e-8``, R default).
|
|
736
|
+
trace : bool
|
|
737
|
+
If True, print iteration information (default False).
|
|
738
|
+
nREPORT : int
|
|
739
|
+
Report every *nREPORT* iterations when *trace* is True.
|
|
740
|
+
|
|
741
|
+
Returns
|
|
742
|
+
-------
|
|
743
|
+
scipy.optimize.OptimizeResult
|
|
744
|
+
Result object with fields ``x``, ``fun``, ``nit``, ``nfev``,
|
|
745
|
+
``njev``, ``success``, ``message``.
|
|
746
|
+
|
|
747
|
+
Notes
|
|
748
|
+
-----
|
|
749
|
+
Constants hard-coded to match R exactly:
|
|
750
|
+
|
|
751
|
+
* ``stepredn = 0.2`` – step reduction factor in backtracking
|
|
752
|
+
* ``acctol = 0.0001`` – Armijo sufficient-decrease parameter
|
|
753
|
+
* ``reltest = 10.0`` – used to detect "no change" in parameters
|
|
754
|
+
|
|
755
|
+
Convergence criterion (R ``reltol``):
|
|
756
|
+
``|f_new - f_old| > reltol * (|f_old| + reltol)``
|
|
757
|
+
|
|
758
|
+
References
|
|
759
|
+
----------
|
|
760
|
+
J.C. Nash, *Compact Numerical Methods for Computers*, 2nd ed.
|
|
761
|
+
R Core Team, ``src/appl/optim.c`` (vmmin).
|
|
762
|
+
"""
|
|
763
|
+
# If no analytical gradient provided, use R's default numerical gradient.
|
|
764
|
+
# R's fmingr in optim.c uses forward finite differences with ndeps=1e-3.
|
|
765
|
+
if gr is None:
|
|
766
|
+
_ndeps = DEFAULT_CONFIG.ndeps
|
|
767
|
+
def gr(b):
|
|
768
|
+
"""Forward finite difference gradient, matching R's fmingr."""
|
|
769
|
+
f0 = fn(b)
|
|
770
|
+
g = np.empty_like(b)
|
|
771
|
+
for i in range(len(b)):
|
|
772
|
+
b_pert = b.copy()
|
|
773
|
+
b_pert[i] += _ndeps
|
|
774
|
+
g[i] = (fn(b_pert) - f0) / _ndeps
|
|
775
|
+
return g
|
|
776
|
+
# ---- constants (must match R exactly) ----
|
|
777
|
+
STEPREDN = 0.2
|
|
778
|
+
ACCTOL = 0.0001
|
|
779
|
+
RELTEST = 10.0
|
|
780
|
+
|
|
781
|
+
# Optional tqdm progress bar (soft dependency)
|
|
782
|
+
pbar = None
|
|
783
|
+
if show_progress:
|
|
784
|
+
try:
|
|
785
|
+
from tqdm import tqdm
|
|
786
|
+
pbar = tqdm(total=maxit, desc="BFGS optimization", leave=False)
|
|
787
|
+
except ImportError:
|
|
788
|
+
pass
|
|
789
|
+
|
|
790
|
+
n = len(b0)
|
|
791
|
+
b = b0.astype(float).copy()
|
|
792
|
+
|
|
793
|
+
if maxit <= 0:
|
|
794
|
+
f = fn(b)
|
|
795
|
+
if pbar:
|
|
796
|
+
pbar.close()
|
|
797
|
+
return scipy.optimize.OptimizeResult(
|
|
798
|
+
x=b, fun=f, nit=0, nfev=1, njev=0,
|
|
799
|
+
success=True, message="maxit <= 0, returning initial value",
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
# All parameters are free (mask = all True).
|
|
803
|
+
# In R, l[] maps free-parameter indices; here every index is free,
|
|
804
|
+
# so l[i] = i and the indirection is a no-op.
|
|
805
|
+
|
|
806
|
+
# ---- allocate working arrays ----
|
|
807
|
+
g = np.empty(n) # gradient
|
|
808
|
+
t = np.empty(n) # search direction
|
|
809
|
+
X = np.empty(n) # saved parameters
|
|
810
|
+
c = np.empty(n) # saved gradient
|
|
811
|
+
# B: lower-triangular BFGS Hessian-inverse approximation
|
|
812
|
+
# B[i][j] stored for j <= i (symmetric, only lower triangle kept)
|
|
813
|
+
B = np.zeros((n, n))
|
|
814
|
+
|
|
815
|
+
# ---- initial evaluation ----
|
|
816
|
+
f = fn(b)
|
|
817
|
+
if not np.isfinite(f):
|
|
818
|
+
raise ValueError(
|
|
819
|
+
"initial value in vmmin is not finite. "
|
|
820
|
+
"Suggestions: (1) Check for extreme covariate values, "
|
|
821
|
+
"(2) Scale covariates to have similar ranges, "
|
|
822
|
+
"(3) Remove covariates with very low variance, "
|
|
823
|
+
"(4) Try init_params with values closer to zero."
|
|
824
|
+
)
|
|
825
|
+
if trace:
|
|
826
|
+
print(f"initial value {f}")
|
|
827
|
+
Fmin = f
|
|
828
|
+
funcount = 1
|
|
829
|
+
gradcount = 1
|
|
830
|
+
g[:] = gr(b)
|
|
831
|
+
iter_ = 1
|
|
832
|
+
ilast = gradcount
|
|
833
|
+
|
|
834
|
+
while True:
|
|
835
|
+
# ---- Hessian reset when needed ----
|
|
836
|
+
if ilast == gradcount:
|
|
837
|
+
B[:, :] = 0.0
|
|
838
|
+
np.fill_diagonal(B, 1.0)
|
|
839
|
+
|
|
840
|
+
# ---- save current state ----
|
|
841
|
+
X[:] = b
|
|
842
|
+
c[:] = g
|
|
843
|
+
|
|
844
|
+
# ---- compute search direction t = -B g ----
|
|
845
|
+
# B is symmetric; use full matrix-vector product
|
|
846
|
+
t[:] = -(B @ g)
|
|
847
|
+
gradproj = float(t @ g)
|
|
848
|
+
|
|
849
|
+
if gradproj < 0.0:
|
|
850
|
+
# ---- downhill: backtracking line search ----
|
|
851
|
+
steplength = 1.0
|
|
852
|
+
accpoint = False
|
|
853
|
+
while True:
|
|
854
|
+
b[:] = X + steplength * t
|
|
855
|
+
count = int(np.sum(RELTEST + X == RELTEST + b))
|
|
856
|
+
if count < n:
|
|
857
|
+
f = fn(b)
|
|
858
|
+
funcount += 1
|
|
859
|
+
accpoint = (
|
|
860
|
+
np.isfinite(f)
|
|
861
|
+
and f <= Fmin + gradproj * steplength * ACCTOL
|
|
862
|
+
)
|
|
863
|
+
if not accpoint:
|
|
864
|
+
steplength *= STEPREDN
|
|
865
|
+
if count == n or accpoint:
|
|
866
|
+
break
|
|
867
|
+
|
|
868
|
+
enough = (
|
|
869
|
+
f > abstol
|
|
870
|
+
and abs(f - Fmin) > reltol * (abs(Fmin) + reltol)
|
|
871
|
+
)
|
|
872
|
+
if not enough:
|
|
873
|
+
count = n
|
|
874
|
+
Fmin = f
|
|
875
|
+
|
|
876
|
+
if count < n:
|
|
877
|
+
# ---- making progress ----
|
|
878
|
+
Fmin = f
|
|
879
|
+
g[:] = gr(b)
|
|
880
|
+
gradcount += 1
|
|
881
|
+
iter_ += 1
|
|
882
|
+
|
|
883
|
+
# prepare for BFGS update
|
|
884
|
+
t *= steplength # actual step
|
|
885
|
+
c[:] = g - c # gradient change
|
|
886
|
+
D1 = float(t @ c)
|
|
887
|
+
|
|
888
|
+
if D1 > 0:
|
|
889
|
+
# ---- BFGS Hessian-inverse update ----
|
|
890
|
+
# Compute X_tmp = B @ c (vectorized)
|
|
891
|
+
X[:] = B @ c
|
|
892
|
+
D2 = 1.0 + float(X @ c) / D1
|
|
893
|
+
|
|
894
|
+
# Rank-2 symmetric update (only lower triangle matters
|
|
895
|
+
# but we maintain full symmetric matrix for B @ g)
|
|
896
|
+
B += (D2 * np.outer(t, t)
|
|
897
|
+
- np.outer(X, t)
|
|
898
|
+
- np.outer(t, X)) / D1
|
|
899
|
+
else:
|
|
900
|
+
# D1 <= 0: curvature condition violated → reset
|
|
901
|
+
ilast = gradcount
|
|
902
|
+
else:
|
|
903
|
+
# ---- no progress ----
|
|
904
|
+
if ilast < gradcount:
|
|
905
|
+
count = 0
|
|
906
|
+
ilast = gradcount
|
|
907
|
+
else:
|
|
908
|
+
# ---- uphill search direction ----
|
|
909
|
+
count = 0
|
|
910
|
+
if ilast == gradcount:
|
|
911
|
+
count = n # already reset → give up
|
|
912
|
+
else:
|
|
913
|
+
ilast = gradcount # reset Hessian
|
|
914
|
+
|
|
915
|
+
if pbar:
|
|
916
|
+
pbar.update(1)
|
|
917
|
+
|
|
918
|
+
if trace and (iter_ % nREPORT == 0):
|
|
919
|
+
print(f"iter{iter_:4d} value {f}")
|
|
920
|
+
|
|
921
|
+
if iter_ >= maxit:
|
|
922
|
+
break
|
|
923
|
+
|
|
924
|
+
# ---- periodic restart ----
|
|
925
|
+
if gradcount - ilast > 2 * n:
|
|
926
|
+
ilast = gradcount
|
|
927
|
+
|
|
928
|
+
if count == n and ilast == gradcount:
|
|
929
|
+
break
|
|
930
|
+
|
|
931
|
+
if pbar:
|
|
932
|
+
pbar.close()
|
|
933
|
+
|
|
934
|
+
if trace:
|
|
935
|
+
print(f"final value {Fmin}")
|
|
936
|
+
if iter_ < maxit:
|
|
937
|
+
print("converged")
|
|
938
|
+
else:
|
|
939
|
+
print(f"stopped after {iter_} iterations")
|
|
940
|
+
|
|
941
|
+
success = iter_ < maxit
|
|
942
|
+
return scipy.optimize.OptimizeResult(
|
|
943
|
+
x=b,
|
|
944
|
+
fun=Fmin,
|
|
945
|
+
nit=iter_,
|
|
946
|
+
nfev=funcount,
|
|
947
|
+
njev=gradcount,
|
|
948
|
+
success=success,
|
|
949
|
+
message="converged" if success else f"stopped after {iter_} iterations",
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
def _glm_init(
|
|
955
|
+
treat: np.ndarray,
|
|
956
|
+
X: np.ndarray,
|
|
957
|
+
sample_weights: np.ndarray,
|
|
958
|
+
att: int,
|
|
959
|
+
gmm_loss_func: Callable
|
|
960
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
961
|
+
"""
|
|
962
|
+
Initialize GLM coefficients through six-step optimization.
|
|
963
|
+
|
|
964
|
+
This function computes initial values for the CBPS optimization by
|
|
965
|
+
fitting a standard GLM model and then optimizing the scaling factor
|
|
966
|
+
alpha to minimize the GMM loss function.
|
|
967
|
+
|
|
968
|
+
Parameters
|
|
969
|
+
----------
|
|
970
|
+
treat : np.ndarray
|
|
971
|
+
Binary treatment indicator, shape (n,).
|
|
972
|
+
X : np.ndarray
|
|
973
|
+
Covariate matrix, shape (n, k).
|
|
974
|
+
sample_weights : np.ndarray
|
|
975
|
+
Normalized sampling weights, shape (n,).
|
|
976
|
+
att : int
|
|
977
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
978
|
+
gmm_loss_func : Callable
|
|
979
|
+
GMM loss function for alpha scaling optimization.
|
|
980
|
+
|
|
981
|
+
Returns
|
|
982
|
+
-------
|
|
983
|
+
beta_init : np.ndarray
|
|
984
|
+
Initial coefficients after GLM fitting and alpha scaling.
|
|
985
|
+
beta_glm : np.ndarray
|
|
986
|
+
Original GLM coefficients, used for computing MLE J-statistic.
|
|
987
|
+
|
|
988
|
+
Notes
|
|
989
|
+
-----
|
|
990
|
+
Six-step initialization process:
|
|
991
|
+
1. Fit GLM model with warnings suppressed
|
|
992
|
+
2. Set NA coefficients to 0 (first pass)
|
|
993
|
+
3. Sequential probability clipping
|
|
994
|
+
4. Extract coefficients and handle NA (second pass)
|
|
995
|
+
5. Optimize alpha scaling factor in [0.8, 1.1]
|
|
996
|
+
"""
|
|
997
|
+
# Step 1: GLM fitting with warnings suppressed
|
|
998
|
+
# Note: GLM doesn't use weights parameter here; sample_weights are used only in GMM steps.
|
|
999
|
+
with warnings.catch_warnings():
|
|
1000
|
+
warnings.simplefilter("ignore")
|
|
1001
|
+
model = sm.GLM(treat, X, family=Binomial())
|
|
1002
|
+
glm_fit = model.fit(tol=DEFAULT_CONFIG.glm_tol, maxiter=25) # Standard IRLS algorithm
|
|
1003
|
+
|
|
1004
|
+
# Step 2: Handle NA coefficients (first pass)
|
|
1005
|
+
beta_glm = glm_fit.params.copy()
|
|
1006
|
+
beta_glm[np.isnan(beta_glm)] = 0
|
|
1007
|
+
|
|
1008
|
+
# Step 3: Probability clipping
|
|
1009
|
+
probs_glm = np.clip(glm_fit.fittedvalues, PROBS_MIN, 1 - PROBS_MIN)
|
|
1010
|
+
|
|
1011
|
+
# Step 4: Extract coefficients and handle NA (second pass)
|
|
1012
|
+
beta_curr = beta_glm.copy()
|
|
1013
|
+
beta_curr[np.isnan(beta_curr)] = 0
|
|
1014
|
+
|
|
1015
|
+
# Step 5: Alpha scaling optimization (1D search for optimal scaling factor)
|
|
1016
|
+
alpha_func = lambda alpha: gmm_loss_func(beta_curr * alpha)
|
|
1017
|
+
result = scipy.optimize.minimize_scalar(
|
|
1018
|
+
alpha_func, bounds=(0.8, 1.1), method='bounded'
|
|
1019
|
+
)
|
|
1020
|
+
beta_curr = beta_curr * result.x
|
|
1021
|
+
|
|
1022
|
+
# Return: scaled coefficients and original GLM coefficients (for MLE J-statistic)
|
|
1023
|
+
return beta_curr, beta_glm
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def _compute_moment_conditions(
|
|
1027
|
+
beta: np.ndarray,
|
|
1028
|
+
X: np.ndarray,
|
|
1029
|
+
treat: np.ndarray,
|
|
1030
|
+
sample_weights: np.ndarray,
|
|
1031
|
+
att: int,
|
|
1032
|
+
n: int
|
|
1033
|
+
) -> np.ndarray:
|
|
1034
|
+
"""
|
|
1035
|
+
Compute CBPS moment conditions (covariate balance conditions).
|
|
1036
|
+
|
|
1037
|
+
Implements the moment conditions from Imai & Ratkovic (2014) JRSS-B:
|
|
1038
|
+
- Equation (10): ATE balance condition
|
|
1039
|
+
- Equation (11): ATT balance condition
|
|
1040
|
+
|
|
1041
|
+
Parameters
|
|
1042
|
+
----------
|
|
1043
|
+
beta : np.ndarray
|
|
1044
|
+
Coefficient vector, shape (k,).
|
|
1045
|
+
X : np.ndarray
|
|
1046
|
+
Covariate matrix, shape (n, k).
|
|
1047
|
+
treat : np.ndarray
|
|
1048
|
+
Binary treatment vector (0/1), shape (n,).
|
|
1049
|
+
sample_weights : np.ndarray
|
|
1050
|
+
Sample weights, shape (n,).
|
|
1051
|
+
att : int
|
|
1052
|
+
Estimand: 0=ATE, 1=ATT (T=1 is treated), 2=ATT (T=0 is treated).
|
|
1053
|
+
n : int
|
|
1054
|
+
Sample size.
|
|
1055
|
+
|
|
1056
|
+
Returns
|
|
1057
|
+
-------
|
|
1058
|
+
np.ndarray
|
|
1059
|
+
k-dimensional moment condition vector.
|
|
1060
|
+
For just-identified GMM: moments should be approximately zero.
|
|
1061
|
+
|
|
1062
|
+
Notes
|
|
1063
|
+
-----
|
|
1064
|
+
This is the core of just-identified GMM: k equations for k unknowns.
|
|
1065
|
+
The theoretical requirement is to solve moments = 0 directly.
|
|
1066
|
+
"""
|
|
1067
|
+
theta = X @ beta
|
|
1068
|
+
pi = scipy.special.expit(theta)
|
|
1069
|
+
pi = np.clip(pi, PROBS_MIN, 1 - PROBS_MIN)
|
|
1070
|
+
|
|
1071
|
+
# Compute weights based on estimand (Equations 10/11 in the paper)
|
|
1072
|
+
if att == 1:
|
|
1073
|
+
# ATT Equation (11): w = (n/n_1) * (T - pi) / (1 - pi)
|
|
1074
|
+
n_treated = np.sum(treat * sample_weights)
|
|
1075
|
+
w = (n / n_treated) * (treat - pi) / (1 - pi)
|
|
1076
|
+
elif att == 2:
|
|
1077
|
+
# ATT with reversed treatment (T=0 is treated)
|
|
1078
|
+
n_control = np.sum((1 - treat) * sample_weights)
|
|
1079
|
+
w = (n / n_control) * (treat - pi) / pi
|
|
1080
|
+
else:
|
|
1081
|
+
# ATE Equation (10): w = (T - pi) / (pi * (1 - pi))
|
|
1082
|
+
w = (treat - pi) / (pi * (1 - pi))
|
|
1083
|
+
|
|
1084
|
+
# Moment conditions (covariate balance)
|
|
1085
|
+
moments = (sample_weights[:, None] * X).T @ w / n
|
|
1086
|
+
|
|
1087
|
+
return moments
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
def _solve_moment_equations(
|
|
1091
|
+
beta_init: np.ndarray,
|
|
1092
|
+
X: np.ndarray,
|
|
1093
|
+
treat: np.ndarray,
|
|
1094
|
+
sample_weights: np.ndarray,
|
|
1095
|
+
att: int,
|
|
1096
|
+
n: int,
|
|
1097
|
+
iterations: int = 1000
|
|
1098
|
+
) -> Tuple[np.ndarray, bool, np.ndarray, str]:
|
|
1099
|
+
"""
|
|
1100
|
+
Solve moment equations directly (theoretically correct just-identified GMM).
|
|
1101
|
+
|
|
1102
|
+
This implementation follows the GMM framework:
|
|
1103
|
+
- Hansen (1982) GMM: Just-identified = solve E[g(X, theta)] = 0
|
|
1104
|
+
- Imai & Ratkovic (2014) Equations (10)/(11): Balance conditions
|
|
1105
|
+
|
|
1106
|
+
Parameters
|
|
1107
|
+
----------
|
|
1108
|
+
beta_init : np.ndarray
|
|
1109
|
+
Initial values (from GLM or balance optimization), shape (k,).
|
|
1110
|
+
|
|
1111
|
+
Returns
|
|
1112
|
+
-------
|
|
1113
|
+
beta_opt : np.ndarray
|
|
1114
|
+
Optimal coefficients satisfying moment = 0.
|
|
1115
|
+
success : bool
|
|
1116
|
+
Whether convergence was achieved.
|
|
1117
|
+
moments_final : np.ndarray
|
|
1118
|
+
Final moment values (should be approximately zero).
|
|
1119
|
+
method : str
|
|
1120
|
+
Solver method used.
|
|
1121
|
+
|
|
1122
|
+
Notes
|
|
1123
|
+
-----
|
|
1124
|
+
Advantages over balance loss optimization:
|
|
1125
|
+
1. Theoretically correct: directly corresponds to just-identified GMM
|
|
1126
|
+
2. Numerical precision: can achieve machine precision (~1e-15)
|
|
1127
|
+
3. Computational efficiency: typically faster
|
|
1128
|
+
|
|
1129
|
+
Solver strategy:
|
|
1130
|
+
1. First try 'hybr' (hybrid Powell, robust and fast)
|
|
1131
|
+
2. Fall back to 'lm' (Levenberg-Marquardt, more robust but slower)
|
|
1132
|
+
3. If both fail, return failure status
|
|
1133
|
+
"""
|
|
1134
|
+
from scipy.optimize import root
|
|
1135
|
+
|
|
1136
|
+
def moment_eq(beta):
|
|
1137
|
+
"""Moment equations: k equations for k unknowns."""
|
|
1138
|
+
return _compute_moment_conditions(beta, X, treat, sample_weights, att, n)
|
|
1139
|
+
|
|
1140
|
+
# Primary solver: hybrid Powell method (fast and robust)
|
|
1141
|
+
result = root(
|
|
1142
|
+
moment_eq,
|
|
1143
|
+
x0=beta_init,
|
|
1144
|
+
method='hybr',
|
|
1145
|
+
options={'xtol': DEFAULT_CONFIG.optim_xtol, 'maxfev': iterations * 10}
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
if result.success:
|
|
1149
|
+
moments_final = moment_eq(result.x)
|
|
1150
|
+
return result.x, True, moments_final, 'hybr'
|
|
1151
|
+
|
|
1152
|
+
# Fallback: Levenberg-Marquardt (more robust)
|
|
1153
|
+
try:
|
|
1154
|
+
result = root(
|
|
1155
|
+
moment_eq,
|
|
1156
|
+
x0=beta_init,
|
|
1157
|
+
method='lm',
|
|
1158
|
+
options={'xtol': DEFAULT_CONFIG.optim_xtol, 'maxiter': iterations * 5}
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
if result.success:
|
|
1162
|
+
moments_final = moment_eq(result.x)
|
|
1163
|
+
return result.x, True, moments_final, 'lm'
|
|
1164
|
+
except (ValueError, RuntimeError, np.linalg.LinAlgError):
|
|
1165
|
+
pass
|
|
1166
|
+
|
|
1167
|
+
# Both solvers failed: return initial values with failure status
|
|
1168
|
+
moments_final = moment_eq(beta_init)
|
|
1169
|
+
return beta_init, False, moments_final, 'failed'
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def _optimize_balance(
|
|
1173
|
+
gmm_init: np.ndarray,
|
|
1174
|
+
X: np.ndarray,
|
|
1175
|
+
treat: np.ndarray,
|
|
1176
|
+
sample_weights: np.ndarray,
|
|
1177
|
+
XprimeX_inv: np.ndarray,
|
|
1178
|
+
att: int,
|
|
1179
|
+
two_step: bool,
|
|
1180
|
+
iterations: int,
|
|
1181
|
+
bal_only: bool = False,
|
|
1182
|
+
show_progress: bool = False,
|
|
1183
|
+
**kwargs
|
|
1184
|
+
) -> scipy.optimize.OptimizeResult:
|
|
1185
|
+
"""
|
|
1186
|
+
Optimize balance loss to find initial values for GMM.
|
|
1187
|
+
|
|
1188
|
+
Uses R's vmmin BFGS algorithm (simple Armijo backtracking line search
|
|
1189
|
+
with reltol convergence) to exactly replicate R CBPS package behavior.
|
|
1190
|
+
|
|
1191
|
+
Parameters
|
|
1192
|
+
----------
|
|
1193
|
+
gmm_init : np.ndarray
|
|
1194
|
+
GLM-initialized coefficients, shape (k,).
|
|
1195
|
+
bal_only : bool
|
|
1196
|
+
Whether this is just-identified mode (method='exact').
|
|
1197
|
+
**kwargs
|
|
1198
|
+
Additional arguments passed through from CBPS wrapper.
|
|
1199
|
+
|
|
1200
|
+
Returns
|
|
1201
|
+
-------
|
|
1202
|
+
scipy.optimize.OptimizeResult
|
|
1203
|
+
Balance optimization result object.
|
|
1204
|
+
|
|
1205
|
+
Notes
|
|
1206
|
+
-----
|
|
1207
|
+
The analytical gradient is required for reliable optimization because
|
|
1208
|
+
the balance loss function contains an absolute value, which has
|
|
1209
|
+
discontinuous derivatives at zero. Numerical gradients perform poorly
|
|
1210
|
+
in this case.
|
|
1211
|
+
"""
|
|
1212
|
+
bal_loss_func = lambda b: _bal_loss(b, X, treat, sample_weights, XprimeX_inv, att)
|
|
1213
|
+
bal_grad_func = lambda b: _bal_gradient(b, X, treat, sample_weights, XprimeX_inv, att)
|
|
1214
|
+
|
|
1215
|
+
verbose = kwargs.get('verbose', False)
|
|
1216
|
+
|
|
1217
|
+
# R CBPS package only provides analytical gradient for balance optimization
|
|
1218
|
+
# when twostep=TRUE. For continuous updating (twostep=FALSE), R uses
|
|
1219
|
+
# numerical gradients (finite differences via optim's default behavior).
|
|
1220
|
+
gr_func = bal_grad_func if two_step else None
|
|
1221
|
+
|
|
1222
|
+
# Use R's vmmin BFGS (faithful translation of R's optim(..., method="BFGS"))
|
|
1223
|
+
# This ensures identical convergence behavior: simple Armijo backtracking
|
|
1224
|
+
# line search + reltol convergence criterion.
|
|
1225
|
+
opt_bal = _vmmin_bfgs(
|
|
1226
|
+
gmm_init,
|
|
1227
|
+
fn=bal_loss_func,
|
|
1228
|
+
gr=gr_func,
|
|
1229
|
+
maxit=iterations,
|
|
1230
|
+
trace=verbose,
|
|
1231
|
+
show_progress=show_progress,
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
return opt_bal
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
def _optimize_gmm_dual_init(
|
|
1238
|
+
gmm_init: np.ndarray,
|
|
1239
|
+
beta_bal: np.ndarray,
|
|
1240
|
+
X: np.ndarray,
|
|
1241
|
+
treat: np.ndarray,
|
|
1242
|
+
sample_weights: np.ndarray,
|
|
1243
|
+
att: int,
|
|
1244
|
+
this_inv_V: np.ndarray,
|
|
1245
|
+
two_step: bool,
|
|
1246
|
+
iterations: int,
|
|
1247
|
+
show_progress: bool = False,
|
|
1248
|
+
**kwargs
|
|
1249
|
+
) -> scipy.optimize.OptimizeResult:
|
|
1250
|
+
"""
|
|
1251
|
+
Perform GMM optimization with dual initialization strategy.
|
|
1252
|
+
|
|
1253
|
+
Runs GMM optimization from two starting points (GLM-initialized and
|
|
1254
|
+
balance-optimized) and returns the result with lower objective value.
|
|
1255
|
+
Uses R's vmmin BFGS algorithm for exact replication.
|
|
1256
|
+
|
|
1257
|
+
Parameters
|
|
1258
|
+
----------
|
|
1259
|
+
gmm_init : np.ndarray
|
|
1260
|
+
GLM-initialized coefficients (after alpha scaling), shape (k,).
|
|
1261
|
+
beta_bal : np.ndarray
|
|
1262
|
+
Balance-optimized coefficients, shape (k,).
|
|
1263
|
+
this_inv_V : np.ndarray
|
|
1264
|
+
Precomputed inverse covariance matrix (for two-step GMM).
|
|
1265
|
+
|
|
1266
|
+
Returns
|
|
1267
|
+
-------
|
|
1268
|
+
scipy.optimize.OptimizeResult
|
|
1269
|
+
Optimization result with lower objective value.
|
|
1270
|
+
|
|
1271
|
+
Notes
|
|
1272
|
+
-----
|
|
1273
|
+
The dual initialization strategy improves robustness by exploring
|
|
1274
|
+
different regions of the parameter space. When two_step=True, analytical
|
|
1275
|
+
gradients are used following the R CBPS package implementation.
|
|
1276
|
+
"""
|
|
1277
|
+
verbose = kwargs.get('verbose', False)
|
|
1278
|
+
|
|
1279
|
+
if two_step:
|
|
1280
|
+
# Two-step GMM optimization using analytical gradients (R-compatible)
|
|
1281
|
+
def gmm_loss_with_inv_V(b):
|
|
1282
|
+
return _gmm_loss(b, X, treat, sample_weights, att, this_inv_V)
|
|
1283
|
+
|
|
1284
|
+
def gmm_grad_with_inv_V(b):
|
|
1285
|
+
return _gmm_gradient(b, this_inv_V, X, treat, sample_weights, att)
|
|
1286
|
+
|
|
1287
|
+
gmm_glm_init = _vmmin_bfgs(
|
|
1288
|
+
gmm_init,
|
|
1289
|
+
fn=gmm_loss_with_inv_V,
|
|
1290
|
+
gr=gmm_grad_with_inv_V,
|
|
1291
|
+
maxit=iterations,
|
|
1292
|
+
trace=verbose,
|
|
1293
|
+
show_progress=show_progress,
|
|
1294
|
+
)
|
|
1295
|
+
gmm_bal_init = _vmmin_bfgs(
|
|
1296
|
+
beta_bal,
|
|
1297
|
+
fn=gmm_loss_with_inv_V,
|
|
1298
|
+
gr=gmm_grad_with_inv_V,
|
|
1299
|
+
maxit=iterations,
|
|
1300
|
+
trace=verbose,
|
|
1301
|
+
show_progress=show_progress,
|
|
1302
|
+
)
|
|
1303
|
+
else:
|
|
1304
|
+
# Continuous updating GMM optimization
|
|
1305
|
+
# R CBPS package does NOT provide analytical gradients for continuous
|
|
1306
|
+
# updating (twostep=FALSE). It relies on numerical differentiation
|
|
1307
|
+
# via optim's default finite-difference method. This is because
|
|
1308
|
+
# _gmm_gradient treats inv_V as fixed, which is only valid for
|
|
1309
|
+
# two-step GMM where V is pre-computed. In continuous updating,
|
|
1310
|
+
# V is recomputed at each iteration, making the fixed-V gradient
|
|
1311
|
+
# only an approximation.
|
|
1312
|
+
def gmm_loss_continuous(b):
|
|
1313
|
+
return _gmm_loss(b, X, treat, sample_weights, att, None)
|
|
1314
|
+
|
|
1315
|
+
gmm_glm_init = _vmmin_bfgs(
|
|
1316
|
+
gmm_init,
|
|
1317
|
+
fn=gmm_loss_continuous,
|
|
1318
|
+
gr=None,
|
|
1319
|
+
maxit=iterations,
|
|
1320
|
+
trace=verbose,
|
|
1321
|
+
show_progress=show_progress,
|
|
1322
|
+
)
|
|
1323
|
+
gmm_bal_init = _vmmin_bfgs(
|
|
1324
|
+
beta_bal,
|
|
1325
|
+
fn=gmm_loss_continuous,
|
|
1326
|
+
gr=None,
|
|
1327
|
+
maxit=iterations,
|
|
1328
|
+
trace=verbose,
|
|
1329
|
+
show_progress=show_progress,
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
# Return the result with lower objective value
|
|
1333
|
+
if gmm_glm_init.fun < gmm_bal_init.fun:
|
|
1334
|
+
return gmm_glm_init
|
|
1335
|
+
else:
|
|
1336
|
+
return gmm_bal_init
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
def _classify_separation(
|
|
1340
|
+
probs_opt_raw: np.ndarray,
|
|
1341
|
+
beta_opt: np.ndarray,
|
|
1342
|
+
extreme_coef_threshold: float = 10.0
|
|
1343
|
+
) -> Optional[Tuple[str, str]]:
|
|
1344
|
+
"""
|
|
1345
|
+
Classify separation severity based on propensity scores at boundaries.
|
|
1346
|
+
|
|
1347
|
+
This is a pure function that examines how many observations have
|
|
1348
|
+
propensity scores at or beyond the clipping boundaries (PROBS_MIN
|
|
1349
|
+
and 1 - PROBS_MIN) and returns the appropriate severity level.
|
|
1350
|
+
|
|
1351
|
+
Parameters
|
|
1352
|
+
----------
|
|
1353
|
+
probs_opt_raw : np.ndarray
|
|
1354
|
+
Raw (unclipped) propensity scores from expit(X @ beta), shape (n,).
|
|
1355
|
+
beta_opt : np.ndarray
|
|
1356
|
+
Optimized coefficient vector, shape (k,). Used to check for
|
|
1357
|
+
extreme coefficients.
|
|
1358
|
+
extreme_coef_threshold : float, default 10.0
|
|
1359
|
+
Threshold for flagging extreme coefficients.
|
|
1360
|
+
|
|
1361
|
+
Returns
|
|
1362
|
+
-------
|
|
1363
|
+
tuple or None
|
|
1364
|
+
If no boundary observations: returns None.
|
|
1365
|
+
Otherwise: (severity_level, warning_msg) where severity_level is one of
|
|
1366
|
+
'MINOR', 'MODERATE SEPARATION', 'QUASI-SEPARATION', 'COMPLETE SEPARATION'.
|
|
1367
|
+
"""
|
|
1368
|
+
n = len(probs_opt_raw)
|
|
1369
|
+
n_clipped_low = np.sum(probs_opt_raw <= PROBS_MIN)
|
|
1370
|
+
n_clipped_high = np.sum(probs_opt_raw >= 1 - PROBS_MIN)
|
|
1371
|
+
n_boundary = n_clipped_low + n_clipped_high
|
|
1372
|
+
|
|
1373
|
+
if n_boundary == 0:
|
|
1374
|
+
return None
|
|
1375
|
+
|
|
1376
|
+
boundary_pct = 100.0 * n_boundary / n
|
|
1377
|
+
|
|
1378
|
+
# Check for extreme coefficients (may indicate separation)
|
|
1379
|
+
extreme_coef_mask = np.abs(beta_opt) > extreme_coef_threshold
|
|
1380
|
+
has_extreme_coefs = np.any(extreme_coef_mask)
|
|
1381
|
+
|
|
1382
|
+
# Build common diagnostic lines
|
|
1383
|
+
_header_lines = (
|
|
1384
|
+
f"Detected: {n_boundary} observations ({boundary_pct:.1f}%) "
|
|
1385
|
+
f"at probability boundary\n"
|
|
1386
|
+
f" - Low boundary (\u03c0 \u2264 {PROBS_MIN}): {n_clipped_low}\n"
|
|
1387
|
+
f" - High boundary (\u03c0 \u2265 {1-PROBS_MIN}): {n_clipped_high}"
|
|
1388
|
+
)
|
|
1389
|
+
_extreme_line = (
|
|
1390
|
+
f"\n - Extreme coefficients (|\u03b2| > {extreme_coef_threshold}): detected"
|
|
1391
|
+
if has_extreme_coefs else ""
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
# Issue graduated warnings based on severity
|
|
1395
|
+
if boundary_pct >= 100.0:
|
|
1396
|
+
severity_level = "COMPLETE SEPARATION"
|
|
1397
|
+
suggestions = [
|
|
1398
|
+
"Check for perfect predictors: examine if any covariate "
|
|
1399
|
+
"perfectly separates treatment groups",
|
|
1400
|
+
"Consider penalized estimation: use hdCBPS with LASSO regularization",
|
|
1401
|
+
"Remove or combine highly predictive variables",
|
|
1402
|
+
"Consider Firth's penalized likelihood as initialization",
|
|
1403
|
+
"Verify data coding: check for data entry errors in treatment variable",
|
|
1404
|
+
]
|
|
1405
|
+
elif boundary_pct >= 50.0:
|
|
1406
|
+
severity_level = "QUASI-SEPARATION"
|
|
1407
|
+
suggestions = [
|
|
1408
|
+
"Check for multicollinearity: compute VIF for covariates",
|
|
1409
|
+
"Consider trimming: remove units with extreme propensity scores "
|
|
1410
|
+
"(Crump et al. 2009)",
|
|
1411
|
+
"Use regularized estimation (hdCBPS with LASSO)",
|
|
1412
|
+
"Report sensitivity analysis with different trimming thresholds",
|
|
1413
|
+
"Consider weight truncation at the 1st/99th percentile",
|
|
1414
|
+
]
|
|
1415
|
+
elif boundary_pct >= 10.0:
|
|
1416
|
+
severity_level = "MODERATE SEPARATION"
|
|
1417
|
+
suggestions = [
|
|
1418
|
+
"Examine covariate balance after weighting",
|
|
1419
|
+
"Report effective sample size (ESS = (sum(w))^2 / sum(w^2))",
|
|
1420
|
+
"Verify stability: compare results with 'exact' vs 'over' method",
|
|
1421
|
+
"Consider standardize=True to reduce weight variability",
|
|
1422
|
+
]
|
|
1423
|
+
else:
|
|
1424
|
+
severity_level = "MINOR"
|
|
1425
|
+
suggestions = [
|
|
1426
|
+
"This is usually acceptable but check covariate balance.",
|
|
1427
|
+
]
|
|
1428
|
+
|
|
1429
|
+
# Assemble final warning message
|
|
1430
|
+
suggestion_text = "\n".join(
|
|
1431
|
+
f" {i}. {s}" for i, s in enumerate(suggestions, 1)
|
|
1432
|
+
)
|
|
1433
|
+
warning_msg = (
|
|
1434
|
+
f"[CBPS Separation Warning - {severity_level}]\n"
|
|
1435
|
+
f"{_header_lines}{_extreme_line}\n"
|
|
1436
|
+
f"Suggested actions:\n"
|
|
1437
|
+
f"{suggestion_text}"
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
return severity_level, warning_msg
|
|
1441
|
+
|
|
1442
|
+
|
|
1443
|
+
def _compute_final_weights(
|
|
1444
|
+
beta_opt: np.ndarray,
|
|
1445
|
+
X: np.ndarray,
|
|
1446
|
+
treat: np.ndarray,
|
|
1447
|
+
sample_weights: np.ndarray,
|
|
1448
|
+
att: int,
|
|
1449
|
+
standardize: bool
|
|
1450
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1451
|
+
"""
|
|
1452
|
+
Compute final propensity scores and inverse probability weights.
|
|
1453
|
+
|
|
1454
|
+
Parameters
|
|
1455
|
+
----------
|
|
1456
|
+
beta_opt : np.ndarray
|
|
1457
|
+
Optimized coefficient vector, shape (k,).
|
|
1458
|
+
X : np.ndarray
|
|
1459
|
+
Covariate matrix, shape (n, k).
|
|
1460
|
+
treat : np.ndarray
|
|
1461
|
+
Binary treatment indicator, shape (n,).
|
|
1462
|
+
sample_weights : np.ndarray
|
|
1463
|
+
Sampling weights, shape (n,).
|
|
1464
|
+
att : int
|
|
1465
|
+
Estimand type: 0 for ATE, 1 for ATT.
|
|
1466
|
+
standardize : bool
|
|
1467
|
+
Whether to normalize weights to sum to sample size.
|
|
1468
|
+
|
|
1469
|
+
Returns
|
|
1470
|
+
-------
|
|
1471
|
+
probs_opt : np.ndarray
|
|
1472
|
+
Final propensity scores, shape (n,).
|
|
1473
|
+
w_opt : np.ndarray
|
|
1474
|
+
Final inverse probability weights (standardized and incorporating
|
|
1475
|
+
sample_weights), shape (n,).
|
|
1476
|
+
|
|
1477
|
+
Notes
|
|
1478
|
+
-----
|
|
1479
|
+
The weight computation follows these steps:
|
|
1480
|
+
1. Compute propensity scores from optimized coefficients
|
|
1481
|
+
2. Compute initial IPW weights (ATT or ATE formula)
|
|
1482
|
+
3. Standardize weights if requested
|
|
1483
|
+
4. Incorporate sampling weights
|
|
1484
|
+
"""
|
|
1485
|
+
# Compute propensity scores from optimized coefficients
|
|
1486
|
+
theta_opt = X @ beta_opt
|
|
1487
|
+
probs_opt_raw = scipy.special.expit(theta_opt)
|
|
1488
|
+
probs_opt = np.clip(probs_opt_raw, PROBS_MIN, 1 - PROBS_MIN)
|
|
1489
|
+
|
|
1490
|
+
# Detect separation issues (propensity scores at boundaries)
|
|
1491
|
+
result = _classify_separation(probs_opt_raw, beta_opt)
|
|
1492
|
+
if result is not None:
|
|
1493
|
+
_, warning_msg = result
|
|
1494
|
+
warnings.warn(warning_msg, UserWarning, stacklevel=3)
|
|
1495
|
+
|
|
1496
|
+
# Compute initial IPW weights
|
|
1497
|
+
if att:
|
|
1498
|
+
# ATT weights
|
|
1499
|
+
w_opt = np.abs(_att_wt_func(beta_opt, X, treat, sample_weights))
|
|
1500
|
+
else:
|
|
1501
|
+
# ATE weights
|
|
1502
|
+
w_opt = np.abs(1 / (probs_opt - 1 + treat))
|
|
1503
|
+
|
|
1504
|
+
# Standardize weights and incorporate sampling weights
|
|
1505
|
+
w_opt = standardize_weights(w_opt, treat, probs_opt, sample_weights, att, standardize)
|
|
1506
|
+
|
|
1507
|
+
return probs_opt, w_opt
|
|
1508
|
+
|
|
1509
|
+
|
|
1510
|
+
def _compute_diagnostics(
|
|
1511
|
+
beta_opt: np.ndarray,
|
|
1512
|
+
beta_glm: np.ndarray,
|
|
1513
|
+
probs_opt: np.ndarray,
|
|
1514
|
+
treat: np.ndarray,
|
|
1515
|
+
sample_weights: np.ndarray,
|
|
1516
|
+
att: int,
|
|
1517
|
+
two_step: bool,
|
|
1518
|
+
this_inv_V: np.ndarray,
|
|
1519
|
+
X: np.ndarray
|
|
1520
|
+
) -> Tuple[float, float, float, float]:
|
|
1521
|
+
"""
|
|
1522
|
+
Compute J-statistic, deviance, and null deviance.
|
|
1523
|
+
|
|
1524
|
+
Returns
|
|
1525
|
+
-------
|
|
1526
|
+
J_opt : float
|
|
1527
|
+
J-statistic (GMM loss, over-identification test).
|
|
1528
|
+
mle_J : float
|
|
1529
|
+
MLE baseline J (computed with GLM coefficients).
|
|
1530
|
+
deviance : float
|
|
1531
|
+
Negative 2 times weighted log-likelihood.
|
|
1532
|
+
nulldeviance : float
|
|
1533
|
+
Null model deviance (intercept-only model).
|
|
1534
|
+
|
|
1535
|
+
Notes
|
|
1536
|
+
-----
|
|
1537
|
+
The J-statistic can be used to test the over-identifying restrictions
|
|
1538
|
+
in the GMM framework. Under the null hypothesis of correct specification,
|
|
1539
|
+
J ~ chi-squared with degrees of freedom equal to the number of
|
|
1540
|
+
over-identifying restrictions.
|
|
1541
|
+
"""
|
|
1542
|
+
# Compute J-statistic based on two-step or continuous updating
|
|
1543
|
+
if two_step:
|
|
1544
|
+
J_opt = _gmm_func(beta_opt, X, treat, sample_weights, att, inv_V=this_inv_V)['loss']
|
|
1545
|
+
else:
|
|
1546
|
+
J_opt = _gmm_func(beta_opt, X, treat, sample_weights, att, inv_V=None)['loss']
|
|
1547
|
+
|
|
1548
|
+
# Compute MLE baseline J-statistic using GLM coefficients
|
|
1549
|
+
if two_step:
|
|
1550
|
+
mle_J = _gmm_func(beta_glm, X, treat, sample_weights, att, inv_V=this_inv_V)['loss']
|
|
1551
|
+
else:
|
|
1552
|
+
mle_J = _gmm_func(beta_glm, X, treat, sample_weights, att, inv_V=None)['loss']
|
|
1553
|
+
|
|
1554
|
+
# Deviance: negative 2 times weighted log-likelihood
|
|
1555
|
+
deviance = -2 * np.sum(
|
|
1556
|
+
treat * sample_weights * np.log(probs_opt) +
|
|
1557
|
+
(1 - treat) * sample_weights * np.log(1 - probs_opt)
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
# Null deviance: intercept-only model with predicted probability = sample mean
|
|
1561
|
+
treat_mean = np.average(treat, weights=sample_weights)
|
|
1562
|
+
treat_mean = np.clip(treat_mean, 1e-10, 1 - 1e-10) # Prevent log(0)
|
|
1563
|
+
nulldeviance = -2 * np.sum(
|
|
1564
|
+
treat * sample_weights * np.log(treat_mean) +
|
|
1565
|
+
(1 - treat) * sample_weights * np.log(1 - treat_mean)
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1568
|
+
return J_opt, mle_J, deviance, nulldeviance
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
def _compute_vcov(
|
|
1572
|
+
beta_opt: np.ndarray,
|
|
1573
|
+
probs_opt: np.ndarray,
|
|
1574
|
+
treat: np.ndarray,
|
|
1575
|
+
X: np.ndarray,
|
|
1576
|
+
sample_weights: np.ndarray,
|
|
1577
|
+
att: int,
|
|
1578
|
+
bal_only: bool,
|
|
1579
|
+
XprimeX_inv: np.ndarray,
|
|
1580
|
+
this_inv_V: np.ndarray,
|
|
1581
|
+
two_step: bool,
|
|
1582
|
+
n: int
|
|
1583
|
+
) -> np.ndarray:
|
|
1584
|
+
"""
|
|
1585
|
+
Compute sandwich variance-covariance matrix.
|
|
1586
|
+
|
|
1587
|
+
Returns
|
|
1588
|
+
-------
|
|
1589
|
+
np.ndarray
|
|
1590
|
+
Coefficient variance-covariance matrix, shape (k, k).
|
|
1591
|
+
|
|
1592
|
+
Notes
|
|
1593
|
+
-----
|
|
1594
|
+
Implements the sandwich estimator (Newey & McFadden 1994, Eq. 6.17):
|
|
1595
|
+
Var(beta_hat) = (G'WG)^-1 G'W Omega W'G (G'WG)^-1
|
|
1596
|
+
|
|
1597
|
+
Processing steps:
|
|
1598
|
+
1. Construct G matrix (gradients) and W1 matrix (moment conditions)
|
|
1599
|
+
2. Assemble G and W matrices based on identification mode
|
|
1600
|
+
3. Compute variance using sandwich formula
|
|
1601
|
+
"""
|
|
1602
|
+
n_c = np.sum(sample_weights[treat == 0])
|
|
1603
|
+
n_t = np.sum(sample_weights[treat == 1])
|
|
1604
|
+
|
|
1605
|
+
# Score condition components (shared by ATT/ATE)
|
|
1606
|
+
XG_1 = -X * (probs_opt * (1 - probs_opt))[:, None] * sample_weights[:, None]
|
|
1607
|
+
XW_1 = X * (treat - probs_opt)[:, None] * np.sqrt(sample_weights)[:, None]
|
|
1608
|
+
|
|
1609
|
+
# Balance condition components (ATT/ATE branches)
|
|
1610
|
+
if att:
|
|
1611
|
+
# ATT branch
|
|
1612
|
+
XW_2 = X * _att_wt_func(beta_opt, X, treat, sample_weights)[:, None] * sample_weights[:, None]
|
|
1613
|
+
dw2 = -n / n_t * probs_opt / (1 - probs_opt)
|
|
1614
|
+
dw2[treat == 1] = 0 # Zero derivative for treated units
|
|
1615
|
+
XG_2 = X * dw2[:, None] * sample_weights[:, None]
|
|
1616
|
+
else:
|
|
1617
|
+
# ATE branch
|
|
1618
|
+
XW_2 = X * (1 / (probs_opt - 1 + treat))[:, None] * np.sqrt(sample_weights)[:, None]
|
|
1619
|
+
XG_2 = -X * ((treat - probs_opt)**2 / (probs_opt * (1 - probs_opt)))[:, None] * sample_weights[:, None]
|
|
1620
|
+
|
|
1621
|
+
# Assemble G and W matrices based on identification mode
|
|
1622
|
+
if bal_only: # method='exact'
|
|
1623
|
+
# Balance conditions only
|
|
1624
|
+
G = (XG_2.T @ X) / n
|
|
1625
|
+
W1 = XW_2.T
|
|
1626
|
+
W = XprimeX_inv
|
|
1627
|
+
else: # method='over'
|
|
1628
|
+
# Score + balance conditions
|
|
1629
|
+
G = np.hstack([(XG_1.T @ X), (XG_2.T @ X)]) / n
|
|
1630
|
+
W1 = np.vstack([XW_1.T, XW_2.T])
|
|
1631
|
+
|
|
1632
|
+
# Select W matrix based on estimation method
|
|
1633
|
+
if two_step:
|
|
1634
|
+
W = this_inv_V # Reuse precomputed
|
|
1635
|
+
else:
|
|
1636
|
+
W = _gmm_func(beta_opt, X, treat, sample_weights, att, inv_V=None)['inv_V']
|
|
1637
|
+
|
|
1638
|
+
# Sandwich formula
|
|
1639
|
+
Omega = (W1 @ W1.T) / n # Moment condition covariance
|
|
1640
|
+
GWG = G @ W @ G.T
|
|
1641
|
+
GWGinv = _r_ginv(GWG) # Moore-Penrose pseudoinverse
|
|
1642
|
+
GWGinvGW = GWGinv @ G @ W
|
|
1643
|
+
vcov = GWGinvGW @ Omega @ GWGinvGW.T
|
|
1644
|
+
|
|
1645
|
+
return vcov
|
|
1646
|
+
|
|
1647
|
+
|
|
1648
|
+
def cbps_binary_fit(
|
|
1649
|
+
treat: np.ndarray,
|
|
1650
|
+
X: np.ndarray,
|
|
1651
|
+
att: int = 1,
|
|
1652
|
+
method: str = 'over',
|
|
1653
|
+
two_step: bool = True,
|
|
1654
|
+
standardize: bool = True,
|
|
1655
|
+
sample_weights: Optional[np.ndarray] = None,
|
|
1656
|
+
iterations: int = 1000,
|
|
1657
|
+
XprimeX_inv: Optional[np.ndarray] = None,
|
|
1658
|
+
verbose: int = 0,
|
|
1659
|
+
init_params: Optional[np.ndarray] = None,
|
|
1660
|
+
show_progress: bool = False,
|
|
1661
|
+
**kwargs
|
|
1662
|
+
) -> Dict:
|
|
1663
|
+
"""
|
|
1664
|
+
Estimate covariate balancing propensity scores for binary treatments.
|
|
1665
|
+
|
|
1666
|
+
Implements the covariate balancing propensity score (CBPS) methodology
|
|
1667
|
+
for binary treatment assignments using generalized method of moments
|
|
1668
|
+
(GMM) estimation. The function simultaneously optimizes covariate balance
|
|
1669
|
+
and prediction of treatment assignment.
|
|
1670
|
+
|
|
1671
|
+
Parameters
|
|
1672
|
+
----------
|
|
1673
|
+
treat : np.ndarray
|
|
1674
|
+
Binary treatment indicator vector coded as 0/1, shape (n,).
|
|
1675
|
+
X : np.ndarray
|
|
1676
|
+
Covariate matrix including intercept column, shape (n, k).
|
|
1677
|
+
The intercept should be the first column.
|
|
1678
|
+
att : int, default 1
|
|
1679
|
+
Target estimand for estimation:
|
|
1680
|
+
- 0: Average treatment effect (ATE)
|
|
1681
|
+
- 1: Average treatment effect on the treated (ATT) with treatment=1
|
|
1682
|
+
- 2: Average treatment effect on the treated (ATT) with treatment=0
|
|
1683
|
+
method : {'over', 'exact'}, default 'over'
|
|
1684
|
+
GMM estimation method:
|
|
1685
|
+
- 'over': Over-identified GMM combining likelihood and balance conditions
|
|
1686
|
+
- 'exact': Exactly-identified GMM using balance conditions only
|
|
1687
|
+
two_step : bool, default True
|
|
1688
|
+
GMM estimator type:
|
|
1689
|
+
- True: Two-step GMM with pre-computed weight matrix (faster)
|
|
1690
|
+
- False: Continuous-updating GMM with iterative weight updates
|
|
1691
|
+
standardize : bool, default True
|
|
1692
|
+
Weight standardization:
|
|
1693
|
+
- True: Weights sum to 1 within each treatment group
|
|
1694
|
+
- False: Return Horvitz-Thompson weights
|
|
1695
|
+
sample_weights : np.ndarray, optional
|
|
1696
|
+
Sampling weights for observations. If None, defaults to equal weights.
|
|
1697
|
+
iterations : int, default 1000
|
|
1698
|
+
Maximum number of iterations for the optimization algorithm.
|
|
1699
|
+
XprimeX_inv : np.ndarray, optional
|
|
1700
|
+
Pre-computed inverse of X'X matrix for balance loss computation.
|
|
1701
|
+
init_params : np.ndarray, optional
|
|
1702
|
+
Initial parameter values for warm start. If provided, skips GLM
|
|
1703
|
+
initialization and uses these values directly. Length must equal
|
|
1704
|
+
the number of columns in X.
|
|
1705
|
+
theoretical_exact : bool, default False (passed via **kwargs)
|
|
1706
|
+
Only applicable when method='exact':
|
|
1707
|
+
- True: Direct equation solver for moment conditions (precision ~1e-15)
|
|
1708
|
+
- False: Balance loss optimization (R-compatible, precision ~1e-6)
|
|
1709
|
+
**kwargs
|
|
1710
|
+
Additional arguments passed to scipy.optimize.minimize.
|
|
1711
|
+
|
|
1712
|
+
Returns
|
|
1713
|
+
-------
|
|
1714
|
+
dict
|
|
1715
|
+
Fitted CBPS object containing:
|
|
1716
|
+
- coefficients: Estimated propensity score coefficients, shape (k, 1)
|
|
1717
|
+
- fitted_values: Estimated propensity scores, shape (n,)
|
|
1718
|
+
- weights: CBPS weights for causal effect estimation, shape (n,)
|
|
1719
|
+
- J: J-statistic for overidentification test
|
|
1720
|
+
- var: Asymptotic variance-covariance matrix, shape (k, k)
|
|
1721
|
+
- converged: Boolean convergence indicator
|
|
1722
|
+
- mle_J: Maximum likelihood J-statistic
|
|
1723
|
+
- deviance: Model deviance
|
|
1724
|
+
- linear_predictor: Linear predictor values (X @ coefficients)
|
|
1725
|
+
- y: Treatment indicator vector
|
|
1726
|
+
- x: Covariate matrix
|
|
1727
|
+
|
|
1728
|
+
Notes
|
|
1729
|
+
-----
|
|
1730
|
+
The algorithm implements the following key steps:
|
|
1731
|
+
1. Initial MLE estimation for starting values
|
|
1732
|
+
2. Balance loss optimization for initial GMM values
|
|
1733
|
+
3. GMM optimization to satisfy both score and balance conditions
|
|
1734
|
+
4. Final weight computation and diagnostics
|
|
1735
|
+
|
|
1736
|
+
For ATT estimation, weights are constructed to balance covariates between
|
|
1737
|
+
the treated group and the weighted control group. For ATE estimation,
|
|
1738
|
+
weights balance all groups simultaneously.
|
|
1739
|
+
|
|
1740
|
+
References
|
|
1741
|
+
----------
|
|
1742
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
1743
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
1744
|
+
https://doi.org/10.1111/rssb.12027
|
|
1745
|
+
"""
|
|
1746
|
+
# Ensure dense matrix (sparse input auto-converted)
|
|
1747
|
+
X = ensure_dense(X)
|
|
1748
|
+
treat = np.asarray(treat, dtype=float).ravel()
|
|
1749
|
+
|
|
1750
|
+
# Input validation: NaN/Inf check (before any computation)
|
|
1751
|
+
validate_cbps_input(
|
|
1752
|
+
treat, X,
|
|
1753
|
+
min_observations=2,
|
|
1754
|
+
module_name="Binary CBPS",
|
|
1755
|
+
check_treatment_variance=False
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
# Normalize att parameter (support string: 'ate', 'att', 'atc')
|
|
1759
|
+
att = _normalize_att(att)
|
|
1760
|
+
|
|
1761
|
+
n = len(treat)
|
|
1762
|
+
bal_only = (method == 'exact')
|
|
1763
|
+
|
|
1764
|
+
# Note: SVD preprocessing is applied in the CBPS() main function before
|
|
1765
|
+
# calling this function, matching R package's CBPSMain.R behavior.
|
|
1766
|
+
# The X passed here may already be SVD-transformed (U matrix).
|
|
1767
|
+
# X is not modified in-place; use view to avoid unnecessary copy
|
|
1768
|
+
X_orig = X
|
|
1769
|
+
|
|
1770
|
+
# Full rank check
|
|
1771
|
+
k = np.linalg.matrix_rank(X)
|
|
1772
|
+
if k < X.shape[1]:
|
|
1773
|
+
raise ValueError(
|
|
1774
|
+
f"X is not full rank: rank={k} < ncol={X.shape[1]}. "
|
|
1775
|
+
f"Suggestions: (1) Remove collinear variables, "
|
|
1776
|
+
f"(2) Check for duplicate columns, "
|
|
1777
|
+
f"(3) Use hdCBPS for automatic variable selection."
|
|
1778
|
+
)
|
|
1779
|
+
|
|
1780
|
+
# Step 1: Normalize sample weights
|
|
1781
|
+
sample_weights = normalize_sample_weights(sample_weights, n)
|
|
1782
|
+
n_c = np.sum(sample_weights[treat == 0])
|
|
1783
|
+
n_t = np.sum(sample_weights[treat == 1])
|
|
1784
|
+
|
|
1785
|
+
# Compute XprimeX_inv
|
|
1786
|
+
if XprimeX_inv is None:
|
|
1787
|
+
sw_sqrt_X = np.sqrt(sample_weights)[:, None] * X
|
|
1788
|
+
XprimeX = sw_sqrt_X.T @ sw_sqrt_X
|
|
1789
|
+
XprimeX_inv = _r_ginv(XprimeX)
|
|
1790
|
+
|
|
1791
|
+
# Step 2: GLM initialization (or warm start)
|
|
1792
|
+
if init_params is not None:
|
|
1793
|
+
init_params = np.asarray(init_params, dtype=float)
|
|
1794
|
+
if len(init_params) != X.shape[1]:
|
|
1795
|
+
raise ValueError(
|
|
1796
|
+
f"init_params length {len(init_params)} != {X.shape[1]} covariates. "
|
|
1797
|
+
f"Ensure init_params matches the number of columns in the design matrix."
|
|
1798
|
+
)
|
|
1799
|
+
# init_params is never modified in-place downstream;
|
|
1800
|
+
# _vmmin_bfgs copies internally, _compute_diagnostics is read-only
|
|
1801
|
+
beta_init = init_params
|
|
1802
|
+
beta_glm = init_params
|
|
1803
|
+
else:
|
|
1804
|
+
gmm_loss_func_for_init = lambda b: _gmm_loss(b, X, treat, sample_weights, att, None)
|
|
1805
|
+
beta_init, beta_glm = _glm_init(
|
|
1806
|
+
treat, X, sample_weights, att, gmm_loss_func_for_init
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
# Step 3: Pre-compute inverse covariance matrix (for two-step GMM)
|
|
1810
|
+
gmm_init = beta_init
|
|
1811
|
+
gmm_result_init = _gmm_func(gmm_init, X, treat, sample_weights, att, inv_V=None)
|
|
1812
|
+
this_inv_V = gmm_result_init['inv_V']
|
|
1813
|
+
|
|
1814
|
+
# Configure logging from verbose parameter (backward compatibility)
|
|
1815
|
+
if verbose >= 2:
|
|
1816
|
+
set_verbosity(2)
|
|
1817
|
+
elif verbose >= 1:
|
|
1818
|
+
set_verbosity(1)
|
|
1819
|
+
|
|
1820
|
+
# Step 4: Balance loss optimization for initial values
|
|
1821
|
+
logger.info(f"Starting balance optimization (max_iter={iterations})...")
|
|
1822
|
+
|
|
1823
|
+
opt_bal = _optimize_balance(
|
|
1824
|
+
gmm_init, X, treat, sample_weights, XprimeX_inv, att,
|
|
1825
|
+
two_step, iterations, bal_only=bal_only, show_progress=show_progress, **kwargs
|
|
1826
|
+
)
|
|
1827
|
+
|
|
1828
|
+
logger.info(f"Balance optimization complete: loss={opt_bal.fun:.6f}, converged={opt_bal.success}")
|
|
1829
|
+
beta_bal = opt_bal.x # Extract balance-optimized coefficients
|
|
1830
|
+
|
|
1831
|
+
# Step 5: GMM optimization (for method='over') or exact moment solving
|
|
1832
|
+
if bal_only:
|
|
1833
|
+
# For just-identified GMM, user can choose:
|
|
1834
|
+
# - theoretical_exact=True: Use equation solver (precision ~1e-15)
|
|
1835
|
+
# - theoretical_exact=False: Use balance loss (R-compatible, precision ~1e-6)
|
|
1836
|
+
|
|
1837
|
+
use_theoretical_exact = kwargs.get('theoretical_exact', False)
|
|
1838
|
+
|
|
1839
|
+
if use_theoretical_exact:
|
|
1840
|
+
# Direct moment equation solving (theoretically correct)
|
|
1841
|
+
beta_opt, root_success, moments_final, solver_method = _solve_moment_equations(
|
|
1842
|
+
beta_bal, # Use balance-optimized result as initial value
|
|
1843
|
+
X, treat, sample_weights, att, n, iterations
|
|
1844
|
+
)
|
|
1845
|
+
|
|
1846
|
+
max_moment = np.max(np.abs(moments_final))
|
|
1847
|
+
|
|
1848
|
+
if root_success:
|
|
1849
|
+
if max_moment < 1e-8:
|
|
1850
|
+
# Perfect convergence to theoretical precision
|
|
1851
|
+
pass
|
|
1852
|
+
else:
|
|
1853
|
+
# Solver converged but moment not satisfied (rare)
|
|
1854
|
+
warnings.warn(
|
|
1855
|
+
f"theoretical_exact=True: Equation solver converged but moment={max_moment:.2e}, "
|
|
1856
|
+
f"below theoretical requirement <1e-10. Consider better variable preprocessing.",
|
|
1857
|
+
UserWarning
|
|
1858
|
+
)
|
|
1859
|
+
else:
|
|
1860
|
+
# Equation solver failed, fall back to balance optimization
|
|
1861
|
+
warnings.warn(
|
|
1862
|
+
f"theoretical_exact=True: Equation solver failed ({solver_method}), "
|
|
1863
|
+
f"falling back to balance loss optimization result.",
|
|
1864
|
+
UserWarning
|
|
1865
|
+
)
|
|
1866
|
+
beta_opt = beta_bal
|
|
1867
|
+
|
|
1868
|
+
# Update opt1 object for interface compatibility
|
|
1869
|
+
opt1 = opt_bal
|
|
1870
|
+
opt1.x = beta_opt
|
|
1871
|
+
else:
|
|
1872
|
+
# R-compatible implementation: balance loss optimization
|
|
1873
|
+
opt1 = opt_bal
|
|
1874
|
+
|
|
1875
|
+
# Check moment convergence
|
|
1876
|
+
moments_final = _compute_moment_conditions(
|
|
1877
|
+
opt1.x, X, treat, sample_weights, att, n
|
|
1878
|
+
)
|
|
1879
|
+
max_moment = np.max(np.abs(moments_final))
|
|
1880
|
+
|
|
1881
|
+
# Note: For method='exact', the J-statistic is computed using over-identified
|
|
1882
|
+
# GMM conditions (score + balance). This means J > 0 even for just-identified
|
|
1883
|
+
# models, reflecting the degree to which score conditions are violated.
|
|
1884
|
+
|
|
1885
|
+
if max_moment > 1e-6:
|
|
1886
|
+
warnings.warn(
|
|
1887
|
+
f"method='exact': Moment conditions converged to {max_moment:.2e}, "
|
|
1888
|
+
f"below theoretical GMM precision <1e-10. This is a known limitation "
|
|
1889
|
+
f"of balance loss optimization.\n"
|
|
1890
|
+
f"For exact moment=0 satisfaction (~1e-15 precision), "
|
|
1891
|
+
f"use theoretical_exact=True in CBPS() call.",
|
|
1892
|
+
UserWarning
|
|
1893
|
+
)
|
|
1894
|
+
else:
|
|
1895
|
+
logger.info("Starting GMM optimization with dual initialization...")
|
|
1896
|
+
|
|
1897
|
+
opt1 = _optimize_gmm_dual_init(
|
|
1898
|
+
gmm_init, beta_bal, X, treat, sample_weights, att,
|
|
1899
|
+
this_inv_V, two_step, iterations, show_progress=show_progress, **kwargs
|
|
1900
|
+
)
|
|
1901
|
+
|
|
1902
|
+
logger.info(f"GMM optimization complete: J={opt1.fun:.6f}, converged={opt1.success}")
|
|
1903
|
+
|
|
1904
|
+
# Step 6: Final probabilities and weights
|
|
1905
|
+
beta_opt = opt1.x
|
|
1906
|
+
probs_opt, w_opt = _compute_final_weights(
|
|
1907
|
+
beta_opt, X, treat, sample_weights, att, standardize
|
|
1908
|
+
)
|
|
1909
|
+
|
|
1910
|
+
# Step 7: Compute J-statistic, deviance, and null deviance
|
|
1911
|
+
J_opt, mle_J, deviance, nulldeviance = _compute_diagnostics(
|
|
1912
|
+
beta_opt, beta_glm, probs_opt, treat, sample_weights,
|
|
1913
|
+
att, two_step, this_inv_V, X
|
|
1914
|
+
)
|
|
1915
|
+
|
|
1916
|
+
# Note: For method='exact', the J-statistic is computed using over-identified
|
|
1917
|
+
# GMM conditions. Theoretically, J should be 0 for just-identified models,
|
|
1918
|
+
# but the full GMM conditions provide a useful diagnostic.
|
|
1919
|
+
|
|
1920
|
+
# Step 8: Variance-covariance matrix
|
|
1921
|
+
vcov = _compute_vcov(
|
|
1922
|
+
beta_opt, probs_opt, treat, X, sample_weights, att,
|
|
1923
|
+
bal_only, XprimeX_inv, this_inv_V, two_step, n
|
|
1924
|
+
)
|
|
1925
|
+
|
|
1926
|
+
# Step 9: Construct return dictionary
|
|
1927
|
+
output = {
|
|
1928
|
+
'coefficients': beta_opt.reshape(-1, 1), # (k, 1) column vector
|
|
1929
|
+
'fitted_values': probs_opt,
|
|
1930
|
+
'linear_predictor': X @ beta_opt,
|
|
1931
|
+
'deviance': deviance,
|
|
1932
|
+
'nulldeviance': nulldeviance,
|
|
1933
|
+
'weights': w_opt,
|
|
1934
|
+
'y': treat,
|
|
1935
|
+
'x': X_orig,
|
|
1936
|
+
'converged': opt1.success,
|
|
1937
|
+
'J': J_opt,
|
|
1938
|
+
'var': vcov,
|
|
1939
|
+
'mle_J': mle_J
|
|
1940
|
+
}
|
|
1941
|
+
|
|
1942
|
+
return output
|
|
1943
|
+
|