cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,205 @@
1
+ """
2
+ Weight Quality Diagnostics
3
+ ===========================
4
+
5
+ Comprehensive diagnostics for inverse probability weights produced by CBPS
6
+ estimation, including effective sample size (ESS), weight distribution
7
+ summaries, and extreme value detection.
8
+
9
+ The Kish (1965) effective sample size is the primary metric for assessing
10
+ whether extreme weights are degrading estimation precision.
11
+
12
+ References
13
+ ----------
14
+ Kish, L. (1965). Survey Sampling. Wiley, New York.
15
+
16
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
17
+ Journal of the Royal Statistical Society, Series B, 76(1), 243-263.
18
+ """
19
+
20
+ import numpy as np
21
+ from typing import Optional
22
+
23
+
24
+ def weight_diagnostics(weights, treat=None):
25
+ """Compute comprehensive weight quality diagnostics.
26
+
27
+ Based on Kish (1965) effective sample size and standard
28
+ IPW weight quality indicators.
29
+
30
+ Parameters
31
+ ----------
32
+ weights : np.ndarray
33
+ IPW weights from CBPS estimation. Should be non-negative for
34
+ meaningful ESS interpretation. If negative weights are present
35
+ (e.g., from ATT balance conditions), a warning is issued and
36
+ ESS is computed on absolute values.
37
+ treat : np.ndarray, optional
38
+ Treatment indicator for group-specific diagnostics.
39
+
40
+ Returns
41
+ -------
42
+ dict with:
43
+ - ess: Kish effective sample size
44
+ - ess_ratio: ESS / n (closer to 1 = better)
45
+ - weight_max: maximum absolute weight
46
+ - weight_min: minimum absolute weight (among nonzero)
47
+ - weight_ratio: max/min ratio (of absolute values)
48
+ - cv: coefficient of variation of weights
49
+ - n_extreme: count of weights with abs(w) > 10*median(abs(w))
50
+ - n_negative: count of negative weights (0 if all non-negative)
51
+ - warning_level: 'ok'/'caution'/'severe'
52
+ - group_diagnostics: dict per treatment group (if treat provided)
53
+
54
+ Notes
55
+ -----
56
+ Warning thresholds (based on Kish 1965, Chapter 11):
57
+ - ESS/n < 0.5 → 'caution'
58
+ - ESS/n < 0.2 → 'severe'
59
+
60
+ The ESS formula is: ESS = (sum(w))^2 / sum(w^2)
61
+ For uniform weights, ESS = n. For highly variable weights, ESS << n.
62
+
63
+ When negative weights are present, the Kish ESS formula does not have
64
+ its standard interpretation. In this case, ESS is computed on abs(w) and
65
+ a warning is included in the result.
66
+
67
+ References
68
+ ----------
69
+ Kish, L. (1965). Survey Sampling. Wiley, New York. Chapter 11.
70
+ """
71
+ weights = np.asarray(weights, dtype=float).ravel()
72
+ n = len(weights)
73
+
74
+ # Handle degenerate cases
75
+ if n == 0:
76
+ return {
77
+ 'ess': 0.0,
78
+ 'ess_ratio': 0.0,
79
+ 'weight_max': np.nan,
80
+ 'weight_min': np.nan,
81
+ 'weight_ratio': np.nan,
82
+ 'cv': np.nan,
83
+ 'n_extreme': 0,
84
+ 'n_negative': 0,
85
+ 'warning_level': 'severe',
86
+ 'group_diagnostics': None,
87
+ }
88
+
89
+ # Detect negative weights
90
+ n_negative = int(np.sum(weights < 0))
91
+ has_negative = n_negative > 0
92
+
93
+ # For ESS computation: use absolute values when negative weights present
94
+ # Kish ESS is only interpretable for non-negative weights
95
+ if has_negative:
96
+ import warnings
97
+ warnings.warn(
98
+ f"Kish ESS is defined for non-negative weights. "
99
+ f"{n_negative} negative weight(s) detected; "
100
+ f"ESS is computed on |weights| as an approximation. "
101
+ f"Consider using only the final IPW weights (not balance weights) "
102
+ f"for this diagnostic.",
103
+ UserWarning,
104
+ stacklevel=2
105
+ )
106
+ w_for_ess = np.abs(weights)
107
+ else:
108
+ w_for_ess = weights
109
+
110
+ sum_w = np.sum(w_for_ess)
111
+ sum_w2 = np.sum(w_for_ess ** 2)
112
+
113
+ # ESS computation (Kish 1965)
114
+ if sum_w2 == 0:
115
+ # All weights are zero
116
+ ess = 0.0
117
+ ess_ratio = 0.0
118
+ else:
119
+ ess = (sum_w ** 2) / sum_w2
120
+ ess_ratio = ess / n
121
+
122
+ # Weight range based on absolute values (captures extreme negative weights)
123
+ abs_weights = np.abs(weights)
124
+ nonzero_mask = abs_weights > 0
125
+ if np.any(nonzero_mask):
126
+ weight_min = float(np.min(abs_weights[nonzero_mask]))
127
+ weight_max = float(np.max(abs_weights[nonzero_mask]))
128
+ else:
129
+ weight_min = 0.0
130
+ weight_max = 0.0
131
+
132
+ # Max/min ratio
133
+ if weight_min > 0:
134
+ weight_ratio = weight_max / weight_min
135
+ else:
136
+ weight_ratio = np.inf if weight_max > 0 else np.nan
137
+
138
+ # Coefficient of variation (on absolute values when negative present)
139
+ w_for_cv = w_for_ess
140
+ w_mean = np.mean(w_for_cv)
141
+ if w_mean > 0:
142
+ cv = float(np.std(w_for_cv) / w_mean)
143
+ else:
144
+ cv = np.nan
145
+
146
+ # Extreme weight count: abs(w) > 10 * median(abs(w))
147
+ median_abs_w = np.median(abs_weights)
148
+ if median_abs_w > 0:
149
+ n_extreme = int(np.sum(abs_weights > 10 * median_abs_w))
150
+ else:
151
+ # If median is 0, count all nonzero weights as extreme
152
+ n_extreme = int(np.sum(abs_weights > 0))
153
+
154
+ # Warning level
155
+ if ess_ratio < 0.2:
156
+ warning_level = 'severe'
157
+ elif ess_ratio < 0.5:
158
+ warning_level = 'caution'
159
+ else:
160
+ warning_level = 'ok'
161
+
162
+ result = {
163
+ 'ess': float(ess),
164
+ 'ess_ratio': float(ess_ratio),
165
+ 'weight_max': float(weight_max),
166
+ 'weight_min': float(weight_min),
167
+ 'weight_ratio': float(weight_ratio) if np.isfinite(weight_ratio) else weight_ratio,
168
+ 'cv': float(cv) if np.isfinite(cv) else cv,
169
+ 'n_extreme': n_extreme,
170
+ 'n_negative': n_negative,
171
+ 'warning_level': warning_level,
172
+ 'group_diagnostics': None,
173
+ }
174
+
175
+ # Group-specific diagnostics
176
+ if treat is not None:
177
+ treat = np.asarray(treat).ravel()
178
+ if len(treat) == n:
179
+ group_diag = {}
180
+ for level in np.unique(treat):
181
+ mask = treat == level
182
+ g_weights = weights[mask]
183
+ g_n = len(g_weights)
184
+
185
+ # Use absolute values for ESS when negatives present
186
+ g_abs_w = np.abs(g_weights)
187
+ g_sum_w = np.sum(g_abs_w)
188
+ g_sum_w2 = np.sum(g_abs_w ** 2)
189
+
190
+ if g_sum_w2 > 0:
191
+ g_ess = (g_sum_w ** 2) / g_sum_w2
192
+ else:
193
+ g_ess = 0.0
194
+
195
+ group_diag[level] = {
196
+ 'n': g_n,
197
+ 'ess': float(g_ess),
198
+ 'ess_ratio': float(g_ess / g_n) if g_n > 0 else 0.0,
199
+ 'weight_mean': float(np.mean(g_weights)),
200
+ 'weight_max': float(np.max(g_abs_w)) if g_n > 0 else np.nan,
201
+ 'n_negative': int(np.sum(g_weights < 0)),
202
+ }
203
+ result['group_diagnostics'] = group_diag
204
+
205
+ return result
@@ -0,0 +1,84 @@
1
+ """
2
+ High-Dimensional Covariate Balancing Propensity Score (hdCBPS)
3
+ ==============================================================
4
+
5
+ This module implements the High-Dimensional Covariate Balancing Propensity
6
+ Score (hdCBPS) methodology for robust causal inference in settings where
7
+ the number of covariates may exceed the sample size (p >> n).
8
+
9
+ Algorithm Overview
10
+ ------------------
11
+ The hdCBPS algorithm proceeds in four steps as described in Ning et al. (2020):
12
+
13
+ 1. **Propensity Score Estimation** (Equation 5): Fit penalized logistic
14
+ regression (LASSO) to obtain initial propensity score coefficients.
15
+
16
+ 2. **Outcome Model Estimation** (Equation 6): Fit penalized regression
17
+ (LASSO) to estimate outcome model coefficients separately for treatment
18
+ and control groups. This implementation uses unweighted LASSO (w_2=1).
19
+
20
+ 3. **Covariate Balancing** (Equation 7): Calibrate the propensity score by
21
+ minimizing the GMM objective to balance covariates selected in Step 2.
22
+ This achieves the weak covariate balancing property (Equation 9).
23
+
24
+ 4. **Treatment Effect Estimation**: Compute ATE/ATT using the Horvitz-Thompson
25
+ estimator with calibrated propensity scores. Standard errors are computed
26
+ using the sandwich variance estimator (Equation 11).
27
+
28
+ Key Features
29
+ ------------
30
+ - **Double Robustness**: Consistent and asymptotically normal when either
31
+ the propensity score model or outcome model is correctly specified.
32
+ - **Sample Boundedness**: Estimated ATE lies within the range of observed
33
+ outcomes, ensuring stable estimates.
34
+ - **Semiparametric Efficiency**: Achieves the efficiency bound when both
35
+ models are correctly specified.
36
+ - **High-Dimensional Support**: Handles p >> n through L1 regularization.
37
+
38
+ Requirements
39
+ ------------
40
+ - **glmnetforpython**: Required for LASSO regularization with Fortran backend.
41
+ - numpy, scipy: Numerical computations.
42
+ - pandas: Data handling.
43
+
44
+ References
45
+ ----------
46
+ Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
47
+ via a high-dimensional covariate balancing propensity score. Biometrika,
48
+ 107(3), 533-554. https://doi.org/10.1093/biomet/asaa020
49
+
50
+ See Also
51
+ --------
52
+ cbps.CBPS : Standard CBPS for low-dimensional settings.
53
+ cbps.CBPSContinuous : CBPS for continuous treatments.
54
+ """
55
+
56
+ __all__ = []
57
+
58
+ # Import hdCBPS function when glmnet is available
59
+ try:
60
+ from .hdcbps import hdCBPS, HDCBPSResults
61
+ from .lasso_utils import cv_glmnet, select_variables
62
+ __all__.extend(['hdCBPS', 'HDCBPSResults', 'cv_glmnet', 'select_variables'])
63
+ except ImportError:
64
+ import warnings
65
+ warnings.warn(
66
+ "hdCBPS requires glmnetforpython. Install with: "
67
+ "pip install glmnetforpython",
68
+ ImportWarning
69
+ )
70
+
71
+ # Weight functions (available regardless of glmnet)
72
+ from .weight_funcs import (
73
+ ate_wt_func,
74
+ ate_wt_nl_func,
75
+ att_wt_func,
76
+ att_wt_nl_func
77
+ )
78
+
79
+ __all__.extend([
80
+ 'ate_wt_func',
81
+ 'ate_wt_nl_func',
82
+ 'att_wt_func',
83
+ 'att_wt_nl_func'
84
+ ])
@@ -0,0 +1,340 @@
1
+ """
2
+ GMM Loss Functions for High-Dimensional CBPS
3
+ =============================================
4
+
5
+ This module implements the Generalized Method of Moments (GMM) loss functions
6
+ used in Step 3 of the hdCBPS algorithm for covariate balance calibration.
7
+
8
+ The GMM objective minimizes the squared norm of the covariate balancing
9
+ moment conditions (Equation 7 in Ning et al., 2020):
10
+
11
+ .. math::
12
+
13
+ \\tilde{\\gamma} = \\arg\\min_{\\gamma} \\|g_n(\\gamma)\\|_2^2
14
+
15
+ where the moment function is:
16
+
17
+ .. math::
18
+
19
+ g_n(\\gamma) = \\sum_{i=1}^{n}
20
+ \\left( \\frac{T_i}{\\pi(\\gamma^T X_{i\\tilde{S}} +
21
+ \\hat{\\beta}_{\\tilde{S}^c}^T X_{i\\tilde{S}^c})} - 1 \\right) X_{i\\tilde{S}}
22
+
23
+ Note: The paper defines g_n with a 1/n factor, but this implementation uses
24
+ the sum (without 1/n) since minimizing ||g||^2 and ||g/n||^2 yield identical
25
+ solutions. This calibration step removes bias from the penalized estimators
26
+ and achieves the weak covariate balancing property (Equation 9).
27
+
28
+ References
29
+ ----------
30
+ Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
31
+ via a high-dimensional covariate balancing propensity score.
32
+ Biometrika, 107(3), 533-554. https://doi.org/10.1093/biomet/asaa020
33
+ """
34
+
35
+ import numpy as np
36
+ from typing import Tuple
37
+
38
+ from .weight_funcs import ate_wt_func, ate_wt_nl_func, att_wt_func, att_wt_nl_func
39
+
40
+
41
+ def gmm_func(
42
+ beta_curr: np.ndarray,
43
+ S: np.ndarray,
44
+ tt: int,
45
+ X_gmm: np.ndarray,
46
+ method: str,
47
+ cov1_coef: np.ndarray,
48
+ cov0_coef: np.ndarray,
49
+ treat: np.ndarray,
50
+ beta_ini: np.ndarray
51
+ ) -> float:
52
+ """
53
+ Compute the GMM loss function for ATE estimation.
54
+
55
+ This function evaluates the covariate balancing objective for average
56
+ treatment effect (ATE) estimation in Step 3 of hdCBPS. It computes
57
+ the squared norm of the moment conditions that enforce balance between
58
+ treatment groups.
59
+
60
+ Parameters
61
+ ----------
62
+ beta_curr : np.ndarray
63
+ Current coefficient estimates being optimized. Shape depends on method:
64
+
65
+ - Linear method: shape ``(len(S),)`` without intercept
66
+ - Nonlinear methods: shape ``(len(S)+1,)`` with intercept
67
+
68
+ S : np.ndarray
69
+ Indices of LASSO-selected variables from the outcome model (0-based).
70
+ Corresponds to :math:`\\tilde{S}` in the paper.
71
+ tt : int
72
+ Treatment group indicator:
73
+
74
+ - 0: Optimize for control group (estimating :math:`\\mu_0`)
75
+ - 1: Optimize for treated group (estimating :math:`\\mu_1`)
76
+
77
+ X_gmm : np.ndarray, shape (n, p)
78
+ Covariate matrix without intercept column.
79
+ method : str
80
+ Outcome model specification:
81
+
82
+ - ``'linear'``: Gaussian outcome model
83
+ - ``'binomial'``: Logistic outcome model
84
+ - ``'poisson'``: Poisson outcome model
85
+
86
+ cov1_coef : np.ndarray, shape (p+1,)
87
+ Outcome model coefficients for the treated group.
88
+ cov0_coef : np.ndarray, shape (p+1,)
89
+ Outcome model coefficients for the control group.
90
+ treat : np.ndarray, shape (n,)
91
+ Binary treatment indicator (0/1).
92
+ beta_ini : np.ndarray, shape (p+1,)
93
+ Initial propensity score coefficients from LASSO (Step 1).
94
+
95
+ Returns
96
+ -------
97
+ loss : float
98
+ GMM loss value: :math:`\\|g_n(\\gamma)\\|_2^2`.
99
+
100
+ Notes
101
+ -----
102
+ For the linear outcome model, the moment condition is:
103
+
104
+ .. math::
105
+
106
+ g_n(\\gamma) = \\sum_{i=1}^{n}
107
+ \\left( \\frac{T_i}{\\pi_i} - 1 \\right) X_{i\\tilde{S}}
108
+
109
+ For generalized linear models (binomial/poisson), the weighted covariates
110
+ :math:`f(X) = b''(\\tilde{\\alpha}^T X) X_{\\tilde{S}}` are balanced instead,
111
+ as described in Section 4 of the paper.
112
+ """
113
+ # Convert covariate matrix to numpy array
114
+ x1 = np.asarray(X_gmm)
115
+ n1 = x1.shape[0]
116
+
117
+ # IMPORTANT: Match R's behavior exactly
118
+ # R code does: X1 = cbind(rep(1,n1), x1) even though x1 already has intercept
119
+ # This adds an extra intercept column, making X1 have shape (n, p+2)
120
+ # The S indices from coef() are 1-based in R and refer to the (p+1) coefficient vector
121
+ # When used with X1[,S], the extra intercept shifts the column alignment
122
+ # We must replicate this behavior for compatibility
123
+ X1 = np.column_stack([np.ones(n1), x1])
124
+
125
+ # Branch on method type
126
+ if method == "linear":
127
+ # Linear method: simple inverse probability weights
128
+ W = ate_wt_func(beta_curr, S, tt, x1, beta_ini, treat)
129
+
130
+ # Compute weighted covariate means
131
+ if len(S) > 0:
132
+ # Extract selected columns and compute weighted means
133
+ w_curr_del = X1[:, S].T @ W
134
+ w_curr_del = np.asarray(w_curr_del).ravel()
135
+ else:
136
+ # No selected variables
137
+ w_curr_del = np.array([])
138
+
139
+ elif method == "poisson":
140
+ # Poisson method: exponential link weights
141
+ W = ate_wt_nl_func(beta_curr, S, tt, x1, beta_ini, treat)
142
+
143
+ # Select outcome model coefficients by treatment group
144
+ if tt == 1:
145
+ pweight = np.exp(X1 @ cov1_coef)
146
+ else:
147
+ pweight = np.exp(X1 @ cov0_coef)
148
+
149
+ # Compute weighted covariates
150
+ if len(S) > 0:
151
+ # Stack outcome weights with weighted selected covariates
152
+ weighted_X = np.column_stack([
153
+ pweight,
154
+ pweight[:, None] * X1[:, S]
155
+ ])
156
+ w_curr_del = weighted_X.T @ W
157
+ w_curr_del = np.asarray(w_curr_del).ravel()
158
+ else:
159
+ # Only outcome weight when no covariates selected
160
+ w_curr_del = pweight @ W
161
+ w_curr_del = np.asarray([w_curr_del]).ravel()
162
+
163
+ elif method == "binomial":
164
+ # Binomial method: logistic link weights
165
+ W = ate_wt_nl_func(beta_curr, S, tt, x1, beta_ini, treat)
166
+
167
+ # Compute logistic probabilities and derivatives
168
+ if tt == 1:
169
+ # Treated group outcome model
170
+ exp_term = np.exp(X1 @ cov1_coef)
171
+ pweight1 = exp_term / (1.0 + exp_term)
172
+ pweight2 = exp_term / (1.0 + exp_term)**2
173
+ else:
174
+ # Control group outcome model
175
+ exp_term = np.exp(X1 @ cov0_coef)
176
+ pweight1 = exp_term / (1.0 + exp_term)
177
+ pweight2 = exp_term / (1.0 + exp_term)**2
178
+
179
+ # Compute weighted covariates
180
+ if len(S) > 0:
181
+ # Stack probability with derivative-weighted selected covariates
182
+ weighted_X = np.column_stack([
183
+ pweight1,
184
+ pweight2[:, None] * X1[:, S]
185
+ ])
186
+ w_curr_del = weighted_X.T @ W
187
+ w_curr_del = np.asarray(w_curr_del).ravel()
188
+ else:
189
+ # Only probability weight when no covariates selected
190
+ w_curr_del = pweight1 @ W
191
+ w_curr_del = np.asarray([w_curr_del]).ravel()
192
+
193
+ else:
194
+ raise ValueError(
195
+ f"method '{method}' not supported. "
196
+ f"Choose from: 'linear', 'binomial', 'poisson'"
197
+ )
198
+
199
+ # Compute GMM loss as squared norm of moment conditions
200
+ gbar = w_curr_del
201
+ loss = gbar @ gbar
202
+
203
+ return float(loss)
204
+
205
+
206
+ def att_gmm_func(
207
+ beta_curr: np.ndarray,
208
+ S: np.ndarray,
209
+ X_gmm: np.ndarray,
210
+ method: str,
211
+ cov0_coef: np.ndarray,
212
+ treat: np.ndarray,
213
+ beta_ini: np.ndarray
214
+ ) -> float:
215
+ """
216
+ Compute the GMM loss function for ATT estimation.
217
+
218
+ This function evaluates the covariate balancing objective for the average
219
+ treatment effect on the treated (ATT) in Step 3 of hdCBPS. For ATT, only
220
+ the control group propensity score is calibrated to match the treated
221
+ group covariate distribution.
222
+
223
+ Parameters
224
+ ----------
225
+ beta_curr : np.ndarray
226
+ Current coefficient estimates being optimized. Shape depends on method:
227
+
228
+ - Linear method: shape ``(len(S),)`` without intercept
229
+ - Nonlinear methods: shape ``(len(S)+1,)`` with intercept
230
+
231
+ S : np.ndarray
232
+ Indices of LASSO-selected variables from the control outcome model.
233
+ X_gmm : np.ndarray, shape (n, p)
234
+ Covariate matrix without intercept column.
235
+ method : str
236
+ Outcome model specification: ``'linear'``, ``'binomial'``, or ``'poisson'``.
237
+ cov0_coef : np.ndarray, shape (p+1,)
238
+ Outcome model coefficients for the control group.
239
+ treat : np.ndarray, shape (n,)
240
+ Binary treatment indicator (0/1).
241
+ beta_ini : np.ndarray, shape (p+1,)
242
+ Initial propensity score coefficients from LASSO.
243
+
244
+ Returns
245
+ -------
246
+ loss : float
247
+ GMM loss value: :math:`\\|g_n(\\gamma)\\|_2^2`.
248
+
249
+ Notes
250
+ -----
251
+ Unlike ATE estimation which requires separate optimization for treated
252
+ and control groups, ATT estimation only requires calibrating the control
253
+ group weights to match the treated group distribution. The ATT moment
254
+ condition ensures that the reweighted control group has the same covariate
255
+ means as the treated group.
256
+
257
+ See the Supplementary Material of Ning et al. (2020) for theoretical
258
+ details on ATT estimation in high-dimensional settings.
259
+ """
260
+ # Convert covariate matrix to numpy array
261
+ x1 = np.asarray(X_gmm)
262
+ n1 = x1.shape[0]
263
+
264
+ # IMPORTANT: Match R's behavior exactly
265
+ # R code does: X1 = cbind(rep(1,n1), x1) even though x1 already has intercept
266
+ # This adds an extra intercept column, making X1 have shape (n, p+2)
267
+ # The S indices from coef() are 1-based in R and refer to the (p+1) coefficient vector
268
+ # When used with X1[,S], the extra intercept shifts the column alignment
269
+ # We must replicate this behavior for compatibility
270
+ X1 = np.column_stack([np.ones(n1), x1])
271
+
272
+ # Branch on method type
273
+ if method == "linear":
274
+ # Linear method: simple inverse probability weights for ATT
275
+ W = att_wt_func(beta_curr, S, x1, beta_ini, treat)
276
+
277
+ # Compute weighted covariate means
278
+ if len(S) > 0:
279
+ w_curr_del = X1[:, S].T @ W
280
+ w_curr_del = np.asarray(w_curr_del).ravel()
281
+ else:
282
+ w_curr_del = np.array([])
283
+
284
+ elif method == "poisson":
285
+ # Poisson method: exponential link weights for ATT
286
+ W = att_wt_nl_func(beta_curr, S, x1, beta_ini, treat)
287
+
288
+ # Compute exponential weights from control outcome model
289
+ # Note: X1 = cbind(1, x1) has shape (n, p+2), cov0_coef has length p+2
290
+ pweight = np.exp(X1 @ cov0_coef)
291
+
292
+ # Compute weighted covariates
293
+ if len(S) > 0:
294
+ # Stack outcome weights with weighted selected covariates
295
+ weighted_X = np.column_stack([
296
+ pweight,
297
+ pweight[:, None] * X1[:, S]
298
+ ])
299
+ w_curr_del = weighted_X.T @ W
300
+ w_curr_del = np.asarray(w_curr_del).ravel()
301
+ else:
302
+ # Only outcome weight when no covariates selected
303
+ w_curr_del = pweight @ W
304
+ w_curr_del = np.asarray([w_curr_del]).ravel()
305
+
306
+ elif method == "binomial":
307
+ # Binomial method: logistic link weights for ATT
308
+ W = att_wt_nl_func(beta_curr, S, x1, beta_ini, treat)
309
+
310
+ # Compute logistic probabilities and derivatives (control group only)
311
+ # Note: X1 = cbind(1, x1) has shape (n, p+2), cov0_coef has length p+2
312
+ exp_term = np.exp(X1 @ cov0_coef)
313
+ pweight1 = exp_term / (1.0 + exp_term)
314
+ pweight2 = exp_term / (1.0 + exp_term)**2
315
+
316
+ # Compute weighted covariates
317
+ if len(S) > 0:
318
+ # Stack probability with derivative-weighted selected covariates
319
+ weighted_X = np.column_stack([
320
+ pweight1,
321
+ pweight2[:, None] * X1[:, S]
322
+ ])
323
+ w_curr_del = weighted_X.T @ W
324
+ w_curr_del = np.asarray(w_curr_del).ravel()
325
+ else:
326
+ # Only probability weight when no covariates selected
327
+ w_curr_del = pweight1 @ W
328
+ w_curr_del = np.asarray([w_curr_del]).ravel()
329
+
330
+ else:
331
+ raise ValueError(
332
+ f"method '{method}' not supported. "
333
+ f"Choose from: 'linear', 'binomial', 'poisson'"
334
+ )
335
+
336
+ # Compute GMM loss as squared norm of moment conditions
337
+ gbar = w_curr_del
338
+ loss = gbar @ gbar
339
+
340
+ return float(loss)