cbps 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbps/__init__.py +3462 -0
- cbps/constants.py +46 -0
- cbps/core/__init__.py +93 -0
- cbps/core/cbps_binary.py +1943 -0
- cbps/core/cbps_continuous.py +945 -0
- cbps/core/cbps_multitreat.py +1123 -0
- cbps/core/cbps_optimal.py +507 -0
- cbps/core/results.py +1447 -0
- cbps/data/Blackwell.csv +571 -0
- cbps/data/LaLonde.csv +3213 -0
- cbps/data/npcbps_continuous_sim.csv +501 -0
- cbps/data/nsw.csv +723 -0
- cbps/data/nsw_dw.csv +446 -0
- cbps/data/political_ads_urban_niebler.csv +16266 -0
- cbps/data/psid_controls.csv +2491 -0
- cbps/data/psid_controls2.csv +254 -0
- cbps/data/psid_controls3.csv +129 -0
- cbps/data/simulation_dgp1_seed12345.csv +201 -0
- cbps/data/simulation_dgp2_seed12345.csv +201 -0
- cbps/data/simulation_dgp3_seed12345.csv +201 -0
- cbps/data/simulation_dgp4_seed12345.csv +201 -0
- cbps/datasets/__init__.py +78 -0
- cbps/datasets/blackwell.py +112 -0
- cbps/datasets/continuous.py +223 -0
- cbps/datasets/lalonde.py +272 -0
- cbps/datasets/npcbps_sim.py +101 -0
- cbps/diagnostics/__init__.py +101 -0
- cbps/diagnostics/balance.py +760 -0
- cbps/diagnostics/balance_cbmsm_addon.py +162 -0
- cbps/diagnostics/continuous_diagnostics.py +259 -0
- cbps/diagnostics/normality.py +173 -0
- cbps/diagnostics/ocbps_conditions.py +197 -0
- cbps/diagnostics/overlap.py +198 -0
- cbps/diagnostics/plots.py +1193 -0
- cbps/diagnostics/weights_diag.py +205 -0
- cbps/highdim/__init__.py +84 -0
- cbps/highdim/gmm_loss.py +340 -0
- cbps/highdim/hdcbps.py +1078 -0
- cbps/highdim/lasso_utils.py +498 -0
- cbps/highdim/weight_funcs.py +298 -0
- cbps/inference/__init__.py +42 -0
- cbps/inference/asyvar.py +621 -0
- cbps/inference/vcov_outcome.py +217 -0
- cbps/iv/__init__.py +48 -0
- cbps/iv/cbiv.py +2603 -0
- cbps/logging_config.py +45 -0
- cbps/msm/__init__.py +45 -0
- cbps/msm/cbmsm.py +1871 -0
- cbps/msm/rank_diagnostics.py +112 -0
- cbps/nonparametric/__init__.py +58 -0
- cbps/nonparametric/cholesky_whitening.py +232 -0
- cbps/nonparametric/empirical_likelihood.py +339 -0
- cbps/nonparametric/npcbps.py +1036 -0
- cbps/nonparametric/taylor_approx.py +207 -0
- cbps/py.typed +0 -0
- cbps/sklearn/__init__.py +42 -0
- cbps/sklearn/estimator.py +378 -0
- cbps/utils/__init__.py +82 -0
- cbps/utils/formula.py +415 -0
- cbps/utils/helpers.py +378 -0
- cbps/utils/numerics.py +438 -0
- cbps/utils/r_compat.py +109 -0
- cbps/utils/validation.py +224 -0
- cbps/utils/variance_transform.py +483 -0
- cbps/utils/weights.py +586 -0
- cbps-0.2.0.dist-info/METADATA +1090 -0
- cbps-0.2.0.dist-info/RECORD +70 -0
- cbps-0.2.0.dist-info/WHEEL +5 -0
- cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
- cbps-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Weight Quality Diagnostics
|
|
3
|
+
===========================
|
|
4
|
+
|
|
5
|
+
Comprehensive diagnostics for inverse probability weights produced by CBPS
|
|
6
|
+
estimation, including effective sample size (ESS), weight distribution
|
|
7
|
+
summaries, and extreme value detection.
|
|
8
|
+
|
|
9
|
+
The Kish (1965) effective sample size is the primary metric for assessing
|
|
10
|
+
whether extreme weights are degrading estimation precision.
|
|
11
|
+
|
|
12
|
+
References
|
|
13
|
+
----------
|
|
14
|
+
Kish, L. (1965). Survey Sampling. Wiley, New York.
|
|
15
|
+
|
|
16
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
17
|
+
Journal of the Royal Statistical Society, Series B, 76(1), 243-263.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from typing import Optional
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def weight_diagnostics(weights, treat=None):
|
|
25
|
+
"""Compute comprehensive weight quality diagnostics.
|
|
26
|
+
|
|
27
|
+
Based on Kish (1965) effective sample size and standard
|
|
28
|
+
IPW weight quality indicators.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
weights : np.ndarray
|
|
33
|
+
IPW weights from CBPS estimation. Should be non-negative for
|
|
34
|
+
meaningful ESS interpretation. If negative weights are present
|
|
35
|
+
(e.g., from ATT balance conditions), a warning is issued and
|
|
36
|
+
ESS is computed on absolute values.
|
|
37
|
+
treat : np.ndarray, optional
|
|
38
|
+
Treatment indicator for group-specific diagnostics.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
dict with:
|
|
43
|
+
- ess: Kish effective sample size
|
|
44
|
+
- ess_ratio: ESS / n (closer to 1 = better)
|
|
45
|
+
- weight_max: maximum absolute weight
|
|
46
|
+
- weight_min: minimum absolute weight (among nonzero)
|
|
47
|
+
- weight_ratio: max/min ratio (of absolute values)
|
|
48
|
+
- cv: coefficient of variation of weights
|
|
49
|
+
- n_extreme: count of weights with abs(w) > 10*median(abs(w))
|
|
50
|
+
- n_negative: count of negative weights (0 if all non-negative)
|
|
51
|
+
- warning_level: 'ok'/'caution'/'severe'
|
|
52
|
+
- group_diagnostics: dict per treatment group (if treat provided)
|
|
53
|
+
|
|
54
|
+
Notes
|
|
55
|
+
-----
|
|
56
|
+
Warning thresholds (based on Kish 1965, Chapter 11):
|
|
57
|
+
- ESS/n < 0.5 → 'caution'
|
|
58
|
+
- ESS/n < 0.2 → 'severe'
|
|
59
|
+
|
|
60
|
+
The ESS formula is: ESS = (sum(w))^2 / sum(w^2)
|
|
61
|
+
For uniform weights, ESS = n. For highly variable weights, ESS << n.
|
|
62
|
+
|
|
63
|
+
When negative weights are present, the Kish ESS formula does not have
|
|
64
|
+
its standard interpretation. In this case, ESS is computed on abs(w) and
|
|
65
|
+
a warning is included in the result.
|
|
66
|
+
|
|
67
|
+
References
|
|
68
|
+
----------
|
|
69
|
+
Kish, L. (1965). Survey Sampling. Wiley, New York. Chapter 11.
|
|
70
|
+
"""
|
|
71
|
+
weights = np.asarray(weights, dtype=float).ravel()
|
|
72
|
+
n = len(weights)
|
|
73
|
+
|
|
74
|
+
# Handle degenerate cases
|
|
75
|
+
if n == 0:
|
|
76
|
+
return {
|
|
77
|
+
'ess': 0.0,
|
|
78
|
+
'ess_ratio': 0.0,
|
|
79
|
+
'weight_max': np.nan,
|
|
80
|
+
'weight_min': np.nan,
|
|
81
|
+
'weight_ratio': np.nan,
|
|
82
|
+
'cv': np.nan,
|
|
83
|
+
'n_extreme': 0,
|
|
84
|
+
'n_negative': 0,
|
|
85
|
+
'warning_level': 'severe',
|
|
86
|
+
'group_diagnostics': None,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Detect negative weights
|
|
90
|
+
n_negative = int(np.sum(weights < 0))
|
|
91
|
+
has_negative = n_negative > 0
|
|
92
|
+
|
|
93
|
+
# For ESS computation: use absolute values when negative weights present
|
|
94
|
+
# Kish ESS is only interpretable for non-negative weights
|
|
95
|
+
if has_negative:
|
|
96
|
+
import warnings
|
|
97
|
+
warnings.warn(
|
|
98
|
+
f"Kish ESS is defined for non-negative weights. "
|
|
99
|
+
f"{n_negative} negative weight(s) detected; "
|
|
100
|
+
f"ESS is computed on |weights| as an approximation. "
|
|
101
|
+
f"Consider using only the final IPW weights (not balance weights) "
|
|
102
|
+
f"for this diagnostic.",
|
|
103
|
+
UserWarning,
|
|
104
|
+
stacklevel=2
|
|
105
|
+
)
|
|
106
|
+
w_for_ess = np.abs(weights)
|
|
107
|
+
else:
|
|
108
|
+
w_for_ess = weights
|
|
109
|
+
|
|
110
|
+
sum_w = np.sum(w_for_ess)
|
|
111
|
+
sum_w2 = np.sum(w_for_ess ** 2)
|
|
112
|
+
|
|
113
|
+
# ESS computation (Kish 1965)
|
|
114
|
+
if sum_w2 == 0:
|
|
115
|
+
# All weights are zero
|
|
116
|
+
ess = 0.0
|
|
117
|
+
ess_ratio = 0.0
|
|
118
|
+
else:
|
|
119
|
+
ess = (sum_w ** 2) / sum_w2
|
|
120
|
+
ess_ratio = ess / n
|
|
121
|
+
|
|
122
|
+
# Weight range based on absolute values (captures extreme negative weights)
|
|
123
|
+
abs_weights = np.abs(weights)
|
|
124
|
+
nonzero_mask = abs_weights > 0
|
|
125
|
+
if np.any(nonzero_mask):
|
|
126
|
+
weight_min = float(np.min(abs_weights[nonzero_mask]))
|
|
127
|
+
weight_max = float(np.max(abs_weights[nonzero_mask]))
|
|
128
|
+
else:
|
|
129
|
+
weight_min = 0.0
|
|
130
|
+
weight_max = 0.0
|
|
131
|
+
|
|
132
|
+
# Max/min ratio
|
|
133
|
+
if weight_min > 0:
|
|
134
|
+
weight_ratio = weight_max / weight_min
|
|
135
|
+
else:
|
|
136
|
+
weight_ratio = np.inf if weight_max > 0 else np.nan
|
|
137
|
+
|
|
138
|
+
# Coefficient of variation (on absolute values when negative present)
|
|
139
|
+
w_for_cv = w_for_ess
|
|
140
|
+
w_mean = np.mean(w_for_cv)
|
|
141
|
+
if w_mean > 0:
|
|
142
|
+
cv = float(np.std(w_for_cv) / w_mean)
|
|
143
|
+
else:
|
|
144
|
+
cv = np.nan
|
|
145
|
+
|
|
146
|
+
# Extreme weight count: abs(w) > 10 * median(abs(w))
|
|
147
|
+
median_abs_w = np.median(abs_weights)
|
|
148
|
+
if median_abs_w > 0:
|
|
149
|
+
n_extreme = int(np.sum(abs_weights > 10 * median_abs_w))
|
|
150
|
+
else:
|
|
151
|
+
# If median is 0, count all nonzero weights as extreme
|
|
152
|
+
n_extreme = int(np.sum(abs_weights > 0))
|
|
153
|
+
|
|
154
|
+
# Warning level
|
|
155
|
+
if ess_ratio < 0.2:
|
|
156
|
+
warning_level = 'severe'
|
|
157
|
+
elif ess_ratio < 0.5:
|
|
158
|
+
warning_level = 'caution'
|
|
159
|
+
else:
|
|
160
|
+
warning_level = 'ok'
|
|
161
|
+
|
|
162
|
+
result = {
|
|
163
|
+
'ess': float(ess),
|
|
164
|
+
'ess_ratio': float(ess_ratio),
|
|
165
|
+
'weight_max': float(weight_max),
|
|
166
|
+
'weight_min': float(weight_min),
|
|
167
|
+
'weight_ratio': float(weight_ratio) if np.isfinite(weight_ratio) else weight_ratio,
|
|
168
|
+
'cv': float(cv) if np.isfinite(cv) else cv,
|
|
169
|
+
'n_extreme': n_extreme,
|
|
170
|
+
'n_negative': n_negative,
|
|
171
|
+
'warning_level': warning_level,
|
|
172
|
+
'group_diagnostics': None,
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
# Group-specific diagnostics
|
|
176
|
+
if treat is not None:
|
|
177
|
+
treat = np.asarray(treat).ravel()
|
|
178
|
+
if len(treat) == n:
|
|
179
|
+
group_diag = {}
|
|
180
|
+
for level in np.unique(treat):
|
|
181
|
+
mask = treat == level
|
|
182
|
+
g_weights = weights[mask]
|
|
183
|
+
g_n = len(g_weights)
|
|
184
|
+
|
|
185
|
+
# Use absolute values for ESS when negatives present
|
|
186
|
+
g_abs_w = np.abs(g_weights)
|
|
187
|
+
g_sum_w = np.sum(g_abs_w)
|
|
188
|
+
g_sum_w2 = np.sum(g_abs_w ** 2)
|
|
189
|
+
|
|
190
|
+
if g_sum_w2 > 0:
|
|
191
|
+
g_ess = (g_sum_w ** 2) / g_sum_w2
|
|
192
|
+
else:
|
|
193
|
+
g_ess = 0.0
|
|
194
|
+
|
|
195
|
+
group_diag[level] = {
|
|
196
|
+
'n': g_n,
|
|
197
|
+
'ess': float(g_ess),
|
|
198
|
+
'ess_ratio': float(g_ess / g_n) if g_n > 0 else 0.0,
|
|
199
|
+
'weight_mean': float(np.mean(g_weights)),
|
|
200
|
+
'weight_max': float(np.max(g_abs_w)) if g_n > 0 else np.nan,
|
|
201
|
+
'n_negative': int(np.sum(g_weights < 0)),
|
|
202
|
+
}
|
|
203
|
+
result['group_diagnostics'] = group_diag
|
|
204
|
+
|
|
205
|
+
return result
|
cbps/highdim/__init__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""
|
|
2
|
+
High-Dimensional Covariate Balancing Propensity Score (hdCBPS)
|
|
3
|
+
==============================================================
|
|
4
|
+
|
|
5
|
+
This module implements the High-Dimensional Covariate Balancing Propensity
|
|
6
|
+
Score (hdCBPS) methodology for robust causal inference in settings where
|
|
7
|
+
the number of covariates may exceed the sample size (p >> n).
|
|
8
|
+
|
|
9
|
+
Algorithm Overview
|
|
10
|
+
------------------
|
|
11
|
+
The hdCBPS algorithm proceeds in four steps as described in Ning et al. (2020):
|
|
12
|
+
|
|
13
|
+
1. **Propensity Score Estimation** (Equation 5): Fit penalized logistic
|
|
14
|
+
regression (LASSO) to obtain initial propensity score coefficients.
|
|
15
|
+
|
|
16
|
+
2. **Outcome Model Estimation** (Equation 6): Fit penalized regression
|
|
17
|
+
(LASSO) to estimate outcome model coefficients separately for treatment
|
|
18
|
+
and control groups. This implementation uses unweighted LASSO (w_2=1).
|
|
19
|
+
|
|
20
|
+
3. **Covariate Balancing** (Equation 7): Calibrate the propensity score by
|
|
21
|
+
minimizing the GMM objective to balance covariates selected in Step 2.
|
|
22
|
+
This achieves the weak covariate balancing property (Equation 9).
|
|
23
|
+
|
|
24
|
+
4. **Treatment Effect Estimation**: Compute ATE/ATT using the Horvitz-Thompson
|
|
25
|
+
estimator with calibrated propensity scores. Standard errors are computed
|
|
26
|
+
using the sandwich variance estimator (Equation 11).
|
|
27
|
+
|
|
28
|
+
Key Features
|
|
29
|
+
------------
|
|
30
|
+
- **Double Robustness**: Consistent and asymptotically normal when either
|
|
31
|
+
the propensity score model or outcome model is correctly specified.
|
|
32
|
+
- **Sample Boundedness**: Estimated ATE lies within the range of observed
|
|
33
|
+
outcomes, ensuring stable estimates.
|
|
34
|
+
- **Semiparametric Efficiency**: Achieves the efficiency bound when both
|
|
35
|
+
models are correctly specified.
|
|
36
|
+
- **High-Dimensional Support**: Handles p >> n through L1 regularization.
|
|
37
|
+
|
|
38
|
+
Requirements
|
|
39
|
+
------------
|
|
40
|
+
- **glmnetforpython**: Required for LASSO regularization with Fortran backend.
|
|
41
|
+
- numpy, scipy: Numerical computations.
|
|
42
|
+
- pandas: Data handling.
|
|
43
|
+
|
|
44
|
+
References
|
|
45
|
+
----------
|
|
46
|
+
Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
|
|
47
|
+
via a high-dimensional covariate balancing propensity score. Biometrika,
|
|
48
|
+
107(3), 533-554. https://doi.org/10.1093/biomet/asaa020
|
|
49
|
+
|
|
50
|
+
See Also
|
|
51
|
+
--------
|
|
52
|
+
cbps.CBPS : Standard CBPS for low-dimensional settings.
|
|
53
|
+
cbps.CBPSContinuous : CBPS for continuous treatments.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
__all__ = []
|
|
57
|
+
|
|
58
|
+
# Import hdCBPS function when glmnet is available
|
|
59
|
+
try:
|
|
60
|
+
from .hdcbps import hdCBPS, HDCBPSResults
|
|
61
|
+
from .lasso_utils import cv_glmnet, select_variables
|
|
62
|
+
__all__.extend(['hdCBPS', 'HDCBPSResults', 'cv_glmnet', 'select_variables'])
|
|
63
|
+
except ImportError:
|
|
64
|
+
import warnings
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"hdCBPS requires glmnetforpython. Install with: "
|
|
67
|
+
"pip install glmnetforpython",
|
|
68
|
+
ImportWarning
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Weight functions (available regardless of glmnet)
|
|
72
|
+
from .weight_funcs import (
|
|
73
|
+
ate_wt_func,
|
|
74
|
+
ate_wt_nl_func,
|
|
75
|
+
att_wt_func,
|
|
76
|
+
att_wt_nl_func
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
__all__.extend([
|
|
80
|
+
'ate_wt_func',
|
|
81
|
+
'ate_wt_nl_func',
|
|
82
|
+
'att_wt_func',
|
|
83
|
+
'att_wt_nl_func'
|
|
84
|
+
])
|
cbps/highdim/gmm_loss.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GMM Loss Functions for High-Dimensional CBPS
|
|
3
|
+
=============================================
|
|
4
|
+
|
|
5
|
+
This module implements the Generalized Method of Moments (GMM) loss functions
|
|
6
|
+
used in Step 3 of the hdCBPS algorithm for covariate balance calibration.
|
|
7
|
+
|
|
8
|
+
The GMM objective minimizes the squared norm of the covariate balancing
|
|
9
|
+
moment conditions (Equation 7 in Ning et al., 2020):
|
|
10
|
+
|
|
11
|
+
.. math::
|
|
12
|
+
|
|
13
|
+
\\tilde{\\gamma} = \\arg\\min_{\\gamma} \\|g_n(\\gamma)\\|_2^2
|
|
14
|
+
|
|
15
|
+
where the moment function is:
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
|
|
19
|
+
g_n(\\gamma) = \\sum_{i=1}^{n}
|
|
20
|
+
\\left( \\frac{T_i}{\\pi(\\gamma^T X_{i\\tilde{S}} +
|
|
21
|
+
\\hat{\\beta}_{\\tilde{S}^c}^T X_{i\\tilde{S}^c})} - 1 \\right) X_{i\\tilde{S}}
|
|
22
|
+
|
|
23
|
+
Note: The paper defines g_n with a 1/n factor, but this implementation uses
|
|
24
|
+
the sum (without 1/n) since minimizing ||g||^2 and ||g/n||^2 yield identical
|
|
25
|
+
solutions. This calibration step removes bias from the penalized estimators
|
|
26
|
+
and achieves the weak covariate balancing property (Equation 9).
|
|
27
|
+
|
|
28
|
+
References
|
|
29
|
+
----------
|
|
30
|
+
Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
|
|
31
|
+
via a high-dimensional covariate balancing propensity score.
|
|
32
|
+
Biometrika, 107(3), 533-554. https://doi.org/10.1093/biomet/asaa020
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
import numpy as np
|
|
36
|
+
from typing import Tuple
|
|
37
|
+
|
|
38
|
+
from .weight_funcs import ate_wt_func, ate_wt_nl_func, att_wt_func, att_wt_nl_func
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def gmm_func(
|
|
42
|
+
beta_curr: np.ndarray,
|
|
43
|
+
S: np.ndarray,
|
|
44
|
+
tt: int,
|
|
45
|
+
X_gmm: np.ndarray,
|
|
46
|
+
method: str,
|
|
47
|
+
cov1_coef: np.ndarray,
|
|
48
|
+
cov0_coef: np.ndarray,
|
|
49
|
+
treat: np.ndarray,
|
|
50
|
+
beta_ini: np.ndarray
|
|
51
|
+
) -> float:
|
|
52
|
+
"""
|
|
53
|
+
Compute the GMM loss function for ATE estimation.
|
|
54
|
+
|
|
55
|
+
This function evaluates the covariate balancing objective for average
|
|
56
|
+
treatment effect (ATE) estimation in Step 3 of hdCBPS. It computes
|
|
57
|
+
the squared norm of the moment conditions that enforce balance between
|
|
58
|
+
treatment groups.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
beta_curr : np.ndarray
|
|
63
|
+
Current coefficient estimates being optimized. Shape depends on method:
|
|
64
|
+
|
|
65
|
+
- Linear method: shape ``(len(S),)`` without intercept
|
|
66
|
+
- Nonlinear methods: shape ``(len(S)+1,)`` with intercept
|
|
67
|
+
|
|
68
|
+
S : np.ndarray
|
|
69
|
+
Indices of LASSO-selected variables from the outcome model (0-based).
|
|
70
|
+
Corresponds to :math:`\\tilde{S}` in the paper.
|
|
71
|
+
tt : int
|
|
72
|
+
Treatment group indicator:
|
|
73
|
+
|
|
74
|
+
- 0: Optimize for control group (estimating :math:`\\mu_0`)
|
|
75
|
+
- 1: Optimize for treated group (estimating :math:`\\mu_1`)
|
|
76
|
+
|
|
77
|
+
X_gmm : np.ndarray, shape (n, p)
|
|
78
|
+
Covariate matrix without intercept column.
|
|
79
|
+
method : str
|
|
80
|
+
Outcome model specification:
|
|
81
|
+
|
|
82
|
+
- ``'linear'``: Gaussian outcome model
|
|
83
|
+
- ``'binomial'``: Logistic outcome model
|
|
84
|
+
- ``'poisson'``: Poisson outcome model
|
|
85
|
+
|
|
86
|
+
cov1_coef : np.ndarray, shape (p+1,)
|
|
87
|
+
Outcome model coefficients for the treated group.
|
|
88
|
+
cov0_coef : np.ndarray, shape (p+1,)
|
|
89
|
+
Outcome model coefficients for the control group.
|
|
90
|
+
treat : np.ndarray, shape (n,)
|
|
91
|
+
Binary treatment indicator (0/1).
|
|
92
|
+
beta_ini : np.ndarray, shape (p+1,)
|
|
93
|
+
Initial propensity score coefficients from LASSO (Step 1).
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
loss : float
|
|
98
|
+
GMM loss value: :math:`\\|g_n(\\gamma)\\|_2^2`.
|
|
99
|
+
|
|
100
|
+
Notes
|
|
101
|
+
-----
|
|
102
|
+
For the linear outcome model, the moment condition is:
|
|
103
|
+
|
|
104
|
+
.. math::
|
|
105
|
+
|
|
106
|
+
g_n(\\gamma) = \\sum_{i=1}^{n}
|
|
107
|
+
\\left( \\frac{T_i}{\\pi_i} - 1 \\right) X_{i\\tilde{S}}
|
|
108
|
+
|
|
109
|
+
For generalized linear models (binomial/poisson), the weighted covariates
|
|
110
|
+
:math:`f(X) = b''(\\tilde{\\alpha}^T X) X_{\\tilde{S}}` are balanced instead,
|
|
111
|
+
as described in Section 4 of the paper.
|
|
112
|
+
"""
|
|
113
|
+
# Convert covariate matrix to numpy array
|
|
114
|
+
x1 = np.asarray(X_gmm)
|
|
115
|
+
n1 = x1.shape[0]
|
|
116
|
+
|
|
117
|
+
# IMPORTANT: Match R's behavior exactly
|
|
118
|
+
# R code does: X1 = cbind(rep(1,n1), x1) even though x1 already has intercept
|
|
119
|
+
# This adds an extra intercept column, making X1 have shape (n, p+2)
|
|
120
|
+
# The S indices from coef() are 1-based in R and refer to the (p+1) coefficient vector
|
|
121
|
+
# When used with X1[,S], the extra intercept shifts the column alignment
|
|
122
|
+
# We must replicate this behavior for compatibility
|
|
123
|
+
X1 = np.column_stack([np.ones(n1), x1])
|
|
124
|
+
|
|
125
|
+
# Branch on method type
|
|
126
|
+
if method == "linear":
|
|
127
|
+
# Linear method: simple inverse probability weights
|
|
128
|
+
W = ate_wt_func(beta_curr, S, tt, x1, beta_ini, treat)
|
|
129
|
+
|
|
130
|
+
# Compute weighted covariate means
|
|
131
|
+
if len(S) > 0:
|
|
132
|
+
# Extract selected columns and compute weighted means
|
|
133
|
+
w_curr_del = X1[:, S].T @ W
|
|
134
|
+
w_curr_del = np.asarray(w_curr_del).ravel()
|
|
135
|
+
else:
|
|
136
|
+
# No selected variables
|
|
137
|
+
w_curr_del = np.array([])
|
|
138
|
+
|
|
139
|
+
elif method == "poisson":
|
|
140
|
+
# Poisson method: exponential link weights
|
|
141
|
+
W = ate_wt_nl_func(beta_curr, S, tt, x1, beta_ini, treat)
|
|
142
|
+
|
|
143
|
+
# Select outcome model coefficients by treatment group
|
|
144
|
+
if tt == 1:
|
|
145
|
+
pweight = np.exp(X1 @ cov1_coef)
|
|
146
|
+
else:
|
|
147
|
+
pweight = np.exp(X1 @ cov0_coef)
|
|
148
|
+
|
|
149
|
+
# Compute weighted covariates
|
|
150
|
+
if len(S) > 0:
|
|
151
|
+
# Stack outcome weights with weighted selected covariates
|
|
152
|
+
weighted_X = np.column_stack([
|
|
153
|
+
pweight,
|
|
154
|
+
pweight[:, None] * X1[:, S]
|
|
155
|
+
])
|
|
156
|
+
w_curr_del = weighted_X.T @ W
|
|
157
|
+
w_curr_del = np.asarray(w_curr_del).ravel()
|
|
158
|
+
else:
|
|
159
|
+
# Only outcome weight when no covariates selected
|
|
160
|
+
w_curr_del = pweight @ W
|
|
161
|
+
w_curr_del = np.asarray([w_curr_del]).ravel()
|
|
162
|
+
|
|
163
|
+
elif method == "binomial":
|
|
164
|
+
# Binomial method: logistic link weights
|
|
165
|
+
W = ate_wt_nl_func(beta_curr, S, tt, x1, beta_ini, treat)
|
|
166
|
+
|
|
167
|
+
# Compute logistic probabilities and derivatives
|
|
168
|
+
if tt == 1:
|
|
169
|
+
# Treated group outcome model
|
|
170
|
+
exp_term = np.exp(X1 @ cov1_coef)
|
|
171
|
+
pweight1 = exp_term / (1.0 + exp_term)
|
|
172
|
+
pweight2 = exp_term / (1.0 + exp_term)**2
|
|
173
|
+
else:
|
|
174
|
+
# Control group outcome model
|
|
175
|
+
exp_term = np.exp(X1 @ cov0_coef)
|
|
176
|
+
pweight1 = exp_term / (1.0 + exp_term)
|
|
177
|
+
pweight2 = exp_term / (1.0 + exp_term)**2
|
|
178
|
+
|
|
179
|
+
# Compute weighted covariates
|
|
180
|
+
if len(S) > 0:
|
|
181
|
+
# Stack probability with derivative-weighted selected covariates
|
|
182
|
+
weighted_X = np.column_stack([
|
|
183
|
+
pweight1,
|
|
184
|
+
pweight2[:, None] * X1[:, S]
|
|
185
|
+
])
|
|
186
|
+
w_curr_del = weighted_X.T @ W
|
|
187
|
+
w_curr_del = np.asarray(w_curr_del).ravel()
|
|
188
|
+
else:
|
|
189
|
+
# Only probability weight when no covariates selected
|
|
190
|
+
w_curr_del = pweight1 @ W
|
|
191
|
+
w_curr_del = np.asarray([w_curr_del]).ravel()
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"method '{method}' not supported. "
|
|
196
|
+
f"Choose from: 'linear', 'binomial', 'poisson'"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Compute GMM loss as squared norm of moment conditions
|
|
200
|
+
gbar = w_curr_del
|
|
201
|
+
loss = gbar @ gbar
|
|
202
|
+
|
|
203
|
+
return float(loss)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def att_gmm_func(
|
|
207
|
+
beta_curr: np.ndarray,
|
|
208
|
+
S: np.ndarray,
|
|
209
|
+
X_gmm: np.ndarray,
|
|
210
|
+
method: str,
|
|
211
|
+
cov0_coef: np.ndarray,
|
|
212
|
+
treat: np.ndarray,
|
|
213
|
+
beta_ini: np.ndarray
|
|
214
|
+
) -> float:
|
|
215
|
+
"""
|
|
216
|
+
Compute the GMM loss function for ATT estimation.
|
|
217
|
+
|
|
218
|
+
This function evaluates the covariate balancing objective for the average
|
|
219
|
+
treatment effect on the treated (ATT) in Step 3 of hdCBPS. For ATT, only
|
|
220
|
+
the control group propensity score is calibrated to match the treated
|
|
221
|
+
group covariate distribution.
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
beta_curr : np.ndarray
|
|
226
|
+
Current coefficient estimates being optimized. Shape depends on method:
|
|
227
|
+
|
|
228
|
+
- Linear method: shape ``(len(S),)`` without intercept
|
|
229
|
+
- Nonlinear methods: shape ``(len(S)+1,)`` with intercept
|
|
230
|
+
|
|
231
|
+
S : np.ndarray
|
|
232
|
+
Indices of LASSO-selected variables from the control outcome model.
|
|
233
|
+
X_gmm : np.ndarray, shape (n, p)
|
|
234
|
+
Covariate matrix without intercept column.
|
|
235
|
+
method : str
|
|
236
|
+
Outcome model specification: ``'linear'``, ``'binomial'``, or ``'poisson'``.
|
|
237
|
+
cov0_coef : np.ndarray, shape (p+1,)
|
|
238
|
+
Outcome model coefficients for the control group.
|
|
239
|
+
treat : np.ndarray, shape (n,)
|
|
240
|
+
Binary treatment indicator (0/1).
|
|
241
|
+
beta_ini : np.ndarray, shape (p+1,)
|
|
242
|
+
Initial propensity score coefficients from LASSO.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
loss : float
|
|
247
|
+
GMM loss value: :math:`\\|g_n(\\gamma)\\|_2^2`.
|
|
248
|
+
|
|
249
|
+
Notes
|
|
250
|
+
-----
|
|
251
|
+
Unlike ATE estimation which requires separate optimization for treated
|
|
252
|
+
and control groups, ATT estimation only requires calibrating the control
|
|
253
|
+
group weights to match the treated group distribution. The ATT moment
|
|
254
|
+
condition ensures that the reweighted control group has the same covariate
|
|
255
|
+
means as the treated group.
|
|
256
|
+
|
|
257
|
+
See the Supplementary Material of Ning et al. (2020) for theoretical
|
|
258
|
+
details on ATT estimation in high-dimensional settings.
|
|
259
|
+
"""
|
|
260
|
+
# Convert covariate matrix to numpy array
|
|
261
|
+
x1 = np.asarray(X_gmm)
|
|
262
|
+
n1 = x1.shape[0]
|
|
263
|
+
|
|
264
|
+
# IMPORTANT: Match R's behavior exactly
|
|
265
|
+
# R code does: X1 = cbind(rep(1,n1), x1) even though x1 already has intercept
|
|
266
|
+
# This adds an extra intercept column, making X1 have shape (n, p+2)
|
|
267
|
+
# The S indices from coef() are 1-based in R and refer to the (p+1) coefficient vector
|
|
268
|
+
# When used with X1[,S], the extra intercept shifts the column alignment
|
|
269
|
+
# We must replicate this behavior for compatibility
|
|
270
|
+
X1 = np.column_stack([np.ones(n1), x1])
|
|
271
|
+
|
|
272
|
+
# Branch on method type
|
|
273
|
+
if method == "linear":
|
|
274
|
+
# Linear method: simple inverse probability weights for ATT
|
|
275
|
+
W = att_wt_func(beta_curr, S, x1, beta_ini, treat)
|
|
276
|
+
|
|
277
|
+
# Compute weighted covariate means
|
|
278
|
+
if len(S) > 0:
|
|
279
|
+
w_curr_del = X1[:, S].T @ W
|
|
280
|
+
w_curr_del = np.asarray(w_curr_del).ravel()
|
|
281
|
+
else:
|
|
282
|
+
w_curr_del = np.array([])
|
|
283
|
+
|
|
284
|
+
elif method == "poisson":
|
|
285
|
+
# Poisson method: exponential link weights for ATT
|
|
286
|
+
W = att_wt_nl_func(beta_curr, S, x1, beta_ini, treat)
|
|
287
|
+
|
|
288
|
+
# Compute exponential weights from control outcome model
|
|
289
|
+
# Note: X1 = cbind(1, x1) has shape (n, p+2), cov0_coef has length p+2
|
|
290
|
+
pweight = np.exp(X1 @ cov0_coef)
|
|
291
|
+
|
|
292
|
+
# Compute weighted covariates
|
|
293
|
+
if len(S) > 0:
|
|
294
|
+
# Stack outcome weights with weighted selected covariates
|
|
295
|
+
weighted_X = np.column_stack([
|
|
296
|
+
pweight,
|
|
297
|
+
pweight[:, None] * X1[:, S]
|
|
298
|
+
])
|
|
299
|
+
w_curr_del = weighted_X.T @ W
|
|
300
|
+
w_curr_del = np.asarray(w_curr_del).ravel()
|
|
301
|
+
else:
|
|
302
|
+
# Only outcome weight when no covariates selected
|
|
303
|
+
w_curr_del = pweight @ W
|
|
304
|
+
w_curr_del = np.asarray([w_curr_del]).ravel()
|
|
305
|
+
|
|
306
|
+
elif method == "binomial":
|
|
307
|
+
# Binomial method: logistic link weights for ATT
|
|
308
|
+
W = att_wt_nl_func(beta_curr, S, x1, beta_ini, treat)
|
|
309
|
+
|
|
310
|
+
# Compute logistic probabilities and derivatives (control group only)
|
|
311
|
+
# Note: X1 = cbind(1, x1) has shape (n, p+2), cov0_coef has length p+2
|
|
312
|
+
exp_term = np.exp(X1 @ cov0_coef)
|
|
313
|
+
pweight1 = exp_term / (1.0 + exp_term)
|
|
314
|
+
pweight2 = exp_term / (1.0 + exp_term)**2
|
|
315
|
+
|
|
316
|
+
# Compute weighted covariates
|
|
317
|
+
if len(S) > 0:
|
|
318
|
+
# Stack probability with derivative-weighted selected covariates
|
|
319
|
+
weighted_X = np.column_stack([
|
|
320
|
+
pweight1,
|
|
321
|
+
pweight2[:, None] * X1[:, S]
|
|
322
|
+
])
|
|
323
|
+
w_curr_del = weighted_X.T @ W
|
|
324
|
+
w_curr_del = np.asarray(w_curr_del).ravel()
|
|
325
|
+
else:
|
|
326
|
+
# Only probability weight when no covariates selected
|
|
327
|
+
w_curr_del = pweight1 @ W
|
|
328
|
+
w_curr_del = np.asarray([w_curr_del]).ravel()
|
|
329
|
+
|
|
330
|
+
else:
|
|
331
|
+
raise ValueError(
|
|
332
|
+
f"method '{method}' not supported. "
|
|
333
|
+
f"Choose from: 'linear', 'binomial', 'poisson'"
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Compute GMM loss as squared norm of moment conditions
|
|
337
|
+
gbar = w_curr_del
|
|
338
|
+
loss = gbar @ gbar
|
|
339
|
+
|
|
340
|
+
return float(loss)
|