diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/staggered.py
ADDED
|
@@ -0,0 +1,3895 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Staggered Difference-in-Differences estimators.
|
|
3
|
+
|
|
4
|
+
Implements modern methods for DiD with variation in treatment timing,
|
|
5
|
+
including the Callaway-Sant'Anna (2021) estimator.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from scipy import linalg as scipy_linalg
|
|
14
|
+
|
|
15
|
+
from diff_diff.linalg import (
|
|
16
|
+
_check_propensity_diagnostics,
|
|
17
|
+
_detect_rank_deficiency,
|
|
18
|
+
_format_dropped_columns,
|
|
19
|
+
solve_logit,
|
|
20
|
+
solve_ols,
|
|
21
|
+
)
|
|
22
|
+
from diff_diff.staggered_aggregation import (
|
|
23
|
+
CallawaySantAnnaAggregationMixin,
|
|
24
|
+
)
|
|
25
|
+
from diff_diff.staggered_bootstrap import (
|
|
26
|
+
CallawaySantAnnaBootstrapMixin,
|
|
27
|
+
CSBootstrapResults,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Import from split modules
|
|
31
|
+
from diff_diff.staggered_results import (
|
|
32
|
+
CallawaySantAnnaResults,
|
|
33
|
+
GroupTimeEffect,
|
|
34
|
+
)
|
|
35
|
+
from diff_diff.utils import safe_inference, safe_inference_batch
|
|
36
|
+
|
|
37
|
+
# Re-export for backward compatibility
|
|
38
|
+
__all__ = [
|
|
39
|
+
"CallawaySantAnna",
|
|
40
|
+
"CallawaySantAnnaResults",
|
|
41
|
+
"CSBootstrapResults",
|
|
42
|
+
"GroupTimeEffect",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
# Type alias for pre-computed structures
|
|
46
|
+
PrecomputedData = Dict[str, Any]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _linear_regression(
|
|
50
|
+
X: np.ndarray,
|
|
51
|
+
y: np.ndarray,
|
|
52
|
+
rank_deficient_action: str = "warn",
|
|
53
|
+
weights: Optional[np.ndarray] = None,
|
|
54
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
55
|
+
"""
|
|
56
|
+
Fit OLS regression.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
X : np.ndarray
|
|
61
|
+
Feature matrix (n_samples, n_features). Intercept added automatically.
|
|
62
|
+
y : np.ndarray
|
|
63
|
+
Outcome variable.
|
|
64
|
+
rank_deficient_action : str, default "warn"
|
|
65
|
+
Action when design matrix is rank-deficient:
|
|
66
|
+
- "warn": Issue warning and drop linearly dependent columns (default)
|
|
67
|
+
- "error": Raise ValueError
|
|
68
|
+
- "silent": Drop columns silently without warning
|
|
69
|
+
weights : np.ndarray, optional
|
|
70
|
+
Observation weights for WLS. When None, OLS is used.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
beta : np.ndarray
|
|
75
|
+
Fitted coefficients (including intercept).
|
|
76
|
+
residuals : np.ndarray
|
|
77
|
+
Residuals from the fit.
|
|
78
|
+
"""
|
|
79
|
+
n = X.shape[0]
|
|
80
|
+
# Add intercept
|
|
81
|
+
X_with_intercept = np.column_stack([np.ones(n), X])
|
|
82
|
+
|
|
83
|
+
# Use unified OLS backend (no vcov needed)
|
|
84
|
+
beta, residuals, _ = solve_ols(
|
|
85
|
+
X_with_intercept,
|
|
86
|
+
y,
|
|
87
|
+
return_vcov=False,
|
|
88
|
+
rank_deficient_action=rank_deficient_action,
|
|
89
|
+
weights=weights,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return beta, residuals
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _safe_inv(A: np.ndarray) -> np.ndarray:
|
|
96
|
+
"""Invert a square matrix with lstsq fallback for near-singular cases."""
|
|
97
|
+
try:
|
|
98
|
+
return np.linalg.solve(A, np.eye(A.shape[0]))
|
|
99
|
+
except np.linalg.LinAlgError:
|
|
100
|
+
return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class CallawaySantAnna(
|
|
104
|
+
CallawaySantAnnaBootstrapMixin,
|
|
105
|
+
CallawaySantAnnaAggregationMixin,
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
|
|
109
|
+
|
|
110
|
+
This estimator handles DiD designs with variation in treatment timing
|
|
111
|
+
(staggered adoption) and heterogeneous treatment effects. It avoids the
|
|
112
|
+
bias of traditional two-way fixed effects (TWFE) estimators by:
|
|
113
|
+
|
|
114
|
+
1. Computing group-time average treatment effects ATT(g,t) for each
|
|
115
|
+
cohort g (units first treated in period g) and time t.
|
|
116
|
+
2. Aggregating these to summary measures (overall ATT, event study, etc.)
|
|
117
|
+
using appropriate weights.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
control_group : str, default="never_treated"
|
|
122
|
+
Which units to use as controls:
|
|
123
|
+
- "never_treated": Use only never-treated units (recommended)
|
|
124
|
+
- "not_yet_treated": Use never-treated and not-yet-treated units
|
|
125
|
+
anticipation : int, default=0
|
|
126
|
+
Number of periods before treatment where effects may occur.
|
|
127
|
+
Set to > 0 if treatment effects can begin before the official
|
|
128
|
+
treatment date.
|
|
129
|
+
estimation_method : str, default="dr"
|
|
130
|
+
Estimation method:
|
|
131
|
+
- "dr": Doubly robust (recommended)
|
|
132
|
+
- "ipw": Inverse probability weighting
|
|
133
|
+
- "reg": Outcome regression
|
|
134
|
+
alpha : float, default=0.05
|
|
135
|
+
Significance level for confidence intervals.
|
|
136
|
+
cluster : str, optional
|
|
137
|
+
Column name for cluster-robust standard errors.
|
|
138
|
+
Defaults to unit-level clustering.
|
|
139
|
+
n_bootstrap : int, default=0
|
|
140
|
+
Number of bootstrap iterations for inference.
|
|
141
|
+
If 0, uses analytical standard errors.
|
|
142
|
+
Recommended: 999 or more for reliable inference.
|
|
143
|
+
|
|
144
|
+
.. note:: Memory Usage
|
|
145
|
+
The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
|
|
146
|
+
float64 array. For large datasets, this can be significant:
|
|
147
|
+
- 1K bootstrap × 10K units = ~80 MB
|
|
148
|
+
- 10K bootstrap × 100K units = ~8 GB
|
|
149
|
+
Consider reducing n_bootstrap if memory is constrained.
|
|
150
|
+
|
|
151
|
+
bootstrap_weights : str, default="rademacher"
|
|
152
|
+
Type of weights for multiplier bootstrap:
|
|
153
|
+
- "rademacher": +1/-1 with equal probability (standard choice)
|
|
154
|
+
- "mammen": Two-point distribution (asymptotically valid, matches skewness)
|
|
155
|
+
- "webb": Six-point distribution (recommended when n_clusters < 20)
|
|
156
|
+
seed : int, optional
|
|
157
|
+
Random seed for reproducibility.
|
|
158
|
+
rank_deficient_action : str, default="warn"
|
|
159
|
+
Action when design matrix is rank-deficient (linearly dependent columns):
|
|
160
|
+
- "warn": Issue warning and drop linearly dependent columns (default)
|
|
161
|
+
- "error": Raise ValueError
|
|
162
|
+
- "silent": Drop columns silently without warning
|
|
163
|
+
base_period : str, default="varying"
|
|
164
|
+
Method for selecting the base (reference) period for computing
|
|
165
|
+
ATT(g,t). Options:
|
|
166
|
+
- "varying": For pre-treatment periods (t < g - anticipation), use
|
|
167
|
+
t-1 as base (consecutive comparisons). For post-treatment, use
|
|
168
|
+
g-1-anticipation. Requires t-1 to exist in data.
|
|
169
|
+
- "universal": Always use g-1-anticipation as base period.
|
|
170
|
+
Both produce identical post-treatment effects. Matches R's
|
|
171
|
+
did::att_gt() base_period parameter.
|
|
172
|
+
cband : bool, default=True
|
|
173
|
+
Whether to compute simultaneous confidence bands (sup-t) for
|
|
174
|
+
event study aggregation. Requires ``n_bootstrap > 0``.
|
|
175
|
+
When True, results include ``cband_crit_value`` and per-event-time
|
|
176
|
+
``cband_conf_int`` entries controlling family-wise error rate.
|
|
177
|
+
pscore_trim : float, default=0.01
|
|
178
|
+
Trimming bound for propensity scores. Scores are clipped to
|
|
179
|
+
``[pscore_trim, 1 - pscore_trim]`` before weight computation
|
|
180
|
+
in IPW and DR estimation. Must be in ``(0, 0.5)``.
|
|
181
|
+
panel : bool, default=True
|
|
182
|
+
Whether the data is a balanced/unbalanced panel (units observed
|
|
183
|
+
across multiple time periods). Set to ``False`` for stationary
|
|
184
|
+
repeated cross-sections where each observation has a unique unit
|
|
185
|
+
ID and units do not repeat across periods. Requires that the
|
|
186
|
+
cross-sectional samples are drawn from the same population in
|
|
187
|
+
each period (stationarity). Uses cross-sectional DRDID
|
|
188
|
+
(Sant'Anna & Zhao 2020, Section 4) with per-observation influence
|
|
189
|
+
functions.
|
|
190
|
+
epv_threshold : float, default=10
|
|
191
|
+
Events Per Variable threshold for propensity score logit.
|
|
192
|
+
When the ratio of minority-class observations to predictor
|
|
193
|
+
variables (excluding intercept) falls below this value, a
|
|
194
|
+
warning is emitted (or ``ValueError`` raised if
|
|
195
|
+
``rank_deficient_action="error"``). Based on Peduzzi et al.
|
|
196
|
+
(1996). Only applies to IPW and DR estimation methods.
|
|
197
|
+
Use ``diagnose_propensity()`` for a pre-estimation check across
|
|
198
|
+
all cohorts.
|
|
199
|
+
pscore_fallback : str, default="error"
|
|
200
|
+
Action when propensity score estimation fails entirely
|
|
201
|
+
(``LinAlgError`` or ``ValueError`` from IRLS):
|
|
202
|
+
- "error": Raise the exception (default). Ensures the user is
|
|
203
|
+
aware of estimation failures.
|
|
204
|
+
- "unconditional": Fall back to unconditional propensity
|
|
205
|
+
with a warning. For IPW, this drops all covariates. For DR,
|
|
206
|
+
the propensity model becomes unconditional but outcome
|
|
207
|
+
regression still uses covariates.
|
|
208
|
+
When ``rank_deficient_action="error"``, errors are always
|
|
209
|
+
re-raised regardless of this setting.
|
|
210
|
+
|
|
211
|
+
Attributes
|
|
212
|
+
----------
|
|
213
|
+
results_ : CallawaySantAnnaResults
|
|
214
|
+
Estimation results after calling fit().
|
|
215
|
+
is_fitted_ : bool
|
|
216
|
+
Whether the model has been fitted.
|
|
217
|
+
|
|
218
|
+
Examples
|
|
219
|
+
--------
|
|
220
|
+
Basic usage:
|
|
221
|
+
|
|
222
|
+
>>> import pandas as pd
|
|
223
|
+
>>> from diff_diff import CallawaySantAnna
|
|
224
|
+
>>>
|
|
225
|
+
>>> # Panel data with staggered treatment
|
|
226
|
+
>>> # 'first_treat' = period when unit was first treated (0 if never treated)
|
|
227
|
+
>>> data = pd.DataFrame({
|
|
228
|
+
... 'unit': [...],
|
|
229
|
+
... 'time': [...],
|
|
230
|
+
... 'outcome': [...],
|
|
231
|
+
... 'first_treat': [...] # 0 for never-treated, else first treatment period
|
|
232
|
+
... })
|
|
233
|
+
>>>
|
|
234
|
+
>>> cs = CallawaySantAnna()
|
|
235
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
236
|
+
... time='time', first_treat='first_treat')
|
|
237
|
+
>>>
|
|
238
|
+
>>> results.print_summary()
|
|
239
|
+
|
|
240
|
+
With event study aggregation:
|
|
241
|
+
|
|
242
|
+
>>> cs = CallawaySantAnna()
|
|
243
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
244
|
+
... time='time', first_treat='first_treat',
|
|
245
|
+
... aggregate='event_study')
|
|
246
|
+
>>>
|
|
247
|
+
>>> # Plot event study
|
|
248
|
+
>>> from diff_diff import plot_event_study
|
|
249
|
+
>>> plot_event_study(results)
|
|
250
|
+
|
|
251
|
+
With covariate adjustment (conditional parallel trends):
|
|
252
|
+
|
|
253
|
+
>>> # When parallel trends only holds conditional on covariates
|
|
254
|
+
>>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
|
|
255
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
256
|
+
... time='time', first_treat='first_treat',
|
|
257
|
+
... covariates=['age', 'income'])
|
|
258
|
+
>>>
|
|
259
|
+
>>> # DR is recommended: consistent if either outcome model
|
|
260
|
+
>>> # or propensity model is correctly specified
|
|
261
|
+
|
|
262
|
+
Notes
|
|
263
|
+
-----
|
|
264
|
+
The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
|
|
265
|
+
approach: instead of estimating a single treatment effect, they estimate
|
|
266
|
+
ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
|
|
267
|
+
problem where already-treated units act as controls.
|
|
268
|
+
|
|
269
|
+
The ATT(g,t) is identified under parallel trends conditional on covariates:
|
|
270
|
+
|
|
271
|
+
E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
|
|
272
|
+
|
|
273
|
+
where G=g indicates treatment cohort g and C=1 indicates control units.
|
|
274
|
+
This uses g-1 as the base period, which applies to post-treatment (t >= g).
|
|
275
|
+
With base_period="varying" (default), pre-treatment uses t-1 as base for
|
|
276
|
+
consecutive comparisons useful in parallel trends diagnostics.
|
|
277
|
+
|
|
278
|
+
References
|
|
279
|
+
----------
|
|
280
|
+
Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
|
|
281
|
+
multiple time periods. Journal of Econometrics, 225(2), 200-230.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
control_group: str = "never_treated",
|
|
287
|
+
anticipation: int = 0,
|
|
288
|
+
estimation_method: str = "dr",
|
|
289
|
+
alpha: float = 0.05,
|
|
290
|
+
cluster: Optional[str] = None,
|
|
291
|
+
n_bootstrap: int = 0,
|
|
292
|
+
bootstrap_weights: Optional[str] = None,
|
|
293
|
+
seed: Optional[int] = None,
|
|
294
|
+
rank_deficient_action: str = "warn",
|
|
295
|
+
base_period: str = "varying",
|
|
296
|
+
cband: bool = True,
|
|
297
|
+
pscore_trim: float = 0.01,
|
|
298
|
+
panel: bool = True,
|
|
299
|
+
epv_threshold: float = 10,
|
|
300
|
+
pscore_fallback: str = "error",
|
|
301
|
+
):
|
|
302
|
+
import warnings
|
|
303
|
+
|
|
304
|
+
if control_group not in ["never_treated", "not_yet_treated"]:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"control_group must be 'never_treated' or 'not_yet_treated', "
|
|
307
|
+
f"got '{control_group}'"
|
|
308
|
+
)
|
|
309
|
+
if estimation_method not in ["dr", "ipw", "reg"]:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'"
|
|
312
|
+
)
|
|
313
|
+
if not (0 < pscore_trim < 0.5):
|
|
314
|
+
raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}")
|
|
315
|
+
if epv_threshold <= 0:
|
|
316
|
+
raise ValueError(f"epv_threshold must be > 0, got {epv_threshold}")
|
|
317
|
+
if pscore_fallback not in ["error", "unconditional"]:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"pscore_fallback must be 'error' or 'unconditional', " f"got '{pscore_fallback}'"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Default to rademacher if not specified
|
|
323
|
+
if bootstrap_weights is None:
|
|
324
|
+
bootstrap_weights = "rademacher"
|
|
325
|
+
|
|
326
|
+
if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
|
|
327
|
+
raise ValueError(
|
|
328
|
+
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
|
|
329
|
+
f"got '{bootstrap_weights}'"
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
if rank_deficient_action not in ["warn", "error", "silent"]:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
|
|
335
|
+
f"got '{rank_deficient_action}'"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if base_period not in ["varying", "universal"]:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"base_period must be 'varying' or 'universal', " f"got '{base_period}'"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
self.control_group = control_group
|
|
344
|
+
self.anticipation = anticipation
|
|
345
|
+
self.estimation_method = estimation_method
|
|
346
|
+
self.alpha = alpha
|
|
347
|
+
self.cluster = cluster
|
|
348
|
+
self.n_bootstrap = n_bootstrap
|
|
349
|
+
self.bootstrap_weights = bootstrap_weights
|
|
350
|
+
self.seed = seed
|
|
351
|
+
self.rank_deficient_action = rank_deficient_action
|
|
352
|
+
self.base_period = base_period
|
|
353
|
+
|
|
354
|
+
self.cband = cband
|
|
355
|
+
self.pscore_trim = pscore_trim
|
|
356
|
+
self.panel = panel
|
|
357
|
+
self.epv_threshold = epv_threshold
|
|
358
|
+
self.pscore_fallback = pscore_fallback
|
|
359
|
+
|
|
360
|
+
self.is_fitted_ = False
|
|
361
|
+
self.results_: Optional[CallawaySantAnnaResults] = None
|
|
362
|
+
|
|
363
|
+
def diagnose_propensity(
|
|
364
|
+
self,
|
|
365
|
+
df: pd.DataFrame,
|
|
366
|
+
outcome: str,
|
|
367
|
+
unit: str,
|
|
368
|
+
time: str,
|
|
369
|
+
first_treat: str,
|
|
370
|
+
covariates: Optional[List[str]] = None,
|
|
371
|
+
) -> pd.DataFrame:
|
|
372
|
+
"""
|
|
373
|
+
Check Events Per Variable (EPV) across all cohorts without estimation.
|
|
374
|
+
|
|
375
|
+
Examines the data to identify cohorts where propensity score logit may
|
|
376
|
+
be unreliable due to too few events per covariate. Based on Peduzzi
|
|
377
|
+
et al. (1996).
|
|
378
|
+
|
|
379
|
+
This is a raw-count heuristic: it uses total cohort/control unit
|
|
380
|
+
counts without filtering for missing outcomes, zero survey weights,
|
|
381
|
+
or period-specific validity. The actual fit-time EPV (stored in
|
|
382
|
+
``results.epv_diagnostics``) may be lower because ``fit()`` operates
|
|
383
|
+
on the valid base/post outcome pair and the positive-weight effective
|
|
384
|
+
sample. Use this method as a quick pre-check; rely on
|
|
385
|
+
``results.epv_diagnostics`` for authoritative per-cell EPV.
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
df, outcome, unit, time, first_treat, covariates
|
|
390
|
+
Same arguments as ``fit()``.
|
|
391
|
+
|
|
392
|
+
Returns
|
|
393
|
+
-------
|
|
394
|
+
pd.DataFrame
|
|
395
|
+
Per-cohort EPV diagnostics with columns: group, n_treated,
|
|
396
|
+
n_control, n_covariates, n_params, epv, status.
|
|
397
|
+
"""
|
|
398
|
+
if not self.panel:
|
|
399
|
+
raise NotImplementedError(
|
|
400
|
+
"diagnose_propensity() is not yet supported for repeated "
|
|
401
|
+
"cross-section data (panel=False). Use fit() with covariates "
|
|
402
|
+
"and check results.epv_diagnostics instead."
|
|
403
|
+
)
|
|
404
|
+
if self.control_group == "not_yet_treated":
|
|
405
|
+
raise NotImplementedError(
|
|
406
|
+
"diagnose_propensity() is not yet supported for "
|
|
407
|
+
"control_group='not_yet_treated' because the control set "
|
|
408
|
+
"varies per (g, t) cell. Use fit() with covariates and "
|
|
409
|
+
"check results.epv_diagnostics instead."
|
|
410
|
+
)
|
|
411
|
+
if self.estimation_method == "reg":
|
|
412
|
+
return pd.DataFrame(
|
|
413
|
+
columns=[
|
|
414
|
+
"group",
|
|
415
|
+
"n_treated",
|
|
416
|
+
"n_control",
|
|
417
|
+
"n_covariates",
|
|
418
|
+
"n_params",
|
|
419
|
+
"epv",
|
|
420
|
+
"status",
|
|
421
|
+
]
|
|
422
|
+
)
|
|
423
|
+
if not covariates:
|
|
424
|
+
return pd.DataFrame(
|
|
425
|
+
columns=[
|
|
426
|
+
"group",
|
|
427
|
+
"n_treated",
|
|
428
|
+
"n_control",
|
|
429
|
+
"n_covariates",
|
|
430
|
+
"n_params",
|
|
431
|
+
"epv",
|
|
432
|
+
"status",
|
|
433
|
+
]
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Normalize np.inf → 0 for never-treated encoding (same as fit())
|
|
437
|
+
df = df.copy()
|
|
438
|
+
_inf_mask_diag = df[first_treat].isin([np.inf, float("inf")])
|
|
439
|
+
if _inf_mask_diag.any():
|
|
440
|
+
n_inf_units = df.loc[_inf_mask_diag, unit].nunique()
|
|
441
|
+
warnings.warn(
|
|
442
|
+
f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 "
|
|
443
|
+
f"(never-treated). Use first_treat=0 to suppress this warning.",
|
|
444
|
+
UserWarning,
|
|
445
|
+
stacklevel=2,
|
|
446
|
+
)
|
|
447
|
+
df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
|
|
448
|
+
|
|
449
|
+
# Compute time_periods and treatment_groups (same logic as fit())
|
|
450
|
+
time_periods = sorted(df[time].unique())
|
|
451
|
+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
|
|
452
|
+
precomputed = self._precompute_structures(
|
|
453
|
+
df,
|
|
454
|
+
outcome,
|
|
455
|
+
unit,
|
|
456
|
+
time,
|
|
457
|
+
first_treat,
|
|
458
|
+
covariates,
|
|
459
|
+
time_periods=time_periods,
|
|
460
|
+
treatment_groups=treatment_groups,
|
|
461
|
+
)
|
|
462
|
+
cohort_masks = precomputed["cohort_masks"]
|
|
463
|
+
never_treated_mask = precomputed["never_treated_mask"]
|
|
464
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
465
|
+
n_covariates = len(covariates)
|
|
466
|
+
n_params = n_covariates # predictor count, excluding intercept (Peduzzi convention)
|
|
467
|
+
|
|
468
|
+
rows = []
|
|
469
|
+
for g in sorted(cohort_masks.keys()):
|
|
470
|
+
treated_mask = cohort_masks[g]
|
|
471
|
+
if self.control_group == "never_treated":
|
|
472
|
+
control_mask = never_treated_mask
|
|
473
|
+
else:
|
|
474
|
+
base_period_val = g - 1 - self.anticipation
|
|
475
|
+
nyt_threshold = base_period_val + self.anticipation
|
|
476
|
+
control_mask = never_treated_mask | (
|
|
477
|
+
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
n_treated = int(np.sum(treated_mask))
|
|
481
|
+
n_control = int(np.sum(control_mask))
|
|
482
|
+
n_events = min(n_treated, n_control)
|
|
483
|
+
epv = n_events / n_params if n_params > 0 else float("inf")
|
|
484
|
+
|
|
485
|
+
if epv >= self.epv_threshold:
|
|
486
|
+
status = "ok"
|
|
487
|
+
elif epv >= 2:
|
|
488
|
+
status = "low"
|
|
489
|
+
else:
|
|
490
|
+
status = "critical"
|
|
491
|
+
|
|
492
|
+
rows.append(
|
|
493
|
+
{
|
|
494
|
+
"group": g,
|
|
495
|
+
"n_treated": n_treated,
|
|
496
|
+
"n_control": n_control,
|
|
497
|
+
"n_covariates": n_covariates,
|
|
498
|
+
"n_params": n_params,
|
|
499
|
+
"epv": round(epv, 1),
|
|
500
|
+
"status": status,
|
|
501
|
+
}
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
return pd.DataFrame(rows)
|
|
505
|
+
|
|
506
|
+
@staticmethod
|
|
507
|
+
def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
|
|
508
|
+
"""Create unit-level ResolvedSurveyDesign for panel IF-based variance.
|
|
509
|
+
|
|
510
|
+
Survey design columns are constant within units (validated upstream).
|
|
511
|
+
This extracts one row per unit, aligned to ``all_units`` ordering.
|
|
512
|
+
"""
|
|
513
|
+
from diff_diff.survey import collapse_survey_to_unit_level
|
|
514
|
+
|
|
515
|
+
return collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units)
|
|
516
|
+
|
|
517
|
+
def _precompute_structures(
|
|
518
|
+
self,
|
|
519
|
+
df: pd.DataFrame,
|
|
520
|
+
outcome: str,
|
|
521
|
+
unit: str,
|
|
522
|
+
time: str,
|
|
523
|
+
first_treat: str,
|
|
524
|
+
covariates: Optional[List[str]],
|
|
525
|
+
time_periods: List[Any],
|
|
526
|
+
treatment_groups: List[Any],
|
|
527
|
+
resolved_survey=None,
|
|
528
|
+
) -> PrecomputedData:
|
|
529
|
+
"""
|
|
530
|
+
Pre-compute data structures for efficient ATT(g,t) computation.
|
|
531
|
+
|
|
532
|
+
This pivots data to wide format and pre-computes:
|
|
533
|
+
- Outcome matrix (units x time periods)
|
|
534
|
+
- Covariate matrix (units x covariates) from base period
|
|
535
|
+
- Unit cohort membership masks
|
|
536
|
+
- Control unit masks
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
PrecomputedData
|
|
541
|
+
Dictionary with pre-computed structures.
|
|
542
|
+
"""
|
|
543
|
+
# Get unique units and their cohort assignments
|
|
544
|
+
unit_info = df.groupby(unit)[first_treat].first()
|
|
545
|
+
all_units = unit_info.index.values
|
|
546
|
+
unit_cohorts = unit_info.values
|
|
547
|
+
|
|
548
|
+
# Create unit index mapping for fast lookups
|
|
549
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
550
|
+
|
|
551
|
+
# Pivot outcome to wide format: rows = units, columns = time periods
|
|
552
|
+
outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
|
|
553
|
+
# Reindex to ensure all units are present (handles unbalanced panels)
|
|
554
|
+
outcome_wide = outcome_wide.reindex(all_units)
|
|
555
|
+
outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
|
|
556
|
+
period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
|
|
557
|
+
|
|
558
|
+
# Pre-compute cohort masks (boolean arrays)
|
|
559
|
+
cohort_masks = {}
|
|
560
|
+
for g in treatment_groups:
|
|
561
|
+
cohort_masks[g] = unit_cohorts == g
|
|
562
|
+
|
|
563
|
+
# Never-treated mask
|
|
564
|
+
# np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
|
|
565
|
+
never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
|
|
566
|
+
|
|
567
|
+
# Pre-compute covariate matrices by time period if needed
|
|
568
|
+
# (covariates are retrieved from the base period of each comparison)
|
|
569
|
+
covariate_by_period = None
|
|
570
|
+
if covariates:
|
|
571
|
+
covariate_by_period = {}
|
|
572
|
+
for t in time_periods:
|
|
573
|
+
period_data = df[df[time] == t].set_index(unit)
|
|
574
|
+
period_cov = period_data.reindex(all_units)[covariates]
|
|
575
|
+
covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
|
|
576
|
+
|
|
577
|
+
is_balanced = not np.any(np.isnan(outcome_matrix))
|
|
578
|
+
|
|
579
|
+
# Extract per-unit survey weights (one weight per unit)
|
|
580
|
+
if resolved_survey is not None:
|
|
581
|
+
sw_by_unit = (
|
|
582
|
+
pd.Series(resolved_survey.weights, index=df.index).groupby(df[unit]).first()
|
|
583
|
+
)
|
|
584
|
+
survey_weights_arr = sw_by_unit.reindex(all_units).values
|
|
585
|
+
else:
|
|
586
|
+
survey_weights_arr = None
|
|
587
|
+
|
|
588
|
+
resolved_survey_unit = (
|
|
589
|
+
self._collapse_survey_to_unit_level(resolved_survey, df, unit, all_units)
|
|
590
|
+
if resolved_survey is not None
|
|
591
|
+
else None
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
return {
|
|
595
|
+
"all_units": all_units,
|
|
596
|
+
"unit_to_idx": unit_to_idx,
|
|
597
|
+
"unit_cohorts": unit_cohorts,
|
|
598
|
+
"outcome_matrix": outcome_matrix,
|
|
599
|
+
"period_to_col": period_to_col,
|
|
600
|
+
"cohort_masks": cohort_masks,
|
|
601
|
+
"never_treated_mask": never_treated_mask,
|
|
602
|
+
"covariate_by_period": covariate_by_period,
|
|
603
|
+
"time_periods": time_periods,
|
|
604
|
+
"is_balanced": is_balanced,
|
|
605
|
+
"is_panel": True,
|
|
606
|
+
"canonical_size": len(all_units),
|
|
607
|
+
"survey_weights": survey_weights_arr,
|
|
608
|
+
"resolved_survey": resolved_survey,
|
|
609
|
+
"resolved_survey_unit": resolved_survey_unit,
|
|
610
|
+
"df_survey": (
|
|
611
|
+
resolved_survey_unit.df_survey if resolved_survey_unit is not None else None
|
|
612
|
+
),
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
def _compute_att_gt_fast(
|
|
616
|
+
self,
|
|
617
|
+
precomputed: PrecomputedData,
|
|
618
|
+
g: Any,
|
|
619
|
+
t: Any,
|
|
620
|
+
covariates: Optional[List[str]],
|
|
621
|
+
pscore_cache: Optional[Dict] = None,
|
|
622
|
+
cho_cache: Optional[Dict] = None,
|
|
623
|
+
epv_diagnostics: Optional[Dict] = None,
|
|
624
|
+
) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]:
|
|
625
|
+
"""
|
|
626
|
+
Compute ATT(g,t) using pre-computed data structures (fast version).
|
|
627
|
+
|
|
628
|
+
Uses vectorized numpy operations on pre-pivoted outcome matrix
|
|
629
|
+
instead of repeated pandas filtering.
|
|
630
|
+
|
|
631
|
+
Returns
|
|
632
|
+
-------
|
|
633
|
+
att_gt : float or None
|
|
634
|
+
se_gt : float
|
|
635
|
+
n_treated : int
|
|
636
|
+
n_control : int
|
|
637
|
+
inf_func_info : dict or None
|
|
638
|
+
survey_weight_sum : float or None
|
|
639
|
+
Sum of survey weights for treated units (for aggregation weighting).
|
|
640
|
+
"""
|
|
641
|
+
period_to_col = precomputed["period_to_col"]
|
|
642
|
+
outcome_matrix = precomputed["outcome_matrix"]
|
|
643
|
+
cohort_masks = precomputed["cohort_masks"]
|
|
644
|
+
never_treated_mask = precomputed["never_treated_mask"]
|
|
645
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
646
|
+
covariate_by_period = precomputed["covariate_by_period"]
|
|
647
|
+
|
|
648
|
+
# Base period selection based on mode
|
|
649
|
+
if self.base_period == "universal":
|
|
650
|
+
# Universal: always use g - 1 - anticipation
|
|
651
|
+
base_period_val = g - 1 - self.anticipation
|
|
652
|
+
else: # varying
|
|
653
|
+
if t < g - self.anticipation:
|
|
654
|
+
# Pre-treatment: use t - 1 (consecutive comparison)
|
|
655
|
+
base_period_val = t - 1
|
|
656
|
+
else:
|
|
657
|
+
# Post-treatment: use g - 1 - anticipation
|
|
658
|
+
base_period_val = g - 1 - self.anticipation
|
|
659
|
+
|
|
660
|
+
if base_period_val not in period_to_col:
|
|
661
|
+
# Base period must exist; no fallback to maintain methodological consistency
|
|
662
|
+
return None, 0.0, 0, 0, None, None
|
|
663
|
+
|
|
664
|
+
# Check if periods exist in the data
|
|
665
|
+
if base_period_val not in period_to_col or t not in period_to_col:
|
|
666
|
+
return None, 0.0, 0, 0, None, None
|
|
667
|
+
|
|
668
|
+
base_col = period_to_col[base_period_val]
|
|
669
|
+
post_col = period_to_col[t]
|
|
670
|
+
|
|
671
|
+
# Get treated units mask (cohort g)
|
|
672
|
+
treated_mask = cohort_masks[g]
|
|
673
|
+
|
|
674
|
+
# Get control units mask
|
|
675
|
+
if self.control_group == "never_treated":
|
|
676
|
+
control_mask = never_treated_mask
|
|
677
|
+
else: # not_yet_treated
|
|
678
|
+
# Not yet treated at BOTH time t and the base period:
|
|
679
|
+
# Controls must be untreated at whichever is later, otherwise
|
|
680
|
+
# their outcome at the base period is contaminated by treatment.
|
|
681
|
+
nyt_threshold = max(t, base_period_val) + self.anticipation
|
|
682
|
+
control_mask = never_treated_mask | (
|
|
683
|
+
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# Extract outcomes for base and post periods
|
|
687
|
+
y_base = outcome_matrix[:, base_col]
|
|
688
|
+
y_post = outcome_matrix[:, post_col]
|
|
689
|
+
|
|
690
|
+
# Compute outcome changes (vectorized)
|
|
691
|
+
outcome_change = y_post - y_base
|
|
692
|
+
|
|
693
|
+
# Filter to units with valid data (no NaN in either period)
|
|
694
|
+
valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
|
|
695
|
+
|
|
696
|
+
# Get treated and control with valid data
|
|
697
|
+
treated_valid = treated_mask & valid_mask
|
|
698
|
+
control_valid = control_mask & valid_mask
|
|
699
|
+
|
|
700
|
+
n_treated = np.sum(treated_valid)
|
|
701
|
+
n_control = np.sum(control_valid)
|
|
702
|
+
|
|
703
|
+
if n_treated == 0 or n_control == 0:
|
|
704
|
+
return None, 0.0, 0, 0, None, None
|
|
705
|
+
|
|
706
|
+
# Extract outcome changes for treated and control
|
|
707
|
+
treated_change = outcome_change[treated_valid]
|
|
708
|
+
control_change = outcome_change[control_valid]
|
|
709
|
+
|
|
710
|
+
# Extract survey weights for treated and control
|
|
711
|
+
survey_w = precomputed.get("survey_weights")
|
|
712
|
+
sw_treated = survey_w[treated_valid] if survey_w is not None else None
|
|
713
|
+
sw_control = survey_w[control_valid] if survey_w is not None else None
|
|
714
|
+
|
|
715
|
+
# Guard against zero effective mass after subpopulation filtering
|
|
716
|
+
if sw_treated is not None and np.sum(sw_treated) <= 0:
|
|
717
|
+
return None, 0.0, 0, 0, None, None
|
|
718
|
+
if sw_control is not None and np.sum(sw_control) <= 0:
|
|
719
|
+
return None, 0.0, 0, 0, None, None
|
|
720
|
+
|
|
721
|
+
# Get covariates if specified (from the base period)
|
|
722
|
+
X_treated = None
|
|
723
|
+
X_control = None
|
|
724
|
+
if covariates and covariate_by_period is not None:
|
|
725
|
+
cov_matrix = covariate_by_period[base_period_val]
|
|
726
|
+
X_treated = cov_matrix[treated_valid]
|
|
727
|
+
X_control = cov_matrix[control_valid]
|
|
728
|
+
|
|
729
|
+
# Check for missing values
|
|
730
|
+
if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
|
|
731
|
+
warnings.warn(
|
|
732
|
+
f"Missing values in covariates for group {g}, time {t}. "
|
|
733
|
+
"Falling back to unconditional estimation.",
|
|
734
|
+
UserWarning,
|
|
735
|
+
stacklevel=3,
|
|
736
|
+
)
|
|
737
|
+
X_treated = None
|
|
738
|
+
X_control = None
|
|
739
|
+
|
|
740
|
+
# Compute cache key for propensity score reuse
|
|
741
|
+
pscore_key = None
|
|
742
|
+
if pscore_cache is not None and X_treated is not None:
|
|
743
|
+
is_balanced = precomputed.get("is_balanced", False)
|
|
744
|
+
if is_balanced and self.control_group == "never_treated":
|
|
745
|
+
pscore_key = (g, base_period_val)
|
|
746
|
+
else:
|
|
747
|
+
pscore_key = (g, base_period_val, t)
|
|
748
|
+
|
|
749
|
+
# Compute cache key for Cholesky reuse (DR outcome regression)
|
|
750
|
+
cho_key = None
|
|
751
|
+
if cho_cache is not None and X_control is not None:
|
|
752
|
+
is_balanced = precomputed.get("is_balanced", False)
|
|
753
|
+
if is_balanced and self.control_group == "never_treated":
|
|
754
|
+
cho_key = base_period_val
|
|
755
|
+
else:
|
|
756
|
+
cho_key = (g, base_period_val, t)
|
|
757
|
+
|
|
758
|
+
# Estimation method
|
|
759
|
+
if self.estimation_method == "reg":
|
|
760
|
+
att_gt, se_gt, inf_func = self._outcome_regression(
|
|
761
|
+
treated_change,
|
|
762
|
+
control_change,
|
|
763
|
+
X_treated,
|
|
764
|
+
X_control,
|
|
765
|
+
sw_treated=sw_treated,
|
|
766
|
+
sw_control=sw_control,
|
|
767
|
+
)
|
|
768
|
+
elif self.estimation_method == "ipw":
|
|
769
|
+
sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None
|
|
770
|
+
epv_diag: dict = {}
|
|
771
|
+
att_gt, se_gt, inf_func = self._ipw_estimation(
|
|
772
|
+
treated_change,
|
|
773
|
+
control_change,
|
|
774
|
+
int(n_treated),
|
|
775
|
+
int(n_control),
|
|
776
|
+
X_treated,
|
|
777
|
+
X_control,
|
|
778
|
+
pscore_cache=pscore_cache,
|
|
779
|
+
pscore_key=pscore_key,
|
|
780
|
+
sw_treated=sw_treated,
|
|
781
|
+
sw_control=sw_control,
|
|
782
|
+
sw_all=sw_all,
|
|
783
|
+
context_label=f"cohort g={g}",
|
|
784
|
+
epv_diagnostics_out=epv_diag,
|
|
785
|
+
)
|
|
786
|
+
if epv_diagnostics is not None and epv_diag:
|
|
787
|
+
epv_diagnostics[(g, t)] = epv_diag
|
|
788
|
+
else: # doubly robust
|
|
789
|
+
sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None
|
|
790
|
+
epv_diag = {}
|
|
791
|
+
att_gt, se_gt, inf_func = self._doubly_robust(
|
|
792
|
+
treated_change,
|
|
793
|
+
control_change,
|
|
794
|
+
X_treated,
|
|
795
|
+
X_control,
|
|
796
|
+
pscore_cache=pscore_cache,
|
|
797
|
+
pscore_key=pscore_key,
|
|
798
|
+
cho_cache=cho_cache,
|
|
799
|
+
cho_key=cho_key,
|
|
800
|
+
sw_treated=sw_treated,
|
|
801
|
+
sw_control=sw_control,
|
|
802
|
+
sw_all=sw_all,
|
|
803
|
+
context_label=f"cohort g={g}",
|
|
804
|
+
epv_diagnostics_out=epv_diag,
|
|
805
|
+
)
|
|
806
|
+
if epv_diagnostics is not None and epv_diag:
|
|
807
|
+
epv_diagnostics[(g, t)] = epv_diag
|
|
808
|
+
|
|
809
|
+
# Package influence function info with index arrays (positions into
|
|
810
|
+
# precomputed['all_units']) for O(1) downstream lookups instead of
|
|
811
|
+
# O(n) Python dict lookups.
|
|
812
|
+
n_t = int(n_treated)
|
|
813
|
+
all_units = precomputed["all_units"]
|
|
814
|
+
treated_positions = np.where(treated_valid)[0]
|
|
815
|
+
control_positions = np.where(control_valid)[0]
|
|
816
|
+
inf_func_info = {
|
|
817
|
+
"treated_idx": treated_positions,
|
|
818
|
+
"control_idx": control_positions,
|
|
819
|
+
"treated_units": all_units[treated_positions],
|
|
820
|
+
"control_units": all_units[control_positions],
|
|
821
|
+
"treated_inf": inf_func[:n_t],
|
|
822
|
+
"control_inf": inf_func[n_t:],
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
sw_sum = float(np.sum(sw_treated)) if sw_treated is not None else None
|
|
826
|
+
return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info, sw_sum
|
|
827
|
+
|
|
828
|
+
def _compute_all_att_gt_vectorized(
|
|
829
|
+
self,
|
|
830
|
+
precomputed: PrecomputedData,
|
|
831
|
+
treatment_groups: List[Any],
|
|
832
|
+
time_periods: List[Any],
|
|
833
|
+
min_period: Any,
|
|
834
|
+
) -> Tuple[Dict, Dict, Dict]:
|
|
835
|
+
"""
|
|
836
|
+
Vectorized computation of all ATT(g,t) for the no-covariates regression case.
|
|
837
|
+
|
|
838
|
+
This inlines the simple difference-in-means path from _outcome_regression()
|
|
839
|
+
and eliminates per-(g,t) Python function call overhead.
|
|
840
|
+
|
|
841
|
+
Returns
|
|
842
|
+
-------
|
|
843
|
+
group_time_effects : dict
|
|
844
|
+
Mapping (g, t) -> effect dict.
|
|
845
|
+
influence_func_info : dict
|
|
846
|
+
Mapping (g, t) -> influence function info dict.
|
|
847
|
+
"""
|
|
848
|
+
period_to_col = precomputed["period_to_col"]
|
|
849
|
+
outcome_matrix = precomputed["outcome_matrix"]
|
|
850
|
+
cohort_masks = precomputed["cohort_masks"]
|
|
851
|
+
never_treated_mask = precomputed["never_treated_mask"]
|
|
852
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
853
|
+
survey_w = precomputed.get("survey_weights")
|
|
854
|
+
|
|
855
|
+
group_time_effects = {}
|
|
856
|
+
influence_func_info = {}
|
|
857
|
+
skipped_missing_period: List[Tuple] = []
|
|
858
|
+
skipped_empty_cell: List[Tuple] = []
|
|
859
|
+
|
|
860
|
+
# Collect all valid (g, t, base_col, post_col) tuples
|
|
861
|
+
tasks = []
|
|
862
|
+
for g in treatment_groups:
|
|
863
|
+
if self.base_period == "universal":
|
|
864
|
+
universal_base = g - 1 - self.anticipation
|
|
865
|
+
valid_periods = [t for t in time_periods if t != universal_base]
|
|
866
|
+
else:
|
|
867
|
+
valid_periods = [
|
|
868
|
+
t for t in time_periods if t >= g - self.anticipation or t > min_period
|
|
869
|
+
]
|
|
870
|
+
|
|
871
|
+
for t in valid_periods:
|
|
872
|
+
# Base period selection
|
|
873
|
+
if self.base_period == "universal":
|
|
874
|
+
base_period_val = g - 1 - self.anticipation
|
|
875
|
+
else:
|
|
876
|
+
if t < g - self.anticipation:
|
|
877
|
+
base_period_val = t - 1
|
|
878
|
+
else:
|
|
879
|
+
base_period_val = g - 1 - self.anticipation
|
|
880
|
+
|
|
881
|
+
if base_period_val not in period_to_col or t not in period_to_col:
|
|
882
|
+
skipped_missing_period.append((g, t))
|
|
883
|
+
continue
|
|
884
|
+
|
|
885
|
+
tasks.append(
|
|
886
|
+
(g, t, period_to_col[base_period_val], period_to_col[t], base_period_val)
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
# Process all tasks
|
|
890
|
+
atts = []
|
|
891
|
+
ses = []
|
|
892
|
+
task_keys = []
|
|
893
|
+
|
|
894
|
+
for g, t, base_col, post_col, base_period_val in tasks:
|
|
895
|
+
treated_mask = cohort_masks[g]
|
|
896
|
+
|
|
897
|
+
if self.control_group == "never_treated":
|
|
898
|
+
control_mask = never_treated_mask
|
|
899
|
+
else:
|
|
900
|
+
# Controls must be untreated at both t and base_period_val
|
|
901
|
+
nyt_threshold = max(t, base_period_val) + self.anticipation
|
|
902
|
+
control_mask = never_treated_mask | (
|
|
903
|
+
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
y_base = outcome_matrix[:, base_col]
|
|
907
|
+
y_post = outcome_matrix[:, post_col]
|
|
908
|
+
outcome_change = y_post - y_base
|
|
909
|
+
valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
|
|
910
|
+
|
|
911
|
+
treated_valid = treated_mask & valid_mask
|
|
912
|
+
control_valid = control_mask & valid_mask
|
|
913
|
+
|
|
914
|
+
n_treated = np.sum(treated_valid)
|
|
915
|
+
n_control = np.sum(control_valid)
|
|
916
|
+
|
|
917
|
+
if n_treated == 0 or n_control == 0:
|
|
918
|
+
skipped_empty_cell.append((g, t))
|
|
919
|
+
continue
|
|
920
|
+
|
|
921
|
+
treated_change = outcome_change[treated_valid]
|
|
922
|
+
control_change = outcome_change[control_valid]
|
|
923
|
+
|
|
924
|
+
n_t = int(n_treated)
|
|
925
|
+
n_c = int(n_control)
|
|
926
|
+
|
|
927
|
+
# Inline no-covariates regression (difference in means)
|
|
928
|
+
if survey_w is not None:
|
|
929
|
+
sw_t = survey_w[treated_valid]
|
|
930
|
+
sw_c = survey_w[control_valid]
|
|
931
|
+
# Guard against zero effective mass
|
|
932
|
+
if np.sum(sw_t) <= 0 or np.sum(sw_c) <= 0:
|
|
933
|
+
skipped_empty_cell.append((g, t))
|
|
934
|
+
continue
|
|
935
|
+
sw_t_norm = sw_t / np.sum(sw_t)
|
|
936
|
+
sw_c_norm = sw_c / np.sum(sw_c)
|
|
937
|
+
mu_t = float(np.sum(sw_t_norm * treated_change))
|
|
938
|
+
mu_c = float(np.sum(sw_c_norm * control_change))
|
|
939
|
+
att = mu_t - mu_c
|
|
940
|
+
|
|
941
|
+
# Influence function (survey-weighted)
|
|
942
|
+
inf_treated = sw_t_norm * (treated_change - mu_t)
|
|
943
|
+
inf_control = -sw_c_norm * (control_change - mu_c)
|
|
944
|
+
# SE derived from IF: sum(IF_i^2)
|
|
945
|
+
se = (
|
|
946
|
+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
|
|
947
|
+
if (n_t > 0 and n_c > 0)
|
|
948
|
+
else 0.0
|
|
949
|
+
)
|
|
950
|
+
sw_sum = float(np.sum(sw_t))
|
|
951
|
+
else:
|
|
952
|
+
att = float(np.mean(treated_change) - np.mean(control_change))
|
|
953
|
+
|
|
954
|
+
var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0
|
|
955
|
+
var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0
|
|
956
|
+
se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
|
|
957
|
+
|
|
958
|
+
# Influence function
|
|
959
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
960
|
+
inf_control = -(control_change - np.mean(control_change)) / n_c
|
|
961
|
+
sw_sum = None
|
|
962
|
+
|
|
963
|
+
gte_entry = {
|
|
964
|
+
"effect": att,
|
|
965
|
+
"se": se,
|
|
966
|
+
# t_stat, p_value, conf_int filled by batch inference below
|
|
967
|
+
"t_stat": np.nan,
|
|
968
|
+
"p_value": np.nan,
|
|
969
|
+
"conf_int": (np.nan, np.nan),
|
|
970
|
+
"n_treated": n_t,
|
|
971
|
+
"n_control": n_c,
|
|
972
|
+
}
|
|
973
|
+
if sw_sum is not None:
|
|
974
|
+
gte_entry["survey_weight_sum"] = sw_sum
|
|
975
|
+
group_time_effects[(g, t)] = gte_entry
|
|
976
|
+
|
|
977
|
+
all_units = precomputed["all_units"]
|
|
978
|
+
treated_positions = np.where(treated_valid)[0]
|
|
979
|
+
control_positions = np.where(control_valid)[0]
|
|
980
|
+
influence_func_info[(g, t)] = {
|
|
981
|
+
"treated_idx": treated_positions,
|
|
982
|
+
"control_idx": control_positions,
|
|
983
|
+
"treated_units": all_units[treated_positions],
|
|
984
|
+
"control_units": all_units[control_positions],
|
|
985
|
+
"treated_inf": inf_treated,
|
|
986
|
+
"control_inf": inf_control,
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
atts.append(att)
|
|
990
|
+
ses.append(se)
|
|
991
|
+
task_keys.append((g, t))
|
|
992
|
+
|
|
993
|
+
# Batch inference for all (g,t) pairs at once
|
|
994
|
+
if task_keys:
|
|
995
|
+
df_survey_val = precomputed.get("df_survey")
|
|
996
|
+
# Guard: replicate design with undefined df → NaN inference
|
|
997
|
+
if (
|
|
998
|
+
df_survey_val is None
|
|
999
|
+
and precomputed.get("resolved_survey_unit") is not None
|
|
1000
|
+
and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
|
|
1001
|
+
and precomputed["resolved_survey_unit"].uses_replicate_variance
|
|
1002
|
+
):
|
|
1003
|
+
df_survey_val = 0
|
|
1004
|
+
t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
|
|
1005
|
+
np.array(atts),
|
|
1006
|
+
np.array(ses),
|
|
1007
|
+
alpha=self.alpha,
|
|
1008
|
+
df=df_survey_val,
|
|
1009
|
+
)
|
|
1010
|
+
for idx, key in enumerate(task_keys):
|
|
1011
|
+
group_time_effects[key]["t_stat"] = float(t_stats[idx])
|
|
1012
|
+
group_time_effects[key]["p_value"] = float(p_values[idx])
|
|
1013
|
+
group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx]))
|
|
1014
|
+
|
|
1015
|
+
skip_info = {
|
|
1016
|
+
"missing_period": skipped_missing_period,
|
|
1017
|
+
"empty_cell": skipped_empty_cell,
|
|
1018
|
+
}
|
|
1019
|
+
return group_time_effects, influence_func_info, skip_info
|
|
1020
|
+
|
|
1021
|
+
def _compute_all_att_gt_covariate_reg(
|
|
1022
|
+
self,
|
|
1023
|
+
precomputed: PrecomputedData,
|
|
1024
|
+
treatment_groups: List[Any],
|
|
1025
|
+
time_periods: List[Any],
|
|
1026
|
+
min_period: Any,
|
|
1027
|
+
) -> Tuple[Dict, Dict, Dict]:
|
|
1028
|
+
"""
|
|
1029
|
+
Optimized computation of all ATT(g,t) for the covariate regression case.
|
|
1030
|
+
|
|
1031
|
+
Groups (g,t) pairs by their control regression key to reuse Cholesky
|
|
1032
|
+
factorizations of X^T X across pairs that share the same control design
|
|
1033
|
+
matrix.
|
|
1034
|
+
|
|
1035
|
+
Returns
|
|
1036
|
+
-------
|
|
1037
|
+
group_time_effects : dict
|
|
1038
|
+
Mapping (g, t) -> effect dict.
|
|
1039
|
+
influence_func_info : dict
|
|
1040
|
+
Mapping (g, t) -> influence function info dict.
|
|
1041
|
+
"""
|
|
1042
|
+
period_to_col = precomputed["period_to_col"]
|
|
1043
|
+
outcome_matrix = precomputed["outcome_matrix"]
|
|
1044
|
+
cohort_masks = precomputed["cohort_masks"]
|
|
1045
|
+
never_treated_mask = precomputed["never_treated_mask"]
|
|
1046
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
1047
|
+
covariate_by_period = precomputed["covariate_by_period"]
|
|
1048
|
+
is_balanced = precomputed["is_balanced"]
|
|
1049
|
+
|
|
1050
|
+
group_time_effects = {}
|
|
1051
|
+
influence_func_info = {}
|
|
1052
|
+
atts = []
|
|
1053
|
+
ses = []
|
|
1054
|
+
task_keys = []
|
|
1055
|
+
n_nan_cells = 0
|
|
1056
|
+
skipped_missing_period: List[Tuple] = []
|
|
1057
|
+
skipped_empty_cell: List[Tuple] = []
|
|
1058
|
+
|
|
1059
|
+
# Collect all valid (g, t) tasks with their base periods
|
|
1060
|
+
tasks_by_group = {} # control_key -> list of (g, t, base_period_val, base_col, post_col)
|
|
1061
|
+
for g in treatment_groups:
|
|
1062
|
+
if self.base_period == "universal":
|
|
1063
|
+
universal_base = g - 1 - self.anticipation
|
|
1064
|
+
valid_periods = [t for t in time_periods if t != universal_base]
|
|
1065
|
+
else:
|
|
1066
|
+
valid_periods = [
|
|
1067
|
+
t for t in time_periods if t >= g - self.anticipation or t > min_period
|
|
1068
|
+
]
|
|
1069
|
+
|
|
1070
|
+
for t in valid_periods:
|
|
1071
|
+
if self.base_period == "universal":
|
|
1072
|
+
base_period_val = g - 1 - self.anticipation
|
|
1073
|
+
else:
|
|
1074
|
+
if t < g - self.anticipation:
|
|
1075
|
+
base_period_val = t - 1
|
|
1076
|
+
else:
|
|
1077
|
+
base_period_val = g - 1 - self.anticipation
|
|
1078
|
+
|
|
1079
|
+
if base_period_val not in period_to_col or t not in period_to_col:
|
|
1080
|
+
skipped_missing_period.append((g, t))
|
|
1081
|
+
continue
|
|
1082
|
+
|
|
1083
|
+
# Determine control regression grouping key.
|
|
1084
|
+
# For balanced panels with never_treated control, X_control depends
|
|
1085
|
+
# only on base_period_val (control mask is time-invariant).
|
|
1086
|
+
# For not_yet_treated, the control mask excludes cohort g, so include g.
|
|
1087
|
+
if is_balanced and self.control_group == "never_treated":
|
|
1088
|
+
control_key = base_period_val
|
|
1089
|
+
else:
|
|
1090
|
+
control_key = (g, base_period_val, t)
|
|
1091
|
+
|
|
1092
|
+
tasks_by_group.setdefault(control_key, []).append(
|
|
1093
|
+
(g, t, base_period_val, period_to_col[base_period_val], period_to_col[t])
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
# Process each group of tasks sharing the same control regression
|
|
1097
|
+
for control_key, tasks in tasks_by_group.items():
|
|
1098
|
+
# Use the first task to build X_control (same for all in the group)
|
|
1099
|
+
first_g, first_t, base_period_val, first_base_col, first_post_col = tasks[0]
|
|
1100
|
+
|
|
1101
|
+
cov_matrix = covariate_by_period[base_period_val]
|
|
1102
|
+
|
|
1103
|
+
# Build control mask (same for all tasks in this group)
|
|
1104
|
+
if self.control_group == "never_treated":
|
|
1105
|
+
control_mask = never_treated_mask
|
|
1106
|
+
else:
|
|
1107
|
+
# Controls must be untreated at both t and base_period_val
|
|
1108
|
+
nyt_threshold = max(first_t, base_period_val) + self.anticipation
|
|
1109
|
+
control_mask = never_treated_mask | (
|
|
1110
|
+
(unit_cohorts > nyt_threshold) & (unit_cohorts != first_g)
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
# For balanced panels, valid_mask is all True so control_valid = control_mask
|
|
1114
|
+
if is_balanced:
|
|
1115
|
+
control_valid_base = control_mask
|
|
1116
|
+
else:
|
|
1117
|
+
y_base_first = outcome_matrix[:, first_base_col]
|
|
1118
|
+
y_post_first = outcome_matrix[:, first_post_col]
|
|
1119
|
+
valid_first = ~(np.isnan(y_base_first) | np.isnan(y_post_first))
|
|
1120
|
+
control_valid_base = control_mask & valid_first
|
|
1121
|
+
|
|
1122
|
+
X_ctrl_raw = cov_matrix[control_valid_base]
|
|
1123
|
+
|
|
1124
|
+
# Check for NaN in control covariates
|
|
1125
|
+
ctrl_has_nan = bool(np.any(np.isnan(X_ctrl_raw)))
|
|
1126
|
+
|
|
1127
|
+
# Build X_ctrl with intercept
|
|
1128
|
+
n_c_base = int(np.sum(control_valid_base))
|
|
1129
|
+
if n_c_base == 0:
|
|
1130
|
+
skipped_empty_cell.extend((g, t) for g, t, *_ in tasks)
|
|
1131
|
+
continue
|
|
1132
|
+
|
|
1133
|
+
X_ctrl = None
|
|
1134
|
+
cho = None
|
|
1135
|
+
kept_cols = None
|
|
1136
|
+
if not ctrl_has_nan:
|
|
1137
|
+
X_ctrl = np.column_stack([np.ones(n_c_base), X_ctrl_raw])
|
|
1138
|
+
|
|
1139
|
+
# One-time rank check for this control group
|
|
1140
|
+
rank, dropped_cols, _ = _detect_rank_deficiency(X_ctrl)
|
|
1141
|
+
|
|
1142
|
+
if len(dropped_cols) > 0:
|
|
1143
|
+
# Rank-deficient: force lstsq for both "warn" and "silent".
|
|
1144
|
+
# Cholesky on near-singular XtX could yield unstable coefficients.
|
|
1145
|
+
if self.rank_deficient_action == "warn":
|
|
1146
|
+
col_info = _format_dropped_columns(dropped_cols)
|
|
1147
|
+
warnings.warn(
|
|
1148
|
+
f"Rank-deficient covariate design (control_key={control_key}): "
|
|
1149
|
+
f"dropped columns {col_info}. Rank {rank} < {X_ctrl.shape[1]}. "
|
|
1150
|
+
"Using minimum-norm least-squares solution.",
|
|
1151
|
+
UserWarning,
|
|
1152
|
+
stacklevel=2,
|
|
1153
|
+
)
|
|
1154
|
+
cho = None # Force lstsq path for ALL rank-deficient cases
|
|
1155
|
+
kept_cols = np.array(
|
|
1156
|
+
[i for i in range(X_ctrl.shape[1]) if i not in dropped_cols]
|
|
1157
|
+
)
|
|
1158
|
+
else:
|
|
1159
|
+
kept_cols = None # Full rank — use all columns
|
|
1160
|
+
with np.errstate(all="ignore"):
|
|
1161
|
+
XtX = X_ctrl.T @ X_ctrl
|
|
1162
|
+
try:
|
|
1163
|
+
cho = scipy_linalg.cho_factor(XtX)
|
|
1164
|
+
except np.linalg.LinAlgError:
|
|
1165
|
+
cho = None
|
|
1166
|
+
|
|
1167
|
+
# Process each (g, t) pair in this group
|
|
1168
|
+
for g, t, bp_val, base_col, post_col in tasks:
|
|
1169
|
+
treated_mask = cohort_masks[g]
|
|
1170
|
+
|
|
1171
|
+
# Recompute control mask for not_yet_treated (varies by g, t)
|
|
1172
|
+
if self.control_group == "not_yet_treated":
|
|
1173
|
+
# Controls must be untreated at both t and base period
|
|
1174
|
+
nyt_threshold = max(t, bp_val) + self.anticipation
|
|
1175
|
+
control_mask = never_treated_mask | (
|
|
1176
|
+
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
y_base = outcome_matrix[:, base_col]
|
|
1180
|
+
y_post = outcome_matrix[:, post_col]
|
|
1181
|
+
outcome_change = y_post - y_base
|
|
1182
|
+
|
|
1183
|
+
if is_balanced:
|
|
1184
|
+
valid_mask_pair = np.ones(len(y_base), dtype=bool)
|
|
1185
|
+
else:
|
|
1186
|
+
valid_mask_pair = ~(np.isnan(y_base) | np.isnan(y_post))
|
|
1187
|
+
|
|
1188
|
+
treated_valid = treated_mask & valid_mask_pair
|
|
1189
|
+
# For balanced + never_treated, control_valid is same as control_valid_base
|
|
1190
|
+
if is_balanced and self.control_group == "never_treated":
|
|
1191
|
+
control_valid = control_valid_base
|
|
1192
|
+
else:
|
|
1193
|
+
control_valid = control_mask & valid_mask_pair
|
|
1194
|
+
|
|
1195
|
+
n_t = int(np.sum(treated_valid))
|
|
1196
|
+
n_c = int(np.sum(control_valid))
|
|
1197
|
+
|
|
1198
|
+
if n_t == 0 or n_c == 0:
|
|
1199
|
+
skipped_empty_cell.append((g, t))
|
|
1200
|
+
continue
|
|
1201
|
+
|
|
1202
|
+
treated_change = outcome_change[treated_valid]
|
|
1203
|
+
control_change = outcome_change[control_valid]
|
|
1204
|
+
|
|
1205
|
+
X_treated_pair = cov_matrix[treated_valid]
|
|
1206
|
+
X_control_pair = cov_matrix[control_valid]
|
|
1207
|
+
|
|
1208
|
+
# Check for NaN in this pair's covariates
|
|
1209
|
+
if np.any(np.isnan(X_treated_pair)) or np.any(np.isnan(X_control_pair)):
|
|
1210
|
+
# Fall back to unconditional (difference in means)
|
|
1211
|
+
warnings.warn(
|
|
1212
|
+
f"Missing values in covariates for group {g}, time {t}. "
|
|
1213
|
+
"Falling back to unconditional estimation.",
|
|
1214
|
+
UserWarning,
|
|
1215
|
+
stacklevel=3,
|
|
1216
|
+
)
|
|
1217
|
+
att = float(np.mean(treated_change) - np.mean(control_change))
|
|
1218
|
+
var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0
|
|
1219
|
+
var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0
|
|
1220
|
+
se = float(np.sqrt(var_t / n_t + var_c / n_c))
|
|
1221
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
1222
|
+
inf_control = -(control_change - np.mean(control_change)) / n_c
|
|
1223
|
+
else:
|
|
1224
|
+
# Build per-pair X_ctrl if control_valid differs from base
|
|
1225
|
+
if is_balanced and self.control_group == "never_treated" and X_ctrl is not None:
|
|
1226
|
+
pair_X_ctrl = X_ctrl
|
|
1227
|
+
pair_n_c = n_c_base
|
|
1228
|
+
else:
|
|
1229
|
+
pair_X_ctrl = np.column_stack([np.ones(n_c), X_control_pair])
|
|
1230
|
+
pair_n_c = n_c
|
|
1231
|
+
|
|
1232
|
+
# Solve for beta
|
|
1233
|
+
beta = None
|
|
1234
|
+
with np.errstate(all="ignore"):
|
|
1235
|
+
if (
|
|
1236
|
+
cho is not None
|
|
1237
|
+
and is_balanced
|
|
1238
|
+
and self.control_group == "never_treated"
|
|
1239
|
+
):
|
|
1240
|
+
# Use cached Cholesky
|
|
1241
|
+
Xty = pair_X_ctrl.T @ control_change
|
|
1242
|
+
beta = scipy_linalg.cho_solve(cho, Xty)
|
|
1243
|
+
else:
|
|
1244
|
+
# Compute per-pair Cholesky or lstsq fallback
|
|
1245
|
+
if kept_cols is not None:
|
|
1246
|
+
# Rank-deficient: skip Cholesky, use reduced lstsq
|
|
1247
|
+
pass
|
|
1248
|
+
else:
|
|
1249
|
+
pair_XtX = pair_X_ctrl.T @ pair_X_ctrl
|
|
1250
|
+
try:
|
|
1251
|
+
pair_cho = scipy_linalg.cho_factor(pair_XtX)
|
|
1252
|
+
Xty = pair_X_ctrl.T @ control_change
|
|
1253
|
+
beta = scipy_linalg.cho_solve(pair_cho, Xty)
|
|
1254
|
+
except np.linalg.LinAlgError:
|
|
1255
|
+
pass
|
|
1256
|
+
|
|
1257
|
+
if beta is None or np.any(~np.isfinite(beta)):
|
|
1258
|
+
if kept_cols is not None:
|
|
1259
|
+
# Reduced solve for rank-deficient design
|
|
1260
|
+
result = scipy_linalg.lstsq(
|
|
1261
|
+
pair_X_ctrl[:, kept_cols],
|
|
1262
|
+
control_change,
|
|
1263
|
+
cond=1e-07,
|
|
1264
|
+
)
|
|
1265
|
+
beta = np.zeros(pair_X_ctrl.shape[1])
|
|
1266
|
+
beta[kept_cols] = result[0]
|
|
1267
|
+
else:
|
|
1268
|
+
# Full-rank lstsq fallback (Cholesky numerical failure)
|
|
1269
|
+
result = scipy_linalg.lstsq(
|
|
1270
|
+
pair_X_ctrl,
|
|
1271
|
+
control_change,
|
|
1272
|
+
cond=1e-07,
|
|
1273
|
+
)
|
|
1274
|
+
beta = result[0]
|
|
1275
|
+
|
|
1276
|
+
nan_cell = False
|
|
1277
|
+
|
|
1278
|
+
if beta is None or np.any(~np.isfinite(beta)):
|
|
1279
|
+
nan_cell = True
|
|
1280
|
+
n_nan_cells += 1
|
|
1281
|
+
|
|
1282
|
+
if not nan_cell:
|
|
1283
|
+
X_treated_w_intercept = np.column_stack([np.ones(n_t), X_treated_pair])
|
|
1284
|
+
with np.errstate(all="ignore"):
|
|
1285
|
+
predicted_control = X_treated_w_intercept @ beta
|
|
1286
|
+
treated_residuals = treated_change - predicted_control
|
|
1287
|
+
if np.any(~np.isfinite(predicted_control)):
|
|
1288
|
+
nan_cell = True
|
|
1289
|
+
n_nan_cells += 1
|
|
1290
|
+
|
|
1291
|
+
if not nan_cell:
|
|
1292
|
+
att = float(np.mean(treated_residuals))
|
|
1293
|
+
with np.errstate(all="ignore"):
|
|
1294
|
+
residuals = control_change - pair_X_ctrl @ beta
|
|
1295
|
+
if np.any(~np.isfinite(residuals)):
|
|
1296
|
+
nan_cell = True
|
|
1297
|
+
n_nan_cells += 1
|
|
1298
|
+
|
|
1299
|
+
if nan_cell:
|
|
1300
|
+
att = np.nan
|
|
1301
|
+
se = np.nan
|
|
1302
|
+
inf_treated = np.zeros(n_t)
|
|
1303
|
+
inf_control = np.zeros(n_c)
|
|
1304
|
+
else:
|
|
1305
|
+
var_t = float(np.var(treated_residuals, ddof=1)) if n_t > 1 else 0.0
|
|
1306
|
+
var_c = float(np.var(residuals, ddof=1)) if pair_n_c > 1 else 0.0
|
|
1307
|
+
se = float(np.sqrt(var_t / n_t + var_c / pair_n_c))
|
|
1308
|
+
inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
|
|
1309
|
+
inf_control = -residuals / pair_n_c
|
|
1310
|
+
|
|
1311
|
+
group_time_effects[(g, t)] = {
|
|
1312
|
+
"effect": att,
|
|
1313
|
+
"se": se,
|
|
1314
|
+
"t_stat": np.nan,
|
|
1315
|
+
"p_value": np.nan,
|
|
1316
|
+
"conf_int": (np.nan, np.nan),
|
|
1317
|
+
"n_treated": n_t,
|
|
1318
|
+
"n_control": n_c,
|
|
1319
|
+
}
|
|
1320
|
+
|
|
1321
|
+
all_units = precomputed["all_units"]
|
|
1322
|
+
treated_positions = np.where(treated_valid)[0]
|
|
1323
|
+
control_positions = np.where(control_valid)[0]
|
|
1324
|
+
influence_func_info[(g, t)] = {
|
|
1325
|
+
"treated_idx": treated_positions,
|
|
1326
|
+
"control_idx": control_positions,
|
|
1327
|
+
"treated_units": all_units[treated_positions],
|
|
1328
|
+
"control_units": all_units[control_positions],
|
|
1329
|
+
"treated_inf": inf_treated,
|
|
1330
|
+
"control_inf": inf_control,
|
|
1331
|
+
}
|
|
1332
|
+
|
|
1333
|
+
atts.append(att)
|
|
1334
|
+
ses.append(se)
|
|
1335
|
+
task_keys.append((g, t))
|
|
1336
|
+
|
|
1337
|
+
if n_nan_cells > 0:
|
|
1338
|
+
warnings.warn(
|
|
1339
|
+
f"{n_nan_cells} group-time cell(s) have non-finite regression results "
|
|
1340
|
+
"(near-singular covariates). These cells are preserved with NaN inference.",
|
|
1341
|
+
UserWarning,
|
|
1342
|
+
stacklevel=2,
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
# Batch inference
|
|
1346
|
+
if task_keys:
|
|
1347
|
+
# Use survey df for replicate designs (propagated from precomputed)
|
|
1348
|
+
_ipw_dr_df = precomputed.get("df_survey") if precomputed is not None else None
|
|
1349
|
+
# Guard: replicate design with undefined df → NaN inference
|
|
1350
|
+
if (
|
|
1351
|
+
_ipw_dr_df is None
|
|
1352
|
+
and precomputed is not None
|
|
1353
|
+
and precomputed.get("resolved_survey_unit") is not None
|
|
1354
|
+
and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
|
|
1355
|
+
and precomputed["resolved_survey_unit"].uses_replicate_variance
|
|
1356
|
+
):
|
|
1357
|
+
_ipw_dr_df = 0
|
|
1358
|
+
t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
|
|
1359
|
+
np.array(atts), np.array(ses), alpha=self.alpha, df=_ipw_dr_df
|
|
1360
|
+
)
|
|
1361
|
+
for idx, key in enumerate(task_keys):
|
|
1362
|
+
group_time_effects[key]["t_stat"] = float(t_stats[idx])
|
|
1363
|
+
group_time_effects[key]["p_value"] = float(p_values[idx])
|
|
1364
|
+
group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx]))
|
|
1365
|
+
|
|
1366
|
+
skip_info = {
|
|
1367
|
+
"missing_period": skipped_missing_period,
|
|
1368
|
+
"empty_cell": skipped_empty_cell,
|
|
1369
|
+
}
|
|
1370
|
+
return group_time_effects, influence_func_info, skip_info
|
|
1371
|
+
|
|
1372
|
+
def fit(
|
|
1373
|
+
self,
|
|
1374
|
+
data: pd.DataFrame,
|
|
1375
|
+
outcome: str,
|
|
1376
|
+
unit: str,
|
|
1377
|
+
time: str,
|
|
1378
|
+
first_treat: str,
|
|
1379
|
+
covariates: Optional[List[str]] = None,
|
|
1380
|
+
aggregate: Optional[str] = None,
|
|
1381
|
+
balance_e: Optional[int] = None,
|
|
1382
|
+
survey_design: object = None,
|
|
1383
|
+
) -> CallawaySantAnnaResults:
|
|
1384
|
+
"""
|
|
1385
|
+
Fit the Callaway-Sant'Anna estimator.
|
|
1386
|
+
|
|
1387
|
+
Parameters
|
|
1388
|
+
----------
|
|
1389
|
+
data : pd.DataFrame
|
|
1390
|
+
Panel data with unit and time identifiers. For repeated
|
|
1391
|
+
cross-sections (``panel=False``), each observation should
|
|
1392
|
+
have a unique unit ID — units do not repeat across periods.
|
|
1393
|
+
outcome : str
|
|
1394
|
+
Name of outcome variable column.
|
|
1395
|
+
unit : str
|
|
1396
|
+
Name of unit identifier column.
|
|
1397
|
+
time : str
|
|
1398
|
+
Name of time period column.
|
|
1399
|
+
first_treat : str
|
|
1400
|
+
Name of column indicating when unit was first treated.
|
|
1401
|
+
Use 0 (or np.inf) for never-treated units.
|
|
1402
|
+
covariates : list, optional
|
|
1403
|
+
List of covariate column names for conditional parallel trends.
|
|
1404
|
+
aggregate : str, optional
|
|
1405
|
+
How to aggregate group-time effects:
|
|
1406
|
+
- None: Only compute ATT(g,t) (default)
|
|
1407
|
+
- "simple": Simple weighted average (overall ATT)
|
|
1408
|
+
- "event_study": Aggregate by relative time (event study)
|
|
1409
|
+
- "group": Aggregate by treatment cohort
|
|
1410
|
+
- "all": Compute all aggregations
|
|
1411
|
+
balance_e : int, optional
|
|
1412
|
+
For event study, balance the panel at relative time e.
|
|
1413
|
+
Ensures all groups contribute to each relative period.
|
|
1414
|
+
survey_design : SurveyDesign, optional
|
|
1415
|
+
Survey design specification. Supports pweight with strata/PSU/FPC.
|
|
1416
|
+
Aggregated SEs (overall, event study, group) use design-based
|
|
1417
|
+
variance via compute_survey_if_variance(). All estimation methods
|
|
1418
|
+
(reg, ipw, dr) support covariates + survey. For repeated
|
|
1419
|
+
cross-sections (``panel=False``), survey weights are
|
|
1420
|
+
per-observation (no unit-level collapse).
|
|
1421
|
+
|
|
1422
|
+
Returns
|
|
1423
|
+
-------
|
|
1424
|
+
CallawaySantAnnaResults
|
|
1425
|
+
Object containing all estimation results.
|
|
1426
|
+
|
|
1427
|
+
Raises
|
|
1428
|
+
------
|
|
1429
|
+
ValueError
|
|
1430
|
+
If required columns are missing or data validation fails.
|
|
1431
|
+
"""
|
|
1432
|
+
# Validate pscore_trim (may have been changed via set_params)
|
|
1433
|
+
if not (0 < self.pscore_trim < 0.5):
|
|
1434
|
+
raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}")
|
|
1435
|
+
|
|
1436
|
+
# Reset stale state from prior fit (prevents leaking event-study VCV)
|
|
1437
|
+
self._event_study_vcov = None
|
|
1438
|
+
|
|
1439
|
+
if not self.panel:
|
|
1440
|
+
warnings.warn(
|
|
1441
|
+
"panel=False uses repeated cross-section DRDID estimators "
|
|
1442
|
+
"(Sant'Anna & Zhao 2020, Section 4) which assume stationary "
|
|
1443
|
+
"cross-sectional sampling: the population distribution of "
|
|
1444
|
+
"(Y, X, G) must be stable across periods. This assumption "
|
|
1445
|
+
"is not data-checkable.",
|
|
1446
|
+
UserWarning,
|
|
1447
|
+
stacklevel=2,
|
|
1448
|
+
)
|
|
1449
|
+
|
|
1450
|
+
# Validate unique unit IDs for panel=False
|
|
1451
|
+
if not self.panel:
|
|
1452
|
+
if data[unit].duplicated().any():
|
|
1453
|
+
raise ValueError(
|
|
1454
|
+
"panel=False requires unique unit IDs (one observation per unit). "
|
|
1455
|
+
"Found duplicate unit IDs. If your data is a panel, use panel=True."
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
# Normalize empty covariates list to None
|
|
1459
|
+
if covariates is not None and len(covariates) == 0:
|
|
1460
|
+
covariates = None
|
|
1461
|
+
|
|
1462
|
+
# Resolve survey design if provided
|
|
1463
|
+
from diff_diff.survey import (
|
|
1464
|
+
_resolve_survey_for_fit,
|
|
1465
|
+
_validate_unit_constant_survey,
|
|
1466
|
+
)
|
|
1467
|
+
|
|
1468
|
+
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
|
|
1469
|
+
_resolve_survey_for_fit(survey_design, data, "analytical")
|
|
1470
|
+
)
|
|
1471
|
+
|
|
1472
|
+
# Validate within-unit constancy for panel survey designs
|
|
1473
|
+
if resolved_survey is not None:
|
|
1474
|
+
if self.panel:
|
|
1475
|
+
_validate_unit_constant_survey(data, unit, survey_design)
|
|
1476
|
+
if resolved_survey.weight_type != "pweight":
|
|
1477
|
+
raise ValueError(
|
|
1478
|
+
f"CallawaySantAnna survey support requires weight_type='pweight', "
|
|
1479
|
+
f"got '{resolved_survey.weight_type}'. The survey variance math "
|
|
1480
|
+
f"assumes probability weights (pweight)."
|
|
1481
|
+
)
|
|
1482
|
+
# Note: strata/PSU/FPC are now supported — aggregated SEs use
|
|
1483
|
+
# compute_survey_if_variance() for design-based inference.
|
|
1484
|
+
|
|
1485
|
+
# Bootstrap + survey is now supported via PSU-level multiplier bootstrap.
|
|
1486
|
+
|
|
1487
|
+
# Validate inputs
|
|
1488
|
+
required_cols = [outcome, unit, time, first_treat]
|
|
1489
|
+
if covariates:
|
|
1490
|
+
required_cols.extend(covariates)
|
|
1491
|
+
|
|
1492
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
1493
|
+
if missing:
|
|
1494
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
1495
|
+
|
|
1496
|
+
# Create working copy
|
|
1497
|
+
df = data.copy()
|
|
1498
|
+
|
|
1499
|
+
# Ensure numeric types
|
|
1500
|
+
df[time] = pd.to_numeric(df[time])
|
|
1501
|
+
df[first_treat] = pd.to_numeric(df[first_treat])
|
|
1502
|
+
|
|
1503
|
+
# Standardize the first_treat column name for internal use
|
|
1504
|
+
# This avoids hardcoding column names in internal methods
|
|
1505
|
+
df["first_treat"] = df[first_treat]
|
|
1506
|
+
|
|
1507
|
+
# Never-treated indicator (must precede treatment_groups to exclude np.inf)
|
|
1508
|
+
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
|
|
1509
|
+
# Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
|
|
1510
|
+
_inf_mask = df[first_treat] == np.inf
|
|
1511
|
+
if _inf_mask.any():
|
|
1512
|
+
n_inf_units = df.loc[_inf_mask, unit].nunique()
|
|
1513
|
+
warnings.warn(
|
|
1514
|
+
f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 "
|
|
1515
|
+
f"(never-treated). Use first_treat=0 to suppress this warning.",
|
|
1516
|
+
UserWarning,
|
|
1517
|
+
stacklevel=2,
|
|
1518
|
+
)
|
|
1519
|
+
df.loc[_inf_mask, first_treat] = 0
|
|
1520
|
+
|
|
1521
|
+
# Identify groups and time periods
|
|
1522
|
+
time_periods = sorted(df[time].unique())
|
|
1523
|
+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
|
|
1524
|
+
|
|
1525
|
+
if self.panel:
|
|
1526
|
+
# Panel: count unique units
|
|
1527
|
+
unit_info = (
|
|
1528
|
+
df.groupby(unit)
|
|
1529
|
+
.agg({first_treat: "first", "_never_treated": "first"})
|
|
1530
|
+
.reset_index()
|
|
1531
|
+
)
|
|
1532
|
+
n_treated_units = (unit_info[first_treat] > 0).sum()
|
|
1533
|
+
n_control_units = (unit_info["_never_treated"]).sum()
|
|
1534
|
+
else:
|
|
1535
|
+
# RCS: count observations per cohort (no unit tracking)
|
|
1536
|
+
n_treated_units = int((df[first_treat] > 0).sum())
|
|
1537
|
+
n_control_units = int(df["_never_treated"].sum())
|
|
1538
|
+
|
|
1539
|
+
if n_control_units == 0 and self.control_group == "never_treated":
|
|
1540
|
+
raise ValueError(
|
|
1541
|
+
"No never-treated units found. Check 'first_treat' column. "
|
|
1542
|
+
"Use control_group='not_yet_treated' if all units are eventually treated."
|
|
1543
|
+
)
|
|
1544
|
+
if n_control_units == 0 and self.control_group == "not_yet_treated":
|
|
1545
|
+
# With not_yet_treated, controls are units not yet treated at each
|
|
1546
|
+
# (g, t) pair — never-treated units are not required.
|
|
1547
|
+
if len(treatment_groups) < 2:
|
|
1548
|
+
raise ValueError(
|
|
1549
|
+
"not_yet_treated control group requires at least 2 treatment "
|
|
1550
|
+
"cohorts when there are no never-treated units."
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
# Note: CallawaySantAnna supports survey weights, strata, PSU, and FPC.
|
|
1554
|
+
# Per-cell SEs use IF-based variance; aggregated SEs use design-based
|
|
1555
|
+
# variance via compute_survey_if_variance() or PSU-level bootstrap.
|
|
1556
|
+
# Pre-compute data structures for efficient ATT(g,t) computation
|
|
1557
|
+
if self.panel:
|
|
1558
|
+
precomputed = self._precompute_structures(
|
|
1559
|
+
df,
|
|
1560
|
+
outcome,
|
|
1561
|
+
unit,
|
|
1562
|
+
time,
|
|
1563
|
+
first_treat,
|
|
1564
|
+
covariates,
|
|
1565
|
+
time_periods,
|
|
1566
|
+
treatment_groups,
|
|
1567
|
+
resolved_survey=resolved_survey,
|
|
1568
|
+
)
|
|
1569
|
+
else:
|
|
1570
|
+
precomputed = self._precompute_structures_rc(
|
|
1571
|
+
df,
|
|
1572
|
+
outcome,
|
|
1573
|
+
unit,
|
|
1574
|
+
time,
|
|
1575
|
+
first_treat,
|
|
1576
|
+
covariates,
|
|
1577
|
+
time_periods,
|
|
1578
|
+
treatment_groups,
|
|
1579
|
+
resolved_survey=resolved_survey,
|
|
1580
|
+
)
|
|
1581
|
+
|
|
1582
|
+
# Recompute survey metadata from the unit-level resolved survey so
|
|
1583
|
+
# that n_psu and df_survey reflect the actual survey design (explicit
|
|
1584
|
+
# PSU/strata) rather than hard-coding n_units.
|
|
1585
|
+
if resolved_survey is not None and survey_metadata is not None:
|
|
1586
|
+
resolved_survey_unit = precomputed.get("resolved_survey_unit")
|
|
1587
|
+
if resolved_survey_unit is not None:
|
|
1588
|
+
from diff_diff.survey import compute_survey_metadata
|
|
1589
|
+
|
|
1590
|
+
unit_w = resolved_survey_unit.weights
|
|
1591
|
+
survey_metadata = compute_survey_metadata(resolved_survey_unit, unit_w)
|
|
1592
|
+
|
|
1593
|
+
# Survey df for safe_inference calls — use the unit-level resolved
|
|
1594
|
+
# survey df computed in _precompute_structures for consistency.
|
|
1595
|
+
df_survey = precomputed.get("df_survey")
|
|
1596
|
+
# Guard: replicate design with undefined df (rank <= 1) → NaN inference
|
|
1597
|
+
if (
|
|
1598
|
+
df_survey is None
|
|
1599
|
+
and resolved_survey is not None
|
|
1600
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
1601
|
+
and resolved_survey.uses_replicate_variance
|
|
1602
|
+
):
|
|
1603
|
+
df_survey = 0
|
|
1604
|
+
|
|
1605
|
+
# Compute ATT(g,t) for each group-time combination
|
|
1606
|
+
min_period = min(time_periods)
|
|
1607
|
+
has_survey = resolved_survey is not None
|
|
1608
|
+
|
|
1609
|
+
_skip_info = {"missing_period": [], "empty_cell": []}
|
|
1610
|
+
_n_skipped_other = 0
|
|
1611
|
+
|
|
1612
|
+
if not self.panel:
|
|
1613
|
+
# --- Repeated cross-section path ---
|
|
1614
|
+
# No vectorized/Cholesky fast paths (panel-only optimizations).
|
|
1615
|
+
# Loop using _compute_att_gt_rc() for each (g,t).
|
|
1616
|
+
group_time_effects = {}
|
|
1617
|
+
influence_func_info = {}
|
|
1618
|
+
epv_diagnostics = (
|
|
1619
|
+
{} if (covariates and self.estimation_method in ("ipw", "dr")) else None
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1622
|
+
for g in treatment_groups:
|
|
1623
|
+
if self.base_period == "universal":
|
|
1624
|
+
universal_base = g - 1 - self.anticipation
|
|
1625
|
+
valid_periods = [t for t in time_periods if t != universal_base]
|
|
1626
|
+
else:
|
|
1627
|
+
valid_periods = [
|
|
1628
|
+
t for t in time_periods if t >= g - self.anticipation or t > min_period
|
|
1629
|
+
]
|
|
1630
|
+
|
|
1631
|
+
for t in valid_periods:
|
|
1632
|
+
rc_result = self._compute_att_gt_rc(
|
|
1633
|
+
precomputed,
|
|
1634
|
+
g,
|
|
1635
|
+
t,
|
|
1636
|
+
covariates,
|
|
1637
|
+
epv_diagnostics=epv_diagnostics,
|
|
1638
|
+
)
|
|
1639
|
+
att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = rc_result[:6]
|
|
1640
|
+
agg_w = rc_result[6] if len(rc_result) > 6 else n_treat
|
|
1641
|
+
|
|
1642
|
+
if att_gt is not None:
|
|
1643
|
+
t_stat, p_val, ci = safe_inference(
|
|
1644
|
+
att_gt,
|
|
1645
|
+
se_gt,
|
|
1646
|
+
alpha=self.alpha,
|
|
1647
|
+
df=df_survey,
|
|
1648
|
+
)
|
|
1649
|
+
|
|
1650
|
+
gte_entry = {
|
|
1651
|
+
"effect": att_gt,
|
|
1652
|
+
"se": se_gt,
|
|
1653
|
+
"t_stat": t_stat,
|
|
1654
|
+
"p_value": p_val,
|
|
1655
|
+
"conf_int": ci,
|
|
1656
|
+
"n_treated": n_treat,
|
|
1657
|
+
"n_control": n_ctrl,
|
|
1658
|
+
"agg_weight": agg_w,
|
|
1659
|
+
}
|
|
1660
|
+
if sw_sum is not None:
|
|
1661
|
+
gte_entry["survey_weight_sum"] = sw_sum
|
|
1662
|
+
group_time_effects[(g, t)] = gte_entry
|
|
1663
|
+
|
|
1664
|
+
if inf_info is not None:
|
|
1665
|
+
influence_func_info[(g, t)] = inf_info
|
|
1666
|
+
else:
|
|
1667
|
+
_n_skipped_other += 1
|
|
1668
|
+
|
|
1669
|
+
elif covariates is None and self.estimation_method == "reg":
|
|
1670
|
+
# Fast vectorized path for the common no-covariates regression case
|
|
1671
|
+
group_time_effects, influence_func_info, _skip_info = (
|
|
1672
|
+
self._compute_all_att_gt_vectorized(
|
|
1673
|
+
precomputed, treatment_groups, time_periods, min_period
|
|
1674
|
+
)
|
|
1675
|
+
)
|
|
1676
|
+
epv_diagnostics = None # No logit in this path
|
|
1677
|
+
elif (
|
|
1678
|
+
covariates is not None
|
|
1679
|
+
and self.estimation_method == "reg"
|
|
1680
|
+
and self.rank_deficient_action != "error"
|
|
1681
|
+
and not has_survey # Cholesky cache uses X'X; survey needs X'WX
|
|
1682
|
+
):
|
|
1683
|
+
# Optimized covariate regression path with Cholesky caching
|
|
1684
|
+
group_time_effects, influence_func_info, _skip_info = (
|
|
1685
|
+
self._compute_all_att_gt_covariate_reg(
|
|
1686
|
+
precomputed, treatment_groups, time_periods, min_period
|
|
1687
|
+
)
|
|
1688
|
+
)
|
|
1689
|
+
epv_diagnostics = None # No logit in this path
|
|
1690
|
+
else:
|
|
1691
|
+
# General path: IPW, DR, rank_deficient_action="error", or edge cases
|
|
1692
|
+
group_time_effects = {}
|
|
1693
|
+
influence_func_info = {}
|
|
1694
|
+
|
|
1695
|
+
# Propensity score cache for IPW/DR with covariates
|
|
1696
|
+
pscore_cache = {} if (covariates and self.estimation_method in ("ipw", "dr")) else None
|
|
1697
|
+
# Cholesky cache for DR outcome regression component
|
|
1698
|
+
# Skip cache when survey weights present (X'WX differs from X'X)
|
|
1699
|
+
cho_cache = (
|
|
1700
|
+
{}
|
|
1701
|
+
if (
|
|
1702
|
+
covariates
|
|
1703
|
+
and self.estimation_method == "dr"
|
|
1704
|
+
and self.rank_deficient_action != "error"
|
|
1705
|
+
and not has_survey
|
|
1706
|
+
)
|
|
1707
|
+
else None
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
epv_diagnostics = (
|
|
1711
|
+
{} if (covariates and self.estimation_method in ("ipw", "dr")) else None
|
|
1712
|
+
)
|
|
1713
|
+
|
|
1714
|
+
for g in treatment_groups:
|
|
1715
|
+
if self.base_period == "universal":
|
|
1716
|
+
universal_base = g - 1 - self.anticipation
|
|
1717
|
+
valid_periods = [t for t in time_periods if t != universal_base]
|
|
1718
|
+
else:
|
|
1719
|
+
valid_periods = [
|
|
1720
|
+
t for t in time_periods if t >= g - self.anticipation or t > min_period
|
|
1721
|
+
]
|
|
1722
|
+
|
|
1723
|
+
for t in valid_periods:
|
|
1724
|
+
att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_fast(
|
|
1725
|
+
precomputed,
|
|
1726
|
+
g,
|
|
1727
|
+
t,
|
|
1728
|
+
covariates,
|
|
1729
|
+
pscore_cache=pscore_cache,
|
|
1730
|
+
cho_cache=cho_cache,
|
|
1731
|
+
epv_diagnostics=epv_diagnostics,
|
|
1732
|
+
)
|
|
1733
|
+
|
|
1734
|
+
if att_gt is not None:
|
|
1735
|
+
t_stat, p_val, ci = safe_inference(
|
|
1736
|
+
att_gt,
|
|
1737
|
+
se_gt,
|
|
1738
|
+
alpha=self.alpha,
|
|
1739
|
+
df=df_survey,
|
|
1740
|
+
)
|
|
1741
|
+
|
|
1742
|
+
gte_entry = {
|
|
1743
|
+
"effect": att_gt,
|
|
1744
|
+
"se": se_gt,
|
|
1745
|
+
"t_stat": t_stat,
|
|
1746
|
+
"p_value": p_val,
|
|
1747
|
+
"conf_int": ci,
|
|
1748
|
+
"n_treated": n_treat,
|
|
1749
|
+
"n_control": n_ctrl,
|
|
1750
|
+
}
|
|
1751
|
+
if sw_sum is not None:
|
|
1752
|
+
gte_entry["survey_weight_sum"] = sw_sum
|
|
1753
|
+
group_time_effects[(g, t)] = gte_entry
|
|
1754
|
+
|
|
1755
|
+
if inf_info is not None:
|
|
1756
|
+
influence_func_info[(g, t)] = inf_info
|
|
1757
|
+
else:
|
|
1758
|
+
_n_skipped_other += 1
|
|
1759
|
+
|
|
1760
|
+
if not group_time_effects:
|
|
1761
|
+
raise ValueError(
|
|
1762
|
+
"Could not estimate any group-time effects. "
|
|
1763
|
+
"Check that data has sufficient observations."
|
|
1764
|
+
)
|
|
1765
|
+
|
|
1766
|
+
# Consolidated EPV summary warning
|
|
1767
|
+
if epv_diagnostics:
|
|
1768
|
+
low_epv = {k: v for k, v in epv_diagnostics.items() if v.get("is_low")}
|
|
1769
|
+
if low_epv:
|
|
1770
|
+
n_affected = len(low_epv)
|
|
1771
|
+
n_total = len(epv_diagnostics)
|
|
1772
|
+
min_entry = min(low_epv.values(), key=lambda v: v["epv"])
|
|
1773
|
+
min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
|
|
1774
|
+
warnings.warn(
|
|
1775
|
+
f"Low Events Per Variable (EPV) detected in propensity "
|
|
1776
|
+
f"score estimation for {n_affected} of {n_total} cell(s). "
|
|
1777
|
+
f"Minimum EPV = {min_entry['epv']:.1f} "
|
|
1778
|
+
f"(cohort g={min_g[0]}). "
|
|
1779
|
+
f"Consider estimation_method='reg' (avoids propensity "
|
|
1780
|
+
f"scores) or reducing the number of covariates. "
|
|
1781
|
+
f"See results.epv_summary() for details.",
|
|
1782
|
+
UserWarning,
|
|
1783
|
+
stacklevel=2,
|
|
1784
|
+
)
|
|
1785
|
+
|
|
1786
|
+
# Consolidated (g,t) cell skip warning (all paths)
|
|
1787
|
+
_n_missing = len(_skip_info.get("missing_period", []))
|
|
1788
|
+
_n_empty = len(_skip_info.get("empty_cell", []))
|
|
1789
|
+
_n_total_skipped = _n_missing + _n_empty + _n_skipped_other
|
|
1790
|
+
if _n_total_skipped > 0:
|
|
1791
|
+
_parts = []
|
|
1792
|
+
if _n_missing:
|
|
1793
|
+
_parts.append(
|
|
1794
|
+
f"{_n_missing} due to missing base/post period " f"in panel structure"
|
|
1795
|
+
)
|
|
1796
|
+
if _n_empty:
|
|
1797
|
+
_parts.append(f"{_n_empty} due to zero treated or control " f"observations")
|
|
1798
|
+
if _n_skipped_other:
|
|
1799
|
+
_parts.append(
|
|
1800
|
+
f"{_n_skipped_other} due to insufficient data or " f"non-estimable cells"
|
|
1801
|
+
)
|
|
1802
|
+
warnings.warn(
|
|
1803
|
+
f"{_n_total_skipped} (group, time) cell(s) could not be "
|
|
1804
|
+
f"estimated: {'; '.join(_parts)}.",
|
|
1805
|
+
UserWarning,
|
|
1806
|
+
stacklevel=2,
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
# Compute overall ATT (simple aggregation)
|
|
1810
|
+
overall_att, overall_se, overall_effective_df = self._aggregate_simple(
|
|
1811
|
+
group_time_effects, influence_func_info, df, unit, precomputed
|
|
1812
|
+
)
|
|
1813
|
+
# Use per-statistic effective df from replicate aggregation if available;
|
|
1814
|
+
# otherwise fall back to the original df from the survey design.
|
|
1815
|
+
if overall_effective_df is not None:
|
|
1816
|
+
df_survey = overall_effective_df
|
|
1817
|
+
# Propagate to survey_metadata for display consistency
|
|
1818
|
+
if survey_metadata is not None:
|
|
1819
|
+
survey_metadata.df_survey = df_survey
|
|
1820
|
+
# Guard: replicate design with undefined df (rank <= 1) → NaN inference
|
|
1821
|
+
if (
|
|
1822
|
+
df_survey is None
|
|
1823
|
+
and resolved_survey is not None
|
|
1824
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
1825
|
+
and resolved_survey.uses_replicate_variance
|
|
1826
|
+
):
|
|
1827
|
+
df_survey = 0
|
|
1828
|
+
overall_t, overall_p, overall_ci = safe_inference(
|
|
1829
|
+
overall_att,
|
|
1830
|
+
overall_se,
|
|
1831
|
+
alpha=self.alpha,
|
|
1832
|
+
df=df_survey,
|
|
1833
|
+
)
|
|
1834
|
+
|
|
1835
|
+
# Compute additional aggregations if requested
|
|
1836
|
+
event_study_effects = None
|
|
1837
|
+
group_effects = None
|
|
1838
|
+
|
|
1839
|
+
if aggregate in ["event_study", "all"]:
|
|
1840
|
+
event_study_effects = self._aggregate_event_study(
|
|
1841
|
+
group_time_effects,
|
|
1842
|
+
influence_func_info,
|
|
1843
|
+
treatment_groups,
|
|
1844
|
+
time_periods,
|
|
1845
|
+
balance_e,
|
|
1846
|
+
df,
|
|
1847
|
+
unit,
|
|
1848
|
+
precomputed,
|
|
1849
|
+
)
|
|
1850
|
+
|
|
1851
|
+
if aggregate in ["group", "all"]:
|
|
1852
|
+
group_effects = self._aggregate_by_group(
|
|
1853
|
+
group_time_effects,
|
|
1854
|
+
influence_func_info,
|
|
1855
|
+
treatment_groups,
|
|
1856
|
+
precomputed=precomputed,
|
|
1857
|
+
df=df,
|
|
1858
|
+
unit=unit,
|
|
1859
|
+
)
|
|
1860
|
+
|
|
1861
|
+
# Reject replicate-weight designs for bootstrap — replicate variance
|
|
1862
|
+
# is an analytical alternative, not compatible with bootstrap
|
|
1863
|
+
if (
|
|
1864
|
+
self.n_bootstrap > 0
|
|
1865
|
+
and resolved_survey is not None
|
|
1866
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
1867
|
+
and resolved_survey.uses_replicate_variance
|
|
1868
|
+
):
|
|
1869
|
+
raise NotImplementedError(
|
|
1870
|
+
"CallawaySantAnna bootstrap (n_bootstrap > 0) is not supported "
|
|
1871
|
+
"with replicate-weight survey designs. Replicate weights provide "
|
|
1872
|
+
"analytical variance; use n_bootstrap=0 instead."
|
|
1873
|
+
)
|
|
1874
|
+
|
|
1875
|
+
# Run bootstrap inference if requested
|
|
1876
|
+
bootstrap_results = None
|
|
1877
|
+
if self.n_bootstrap > 0 and influence_func_info:
|
|
1878
|
+
bootstrap_results = self._run_multiplier_bootstrap(
|
|
1879
|
+
group_time_effects=group_time_effects,
|
|
1880
|
+
influence_func_info=influence_func_info,
|
|
1881
|
+
aggregate=aggregate,
|
|
1882
|
+
balance_e=balance_e,
|
|
1883
|
+
treatment_groups=treatment_groups,
|
|
1884
|
+
time_periods=time_periods,
|
|
1885
|
+
df=df,
|
|
1886
|
+
unit=unit,
|
|
1887
|
+
precomputed=precomputed,
|
|
1888
|
+
cband=self.cband,
|
|
1889
|
+
)
|
|
1890
|
+
|
|
1891
|
+
# Update estimates with bootstrap inference
|
|
1892
|
+
overall_se = bootstrap_results.overall_att_se
|
|
1893
|
+
overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
|
|
1894
|
+
overall_p = bootstrap_results.overall_att_p_value
|
|
1895
|
+
overall_ci = bootstrap_results.overall_att_ci
|
|
1896
|
+
|
|
1897
|
+
# Update group-time effects with bootstrap SEs (batched)
|
|
1898
|
+
gt_keys = [gt for gt in group_time_effects if gt in bootstrap_results.group_time_ses]
|
|
1899
|
+
if gt_keys:
|
|
1900
|
+
gt_effects_arr = np.array(
|
|
1901
|
+
[float(group_time_effects[gt]["effect"]) for gt in gt_keys]
|
|
1902
|
+
)
|
|
1903
|
+
gt_ses_arr = np.array(
|
|
1904
|
+
[float(bootstrap_results.group_time_ses[gt]) for gt in gt_keys]
|
|
1905
|
+
)
|
|
1906
|
+
gt_t_stats, _, _, _ = safe_inference_batch(
|
|
1907
|
+
gt_effects_arr, gt_ses_arr, alpha=self.alpha
|
|
1908
|
+
)
|
|
1909
|
+
for idx, gt in enumerate(gt_keys):
|
|
1910
|
+
group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt]
|
|
1911
|
+
group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt]
|
|
1912
|
+
group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt]
|
|
1913
|
+
group_time_effects[gt]["t_stat"] = float(gt_t_stats[idx])
|
|
1914
|
+
|
|
1915
|
+
# Update event study effects with bootstrap SEs (batched)
|
|
1916
|
+
if (
|
|
1917
|
+
event_study_effects is not None
|
|
1918
|
+
and bootstrap_results.event_study_ses is not None
|
|
1919
|
+
and bootstrap_results.event_study_cis is not None
|
|
1920
|
+
and bootstrap_results.event_study_p_values is not None
|
|
1921
|
+
):
|
|
1922
|
+
es_keys = [e for e in event_study_effects if e in bootstrap_results.event_study_ses]
|
|
1923
|
+
if es_keys:
|
|
1924
|
+
es_effects_arr = np.array(
|
|
1925
|
+
[float(event_study_effects[e]["effect"]) for e in es_keys]
|
|
1926
|
+
)
|
|
1927
|
+
es_ses_arr = np.array(
|
|
1928
|
+
[float(bootstrap_results.event_study_ses[e]) for e in es_keys]
|
|
1929
|
+
)
|
|
1930
|
+
es_t_stats, _, _, _ = safe_inference_batch(
|
|
1931
|
+
es_effects_arr, es_ses_arr, alpha=self.alpha
|
|
1932
|
+
)
|
|
1933
|
+
for idx, e in enumerate(es_keys):
|
|
1934
|
+
event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
|
|
1935
|
+
event_study_effects[e]["conf_int"] = bootstrap_results.event_study_cis[e]
|
|
1936
|
+
event_study_effects[e]["p_value"] = bootstrap_results.event_study_p_values[
|
|
1937
|
+
e
|
|
1938
|
+
]
|
|
1939
|
+
event_study_effects[e]["t_stat"] = float(es_t_stats[idx])
|
|
1940
|
+
|
|
1941
|
+
# Update group effects with bootstrap SEs (batched)
|
|
1942
|
+
if (
|
|
1943
|
+
group_effects is not None
|
|
1944
|
+
and bootstrap_results.group_effect_ses is not None
|
|
1945
|
+
and bootstrap_results.group_effect_cis is not None
|
|
1946
|
+
and bootstrap_results.group_effect_p_values is not None
|
|
1947
|
+
):
|
|
1948
|
+
grp_keys = [g for g in group_effects if g in bootstrap_results.group_effect_ses]
|
|
1949
|
+
if grp_keys:
|
|
1950
|
+
grp_effects_arr = np.array(
|
|
1951
|
+
[float(group_effects[g]["effect"]) for g in grp_keys]
|
|
1952
|
+
)
|
|
1953
|
+
grp_ses_arr = np.array(
|
|
1954
|
+
[float(bootstrap_results.group_effect_ses[g]) for g in grp_keys]
|
|
1955
|
+
)
|
|
1956
|
+
grp_t_stats, _, _, _ = safe_inference_batch(
|
|
1957
|
+
grp_effects_arr, grp_ses_arr, alpha=self.alpha
|
|
1958
|
+
)
|
|
1959
|
+
for idx, g in enumerate(grp_keys):
|
|
1960
|
+
group_effects[g]["se"] = bootstrap_results.group_effect_ses[g]
|
|
1961
|
+
group_effects[g]["conf_int"] = bootstrap_results.group_effect_cis[g]
|
|
1962
|
+
group_effects[g]["p_value"] = bootstrap_results.group_effect_p_values[g]
|
|
1963
|
+
group_effects[g]["t_stat"] = float(grp_t_stats[idx])
|
|
1964
|
+
|
|
1965
|
+
# Compute simultaneous confidence band CIs if cband is available
|
|
1966
|
+
cband_crit_value = None
|
|
1967
|
+
if bootstrap_results is not None:
|
|
1968
|
+
cband_crit_value = bootstrap_results.cband_crit_value
|
|
1969
|
+
|
|
1970
|
+
if cband_crit_value is not None and event_study_effects is not None:
|
|
1971
|
+
for e, eff_data in event_study_effects.items():
|
|
1972
|
+
se_val = eff_data["se"]
|
|
1973
|
+
if np.isfinite(se_val) and se_val > 0:
|
|
1974
|
+
eff_data["cband_conf_int"] = (
|
|
1975
|
+
eff_data["effect"] - cband_crit_value * se_val,
|
|
1976
|
+
eff_data["effect"] + cband_crit_value * se_val,
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
# Store results
|
|
1980
|
+
# Retrieve event-study VCV from aggregation mixin (Phase 7d).
|
|
1981
|
+
# Clear it when bootstrap overwrites event-study SEs to prevent
|
|
1982
|
+
# HonestDiD from mixing analytical VCV with bootstrap SEs.
|
|
1983
|
+
event_study_vcov = getattr(self, "_event_study_vcov", None)
|
|
1984
|
+
event_study_vcov_index = getattr(self, "_event_study_vcov_index", None)
|
|
1985
|
+
if bootstrap_results is not None and event_study_vcov is not None:
|
|
1986
|
+
event_study_vcov = None
|
|
1987
|
+
event_study_vcov_index = None
|
|
1988
|
+
|
|
1989
|
+
self.results_ = CallawaySantAnnaResults(
|
|
1990
|
+
group_time_effects=group_time_effects,
|
|
1991
|
+
overall_att=overall_att,
|
|
1992
|
+
overall_se=overall_se,
|
|
1993
|
+
overall_t_stat=overall_t,
|
|
1994
|
+
overall_p_value=overall_p,
|
|
1995
|
+
overall_conf_int=overall_ci,
|
|
1996
|
+
groups=treatment_groups,
|
|
1997
|
+
time_periods=time_periods,
|
|
1998
|
+
n_obs=len(df),
|
|
1999
|
+
n_treated_units=n_treated_units,
|
|
2000
|
+
n_control_units=n_control_units,
|
|
2001
|
+
alpha=self.alpha,
|
|
2002
|
+
control_group=self.control_group,
|
|
2003
|
+
base_period=self.base_period,
|
|
2004
|
+
event_study_effects=event_study_effects,
|
|
2005
|
+
group_effects=group_effects,
|
|
2006
|
+
bootstrap_results=bootstrap_results,
|
|
2007
|
+
cband_crit_value=cband_crit_value,
|
|
2008
|
+
pscore_trim=self.pscore_trim,
|
|
2009
|
+
survey_metadata=survey_metadata,
|
|
2010
|
+
event_study_vcov=event_study_vcov,
|
|
2011
|
+
event_study_vcov_index=event_study_vcov_index,
|
|
2012
|
+
panel=self.panel,
|
|
2013
|
+
epv_diagnostics=epv_diagnostics if epv_diagnostics else None,
|
|
2014
|
+
epv_threshold=self.epv_threshold,
|
|
2015
|
+
pscore_fallback=self.pscore_fallback,
|
|
2016
|
+
)
|
|
2017
|
+
|
|
2018
|
+
self.is_fitted_ = True
|
|
2019
|
+
return self.results_
|
|
2020
|
+
|
|
2021
|
+
def _outcome_regression(
|
|
2022
|
+
self,
|
|
2023
|
+
treated_change: np.ndarray,
|
|
2024
|
+
control_change: np.ndarray,
|
|
2025
|
+
X_treated: Optional[np.ndarray] = None,
|
|
2026
|
+
X_control: Optional[np.ndarray] = None,
|
|
2027
|
+
sw_treated: Optional[np.ndarray] = None,
|
|
2028
|
+
sw_control: Optional[np.ndarray] = None,
|
|
2029
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
2030
|
+
"""
|
|
2031
|
+
Estimate ATT using outcome regression.
|
|
2032
|
+
|
|
2033
|
+
With covariates:
|
|
2034
|
+
1. Regress outcome changes on covariates for control group
|
|
2035
|
+
2. Predict counterfactual for treated using their covariates
|
|
2036
|
+
3. ATT = mean(treated_change) - mean(predicted_counterfactual)
|
|
2037
|
+
|
|
2038
|
+
Without covariates:
|
|
2039
|
+
Simple difference in means.
|
|
2040
|
+
|
|
2041
|
+
Parameters
|
|
2042
|
+
----------
|
|
2043
|
+
sw_treated, sw_control : np.ndarray, optional
|
|
2044
|
+
Survey weights for treated and control units.
|
|
2045
|
+
"""
|
|
2046
|
+
n_t = len(treated_change)
|
|
2047
|
+
n_c = len(control_change)
|
|
2048
|
+
|
|
2049
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
2050
|
+
# Covariate-adjusted outcome regression
|
|
2051
|
+
# Fit regression on control units: E[Delta Y | X, D=0]
|
|
2052
|
+
beta, residuals = _linear_regression(
|
|
2053
|
+
X_control,
|
|
2054
|
+
control_change,
|
|
2055
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
2056
|
+
weights=sw_control,
|
|
2057
|
+
)
|
|
2058
|
+
|
|
2059
|
+
# Zero NaN coefficients for prediction (dropped rank-deficient columns
|
|
2060
|
+
# contribute 0 to the column space projection, matching DR path convention)
|
|
2061
|
+
beta = np.where(np.isfinite(beta), beta, 0.0)
|
|
2062
|
+
|
|
2063
|
+
# Predict counterfactual for treated units
|
|
2064
|
+
X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
|
|
2065
|
+
predicted_control = np.dot(X_treated_with_intercept, beta)
|
|
2066
|
+
|
|
2067
|
+
# ATT: survey-weighted mean of treated residuals
|
|
2068
|
+
treated_residuals = treated_change - predicted_control
|
|
2069
|
+
|
|
2070
|
+
if sw_treated is not None:
|
|
2071
|
+
sw_t_sum = float(np.sum(sw_treated))
|
|
2072
|
+
sw_c_sum = float(np.sum(sw_control))
|
|
2073
|
+
sw_t_norm = sw_treated / sw_t_sum
|
|
2074
|
+
sw_c_norm = sw_control / sw_c_sum
|
|
2075
|
+
att = float(np.sum(sw_t_norm * treated_residuals))
|
|
2076
|
+
|
|
2077
|
+
# Survey-weighted OR influence function.
|
|
2078
|
+
# Mirrors unweighted: inf_treated = (resid-ATT)/n_t,
|
|
2079
|
+
# inf_control = -resid/n_c. Survey: w_i/sum(w_group).
|
|
2080
|
+
# WLS residuals are orthogonal to W*X by construction.
|
|
2081
|
+
X_c_int = np.column_stack([np.ones(n_c), X_control])
|
|
2082
|
+
resid_c = control_change - np.dot(X_c_int, beta)
|
|
2083
|
+
|
|
2084
|
+
inf_treated = sw_t_norm * (treated_residuals - att)
|
|
2085
|
+
inf_control = -sw_c_norm * resid_c
|
|
2086
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2087
|
+
|
|
2088
|
+
# SE: survey-weighted variance matching unweighted var_t/n_t + var_c/n_c
|
|
2089
|
+
var_t = float(np.sum(sw_t_norm * (treated_residuals - att) ** 2))
|
|
2090
|
+
var_c = float(np.sum(sw_c_norm * resid_c**2))
|
|
2091
|
+
se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0
|
|
2092
|
+
else:
|
|
2093
|
+
att = float(np.mean(treated_residuals))
|
|
2094
|
+
|
|
2095
|
+
# Standard error using sandwich estimator
|
|
2096
|
+
var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
|
|
2097
|
+
var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
|
|
2098
|
+
se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
|
|
2099
|
+
|
|
2100
|
+
# Influence function
|
|
2101
|
+
inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
|
|
2102
|
+
inf_control = -residuals / n_c
|
|
2103
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2104
|
+
else:
|
|
2105
|
+
# Simple difference in means (no covariates)
|
|
2106
|
+
if sw_treated is not None:
|
|
2107
|
+
sw_t_norm = sw_treated / np.sum(sw_treated)
|
|
2108
|
+
sw_c_norm = sw_control / np.sum(sw_control)
|
|
2109
|
+
mu_t = float(np.sum(sw_t_norm * treated_change))
|
|
2110
|
+
mu_c = float(np.sum(sw_c_norm * control_change))
|
|
2111
|
+
att = mu_t - mu_c
|
|
2112
|
+
|
|
2113
|
+
# Influence function (survey-weighted)
|
|
2114
|
+
inf_treated = sw_t_norm * (treated_change - mu_t)
|
|
2115
|
+
inf_control = -sw_c_norm * (control_change - mu_c)
|
|
2116
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2117
|
+
|
|
2118
|
+
# SE from influence function variance
|
|
2119
|
+
se = (
|
|
2120
|
+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
|
|
2121
|
+
if (n_t > 0 and n_c > 0)
|
|
2122
|
+
else 0.0
|
|
2123
|
+
)
|
|
2124
|
+
else:
|
|
2125
|
+
att = float(np.mean(treated_change) - np.mean(control_change))
|
|
2126
|
+
|
|
2127
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
2128
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
2129
|
+
se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
|
|
2130
|
+
|
|
2131
|
+
# Influence function (for aggregation)
|
|
2132
|
+
inf_treated = treated_change - np.mean(treated_change)
|
|
2133
|
+
inf_control = control_change - np.mean(control_change)
|
|
2134
|
+
inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
|
|
2135
|
+
|
|
2136
|
+
return att, se, inf_func
|
|
2137
|
+
|
|
2138
|
+
def _ipw_estimation(
|
|
2139
|
+
self,
|
|
2140
|
+
treated_change: np.ndarray,
|
|
2141
|
+
control_change: np.ndarray,
|
|
2142
|
+
n_treated: int,
|
|
2143
|
+
n_control: int,
|
|
2144
|
+
X_treated: Optional[np.ndarray] = None,
|
|
2145
|
+
X_control: Optional[np.ndarray] = None,
|
|
2146
|
+
pscore_cache: Optional[Dict] = None,
|
|
2147
|
+
pscore_key: Optional[Any] = None,
|
|
2148
|
+
sw_treated: Optional[np.ndarray] = None,
|
|
2149
|
+
sw_control: Optional[np.ndarray] = None,
|
|
2150
|
+
sw_all: Optional[np.ndarray] = None,
|
|
2151
|
+
context_label: str = "",
|
|
2152
|
+
epv_diagnostics_out: Optional[dict] = None,
|
|
2153
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
2154
|
+
"""
|
|
2155
|
+
Estimate ATT using inverse probability weighting.
|
|
2156
|
+
|
|
2157
|
+
With covariates:
|
|
2158
|
+
1. Estimate propensity score P(D=1|X) using logistic regression
|
|
2159
|
+
2. Reweight control units to match treated covariate distribution
|
|
2160
|
+
3. ATT = mean(treated) - weighted_mean(control)
|
|
2161
|
+
|
|
2162
|
+
Without covariates:
|
|
2163
|
+
Simple difference in means with unconditional propensity weighting.
|
|
2164
|
+
|
|
2165
|
+
Parameters
|
|
2166
|
+
----------
|
|
2167
|
+
sw_treated, sw_control, sw_all : np.ndarray, optional
|
|
2168
|
+
Survey weights for treated, control, and all units.
|
|
2169
|
+
"""
|
|
2170
|
+
n_t = len(treated_change)
|
|
2171
|
+
n_c = len(control_change)
|
|
2172
|
+
n_total = n_treated + n_control
|
|
2173
|
+
|
|
2174
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
2175
|
+
# Covariate-adjusted IPW estimation
|
|
2176
|
+
ps_fallback_used = False
|
|
2177
|
+
# Check propensity score cache
|
|
2178
|
+
cached_pscore = None
|
|
2179
|
+
if pscore_cache is not None and pscore_key is not None:
|
|
2180
|
+
cached_pscore = pscore_cache.get(pscore_key)
|
|
2181
|
+
|
|
2182
|
+
if cached_pscore is not None:
|
|
2183
|
+
# Use cached propensity scores (beta coefficients + EPV diag)
|
|
2184
|
+
beta_logistic, cached_diag = cached_pscore
|
|
2185
|
+
X_all = np.vstack([X_treated, X_control])
|
|
2186
|
+
X_all_with_intercept = np.column_stack([np.ones(n_t + n_c), X_all])
|
|
2187
|
+
z = np.dot(X_all_with_intercept, beta_logistic)
|
|
2188
|
+
z = np.clip(z, -500, 500)
|
|
2189
|
+
pscore = 1 / (1 + np.exp(-z))
|
|
2190
|
+
if epv_diagnostics_out is not None and cached_diag:
|
|
2191
|
+
epv_diagnostics_out.update(cached_diag)
|
|
2192
|
+
else:
|
|
2193
|
+
# Stack covariates and create treatment indicator
|
|
2194
|
+
X_all = np.vstack([X_treated, X_control])
|
|
2195
|
+
D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
2196
|
+
|
|
2197
|
+
# Estimate propensity scores using IRLS logistic regression
|
|
2198
|
+
diag = {}
|
|
2199
|
+
try:
|
|
2200
|
+
beta_logistic, pscore = solve_logit(
|
|
2201
|
+
X_all,
|
|
2202
|
+
D,
|
|
2203
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
2204
|
+
weights=sw_all,
|
|
2205
|
+
epv_threshold=self.epv_threshold,
|
|
2206
|
+
context_label=context_label,
|
|
2207
|
+
diagnostics_out=diag,
|
|
2208
|
+
)
|
|
2209
|
+
_check_propensity_diagnostics(pscore, self.pscore_trim)
|
|
2210
|
+
# Cache the fitted coefficients (zero-fill NaN from
|
|
2211
|
+
# dropped rank-deficient columns to prevent NaN
|
|
2212
|
+
# propagation on cache reuse) alongside EPV diagnostics
|
|
2213
|
+
if pscore_cache is not None and pscore_key is not None:
|
|
2214
|
+
beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0)
|
|
2215
|
+
pscore_cache[pscore_key] = (beta_clean, diag)
|
|
2216
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
2217
|
+
if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
|
|
2218
|
+
raise
|
|
2219
|
+
# Fallback to unconditional if logistic regression fails
|
|
2220
|
+
ctx = f" for {context_label}" if context_label else ""
|
|
2221
|
+
warnings.warn(
|
|
2222
|
+
f"Propensity score estimation failed{ctx}. "
|
|
2223
|
+
f"Falling back to unconditional propensity "
|
|
2224
|
+
f"(all covariates dropped for this cell). "
|
|
2225
|
+
f"Consider estimation_method='reg' to avoid "
|
|
2226
|
+
f"propensity scores entirely.",
|
|
2227
|
+
UserWarning,
|
|
2228
|
+
stacklevel=4,
|
|
2229
|
+
)
|
|
2230
|
+
if sw_all is not None:
|
|
2231
|
+
pos = sw_all > 0
|
|
2232
|
+
p_uc = float(np.average(D[pos], weights=sw_all[pos]))
|
|
2233
|
+
else:
|
|
2234
|
+
p_uc = n_t / (n_t + n_c)
|
|
2235
|
+
pscore = np.full(len(D), p_uc)
|
|
2236
|
+
ps_fallback_used = True
|
|
2237
|
+
if epv_diagnostics_out is not None and diag:
|
|
2238
|
+
epv_diagnostics_out.update(diag)
|
|
2239
|
+
|
|
2240
|
+
# Propensity scores for treated and control
|
|
2241
|
+
pscore_treated = pscore[:n_t]
|
|
2242
|
+
pscore_control = pscore[n_t:]
|
|
2243
|
+
|
|
2244
|
+
# Clip propensity scores to avoid extreme weights
|
|
2245
|
+
pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim)
|
|
2246
|
+
pscore_treated = np.clip(pscore_treated, self.pscore_trim, 1 - self.pscore_trim)
|
|
2247
|
+
|
|
2248
|
+
if sw_treated is not None:
|
|
2249
|
+
# IPW weights compose with survey weights:
|
|
2250
|
+
# w_i = sw_i * p(X_i) / (1 - p(X_i))
|
|
2251
|
+
weights_control = sw_control * pscore_control / (1 - pscore_control)
|
|
2252
|
+
weights_control_norm = weights_control / np.sum(weights_control)
|
|
2253
|
+
|
|
2254
|
+
# ATT: survey-weighted treated mean minus composite-weighted control mean
|
|
2255
|
+
sw_t_norm = sw_treated / np.sum(sw_treated)
|
|
2256
|
+
mu_t = float(np.sum(sw_t_norm * treated_change))
|
|
2257
|
+
att = mu_t - float(np.sum(weights_control_norm * control_change))
|
|
2258
|
+
|
|
2259
|
+
# Influence function (survey-weighted)
|
|
2260
|
+
inf_treated = sw_t_norm * (treated_change - mu_t)
|
|
2261
|
+
inf_control = -weights_control_norm * (
|
|
2262
|
+
control_change - np.sum(weights_control_norm * control_change)
|
|
2263
|
+
)
|
|
2264
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2265
|
+
|
|
2266
|
+
if not ps_fallback_used:
|
|
2267
|
+
# Propensity score IF correction
|
|
2268
|
+
# Accounts for estimation uncertainty in logistic regression coefficients
|
|
2269
|
+
X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
|
|
2270
|
+
pscore_all = np.concatenate([pscore_treated, pscore_control])
|
|
2271
|
+
|
|
2272
|
+
# PS IF correction — compute in R's psi convention, convert to phi
|
|
2273
|
+
n_all_panel = n_t + n_c
|
|
2274
|
+
W_ps = pscore_all * (1 - pscore_all)
|
|
2275
|
+
if sw_all is not None:
|
|
2276
|
+
W_ps = W_ps * sw_all
|
|
2277
|
+
# R: Hessian.ps = crossprod(X * sqrt(W)) / n
|
|
2278
|
+
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
|
|
2279
|
+
H_psi_inv = _safe_inv(H_psi)
|
|
2280
|
+
|
|
2281
|
+
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
2282
|
+
score_ps = (D_all - pscore_all)[:, None] * X_all_int
|
|
2283
|
+
if sw_all is not None:
|
|
2284
|
+
score_ps = score_ps * sw_all[:, None]
|
|
2285
|
+
# R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs)
|
|
2286
|
+
asy_lin_rep_psi = score_ps @ H_psi_inv
|
|
2287
|
+
|
|
2288
|
+
att_control_weighted = np.sum(weights_control_norm * control_change)
|
|
2289
|
+
# R: M2 = colMeans(w.cont * (y - att) * X) / mean(w.cont)
|
|
2290
|
+
# np.sum (not mean): subset sum with normalized weights matches
|
|
2291
|
+
# R's full-sample colMeans/mean(w) after cancellation
|
|
2292
|
+
M2 = np.sum(
|
|
2293
|
+
(weights_control_norm * (control_change - att_control_weighted))[:, None]
|
|
2294
|
+
* X_all_int[n_t:],
|
|
2295
|
+
axis=0,
|
|
2296
|
+
)
|
|
2297
|
+
|
|
2298
|
+
# psi-scale correction, convert to phi for storage
|
|
2299
|
+
# Subtract: R adds PS correction to inf.control, then att = treat - control
|
|
2300
|
+
inf_func = inf_func - (asy_lin_rep_psi @ M2) / n_all_panel
|
|
2301
|
+
|
|
2302
|
+
# SE from influence function variance
|
|
2303
|
+
var_psi = np.sum(inf_func**2)
|
|
2304
|
+
se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
|
|
2305
|
+
else:
|
|
2306
|
+
# IPW weights for control units: p(X) / (1 - p(X))
|
|
2307
|
+
# This reweights controls to have same covariate distribution as treated
|
|
2308
|
+
weights_control = pscore_control / (1 - pscore_control)
|
|
2309
|
+
weights_control = weights_control / np.sum(weights_control) # normalize
|
|
2310
|
+
|
|
2311
|
+
# ATT = mean(treated) - weighted_mean(control)
|
|
2312
|
+
att = float(np.mean(treated_change) - np.sum(weights_control * control_change))
|
|
2313
|
+
|
|
2314
|
+
# Compute standard error
|
|
2315
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
2316
|
+
|
|
2317
|
+
weighted_var_c = np.sum(
|
|
2318
|
+
weights_control
|
|
2319
|
+
* (control_change - np.sum(weights_control * control_change)) ** 2
|
|
2320
|
+
)
|
|
2321
|
+
|
|
2322
|
+
se = float(np.sqrt(var_t / n_t + weighted_var_c)) if (n_t > 0 and n_c > 0) else 0.0
|
|
2323
|
+
|
|
2324
|
+
# Influence function
|
|
2325
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
2326
|
+
inf_control = -weights_control * (
|
|
2327
|
+
control_change - np.sum(weights_control * control_change)
|
|
2328
|
+
)
|
|
2329
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2330
|
+
else:
|
|
2331
|
+
# Unconditional IPW (reduces to difference in means)
|
|
2332
|
+
if sw_treated is not None:
|
|
2333
|
+
# Survey-weighted difference in means
|
|
2334
|
+
sw_t_norm = sw_treated / np.sum(sw_treated)
|
|
2335
|
+
sw_c_norm = sw_control / np.sum(sw_control)
|
|
2336
|
+
mu_t = float(np.sum(sw_t_norm * treated_change))
|
|
2337
|
+
mu_c = float(np.sum(sw_c_norm * control_change))
|
|
2338
|
+
att = mu_t - mu_c
|
|
2339
|
+
|
|
2340
|
+
inf_treated = sw_t_norm * (treated_change - mu_t)
|
|
2341
|
+
inf_control = -sw_c_norm * (control_change - mu_c)
|
|
2342
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2343
|
+
|
|
2344
|
+
se = (
|
|
2345
|
+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
|
|
2346
|
+
if (n_t > 0 and n_c > 0)
|
|
2347
|
+
else 0.0
|
|
2348
|
+
)
|
|
2349
|
+
else:
|
|
2350
|
+
p_treat = n_treated / n_total # unconditional propensity score
|
|
2351
|
+
|
|
2352
|
+
att = float(np.mean(treated_change) - np.mean(control_change))
|
|
2353
|
+
|
|
2354
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
2355
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
2356
|
+
|
|
2357
|
+
# Adjusted variance for IPW
|
|
2358
|
+
se = float(
|
|
2359
|
+
np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat))
|
|
2360
|
+
if (n_t > 0 and n_c > 0 and p_treat > 0)
|
|
2361
|
+
else 0.0
|
|
2362
|
+
)
|
|
2363
|
+
|
|
2364
|
+
# Influence function (for aggregation)
|
|
2365
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
2366
|
+
inf_control = (control_change - np.mean(control_change)) / n_c
|
|
2367
|
+
inf_func = np.concatenate([inf_treated, -inf_control])
|
|
2368
|
+
|
|
2369
|
+
return att, se, inf_func
|
|
2370
|
+
|
|
2371
|
+
def _doubly_robust(
|
|
2372
|
+
self,
|
|
2373
|
+
treated_change: np.ndarray,
|
|
2374
|
+
control_change: np.ndarray,
|
|
2375
|
+
X_treated: Optional[np.ndarray] = None,
|
|
2376
|
+
X_control: Optional[np.ndarray] = None,
|
|
2377
|
+
pscore_cache: Optional[Dict] = None,
|
|
2378
|
+
pscore_key: Optional[Any] = None,
|
|
2379
|
+
cho_cache: Optional[Dict] = None,
|
|
2380
|
+
cho_key: Optional[Any] = None,
|
|
2381
|
+
sw_treated: Optional[np.ndarray] = None,
|
|
2382
|
+
sw_control: Optional[np.ndarray] = None,
|
|
2383
|
+
sw_all: Optional[np.ndarray] = None,
|
|
2384
|
+
context_label: str = "",
|
|
2385
|
+
epv_diagnostics_out: Optional[dict] = None,
|
|
2386
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
2387
|
+
"""
|
|
2388
|
+
Estimate ATT using doubly robust estimation.
|
|
2389
|
+
|
|
2390
|
+
With covariates:
|
|
2391
|
+
Combines outcome regression and IPW for double robustness.
|
|
2392
|
+
The estimator is consistent if either the outcome model OR
|
|
2393
|
+
the propensity model is correctly specified.
|
|
2394
|
+
|
|
2395
|
+
ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
|
|
2396
|
+
+ (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
|
|
2397
|
+
|
|
2398
|
+
where m(X) is the outcome model and w_i are IPW weights.
|
|
2399
|
+
|
|
2400
|
+
Without covariates:
|
|
2401
|
+
Reduces to simple difference in means.
|
|
2402
|
+
|
|
2403
|
+
Parameters
|
|
2404
|
+
----------
|
|
2405
|
+
sw_treated, sw_control, sw_all : np.ndarray, optional
|
|
2406
|
+
Survey weights for treated, control, and all units.
|
|
2407
|
+
"""
|
|
2408
|
+
n_t = len(treated_change)
|
|
2409
|
+
n_c = len(control_change)
|
|
2410
|
+
|
|
2411
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
2412
|
+
# Doubly robust estimation with covariates
|
|
2413
|
+
ps_fallback_used = False
|
|
2414
|
+
# Step 1: Outcome regression - fit E[Delta Y | X] on control
|
|
2415
|
+
# Try Cholesky cache for outcome regression (disabled when survey weights present)
|
|
2416
|
+
beta = None
|
|
2417
|
+
X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
|
|
2418
|
+
if cho_cache is not None and cho_key is not None:
|
|
2419
|
+
cached_cho = cho_cache.get(cho_key)
|
|
2420
|
+
|
|
2421
|
+
if cached_cho is False:
|
|
2422
|
+
# Rank-deficient sentinel: skip Cholesky, fall through
|
|
2423
|
+
pass
|
|
2424
|
+
elif cached_cho is not None:
|
|
2425
|
+
Xty = X_control_with_intercept.T @ control_change
|
|
2426
|
+
beta = scipy_linalg.cho_solve(cached_cho, Xty)
|
|
2427
|
+
if np.any(~np.isfinite(beta)):
|
|
2428
|
+
beta = None
|
|
2429
|
+
else:
|
|
2430
|
+
# First time for this cho_key: check rank before Cholesky
|
|
2431
|
+
rank_info = _detect_rank_deficiency(X_control_with_intercept)
|
|
2432
|
+
if len(rank_info[1]) > 0:
|
|
2433
|
+
cho_cache[cho_key] = False # Sentinel
|
|
2434
|
+
else:
|
|
2435
|
+
XtX = X_control_with_intercept.T @ X_control_with_intercept
|
|
2436
|
+
try:
|
|
2437
|
+
cho_factor = scipy_linalg.cho_factor(XtX)
|
|
2438
|
+
cho_cache[cho_key] = cho_factor
|
|
2439
|
+
Xty = X_control_with_intercept.T @ control_change
|
|
2440
|
+
beta = scipy_linalg.cho_solve(cho_factor, Xty)
|
|
2441
|
+
if np.any(~np.isfinite(beta)):
|
|
2442
|
+
beta = None
|
|
2443
|
+
except np.linalg.LinAlgError:
|
|
2444
|
+
pass
|
|
2445
|
+
|
|
2446
|
+
if beta is None:
|
|
2447
|
+
beta, _ = _linear_regression(
|
|
2448
|
+
X_control,
|
|
2449
|
+
control_change,
|
|
2450
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
2451
|
+
weights=sw_control,
|
|
2452
|
+
)
|
|
2453
|
+
# Zero NaN coefficients for prediction only — dropped columns
|
|
2454
|
+
# contribute 0 to the column space projection. Note: solve_ols
|
|
2455
|
+
# deliberately uses NaN (R's lm() convention) for inference, but
|
|
2456
|
+
# here we only need beta for prediction (m_treated, m_control).
|
|
2457
|
+
beta = np.where(np.isfinite(beta), beta, 0.0)
|
|
2458
|
+
|
|
2459
|
+
# Predict counterfactual for both treated and control
|
|
2460
|
+
X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
|
|
2461
|
+
m_treated = np.dot(X_treated_with_intercept, beta)
|
|
2462
|
+
m_control = np.dot(X_control_with_intercept, beta)
|
|
2463
|
+
|
|
2464
|
+
# Step 2: Propensity score estimation
|
|
2465
|
+
# Check propensity score cache
|
|
2466
|
+
cached_pscore = None
|
|
2467
|
+
if pscore_cache is not None and pscore_key is not None:
|
|
2468
|
+
cached_pscore = pscore_cache.get(pscore_key)
|
|
2469
|
+
|
|
2470
|
+
if cached_pscore is not None:
|
|
2471
|
+
beta_logistic, cached_diag = cached_pscore
|
|
2472
|
+
X_all = np.vstack([X_treated, X_control])
|
|
2473
|
+
X_all_with_intercept = np.column_stack([np.ones(n_t + n_c), X_all])
|
|
2474
|
+
z = np.dot(X_all_with_intercept, beta_logistic)
|
|
2475
|
+
z = np.clip(z, -500, 500)
|
|
2476
|
+
pscore = 1 / (1 + np.exp(-z))
|
|
2477
|
+
if epv_diagnostics_out is not None and cached_diag:
|
|
2478
|
+
epv_diagnostics_out.update(cached_diag)
|
|
2479
|
+
else:
|
|
2480
|
+
X_all = np.vstack([X_treated, X_control])
|
|
2481
|
+
D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
2482
|
+
|
|
2483
|
+
diag = {}
|
|
2484
|
+
try:
|
|
2485
|
+
beta_logistic, pscore = solve_logit(
|
|
2486
|
+
X_all,
|
|
2487
|
+
D,
|
|
2488
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
2489
|
+
weights=sw_all,
|
|
2490
|
+
epv_threshold=self.epv_threshold,
|
|
2491
|
+
context_label=context_label,
|
|
2492
|
+
diagnostics_out=diag,
|
|
2493
|
+
)
|
|
2494
|
+
_check_propensity_diagnostics(pscore, self.pscore_trim)
|
|
2495
|
+
if pscore_cache is not None and pscore_key is not None:
|
|
2496
|
+
beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0)
|
|
2497
|
+
pscore_cache[pscore_key] = (beta_clean, diag)
|
|
2498
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
2499
|
+
if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
|
|
2500
|
+
raise
|
|
2501
|
+
# Fallback to unconditional if logistic regression fails
|
|
2502
|
+
ctx = f" for {context_label}" if context_label else ""
|
|
2503
|
+
warnings.warn(
|
|
2504
|
+
f"Propensity score estimation failed{ctx}. "
|
|
2505
|
+
f"Falling back to unconditional propensity "
|
|
2506
|
+
f"(propensity model ignores covariates; outcome "
|
|
2507
|
+
f"regression still uses them). "
|
|
2508
|
+
f"Consider estimation_method='reg' to avoid "
|
|
2509
|
+
f"propensity scores entirely.",
|
|
2510
|
+
UserWarning,
|
|
2511
|
+
stacklevel=4,
|
|
2512
|
+
)
|
|
2513
|
+
if sw_all is not None:
|
|
2514
|
+
pos = sw_all > 0
|
|
2515
|
+
p_uc = float(np.average(D[pos], weights=sw_all[pos]))
|
|
2516
|
+
else:
|
|
2517
|
+
p_uc = n_t / (n_t + n_c)
|
|
2518
|
+
pscore = np.full(len(D), p_uc)
|
|
2519
|
+
ps_fallback_used = True
|
|
2520
|
+
if epv_diagnostics_out is not None and diag:
|
|
2521
|
+
epv_diagnostics_out.update(diag)
|
|
2522
|
+
|
|
2523
|
+
pscore_control = pscore[n_t:]
|
|
2524
|
+
|
|
2525
|
+
# Clip propensity scores
|
|
2526
|
+
pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim)
|
|
2527
|
+
|
|
2528
|
+
if sw_treated is not None:
|
|
2529
|
+
# IPW weights compose with survey weights
|
|
2530
|
+
weights_control = sw_control * pscore_control / (1 - pscore_control)
|
|
2531
|
+
|
|
2532
|
+
# Step 3: DR ATT (survey-weighted)
|
|
2533
|
+
sw_t_sum = np.sum(sw_treated)
|
|
2534
|
+
att_treated_part = float(
|
|
2535
|
+
np.sum(sw_treated * (treated_change - m_treated)) / sw_t_sum
|
|
2536
|
+
)
|
|
2537
|
+
augmentation = float(
|
|
2538
|
+
np.sum(weights_control * (m_control - control_change)) / sw_t_sum
|
|
2539
|
+
)
|
|
2540
|
+
att = att_treated_part + augmentation
|
|
2541
|
+
|
|
2542
|
+
# Step 4: Influence function (survey-weighted DR)
|
|
2543
|
+
# Start with plug-in IF, then add nuisance parameter corrections
|
|
2544
|
+
# (Sant'Anna & Zhao 2020, Theorem 3.1)
|
|
2545
|
+
psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att)
|
|
2546
|
+
psi_control = (weights_control / sw_t_sum) * (m_control - control_change)
|
|
2547
|
+
inf_func = np.concatenate([psi_treated, psi_control])
|
|
2548
|
+
|
|
2549
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
2550
|
+
if not ps_fallback_used:
|
|
2551
|
+
# --- PS IF correction (mirrors IPW L1929-1961) ---
|
|
2552
|
+
# Accounts for propensity score estimation uncertainty
|
|
2553
|
+
X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
|
|
2554
|
+
pscore_treated_clipped = np.clip(
|
|
2555
|
+
pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim
|
|
2556
|
+
)
|
|
2557
|
+
pscore_all = np.concatenate([pscore_treated_clipped, pscore_control])
|
|
2558
|
+
|
|
2559
|
+
# PS IF correction — psi convention, convert to phi
|
|
2560
|
+
n_all_panel = n_t + n_c
|
|
2561
|
+
W_ps = pscore_all * (1 - pscore_all)
|
|
2562
|
+
if sw_all is not None:
|
|
2563
|
+
W_ps = W_ps * sw_all
|
|
2564
|
+
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
|
|
2565
|
+
H_psi_inv = _safe_inv(H_psi)
|
|
2566
|
+
|
|
2567
|
+
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
2568
|
+
score_ps = (D_all - pscore_all)[:, None] * X_all_int
|
|
2569
|
+
if sw_all is not None:
|
|
2570
|
+
score_ps = score_ps * sw_all[:, None]
|
|
2571
|
+
# R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale)
|
|
2572
|
+
asy_lin_rep_psi = score_ps @ H_psi_inv
|
|
2573
|
+
|
|
2574
|
+
dr_resid_control = m_control - control_change
|
|
2575
|
+
M2_dr = np.sum(
|
|
2576
|
+
((weights_control / sw_t_sum) * dr_resid_control)[:, None]
|
|
2577
|
+
* X_all_int[n_t:],
|
|
2578
|
+
axis=0,
|
|
2579
|
+
)
|
|
2580
|
+
inf_func = inf_func + (asy_lin_rep_psi @ M2_dr) / n_all_panel
|
|
2581
|
+
|
|
2582
|
+
# --- OR IF correction ---
|
|
2583
|
+
# Accounts for outcome regression estimation uncertainty
|
|
2584
|
+
X_c_int = X_control_with_intercept
|
|
2585
|
+
W_diag = sw_control if sw_control is not None else np.ones(n_c)
|
|
2586
|
+
XtWX = X_c_int.T @ (W_diag[:, None] * X_c_int)
|
|
2587
|
+
bread = _safe_inv(XtWX)
|
|
2588
|
+
|
|
2589
|
+
# M1: dATT/dbeta — gradient of DR ATT w.r.t. OR parameters
|
|
2590
|
+
X_t_int = X_treated_with_intercept
|
|
2591
|
+
M1 = (
|
|
2592
|
+
-np.sum(sw_treated[:, None] * X_t_int, axis=0)
|
|
2593
|
+
+ np.sum(weights_control[:, None] * X_c_int, axis=0)
|
|
2594
|
+
) / sw_t_sum
|
|
2595
|
+
|
|
2596
|
+
# OR asymptotic linear representation (control-only)
|
|
2597
|
+
resid_c = control_change - m_control
|
|
2598
|
+
asy_lin_rep_or = (W_diag * resid_c)[:, None] * X_c_int @ bread
|
|
2599
|
+
# Apply to control portion only (treated contribute zero)
|
|
2600
|
+
inf_func[n_t:] += asy_lin_rep_or @ M1
|
|
2601
|
+
|
|
2602
|
+
# Recompute SE from corrected IF
|
|
2603
|
+
var_psi = np.sum(inf_func**2)
|
|
2604
|
+
se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
|
|
2605
|
+
else:
|
|
2606
|
+
# IPW weights for control: p(X) / (1 - p(X))
|
|
2607
|
+
weights_control = pscore_control / (1 - pscore_control)
|
|
2608
|
+
|
|
2609
|
+
# Step 3: Doubly robust ATT
|
|
2610
|
+
att_treated_part = float(np.mean(treated_change - m_treated))
|
|
2611
|
+
augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t)
|
|
2612
|
+
att = att_treated_part + augmentation
|
|
2613
|
+
|
|
2614
|
+
# Step 4: Influence function with nuisance IF corrections
|
|
2615
|
+
psi_treated = (treated_change - m_treated - att) / n_t
|
|
2616
|
+
psi_control = (weights_control * (m_control - control_change)) / n_t
|
|
2617
|
+
inf_func = np.concatenate([psi_treated, psi_control])
|
|
2618
|
+
|
|
2619
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
2620
|
+
if not ps_fallback_used:
|
|
2621
|
+
# --- PS IF correction — psi convention, convert to phi ---
|
|
2622
|
+
n_all_panel = n_t + n_c
|
|
2623
|
+
X_all_int = np.column_stack([np.ones(n_all_panel), X_all])
|
|
2624
|
+
pscore_treated_clipped = np.clip(
|
|
2625
|
+
pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim
|
|
2626
|
+
)
|
|
2627
|
+
pscore_all = np.concatenate([pscore_treated_clipped, pscore_control])
|
|
2628
|
+
|
|
2629
|
+
W_ps = pscore_all * (1 - pscore_all)
|
|
2630
|
+
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
|
|
2631
|
+
H_psi_inv = _safe_inv(H_psi)
|
|
2632
|
+
|
|
2633
|
+
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
2634
|
+
score_ps = (D_all - pscore_all)[:, None] * X_all_int
|
|
2635
|
+
# R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale)
|
|
2636
|
+
asy_lin_rep_psi = score_ps @ H_psi_inv
|
|
2637
|
+
|
|
2638
|
+
dr_resid_control = m_control - control_change
|
|
2639
|
+
M2_dr = np.sum(
|
|
2640
|
+
((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:],
|
|
2641
|
+
axis=0,
|
|
2642
|
+
)
|
|
2643
|
+
inf_func = inf_func + (asy_lin_rep_psi @ M2_dr) / n_all_panel
|
|
2644
|
+
|
|
2645
|
+
# --- OR IF correction ---
|
|
2646
|
+
X_c_int = X_control_with_intercept
|
|
2647
|
+
XtX = X_c_int.T @ X_c_int
|
|
2648
|
+
bread = _safe_inv(XtX)
|
|
2649
|
+
|
|
2650
|
+
X_t_int = X_treated_with_intercept
|
|
2651
|
+
M1 = (
|
|
2652
|
+
-np.sum(X_t_int, axis=0)
|
|
2653
|
+
+ np.sum(weights_control[:, None] * X_c_int, axis=0)
|
|
2654
|
+
) / n_t
|
|
2655
|
+
|
|
2656
|
+
resid_c = control_change - m_control
|
|
2657
|
+
asy_lin_rep_or = resid_c[:, None] * X_c_int @ bread
|
|
2658
|
+
inf_func[n_t:] += asy_lin_rep_or @ M1
|
|
2659
|
+
|
|
2660
|
+
# Recompute SE from corrected IF
|
|
2661
|
+
var_psi = np.sum(inf_func**2)
|
|
2662
|
+
se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
|
|
2663
|
+
else:
|
|
2664
|
+
# Without covariates, DR simplifies to difference in means
|
|
2665
|
+
if sw_treated is not None:
|
|
2666
|
+
sw_t_norm = sw_treated / np.sum(sw_treated)
|
|
2667
|
+
sw_c_norm = sw_control / np.sum(sw_control)
|
|
2668
|
+
mu_t = float(np.sum(sw_t_norm * treated_change))
|
|
2669
|
+
mu_c = float(np.sum(sw_c_norm * control_change))
|
|
2670
|
+
att = mu_t - mu_c
|
|
2671
|
+
|
|
2672
|
+
inf_treated = sw_t_norm * (treated_change - mu_t)
|
|
2673
|
+
inf_control = -sw_c_norm * (control_change - mu_c)
|
|
2674
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
2675
|
+
|
|
2676
|
+
se = (
|
|
2677
|
+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
|
|
2678
|
+
if (n_t > 0 and n_c > 0)
|
|
2679
|
+
else 0.0
|
|
2680
|
+
)
|
|
2681
|
+
else:
|
|
2682
|
+
att = float(np.mean(treated_change) - np.mean(control_change))
|
|
2683
|
+
|
|
2684
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
2685
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
2686
|
+
|
|
2687
|
+
se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
|
|
2688
|
+
|
|
2689
|
+
# Influence function for DR estimator
|
|
2690
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
2691
|
+
inf_control = (control_change - np.mean(control_change)) / n_c
|
|
2692
|
+
inf_func = np.concatenate([inf_treated, -inf_control])
|
|
2693
|
+
|
|
2694
|
+
return att, se, inf_func
|
|
2695
|
+
|
|
2696
|
+
# =========================================================================
|
|
2697
|
+
# Repeated Cross-Section (RCS) methods
|
|
2698
|
+
# =========================================================================
|
|
2699
|
+
|
|
2700
|
+
def _precompute_structures_rc(
|
|
2701
|
+
self,
|
|
2702
|
+
df: pd.DataFrame,
|
|
2703
|
+
outcome: str,
|
|
2704
|
+
unit: str,
|
|
2705
|
+
time: str,
|
|
2706
|
+
first_treat: str,
|
|
2707
|
+
covariates: Optional[List[str]],
|
|
2708
|
+
time_periods: List[Any],
|
|
2709
|
+
treatment_groups: List[Any],
|
|
2710
|
+
resolved_survey=None,
|
|
2711
|
+
) -> PrecomputedData:
|
|
2712
|
+
"""
|
|
2713
|
+
Pre-compute observation-level structures for repeated cross-section.
|
|
2714
|
+
|
|
2715
|
+
Unlike the panel path, RCS does not pivot to wide format. Each
|
|
2716
|
+
observation is treated independently (no within-unit differencing).
|
|
2717
|
+
|
|
2718
|
+
Returns
|
|
2719
|
+
-------
|
|
2720
|
+
PrecomputedData
|
|
2721
|
+
Dictionary with pre-computed structures (observation-level).
|
|
2722
|
+
"""
|
|
2723
|
+
n_obs = len(df)
|
|
2724
|
+
|
|
2725
|
+
# Observation-level arrays (no pivot)
|
|
2726
|
+
obs_time = df[time].values
|
|
2727
|
+
obs_outcome = df[outcome].values
|
|
2728
|
+
unit_cohorts = df[first_treat].values
|
|
2729
|
+
|
|
2730
|
+
# "all_units" key holds integer observation indices for backward
|
|
2731
|
+
# compatibility with aggregation code
|
|
2732
|
+
all_units = np.arange(n_obs)
|
|
2733
|
+
|
|
2734
|
+
# Pre-compute cohort masks (boolean arrays, observation-level)
|
|
2735
|
+
cohort_masks = {}
|
|
2736
|
+
for g in treatment_groups:
|
|
2737
|
+
cohort_masks[g] = unit_cohorts == g
|
|
2738
|
+
|
|
2739
|
+
# Never-treated mask
|
|
2740
|
+
never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
|
|
2741
|
+
|
|
2742
|
+
# Period-to-column mapping (identity for RCS — used for base period checks)
|
|
2743
|
+
period_to_col = {t: i for i, t in enumerate(sorted(time_periods))}
|
|
2744
|
+
|
|
2745
|
+
# Covariates (observation-level, not per-period)
|
|
2746
|
+
obs_covariates = None
|
|
2747
|
+
if covariates:
|
|
2748
|
+
obs_covariates = df[covariates].values
|
|
2749
|
+
|
|
2750
|
+
# Survey weights (already per-observation for RCS)
|
|
2751
|
+
if resolved_survey is not None:
|
|
2752
|
+
survey_weights_arr = resolved_survey.weights.copy()
|
|
2753
|
+
else:
|
|
2754
|
+
survey_weights_arr = None
|
|
2755
|
+
|
|
2756
|
+
# For RCS, the resolved survey is already per-observation
|
|
2757
|
+
resolved_survey_rc = resolved_survey
|
|
2758
|
+
|
|
2759
|
+
# Fixed cohort masses: total observations per cohort across all periods.
|
|
2760
|
+
# Used as aggregation weights so that n_treated is consistent with WIF.
|
|
2761
|
+
rcs_cohort_masses = {}
|
|
2762
|
+
for g in treatment_groups:
|
|
2763
|
+
rcs_cohort_masses[g] = int(np.sum(unit_cohorts == g))
|
|
2764
|
+
|
|
2765
|
+
return {
|
|
2766
|
+
"all_units": all_units,
|
|
2767
|
+
"unit_to_idx": None, # RCS: obs indices are positions
|
|
2768
|
+
"unit_cohorts": unit_cohorts,
|
|
2769
|
+
"canonical_size": n_obs,
|
|
2770
|
+
"is_panel": False,
|
|
2771
|
+
"obs_time": obs_time,
|
|
2772
|
+
"obs_outcome": obs_outcome,
|
|
2773
|
+
"obs_covariates": obs_covariates,
|
|
2774
|
+
"cohort_masks": cohort_masks,
|
|
2775
|
+
"never_treated_mask": never_treated_mask,
|
|
2776
|
+
"time_periods": time_periods,
|
|
2777
|
+
"period_to_col": period_to_col,
|
|
2778
|
+
"is_balanced": False,
|
|
2779
|
+
"survey_weights": survey_weights_arr,
|
|
2780
|
+
"resolved_survey": resolved_survey,
|
|
2781
|
+
"resolved_survey_unit": resolved_survey_rc,
|
|
2782
|
+
"df_survey": (
|
|
2783
|
+
resolved_survey_rc.df_survey
|
|
2784
|
+
if resolved_survey_rc is not None and hasattr(resolved_survey_rc, "df_survey")
|
|
2785
|
+
else None
|
|
2786
|
+
),
|
|
2787
|
+
"rcs_cohort_masses": rcs_cohort_masses,
|
|
2788
|
+
}
|
|
2789
|
+
|
|
2790
|
+
def _compute_att_gt_rc(
|
|
2791
|
+
self,
|
|
2792
|
+
precomputed: PrecomputedData,
|
|
2793
|
+
g: Any,
|
|
2794
|
+
t: Any,
|
|
2795
|
+
covariates: Optional[List[str]],
|
|
2796
|
+
epv_diagnostics: Optional[Dict] = None,
|
|
2797
|
+
) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]:
|
|
2798
|
+
"""
|
|
2799
|
+
Compute ATT(g,t) for repeated cross-section data.
|
|
2800
|
+
|
|
2801
|
+
For RCS, the 2x2 DiD compares outcomes across two independent
|
|
2802
|
+
cross-sections (periods t and base period s) rather than
|
|
2803
|
+
within-unit changes.
|
|
2804
|
+
|
|
2805
|
+
Returns
|
|
2806
|
+
-------
|
|
2807
|
+
att_gt : float or None
|
|
2808
|
+
se_gt : float
|
|
2809
|
+
n_treated : int (treated obs at period t)
|
|
2810
|
+
n_control : int (control obs at period t)
|
|
2811
|
+
inf_func_info : dict or None
|
|
2812
|
+
survey_weight_sum : float or None
|
|
2813
|
+
"""
|
|
2814
|
+
cohort_masks = precomputed["cohort_masks"]
|
|
2815
|
+
never_treated_mask = precomputed["never_treated_mask"]
|
|
2816
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
2817
|
+
obs_time = precomputed["obs_time"]
|
|
2818
|
+
obs_outcome = precomputed["obs_outcome"]
|
|
2819
|
+
period_to_col = precomputed["period_to_col"]
|
|
2820
|
+
|
|
2821
|
+
# Base period selection (same logic as panel)
|
|
2822
|
+
if self.base_period == "universal":
|
|
2823
|
+
base_period_val = g - 1 - self.anticipation
|
|
2824
|
+
else: # varying
|
|
2825
|
+
if t < g - self.anticipation:
|
|
2826
|
+
base_period_val = t - 1
|
|
2827
|
+
else:
|
|
2828
|
+
base_period_val = g - 1 - self.anticipation
|
|
2829
|
+
|
|
2830
|
+
if base_period_val not in period_to_col or t not in period_to_col:
|
|
2831
|
+
return None, 0.0, 0, 0, None, None
|
|
2832
|
+
|
|
2833
|
+
# Treated mask = cohort g
|
|
2834
|
+
treated_mask = cohort_masks[g]
|
|
2835
|
+
|
|
2836
|
+
# Control mask (same logic as panel)
|
|
2837
|
+
if self.control_group == "never_treated":
|
|
2838
|
+
control_mask = never_treated_mask
|
|
2839
|
+
else: # not_yet_treated
|
|
2840
|
+
nyt_threshold = max(t, base_period_val) + self.anticipation
|
|
2841
|
+
control_mask = never_treated_mask | (
|
|
2842
|
+
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
|
|
2843
|
+
)
|
|
2844
|
+
|
|
2845
|
+
# Period masks
|
|
2846
|
+
at_t = obs_time == t
|
|
2847
|
+
at_s = obs_time == base_period_val
|
|
2848
|
+
|
|
2849
|
+
# 4 groups of observations
|
|
2850
|
+
treated_t = treated_mask & at_t
|
|
2851
|
+
treated_s = treated_mask & at_s
|
|
2852
|
+
control_t = control_mask & at_t
|
|
2853
|
+
control_s = control_mask & at_s
|
|
2854
|
+
|
|
2855
|
+
n_gt = int(np.sum(treated_t))
|
|
2856
|
+
n_gs = int(np.sum(treated_s))
|
|
2857
|
+
n_ct = int(np.sum(control_t))
|
|
2858
|
+
n_cs = int(np.sum(control_s))
|
|
2859
|
+
|
|
2860
|
+
if n_gt == 0 or n_ct == 0 or n_gs == 0 or n_cs == 0:
|
|
2861
|
+
return None, 0.0, 0, 0, None, None
|
|
2862
|
+
|
|
2863
|
+
# Extract outcomes for each group
|
|
2864
|
+
y_gt = obs_outcome[treated_t]
|
|
2865
|
+
y_gs = obs_outcome[treated_s]
|
|
2866
|
+
y_ct = obs_outcome[control_t]
|
|
2867
|
+
y_cs = obs_outcome[control_s]
|
|
2868
|
+
|
|
2869
|
+
# Survey weights
|
|
2870
|
+
survey_w = precomputed.get("survey_weights")
|
|
2871
|
+
sw_gt = survey_w[treated_t] if survey_w is not None else None
|
|
2872
|
+
sw_gs = survey_w[treated_s] if survey_w is not None else None
|
|
2873
|
+
sw_ct = survey_w[control_t] if survey_w is not None else None
|
|
2874
|
+
sw_cs = survey_w[control_s] if survey_w is not None else None
|
|
2875
|
+
|
|
2876
|
+
# Guard against zero effective mass
|
|
2877
|
+
if sw_gt is not None:
|
|
2878
|
+
if np.sum(sw_gt) <= 0 or np.sum(sw_gs) <= 0:
|
|
2879
|
+
return None, 0.0, 0, 0, None, None
|
|
2880
|
+
if np.sum(sw_ct) <= 0 or np.sum(sw_cs) <= 0:
|
|
2881
|
+
return None, 0.0, 0, 0, None, None
|
|
2882
|
+
|
|
2883
|
+
# Get covariates if specified
|
|
2884
|
+
obs_covariates = precomputed.get("obs_covariates")
|
|
2885
|
+
has_covariates = covariates is not None and obs_covariates is not None
|
|
2886
|
+
|
|
2887
|
+
if has_covariates:
|
|
2888
|
+
X_gt = obs_covariates[treated_t]
|
|
2889
|
+
X_gs = obs_covariates[treated_s]
|
|
2890
|
+
X_ct = obs_covariates[control_t]
|
|
2891
|
+
X_cs = obs_covariates[control_s]
|
|
2892
|
+
|
|
2893
|
+
# Check for NaN in covariates
|
|
2894
|
+
if (
|
|
2895
|
+
np.any(np.isnan(X_gt))
|
|
2896
|
+
or np.any(np.isnan(X_gs))
|
|
2897
|
+
or np.any(np.isnan(X_ct))
|
|
2898
|
+
or np.any(np.isnan(X_cs))
|
|
2899
|
+
):
|
|
2900
|
+
warnings.warn(
|
|
2901
|
+
f"Missing values in covariates for group {g}, time {t} (RCS). "
|
|
2902
|
+
"Falling back to unconditional estimation.",
|
|
2903
|
+
UserWarning,
|
|
2904
|
+
stacklevel=3,
|
|
2905
|
+
)
|
|
2906
|
+
has_covariates = False
|
|
2907
|
+
|
|
2908
|
+
if has_covariates and self.estimation_method == "reg":
|
|
2909
|
+
att, se, inf_func_all, idx_all = self._outcome_regression_rc(
|
|
2910
|
+
y_gt,
|
|
2911
|
+
y_gs,
|
|
2912
|
+
y_ct,
|
|
2913
|
+
y_cs,
|
|
2914
|
+
X_gt,
|
|
2915
|
+
X_gs,
|
|
2916
|
+
X_ct,
|
|
2917
|
+
X_cs,
|
|
2918
|
+
sw_gt=sw_gt,
|
|
2919
|
+
sw_gs=sw_gs,
|
|
2920
|
+
sw_ct=sw_ct,
|
|
2921
|
+
sw_cs=sw_cs,
|
|
2922
|
+
)
|
|
2923
|
+
elif has_covariates and self.estimation_method == "ipw":
|
|
2924
|
+
epv_diag: dict = {}
|
|
2925
|
+
att, se, inf_func_all, idx_all = self._ipw_estimation_rc(
|
|
2926
|
+
y_gt,
|
|
2927
|
+
y_gs,
|
|
2928
|
+
y_ct,
|
|
2929
|
+
y_cs,
|
|
2930
|
+
X_gt,
|
|
2931
|
+
X_gs,
|
|
2932
|
+
X_ct,
|
|
2933
|
+
X_cs,
|
|
2934
|
+
sw_gt=sw_gt,
|
|
2935
|
+
sw_gs=sw_gs,
|
|
2936
|
+
sw_ct=sw_ct,
|
|
2937
|
+
sw_cs=sw_cs,
|
|
2938
|
+
context_label=f"cohort g={g}",
|
|
2939
|
+
epv_diagnostics_out=epv_diag,
|
|
2940
|
+
)
|
|
2941
|
+
if epv_diagnostics is not None and epv_diag:
|
|
2942
|
+
epv_diagnostics[(g, t)] = epv_diag
|
|
2943
|
+
elif has_covariates and self.estimation_method == "dr":
|
|
2944
|
+
epv_diag = {}
|
|
2945
|
+
att, se, inf_func_all, idx_all = self._doubly_robust_rc(
|
|
2946
|
+
y_gt,
|
|
2947
|
+
y_gs,
|
|
2948
|
+
y_ct,
|
|
2949
|
+
y_cs,
|
|
2950
|
+
X_gt,
|
|
2951
|
+
X_gs,
|
|
2952
|
+
X_ct,
|
|
2953
|
+
X_cs,
|
|
2954
|
+
sw_gt=sw_gt,
|
|
2955
|
+
sw_gs=sw_gs,
|
|
2956
|
+
sw_ct=sw_ct,
|
|
2957
|
+
sw_cs=sw_cs,
|
|
2958
|
+
context_label=f"cohort g={g}",
|
|
2959
|
+
epv_diagnostics_out=epv_diag,
|
|
2960
|
+
)
|
|
2961
|
+
if epv_diagnostics is not None and epv_diag:
|
|
2962
|
+
epv_diagnostics[(g, t)] = epv_diag
|
|
2963
|
+
else:
|
|
2964
|
+
# No-covariates 2x2 DiD (all methods reduce to same)
|
|
2965
|
+
att, se, inf_func_all, idx_all = self._rc_2x2_did(
|
|
2966
|
+
y_gt,
|
|
2967
|
+
y_gs,
|
|
2968
|
+
y_ct,
|
|
2969
|
+
y_cs,
|
|
2970
|
+
treated_t,
|
|
2971
|
+
treated_s,
|
|
2972
|
+
control_t,
|
|
2973
|
+
control_s,
|
|
2974
|
+
sw_gt=sw_gt,
|
|
2975
|
+
sw_gs=sw_gs,
|
|
2976
|
+
sw_ct=sw_ct,
|
|
2977
|
+
sw_cs=sw_cs,
|
|
2978
|
+
)
|
|
2979
|
+
|
|
2980
|
+
# Build influence function info
|
|
2981
|
+
# For RCS, treated_idx/control_idx combine obs from BOTH periods
|
|
2982
|
+
treated_idx = np.concatenate([np.where(treated_t)[0], np.where(treated_s)[0]])
|
|
2983
|
+
control_idx = np.concatenate([np.where(control_t)[0], np.where(control_s)[0]])
|
|
2984
|
+
|
|
2985
|
+
n_treated_combined = len(treated_idx)
|
|
2986
|
+
inf_func_info = {
|
|
2987
|
+
"treated_idx": treated_idx,
|
|
2988
|
+
"control_idx": control_idx,
|
|
2989
|
+
"treated_units": treated_idx, # For RCS, obs indices = "units"
|
|
2990
|
+
"control_units": control_idx,
|
|
2991
|
+
"treated_inf": inf_func_all[:n_treated_combined],
|
|
2992
|
+
"control_inf": inf_func_all[n_treated_combined:],
|
|
2993
|
+
}
|
|
2994
|
+
|
|
2995
|
+
sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None
|
|
2996
|
+
# n_treated = per-cell treated count at period t (for display).
|
|
2997
|
+
# cohort_mass = total treated across all periods (for aggregation weights).
|
|
2998
|
+
cohort_mass = precomputed.get("rcs_cohort_masses", {}).get(g, n_gt)
|
|
2999
|
+
return att, se, n_gt, n_ct, inf_func_info, sw_sum, cohort_mass
|
|
3000
|
+
|
|
3001
|
+
def _rc_2x2_did(
|
|
3002
|
+
self,
|
|
3003
|
+
y_gt,
|
|
3004
|
+
y_gs,
|
|
3005
|
+
y_ct,
|
|
3006
|
+
y_cs,
|
|
3007
|
+
mask_gt,
|
|
3008
|
+
mask_gs,
|
|
3009
|
+
mask_ct,
|
|
3010
|
+
mask_cs,
|
|
3011
|
+
sw_gt=None,
|
|
3012
|
+
sw_gs=None,
|
|
3013
|
+
sw_ct=None,
|
|
3014
|
+
sw_cs=None,
|
|
3015
|
+
):
|
|
3016
|
+
"""
|
|
3017
|
+
Compute the basic 2x2 DiD for RCS (no covariates).
|
|
3018
|
+
|
|
3019
|
+
ATT = (mean(Y_treated_t) - mean(Y_control_t))
|
|
3020
|
+
- (mean(Y_treated_s) - mean(Y_control_s))
|
|
3021
|
+
|
|
3022
|
+
Returns (att, se, inf_func_concat, idx_concat) where inf_func_concat
|
|
3023
|
+
has treated obs (both periods) first, then control obs (both periods).
|
|
3024
|
+
"""
|
|
3025
|
+
n_gt = len(y_gt)
|
|
3026
|
+
n_gs = len(y_gs)
|
|
3027
|
+
n_ct = len(y_ct)
|
|
3028
|
+
n_cs = len(y_cs)
|
|
3029
|
+
|
|
3030
|
+
if sw_gt is not None:
|
|
3031
|
+
sw_gt_norm = sw_gt / np.sum(sw_gt)
|
|
3032
|
+
sw_gs_norm = sw_gs / np.sum(sw_gs)
|
|
3033
|
+
sw_ct_norm = sw_ct / np.sum(sw_ct)
|
|
3034
|
+
sw_cs_norm = sw_cs / np.sum(sw_cs)
|
|
3035
|
+
|
|
3036
|
+
mu_gt = float(np.sum(sw_gt_norm * y_gt))
|
|
3037
|
+
mu_gs = float(np.sum(sw_gs_norm * y_gs))
|
|
3038
|
+
mu_ct = float(np.sum(sw_ct_norm * y_ct))
|
|
3039
|
+
mu_cs = float(np.sum(sw_cs_norm * y_cs))
|
|
3040
|
+
|
|
3041
|
+
att = (mu_gt - mu_ct) - (mu_gs - mu_cs)
|
|
3042
|
+
|
|
3043
|
+
# Influence function for 4 groups (survey-weighted)
|
|
3044
|
+
inf_gt = sw_gt_norm * (y_gt - mu_gt)
|
|
3045
|
+
inf_ct = -sw_ct_norm * (y_ct - mu_ct)
|
|
3046
|
+
inf_gs = -sw_gs_norm * (y_gs - mu_gs)
|
|
3047
|
+
inf_cs = sw_cs_norm * (y_cs - mu_cs)
|
|
3048
|
+
else:
|
|
3049
|
+
mu_gt = float(np.mean(y_gt))
|
|
3050
|
+
mu_gs = float(np.mean(y_gs))
|
|
3051
|
+
mu_ct = float(np.mean(y_ct))
|
|
3052
|
+
mu_cs = float(np.mean(y_cs))
|
|
3053
|
+
|
|
3054
|
+
att = (mu_gt - mu_ct) - (mu_gs - mu_cs)
|
|
3055
|
+
|
|
3056
|
+
# Influence function for 4 groups
|
|
3057
|
+
inf_gt = (y_gt - mu_gt) / n_gt
|
|
3058
|
+
inf_ct = -(y_ct - mu_ct) / n_ct
|
|
3059
|
+
inf_gs = -(y_gs - mu_gs) / n_gs
|
|
3060
|
+
inf_cs = (y_cs - mu_cs) / n_cs
|
|
3061
|
+
|
|
3062
|
+
# Concatenate: treated (t then s), control (t then s)
|
|
3063
|
+
inf_treated = np.concatenate([inf_gt, inf_gs])
|
|
3064
|
+
inf_control = np.concatenate([inf_ct, inf_cs])
|
|
3065
|
+
inf_all = np.concatenate([inf_treated, inf_control])
|
|
3066
|
+
|
|
3067
|
+
# SE from influence function
|
|
3068
|
+
se = float(np.sqrt(np.sum(inf_all**2)))
|
|
3069
|
+
|
|
3070
|
+
idx_all = np.concatenate(
|
|
3071
|
+
[
|
|
3072
|
+
np.where(mask_gt)[0],
|
|
3073
|
+
np.where(mask_gs)[0],
|
|
3074
|
+
np.where(mask_ct)[0],
|
|
3075
|
+
np.where(mask_cs)[0],
|
|
3076
|
+
]
|
|
3077
|
+
)
|
|
3078
|
+
|
|
3079
|
+
return att, se, inf_all, idx_all
|
|
3080
|
+
|
|
3081
|
+
def _outcome_regression_rc(
|
|
3082
|
+
self,
|
|
3083
|
+
y_gt,
|
|
3084
|
+
y_gs,
|
|
3085
|
+
y_ct,
|
|
3086
|
+
y_cs,
|
|
3087
|
+
X_gt,
|
|
3088
|
+
X_gs,
|
|
3089
|
+
X_ct,
|
|
3090
|
+
X_cs,
|
|
3091
|
+
sw_gt=None,
|
|
3092
|
+
sw_gs=None,
|
|
3093
|
+
sw_ct=None,
|
|
3094
|
+
sw_cs=None,
|
|
3095
|
+
):
|
|
3096
|
+
"""
|
|
3097
|
+
Cross-sectional outcome regression for ATT(g,t).
|
|
3098
|
+
|
|
3099
|
+
Matches R DRDID::reg_did_rc (Sant'Anna & Zhao 2020, Eq 2.2).
|
|
3100
|
+
|
|
3101
|
+
Two OLS models fit on controls (period t and base period s).
|
|
3102
|
+
Predictions made for ALL treated (both periods).
|
|
3103
|
+
OR correction pools ALL treated observations across both periods.
|
|
3104
|
+
|
|
3105
|
+
IF convention
|
|
3106
|
+
-------------
|
|
3107
|
+
Intermediate terms use R's unnormalized psi_i convention throughout.
|
|
3108
|
+
R computes SE as ``sd(psi) / sqrt(n)``; with mean(psi) approx 0 this
|
|
3109
|
+
equals ``sqrt(sum(psi^2)) / n``. At the end we convert to the
|
|
3110
|
+
library's pre-scaled phi_i = psi_i / n convention where
|
|
3111
|
+
``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer.
|
|
3112
|
+
|
|
3113
|
+
Returns (att, se, inf_func_concat, idx_concat).
|
|
3114
|
+
"""
|
|
3115
|
+
n_gt = len(y_gt)
|
|
3116
|
+
n_gs = len(y_gs)
|
|
3117
|
+
n_ct = len(y_ct)
|
|
3118
|
+
n_cs = len(y_cs)
|
|
3119
|
+
n_all = n_gt + n_gs + n_ct + n_cs
|
|
3120
|
+
|
|
3121
|
+
# --- Fit 2 OLS on control groups (period t and s separately) ---
|
|
3122
|
+
beta_t, resid_ct = _linear_regression(
|
|
3123
|
+
X_ct,
|
|
3124
|
+
y_ct,
|
|
3125
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3126
|
+
weights=sw_ct,
|
|
3127
|
+
)
|
|
3128
|
+
beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0)
|
|
3129
|
+
|
|
3130
|
+
beta_s, resid_cs = _linear_regression(
|
|
3131
|
+
X_cs,
|
|
3132
|
+
y_cs,
|
|
3133
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3134
|
+
weights=sw_cs,
|
|
3135
|
+
)
|
|
3136
|
+
beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0)
|
|
3137
|
+
|
|
3138
|
+
# --- Predict counterfactual for ALL treated (both periods) ---
|
|
3139
|
+
X_gt_int = np.column_stack([np.ones(n_gt), X_gt])
|
|
3140
|
+
X_gs_int = np.column_stack([np.ones(n_gs), X_gs])
|
|
3141
|
+
X_ct_int = np.column_stack([np.ones(n_ct), X_ct])
|
|
3142
|
+
X_cs_int = np.column_stack([np.ones(n_cs), X_cs])
|
|
3143
|
+
|
|
3144
|
+
# mu_hat_{0,t}(X) and mu_hat_{0,s}(X) for each treated obs
|
|
3145
|
+
mu_post_gt = X_gt_int @ beta_t # treated-post predicted at post model
|
|
3146
|
+
mu_pre_gt = X_gt_int @ beta_s # treated-post predicted at pre model
|
|
3147
|
+
mu_post_gs = X_gs_int @ beta_t # treated-pre predicted at post model
|
|
3148
|
+
mu_pre_gs = X_gs_int @ beta_s # treated-pre predicted at pre model
|
|
3149
|
+
|
|
3150
|
+
# --- Group weights (R: w.treat.pre, w.treat.post, w.cont = w.D) ---
|
|
3151
|
+
if sw_gt is not None:
|
|
3152
|
+
w_treat_post = sw_gt # treated at t
|
|
3153
|
+
w_treat_pre = sw_gs # treated at s
|
|
3154
|
+
w_D_gt = sw_gt # ALL treated: t portion
|
|
3155
|
+
w_D_gs = sw_gs # ALL treated: s portion
|
|
3156
|
+
else:
|
|
3157
|
+
w_treat_post = np.ones(n_gt)
|
|
3158
|
+
w_treat_pre = np.ones(n_gs)
|
|
3159
|
+
w_D_gt = np.ones(n_gt)
|
|
3160
|
+
w_D_gs = np.ones(n_gs)
|
|
3161
|
+
|
|
3162
|
+
sum_w_treat_post = np.sum(w_treat_post)
|
|
3163
|
+
sum_w_treat_pre = np.sum(w_treat_pre)
|
|
3164
|
+
sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # pool ALL treated
|
|
3165
|
+
|
|
3166
|
+
# R: mean(w.treat.post), mean(w.treat.pre), mean(w.cont)
|
|
3167
|
+
mean_w_treat_post = sum_w_treat_post / n_all
|
|
3168
|
+
mean_w_treat_pre = sum_w_treat_pre / n_all
|
|
3169
|
+
mean_w_D = sum_w_D / n_all
|
|
3170
|
+
|
|
3171
|
+
# --- Treated means (period-specific Hajek means) ---
|
|
3172
|
+
eta_treat_post = np.sum(w_treat_post * y_gt) / sum_w_treat_post
|
|
3173
|
+
eta_treat_pre = np.sum(w_treat_pre * y_gs) / sum_w_treat_pre
|
|
3174
|
+
|
|
3175
|
+
# --- OR correction: pools ALL treated ---
|
|
3176
|
+
# R: out.y.post - out.y.pre for each treated obs
|
|
3177
|
+
or_diff_gt = mu_post_gt - mu_pre_gt # treated at t
|
|
3178
|
+
or_diff_gs = mu_post_gs - mu_pre_gs # treated at s
|
|
3179
|
+
eta_cont = (np.sum(w_D_gt * or_diff_gt) + np.sum(w_D_gs * or_diff_gs)) / sum_w_D
|
|
3180
|
+
|
|
3181
|
+
# --- Point estimate ---
|
|
3182
|
+
att = float(eta_treat_post - eta_treat_pre - eta_cont)
|
|
3183
|
+
|
|
3184
|
+
# =================================================================
|
|
3185
|
+
# Influence function in R's unnormalized psi convention
|
|
3186
|
+
# (R: reg_did_rc.R, psi = n * phi)
|
|
3187
|
+
# =================================================================
|
|
3188
|
+
|
|
3189
|
+
# --- Treated psi (R: eta.treat.post, eta.treat.pre) ---
|
|
3190
|
+
# R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post)
|
|
3191
|
+
psi_treat_post = w_treat_post * (y_gt - eta_treat_post) / mean_w_treat_post
|
|
3192
|
+
# R: w.treat.pre * (y - eta.treat.pre) / mean(w.treat.pre)
|
|
3193
|
+
psi_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / mean_w_treat_pre
|
|
3194
|
+
|
|
3195
|
+
# --- Control psi: leading term (R: inf.cont.1) ---
|
|
3196
|
+
# R: w.cont * (or_diff - eta.cont) [before /mean(w.cont)]
|
|
3197
|
+
psi_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont)
|
|
3198
|
+
psi_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont)
|
|
3199
|
+
|
|
3200
|
+
# --- Control psi: estimation effect (R: inf.cont.2) ---
|
|
3201
|
+
# R: bread = solve(crossprod(X_ctrl, W * X_ctrl) / n)
|
|
3202
|
+
# Here bread is (X'WX)^{-1} (without /n), so asy_lin_rep already
|
|
3203
|
+
# absorbs the 1/n that R puts in its bread. We compensate by using
|
|
3204
|
+
# R's colMeans (= sum/n_all) for M1, matching the product exactly.
|
|
3205
|
+
W_ct = sw_ct if sw_ct is not None else np.ones(n_ct)
|
|
3206
|
+
W_cs = sw_cs if sw_cs is not None else np.ones(n_cs)
|
|
3207
|
+
bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int))
|
|
3208
|
+
bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int))
|
|
3209
|
+
|
|
3210
|
+
# R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all
|
|
3211
|
+
M1 = (
|
|
3212
|
+
np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
|
|
3213
|
+
) / n_all
|
|
3214
|
+
|
|
3215
|
+
# R: asy.lin.rep.ols (per-obs OLS score * bread)
|
|
3216
|
+
asy_lin_rep_ols_t = (W_ct * resid_ct)[:, None] * X_ct_int @ bread_t
|
|
3217
|
+
asy_lin_rep_ols_s = (W_cs * resid_cs)[:, None] * X_cs_int @ bread_s
|
|
3218
|
+
|
|
3219
|
+
# R: inf.cont.2.post = asy.lin.rep.ols_t %*% M1
|
|
3220
|
+
psi_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,)
|
|
3221
|
+
# R: inf.cont.2.pre = asy.lin.rep.ols_s %*% M1
|
|
3222
|
+
psi_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,)
|
|
3223
|
+
|
|
3224
|
+
# --- Assemble per-group psi ---
|
|
3225
|
+
# R: inf.treat = inf.treat.post - inf.treat.pre (across groups)
|
|
3226
|
+
# R: inf.cont = (inf.cont.1 + inf.cont.2.post - inf.cont.2.pre) / mean(w.cont)
|
|
3227
|
+
# R: att.inf.func = inf.treat - inf.cont
|
|
3228
|
+
psi_gt = psi_treat_post - psi_cont_1_gt / mean_w_D
|
|
3229
|
+
psi_gs = -psi_treat_pre - psi_cont_1_gs / mean_w_D
|
|
3230
|
+
psi_ct = -psi_cont_2_ct / mean_w_D
|
|
3231
|
+
psi_cs = psi_cont_2_cs / mean_w_D
|
|
3232
|
+
|
|
3233
|
+
psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs])
|
|
3234
|
+
|
|
3235
|
+
# =================================================================
|
|
3236
|
+
# Convert to library convention: phi = psi / n_all
|
|
3237
|
+
# se = sqrt(sum(phi^2)) == sqrt(sum(psi^2)) / n_all
|
|
3238
|
+
# =================================================================
|
|
3239
|
+
inf_all = psi_all / n_all
|
|
3240
|
+
se = float(np.sqrt(np.sum(inf_all**2)))
|
|
3241
|
+
|
|
3242
|
+
idx_all = None # caller builds idx from masks
|
|
3243
|
+
return att, se, inf_all, idx_all
|
|
3244
|
+
|
|
3245
|
+
def _ipw_estimation_rc(
|
|
3246
|
+
self,
|
|
3247
|
+
y_gt,
|
|
3248
|
+
y_gs,
|
|
3249
|
+
y_ct,
|
|
3250
|
+
y_cs,
|
|
3251
|
+
X_gt,
|
|
3252
|
+
X_gs,
|
|
3253
|
+
X_ct,
|
|
3254
|
+
X_cs,
|
|
3255
|
+
sw_gt=None,
|
|
3256
|
+
sw_gs=None,
|
|
3257
|
+
sw_ct=None,
|
|
3258
|
+
sw_cs=None,
|
|
3259
|
+
context_label: str = "",
|
|
3260
|
+
epv_diagnostics_out: Optional[dict] = None,
|
|
3261
|
+
):
|
|
3262
|
+
"""
|
|
3263
|
+
Cross-sectional IPW estimation for ATT(g,t).
|
|
3264
|
+
|
|
3265
|
+
Propensity score P(G=g | X) estimated on pooled treated+control
|
|
3266
|
+
observations from both periods. Reweight controls in each period.
|
|
3267
|
+
|
|
3268
|
+
IF convention
|
|
3269
|
+
-------------
|
|
3270
|
+
Intermediate terms use R's unnormalized psi_i convention throughout
|
|
3271
|
+
(R: ``ipw_did_rc``). R computes SE as ``sd(psi) / sqrt(n)``.
|
|
3272
|
+
At the end we convert to the library's pre-scaled phi_i = psi_i / n
|
|
3273
|
+
convention where ``se = sqrt(sum(phi^2))``, used by the
|
|
3274
|
+
aggregation/bootstrap layer.
|
|
3275
|
+
|
|
3276
|
+
Returns (att, se, inf_func_concat, idx_concat).
|
|
3277
|
+
"""
|
|
3278
|
+
n_gt = len(y_gt)
|
|
3279
|
+
n_gs = len(y_gs)
|
|
3280
|
+
n_ct = len(y_ct)
|
|
3281
|
+
n_cs = len(y_cs)
|
|
3282
|
+
n_all = n_gt + n_gs + n_ct + n_cs
|
|
3283
|
+
|
|
3284
|
+
# Pool treated and control for propensity score
|
|
3285
|
+
X_all = np.vstack([X_gt, X_gs, X_ct, X_cs])
|
|
3286
|
+
D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)])
|
|
3287
|
+
|
|
3288
|
+
sw_all = None
|
|
3289
|
+
if sw_gt is not None:
|
|
3290
|
+
sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs])
|
|
3291
|
+
|
|
3292
|
+
ps_fallback_used = False
|
|
3293
|
+
diag = {}
|
|
3294
|
+
try:
|
|
3295
|
+
beta_logistic, pscore = solve_logit(
|
|
3296
|
+
X_all,
|
|
3297
|
+
D_all,
|
|
3298
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3299
|
+
weights=sw_all,
|
|
3300
|
+
epv_threshold=self.epv_threshold,
|
|
3301
|
+
context_label=context_label,
|
|
3302
|
+
diagnostics_out=diag,
|
|
3303
|
+
)
|
|
3304
|
+
_check_propensity_diagnostics(pscore, self.pscore_trim)
|
|
3305
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
3306
|
+
if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
|
|
3307
|
+
raise
|
|
3308
|
+
ctx = f" for {context_label}" if context_label else ""
|
|
3309
|
+
warnings.warn(
|
|
3310
|
+
f"Propensity score estimation failed{ctx} (RCS IPW). "
|
|
3311
|
+
f"Falling back to unconditional propensity "
|
|
3312
|
+
f"(all covariates dropped for this cell). "
|
|
3313
|
+
f"Consider estimation_method='reg' to avoid "
|
|
3314
|
+
f"propensity scores entirely.",
|
|
3315
|
+
UserWarning,
|
|
3316
|
+
stacklevel=4,
|
|
3317
|
+
)
|
|
3318
|
+
if sw_all is not None:
|
|
3319
|
+
pos = sw_all > 0
|
|
3320
|
+
p_treat = float(np.average(D_all[pos], weights=sw_all[pos]))
|
|
3321
|
+
else:
|
|
3322
|
+
p_treat = (n_gt + n_gs) / len(D_all)
|
|
3323
|
+
pscore = np.full(len(D_all), p_treat)
|
|
3324
|
+
ps_fallback_used = True
|
|
3325
|
+
if epv_diagnostics_out is not None and diag:
|
|
3326
|
+
epv_diagnostics_out.update(diag)
|
|
3327
|
+
|
|
3328
|
+
# Clip propensity scores
|
|
3329
|
+
pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
|
|
3330
|
+
|
|
3331
|
+
# Split propensity scores (treated ps not used -- only control IPW weights)
|
|
3332
|
+
ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct]
|
|
3333
|
+
ps_cs = pscore[n_gt + n_gs + n_ct :]
|
|
3334
|
+
|
|
3335
|
+
# IPW weights for controls (R: w1.x = ps / (1 - ps))
|
|
3336
|
+
w_ct = ps_ct / (1 - ps_ct)
|
|
3337
|
+
w_cs = ps_cs / (1 - ps_cs)
|
|
3338
|
+
|
|
3339
|
+
if sw_gt is not None:
|
|
3340
|
+
w_ct = sw_ct * w_ct
|
|
3341
|
+
w_cs = sw_cs * w_cs
|
|
3342
|
+
|
|
3343
|
+
# R: mean(w.treat.post), mean(w.treat.pre), mean(w.ipw.ct), mean(w.ipw.cs)
|
|
3344
|
+
if sw_gt is not None:
|
|
3345
|
+
sum_w_treat_post = np.sum(sw_gt)
|
|
3346
|
+
sum_w_treat_pre = np.sum(sw_gs)
|
|
3347
|
+
else:
|
|
3348
|
+
sum_w_treat_post = float(n_gt)
|
|
3349
|
+
sum_w_treat_pre = float(n_gs)
|
|
3350
|
+
|
|
3351
|
+
mean_w_treat_post = sum_w_treat_post / n_all
|
|
3352
|
+
mean_w_treat_pre = sum_w_treat_pre / n_all
|
|
3353
|
+
|
|
3354
|
+
sum_w_ct = np.sum(w_ct)
|
|
3355
|
+
sum_w_cs = np.sum(w_cs)
|
|
3356
|
+
mean_w_ct = sum_w_ct / n_all
|
|
3357
|
+
mean_w_cs = sum_w_cs / n_all
|
|
3358
|
+
|
|
3359
|
+
# Hajek-normalized weights (R normalizes by sum for point estimate)
|
|
3360
|
+
w_ct_norm = w_ct / sum_w_ct if sum_w_ct > 0 else w_ct
|
|
3361
|
+
w_cs_norm = w_cs / sum_w_cs if sum_w_cs > 0 else w_cs
|
|
3362
|
+
|
|
3363
|
+
if sw_gt is not None:
|
|
3364
|
+
sw_gt_norm = sw_gt / sum_w_treat_post
|
|
3365
|
+
sw_gs_norm = sw_gs / sum_w_treat_pre
|
|
3366
|
+
mu_gt = float(np.sum(sw_gt_norm * y_gt))
|
|
3367
|
+
mu_gs = float(np.sum(sw_gs_norm * y_gs))
|
|
3368
|
+
else:
|
|
3369
|
+
mu_gt = float(np.mean(y_gt))
|
|
3370
|
+
mu_gs = float(np.mean(y_gs))
|
|
3371
|
+
|
|
3372
|
+
mu_ct_ipw = float(np.sum(w_ct_norm * y_ct))
|
|
3373
|
+
mu_cs_ipw = float(np.sum(w_cs_norm * y_cs))
|
|
3374
|
+
|
|
3375
|
+
att = (mu_gt - mu_ct_ipw) - (mu_gs - mu_cs_ipw)
|
|
3376
|
+
|
|
3377
|
+
# =================================================================
|
|
3378
|
+
# Influence function in R's unnormalized psi convention
|
|
3379
|
+
# (R: ipw_did_rc.R, psi = n * phi)
|
|
3380
|
+
# =================================================================
|
|
3381
|
+
|
|
3382
|
+
# --- Treated psi (R: eta.treat.post, eta.treat.pre) ---
|
|
3383
|
+
# R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post)
|
|
3384
|
+
if sw_gt is not None:
|
|
3385
|
+
psi_gt = sw_gt * (y_gt - mu_gt) / mean_w_treat_post
|
|
3386
|
+
psi_gs = -sw_gs * (y_gs - mu_gs) / mean_w_treat_pre
|
|
3387
|
+
else:
|
|
3388
|
+
psi_gt = (y_gt - mu_gt) / mean_w_treat_post
|
|
3389
|
+
psi_gs = -(y_gs - mu_gs) / mean_w_treat_pre
|
|
3390
|
+
|
|
3391
|
+
# --- Control psi (R: eta.cont.post, eta.cont.pre) ---
|
|
3392
|
+
# R: w.ipw * (y - eta.cont) / mean(w.ipw)
|
|
3393
|
+
psi_ct = -w_ct * (y_ct - mu_ct_ipw) / mean_w_ct if mean_w_ct > 0 else np.zeros(n_ct)
|
|
3394
|
+
psi_cs = w_cs * (y_cs - mu_cs_ipw) / mean_w_cs if mean_w_cs > 0 else np.zeros(n_cs)
|
|
3395
|
+
|
|
3396
|
+
psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs])
|
|
3397
|
+
|
|
3398
|
+
# Convert leading psi to phi: phi = psi / n_all
|
|
3399
|
+
inf_all = psi_all / n_all
|
|
3400
|
+
|
|
3401
|
+
if not ps_fallback_used:
|
|
3402
|
+
# --- PS IF correction — psi convention, convert to phi ---
|
|
3403
|
+
X_all_int = np.column_stack([np.ones(n_all), X_all])
|
|
3404
|
+
|
|
3405
|
+
W_ps = pscore * (1 - pscore)
|
|
3406
|
+
if sw_all is not None:
|
|
3407
|
+
W_ps = W_ps * sw_all
|
|
3408
|
+
# R: Hessian.ps = crossprod(X * sqrt(W)) / n
|
|
3409
|
+
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all
|
|
3410
|
+
H_psi_inv = _safe_inv(H_psi)
|
|
3411
|
+
|
|
3412
|
+
score_ps = (D_all - pscore)[:, None] * X_all_int
|
|
3413
|
+
if sw_all is not None:
|
|
3414
|
+
score_ps = score_ps * sw_all[:, None]
|
|
3415
|
+
# R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs)
|
|
3416
|
+
asy_lin_rep_psi = score_ps @ H_psi_inv
|
|
3417
|
+
|
|
3418
|
+
# PS nuisance correction in psi convention
|
|
3419
|
+
# R: M2 = colMeans(w_ipw * (y-mu) * X)
|
|
3420
|
+
ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw)
|
|
3421
|
+
ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw)
|
|
3422
|
+
|
|
3423
|
+
ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct)
|
|
3424
|
+
cs_slice = slice(n_gt + n_gs + n_ct, None)
|
|
3425
|
+
|
|
3426
|
+
M2 = np.zeros(X_all_int.shape[1])
|
|
3427
|
+
M2 += np.sum(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0)
|
|
3428
|
+
M2 -= np.sum(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0)
|
|
3429
|
+
|
|
3430
|
+
# psi-scale correction, convert to phi
|
|
3431
|
+
# Subtract: R adds PS correction to inf.control, then att = treat - control
|
|
3432
|
+
inf_all = inf_all - (asy_lin_rep_psi @ M2) / n_all
|
|
3433
|
+
|
|
3434
|
+
# =================================================================
|
|
3435
|
+
# SE from phi: se = sqrt(sum(phi^2))
|
|
3436
|
+
# Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0.
|
|
3437
|
+
# =================================================================
|
|
3438
|
+
se = float(np.sqrt(np.sum(inf_all**2)))
|
|
3439
|
+
|
|
3440
|
+
idx_all = None
|
|
3441
|
+
return att, se, inf_all, idx_all
|
|
3442
|
+
|
|
3443
|
+
def _doubly_robust_rc(
|
|
3444
|
+
self,
|
|
3445
|
+
y_gt,
|
|
3446
|
+
y_gs,
|
|
3447
|
+
y_ct,
|
|
3448
|
+
y_cs,
|
|
3449
|
+
X_gt,
|
|
3450
|
+
X_gs,
|
|
3451
|
+
X_ct,
|
|
3452
|
+
X_cs,
|
|
3453
|
+
sw_gt=None,
|
|
3454
|
+
sw_gs=None,
|
|
3455
|
+
sw_ct=None,
|
|
3456
|
+
sw_cs=None,
|
|
3457
|
+
context_label: str = "",
|
|
3458
|
+
epv_diagnostics_out: Optional[dict] = None,
|
|
3459
|
+
):
|
|
3460
|
+
"""
|
|
3461
|
+
Cross-sectional doubly robust estimation for ATT(g,t).
|
|
3462
|
+
|
|
3463
|
+
Matches R DRDID::drdid_rc (Sant'Anna & Zhao 2020, Eq 3.1).
|
|
3464
|
+
Locally efficient DR estimator with 4 OLS fits (control pre/post,
|
|
3465
|
+
treated pre/post) plus propensity score.
|
|
3466
|
+
|
|
3467
|
+
IF convention
|
|
3468
|
+
-------------
|
|
3469
|
+
Intermediate terms use R's unnormalized psi_i convention throughout
|
|
3470
|
+
(R: ``drdid_rc``). R computes SE as ``sd(psi) / sqrt(n)``.
|
|
3471
|
+
At the end we convert to the library's pre-scaled phi_i = psi_i / n
|
|
3472
|
+
convention where ``se = sqrt(sum(phi^2))``, used by the
|
|
3473
|
+
aggregation/bootstrap layer.
|
|
3474
|
+
|
|
3475
|
+
Returns (att, se, inf_func_concat, idx_concat).
|
|
3476
|
+
"""
|
|
3477
|
+
n_gt = len(y_gt)
|
|
3478
|
+
n_gs = len(y_gs)
|
|
3479
|
+
n_ct = len(y_ct)
|
|
3480
|
+
n_cs = len(y_cs)
|
|
3481
|
+
n_all = n_gt + n_gs + n_ct + n_cs
|
|
3482
|
+
|
|
3483
|
+
# =====================================================================
|
|
3484
|
+
# 1. Outcome regression: 4 OLS fits
|
|
3485
|
+
# =====================================================================
|
|
3486
|
+
# Control OLS: E[Y|X, D=0, T=t] and E[Y|X, D=0, T=s]
|
|
3487
|
+
beta_ct, resid_ct = _linear_regression(
|
|
3488
|
+
X_ct,
|
|
3489
|
+
y_ct,
|
|
3490
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3491
|
+
weights=sw_ct,
|
|
3492
|
+
)
|
|
3493
|
+
beta_ct = np.where(np.isfinite(beta_ct), beta_ct, 0.0)
|
|
3494
|
+
|
|
3495
|
+
beta_cs, resid_cs = _linear_regression(
|
|
3496
|
+
X_cs,
|
|
3497
|
+
y_cs,
|
|
3498
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3499
|
+
weights=sw_cs,
|
|
3500
|
+
)
|
|
3501
|
+
beta_cs = np.where(np.isfinite(beta_cs), beta_cs, 0.0)
|
|
3502
|
+
|
|
3503
|
+
# Treated OLS: E[Y|X, D=1, T=t] and E[Y|X, D=1, T=s]
|
|
3504
|
+
beta_gt, resid_gt = _linear_regression(
|
|
3505
|
+
X_gt,
|
|
3506
|
+
y_gt,
|
|
3507
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3508
|
+
weights=sw_gt,
|
|
3509
|
+
)
|
|
3510
|
+
beta_gt = np.where(np.isfinite(beta_gt), beta_gt, 0.0)
|
|
3511
|
+
|
|
3512
|
+
beta_gs, resid_gs = _linear_regression(
|
|
3513
|
+
X_gs,
|
|
3514
|
+
y_gs,
|
|
3515
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3516
|
+
weights=sw_gs,
|
|
3517
|
+
)
|
|
3518
|
+
beta_gs = np.where(np.isfinite(beta_gs), beta_gs, 0.0)
|
|
3519
|
+
|
|
3520
|
+
# Intercept-augmented design matrices
|
|
3521
|
+
X_gt_int = np.column_stack([np.ones(n_gt), X_gt])
|
|
3522
|
+
X_gs_int = np.column_stack([np.ones(n_gs), X_gs])
|
|
3523
|
+
X_ct_int = np.column_stack([np.ones(n_ct), X_ct])
|
|
3524
|
+
X_cs_int = np.column_stack([np.ones(n_cs), X_cs])
|
|
3525
|
+
|
|
3526
|
+
# Control OR predictions for all groups
|
|
3527
|
+
mu0_post_gt = X_gt_int @ beta_ct # mu_{0,1}(X) for treated-post
|
|
3528
|
+
mu0_pre_gt = X_gt_int @ beta_cs # mu_{0,0}(X) for treated-post
|
|
3529
|
+
mu0_post_gs = X_gs_int @ beta_ct # mu_{0,1}(X) for treated-pre
|
|
3530
|
+
mu0_pre_gs = X_gs_int @ beta_cs # mu_{0,0}(X) for treated-pre
|
|
3531
|
+
mu0_post_ct = X_ct_int @ beta_ct # mu_{0,1}(X) for control-post
|
|
3532
|
+
mu0_pre_ct = X_ct_int @ beta_cs # mu_{0,0}(X) for control-post
|
|
3533
|
+
mu0_post_cs = X_cs_int @ beta_ct # mu_{0,1}(X) for control-pre
|
|
3534
|
+
mu0_pre_cs = X_cs_int @ beta_cs # mu_{0,0}(X) for control-pre
|
|
3535
|
+
|
|
3536
|
+
# Treated OR predictions for all groups (for local efficiency adjustment)
|
|
3537
|
+
mu1_post_gt = X_gt_int @ beta_gt # mu_{1,1}(X) for treated-post
|
|
3538
|
+
mu1_pre_gt = X_gt_int @ beta_gs # mu_{1,0}(X) for treated-post
|
|
3539
|
+
mu1_post_gs = X_gs_int @ beta_gt # mu_{1,1}(X) for treated-pre
|
|
3540
|
+
mu1_pre_gs = X_gs_int @ beta_gs # mu_{1,0}(X) for treated-pre
|
|
3541
|
+
|
|
3542
|
+
# mu_{0,Y}(T_i, X_i): control OR evaluated at own period
|
|
3543
|
+
mu0Y_gt = mu0_post_gt # treated-post: use post control model
|
|
3544
|
+
mu0Y_gs = mu0_pre_gs # treated-pre: use pre control model
|
|
3545
|
+
mu0Y_ct = mu0_post_ct # control-post: use post control model
|
|
3546
|
+
mu0Y_cs = mu0_pre_cs # control-pre: use pre control model
|
|
3547
|
+
|
|
3548
|
+
# =====================================================================
|
|
3549
|
+
# 2. Propensity score
|
|
3550
|
+
# =====================================================================
|
|
3551
|
+
X_all = np.vstack([X_gt, X_gs, X_ct, X_cs])
|
|
3552
|
+
D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)])
|
|
3553
|
+
sw_all = None
|
|
3554
|
+
if sw_gt is not None:
|
|
3555
|
+
sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs])
|
|
3556
|
+
|
|
3557
|
+
ps_fallback_used = False
|
|
3558
|
+
diag = {}
|
|
3559
|
+
try:
|
|
3560
|
+
beta_logistic, pscore = solve_logit(
|
|
3561
|
+
X_all,
|
|
3562
|
+
D_all,
|
|
3563
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
3564
|
+
weights=sw_all,
|
|
3565
|
+
epv_threshold=self.epv_threshold,
|
|
3566
|
+
context_label=context_label,
|
|
3567
|
+
diagnostics_out=diag,
|
|
3568
|
+
)
|
|
3569
|
+
_check_propensity_diagnostics(pscore, self.pscore_trim)
|
|
3570
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
3571
|
+
if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
|
|
3572
|
+
raise
|
|
3573
|
+
ctx = f" for {context_label}" if context_label else ""
|
|
3574
|
+
warnings.warn(
|
|
3575
|
+
f"Propensity score estimation failed{ctx} (RCS DR). "
|
|
3576
|
+
f"Falling back to unconditional propensity "
|
|
3577
|
+
f"(propensity model ignores covariates; outcome "
|
|
3578
|
+
f"regression still uses them). "
|
|
3579
|
+
f"Consider estimation_method='reg' to avoid "
|
|
3580
|
+
f"propensity scores entirely.",
|
|
3581
|
+
UserWarning,
|
|
3582
|
+
stacklevel=4,
|
|
3583
|
+
)
|
|
3584
|
+
if sw_all is not None:
|
|
3585
|
+
pos = sw_all > 0
|
|
3586
|
+
p_treat = float(np.average(D_all[pos], weights=sw_all[pos]))
|
|
3587
|
+
else:
|
|
3588
|
+
p_treat = (n_gt + n_gs) / len(D_all)
|
|
3589
|
+
pscore = np.full(len(D_all), p_treat)
|
|
3590
|
+
ps_fallback_used = True
|
|
3591
|
+
if epv_diagnostics_out is not None and diag:
|
|
3592
|
+
epv_diagnostics_out.update(diag)
|
|
3593
|
+
|
|
3594
|
+
pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
|
|
3595
|
+
|
|
3596
|
+
# Split propensity scores per group
|
|
3597
|
+
ps_gt = pscore[:n_gt]
|
|
3598
|
+
ps_gs = pscore[n_gt : n_gt + n_gs]
|
|
3599
|
+
ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct]
|
|
3600
|
+
ps_cs = pscore[n_gt + n_gs + n_ct :]
|
|
3601
|
+
|
|
3602
|
+
# =====================================================================
|
|
3603
|
+
# 3. Group weights and R-convention means
|
|
3604
|
+
# =====================================================================
|
|
3605
|
+
if sw_gt is not None:
|
|
3606
|
+
w_treat_post = sw_gt
|
|
3607
|
+
w_treat_pre = sw_gs
|
|
3608
|
+
w_D_gt = sw_gt
|
|
3609
|
+
w_D_gs = sw_gs
|
|
3610
|
+
else:
|
|
3611
|
+
w_treat_post = np.ones(n_gt)
|
|
3612
|
+
w_treat_pre = np.ones(n_gs)
|
|
3613
|
+
w_D_gt = np.ones(n_gt)
|
|
3614
|
+
w_D_gs = np.ones(n_gs)
|
|
3615
|
+
|
|
3616
|
+
sum_w_treat_post = np.sum(w_treat_post)
|
|
3617
|
+
sum_w_treat_pre = np.sum(w_treat_pre)
|
|
3618
|
+
sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs)
|
|
3619
|
+
|
|
3620
|
+
# R: mean(w) = sum(w) / n -- used in psi normalizers
|
|
3621
|
+
mean_w_treat_post = sum_w_treat_post / n_all
|
|
3622
|
+
mean_w_treat_pre = sum_w_treat_pre / n_all
|
|
3623
|
+
mean_w_D = sum_w_D / n_all
|
|
3624
|
+
|
|
3625
|
+
# IPW control weights: sw * ps/(1-ps) for controls
|
|
3626
|
+
w_ipw_ct = ps_ct / (1 - ps_ct)
|
|
3627
|
+
w_ipw_cs = ps_cs / (1 - ps_cs)
|
|
3628
|
+
if sw_ct is not None:
|
|
3629
|
+
w_ipw_ct = sw_ct * w_ipw_ct
|
|
3630
|
+
w_ipw_cs = sw_cs * w_ipw_cs
|
|
3631
|
+
|
|
3632
|
+
sum_w_ipw_ct = np.sum(w_ipw_ct)
|
|
3633
|
+
sum_w_ipw_cs = np.sum(w_ipw_cs)
|
|
3634
|
+
mean_w_ipw_ct = sum_w_ipw_ct / n_all
|
|
3635
|
+
mean_w_ipw_cs = sum_w_ipw_cs / n_all
|
|
3636
|
+
|
|
3637
|
+
# =====================================================================
|
|
3638
|
+
# 4. Point estimate: tau_1 (AIPW using control ORs)
|
|
3639
|
+
# =====================================================================
|
|
3640
|
+
# Hajek-normalized means of (y - mu0Y) per group
|
|
3641
|
+
eta_treat_post = np.sum(w_treat_post * (y_gt - mu0Y_gt)) / sum_w_treat_post
|
|
3642
|
+
eta_treat_pre = np.sum(w_treat_pre * (y_gs - mu0Y_gs)) / sum_w_treat_pre
|
|
3643
|
+
|
|
3644
|
+
eta_cont_post = (
|
|
3645
|
+
np.sum(w_ipw_ct * (y_ct - mu0Y_ct)) / sum_w_ipw_ct if sum_w_ipw_ct > 0 else 0.0
|
|
3646
|
+
)
|
|
3647
|
+
eta_cont_pre = (
|
|
3648
|
+
np.sum(w_ipw_cs * (y_cs - mu0Y_cs)) / sum_w_ipw_cs if sum_w_ipw_cs > 0 else 0.0
|
|
3649
|
+
)
|
|
3650
|
+
|
|
3651
|
+
tau_1 = (eta_treat_post - eta_cont_post) - (eta_treat_pre - eta_cont_pre)
|
|
3652
|
+
|
|
3653
|
+
# =====================================================================
|
|
3654
|
+
# 5. Point estimate: local efficiency adjustment (tau_2)
|
|
3655
|
+
# =====================================================================
|
|
3656
|
+
# Differences mu_{1,t}(X) - mu_{0,t}(X) for treated obs
|
|
3657
|
+
or_diff_post_gt = mu1_post_gt - mu0_post_gt # at treated-post
|
|
3658
|
+
or_diff_post_gs = mu1_post_gs - mu0_post_gs # at treated-pre
|
|
3659
|
+
or_diff_pre_gt = mu1_pre_gt - mu0_pre_gt # at treated-post
|
|
3660
|
+
or_diff_pre_gs = mu1_pre_gs - mu0_pre_gs # at treated-pre
|
|
3661
|
+
|
|
3662
|
+
# att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) -- all treated
|
|
3663
|
+
att_d_post = (np.sum(w_D_gt * or_diff_post_gt) + np.sum(w_D_gs * or_diff_post_gs)) / sum_w_D
|
|
3664
|
+
# att_dt1_post -- treated-post only
|
|
3665
|
+
att_dt1_post = np.sum(w_treat_post * or_diff_post_gt) / sum_w_treat_post
|
|
3666
|
+
# att_d_pre -- all treated
|
|
3667
|
+
att_d_pre = (np.sum(w_D_gt * or_diff_pre_gt) + np.sum(w_D_gs * or_diff_pre_gs)) / sum_w_D
|
|
3668
|
+
# att_dt0_pre -- treated-pre only
|
|
3669
|
+
att_dt0_pre = np.sum(w_treat_pre * or_diff_pre_gs) / sum_w_treat_pre
|
|
3670
|
+
|
|
3671
|
+
tau_2 = (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre)
|
|
3672
|
+
|
|
3673
|
+
att = float(tau_1 + tau_2)
|
|
3674
|
+
|
|
3675
|
+
# =====================================================================
|
|
3676
|
+
# 6. Influence function in R's unnormalized psi convention
|
|
3677
|
+
# (R: drdid_rc.R, psi = n * phi)
|
|
3678
|
+
# =====================================================================
|
|
3679
|
+
|
|
3680
|
+
# --- tau_1: treated psi (R: eta.treat.post / mean(w.treat.post)) ---
|
|
3681
|
+
# R: w.treat.post * (y - mu0Y - eta.treat.post) / mean(w.treat.post)
|
|
3682
|
+
psi_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / mean_w_treat_post
|
|
3683
|
+
psi_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / mean_w_treat_pre
|
|
3684
|
+
|
|
3685
|
+
# --- tau_1: control psi (R: eta.cont.post / mean(w.ipw)) ---
|
|
3686
|
+
# R: w.ipw * (y - mu0Y - eta.cont) / mean(w.ipw)
|
|
3687
|
+
psi_cont_post_ct = (
|
|
3688
|
+
w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / mean_w_ipw_ct
|
|
3689
|
+
if mean_w_ipw_ct > 0
|
|
3690
|
+
else np.zeros(n_ct)
|
|
3691
|
+
)
|
|
3692
|
+
psi_cont_pre_cs = (
|
|
3693
|
+
w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / mean_w_ipw_cs
|
|
3694
|
+
if mean_w_ipw_cs > 0
|
|
3695
|
+
else np.zeros(n_cs)
|
|
3696
|
+
)
|
|
3697
|
+
|
|
3698
|
+
# tau_1 psi per group
|
|
3699
|
+
psi_gt_tau1 = psi_treat_post
|
|
3700
|
+
psi_gs_tau1 = -psi_treat_pre
|
|
3701
|
+
psi_ct_tau1 = -psi_cont_post_ct
|
|
3702
|
+
psi_cs_tau1 = psi_cont_pre_cs
|
|
3703
|
+
|
|
3704
|
+
# =====================================================================
|
|
3705
|
+
# 7. tau_2 leading terms (R: att.d.post, att.dt1.post, etc.)
|
|
3706
|
+
# =====================================================================
|
|
3707
|
+
# R: w.D * (or_diff - att.d.post) / mean(w.D)
|
|
3708
|
+
psi_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / mean_w_D
|
|
3709
|
+
psi_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / mean_w_D
|
|
3710
|
+
# R: w.treat.post * (or_diff - att.dt1.post) / mean(w.treat.post)
|
|
3711
|
+
psi_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / mean_w_treat_post
|
|
3712
|
+
# R: w.D * (or_diff_pre - att.d.pre) / mean(w.D)
|
|
3713
|
+
psi_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / mean_w_D
|
|
3714
|
+
psi_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / mean_w_D
|
|
3715
|
+
# R: w.treat.pre * (or_diff_pre - att.dt0.pre) / mean(w.treat.pre)
|
|
3716
|
+
psi_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / mean_w_treat_pre
|
|
3717
|
+
|
|
3718
|
+
# tau_2 psi per group (controls contribute zero)
|
|
3719
|
+
psi_gt_tau2 = (psi_d_post_gt - psi_dt1_post) - psi_d_pre_gt
|
|
3720
|
+
psi_gs_tau2 = psi_d_post_gs - (-psi_dt0_pre + psi_d_pre_gs)
|
|
3721
|
+
|
|
3722
|
+
# =====================================================================
|
|
3723
|
+
# 8. Combined plug-in psi (before nuisance corrections)
|
|
3724
|
+
# =====================================================================
|
|
3725
|
+
psi_gt = psi_gt_tau1 + psi_gt_tau2
|
|
3726
|
+
psi_gs = psi_gs_tau1 + psi_gs_tau2
|
|
3727
|
+
psi_ct = psi_ct_tau1
|
|
3728
|
+
psi_cs = psi_cs_tau1
|
|
3729
|
+
|
|
3730
|
+
psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs])
|
|
3731
|
+
|
|
3732
|
+
# =================================================================
|
|
3733
|
+
# Convert leading psi to library phi convention: phi = psi / n_all
|
|
3734
|
+
# =================================================================
|
|
3735
|
+
inf_all = psi_all / n_all
|
|
3736
|
+
|
|
3737
|
+
# =====================================================================
|
|
3738
|
+
# 9. PS nuisance correction — psi convention, convert to phi
|
|
3739
|
+
# =====================================================================
|
|
3740
|
+
X_all_int = np.column_stack([np.ones(n_all), X_all])
|
|
3741
|
+
if not ps_fallback_used:
|
|
3742
|
+
W_ps = pscore * (1 - pscore)
|
|
3743
|
+
if sw_all is not None:
|
|
3744
|
+
W_ps = W_ps * sw_all
|
|
3745
|
+
# R: Hessian.ps = crossprod(X * sqrt(W)) / n
|
|
3746
|
+
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all
|
|
3747
|
+
H_psi_inv = _safe_inv(H_psi)
|
|
3748
|
+
|
|
3749
|
+
score_ps = (D_all - pscore)[:, None] * X_all_int
|
|
3750
|
+
if sw_all is not None:
|
|
3751
|
+
score_ps = score_ps * sw_all[:, None]
|
|
3752
|
+
# R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs)
|
|
3753
|
+
asy_lin_rep_psi = score_ps @ H_psi_inv
|
|
3754
|
+
|
|
3755
|
+
# R: M2 = colMeans(w_ipw * dr_resid / mean(w_ipw) * X)
|
|
3756
|
+
ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct)
|
|
3757
|
+
cs_slice = slice(n_gt + n_gs + n_ct, None)
|
|
3758
|
+
|
|
3759
|
+
dr_resid_ct = y_ct - mu0Y_ct - eta_cont_post
|
|
3760
|
+
dr_resid_cs = y_cs - mu0Y_cs - eta_cont_pre
|
|
3761
|
+
|
|
3762
|
+
M2 = np.zeros(X_all_int.shape[1])
|
|
3763
|
+
if sum_w_ipw_ct > 0:
|
|
3764
|
+
M2 -= np.sum(
|
|
3765
|
+
((w_ipw_ct * dr_resid_ct / sum_w_ipw_ct)[:, None] * X_all_int[ct_slice]),
|
|
3766
|
+
axis=0,
|
|
3767
|
+
)
|
|
3768
|
+
if sum_w_ipw_cs > 0:
|
|
3769
|
+
M2 += np.sum(
|
|
3770
|
+
((w_ipw_cs * dr_resid_cs / sum_w_ipw_cs)[:, None] * X_all_int[cs_slice]),
|
|
3771
|
+
axis=0,
|
|
3772
|
+
)
|
|
3773
|
+
|
|
3774
|
+
# psi-scale correction, convert to phi
|
|
3775
|
+
inf_all = inf_all + (asy_lin_rep_psi @ M2) / n_all
|
|
3776
|
+
|
|
3777
|
+
# =====================================================================
|
|
3778
|
+
# 10. Control OR nuisance corrections (phi-scale)
|
|
3779
|
+
# =====================================================================
|
|
3780
|
+
W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct)
|
|
3781
|
+
W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs)
|
|
3782
|
+
bread_ct = _safe_inv(X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int))
|
|
3783
|
+
bread_cs = _safe_inv(X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int))
|
|
3784
|
+
|
|
3785
|
+
# R: asy.lin.rep.ols (per-obs OLS score * bread)
|
|
3786
|
+
asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct
|
|
3787
|
+
asy_lin_rep_cs = (W_cs_vals * resid_cs)[:, None] * X_cs_int @ bread_cs
|
|
3788
|
+
|
|
3789
|
+
# M1 for control-post model (beta_ct): gradient from tau_1 + tau_2
|
|
3790
|
+
# tau_1: -w_treat_post*X/sum_w_treat_post (eta_treat_post via mu0Y_gt)
|
|
3791
|
+
# +w_ipw_ct*X/sum_w_ipw_ct (eta_cont_post via mu0Y_ct)
|
|
3792
|
+
# tau_2: -w_D*X/sum_w_D (att_d_post via mu0_post at all treated)
|
|
3793
|
+
# +w_treat_post*X/sum_w_treat_post (att_dt1_post via mu0_post)
|
|
3794
|
+
M1_ct = np.zeros(X_all_int.shape[1])
|
|
3795
|
+
M1_ct -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post
|
|
3796
|
+
if sum_w_ipw_ct > 0:
|
|
3797
|
+
M1_ct += np.sum(w_ipw_ct[:, None] * X_ct_int, axis=0) / sum_w_ipw_ct
|
|
3798
|
+
M1_ct -= (
|
|
3799
|
+
np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
|
|
3800
|
+
) / sum_w_D
|
|
3801
|
+
M1_ct += np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post
|
|
3802
|
+
|
|
3803
|
+
# M1 for control-pre model (beta_cs)
|
|
3804
|
+
M1_cs = np.zeros(X_all_int.shape[1])
|
|
3805
|
+
M1_cs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre
|
|
3806
|
+
if sum_w_ipw_cs > 0:
|
|
3807
|
+
M1_cs -= np.sum(w_ipw_cs[:, None] * X_cs_int, axis=0) / sum_w_ipw_cs
|
|
3808
|
+
M1_cs += (
|
|
3809
|
+
np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
|
|
3810
|
+
) / sum_w_D
|
|
3811
|
+
M1_cs -= np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre
|
|
3812
|
+
|
|
3813
|
+
inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_ct @ M1_ct
|
|
3814
|
+
inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_cs @ M1_cs
|
|
3815
|
+
|
|
3816
|
+
# =====================================================================
|
|
3817
|
+
# 11. Treated OR nuisance corrections (phi-scale)
|
|
3818
|
+
# =====================================================================
|
|
3819
|
+
W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt)
|
|
3820
|
+
W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs)
|
|
3821
|
+
bread_gt = _safe_inv(X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int))
|
|
3822
|
+
bread_gs = _safe_inv(X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int))
|
|
3823
|
+
|
|
3824
|
+
asy_lin_rep_gt = (W_gt_vals * resid_gt)[:, None] * X_gt_int @ bread_gt
|
|
3825
|
+
asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs
|
|
3826
|
+
|
|
3827
|
+
# M1 for treated-post model (beta_gt): mu_{1,1}(X)
|
|
3828
|
+
# From att_d_post: +w_D*X/sum_w_D (all treated)
|
|
3829
|
+
# From att_dt1_post: -w_treat_post*X/sum_w_treat_post (treated-post)
|
|
3830
|
+
M1_gt = np.zeros(X_all_int.shape[1])
|
|
3831
|
+
M1_gt += (
|
|
3832
|
+
np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
|
|
3833
|
+
) / sum_w_D
|
|
3834
|
+
M1_gt -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post
|
|
3835
|
+
|
|
3836
|
+
# M1 for treated-pre model (beta_gs): mu_{1,0}(X)
|
|
3837
|
+
# From att_d_pre: -w_D*X/sum_w_D
|
|
3838
|
+
# From att_dt0_pre: +w_treat_pre*X/sum_w_treat_pre
|
|
3839
|
+
M1_gs = np.zeros(X_all_int.shape[1])
|
|
3840
|
+
M1_gs -= (
|
|
3841
|
+
np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
|
|
3842
|
+
) / sum_w_D
|
|
3843
|
+
M1_gs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre
|
|
3844
|
+
|
|
3845
|
+
inf_all[:n_gt] += asy_lin_rep_gt @ M1_gt
|
|
3846
|
+
inf_all[n_gt : n_gt + n_gs] += asy_lin_rep_gs @ M1_gs
|
|
3847
|
+
|
|
3848
|
+
# =================================================================
|
|
3849
|
+
# SE from phi: se = sqrt(sum(phi^2))
|
|
3850
|
+
# Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0.
|
|
3851
|
+
# =================================================================
|
|
3852
|
+
se = float(np.sqrt(np.sum(inf_all**2)))
|
|
3853
|
+
|
|
3854
|
+
idx_all = None
|
|
3855
|
+
return att, se, inf_all, idx_all
|
|
3856
|
+
|
|
3857
|
+
def get_params(self) -> Dict[str, Any]:
|
|
3858
|
+
"""Get estimator parameters (sklearn-compatible)."""
|
|
3859
|
+
return {
|
|
3860
|
+
"control_group": self.control_group,
|
|
3861
|
+
"anticipation": self.anticipation,
|
|
3862
|
+
"estimation_method": self.estimation_method,
|
|
3863
|
+
"alpha": self.alpha,
|
|
3864
|
+
"cluster": self.cluster,
|
|
3865
|
+
"n_bootstrap": self.n_bootstrap,
|
|
3866
|
+
"bootstrap_weights": self.bootstrap_weights,
|
|
3867
|
+
"seed": self.seed,
|
|
3868
|
+
"rank_deficient_action": self.rank_deficient_action,
|
|
3869
|
+
"base_period": self.base_period,
|
|
3870
|
+
"cband": self.cband,
|
|
3871
|
+
"pscore_trim": self.pscore_trim,
|
|
3872
|
+
"panel": self.panel,
|
|
3873
|
+
"epv_threshold": self.epv_threshold,
|
|
3874
|
+
"pscore_fallback": self.pscore_fallback,
|
|
3875
|
+
}
|
|
3876
|
+
|
|
3877
|
+
def set_params(self, **params) -> "CallawaySantAnna":
|
|
3878
|
+
"""Set estimator parameters (sklearn-compatible)."""
|
|
3879
|
+
for key, value in params.items():
|
|
3880
|
+
if hasattr(self, key):
|
|
3881
|
+
setattr(self, key, value)
|
|
3882
|
+
else:
|
|
3883
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
3884
|
+
return self
|
|
3885
|
+
|
|
3886
|
+
def summary(self) -> str:
|
|
3887
|
+
"""Get summary of estimation results."""
|
|
3888
|
+
if not self.is_fitted_:
|
|
3889
|
+
raise RuntimeError("Model must be fitted before calling summary()")
|
|
3890
|
+
assert self.results_ is not None
|
|
3891
|
+
return self.results_.summary()
|
|
3892
|
+
|
|
3893
|
+
def print_summary(self) -> None:
|
|
3894
|
+
"""Print summary to stdout."""
|
|
3895
|
+
print(self.summary())
|