cbps 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbps/__init__.py +3462 -0
- cbps/constants.py +46 -0
- cbps/core/__init__.py +93 -0
- cbps/core/cbps_binary.py +1943 -0
- cbps/core/cbps_continuous.py +945 -0
- cbps/core/cbps_multitreat.py +1123 -0
- cbps/core/cbps_optimal.py +507 -0
- cbps/core/results.py +1447 -0
- cbps/data/Blackwell.csv +571 -0
- cbps/data/LaLonde.csv +3213 -0
- cbps/data/npcbps_continuous_sim.csv +501 -0
- cbps/data/nsw.csv +723 -0
- cbps/data/nsw_dw.csv +446 -0
- cbps/data/political_ads_urban_niebler.csv +16266 -0
- cbps/data/psid_controls.csv +2491 -0
- cbps/data/psid_controls2.csv +254 -0
- cbps/data/psid_controls3.csv +129 -0
- cbps/data/simulation_dgp1_seed12345.csv +201 -0
- cbps/data/simulation_dgp2_seed12345.csv +201 -0
- cbps/data/simulation_dgp3_seed12345.csv +201 -0
- cbps/data/simulation_dgp4_seed12345.csv +201 -0
- cbps/datasets/__init__.py +78 -0
- cbps/datasets/blackwell.py +112 -0
- cbps/datasets/continuous.py +223 -0
- cbps/datasets/lalonde.py +272 -0
- cbps/datasets/npcbps_sim.py +101 -0
- cbps/diagnostics/__init__.py +101 -0
- cbps/diagnostics/balance.py +760 -0
- cbps/diagnostics/balance_cbmsm_addon.py +162 -0
- cbps/diagnostics/continuous_diagnostics.py +259 -0
- cbps/diagnostics/normality.py +173 -0
- cbps/diagnostics/ocbps_conditions.py +197 -0
- cbps/diagnostics/overlap.py +198 -0
- cbps/diagnostics/plots.py +1193 -0
- cbps/diagnostics/weights_diag.py +205 -0
- cbps/highdim/__init__.py +84 -0
- cbps/highdim/gmm_loss.py +340 -0
- cbps/highdim/hdcbps.py +1078 -0
- cbps/highdim/lasso_utils.py +498 -0
- cbps/highdim/weight_funcs.py +298 -0
- cbps/inference/__init__.py +42 -0
- cbps/inference/asyvar.py +621 -0
- cbps/inference/vcov_outcome.py +217 -0
- cbps/iv/__init__.py +48 -0
- cbps/iv/cbiv.py +2603 -0
- cbps/logging_config.py +45 -0
- cbps/msm/__init__.py +45 -0
- cbps/msm/cbmsm.py +1871 -0
- cbps/msm/rank_diagnostics.py +112 -0
- cbps/nonparametric/__init__.py +58 -0
- cbps/nonparametric/cholesky_whitening.py +232 -0
- cbps/nonparametric/empirical_likelihood.py +339 -0
- cbps/nonparametric/npcbps.py +1036 -0
- cbps/nonparametric/taylor_approx.py +207 -0
- cbps/py.typed +0 -0
- cbps/sklearn/__init__.py +42 -0
- cbps/sklearn/estimator.py +378 -0
- cbps/utils/__init__.py +82 -0
- cbps/utils/formula.py +415 -0
- cbps/utils/helpers.py +378 -0
- cbps/utils/numerics.py +438 -0
- cbps/utils/r_compat.py +109 -0
- cbps/utils/validation.py +224 -0
- cbps/utils/variance_transform.py +483 -0
- cbps/utils/weights.py +586 -0
- cbps-0.2.0.dist-info/METADATA +1090 -0
- cbps-0.2.0.dist-info/RECORD +70 -0
- cbps-0.2.0.dist-info/WHEEL +5 -0
- cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
- cbps-0.2.0.dist-info/top_level.txt +1 -0
cbps/__init__.py
ADDED
|
@@ -0,0 +1,3462 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Covariate Balancing Propensity Score (CBPS)
|
|
3
|
+
===========================================
|
|
4
|
+
|
|
5
|
+
A comprehensive Python implementation of the covariate balancing propensity score
|
|
6
|
+
methodology for causal inference from observational studies.
|
|
7
|
+
|
|
8
|
+
The CBPS approach revolutionizes propensity score estimation by directly incorporating
|
|
9
|
+
covariate balance conditions into the estimation procedure [1]_. Unlike traditional
|
|
10
|
+
propensity score methods that solely maximize the likelihood of treatment assignment,
|
|
11
|
+
CBPS estimates propensity scores by solving moment conditions that simultaneously
|
|
12
|
+
optimize covariate balance between treatment groups while maintaining predictive power.
|
|
13
|
+
|
|
14
|
+
This innovative approach is implemented through the generalized method of moments
|
|
15
|
+
(GMM) framework, where the objective function seamlessly integrates the score function
|
|
16
|
+
for treatment prediction with moment conditions ensuring covariate balance. The resulting
|
|
17
|
+
estimator achieves superior finite-sample balance performance while preserving the
|
|
18
|
+
double robustness properties of conventional propensity score methods.
|
|
19
|
+
|
|
20
|
+
Methodological Framework
|
|
21
|
+
------------------------
|
|
22
|
+
|
|
23
|
+
For a binary treatment :math:`T \\in \\{0,1\\}` and covariates :math:`X`, the CBPS
|
|
24
|
+
estimator :math:`\\hat{\\beta}` solves the following moment conditions:
|
|
25
|
+
|
|
26
|
+
.. math::
|
|
27
|
+
\\frac{1}{n} \\sum_{i=1}^n \\psi_i(\\beta) = 0
|
|
28
|
+
|
|
29
|
+
where the moment function :math:`\\psi_i(\\beta)` combines:
|
|
30
|
+
|
|
31
|
+
1. **Score function**: :math:`\\psi_i^{(1)}(\\beta) = T_i - e(X_i,\\beta)`
|
|
32
|
+
2. **Balance conditions**: :math:`\\psi_i^{(2)}(\\beta) = T_i X_i - e(X_i,\\beta) X_i`
|
|
33
|
+
|
|
34
|
+
with :math:`e(X_i,\\beta)` denoting the propensity score model.
|
|
35
|
+
|
|
36
|
+
Key Features
|
|
37
|
+
------------
|
|
38
|
+
|
|
39
|
+
* **Binary Treatments**: Robust estimation of average treatment effects (ATE) and
|
|
40
|
+
average treatment effects on the treated (ATT) using logistic models [1]_
|
|
41
|
+
|
|
42
|
+
* **Multi-valued Treatments**: Seamless extension to categorical treatments via
|
|
43
|
+
multinomial logistic regression supporting treatments with three or four levels
|
|
44
|
+
|
|
45
|
+
* **Continuous Treatments**: Generalized propensity scores for continuous
|
|
46
|
+
treatment variables using flexible parametric distributions [2]_
|
|
47
|
+
|
|
48
|
+
* **High-dimensional Settings**: State-of-the-art regularization through LASSO
|
|
49
|
+
when the number of covariates exceeds the sample size, with automatic variable
|
|
50
|
+
selection and valid post-selection inference [3]_
|
|
51
|
+
|
|
52
|
+
* **Nonparametric Estimation**: Empirical likelihood methods that completely
|
|
53
|
+
avoid parametric modeling assumptions about the propensity score [4]_
|
|
54
|
+
|
|
55
|
+
* **Longitudinal Data**: Marginal structural models for time-varying treatments
|
|
56
|
+
with time-dependent confounding, extending causal inference to complex study designs [5]_
|
|
57
|
+
|
|
58
|
+
* **Instrumental Variables**: Comprehensive support for treatment noncompliance
|
|
59
|
+
and instrumental variable assignment scenarios [6]_
|
|
60
|
+
|
|
61
|
+
Implementation Highlights
|
|
62
|
+
--------------------------
|
|
63
|
+
|
|
64
|
+
- **Automatic Treatment Detection**: Intelligent recognition of binary, multi-valued,
|
|
65
|
+
and continuous treatments based on data characteristics
|
|
66
|
+
- **Dual Interface Design**: Both intuitive patsy formula interface and efficient
|
|
67
|
+
NumPy array interface for different usage patterns
|
|
68
|
+
- **Advanced GMM Options**: Two-step and continuous updating GMM estimators for
|
|
69
|
+
different precision and speed requirements
|
|
70
|
+
- **Numerical Stability**: Robust optimization with enhanced convergence diagnostics
|
|
71
|
+
and graceful failure handling
|
|
72
|
+
- **High Precision**: Maintains ±1e-6 numerical accuracy for core algorithms,
|
|
73
|
+
ensuring reproducible research results
|
|
74
|
+
|
|
75
|
+
References
|
|
76
|
+
----------
|
|
77
|
+
.. [1] Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
78
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
79
|
+
https://doi.org/10.1111/rssb.12027
|
|
80
|
+
|
|
81
|
+
.. [2] Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
|
|
82
|
+
score for a continuous treatment: Application to the efficacy of political
|
|
83
|
+
advertisements. The Annals of Applied Statistics 12(1), 156-177.
|
|
84
|
+
https://doi.org/10.1214/17-AOAS1101
|
|
85
|
+
|
|
86
|
+
.. [3] Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
|
|
87
|
+
via a high-dimensional covariate balancing propensity score. Biometrika 107(3),
|
|
88
|
+
533-554. https://doi.org/10.1093/biomet/asaa020
|
|
89
|
+
|
|
90
|
+
.. [4] Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
|
|
91
|
+
score for general treatment regimes. Journal of the American Statistical
|
|
92
|
+
Association 113(523), 1316-1329. https://doi.org/10.1080/01621459.2017.1385465
|
|
93
|
+
|
|
94
|
+
.. [5] Imai, K. and Ratkovic, M. (2015). Robust estimation of inverse probability
|
|
95
|
+
weights for marginal structural models. Journal of the American Statistical
|
|
96
|
+
Association 110(511), 1013-1023. https://doi.org/10.1080/01621459.2014.956872
|
|
97
|
+
|
|
98
|
+
.. [6] Fong, C. (2018). Robust and efficient estimation of causal effects with
|
|
99
|
+
calibrated covariate balance. Unpublished manuscript.
|
|
100
|
+
|
|
101
|
+
License
|
|
102
|
+
-------
|
|
103
|
+
AGPL-3.0
|
|
104
|
+
|
|
105
|
+
Copyright (c) 2025-2026 Cai Xuanyu, Xu Wenli
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
from typing import Any, Optional, Union, Dict
|
|
109
|
+
import warnings
|
|
110
|
+
import pandas as pd
|
|
111
|
+
import numpy as np
|
|
112
|
+
|
|
113
|
+
__version__ = "0.1.0"
|
|
114
|
+
|
|
115
|
+
from cbps.core.results import CBPSResults, CBPSSummary
|
|
116
|
+
from cbps.core.cbps_binary import cbps_binary_fit
|
|
117
|
+
from cbps.logging_config import set_verbosity, logger
|
|
118
|
+
|
|
119
|
+
__all__ = [
|
|
120
|
+
"CBPS",
|
|
121
|
+
"cbps_fit",
|
|
122
|
+
"CBMSM",
|
|
123
|
+
"cbmsm_fit",
|
|
124
|
+
"npCBPS",
|
|
125
|
+
"npCBPS_fit",
|
|
126
|
+
"hdCBPS",
|
|
127
|
+
"CBIV",
|
|
128
|
+
"AsyVar",
|
|
129
|
+
"balance",
|
|
130
|
+
"vcov_outcome",
|
|
131
|
+
"plot_cbps",
|
|
132
|
+
"plot_cbps_continuous",
|
|
133
|
+
"plot_cbmsm",
|
|
134
|
+
"plot_npcbps",
|
|
135
|
+
"set_verbosity",
|
|
136
|
+
"fit_multiple",
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _add_balance_labels(balance_result: Dict[str, np.ndarray], cbps_dict: Dict[str, Any],
|
|
141
|
+
coef_names: Optional[list], is_continuous: bool) -> Dict[str, pd.DataFrame]:
|
|
142
|
+
"""
|
|
143
|
+
Attach covariate labels to balance assessment statistics.
|
|
144
|
+
|
|
145
|
+
This internal function transforms balance statistics from numpy arrays to
|
|
146
|
+
labeled pandas DataFrames, facilitating interpretation of balance diagnostics.
|
|
147
|
+
The labeling convention varies by treatment type to reflect the appropriate
|
|
148
|
+
balance metrics.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
balance_result : Dict[str, np.ndarray]
|
|
153
|
+
Balance statistics computed from either discrete or continuous treatment
|
|
154
|
+
models. Dictionary contains keys for weighted ('balanced') and unweighted
|
|
155
|
+
('original' for discrete, 'unweighted' for continuous) statistics.
|
|
156
|
+
cbps_dict : Dict[str, Any]
|
|
157
|
+
Fitted CBPS estimator object containing treatment assignment data and
|
|
158
|
+
model specifications necessary for label generation.
|
|
159
|
+
coef_names : list or None
|
|
160
|
+
Names of covariate variables excluding the intercept term. When None,
|
|
161
|
+
generic covariate labels are generated automatically.
|
|
162
|
+
is_continuous : bool
|
|
163
|
+
Indicator flag for continuous treatment models, which determines the
|
|
164
|
+
appropriate column labeling convention.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
Dict[str, pd.DataFrame]
|
|
169
|
+
Dictionary mirroring the input structure but with DataFrame objects
|
|
170
|
+
containing properly labeled rows (covariates) and columns (balance
|
|
171
|
+
statistics).
|
|
172
|
+
|
|
173
|
+
Notes
|
|
174
|
+
-----
|
|
175
|
+
Column labeling follows treatment-specific conventions:
|
|
176
|
+
|
|
177
|
+
* **Discrete treatments**: Statistics include treatment means and standardized
|
|
178
|
+
mean differences, with columns labeled as "treatment.mean" and
|
|
179
|
+
"treatment.std.mean"
|
|
180
|
+
* **Continuous treatments**: Statistics focus on correlation coefficients,
|
|
181
|
+
with the single column labeled as "corr"
|
|
182
|
+
|
|
183
|
+
The output follows standard balance table conventions, with rows for covariates
|
|
184
|
+
and columns for treatment-specific statistics.
|
|
185
|
+
"""
|
|
186
|
+
# Extract original numpy arrays
|
|
187
|
+
balanced_array = balance_result['balanced']
|
|
188
|
+
original_key = 'unweighted' if is_continuous else 'original'
|
|
189
|
+
original_array = balance_result[original_key]
|
|
190
|
+
|
|
191
|
+
# Generate row names (covariate names)
|
|
192
|
+
n_covars = balanced_array.shape[0]
|
|
193
|
+
if coef_names is not None and len(coef_names) == n_covars:
|
|
194
|
+
row_names = coef_names
|
|
195
|
+
else:
|
|
196
|
+
# Fall back to default names
|
|
197
|
+
row_names = [f"X{i+1}" for i in range(n_covars)]
|
|
198
|
+
|
|
199
|
+
# Generate column names
|
|
200
|
+
if is_continuous:
|
|
201
|
+
# Continuous treatment: single correlation column
|
|
202
|
+
col_names_balanced = ['corr']
|
|
203
|
+
col_names_original = ['corr']
|
|
204
|
+
else:
|
|
205
|
+
# Discrete treatment: mean and standardized mean for each level
|
|
206
|
+
treats = pd.Categorical(cbps_dict['y'])
|
|
207
|
+
treat_levels = treats.categories
|
|
208
|
+
n_treats = len(treat_levels)
|
|
209
|
+
|
|
210
|
+
# Generate all mean columns first, then all standardized mean columns
|
|
211
|
+
col_names = []
|
|
212
|
+
for level in treat_levels:
|
|
213
|
+
# Format treatment levels: remove decimal point for integers
|
|
214
|
+
if isinstance(level, (int, np.integer)):
|
|
215
|
+
level_str = str(int(level))
|
|
216
|
+
elif isinstance(level, (float, np.floating)) and level == int(level):
|
|
217
|
+
level_str = str(int(level)) # 0.0 → "0", 1.0 → "1"
|
|
218
|
+
else:
|
|
219
|
+
level_str = str(level)
|
|
220
|
+
col_names.append(f"{level_str}.mean")
|
|
221
|
+
for level in treat_levels:
|
|
222
|
+
# Apply same formatting logic
|
|
223
|
+
if isinstance(level, (int, np.integer)):
|
|
224
|
+
level_str = str(int(level))
|
|
225
|
+
elif isinstance(level, (float, np.floating)) and level == int(level):
|
|
226
|
+
level_str = str(int(level))
|
|
227
|
+
else:
|
|
228
|
+
level_str = str(level)
|
|
229
|
+
col_names.append(f"{level_str}.std.mean")
|
|
230
|
+
col_names_balanced = col_names
|
|
231
|
+
col_names_original = col_names
|
|
232
|
+
|
|
233
|
+
# Convert to DataFrame
|
|
234
|
+
balanced_df = pd.DataFrame(
|
|
235
|
+
balanced_array,
|
|
236
|
+
columns=col_names_balanced,
|
|
237
|
+
index=row_names
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
original_df = pd.DataFrame(
|
|
241
|
+
original_array,
|
|
242
|
+
columns=col_names_original,
|
|
243
|
+
index=row_names
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Return dictionary with DataFrames
|
|
247
|
+
return {
|
|
248
|
+
'balanced': balanced_df,
|
|
249
|
+
original_key: original_df
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _check_overlap_violation(
|
|
254
|
+
cbps_result: Any,
|
|
255
|
+
is_continuous: bool,
|
|
256
|
+
threshold: float = 0.05
|
|
257
|
+
) -> None:
|
|
258
|
+
"""
|
|
259
|
+
Assess potential violations of the overlap assumption in propensity scores.
|
|
260
|
+
|
|
261
|
+
The overlap assumption, also known as the common support condition, requires
|
|
262
|
+
that all units have non-zero probability of receiving each treatment level.
|
|
263
|
+
This diagnostic function identifies potential violations by detecting extreme
|
|
264
|
+
propensity score values that may indicate perfect separation, quasi-complete
|
|
265
|
+
separation, or substantial lack of overlap between treatment groups.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
cbps_result : CBPSResults
|
|
270
|
+
Fitted CBPS estimator object containing estimated propensity scores.
|
|
271
|
+
is_continuous : bool
|
|
272
|
+
Logical indicator distinguishing between discrete and continuous
|
|
273
|
+
treatment models. Overlap assessment differs by treatment type.
|
|
274
|
+
threshold : float, default=0.05
|
|
275
|
+
Proportion threshold for triggering warnings about extreme values.
|
|
276
|
+
The default 0.05 corresponds to 5% of the sample.
|
|
277
|
+
|
|
278
|
+
Notes
|
|
279
|
+
-----
|
|
280
|
+
The overlap assumption is fundamental for causal inference with propensity
|
|
281
|
+
scores. Formally, it requires that for all covariate values :math:`X`,
|
|
282
|
+
:math:`0 < \\Pr(T = t | X) < 1` for all treatment levels :math:`t`.
|
|
283
|
+
|
|
284
|
+
Extreme value detection follows treatment-specific conventions:
|
|
285
|
+
|
|
286
|
+
* **Discrete treatments**: Propensity scores below 0.01 or above 0.99 are
|
|
287
|
+
flagged as extreme, indicating potential lack of overlap
|
|
288
|
+
* **Continuous treatments**: The check is skipped as fitted values represent
|
|
289
|
+
probability densities rather than probabilities in [0,1]
|
|
290
|
+
|
|
291
|
+
Violations of overlap can lead to:
|
|
292
|
+
- Infinite or unstable coefficient estimates
|
|
293
|
+
- Large variance in treatment effect estimates
|
|
294
|
+
- Dependence on model extrapolation beyond the data support
|
|
295
|
+
|
|
296
|
+
References
|
|
297
|
+
----------
|
|
298
|
+
.. [1] King, G. and Zeng, L. (2001). Logistic regression in rare events data.
|
|
299
|
+
Political Analysis, 9(2), 137-163.
|
|
300
|
+
.. [2] Firth, D. (1993). Bias reduction of maximum likelihood estimates.
|
|
301
|
+
Biometrika, 80(1), 27-38.
|
|
302
|
+
.. [3] Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
303
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
304
|
+
"""
|
|
305
|
+
if is_continuous:
|
|
306
|
+
# For continuous treatments, fitted_values are probability densities
|
|
307
|
+
# rather than probabilities in [0,1]. Skip overlap check.
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
# Check discrete treatments for extreme propensity scores
|
|
311
|
+
fitted_vals = cbps_result.fitted_values
|
|
312
|
+
|
|
313
|
+
# Handle multi-treat case where fitted_values may be 2D
|
|
314
|
+
if fitted_vals.ndim == 2:
|
|
315
|
+
# Multi-treatment: check each column
|
|
316
|
+
probs = fitted_vals
|
|
317
|
+
else:
|
|
318
|
+
# Binary treatment: 1D array
|
|
319
|
+
probs = fitted_vals.ravel()
|
|
320
|
+
|
|
321
|
+
# Define extreme values: < 0.01 or > 0.99
|
|
322
|
+
extreme_low = 0.01
|
|
323
|
+
extreme_high = 0.99
|
|
324
|
+
|
|
325
|
+
# Calculate proportion of extreme values
|
|
326
|
+
if probs.ndim == 1:
|
|
327
|
+
# Binary treatment
|
|
328
|
+
n_extreme = np.sum((probs < extreme_low) | (probs > extreme_high))
|
|
329
|
+
else:
|
|
330
|
+
# Multi-treatment: count if any column has extreme values
|
|
331
|
+
n_extreme = np.sum(np.any((probs < extreme_low) | (probs > extreme_high), axis=1))
|
|
332
|
+
|
|
333
|
+
n_total = len(cbps_result.y)
|
|
334
|
+
extreme_ratio = n_extreme / n_total
|
|
335
|
+
|
|
336
|
+
if extreme_ratio > threshold:
|
|
337
|
+
warnings.warn(
|
|
338
|
+
f"Potential overlap violation detected: {extreme_ratio:.1%} of observations "
|
|
339
|
+
f"have extreme propensity scores (< {extreme_low} or > {extreme_high}). "
|
|
340
|
+
f"This may indicate:\n"
|
|
341
|
+
f" - Perfect or quasi-complete separation\n"
|
|
342
|
+
f" - Severe violation of the overlap assumption\n"
|
|
343
|
+
f" - Possible numerical instability in coefficient estimates\n\n"
|
|
344
|
+
f"Recommendations:\n"
|
|
345
|
+
f" - Check covariate balance diagnostics\n"
|
|
346
|
+
f" - Consider removing or combining problematic covariates\n"
|
|
347
|
+
f" - Use regularization methods (e.g., hdCBPS) if appropriate\n"
|
|
348
|
+
f" - Verify that treatment groups have sufficient covariate overlap\n\n"
|
|
349
|
+
f"Theory: CBPS assumes 0 < Pr(T|X) < 1 for all X (Imai & Ratkovic 2014, Assumption 1).",
|
|
350
|
+
UserWarning,
|
|
351
|
+
stacklevel=3
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _validate_finite_inputs(
|
|
356
|
+
treat: np.ndarray,
|
|
357
|
+
X: np.ndarray,
|
|
358
|
+
func_name: str = "CBPS"
|
|
359
|
+
) -> None:
|
|
360
|
+
"""
|
|
361
|
+
Validate input data for numerical finiteness.
|
|
362
|
+
|
|
363
|
+
This preprocessing function ensures that treatment and covariate data contain
|
|
364
|
+
only finite values, checking for the presence of NaN (Not a Number) or
|
|
365
|
+
infinite values that would compromise the optimization algorithm. The validation
|
|
366
|
+
adapts to different data types, gracefully handling categorical and string
|
|
367
|
+
variables which cannot contain numerical infinities.
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
treat : np.ndarray
|
|
372
|
+
Treatment assignment variable of shape (n,). May be numeric for binary
|
|
373
|
+
or continuous treatments, or categorical/string for multi-valued
|
|
374
|
+
treatments.
|
|
375
|
+
X : np.ndarray
|
|
376
|
+
Covariate matrix of shape (n, k) containing predictor variables.
|
|
377
|
+
Must contain only finite numeric values for model estimation.
|
|
378
|
+
func_name : str, default="CBPS"
|
|
379
|
+
Name of the calling function used to generate informative error
|
|
380
|
+
messages for debugging purposes.
|
|
381
|
+
|
|
382
|
+
Raises
|
|
383
|
+
------
|
|
384
|
+
ValueError
|
|
385
|
+
Raised when either the treatment variable or covariate matrix contains
|
|
386
|
+
NaN or infinite values. The error message includes the count of
|
|
387
|
+
problematic values and suggests data cleaning strategies.
|
|
388
|
+
|
|
389
|
+
Notes
|
|
390
|
+
-----
|
|
391
|
+
The function implements type-aware validation:
|
|
392
|
+
|
|
393
|
+
* **Numeric treatments**: Full finiteness check with detailed error reporting
|
|
394
|
+
* **Categorical treatments**: Validation skipped as categories cannot be
|
|
395
|
+
infinite or NaN
|
|
396
|
+
* **String treatments**: Validation skipped for the same reason
|
|
397
|
+
|
|
398
|
+
For covariates, all columns must be finite as missing or infinite values
|
|
399
|
+
would break the numerical optimization routines used in CBPS estimation.
|
|
400
|
+
"""
|
|
401
|
+
# Check treatment variable
|
|
402
|
+
# Skip isfinite check for categorical/string types (strings cannot have inf/nan)
|
|
403
|
+
treat_is_categorical = (
|
|
404
|
+
hasattr(treat, 'categories') or
|
|
405
|
+
(hasattr(treat, 'dtype') and hasattr(treat.dtype, 'categories'))
|
|
406
|
+
)
|
|
407
|
+
treat_is_string = (
|
|
408
|
+
hasattr(treat, 'dtype') and
|
|
409
|
+
(treat.dtype.kind in ('U', 'O', 'S')) # U=unicode, O=object, S=bytes
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Check treatment variable (skip string/categorical types)
|
|
413
|
+
if not treat_is_string and not treat_is_categorical:
|
|
414
|
+
# Attempt to convert to numeric type for validation
|
|
415
|
+
try:
|
|
416
|
+
treat_numeric = np.asarray(treat, dtype=np.float64)
|
|
417
|
+
if not np.all(np.isfinite(treat_numeric)):
|
|
418
|
+
n_inf = np.isinf(treat_numeric).sum()
|
|
419
|
+
n_nan = np.isnan(treat_numeric).sum()
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"{func_name}: Treatment variable contains {n_nan} NaN and {n_inf} Inf value(s). "
|
|
422
|
+
f"Inf values typically indicate data errors (division by zero, numerical overflow, "
|
|
423
|
+
f"or incorrect feature engineering). Please clean your data before calling {func_name}. "
|
|
424
|
+
f"Consider: data.dropna() or data[np.isfinite(data).all(axis=1)]"
|
|
425
|
+
)
|
|
426
|
+
except (ValueError, TypeError):
|
|
427
|
+
# Cannot convert to numeric type (e.g., strings), skip isfinite check
|
|
428
|
+
pass
|
|
429
|
+
|
|
430
|
+
# Check covariates
|
|
431
|
+
if not np.all(np.isfinite(X)):
|
|
432
|
+
n_inf = np.isinf(X).sum()
|
|
433
|
+
n_nan = np.isnan(X).sum()
|
|
434
|
+
# Identify columns containing inf/nan values
|
|
435
|
+
bad_cols = np.where(~np.all(np.isfinite(X), axis=0))[0]
|
|
436
|
+
raise ValueError(
|
|
437
|
+
f"{func_name}: Covariates contain {n_nan} NaN and {n_inf} Inf value(s) "
|
|
438
|
+
f"in column(s) {bad_cols.tolist()}. "
|
|
439
|
+
f"Inf values typically indicate data errors (e.g., log(0), division by zero). "
|
|
440
|
+
f"Please clean your data before calling {func_name}."
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _has_intercept(X: np.ndarray) -> bool:
|
|
445
|
+
"""
|
|
446
|
+
Detect whether the covariate matrix includes an intercept term.
|
|
447
|
+
|
|
448
|
+
This function determines if the first column of the design matrix represents
|
|
449
|
+
an intercept term (a column of ones). The formula interface automatically
|
|
450
|
+
includes an intercept, while the array interface requires explicit handling.
|
|
451
|
+
|
|
452
|
+
Parameters
|
|
453
|
+
----------
|
|
454
|
+
X : np.ndarray, shape (n, k)
|
|
455
|
+
Design matrix containing covariates and potentially an intercept term.
|
|
456
|
+
The matrix should be in the format expected by CBPS estimation functions.
|
|
457
|
+
|
|
458
|
+
Returns
|
|
459
|
+
-------
|
|
460
|
+
bool
|
|
461
|
+
True if the first column consists entirely of ones (within numerical
|
|
462
|
+
precision), False otherwise.
|
|
463
|
+
|
|
464
|
+
Notes
|
|
465
|
+
-----
|
|
466
|
+
The detection uses np.allclose with default tolerances to account for
|
|
467
|
+
floating-point representation errors. Values such as 1.0000001 or 0.9999999
|
|
468
|
+
are correctly identified as intercept terms.
|
|
469
|
+
|
|
470
|
+
This function is essential for:
|
|
471
|
+
- Proper handling of model specifications across interfaces
|
|
472
|
+
- Avoiding duplicate intercept terms in model fitting
|
|
473
|
+
- Maintaining numerical stability in optimization
|
|
474
|
+
|
|
475
|
+
Examples
|
|
476
|
+
--------
|
|
477
|
+
>>> import numpy as np
|
|
478
|
+
>>> X_with_intercept = np.column_stack([np.ones(100), np.random.normal(size=(100, 3))])
|
|
479
|
+
>>> _has_intercept(X_with_intercept)
|
|
480
|
+
True
|
|
481
|
+
>>> X_no_intercept = np.random.normal(size=(100, 3))
|
|
482
|
+
>>> _has_intercept(X_no_intercept)
|
|
483
|
+
False
|
|
484
|
+
"""
|
|
485
|
+
if X.shape[1] == 0:
|
|
486
|
+
return False
|
|
487
|
+
return np.allclose(X[:, 0], 1.0)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _apply_svd_preprocessing(X: np.ndarray) -> tuple:
|
|
491
|
+
"""
|
|
492
|
+
Apply SVD preprocessing to covariate matrix for numerical stability.
|
|
493
|
+
|
|
494
|
+
This function performs singular value decomposition preprocessing to improve
|
|
495
|
+
numerical stability in multi-valued treatment models.
|
|
496
|
+
|
|
497
|
+
Parameters
|
|
498
|
+
----------
|
|
499
|
+
X : np.ndarray, shape (n, k)
|
|
500
|
+
Covariate matrix with intercept in first column.
|
|
501
|
+
|
|
502
|
+
Returns
|
|
503
|
+
-------
|
|
504
|
+
X_svd : np.ndarray, shape (n, k)
|
|
505
|
+
SVD-orthogonalized matrix (first k columns of U matrix).
|
|
506
|
+
svd_info : dict
|
|
507
|
+
Dictionary containing SVD information needed for inverse transform:
|
|
508
|
+
- 'V': V matrix from SVD
|
|
509
|
+
- 'd': Singular values
|
|
510
|
+
- 'x_sd': Standard deviations for standardization
|
|
511
|
+
- 'x_mean': Means for standardization
|
|
512
|
+
- 'U': Complete U matrix
|
|
513
|
+
|
|
514
|
+
Notes
|
|
515
|
+
-----
|
|
516
|
+
Creates a copy of input matrix to avoid modifying original data.
|
|
517
|
+
"""
|
|
518
|
+
# Create a copy to avoid modifying input
|
|
519
|
+
X_work = X.copy()
|
|
520
|
+
X_orig = X_work.copy() # Save original unstandardized copy
|
|
521
|
+
|
|
522
|
+
# Standardize X (excluding intercept column)
|
|
523
|
+
x_sd = X_work[:, 1:].std(axis=0, ddof=1)
|
|
524
|
+
x_mean = X_work[:, 1:].mean(axis=0)
|
|
525
|
+
X_work[:, 1:] = (X_work[:, 1:] - x_mean) / x_sd
|
|
526
|
+
|
|
527
|
+
# SVD decomposition
|
|
528
|
+
U, s, Vt = np.linalg.svd(X_work, full_matrices=True)
|
|
529
|
+
V_matrix = Vt.T # NumPy returns Vt, R returns V
|
|
530
|
+
|
|
531
|
+
# Save SVD information for inverse transform
|
|
532
|
+
svd_info = {
|
|
533
|
+
'V': V_matrix,
|
|
534
|
+
'd': s,
|
|
535
|
+
'x_sd': x_sd,
|
|
536
|
+
'x_mean': x_mean,
|
|
537
|
+
'U': U,
|
|
538
|
+
'X_standardized': X_work.copy() # Save standardized X for debugging
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
# Replace X with U matrix (first k columns)
|
|
542
|
+
X_svd = U[:, :X_orig.shape[1]] # Take first k columns
|
|
543
|
+
|
|
544
|
+
return X_svd, svd_info
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def _apply_svd_inverse_transform(beta_svd: np.ndarray, svd_info: dict) -> np.ndarray:
|
|
548
|
+
"""
|
|
549
|
+
Apply inverse SVD transform to coefficient matrix.
|
|
550
|
+
|
|
551
|
+
Transforms coefficients from SVD-orthogonalized space back to original
|
|
552
|
+
covariate space.
|
|
553
|
+
|
|
554
|
+
Parameters
|
|
555
|
+
----------
|
|
556
|
+
beta_svd : np.ndarray, shape (k, K-1)
|
|
557
|
+
Coefficient matrix in SVD space.
|
|
558
|
+
svd_info : dict
|
|
559
|
+
SVD information dictionary returned by preprocessing function.
|
|
560
|
+
|
|
561
|
+
Returns
|
|
562
|
+
-------
|
|
563
|
+
beta_transformed : np.ndarray, shape (k, K-1)
|
|
564
|
+
Coefficient matrix in original covariate space.
|
|
565
|
+
|
|
566
|
+
Notes
|
|
567
|
+
-----
|
|
568
|
+
Transformation steps:
|
|
569
|
+
1. SVD inverse transform: beta = V @ diag(d_inv) @ beta_svd
|
|
570
|
+
2. Reverse standardization (except intercept): beta[1:,:] /= x_sd
|
|
571
|
+
3. Adjust intercept: beta[0,:] -= x_mean @ beta[1:,:]
|
|
572
|
+
"""
|
|
573
|
+
# Singular value truncation
|
|
574
|
+
d_inv = svd_info['d'].copy()
|
|
575
|
+
d_inv[d_inv > 1e-5] = 1.0 / d_inv[d_inv > 1e-5]
|
|
576
|
+
d_inv[d_inv <= 1e-5] = 0
|
|
577
|
+
|
|
578
|
+
# Apply inverse SVD transform to coefficients
|
|
579
|
+
beta_transformed = svd_info['V'] @ np.diag(d_inv) @ beta_svd
|
|
580
|
+
|
|
581
|
+
# Reverse standardization (except intercept)
|
|
582
|
+
beta_transformed[1:, :] = beta_transformed[1:, :] / svd_info['x_sd'][:, None]
|
|
583
|
+
|
|
584
|
+
# Adjust intercept
|
|
585
|
+
beta_transformed[0, :] = beta_transformed[0, :] - svd_info['x_mean'] @ beta_transformed[1:, :]
|
|
586
|
+
|
|
587
|
+
return beta_transformed
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
# Whitelist of allowed kwargs to pass through to fitting functions
|
|
591
|
+
_SCIPY_ALLOWED_KWARGS = {
|
|
592
|
+
'callback', # Optimization callback function
|
|
593
|
+
'tol', # Tolerance for termination
|
|
594
|
+
'options', # Options dictionary for optimizer
|
|
595
|
+
'bal_gtol', # Gradient tolerance for balance optimization (R-matching)
|
|
596
|
+
'gmm_gtol', # Gradient tolerance for GMM optimization (R-matching)
|
|
597
|
+
'init_params', # Warm start: initial parameter values
|
|
598
|
+
'show_progress', # Show tqdm progress bar during optimization
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
def _detect_treatment_type(
|
|
602
|
+
treat: np.ndarray,
|
|
603
|
+
formula: Optional[str] = None,
|
|
604
|
+
data: Optional[pd.DataFrame] = None,
|
|
605
|
+
treat_col_name: Optional[str] = None
|
|
606
|
+
) -> tuple[bool, bool, bool]:
|
|
607
|
+
"""
|
|
608
|
+
Detect the type of treatment variable for parameter validation and routing.
|
|
609
|
+
|
|
610
|
+
Parameters
|
|
611
|
+
----------
|
|
612
|
+
treat : np.ndarray
|
|
613
|
+
Treatment variable array.
|
|
614
|
+
formula : str, optional
|
|
615
|
+
Formula string (used for column name extraction).
|
|
616
|
+
data : pd.DataFrame, optional
|
|
617
|
+
Data frame (used for checking categorical types).
|
|
618
|
+
treat_col_name : str, optional
|
|
619
|
+
Treatment column name (if known).
|
|
620
|
+
|
|
621
|
+
Returns
|
|
622
|
+
-------
|
|
623
|
+
tuple of bool
|
|
624
|
+
(is_categorical, is_binary_01, is_continuous) where:
|
|
625
|
+
- is_categorical: True if pandas Categorical type
|
|
626
|
+
- is_binary_01: True if binary 0/1 numeric values
|
|
627
|
+
- is_continuous: True if continuous (non-binary, non-categorical)
|
|
628
|
+
|
|
629
|
+
Notes
|
|
630
|
+
-----
|
|
631
|
+
Detection logic:
|
|
632
|
+
1. If pandas Categorical → (True, False, False)
|
|
633
|
+
2. If unique values are {0, 1} → (False, True, False)
|
|
634
|
+
3. Otherwise → (False, False, True)
|
|
635
|
+
|
|
636
|
+
Examples
|
|
637
|
+
--------
|
|
638
|
+
>>> import numpy as np
|
|
639
|
+
>>> treat = np.array([0, 1, 0, 1])
|
|
640
|
+
>>> is_cat, is_bin, is_cont = _detect_treatment_type(treat)
|
|
641
|
+
>>> print(is_bin, is_cont)
|
|
642
|
+
True False
|
|
643
|
+
"""
|
|
644
|
+
# Ensure treat is a numpy array
|
|
645
|
+
treat_array = np.asarray(treat).ravel()
|
|
646
|
+
|
|
647
|
+
# Step 1: Check if pandas Categorical type
|
|
648
|
+
is_categorical = False
|
|
649
|
+
|
|
650
|
+
if data is not None and treat_col_name is not None:
|
|
651
|
+
# Check original column type from data
|
|
652
|
+
if treat_col_name in data.columns:
|
|
653
|
+
is_categorical = (
|
|
654
|
+
isinstance(data[treat_col_name].dtype, pd.CategoricalDtype) or
|
|
655
|
+
isinstance(data[treat_col_name], pd.Categorical)
|
|
656
|
+
)
|
|
657
|
+
elif hasattr(treat, 'cat'):
|
|
658
|
+
# Directly passed Series might have .cat attribute
|
|
659
|
+
is_categorical = True
|
|
660
|
+
elif isinstance(treat, pd.Categorical):
|
|
661
|
+
is_categorical = True
|
|
662
|
+
|
|
663
|
+
# Step 2: If not categorical, check if binary 0/1
|
|
664
|
+
is_binary_01 = False
|
|
665
|
+
if not is_categorical:
|
|
666
|
+
treat_unique = np.unique(treat_array)
|
|
667
|
+
is_binary_01 = (
|
|
668
|
+
len(treat_unique) == 2 and
|
|
669
|
+
set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
# Step 3: Determine if continuous treatment
|
|
673
|
+
is_continuous = (
|
|
674
|
+
not is_categorical and
|
|
675
|
+
not is_binary_01 and
|
|
676
|
+
np.issubdtype(treat_array.dtype, np.number)
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
return is_categorical, is_binary_01, is_continuous
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def CBPS(
|
|
683
|
+
formula: Optional[str] = None,
|
|
684
|
+
data: Optional[pd.DataFrame] = None,
|
|
685
|
+
treatment: Optional[np.ndarray] = None,
|
|
686
|
+
covariates: Optional[np.ndarray] = None,
|
|
687
|
+
att: int = 1,
|
|
688
|
+
method: str = 'over',
|
|
689
|
+
two_step: bool = True,
|
|
690
|
+
standardize: bool = True,
|
|
691
|
+
sample_weights: Optional[np.ndarray] = None,
|
|
692
|
+
baseline_formula: Optional[str] = None,
|
|
693
|
+
diff_formula: Optional[str] = None,
|
|
694
|
+
iterations: int = 1000,
|
|
695
|
+
theoretical_exact: bool = False,
|
|
696
|
+
na_action: str = 'warn',
|
|
697
|
+
verbose: int = 0,
|
|
698
|
+
ATT: Optional[int] = None,
|
|
699
|
+
twostep: Optional[bool] = None,
|
|
700
|
+
**kwargs
|
|
701
|
+
) -> CBPSResults:
|
|
702
|
+
"""
|
|
703
|
+
Covariate Balancing Propensity Score (CBPS) Estimation
|
|
704
|
+
|
|
705
|
+
Estimates propensity scores such that both covariate balance and prediction
|
|
706
|
+
of treatment assignment are simultaneously maximized. The method avoids
|
|
707
|
+
the iterative process between model fitting and balance checking by
|
|
708
|
+
implementing both objectives simultaneously.
|
|
709
|
+
|
|
710
|
+
Supports binary, multi-valued (3-4 levels), and continuous treatments.
|
|
711
|
+
|
|
712
|
+
Parameters
|
|
713
|
+
----------
|
|
714
|
+
formula : str, optional
|
|
715
|
+
A symbolic description of the model to be fitted. The formula should
|
|
716
|
+
be of the form ``treatment ~ covariate1 + covariate2 + ...``.
|
|
717
|
+
Either ``formula`` and ``data`` or ``treatment`` and ``covariates``
|
|
718
|
+
must be provided.
|
|
719
|
+
data : pd.DataFrame, optional
|
|
720
|
+
A data frame containing the variables in the model. Required when
|
|
721
|
+
using the formula interface.
|
|
722
|
+
treatment : np.ndarray, optional
|
|
723
|
+
Treatment vector. Required when using the array interface instead of
|
|
724
|
+
the formula interface.
|
|
725
|
+
covariates : np.ndarray, optional
|
|
726
|
+
Covariate matrix. Required when using the array interface. Should not
|
|
727
|
+
include an intercept column (it will be added automatically).
|
|
728
|
+
att : int, default 1
|
|
729
|
+
Target estimand. 0 for ATE (average treatment effect), 1 for ATT
|
|
730
|
+
with the second level as treated, 2 for ATT with the first level as
|
|
731
|
+
treated. For non-binary treatments, only ATE is available.
|
|
732
|
+
ATT : int, optional
|
|
733
|
+
Deprecated. Use lowercase ``att`` instead.
|
|
734
|
+
method : {'over', 'exact'}, default 'over'
|
|
735
|
+
Estimation method. 'over' for over-identified GMM (combines propensity
|
|
736
|
+
score likelihood and covariate balancing conditions), 'exact' for
|
|
737
|
+
exactly-identified GMM (covariate balancing conditions only).
|
|
738
|
+
two_step : bool, default True
|
|
739
|
+
If True, uses the two-step GMM estimator (faster). If False, uses
|
|
740
|
+
the continuous-updating GMM estimator (better finite sample properties).
|
|
741
|
+
twostep : bool, optional
|
|
742
|
+
Alias for ``two_step`` parameter. Use ``two_step`` for consistency
|
|
743
|
+
with Python naming conventions.
|
|
744
|
+
standardize : bool, default True
|
|
745
|
+
If True, normalizes weights to sum to 1 within each treatment group
|
|
746
|
+
(or to 1 for the entire sample with continuous treatments). If False,
|
|
747
|
+
returns Horvitz-Thompson weights.
|
|
748
|
+
sample_weights : np.ndarray, optional
|
|
749
|
+
Survey sampling weights for the observations. If None, defaults to
|
|
750
|
+
equal weights of 1 for each observation.
|
|
751
|
+
baseline_formula : str, optional
|
|
752
|
+
Formula for covariates in the baseline outcome model E(Y(0)|X). Used only
|
|
753
|
+
for optimal CBPS (iCBPS) with binary treatments.
|
|
754
|
+
diff_formula : str, optional
|
|
755
|
+
Formula for covariates in the treatment effect difference model
|
|
756
|
+
E(Y(1)-Y(0)|X). Used only for optimal CBPS (iCBPS) with binary treatments.
|
|
757
|
+
iterations : int, default 1000
|
|
758
|
+
Maximum number of iterations for the optimization algorithm.
|
|
759
|
+
theoretical_exact : bool, default False
|
|
760
|
+
When method='exact', uses direct equation solver for exact GMM solution.
|
|
761
|
+
If False, uses balance loss optimization (default behavior).
|
|
762
|
+
na_action : {'warn', 'fail', 'ignore'}, default 'warn'
|
|
763
|
+
How to handle missing values. 'warn' removes observations with missing
|
|
764
|
+
values and issues a warning, 'fail' raises an error, 'ignore' uses
|
|
765
|
+
patsy's default behavior.
|
|
766
|
+
verbose : int, default 0
|
|
767
|
+
Verbosity level. 0 for silent output, 1 for basic progress, 2 for
|
|
768
|
+
detailed iteration information.
|
|
769
|
+
**kwargs
|
|
770
|
+
Additional parameters passed to the optimization routine.
|
|
771
|
+
|
|
772
|
+
Returns
|
|
773
|
+
-------
|
|
774
|
+
CBPSResults
|
|
775
|
+
A fitted CBPS object containing:
|
|
776
|
+
- coefficients: estimated propensity score coefficients
|
|
777
|
+
- fitted.values: estimated propensity scores
|
|
778
|
+
- weights: covariate balancing weights
|
|
779
|
+
- converged: convergence status
|
|
780
|
+
- j_statistic: J-statistic for overidentification test
|
|
781
|
+
|
|
782
|
+
Raises
|
|
783
|
+
------
|
|
784
|
+
ValueError
|
|
785
|
+
If required inputs are missing or invalid, or if the model cannot be
|
|
786
|
+
estimated (e.g., perfect collinearity, insufficient sample size).
|
|
787
|
+
|
|
788
|
+
Notes
|
|
789
|
+
-----
|
|
790
|
+
**Treatment Type Detection**
|
|
791
|
+
|
|
792
|
+
- Binary treatments: Automatically detected for integer arrays with ≤4 unique values
|
|
793
|
+
- Multi-valued treatments: Must be converted to ``pd.Categorical`` before fitting
|
|
794
|
+
- Continuous treatments: Automatically detected for floating-point arrays or >4 unique values
|
|
795
|
+
|
|
796
|
+
**Estimation Methods**
|
|
797
|
+
|
|
798
|
+
- The 'over' method combines likelihood-based score functions with covariate
|
|
799
|
+
balance constraints in an over-identified GMM framework
|
|
800
|
+
- The 'exact' method uses only covariate balancing conditions (exactly-identified)
|
|
801
|
+
|
|
802
|
+
**Weight Standardization**
|
|
803
|
+
|
|
804
|
+
- When standardize=True, weights sum to 1 within each treatment group
|
|
805
|
+
- When standardize=False, returns Horvitz-Thompson weights
|
|
806
|
+
|
|
807
|
+
References
|
|
808
|
+
----------
|
|
809
|
+
Imai, K. and Ratkovic, M. (2014). Covariate Balancing Propensity Score.
|
|
810
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
811
|
+
https://doi.org/10.1111/rssb.12027
|
|
812
|
+
|
|
813
|
+
Fan, J., Imai, K., Lee, I., Liu, H., Ning, Y., and Yang, X. (2022).
|
|
814
|
+
Optimal Covariate Balancing Conditions in Propensity Score Estimation.
|
|
815
|
+
Journal of Business & Economic Statistics, 41(1), 97-110.
|
|
816
|
+
|
|
817
|
+
Examples
|
|
818
|
+
--------
|
|
819
|
+
>>> import cbps
|
|
820
|
+
>>> from cbps.datasets import load_lalonde
|
|
821
|
+
>>> # Load LaLonde job training data
|
|
822
|
+
>>> data = load_lalonde(dehejia_wahba_only=True)
|
|
823
|
+
>>> # Estimate CBPS for ATT
|
|
824
|
+
>>> fit = cbps.CBPS('treat ~ age + educ + black + hisp', data=data, att=1)
|
|
825
|
+
>>> print(fit.summary())
|
|
826
|
+
>>> # Access weights for downstream analysis
|
|
827
|
+
>>> weights = fit.weights
|
|
828
|
+
|
|
829
|
+
"""
|
|
830
|
+
# Handle twostep parameter alias for compatibility
|
|
831
|
+
if twostep is not None:
|
|
832
|
+
# Use twostep value if provided (overrides two_step)
|
|
833
|
+
two_step = twostep
|
|
834
|
+
|
|
835
|
+
# Parameter validation
|
|
836
|
+
# att must be 0, 1, or 2
|
|
837
|
+
# att=0: ATE, att=1: ATT (T=1 as treated), att=2: ATT (T=0 as treated)
|
|
838
|
+
# Check type first, then value range (TypeError before ValueError)
|
|
839
|
+
if not isinstance(att, (int, np.integer)):
|
|
840
|
+
raise TypeError(
|
|
841
|
+
f"att must be an integer (0, 1, or 2), got type {type(att).__name__}: {att}"
|
|
842
|
+
)
|
|
843
|
+
if att not in [0, 1, 2]:
|
|
844
|
+
raise ValueError(
|
|
845
|
+
f"Invalid att parameter: {att}\n\n"
|
|
846
|
+
f"att must be 0, 1, or 2:\n"
|
|
847
|
+
f" att=0: ATE (Average Treatment Effect) for entire population\n"
|
|
848
|
+
f" att=1: ATT (Average Treatment effect on the Treated, T=1 as treated)\n"
|
|
849
|
+
f" att=2: ATT (Average Treatment effect on the Treated, T=0 as treated)\n\n"
|
|
850
|
+
f"You provided: att={att}"
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
# Handle legacy uppercase ATT parameter for backward compatibility
|
|
854
|
+
if ATT is not None:
|
|
855
|
+
# Validate ATT parameter
|
|
856
|
+
if not isinstance(ATT, (int, np.integer)) or ATT not in [0, 1]:
|
|
857
|
+
raise ValueError(
|
|
858
|
+
f"Invalid ATT parameter: {ATT}\n\n"
|
|
859
|
+
f"ATT must be either 0 or 1:\n"
|
|
860
|
+
f" ATT=0: ATE (Average Treatment Effect)\n"
|
|
861
|
+
f" ATT=1: ATT (Average Treatment effect on the Treated)\n\n"
|
|
862
|
+
f"You provided: ATT={ATT} (type: {type(ATT).__name__})"
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
if att == 1: # att is default value, user didn't explicitly set it
|
|
866
|
+
att = ATT
|
|
867
|
+
warnings.warn(
|
|
868
|
+
f"Using deprecated parameter name 'ATT={ATT}'. "
|
|
869
|
+
f"Please use lowercase 'att={ATT}' instead for consistency with Python naming conventions.",
|
|
870
|
+
DeprecationWarning,
|
|
871
|
+
stacklevel=2
|
|
872
|
+
)
|
|
873
|
+
else:
|
|
874
|
+
# User set both att and ATT with different values
|
|
875
|
+
warnings.warn(
|
|
876
|
+
f"Both 'att={att}' and 'ATT={ATT}' were specified. Using 'att={att}'. "
|
|
877
|
+
f"Please use only 'att' parameter (lowercase) to avoid confusion.",
|
|
878
|
+
UserWarning
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
# Validate kwargs to prevent confusing scipy errors
|
|
882
|
+
if kwargs:
|
|
883
|
+
invalid_kwargs = set(kwargs.keys()) - _SCIPY_ALLOWED_KWARGS
|
|
884
|
+
if invalid_kwargs:
|
|
885
|
+
# Check if this is a common error (uppercase parameter names)
|
|
886
|
+
suggestions = []
|
|
887
|
+
for invalid_key in invalid_kwargs:
|
|
888
|
+
# Check for case confusion
|
|
889
|
+
if invalid_key.upper() == 'ATT':
|
|
890
|
+
suggestions.append(f" - Did you mean 'att' (lowercase) instead of '{invalid_key}'?")
|
|
891
|
+
elif invalid_key.lower() == 'standardize':
|
|
892
|
+
suggestions.append(f" - Did you mean 'standardize' (correct spelling) instead of '{invalid_key}'?")
|
|
893
|
+
elif invalid_key.lower() == 'method':
|
|
894
|
+
suggestions.append(f" - Did you mean 'method' (lowercase) instead of '{invalid_key}'?")
|
|
895
|
+
|
|
896
|
+
error_msg = (
|
|
897
|
+
f"CBPS() got unexpected keyword argument(s): {sorted(invalid_kwargs)}\n\n"
|
|
898
|
+
f"Valid scipy.optimize parameters are: {sorted(_SCIPY_ALLOWED_KWARGS)}\n"
|
|
899
|
+
)
|
|
900
|
+
if suggestions:
|
|
901
|
+
error_msg += "\nCommon mistakes:\n" + "\n".join(suggestions)
|
|
902
|
+
else:
|
|
903
|
+
error_msg += (
|
|
904
|
+
"\nNote: CBPS parameters (att, method, standardize, etc.) should be "
|
|
905
|
+
"specified as named arguments, not in **kwargs."
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
raise TypeError(error_msg)
|
|
909
|
+
|
|
910
|
+
# Mutual exclusivity check: formula and treatment cannot both be specified
|
|
911
|
+
if formula is not None and treatment is not None:
|
|
912
|
+
raise ValueError(
|
|
913
|
+
"Cannot specify both 'formula' and 'treatment' parameters. "
|
|
914
|
+
"Please use either:\n"
|
|
915
|
+
" 1. Formula interface: CBPS(formula='treat ~ X1 + X2', data=df)\n"
|
|
916
|
+
" 2. Array interface: CBPS(treatment=treat_array, covariates=X_array)\n"
|
|
917
|
+
f"\nReceived:\n"
|
|
918
|
+
f" formula = {repr(formula)}\n"
|
|
919
|
+
f" treatment = {'<array>' if treatment is not None else 'None'}"
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
# Validate iterations parameter
|
|
923
|
+
if not isinstance(iterations, (int, np.integer)):
|
|
924
|
+
raise TypeError(
|
|
925
|
+
f"iterations must be an integer, got {type(iterations).__name__}. "
|
|
926
|
+
f"Received: iterations={iterations}"
|
|
927
|
+
)
|
|
928
|
+
if iterations < 1:
|
|
929
|
+
raise ValueError(
|
|
930
|
+
f"iterations must be ≥1 (at least one optimization step required). "
|
|
931
|
+
f"Received: iterations={iterations}"
|
|
932
|
+
)
|
|
933
|
+
if iterations > 100000:
|
|
934
|
+
warnings.warn(
|
|
935
|
+
f"iterations={iterations} is very large and may take a long time. "
|
|
936
|
+
f"Consider using a smaller value (default is 1000).",
|
|
937
|
+
UserWarning
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Validate att parameter
|
|
941
|
+
if not isinstance(att, (int, np.integer)):
|
|
942
|
+
raise TypeError(
|
|
943
|
+
f"att must be an integer (0, 1, or 2), got {type(att).__name__}. "
|
|
944
|
+
f"Received: att={att}"
|
|
945
|
+
)
|
|
946
|
+
if att not in (0, 1, 2):
|
|
947
|
+
raise ValueError(
|
|
948
|
+
f"att must be 0 (ATE), 1 (ATT treated=level2), or 2 (ATT treated=level1). "
|
|
949
|
+
f"Received: att={att}\n\n"
|
|
950
|
+
f"Explanation:\n"
|
|
951
|
+
f" att=0: Average Treatment Effect (ATE) for entire population\n"
|
|
952
|
+
f" att=1: Average Treatment effect on the Treated (ATT), second level as treated\n"
|
|
953
|
+
f" att=2: ATT with first level as treated"
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
# Validate method parameter
|
|
957
|
+
valid_methods = {'over', 'exact'}
|
|
958
|
+
if not isinstance(method, str):
|
|
959
|
+
raise TypeError(
|
|
960
|
+
f"method must be a string, got {type(method).__name__}. "
|
|
961
|
+
f"Received: method={method}"
|
|
962
|
+
)
|
|
963
|
+
if method not in valid_methods:
|
|
964
|
+
raise ValueError(
|
|
965
|
+
f"method must be one of {valid_methods}. "
|
|
966
|
+
f"Received: method='{method}'\n\n"
|
|
967
|
+
f"Explanation:\n"
|
|
968
|
+
f" method='over': Over-identified GMM (score + balance conditions, recommended)\n"
|
|
969
|
+
f" method='exact': Exactly-identified GMM (balance conditions only)\n\n"
|
|
970
|
+
f"Note: method is case-sensitive, use lowercase only."
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
# Validate theoretical_exact compatibility with method parameter
|
|
974
|
+
if theoretical_exact and method != 'exact':
|
|
975
|
+
warnings.warn(
|
|
976
|
+
f"theoretical_exact=True only works with method='exact'. "
|
|
977
|
+
f"Current method='{method}' does not use this parameter. "
|
|
978
|
+
f"The theoretical_exact parameter will be ignored.\n\n"
|
|
979
|
+
f"To use theoretical_exact, set method='exact'.",
|
|
980
|
+
UserWarning
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
# Validate verbose parameter
|
|
984
|
+
if not isinstance(verbose, (int, np.integer)):
|
|
985
|
+
raise TypeError(
|
|
986
|
+
f"verbose must be an integer (0, 1, or 2), got {type(verbose).__name__}. "
|
|
987
|
+
f"Received: verbose={verbose}"
|
|
988
|
+
)
|
|
989
|
+
if verbose not in (0, 1, 2):
|
|
990
|
+
raise ValueError(
|
|
991
|
+
f"verbose must be 0 (silent), 1 (basic), or 2 (detailed). "
|
|
992
|
+
f"Received: verbose={verbose}"
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
# Validate two_step parameter
|
|
996
|
+
if not isinstance(two_step, bool):
|
|
997
|
+
raise TypeError(
|
|
998
|
+
f"two_step must be a boolean (True or False), got {type(two_step).__name__}. "
|
|
999
|
+
f"Received: two_step={two_step}\n\n"
|
|
1000
|
+
f"Hint: Use True or False, not 1 or 0."
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
# Note: method='exact' and two_step=True is a valid combination.
|
|
1004
|
+
# In R's CBPS package, method='exact' sets bal.only=TRUE (only balance
|
|
1005
|
+
# conditions used for optimization), while twostep independently controls
|
|
1006
|
+
# whether analytical gradient is used in balance optimization.
|
|
1007
|
+
# twostep=TRUE → analytical gradient; twostep=FALSE → numerical gradient.
|
|
1008
|
+
# These two parameters are orthogonal and should NOT override each other.
|
|
1009
|
+
|
|
1010
|
+
# Validate standardize parameter
|
|
1011
|
+
if not isinstance(standardize, bool):
|
|
1012
|
+
raise TypeError(
|
|
1013
|
+
f"standardize must be a boolean (True or False), got {type(standardize).__name__}. "
|
|
1014
|
+
f"Received: standardize={standardize}\n\n"
|
|
1015
|
+
f"Hint: Use True or False, not 1 or 0."
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
# Step 1: Formula path vs array path
|
|
1019
|
+
na_action_info = None # Track missing value handling info
|
|
1020
|
+
|
|
1021
|
+
# Initialize metadata variables (needed for all code paths)
|
|
1022
|
+
data_original = None
|
|
1023
|
+
terms_obj = None
|
|
1024
|
+
model_frame = None
|
|
1025
|
+
xlevels_obj = None
|
|
1026
|
+
|
|
1027
|
+
if formula is not None:
|
|
1028
|
+
# Formula interface path
|
|
1029
|
+
|
|
1030
|
+
# Validate data parameter type
|
|
1031
|
+
if data is None:
|
|
1032
|
+
raise ValueError(
|
|
1033
|
+
"data parameter is required when using formula interface. "
|
|
1034
|
+
"Please provide a pandas DataFrame containing the variables in your formula."
|
|
1035
|
+
)
|
|
1036
|
+
if not isinstance(data, pd.DataFrame):
|
|
1037
|
+
raise TypeError(
|
|
1038
|
+
f"data must be a pandas DataFrame when using formula interface. "
|
|
1039
|
+
f"Got: {type(data).__name__}. "
|
|
1040
|
+
f"If you have a dict, convert it: pd.DataFrame(your_dict). "
|
|
1041
|
+
f"Or use the array interface: CBPS(treatment=..., covariates=...)"
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
# Validate formula type
|
|
1045
|
+
if not isinstance(formula, str):
|
|
1046
|
+
raise TypeError(
|
|
1047
|
+
f"formula must be a string, got {type(formula).__name__}. "
|
|
1048
|
+
f"Received: formula={formula}\n\n"
|
|
1049
|
+
f"Example of correct formula: 'treat ~ age + educ + black'"
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
# Validate formula format
|
|
1053
|
+
if '~' not in formula:
|
|
1054
|
+
raise ValueError(
|
|
1055
|
+
f"Formula must contain '~' to separate treatment from covariates. "
|
|
1056
|
+
f"Got: '{formula}'. "
|
|
1057
|
+
f"Example: 'treat ~ age + educ + black'"
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
# Step 1.1: Handle missing values
|
|
1061
|
+
# Extract columns involved in formula
|
|
1062
|
+
treat_col = formula.split('~')[0].strip()
|
|
1063
|
+
covar_cols = [col.strip() for col in formula.split('~')[1].split('+')]
|
|
1064
|
+
|
|
1065
|
+
# Use exact column matching (avoid substring matching issues)
|
|
1066
|
+
relevant_cols = [treat_col] + covar_cols
|
|
1067
|
+
# Filter out columns not in data (handles I() and other functions)
|
|
1068
|
+
relevant_cols = [col for col in relevant_cols if col in data.columns]
|
|
1069
|
+
|
|
1070
|
+
# Validate na_action parameter value
|
|
1071
|
+
valid_na_actions = {'warn', 'fail', 'ignore', 'omit'}
|
|
1072
|
+
if na_action not in valid_na_actions:
|
|
1073
|
+
raise ValueError(
|
|
1074
|
+
f"Invalid na_action='{na_action}'. "
|
|
1075
|
+
f"Valid options are: {', '.join(repr(x) for x in sorted(valid_na_actions))}. "
|
|
1076
|
+
f"Note: 'omit' is an alias for 'warn'."
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
# Alias mapping: 'omit' maps to 'warn'
|
|
1080
|
+
if na_action == 'omit':
|
|
1081
|
+
na_action = 'warn'
|
|
1082
|
+
|
|
1083
|
+
# Check for missing values
|
|
1084
|
+
n_missing = data[relevant_cols].isna().any(axis=1).sum()
|
|
1085
|
+
if n_missing > 0:
|
|
1086
|
+
if na_action == 'fail':
|
|
1087
|
+
raise ValueError(
|
|
1088
|
+
f"Missing values detected in {n_missing} observations. "
|
|
1089
|
+
f"Set na_action='warn' to remove them, or handle missing values before calling CBPS()."
|
|
1090
|
+
)
|
|
1091
|
+
elif na_action == 'warn':
|
|
1092
|
+
from cbps.utils.helpers import handle_missing
|
|
1093
|
+
data_clean, n_dropped = handle_missing(data, relevant_cols)
|
|
1094
|
+
data = data_clean
|
|
1095
|
+
na_action_info = {'method': 'omit', 'n_dropped': n_dropped}
|
|
1096
|
+
elif na_action == 'ignore':
|
|
1097
|
+
# Ignore mode: silently remove missing values, still record info
|
|
1098
|
+
data_clean = data.dropna(subset=relevant_cols)
|
|
1099
|
+
n_dropped = len(data) - len(data_clean)
|
|
1100
|
+
data = data_clean
|
|
1101
|
+
na_action_info = {'method': 'ignore', 'n_dropped': n_dropped}
|
|
1102
|
+
|
|
1103
|
+
from patsy import dmatrices, PatsyError
|
|
1104
|
+
from cbps.utils.formula import _convert_r_formula_to_patsy
|
|
1105
|
+
|
|
1106
|
+
# Support dot formula (treat ~ .)
|
|
1107
|
+
# Expands 'y ~ .' to 'y ~ x1 + x2 + ...' since Patsy doesn't support dot syntax
|
|
1108
|
+
if isinstance(formula, str) and '~' in formula:
|
|
1109
|
+
parts = formula.split('~')
|
|
1110
|
+
if len(parts) == 2 and parts[1].strip() == '.':
|
|
1111
|
+
if data is None:
|
|
1112
|
+
raise ValueError("Data must be provided when using dot formula ('~ .')")
|
|
1113
|
+
|
|
1114
|
+
# Parse treatment variable name
|
|
1115
|
+
treat_part = parts[0].strip()
|
|
1116
|
+
|
|
1117
|
+
# Extract real column name (handle C() or factor())
|
|
1118
|
+
import re
|
|
1119
|
+
real_treat_col = treat_part
|
|
1120
|
+
c_match = re.match(r'C\(([^)]+)\)', treat_part)
|
|
1121
|
+
factor_match = re.match(r'factor\(([^)]+)\)', treat_part)
|
|
1122
|
+
if c_match:
|
|
1123
|
+
real_treat_col = c_match.group(1).strip()
|
|
1124
|
+
elif factor_match:
|
|
1125
|
+
real_treat_col = factor_match.group(1).strip()
|
|
1126
|
+
|
|
1127
|
+
# Get all other columns
|
|
1128
|
+
other_cols = [c for c in data.columns if c != real_treat_col]
|
|
1129
|
+
if not other_cols:
|
|
1130
|
+
raise ValueError("No covariates found in data (only treatment column exists)")
|
|
1131
|
+
|
|
1132
|
+
# Rebuild formula
|
|
1133
|
+
# Quote column names with spaces or special characters using Q()
|
|
1134
|
+
def _quote_if_needed(col):
|
|
1135
|
+
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', col):
|
|
1136
|
+
return f"Q('{col}')"
|
|
1137
|
+
return col
|
|
1138
|
+
|
|
1139
|
+
rhs = ' + '.join([_quote_if_needed(c) for c in other_cols])
|
|
1140
|
+
formula = f"{treat_part} ~ {rhs}"
|
|
1141
|
+
|
|
1142
|
+
# Convert R formula syntax to patsy
|
|
1143
|
+
formula = _convert_r_formula_to_patsy(formula)
|
|
1144
|
+
|
|
1145
|
+
# Extract treatment variable from original data to avoid Patsy's one-hot encoding
|
|
1146
|
+
treat_col_name = formula.split('~')[0].strip()
|
|
1147
|
+
|
|
1148
|
+
# Support C() and factor() syntax for explicit categorical specification
|
|
1149
|
+
import re
|
|
1150
|
+
|
|
1151
|
+
# Detect C() or factor() wrapper
|
|
1152
|
+
categorical_from_formula = False
|
|
1153
|
+
c_match = re.match(r'C\(([^)]+)\)', treat_col_name)
|
|
1154
|
+
factor_match = re.match(r'factor\(([^)]+)\)', treat_col_name)
|
|
1155
|
+
|
|
1156
|
+
if c_match:
|
|
1157
|
+
# 'C(treat)' -> 'treat'
|
|
1158
|
+
real_treat_col = c_match.group(1).strip()
|
|
1159
|
+
categorical_from_formula = True
|
|
1160
|
+
elif factor_match:
|
|
1161
|
+
# 'factor(treat)' -> 'treat'
|
|
1162
|
+
real_treat_col = factor_match.group(1).strip()
|
|
1163
|
+
categorical_from_formula = True
|
|
1164
|
+
else:
|
|
1165
|
+
# Plain column name
|
|
1166
|
+
real_treat_col = treat_col_name
|
|
1167
|
+
|
|
1168
|
+
# Save treatment category names for summary display
|
|
1169
|
+
treat_categories_from_formula = None
|
|
1170
|
+
if real_treat_col in data.columns:
|
|
1171
|
+
# Extract treatment variable from original data (preserve Categorical type)
|
|
1172
|
+
treat_orig_series = data[real_treat_col]
|
|
1173
|
+
|
|
1174
|
+
# If formula uses C() or factor(), force categorical flag
|
|
1175
|
+
if categorical_from_formula:
|
|
1176
|
+
is_treat_categorical = True
|
|
1177
|
+
# Convert to categorical if not already
|
|
1178
|
+
if not isinstance(treat_orig_series.dtype, pd.CategoricalDtype):
|
|
1179
|
+
treat_orig_series = pd.Categorical(treat_orig_series)
|
|
1180
|
+
treat_categories_from_formula = list(treat_orig_series.categories)
|
|
1181
|
+
# Extract numeric codes
|
|
1182
|
+
treat = treat_orig_series.codes if hasattr(treat_orig_series, 'codes') else treat_orig_series.cat.codes.to_numpy()
|
|
1183
|
+
if treat_categories_from_formula is None:
|
|
1184
|
+
treat_categories_from_formula = list(treat_orig_series.categories if isinstance(treat_orig_series, pd.Categorical) else treat_orig_series.cat.categories)
|
|
1185
|
+
else:
|
|
1186
|
+
# Detect if categorical (check priority to avoid duplicate conversion)
|
|
1187
|
+
is_treat_categorical = (
|
|
1188
|
+
isinstance(treat_orig_series.dtype, pd.CategoricalDtype) or
|
|
1189
|
+
isinstance(treat_orig_series, pd.Categorical)
|
|
1190
|
+
)
|
|
1191
|
+
|
|
1192
|
+
# Auto-convert string treatment to categorical
|
|
1193
|
+
if not is_treat_categorical:
|
|
1194
|
+
treat = treat_orig_series.to_numpy() # Convert to numpy array
|
|
1195
|
+
|
|
1196
|
+
# Detect string/object type
|
|
1197
|
+
if treat.dtype == object or pd.api.types.is_string_dtype(treat):
|
|
1198
|
+
# Auto-convert to categorical
|
|
1199
|
+
treat_orig_series = pd.Categorical(treat_orig_series)
|
|
1200
|
+
treat = treat_orig_series.codes # Convert to numeric codes
|
|
1201
|
+
is_treat_categorical = True
|
|
1202
|
+
# Save original category names
|
|
1203
|
+
treat_categories_from_formula = list(treat_orig_series.categories)
|
|
1204
|
+
warnings.warn(
|
|
1205
|
+
f"Treatment variable '{real_treat_col}' is string/object type. "
|
|
1206
|
+
f"Automatically converting to categorical with levels: {treat_categories_from_formula}.",
|
|
1207
|
+
UserWarning
|
|
1208
|
+
)
|
|
1209
|
+
else:
|
|
1210
|
+
# Already categorical, extract numeric codes
|
|
1211
|
+
# Note: Categorical Series to_numpy() returns original category values
|
|
1212
|
+
# We need numeric codes instead
|
|
1213
|
+
if hasattr(treat_orig_series, 'cat'):
|
|
1214
|
+
treat = treat_orig_series.cat.codes.to_numpy()
|
|
1215
|
+
# Save original category names
|
|
1216
|
+
treat_categories_from_formula = list(treat_orig_series.cat.categories)
|
|
1217
|
+
elif isinstance(treat_orig_series, pd.Categorical):
|
|
1218
|
+
treat = treat_orig_series.codes
|
|
1219
|
+
# Save original category names
|
|
1220
|
+
treat_categories_from_formula = list(treat_orig_series.categories)
|
|
1221
|
+
else:
|
|
1222
|
+
treat = treat_orig_series.to_numpy()
|
|
1223
|
+
|
|
1224
|
+
# Treatment type detection rules:
|
|
1225
|
+
# - categorical/factor → discrete CBPS
|
|
1226
|
+
# - numeric → continuous CBPS
|
|
1227
|
+
# - only 0/1 binary values are auto-converted to factor
|
|
1228
|
+
if not is_treat_categorical:
|
|
1229
|
+
# Check for 0/1 binary values
|
|
1230
|
+
treat_unique = np.unique(treat)
|
|
1231
|
+
n_unique = len(treat_unique)
|
|
1232
|
+
is_binary_01 = (
|
|
1233
|
+
n_unique == 2 and
|
|
1234
|
+
set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}
|
|
1235
|
+
)
|
|
1236
|
+
# Warn for float type binary treatment
|
|
1237
|
+
if is_binary_01 and np.issubdtype(treat.dtype, np.floating):
|
|
1238
|
+
warnings.warn(
|
|
1239
|
+
"Treatment variable is numeric (float) with only 2 unique values. "
|
|
1240
|
+
"Interpreting as binary treatment. "
|
|
1241
|
+
"Consider using int or Categorical type for clarity.",
|
|
1242
|
+
UserWarning
|
|
1243
|
+
)
|
|
1244
|
+
if is_binary_01:
|
|
1245
|
+
is_treat_categorical = True
|
|
1246
|
+
else:
|
|
1247
|
+
raise ValueError(
|
|
1248
|
+
f"Treatment column '{real_treat_col}' not found in data.\n"
|
|
1249
|
+
f"Original formula: {formula}\n"
|
|
1250
|
+
f"Available columns: {list(data.columns)}"
|
|
1251
|
+
)
|
|
1252
|
+
|
|
1253
|
+
# Handle C() or factor() on left-hand side of formula
|
|
1254
|
+
# Patsy encodes C(treat) as multiple dummy columns, but CBPS expects single vector
|
|
1255
|
+
# Solution: only use patsy for RHS, extract y from original data
|
|
1256
|
+
if categorical_from_formula:
|
|
1257
|
+
# Already extracted treat_orig_series from data
|
|
1258
|
+
# Construct RHS-only formula for patsy
|
|
1259
|
+
from patsy import dmatrix
|
|
1260
|
+
formula_rhs = '~' + formula.split('~')[1]
|
|
1261
|
+
try:
|
|
1262
|
+
X_design = dmatrix(formula_rhs, data, return_type='dataframe')
|
|
1263
|
+
except Exception as e:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
f"Failed to parse formula right-hand side: '{formula_rhs}'\n"
|
|
1266
|
+
f"Error: {type(e).__name__}: {str(e)[:200]}"
|
|
1267
|
+
) from e
|
|
1268
|
+
else:
|
|
1269
|
+
# Standard formula, use dmatrices to parse both sides
|
|
1270
|
+
# Wrap patsy errors with user-friendly messages
|
|
1271
|
+
try:
|
|
1272
|
+
_, X_design = dmatrices(formula, data, return_type='dataframe')
|
|
1273
|
+
except PatsyError as e:
|
|
1274
|
+
# Convert patsy-specific errors to friendlier messages
|
|
1275
|
+
raise ValueError(
|
|
1276
|
+
f"Invalid formula syntax: '{formula}'\n"
|
|
1277
|
+
f"Patsy error: {str(e)[:200]}\n\n"
|
|
1278
|
+
f"Common issues:\n"
|
|
1279
|
+
f" - Undefined variables or functions\n"
|
|
1280
|
+
f" - Syntax errors in I() expressions\n"
|
|
1281
|
+
f" - Missing columns in data\n\n"
|
|
1282
|
+
f"Formula format: 'treatment ~ covariate1 + covariate2 + ...'\n"
|
|
1283
|
+
f"Examples:\n"
|
|
1284
|
+
f" - 'treat ~ age + educ + black'\n"
|
|
1285
|
+
f" - 'treat ~ age + I(age**2) + educ'\n"
|
|
1286
|
+
f" - 'treat ~ C(country) + income'"
|
|
1287
|
+
) from e
|
|
1288
|
+
except NameError as e:
|
|
1289
|
+
# Function or variable undefined
|
|
1290
|
+
raise ValueError(
|
|
1291
|
+
f"Invalid formula: '{formula}'\n"
|
|
1292
|
+
f"Error: {str(e)}\n\n"
|
|
1293
|
+
f"Make sure all variables exist in your data and all functions are defined.\n"
|
|
1294
|
+
f"Available columns: {list(data.columns)}"
|
|
1295
|
+
) from e
|
|
1296
|
+
except KeyError as e:
|
|
1297
|
+
# Column does not exist
|
|
1298
|
+
raise ValueError(
|
|
1299
|
+
f"Invalid formula: '{formula}'\n"
|
|
1300
|
+
f"Column not found in data: {str(e)}\n\n"
|
|
1301
|
+
f"Available columns: {list(data.columns)}"
|
|
1302
|
+
) from e
|
|
1303
|
+
except Exception as e:
|
|
1304
|
+
# Other unexpected errors
|
|
1305
|
+
raise ValueError(
|
|
1306
|
+
f"Failed to parse formula: '{formula}'\n"
|
|
1307
|
+
f"Error: {type(e).__name__}: {str(e)[:200]}\n\n"
|
|
1308
|
+
f"Please check your formula syntax and data."
|
|
1309
|
+
) from e
|
|
1310
|
+
|
|
1311
|
+
X = X_design.values
|
|
1312
|
+
|
|
1313
|
+
# Save terms object for predict() and update() methods
|
|
1314
|
+
terms_obj = X_design.design_info # Patsy's DesignInfo object
|
|
1315
|
+
|
|
1316
|
+
# Extract factor levels for predict() validation
|
|
1317
|
+
xlevels_dict = {}
|
|
1318
|
+
if hasattr(X_design, 'design_info') and hasattr(X_design.design_info, 'factor_infos'):
|
|
1319
|
+
for factor_name, factor_info in X_design.design_info.factor_infos.items():
|
|
1320
|
+
# Check if categorical variable
|
|
1321
|
+
if factor_info.type == 'categorical' and hasattr(factor_info, 'categories'):
|
|
1322
|
+
# Extract variable name (remove EvalFactor wrapper)
|
|
1323
|
+
var_name_str = str(factor_name)
|
|
1324
|
+
# Handle C(var_name) format
|
|
1325
|
+
if 'C(' in var_name_str and ')' in var_name_str:
|
|
1326
|
+
var_name = var_name_str.split('C(')[1].split(')')[0]
|
|
1327
|
+
else:
|
|
1328
|
+
var_name = var_name_str
|
|
1329
|
+
xlevels_dict[var_name] = list(factor_info.categories)
|
|
1330
|
+
xlevels_obj = xlevels_dict if xlevels_dict else None
|
|
1331
|
+
|
|
1332
|
+
# Save original data
|
|
1333
|
+
data_original = data.copy()
|
|
1334
|
+
|
|
1335
|
+
# Reorder columns: Intercept → regular vars (formula order) → I() function cols
|
|
1336
|
+
|
|
1337
|
+
all_cols = list(X_design.columns)
|
|
1338
|
+
intercept_cols = [c for c in all_cols if c == 'Intercept']
|
|
1339
|
+
i_func_cols = [c for c in all_cols if c.startswith('I(') and c != 'Intercept']
|
|
1340
|
+
regular_cols = [c for c in all_cols if c not in intercept_cols and c not in i_func_cols]
|
|
1341
|
+
|
|
1342
|
+
# Construct model frame containing all formula variables (after NA removal)
|
|
1343
|
+
model_cols = [real_treat_col]
|
|
1344
|
+
for col in regular_cols + i_func_cols:
|
|
1345
|
+
# Only include columns that exist in data (exclude I() expressions etc.)
|
|
1346
|
+
if col in data.columns:
|
|
1347
|
+
model_cols.append(col)
|
|
1348
|
+
model_frame = data[model_cols].copy() if len(model_cols) > 0 else data.copy()
|
|
1349
|
+
|
|
1350
|
+
# Standard ordering: Intercept, regular vars, I() functions
|
|
1351
|
+
ordered_cols = intercept_cols + regular_cols + i_func_cols
|
|
1352
|
+
|
|
1353
|
+
# Get column indices and reorder X and coef_names
|
|
1354
|
+
col_indices = [all_cols.index(c) for c in ordered_cols]
|
|
1355
|
+
X = X[:, col_indices]
|
|
1356
|
+
|
|
1357
|
+
# Standardize column names to standard statistical modeling conventions
|
|
1358
|
+
# Format: "(Intercept)", "age", "I(re75 == 0)TRUE" (convert patsy's [T.True] suffix)
|
|
1359
|
+
coef_names = []
|
|
1360
|
+
for name in ordered_cols:
|
|
1361
|
+
if name == 'Intercept':
|
|
1362
|
+
coef_names.append('(Intercept)') # Standard intercept notation
|
|
1363
|
+
elif '[T.True]' in name:
|
|
1364
|
+
# Remove patsy's [T.True] suffix, replace with TRUE
|
|
1365
|
+
coef_names.append(name.replace('[T.True]', 'TRUE'))
|
|
1366
|
+
elif '[T.False]' in name:
|
|
1367
|
+
coef_names.append(name.replace('[T.False]', 'FALSE'))
|
|
1368
|
+
else:
|
|
1369
|
+
coef_names.append(name)
|
|
1370
|
+
|
|
1371
|
+
# Sync sample_weights dimensions when na_action removes rows
|
|
1372
|
+
if sample_weights is not None:
|
|
1373
|
+
original_sample_weights = sample_weights
|
|
1374
|
+
# If sample_weights is Series/DataFrame, use data index to select rows
|
|
1375
|
+
if isinstance(sample_weights, (pd.Series, pd.DataFrame)):
|
|
1376
|
+
if isinstance(sample_weights, pd.DataFrame):
|
|
1377
|
+
sample_weights = sample_weights.iloc[:, 0].values
|
|
1378
|
+
else:
|
|
1379
|
+
sample_weights = sample_weights.loc[data.index].values
|
|
1380
|
+
else:
|
|
1381
|
+
# If numpy array, check dimension match
|
|
1382
|
+
sample_weights = np.asarray(sample_weights)
|
|
1383
|
+
if len(sample_weights) != len(treat):
|
|
1384
|
+
# Dimension mismatch with array type, cannot auto-sync
|
|
1385
|
+
warnings.warn(
|
|
1386
|
+
f"sample_weights length ({len(original_sample_weights)}) does not match "
|
|
1387
|
+
f"the number of valid observations after removing missing values ({len(treat)}). "
|
|
1388
|
+
f"Setting sample_weights to None (equal weights). "
|
|
1389
|
+
f"To avoid this, provide sample_weights as a pandas Series with matching index, "
|
|
1390
|
+
f"or handle missing values before calling CBPS().",
|
|
1391
|
+
UserWarning
|
|
1392
|
+
)
|
|
1393
|
+
sample_weights = None
|
|
1394
|
+
elif treatment is not None and covariates is not None:
|
|
1395
|
+
# Array interface path
|
|
1396
|
+
treat_original = treatment # Save for type detection
|
|
1397
|
+
|
|
1398
|
+
# Convert to numpy array (required by core algorithms)
|
|
1399
|
+
if isinstance(treatment, (pd.Series, pd.Categorical)):
|
|
1400
|
+
treat = np.asarray(treatment).ravel()
|
|
1401
|
+
else:
|
|
1402
|
+
treat = np.asarray(treatment).ravel()
|
|
1403
|
+
X = np.asarray(covariates)
|
|
1404
|
+
|
|
1405
|
+
# Validate covariates dimensions (must be 2D)
|
|
1406
|
+
if X.ndim == 0:
|
|
1407
|
+
raise ValueError(
|
|
1408
|
+
f"covariates must be a 2D array with shape (n_samples, n_features). "
|
|
1409
|
+
f"Got a scalar (0-dimensional array).\n"
|
|
1410
|
+
f"Expected shape: ({len(treat)}, k) where k >= 1.\n"
|
|
1411
|
+
f"If you have a single covariate, reshape it: X.reshape(-1, 1)"
|
|
1412
|
+
)
|
|
1413
|
+
elif X.ndim == 1:
|
|
1414
|
+
raise ValueError(
|
|
1415
|
+
f"covariates must be a 2D array with shape (n_samples, n_features). "
|
|
1416
|
+
f"Got a 1D array with shape {X.shape}.\n"
|
|
1417
|
+
f"Expected shape: ({len(treat)}, k) where k >= 1.\n\n"
|
|
1418
|
+
f"To fix this:\n"
|
|
1419
|
+
f" - If you have a single covariate: X.reshape(-1, 1)\n"
|
|
1420
|
+
f" - If you passed the transposed matrix: X.T\n\n"
|
|
1421
|
+
f"Current shapes:\n"
|
|
1422
|
+
f" treatment: {treat.shape}\n"
|
|
1423
|
+
f" covariates: {X.shape}"
|
|
1424
|
+
)
|
|
1425
|
+
elif X.ndim > 2:
|
|
1426
|
+
raise ValueError(
|
|
1427
|
+
f"covariates must be a 2D array with shape (n_samples, n_features). "
|
|
1428
|
+
f"Got a {X.ndim}-dimensional array with shape {X.shape}.\n"
|
|
1429
|
+
f"Expected shape: ({len(treat)}, k) where k >= 1."
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
# Validate treatment and covariates have matching lengths
|
|
1433
|
+
if len(treat) != X.shape[0]:
|
|
1434
|
+
raise ValueError(
|
|
1435
|
+
f"Treatment and covariates must have the same number of samples.\n"
|
|
1436
|
+
f" treatment length: {len(treat)}\n"
|
|
1437
|
+
f" covariates rows: {X.shape[0]}\n\n"
|
|
1438
|
+
f"Please ensure treatment and covariates come from the same dataset."
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
# Auto-add intercept column if not present
|
|
1442
|
+
if not _has_intercept(X):
|
|
1443
|
+
if verbose > 0:
|
|
1444
|
+
warnings.warn(
|
|
1445
|
+
"Intercept column not detected. Adding intercept to covariates matrix. "
|
|
1446
|
+
"To suppress this warning, manually add intercept: "
|
|
1447
|
+
"np.column_stack([np.ones(n), X])",
|
|
1448
|
+
UserWarning
|
|
1449
|
+
)
|
|
1450
|
+
X = np.column_stack([np.ones(len(treat)), X])
|
|
1451
|
+
|
|
1452
|
+
# Generate default column names
|
|
1453
|
+
if isinstance(covariates, pd.DataFrame):
|
|
1454
|
+
coef_names = covariates.columns.tolist()
|
|
1455
|
+
# If intercept was added, prepend "Intercept" to column names
|
|
1456
|
+
if not _has_intercept(np.asarray(covariates)):
|
|
1457
|
+
coef_names = ["Intercept"] + coef_names
|
|
1458
|
+
else:
|
|
1459
|
+
k = X.shape[1]
|
|
1460
|
+
coef_names = ["Intercept"] + [f"X{i}" for i in range(1, k)]
|
|
1461
|
+
else:
|
|
1462
|
+
raise ValueError(
|
|
1463
|
+
"Must provide either 'formula' and 'data', or 'treatment' and 'covariates'"
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# Step 1.5: Dual formula parsing (oCBPS path)
|
|
1467
|
+
baseline_X = None
|
|
1468
|
+
diff_X = None
|
|
1469
|
+
|
|
1470
|
+
# Check if baseline/diff formula is provided
|
|
1471
|
+
has_baseline_or_diff = (baseline_formula is not None or diff_formula is not None)
|
|
1472
|
+
|
|
1473
|
+
if has_baseline_or_diff:
|
|
1474
|
+
# Check data parameter first
|
|
1475
|
+
if data is None:
|
|
1476
|
+
raise ValueError(
|
|
1477
|
+
"The data parameter is required when using baseline_formula or diff_formula.\n"
|
|
1478
|
+
"These parameters require access to the original DataFrame to parse formulas."
|
|
1479
|
+
)
|
|
1480
|
+
|
|
1481
|
+
# Extract treatment variable and detect type
|
|
1482
|
+
treat_for_check = None
|
|
1483
|
+
treat_col_name_for_check = None
|
|
1484
|
+
|
|
1485
|
+
if formula is not None:
|
|
1486
|
+
# Formula path: extract treatment from data
|
|
1487
|
+
treat_col_name_raw = formula.split('~')[0].strip()
|
|
1488
|
+
|
|
1489
|
+
# Handle C() and factor() syntax
|
|
1490
|
+
import re
|
|
1491
|
+
c_match = re.match(r'C\(([^)]+)\)', treat_col_name_raw)
|
|
1492
|
+
factor_match = re.match(r'factor\(([^)]+)\)', treat_col_name_raw)
|
|
1493
|
+
|
|
1494
|
+
if c_match:
|
|
1495
|
+
treat_col_name_for_check = c_match.group(1).strip()
|
|
1496
|
+
elif factor_match:
|
|
1497
|
+
treat_col_name_for_check = factor_match.group(1).strip()
|
|
1498
|
+
else:
|
|
1499
|
+
treat_col_name_for_check = treat_col_name_raw
|
|
1500
|
+
|
|
1501
|
+
if treat_col_name_for_check in data.columns:
|
|
1502
|
+
treat_for_check = data[treat_col_name_for_check].to_numpy()
|
|
1503
|
+
elif treatment is not None:
|
|
1504
|
+
# Array path: use treatment directly
|
|
1505
|
+
treat_for_check = treatment
|
|
1506
|
+
|
|
1507
|
+
# Call unified treatment type detection function
|
|
1508
|
+
if treat_for_check is not None:
|
|
1509
|
+
is_cat, is_bin, is_cont = _detect_treatment_type(
|
|
1510
|
+
treat_for_check,
|
|
1511
|
+
formula=formula,
|
|
1512
|
+
data=data,
|
|
1513
|
+
treat_col_name=treat_col_name_for_check
|
|
1514
|
+
)
|
|
1515
|
+
|
|
1516
|
+
# Reject continuous treatment immediately (takes priority over XOR check)
|
|
1517
|
+
if is_cont:
|
|
1518
|
+
raise ValueError(
|
|
1519
|
+
"baseline_formula and diff_formula are only supported for binary treatments.\n"
|
|
1520
|
+
"Optimal CBPS is not defined for continuous treatments.\n"
|
|
1521
|
+
"\n"
|
|
1522
|
+
"Reference:\n"
|
|
1523
|
+
" Fan, J., Imai, K., Lee, I., Liu, H., Ning, Y., & Yang, X. (2022).\n"
|
|
1524
|
+
" Optimal Covariate Balancing Conditions in Propensity Score Estimation.\n"
|
|
1525
|
+
" Journal of Business & Economic Statistics, 41(1), 97-110.\n"
|
|
1526
|
+
"\n"
|
|
1527
|
+
"For continuous treatments, use the standard CBPS without baseline/diff formulas."
|
|
1528
|
+
)
|
|
1529
|
+
|
|
1530
|
+
# Passed continuous treatment check, now check XOR (binary treatment only)
|
|
1531
|
+
if (baseline_formula is None) != (diff_formula is None):
|
|
1532
|
+
raise ValueError(
|
|
1533
|
+
"Both baseline_formula and diff_formula must be specified together, or neither.\n"
|
|
1534
|
+
f"Currently: baseline_formula={'provided' if baseline_formula else 'None'}, "
|
|
1535
|
+
f"diff_formula={'provided' if diff_formula else 'None'}.\n"
|
|
1536
|
+
"\n"
|
|
1537
|
+
"Either specify both formulas to use iCBPS (Optimal CBPS), or leave both as None."
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
# Dual formula parsing
|
|
1541
|
+
from patsy import dmatrix
|
|
1542
|
+
|
|
1543
|
+
# Parse baseline formula
|
|
1544
|
+
baseline_X_raw = dmatrix(baseline_formula, data, return_type='dataframe').values
|
|
1545
|
+
# Filter zero-variance columns (intercept with sd=0 will be removed)
|
|
1546
|
+
baseline_X = baseline_X_raw[:, baseline_X_raw.std(axis=0, ddof=1) > 0]
|
|
1547
|
+
|
|
1548
|
+
# Parse diff formula
|
|
1549
|
+
diff_X_raw = dmatrix(diff_formula, data, return_type='dataframe').values
|
|
1550
|
+
# Filter zero-variance columns
|
|
1551
|
+
diff_X = diff_X_raw[:, diff_X_raw.std(axis=0, ddof=1) > 0]
|
|
1552
|
+
|
|
1553
|
+
# Step 1.5.5a: Basic dimension and sample size checks (must execute first)
|
|
1554
|
+
n = len(treat)
|
|
1555
|
+
|
|
1556
|
+
# Handle empty array (n=0)
|
|
1557
|
+
if n == 0:
|
|
1558
|
+
raise ValueError(
|
|
1559
|
+
"Treatment array is empty (n=0). "
|
|
1560
|
+
"CBPS requires at least 10 observations for valid inference."
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1563
|
+
# Zero variance check takes priority over sample size check
|
|
1564
|
+
# Check if treatment variable has variance (all values identical)
|
|
1565
|
+
if n > 1: # Only check variance when n > 1
|
|
1566
|
+
# Get unique value count (works for all types)
|
|
1567
|
+
unique_vals = np.unique(treat)
|
|
1568
|
+
n_unique = len(unique_vals)
|
|
1569
|
+
|
|
1570
|
+
if n_unique == 1:
|
|
1571
|
+
# All values identical, cannot estimate propensity score
|
|
1572
|
+
raise ValueError(
|
|
1573
|
+
f"Treatment variable has zero variance. "
|
|
1574
|
+
f"All {n} observations have the same treatment value (treat={unique_vals[0]}). "
|
|
1575
|
+
f"CBPS requires variation in the treatment variable to estimate propensity scores. "
|
|
1576
|
+
f"Please check your data for errors or use a different subset with treatment variation."
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
# For numeric types, also check if std is too small (near-constant)
|
|
1580
|
+
# Skip Categorical/string types (cannot compute std)
|
|
1581
|
+
is_categorical = hasattr(treat, 'categories') or (
|
|
1582
|
+
hasattr(treat, 'dtype') and hasattr(treat.dtype, 'categories')
|
|
1583
|
+
)
|
|
1584
|
+
is_string_dtype = (
|
|
1585
|
+
hasattr(treat, 'dtype') and
|
|
1586
|
+
(treat.dtype.kind == 'U' or treat.dtype.kind == 'O' or treat.dtype.kind == 'S')
|
|
1587
|
+
)
|
|
1588
|
+
|
|
1589
|
+
if not is_categorical and not is_string_dtype and n_unique > 1:
|
|
1590
|
+
try:
|
|
1591
|
+
treat_numeric = np.asarray(treat, dtype=np.float64)
|
|
1592
|
+
treat_std = np.std(treat_numeric, ddof=1)
|
|
1593
|
+
if treat_std == 0 or np.isclose(treat_std, 0):
|
|
1594
|
+
# Numeric treatment with zero std but multiple unique values (rare)
|
|
1595
|
+
raise ValueError(
|
|
1596
|
+
f"Treatment variable has zero or near-zero variance (std={treat_std:.2e}). "
|
|
1597
|
+
f"CBPS requires sufficient treatment variation for stable estimation."
|
|
1598
|
+
)
|
|
1599
|
+
except (ValueError, TypeError):
|
|
1600
|
+
# Cannot convert to numeric, skip std check (handled in type detection)
|
|
1601
|
+
pass
|
|
1602
|
+
|
|
1603
|
+
# Reject n<10 (statistically meaningless)
|
|
1604
|
+
if n < 10:
|
|
1605
|
+
raise ValueError(
|
|
1606
|
+
f"Sample size (n={n}) too small for CBPS (minimum: n ≥ 10). "
|
|
1607
|
+
f"CBPS relies on asymptotic (large-sample) theory for valid inference. "
|
|
1608
|
+
f"With n<10, standard errors and confidence intervals are completely invalid. "
|
|
1609
|
+
f"Current sample provides insufficient degrees of freedom for reliable estimation."
|
|
1610
|
+
)
|
|
1611
|
+
|
|
1612
|
+
# Step 1.5.5b: Input validation - detect inf/nan values
|
|
1613
|
+
try:
|
|
1614
|
+
_validate_finite_inputs(treat, X, func_name="CBPS")
|
|
1615
|
+
except ValueError as e:
|
|
1616
|
+
# Provide friendlier error message for formula interface
|
|
1617
|
+
if formula is not None:
|
|
1618
|
+
raise ValueError(
|
|
1619
|
+
f"{e}\n"
|
|
1620
|
+
f"Formula used: '{formula}'\n"
|
|
1621
|
+
f"Hint: Check your data for log(0), division by zero, or missing values."
|
|
1622
|
+
) from e
|
|
1623
|
+
else:
|
|
1624
|
+
raise
|
|
1625
|
+
|
|
1626
|
+
# Step 1.6: Zero-variance covariate filtering
|
|
1627
|
+
# Auto-drop zero-variance columns (except intercept) for numerical stability
|
|
1628
|
+
if X.shape[1] > 1: # If there are columns besides intercept
|
|
1629
|
+
# Compute std for all columns except intercept
|
|
1630
|
+
x_sd = X[:, 1:].std(axis=0, ddof=1)
|
|
1631
|
+
const_threshold = 1e-10
|
|
1632
|
+
non_const_mask = x_sd > const_threshold
|
|
1633
|
+
|
|
1634
|
+
# Check if any constant columns need to be dropped
|
|
1635
|
+
n_const_cols = np.sum(~non_const_mask)
|
|
1636
|
+
if n_const_cols > 0:
|
|
1637
|
+
# Record dropped column names if available
|
|
1638
|
+
if 'coef_names' in locals() and len(coef_names) == X.shape[1]:
|
|
1639
|
+
const_col_names = [coef_names[i+1] for i, is_const in enumerate(~non_const_mask) if is_const]
|
|
1640
|
+
warnings.warn(
|
|
1641
|
+
f"Dropping {n_const_cols} constant covariate(s) with zero variance: "
|
|
1642
|
+
f"{const_col_names}.",
|
|
1643
|
+
UserWarning
|
|
1644
|
+
)
|
|
1645
|
+
else:
|
|
1646
|
+
warnings.warn(
|
|
1647
|
+
f"Dropping {n_const_cols} constant covariate(s) with zero variance.",
|
|
1648
|
+
UserWarning
|
|
1649
|
+
)
|
|
1650
|
+
|
|
1651
|
+
# Keep intercept + non-constant columns
|
|
1652
|
+
X = np.column_stack([X[:, 0], X[:, 1:][:, non_const_mask]])
|
|
1653
|
+
|
|
1654
|
+
# Update column names accordingly
|
|
1655
|
+
if 'coef_names' in locals() and len(coef_names) == X.shape[1] + n_const_cols:
|
|
1656
|
+
coef_names = [coef_names[0]] + [coef_names[i+1] for i, is_non_const in enumerate(non_const_mask) if is_non_const]
|
|
1657
|
+
|
|
1658
|
+
# Reject intercept-only model (CBPS requires covariates to balance)
|
|
1659
|
+
if X.shape[1] <= 1:
|
|
1660
|
+
raise ValueError(
|
|
1661
|
+
f"CBPS requires at least one covariate (non-intercept) for covariate balancing.\n"
|
|
1662
|
+
f"Formula '{formula if formula else 'array input'}' resulted in design matrix with only intercept.\n\n"
|
|
1663
|
+
f"Explanation:\n"
|
|
1664
|
+
f" CBPS = Covariate Balancing Propensity Score\n"
|
|
1665
|
+
f" Without covariates, there is nothing to balance.\n\n"
|
|
1666
|
+
f"Theoretical reference:\n"
|
|
1667
|
+
f" Imai & Ratkovic (2014) Equation 8 requires covariates X_i for balance conditions.\n\n"
|
|
1668
|
+
f"Please add covariates to your formula, for example:\n"
|
|
1669
|
+
f" 'treat ~ age + education + income'\n"
|
|
1670
|
+
f" 'treat ~ x1 + x2 + I(x1**2)'\n\n"
|
|
1671
|
+
f"Current design matrix shape: {X.shape}"
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
# Step 1.7: Rank check for collinearity detection
|
|
1675
|
+
rank_X = np.linalg.matrix_rank(X)
|
|
1676
|
+
k = X.shape[1]
|
|
1677
|
+
|
|
1678
|
+
if rank_X < k:
|
|
1679
|
+
# Provide detailed error message to help diagnose the issue
|
|
1680
|
+
raise ValueError(
|
|
1681
|
+
f"Covariate matrix X is not full rank (rank={rank_X} < {k}). "
|
|
1682
|
+
f"This indicates perfect collinearity among covariates. "
|
|
1683
|
+
f"Possible causes:\n"
|
|
1684
|
+
f" - Linear combinations (e.g., X2 = 2*X1 + 3)\n"
|
|
1685
|
+
f" - Duplicate columns (e.g., X2 = X1)\n"
|
|
1686
|
+
f" - Redundant interactions or polynomial terms\n"
|
|
1687
|
+
f"Please remove or combine collinear covariates. "
|
|
1688
|
+
f"Use variance inflation factor (VIF) or correlation matrix to diagnose."
|
|
1689
|
+
)
|
|
1690
|
+
|
|
1691
|
+
# Optional: Condition number warning for near-collinearity
|
|
1692
|
+
# High condition number indicates X'X is near-singular
|
|
1693
|
+
cond_num = np.linalg.cond(X)
|
|
1694
|
+
if cond_num > 1e10:
|
|
1695
|
+
warnings.warn(
|
|
1696
|
+
f"Covariate matrix X has very high condition number ({cond_num:.2e}). "
|
|
1697
|
+
f"This suggests near-collinearity, which may cause numerical instability. "
|
|
1698
|
+
f"Consider:\n"
|
|
1699
|
+
f" - Removing highly correlated covariates (check correlation matrix)\n"
|
|
1700
|
+
f" - Centering and scaling variables\n"
|
|
1701
|
+
f" - Using regularization (hdCBPS for high-dimensional settings)",
|
|
1702
|
+
UserWarning
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
# Step 1.8: Relative sample size check
|
|
1706
|
+
k = X.shape[1]
|
|
1707
|
+
|
|
1708
|
+
# Warn for small samples (10 ≤ n < 30)
|
|
1709
|
+
if n < 30:
|
|
1710
|
+
warnings.warn(
|
|
1711
|
+
f"Small sample size (n={n}, recommended minimum: n ≥ 30). "
|
|
1712
|
+
f"CBPS standard errors rely on asymptotic normality which may not hold well for small samples. "
|
|
1713
|
+
f"Consider:\n"
|
|
1714
|
+
f" - Using bootstrap for more reliable confidence intervals\n"
|
|
1715
|
+
f" - Reporting results with appropriate caution\n"
|
|
1716
|
+
f" - Collecting more data if possible",
|
|
1717
|
+
UserWarning
|
|
1718
|
+
)
|
|
1719
|
+
|
|
1720
|
+
# Warn for low n/k ratio (insufficient relative sample size)
|
|
1721
|
+
if n <= k + 5:
|
|
1722
|
+
warnings.warn(
|
|
1723
|
+
f"Sample size (n={n}) small relative to number of parameters (k={k}). "
|
|
1724
|
+
f"Ratio n/k={n/k:.2f} is low (recommended: n/k ≥ 5). "
|
|
1725
|
+
f"Consider reducing the number of covariates for more stable estimates.",
|
|
1726
|
+
UserWarning
|
|
1727
|
+
)
|
|
1728
|
+
|
|
1729
|
+
# Step 2: Construct call_info
|
|
1730
|
+
if formula is not None:
|
|
1731
|
+
call_info = (f"CBPS(formula='{formula}', data=<DataFrame>, "
|
|
1732
|
+
f"att={att}, method='{method}', two_step={two_step})")
|
|
1733
|
+
else:
|
|
1734
|
+
call_info = (f"CBPS(treatment=<array>, covariates=<array>, "
|
|
1735
|
+
f"att={att}, method='{method}')")
|
|
1736
|
+
|
|
1737
|
+
# Step 3: Treatment type detection and routing
|
|
1738
|
+
# Important: detect factor/categorical before numeric
|
|
1739
|
+
# (pd.Categorical.dtype can be int64, causing misclassification)
|
|
1740
|
+
|
|
1741
|
+
# Debug output for treatment type detection
|
|
1742
|
+
if verbose > 1:
|
|
1743
|
+
print(f"DEBUG: Treatment type detection")
|
|
1744
|
+
if formula is not None:
|
|
1745
|
+
print(f" Formula path: is_treat_categorical={is_treat_categorical}")
|
|
1746
|
+
print(f" treat unique values: {np.unique(treat)}")
|
|
1747
|
+
print(f" treat dtype: {treat.dtype}")
|
|
1748
|
+
|
|
1749
|
+
# Discrete treatment detection (factor/categorical takes priority)
|
|
1750
|
+
if formula is not None:
|
|
1751
|
+
# Formula path: use saved is_treat_categorical
|
|
1752
|
+
is_factor = is_treat_categorical
|
|
1753
|
+
if verbose > 1:
|
|
1754
|
+
print(f" Formula path: is_factor={is_factor}")
|
|
1755
|
+
else:
|
|
1756
|
+
# Array path: detect Categorical or 0/1 binary values
|
|
1757
|
+
# Only 0/1 binary is auto-converted to factor (other numeric stays continuous)
|
|
1758
|
+
treat_unique = np.unique(treat)
|
|
1759
|
+
n_unique = len(treat_unique)
|
|
1760
|
+
|
|
1761
|
+
# Check for 0/1 binary
|
|
1762
|
+
is_binary_01 = (
|
|
1763
|
+
n_unique == 2 and
|
|
1764
|
+
set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}
|
|
1765
|
+
)
|
|
1766
|
+
|
|
1767
|
+
# Warn for float-type binary treatment
|
|
1768
|
+
if is_binary_01 and np.issubdtype(treat_original.dtype, np.floating):
|
|
1769
|
+
warnings.warn(
|
|
1770
|
+
"Treatment variable is numeric (float) with only 2 unique values. "
|
|
1771
|
+
"Interpreting as binary treatment. "
|
|
1772
|
+
"Consider using int or Categorical type for clarity.",
|
|
1773
|
+
UserWarning
|
|
1774
|
+
)
|
|
1775
|
+
|
|
1776
|
+
is_factor = (
|
|
1777
|
+
isinstance(treat_original, pd.Categorical) or
|
|
1778
|
+
hasattr(treat_original, 'cat') or
|
|
1779
|
+
is_binary_01 # Only 0/1 binary auto-detected as discrete
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
# Continuous treatment detection
|
|
1783
|
+
# Numeric and not factor = continuous (regardless of unique value count)
|
|
1784
|
+
is_continuous = (
|
|
1785
|
+
not is_factor and
|
|
1786
|
+
np.issubdtype(treat.dtype, np.number)
|
|
1787
|
+
)
|
|
1788
|
+
|
|
1789
|
+
if is_continuous:
|
|
1790
|
+
# Warn if numeric treatment has few unique values (may be discrete)
|
|
1791
|
+
n_unique = len(np.unique(treat))
|
|
1792
|
+
if n_unique <= 4:
|
|
1793
|
+
warnings.warn(
|
|
1794
|
+
f"Treatment vector is numeric with {n_unique} unique values. "
|
|
1795
|
+
f"Interpreting as a continuous treatment. "
|
|
1796
|
+
f"To solve for a binary or multi-valued treatment, convert treat to categorical "
|
|
1797
|
+
f"(e.g., pd.Categorical(treat) or treat.astype('category')).",
|
|
1798
|
+
UserWarning
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
# Continuous treatment does not support ATT, warn and ignore
|
|
1802
|
+
if att != 0:
|
|
1803
|
+
warnings.warn(
|
|
1804
|
+
f"ATT parameter (att={att}) is not supported for continuous treatments. "
|
|
1805
|
+
f"Continuous CBPS only estimates the Average Treatment Effect (ATE). "
|
|
1806
|
+
f"The att parameter will be ignored. "
|
|
1807
|
+
f"\n\nReason: ATT (Average Treatment Effect on the Treated) requires a binary "
|
|
1808
|
+
f"distinction between 'treated' and 'control' groups, which does not exist "
|
|
1809
|
+
f"for continuous treatments. "
|
|
1810
|
+
f"\n\nTheoretical reference: Fong, Hazlett & Imai (2018, Annals of Applied Statistics) "
|
|
1811
|
+
f"define stabilized weights for continuous treatments that estimate ATE only. "
|
|
1812
|
+
f"\n\nNote: For non-binary treatments, only the ATE is available.",
|
|
1813
|
+
UserWarning
|
|
1814
|
+
)
|
|
1815
|
+
|
|
1816
|
+
# Call continuous CBPS
|
|
1817
|
+
from cbps.core.cbps_continuous import cbps_continuous_fit
|
|
1818
|
+
|
|
1819
|
+
# Apply SVD preprocessing
|
|
1820
|
+
X_orig = X.copy() # Save original X for inverse transform
|
|
1821
|
+
X_svd, svd_info = _apply_svd_preprocessing(X)
|
|
1822
|
+
|
|
1823
|
+
# Compute rank and XprimeX_inv in SVD space
|
|
1824
|
+
k = np.linalg.matrix_rank(X_svd)
|
|
1825
|
+
if k < X_svd.shape[1]:
|
|
1826
|
+
raise ValueError("X is not full rank")
|
|
1827
|
+
|
|
1828
|
+
# Compute XprimeX_inv in SVD space
|
|
1829
|
+
if sample_weights is None:
|
|
1830
|
+
sample_weights_norm = np.ones(len(treat))
|
|
1831
|
+
else:
|
|
1832
|
+
sample_weights_norm = sample_weights / np.mean(sample_weights)
|
|
1833
|
+
|
|
1834
|
+
sw_sqrt_X = np.sqrt(sample_weights_norm)[:, None] * X_svd
|
|
1835
|
+
XprimeX = sw_sqrt_X.T @ sw_sqrt_X
|
|
1836
|
+
from cbps.core.cbps_binary import _r_ginv
|
|
1837
|
+
XprimeX_inv = _r_ginv(XprimeX)
|
|
1838
|
+
|
|
1839
|
+
result_dict = cbps_continuous_fit(
|
|
1840
|
+
treat, X_svd, # Pass SVD-preprocessed X
|
|
1841
|
+
method=method,
|
|
1842
|
+
two_step=two_step,
|
|
1843
|
+
iterations=iterations,
|
|
1844
|
+
standardize=standardize,
|
|
1845
|
+
sample_weights=sample_weights,
|
|
1846
|
+
verbose=verbose
|
|
1847
|
+
)
|
|
1848
|
+
|
|
1849
|
+
# SVD inverse transform
|
|
1850
|
+
beta_svd = result_dict['coefficients']
|
|
1851
|
+
if beta_svd.ndim == 1:
|
|
1852
|
+
beta_svd = beta_svd.reshape(-1, 1)
|
|
1853
|
+
beta_transformed = _apply_svd_inverse_transform(beta_svd, svd_info)
|
|
1854
|
+
result_dict['coefficients'] = beta_transformed
|
|
1855
|
+
|
|
1856
|
+
# Update x to original X
|
|
1857
|
+
result_dict['x'] = X_orig
|
|
1858
|
+
|
|
1859
|
+
# Remove keys not accepted by CBPSResults
|
|
1860
|
+
result_dict.pop('normality_diagnostics', None)
|
|
1861
|
+
|
|
1862
|
+
# Wrap in CBPSResults
|
|
1863
|
+
result = CBPSResults(
|
|
1864
|
+
**result_dict,
|
|
1865
|
+
coef_names=coef_names,
|
|
1866
|
+
call_info=call_info,
|
|
1867
|
+
formula=formula,
|
|
1868
|
+
data=data_original if formula is not None else None,
|
|
1869
|
+
terms=terms_obj if formula is not None else None,
|
|
1870
|
+
model=model_frame if formula is not None else None,
|
|
1871
|
+
xlevels=xlevels_obj if formula is not None else None,
|
|
1872
|
+
att=att,
|
|
1873
|
+
method=method,
|
|
1874
|
+
standardize=standardize,
|
|
1875
|
+
two_step=two_step
|
|
1876
|
+
)
|
|
1877
|
+
|
|
1878
|
+
return result
|
|
1879
|
+
|
|
1880
|
+
# Discrete treatment routing
|
|
1881
|
+
# Detect treatment levels (prioritize saved category names from formula interface)
|
|
1882
|
+
if formula is not None and 'treat_categories_from_formula' in locals() and treat_categories_from_formula is not None:
|
|
1883
|
+
treat_levels = np.array(treat_categories_from_formula)
|
|
1884
|
+
elif isinstance(treat, pd.Categorical):
|
|
1885
|
+
treat_levels = treat.categories.values
|
|
1886
|
+
elif hasattr(treat, 'cat'): # pandas Series with categorical dtype
|
|
1887
|
+
treat_levels = treat.cat.categories.values
|
|
1888
|
+
else:
|
|
1889
|
+
treat_levels = np.unique(treat)
|
|
1890
|
+
|
|
1891
|
+
# Sort treat_levels for consistent baseline (MNLogit uses treat_levels[0] as baseline)
|
|
1892
|
+
treat_levels = np.sort(treat_levels)
|
|
1893
|
+
|
|
1894
|
+
# Re-encode if treat uses categorical codes to align with sorted levels
|
|
1895
|
+
if formula is not None and ('treat_orig_series' in locals()):
|
|
1896
|
+
if hasattr(treat_orig_series, 'cat') or isinstance(treat_orig_series, pd.Categorical):
|
|
1897
|
+
# Re-encode: map original values to sorted indices
|
|
1898
|
+
if isinstance(treat_orig_series, pd.Categorical):
|
|
1899
|
+
treat_original_values = np.asarray(treat_orig_series)
|
|
1900
|
+
else:
|
|
1901
|
+
treat_original_values = treat_orig_series.to_numpy()
|
|
1902
|
+
value_to_sorted_index = {val: i for i, val in enumerate(treat_levels)}
|
|
1903
|
+
treat = np.array([value_to_sorted_index[val] for val in treat_original_values])
|
|
1904
|
+
|
|
1905
|
+
no_treats = len(treat_levels)
|
|
1906
|
+
|
|
1907
|
+
# Validate treatment level count
|
|
1908
|
+
if no_treats > 4:
|
|
1909
|
+
raise ValueError(
|
|
1910
|
+
"Parametric CBPS is not implemented for more than 4 treatment values. "
|
|
1911
|
+
"Consider using a continuous value."
|
|
1912
|
+
)
|
|
1913
|
+
if no_treats < 2:
|
|
1914
|
+
raise ValueError("Treatment must take more than one value")
|
|
1915
|
+
|
|
1916
|
+
# theoretical_exact not supported for multi-valued treatments
|
|
1917
|
+
if no_treats >= 3 and theoretical_exact:
|
|
1918
|
+
raise ValueError(
|
|
1919
|
+
f"theoretical_exact=True is not supported for multi-valued treatments ({no_treats} levels). "
|
|
1920
|
+
f"theoretical_exact is an experimental feature for binary treatments only.\n\n"
|
|
1921
|
+
f"Please set theoretical_exact=False (default) or use binary treatment."
|
|
1922
|
+
)
|
|
1923
|
+
|
|
1924
|
+
# Multi-valued treatment ATT handling (only ATE supported for 3+ levels)
|
|
1925
|
+
if no_treats >= 3 and att != 0:
|
|
1926
|
+
warnings.warn(
|
|
1927
|
+
f"Multi-valued treatment ({no_treats} levels) only supports att=0 (ATE). "
|
|
1928
|
+
f"ATT parameter (att={att}) will be overridden to att=0.\n\n"
|
|
1929
|
+
f"Reason: ATT requires a binary distinction between 'treated' and 'control'. "
|
|
1930
|
+
f"With {no_treats} levels, there is no single 'treated' group.\n\n"
|
|
1931
|
+
f"Reference: Imai & Ratkovic (2014), JRSS-B, Section 4.1.",
|
|
1932
|
+
UserWarning
|
|
1933
|
+
)
|
|
1934
|
+
att = 0 # Force ATE
|
|
1935
|
+
|
|
1936
|
+
# Binary treatment routing
|
|
1937
|
+
if no_treats == 2:
|
|
1938
|
+
# Handle att=2 encoding reversal
|
|
1939
|
+
from cbps.utils.helpers import encode_treatment_factor
|
|
1940
|
+
|
|
1941
|
+
# Save original treat for result object
|
|
1942
|
+
treat_original_for_results = treat.copy() if isinstance(treat, np.ndarray) else treat
|
|
1943
|
+
|
|
1944
|
+
# oCBPS path check - must be done BEFORE encoding to prevent att=2 reversal
|
|
1945
|
+
is_ocbps_path = baseline_X is not None and diff_X is not None
|
|
1946
|
+
|
|
1947
|
+
# For oCBPS, force att=0 BEFORE encoding to match R behavior
|
|
1948
|
+
att_for_encoding = att
|
|
1949
|
+
if is_ocbps_path and att != 0:
|
|
1950
|
+
warnings.warn(
|
|
1951
|
+
f"CBPSOptimal only supports att=0 (ATE). "
|
|
1952
|
+
f"Received att={att}, forcing to att=0. "
|
|
1953
|
+
f"Treatment encoding will NOT be reversed.",
|
|
1954
|
+
UserWarning
|
|
1955
|
+
)
|
|
1956
|
+
att_for_encoding = 0 # Force ATE encoding for oCBPS
|
|
1957
|
+
|
|
1958
|
+
# Apply ATT encoding logic for binary treatment
|
|
1959
|
+
if formula is not None and 'treat_orig_series' in locals() and is_treat_categorical:
|
|
1960
|
+
# Formula path: use original categorical series
|
|
1961
|
+
treat_encoded, treat_levels_ordered, treat_orig = encode_treatment_factor(treat_orig_series, att_for_encoding, verbose=verbose)
|
|
1962
|
+
else:
|
|
1963
|
+
# Array path or treat is already numeric
|
|
1964
|
+
treat_encoded, treat_levels_ordered, treat_orig = encode_treatment_factor(treat, att_for_encoding, verbose=verbose)
|
|
1965
|
+
|
|
1966
|
+
# Update treat to encoded 0/1 array
|
|
1967
|
+
treat = treat_encoded
|
|
1968
|
+
|
|
1969
|
+
# Normalize att to 0 or 1 (encoding already handles att=2 reversal)
|
|
1970
|
+
# att=0 → 0 (ATE), att=1 → 1 (ATT), att=2 → 1 (ATT with reversed encoding)
|
|
1971
|
+
# For oCBPS, att_for_encoding is always 0, so att_normalized will be 0
|
|
1972
|
+
att_normalized = 0 if att_for_encoding == 0 else 1
|
|
1973
|
+
|
|
1974
|
+
# oCBPS routing
|
|
1975
|
+
if is_ocbps_path:
|
|
1976
|
+
# oCBPS path - only supports ATE (att=0)
|
|
1977
|
+
# Warning already issued above if att != 0
|
|
1978
|
+
|
|
1979
|
+
# Force ATT=0 for oCBPS
|
|
1980
|
+
from cbps.core.cbps_optimal import cbps_optimal_2treat
|
|
1981
|
+
result_dict = cbps_optimal_2treat(
|
|
1982
|
+
treat, X, baseline_X, diff_X,
|
|
1983
|
+
iterations=iterations,
|
|
1984
|
+
att=0, # Force to 0
|
|
1985
|
+
standardize=standardize
|
|
1986
|
+
)
|
|
1987
|
+
elif baseline_X is not None or diff_X is not None:
|
|
1988
|
+
# Only one of baseline_X/diff_X provided - invalid for oCBPS
|
|
1989
|
+
raise ValueError(
|
|
1990
|
+
"For oCBPS (optimal CBPS), both baseline_formula and diff_formula "
|
|
1991
|
+
"(or baseline_X and diff_X) must be provided. "
|
|
1992
|
+
f"Received: baseline={'provided' if baseline_X is not None else 'None'}, "
|
|
1993
|
+
f"diff={'provided' if diff_X is not None else 'None'}. "
|
|
1994
|
+
"Either provide both for oCBPS, or neither for standard CBPS."
|
|
1995
|
+
)
|
|
1996
|
+
else:
|
|
1997
|
+
# Standard CBPS path
|
|
1998
|
+
# Apply SVD preprocessing (matching R package CBPSMain.R lines 307-314)
|
|
1999
|
+
X_orig_binary = X.copy()
|
|
2000
|
+
X_svd_binary, svd_info_binary = _apply_svd_preprocessing(X)
|
|
2001
|
+
|
|
2002
|
+
# Compute rank check in SVD space
|
|
2003
|
+
k_binary = np.linalg.matrix_rank(X_svd_binary)
|
|
2004
|
+
if k_binary < X_svd_binary.shape[1]:
|
|
2005
|
+
raise ValueError("X is not full rank")
|
|
2006
|
+
|
|
2007
|
+
# Compute XprimeX_inv in SVD space
|
|
2008
|
+
if sample_weights is None:
|
|
2009
|
+
sw_norm_binary = np.ones(len(treat))
|
|
2010
|
+
else:
|
|
2011
|
+
sw_norm_binary = sample_weights / np.mean(sample_weights)
|
|
2012
|
+
sw_sqrt_X_binary = np.sqrt(sw_norm_binary)[:, None] * X_svd_binary
|
|
2013
|
+
XprimeX_binary = sw_sqrt_X_binary.T @ sw_sqrt_X_binary
|
|
2014
|
+
from cbps.core.cbps_binary import _r_ginv
|
|
2015
|
+
XprimeX_inv_binary = _r_ginv(XprimeX_binary)
|
|
2016
|
+
|
|
2017
|
+
result_dict = cbps_binary_fit(
|
|
2018
|
+
treat, X_svd_binary, # Pass SVD-transformed X
|
|
2019
|
+
att=att_normalized,
|
|
2020
|
+
method=method,
|
|
2021
|
+
two_step=two_step,
|
|
2022
|
+
standardize=standardize,
|
|
2023
|
+
sample_weights=sample_weights,
|
|
2024
|
+
iterations=iterations,
|
|
2025
|
+
XprimeX_inv=XprimeX_inv_binary,
|
|
2026
|
+
theoretical_exact=theoretical_exact,
|
|
2027
|
+
verbose=verbose,
|
|
2028
|
+
# R-matching optimizer tolerances (only set if user hasn't specified)
|
|
2029
|
+
bal_gtol=kwargs.pop('bal_gtol', 1e-6),
|
|
2030
|
+
gmm_gtol=kwargs.pop('gmm_gtol', 1e-10),
|
|
2031
|
+
**kwargs
|
|
2032
|
+
)
|
|
2033
|
+
|
|
2034
|
+
# SVD inverse transform for coefficients
|
|
2035
|
+
# R: beta.opt = V %*% diag(d.inv) %*% coef(output)
|
|
2036
|
+
# R: beta.opt[-1,] = beta.opt[-1,] / x.sd
|
|
2037
|
+
# R: beta.opt[1,] = beta.opt[1,] - x.mean %*% beta.opt[-1,]
|
|
2038
|
+
beta_svd_binary = result_dict['coefficients'] # (k, 1)
|
|
2039
|
+
beta_transformed_binary = _apply_svd_inverse_transform(
|
|
2040
|
+
beta_svd_binary, svd_info_binary
|
|
2041
|
+
)
|
|
2042
|
+
result_dict['coefficients'] = beta_transformed_binary
|
|
2043
|
+
|
|
2044
|
+
# SVD inverse transform for variance-covariance matrix
|
|
2045
|
+
# R: Dx.inv %*% ginv(X.orig'X.orig) %*% X.orig' %*% X_svd %*% V %*%
|
|
2046
|
+
# ginv(diag(d)) %*% var %*% ginv(diag(d)) %*% V' %*% X_svd' %*%
|
|
2047
|
+
# X.orig %*% ginv(X.orig'X.orig) %*% Dx.inv
|
|
2048
|
+
variance_svd = result_dict['var']
|
|
2049
|
+
x_sd = svd_info_binary['x_sd']
|
|
2050
|
+
x_mean = svd_info_binary['x_mean']
|
|
2051
|
+
V_mat = svd_info_binary['V']
|
|
2052
|
+
d_vals = svd_info_binary['d']
|
|
2053
|
+
X_svd_mat = X_svd_binary # U matrix
|
|
2054
|
+
|
|
2055
|
+
# Dx_inv in R is diag(c(1, x.sd)) — note: R's naming is misleading
|
|
2056
|
+
Dx = np.diag(np.concatenate([[1.0], x_sd]))
|
|
2057
|
+
|
|
2058
|
+
# d_inv for variance transform
|
|
2059
|
+
d_inv_var = d_vals.copy()
|
|
2060
|
+
d_inv_var[d_inv_var > 1e-5] = 1.0 / d_inv_var[d_inv_var > 1e-5]
|
|
2061
|
+
d_inv_var[d_inv_var <= 1e-5] = 0.0
|
|
2062
|
+
|
|
2063
|
+
# Build transform matrix A:
|
|
2064
|
+
# A = Dx %*% ginv(X.orig'X.orig) %*% X.orig' %*% X_svd %*% V %*% diag(d_inv)
|
|
2065
|
+
XoXo_inv = _r_ginv(X_orig_binary.T @ X_orig_binary)
|
|
2066
|
+
A = (Dx @ XoXo_inv @ X_orig_binary.T @ X_svd_mat
|
|
2067
|
+
@ V_mat @ np.diag(d_inv_var))
|
|
2068
|
+
|
|
2069
|
+
# var_transformed = A %*% variance %*% A'
|
|
2070
|
+
result_dict['var'] = A @ variance_svd @ A.T
|
|
2071
|
+
|
|
2072
|
+
# Restore original X (fitted_values and weights are preserved by SVD)
|
|
2073
|
+
result_dict['x'] = X_orig_binary
|
|
2074
|
+
|
|
2075
|
+
# 3-level treatment routing
|
|
2076
|
+
elif no_treats == 3:
|
|
2077
|
+
from cbps.core.cbps_multitreat import cbps_3treat_fit
|
|
2078
|
+
|
|
2079
|
+
# Convert method to bal_only flag
|
|
2080
|
+
bal_only = (method == 'exact')
|
|
2081
|
+
|
|
2082
|
+
# Apply SVD preprocessing
|
|
2083
|
+
X_orig = X.copy() # Save original X
|
|
2084
|
+
X_svd, svd_info = _apply_svd_preprocessing(X)
|
|
2085
|
+
|
|
2086
|
+
# Compute rank and XprimeX_inv
|
|
2087
|
+
k = np.linalg.matrix_rank(X_svd)
|
|
2088
|
+
if k < X_svd.shape[1]:
|
|
2089
|
+
raise ValueError("X is not full rank")
|
|
2090
|
+
|
|
2091
|
+
# Compute XprimeX_inv in SVD space
|
|
2092
|
+
if sample_weights is None:
|
|
2093
|
+
sample_weights_norm = np.ones(len(treat))
|
|
2094
|
+
else:
|
|
2095
|
+
sample_weights_norm = sample_weights / np.mean(sample_weights)
|
|
2096
|
+
|
|
2097
|
+
sw_sqrt_X = np.sqrt(sample_weights_norm)[:, None] * X_svd
|
|
2098
|
+
XprimeX = sw_sqrt_X.T @ sw_sqrt_X
|
|
2099
|
+
from cbps.core.cbps_binary import _r_ginv
|
|
2100
|
+
XprimeX_inv = _r_ginv(XprimeX)
|
|
2101
|
+
|
|
2102
|
+
# Call 3-level fit in SVD space
|
|
2103
|
+
result_dict = cbps_3treat_fit(
|
|
2104
|
+
treat=treat,
|
|
2105
|
+
X=X_svd, # SVD-orthogonalized matrix
|
|
2106
|
+
method=method,
|
|
2107
|
+
k=k,
|
|
2108
|
+
XprimeX_inv=XprimeX_inv,
|
|
2109
|
+
bal_only=bal_only,
|
|
2110
|
+
iterations=iterations,
|
|
2111
|
+
standardize=standardize,
|
|
2112
|
+
two_step=two_step,
|
|
2113
|
+
sample_weights=sample_weights,
|
|
2114
|
+
treat_levels=treat_levels,
|
|
2115
|
+
verbose=verbose
|
|
2116
|
+
)
|
|
2117
|
+
|
|
2118
|
+
# SVD inverse transform
|
|
2119
|
+
beta_svd = result_dict['coefficients'] # (k, 2)
|
|
2120
|
+
beta_transformed = _apply_svd_inverse_transform(beta_svd, svd_info)
|
|
2121
|
+
|
|
2122
|
+
# Update coefficients in result_dict
|
|
2123
|
+
result_dict['coefficients'] = beta_transformed
|
|
2124
|
+
result_dict['x'] = X_orig # Restore original X
|
|
2125
|
+
|
|
2126
|
+
# Recompute fitted_values and linear_predictor with original X and transformed beta
|
|
2127
|
+
theta_transformed = X_orig @ beta_transformed # (n, 2)
|
|
2128
|
+
|
|
2129
|
+
# Recompute softmax probabilities (numerically stable)
|
|
2130
|
+
from cbps.core.cbps_multitreat import PROBS_MIN, _compute_softmax_probs_3treat
|
|
2131
|
+
probs_transformed = _compute_softmax_probs_3treat(theta_transformed, PROBS_MIN)
|
|
2132
|
+
|
|
2133
|
+
# Update result_dict
|
|
2134
|
+
result_dict['fitted_values'] = probs_transformed
|
|
2135
|
+
result_dict['linear_predictor'] = theta_transformed
|
|
2136
|
+
|
|
2137
|
+
# Add treat_names for result object
|
|
2138
|
+
treat_names = [str(level) for level in treat_levels]
|
|
2139
|
+
|
|
2140
|
+
# 4-level treatment routing
|
|
2141
|
+
elif no_treats == 4:
|
|
2142
|
+
from cbps.core.cbps_multitreat import cbps_4treat_fit
|
|
2143
|
+
|
|
2144
|
+
bal_only = (method == 'exact')
|
|
2145
|
+
|
|
2146
|
+
# Apply SVD preprocessing
|
|
2147
|
+
X_orig = X.copy() # Save original X
|
|
2148
|
+
X_svd, svd_info = _apply_svd_preprocessing(X)
|
|
2149
|
+
|
|
2150
|
+
# Compute rank and XprimeX_inv
|
|
2151
|
+
k = np.linalg.matrix_rank(X_svd)
|
|
2152
|
+
if k < X_svd.shape[1]:
|
|
2153
|
+
raise ValueError("X is not full rank")
|
|
2154
|
+
|
|
2155
|
+
if sample_weights is None:
|
|
2156
|
+
sample_weights_norm = np.ones(len(treat))
|
|
2157
|
+
else:
|
|
2158
|
+
sample_weights_norm = sample_weights / np.mean(sample_weights)
|
|
2159
|
+
|
|
2160
|
+
sw_sqrt_X = np.sqrt(sample_weights_norm)[:, None] * X_svd
|
|
2161
|
+
XprimeX = sw_sqrt_X.T @ sw_sqrt_X
|
|
2162
|
+
from cbps.core.cbps_binary import _r_ginv
|
|
2163
|
+
XprimeX_inv = _r_ginv(XprimeX)
|
|
2164
|
+
|
|
2165
|
+
# Call 4-level fit in SVD space
|
|
2166
|
+
result_dict = cbps_4treat_fit(
|
|
2167
|
+
treat=treat,
|
|
2168
|
+
X=X_svd, # SVD-orthogonalized matrix
|
|
2169
|
+
method=method,
|
|
2170
|
+
k=k,
|
|
2171
|
+
XprimeX_inv=XprimeX_inv,
|
|
2172
|
+
bal_only=bal_only,
|
|
2173
|
+
iterations=iterations,
|
|
2174
|
+
standardize=standardize,
|
|
2175
|
+
two_step=two_step,
|
|
2176
|
+
sample_weights=sample_weights,
|
|
2177
|
+
treat_levels=treat_levels,
|
|
2178
|
+
verbose=verbose
|
|
2179
|
+
)
|
|
2180
|
+
|
|
2181
|
+
# SVD inverse transform
|
|
2182
|
+
beta_svd = result_dict['coefficients'] # (k, 3)
|
|
2183
|
+
beta_transformed = _apply_svd_inverse_transform(beta_svd, svd_info)
|
|
2184
|
+
|
|
2185
|
+
# Update result_dict
|
|
2186
|
+
result_dict['coefficients'] = beta_transformed
|
|
2187
|
+
result_dict['x'] = X_orig
|
|
2188
|
+
|
|
2189
|
+
# Recompute fitted_values and linear_predictor with original X and transformed beta
|
|
2190
|
+
theta_transformed = X_orig @ beta_transformed # (n, 3)
|
|
2191
|
+
|
|
2192
|
+
# Recompute softmax probabilities (numerically stable)
|
|
2193
|
+
from cbps.core.cbps_multitreat import PROBS_MIN, _compute_softmax_probs_4treat
|
|
2194
|
+
probs_transformed = _compute_softmax_probs_4treat(theta_transformed, PROBS_MIN)
|
|
2195
|
+
|
|
2196
|
+
# Update result_dict
|
|
2197
|
+
result_dict['fitted_values'] = probs_transformed
|
|
2198
|
+
result_dict['linear_predictor'] = theta_transformed
|
|
2199
|
+
|
|
2200
|
+
# Add treat_names for result object
|
|
2201
|
+
treat_names = [str(level) for level in treat_levels]
|
|
2202
|
+
|
|
2203
|
+
# Step 4: Wrap in CBPSResults object
|
|
2204
|
+
# Remove keys not accepted by CBPSResults
|
|
2205
|
+
result_dict.pop('ocbps_conditions', None)
|
|
2206
|
+
result_dict.pop('normality_diagnostics', None)
|
|
2207
|
+
if no_treats in [3, 4]:
|
|
2208
|
+
result = CBPSResults(
|
|
2209
|
+
**result_dict,
|
|
2210
|
+
coef_names=coef_names,
|
|
2211
|
+
call_info=call_info,
|
|
2212
|
+
formula=formula,
|
|
2213
|
+
na_action=na_action_info,
|
|
2214
|
+
data=data_original,
|
|
2215
|
+
terms=terms_obj,
|
|
2216
|
+
model=model_frame,
|
|
2217
|
+
xlevels=xlevels_obj,
|
|
2218
|
+
treat_names=treat_names,
|
|
2219
|
+
att=att,
|
|
2220
|
+
method=method,
|
|
2221
|
+
standardize=standardize,
|
|
2222
|
+
two_step=two_step
|
|
2223
|
+
)
|
|
2224
|
+
else:
|
|
2225
|
+
result = CBPSResults(
|
|
2226
|
+
**result_dict,
|
|
2227
|
+
coef_names=coef_names,
|
|
2228
|
+
call_info=call_info,
|
|
2229
|
+
formula=formula,
|
|
2230
|
+
na_action=na_action_info,
|
|
2231
|
+
data=data_original if formula is not None else None,
|
|
2232
|
+
terms=terms_obj if formula is not None else None,
|
|
2233
|
+
model=model_frame if formula is not None else None,
|
|
2234
|
+
xlevels=xlevels_obj if formula is not None else None,
|
|
2235
|
+
att=att,
|
|
2236
|
+
method=method,
|
|
2237
|
+
standardize=standardize,
|
|
2238
|
+
two_step=two_step
|
|
2239
|
+
)
|
|
2240
|
+
|
|
2241
|
+
# Check for overlap violation
|
|
2242
|
+
_check_overlap_violation(result, is_continuous)
|
|
2243
|
+
|
|
2244
|
+
return result
|
|
2245
|
+
|
|
2246
|
+
|
|
2247
|
+
def cbps_fit(
|
|
2248
|
+
treat: Union[np.ndarray, pd.Series, pd.Categorical],
|
|
2249
|
+
X: np.ndarray,
|
|
2250
|
+
method: str = 'over',
|
|
2251
|
+
att: int = 1,
|
|
2252
|
+
two_step: bool = True,
|
|
2253
|
+
standardize: bool = True,
|
|
2254
|
+
iterations: int = 1000,
|
|
2255
|
+
sample_weights: Optional[np.ndarray] = None,
|
|
2256
|
+
baseline_X: Optional[np.ndarray] = None,
|
|
2257
|
+
diff_X: Optional[np.ndarray] = None,
|
|
2258
|
+
verbose: int = 0,
|
|
2259
|
+
**kwargs
|
|
2260
|
+
) -> Dict[str, Any]:
|
|
2261
|
+
"""
|
|
2262
|
+
Low-level CBPS fitting function (type detection and routing).
|
|
2263
|
+
|
|
2264
|
+
Performs treatment type detection, SVD preprocessing, routes to specific
|
|
2265
|
+
algorithm, applies SVD inverse transform, and returns raw dict (not wrapped
|
|
2266
|
+
in CBPSResults object).
|
|
2267
|
+
|
|
2268
|
+
Parameters
|
|
2269
|
+
----------
|
|
2270
|
+
treat : np.ndarray or pd.Series or pd.Categorical, shape (n,)
|
|
2271
|
+
Treatment variable.
|
|
2272
|
+
- pd.Categorical or pd.Series with categorical dtype: discrete treatment
|
|
2273
|
+
- np.ndarray (int/float): numeric treatment (0/1 auto-converted to factor)
|
|
2274
|
+
X : np.ndarray, shape (n, k)
|
|
2275
|
+
Covariate matrix, first column is intercept (all ones).
|
|
2276
|
+
method : {'over', 'exact'}, default='over'
|
|
2277
|
+
'over': over-identified GMM (default)
|
|
2278
|
+
'exact': exactly identified
|
|
2279
|
+
att : int or str, {0, 1, 2, 'ate', 'att', 'atc'}, default=1
|
|
2280
|
+
Target estimand for causal effect estimation:
|
|
2281
|
+
- 0 or 'ate': ATE (Average Treatment Effect)
|
|
2282
|
+
- 1 or 'att': ATT (Average Treatment Effect on Treated)
|
|
2283
|
+
- 2 or 'atc': ATC (Average Treatment Effect on Controls)
|
|
2284
|
+
String values are case-insensitive.
|
|
2285
|
+
two_step : bool, default=True
|
|
2286
|
+
Whether to use two-step estimation.
|
|
2287
|
+
standardize : bool, default=True
|
|
2288
|
+
Whether to standardize.
|
|
2289
|
+
iterations : int, default=1000
|
|
2290
|
+
Maximum iterations.
|
|
2291
|
+
sample_weights : np.ndarray, optional
|
|
2292
|
+
Sample weights (observation-level).
|
|
2293
|
+
baseline_X : np.ndarray, optional
|
|
2294
|
+
Baseline outcome covariate matrix for oCBPS.
|
|
2295
|
+
diff_X : np.ndarray, optional
|
|
2296
|
+
Treatment effect covariate matrix for oCBPS.
|
|
2297
|
+
verbose : int, default=0
|
|
2298
|
+
Verbosity level (0=silent, 1=basic, 2=detailed).
|
|
2299
|
+
**kwargs
|
|
2300
|
+
Additional arguments passed to underlying algorithm.
|
|
2301
|
+
Notable pass-through parameters for binary treatment:
|
|
2302
|
+
|
|
2303
|
+
- ``init_params`` (np.ndarray): Initial parameter values for warm
|
|
2304
|
+
start. Skips GLM initialization and uses these values directly.
|
|
2305
|
+
Length must equal the number of columns in X.
|
|
2306
|
+
|
|
2307
|
+
Returns
|
|
2308
|
+
-------
|
|
2309
|
+
dict
|
|
2310
|
+
Dictionary containing all fitting results:
|
|
2311
|
+
- 'coefficients': coefficient matrix
|
|
2312
|
+
- 'fitted_values': fitted propensity scores
|
|
2313
|
+
- 'weights': inverse probability weights
|
|
2314
|
+
- 'y': treatment variable
|
|
2315
|
+
- 'x': covariate matrix (original space)
|
|
2316
|
+
- 'converged': convergence status
|
|
2317
|
+
- 'J': J statistic
|
|
2318
|
+
- 'var': variance-covariance matrix
|
|
2319
|
+
- other fields vary by treatment type
|
|
2320
|
+
|
|
2321
|
+
Notes
|
|
2322
|
+
-----
|
|
2323
|
+
Difference from CBPS main function:
|
|
2324
|
+
- cbps_fit is low-level API, accepts numpy arrays instead of formulas
|
|
2325
|
+
- Returns dict instead of CBPSResults object
|
|
2326
|
+
- Handles SVD preprocessing and inverse transform
|
|
2327
|
+
- More flexible, suitable for advanced users
|
|
2328
|
+
|
|
2329
|
+
SVD preprocessing workflow:
|
|
2330
|
+
1. Standardize X (except intercept)
|
|
2331
|
+
2. SVD decomposition: X = U·D·V'
|
|
2332
|
+
3. Use orthogonalized U as new X
|
|
2333
|
+
4. Call underlying algorithm (in SVD space)
|
|
2334
|
+
5. Inverse transform coefficients and variance back to original space
|
|
2335
|
+
|
|
2336
|
+
Examples
|
|
2337
|
+
--------
|
|
2338
|
+
>>> import numpy as np
|
|
2339
|
+
>>> from cbps import cbps_fit
|
|
2340
|
+
>>>
|
|
2341
|
+
>>> # Prepare data
|
|
2342
|
+
>>> n = 100
|
|
2343
|
+
>>> treat = np.array([0, 1] * 50)
|
|
2344
|
+
>>> X = np.column_stack([np.ones(n), np.random.randn(n, 2)])
|
|
2345
|
+
>>>
|
|
2346
|
+
>>> # Call low-level API
|
|
2347
|
+
>>> result = cbps_fit(treat, X, method='over', att=1)
|
|
2348
|
+
>>> print(result['coefficients'])
|
|
2349
|
+
>>> print(result['weights'])
|
|
2350
|
+
|
|
2351
|
+
References
|
|
2352
|
+
----------
|
|
2353
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
2354
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
2355
|
+
https://doi.org/10.1111/rssb.12027
|
|
2356
|
+
"""
|
|
2357
|
+
from cbps.core.cbps_binary import cbps_binary_fit, _r_ginv
|
|
2358
|
+
from cbps.core.cbps_multitreat import cbps_3treat_fit, cbps_4treat_fit
|
|
2359
|
+
from cbps.core.cbps_continuous import cbps_continuous_fit
|
|
2360
|
+
from cbps.core.cbps_optimal import cbps_optimal_2treat
|
|
2361
|
+
|
|
2362
|
+
# Step 1: 0/1 binary special handling
|
|
2363
|
+
# Numeric 0/1 auto-converted to factor
|
|
2364
|
+
is_factor = False
|
|
2365
|
+
treat_array = treat
|
|
2366
|
+
|
|
2367
|
+
if isinstance(treat, pd.Categorical):
|
|
2368
|
+
is_factor = True
|
|
2369
|
+
treat_array = treat.codes # Numeric codes
|
|
2370
|
+
treat_categories = treat.categories
|
|
2371
|
+
elif hasattr(treat, 'cat'): # pd.Series with categorical dtype
|
|
2372
|
+
is_factor = True
|
|
2373
|
+
treat_array = treat.cat.codes.to_numpy()
|
|
2374
|
+
treat_categories = treat.cat.categories
|
|
2375
|
+
elif isinstance(treat, pd.Series):
|
|
2376
|
+
treat_array = treat.to_numpy()
|
|
2377
|
+
else:
|
|
2378
|
+
treat_array = np.asarray(treat)
|
|
2379
|
+
|
|
2380
|
+
# Detect 0/1 binary
|
|
2381
|
+
if not is_factor and np.issubdtype(treat_array.dtype, np.number):
|
|
2382
|
+
treat_unique = np.unique(treat_array)
|
|
2383
|
+
if len(treat_unique) == 2 and set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}:
|
|
2384
|
+
# Auto-convert to factor
|
|
2385
|
+
treat = pd.Categorical(treat_array)
|
|
2386
|
+
is_factor = True
|
|
2387
|
+
treat_array = treat.codes
|
|
2388
|
+
treat_categories = treat.categories
|
|
2389
|
+
|
|
2390
|
+
# Step 2: Method parameter conversion
|
|
2391
|
+
bal_only = (method == 'exact')
|
|
2392
|
+
|
|
2393
|
+
# Step 3: Variable name handling
|
|
2394
|
+
# Column names for X (for result output)
|
|
2395
|
+
names_X = [f"X{i}" if i > 0 else "(Intercept)" for i in range(X.shape[1])]
|
|
2396
|
+
# Mark zero-variance columns as "(Intercept)"
|
|
2397
|
+
x_sd_check = X.std(axis=0, ddof=1)
|
|
2398
|
+
for i in range(X.shape[1]):
|
|
2399
|
+
if x_sd_check[i] < 1e-10:
|
|
2400
|
+
names_X[i] = "(Intercept)"
|
|
2401
|
+
|
|
2402
|
+
# Step 4: SVD preprocessing (non-oCBPS path only)
|
|
2403
|
+
# oCBPS requires both baseline_X and diff_X; if only one is provided,
|
|
2404
|
+
# we'll raise an error later in the routing logic
|
|
2405
|
+
X_orig = X.copy()
|
|
2406
|
+
svd_info = None
|
|
2407
|
+
|
|
2408
|
+
if baseline_X is None and diff_X is None: # Non-oCBPS path
|
|
2409
|
+
# Apply SVD preprocessing
|
|
2410
|
+
X_svd, svd_info = _apply_svd_preprocessing(X)
|
|
2411
|
+
X_for_algo = X_svd
|
|
2412
|
+
else:
|
|
2413
|
+
# oCBPS path (or partial - will be validated later): no SVD preprocessing
|
|
2414
|
+
X_for_algo = X
|
|
2415
|
+
|
|
2416
|
+
# Step 5: Rank check and XprimeX_inv
|
|
2417
|
+
k = np.linalg.matrix_rank(X_for_algo)
|
|
2418
|
+
if k < X_for_algo.shape[1]:
|
|
2419
|
+
raise ValueError("X is not full rank")
|
|
2420
|
+
|
|
2421
|
+
# Compute weighted XprimeX_inv
|
|
2422
|
+
if sample_weights is None:
|
|
2423
|
+
sample_weights = np.ones(len(treat_array))
|
|
2424
|
+
|
|
2425
|
+
w_sqrt = np.sqrt(sample_weights)
|
|
2426
|
+
X_weighted = w_sqrt[:, None] * X_for_algo
|
|
2427
|
+
XprimeX_inv = _r_ginv(X_weighted.T @ X_weighted)
|
|
2428
|
+
|
|
2429
|
+
# Step 6: Treatment type detection and routing
|
|
2430
|
+
output = None
|
|
2431
|
+
|
|
2432
|
+
if is_factor:
|
|
2433
|
+
# Discrete treatment path
|
|
2434
|
+
no_treats = len(treat_categories)
|
|
2435
|
+
|
|
2436
|
+
# Validate treatment count
|
|
2437
|
+
if no_treats > 4:
|
|
2438
|
+
raise ValueError(
|
|
2439
|
+
"Parametric CBPS is not implemented for more than 4 treatment values. "
|
|
2440
|
+
"Consider using a continuous treatment."
|
|
2441
|
+
)
|
|
2442
|
+
if no_treats < 2:
|
|
2443
|
+
raise ValueError("Treatment must take more than one value")
|
|
2444
|
+
|
|
2445
|
+
# Route to appropriate algorithm
|
|
2446
|
+
if no_treats == 2:
|
|
2447
|
+
# Binary treatment
|
|
2448
|
+
if baseline_X is not None and diff_X is not None:
|
|
2449
|
+
# oCBPS path
|
|
2450
|
+
if att != 0:
|
|
2451
|
+
warnings.warn(
|
|
2452
|
+
f"CBPSOptimal only supports att=0 (ATE). "
|
|
2453
|
+
f"Received att={att}, forcing to att=0.",
|
|
2454
|
+
UserWarning
|
|
2455
|
+
)
|
|
2456
|
+
output = cbps_optimal_2treat(
|
|
2457
|
+
treat=treat_array,
|
|
2458
|
+
X=X_for_algo, # oCBPS uses original X
|
|
2459
|
+
baseline_X=baseline_X,
|
|
2460
|
+
diff_X=diff_X,
|
|
2461
|
+
iterations=iterations,
|
|
2462
|
+
att=0, # oCBPS forces att=0 (ATE only)
|
|
2463
|
+
standardize=standardize
|
|
2464
|
+
)
|
|
2465
|
+
elif baseline_X is not None or diff_X is not None:
|
|
2466
|
+
# Only one of baseline_X/diff_X provided - invalid for oCBPS
|
|
2467
|
+
raise ValueError(
|
|
2468
|
+
"For oCBPS (optimal CBPS), both baseline_X and diff_X must be provided. "
|
|
2469
|
+
f"Received: baseline_X={'provided' if baseline_X is not None else 'None'}, "
|
|
2470
|
+
f"diff_X={'provided' if diff_X is not None else 'None'}. "
|
|
2471
|
+
"Either provide both for oCBPS, or neither for standard CBPS."
|
|
2472
|
+
)
|
|
2473
|
+
else:
|
|
2474
|
+
# Standard binary CBPS
|
|
2475
|
+
output = cbps_binary_fit(
|
|
2476
|
+
treat=treat_array,
|
|
2477
|
+
X=X_for_algo, # SVD space X
|
|
2478
|
+
att=att,
|
|
2479
|
+
method=method,
|
|
2480
|
+
two_step=two_step,
|
|
2481
|
+
iterations=iterations,
|
|
2482
|
+
standardize=standardize,
|
|
2483
|
+
sample_weights=sample_weights,
|
|
2484
|
+
XprimeX_inv=XprimeX_inv,
|
|
2485
|
+
|
|
2486
|
+
verbose=verbose
|
|
2487
|
+
)
|
|
2488
|
+
|
|
2489
|
+
elif no_treats == 3:
|
|
2490
|
+
# 3-level treatment
|
|
2491
|
+
output = cbps_3treat_fit(
|
|
2492
|
+
treat=treat_array,
|
|
2493
|
+
X=X_for_algo,
|
|
2494
|
+
method=method,
|
|
2495
|
+
k=k,
|
|
2496
|
+
XprimeX_inv=XprimeX_inv,
|
|
2497
|
+
bal_only=bal_only,
|
|
2498
|
+
iterations=iterations,
|
|
2499
|
+
standardize=standardize,
|
|
2500
|
+
two_step=two_step,
|
|
2501
|
+
sample_weights=sample_weights,
|
|
2502
|
+
treat_levels=treat_categories.to_numpy() if hasattr(treat_categories, 'to_numpy') else np.array(list(treat_categories)),
|
|
2503
|
+
verbose=verbose
|
|
2504
|
+
)
|
|
2505
|
+
|
|
2506
|
+
elif no_treats == 4:
|
|
2507
|
+
# 4-level treatment
|
|
2508
|
+
output = cbps_4treat_fit(
|
|
2509
|
+
treat=treat_array,
|
|
2510
|
+
X=X_for_algo,
|
|
2511
|
+
method=method,
|
|
2512
|
+
k=k,
|
|
2513
|
+
XprimeX_inv=XprimeX_inv,
|
|
2514
|
+
bal_only=bal_only,
|
|
2515
|
+
iterations=iterations,
|
|
2516
|
+
standardize=standardize,
|
|
2517
|
+
two_step=two_step,
|
|
2518
|
+
sample_weights=sample_weights,
|
|
2519
|
+
treat_levels=treat_categories.to_numpy() if hasattr(treat_categories, 'to_numpy') else np.array(list(treat_categories)),
|
|
2520
|
+
verbose=verbose
|
|
2521
|
+
)
|
|
2522
|
+
|
|
2523
|
+
elif np.issubdtype(treat_array.dtype, np.number):
|
|
2524
|
+
# Continuous treatment path
|
|
2525
|
+
# Warn if ≤4 unique values (may be discrete)
|
|
2526
|
+
n_unique = len(np.unique(treat_array))
|
|
2527
|
+
if n_unique <= 4:
|
|
2528
|
+
warnings.warn(
|
|
2529
|
+
f"Treatment vector is numeric with {n_unique} unique values. "
|
|
2530
|
+
f"Interpreting as a continuous treatment. "
|
|
2531
|
+
f"To solve for a binary or multi-valued treatment, make treat a factor.",
|
|
2532
|
+
UserWarning
|
|
2533
|
+
)
|
|
2534
|
+
|
|
2535
|
+
output = cbps_continuous_fit(
|
|
2536
|
+
treat=treat_array,
|
|
2537
|
+
X=X_for_algo,
|
|
2538
|
+
method=method,
|
|
2539
|
+
two_step=two_step,
|
|
2540
|
+
iterations=iterations,
|
|
2541
|
+
standardize=standardize,
|
|
2542
|
+
sample_weights=sample_weights,
|
|
2543
|
+
verbose=verbose
|
|
2544
|
+
)
|
|
2545
|
+
|
|
2546
|
+
else:
|
|
2547
|
+
raise ValueError("Treatment must be either a factor or numeric")
|
|
2548
|
+
|
|
2549
|
+
# Step 7: SVD inverse transform (non-oCBPS path only)
|
|
2550
|
+
if svd_info is not None:
|
|
2551
|
+
# Inverse transform coefficients
|
|
2552
|
+
beta_svd = output['coefficients']
|
|
2553
|
+
beta_orig = _apply_svd_inverse_transform(beta_svd, svd_info)
|
|
2554
|
+
|
|
2555
|
+
# Update output
|
|
2556
|
+
output['coefficients'] = beta_orig
|
|
2557
|
+
output['x'] = X_orig # Replace with original X
|
|
2558
|
+
|
|
2559
|
+
# Variance inverse transform
|
|
2560
|
+
from cbps.utils.variance_transform import apply_variance_svd_inverse_transform
|
|
2561
|
+
|
|
2562
|
+
# Infer treatment type from coefficients shape
|
|
2563
|
+
k = X_orig.shape[1]
|
|
2564
|
+
coef_shape = beta_orig.shape
|
|
2565
|
+
|
|
2566
|
+
# Determine is_factor and no_treats
|
|
2567
|
+
# If coefficients is (k, K-1) shape, it's K-level treatment
|
|
2568
|
+
if len(coef_shape) == 2 and coef_shape[1] > 1:
|
|
2569
|
+
is_factor_inferred = True
|
|
2570
|
+
no_treats_inferred = coef_shape[1] + 1 # K-1 cols → K-level treatment
|
|
2571
|
+
elif len(coef_shape) == 2 and coef_shape[1] == 1:
|
|
2572
|
+
# (k, 1) may be binary or continuous
|
|
2573
|
+
is_factor_inferred = is_factor if 'is_factor' in locals() else False
|
|
2574
|
+
no_treats_inferred = 2 if is_factor_inferred else None
|
|
2575
|
+
else:
|
|
2576
|
+
# (k,) shape, may be binary or continuous
|
|
2577
|
+
is_factor_inferred = is_factor if 'is_factor' in locals() else False
|
|
2578
|
+
no_treats_inferred = 2 if is_factor_inferred else None
|
|
2579
|
+
|
|
2580
|
+
variance_svd = output['var']
|
|
2581
|
+
variance_orig = apply_variance_svd_inverse_transform(
|
|
2582
|
+
variance_svd=variance_svd,
|
|
2583
|
+
svd_info=svd_info,
|
|
2584
|
+
X_orig=X_orig,
|
|
2585
|
+
X_svd=X_for_algo,
|
|
2586
|
+
is_factor=is_factor_inferred,
|
|
2587
|
+
no_treats=no_treats_inferred
|
|
2588
|
+
)
|
|
2589
|
+
output['var'] = variance_orig
|
|
2590
|
+
|
|
2591
|
+
if verbose > 0:
|
|
2592
|
+
print(f"cbps_fit: SVD inverse transform done, coef shape={beta_orig.shape}, var shape={variance_orig.shape}")
|
|
2593
|
+
|
|
2594
|
+
# Add method field
|
|
2595
|
+
output['method'] = method
|
|
2596
|
+
|
|
2597
|
+
return output
|
|
2598
|
+
|
|
2599
|
+
|
|
2600
|
+
def cbmsm_fit(
|
|
2601
|
+
treat: np.ndarray,
|
|
2602
|
+
X: np.ndarray,
|
|
2603
|
+
id: np.ndarray,
|
|
2604
|
+
time: np.ndarray,
|
|
2605
|
+
type: str = "MSM",
|
|
2606
|
+
twostep: bool = True,
|
|
2607
|
+
msm_variance: str = "approx",
|
|
2608
|
+
time_vary: bool = False,
|
|
2609
|
+
init: str = "opt",
|
|
2610
|
+
sample_weights: Optional[np.ndarray] = None,
|
|
2611
|
+
iterations: Optional[int] = None,
|
|
2612
|
+
**kwargs: Any
|
|
2613
|
+
) -> 'CBMSMResults':
|
|
2614
|
+
"""
|
|
2615
|
+
CBMSM Matrix Interface (Low-Level Fitting Function)
|
|
2616
|
+
|
|
2617
|
+
This is the low-level matrix interface for CBMSM, accepting preprocessed
|
|
2618
|
+
matrix inputs. For most users, the formula interface CBMSM() is recommended.
|
|
2619
|
+
|
|
2620
|
+
Parameters
|
|
2621
|
+
----------
|
|
2622
|
+
treat : np.ndarray, shape (N*T,)
|
|
2623
|
+
Treatment vector for N units over T periods.
|
|
2624
|
+
X : np.ndarray, shape (N*T, p)
|
|
2625
|
+
Covariate matrix (including intercept column).
|
|
2626
|
+
id : np.ndarray, shape (N*T,)
|
|
2627
|
+
Unit identifiers.
|
|
2628
|
+
time : np.ndarray, shape (N*T,)
|
|
2629
|
+
Time period identifiers.
|
|
2630
|
+
type : str, default="MSM"
|
|
2631
|
+
Weight type ('MSM' or 'MultiBin').
|
|
2632
|
+
twostep : bool, default=True
|
|
2633
|
+
Whether to use two-step estimation.
|
|
2634
|
+
msm_variance : str, default="approx"
|
|
2635
|
+
Variance estimation method ('approx' or 'full').
|
|
2636
|
+
time_vary : bool, default=False
|
|
2637
|
+
Whether coefficients vary with time.
|
|
2638
|
+
init : str, default="opt"
|
|
2639
|
+
Initialization method ('opt', 'glm', 'CBPS').
|
|
2640
|
+
sample_weights : np.ndarray, optional
|
|
2641
|
+
Observation weights.
|
|
2642
|
+
iterations : int, optional
|
|
2643
|
+
Maximum iterations.
|
|
2644
|
+
**kwargs
|
|
2645
|
+
Additional arguments.
|
|
2646
|
+
|
|
2647
|
+
Returns
|
|
2648
|
+
-------
|
|
2649
|
+
CBMSMResults
|
|
2650
|
+
CBMSM fitting result object.
|
|
2651
|
+
|
|
2652
|
+
See Also
|
|
2653
|
+
--------
|
|
2654
|
+
CBMSM : Formula interface (recommended)
|
|
2655
|
+
|
|
2656
|
+
Examples
|
|
2657
|
+
--------
|
|
2658
|
+
>>> from cbps import cbmsm_fit
|
|
2659
|
+
>>> import numpy as np
|
|
2660
|
+
>>> # Prepare matrix data
|
|
2661
|
+
>>> treat = np.array([0, 1, 0, 1, 0, 1])
|
|
2662
|
+
>>> X = np.column_stack([np.ones(6), np.random.randn(6, 2)])
|
|
2663
|
+
>>> id_vec = np.array([1, 2, 3, 1, 2, 3])
|
|
2664
|
+
>>> time_vec = np.array([1, 1, 1, 2, 2, 2])
|
|
2665
|
+
>>> result = cbmsm_fit(treat, X, id_vec, time_vec)
|
|
2666
|
+
"""
|
|
2667
|
+
from cbps.msm.cbmsm import cbmsm_fit as _cbmsm_fit
|
|
2668
|
+
return _cbmsm_fit(
|
|
2669
|
+
treat=treat, X=X, id=id, time=time,
|
|
2670
|
+
type=type, twostep=twostep, msm_variance=msm_variance,
|
|
2671
|
+
time_vary=time_vary, init=init, sample_weights=sample_weights,
|
|
2672
|
+
iterations=iterations, **kwargs
|
|
2673
|
+
)
|
|
2674
|
+
|
|
2675
|
+
|
|
2676
|
+
def CBMSM(
|
|
2677
|
+
formula: str,
|
|
2678
|
+
id: Union[str, pd.Series, np.ndarray],
|
|
2679
|
+
time: Union[str, pd.Series, np.ndarray],
|
|
2680
|
+
data: pd.DataFrame,
|
|
2681
|
+
type: str = "MSM",
|
|
2682
|
+
twostep: bool = True,
|
|
2683
|
+
msm_variance: str = "approx",
|
|
2684
|
+
time_vary: bool = False,
|
|
2685
|
+
init: str = "opt",
|
|
2686
|
+
iterations: Optional[int] = None,
|
|
2687
|
+
**kwargs: Any
|
|
2688
|
+
) -> 'CBMSMResults':
|
|
2689
|
+
"""
|
|
2690
|
+
Covariate Balancing Propensity Score for Marginal Structural Models.
|
|
2691
|
+
|
|
2692
|
+
Estimates inverse probability of treatment weights for longitudinal data
|
|
2693
|
+
with time-varying treatments and confounders. Designed for panel data where
|
|
2694
|
+
treatment effects unfold over multiple time periods.
|
|
2695
|
+
|
|
2696
|
+
Parameters
|
|
2697
|
+
----------
|
|
2698
|
+
formula : str
|
|
2699
|
+
Treatment model formula (e.g., 'treat ~ x1 + x2 + x3').
|
|
2700
|
+
The same covariates are used for all time periods. Data should be
|
|
2701
|
+
sorted by time within each unit.
|
|
2702
|
+
id : str or array-like
|
|
2703
|
+
Unit identifier column name (str) or ID array identifying individuals
|
|
2704
|
+
in the panel data.
|
|
2705
|
+
time : str or array-like
|
|
2706
|
+
Time column name (str) or time array identifying the temporal ordering
|
|
2707
|
+
of observations.
|
|
2708
|
+
data : pd.DataFrame
|
|
2709
|
+
DataFrame containing treatment, covariates, ID, and time variables.
|
|
2710
|
+
type : {'MSM', 'MultiBin'}, default='MSM'
|
|
2711
|
+
Weight type:
|
|
2712
|
+
- 'MSM': Marginal structural model weights (default)
|
|
2713
|
+
- 'MultiBin': Multiple binary treatment weights
|
|
2714
|
+
twostep : bool, default=True
|
|
2715
|
+
Whether to use two-step estimation (faster with MLE initialization).
|
|
2716
|
+
- True: Estimate parameters for each period separately, then combine
|
|
2717
|
+
- False: Estimate all parameters simultaneously (single-step)
|
|
2718
|
+
msm_variance : {'approx', 'full', None}, default='approx'
|
|
2719
|
+
Variance estimation method:
|
|
2720
|
+
- 'approx': Approximate variance (fast, recommended)
|
|
2721
|
+
- 'full': Full sandwich variance (accurate but slower)
|
|
2722
|
+
- None: Do not compute variance
|
|
2723
|
+
time_vary : bool, default=False
|
|
2724
|
+
Whether treatment model coefficients vary across time:
|
|
2725
|
+
- False: Time-invariant model (shared coefficients across periods)
|
|
2726
|
+
- True: Time-varying model (independent coefficients per period)
|
|
2727
|
+
init : {'opt', 'glm'}, default='opt'
|
|
2728
|
+
Initialization method:
|
|
2729
|
+
- 'opt': Use both CBPS and GLM starting values, select best balance
|
|
2730
|
+
- 'glm': Use only GLM starting values
|
|
2731
|
+
iterations : int, optional
|
|
2732
|
+
Maximum number of optimization iterations.
|
|
2733
|
+
**kwargs
|
|
2734
|
+
Additional parameters passed to the underlying implementation.
|
|
2735
|
+
|
|
2736
|
+
Returns
|
|
2737
|
+
-------
|
|
2738
|
+
CBMSMResults
|
|
2739
|
+
CBMSM fitted result object containing:
|
|
2740
|
+
- weights: MSM weight array (unit-level)
|
|
2741
|
+
- fitted_values: Propensity scores for each period
|
|
2742
|
+
- converged: Convergence status
|
|
2743
|
+
- coefficients: Estimated model coefficients
|
|
2744
|
+
|
|
2745
|
+
Examples
|
|
2746
|
+
--------
|
|
2747
|
+
Estimate MSM weights using panel data:
|
|
2748
|
+
|
|
2749
|
+
>>> from cbps import CBMSM
|
|
2750
|
+
>>> from cbps.datasets import load_blackwell
|
|
2751
|
+
>>> data = load_blackwell()
|
|
2752
|
+
>>> fit = CBMSM('d.gone.neg ~ d.gone.neg.l1 + camp.length',
|
|
2753
|
+
... id='demName', time='time', data=data, type='MSM')
|
|
2754
|
+
>>> print(f"Weights shape: {fit.weights.shape}")
|
|
2755
|
+
|
|
2756
|
+
Notes
|
|
2757
|
+
-----
|
|
2758
|
+
**Data Requirements**: Must be a balanced panel where each id appears
|
|
2759
|
+
exactly once at each time period.
|
|
2760
|
+
|
|
2761
|
+
References
|
|
2762
|
+
----------
|
|
2763
|
+
Imai, K. and Ratkovic, M. (2015). Robust Estimation of Inverse Probability
|
|
2764
|
+
Weights for Marginal Structural Models. Journal of the American Statistical
|
|
2765
|
+
Association, 110(511), 1013-1023. https://doi.org/10.1080/01621459.2014.956872
|
|
2766
|
+
|
|
2767
|
+
See Also
|
|
2768
|
+
--------
|
|
2769
|
+
CBPS : Covariate balancing propensity score for cross-sectional data
|
|
2770
|
+
"""
|
|
2771
|
+
from cbps.msm.cbmsm import CBMSM as _CBMSM
|
|
2772
|
+
# Handle two_step alias
|
|
2773
|
+
if 'two_step' in kwargs and twostep is True:
|
|
2774
|
+
twostep = kwargs.pop('two_step')
|
|
2775
|
+
|
|
2776
|
+
return _CBMSM(
|
|
2777
|
+
formula=formula, id=id, time=time, data=data,
|
|
2778
|
+
type=type, twostep=twostep, msm_variance=msm_variance,
|
|
2779
|
+
time_vary=time_vary, init=init, iterations=iterations,
|
|
2780
|
+
**kwargs
|
|
2781
|
+
)
|
|
2782
|
+
|
|
2783
|
+
|
|
2784
|
+
def npCBPS(
|
|
2785
|
+
formula: str,
|
|
2786
|
+
data: pd.DataFrame,
|
|
2787
|
+
na_action: Optional[str] = None,
|
|
2788
|
+
corprior: Optional[float] = None,
|
|
2789
|
+
print_level: int = 0,
|
|
2790
|
+
seed: Optional[int] = None,
|
|
2791
|
+
verbose: int = 0,
|
|
2792
|
+
**kwargs: Any
|
|
2793
|
+
) -> 'NPCBPSResults':
|
|
2794
|
+
"""
|
|
2795
|
+
Nonparametric Covariate Balancing Propensity Score.
|
|
2796
|
+
|
|
2797
|
+
Estimates weights directly using the empirical likelihood framework,
|
|
2798
|
+
without requiring a parametric propensity score model specification.
|
|
2799
|
+
|
|
2800
|
+
Parameters
|
|
2801
|
+
----------
|
|
2802
|
+
formula : str
|
|
2803
|
+
Model formula specifying treatment and covariates (e.g., 'treat ~ age + educ').
|
|
2804
|
+
data : pd.DataFrame
|
|
2805
|
+
DataFrame containing the treatment and covariate variables.
|
|
2806
|
+
corprior : float, default=None
|
|
2807
|
+
Prior standard deviation σ controlling the weighted correlation between
|
|
2808
|
+
covariates and treatment, where η ~ N(0, σ²I).
|
|
2809
|
+
Note: corprior is the standard deviation σ, not the variance σ².
|
|
2810
|
+
|
|
2811
|
+
Default (None): Automatically set to 0.1/n (sample-size adaptive).
|
|
2812
|
+
- Small sample (n=10): corprior ≈ 0.01
|
|
2813
|
+
- Medium sample (n=100): corprior ≈ 0.001
|
|
2814
|
+
- Large sample (n=1000): corprior ≈ 0.0001
|
|
2815
|
+
|
|
2816
|
+
Reference: Fong, Hazlett & Imai (2018) Section 3.3.4
|
|
2817
|
+
print_level : int, default=0
|
|
2818
|
+
Diagnostic output verbosity level.
|
|
2819
|
+
seed : int, optional
|
|
2820
|
+
Random seed for reproducibility.
|
|
2821
|
+
verbose : int, default=0
|
|
2822
|
+
Verbosity level for progress messages.
|
|
2823
|
+
**kwargs : Any
|
|
2824
|
+
Additional parameters passed to the underlying optimization routine.
|
|
2825
|
+
|
|
2826
|
+
Returns
|
|
2827
|
+
-------
|
|
2828
|
+
NPCBPSResults
|
|
2829
|
+
Fitted result object containing:
|
|
2830
|
+
- weights: Estimated empirical likelihood weights
|
|
2831
|
+
- eta: Weighted correlations (balance diagnostics)
|
|
2832
|
+
- sumw0: Sum of weights (should be ≈ 1, tolerance ±5%)
|
|
2833
|
+
- log_el, log_p_eta: Log empirical likelihood and prior density
|
|
2834
|
+
|
|
2835
|
+
Notes
|
|
2836
|
+
-----
|
|
2837
|
+
The empirical likelihood optimization is non-convex, which may lead to
|
|
2838
|
+
different local optima across implementations. Convergence quality should
|
|
2839
|
+
be verified by checking that sumw0 ≈ 1.0 (within 5% tolerance).
|
|
2840
|
+
|
|
2841
|
+
References
|
|
2842
|
+
----------
|
|
2843
|
+
Fong, C., Hazlett, C., and Imai, K. (2018). Covariate Balancing
|
|
2844
|
+
Propensity Score for a Continuous Treatment. The Annals of Applied
|
|
2845
|
+
Statistics 12(1), 156-177. https://doi.org/10.1214/17-AOAS1101
|
|
2846
|
+
|
|
2847
|
+
Examples
|
|
2848
|
+
--------
|
|
2849
|
+
>>> from cbps import npCBPS
|
|
2850
|
+
>>> from cbps.datasets import load_lalonde
|
|
2851
|
+
>>> df = load_lalonde(dehejia_wahba_only=True)
|
|
2852
|
+
>>> fit = npCBPS('treat ~ age + educ', data=df, corprior=0.01)
|
|
2853
|
+
>>> # Verify convergence
|
|
2854
|
+
>>> assert abs(fit.sumw0 - 1.0) < 0.05, "Weight sum should be close to 1"
|
|
2855
|
+
"""
|
|
2856
|
+
from cbps.nonparametric.npcbps import npCBPS as _npCBPS, npCBPS_fit
|
|
2857
|
+
# verbose parameter is accepted for API consistency but not passed to underlying function
|
|
2858
|
+
# The underlying npCBPS_fit uses print_level to control output
|
|
2859
|
+
_ = verbose # Mark parameter as processed to avoid linter warnings
|
|
2860
|
+
return _npCBPS(
|
|
2861
|
+
formula=formula, data=data, na_action=na_action,
|
|
2862
|
+
corprior=corprior, print_level=print_level, seed=seed,
|
|
2863
|
+
**kwargs
|
|
2864
|
+
)
|
|
2865
|
+
|
|
2866
|
+
|
|
2867
|
+
def hdCBPS(
|
|
2868
|
+
formula: str,
|
|
2869
|
+
data: pd.DataFrame,
|
|
2870
|
+
y: Union[str, np.ndarray],
|
|
2871
|
+
ATT: int = 0,
|
|
2872
|
+
iterations: int = 1000,
|
|
2873
|
+
method: str = 'linear',
|
|
2874
|
+
seed: Optional[int] = None,
|
|
2875
|
+
na_action: Optional[str] = None,
|
|
2876
|
+
verbose: int = 0
|
|
2877
|
+
) -> 'HDCBPSResults':
|
|
2878
|
+
"""
|
|
2879
|
+
High-Dimensional Covariate Balancing Propensity Score estimation.
|
|
2880
|
+
|
|
2881
|
+
Implements covariate balancing propensity score methodology for high-dimensional
|
|
2882
|
+
settings where the number of covariates substantially exceeds the sample
|
|
2883
|
+
size (d >> n). The approach combines LASSO variable selection with covariate
|
|
2884
|
+
balancing constraints to achieve valid causal effect estimation.
|
|
2885
|
+
|
|
2886
|
+
Parameters
|
|
2887
|
+
----------
|
|
2888
|
+
formula : str
|
|
2889
|
+
Model formula specifying treatment and covariates.
|
|
2890
|
+
Example: 'treat ~ age + educ + black + hisp + married + nodegr + re74 + re75'
|
|
2891
|
+
data : pd.DataFrame
|
|
2892
|
+
Dataset containing all variables specified in the formula.
|
|
2893
|
+
y : str or np.ndarray
|
|
2894
|
+
Outcome variable name or array. Used for variable selection in the
|
|
2895
|
+
high-dimensional framework.
|
|
2896
|
+
ATT : int, default 0
|
|
2897
|
+
Target estimand: 0 for ATE (average treatment effect), 1 for ATT
|
|
2898
|
+
(average treatment effect on the treated).
|
|
2899
|
+
iterations : int, default 1000
|
|
2900
|
+
Maximum number of iterations for the optimization algorithm.
|
|
2901
|
+
method : {'linear', 'binomial', 'poisson'}, default 'linear'
|
|
2902
|
+
Type of outcome model for variable selection:
|
|
2903
|
+
- 'linear': Linear regression model
|
|
2904
|
+
- 'binomial': Logistic regression model
|
|
2905
|
+
- 'poisson': Poisson regression model
|
|
2906
|
+
seed : int, optional
|
|
2907
|
+
Random seed for reproducibility. Note: Current implementation uses
|
|
2908
|
+
deterministic LASSO, so this parameter does not affect results.
|
|
2909
|
+
na_action : {None, 'warn', 'drop', 'fail'}, optional
|
|
2910
|
+
How to handle missing values:
|
|
2911
|
+
- None or 'warn': Remove missing observations with warning
|
|
2912
|
+
- 'drop': Remove missing observations silently
|
|
2913
|
+
- 'fail': Raise an error for missing values
|
|
2914
|
+
verbose : int, default 0
|
|
2915
|
+
Verbosity level for output:
|
|
2916
|
+
- 0: Silent mode
|
|
2917
|
+
- 1: Basic iteration information
|
|
2918
|
+
- 2: Detailed debugging information
|
|
2919
|
+
|
|
2920
|
+
Returns
|
|
2921
|
+
-------
|
|
2922
|
+
HDCBPSResults
|
|
2923
|
+
Result object containing:
|
|
2924
|
+
- ATE: Estimated average treatment effect
|
|
2925
|
+
- ATT: Estimated average treatment effect on the treated
|
|
2926
|
+
- s: Selected variables
|
|
2927
|
+
- fitted_values: Estimated propensity scores
|
|
2928
|
+
- coefficients0: LASSO coefficients for control group (T=0)
|
|
2929
|
+
- coefficients1: LASSO coefficients for treatment group (T=1)
|
|
2930
|
+
- coefficients: Alias for coefficients0 (for API consistency)
|
|
2931
|
+
|
|
2932
|
+
Notes
|
|
2933
|
+
-----
|
|
2934
|
+
The high-dimensional CBPS methodology extends the original CBPS approach
|
|
2935
|
+
to settings with many covariates by incorporating variable selection. The
|
|
2936
|
+
algorithm selects a subset of covariates that are predictive of both the
|
|
2937
|
+
treatment and outcome while maintaining covariate balance.
|
|
2938
|
+
|
|
2939
|
+
Unlike standard CBPS which has one set of coefficients, hdCBPS estimates
|
|
2940
|
+
two LASSO models (one for each treatment level) to achieve variable
|
|
2941
|
+
selection in the high-dimensional setting.
|
|
2942
|
+
|
|
2943
|
+
References
|
|
2944
|
+
----------
|
|
2945
|
+
Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
|
|
2946
|
+
via a high-dimensional covariate balancing propensity score. Biometrika
|
|
2947
|
+
107(3), 533-554. https://doi.org/10.1093/biomet/asaa020
|
|
2948
|
+
|
|
2949
|
+
Examples
|
|
2950
|
+
--------
|
|
2951
|
+
>>> from cbps import hdCBPS
|
|
2952
|
+
>>> from cbps.datasets import load_lalonde
|
|
2953
|
+
>>> # Load high-dimensional data
|
|
2954
|
+
>>> df = load_lalonde(dehejia_wahba_only=True)
|
|
2955
|
+
>>>
|
|
2956
|
+
>>> # Fit high-dimensional CBPS
|
|
2957
|
+
>>> result = hdCBPS(
|
|
2958
|
+
... formula='treat ~ age + educ + black + hisp + married + nodegr + re74 + re75',
|
|
2959
|
+
... data=df,
|
|
2960
|
+
... y='re78', # Outcome variable
|
|
2961
|
+
... ATT=0, # Estimate ATE
|
|
2962
|
+
... method='linear'
|
|
2963
|
+
... )
|
|
2964
|
+
>>>
|
|
2965
|
+
>>> # View results
|
|
2966
|
+
>>> print(f"ATE: {result.ATE:.4f}")
|
|
2967
|
+
>>> print(f"Selected variables: {len(result.s)}")
|
|
2968
|
+
>>> print(f"Converged: {result.converged}")
|
|
2969
|
+
"""
|
|
2970
|
+
from cbps.highdim.hdcbps import hdCBPS as _hdCBPS
|
|
2971
|
+
return _hdCBPS(formula, data, y, ATT, iterations, method, seed, na_action, verbose)
|
|
2972
|
+
|
|
2973
|
+
|
|
2974
|
+
def CBIV(
|
|
2975
|
+
formula: Optional[str] = None,
|
|
2976
|
+
data: Optional[pd.DataFrame] = None,
|
|
2977
|
+
Tr: Optional[np.ndarray] = None,
|
|
2978
|
+
Z: Optional[np.ndarray] = None,
|
|
2979
|
+
X: Optional[np.ndarray] = None,
|
|
2980
|
+
iterations: int = 1000,
|
|
2981
|
+
method: str = "over",
|
|
2982
|
+
twostep: bool = True,
|
|
2983
|
+
twosided: bool = True,
|
|
2984
|
+
probs_min: float = 1e-6,
|
|
2985
|
+
warn_clipping: bool = True,
|
|
2986
|
+
clipping_warn_threshold: float = 0.05,
|
|
2987
|
+
verbose: int = 0,
|
|
2988
|
+
**kwargs: Any
|
|
2989
|
+
) -> 'CBIVResults':
|
|
2990
|
+
"""
|
|
2991
|
+
Covariate Balancing Propensity Score for Instrumental Variables.
|
|
2992
|
+
|
|
2993
|
+
Estimates propensity scores for compliers in instrumental variable settings
|
|
2994
|
+
with treatment noncompliance. This method is designed for encouragement
|
|
2995
|
+
designs where randomized assignment (instrument) affects treatment uptake
|
|
2996
|
+
but may not guarantee compliance.
|
|
2997
|
+
|
|
2998
|
+
Parameters
|
|
2999
|
+
----------
|
|
3000
|
+
formula : str, optional
|
|
3001
|
+
IV formula in the format "treatment ~ covariates | instrument".
|
|
3002
|
+
Example: "treat ~ x1 + x2 | z". Intercept is added automatically.
|
|
3003
|
+
data : pd.DataFrame, optional
|
|
3004
|
+
DataFrame containing the variables specified in formula.
|
|
3005
|
+
Required when using formula interface.
|
|
3006
|
+
Tr : np.ndarray, shape (n,), optional
|
|
3007
|
+
Binary treatment variable (0/1). Required for matrix interface.
|
|
3008
|
+
Z : np.ndarray, shape (n,), optional
|
|
3009
|
+
Binary instrument variable (0/1). Required for matrix interface.
|
|
3010
|
+
X : np.ndarray, shape (n, p), optional
|
|
3011
|
+
Pre-treatment covariate matrix (without intercept). Required for
|
|
3012
|
+
matrix interface.
|
|
3013
|
+
iterations : int, default=1000
|
|
3014
|
+
Maximum number of optimization iterations.
|
|
3015
|
+
method : str, default="over"
|
|
3016
|
+
Estimation method:
|
|
3017
|
+
|
|
3018
|
+
- 'over': Over-identified GMM (propensity score + balance conditions)
|
|
3019
|
+
- 'exact': Just-identified GMM (balance conditions only)
|
|
3020
|
+
- 'mle': Maximum likelihood estimation (propensity score only)
|
|
3021
|
+
twostep : bool, default=True
|
|
3022
|
+
Whether to use two-step GMM estimation. If False, uses continuously
|
|
3023
|
+
updating GMM which has better finite-sample properties but is slower.
|
|
3024
|
+
twosided : bool, default=True
|
|
3025
|
+
Whether to allow two-sided noncompliance:
|
|
3026
|
+
|
|
3027
|
+
- True: Allows compliers, always-takers, and never-takers
|
|
3028
|
+
- False: One-sided noncompliance (compliers and never-takers only)
|
|
3029
|
+
probs_min : float, default=1e-6
|
|
3030
|
+
Probability clipping bound. Compliance probabilities are constrained
|
|
3031
|
+
to the interval [probs_min, 1-probs_min].
|
|
3032
|
+
warn_clipping : bool, default=True
|
|
3033
|
+
Whether to issue a warning when the proportion of clipped compliance
|
|
3034
|
+
probabilities exceeds the threshold.
|
|
3035
|
+
clipping_warn_threshold : float, default=0.05
|
|
3036
|
+
Minimum clipping proportion (between 0 and 1) that triggers a warning.
|
|
3037
|
+
verbose : int, default=0
|
|
3038
|
+
Verbosity level. 0=silent, 1=basic info, 2=detailed diagnostics.
|
|
3039
|
+
|
|
3040
|
+
Returns
|
|
3041
|
+
-------
|
|
3042
|
+
CBIVResults
|
|
3043
|
+
Result object containing coefficients, fitted values, weights, and
|
|
3044
|
+
diagnostic statistics.
|
|
3045
|
+
|
|
3046
|
+
Notes
|
|
3047
|
+
-----
|
|
3048
|
+
The method implements principal stratification with three compliance types:
|
|
3049
|
+
|
|
3050
|
+
- **Compliers**: Units who take treatment when encouraged (Z=1) and do not
|
|
3051
|
+
take treatment when not encouraged (Z=0)
|
|
3052
|
+
- **Always-takers**: Units who always take treatment regardless of Z
|
|
3053
|
+
- **Never-takers**: Units who never take treatment regardless of Z
|
|
3054
|
+
|
|
3055
|
+
The Complier Average Causal Effect (CACE) is identified under standard IV
|
|
3056
|
+
assumptions (exclusion restriction, monotonicity, non-zero first stage).
|
|
3057
|
+
|
|
3058
|
+
References
|
|
3059
|
+
----------
|
|
3060
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
3061
|
+
Journal of the Royal Statistical Society: Series B, 76(1), 243-263.
|
|
3062
|
+
|
|
3063
|
+
Examples
|
|
3064
|
+
--------
|
|
3065
|
+
>>> import numpy as np
|
|
3066
|
+
>>> import pandas as pd
|
|
3067
|
+
>>> from cbps import CBIV
|
|
3068
|
+
>>> # Formula interface
|
|
3069
|
+
>>> df = pd.DataFrame({
|
|
3070
|
+
... 'treat': np.random.binomial(1, 0.5, 100),
|
|
3071
|
+
... 'z': np.random.binomial(1, 0.5, 100),
|
|
3072
|
+
... 'x1': np.random.randn(100),
|
|
3073
|
+
... 'x2': np.random.randn(100)
|
|
3074
|
+
... })
|
|
3075
|
+
>>> fit = CBIV(formula="treat ~ x1 + x2 | z", data=df)
|
|
3076
|
+
>>> print(fit.coefficients.shape)
|
|
3077
|
+
>>>
|
|
3078
|
+
>>> # Matrix interface
|
|
3079
|
+
>>> Tr = np.random.binomial(1, 0.5, 100)
|
|
3080
|
+
>>> Z = np.random.binomial(1, 0.5, 100)
|
|
3081
|
+
>>> X = np.random.randn(100, 2)
|
|
3082
|
+
>>> fit = CBIV(Tr=Tr, Z=Z, X=X, method='over', twosided=True)
|
|
3083
|
+
>>> print(fit.fitted_values.shape)
|
|
3084
|
+
"""
|
|
3085
|
+
from cbps.iv.cbiv import CBIV as _CBIV
|
|
3086
|
+
return _CBIV(
|
|
3087
|
+
formula=formula, data=data, Tr=Tr, Z=Z, X=X,
|
|
3088
|
+
iterations=iterations, method=method, twostep=twostep,
|
|
3089
|
+
twosided=twosided, probs_min=probs_min, warn_clipping=warn_clipping,
|
|
3090
|
+
clipping_warn_threshold=clipping_warn_threshold, verbose=verbose,
|
|
3091
|
+
**kwargs
|
|
3092
|
+
)
|
|
3093
|
+
|
|
3094
|
+
|
|
3095
|
+
def AsyVar(
|
|
3096
|
+
Y: np.ndarray,
|
|
3097
|
+
Y_1_hat: Optional[np.ndarray] = None,
|
|
3098
|
+
Y_0_hat: Optional[np.ndarray] = None,
|
|
3099
|
+
CBPS_obj: Optional[Union[Dict[str, Any], 'CBPSResults']] = None,
|
|
3100
|
+
method: str = "CBPS",
|
|
3101
|
+
X: Optional[np.ndarray] = None,
|
|
3102
|
+
TL: Optional[np.ndarray] = None,
|
|
3103
|
+
pi: Optional[np.ndarray] = None,
|
|
3104
|
+
mu: Optional[float] = None,
|
|
3105
|
+
CI: float = 0.95,
|
|
3106
|
+
use_observed_y: bool = False,
|
|
3107
|
+
**kwargs: Any
|
|
3108
|
+
) -> Dict[str, Any]:
|
|
3109
|
+
"""
|
|
3110
|
+
Asymptotic Variance and Confidence Intervals for ATE.
|
|
3111
|
+
|
|
3112
|
+
Estimates the asymptotic variance of the average treatment effect obtained
|
|
3113
|
+
using CBPS or optimal CBPS (oCBPS) methods. This function computes valid
|
|
3114
|
+
confidence intervals that properly account for the uncertainty in propensity
|
|
3115
|
+
score estimation.
|
|
3116
|
+
|
|
3117
|
+
Parameters
|
|
3118
|
+
----------
|
|
3119
|
+
Y : np.ndarray
|
|
3120
|
+
Observed outcome values.
|
|
3121
|
+
Y_1_hat : np.ndarray, optional
|
|
3122
|
+
Predicted outcomes under treatment. If None, will be automatically fitted.
|
|
3123
|
+
Y_0_hat : np.ndarray, optional
|
|
3124
|
+
Predicted outcomes under control. If None, will be automatically fitted.
|
|
3125
|
+
CBPS_obj : dict or CBPSResults, optional
|
|
3126
|
+
Fitted CBPS object. Required for the CBPS variance estimation path.
|
|
3127
|
+
method : str, default="CBPS"
|
|
3128
|
+
Variance estimation method: 'CBPS' (standard) or 'oCBPS' (optimal).
|
|
3129
|
+
X : np.ndarray, optional
|
|
3130
|
+
Covariate matrix (first column must be intercept).
|
|
3131
|
+
TL : np.ndarray, optional
|
|
3132
|
+
Treatment indicator variable (1=treated, 0=control).
|
|
3133
|
+
pi : np.ndarray, optional
|
|
3134
|
+
Propensity score vector.
|
|
3135
|
+
mu : float, optional
|
|
3136
|
+
Average treatment effect estimate.
|
|
3137
|
+
CI : float, default=0.95
|
|
3138
|
+
Confidence level for the confidence interval.
|
|
3139
|
+
use_observed_y : bool, default=False
|
|
3140
|
+
Sigma_mu computation method:
|
|
3141
|
+
|
|
3142
|
+
- False (default): Use predicted values Y_1_hat, Y_0_hat.
|
|
3143
|
+
This matches R CBPS package behavior and is recommended.
|
|
3144
|
+
- True: Use observed Y values. This is an experimental option
|
|
3145
|
+
not implemented in the R package.
|
|
3146
|
+
|
|
3147
|
+
Returns
|
|
3148
|
+
-------
|
|
3149
|
+
dict
|
|
3150
|
+
Dictionary containing (snake_case keys are preferred):
|
|
3151
|
+
|
|
3152
|
+
- 'mu_hat' (or 'mu.hat'): ATE estimate
|
|
3153
|
+
- 'asy_var' (or 'asy.var'): Asymptotic variance of sqrt(n) * (mu_hat - mu)
|
|
3154
|
+
- 'var': Finite-sample variance = asy_var / n
|
|
3155
|
+
- 'std_err' (or 'std.err'): Standard error = sqrt(var)
|
|
3156
|
+
- 'ci_mu_hat' (or 'CI.mu.hat'): Confidence interval [lower, upper]
|
|
3157
|
+
|
|
3158
|
+
R-style dot-separated keys (e.g., 'mu.hat') are retained as
|
|
3159
|
+
backward-compatible aliases and point to the same value objects.
|
|
3160
|
+
|
|
3161
|
+
References
|
|
3162
|
+
----------
|
|
3163
|
+
Fan, J., Imai, K., Lee, I., Liu, H., Ning, Y., and Yang, X. (2022).
|
|
3164
|
+
Optimal covariate balancing conditions in propensity score estimation.
|
|
3165
|
+
Journal of Business & Economic Statistics, 41(1), 97-110.
|
|
3166
|
+
https://doi.org/10.1080/07350015.2021.2002159
|
|
3167
|
+
|
|
3168
|
+
Examples
|
|
3169
|
+
--------
|
|
3170
|
+
>>> from cbps import CBPS, AsyVar
|
|
3171
|
+
>>> from cbps.datasets import load_lalonde
|
|
3172
|
+
>>> data = load_lalonde()
|
|
3173
|
+
>>> fit = CBPS('treat ~ age + educ + black + hisp', data=data, att=0)
|
|
3174
|
+
>>> result = AsyVar(Y=data['re78'].values, CBPS_obj=fit, method="oCBPS")
|
|
3175
|
+
>>> print(f"ATE: {result['mu.hat']:.3f} (SE: {result['std.err']:.3f})")
|
|
3176
|
+
"""
|
|
3177
|
+
from cbps.inference.asyvar import asy_var
|
|
3178
|
+
|
|
3179
|
+
# Check for CBPS_obj in kwargs for backward compatibility
|
|
3180
|
+
if CBPS_obj is None and 'CBPS_obj' in kwargs:
|
|
3181
|
+
CBPS_obj = kwargs['CBPS_obj']
|
|
3182
|
+
|
|
3183
|
+
# Convert CBPSResults object to dict format if necessary
|
|
3184
|
+
if CBPS_obj is not None and hasattr(CBPS_obj, 'fitted_values'):
|
|
3185
|
+
cbps_dict = {
|
|
3186
|
+
'x': CBPS_obj.x,
|
|
3187
|
+
'y': CBPS_obj.y,
|
|
3188
|
+
'fitted_values': CBPS_obj.fitted_values,
|
|
3189
|
+
'coefficients': CBPS_obj.coefficients
|
|
3190
|
+
}
|
|
3191
|
+
# Include residuals if available
|
|
3192
|
+
if hasattr(CBPS_obj, 'residuals'):
|
|
3193
|
+
cbps_dict['residuals'] = CBPS_obj.residuals
|
|
3194
|
+
CBPS_obj = cbps_dict
|
|
3195
|
+
|
|
3196
|
+
result = asy_var(
|
|
3197
|
+
Y=Y, Y_1_hat=Y_1_hat, Y_0_hat=Y_0_hat, CBPS_obj=CBPS_obj,
|
|
3198
|
+
method=method, X=X, TL=TL, pi=pi, mu=mu, CI=CI,
|
|
3199
|
+
use_observed_y=use_observed_y, **kwargs
|
|
3200
|
+
)
|
|
3201
|
+
|
|
3202
|
+
# Add snake_case aliases (retain original R-style keys)
|
|
3203
|
+
key_mapping = {
|
|
3204
|
+
'mu.hat': 'mu_hat',
|
|
3205
|
+
'asy.var': 'asy_var',
|
|
3206
|
+
'CI.mu.hat': 'ci_mu_hat',
|
|
3207
|
+
'std.err': 'std_err',
|
|
3208
|
+
}
|
|
3209
|
+
for old_key, new_key in key_mapping.items():
|
|
3210
|
+
if old_key in result:
|
|
3211
|
+
result[new_key] = result[old_key]
|
|
3212
|
+
|
|
3213
|
+
return result
|
|
3214
|
+
|
|
3215
|
+
|
|
3216
|
+
def balance(cbps_obj, enhanced: bool = False, threshold: float = 0.1,
|
|
3217
|
+
covariate_names: Optional[list] = None, *args: Any, **kwargs: Any):
|
|
3218
|
+
"""
|
|
3219
|
+
Assess covariate balance before and after CBPS weighting.
|
|
3220
|
+
|
|
3221
|
+
Computes balance statistics to evaluate the effectiveness of propensity score
|
|
3222
|
+
estimation in achieving covariate balance between treatment groups. This
|
|
3223
|
+
is a fundamental diagnostic tool for causal inference analyses.
|
|
3224
|
+
|
|
3225
|
+
Parameters
|
|
3226
|
+
----------
|
|
3227
|
+
cbps_obj : dict or CBPSResults or NPCBPSResults
|
|
3228
|
+
Fitted CBPS object containing the estimation results. Must include:
|
|
3229
|
+
- weights: final CBPS weights
|
|
3230
|
+
- x: covariate matrix
|
|
3231
|
+
- y: treatment variable
|
|
3232
|
+
Supports CBPS, CBPSContinuous, and npCBPS objects.
|
|
3233
|
+
enhanced : bool, default False
|
|
3234
|
+
If False, returns basic balance statistics format.
|
|
3235
|
+
If True, returns enhanced diagnostics including:
|
|
3236
|
+
- Improvement percentages
|
|
3237
|
+
- Summary statistics
|
|
3238
|
+
- Text-based diagnostic report
|
|
3239
|
+
threshold : float, default 0.1
|
|
3240
|
+
Threshold for determining covariate imbalance (used when enhanced=True).
|
|
3241
|
+
Standard threshold: SMD < 0.1 indicates excellent balance (Stuart 2010).
|
|
3242
|
+
covariate_names : list, optional
|
|
3243
|
+
List of covariate names for generating detailed reports. Used when enhanced=True.
|
|
3244
|
+
|
|
3245
|
+
Returns
|
|
3246
|
+
-------
|
|
3247
|
+
dict
|
|
3248
|
+
If enhanced=False (default):
|
|
3249
|
+
- balanced: balance statistics after weighting
|
|
3250
|
+
- original/unweighted: baseline unweighted statistics
|
|
3251
|
+
|
|
3252
|
+
If enhanced=True (enhanced diagnostics):
|
|
3253
|
+
Contains above keys plus:
|
|
3254
|
+
- smd_weighted/abs_corr_weighted: weighted SMDs or correlations
|
|
3255
|
+
- smd_unweighted/abs_corr_unweighted: unweighted SMDs or correlations
|
|
3256
|
+
- improvement_pct: percentage improvement in balance
|
|
3257
|
+
- n_imbalanced_before/after: number of imbalanced covariates
|
|
3258
|
+
- summary: dictionary with summary statistics
|
|
3259
|
+
- report: text-based diagnostic report
|
|
3260
|
+
|
|
3261
|
+
Notes
|
|
3262
|
+
-----
|
|
3263
|
+
**Balance Metrics:**
|
|
3264
|
+
- Binary/multi-valued treatments: Standardized mean differences (SMDs)
|
|
3265
|
+
- Continuous treatments: Absolute Pearson correlations
|
|
3266
|
+
- For npCBPS, routes to appropriate function based on treatment type
|
|
3267
|
+
|
|
3268
|
+
**Interpretation Guidelines:**
|
|
3269
|
+
- SMD < 0.1: Excellent balance
|
|
3270
|
+
- SMD 0.1-0.25: Moderate imbalance
|
|
3271
|
+
- SMD > 0.25: Severe imbalance
|
|
3272
|
+
- For correlations: closer to 0 indicates better balance
|
|
3273
|
+
|
|
3274
|
+
The enhanced diagnostic mode provides comprehensive assessment following
|
|
3275
|
+
best practices in the causal inference literature.
|
|
3276
|
+
|
|
3277
|
+
References
|
|
3278
|
+
----------
|
|
3279
|
+
Imai, K. and Ratkovic, M. (2014). Covariate Balancing Propensity Score.
|
|
3280
|
+
Journal of the Royal Statistical Society, Series B 76(1), 243-263.
|
|
3281
|
+
https://doi.org/10.1111/rssb.12027
|
|
3282
|
+
|
|
3283
|
+
Stuart, E.A. (2010). "Matching methods for causal inference: A review and
|
|
3284
|
+
a look forward." Statistical Science 25(1), 1-21.
|
|
3285
|
+
|
|
3286
|
+
Austin, P.C. (2009). "Some methods of propensity-score matching resulted
|
|
3287
|
+
in substantial bias in examining the effects of medical interventions."
|
|
3288
|
+
Statistics in Medicine 28(25), 3083-3107.
|
|
3289
|
+
|
|
3290
|
+
Examples
|
|
3291
|
+
--------
|
|
3292
|
+
>>> import cbps
|
|
3293
|
+
>>> # Fit CBPS model
|
|
3294
|
+
>>> fit = cbps.CBPS('treat ~ age + education + income', data=df)
|
|
3295
|
+
>>>
|
|
3296
|
+
>>> # Basic balance assessment (R-compatible)
|
|
3297
|
+
>>> bal = cbps.balance(fit)
|
|
3298
|
+
>>> print("Balance after weighting:", bal['balanced'])
|
|
3299
|
+
>>> print("Balance before weighting:", bal['original'])
|
|
3300
|
+
>>>
|
|
3301
|
+
>>> # Enhanced diagnostics with detailed report
|
|
3302
|
+
>>> bal_enh = cbps.balance(fit, enhanced=True, threshold=0.1)
|
|
3303
|
+
>>> print(bal_enh['report'])
|
|
3304
|
+
>>> print(f"Mean SMD after: {bal_enh['summary']['mean_smd_after']:.3f}")
|
|
3305
|
+
>>> print(f"Imbalanced covariates: {bal_enh['n_imbalanced_after']}")
|
|
3306
|
+
"""
|
|
3307
|
+
from cbps.diagnostics.balance import (
|
|
3308
|
+
balance_cbps, balance_cbps_continuous,
|
|
3309
|
+
balance_cbps_enhanced, balance_cbps_continuous_enhanced
|
|
3310
|
+
)
|
|
3311
|
+
from cbps.nonparametric.npcbps import NPCBPSResults
|
|
3312
|
+
|
|
3313
|
+
# Extract covariate names for DataFrame labeling
|
|
3314
|
+
# Skip intercept column
|
|
3315
|
+
coef_names_for_balance = None
|
|
3316
|
+
if isinstance(cbps_obj, CBPSResults):
|
|
3317
|
+
if hasattr(cbps_obj, 'coef_names') and cbps_obj.coef_names is not None:
|
|
3318
|
+
# Skip intercept column
|
|
3319
|
+
coef_names_for_balance = [name for name in cbps_obj.coef_names if name not in ['(Intercept)', 'Intercept']]
|
|
3320
|
+
elif isinstance(cbps_obj, NPCBPSResults):
|
|
3321
|
+
# Extract covariate names from NPCBPSResults.terms (patsy DesignInfo)
|
|
3322
|
+
if hasattr(cbps_obj, 'terms') and cbps_obj.terms is not None:
|
|
3323
|
+
try:
|
|
3324
|
+
coef_names_for_balance = [name for name in cbps_obj.terms.column_names
|
|
3325
|
+
if name not in ['Intercept', '(Intercept)']]
|
|
3326
|
+
except AttributeError:
|
|
3327
|
+
pass
|
|
3328
|
+
|
|
3329
|
+
# Detect object type and route to appropriate function
|
|
3330
|
+
if isinstance(cbps_obj, CBPSResults):
|
|
3331
|
+
# Convert to dict format
|
|
3332
|
+
cbps_dict = {
|
|
3333
|
+
'weights': cbps_obj.weights,
|
|
3334
|
+
'x': cbps_obj.x,
|
|
3335
|
+
'y': cbps_obj.y,
|
|
3336
|
+
'fitted_values': cbps_obj.fitted_values
|
|
3337
|
+
}
|
|
3338
|
+
elif isinstance(cbps_obj, NPCBPSResults):
|
|
3339
|
+
# npCBPS result object
|
|
3340
|
+
# Route to appropriate balance function based on treatment type
|
|
3341
|
+
cbps_dict = {
|
|
3342
|
+
'weights': cbps_obj.weights,
|
|
3343
|
+
'x': cbps_obj.x,
|
|
3344
|
+
'y': cbps_obj.y,
|
|
3345
|
+
'log_el': cbps_obj.log_el, # Include log_el to identify npCBPS
|
|
3346
|
+
}
|
|
3347
|
+
# Detect continuous treatment
|
|
3348
|
+
# Handle CategoricalDtype separately (always discrete)
|
|
3349
|
+
y_array = cbps_obj.y
|
|
3350
|
+
is_categorical = hasattr(y_array, 'dtype') and hasattr(y_array.dtype, 'name') and 'category' in str(y_array.dtype).lower()
|
|
3351
|
+
is_continuous = False
|
|
3352
|
+
if not is_categorical:
|
|
3353
|
+
try:
|
|
3354
|
+
is_continuous = np.issubdtype(y_array.dtype, np.number) and len(np.unique(y_array)) > 4
|
|
3355
|
+
except TypeError:
|
|
3356
|
+
# If dtype check fails, treat as discrete
|
|
3357
|
+
is_continuous = False
|
|
3358
|
+
|
|
3359
|
+
if is_continuous:
|
|
3360
|
+
# Continuous treatment path
|
|
3361
|
+
if enhanced:
|
|
3362
|
+
result = balance_cbps_continuous_enhanced(cbps_dict, threshold, covariate_names)
|
|
3363
|
+
else:
|
|
3364
|
+
result = balance_cbps_continuous(cbps_dict, *args, **kwargs)
|
|
3365
|
+
# Add row/column labels
|
|
3366
|
+
return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=True)
|
|
3367
|
+
else:
|
|
3368
|
+
# Discrete treatment path
|
|
3369
|
+
if enhanced:
|
|
3370
|
+
result = balance_cbps_enhanced(cbps_dict, threshold, covariate_names)
|
|
3371
|
+
else:
|
|
3372
|
+
result = balance_cbps(cbps_dict, *args, **kwargs)
|
|
3373
|
+
# Add row/column labels
|
|
3374
|
+
return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=False)
|
|
3375
|
+
elif hasattr(cbps_obj, '__class__') and cbps_obj.__class__.__name__ == 'CBMSMResults':
|
|
3376
|
+
# CBMSM result object support
|
|
3377
|
+
from cbps.diagnostics.balance_cbmsm_addon import balance_cbmsm
|
|
3378
|
+
|
|
3379
|
+
# Convert to dict format
|
|
3380
|
+
cbmsm_dict = {
|
|
3381
|
+
'y': cbps_obj.y,
|
|
3382
|
+
'x': cbps_obj.x,
|
|
3383
|
+
'weights': cbps_obj.weights,
|
|
3384
|
+
'glm_weights': cbps_obj.glm_weights,
|
|
3385
|
+
'id': cbps_obj.id,
|
|
3386
|
+
'time': cbps_obj.time
|
|
3387
|
+
}
|
|
3388
|
+
|
|
3389
|
+
# Call CBMSM-specific balance function
|
|
3390
|
+
result = balance_cbmsm(cbmsm_dict)
|
|
3391
|
+
|
|
3392
|
+
# Note: CBMSM return format differs from CBPS (includes StatBal)
|
|
3393
|
+
return result
|
|
3394
|
+
else:
|
|
3395
|
+
cbps_dict = cbps_obj
|
|
3396
|
+
|
|
3397
|
+
# Detect continuous treatment (via fitted_values dimension)
|
|
3398
|
+
# Continuous: fitted_values is 1D array
|
|
3399
|
+
# Discrete: fitted_values is 2D array or scalar
|
|
3400
|
+
if 'fitted_values' in cbps_dict:
|
|
3401
|
+
fv = cbps_dict['fitted_values']
|
|
3402
|
+
if isinstance(fv, np.ndarray) and fv.ndim == 1 and len(np.unique(cbps_dict['y'])) > 4:
|
|
3403
|
+
# Continuous treatment path
|
|
3404
|
+
if enhanced:
|
|
3405
|
+
result = balance_cbps_continuous_enhanced(cbps_dict, threshold, covariate_names)
|
|
3406
|
+
else:
|
|
3407
|
+
result = balance_cbps_continuous(cbps_dict, *args, **kwargs)
|
|
3408
|
+
# Add row/column labels
|
|
3409
|
+
return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=True)
|
|
3410
|
+
|
|
3411
|
+
# Default: discrete treatment path
|
|
3412
|
+
if enhanced:
|
|
3413
|
+
result = balance_cbps_enhanced(cbps_dict, threshold, covariate_names)
|
|
3414
|
+
else:
|
|
3415
|
+
result = balance_cbps(cbps_dict, *args, **kwargs)
|
|
3416
|
+
# Add row/column labels
|
|
3417
|
+
return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=False)
|
|
3418
|
+
|
|
3419
|
+
|
|
3420
|
+
# Import vcov_outcome
|
|
3421
|
+
from cbps.inference.vcov_outcome import vcov_outcome
|
|
3422
|
+
|
|
3423
|
+
# Import plot functions
|
|
3424
|
+
from cbps.diagnostics.plots import plot_cbps, plot_cbps_continuous, plot_cbmsm, plot_npcbps
|
|
3425
|
+
|
|
3426
|
+
# Import npCBPS_fit low-level interface
|
|
3427
|
+
from cbps.nonparametric.npcbps import npCBPS_fit
|
|
3428
|
+
|
|
3429
|
+
|
|
3430
|
+
def fit_multiple(formula, datasets, **kwargs):
|
|
3431
|
+
"""Fit CBPS on multiple datasets.
|
|
3432
|
+
|
|
3433
|
+
Useful for simulation studies and specification comparisons.
|
|
3434
|
+
|
|
3435
|
+
Parameters
|
|
3436
|
+
----------
|
|
3437
|
+
formula : str
|
|
3438
|
+
R-style formula (same for all datasets).
|
|
3439
|
+
datasets : list of pd.DataFrame
|
|
3440
|
+
Multiple datasets to estimate on.
|
|
3441
|
+
**kwargs :
|
|
3442
|
+
Additional arguments passed to CBPS().
|
|
3443
|
+
|
|
3444
|
+
Returns
|
|
3445
|
+
-------
|
|
3446
|
+
list
|
|
3447
|
+
List of CBPSResults objects. If a fit fails for a dataset, the
|
|
3448
|
+
corresponding entry is a dict with keys 'error' and 'dataset_index'.
|
|
3449
|
+
|
|
3450
|
+
Examples
|
|
3451
|
+
--------
|
|
3452
|
+
>>> results = fit_multiple('treat ~ age + educ', [df1, df2, df3], att=0)
|
|
3453
|
+
>>> successful = [r for r in results if isinstance(r, CBPSResults)]
|
|
3454
|
+
"""
|
|
3455
|
+
results = []
|
|
3456
|
+
for i, data in enumerate(datasets):
|
|
3457
|
+
try:
|
|
3458
|
+
result = CBPS(formula, data, **kwargs)
|
|
3459
|
+
results.append(result)
|
|
3460
|
+
except Exception as e:
|
|
3461
|
+
results.append({'error': str(e), 'dataset_index': i})
|
|
3462
|
+
return results
|