diff-diff 2.3.2__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +254 -0
- diff_diff/_backend.py +112 -0
- diff_diff/_rust_backend.cp313-win_amd64.pyd +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1161 -0
- diff_diff/honest_did.py +1511 -0
- diff_diff/imputation.py +2480 -0
- diff_diff/linalg.py +1537 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1241 -0
- diff_diff/prep_dgp.py +777 -0
- diff_diff/pretrends.py +1104 -0
- diff_diff/results.py +794 -0
- diff_diff/staggered.py +1120 -0
- diff_diff/staggered_aggregation.py +492 -0
- diff_diff/staggered_bootstrap.py +753 -0
- diff_diff/staggered_results.py +296 -0
- diff_diff/sun_abraham.py +1227 -0
- diff_diff/synthetic_did.py +858 -0
- diff_diff/triple_diff.py +1322 -0
- diff_diff/trop.py +2904 -0
- diff_diff/twfe.py +428 -0
- diff_diff/utils.py +1845 -0
- diff_diff/visualization.py +1676 -0
- diff_diff-2.3.2.dist-info/METADATA +2646 -0
- diff_diff-2.3.2.dist-info/RECORD +30 -0
- diff_diff-2.3.2.dist-info/WHEEL +4 -0
- diff_diff-2.3.2.dist-info/sboms/diff_diff_rust.cyclonedx.json +5952 -0
diff_diff/imputation.py
ADDED
|
@@ -0,0 +1,2480 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Borusyak-Jaravel-Spiess (2024) Imputation DiD Estimator.
|
|
3
|
+
|
|
4
|
+
Implements the efficient imputation estimator for staggered
|
|
5
|
+
Difference-in-Differences from Borusyak, Jaravel & Spiess (2024),
|
|
6
|
+
"Revisiting Event-Study Designs: Robust and Efficient Estimation",
|
|
7
|
+
Review of Economic Studies.
|
|
8
|
+
|
|
9
|
+
The estimator:
|
|
10
|
+
1. Runs OLS on untreated observations to estimate unit + time fixed effects
|
|
11
|
+
2. Imputes counterfactual Y(0) for treated observations
|
|
12
|
+
3. Aggregates imputed treatment effects with researcher-chosen weights
|
|
13
|
+
|
|
14
|
+
Inference uses the conservative clustered variance estimator (Theorem 3).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import warnings
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
23
|
+
from scipy import sparse, stats
|
|
24
|
+
from scipy.sparse.linalg import spsolve
|
|
25
|
+
|
|
26
|
+
from diff_diff.linalg import solve_ols
|
|
27
|
+
from diff_diff.results import _get_significance_stars
|
|
28
|
+
from diff_diff.utils import compute_confidence_interval, compute_p_value
|
|
29
|
+
|
|
30
|
+
# =============================================================================
|
|
31
|
+
# Results Dataclasses
|
|
32
|
+
# =============================================================================
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ImputationBootstrapResults:
|
|
37
|
+
"""
|
|
38
|
+
Results from ImputationDiD bootstrap inference.
|
|
39
|
+
|
|
40
|
+
Bootstrap is a library extension beyond Borusyak et al. (2024), which
|
|
41
|
+
proposes only analytical inference via the conservative variance estimator.
|
|
42
|
+
Provided for consistency with CallawaySantAnna and SunAbraham.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
n_bootstrap : int
|
|
47
|
+
Number of bootstrap iterations.
|
|
48
|
+
weight_type : str
|
|
49
|
+
Type of bootstrap weights (currently "rademacher" only).
|
|
50
|
+
alpha : float
|
|
51
|
+
Significance level used for confidence intervals.
|
|
52
|
+
overall_att_se : float
|
|
53
|
+
Bootstrap standard error for overall ATT.
|
|
54
|
+
overall_att_ci : tuple
|
|
55
|
+
Bootstrap confidence interval for overall ATT.
|
|
56
|
+
overall_att_p_value : float
|
|
57
|
+
Bootstrap p-value for overall ATT.
|
|
58
|
+
event_study_ses : dict, optional
|
|
59
|
+
Bootstrap SEs for event study effects.
|
|
60
|
+
event_study_cis : dict, optional
|
|
61
|
+
Bootstrap CIs for event study effects.
|
|
62
|
+
event_study_p_values : dict, optional
|
|
63
|
+
Bootstrap p-values for event study effects.
|
|
64
|
+
group_ses : dict, optional
|
|
65
|
+
Bootstrap SEs for group effects.
|
|
66
|
+
group_cis : dict, optional
|
|
67
|
+
Bootstrap CIs for group effects.
|
|
68
|
+
group_p_values : dict, optional
|
|
69
|
+
Bootstrap p-values for group effects.
|
|
70
|
+
bootstrap_distribution : np.ndarray, optional
|
|
71
|
+
Full bootstrap distribution of overall ATT.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
n_bootstrap: int
|
|
75
|
+
weight_type: str
|
|
76
|
+
alpha: float
|
|
77
|
+
overall_att_se: float
|
|
78
|
+
overall_att_ci: Tuple[float, float]
|
|
79
|
+
overall_att_p_value: float
|
|
80
|
+
event_study_ses: Optional[Dict[int, float]] = None
|
|
81
|
+
event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
|
|
82
|
+
event_study_p_values: Optional[Dict[int, float]] = None
|
|
83
|
+
group_ses: Optional[Dict[Any, float]] = None
|
|
84
|
+
group_cis: Optional[Dict[Any, Tuple[float, float]]] = None
|
|
85
|
+
group_p_values: Optional[Dict[Any, float]] = None
|
|
86
|
+
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class ImputationDiDResults:
|
|
91
|
+
"""
|
|
92
|
+
Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation.
|
|
93
|
+
|
|
94
|
+
Attributes
|
|
95
|
+
----------
|
|
96
|
+
treatment_effects : pd.DataFrame
|
|
97
|
+
Unit-level treatment effects with columns: unit, time, tau_hat, weight.
|
|
98
|
+
overall_att : float
|
|
99
|
+
Overall average treatment effect on the treated.
|
|
100
|
+
overall_se : float
|
|
101
|
+
Standard error of overall ATT.
|
|
102
|
+
overall_t_stat : float
|
|
103
|
+
T-statistic for overall ATT.
|
|
104
|
+
overall_p_value : float
|
|
105
|
+
P-value for overall ATT.
|
|
106
|
+
overall_conf_int : tuple
|
|
107
|
+
Confidence interval for overall ATT.
|
|
108
|
+
event_study_effects : dict, optional
|
|
109
|
+
Dictionary mapping relative time h to effect dict with keys:
|
|
110
|
+
'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'.
|
|
111
|
+
group_effects : dict, optional
|
|
112
|
+
Dictionary mapping cohort g to effect dict.
|
|
113
|
+
groups : list
|
|
114
|
+
List of treatment cohorts.
|
|
115
|
+
time_periods : list
|
|
116
|
+
List of all time periods.
|
|
117
|
+
n_obs : int
|
|
118
|
+
Total number of observations.
|
|
119
|
+
n_treated_obs : int
|
|
120
|
+
Number of treated observations (|Omega_1|).
|
|
121
|
+
n_untreated_obs : int
|
|
122
|
+
Number of untreated observations (|Omega_0|).
|
|
123
|
+
n_treated_units : int
|
|
124
|
+
Number of ever-treated units.
|
|
125
|
+
n_control_units : int
|
|
126
|
+
Number of units contributing to Omega_0.
|
|
127
|
+
alpha : float
|
|
128
|
+
Significance level used.
|
|
129
|
+
pretrend_results : dict, optional
|
|
130
|
+
Populated by pretrend_test().
|
|
131
|
+
bootstrap_results : ImputationBootstrapResults, optional
|
|
132
|
+
Bootstrap inference results.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
treatment_effects: pd.DataFrame
|
|
136
|
+
overall_att: float
|
|
137
|
+
overall_se: float
|
|
138
|
+
overall_t_stat: float
|
|
139
|
+
overall_p_value: float
|
|
140
|
+
overall_conf_int: Tuple[float, float]
|
|
141
|
+
event_study_effects: Optional[Dict[int, Dict[str, Any]]]
|
|
142
|
+
group_effects: Optional[Dict[Any, Dict[str, Any]]]
|
|
143
|
+
groups: List[Any]
|
|
144
|
+
time_periods: List[Any]
|
|
145
|
+
n_obs: int
|
|
146
|
+
n_treated_obs: int
|
|
147
|
+
n_untreated_obs: int
|
|
148
|
+
n_treated_units: int
|
|
149
|
+
n_control_units: int
|
|
150
|
+
alpha: float = 0.05
|
|
151
|
+
pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
|
152
|
+
bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False)
|
|
153
|
+
# Internal: stores data needed for pretrend_test()
|
|
154
|
+
_estimator_ref: Optional[Any] = field(default=None, repr=False)
|
|
155
|
+
|
|
156
|
+
def __repr__(self) -> str:
|
|
157
|
+
"""Concise string representation."""
|
|
158
|
+
sig = _get_significance_stars(self.overall_p_value)
|
|
159
|
+
return (
|
|
160
|
+
f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, "
|
|
161
|
+
f"SE={self.overall_se:.4f}, "
|
|
162
|
+
f"n_groups={len(self.groups)}, "
|
|
163
|
+
f"n_treated_obs={self.n_treated_obs})"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def summary(self, alpha: Optional[float] = None) -> str:
|
|
167
|
+
"""
|
|
168
|
+
Generate formatted summary of estimation results.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
alpha : float, optional
|
|
173
|
+
Significance level. Defaults to alpha used in estimation.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
str
|
|
178
|
+
Formatted summary.
|
|
179
|
+
"""
|
|
180
|
+
alpha = alpha or self.alpha
|
|
181
|
+
conf_level = int((1 - alpha) * 100)
|
|
182
|
+
|
|
183
|
+
lines = [
|
|
184
|
+
"=" * 85,
|
|
185
|
+
"Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85),
|
|
186
|
+
"=" * 85,
|
|
187
|
+
"",
|
|
188
|
+
f"{'Total observations:':<30} {self.n_obs:>10}",
|
|
189
|
+
f"{'Treated observations:':<30} {self.n_treated_obs:>10}",
|
|
190
|
+
f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}",
|
|
191
|
+
f"{'Treated units:':<30} {self.n_treated_units:>10}",
|
|
192
|
+
f"{'Control units:':<30} {self.n_control_units:>10}",
|
|
193
|
+
f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
|
|
194
|
+
f"{'Time periods:':<30} {len(self.time_periods):>10}",
|
|
195
|
+
"",
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
# Overall ATT
|
|
199
|
+
lines.extend(
|
|
200
|
+
[
|
|
201
|
+
"-" * 85,
|
|
202
|
+
"Overall Average Treatment Effect on the Treated".center(85),
|
|
203
|
+
"-" * 85,
|
|
204
|
+
f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
|
|
205
|
+
f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
|
|
206
|
+
"-" * 85,
|
|
207
|
+
]
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
t_str = (
|
|
211
|
+
f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}"
|
|
212
|
+
)
|
|
213
|
+
p_str = (
|
|
214
|
+
f"{self.overall_p_value:>10.4f}"
|
|
215
|
+
if np.isfinite(self.overall_p_value)
|
|
216
|
+
else f"{'NaN':>10}"
|
|
217
|
+
)
|
|
218
|
+
sig = _get_significance_stars(self.overall_p_value)
|
|
219
|
+
|
|
220
|
+
lines.extend(
|
|
221
|
+
[
|
|
222
|
+
f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
|
|
223
|
+
f"{t_str} {p_str} {sig:>6}",
|
|
224
|
+
"-" * 85,
|
|
225
|
+
"",
|
|
226
|
+
f"{conf_level}% Confidence Interval: "
|
|
227
|
+
f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
|
|
228
|
+
"",
|
|
229
|
+
]
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Event study effects
|
|
233
|
+
if self.event_study_effects:
|
|
234
|
+
lines.extend(
|
|
235
|
+
[
|
|
236
|
+
"-" * 85,
|
|
237
|
+
"Event Study (Dynamic) Effects".center(85),
|
|
238
|
+
"-" * 85,
|
|
239
|
+
f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
|
|
240
|
+
f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
|
|
241
|
+
"-" * 85,
|
|
242
|
+
]
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
for h in sorted(self.event_study_effects.keys()):
|
|
246
|
+
eff = self.event_study_effects[h]
|
|
247
|
+
if eff.get("n_obs", 1) == 0:
|
|
248
|
+
# Reference period marker
|
|
249
|
+
lines.append(
|
|
250
|
+
f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}"
|
|
251
|
+
)
|
|
252
|
+
elif np.isnan(eff["effect"]):
|
|
253
|
+
lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}")
|
|
254
|
+
else:
|
|
255
|
+
e_sig = _get_significance_stars(eff["p_value"])
|
|
256
|
+
e_t = (
|
|
257
|
+
f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}"
|
|
258
|
+
)
|
|
259
|
+
e_p = (
|
|
260
|
+
f"{eff['p_value']:>10.4f}"
|
|
261
|
+
if np.isfinite(eff["p_value"])
|
|
262
|
+
else f"{'NaN':>10}"
|
|
263
|
+
)
|
|
264
|
+
lines.append(
|
|
265
|
+
f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
|
|
266
|
+
f"{e_t} {e_p} {e_sig:>6}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
lines.extend(["-" * 85, ""])
|
|
270
|
+
|
|
271
|
+
# Group effects
|
|
272
|
+
if self.group_effects:
|
|
273
|
+
lines.extend(
|
|
274
|
+
[
|
|
275
|
+
"-" * 85,
|
|
276
|
+
"Group (Cohort) Effects".center(85),
|
|
277
|
+
"-" * 85,
|
|
278
|
+
f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} "
|
|
279
|
+
f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
|
|
280
|
+
"-" * 85,
|
|
281
|
+
]
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
for g in sorted(self.group_effects.keys()):
|
|
285
|
+
eff = self.group_effects[g]
|
|
286
|
+
if np.isnan(eff["effect"]):
|
|
287
|
+
lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}")
|
|
288
|
+
else:
|
|
289
|
+
g_sig = _get_significance_stars(eff["p_value"])
|
|
290
|
+
g_t = (
|
|
291
|
+
f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}"
|
|
292
|
+
)
|
|
293
|
+
g_p = (
|
|
294
|
+
f"{eff['p_value']:>10.4f}"
|
|
295
|
+
if np.isfinite(eff["p_value"])
|
|
296
|
+
else f"{'NaN':>10}"
|
|
297
|
+
)
|
|
298
|
+
lines.append(
|
|
299
|
+
f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
|
|
300
|
+
f"{g_t} {g_p} {g_sig:>6}"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
lines.extend(["-" * 85, ""])
|
|
304
|
+
|
|
305
|
+
# Pre-trend test
|
|
306
|
+
if self.pretrend_results is not None:
|
|
307
|
+
pt = self.pretrend_results
|
|
308
|
+
lines.extend(
|
|
309
|
+
[
|
|
310
|
+
"-" * 85,
|
|
311
|
+
"Pre-Trend Test (Equation 9)".center(85),
|
|
312
|
+
"-" * 85,
|
|
313
|
+
f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}",
|
|
314
|
+
f"{'P-value:':<30} {pt['p_value']:>10.4f}",
|
|
315
|
+
f"{'Degrees of freedom:':<30} {pt['df']:>10}",
|
|
316
|
+
f"{'Number of leads:':<30} {pt['n_leads']:>10}",
|
|
317
|
+
"-" * 85,
|
|
318
|
+
"",
|
|
319
|
+
]
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
lines.extend(
|
|
323
|
+
[
|
|
324
|
+
"Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
|
|
325
|
+
"=" * 85,
|
|
326
|
+
]
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
return "\n".join(lines)
|
|
330
|
+
|
|
331
|
+
def print_summary(self, alpha: Optional[float] = None) -> None:
|
|
332
|
+
"""Print summary to stdout."""
|
|
333
|
+
print(self.summary(alpha))
|
|
334
|
+
|
|
335
|
+
def to_dataframe(self, level: str = "observation") -> pd.DataFrame:
|
|
336
|
+
"""
|
|
337
|
+
Convert results to DataFrame.
|
|
338
|
+
|
|
339
|
+
Parameters
|
|
340
|
+
----------
|
|
341
|
+
level : str, default="observation"
|
|
342
|
+
Level of aggregation:
|
|
343
|
+
- "observation": Unit-level treatment effects
|
|
344
|
+
- "event_study": Event study effects by relative time
|
|
345
|
+
- "group": Group (cohort) effects
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
pd.DataFrame
|
|
350
|
+
Results as DataFrame.
|
|
351
|
+
"""
|
|
352
|
+
if level == "observation":
|
|
353
|
+
return self.treatment_effects.copy()
|
|
354
|
+
|
|
355
|
+
elif level == "event_study":
|
|
356
|
+
if self.event_study_effects is None:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
"Event study effects not computed. "
|
|
359
|
+
"Use aggregate='event_study' or aggregate='all'."
|
|
360
|
+
)
|
|
361
|
+
rows = []
|
|
362
|
+
for h, data in sorted(self.event_study_effects.items()):
|
|
363
|
+
rows.append(
|
|
364
|
+
{
|
|
365
|
+
"relative_period": h,
|
|
366
|
+
"effect": data["effect"],
|
|
367
|
+
"se": data["se"],
|
|
368
|
+
"t_stat": data["t_stat"],
|
|
369
|
+
"p_value": data["p_value"],
|
|
370
|
+
"conf_int_lower": data["conf_int"][0],
|
|
371
|
+
"conf_int_upper": data["conf_int"][1],
|
|
372
|
+
"n_obs": data.get("n_obs", np.nan),
|
|
373
|
+
}
|
|
374
|
+
)
|
|
375
|
+
return pd.DataFrame(rows)
|
|
376
|
+
|
|
377
|
+
elif level == "group":
|
|
378
|
+
if self.group_effects is None:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
"Group effects not computed. " "Use aggregate='group' or aggregate='all'."
|
|
381
|
+
)
|
|
382
|
+
rows = []
|
|
383
|
+
for g, data in sorted(self.group_effects.items()):
|
|
384
|
+
rows.append(
|
|
385
|
+
{
|
|
386
|
+
"group": g,
|
|
387
|
+
"effect": data["effect"],
|
|
388
|
+
"se": data["se"],
|
|
389
|
+
"t_stat": data["t_stat"],
|
|
390
|
+
"p_value": data["p_value"],
|
|
391
|
+
"conf_int_lower": data["conf_int"][0],
|
|
392
|
+
"conf_int_upper": data["conf_int"][1],
|
|
393
|
+
"n_obs": data.get("n_obs", np.nan),
|
|
394
|
+
}
|
|
395
|
+
)
|
|
396
|
+
return pd.DataFrame(rows)
|
|
397
|
+
|
|
398
|
+
else:
|
|
399
|
+
raise ValueError(
|
|
400
|
+
f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'."
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
|
|
404
|
+
"""
|
|
405
|
+
Run a pre-trend test (Equation 9 of Borusyak et al. 2024).
|
|
406
|
+
|
|
407
|
+
Adds pre-treatment lead indicators to the Step 1 OLS and tests
|
|
408
|
+
their joint significance via a cluster-robust Wald F-test.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
n_leads : int, optional
|
|
413
|
+
Number of pre-treatment leads to include. If None, uses all
|
|
414
|
+
available pre-treatment periods minus one (for the reference period).
|
|
415
|
+
|
|
416
|
+
Returns
|
|
417
|
+
-------
|
|
418
|
+
dict
|
|
419
|
+
Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads',
|
|
420
|
+
'lead_coefficients'.
|
|
421
|
+
"""
|
|
422
|
+
if self._estimator_ref is None:
|
|
423
|
+
raise RuntimeError(
|
|
424
|
+
"Pre-trend test requires internal estimator reference. "
|
|
425
|
+
"Re-fit the model to use this method."
|
|
426
|
+
)
|
|
427
|
+
result = self._estimator_ref._pretrend_test(n_leads=n_leads)
|
|
428
|
+
self.pretrend_results = result
|
|
429
|
+
return result
|
|
430
|
+
|
|
431
|
+
@property
|
|
432
|
+
def is_significant(self) -> bool:
|
|
433
|
+
"""Check if overall ATT is significant."""
|
|
434
|
+
return bool(self.overall_p_value < self.alpha)
|
|
435
|
+
|
|
436
|
+
@property
|
|
437
|
+
def significance_stars(self) -> str:
|
|
438
|
+
"""Significance stars for overall ATT."""
|
|
439
|
+
return _get_significance_stars(self.overall_p_value)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# =============================================================================
|
|
443
|
+
# Main Estimator
|
|
444
|
+
# =============================================================================
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class ImputationDiD:
|
|
448
|
+
"""
|
|
449
|
+
Borusyak-Jaravel-Spiess (2024) imputation DiD estimator.
|
|
450
|
+
|
|
451
|
+
This is the efficient estimator for staggered Difference-in-Differences
|
|
452
|
+
under parallel trends. It produces shorter confidence intervals than
|
|
453
|
+
Callaway-Sant'Anna (~50% shorter) and Sun-Abraham (2-3.5x shorter)
|
|
454
|
+
under homogeneous treatment effects.
|
|
455
|
+
|
|
456
|
+
The estimation procedure:
|
|
457
|
+
1. Run OLS on untreated observations to estimate unit + time fixed effects
|
|
458
|
+
2. Impute counterfactual Y(0) for treated observations
|
|
459
|
+
3. Aggregate imputed treatment effects with researcher-chosen weights
|
|
460
|
+
|
|
461
|
+
Inference uses the conservative clustered variance estimator from Theorem 3
|
|
462
|
+
of the paper.
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
anticipation : int, default=0
|
|
467
|
+
Number of periods before treatment where effects may occur.
|
|
468
|
+
alpha : float, default=0.05
|
|
469
|
+
Significance level for confidence intervals.
|
|
470
|
+
cluster : str, optional
|
|
471
|
+
Column name for cluster-robust standard errors.
|
|
472
|
+
If None, clusters at the unit level by default.
|
|
473
|
+
n_bootstrap : int, default=0
|
|
474
|
+
Number of bootstrap iterations. If 0, uses analytical inference
|
|
475
|
+
(conservative variance from Theorem 3).
|
|
476
|
+
seed : int, optional
|
|
477
|
+
Random seed for reproducibility.
|
|
478
|
+
rank_deficient_action : str, default="warn"
|
|
479
|
+
Action when design matrix is rank-deficient:
|
|
480
|
+
- "warn": Issue warning and drop linearly dependent columns
|
|
481
|
+
- "error": Raise ValueError
|
|
482
|
+
- "silent": Drop columns silently
|
|
483
|
+
horizon_max : int, optional
|
|
484
|
+
Maximum event-study horizon. If set, event study effects are only
|
|
485
|
+
computed for |h| <= horizon_max.
|
|
486
|
+
aux_partition : str, default="cohort_horizon"
|
|
487
|
+
Controls the auxiliary model partition for Theorem 3 variance:
|
|
488
|
+
- "cohort_horizon": Groups by cohort x relative time (tightest SEs)
|
|
489
|
+
- "cohort": Groups by cohort only (more conservative)
|
|
490
|
+
- "horizon": Groups by relative time only (more conservative)
|
|
491
|
+
|
|
492
|
+
Attributes
|
|
493
|
+
----------
|
|
494
|
+
results_ : ImputationDiDResults
|
|
495
|
+
Estimation results after calling fit().
|
|
496
|
+
is_fitted_ : bool
|
|
497
|
+
Whether the model has been fitted.
|
|
498
|
+
|
|
499
|
+
Examples
|
|
500
|
+
--------
|
|
501
|
+
Basic usage:
|
|
502
|
+
|
|
503
|
+
>>> from diff_diff import ImputationDiD, generate_staggered_data
|
|
504
|
+
>>> data = generate_staggered_data(n_units=200, seed=42)
|
|
505
|
+
>>> est = ImputationDiD()
|
|
506
|
+
>>> results = est.fit(data, outcome='outcome', unit='unit',
|
|
507
|
+
... time='time', first_treat='first_treat')
|
|
508
|
+
>>> results.print_summary()
|
|
509
|
+
|
|
510
|
+
With event study:
|
|
511
|
+
|
|
512
|
+
>>> est = ImputationDiD()
|
|
513
|
+
>>> results = est.fit(data, outcome='outcome', unit='unit',
|
|
514
|
+
... time='time', first_treat='first_treat',
|
|
515
|
+
... aggregate='event_study')
|
|
516
|
+
>>> from diff_diff import plot_event_study
|
|
517
|
+
>>> plot_event_study(results)
|
|
518
|
+
|
|
519
|
+
Notes
|
|
520
|
+
-----
|
|
521
|
+
The imputation estimator uses ALL untreated observations (never-treated +
|
|
522
|
+
not-yet-treated periods of eventually-treated units) to estimate the
|
|
523
|
+
counterfactual model. There is no ``control_group`` parameter because this
|
|
524
|
+
is fundamental to the method's efficiency.
|
|
525
|
+
|
|
526
|
+
References
|
|
527
|
+
----------
|
|
528
|
+
Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event-Study
|
|
529
|
+
Designs: Robust and Efficient Estimation. Review of Economic Studies,
|
|
530
|
+
91(6), 3253-3285.
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
def __init__(
|
|
534
|
+
self,
|
|
535
|
+
anticipation: int = 0,
|
|
536
|
+
alpha: float = 0.05,
|
|
537
|
+
cluster: Optional[str] = None,
|
|
538
|
+
n_bootstrap: int = 0,
|
|
539
|
+
seed: Optional[int] = None,
|
|
540
|
+
rank_deficient_action: str = "warn",
|
|
541
|
+
horizon_max: Optional[int] = None,
|
|
542
|
+
aux_partition: str = "cohort_horizon",
|
|
543
|
+
):
|
|
544
|
+
if rank_deficient_action not in ("warn", "error", "silent"):
|
|
545
|
+
raise ValueError(
|
|
546
|
+
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
|
|
547
|
+
f"got '{rank_deficient_action}'"
|
|
548
|
+
)
|
|
549
|
+
if aux_partition not in ("cohort_horizon", "cohort", "horizon"):
|
|
550
|
+
raise ValueError(
|
|
551
|
+
f"aux_partition must be 'cohort_horizon', 'cohort', or 'horizon', "
|
|
552
|
+
f"got '{aux_partition}'"
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
self.anticipation = anticipation
|
|
556
|
+
self.alpha = alpha
|
|
557
|
+
self.cluster = cluster
|
|
558
|
+
self.n_bootstrap = n_bootstrap
|
|
559
|
+
self.seed = seed
|
|
560
|
+
self.rank_deficient_action = rank_deficient_action
|
|
561
|
+
self.horizon_max = horizon_max
|
|
562
|
+
self.aux_partition = aux_partition
|
|
563
|
+
|
|
564
|
+
self.is_fitted_ = False
|
|
565
|
+
self.results_: Optional[ImputationDiDResults] = None
|
|
566
|
+
|
|
567
|
+
# Internal state preserved for pretrend_test()
|
|
568
|
+
self._fit_data: Optional[Dict[str, Any]] = None
|
|
569
|
+
|
|
570
|
+
def fit(
|
|
571
|
+
self,
|
|
572
|
+
data: pd.DataFrame,
|
|
573
|
+
outcome: str,
|
|
574
|
+
unit: str,
|
|
575
|
+
time: str,
|
|
576
|
+
first_treat: str,
|
|
577
|
+
covariates: Optional[List[str]] = None,
|
|
578
|
+
aggregate: Optional[str] = None,
|
|
579
|
+
balance_e: Optional[int] = None,
|
|
580
|
+
) -> ImputationDiDResults:
|
|
581
|
+
"""
|
|
582
|
+
Fit the imputation DiD estimator.
|
|
583
|
+
|
|
584
|
+
Parameters
|
|
585
|
+
----------
|
|
586
|
+
data : pd.DataFrame
|
|
587
|
+
Panel data with unit and time identifiers.
|
|
588
|
+
outcome : str
|
|
589
|
+
Name of outcome variable column.
|
|
590
|
+
unit : str
|
|
591
|
+
Name of unit identifier column.
|
|
592
|
+
time : str
|
|
593
|
+
Name of time period column.
|
|
594
|
+
first_treat : str
|
|
595
|
+
Name of column indicating when unit was first treated.
|
|
596
|
+
Use 0 (or np.inf) for never-treated units.
|
|
597
|
+
covariates : list of str, optional
|
|
598
|
+
List of covariate column names.
|
|
599
|
+
aggregate : str, optional
|
|
600
|
+
Aggregation mode: None/"simple" (overall ATT only),
|
|
601
|
+
"event_study", "group", or "all".
|
|
602
|
+
balance_e : int, optional
|
|
603
|
+
When computing event study, restrict to cohorts observed at all
|
|
604
|
+
relative times in [-balance_e, max_h].
|
|
605
|
+
|
|
606
|
+
Returns
|
|
607
|
+
-------
|
|
608
|
+
ImputationDiDResults
|
|
609
|
+
Object containing all estimation results.
|
|
610
|
+
|
|
611
|
+
Raises
|
|
612
|
+
------
|
|
613
|
+
ValueError
|
|
614
|
+
If required columns are missing or data validation fails.
|
|
615
|
+
"""
|
|
616
|
+
# Validate inputs
|
|
617
|
+
required_cols = [outcome, unit, time, first_treat]
|
|
618
|
+
if covariates:
|
|
619
|
+
required_cols.extend(covariates)
|
|
620
|
+
|
|
621
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
622
|
+
if missing:
|
|
623
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
624
|
+
|
|
625
|
+
# Create working copy
|
|
626
|
+
df = data.copy()
|
|
627
|
+
|
|
628
|
+
# Ensure numeric types
|
|
629
|
+
df[time] = pd.to_numeric(df[time])
|
|
630
|
+
df[first_treat] = pd.to_numeric(df[first_treat])
|
|
631
|
+
|
|
632
|
+
# Validate absorbing treatment: first_treat must be constant within each unit
|
|
633
|
+
ft_nunique = df.groupby(unit)[first_treat].nunique()
|
|
634
|
+
non_constant = ft_nunique[ft_nunique > 1]
|
|
635
|
+
if len(non_constant) > 0:
|
|
636
|
+
example_unit = non_constant.index[0]
|
|
637
|
+
example_vals = sorted(df.loc[df[unit] == example_unit, first_treat].unique())
|
|
638
|
+
warnings.warn(
|
|
639
|
+
f"{len(non_constant)} unit(s) have non-constant '{first_treat}' "
|
|
640
|
+
f"values (e.g., unit '{example_unit}' has values {example_vals}). "
|
|
641
|
+
f"ImputationDiD assumes treatment is an absorbing state "
|
|
642
|
+
f"(once treated, always treated) with a single treatment onset "
|
|
643
|
+
f"time per unit. Non-constant first_treat violates this assumption "
|
|
644
|
+
f"and may produce unreliable estimates.",
|
|
645
|
+
UserWarning,
|
|
646
|
+
stacklevel=2,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# Coerce to per-unit value so downstream code
|
|
650
|
+
# (_never_treated, _treated, _rel_time) uses a single
|
|
651
|
+
# consistent first_treat per unit.
|
|
652
|
+
df[first_treat] = df.groupby(unit)[first_treat].transform("first")
|
|
653
|
+
|
|
654
|
+
# Identify treatment status
|
|
655
|
+
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
|
|
656
|
+
|
|
657
|
+
# Check for always-treated units (treated in all observed periods)
|
|
658
|
+
min_time = df[time].min()
|
|
659
|
+
always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time)
|
|
660
|
+
n_always_treated = df.loc[always_treated_mask, unit].nunique()
|
|
661
|
+
if n_always_treated > 0:
|
|
662
|
+
warnings.warn(
|
|
663
|
+
f"{n_always_treated} unit(s) are treated in all observed periods "
|
|
664
|
+
f"(first_treat <= {min_time}). These units have no untreated "
|
|
665
|
+
"observations and cannot contribute to the counterfactual model. "
|
|
666
|
+
"Their treatment effects will be imputed but may be unreliable.",
|
|
667
|
+
UserWarning,
|
|
668
|
+
stacklevel=2,
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# Create treatment indicator D_it
|
|
672
|
+
# D_it = 1 if t >= first_treat and first_treat > 0
|
|
673
|
+
# With anticipation: D_it = 1 if t >= first_treat - anticipation
|
|
674
|
+
effective_treat = df[first_treat] - self.anticipation
|
|
675
|
+
df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat)
|
|
676
|
+
|
|
677
|
+
# Identify Omega_0 (untreated) and Omega_1 (treated)
|
|
678
|
+
omega_0_mask = ~df["_treated"]
|
|
679
|
+
omega_1_mask = df["_treated"]
|
|
680
|
+
|
|
681
|
+
n_omega_0 = int(omega_0_mask.sum())
|
|
682
|
+
n_omega_1 = int(omega_1_mask.sum())
|
|
683
|
+
|
|
684
|
+
if n_omega_0 == 0:
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"No untreated observations found. Cannot estimate counterfactual model."
|
|
687
|
+
)
|
|
688
|
+
if n_omega_1 == 0:
|
|
689
|
+
raise ValueError("No treated observations found. Nothing to estimate.")
|
|
690
|
+
|
|
691
|
+
# Identify groups and time periods
|
|
692
|
+
time_periods = sorted(df[time].unique())
|
|
693
|
+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf])
|
|
694
|
+
|
|
695
|
+
if len(treatment_groups) == 0:
|
|
696
|
+
raise ValueError("No treated units found. Check 'first_treat' column.")
|
|
697
|
+
|
|
698
|
+
# Unit info
|
|
699
|
+
unit_info = (
|
|
700
|
+
df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index()
|
|
701
|
+
)
|
|
702
|
+
n_treated_units = int((~unit_info["_never_treated"]).sum())
|
|
703
|
+
# Control units = units with at least one untreated observation
|
|
704
|
+
units_in_omega_0 = df.loc[omega_0_mask, unit].unique()
|
|
705
|
+
n_control_units = len(units_in_omega_0)
|
|
706
|
+
|
|
707
|
+
# Cluster variable
|
|
708
|
+
cluster_var = self.cluster if self.cluster is not None else unit
|
|
709
|
+
if self.cluster is not None and self.cluster not in df.columns:
|
|
710
|
+
raise ValueError(
|
|
711
|
+
f"Cluster column '{self.cluster}' not found in data. "
|
|
712
|
+
f"Available columns: {list(df.columns)}"
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
# Compute relative time
|
|
716
|
+
df["_rel_time"] = np.where(
|
|
717
|
+
~df["_never_treated"],
|
|
718
|
+
df[time] - df[first_treat],
|
|
719
|
+
np.nan,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
# ---- Step 1: OLS on untreated observations ----
|
|
723
|
+
unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model(
|
|
724
|
+
df, outcome, unit, time, covariates, omega_0_mask
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# ---- Rank condition checks ----
|
|
728
|
+
# Check: every treated unit should have >= 1 untreated period (for unit FE)
|
|
729
|
+
treated_unit_ids = df.loc[omega_1_mask, unit].unique()
|
|
730
|
+
units_with_fe = set(unit_fe.keys())
|
|
731
|
+
units_missing_fe = set(treated_unit_ids) - units_with_fe
|
|
732
|
+
|
|
733
|
+
# Check: every post-treatment period should have >= 1 untreated unit (for time FE)
|
|
734
|
+
post_period_ids = df.loc[omega_1_mask, time].unique()
|
|
735
|
+
periods_with_fe = set(time_fe.keys())
|
|
736
|
+
periods_missing_fe = set(post_period_ids) - periods_with_fe
|
|
737
|
+
|
|
738
|
+
if units_missing_fe or periods_missing_fe:
|
|
739
|
+
parts = []
|
|
740
|
+
if units_missing_fe:
|
|
741
|
+
sorted_missing = sorted(units_missing_fe)
|
|
742
|
+
parts.append(
|
|
743
|
+
f"{len(units_missing_fe)} treated unit(s) have no untreated "
|
|
744
|
+
f"periods (units: {sorted_missing[:5]}"
|
|
745
|
+
f"{'...' if len(units_missing_fe) > 5 else ''})"
|
|
746
|
+
)
|
|
747
|
+
if periods_missing_fe:
|
|
748
|
+
sorted_missing = sorted(periods_missing_fe)
|
|
749
|
+
parts.append(
|
|
750
|
+
f"{len(periods_missing_fe)} post-treatment period(s) have no "
|
|
751
|
+
f"untreated units (periods: {sorted_missing[:5]}"
|
|
752
|
+
f"{'...' if len(periods_missing_fe) > 5 else ''})"
|
|
753
|
+
)
|
|
754
|
+
msg = (
|
|
755
|
+
"Rank condition violated: "
|
|
756
|
+
+ "; ".join(parts)
|
|
757
|
+
+ ". Affected treatment effects will be NaN."
|
|
758
|
+
)
|
|
759
|
+
if self.rank_deficient_action == "error":
|
|
760
|
+
raise ValueError(msg)
|
|
761
|
+
elif self.rank_deficient_action == "warn":
|
|
762
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
763
|
+
# "silent": continue without warning
|
|
764
|
+
|
|
765
|
+
# ---- Step 2: Impute treatment effects ----
|
|
766
|
+
tau_hat, y_hat_0 = self._impute_treatment_effects(
|
|
767
|
+
df,
|
|
768
|
+
outcome,
|
|
769
|
+
unit,
|
|
770
|
+
time,
|
|
771
|
+
covariates,
|
|
772
|
+
omega_1_mask,
|
|
773
|
+
unit_fe,
|
|
774
|
+
time_fe,
|
|
775
|
+
grand_mean,
|
|
776
|
+
delta_hat,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# Store tau_hat in dataframe
|
|
780
|
+
df["_tau_hat"] = np.nan
|
|
781
|
+
df.loc[omega_1_mask, "_tau_hat"] = tau_hat
|
|
782
|
+
|
|
783
|
+
# ---- Step 3: Aggregate ----
|
|
784
|
+
# Always compute overall ATT (simple aggregation)
|
|
785
|
+
valid_tau = tau_hat[np.isfinite(tau_hat)]
|
|
786
|
+
|
|
787
|
+
if len(valid_tau) == 0:
|
|
788
|
+
overall_att = np.nan
|
|
789
|
+
else:
|
|
790
|
+
overall_att = float(np.mean(valid_tau))
|
|
791
|
+
|
|
792
|
+
# ---- Conservative variance (Theorem 3) ----
|
|
793
|
+
# Build weights matching the ATT: uniform over finite tau_hat, zero for NaN
|
|
794
|
+
overall_weights = np.zeros(n_omega_1)
|
|
795
|
+
finite_mask = np.isfinite(tau_hat)
|
|
796
|
+
n_valid = int(finite_mask.sum())
|
|
797
|
+
if n_valid > 0:
|
|
798
|
+
overall_weights[finite_mask] = 1.0 / n_valid
|
|
799
|
+
|
|
800
|
+
if n_valid == 0:
|
|
801
|
+
overall_se = np.nan
|
|
802
|
+
else:
|
|
803
|
+
overall_se = self._compute_conservative_variance(
|
|
804
|
+
df=df,
|
|
805
|
+
outcome=outcome,
|
|
806
|
+
unit=unit,
|
|
807
|
+
time=time,
|
|
808
|
+
first_treat=first_treat,
|
|
809
|
+
covariates=covariates,
|
|
810
|
+
omega_0_mask=omega_0_mask,
|
|
811
|
+
omega_1_mask=omega_1_mask,
|
|
812
|
+
unit_fe=unit_fe,
|
|
813
|
+
time_fe=time_fe,
|
|
814
|
+
grand_mean=grand_mean,
|
|
815
|
+
delta_hat=delta_hat,
|
|
816
|
+
weights=overall_weights,
|
|
817
|
+
cluster_var=cluster_var,
|
|
818
|
+
kept_cov_mask=kept_cov_mask,
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
overall_t = (
|
|
822
|
+
overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
|
|
823
|
+
)
|
|
824
|
+
overall_p = compute_p_value(overall_t)
|
|
825
|
+
overall_ci = (
|
|
826
|
+
compute_confidence_interval(overall_att, overall_se, self.alpha)
|
|
827
|
+
if np.isfinite(overall_se) and overall_se > 0
|
|
828
|
+
else (np.nan, np.nan)
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# Event study and group aggregation
|
|
832
|
+
event_study_effects = None
|
|
833
|
+
group_effects = None
|
|
834
|
+
|
|
835
|
+
if aggregate in ("event_study", "all"):
|
|
836
|
+
event_study_effects = self._aggregate_event_study(
|
|
837
|
+
df=df,
|
|
838
|
+
outcome=outcome,
|
|
839
|
+
unit=unit,
|
|
840
|
+
time=time,
|
|
841
|
+
first_treat=first_treat,
|
|
842
|
+
covariates=covariates,
|
|
843
|
+
omega_0_mask=omega_0_mask,
|
|
844
|
+
omega_1_mask=omega_1_mask,
|
|
845
|
+
unit_fe=unit_fe,
|
|
846
|
+
time_fe=time_fe,
|
|
847
|
+
grand_mean=grand_mean,
|
|
848
|
+
delta_hat=delta_hat,
|
|
849
|
+
cluster_var=cluster_var,
|
|
850
|
+
treatment_groups=treatment_groups,
|
|
851
|
+
balance_e=balance_e,
|
|
852
|
+
kept_cov_mask=kept_cov_mask,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
if aggregate in ("group", "all"):
|
|
856
|
+
group_effects = self._aggregate_group(
|
|
857
|
+
df=df,
|
|
858
|
+
outcome=outcome,
|
|
859
|
+
unit=unit,
|
|
860
|
+
time=time,
|
|
861
|
+
first_treat=first_treat,
|
|
862
|
+
covariates=covariates,
|
|
863
|
+
omega_0_mask=omega_0_mask,
|
|
864
|
+
omega_1_mask=omega_1_mask,
|
|
865
|
+
unit_fe=unit_fe,
|
|
866
|
+
time_fe=time_fe,
|
|
867
|
+
grand_mean=grand_mean,
|
|
868
|
+
delta_hat=delta_hat,
|
|
869
|
+
cluster_var=cluster_var,
|
|
870
|
+
treatment_groups=treatment_groups,
|
|
871
|
+
kept_cov_mask=kept_cov_mask,
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
# Build treatment effects dataframe
|
|
875
|
+
treated_df = df.loc[omega_1_mask, [unit, time, "_tau_hat", "_rel_time"]].copy()
|
|
876
|
+
treated_df = treated_df.rename(columns={"_tau_hat": "tau_hat", "_rel_time": "rel_time"})
|
|
877
|
+
# Weights consistent with actual ATT: zero for NaN tau_hat, 1/n_valid for finite
|
|
878
|
+
tau_finite = treated_df["tau_hat"].notna()
|
|
879
|
+
n_valid_te = int(tau_finite.sum())
|
|
880
|
+
if n_valid_te > 0:
|
|
881
|
+
treated_df["weight"] = np.where(tau_finite, 1.0 / n_valid_te, 0.0)
|
|
882
|
+
else:
|
|
883
|
+
treated_df["weight"] = 0.0
|
|
884
|
+
|
|
885
|
+
# Store fit data for pretrend_test
|
|
886
|
+
self._fit_data = {
|
|
887
|
+
"df": df,
|
|
888
|
+
"outcome": outcome,
|
|
889
|
+
"unit": unit,
|
|
890
|
+
"time": time,
|
|
891
|
+
"first_treat": first_treat,
|
|
892
|
+
"covariates": covariates,
|
|
893
|
+
"omega_0_mask": omega_0_mask,
|
|
894
|
+
"omega_1_mask": omega_1_mask,
|
|
895
|
+
"cluster_var": cluster_var,
|
|
896
|
+
"unit_fe": unit_fe,
|
|
897
|
+
"time_fe": time_fe,
|
|
898
|
+
"grand_mean": grand_mean,
|
|
899
|
+
"delta_hat": delta_hat,
|
|
900
|
+
"kept_cov_mask": kept_cov_mask,
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
# Pre-compute cluster psi sums for bootstrap
|
|
904
|
+
psi_data = None
|
|
905
|
+
if self.n_bootstrap > 0 and n_valid > 0:
|
|
906
|
+
try:
|
|
907
|
+
psi_data = self._precompute_bootstrap_psi(
|
|
908
|
+
df=df,
|
|
909
|
+
outcome=outcome,
|
|
910
|
+
unit=unit,
|
|
911
|
+
time=time,
|
|
912
|
+
first_treat=first_treat,
|
|
913
|
+
covariates=covariates,
|
|
914
|
+
omega_0_mask=omega_0_mask,
|
|
915
|
+
omega_1_mask=omega_1_mask,
|
|
916
|
+
unit_fe=unit_fe,
|
|
917
|
+
time_fe=time_fe,
|
|
918
|
+
grand_mean=grand_mean,
|
|
919
|
+
delta_hat=delta_hat,
|
|
920
|
+
cluster_var=cluster_var,
|
|
921
|
+
kept_cov_mask=kept_cov_mask,
|
|
922
|
+
overall_weights=overall_weights,
|
|
923
|
+
event_study_effects=event_study_effects,
|
|
924
|
+
group_effects=group_effects,
|
|
925
|
+
treatment_groups=treatment_groups,
|
|
926
|
+
tau_hat=tau_hat,
|
|
927
|
+
balance_e=balance_e,
|
|
928
|
+
)
|
|
929
|
+
except Exception as e:
|
|
930
|
+
warnings.warn(
|
|
931
|
+
f"Bootstrap pre-computation failed: {e}. " "Skipping bootstrap inference.",
|
|
932
|
+
UserWarning,
|
|
933
|
+
stacklevel=2,
|
|
934
|
+
)
|
|
935
|
+
psi_data = None
|
|
936
|
+
|
|
937
|
+
# Bootstrap
|
|
938
|
+
bootstrap_results = None
|
|
939
|
+
if self.n_bootstrap > 0 and psi_data is not None:
|
|
940
|
+
bootstrap_results = self._run_bootstrap(
|
|
941
|
+
original_att=overall_att,
|
|
942
|
+
original_event_study=event_study_effects,
|
|
943
|
+
original_group=group_effects,
|
|
944
|
+
psi_data=psi_data,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# Update inference with bootstrap results
|
|
948
|
+
overall_se = bootstrap_results.overall_att_se
|
|
949
|
+
overall_t = (
|
|
950
|
+
overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
|
|
951
|
+
)
|
|
952
|
+
overall_p = bootstrap_results.overall_att_p_value
|
|
953
|
+
overall_ci = bootstrap_results.overall_att_ci
|
|
954
|
+
|
|
955
|
+
# Update event study
|
|
956
|
+
if event_study_effects and bootstrap_results.event_study_ses:
|
|
957
|
+
for h in event_study_effects:
|
|
958
|
+
if (
|
|
959
|
+
h in bootstrap_results.event_study_ses
|
|
960
|
+
and event_study_effects[h].get("n_obs", 1) > 0
|
|
961
|
+
):
|
|
962
|
+
event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h]
|
|
963
|
+
event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[h]
|
|
964
|
+
event_study_effects[h]["p_value"] = bootstrap_results.event_study_p_values[
|
|
965
|
+
h
|
|
966
|
+
]
|
|
967
|
+
eff_val = event_study_effects[h]["effect"]
|
|
968
|
+
se_val = event_study_effects[h]["se"]
|
|
969
|
+
event_study_effects[h]["t_stat"] = (
|
|
970
|
+
eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
# Update group effects
|
|
974
|
+
if group_effects and bootstrap_results.group_ses:
|
|
975
|
+
for g in group_effects:
|
|
976
|
+
if g in bootstrap_results.group_ses:
|
|
977
|
+
group_effects[g]["se"] = bootstrap_results.group_ses[g]
|
|
978
|
+
group_effects[g]["conf_int"] = bootstrap_results.group_cis[g]
|
|
979
|
+
group_effects[g]["p_value"] = bootstrap_results.group_p_values[g]
|
|
980
|
+
eff_val = group_effects[g]["effect"]
|
|
981
|
+
se_val = group_effects[g]["se"]
|
|
982
|
+
group_effects[g]["t_stat"] = (
|
|
983
|
+
eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
# Construct results
|
|
987
|
+
self.results_ = ImputationDiDResults(
|
|
988
|
+
treatment_effects=treated_df,
|
|
989
|
+
overall_att=overall_att,
|
|
990
|
+
overall_se=overall_se,
|
|
991
|
+
overall_t_stat=overall_t,
|
|
992
|
+
overall_p_value=overall_p,
|
|
993
|
+
overall_conf_int=overall_ci,
|
|
994
|
+
event_study_effects=event_study_effects,
|
|
995
|
+
group_effects=group_effects,
|
|
996
|
+
groups=treatment_groups,
|
|
997
|
+
time_periods=time_periods,
|
|
998
|
+
n_obs=len(df),
|
|
999
|
+
n_treated_obs=n_omega_1,
|
|
1000
|
+
n_untreated_obs=n_omega_0,
|
|
1001
|
+
n_treated_units=n_treated_units,
|
|
1002
|
+
n_control_units=n_control_units,
|
|
1003
|
+
alpha=self.alpha,
|
|
1004
|
+
bootstrap_results=bootstrap_results,
|
|
1005
|
+
_estimator_ref=self,
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
self.is_fitted_ = True
|
|
1009
|
+
return self.results_
|
|
1010
|
+
|
|
1011
|
+
# =========================================================================
|
|
1012
|
+
# Step 1: OLS on untreated observations
|
|
1013
|
+
# =========================================================================
|
|
1014
|
+
|
|
1015
|
+
def _iterative_fe(
|
|
1016
|
+
self,
|
|
1017
|
+
y: np.ndarray,
|
|
1018
|
+
unit_vals: np.ndarray,
|
|
1019
|
+
time_vals: np.ndarray,
|
|
1020
|
+
idx: pd.Index,
|
|
1021
|
+
max_iter: int = 100,
|
|
1022
|
+
tol: float = 1e-10,
|
|
1023
|
+
) -> Tuple[Dict[Any, float], Dict[Any, float]]:
|
|
1024
|
+
"""
|
|
1025
|
+
Estimate unit and time FE via iterative alternating projection (Gauss-Seidel).
|
|
1026
|
+
|
|
1027
|
+
Converges to the exact OLS solution for both balanced and unbalanced panels.
|
|
1028
|
+
For balanced panels, converges in 1-2 iterations (identical to one-pass).
|
|
1029
|
+
For unbalanced panels, typically 5-20 iterations.
|
|
1030
|
+
|
|
1031
|
+
Returns
|
|
1032
|
+
-------
|
|
1033
|
+
unit_fe : dict
|
|
1034
|
+
Mapping from unit -> unit fixed effect.
|
|
1035
|
+
time_fe : dict
|
|
1036
|
+
Mapping from time -> time fixed effect.
|
|
1037
|
+
"""
|
|
1038
|
+
n = len(y)
|
|
1039
|
+
alpha = np.zeros(n) # unit FE broadcast to obs level
|
|
1040
|
+
beta = np.zeros(n) # time FE broadcast to obs level
|
|
1041
|
+
|
|
1042
|
+
with np.errstate(invalid="ignore", divide="ignore"):
|
|
1043
|
+
for iteration in range(max_iter):
|
|
1044
|
+
# Update time FE: beta_t = mean_i(y_it - alpha_i)
|
|
1045
|
+
resid_after_alpha = y - alpha
|
|
1046
|
+
beta_new = (
|
|
1047
|
+
pd.Series(resid_after_alpha, index=idx)
|
|
1048
|
+
.groupby(time_vals)
|
|
1049
|
+
.transform("mean")
|
|
1050
|
+
.values
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
# Update unit FE: alpha_i = mean_t(y_it - beta_t)
|
|
1054
|
+
resid_after_beta = y - beta_new
|
|
1055
|
+
alpha_new = (
|
|
1056
|
+
pd.Series(resid_after_beta, index=idx)
|
|
1057
|
+
.groupby(unit_vals)
|
|
1058
|
+
.transform("mean")
|
|
1059
|
+
.values
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# Check convergence on FE changes
|
|
1063
|
+
max_change = max(
|
|
1064
|
+
np.max(np.abs(alpha_new - alpha)),
|
|
1065
|
+
np.max(np.abs(beta_new - beta)),
|
|
1066
|
+
)
|
|
1067
|
+
alpha = alpha_new
|
|
1068
|
+
beta = beta_new
|
|
1069
|
+
if max_change < tol:
|
|
1070
|
+
break
|
|
1071
|
+
|
|
1072
|
+
unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
|
|
1073
|
+
time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
|
|
1074
|
+
return unit_fe, time_fe
|
|
1075
|
+
|
|
1076
|
+
@staticmethod
|
|
1077
|
+
def _iterative_demean(
|
|
1078
|
+
vals: np.ndarray,
|
|
1079
|
+
unit_vals: np.ndarray,
|
|
1080
|
+
time_vals: np.ndarray,
|
|
1081
|
+
idx: pd.Index,
|
|
1082
|
+
max_iter: int = 100,
|
|
1083
|
+
tol: float = 1e-10,
|
|
1084
|
+
) -> np.ndarray:
|
|
1085
|
+
"""Demean a vector by iterative alternating projection (unit + time FE removal).
|
|
1086
|
+
|
|
1087
|
+
Converges to the exact within-transformation for both balanced and
|
|
1088
|
+
unbalanced panels. For balanced panels, converges in 1-2 iterations.
|
|
1089
|
+
"""
|
|
1090
|
+
result = vals.copy()
|
|
1091
|
+
with np.errstate(invalid="ignore", divide="ignore"):
|
|
1092
|
+
for _ in range(max_iter):
|
|
1093
|
+
time_means = (
|
|
1094
|
+
pd.Series(result, index=idx).groupby(time_vals).transform("mean").values
|
|
1095
|
+
)
|
|
1096
|
+
result_after_time = result - time_means
|
|
1097
|
+
unit_means = (
|
|
1098
|
+
pd.Series(result_after_time, index=idx)
|
|
1099
|
+
.groupby(unit_vals)
|
|
1100
|
+
.transform("mean")
|
|
1101
|
+
.values
|
|
1102
|
+
)
|
|
1103
|
+
result_new = result_after_time - unit_means
|
|
1104
|
+
if np.max(np.abs(result_new - result)) < tol:
|
|
1105
|
+
result = result_new
|
|
1106
|
+
break
|
|
1107
|
+
result = result_new
|
|
1108
|
+
return result
|
|
1109
|
+
|
|
1110
|
+
@staticmethod
|
|
1111
|
+
def _compute_balanced_cohort_mask(
|
|
1112
|
+
df_treated: pd.DataFrame,
|
|
1113
|
+
first_treat: str,
|
|
1114
|
+
all_horizons: List[int],
|
|
1115
|
+
balance_e: int,
|
|
1116
|
+
cohort_rel_times: Dict[Any, Set[int]],
|
|
1117
|
+
) -> np.ndarray:
|
|
1118
|
+
"""Compute boolean mask selecting treated obs from balanced cohorts.
|
|
1119
|
+
|
|
1120
|
+
A cohort is 'balanced' if it has observations at every relative time
|
|
1121
|
+
in [-balance_e, max(all_horizons)].
|
|
1122
|
+
|
|
1123
|
+
Parameters
|
|
1124
|
+
----------
|
|
1125
|
+
df_treated : pd.DataFrame
|
|
1126
|
+
Post-treatment observations (Omega_1).
|
|
1127
|
+
first_treat : str
|
|
1128
|
+
Column name for cohort identifier.
|
|
1129
|
+
all_horizons : list of int
|
|
1130
|
+
Post-treatment horizons in the event study.
|
|
1131
|
+
balance_e : int
|
|
1132
|
+
Number of pre-treatment periods to require.
|
|
1133
|
+
cohort_rel_times : dict
|
|
1134
|
+
Maps each cohort value to the set of all observed relative times
|
|
1135
|
+
(including pre-treatment) from the full panel. Built by
|
|
1136
|
+
_build_cohort_rel_times().
|
|
1137
|
+
"""
|
|
1138
|
+
if not all_horizons:
|
|
1139
|
+
return np.ones(len(df_treated), dtype=bool)
|
|
1140
|
+
|
|
1141
|
+
max_h = max(all_horizons)
|
|
1142
|
+
required_range = set(range(-balance_e, max_h + 1))
|
|
1143
|
+
|
|
1144
|
+
balanced_cohorts = set()
|
|
1145
|
+
for g, horizons in cohort_rel_times.items():
|
|
1146
|
+
if required_range.issubset(horizons):
|
|
1147
|
+
balanced_cohorts.add(g)
|
|
1148
|
+
|
|
1149
|
+
return df_treated[first_treat].isin(balanced_cohorts).values
|
|
1150
|
+
|
|
1151
|
+
@staticmethod
|
|
1152
|
+
def _build_cohort_rel_times(
|
|
1153
|
+
df: pd.DataFrame,
|
|
1154
|
+
first_treat: str,
|
|
1155
|
+
) -> Dict[Any, Set[int]]:
|
|
1156
|
+
"""Build mapping of cohort -> set of observed relative times from full panel.
|
|
1157
|
+
|
|
1158
|
+
Precondition: df must have '_never_treated' and '_rel_time' columns
|
|
1159
|
+
(set by fit() before any aggregation calls).
|
|
1160
|
+
"""
|
|
1161
|
+
treated_mask = ~df["_never_treated"]
|
|
1162
|
+
treated_df = df.loc[treated_mask]
|
|
1163
|
+
result: Dict[Any, Set[int]] = {}
|
|
1164
|
+
ft_vals = treated_df[first_treat].values
|
|
1165
|
+
rt_vals = treated_df["_rel_time"].values
|
|
1166
|
+
for i in range(len(treated_df)):
|
|
1167
|
+
h = rt_vals[i]
|
|
1168
|
+
if np.isfinite(h):
|
|
1169
|
+
result.setdefault(ft_vals[i], set()).add(int(h))
|
|
1170
|
+
return result
|
|
1171
|
+
|
|
1172
|
+
def _fit_untreated_model(
|
|
1173
|
+
self,
|
|
1174
|
+
df: pd.DataFrame,
|
|
1175
|
+
outcome: str,
|
|
1176
|
+
unit: str,
|
|
1177
|
+
time: str,
|
|
1178
|
+
covariates: Optional[List[str]],
|
|
1179
|
+
omega_0_mask: pd.Series,
|
|
1180
|
+
) -> Tuple[
|
|
1181
|
+
Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray]
|
|
1182
|
+
]:
|
|
1183
|
+
"""
|
|
1184
|
+
Step 1: Estimate unit + time FE on untreated observations.
|
|
1185
|
+
|
|
1186
|
+
Uses iterative alternating projection (Gauss-Seidel) to compute exact
|
|
1187
|
+
OLS fixed effects for both balanced and unbalanced panels. For balanced
|
|
1188
|
+
panels, converges in 1-2 iterations (identical to one-pass demeaning).
|
|
1189
|
+
|
|
1190
|
+
Returns
|
|
1191
|
+
-------
|
|
1192
|
+
unit_fe : dict
|
|
1193
|
+
Unit fixed effects {unit_id: alpha_i}.
|
|
1194
|
+
time_fe : dict
|
|
1195
|
+
Time fixed effects {time_period: beta_t}.
|
|
1196
|
+
grand_mean : float
|
|
1197
|
+
Grand mean (0.0 — absorbed into iterative FE).
|
|
1198
|
+
delta_hat : np.ndarray or None
|
|
1199
|
+
Covariate coefficients (if covariates provided).
|
|
1200
|
+
kept_cov_mask : np.ndarray or None
|
|
1201
|
+
Boolean mask of shape (n_covariates,) indicating which covariates
|
|
1202
|
+
have finite coefficients. None if no covariates.
|
|
1203
|
+
"""
|
|
1204
|
+
df_0 = df.loc[omega_0_mask]
|
|
1205
|
+
|
|
1206
|
+
if covariates is None or len(covariates) == 0:
|
|
1207
|
+
# No covariates: estimate FE via iterative alternating projection
|
|
1208
|
+
# (exact OLS for both balanced and unbalanced panels)
|
|
1209
|
+
y = df_0[outcome].values.copy()
|
|
1210
|
+
unit_fe, time_fe = self._iterative_fe(
|
|
1211
|
+
y, df_0[unit].values, df_0[time].values, df_0.index
|
|
1212
|
+
)
|
|
1213
|
+
# grand_mean = 0: iterative FE absorb the intercept
|
|
1214
|
+
return unit_fe, time_fe, 0.0, None, None
|
|
1215
|
+
|
|
1216
|
+
else:
|
|
1217
|
+
# With covariates: iteratively demean Y and X, OLS for delta,
|
|
1218
|
+
# then recover FE from covariate-adjusted outcome
|
|
1219
|
+
y = df_0[outcome].values.copy()
|
|
1220
|
+
X_raw = df_0[covariates].values.copy()
|
|
1221
|
+
units = df_0[unit].values
|
|
1222
|
+
times = df_0[time].values
|
|
1223
|
+
n_cov = len(covariates)
|
|
1224
|
+
|
|
1225
|
+
# Step A: Iteratively demean Y and all X columns to remove unit+time FE
|
|
1226
|
+
y_dm = self._iterative_demean(y, units, times, df_0.index)
|
|
1227
|
+
X_dm = np.column_stack(
|
|
1228
|
+
[
|
|
1229
|
+
self._iterative_demean(X_raw[:, j], units, times, df_0.index)
|
|
1230
|
+
for j in range(n_cov)
|
|
1231
|
+
]
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
# Step B: OLS for covariate coefficients on demeaned data
|
|
1235
|
+
result = solve_ols(
|
|
1236
|
+
X_dm,
|
|
1237
|
+
y_dm,
|
|
1238
|
+
return_vcov=False,
|
|
1239
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
1240
|
+
column_names=covariates,
|
|
1241
|
+
)
|
|
1242
|
+
delta_hat = result[0]
|
|
1243
|
+
|
|
1244
|
+
# Mask of covariates with finite coefficients (before cleaning)
|
|
1245
|
+
# Used to exclude rank-deficient covariates from variance design matrices
|
|
1246
|
+
kept_cov_mask = np.isfinite(delta_hat)
|
|
1247
|
+
|
|
1248
|
+
# Replace NaN coefficients with 0 for adjustment
|
|
1249
|
+
# (rank-deficient covariates are dropped)
|
|
1250
|
+
delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0)
|
|
1251
|
+
|
|
1252
|
+
# Step C: Recover FE from covariate-adjusted outcome using iterative FE
|
|
1253
|
+
y_adj = y - X_raw @ delta_hat_clean
|
|
1254
|
+
unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index)
|
|
1255
|
+
|
|
1256
|
+
# grand_mean = 0: iterative FE absorb the intercept
|
|
1257
|
+
return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask
|
|
1258
|
+
|
|
1259
|
+
# =========================================================================
|
|
1260
|
+
# Step 2: Impute counterfactuals
|
|
1261
|
+
# =========================================================================
|
|
1262
|
+
|
|
1263
|
+
def _impute_treatment_effects(
|
|
1264
|
+
self,
|
|
1265
|
+
df: pd.DataFrame,
|
|
1266
|
+
outcome: str,
|
|
1267
|
+
unit: str,
|
|
1268
|
+
time: str,
|
|
1269
|
+
covariates: Optional[List[str]],
|
|
1270
|
+
omega_1_mask: pd.Series,
|
|
1271
|
+
unit_fe: Dict[Any, float],
|
|
1272
|
+
time_fe: Dict[Any, float],
|
|
1273
|
+
grand_mean: float,
|
|
1274
|
+
delta_hat: Optional[np.ndarray],
|
|
1275
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1276
|
+
"""
|
|
1277
|
+
Step 2: Impute Y(0) for treated observations and compute tau_hat.
|
|
1278
|
+
|
|
1279
|
+
Returns
|
|
1280
|
+
-------
|
|
1281
|
+
tau_hat : np.ndarray
|
|
1282
|
+
Imputed treatment effects for each treated observation.
|
|
1283
|
+
y_hat_0 : np.ndarray
|
|
1284
|
+
Imputed counterfactual Y(0).
|
|
1285
|
+
"""
|
|
1286
|
+
df_1 = df.loc[omega_1_mask]
|
|
1287
|
+
n_1 = len(df_1)
|
|
1288
|
+
|
|
1289
|
+
# Look up unit and time FE
|
|
1290
|
+
alpha_i = df_1[unit].map(unit_fe).values
|
|
1291
|
+
beta_t = df_1[time].map(time_fe).values
|
|
1292
|
+
|
|
1293
|
+
# Handle missing FE (set to NaN)
|
|
1294
|
+
alpha_i = np.where(pd.isna(alpha_i), np.nan, alpha_i).astype(float)
|
|
1295
|
+
beta_t = np.where(pd.isna(beta_t), np.nan, beta_t).astype(float)
|
|
1296
|
+
|
|
1297
|
+
y_hat_0 = grand_mean + alpha_i + beta_t
|
|
1298
|
+
|
|
1299
|
+
if delta_hat is not None and covariates:
|
|
1300
|
+
X_1 = df_1[covariates].values
|
|
1301
|
+
y_hat_0 = y_hat_0 + X_1 @ delta_hat
|
|
1302
|
+
|
|
1303
|
+
tau_hat = df_1[outcome].values - y_hat_0
|
|
1304
|
+
|
|
1305
|
+
return tau_hat, y_hat_0
|
|
1306
|
+
|
|
1307
|
+
# =========================================================================
|
|
1308
|
+
# Conservative Variance (Theorem 3)
|
|
1309
|
+
# =========================================================================
|
|
1310
|
+
|
|
1311
|
+
def _compute_cluster_psi_sums(
|
|
1312
|
+
self,
|
|
1313
|
+
df: pd.DataFrame,
|
|
1314
|
+
outcome: str,
|
|
1315
|
+
unit: str,
|
|
1316
|
+
time: str,
|
|
1317
|
+
first_treat: str,
|
|
1318
|
+
covariates: Optional[List[str]],
|
|
1319
|
+
omega_0_mask: pd.Series,
|
|
1320
|
+
omega_1_mask: pd.Series,
|
|
1321
|
+
unit_fe: Dict[Any, float],
|
|
1322
|
+
time_fe: Dict[Any, float],
|
|
1323
|
+
grand_mean: float,
|
|
1324
|
+
delta_hat: Optional[np.ndarray],
|
|
1325
|
+
weights: np.ndarray,
|
|
1326
|
+
cluster_var: str,
|
|
1327
|
+
kept_cov_mask: Optional[np.ndarray] = None,
|
|
1328
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1329
|
+
"""
|
|
1330
|
+
Compute cluster-level influence function sums (Theorem 3).
|
|
1331
|
+
|
|
1332
|
+
psi_i = sum_t v_it * epsilon_tilde_it, summed within each cluster.
|
|
1333
|
+
|
|
1334
|
+
Returns
|
|
1335
|
+
-------
|
|
1336
|
+
cluster_psi_sums : np.ndarray
|
|
1337
|
+
Array of cluster-level psi sums.
|
|
1338
|
+
cluster_ids_unique : np.ndarray
|
|
1339
|
+
Unique cluster identifiers (matching order of psi sums).
|
|
1340
|
+
"""
|
|
1341
|
+
df_0 = df.loc[omega_0_mask]
|
|
1342
|
+
df_1 = df.loc[omega_1_mask]
|
|
1343
|
+
n_0 = len(df_0)
|
|
1344
|
+
n_1 = len(df_1)
|
|
1345
|
+
|
|
1346
|
+
# ---- Compute v_it for treated observations ----
|
|
1347
|
+
v_treated = weights.copy()
|
|
1348
|
+
|
|
1349
|
+
# ---- Compute v_it for untreated observations ----
|
|
1350
|
+
if covariates is None or len(covariates) == 0:
|
|
1351
|
+
# FE-only case: closed-form
|
|
1352
|
+
treated_units = df_1[unit].values
|
|
1353
|
+
treated_times = df_1[time].values
|
|
1354
|
+
|
|
1355
|
+
w_by_unit: Dict[Any, float] = {}
|
|
1356
|
+
for i_idx in range(n_1):
|
|
1357
|
+
u = treated_units[i_idx]
|
|
1358
|
+
w_by_unit[u] = w_by_unit.get(u, 0.0) + weights[i_idx]
|
|
1359
|
+
|
|
1360
|
+
w_by_time: Dict[Any, float] = {}
|
|
1361
|
+
for i_idx in range(n_1):
|
|
1362
|
+
t = treated_times[i_idx]
|
|
1363
|
+
w_by_time[t] = w_by_time.get(t, 0.0) + weights[i_idx]
|
|
1364
|
+
|
|
1365
|
+
w_total = float(np.sum(weights))
|
|
1366
|
+
|
|
1367
|
+
n0_by_unit = df_0.groupby(unit).size().to_dict()
|
|
1368
|
+
n0_by_time = df_0.groupby(time).size().to_dict()
|
|
1369
|
+
|
|
1370
|
+
untreated_units = df_0[unit].values
|
|
1371
|
+
untreated_times = df_0[time].values
|
|
1372
|
+
v_untreated = np.zeros(n_0)
|
|
1373
|
+
|
|
1374
|
+
for j in range(n_0):
|
|
1375
|
+
u = untreated_units[j]
|
|
1376
|
+
t = untreated_times[j]
|
|
1377
|
+
w_i = w_by_unit.get(u, 0.0)
|
|
1378
|
+
w_t = w_by_time.get(t, 0.0)
|
|
1379
|
+
n0_i = n0_by_unit.get(u, 1)
|
|
1380
|
+
n0_t = n0_by_time.get(t, 1)
|
|
1381
|
+
v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n_0)
|
|
1382
|
+
else:
|
|
1383
|
+
v_untreated = self._compute_v_untreated_with_covariates(
|
|
1384
|
+
df_0,
|
|
1385
|
+
df_1,
|
|
1386
|
+
unit,
|
|
1387
|
+
time,
|
|
1388
|
+
covariates,
|
|
1389
|
+
weights,
|
|
1390
|
+
delta_hat,
|
|
1391
|
+
kept_cov_mask=kept_cov_mask,
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
# ---- Compute auxiliary model residuals (Equation 8) ----
|
|
1395
|
+
epsilon_treated = self._compute_auxiliary_residuals_treated(
|
|
1396
|
+
df_1,
|
|
1397
|
+
outcome,
|
|
1398
|
+
unit,
|
|
1399
|
+
time,
|
|
1400
|
+
first_treat,
|
|
1401
|
+
covariates,
|
|
1402
|
+
unit_fe,
|
|
1403
|
+
time_fe,
|
|
1404
|
+
grand_mean,
|
|
1405
|
+
delta_hat,
|
|
1406
|
+
v_treated,
|
|
1407
|
+
)
|
|
1408
|
+
epsilon_untreated = self._compute_residuals_untreated(
|
|
1409
|
+
df_0, outcome, unit, time, covariates, unit_fe, time_fe, grand_mean, delta_hat
|
|
1410
|
+
)
|
|
1411
|
+
|
|
1412
|
+
# ---- psi_it = v_it * epsilon_tilde_it ----
|
|
1413
|
+
v_all = np.empty(len(df))
|
|
1414
|
+
v_all[omega_1_mask.values] = v_treated
|
|
1415
|
+
v_all[omega_0_mask.values] = v_untreated
|
|
1416
|
+
|
|
1417
|
+
eps_all = np.empty(len(df))
|
|
1418
|
+
eps_all[omega_1_mask.values] = epsilon_treated
|
|
1419
|
+
eps_all[omega_0_mask.values] = epsilon_untreated
|
|
1420
|
+
|
|
1421
|
+
ve_product = v_all * eps_all
|
|
1422
|
+
# NaN eps from missing FE (rank condition violation). Zero their variance
|
|
1423
|
+
# contribution — matches R's did_imputation which drops unimputable obs.
|
|
1424
|
+
np.nan_to_num(ve_product, copy=False, nan=0.0)
|
|
1425
|
+
|
|
1426
|
+
# Sum within clusters
|
|
1427
|
+
cluster_ids = df[cluster_var].values
|
|
1428
|
+
ve_series = pd.Series(ve_product, index=df.index)
|
|
1429
|
+
cluster_sums = ve_series.groupby(cluster_ids).sum()
|
|
1430
|
+
|
|
1431
|
+
return cluster_sums.values, cluster_sums.index.values
|
|
1432
|
+
|
|
1433
|
+
def _compute_conservative_variance(
|
|
1434
|
+
self,
|
|
1435
|
+
df: pd.DataFrame,
|
|
1436
|
+
outcome: str,
|
|
1437
|
+
unit: str,
|
|
1438
|
+
time: str,
|
|
1439
|
+
first_treat: str,
|
|
1440
|
+
covariates: Optional[List[str]],
|
|
1441
|
+
omega_0_mask: pd.Series,
|
|
1442
|
+
omega_1_mask: pd.Series,
|
|
1443
|
+
unit_fe: Dict[Any, float],
|
|
1444
|
+
time_fe: Dict[Any, float],
|
|
1445
|
+
grand_mean: float,
|
|
1446
|
+
delta_hat: Optional[np.ndarray],
|
|
1447
|
+
weights: np.ndarray,
|
|
1448
|
+
cluster_var: str,
|
|
1449
|
+
kept_cov_mask: Optional[np.ndarray] = None,
|
|
1450
|
+
) -> float:
|
|
1451
|
+
"""
|
|
1452
|
+
Compute conservative clustered variance (Theorem 3, Equation 7).
|
|
1453
|
+
|
|
1454
|
+
Parameters
|
|
1455
|
+
----------
|
|
1456
|
+
weights : np.ndarray
|
|
1457
|
+
Aggregation weights w_it for treated observations.
|
|
1458
|
+
Shape: (n_treated,), must sum to 1.
|
|
1459
|
+
|
|
1460
|
+
Returns
|
|
1461
|
+
-------
|
|
1462
|
+
float
|
|
1463
|
+
Standard error.
|
|
1464
|
+
"""
|
|
1465
|
+
cluster_psi_sums, _ = self._compute_cluster_psi_sums(
|
|
1466
|
+
df=df,
|
|
1467
|
+
outcome=outcome,
|
|
1468
|
+
unit=unit,
|
|
1469
|
+
time=time,
|
|
1470
|
+
first_treat=first_treat,
|
|
1471
|
+
covariates=covariates,
|
|
1472
|
+
omega_0_mask=omega_0_mask,
|
|
1473
|
+
omega_1_mask=omega_1_mask,
|
|
1474
|
+
unit_fe=unit_fe,
|
|
1475
|
+
time_fe=time_fe,
|
|
1476
|
+
grand_mean=grand_mean,
|
|
1477
|
+
delta_hat=delta_hat,
|
|
1478
|
+
weights=weights,
|
|
1479
|
+
cluster_var=cluster_var,
|
|
1480
|
+
kept_cov_mask=kept_cov_mask,
|
|
1481
|
+
)
|
|
1482
|
+
sigma_sq = float((cluster_psi_sums**2).sum())
|
|
1483
|
+
return np.sqrt(max(sigma_sq, 0.0))
|
|
1484
|
+
|
|
1485
|
+
def _compute_v_untreated_with_covariates(
|
|
1486
|
+
self,
|
|
1487
|
+
df_0: pd.DataFrame,
|
|
1488
|
+
df_1: pd.DataFrame,
|
|
1489
|
+
unit: str,
|
|
1490
|
+
time: str,
|
|
1491
|
+
covariates: List[str],
|
|
1492
|
+
weights: np.ndarray,
|
|
1493
|
+
delta_hat: Optional[np.ndarray],
|
|
1494
|
+
kept_cov_mask: Optional[np.ndarray] = None,
|
|
1495
|
+
) -> np.ndarray:
|
|
1496
|
+
"""
|
|
1497
|
+
Compute v_it for untreated observations with covariates.
|
|
1498
|
+
|
|
1499
|
+
Uses the projection: v_untreated = -A_0 (A_0'A_0)^{-1} A_1' w_treated
|
|
1500
|
+
|
|
1501
|
+
Uses scipy.sparse for FE dummy columns to reduce memory from O(N*(U+T))
|
|
1502
|
+
to O(N) for the FE portion.
|
|
1503
|
+
"""
|
|
1504
|
+
# Exclude rank-deficient covariates from design matrices
|
|
1505
|
+
if kept_cov_mask is not None and not np.all(kept_cov_mask):
|
|
1506
|
+
covariates = [c for c, k in zip(covariates, kept_cov_mask) if k]
|
|
1507
|
+
|
|
1508
|
+
units_0 = df_0[unit].values
|
|
1509
|
+
times_0 = df_0[time].values
|
|
1510
|
+
units_1 = df_1[unit].values
|
|
1511
|
+
times_1 = df_1[time].values
|
|
1512
|
+
|
|
1513
|
+
all_units = np.unique(np.concatenate([units_0, units_1]))
|
|
1514
|
+
all_times = np.unique(np.concatenate([times_0, times_1]))
|
|
1515
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
1516
|
+
time_to_idx = {t: i for i, t in enumerate(all_times)}
|
|
1517
|
+
n_units = len(all_units)
|
|
1518
|
+
n_times = len(all_times)
|
|
1519
|
+
n_cov = len(covariates)
|
|
1520
|
+
n_fe_cols = (n_units - 1) + (n_times - 1)
|
|
1521
|
+
|
|
1522
|
+
def _build_A_sparse(df_sub, unit_vals, time_vals):
|
|
1523
|
+
n = len(df_sub)
|
|
1524
|
+
|
|
1525
|
+
# Unit dummies (drop first) — vectorized
|
|
1526
|
+
u_indices = np.array([unit_to_idx[u] for u in unit_vals])
|
|
1527
|
+
u_mask = u_indices > 0 # skip first unit (dropped)
|
|
1528
|
+
u_rows = np.arange(n)[u_mask]
|
|
1529
|
+
u_cols = u_indices[u_mask] - 1
|
|
1530
|
+
|
|
1531
|
+
# Time dummies (drop first) — vectorized
|
|
1532
|
+
t_indices = np.array([time_to_idx[t] for t in time_vals])
|
|
1533
|
+
t_mask = t_indices > 0
|
|
1534
|
+
t_rows = np.arange(n)[t_mask]
|
|
1535
|
+
t_cols = (n_units - 1) + t_indices[t_mask] - 1
|
|
1536
|
+
|
|
1537
|
+
rows = np.concatenate([u_rows, t_rows])
|
|
1538
|
+
cols = np.concatenate([u_cols, t_cols])
|
|
1539
|
+
data = np.ones(len(rows))
|
|
1540
|
+
|
|
1541
|
+
A_fe = sparse.csr_matrix((data, (rows, cols)), shape=(n, n_fe_cols))
|
|
1542
|
+
|
|
1543
|
+
# Covariates (dense, typically few columns)
|
|
1544
|
+
if n_cov > 0:
|
|
1545
|
+
A_cov = sparse.csr_matrix(df_sub[covariates].values)
|
|
1546
|
+
A = sparse.hstack([A_fe, A_cov], format="csr")
|
|
1547
|
+
else:
|
|
1548
|
+
A = A_fe
|
|
1549
|
+
|
|
1550
|
+
return A
|
|
1551
|
+
|
|
1552
|
+
A_0 = _build_A_sparse(df_0, units_0, times_0)
|
|
1553
|
+
A_1 = _build_A_sparse(df_1, units_1, times_1)
|
|
1554
|
+
|
|
1555
|
+
# Compute A_1' w (sparse.T @ dense -> dense)
|
|
1556
|
+
A1_w = A_1.T @ weights # shape (p,)
|
|
1557
|
+
|
|
1558
|
+
# Solve (A_0'A_0) z = A_1' w using sparse direct solver
|
|
1559
|
+
A0tA0_sparse = A_0.T @ A_0 # stays sparse
|
|
1560
|
+
try:
|
|
1561
|
+
z = spsolve(A0tA0_sparse.tocsc(), A1_w)
|
|
1562
|
+
except Exception:
|
|
1563
|
+
# Fallback to dense lstsq if sparse solver fails (e.g., singular matrix)
|
|
1564
|
+
A0tA0_dense = A0tA0_sparse.toarray()
|
|
1565
|
+
z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None)
|
|
1566
|
+
|
|
1567
|
+
# v_untreated = -A_0 z (sparse @ dense -> dense)
|
|
1568
|
+
v_untreated = -(A_0 @ z)
|
|
1569
|
+
return v_untreated
|
|
1570
|
+
|
|
1571
|
+
def _compute_auxiliary_residuals_treated(
|
|
1572
|
+
self,
|
|
1573
|
+
df_1: pd.DataFrame,
|
|
1574
|
+
outcome: str,
|
|
1575
|
+
unit: str,
|
|
1576
|
+
time: str,
|
|
1577
|
+
first_treat: str,
|
|
1578
|
+
covariates: Optional[List[str]],
|
|
1579
|
+
unit_fe: Dict[Any, float],
|
|
1580
|
+
time_fe: Dict[Any, float],
|
|
1581
|
+
grand_mean: float,
|
|
1582
|
+
delta_hat: Optional[np.ndarray],
|
|
1583
|
+
v_treated: np.ndarray,
|
|
1584
|
+
) -> np.ndarray:
|
|
1585
|
+
"""
|
|
1586
|
+
Compute v_it-weighted auxiliary residuals for treated obs (Equation 8).
|
|
1587
|
+
|
|
1588
|
+
Computes v_it-weighted tau_tilde_g per Equation 8 of Borusyak et al. (2024):
|
|
1589
|
+
tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g.
|
|
1590
|
+
|
|
1591
|
+
epsilon_tilde_it = Y_it - alpha_i - beta_t [- X'delta] - tau_tilde_g
|
|
1592
|
+
"""
|
|
1593
|
+
n_1 = len(df_1)
|
|
1594
|
+
|
|
1595
|
+
# Compute base residuals (Y - Y_hat(0) = tau_hat)
|
|
1596
|
+
# NaN for missing FE (consistent with _impute_treatment_effects)
|
|
1597
|
+
alpha_i = df_1[unit].map(unit_fe).values.astype(float) # NaN for missing
|
|
1598
|
+
beta_t = df_1[time].map(time_fe).values.astype(float) # NaN for missing
|
|
1599
|
+
y_hat_0 = grand_mean + alpha_i + beta_t
|
|
1600
|
+
|
|
1601
|
+
if delta_hat is not None and covariates:
|
|
1602
|
+
y_hat_0 = y_hat_0 + df_1[covariates].values @ delta_hat
|
|
1603
|
+
|
|
1604
|
+
tau_hat = df_1[outcome].values - y_hat_0
|
|
1605
|
+
|
|
1606
|
+
# Partition Omega_1 and compute tau_tilde for each group
|
|
1607
|
+
if self.aux_partition == "cohort_horizon":
|
|
1608
|
+
group_keys = list(zip(df_1[first_treat].values, df_1["_rel_time"].values))
|
|
1609
|
+
elif self.aux_partition == "cohort":
|
|
1610
|
+
group_keys = list(df_1[first_treat].values)
|
|
1611
|
+
elif self.aux_partition == "horizon":
|
|
1612
|
+
group_keys = list(df_1["_rel_time"].values)
|
|
1613
|
+
else:
|
|
1614
|
+
group_keys = list(range(n_1)) # each obs is its own group
|
|
1615
|
+
|
|
1616
|
+
# Compute v_it-weighted average tau within each partition group (Equation 8)
|
|
1617
|
+
# tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g
|
|
1618
|
+
group_series = pd.Series(group_keys, index=df_1.index)
|
|
1619
|
+
tau_series = pd.Series(tau_hat, index=df_1.index)
|
|
1620
|
+
v_series = pd.Series(v_treated, index=df_1.index)
|
|
1621
|
+
|
|
1622
|
+
weighted_tau_sum = (v_series * tau_series).groupby(group_series).sum()
|
|
1623
|
+
weight_sum = v_series.groupby(group_series).sum()
|
|
1624
|
+
|
|
1625
|
+
# Guard: zero-weight groups -> their tau_tilde doesn't affect variance
|
|
1626
|
+
# (v_it ~ 0 means these obs contribute nothing to the estimand)
|
|
1627
|
+
# Use simple mean as fallback. This is common for event-study SE computation
|
|
1628
|
+
# where weights target a specific horizon, making other partition groups zero.
|
|
1629
|
+
zero_weight_groups = weight_sum.abs() < 1e-15
|
|
1630
|
+
if zero_weight_groups.any():
|
|
1631
|
+
simple_means = tau_series.groupby(group_series).mean()
|
|
1632
|
+
tau_tilde_map = weighted_tau_sum / weight_sum
|
|
1633
|
+
tau_tilde_map = tau_tilde_map.where(~zero_weight_groups, simple_means)
|
|
1634
|
+
else:
|
|
1635
|
+
tau_tilde_map = weighted_tau_sum / weight_sum
|
|
1636
|
+
|
|
1637
|
+
tau_tilde = group_series.map(tau_tilde_map).values
|
|
1638
|
+
|
|
1639
|
+
# Auxiliary residuals
|
|
1640
|
+
epsilon_treated = tau_hat - tau_tilde
|
|
1641
|
+
|
|
1642
|
+
return epsilon_treated
|
|
1643
|
+
|
|
1644
|
+
def _compute_residuals_untreated(
|
|
1645
|
+
self,
|
|
1646
|
+
df_0: pd.DataFrame,
|
|
1647
|
+
outcome: str,
|
|
1648
|
+
unit: str,
|
|
1649
|
+
time: str,
|
|
1650
|
+
covariates: Optional[List[str]],
|
|
1651
|
+
unit_fe: Dict[Any, float],
|
|
1652
|
+
time_fe: Dict[Any, float],
|
|
1653
|
+
grand_mean: float,
|
|
1654
|
+
delta_hat: Optional[np.ndarray],
|
|
1655
|
+
) -> np.ndarray:
|
|
1656
|
+
"""Compute Step 1 residuals for untreated observations."""
|
|
1657
|
+
alpha_i = df_0[unit].map(unit_fe).fillna(0.0).values
|
|
1658
|
+
beta_t = df_0[time].map(time_fe).fillna(0.0).values
|
|
1659
|
+
y_hat = grand_mean + alpha_i + beta_t
|
|
1660
|
+
|
|
1661
|
+
if delta_hat is not None and covariates:
|
|
1662
|
+
y_hat = y_hat + df_0[covariates].values @ delta_hat
|
|
1663
|
+
|
|
1664
|
+
return df_0[outcome].values - y_hat
|
|
1665
|
+
|
|
1666
|
+
# =========================================================================
|
|
1667
|
+
# Aggregation
|
|
1668
|
+
# =========================================================================
|
|
1669
|
+
|
|
1670
|
+
def _aggregate_event_study(
|
|
1671
|
+
self,
|
|
1672
|
+
df: pd.DataFrame,
|
|
1673
|
+
outcome: str,
|
|
1674
|
+
unit: str,
|
|
1675
|
+
time: str,
|
|
1676
|
+
first_treat: str,
|
|
1677
|
+
covariates: Optional[List[str]],
|
|
1678
|
+
omega_0_mask: pd.Series,
|
|
1679
|
+
omega_1_mask: pd.Series,
|
|
1680
|
+
unit_fe: Dict[Any, float],
|
|
1681
|
+
time_fe: Dict[Any, float],
|
|
1682
|
+
grand_mean: float,
|
|
1683
|
+
delta_hat: Optional[np.ndarray],
|
|
1684
|
+
cluster_var: str,
|
|
1685
|
+
treatment_groups: List[Any],
|
|
1686
|
+
balance_e: Optional[int] = None,
|
|
1687
|
+
kept_cov_mask: Optional[np.ndarray] = None,
|
|
1688
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
1689
|
+
"""Aggregate treatment effects by event-study horizon."""
|
|
1690
|
+
df_1 = df.loc[omega_1_mask]
|
|
1691
|
+
tau_hat = df["_tau_hat"].loc[omega_1_mask].values
|
|
1692
|
+
rel_times = df_1["_rel_time"].values
|
|
1693
|
+
|
|
1694
|
+
# Get all horizons
|
|
1695
|
+
all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h)))
|
|
1696
|
+
|
|
1697
|
+
# Apply horizon_max filter
|
|
1698
|
+
if self.horizon_max is not None:
|
|
1699
|
+
all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
|
|
1700
|
+
|
|
1701
|
+
# Apply balance_e filter
|
|
1702
|
+
if balance_e is not None:
|
|
1703
|
+
cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
|
|
1704
|
+
balanced_mask = pd.Series(
|
|
1705
|
+
self._compute_balanced_cohort_mask(
|
|
1706
|
+
df_1, first_treat, all_horizons, balance_e, cohort_rel_times
|
|
1707
|
+
),
|
|
1708
|
+
index=df_1.index,
|
|
1709
|
+
)
|
|
1710
|
+
else:
|
|
1711
|
+
balanced_mask = pd.Series(True, index=df_1.index)
|
|
1712
|
+
|
|
1713
|
+
# Check Proposition 5: no never-treated units
|
|
1714
|
+
has_never_treated = df["_never_treated"].any()
|
|
1715
|
+
h_bar = np.inf
|
|
1716
|
+
if not has_never_treated and len(treatment_groups) > 1:
|
|
1717
|
+
h_bar = max(treatment_groups) - min(treatment_groups)
|
|
1718
|
+
|
|
1719
|
+
# Reference period
|
|
1720
|
+
ref_period = -1 - self.anticipation
|
|
1721
|
+
|
|
1722
|
+
event_study_effects: Dict[int, Dict[str, Any]] = {}
|
|
1723
|
+
|
|
1724
|
+
# Add reference period marker
|
|
1725
|
+
event_study_effects[ref_period] = {
|
|
1726
|
+
"effect": 0.0,
|
|
1727
|
+
"se": 0.0,
|
|
1728
|
+
"t_stat": np.nan,
|
|
1729
|
+
"p_value": np.nan,
|
|
1730
|
+
"conf_int": (0.0, 0.0),
|
|
1731
|
+
"n_obs": 0,
|
|
1732
|
+
}
|
|
1733
|
+
|
|
1734
|
+
# Collect horizons with Proposition 5 violations
|
|
1735
|
+
prop5_horizons = []
|
|
1736
|
+
|
|
1737
|
+
for h in all_horizons:
|
|
1738
|
+
if h == ref_period:
|
|
1739
|
+
continue
|
|
1740
|
+
|
|
1741
|
+
# Select treated obs at this horizon from balanced cohorts
|
|
1742
|
+
h_mask = (rel_times == h) & balanced_mask.values
|
|
1743
|
+
n_h = int(h_mask.sum())
|
|
1744
|
+
|
|
1745
|
+
if n_h == 0:
|
|
1746
|
+
continue
|
|
1747
|
+
|
|
1748
|
+
# Proposition 5 check
|
|
1749
|
+
if not has_never_treated and h >= h_bar:
|
|
1750
|
+
prop5_horizons.append(h)
|
|
1751
|
+
event_study_effects[h] = {
|
|
1752
|
+
"effect": np.nan,
|
|
1753
|
+
"se": np.nan,
|
|
1754
|
+
"t_stat": np.nan,
|
|
1755
|
+
"p_value": np.nan,
|
|
1756
|
+
"conf_int": (np.nan, np.nan),
|
|
1757
|
+
"n_obs": n_h,
|
|
1758
|
+
}
|
|
1759
|
+
continue
|
|
1760
|
+
|
|
1761
|
+
tau_h = tau_hat[h_mask]
|
|
1762
|
+
valid_tau = tau_h[np.isfinite(tau_h)]
|
|
1763
|
+
|
|
1764
|
+
if len(valid_tau) == 0:
|
|
1765
|
+
event_study_effects[h] = {
|
|
1766
|
+
"effect": np.nan,
|
|
1767
|
+
"se": np.nan,
|
|
1768
|
+
"t_stat": np.nan,
|
|
1769
|
+
"p_value": np.nan,
|
|
1770
|
+
"conf_int": (np.nan, np.nan),
|
|
1771
|
+
"n_obs": n_h,
|
|
1772
|
+
}
|
|
1773
|
+
continue
|
|
1774
|
+
|
|
1775
|
+
effect = float(np.mean(valid_tau))
|
|
1776
|
+
|
|
1777
|
+
# Compute SE via conservative variance with horizon-specific weights
|
|
1778
|
+
weights_h = np.zeros(int(omega_1_mask.sum()))
|
|
1779
|
+
# Map h_mask (relative to df_1) to weights array
|
|
1780
|
+
h_indices_in_omega1 = np.where(h_mask)[0]
|
|
1781
|
+
n_valid = len(valid_tau)
|
|
1782
|
+
# Only weight valid (finite) observations
|
|
1783
|
+
finite_mask = np.isfinite(tau_hat[h_mask])
|
|
1784
|
+
valid_h_indices = h_indices_in_omega1[finite_mask]
|
|
1785
|
+
for idx in valid_h_indices:
|
|
1786
|
+
weights_h[idx] = 1.0 / n_valid
|
|
1787
|
+
|
|
1788
|
+
se = self._compute_conservative_variance(
|
|
1789
|
+
df=df,
|
|
1790
|
+
outcome=outcome,
|
|
1791
|
+
unit=unit,
|
|
1792
|
+
time=time,
|
|
1793
|
+
first_treat=first_treat,
|
|
1794
|
+
covariates=covariates,
|
|
1795
|
+
omega_0_mask=omega_0_mask,
|
|
1796
|
+
omega_1_mask=omega_1_mask,
|
|
1797
|
+
unit_fe=unit_fe,
|
|
1798
|
+
time_fe=time_fe,
|
|
1799
|
+
grand_mean=grand_mean,
|
|
1800
|
+
delta_hat=delta_hat,
|
|
1801
|
+
weights=weights_h,
|
|
1802
|
+
cluster_var=cluster_var,
|
|
1803
|
+
kept_cov_mask=kept_cov_mask,
|
|
1804
|
+
)
|
|
1805
|
+
|
|
1806
|
+
t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan
|
|
1807
|
+
p_value = compute_p_value(t_stat)
|
|
1808
|
+
conf_int = (
|
|
1809
|
+
compute_confidence_interval(effect, se, self.alpha)
|
|
1810
|
+
if np.isfinite(se) and se > 0
|
|
1811
|
+
else (np.nan, np.nan)
|
|
1812
|
+
)
|
|
1813
|
+
|
|
1814
|
+
event_study_effects[h] = {
|
|
1815
|
+
"effect": effect,
|
|
1816
|
+
"se": se,
|
|
1817
|
+
"t_stat": t_stat,
|
|
1818
|
+
"p_value": p_value,
|
|
1819
|
+
"conf_int": conf_int,
|
|
1820
|
+
"n_obs": n_h,
|
|
1821
|
+
}
|
|
1822
|
+
|
|
1823
|
+
# Proposition 5 warning
|
|
1824
|
+
if prop5_horizons:
|
|
1825
|
+
warnings.warn(
|
|
1826
|
+
f"Horizons {prop5_horizons} are not identified without "
|
|
1827
|
+
f"never-treated units (Proposition 5). Set to NaN.",
|
|
1828
|
+
UserWarning,
|
|
1829
|
+
stacklevel=3,
|
|
1830
|
+
)
|
|
1831
|
+
|
|
1832
|
+
# Check for empty result set after filtering
|
|
1833
|
+
real_effects = [
|
|
1834
|
+
h for h, v in event_study_effects.items() if h != ref_period and v.get("n_obs", 0) > 0
|
|
1835
|
+
]
|
|
1836
|
+
if len(real_effects) == 0:
|
|
1837
|
+
filter_info = []
|
|
1838
|
+
if balance_e is not None:
|
|
1839
|
+
filter_info.append(f"balance_e={balance_e}")
|
|
1840
|
+
if self.horizon_max is not None:
|
|
1841
|
+
filter_info.append(f"horizon_max={self.horizon_max}")
|
|
1842
|
+
filter_str = " and ".join(filter_info) if filter_info else "filters"
|
|
1843
|
+
warnings.warn(
|
|
1844
|
+
f"Event study aggregation produced no horizons with observations "
|
|
1845
|
+
f"after applying {filter_str}. The result contains only the "
|
|
1846
|
+
f"reference period marker. Consider relaxing filter parameters.",
|
|
1847
|
+
UserWarning,
|
|
1848
|
+
stacklevel=3,
|
|
1849
|
+
)
|
|
1850
|
+
|
|
1851
|
+
return event_study_effects
|
|
1852
|
+
|
|
1853
|
+
def _aggregate_group(
|
|
1854
|
+
self,
|
|
1855
|
+
df: pd.DataFrame,
|
|
1856
|
+
outcome: str,
|
|
1857
|
+
unit: str,
|
|
1858
|
+
time: str,
|
|
1859
|
+
first_treat: str,
|
|
1860
|
+
covariates: Optional[List[str]],
|
|
1861
|
+
omega_0_mask: pd.Series,
|
|
1862
|
+
omega_1_mask: pd.Series,
|
|
1863
|
+
unit_fe: Dict[Any, float],
|
|
1864
|
+
time_fe: Dict[Any, float],
|
|
1865
|
+
grand_mean: float,
|
|
1866
|
+
delta_hat: Optional[np.ndarray],
|
|
1867
|
+
cluster_var: str,
|
|
1868
|
+
treatment_groups: List[Any],
|
|
1869
|
+
kept_cov_mask: Optional[np.ndarray] = None,
|
|
1870
|
+
) -> Dict[Any, Dict[str, Any]]:
|
|
1871
|
+
"""Aggregate treatment effects by cohort."""
|
|
1872
|
+
df_1 = df.loc[omega_1_mask]
|
|
1873
|
+
tau_hat = df["_tau_hat"].loc[omega_1_mask].values
|
|
1874
|
+
cohorts = df_1[first_treat].values
|
|
1875
|
+
|
|
1876
|
+
group_effects: Dict[Any, Dict[str, Any]] = {}
|
|
1877
|
+
|
|
1878
|
+
for g in treatment_groups:
|
|
1879
|
+
g_mask = cohorts == g
|
|
1880
|
+
n_g = int(g_mask.sum())
|
|
1881
|
+
|
|
1882
|
+
if n_g == 0:
|
|
1883
|
+
continue
|
|
1884
|
+
|
|
1885
|
+
tau_g = tau_hat[g_mask]
|
|
1886
|
+
valid_tau = tau_g[np.isfinite(tau_g)]
|
|
1887
|
+
|
|
1888
|
+
if len(valid_tau) == 0:
|
|
1889
|
+
group_effects[g] = {
|
|
1890
|
+
"effect": np.nan,
|
|
1891
|
+
"se": np.nan,
|
|
1892
|
+
"t_stat": np.nan,
|
|
1893
|
+
"p_value": np.nan,
|
|
1894
|
+
"conf_int": (np.nan, np.nan),
|
|
1895
|
+
"n_obs": n_g,
|
|
1896
|
+
}
|
|
1897
|
+
continue
|
|
1898
|
+
|
|
1899
|
+
effect = float(np.mean(valid_tau))
|
|
1900
|
+
|
|
1901
|
+
# Compute SE with group-specific weights
|
|
1902
|
+
weights_g = np.zeros(int(omega_1_mask.sum()))
|
|
1903
|
+
finite_mask = np.isfinite(tau_hat) & g_mask
|
|
1904
|
+
g_indices = np.where(finite_mask)[0]
|
|
1905
|
+
n_valid = len(valid_tau)
|
|
1906
|
+
for idx in g_indices:
|
|
1907
|
+
weights_g[idx] = 1.0 / n_valid
|
|
1908
|
+
|
|
1909
|
+
se = self._compute_conservative_variance(
|
|
1910
|
+
df=df,
|
|
1911
|
+
outcome=outcome,
|
|
1912
|
+
unit=unit,
|
|
1913
|
+
time=time,
|
|
1914
|
+
first_treat=first_treat,
|
|
1915
|
+
covariates=covariates,
|
|
1916
|
+
omega_0_mask=omega_0_mask,
|
|
1917
|
+
omega_1_mask=omega_1_mask,
|
|
1918
|
+
unit_fe=unit_fe,
|
|
1919
|
+
time_fe=time_fe,
|
|
1920
|
+
grand_mean=grand_mean,
|
|
1921
|
+
delta_hat=delta_hat,
|
|
1922
|
+
weights=weights_g,
|
|
1923
|
+
cluster_var=cluster_var,
|
|
1924
|
+
kept_cov_mask=kept_cov_mask,
|
|
1925
|
+
)
|
|
1926
|
+
|
|
1927
|
+
t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan
|
|
1928
|
+
p_value = compute_p_value(t_stat)
|
|
1929
|
+
conf_int = (
|
|
1930
|
+
compute_confidence_interval(effect, se, self.alpha)
|
|
1931
|
+
if np.isfinite(se) and se > 0
|
|
1932
|
+
else (np.nan, np.nan)
|
|
1933
|
+
)
|
|
1934
|
+
|
|
1935
|
+
group_effects[g] = {
|
|
1936
|
+
"effect": effect,
|
|
1937
|
+
"se": se,
|
|
1938
|
+
"t_stat": t_stat,
|
|
1939
|
+
"p_value": p_value,
|
|
1940
|
+
"conf_int": conf_int,
|
|
1941
|
+
"n_obs": n_g,
|
|
1942
|
+
}
|
|
1943
|
+
|
|
1944
|
+
return group_effects
|
|
1945
|
+
|
|
1946
|
+
# =========================================================================
|
|
1947
|
+
# Pre-trend test (Equation 9)
|
|
1948
|
+
# =========================================================================
|
|
1949
|
+
|
|
1950
|
+
def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
|
|
1951
|
+
"""
|
|
1952
|
+
Run pre-trend test (Equation 9).
|
|
1953
|
+
|
|
1954
|
+
Adds pre-treatment lead indicators to the Step 1 OLS on Omega_0
|
|
1955
|
+
and tests their joint significance via cluster-robust Wald F-test.
|
|
1956
|
+
"""
|
|
1957
|
+
if self._fit_data is None:
|
|
1958
|
+
raise RuntimeError("Must call fit() before pretrend_test().")
|
|
1959
|
+
|
|
1960
|
+
fd = self._fit_data
|
|
1961
|
+
df = fd["df"]
|
|
1962
|
+
outcome = fd["outcome"]
|
|
1963
|
+
unit = fd["unit"]
|
|
1964
|
+
time = fd["time"]
|
|
1965
|
+
first_treat = fd["first_treat"]
|
|
1966
|
+
covariates = fd["covariates"]
|
|
1967
|
+
omega_0_mask = fd["omega_0_mask"]
|
|
1968
|
+
cluster_var = fd["cluster_var"]
|
|
1969
|
+
|
|
1970
|
+
df_0 = df.loc[omega_0_mask].copy()
|
|
1971
|
+
|
|
1972
|
+
# Compute relative time for untreated obs
|
|
1973
|
+
# For not-yet-treated units in their pre-treatment periods
|
|
1974
|
+
rel_time_0 = np.where(
|
|
1975
|
+
~df_0["_never_treated"],
|
|
1976
|
+
df_0[time] - df_0[first_treat],
|
|
1977
|
+
np.nan,
|
|
1978
|
+
)
|
|
1979
|
+
|
|
1980
|
+
# Get available pre-treatment relative times (negative values)
|
|
1981
|
+
pre_rel_times = sorted(
|
|
1982
|
+
set(int(h) for h in rel_time_0 if np.isfinite(h) and h < -self.anticipation)
|
|
1983
|
+
)
|
|
1984
|
+
|
|
1985
|
+
if len(pre_rel_times) == 0:
|
|
1986
|
+
return {
|
|
1987
|
+
"f_stat": np.nan,
|
|
1988
|
+
"p_value": np.nan,
|
|
1989
|
+
"df": 0,
|
|
1990
|
+
"n_leads": 0,
|
|
1991
|
+
"lead_coefficients": {},
|
|
1992
|
+
}
|
|
1993
|
+
|
|
1994
|
+
# Exclude the reference period (last pre-treatment period)
|
|
1995
|
+
ref = -1 - self.anticipation
|
|
1996
|
+
pre_rel_times = [h for h in pre_rel_times if h != ref]
|
|
1997
|
+
|
|
1998
|
+
if n_leads is not None:
|
|
1999
|
+
# Take the n_leads periods closest to treatment
|
|
2000
|
+
pre_rel_times = sorted(pre_rel_times, reverse=True)[:n_leads]
|
|
2001
|
+
pre_rel_times = sorted(pre_rel_times)
|
|
2002
|
+
|
|
2003
|
+
if len(pre_rel_times) == 0:
|
|
2004
|
+
return {
|
|
2005
|
+
"f_stat": np.nan,
|
|
2006
|
+
"p_value": np.nan,
|
|
2007
|
+
"df": 0,
|
|
2008
|
+
"n_leads": 0,
|
|
2009
|
+
"lead_coefficients": {},
|
|
2010
|
+
}
|
|
2011
|
+
|
|
2012
|
+
# Build lead indicators
|
|
2013
|
+
lead_cols = []
|
|
2014
|
+
for h in pre_rel_times:
|
|
2015
|
+
col_name = f"_lead_{h}"
|
|
2016
|
+
df_0[col_name] = ((rel_time_0 == h)).astype(float)
|
|
2017
|
+
lead_cols.append(col_name)
|
|
2018
|
+
|
|
2019
|
+
# Within-transform via iterative demeaning (exact for unbalanced panels)
|
|
2020
|
+
y_dm = self._iterative_demean(
|
|
2021
|
+
df_0[outcome].values, df_0[unit].values, df_0[time].values, df_0.index
|
|
2022
|
+
)
|
|
2023
|
+
|
|
2024
|
+
all_x_cols = lead_cols[:]
|
|
2025
|
+
if covariates:
|
|
2026
|
+
all_x_cols.extend(covariates)
|
|
2027
|
+
|
|
2028
|
+
X_dm = np.column_stack(
|
|
2029
|
+
[
|
|
2030
|
+
self._iterative_demean(
|
|
2031
|
+
df_0[col].values, df_0[unit].values, df_0[time].values, df_0.index
|
|
2032
|
+
)
|
|
2033
|
+
for col in all_x_cols
|
|
2034
|
+
]
|
|
2035
|
+
)
|
|
2036
|
+
|
|
2037
|
+
# OLS with cluster-robust SEs
|
|
2038
|
+
cluster_ids = df_0[cluster_var].values
|
|
2039
|
+
result = solve_ols(
|
|
2040
|
+
X_dm,
|
|
2041
|
+
y_dm,
|
|
2042
|
+
cluster_ids=cluster_ids,
|
|
2043
|
+
return_vcov=True,
|
|
2044
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
2045
|
+
column_names=all_x_cols,
|
|
2046
|
+
)
|
|
2047
|
+
coefficients = result[0]
|
|
2048
|
+
vcov = result[2]
|
|
2049
|
+
|
|
2050
|
+
# Extract lead coefficients and their sub-VCV
|
|
2051
|
+
n_leads_actual = len(lead_cols)
|
|
2052
|
+
gamma = coefficients[:n_leads_actual]
|
|
2053
|
+
V_gamma = vcov[:n_leads_actual, :n_leads_actual]
|
|
2054
|
+
|
|
2055
|
+
# Wald F-test: F = (gamma' V^{-1} gamma) / n_leads
|
|
2056
|
+
try:
|
|
2057
|
+
V_inv_gamma = np.linalg.solve(V_gamma, gamma)
|
|
2058
|
+
wald_stat = float(gamma @ V_inv_gamma)
|
|
2059
|
+
f_stat = wald_stat / n_leads_actual
|
|
2060
|
+
except np.linalg.LinAlgError:
|
|
2061
|
+
f_stat = np.nan
|
|
2062
|
+
|
|
2063
|
+
# P-value from F distribution
|
|
2064
|
+
if np.isfinite(f_stat) and f_stat >= 0:
|
|
2065
|
+
n_clusters = len(np.unique(cluster_ids))
|
|
2066
|
+
df_denom = max(n_clusters - 1, 1)
|
|
2067
|
+
p_value = float(stats.f.sf(f_stat, n_leads_actual, df_denom))
|
|
2068
|
+
else:
|
|
2069
|
+
p_value = np.nan
|
|
2070
|
+
|
|
2071
|
+
# Store lead coefficients
|
|
2072
|
+
lead_coefficients = {}
|
|
2073
|
+
for j, h in enumerate(pre_rel_times):
|
|
2074
|
+
lead_coefficients[h] = float(gamma[j])
|
|
2075
|
+
|
|
2076
|
+
return {
|
|
2077
|
+
"f_stat": f_stat,
|
|
2078
|
+
"p_value": p_value,
|
|
2079
|
+
"df": n_leads_actual,
|
|
2080
|
+
"n_leads": n_leads_actual,
|
|
2081
|
+
"lead_coefficients": lead_coefficients,
|
|
2082
|
+
}
|
|
2083
|
+
|
|
2084
|
+
# =========================================================================
|
|
2085
|
+
# Bootstrap
|
|
2086
|
+
# =========================================================================
|
|
2087
|
+
|
|
2088
|
+
def _compute_percentile_ci(
|
|
2089
|
+
self,
|
|
2090
|
+
boot_dist: np.ndarray,
|
|
2091
|
+
alpha: float,
|
|
2092
|
+
) -> Tuple[float, float]:
|
|
2093
|
+
"""Compute percentile confidence interval from bootstrap distribution."""
|
|
2094
|
+
lower = float(np.percentile(boot_dist, alpha / 2 * 100))
|
|
2095
|
+
upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
|
|
2096
|
+
return (lower, upper)
|
|
2097
|
+
|
|
2098
|
+
def _compute_bootstrap_pvalue(
|
|
2099
|
+
self,
|
|
2100
|
+
original_effect: float,
|
|
2101
|
+
boot_dist: np.ndarray,
|
|
2102
|
+
n_valid: Optional[int] = None,
|
|
2103
|
+
) -> float:
|
|
2104
|
+
"""
|
|
2105
|
+
Compute two-sided bootstrap p-value.
|
|
2106
|
+
|
|
2107
|
+
Uses the percentile method: p-value is the proportion of bootstrap
|
|
2108
|
+
estimates on the opposite side of zero from the original estimate,
|
|
2109
|
+
doubled for two-sided test.
|
|
2110
|
+
|
|
2111
|
+
Parameters
|
|
2112
|
+
----------
|
|
2113
|
+
original_effect : float
|
|
2114
|
+
Original point estimate.
|
|
2115
|
+
boot_dist : np.ndarray
|
|
2116
|
+
Bootstrap distribution of the effect.
|
|
2117
|
+
n_valid : int, optional
|
|
2118
|
+
Number of valid bootstrap samples. If None, uses self.n_bootstrap.
|
|
2119
|
+
"""
|
|
2120
|
+
if original_effect >= 0:
|
|
2121
|
+
p_one_sided = float(np.mean(boot_dist <= 0))
|
|
2122
|
+
else:
|
|
2123
|
+
p_one_sided = float(np.mean(boot_dist >= 0))
|
|
2124
|
+
p_value = min(2 * p_one_sided, 1.0)
|
|
2125
|
+
n_for_floor = n_valid if n_valid is not None else self.n_bootstrap
|
|
2126
|
+
p_value = max(p_value, 1 / (n_for_floor + 1))
|
|
2127
|
+
return p_value
|
|
2128
|
+
|
|
2129
|
+
def _precompute_bootstrap_psi(
|
|
2130
|
+
self,
|
|
2131
|
+
df: pd.DataFrame,
|
|
2132
|
+
outcome: str,
|
|
2133
|
+
unit: str,
|
|
2134
|
+
time: str,
|
|
2135
|
+
first_treat: str,
|
|
2136
|
+
covariates: Optional[List[str]],
|
|
2137
|
+
omega_0_mask: pd.Series,
|
|
2138
|
+
omega_1_mask: pd.Series,
|
|
2139
|
+
unit_fe: Dict[Any, float],
|
|
2140
|
+
time_fe: Dict[Any, float],
|
|
2141
|
+
grand_mean: float,
|
|
2142
|
+
delta_hat: Optional[np.ndarray],
|
|
2143
|
+
cluster_var: str,
|
|
2144
|
+
kept_cov_mask: Optional[np.ndarray],
|
|
2145
|
+
overall_weights: np.ndarray,
|
|
2146
|
+
event_study_effects: Optional[Dict[int, Dict[str, Any]]],
|
|
2147
|
+
group_effects: Optional[Dict[Any, Dict[str, Any]]],
|
|
2148
|
+
treatment_groups: List[Any],
|
|
2149
|
+
tau_hat: np.ndarray,
|
|
2150
|
+
balance_e: Optional[int],
|
|
2151
|
+
) -> Dict[str, Any]:
|
|
2152
|
+
"""
|
|
2153
|
+
Pre-compute cluster-level influence function sums for each bootstrap target.
|
|
2154
|
+
|
|
2155
|
+
For each aggregation target (overall, per-horizon, per-group), computes
|
|
2156
|
+
psi_i = sum_t v_it * epsilon_tilde_it for each cluster. The multiplier
|
|
2157
|
+
bootstrap then perturbs these psi sums with Rademacher weights.
|
|
2158
|
+
|
|
2159
|
+
Computational cost scales with the number of aggregation targets, since
|
|
2160
|
+
each target requires its own v_untreated computation (weight-dependent).
|
|
2161
|
+
"""
|
|
2162
|
+
result: Dict[str, Any] = {}
|
|
2163
|
+
|
|
2164
|
+
common = dict(
|
|
2165
|
+
df=df,
|
|
2166
|
+
outcome=outcome,
|
|
2167
|
+
unit=unit,
|
|
2168
|
+
time=time,
|
|
2169
|
+
first_treat=first_treat,
|
|
2170
|
+
covariates=covariates,
|
|
2171
|
+
omega_0_mask=omega_0_mask,
|
|
2172
|
+
omega_1_mask=omega_1_mask,
|
|
2173
|
+
unit_fe=unit_fe,
|
|
2174
|
+
time_fe=time_fe,
|
|
2175
|
+
grand_mean=grand_mean,
|
|
2176
|
+
delta_hat=delta_hat,
|
|
2177
|
+
cluster_var=cluster_var,
|
|
2178
|
+
kept_cov_mask=kept_cov_mask,
|
|
2179
|
+
)
|
|
2180
|
+
|
|
2181
|
+
# Overall ATT
|
|
2182
|
+
overall_psi, cluster_ids = self._compute_cluster_psi_sums(**common, weights=overall_weights)
|
|
2183
|
+
result["overall"] = (overall_psi, cluster_ids)
|
|
2184
|
+
|
|
2185
|
+
# Event study: per-horizon weights
|
|
2186
|
+
# NOTE: weight logic duplicated from _aggregate_event_study.
|
|
2187
|
+
# If weight scheme changes there, update here too.
|
|
2188
|
+
if event_study_effects:
|
|
2189
|
+
result["event_study"] = {}
|
|
2190
|
+
df_1 = df.loc[omega_1_mask]
|
|
2191
|
+
rel_times = df_1["_rel_time"].values
|
|
2192
|
+
n_omega_1 = int(omega_1_mask.sum())
|
|
2193
|
+
|
|
2194
|
+
# Balanced cohort mask (same logic as _aggregate_event_study)
|
|
2195
|
+
balanced_mask = None
|
|
2196
|
+
if balance_e is not None:
|
|
2197
|
+
all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h)))
|
|
2198
|
+
if self.horizon_max is not None:
|
|
2199
|
+
all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
|
|
2200
|
+
cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
|
|
2201
|
+
balanced_mask = self._compute_balanced_cohort_mask(
|
|
2202
|
+
df_1, first_treat, all_horizons, balance_e, cohort_rel_times
|
|
2203
|
+
)
|
|
2204
|
+
|
|
2205
|
+
ref_period = -1 - self.anticipation
|
|
2206
|
+
for h in event_study_effects:
|
|
2207
|
+
if event_study_effects[h].get("n_obs", 0) == 0:
|
|
2208
|
+
continue
|
|
2209
|
+
if h == ref_period:
|
|
2210
|
+
continue
|
|
2211
|
+
if not np.isfinite(event_study_effects[h].get("effect", np.nan)):
|
|
2212
|
+
continue
|
|
2213
|
+
h_mask = rel_times == h
|
|
2214
|
+
if balanced_mask is not None:
|
|
2215
|
+
h_mask = h_mask & balanced_mask
|
|
2216
|
+
weights_h = np.zeros(n_omega_1)
|
|
2217
|
+
finite_h = np.isfinite(tau_hat) & h_mask
|
|
2218
|
+
n_valid_h = int(finite_h.sum())
|
|
2219
|
+
if n_valid_h == 0:
|
|
2220
|
+
continue
|
|
2221
|
+
weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h
|
|
2222
|
+
|
|
2223
|
+
psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h)
|
|
2224
|
+
result["event_study"][h] = psi_h
|
|
2225
|
+
|
|
2226
|
+
# Group effects: per-group weights
|
|
2227
|
+
# NOTE: weight logic duplicated from _aggregate_group.
|
|
2228
|
+
# If weight scheme changes there, update here too.
|
|
2229
|
+
if group_effects:
|
|
2230
|
+
result["group"] = {}
|
|
2231
|
+
df_1 = df.loc[omega_1_mask]
|
|
2232
|
+
cohorts = df_1[first_treat].values
|
|
2233
|
+
n_omega_1 = int(omega_1_mask.sum())
|
|
2234
|
+
|
|
2235
|
+
for g in group_effects:
|
|
2236
|
+
if group_effects[g].get("n_obs", 0) == 0:
|
|
2237
|
+
continue
|
|
2238
|
+
if not np.isfinite(group_effects[g].get("effect", np.nan)):
|
|
2239
|
+
continue
|
|
2240
|
+
g_mask = cohorts == g
|
|
2241
|
+
weights_g = np.zeros(n_omega_1)
|
|
2242
|
+
finite_g = np.isfinite(tau_hat) & g_mask
|
|
2243
|
+
n_valid_g = int(finite_g.sum())
|
|
2244
|
+
if n_valid_g == 0:
|
|
2245
|
+
continue
|
|
2246
|
+
weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g
|
|
2247
|
+
|
|
2248
|
+
psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g)
|
|
2249
|
+
result["group"][g] = psi_g
|
|
2250
|
+
|
|
2251
|
+
return result
|
|
2252
|
+
|
|
2253
|
+
def _run_bootstrap(
|
|
2254
|
+
self,
|
|
2255
|
+
original_att: float,
|
|
2256
|
+
original_event_study: Optional[Dict[int, Dict[str, Any]]],
|
|
2257
|
+
original_group: Optional[Dict[Any, Dict[str, Any]]],
|
|
2258
|
+
psi_data: Dict[str, Any],
|
|
2259
|
+
) -> ImputationBootstrapResults:
|
|
2260
|
+
"""
|
|
2261
|
+
Run multiplier bootstrap on pre-computed influence function sums.
|
|
2262
|
+
|
|
2263
|
+
Uses T_b = sum_i w_b_i * psi_i where w_b_i are Rademacher weights
|
|
2264
|
+
and psi_i are cluster-level influence function sums from Theorem 3.
|
|
2265
|
+
SE = std(T_b, ddof=1).
|
|
2266
|
+
"""
|
|
2267
|
+
if self.n_bootstrap < 50:
|
|
2268
|
+
warnings.warn(
|
|
2269
|
+
f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
|
|
2270
|
+
"for reliable inference.",
|
|
2271
|
+
UserWarning,
|
|
2272
|
+
stacklevel=3,
|
|
2273
|
+
)
|
|
2274
|
+
|
|
2275
|
+
rng = np.random.default_rng(self.seed)
|
|
2276
|
+
|
|
2277
|
+
from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch
|
|
2278
|
+
|
|
2279
|
+
overall_psi, cluster_ids = psi_data["overall"]
|
|
2280
|
+
n_clusters = len(cluster_ids)
|
|
2281
|
+
|
|
2282
|
+
# Generate ALL weights upfront: shape (n_bootstrap, n_clusters)
|
|
2283
|
+
all_weights = _generate_bootstrap_weights_batch(
|
|
2284
|
+
self.n_bootstrap, n_clusters, "rademacher", rng
|
|
2285
|
+
)
|
|
2286
|
+
|
|
2287
|
+
# Overall ATT bootstrap draws
|
|
2288
|
+
boot_overall = all_weights @ overall_psi # (n_bootstrap,)
|
|
2289
|
+
|
|
2290
|
+
# Event study: loop over horizons
|
|
2291
|
+
boot_event_study: Optional[Dict[int, np.ndarray]] = None
|
|
2292
|
+
if original_event_study and "event_study" in psi_data:
|
|
2293
|
+
boot_event_study = {}
|
|
2294
|
+
for h, psi_h in psi_data["event_study"].items():
|
|
2295
|
+
boot_event_study[h] = all_weights @ psi_h
|
|
2296
|
+
|
|
2297
|
+
# Group effects: loop over groups
|
|
2298
|
+
boot_group: Optional[Dict[Any, np.ndarray]] = None
|
|
2299
|
+
if original_group and "group" in psi_data:
|
|
2300
|
+
boot_group = {}
|
|
2301
|
+
for g, psi_g in psi_data["group"].items():
|
|
2302
|
+
boot_group[g] = all_weights @ psi_g
|
|
2303
|
+
|
|
2304
|
+
# --- Inference (percentile bootstrap, matching CS/SA convention) ---
|
|
2305
|
+
# Shift perturbation-centered draws to effect-centered draws.
|
|
2306
|
+
# The multiplier bootstrap produces T_b = sum w_b_i * psi_i centered at 0.
|
|
2307
|
+
# CS adds the original effect back (L411 of staggered_bootstrap.py).
|
|
2308
|
+
# We do the same here so percentile CIs and empirical p-values work correctly.
|
|
2309
|
+
boot_overall_shifted = boot_overall + original_att
|
|
2310
|
+
|
|
2311
|
+
overall_se = float(np.std(boot_overall, ddof=1))
|
|
2312
|
+
overall_ci = (
|
|
2313
|
+
self._compute_percentile_ci(boot_overall_shifted, self.alpha)
|
|
2314
|
+
if overall_se > 0
|
|
2315
|
+
else (np.nan, np.nan)
|
|
2316
|
+
)
|
|
2317
|
+
overall_p = (
|
|
2318
|
+
self._compute_bootstrap_pvalue(original_att, boot_overall_shifted)
|
|
2319
|
+
if overall_se > 0
|
|
2320
|
+
else np.nan
|
|
2321
|
+
)
|
|
2322
|
+
|
|
2323
|
+
event_study_ses = None
|
|
2324
|
+
event_study_cis = None
|
|
2325
|
+
event_study_p_values = None
|
|
2326
|
+
if boot_event_study and original_event_study:
|
|
2327
|
+
event_study_ses = {}
|
|
2328
|
+
event_study_cis = {}
|
|
2329
|
+
event_study_p_values = {}
|
|
2330
|
+
for h in boot_event_study:
|
|
2331
|
+
se_h = float(np.std(boot_event_study[h], ddof=1))
|
|
2332
|
+
event_study_ses[h] = se_h
|
|
2333
|
+
orig_eff = original_event_study[h]["effect"]
|
|
2334
|
+
if se_h > 0 and np.isfinite(orig_eff):
|
|
2335
|
+
shifted_h = boot_event_study[h] + orig_eff
|
|
2336
|
+
event_study_p_values[h] = self._compute_bootstrap_pvalue(orig_eff, shifted_h)
|
|
2337
|
+
event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha)
|
|
2338
|
+
else:
|
|
2339
|
+
event_study_p_values[h] = np.nan
|
|
2340
|
+
event_study_cis[h] = (np.nan, np.nan)
|
|
2341
|
+
|
|
2342
|
+
group_ses = None
|
|
2343
|
+
group_cis = None
|
|
2344
|
+
group_p_values = None
|
|
2345
|
+
if boot_group and original_group:
|
|
2346
|
+
group_ses = {}
|
|
2347
|
+
group_cis = {}
|
|
2348
|
+
group_p_values = {}
|
|
2349
|
+
for g in boot_group:
|
|
2350
|
+
se_g = float(np.std(boot_group[g], ddof=1))
|
|
2351
|
+
group_ses[g] = se_g
|
|
2352
|
+
orig_eff = original_group[g]["effect"]
|
|
2353
|
+
if se_g > 0 and np.isfinite(orig_eff):
|
|
2354
|
+
shifted_g = boot_group[g] + orig_eff
|
|
2355
|
+
group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g)
|
|
2356
|
+
group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha)
|
|
2357
|
+
else:
|
|
2358
|
+
group_p_values[g] = np.nan
|
|
2359
|
+
group_cis[g] = (np.nan, np.nan)
|
|
2360
|
+
|
|
2361
|
+
return ImputationBootstrapResults(
|
|
2362
|
+
n_bootstrap=self.n_bootstrap,
|
|
2363
|
+
weight_type="rademacher",
|
|
2364
|
+
alpha=self.alpha,
|
|
2365
|
+
overall_att_se=overall_se,
|
|
2366
|
+
overall_att_ci=overall_ci,
|
|
2367
|
+
overall_att_p_value=overall_p,
|
|
2368
|
+
event_study_ses=event_study_ses,
|
|
2369
|
+
event_study_cis=event_study_cis,
|
|
2370
|
+
event_study_p_values=event_study_p_values,
|
|
2371
|
+
group_ses=group_ses,
|
|
2372
|
+
group_cis=group_cis,
|
|
2373
|
+
group_p_values=group_p_values,
|
|
2374
|
+
bootstrap_distribution=boot_overall_shifted,
|
|
2375
|
+
)
|
|
2376
|
+
|
|
2377
|
+
# =========================================================================
|
|
2378
|
+
# sklearn-compatible interface
|
|
2379
|
+
# =========================================================================
|
|
2380
|
+
|
|
2381
|
+
def get_params(self) -> Dict[str, Any]:
|
|
2382
|
+
"""Get estimator parameters (sklearn-compatible)."""
|
|
2383
|
+
return {
|
|
2384
|
+
"anticipation": self.anticipation,
|
|
2385
|
+
"alpha": self.alpha,
|
|
2386
|
+
"cluster": self.cluster,
|
|
2387
|
+
"n_bootstrap": self.n_bootstrap,
|
|
2388
|
+
"seed": self.seed,
|
|
2389
|
+
"rank_deficient_action": self.rank_deficient_action,
|
|
2390
|
+
"horizon_max": self.horizon_max,
|
|
2391
|
+
"aux_partition": self.aux_partition,
|
|
2392
|
+
}
|
|
2393
|
+
|
|
2394
|
+
def set_params(self, **params) -> "ImputationDiD":
|
|
2395
|
+
"""Set estimator parameters (sklearn-compatible)."""
|
|
2396
|
+
for key, value in params.items():
|
|
2397
|
+
if hasattr(self, key):
|
|
2398
|
+
setattr(self, key, value)
|
|
2399
|
+
else:
|
|
2400
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
2401
|
+
return self
|
|
2402
|
+
|
|
2403
|
+
def summary(self) -> str:
|
|
2404
|
+
"""Get summary of estimation results."""
|
|
2405
|
+
if not self.is_fitted_:
|
|
2406
|
+
raise RuntimeError("Model must be fitted before calling summary()")
|
|
2407
|
+
assert self.results_ is not None
|
|
2408
|
+
return self.results_.summary()
|
|
2409
|
+
|
|
2410
|
+
def print_summary(self) -> None:
|
|
2411
|
+
"""Print summary to stdout."""
|
|
2412
|
+
print(self.summary())
|
|
2413
|
+
|
|
2414
|
+
|
|
2415
|
+
# =============================================================================
|
|
2416
|
+
# Convenience function
|
|
2417
|
+
# =============================================================================
|
|
2418
|
+
|
|
2419
|
+
|
|
2420
|
+
def imputation_did(
|
|
2421
|
+
data: pd.DataFrame,
|
|
2422
|
+
outcome: str,
|
|
2423
|
+
unit: str,
|
|
2424
|
+
time: str,
|
|
2425
|
+
first_treat: str,
|
|
2426
|
+
covariates: Optional[List[str]] = None,
|
|
2427
|
+
aggregate: Optional[str] = None,
|
|
2428
|
+
balance_e: Optional[int] = None,
|
|
2429
|
+
**kwargs,
|
|
2430
|
+
) -> ImputationDiDResults:
|
|
2431
|
+
"""
|
|
2432
|
+
Convenience function for imputation DiD estimation.
|
|
2433
|
+
|
|
2434
|
+
This is a shortcut for creating an ImputationDiD estimator and calling fit().
|
|
2435
|
+
|
|
2436
|
+
Parameters
|
|
2437
|
+
----------
|
|
2438
|
+
data : pd.DataFrame
|
|
2439
|
+
Panel data.
|
|
2440
|
+
outcome : str
|
|
2441
|
+
Outcome variable column name.
|
|
2442
|
+
unit : str
|
|
2443
|
+
Unit identifier column name.
|
|
2444
|
+
time : str
|
|
2445
|
+
Time period column name.
|
|
2446
|
+
first_treat : str
|
|
2447
|
+
Column indicating first treatment period (0 for never-treated).
|
|
2448
|
+
covariates : list of str, optional
|
|
2449
|
+
Covariate column names.
|
|
2450
|
+
aggregate : str, optional
|
|
2451
|
+
Aggregation mode: None, "simple", "event_study", "group", "all".
|
|
2452
|
+
balance_e : int, optional
|
|
2453
|
+
Balance event study to cohorts observed at all relative times.
|
|
2454
|
+
**kwargs
|
|
2455
|
+
Additional keyword arguments passed to ImputationDiD constructor.
|
|
2456
|
+
|
|
2457
|
+
Returns
|
|
2458
|
+
-------
|
|
2459
|
+
ImputationDiDResults
|
|
2460
|
+
Estimation results.
|
|
2461
|
+
|
|
2462
|
+
Examples
|
|
2463
|
+
--------
|
|
2464
|
+
>>> from diff_diff import imputation_did, generate_staggered_data
|
|
2465
|
+
>>> data = generate_staggered_data(seed=42)
|
|
2466
|
+
>>> results = imputation_did(data, 'outcome', 'unit', 'time', 'first_treat',
|
|
2467
|
+
... aggregate='event_study')
|
|
2468
|
+
>>> results.print_summary()
|
|
2469
|
+
"""
|
|
2470
|
+
est = ImputationDiD(**kwargs)
|
|
2471
|
+
return est.fit(
|
|
2472
|
+
data,
|
|
2473
|
+
outcome=outcome,
|
|
2474
|
+
unit=unit,
|
|
2475
|
+
time=time,
|
|
2476
|
+
first_treat=first_treat,
|
|
2477
|
+
covariates=covariates,
|
|
2478
|
+
aggregate=aggregate,
|
|
2479
|
+
balance_e=balance_e,
|
|
2480
|
+
)
|