diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1685 @@
1
+ """
2
+ Sun-Abraham Interaction-Weighted Estimator for staggered DiD.
3
+
4
+ Implements the estimator from Sun & Abraham (2021), "Estimating dynamic
5
+ treatment effects in event studies with heterogeneous treatment effects",
6
+ Journal of Econometrics.
7
+
8
+ This provides an alternative to Callaway-Sant'Anna using a saturated
9
+ regression with cohort × relative-time interactions.
10
+ """
11
+
12
+ import warnings
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from diff_diff.bootstrap_utils import compute_effect_bootstrap_stats
20
+ from diff_diff.linalg import LinearRegression
21
+ from diff_diff.results import _format_survey_block, _get_significance_stars
22
+ from diff_diff.utils import (
23
+ safe_inference,
24
+ )
25
+ from diff_diff.utils import (
26
+ within_transform as _within_transform_util,
27
+ )
28
+
29
+
30
+ @dataclass
31
+ class SunAbrahamResults:
32
+ """
33
+ Results from Sun-Abraham (2021) interaction-weighted estimation.
34
+
35
+ Attributes
36
+ ----------
37
+ event_study_effects : dict
38
+ Dictionary mapping relative time to effect dictionaries with keys:
39
+ 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_groups'.
40
+ overall_att : float
41
+ Overall average treatment effect (weighted average of post-treatment effects).
42
+ overall_se : float
43
+ Standard error of overall ATT.
44
+ overall_t_stat : float
45
+ T-statistic for overall ATT.
46
+ overall_p_value : float
47
+ P-value for overall ATT.
48
+ overall_conf_int : tuple
49
+ Confidence interval for overall ATT.
50
+ cohort_weights : dict
51
+ Dictionary mapping relative time to cohort weight dictionaries.
52
+ groups : list
53
+ List of treatment cohorts (first treatment periods).
54
+ time_periods : list
55
+ List of all time periods.
56
+ n_obs : int
57
+ Total number of observations.
58
+ n_treated_units : int
59
+ Number of ever-treated units.
60
+ n_control_units : int
61
+ Number of never-treated units.
62
+ alpha : float
63
+ Significance level used for confidence intervals.
64
+ control_group : str
65
+ Type of control group used.
66
+ """
67
+
68
+ event_study_effects: Dict[int, Dict[str, Any]]
69
+ overall_att: float
70
+ overall_se: float
71
+ overall_t_stat: float
72
+ overall_p_value: float
73
+ overall_conf_int: Tuple[float, float]
74
+ cohort_weights: Dict[int, Dict[Any, float]]
75
+ groups: List[Any]
76
+ time_periods: List[Any]
77
+ n_obs: int
78
+ n_treated_units: int
79
+ n_control_units: int
80
+ alpha: float = 0.05
81
+ control_group: str = "never_treated"
82
+ bootstrap_results: Optional["SABootstrapResults"] = field(default=None, repr=False)
83
+ cohort_effects: Optional[Dict[Tuple[Any, int], Dict[str, Any]]] = field(
84
+ default=None, repr=False
85
+ )
86
+ # Survey design metadata (SurveyMetadata instance from diff_diff.survey)
87
+ survey_metadata: Optional[Any] = field(default=None)
88
+
89
+ def __repr__(self) -> str:
90
+ """Concise string representation."""
91
+ sig = _get_significance_stars(self.overall_p_value)
92
+ n_rel_periods = len(self.event_study_effects)
93
+ return (
94
+ f"SunAbrahamResults(ATT={self.overall_att:.4f}{sig}, "
95
+ f"SE={self.overall_se:.4f}, "
96
+ f"n_groups={len(self.groups)}, "
97
+ f"n_rel_periods={n_rel_periods})"
98
+ )
99
+
100
+ @property
101
+ def coef_var(self) -> float:
102
+ """Coefficient of variation: SE / |overall ATT|. NaN when ATT is 0 or SE non-finite."""
103
+ if not (np.isfinite(self.overall_se) and self.overall_se >= 0):
104
+ return np.nan
105
+ if not np.isfinite(self.overall_att) or self.overall_att == 0:
106
+ return np.nan
107
+ return self.overall_se / abs(self.overall_att)
108
+
109
+ def summary(self, alpha: Optional[float] = None) -> str:
110
+ """
111
+ Generate formatted summary of estimation results.
112
+
113
+ Parameters
114
+ ----------
115
+ alpha : float, optional
116
+ Significance level. Defaults to alpha used in estimation.
117
+
118
+ Returns
119
+ -------
120
+ str
121
+ Formatted summary.
122
+ """
123
+ alpha = alpha or self.alpha
124
+ conf_level = int((1 - alpha) * 100)
125
+
126
+ lines = [
127
+ "=" * 85,
128
+ "Sun-Abraham Interaction-Weighted Estimator Results".center(85),
129
+ "=" * 85,
130
+ "",
131
+ f"{'Total observations:':<30} {self.n_obs:>10}",
132
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
133
+ f"{'Control units:':<30} {self.n_control_units:>10}",
134
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
135
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
136
+ f"{'Control group:':<30} {self.control_group:>10}",
137
+ "",
138
+ ]
139
+
140
+ # Add survey design info
141
+ if self.survey_metadata is not None:
142
+ sm = self.survey_metadata
143
+ lines.extend(_format_survey_block(sm, 85))
144
+
145
+ # Overall ATT
146
+ lines.extend(
147
+ [
148
+ "-" * 85,
149
+ "Overall Average Treatment Effect on the Treated".center(85),
150
+ "-" * 85,
151
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
152
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
153
+ "-" * 85,
154
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
155
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
156
+ f"{_get_significance_stars(self.overall_p_value):>6}",
157
+ "-" * 85,
158
+ "",
159
+ f"{conf_level}% Confidence Interval: "
160
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
161
+ ]
162
+ )
163
+
164
+ cv = self.coef_var
165
+ if np.isfinite(cv):
166
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
167
+
168
+ lines.append("")
169
+
170
+ # Event study effects
171
+ lines.extend(
172
+ [
173
+ "-" * 85,
174
+ "Event Study (Dynamic) Effects".center(85),
175
+ "-" * 85,
176
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
177
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
178
+ "-" * 85,
179
+ ]
180
+ )
181
+
182
+ for rel_t in sorted(self.event_study_effects.keys()):
183
+ eff = self.event_study_effects[rel_t]
184
+ sig = _get_significance_stars(eff["p_value"])
185
+ lines.append(
186
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
187
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
188
+ )
189
+
190
+ lines.extend(["-" * 85, ""])
191
+
192
+ lines.extend(
193
+ [
194
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
195
+ "=" * 85,
196
+ ]
197
+ )
198
+
199
+ return "\n".join(lines)
200
+
201
+ def print_summary(self, alpha: Optional[float] = None) -> None:
202
+ """Print summary to stdout."""
203
+ print(self.summary(alpha))
204
+
205
+ def to_dataframe(self, level: str = "event_study") -> pd.DataFrame:
206
+ """
207
+ Convert results to DataFrame.
208
+
209
+ Parameters
210
+ ----------
211
+ level : str, default="event_study"
212
+ Level of aggregation: "event_study" or "cohort".
213
+
214
+ Returns
215
+ -------
216
+ pd.DataFrame
217
+ Results as DataFrame.
218
+ """
219
+ if level == "event_study":
220
+ rows = []
221
+ for rel_t, data in sorted(self.event_study_effects.items()):
222
+ rows.append(
223
+ {
224
+ "relative_period": rel_t,
225
+ "effect": data["effect"],
226
+ "se": data["se"],
227
+ "t_stat": data["t_stat"],
228
+ "p_value": data["p_value"],
229
+ "conf_int_lower": data["conf_int"][0],
230
+ "conf_int_upper": data["conf_int"][1],
231
+ }
232
+ )
233
+ return pd.DataFrame(rows)
234
+
235
+ elif level == "cohort":
236
+ if self.cohort_effects is None:
237
+ raise ValueError(
238
+ "Cohort-level effects not available. "
239
+ "They are computed internally but not stored by default."
240
+ )
241
+ rows = []
242
+ for (cohort, rel_t), data in sorted(self.cohort_effects.items()):
243
+ rows.append(
244
+ {
245
+ "cohort": cohort,
246
+ "relative_period": rel_t,
247
+ "effect": data["effect"],
248
+ "se": data["se"],
249
+ "weight": data.get("weight", np.nan),
250
+ }
251
+ )
252
+ return pd.DataFrame(rows)
253
+
254
+ else:
255
+ raise ValueError(f"Unknown level: {level}. Use 'event_study' or 'cohort'.")
256
+
257
+ @property
258
+ def is_significant(self) -> bool:
259
+ """Check if overall ATT is significant."""
260
+ return bool(self.overall_p_value < self.alpha)
261
+
262
+ @property
263
+ def significance_stars(self) -> str:
264
+ """Significance stars for overall ATT."""
265
+ return _get_significance_stars(self.overall_p_value)
266
+
267
+
268
+ @dataclass
269
+ class SABootstrapResults:
270
+ """
271
+ Results from Sun-Abraham bootstrap inference.
272
+
273
+ Attributes
274
+ ----------
275
+ n_bootstrap : int
276
+ Number of bootstrap iterations.
277
+ weight_type : str
278
+ Type of bootstrap used (always "pairs" for pairs bootstrap).
279
+ alpha : float
280
+ Significance level used for confidence intervals.
281
+ overall_att_se : float
282
+ Bootstrap standard error for overall ATT.
283
+ overall_att_ci : Tuple[float, float]
284
+ Bootstrap confidence interval for overall ATT.
285
+ overall_att_p_value : float
286
+ Bootstrap p-value for overall ATT.
287
+ event_study_ses : Dict[int, float]
288
+ Bootstrap SEs for event study effects.
289
+ event_study_cis : Dict[int, Tuple[float, float]]
290
+ Bootstrap CIs for event study effects.
291
+ event_study_p_values : Dict[int, float]
292
+ Bootstrap p-values for event study effects.
293
+ bootstrap_distribution : Optional[np.ndarray]
294
+ Full bootstrap distribution of overall ATT.
295
+ """
296
+
297
+ n_bootstrap: int
298
+ weight_type: str
299
+ alpha: float
300
+ overall_att_se: float
301
+ overall_att_ci: Tuple[float, float]
302
+ overall_att_p_value: float
303
+ event_study_ses: Dict[int, float]
304
+ event_study_cis: Dict[int, Tuple[float, float]]
305
+ event_study_p_values: Dict[int, float]
306
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
307
+
308
+
309
+ class SunAbraham:
310
+ """
311
+ Sun-Abraham (2021) interaction-weighted estimator for staggered DiD.
312
+
313
+ This estimator provides event-study coefficients using a saturated
314
+ TWFE regression with cohort × relative-time interactions, following
315
+ the methodology in Sun & Abraham (2021).
316
+
317
+ The estimation procedure follows three steps:
318
+ 1. Run a saturated TWFE regression with cohort × relative-time dummies
319
+ 2. Compute cohort shares (weights) at each relative time
320
+ 3. Aggregate cohort-specific effects using interaction weights
321
+
322
+ This avoids the negative weighting problem of standard TWFE and provides
323
+ consistent event-study estimates under treatment effect heterogeneity.
324
+
325
+ Parameters
326
+ ----------
327
+ control_group : str, default="never_treated"
328
+ Which units to use as controls:
329
+ - "never_treated": Use only never-treated units (recommended)
330
+ - "not_yet_treated": Use never-treated and not-yet-treated units
331
+ anticipation : int, default=0
332
+ Number of periods before treatment where effects may occur.
333
+ alpha : float, default=0.05
334
+ Significance level for confidence intervals.
335
+ cluster : str, optional
336
+ Column name for cluster-robust standard errors.
337
+ If None, clusters at the unit level by default.
338
+ n_bootstrap : int, default=0
339
+ Number of bootstrap iterations for inference.
340
+ If 0, uses analytical cluster-robust standard errors.
341
+ seed : int, optional
342
+ Random seed for reproducibility.
343
+ rank_deficient_action : str, default="warn"
344
+ Action when design matrix is rank-deficient (linearly dependent columns):
345
+ - "warn": Issue warning and drop linearly dependent columns (default)
346
+ - "error": Raise ValueError
347
+ - "silent": Drop columns silently without warning
348
+
349
+ Attributes
350
+ ----------
351
+ results_ : SunAbrahamResults
352
+ Estimation results after calling fit().
353
+ is_fitted_ : bool
354
+ Whether the model has been fitted.
355
+
356
+ Examples
357
+ --------
358
+ Basic usage:
359
+
360
+ >>> import pandas as pd
361
+ >>> from diff_diff import SunAbraham
362
+ >>>
363
+ >>> # Panel data with staggered treatment
364
+ >>> data = pd.DataFrame({
365
+ ... 'unit': [...],
366
+ ... 'time': [...],
367
+ ... 'outcome': [...],
368
+ ... 'first_treat': [...] # 0 for never-treated
369
+ ... })
370
+ >>>
371
+ >>> sa = SunAbraham()
372
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
373
+ ... time='time', first_treat='first_treat')
374
+ >>> results.print_summary()
375
+
376
+ With covariates:
377
+
378
+ >>> sa = SunAbraham()
379
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
380
+ ... time='time', first_treat='first_treat',
381
+ ... covariates=['age', 'income'])
382
+
383
+ Notes
384
+ -----
385
+ The Sun-Abraham estimator uses a saturated regression approach:
386
+
387
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × 1(G_i=g) × D_{it}^e] + X'γ + ε_it
388
+
389
+ where:
390
+ - α_i = unit fixed effects
391
+ - λ_t = time fixed effects
392
+ - G_i = unit i's treatment cohort (first treatment period)
393
+ - D_{it}^e = indicator for being e periods from treatment
394
+ - δ_{g,e} = cohort-specific effect (CATT) at relative time e
395
+
396
+ The event-study coefficients are then computed as:
397
+
398
+ β_e = Σ_g w_{g,e} × δ_{g,e}
399
+
400
+ where w_{g,e} is the share of cohort g in the treated population at
401
+ relative time e (interaction weights).
402
+
403
+ Compared to Callaway-Sant'Anna:
404
+ - SA uses saturated regression; CS uses 2x2 DiD comparisons
405
+ - SA can be more efficient when model is correctly specified
406
+ - Both are consistent under heterogeneous treatment effects
407
+ - Running both provides a useful robustness check
408
+
409
+ References
410
+ ----------
411
+ Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in
412
+ event studies with heterogeneous treatment effects. Journal of
413
+ Econometrics, 225(2), 175-199.
414
+ """
415
+
416
+ def __init__(
417
+ self,
418
+ control_group: str = "never_treated",
419
+ anticipation: int = 0,
420
+ alpha: float = 0.05,
421
+ cluster: Optional[str] = None,
422
+ n_bootstrap: int = 0,
423
+ seed: Optional[int] = None,
424
+ rank_deficient_action: str = "warn",
425
+ ):
426
+ if control_group not in ["never_treated", "not_yet_treated"]:
427
+ raise ValueError(
428
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
429
+ f"got '{control_group}'"
430
+ )
431
+
432
+ if rank_deficient_action not in ["warn", "error", "silent"]:
433
+ raise ValueError(
434
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
435
+ f"got '{rank_deficient_action}'"
436
+ )
437
+
438
+ self.control_group = control_group
439
+ self.anticipation = anticipation
440
+ self.alpha = alpha
441
+ self.cluster = cluster
442
+ self.n_bootstrap = n_bootstrap
443
+ self.seed = seed
444
+ self.rank_deficient_action = rank_deficient_action
445
+
446
+ self.is_fitted_ = False
447
+ self.results_: Optional[SunAbrahamResults] = None
448
+ self._reference_period = -1 # Will be set during fit
449
+
450
+ def fit(
451
+ self,
452
+ data: pd.DataFrame,
453
+ outcome: str,
454
+ unit: str,
455
+ time: str,
456
+ first_treat: str,
457
+ covariates: Optional[List[str]] = None,
458
+ survey_design: object = None,
459
+ ) -> SunAbrahamResults:
460
+ """
461
+ Fit the Sun-Abraham estimator using saturated regression.
462
+
463
+ Parameters
464
+ ----------
465
+ data : pd.DataFrame
466
+ Panel data with unit and time identifiers.
467
+ outcome : str
468
+ Name of outcome variable column.
469
+ unit : str
470
+ Name of unit identifier column.
471
+ time : str
472
+ Name of time period column.
473
+ first_treat : str
474
+ Name of column indicating when unit was first treated.
475
+ Use 0 (or np.inf) for never-treated units.
476
+ covariates : list, optional
477
+ List of covariate column names to include in regression.
478
+ survey_design : SurveyDesign, optional
479
+ Survey design specification for design-based inference.
480
+ Supports weighted estimation and Taylor series linearization
481
+ variance with strata, PSU, and FPC.
482
+
483
+ Returns
484
+ -------
485
+ SunAbrahamResults
486
+ Object containing all estimation results.
487
+
488
+ Raises
489
+ ------
490
+ ValueError
491
+ If required columns are missing or data validation fails.
492
+ """
493
+ # Validate inputs
494
+ required_cols = [outcome, unit, time, first_treat]
495
+ if covariates:
496
+ required_cols.extend(covariates)
497
+
498
+ missing = [c for c in required_cols if c not in data.columns]
499
+ if missing:
500
+ raise ValueError(f"Missing columns: {missing}")
501
+
502
+ # Resolve survey design if provided
503
+ from diff_diff.survey import (
504
+ _resolve_effective_cluster,
505
+ _resolve_survey_for_fit,
506
+ _validate_unit_constant_survey,
507
+ )
508
+
509
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
510
+ _resolve_survey_for_fit(survey_design, data, "analytical")
511
+ )
512
+
513
+ # Validate survey columns are constant within units (required for
514
+ # unit-level collapse in Rao-Wu bootstrap)
515
+ if resolved_survey is not None:
516
+ _validate_unit_constant_survey(data, unit, survey_design)
517
+
518
+ _uses_replicate_sa = resolved_survey is not None and resolved_survey.uses_replicate_variance
519
+ if _uses_replicate_sa and self.n_bootstrap > 0:
520
+ raise ValueError(
521
+ "Cannot use n_bootstrap > 0 with replicate-weight survey designs. "
522
+ "Replicate weights provide their own variance estimation."
523
+ )
524
+
525
+ # Bootstrap + survey supported via Rao-Wu rescaled bootstrap.
526
+ # Determine Rao-Wu eligibility from the *original* survey_design
527
+ # (before cluster-as-PSU injection which adds PSU to weights-only designs).
528
+ _use_rao_wu = False
529
+ if survey_design is not None and resolved_survey is not None:
530
+ _has_explicit_strata = getattr(survey_design, "strata", None) is not None
531
+ _has_explicit_psu = getattr(survey_design, "psu", None) is not None
532
+ _has_explicit_fpc = getattr(survey_design, "fpc", None) is not None
533
+ if _has_explicit_strata or _has_explicit_psu or _has_explicit_fpc:
534
+ _use_rao_wu = True
535
+
536
+ # Create working copy
537
+ df = data.copy()
538
+
539
+ # Ensure numeric types
540
+ df[time] = pd.to_numeric(df[time])
541
+ df[first_treat] = pd.to_numeric(df[first_treat])
542
+
543
+ # Never-treated indicator (must precede treatment_groups to exclude np.inf)
544
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
545
+ # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
546
+ df.loc[df[first_treat] == np.inf, first_treat] = 0
547
+
548
+ # Identify groups and time periods
549
+ time_periods = sorted(df[time].unique())
550
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
551
+
552
+ # Get unique units
553
+ unit_info = (
554
+ df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index()
555
+ )
556
+
557
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
558
+ n_control_units = int((unit_info["_never_treated"]).sum())
559
+
560
+ if n_control_units == 0:
561
+ raise ValueError("No never-treated units found. Check 'first_treat' column.")
562
+
563
+ if len(treatment_groups) == 0:
564
+ raise ValueError("No treated units found. Check 'first_treat' column.")
565
+
566
+ # Compute relative time for each observation (vectorized)
567
+ df["_rel_time"] = np.where(df[first_treat] > 0, df[time] - df[first_treat], np.nan)
568
+
569
+ # Identify the range of relative time periods to estimate
570
+ rel_times_by_cohort = {}
571
+ for g in treatment_groups:
572
+ g_times = df[df[first_treat] == g][time].unique()
573
+ rel_times_by_cohort[g] = sorted([t - g for t in g_times])
574
+
575
+ # Find all relative time values
576
+ all_rel_times: set = set()
577
+ for g, rel_times in rel_times_by_cohort.items():
578
+ all_rel_times.update(rel_times)
579
+
580
+ all_rel_times_sorted = sorted(all_rel_times)
581
+
582
+ # Use full range of relative times (no artificial truncation, matches R's fixest::sunab())
583
+ min_rel = min(all_rel_times_sorted)
584
+ max_rel = max(all_rel_times_sorted)
585
+
586
+ # Reference period: last pre-treatment period (typically -1)
587
+ self._reference_period = -1 - self.anticipation
588
+
589
+ # Get relative periods to estimate (excluding reference)
590
+ rel_periods_to_estimate = [
591
+ e
592
+ for e in all_rel_times_sorted
593
+ if min_rel <= e <= max_rel and e != self._reference_period
594
+ ]
595
+
596
+ # Determine cluster variable
597
+ cluster_var = self.cluster if self.cluster is not None else unit
598
+
599
+ # Filter data based on control_group setting
600
+ if self.control_group == "never_treated":
601
+ # Only keep never-treated as controls
602
+ df_reg = df[df["_never_treated"] | (df[first_treat] > 0)].copy()
603
+ else:
604
+ # Keep all units (not_yet_treated will be handled by the regression)
605
+ df_reg = df.copy()
606
+
607
+ # Resolve effective cluster and inject cluster-as-PSU
608
+ cluster_ids_raw = df_reg[cluster_var].values if cluster_var in df_reg.columns else None
609
+ effective_cluster_ids = _resolve_effective_cluster(
610
+ resolved_survey, cluster_ids_raw, cluster_var if self.cluster is not None else None
611
+ )
612
+ if resolved_survey is not None and effective_cluster_ids is not None:
613
+ from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata
614
+
615
+ resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
616
+ if resolved_survey.psu is not None and survey_metadata is not None:
617
+ raw_w = (
618
+ data[survey_design.weights].values.astype(np.float64)
619
+ if survey_design.weights
620
+ else np.ones(len(data), dtype=np.float64)
621
+ )
622
+ survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
623
+
624
+ # Fit saturated regression
625
+ (
626
+ cohort_effects,
627
+ cohort_ses,
628
+ vcov_cohort,
629
+ coef_index_map,
630
+ ) = self._fit_saturated_regression(
631
+ df_reg,
632
+ outcome,
633
+ unit,
634
+ time,
635
+ first_treat,
636
+ treatment_groups,
637
+ rel_periods_to_estimate,
638
+ covariates,
639
+ cluster_var,
640
+ survey_weights=survey_weights,
641
+ survey_weight_type=survey_weight_type,
642
+ # For replicate designs: pass None to prevent LinearRegression from
643
+ # computing bogus replicate vcov on already-demeaned data. We
644
+ # override vcov_cohort below with the correct estimator-level refit.
645
+ resolved_survey=None if _uses_replicate_sa else resolved_survey,
646
+ )
647
+
648
+ # Replicate variance override: fully refit the IW estimator per
649
+ # replicate, including recomputing cohort-share aggregation weights
650
+ # from w_r, so replicate SEs reflect the complete estimator.
651
+ _n_valid_rep_sa = None
652
+ if _uses_replicate_sa:
653
+ from diff_diff.survey import compute_replicate_refit_variance
654
+
655
+ # The refit returns [overall_att, es_e0, es_e1, ...] after
656
+ # full re-aggregation with replicate-weighted cohort shares.
657
+ _sa_rel_periods = list(rel_periods_to_estimate)
658
+
659
+ def _refit_sa(w_r):
660
+ # Drop zero-weight obs for within-transform safety
661
+ nz = w_r > 0
662
+ df_reg_nz = df_reg[nz] if not np.all(nz) else df_reg
663
+ w_nz = w_r[nz] if not np.all(nz) else w_r
664
+ ce_r, _, vcov_r, cim_r = self._fit_saturated_regression(
665
+ df_reg_nz,
666
+ outcome,
667
+ unit,
668
+ time,
669
+ first_treat,
670
+ treatment_groups,
671
+ _sa_rel_periods,
672
+ covariates,
673
+ cluster_var,
674
+ survey_weights=w_nz,
675
+ survey_weight_type=survey_weight_type,
676
+ resolved_survey=None,
677
+ )
678
+ # Create temp weight column for IW aggregation with w_r
679
+ # Use full w_r (including zeros) for correct mass computation
680
+ _wt_col = "_rep_wt"
681
+ df[_wt_col] = w_r
682
+ es_r, _ = self._compute_iw_effects(
683
+ df,
684
+ unit,
685
+ first_treat,
686
+ treatment_groups,
687
+ _sa_rel_periods,
688
+ ce_r,
689
+ {},
690
+ vcov_r,
691
+ cim_r,
692
+ survey_weight_col=_wt_col,
693
+ )
694
+ att_r, _ = self._compute_overall_att(
695
+ df,
696
+ first_treat,
697
+ es_r,
698
+ ce_r,
699
+ _,
700
+ vcov_r,
701
+ cim_r,
702
+ survey_weight_col=_wt_col,
703
+ )
704
+ results = [att_r]
705
+ for e in _sa_rel_periods:
706
+ results.append(es_r[e]["effect"] if e in es_r else np.nan)
707
+ return np.array(results)
708
+
709
+ # Resolve survey weight column name for cohort aggregation
710
+ survey_weight_col = (
711
+ survey_design.weights
712
+ if survey_design is not None
713
+ and hasattr(survey_design, "weights")
714
+ and survey_design.weights
715
+ else None
716
+ )
717
+
718
+ # Survey degrees of freedom for t-distribution inference
719
+ _sa_survey_df = (
720
+ max(survey_metadata.df_survey, 1)
721
+ if survey_metadata is not None and survey_metadata.df_survey is not None
722
+ else None
723
+ )
724
+ # Replicate df: rank-deficient → NaN inference (dropped-replicate
725
+ # override happens after replicate refit below)
726
+ if _uses_replicate_sa and _sa_survey_df is None:
727
+ _sa_survey_df = 0 # rank-deficient replicate → NaN inference
728
+
729
+ # Compute interaction-weighted event study effects
730
+ event_study_effects, cohort_weights = self._compute_iw_effects(
731
+ df,
732
+ unit,
733
+ first_treat,
734
+ treatment_groups,
735
+ rel_periods_to_estimate,
736
+ cohort_effects,
737
+ cohort_ses,
738
+ vcov_cohort,
739
+ coef_index_map,
740
+ survey_weight_col=survey_weight_col,
741
+ survey_df=_sa_survey_df,
742
+ )
743
+
744
+ # Compute overall ATT (average of post-treatment effects)
745
+ overall_att, overall_se = self._compute_overall_att(
746
+ df,
747
+ first_treat,
748
+ event_study_effects,
749
+ cohort_effects,
750
+ cohort_weights,
751
+ vcov_cohort,
752
+ coef_index_map,
753
+ survey_weight_col=survey_weight_col,
754
+ )
755
+
756
+ overall_t, overall_p, overall_ci = safe_inference(
757
+ overall_att, overall_se, alpha=self.alpha, df=_sa_survey_df
758
+ )
759
+
760
+ # Replicate variance override: refit fully re-aggregated estimates
761
+ if _uses_replicate_sa:
762
+ # Build full-sample estimate vector from actual outputs
763
+ _full_est_sa = [overall_att]
764
+ for e in _sa_rel_periods:
765
+ _full_est_sa.append(
766
+ event_study_effects[e]["effect"] if e in event_study_effects else np.nan
767
+ )
768
+
769
+ _vcov_sa, _n_valid_rep_sa = compute_replicate_refit_variance(
770
+ _refit_sa, np.array(_full_est_sa), resolved_survey
771
+ )
772
+
773
+ # Override df if replicates dropped
774
+ if _n_valid_rep_sa < resolved_survey.n_replicates:
775
+ _sa_survey_df = _n_valid_rep_sa - 1 if _n_valid_rep_sa > 1 else 0
776
+ if survey_metadata is not None:
777
+ survey_metadata.df_survey = (
778
+ _sa_survey_df if _sa_survey_df and _sa_survey_df > 0 else None
779
+ )
780
+
781
+ # Override overall ATT SE
782
+ overall_se = float(np.sqrt(max(_vcov_sa[0, 0], 0.0)))
783
+ overall_t, overall_p, overall_ci = safe_inference(
784
+ overall_att, overall_se, alpha=self.alpha, df=_sa_survey_df
785
+ )
786
+
787
+ # Override event-study SEs
788
+ for i, e in enumerate(_sa_rel_periods):
789
+ if e in event_study_effects and np.isfinite(event_study_effects[e]["effect"]):
790
+ se_e = float(np.sqrt(max(_vcov_sa[1 + i, 1 + i], 0.0)))
791
+ eff_e = event_study_effects[e]["effect"]
792
+ t_e, p_e, ci_e = safe_inference(eff_e, se_e, alpha=self.alpha, df=_sa_survey_df)
793
+ event_study_effects[e]["se"] = se_e
794
+ event_study_effects[e]["t_stat"] = t_e
795
+ event_study_effects[e]["p_value"] = p_e
796
+ event_study_effects[e]["conf_int"] = ci_e
797
+
798
+ # Cohort-level replicate SEs: second refit for raw (g,e) coefficients
799
+ _keys_ordered = sorted(coef_index_map.keys(), key=lambda k: coef_index_map[k])
800
+ _full_cohort_vec = np.array([cohort_effects.get(k, np.nan) for k in _keys_ordered])
801
+
802
+ def _refit_sa_cohort(w_r):
803
+ nz = w_r > 0
804
+ df_reg_nz = df_reg[nz] if not np.all(nz) else df_reg
805
+ w_nz = w_r[nz] if not np.all(nz) else w_r
806
+ ce_r, _, _, _ = self._fit_saturated_regression(
807
+ df_reg_nz,
808
+ outcome,
809
+ unit,
810
+ time,
811
+ first_treat,
812
+ treatment_groups,
813
+ _sa_rel_periods,
814
+ covariates,
815
+ cluster_var,
816
+ survey_weights=w_nz,
817
+ survey_weight_type=survey_weight_type,
818
+ resolved_survey=None,
819
+ )
820
+ return np.array([ce_r.get(k, np.nan) for k in _keys_ordered])
821
+
822
+ _vcov_cohort_rep, _ = compute_replicate_refit_variance(
823
+ _refit_sa_cohort, _full_cohort_vec, resolved_survey
824
+ )
825
+ for key in _keys_ordered:
826
+ idx = coef_index_map[key]
827
+ cohort_ses[key] = float(np.sqrt(max(_vcov_cohort_rep[idx, idx], 0.0)))
828
+
829
+ # Run bootstrap if requested
830
+ bootstrap_results = None
831
+ if self.n_bootstrap > 0:
832
+ bootstrap_results = self._run_bootstrap(
833
+ df=df_reg,
834
+ outcome=outcome,
835
+ unit=unit,
836
+ time=time,
837
+ first_treat=first_treat,
838
+ treatment_groups=treatment_groups,
839
+ rel_periods_to_estimate=rel_periods_to_estimate,
840
+ covariates=covariates,
841
+ cluster_var=cluster_var,
842
+ original_event_study=event_study_effects,
843
+ original_overall_att=overall_att,
844
+ resolved_survey=resolved_survey,
845
+ survey_weights=survey_weights,
846
+ survey_weight_type=survey_weight_type,
847
+ survey_weight_col=survey_weight_col,
848
+ use_rao_wu=_use_rao_wu,
849
+ )
850
+
851
+ # Update results with bootstrap inference
852
+ overall_se = bootstrap_results.overall_att_se
853
+ overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
854
+ overall_p = bootstrap_results.overall_att_p_value
855
+ overall_ci = bootstrap_results.overall_att_ci
856
+
857
+ # Update event study effects
858
+ for e in event_study_effects:
859
+ if e in bootstrap_results.event_study_ses:
860
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
861
+ event_study_effects[e]["conf_int"] = bootstrap_results.event_study_cis[e]
862
+ event_study_effects[e]["p_value"] = bootstrap_results.event_study_p_values[e]
863
+ eff_val = event_study_effects[e]["effect"]
864
+ se_val = event_study_effects[e]["se"]
865
+ event_study_effects[e]["t_stat"] = safe_inference(
866
+ eff_val, se_val, alpha=self.alpha
867
+ )[0]
868
+
869
+ # Convert cohort effects to storage format
870
+ cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {}
871
+ for (g, e), effect in cohort_effects.items():
872
+ weight = cohort_weights.get(e, {}).get(g, 0.0)
873
+ se = cohort_ses.get((g, e), 0.0)
874
+ cohort_effects_storage[(g, e)] = {
875
+ "effect": effect,
876
+ "se": se,
877
+ "weight": weight,
878
+ }
879
+
880
+ # Store results
881
+ self.results_ = SunAbrahamResults(
882
+ event_study_effects=event_study_effects,
883
+ overall_att=overall_att,
884
+ overall_se=overall_se,
885
+ overall_t_stat=overall_t,
886
+ overall_p_value=overall_p,
887
+ overall_conf_int=overall_ci,
888
+ cohort_weights=cohort_weights,
889
+ groups=treatment_groups,
890
+ time_periods=time_periods,
891
+ n_obs=len(df),
892
+ n_treated_units=n_treated_units,
893
+ n_control_units=n_control_units,
894
+ alpha=self.alpha,
895
+ control_group=self.control_group,
896
+ bootstrap_results=bootstrap_results,
897
+ cohort_effects=cohort_effects_storage,
898
+ survey_metadata=survey_metadata,
899
+ )
900
+
901
+ self.is_fitted_ = True
902
+ return self.results_
903
+
904
+ def _fit_saturated_regression(
905
+ self,
906
+ df: pd.DataFrame,
907
+ outcome: str,
908
+ unit: str,
909
+ time: str,
910
+ first_treat: str,
911
+ treatment_groups: List[Any],
912
+ rel_periods: List[int],
913
+ covariates: Optional[List[str]],
914
+ cluster_var: str,
915
+ survey_weights: Optional[np.ndarray] = None,
916
+ survey_weight_type: str = "pweight",
917
+ resolved_survey: object = None,
918
+ ) -> Tuple[
919
+ Dict[Tuple[Any, int], float],
920
+ Dict[Tuple[Any, int], float],
921
+ np.ndarray,
922
+ Dict[Tuple[Any, int], int],
923
+ ]:
924
+ """
925
+ Fit saturated TWFE regression with cohort × relative-time interactions.
926
+
927
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε
928
+
929
+ Uses within-transformation for unit fixed effects and time dummies.
930
+
931
+ Returns
932
+ -------
933
+ cohort_effects : dict
934
+ Mapping (cohort, rel_period) -> effect estimate δ_{g,e}
935
+ cohort_ses : dict
936
+ Mapping (cohort, rel_period) -> standard error
937
+ vcov : np.ndarray
938
+ Variance-covariance matrix for cohort effects
939
+ coef_index_map : dict
940
+ Mapping (cohort, rel_period) -> index in coefficient vector
941
+ """
942
+ df = df.copy()
943
+
944
+ # Create cohort × relative-time interaction dummies
945
+ # Exclude reference period
946
+ # Build all columns at once to avoid fragmentation
947
+ interaction_data = {}
948
+ coef_index_map: Dict[Tuple[Any, int], int] = {}
949
+ idx = 0
950
+
951
+ for g in treatment_groups:
952
+ for e in rel_periods:
953
+ col_name = f"_D_{g}_{e}"
954
+ # Indicator: unit is in cohort g AND at relative time e
955
+ indicator = ((df[first_treat] == g) & (df["_rel_time"] == e)).astype(float)
956
+
957
+ # Only include if there are observations
958
+ if indicator.sum() > 0:
959
+ interaction_data[col_name] = indicator.values
960
+ coef_index_map[(g, e)] = idx
961
+ idx += 1
962
+
963
+ # Add all interaction columns at once
964
+ interaction_cols = list(interaction_data.keys())
965
+ if interaction_data:
966
+ interaction_df = pd.DataFrame(interaction_data, index=df.index)
967
+ df = pd.concat([df, interaction_df], axis=1)
968
+
969
+ if len(interaction_cols) == 0:
970
+ raise ValueError(
971
+ "No valid cohort × relative-time interactions found. " "Check your data structure."
972
+ )
973
+
974
+ # Apply within-transformation for unit and time fixed effects
975
+ variables_to_demean = [outcome] + interaction_cols
976
+ if covariates:
977
+ variables_to_demean.extend(covariates)
978
+
979
+ df_demeaned = _within_transform_util(
980
+ df, variables_to_demean, unit, time, suffix="_dm", weights=survey_weights
981
+ )
982
+
983
+ # Build design matrix
984
+ X_cols = [f"{col}_dm" for col in interaction_cols]
985
+ if covariates:
986
+ X_cols.extend([f"{cov}_dm" for cov in covariates])
987
+
988
+ X = df_demeaned[X_cols].values
989
+ y = df_demeaned[f"{outcome}_dm"].values
990
+
991
+ # Fit OLS using LinearRegression helper (more stable than manual X'X inverse)
992
+ cluster_ids = df_demeaned[cluster_var].values
993
+
994
+ # Degrees of freedom adjustment for absorbed unit and time fixed effects
995
+ n_units_fe = df[unit].nunique()
996
+ n_times_fe = df[time].nunique()
997
+ df_adj = n_units_fe + n_times_fe - 1
998
+
999
+ reg = LinearRegression(
1000
+ include_intercept=False, # Already demeaned, no intercept needed
1001
+ robust=True,
1002
+ cluster_ids=cluster_ids,
1003
+ rank_deficient_action=self.rank_deficient_action,
1004
+ weights=survey_weights,
1005
+ weight_type=survey_weight_type,
1006
+ survey_design=resolved_survey,
1007
+ ).fit(X, y, df_adjustment=df_adj)
1008
+
1009
+ vcov = reg.vcov_
1010
+
1011
+ # Extract cohort effects and standard errors using get_inference
1012
+ cohort_effects: Dict[Tuple[Any, int], float] = {}
1013
+ cohort_ses: Dict[Tuple[Any, int], float] = {}
1014
+
1015
+ n_interactions = len(interaction_cols)
1016
+ for (g, e), coef_idx in coef_index_map.items():
1017
+ inference = reg.get_inference(coef_idx)
1018
+ cohort_effects[(g, e)] = inference.coefficient
1019
+ cohort_ses[(g, e)] = inference.se
1020
+
1021
+ # Extract just the vcov for cohort effects (excluding covariates)
1022
+ assert vcov is not None
1023
+ vcov_cohort = vcov[:n_interactions, :n_interactions]
1024
+
1025
+ return cohort_effects, cohort_ses, vcov_cohort, coef_index_map
1026
+
1027
+ def _within_transform(
1028
+ self,
1029
+ df: pd.DataFrame,
1030
+ variables: List[str],
1031
+ unit: str,
1032
+ time: str,
1033
+ ) -> pd.DataFrame:
1034
+ """
1035
+ Apply two-way within transformation to remove unit and time fixed effects.
1036
+
1037
+ y_it - y_i. - y_.t + y_..
1038
+ """
1039
+ return _within_transform_util(df, variables, unit, time, suffix="_dm")
1040
+
1041
+ def _compute_iw_effects(
1042
+ self,
1043
+ df: pd.DataFrame,
1044
+ unit: str,
1045
+ first_treat: str,
1046
+ treatment_groups: List[Any],
1047
+ rel_periods: List[int],
1048
+ cohort_effects: Dict[Tuple[Any, int], float],
1049
+ cohort_ses: Dict[Tuple[Any, int], float],
1050
+ vcov_cohort: np.ndarray,
1051
+ coef_index_map: Dict[Tuple[Any, int], int],
1052
+ survey_weight_col: Optional[str] = None,
1053
+ survey_df: Optional[int] = None,
1054
+ ) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]:
1055
+ """
1056
+ Compute interaction-weighted event study effects.
1057
+
1058
+ β_e = Σ_g w_{g,e} × δ_{g,e}
1059
+
1060
+ where w_{g,e} = n_{g,e} / Σ_g n_{g,e} is the share of observations from cohort g
1061
+ at event-time e among all treated observations at that event-time.
1062
+
1063
+ When survey weights are provided, n_{g,e} is the survey-weighted mass
1064
+ (sum of weights) rather than raw observation counts, so the estimand
1065
+ reflects the survey-weighted cohort composition.
1066
+
1067
+ Returns
1068
+ -------
1069
+ event_study_effects : dict
1070
+ Dictionary mapping relative period to aggregated effect info.
1071
+ cohort_weights : dict
1072
+ Dictionary mapping relative period to cohort weight dictionary.
1073
+ """
1074
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
1075
+ cohort_weights: Dict[int, Dict[Any, float]] = {}
1076
+
1077
+ # Pre-compute per-event-time observation mass: n_{g,e}
1078
+ # With survey weights, use weighted sum; otherwise raw counts.
1079
+ treated_mask = df[first_treat] > 0
1080
+ if survey_weight_col is not None and survey_weight_col in df.columns:
1081
+ event_time_counts = (
1082
+ df[treated_mask].groupby([first_treat, "_rel_time"])[survey_weight_col].sum()
1083
+ )
1084
+ else:
1085
+ event_time_counts = df[treated_mask].groupby([first_treat, "_rel_time"]).size()
1086
+
1087
+ for e in rel_periods:
1088
+ # Get cohorts that have observations at this relative time
1089
+ cohorts_at_e = [g for g in treatment_groups if (g, e) in cohort_effects]
1090
+
1091
+ if not cohorts_at_e:
1092
+ continue
1093
+
1094
+ # Compute IW weights: n_{g,e} / Σ_g n_{g,e}
1095
+ weights = {}
1096
+ total_size = 0
1097
+ for g in cohorts_at_e:
1098
+ n_g_e = event_time_counts.get((g, e), 0)
1099
+ weights[g] = n_g_e
1100
+ total_size += n_g_e
1101
+
1102
+ if total_size == 0:
1103
+ continue
1104
+
1105
+ # Normalize weights
1106
+ for g in weights:
1107
+ weights[g] = weights[g] / total_size
1108
+
1109
+ cohort_weights[e] = weights
1110
+
1111
+ # Compute weighted average effect
1112
+ agg_effect = 0.0
1113
+ for g in cohorts_at_e:
1114
+ w = weights[g]
1115
+ agg_effect += w * cohort_effects[(g, e)]
1116
+
1117
+ # Compute SE using delta method with vcov
1118
+ # Var(β_e) = w' Σ w where w is weight vector and Σ is vcov submatrix
1119
+ indices = [coef_index_map[(g, e)] for g in cohorts_at_e]
1120
+ weight_vec = np.array([weights[g] for g in cohorts_at_e])
1121
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
1122
+ agg_var = float(weight_vec @ vcov_subset @ weight_vec)
1123
+ agg_se = np.sqrt(max(agg_var, 0))
1124
+
1125
+ t_stat, p_val, ci = safe_inference(agg_effect, agg_se, alpha=self.alpha, df=survey_df)
1126
+
1127
+ event_study_effects[e] = {
1128
+ "effect": agg_effect,
1129
+ "se": agg_se,
1130
+ "t_stat": t_stat,
1131
+ "p_value": p_val,
1132
+ "conf_int": ci,
1133
+ "n_groups": len(cohorts_at_e),
1134
+ }
1135
+
1136
+ return event_study_effects, cohort_weights
1137
+
1138
+ def _compute_overall_att(
1139
+ self,
1140
+ df: pd.DataFrame,
1141
+ first_treat: str,
1142
+ event_study_effects: Dict[int, Dict[str, Any]],
1143
+ cohort_effects: Dict[Tuple[Any, int], float],
1144
+ cohort_weights: Dict[int, Dict[Any, float]],
1145
+ vcov_cohort: np.ndarray,
1146
+ coef_index_map: Dict[Tuple[Any, int], int],
1147
+ survey_weight_col: Optional[str] = None,
1148
+ ) -> Tuple[float, float]:
1149
+ """
1150
+ Compute overall ATT as weighted average of post-treatment effects.
1151
+
1152
+ When survey weights are provided, the per-period weights use
1153
+ survey-weighted mass rather than raw observation counts.
1154
+
1155
+ Returns (att, se) tuple.
1156
+ """
1157
+ post_effects = [(e, eff) for e, eff in event_study_effects.items() if e >= 0]
1158
+
1159
+ if not post_effects:
1160
+ return np.nan, np.nan
1161
+
1162
+ # Weight by (survey-weighted) mass of treated observations at each relative time
1163
+ post_weights = []
1164
+ post_estimates = []
1165
+
1166
+ for e, eff in post_effects:
1167
+ mask = (df["_rel_time"] == e) & (df[first_treat] > 0)
1168
+ if survey_weight_col is not None and survey_weight_col in df.columns:
1169
+ # No floor for survey weights — valid masses can be < 1
1170
+ n_at_e = df.loc[mask, survey_weight_col].sum()
1171
+ post_weights.append(n_at_e if n_at_e > 0 else 0.0)
1172
+ else:
1173
+ n_at_e = len(df[mask])
1174
+ post_weights.append(max(n_at_e, 1))
1175
+ post_estimates.append(eff["effect"])
1176
+
1177
+ post_weights_arr = np.array(post_weights, dtype=float)
1178
+ post_weights_arr = post_weights_arr / post_weights_arr.sum()
1179
+
1180
+ overall_att = float(np.sum(post_weights_arr * np.array(post_estimates)))
1181
+
1182
+ # Compute SE using delta method
1183
+ # Need to trace back through the full weighting scheme
1184
+ # ATT = Σ_e w_e × β_e = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e}
1185
+ # Collect all (g, e) pairs and their overall weights
1186
+ overall_weights_by_coef: Dict[Tuple[Any, int], float] = {}
1187
+
1188
+ for i, (e, _) in enumerate(post_effects):
1189
+ period_weight = post_weights_arr[i]
1190
+ if e in cohort_weights:
1191
+ for g, cw in cohort_weights[e].items():
1192
+ key = (g, e)
1193
+ if key in coef_index_map:
1194
+ if key not in overall_weights_by_coef:
1195
+ overall_weights_by_coef[key] = 0.0
1196
+ overall_weights_by_coef[key] += period_weight * cw
1197
+
1198
+ if not overall_weights_by_coef:
1199
+ # Fallback to simplified variance that ignores covariances between periods
1200
+ warnings.warn(
1201
+ "Could not construct full weight vector for overall ATT SE. "
1202
+ "Using simplified variance that ignores covariances between periods.",
1203
+ UserWarning,
1204
+ stacklevel=2,
1205
+ )
1206
+ overall_var = float(
1207
+ np.sum(
1208
+ (post_weights_arr**2) * np.array([eff["se"] ** 2 for _, eff in post_effects])
1209
+ )
1210
+ )
1211
+ return overall_att, np.sqrt(overall_var)
1212
+
1213
+ # Build full weight vector and compute variance
1214
+ indices = [coef_index_map[key] for key in overall_weights_by_coef.keys()]
1215
+ weight_vec = np.array(list(overall_weights_by_coef.values()))
1216
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
1217
+ overall_var = float(weight_vec @ vcov_subset @ weight_vec)
1218
+ overall_se = np.sqrt(max(overall_var, 0))
1219
+
1220
+ return overall_att, overall_se
1221
+
1222
+ def _run_bootstrap(
1223
+ self,
1224
+ df: pd.DataFrame,
1225
+ outcome: str,
1226
+ unit: str,
1227
+ time: str,
1228
+ first_treat: str,
1229
+ treatment_groups: List[Any],
1230
+ rel_periods_to_estimate: List[int],
1231
+ covariates: Optional[List[str]],
1232
+ cluster_var: str,
1233
+ original_event_study: Dict[int, Dict[str, Any]],
1234
+ original_overall_att: float,
1235
+ resolved_survey: object = None,
1236
+ survey_weights: Optional[np.ndarray] = None,
1237
+ survey_weight_type: str = "pweight",
1238
+ survey_weight_col: Optional[str] = None,
1239
+ use_rao_wu: bool = False,
1240
+ ) -> SABootstrapResults:
1241
+ """
1242
+ Run bootstrap for inference.
1243
+
1244
+ When use_rao_wu is True (survey design with explicit strata/PSU/FPC),
1245
+ uses Rao-Wu rescaled bootstrap (weight perturbation). Otherwise, uses
1246
+ pairs bootstrap (resampling units with replacement).
1247
+ """
1248
+ if self.n_bootstrap < 50:
1249
+ warnings.warn(
1250
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
1251
+ "for reliable inference.",
1252
+ UserWarning,
1253
+ stacklevel=3,
1254
+ )
1255
+
1256
+ rng = np.random.default_rng(self.seed)
1257
+
1258
+ if use_rao_wu:
1259
+ return self._run_rao_wu_bootstrap(
1260
+ df=df,
1261
+ outcome=outcome,
1262
+ unit=unit,
1263
+ time=time,
1264
+ first_treat=first_treat,
1265
+ treatment_groups=treatment_groups,
1266
+ rel_periods_to_estimate=rel_periods_to_estimate,
1267
+ covariates=covariates,
1268
+ cluster_var=cluster_var,
1269
+ original_event_study=original_event_study,
1270
+ original_overall_att=original_overall_att,
1271
+ resolved_survey=resolved_survey,
1272
+ survey_weight_type=survey_weight_type,
1273
+ survey_weight_col=survey_weight_col,
1274
+ rng=rng,
1275
+ )
1276
+
1277
+ # --- Pairs bootstrap (non-survey or weights-only survey) ---
1278
+
1279
+ # Get unique units
1280
+ all_units = df[unit].unique()
1281
+ n_units = len(all_units)
1282
+
1283
+ # Pre-compute unit -> row indices mapping (avoids repeated boolean scans)
1284
+ unit_row_indices = {u: df.index[df[unit] == u].values for u in all_units}
1285
+ unit_row_counts = {u: len(idx) for u, idx in unit_row_indices.items()}
1286
+
1287
+ # Store bootstrap samples
1288
+ rel_periods = sorted(original_event_study.keys())
1289
+ bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods}
1290
+ bootstrap_overall = np.zeros(self.n_bootstrap)
1291
+
1292
+ for b in range(self.n_bootstrap):
1293
+ # Resample units with replacement (pairs bootstrap)
1294
+ boot_units = rng.choice(all_units, size=n_units, replace=True)
1295
+
1296
+ # Create bootstrap sample using pre-computed index mapping
1297
+ boot_indices = np.concatenate([unit_row_indices[u] for u in boot_units])
1298
+ df_b = df.iloc[boot_indices].copy()
1299
+
1300
+ # Reassign unique unit IDs (vectorized)
1301
+ rows_per_unit = np.array([unit_row_counts[u] for u in boot_units])
1302
+ df_b[unit] = np.repeat(np.arange(n_units), rows_per_unit)
1303
+
1304
+ # Recompute relative time (vectorized)
1305
+ df_b["_rel_time"] = np.where(
1306
+ df_b[first_treat] > 0, df_b[time] - df_b[first_treat], np.nan
1307
+ )
1308
+ # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
1309
+ df_b["_never_treated"] = (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)
1310
+
1311
+ try:
1312
+ # Extract survey weights from resampled data if present
1313
+ boot_survey_weights = None
1314
+ if survey_weight_col is not None and survey_weight_col in df_b.columns:
1315
+ boot_survey_weights = df_b[survey_weight_col].values
1316
+
1317
+ # Re-estimate saturated regression
1318
+ (
1319
+ cohort_effects_b,
1320
+ cohort_ses_b,
1321
+ vcov_b,
1322
+ coef_map_b,
1323
+ ) = self._fit_saturated_regression(
1324
+ df_b,
1325
+ outcome,
1326
+ unit,
1327
+ time,
1328
+ first_treat,
1329
+ treatment_groups,
1330
+ rel_periods_to_estimate,
1331
+ covariates,
1332
+ cluster_var,
1333
+ survey_weights=boot_survey_weights,
1334
+ survey_weight_type=survey_weight_type,
1335
+ resolved_survey=None, # Use explicit weights, not stale design
1336
+ )
1337
+
1338
+ # Compute IW effects for this bootstrap sample
1339
+ event_study_b, cohort_weights_b = self._compute_iw_effects(
1340
+ df_b,
1341
+ unit,
1342
+ first_treat,
1343
+ treatment_groups,
1344
+ rel_periods_to_estimate,
1345
+ cohort_effects_b,
1346
+ cohort_ses_b,
1347
+ vcov_b,
1348
+ coef_map_b,
1349
+ survey_weight_col=survey_weight_col,
1350
+ )
1351
+
1352
+ # Store bootstrap estimates
1353
+ for e in rel_periods:
1354
+ if e in event_study_b:
1355
+ bootstrap_effects[e][b] = event_study_b[e]["effect"]
1356
+ else:
1357
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1358
+
1359
+ # Compute overall ATT for this bootstrap sample
1360
+ overall_b, _ = self._compute_overall_att(
1361
+ df_b,
1362
+ first_treat,
1363
+ event_study_b,
1364
+ cohort_effects_b,
1365
+ cohort_weights_b,
1366
+ vcov_b,
1367
+ coef_map_b,
1368
+ survey_weight_col=survey_weight_col,
1369
+ )
1370
+ bootstrap_overall[b] = overall_b
1371
+
1372
+ except (ValueError, np.linalg.LinAlgError) as exc:
1373
+ # If bootstrap iteration fails, use original
1374
+ warnings.warn(
1375
+ f"Bootstrap iteration {b} failed: {exc}. Using original estimate.",
1376
+ UserWarning,
1377
+ stacklevel=2,
1378
+ )
1379
+ for e in rel_periods:
1380
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1381
+ bootstrap_overall[b] = original_overall_att
1382
+
1383
+ # Compute bootstrap statistics
1384
+ event_study_ses = {}
1385
+ event_study_cis = {}
1386
+ event_study_p_values = {}
1387
+
1388
+ for e in rel_periods:
1389
+ boot_dist = bootstrap_effects[e]
1390
+ original_effect = original_event_study[e]["effect"]
1391
+ se, ci, p_value = compute_effect_bootstrap_stats(
1392
+ original_effect,
1393
+ boot_dist,
1394
+ alpha=self.alpha,
1395
+ context=f"event study e={e}",
1396
+ )
1397
+ event_study_ses[e] = se
1398
+ event_study_cis[e] = ci
1399
+ event_study_p_values[e] = p_value
1400
+
1401
+ # Overall ATT statistics
1402
+ overall_se, overall_ci, overall_p = compute_effect_bootstrap_stats(
1403
+ original_overall_att,
1404
+ bootstrap_overall,
1405
+ alpha=self.alpha,
1406
+ context="overall ATT",
1407
+ )
1408
+
1409
+ return SABootstrapResults(
1410
+ n_bootstrap=self.n_bootstrap,
1411
+ weight_type="pairs",
1412
+ alpha=self.alpha,
1413
+ overall_att_se=overall_se,
1414
+ overall_att_ci=overall_ci,
1415
+ overall_att_p_value=overall_p,
1416
+ event_study_ses=event_study_ses,
1417
+ event_study_cis=event_study_cis,
1418
+ event_study_p_values=event_study_p_values,
1419
+ bootstrap_distribution=bootstrap_overall,
1420
+ )
1421
+
1422
+ def _run_rao_wu_bootstrap(
1423
+ self,
1424
+ df: pd.DataFrame,
1425
+ outcome: str,
1426
+ unit: str,
1427
+ time: str,
1428
+ first_treat: str,
1429
+ treatment_groups: List[Any],
1430
+ rel_periods_to_estimate: List[int],
1431
+ covariates: Optional[List[str]],
1432
+ cluster_var: str,
1433
+ original_event_study: Dict[int, Dict[str, Any]],
1434
+ original_overall_att: float,
1435
+ resolved_survey: object,
1436
+ survey_weight_type: str,
1437
+ survey_weight_col: Optional[str],
1438
+ rng: np.random.Generator,
1439
+ ) -> SABootstrapResults:
1440
+ """
1441
+ Run Rao-Wu rescaled bootstrap for survey-aware inference.
1442
+
1443
+ Instead of physically resampling units, each iteration generates
1444
+ rescaled observation weights via Rao-Wu (1988) weight perturbation.
1445
+ The rescaled weights feed into the existing WLS regression path.
1446
+ """
1447
+ from diff_diff.bootstrap_utils import generate_rao_wu_weights
1448
+ from diff_diff.survey import ResolvedSurveyDesign
1449
+
1450
+ # Column name for rescaled weights in the bootstrap DataFrame
1451
+ _rw_col = "__rw_boot_weight"
1452
+
1453
+ # Collapse survey design to unit level so Rao-Wu respects panel
1454
+ # structure: each unit gets one set of weights regardless of how
1455
+ # many time periods it has. Without this, when there is no
1456
+ # explicit PSU, generate_rao_wu_weights treats each observation as
1457
+ # its own PSU and different obs of the same unit can get different
1458
+ # weights, breaking panel semantics.
1459
+ all_units = df[unit].unique()
1460
+
1461
+ weights_unit = (
1462
+ pd.Series(resolved_survey.weights, index=df.index)
1463
+ .groupby(df[unit])
1464
+ .first()
1465
+ .reindex(all_units)
1466
+ .values.astype(np.float64)
1467
+ )
1468
+
1469
+ strata_unit = None
1470
+ if resolved_survey.strata is not None:
1471
+ strata_unit = (
1472
+ pd.Series(resolved_survey.strata, index=df.index)
1473
+ .groupby(df[unit])
1474
+ .first()
1475
+ .reindex(all_units)
1476
+ .values
1477
+ )
1478
+
1479
+ psu_unit = None
1480
+ if resolved_survey.psu is not None:
1481
+ psu_unit = (
1482
+ pd.Series(resolved_survey.psu, index=df.index)
1483
+ .groupby(df[unit])
1484
+ .first()
1485
+ .reindex(all_units)
1486
+ .values
1487
+ )
1488
+
1489
+ fpc_unit = None
1490
+ if resolved_survey.fpc is not None:
1491
+ fpc_unit = (
1492
+ pd.Series(resolved_survey.fpc, index=df.index)
1493
+ .groupby(df[unit])
1494
+ .first()
1495
+ .reindex(all_units)
1496
+ .values
1497
+ )
1498
+
1499
+ unit_resolved = ResolvedSurveyDesign(
1500
+ weights=weights_unit,
1501
+ weight_type=resolved_survey.weight_type,
1502
+ strata=strata_unit,
1503
+ psu=psu_unit,
1504
+ fpc=fpc_unit,
1505
+ n_strata=resolved_survey.n_strata,
1506
+ n_psu=resolved_survey.n_psu,
1507
+ lonely_psu=resolved_survey.lonely_psu,
1508
+ )
1509
+
1510
+ # Build unit -> row indices mapping for expanding unit-level weights
1511
+ unit_to_rows = {u: df.index[df[unit] == u].values for u in all_units}
1512
+ unit_order = {u: i for i, u in enumerate(all_units)}
1513
+
1514
+ # Store bootstrap samples
1515
+ rel_periods = sorted(original_event_study.keys())
1516
+ bootstrap_effects = {e: np.full(self.n_bootstrap, np.nan) for e in rel_periods}
1517
+ bootstrap_overall = np.full(self.n_bootstrap, np.nan)
1518
+
1519
+ for b in range(self.n_bootstrap):
1520
+ try:
1521
+ # Generate Rao-Wu rescaled weights at unit level
1522
+ unit_boot_weights = generate_rao_wu_weights(unit_resolved, rng)
1523
+
1524
+ # Expand unit-level weights to observation level
1525
+ boot_weights = np.empty(len(df), dtype=np.float64)
1526
+ for u, idx in unit_to_rows.items():
1527
+ boot_weights[idx] = unit_boot_weights[unit_order[u]]
1528
+
1529
+ # Drop observations with zero weight (PSUs not drawn in this
1530
+ # iteration) to avoid NaN/Inf in within-transformation.
1531
+ positive_mask = boot_weights > 0
1532
+ if positive_mask.sum() < 2:
1533
+ # Too few observations with positive weight
1534
+ raise ValueError("Rao-Wu iteration produced < 2 positive weights")
1535
+
1536
+ df_b = df[positive_mask].reset_index(drop=True)
1537
+ boot_weights_b = boot_weights[positive_mask]
1538
+ df_b[_rw_col] = boot_weights_b
1539
+
1540
+ # Verify we still have both treated and control observations
1541
+ has_treated = (df_b[first_treat] > 0).any()
1542
+ has_control = ((df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)).any()
1543
+ if not has_treated or not has_control:
1544
+ raise ValueError("Rao-Wu iteration dropped all treated or control units")
1545
+
1546
+ # Re-estimate saturated regression with rescaled weights.
1547
+ # Pass resolved_survey=None since inference comes from the
1548
+ # bootstrap distribution, not from within-iteration vcov.
1549
+ (
1550
+ cohort_effects_b,
1551
+ cohort_ses_b,
1552
+ vcov_b,
1553
+ coef_map_b,
1554
+ ) = self._fit_saturated_regression(
1555
+ df_b,
1556
+ outcome,
1557
+ unit,
1558
+ time,
1559
+ first_treat,
1560
+ treatment_groups,
1561
+ rel_periods_to_estimate,
1562
+ covariates,
1563
+ cluster_var,
1564
+ survey_weights=boot_weights_b,
1565
+ survey_weight_type=survey_weight_type,
1566
+ resolved_survey=None,
1567
+ )
1568
+
1569
+ # Compute IW effects using rescaled weights for cohort shares
1570
+ event_study_b, cohort_weights_b = self._compute_iw_effects(
1571
+ df_b,
1572
+ unit,
1573
+ first_treat,
1574
+ treatment_groups,
1575
+ rel_periods_to_estimate,
1576
+ cohort_effects_b,
1577
+ cohort_ses_b,
1578
+ vcov_b,
1579
+ coef_map_b,
1580
+ survey_weight_col=_rw_col,
1581
+ )
1582
+
1583
+ # Store bootstrap estimates
1584
+ for e in rel_periods:
1585
+ if e in event_study_b:
1586
+ bootstrap_effects[e][b] = event_study_b[e]["effect"]
1587
+ else:
1588
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1589
+
1590
+ # Compute overall ATT using rescaled weights
1591
+ overall_b, _ = self._compute_overall_att(
1592
+ df_b,
1593
+ first_treat,
1594
+ event_study_b,
1595
+ cohort_effects_b,
1596
+ cohort_weights_b,
1597
+ vcov_b,
1598
+ coef_map_b,
1599
+ survey_weight_col=_rw_col,
1600
+ )
1601
+ bootstrap_overall[b] = overall_b
1602
+
1603
+ except (ValueError, np.linalg.LinAlgError) as exc:
1604
+ # Failed draws stored as NaN (not original estimate) to avoid
1605
+ # shrinking bootstrap dispersion. compute_effect_bootstrap_stats
1606
+ # handles NaN draws via nanstd.
1607
+ warnings.warn(
1608
+ f"Bootstrap iteration {b} failed: {exc}. Storing NaN.",
1609
+ UserWarning,
1610
+ stacklevel=2,
1611
+ )
1612
+ for e in rel_periods:
1613
+ bootstrap_effects[e][b] = np.nan
1614
+ bootstrap_overall[b] = np.nan
1615
+
1616
+ # Compute bootstrap statistics
1617
+ event_study_ses = {}
1618
+ event_study_cis = {}
1619
+ event_study_p_values = {}
1620
+
1621
+ for e in rel_periods:
1622
+ boot_dist = bootstrap_effects[e]
1623
+ original_effect = original_event_study[e]["effect"]
1624
+ se, ci, p_value = compute_effect_bootstrap_stats(
1625
+ original_effect,
1626
+ boot_dist,
1627
+ alpha=self.alpha,
1628
+ context=f"event study e={e}",
1629
+ )
1630
+ event_study_ses[e] = se
1631
+ event_study_cis[e] = ci
1632
+ event_study_p_values[e] = p_value
1633
+
1634
+ # Overall ATT statistics
1635
+ overall_se, overall_ci, overall_p = compute_effect_bootstrap_stats(
1636
+ original_overall_att,
1637
+ bootstrap_overall,
1638
+ alpha=self.alpha,
1639
+ context="overall ATT",
1640
+ )
1641
+
1642
+ return SABootstrapResults(
1643
+ n_bootstrap=self.n_bootstrap,
1644
+ weight_type="rao_wu",
1645
+ alpha=self.alpha,
1646
+ overall_att_se=overall_se,
1647
+ overall_att_ci=overall_ci,
1648
+ overall_att_p_value=overall_p,
1649
+ event_study_ses=event_study_ses,
1650
+ event_study_cis=event_study_cis,
1651
+ event_study_p_values=event_study_p_values,
1652
+ bootstrap_distribution=bootstrap_overall,
1653
+ )
1654
+
1655
+ def get_params(self) -> Dict[str, Any]:
1656
+ """Get estimator parameters (sklearn-compatible)."""
1657
+ return {
1658
+ "control_group": self.control_group,
1659
+ "anticipation": self.anticipation,
1660
+ "alpha": self.alpha,
1661
+ "cluster": self.cluster,
1662
+ "n_bootstrap": self.n_bootstrap,
1663
+ "seed": self.seed,
1664
+ "rank_deficient_action": self.rank_deficient_action,
1665
+ }
1666
+
1667
+ def set_params(self, **params) -> "SunAbraham":
1668
+ """Set estimator parameters (sklearn-compatible)."""
1669
+ for key, value in params.items():
1670
+ if hasattr(self, key):
1671
+ setattr(self, key, value)
1672
+ else:
1673
+ raise ValueError(f"Unknown parameter: {key}")
1674
+ return self
1675
+
1676
+ def summary(self) -> str:
1677
+ """Get summary of estimation results."""
1678
+ if not self.is_fitted_:
1679
+ raise RuntimeError("Model must be fitted before calling summary()")
1680
+ assert self.results_ is not None
1681
+ return self.results_.summary()
1682
+
1683
+ def print_summary(self) -> None:
1684
+ """Print summary to stdout."""
1685
+ print(self.summary())