cbps 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbps/__init__.py +3462 -0
- cbps/constants.py +46 -0
- cbps/core/__init__.py +93 -0
- cbps/core/cbps_binary.py +1943 -0
- cbps/core/cbps_continuous.py +945 -0
- cbps/core/cbps_multitreat.py +1123 -0
- cbps/core/cbps_optimal.py +507 -0
- cbps/core/results.py +1447 -0
- cbps/data/Blackwell.csv +571 -0
- cbps/data/LaLonde.csv +3213 -0
- cbps/data/npcbps_continuous_sim.csv +501 -0
- cbps/data/nsw.csv +723 -0
- cbps/data/nsw_dw.csv +446 -0
- cbps/data/political_ads_urban_niebler.csv +16266 -0
- cbps/data/psid_controls.csv +2491 -0
- cbps/data/psid_controls2.csv +254 -0
- cbps/data/psid_controls3.csv +129 -0
- cbps/data/simulation_dgp1_seed12345.csv +201 -0
- cbps/data/simulation_dgp2_seed12345.csv +201 -0
- cbps/data/simulation_dgp3_seed12345.csv +201 -0
- cbps/data/simulation_dgp4_seed12345.csv +201 -0
- cbps/datasets/__init__.py +78 -0
- cbps/datasets/blackwell.py +112 -0
- cbps/datasets/continuous.py +223 -0
- cbps/datasets/lalonde.py +272 -0
- cbps/datasets/npcbps_sim.py +101 -0
- cbps/diagnostics/__init__.py +101 -0
- cbps/diagnostics/balance.py +760 -0
- cbps/diagnostics/balance_cbmsm_addon.py +162 -0
- cbps/diagnostics/continuous_diagnostics.py +259 -0
- cbps/diagnostics/normality.py +173 -0
- cbps/diagnostics/ocbps_conditions.py +197 -0
- cbps/diagnostics/overlap.py +198 -0
- cbps/diagnostics/plots.py +1193 -0
- cbps/diagnostics/weights_diag.py +205 -0
- cbps/highdim/__init__.py +84 -0
- cbps/highdim/gmm_loss.py +340 -0
- cbps/highdim/hdcbps.py +1078 -0
- cbps/highdim/lasso_utils.py +498 -0
- cbps/highdim/weight_funcs.py +298 -0
- cbps/inference/__init__.py +42 -0
- cbps/inference/asyvar.py +621 -0
- cbps/inference/vcov_outcome.py +217 -0
- cbps/iv/__init__.py +48 -0
- cbps/iv/cbiv.py +2603 -0
- cbps/logging_config.py +45 -0
- cbps/msm/__init__.py +45 -0
- cbps/msm/cbmsm.py +1871 -0
- cbps/msm/rank_diagnostics.py +112 -0
- cbps/nonparametric/__init__.py +58 -0
- cbps/nonparametric/cholesky_whitening.py +232 -0
- cbps/nonparametric/empirical_likelihood.py +339 -0
- cbps/nonparametric/npcbps.py +1036 -0
- cbps/nonparametric/taylor_approx.py +207 -0
- cbps/py.typed +0 -0
- cbps/sklearn/__init__.py +42 -0
- cbps/sklearn/estimator.py +378 -0
- cbps/utils/__init__.py +82 -0
- cbps/utils/formula.py +415 -0
- cbps/utils/helpers.py +378 -0
- cbps/utils/numerics.py +438 -0
- cbps/utils/r_compat.py +109 -0
- cbps/utils/validation.py +224 -0
- cbps/utils/variance_transform.py +483 -0
- cbps/utils/weights.py +586 -0
- cbps-0.2.0.dist-info/METADATA +1090 -0
- cbps-0.2.0.dist-info/RECORD +70 -0
- cbps-0.2.0.dist-info/WHEEL +5 -0
- cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
- cbps-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,945 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Covariate Balancing Propensity Score for Continuous Treatments
|
|
3
|
+
===============================================================
|
|
4
|
+
|
|
5
|
+
This module implements the Covariate Balancing Propensity Score (CBPS) methodology
|
|
6
|
+
for continuous treatments using generalized propensity scores (GPS). The implementation
|
|
7
|
+
extends the binary CBPS framework to handle continuous treatment variables through
|
|
8
|
+
covariate whitening and normal density estimation.
|
|
9
|
+
|
|
10
|
+
Methodology
|
|
11
|
+
-----------
|
|
12
|
+
The continuous CBPS estimates the generalized propensity score by maximizing the
|
|
13
|
+
covariate balance. The method involves:
|
|
14
|
+
1. Cholesky whitening of covariates with sample weights.
|
|
15
|
+
2. Log-space normal density computation for numerical stability.
|
|
16
|
+
3. GMM optimization with multiple starting values.
|
|
17
|
+
4. Coefficient inverse transformation from whitened to original space.
|
|
18
|
+
|
|
19
|
+
References
|
|
20
|
+
----------
|
|
21
|
+
Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity score
|
|
22
|
+
for a continuous treatment: Application to the efficacy of political advertisements.
|
|
23
|
+
The Annals of Applied Statistics, 12(1), 156-177.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from typing import Dict, Any, Optional
|
|
27
|
+
import warnings
|
|
28
|
+
import numpy as np
|
|
29
|
+
import scipy.stats
|
|
30
|
+
import scipy.optimize
|
|
31
|
+
import scipy.linalg
|
|
32
|
+
import statsmodels.api as sm
|
|
33
|
+
|
|
34
|
+
from cbps.utils.validation import validate_cbps_input
|
|
35
|
+
from cbps.utils.validation import ensure_dense
|
|
36
|
+
from cbps.logging_config import logger, set_verbosity
|
|
37
|
+
from cbps.constants import DEFAULT_CONFIG
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ========== Constants (sourced from unified NumericalConfig) ==========
|
|
41
|
+
PROBS_MIN = DEFAULT_CONFIG.probs_min
|
|
42
|
+
CONST_COL_THRESHOLD = DEFAULT_CONFIG.const_col_threshold
|
|
43
|
+
ALPHA_BOUNDS = (0.8, 1.1)
|
|
44
|
+
CLIP_RANGE = DEFAULT_CONFIG.log_clip_range
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def cbps_continuous_fit(
|
|
48
|
+
treat: np.ndarray,
|
|
49
|
+
X: np.ndarray,
|
|
50
|
+
method: str = 'over',
|
|
51
|
+
two_step: bool = True,
|
|
52
|
+
iterations: int = 1000,
|
|
53
|
+
standardize: bool = True,
|
|
54
|
+
sample_weights: Optional[np.ndarray] = None,
|
|
55
|
+
verbose: int = 0
|
|
56
|
+
) -> Dict[str, Any]:
|
|
57
|
+
"""
|
|
58
|
+
Fit the Covariate Balancing Propensity Score model for continuous treatments.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
treat : np.ndarray
|
|
63
|
+
Continuous treatment vector, shape (n,).
|
|
64
|
+
X : np.ndarray
|
|
65
|
+
Covariate matrix (including intercept column), shape (n, k).
|
|
66
|
+
method : {'over', 'exact'}, default='over'
|
|
67
|
+
Estimation method:
|
|
68
|
+
- 'over': Over-identified GMM (score + balance + sigma conditions).
|
|
69
|
+
- 'exact': Exactly identified GMM (balance + sigma conditions).
|
|
70
|
+
two_step : bool, default=True
|
|
71
|
+
If True, use two-step GMM with fixed weight matrix.
|
|
72
|
+
If False, use continuously updating GMM.
|
|
73
|
+
iterations : int, default=1000
|
|
74
|
+
Maximum number of optimization iterations.
|
|
75
|
+
standardize : bool, default=True
|
|
76
|
+
If True, standardize weights to sum to the sample size.
|
|
77
|
+
sample_weights : np.ndarray, optional
|
|
78
|
+
Sampling weights. Defaults to uniform weights if None.
|
|
79
|
+
Weights will be normalized to sum to n.
|
|
80
|
+
verbose : int, default=0
|
|
81
|
+
Verbosity level.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
dict
|
|
86
|
+
Dictionary containing estimation results:
|
|
87
|
+
- coefficients: Estimated parameters.
|
|
88
|
+
- fitted_values: Estimated propensity scores.
|
|
89
|
+
- weights: Inverse probability weights.
|
|
90
|
+
- deviance: Model deviance.
|
|
91
|
+
- converged: Convergence status.
|
|
92
|
+
- J: GMM loss function value.
|
|
93
|
+
- var: Variance-covariance matrix.
|
|
94
|
+
- sigmasq: Estimated residual variance.
|
|
95
|
+
- Ttilde: Standardized treatment.
|
|
96
|
+
- Xtilde: Whitened covariates.
|
|
97
|
+
|
|
98
|
+
Notes
|
|
99
|
+
-----
|
|
100
|
+
The algorithm performs Cholesky whitening on covariates, standardizes the treatment,
|
|
101
|
+
and then optimizes the GMM objective function. It handles potential numerical
|
|
102
|
+
instability in the weight matrix calculation through regularization when necessary.
|
|
103
|
+
"""
|
|
104
|
+
# Input validation
|
|
105
|
+
X = ensure_dense(X)
|
|
106
|
+
validate_cbps_input(
|
|
107
|
+
treat, X,
|
|
108
|
+
min_observations=2,
|
|
109
|
+
module_name="Continuous CBPS",
|
|
110
|
+
check_treatment_variance=True
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Auto-fallback: if method='over' encounters infinite V matrix,
|
|
114
|
+
# fall back to method='exact' (matching R CBPS behavior)
|
|
115
|
+
if method == 'over':
|
|
116
|
+
try:
|
|
117
|
+
return _cbps_continuous_fit_impl(
|
|
118
|
+
treat, X, method=method, two_step=two_step,
|
|
119
|
+
iterations=iterations, standardize=standardize,
|
|
120
|
+
sample_weights=sample_weights, verbose=verbose
|
|
121
|
+
)
|
|
122
|
+
except ValueError as e:
|
|
123
|
+
if "infinite value in the weighting matrix" in str(e).lower():
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"Over-identified GMM failed due to infinite V matrix values. "
|
|
126
|
+
f'Automatically falling back to method="exact" '
|
|
127
|
+
f"(just-identified). Original error: {e}",
|
|
128
|
+
UserWarning
|
|
129
|
+
)
|
|
130
|
+
return _cbps_continuous_fit_impl(
|
|
131
|
+
treat, X, method='exact', two_step=two_step,
|
|
132
|
+
iterations=iterations, standardize=standardize,
|
|
133
|
+
sample_weights=sample_weights, verbose=verbose
|
|
134
|
+
)
|
|
135
|
+
raise
|
|
136
|
+
else:
|
|
137
|
+
return _cbps_continuous_fit_impl(
|
|
138
|
+
treat, X, method=method, two_step=two_step,
|
|
139
|
+
iterations=iterations, standardize=standardize,
|
|
140
|
+
sample_weights=sample_weights, verbose=verbose
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _cbps_continuous_fit_impl(
|
|
145
|
+
treat: np.ndarray,
|
|
146
|
+
X: np.ndarray,
|
|
147
|
+
method: str = 'over',
|
|
148
|
+
two_step: bool = True,
|
|
149
|
+
iterations: int = 1000,
|
|
150
|
+
standardize: bool = True,
|
|
151
|
+
sample_weights: Optional[np.ndarray] = None,
|
|
152
|
+
verbose: int = 0
|
|
153
|
+
) -> Dict[str, Any]:
|
|
154
|
+
"""Internal implementation of cbps_continuous_fit."""
|
|
155
|
+
|
|
156
|
+
# Configure logging from verbose parameter (backward compatibility)
|
|
157
|
+
if verbose >= 2:
|
|
158
|
+
set_verbosity(2)
|
|
159
|
+
elif verbose >= 1:
|
|
160
|
+
set_verbosity(1)
|
|
161
|
+
|
|
162
|
+
# Initialization
|
|
163
|
+
n = len(treat)
|
|
164
|
+
k = X.shape[1]
|
|
165
|
+
bal_only = (method == 'exact')
|
|
166
|
+
|
|
167
|
+
# Normalize sample weights
|
|
168
|
+
if sample_weights is None:
|
|
169
|
+
sample_weights = np.ones(n)
|
|
170
|
+
sample_weights = sample_weights / sample_weights.mean()
|
|
171
|
+
if not np.isclose(sample_weights.sum(), n, atol=1e-10):
|
|
172
|
+
warnings.warn(f"Sample weights normalization check failed: sum={sample_weights.sum():.6f} != n={n}")
|
|
173
|
+
|
|
174
|
+
# Save original X
|
|
175
|
+
X_orig = X.copy()
|
|
176
|
+
|
|
177
|
+
# ========== Covariate Whitening Preprocessing ==========
|
|
178
|
+
|
|
179
|
+
# Detect constant columns
|
|
180
|
+
col_std = np.std(X, axis=0, ddof=1)
|
|
181
|
+
int_ind = np.where(col_std <= CONST_COL_THRESHOLD)[0]
|
|
182
|
+
non_const_ind = np.where(col_std > CONST_COL_THRESHOLD)[0]
|
|
183
|
+
|
|
184
|
+
if len(non_const_ind) == 0:
|
|
185
|
+
warnings.warn(
|
|
186
|
+
"All columns are constant (sd <= 1e-10). "
|
|
187
|
+
"Continuous CBPS will degenerate to no-covariate model. "
|
|
188
|
+
"This is a valid edge case where the model only standardizes the treatment distribution.",
|
|
189
|
+
UserWarning
|
|
190
|
+
)
|
|
191
|
+
# Degenerate case: Xtilde is just X
|
|
192
|
+
Xtilde = X.copy()
|
|
193
|
+
else:
|
|
194
|
+
# Perform Cholesky whitening on non-constant columns
|
|
195
|
+
X_non_const = X[:, non_const_ind]
|
|
196
|
+
sw_X_non_const = sample_weights[:, None] * X_non_const
|
|
197
|
+
cov_weighted = np.cov(sw_X_non_const.T, ddof=1)
|
|
198
|
+
|
|
199
|
+
assert np.allclose(cov_weighted, cov_weighted.T, atol=1e-12), \
|
|
200
|
+
"Weighted covariance matrix must be symmetric"
|
|
201
|
+
|
|
202
|
+
# Cholesky decomposition to get upper triangular U
|
|
203
|
+
U = scipy.linalg.cholesky(cov_weighted, lower=False)
|
|
204
|
+
|
|
205
|
+
assert np.allclose(np.tril(U, k=-1), 0, atol=1e-12), "U must be upper triangular"
|
|
206
|
+
assert np.all(np.diag(U) > 0), "Diagonal elements of U must be positive"
|
|
207
|
+
|
|
208
|
+
U_inv = np.linalg.inv(U)
|
|
209
|
+
|
|
210
|
+
# Whitening transformation
|
|
211
|
+
X_white = sw_X_non_const @ U_inv
|
|
212
|
+
|
|
213
|
+
# Centering (no scaling)
|
|
214
|
+
X_white_centered = X_white - X_white.mean(axis=0)
|
|
215
|
+
|
|
216
|
+
assert abs(X_white_centered.mean()) < 1e-10, "Whitened data should be centered"
|
|
217
|
+
|
|
218
|
+
# Combine constant and whitened columns
|
|
219
|
+
if len(int_ind) > 0:
|
|
220
|
+
X_const = X[:, int_ind]
|
|
221
|
+
Xtilde = np.column_stack([X_const, X_white_centered])
|
|
222
|
+
else:
|
|
223
|
+
Xtilde = X_white_centered
|
|
224
|
+
|
|
225
|
+
# Verify shape consistency
|
|
226
|
+
if Xtilde.shape != X.shape:
|
|
227
|
+
raise ValueError(f"Xtilde shape {Xtilde.shape} != X shape {X.shape}")
|
|
228
|
+
|
|
229
|
+
# ========== Auxiliary Matrix Calculation ==========
|
|
230
|
+
|
|
231
|
+
# Pre-compute weighted Xtilde
|
|
232
|
+
wtXilde = sample_weights[:, None] * Xtilde
|
|
233
|
+
|
|
234
|
+
# Standardize treatment (zero mean, unit variance)
|
|
235
|
+
sw_treat = sample_weights * treat
|
|
236
|
+
Ttilde = (sw_treat - sw_treat.mean()) / sw_treat.std(ddof=1)
|
|
237
|
+
|
|
238
|
+
# Internal consistency checks
|
|
239
|
+
assert abs(Ttilde.mean()) < 1e-10
|
|
240
|
+
assert abs(Ttilde.std(ddof=1) - 1) < 1e-10
|
|
241
|
+
|
|
242
|
+
n_identity_vec = np.ones((n, 1))
|
|
243
|
+
|
|
244
|
+
# ========== Stabilizers Calculation ==========
|
|
245
|
+
# Calculate log marginal density log f(T*)
|
|
246
|
+
# Ideally constant, but computed per observation for robustness
|
|
247
|
+
|
|
248
|
+
pdf_vals = scipy.stats.norm.pdf(Ttilde, 0, 1)
|
|
249
|
+
pdf_clipped = np.clip(pdf_vals, PROBS_MIN, 1 - PROBS_MIN)
|
|
250
|
+
stabilizers = np.log(pdf_clipped)
|
|
251
|
+
|
|
252
|
+
# ========== GMM Objective Function ==========
|
|
253
|
+
|
|
254
|
+
def gmm_func(params_curr: np.ndarray, invV: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
|
255
|
+
"""
|
|
256
|
+
GMM objective function for over-identified case.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
params_curr : np.ndarray
|
|
261
|
+
Parameter vector [beta, log(sigma^2)].
|
|
262
|
+
invV : np.ndarray, optional
|
|
263
|
+
Inverse weight matrix V.
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
dict
|
|
268
|
+
Dictionary containing loss value and inverse V matrix.
|
|
269
|
+
"""
|
|
270
|
+
beta_curr = params_curr[:-1]
|
|
271
|
+
sigmasq = np.exp(params_curr[-1])
|
|
272
|
+
|
|
273
|
+
# Log conditional density
|
|
274
|
+
log_dens = scipy.stats.norm.logpdf(
|
|
275
|
+
Ttilde,
|
|
276
|
+
loc=Xtilde @ beta_curr,
|
|
277
|
+
scale=np.sqrt(sigmasq)
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Log-space clipping
|
|
281
|
+
log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
|
|
282
|
+
log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
|
|
283
|
+
|
|
284
|
+
# Weight calculation in log space
|
|
285
|
+
log_diff = stabilizers - log_dens
|
|
286
|
+
log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
|
|
287
|
+
w_curr = Ttilde * np.exp(log_diff_clipped)
|
|
288
|
+
|
|
289
|
+
if not np.all(np.isfinite(w_curr)):
|
|
290
|
+
raise ValueError("Weights contain non-finite values")
|
|
291
|
+
|
|
292
|
+
# Construct sample moment conditions gbar
|
|
293
|
+
# Moment 1: Score condition for sigma^2
|
|
294
|
+
gbar_1 = (1/n) * wtXilde.T @ ((Ttilde - Xtilde @ beta_curr) / sigmasq)
|
|
295
|
+
|
|
296
|
+
# Moment 2: Balance condition
|
|
297
|
+
w_curr_del = (1/n) * wtXilde.T @ w_curr
|
|
298
|
+
gbar_2 = w_curr_del.ravel()
|
|
299
|
+
|
|
300
|
+
# Moment 3: Score condition for beta
|
|
301
|
+
gbar_3 = (1/n) * sample_weights.T @ (
|
|
302
|
+
(Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
gbar = np.concatenate([gbar_1.ravel(), gbar_2, [gbar_3]])
|
|
306
|
+
|
|
307
|
+
# Compute V matrix or use pre-computed invV
|
|
308
|
+
if invV is None:
|
|
309
|
+
# Construct V matrix blocks
|
|
310
|
+
V11 = (1/sigmasq) * wtXilde.T @ Xtilde
|
|
311
|
+
V12 = wtXilde.T @ Xtilde / sigmasq
|
|
312
|
+
V13 = wtXilde.T @ n_identity_vec * 0
|
|
313
|
+
|
|
314
|
+
# V22 calculation with scaling vector
|
|
315
|
+
linear_pred = Xtilde @ beta_curr
|
|
316
|
+
linear_pred_sq = linear_pred**2
|
|
317
|
+
term_A = linear_pred_sq / sigmasq
|
|
318
|
+
term_B = np.log(sigmasq + linear_pred_sq)
|
|
319
|
+
|
|
320
|
+
exponent = term_A + term_B
|
|
321
|
+
if np.any(exponent > 700):
|
|
322
|
+
raise ValueError(
|
|
323
|
+
f"Potential overflow in V matrix calculation (max exponent={exponent.max():.2f}). "
|
|
324
|
+
f"Residual variance sigma^2={sigmasq:.6f} might be too small. "
|
|
325
|
+
f"Consider using method='exact'."
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
vec_scaling = np.exp(exponent)
|
|
329
|
+
|
|
330
|
+
if not np.all(np.isfinite(vec_scaling)):
|
|
331
|
+
raise ValueError("V22 scaling vector contains non-finite values.")
|
|
332
|
+
|
|
333
|
+
Xtilde_swept = vec_scaling[:, None] * Xtilde
|
|
334
|
+
V22 = wtXilde.T @ Xtilde_swept
|
|
335
|
+
|
|
336
|
+
V23 = (wtXilde.T @ (-Xtilde @ beta_curr) * (-2/sigmasq)).reshape(-1, 1)
|
|
337
|
+
|
|
338
|
+
V33_scalar = sample_weights.T @ n_identity_vec.ravel() * 2
|
|
339
|
+
V33 = np.array([[V33_scalar]])
|
|
340
|
+
|
|
341
|
+
# Assemble V
|
|
342
|
+
V = (1/n) * np.block([
|
|
343
|
+
[V11, V12, V13],
|
|
344
|
+
[V12, V22, V23],
|
|
345
|
+
[V13.T, V23.T, V33]
|
|
346
|
+
])
|
|
347
|
+
|
|
348
|
+
if not np.allclose(V, V.T, atol=1e-12):
|
|
349
|
+
warnings.warn("V matrix is not symmetric within tolerance")
|
|
350
|
+
|
|
351
|
+
if np.any(np.isinf(V)):
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Encountered an infinite value in the weighting matrix. "
|
|
354
|
+
'Use the just-identified version of CBPS instead by setting method="exact".'
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
invV = scipy.linalg.pinv(V)
|
|
358
|
+
|
|
359
|
+
loss = gbar.T @ invV @ gbar
|
|
360
|
+
|
|
361
|
+
if loss < -1e-6:
|
|
362
|
+
warnings.warn(
|
|
363
|
+
f"GMM loss is negative ({loss:.2e}). Check numerical stability.",
|
|
364
|
+
UserWarning
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
return {'loss': float(loss), 'invV': invV}
|
|
368
|
+
|
|
369
|
+
def gmm_loss(params_curr: np.ndarray, invV: Optional[np.ndarray] = None) -> float:
|
|
370
|
+
return gmm_func(params_curr, invV)['loss']
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# ========== bal_func (exactly-identified, 2 moment conditions) ==========
|
|
374
|
+
|
|
375
|
+
def bal_func(params_curr: np.ndarray) -> Dict[str, float]:
|
|
376
|
+
"""
|
|
377
|
+
Balance objective function for the exactly-identified case.
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
params_curr : np.ndarray
|
|
382
|
+
Parameter vector [beta, log(sigma^2)].
|
|
383
|
+
|
|
384
|
+
Returns
|
|
385
|
+
-------
|
|
386
|
+
dict
|
|
387
|
+
Dictionary containing the balance loss value.
|
|
388
|
+
"""
|
|
389
|
+
beta_curr = params_curr[:-1]
|
|
390
|
+
sigmasq = np.exp(params_curr[-1])
|
|
391
|
+
|
|
392
|
+
# Log conditional density
|
|
393
|
+
log_dens = scipy.stats.norm.logpdf(
|
|
394
|
+
Ttilde,
|
|
395
|
+
loc=Xtilde @ beta_curr,
|
|
396
|
+
scale=np.sqrt(sigmasq)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
|
|
400
|
+
log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
|
|
401
|
+
|
|
402
|
+
# Weight calculation
|
|
403
|
+
log_diff = stabilizers - log_dens
|
|
404
|
+
log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
|
|
405
|
+
w_curr = Ttilde * np.exp(log_diff_clipped)
|
|
406
|
+
|
|
407
|
+
# Construct sample moment conditions
|
|
408
|
+
w_curr_del = (1/n) * wtXilde.T @ w_curr
|
|
409
|
+
|
|
410
|
+
gbar = np.concatenate([
|
|
411
|
+
w_curr_del.ravel(), # Balance condition
|
|
412
|
+
[(1/n) * sample_weights.T @ (
|
|
413
|
+
(Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
|
|
414
|
+
)] # Sigma^2 condition
|
|
415
|
+
])
|
|
416
|
+
|
|
417
|
+
if gbar.shape != (k + 1,):
|
|
418
|
+
raise ValueError(f"Gradient vector shape mismatch: {gbar.shape}")
|
|
419
|
+
|
|
420
|
+
# Loss calculation with identity weight matrix
|
|
421
|
+
loss = gbar.T @ np.eye(k + 1) @ gbar
|
|
422
|
+
|
|
423
|
+
return {'loss': float(loss)}
|
|
424
|
+
|
|
425
|
+
def bal_loss(params_curr: np.ndarray) -> float:
|
|
426
|
+
"""Wrapper for balance loss function."""
|
|
427
|
+
return bal_func(params_curr)['loss']
|
|
428
|
+
|
|
429
|
+
# ========== GMM Gradient Calculation ==========
|
|
430
|
+
|
|
431
|
+
def gmm_gradient(params_curr: np.ndarray, invV: np.ndarray) -> np.ndarray:
|
|
432
|
+
"""
|
|
433
|
+
Gradient of the GMM objective function.
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
params_curr : np.ndarray
|
|
438
|
+
Parameter vector.
|
|
439
|
+
invV : np.ndarray
|
|
440
|
+
Inverse weight matrix V.
|
|
441
|
+
|
|
442
|
+
Returns
|
|
443
|
+
-------
|
|
444
|
+
np.ndarray
|
|
445
|
+
Gradient vector.
|
|
446
|
+
"""
|
|
447
|
+
beta_curr = params_curr[:-1]
|
|
448
|
+
sigmasq = np.exp(params_curr[-1])
|
|
449
|
+
|
|
450
|
+
# Log conditional density
|
|
451
|
+
log_dens = scipy.stats.norm.logpdf(
|
|
452
|
+
Ttilde,
|
|
453
|
+
loc=Xtilde @ beta_curr,
|
|
454
|
+
scale=np.sqrt(sigmasq)
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
|
|
458
|
+
log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
|
|
459
|
+
|
|
460
|
+
# Weights
|
|
461
|
+
log_diff = stabilizers - log_dens
|
|
462
|
+
log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
|
|
463
|
+
w_curr = Ttilde * np.exp(log_diff_clipped)
|
|
464
|
+
|
|
465
|
+
# Recompute gbar
|
|
466
|
+
gbar_1 = (1/n) * wtXilde.T @ ((Ttilde - Xtilde @ beta_curr) / sigmasq)
|
|
467
|
+
w_curr_del = (1/n) * wtXilde.T @ w_curr
|
|
468
|
+
gbar_2 = w_curr_del.ravel()
|
|
469
|
+
gbar_3 = (1/n) * sample_weights.T @ (
|
|
470
|
+
(Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
|
|
471
|
+
)
|
|
472
|
+
gbar = np.concatenate([gbar_1.ravel(), gbar_2, [gbar_3]])
|
|
473
|
+
|
|
474
|
+
# Calculate dgbar blocks
|
|
475
|
+
# dgbar.1.1 (k x k)
|
|
476
|
+
dgbar_1_1 = (-wtXilde.T @ Xtilde) / sigmasq
|
|
477
|
+
|
|
478
|
+
# dgbar.1.2 (1 x k)
|
|
479
|
+
dgbar_1_2 = (
|
|
480
|
+
-sample_weights * (Ttilde - Xtilde @ beta_curr) / (sigmasq**2)
|
|
481
|
+
).reshape(1, -1) @ Xtilde
|
|
482
|
+
|
|
483
|
+
# dgbar.2.1 (k x k)
|
|
484
|
+
vec_L110 = -(Ttilde - Xtilde @ beta_curr) / sigmasq * w_curr
|
|
485
|
+
dgbar_2_1 = (wtXilde.T * vec_L110) @ Xtilde
|
|
486
|
+
|
|
487
|
+
# dgbar.2.2 (1 x k)
|
|
488
|
+
dgbar_2_2 = (
|
|
489
|
+
w_curr * (1/(2*sigmasq) - (Ttilde - Xtilde @ beta_curr)**2 / (2*sigmasq**2))
|
|
490
|
+
).reshape(1, -1) @ Xtilde
|
|
491
|
+
|
|
492
|
+
# dgbar.3.1 (k x 1)
|
|
493
|
+
dgbar_3_1 = wtXilde.T @ (
|
|
494
|
+
-2 * (Ttilde - Xtilde @ beta_curr) / sigmasq
|
|
495
|
+
).reshape(-1, 1)
|
|
496
|
+
|
|
497
|
+
# dgbar.3.2 (scalar)
|
|
498
|
+
dgbar_3_2 = sample_weights.T @ (
|
|
499
|
+
-(Ttilde - Xtilde @ beta_curr)**2 / (sigmasq**2)
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Assemble dgbar
|
|
503
|
+
col1 = np.vstack([dgbar_1_1, dgbar_1_2 * sigmasq])
|
|
504
|
+
col2 = np.vstack([dgbar_2_1, dgbar_2_2 * sigmasq])
|
|
505
|
+
col3 = np.vstack([dgbar_3_1, dgbar_3_2.reshape(1, 1) * sigmasq])
|
|
506
|
+
|
|
507
|
+
dgbar = (1/n) * np.hstack([col1, col2, col3])
|
|
508
|
+
|
|
509
|
+
# Gradient calculation: 2 * dgbar @ invV @ gbar
|
|
510
|
+
gradient = 2 * dgbar @ invV @ gbar
|
|
511
|
+
|
|
512
|
+
return gradient.ravel()
|
|
513
|
+
|
|
514
|
+
# ========== Balance Gradient Calculation ==========
|
|
515
|
+
|
|
516
|
+
def bal_gradient(params_curr: np.ndarray) -> np.ndarray:
|
|
517
|
+
"""
|
|
518
|
+
Gradient of the balance objective function.
|
|
519
|
+
|
|
520
|
+
Parameters
|
|
521
|
+
----------
|
|
522
|
+
params_curr : np.ndarray
|
|
523
|
+
Parameter vector.
|
|
524
|
+
|
|
525
|
+
Returns
|
|
526
|
+
-------
|
|
527
|
+
np.ndarray
|
|
528
|
+
Gradient vector.
|
|
529
|
+
"""
|
|
530
|
+
beta_curr = params_curr[:-1]
|
|
531
|
+
sigmasq = np.exp(params_curr[-1])
|
|
532
|
+
|
|
533
|
+
log_dens = scipy.stats.norm.logpdf(
|
|
534
|
+
Ttilde,
|
|
535
|
+
loc=Xtilde @ beta_curr,
|
|
536
|
+
scale=np.sqrt(sigmasq)
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
log_dens = np.minimum(np.log(1 - PROBS_MIN), log_dens)
|
|
540
|
+
log_dens = np.maximum(np.log(PROBS_MIN), log_dens)
|
|
541
|
+
|
|
542
|
+
log_diff = stabilizers - log_dens
|
|
543
|
+
log_diff_clipped = np.clip(log_diff, -CLIP_RANGE, CLIP_RANGE)
|
|
544
|
+
w_curr = Ttilde * np.exp(log_diff_clipped)
|
|
545
|
+
|
|
546
|
+
w_curr_del = (1/n) * wtXilde.T @ w_curr
|
|
547
|
+
gbar = np.concatenate([
|
|
548
|
+
w_curr_del.ravel(),
|
|
549
|
+
[(1/n) * sample_weights.T @ (
|
|
550
|
+
(Ttilde - Xtilde @ beta_curr)**2 / sigmasq - 1
|
|
551
|
+
)]
|
|
552
|
+
])
|
|
553
|
+
|
|
554
|
+
# Calculate dgbar blocks
|
|
555
|
+
vec_L145 = -(Ttilde - Xtilde @ beta_curr) / sigmasq * w_curr
|
|
556
|
+
dgbar_2_1 = (wtXilde.T * vec_L145) @ Xtilde
|
|
557
|
+
|
|
558
|
+
dgbar_2_2 = (
|
|
559
|
+
w_curr * (1/(2*sigmasq) - (Ttilde - Xtilde @ beta_curr)**2 / (2*sigmasq**2))
|
|
560
|
+
).reshape(1, -1) @ Xtilde
|
|
561
|
+
|
|
562
|
+
dgbar_3_1 = wtXilde.T @ (
|
|
563
|
+
-2 * (Ttilde - Xtilde @ beta_curr) / sigmasq
|
|
564
|
+
).reshape(-1, 1)
|
|
565
|
+
|
|
566
|
+
dgbar_3_2 = sample_weights.T @ (
|
|
567
|
+
-(Ttilde - Xtilde @ beta_curr)**2 / (sigmasq**2)
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
col1 = np.vstack([dgbar_2_1, dgbar_2_2 * sigmasq])
|
|
571
|
+
col2 = np.vstack([dgbar_3_1, dgbar_3_2.reshape(1, 1) * sigmasq])
|
|
572
|
+
|
|
573
|
+
dgbar = (1/n) * np.hstack([col1, col2])
|
|
574
|
+
|
|
575
|
+
gradient = 2 * dgbar @ np.eye(k + 1) @ gbar
|
|
576
|
+
|
|
577
|
+
return gradient.ravel()
|
|
578
|
+
|
|
579
|
+
# ========== Optimization Initialization and Scaling ==========
|
|
580
|
+
|
|
581
|
+
# Initial Linear Regression estimate
|
|
582
|
+
lm_model = sm.WLS(Ttilde, Xtilde, weights=sample_weights).fit()
|
|
583
|
+
|
|
584
|
+
mcoef = lm_model.params.copy()
|
|
585
|
+
mcoef[np.isnan(mcoef)] = 0
|
|
586
|
+
|
|
587
|
+
residuals = Ttilde - Xtilde @ mcoef
|
|
588
|
+
sigmasq_init = np.mean(residuals**2)
|
|
589
|
+
|
|
590
|
+
assert sigmasq_init > 0, f"Initial residual variance must be positive (got {sigmasq_init})"
|
|
591
|
+
|
|
592
|
+
# Calculate MLE probabilities
|
|
593
|
+
probs_mle = scipy.stats.norm.logpdf(
|
|
594
|
+
Ttilde,
|
|
595
|
+
loc=Xtilde @ mcoef,
|
|
596
|
+
scale=np.sqrt(sigmasq_init)
|
|
597
|
+
)
|
|
598
|
+
probs_mle = np.minimum(np.log(1 - PROBS_MIN), probs_mle)
|
|
599
|
+
probs_mle = np.maximum(np.log(PROBS_MIN), probs_mle)
|
|
600
|
+
|
|
601
|
+
# Construct initial parameter vector
|
|
602
|
+
params_curr = np.concatenate([mcoef, [np.log(sigmasq_init)]])
|
|
603
|
+
|
|
604
|
+
# Pre-compute MLE baseline loss for fallback
|
|
605
|
+
mle_J = np.nan
|
|
606
|
+
try:
|
|
607
|
+
mle_J = gmm_loss(params_curr)
|
|
608
|
+
except Exception as e:
|
|
609
|
+
warnings.warn(f"Failed to compute MLE J statistic: {e}")
|
|
610
|
+
|
|
611
|
+
mle_bal = bal_loss(params_curr)
|
|
612
|
+
|
|
613
|
+
# Alpha scaling optimization
|
|
614
|
+
# Implementation Note:
|
|
615
|
+
# We use a fixed V matrix (calculated at alpha=1.0) during the alpha scaling phase.
|
|
616
|
+
# While continuous updating of V is theoretically possible, fixed V provides better
|
|
617
|
+
# numerical stability in pathological cases (e.g., extremely poor initial fit)
|
|
618
|
+
# and matches the performance of standard two-step GMM approaches.
|
|
619
|
+
|
|
620
|
+
glm_invV = None
|
|
621
|
+
try:
|
|
622
|
+
# Pre-compute V inverse at alpha=1.0
|
|
623
|
+
glm_invV = gmm_func(params_curr, invV=None)['invV']
|
|
624
|
+
|
|
625
|
+
def alpha_func(alpha):
|
|
626
|
+
return gmm_loss(params_curr * alpha, invV=glm_invV)
|
|
627
|
+
|
|
628
|
+
alpha_result = scipy.optimize.minimize_scalar(
|
|
629
|
+
alpha_func,
|
|
630
|
+
bounds=ALPHA_BOUNDS,
|
|
631
|
+
method='bounded'
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# Update parameters with optimal alpha scaling
|
|
635
|
+
params_curr = params_curr * alpha_result.x
|
|
636
|
+
|
|
637
|
+
except Exception as e:
|
|
638
|
+
warnings.warn(f"Alpha scaling failed, using unscaled LM initialization: {e}")
|
|
639
|
+
glm_invV = None
|
|
640
|
+
|
|
641
|
+
gmm_init = params_curr.copy()
|
|
642
|
+
|
|
643
|
+
# ========== Balance and GMM Optimization ==========
|
|
644
|
+
|
|
645
|
+
logger.info(f"Starting balance optimization (max_iter={iterations}, two_step={two_step})...")
|
|
646
|
+
|
|
647
|
+
if two_step:
|
|
648
|
+
# Two-step estimation using BFGS
|
|
649
|
+
opt_bal = scipy.optimize.minimize(
|
|
650
|
+
bal_loss, gmm_init,
|
|
651
|
+
method='BFGS',
|
|
652
|
+
jac=bal_gradient,
|
|
653
|
+
options={
|
|
654
|
+
'maxiter': iterations,
|
|
655
|
+
'gtol': 1e-05
|
|
656
|
+
}
|
|
657
|
+
)
|
|
658
|
+
else:
|
|
659
|
+
# Continuous updating with fallback
|
|
660
|
+
try:
|
|
661
|
+
opt_bal = scipy.optimize.minimize(
|
|
662
|
+
bal_loss, gmm_init,
|
|
663
|
+
method='BFGS',
|
|
664
|
+
options={'maxiter': iterations}
|
|
665
|
+
)
|
|
666
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeWarning) as e:
|
|
667
|
+
warnings.warn(f"Balance BFGS failed, falling back to Nelder-Mead: {e}")
|
|
668
|
+
opt_bal = scipy.optimize.minimize(
|
|
669
|
+
bal_loss, gmm_init,
|
|
670
|
+
method='Nelder-Mead',
|
|
671
|
+
options={'maxiter': iterations}
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
params_bal = opt_bal.x
|
|
675
|
+
|
|
676
|
+
if bal_only:
|
|
677
|
+
opt1 = opt_bal
|
|
678
|
+
|
|
679
|
+
if not bal_only:
|
|
680
|
+
logger.info("Starting GMM optimization with dual initialization...")
|
|
681
|
+
|
|
682
|
+
if two_step:
|
|
683
|
+
# Initialize from GLM and Balance solutions
|
|
684
|
+
gmm_glm_init = scipy.optimize.minimize(
|
|
685
|
+
lambda p: gmm_loss(p, invV=glm_invV),
|
|
686
|
+
gmm_init,
|
|
687
|
+
method='BFGS',
|
|
688
|
+
jac=lambda p: gmm_gradient(p, glm_invV),
|
|
689
|
+
options={
|
|
690
|
+
'maxiter': iterations,
|
|
691
|
+
'gtol': 1e-05
|
|
692
|
+
}
|
|
693
|
+
)
|
|
694
|
+
gmm_bal_init = scipy.optimize.minimize(
|
|
695
|
+
lambda p: gmm_loss(p, invV=glm_invV),
|
|
696
|
+
params_bal,
|
|
697
|
+
method='BFGS',
|
|
698
|
+
jac=lambda p: gmm_gradient(p, glm_invV),
|
|
699
|
+
options={
|
|
700
|
+
'maxiter': iterations,
|
|
701
|
+
'gtol': 1e-05
|
|
702
|
+
}
|
|
703
|
+
)
|
|
704
|
+
else:
|
|
705
|
+
# Continuous updating
|
|
706
|
+
try:
|
|
707
|
+
gmm_glm_init = scipy.optimize.minimize(
|
|
708
|
+
gmm_loss, gmm_init,
|
|
709
|
+
method='BFGS',
|
|
710
|
+
options={
|
|
711
|
+
'maxiter': iterations,
|
|
712
|
+
'gtol': 1e-05
|
|
713
|
+
}
|
|
714
|
+
)
|
|
715
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeWarning) as e:
|
|
716
|
+
warnings.warn(f"GMM-GLM BFGS failed, falling back to Nelder-Mead: {e}")
|
|
717
|
+
gmm_glm_init = scipy.optimize.minimize(
|
|
718
|
+
gmm_loss, gmm_init,
|
|
719
|
+
method='Nelder-Mead',
|
|
720
|
+
options={'maxiter': iterations}
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
try:
|
|
724
|
+
gmm_bal_init = scipy.optimize.minimize(
|
|
725
|
+
gmm_loss, params_bal,
|
|
726
|
+
method='BFGS',
|
|
727
|
+
options={
|
|
728
|
+
'maxiter': iterations,
|
|
729
|
+
'gtol': 1e-05
|
|
730
|
+
}
|
|
731
|
+
)
|
|
732
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeWarning) as e:
|
|
733
|
+
warnings.warn(f"GMM-Balance BFGS failed, falling back to Nelder-Mead: {e}")
|
|
734
|
+
gmm_bal_init = scipy.optimize.minimize(
|
|
735
|
+
gmm_loss, params_bal,
|
|
736
|
+
method='Nelder-Mead',
|
|
737
|
+
options={'maxiter': iterations}
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Select best solution
|
|
741
|
+
if gmm_glm_init.fun < gmm_bal_init.fun:
|
|
742
|
+
opt1 = gmm_glm_init
|
|
743
|
+
pick_glm = 1
|
|
744
|
+
else:
|
|
745
|
+
opt1 = gmm_bal_init
|
|
746
|
+
pick_glm = 0
|
|
747
|
+
|
|
748
|
+
if verbose >= 1:
|
|
749
|
+
source = "GLM" if pick_glm == 1 else "Balance"
|
|
750
|
+
logger.info(f"GMM optimization complete: J={opt1.fun:.6f}, converged={opt1.success}, source={source}")
|
|
751
|
+
|
|
752
|
+
# ========== Parameter Extraction and MLE Fallback ==========
|
|
753
|
+
|
|
754
|
+
params_opt = opt1.x
|
|
755
|
+
beta_opt = params_opt[:-1]
|
|
756
|
+
sigmasq = np.exp(params_opt[-1])
|
|
757
|
+
|
|
758
|
+
# Recalculate probabilities
|
|
759
|
+
probs_opt = scipy.stats.norm.logpdf(
|
|
760
|
+
Ttilde,
|
|
761
|
+
loc=Xtilde @ beta_opt,
|
|
762
|
+
scale=np.sqrt(sigmasq)
|
|
763
|
+
)
|
|
764
|
+
probs_opt = np.minimum(np.log(1 - PROBS_MIN), probs_opt)
|
|
765
|
+
probs_opt = np.maximum(np.log(PROBS_MIN), probs_opt)
|
|
766
|
+
|
|
767
|
+
if not bal_only:
|
|
768
|
+
if two_step:
|
|
769
|
+
J_opt = gmm_func(params_opt, invV=glm_invV)['loss']
|
|
770
|
+
else:
|
|
771
|
+
J_opt = gmm_func(params_opt)['loss']
|
|
772
|
+
|
|
773
|
+
# MLE Fallback Logic
|
|
774
|
+
# Check 1: Significantly negative J statistic (theoretical violation)
|
|
775
|
+
if J_opt < -1e-6:
|
|
776
|
+
raise ValueError(
|
|
777
|
+
f"Encountered an infinite value in the weighting matrix. "
|
|
778
|
+
f"J statistic is significantly negative (J={J_opt:.6e}), "
|
|
779
|
+
f"indicating numerical instability in the V matrix. "
|
|
780
|
+
f'Use the just-identified version of CBPS instead by setting method="exact".'
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Check 2: Optimization result worse than MLE
|
|
784
|
+
# R code: if ((J.opt > mle.J) & (bal.loss(params.opt) > mle.bal))
|
|
785
|
+
elif (J_opt > mle_J) and (bal_loss(params_opt) > mle_bal):
|
|
786
|
+
warnings.warn(
|
|
787
|
+
f"Optimization produced worse results than MLE (|J_opt|={abs(J_opt):.6e} > "
|
|
788
|
+
f"|J_mle|={abs(mle_J):.6e}). Falling back to MLE.",
|
|
789
|
+
UserWarning
|
|
790
|
+
)
|
|
791
|
+
beta_opt = mcoef
|
|
792
|
+
probs_opt = probs_mle
|
|
793
|
+
J_opt = mle_J
|
|
794
|
+
|
|
795
|
+
# Check 3: Minor negative J
|
|
796
|
+
elif J_opt < 0:
|
|
797
|
+
warnings.warn(
|
|
798
|
+
f"J statistic is slightly negative (J={J_opt:.6e}). "
|
|
799
|
+
f"This may indicate minor numerical precision issues.",
|
|
800
|
+
UserWarning
|
|
801
|
+
)
|
|
802
|
+
else:
|
|
803
|
+
J_opt = bal_loss(params_opt)
|
|
804
|
+
|
|
805
|
+
# ========== Final Weight Calculation and Variance Estimation ==========
|
|
806
|
+
|
|
807
|
+
w_opt = np.exp(stabilizers - probs_opt)
|
|
808
|
+
|
|
809
|
+
if standardize:
|
|
810
|
+
w_opt = w_opt / np.sum(w_opt * sample_weights)
|
|
811
|
+
if not np.isclose(np.sum(w_opt * sample_weights), 1.0, atol=1e-10):
|
|
812
|
+
warnings.warn("Weight standardization failed to sum to 1")
|
|
813
|
+
|
|
814
|
+
if not np.all(np.isfinite(w_opt)):
|
|
815
|
+
raise ValueError("Final weights contain non-finite values")
|
|
816
|
+
|
|
817
|
+
deviance = -2 * np.sum(probs_opt)
|
|
818
|
+
|
|
819
|
+
# Compute XG matrix blocks (Gradient of moment conditions)
|
|
820
|
+
XG_1_1 = (-wtXilde.T @ Xtilde) / sigmasq
|
|
821
|
+
|
|
822
|
+
XG_2_1 = (wtXilde.T @ (
|
|
823
|
+
-2 * (Ttilde - Xtilde @ beta_opt) / sigmasq
|
|
824
|
+
)).reshape(-1, 1)
|
|
825
|
+
|
|
826
|
+
vec_L258 = -(Ttilde - Xtilde @ beta_opt) / sigmasq * Ttilde * w_opt
|
|
827
|
+
XG_3_1 = (wtXilde.T * vec_L258) @ Xtilde
|
|
828
|
+
|
|
829
|
+
XG_1_2 = ((-wtXilde.T @ (Ttilde - Xtilde @ beta_opt)) / (sigmasq**2)).reshape(-1, 1)
|
|
830
|
+
|
|
831
|
+
XG_2_2_scalar = sample_weights.T @ (
|
|
832
|
+
-(Ttilde - Xtilde @ beta_opt)**2 / (sigmasq**2)
|
|
833
|
+
)
|
|
834
|
+
XG_2_2 = np.array([[XG_2_2_scalar]])
|
|
835
|
+
|
|
836
|
+
XG_3_2 = (
|
|
837
|
+
-Ttilde * sample_weights * w_opt * (
|
|
838
|
+
(Ttilde - Xtilde @ beta_opt)**2 / (2*sigmasq**2) - 1/(2*sigmasq)
|
|
839
|
+
)
|
|
840
|
+
).reshape(1, -1) @ Xtilde
|
|
841
|
+
|
|
842
|
+
# Compute XW matrix blocks
|
|
843
|
+
XW_1 = Xtilde * (
|
|
844
|
+
(Ttilde - Xtilde @ beta_opt) / sigmasq * sample_weights**0.5
|
|
845
|
+
)[:, None]
|
|
846
|
+
|
|
847
|
+
XW_2 = (
|
|
848
|
+
(Ttilde - Xtilde @ beta_opt)**2 / sigmasq - 1
|
|
849
|
+
) * sample_weights**0.5
|
|
850
|
+
|
|
851
|
+
XW_3 = Xtilde * (Ttilde * w_opt * sample_weights)[:, None]
|
|
852
|
+
|
|
853
|
+
if bal_only:
|
|
854
|
+
W = np.eye(k + 1)
|
|
855
|
+
G = (1/n) * np.vstack([
|
|
856
|
+
np.hstack([XG_3_1, XG_3_2.T]),
|
|
857
|
+
np.hstack([XG_2_1.T, XG_2_2])
|
|
858
|
+
])
|
|
859
|
+
W1 = np.vstack([XW_3.T, XW_2.reshape(1, -1)])
|
|
860
|
+
else:
|
|
861
|
+
W = gmm_func(params_opt)['invV']
|
|
862
|
+
G = (1/n) * np.vstack([
|
|
863
|
+
np.hstack([XG_1_1, XG_1_2]),
|
|
864
|
+
np.hstack([XG_3_1, XG_3_2.T]),
|
|
865
|
+
np.hstack([XG_2_1.T, XG_2_2])
|
|
866
|
+
])
|
|
867
|
+
W1 = np.vstack([XW_1.T, XW_3.T, XW_2.reshape(1, -1)])
|
|
868
|
+
|
|
869
|
+
Omega = (1/n) * (W1 @ W1.T)
|
|
870
|
+
|
|
871
|
+
GWG_inv = scipy.linalg.pinv(G.T @ W @ G)
|
|
872
|
+
GWGinvGW = W @ G @ GWG_inv
|
|
873
|
+
|
|
874
|
+
vcov_tilde = (GWGinvGW.T @ Omega @ GWGinvGW)[0:k, 0:k]
|
|
875
|
+
vcov_tilde = (vcov_tilde + vcov_tilde.T) / 2
|
|
876
|
+
|
|
877
|
+
# Inverse transformation to original space
|
|
878
|
+
beta_tilde = beta_opt.copy()
|
|
879
|
+
|
|
880
|
+
XtX_inv = scipy.linalg.pinv(X.T @ X)
|
|
881
|
+
beta_opt = XtX_inv @ X.T @ (
|
|
882
|
+
Xtilde @ beta_tilde * np.std(sw_treat, ddof=1) + np.mean(sw_treat)
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
sigmasq_tilde = sigmasq
|
|
886
|
+
sigmasq = sigmasq_tilde * np.var(sw_treat, ddof=1)
|
|
887
|
+
|
|
888
|
+
# Variance-covariance transformation
|
|
889
|
+
sw_treat_var = np.var(sw_treat, ddof=1)
|
|
890
|
+
middle = Xtilde @ vcov_tilde @ Xtilde.T * sw_treat_var
|
|
891
|
+
vcov = XtX_inv @ X.T @ middle @ X @ XtX_inv
|
|
892
|
+
vcov = (vcov + vcov.T) / 2
|
|
893
|
+
|
|
894
|
+
result = {
|
|
895
|
+
'coefficients': beta_opt.reshape(-1, 1),
|
|
896
|
+
'fitted_values': np.clip(
|
|
897
|
+
scipy.stats.norm.pdf(
|
|
898
|
+
Ttilde,
|
|
899
|
+
loc=Xtilde @ beta_tilde,
|
|
900
|
+
scale=np.sqrt(sigmasq_tilde)
|
|
901
|
+
),
|
|
902
|
+
PROBS_MIN,
|
|
903
|
+
1 - PROBS_MIN
|
|
904
|
+
),
|
|
905
|
+
'linear_predictor': Xtilde @ beta_tilde,
|
|
906
|
+
'deviance': deviance,
|
|
907
|
+
'weights': w_opt * sample_weights,
|
|
908
|
+
'y': treat,
|
|
909
|
+
'x': X,
|
|
910
|
+
'converged': opt1.success,
|
|
911
|
+
'J': J_opt,
|
|
912
|
+
'var': vcov,
|
|
913
|
+
'mle_J': mle_J,
|
|
914
|
+
'sigmasq': sigmasq,
|
|
915
|
+
'Ttilde': Ttilde,
|
|
916
|
+
'Xtilde': Xtilde,
|
|
917
|
+
'beta_tilde': beta_tilde,
|
|
918
|
+
'sigmasq_tilde': sigmasq_tilde,
|
|
919
|
+
'stabilizers': stabilizers
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
# ========== Normality Diagnostics (P1-17) ==========
|
|
923
|
+
# Test the conditional normality assumption T|X ~ N(X'beta, sigma^2)
|
|
924
|
+
# using the original (un-whitened) treatment and covariates.
|
|
925
|
+
try:
|
|
926
|
+
from cbps.diagnostics.normality import test_treatment_normality
|
|
927
|
+
normality_diag = test_treatment_normality(treat, X)
|
|
928
|
+
result['normality_diagnostics'] = normality_diag
|
|
929
|
+
|
|
930
|
+
if normality_diag['reject_normality']:
|
|
931
|
+
warnings.warn(
|
|
932
|
+
f"Treatment normality assumption rejected "
|
|
933
|
+
f"(p={normality_diag['p_value']:.4f}). "
|
|
934
|
+
f"Consider using npCBPS for nonparametric estimation.",
|
|
935
|
+
UserWarning
|
|
936
|
+
)
|
|
937
|
+
except (ImportError, Exception) as e:
|
|
938
|
+
# Diagnostics should never block estimation
|
|
939
|
+
result['normality_diagnostics'] = {
|
|
940
|
+
'error': str(e),
|
|
941
|
+
'reject_normality': None
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
return result
|
|
945
|
+
|