diff-diff 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1198 @@
1
+ """
2
+ Sun-Abraham Interaction-Weighted Estimator for staggered DiD.
3
+
4
+ Implements the estimator from Sun & Abraham (2021), "Estimating dynamic
5
+ treatment effects in event studies with heterogeneous treatment effects",
6
+ Journal of Econometrics.
7
+
8
+ This provides an alternative to Callaway-Sant'Anna using a saturated
9
+ regression with cohort × relative-time interactions.
10
+ """
11
+
12
+ import warnings
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from diff_diff.results import _get_significance_stars
20
+ from diff_diff.utils import (
21
+ compute_confidence_interval,
22
+ compute_p_value,
23
+ compute_robust_se,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class SunAbrahamResults:
29
+ """
30
+ Results from Sun-Abraham (2021) interaction-weighted estimation.
31
+
32
+ Attributes
33
+ ----------
34
+ event_study_effects : dict
35
+ Dictionary mapping relative time to effect dictionaries with keys:
36
+ 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_groups'.
37
+ overall_att : float
38
+ Overall average treatment effect (weighted average of post-treatment effects).
39
+ overall_se : float
40
+ Standard error of overall ATT.
41
+ overall_t_stat : float
42
+ T-statistic for overall ATT.
43
+ overall_p_value : float
44
+ P-value for overall ATT.
45
+ overall_conf_int : tuple
46
+ Confidence interval for overall ATT.
47
+ cohort_weights : dict
48
+ Dictionary mapping relative time to cohort weight dictionaries.
49
+ groups : list
50
+ List of treatment cohorts (first treatment periods).
51
+ time_periods : list
52
+ List of all time periods.
53
+ n_obs : int
54
+ Total number of observations.
55
+ n_treated_units : int
56
+ Number of ever-treated units.
57
+ n_control_units : int
58
+ Number of never-treated units.
59
+ alpha : float
60
+ Significance level used for confidence intervals.
61
+ control_group : str
62
+ Type of control group used.
63
+ """
64
+
65
+ event_study_effects: Dict[int, Dict[str, Any]]
66
+ overall_att: float
67
+ overall_se: float
68
+ overall_t_stat: float
69
+ overall_p_value: float
70
+ overall_conf_int: Tuple[float, float]
71
+ cohort_weights: Dict[int, Dict[Any, float]]
72
+ groups: List[Any]
73
+ time_periods: List[Any]
74
+ n_obs: int
75
+ n_treated_units: int
76
+ n_control_units: int
77
+ alpha: float = 0.05
78
+ control_group: str = "never_treated"
79
+ bootstrap_results: Optional["SABootstrapResults"] = field(default=None, repr=False)
80
+ cohort_effects: Optional[Dict[Tuple[Any, int], Dict[str, Any]]] = field(
81
+ default=None, repr=False
82
+ )
83
+
84
+ def __repr__(self) -> str:
85
+ """Concise string representation."""
86
+ sig = _get_significance_stars(self.overall_p_value)
87
+ n_rel_periods = len(self.event_study_effects)
88
+ return (
89
+ f"SunAbrahamResults(ATT={self.overall_att:.4f}{sig}, "
90
+ f"SE={self.overall_se:.4f}, "
91
+ f"n_groups={len(self.groups)}, "
92
+ f"n_rel_periods={n_rel_periods})"
93
+ )
94
+
95
+ def summary(self, alpha: Optional[float] = None) -> str:
96
+ """
97
+ Generate formatted summary of estimation results.
98
+
99
+ Parameters
100
+ ----------
101
+ alpha : float, optional
102
+ Significance level. Defaults to alpha used in estimation.
103
+
104
+ Returns
105
+ -------
106
+ str
107
+ Formatted summary.
108
+ """
109
+ alpha = alpha or self.alpha
110
+ conf_level = int((1 - alpha) * 100)
111
+
112
+ lines = [
113
+ "=" * 85,
114
+ "Sun-Abraham Interaction-Weighted Estimator Results".center(85),
115
+ "=" * 85,
116
+ "",
117
+ f"{'Total observations:':<30} {self.n_obs:>10}",
118
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
119
+ f"{'Control units:':<30} {self.n_control_units:>10}",
120
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
121
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
122
+ f"{'Control group:':<30} {self.control_group:>10}",
123
+ "",
124
+ ]
125
+
126
+ # Overall ATT
127
+ lines.extend(
128
+ [
129
+ "-" * 85,
130
+ "Overall Average Treatment Effect on the Treated".center(85),
131
+ "-" * 85,
132
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
133
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
134
+ "-" * 85,
135
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
136
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
137
+ f"{_get_significance_stars(self.overall_p_value):>6}",
138
+ "-" * 85,
139
+ "",
140
+ f"{conf_level}% Confidence Interval: "
141
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
142
+ "",
143
+ ]
144
+ )
145
+
146
+ # Event study effects
147
+ lines.extend(
148
+ [
149
+ "-" * 85,
150
+ "Event Study (Dynamic) Effects".center(85),
151
+ "-" * 85,
152
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
153
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
154
+ "-" * 85,
155
+ ]
156
+ )
157
+
158
+ for rel_t in sorted(self.event_study_effects.keys()):
159
+ eff = self.event_study_effects[rel_t]
160
+ sig = _get_significance_stars(eff["p_value"])
161
+ lines.append(
162
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
163
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
164
+ )
165
+
166
+ lines.extend(["-" * 85, ""])
167
+
168
+ lines.extend(
169
+ [
170
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
171
+ "=" * 85,
172
+ ]
173
+ )
174
+
175
+ return "\n".join(lines)
176
+
177
+ def print_summary(self, alpha: Optional[float] = None) -> None:
178
+ """Print summary to stdout."""
179
+ print(self.summary(alpha))
180
+
181
+ def to_dataframe(self, level: str = "event_study") -> pd.DataFrame:
182
+ """
183
+ Convert results to DataFrame.
184
+
185
+ Parameters
186
+ ----------
187
+ level : str, default="event_study"
188
+ Level of aggregation: "event_study" or "cohort".
189
+
190
+ Returns
191
+ -------
192
+ pd.DataFrame
193
+ Results as DataFrame.
194
+ """
195
+ if level == "event_study":
196
+ rows = []
197
+ for rel_t, data in sorted(self.event_study_effects.items()):
198
+ rows.append(
199
+ {
200
+ "relative_period": rel_t,
201
+ "effect": data["effect"],
202
+ "se": data["se"],
203
+ "t_stat": data["t_stat"],
204
+ "p_value": data["p_value"],
205
+ "conf_int_lower": data["conf_int"][0],
206
+ "conf_int_upper": data["conf_int"][1],
207
+ }
208
+ )
209
+ return pd.DataFrame(rows)
210
+
211
+ elif level == "cohort":
212
+ if self.cohort_effects is None:
213
+ raise ValueError(
214
+ "Cohort-level effects not available. "
215
+ "They are computed internally but not stored by default."
216
+ )
217
+ rows = []
218
+ for (cohort, rel_t), data in sorted(self.cohort_effects.items()):
219
+ rows.append(
220
+ {
221
+ "cohort": cohort,
222
+ "relative_period": rel_t,
223
+ "effect": data["effect"],
224
+ "se": data["se"],
225
+ "weight": data.get("weight", np.nan),
226
+ }
227
+ )
228
+ return pd.DataFrame(rows)
229
+
230
+ else:
231
+ raise ValueError(
232
+ f"Unknown level: {level}. Use 'event_study' or 'cohort'."
233
+ )
234
+
235
+ @property
236
+ def is_significant(self) -> bool:
237
+ """Check if overall ATT is significant."""
238
+ return bool(self.overall_p_value < self.alpha)
239
+
240
+ @property
241
+ def significance_stars(self) -> str:
242
+ """Significance stars for overall ATT."""
243
+ return _get_significance_stars(self.overall_p_value)
244
+
245
+
246
+ @dataclass
247
+ class SABootstrapResults:
248
+ """
249
+ Results from Sun-Abraham bootstrap inference.
250
+
251
+ Attributes
252
+ ----------
253
+ n_bootstrap : int
254
+ Number of bootstrap iterations.
255
+ weight_type : str
256
+ Type of bootstrap used (always "pairs" for pairs bootstrap).
257
+ alpha : float
258
+ Significance level used for confidence intervals.
259
+ overall_att_se : float
260
+ Bootstrap standard error for overall ATT.
261
+ overall_att_ci : Tuple[float, float]
262
+ Bootstrap confidence interval for overall ATT.
263
+ overall_att_p_value : float
264
+ Bootstrap p-value for overall ATT.
265
+ event_study_ses : Dict[int, float]
266
+ Bootstrap SEs for event study effects.
267
+ event_study_cis : Dict[int, Tuple[float, float]]
268
+ Bootstrap CIs for event study effects.
269
+ event_study_p_values : Dict[int, float]
270
+ Bootstrap p-values for event study effects.
271
+ bootstrap_distribution : Optional[np.ndarray]
272
+ Full bootstrap distribution of overall ATT.
273
+ """
274
+
275
+ n_bootstrap: int
276
+ weight_type: str
277
+ alpha: float
278
+ overall_att_se: float
279
+ overall_att_ci: Tuple[float, float]
280
+ overall_att_p_value: float
281
+ event_study_ses: Dict[int, float]
282
+ event_study_cis: Dict[int, Tuple[float, float]]
283
+ event_study_p_values: Dict[int, float]
284
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
285
+
286
+
287
+ class SunAbraham:
288
+ """
289
+ Sun-Abraham (2021) interaction-weighted estimator for staggered DiD.
290
+
291
+ This estimator provides event-study coefficients using a saturated
292
+ TWFE regression with cohort × relative-time interactions, following
293
+ the methodology in Sun & Abraham (2021).
294
+
295
+ The estimation procedure follows three steps:
296
+ 1. Run a saturated TWFE regression with cohort × relative-time dummies
297
+ 2. Compute cohort shares (weights) at each relative time
298
+ 3. Aggregate cohort-specific effects using interaction weights
299
+
300
+ This avoids the negative weighting problem of standard TWFE and provides
301
+ consistent event-study estimates under treatment effect heterogeneity.
302
+
303
+ Parameters
304
+ ----------
305
+ control_group : str, default="never_treated"
306
+ Which units to use as controls:
307
+ - "never_treated": Use only never-treated units (recommended)
308
+ - "not_yet_treated": Use never-treated and not-yet-treated units
309
+ anticipation : int, default=0
310
+ Number of periods before treatment where effects may occur.
311
+ alpha : float, default=0.05
312
+ Significance level for confidence intervals.
313
+ cluster : str, optional
314
+ Column name for cluster-robust standard errors.
315
+ If None, clusters at the unit level by default.
316
+ n_bootstrap : int, default=0
317
+ Number of bootstrap iterations for inference.
318
+ If 0, uses analytical cluster-robust standard errors.
319
+ seed : int, optional
320
+ Random seed for reproducibility.
321
+
322
+ Attributes
323
+ ----------
324
+ results_ : SunAbrahamResults
325
+ Estimation results after calling fit().
326
+ is_fitted_ : bool
327
+ Whether the model has been fitted.
328
+
329
+ Examples
330
+ --------
331
+ Basic usage:
332
+
333
+ >>> import pandas as pd
334
+ >>> from diff_diff import SunAbraham
335
+ >>>
336
+ >>> # Panel data with staggered treatment
337
+ >>> data = pd.DataFrame({
338
+ ... 'unit': [...],
339
+ ... 'time': [...],
340
+ ... 'outcome': [...],
341
+ ... 'first_treat': [...] # 0 for never-treated
342
+ ... })
343
+ >>>
344
+ >>> sa = SunAbraham()
345
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
346
+ ... time='time', first_treat='first_treat')
347
+ >>> results.print_summary()
348
+
349
+ With covariates:
350
+
351
+ >>> sa = SunAbraham()
352
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
353
+ ... time='time', first_treat='first_treat',
354
+ ... covariates=['age', 'income'])
355
+
356
+ Notes
357
+ -----
358
+ The Sun-Abraham estimator uses a saturated regression approach:
359
+
360
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × 1(G_i=g) × D_{it}^e] + X'γ + ε_it
361
+
362
+ where:
363
+ - α_i = unit fixed effects
364
+ - λ_t = time fixed effects
365
+ - G_i = unit i's treatment cohort (first treatment period)
366
+ - D_{it}^e = indicator for being e periods from treatment
367
+ - δ_{g,e} = cohort-specific effect (CATT) at relative time e
368
+
369
+ The event-study coefficients are then computed as:
370
+
371
+ β_e = Σ_g w_{g,e} × δ_{g,e}
372
+
373
+ where w_{g,e} is the share of cohort g in the treated population at
374
+ relative time e (interaction weights).
375
+
376
+ Compared to Callaway-Sant'Anna:
377
+ - SA uses saturated regression; CS uses 2x2 DiD comparisons
378
+ - SA can be more efficient when model is correctly specified
379
+ - Both are consistent under heterogeneous treatment effects
380
+ - Running both provides a useful robustness check
381
+
382
+ References
383
+ ----------
384
+ Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in
385
+ event studies with heterogeneous treatment effects. Journal of
386
+ Econometrics, 225(2), 175-199.
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ control_group: str = "never_treated",
392
+ anticipation: int = 0,
393
+ alpha: float = 0.05,
394
+ cluster: Optional[str] = None,
395
+ n_bootstrap: int = 0,
396
+ seed: Optional[int] = None,
397
+ ):
398
+ if control_group not in ["never_treated", "not_yet_treated"]:
399
+ raise ValueError(
400
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
401
+ f"got '{control_group}'"
402
+ )
403
+
404
+ self.control_group = control_group
405
+ self.anticipation = anticipation
406
+ self.alpha = alpha
407
+ self.cluster = cluster
408
+ self.n_bootstrap = n_bootstrap
409
+ self.seed = seed
410
+
411
+ self.is_fitted_ = False
412
+ self.results_: Optional[SunAbrahamResults] = None
413
+ self._reference_period = -1 # Will be set during fit
414
+
415
+ def fit(
416
+ self,
417
+ data: pd.DataFrame,
418
+ outcome: str,
419
+ unit: str,
420
+ time: str,
421
+ first_treat: str,
422
+ covariates: Optional[List[str]] = None,
423
+ min_pre_periods: int = 1,
424
+ min_post_periods: int = 1,
425
+ ) -> SunAbrahamResults:
426
+ """
427
+ Fit the Sun-Abraham estimator using saturated regression.
428
+
429
+ Parameters
430
+ ----------
431
+ data : pd.DataFrame
432
+ Panel data with unit and time identifiers.
433
+ outcome : str
434
+ Name of outcome variable column.
435
+ unit : str
436
+ Name of unit identifier column.
437
+ time : str
438
+ Name of time period column.
439
+ first_treat : str
440
+ Name of column indicating when unit was first treated.
441
+ Use 0 (or np.inf) for never-treated units.
442
+ covariates : list, optional
443
+ List of covariate column names to include in regression.
444
+ min_pre_periods : int, default=1
445
+ Minimum number of pre-treatment periods to include in event study.
446
+ min_post_periods : int, default=1
447
+ Minimum number of post-treatment periods to include in event study.
448
+
449
+ Returns
450
+ -------
451
+ SunAbrahamResults
452
+ Object containing all estimation results.
453
+
454
+ Raises
455
+ ------
456
+ ValueError
457
+ If required columns are missing or data validation fails.
458
+ """
459
+ # Validate inputs
460
+ required_cols = [outcome, unit, time, first_treat]
461
+ if covariates:
462
+ required_cols.extend(covariates)
463
+
464
+ missing = [c for c in required_cols if c not in data.columns]
465
+ if missing:
466
+ raise ValueError(f"Missing columns: {missing}")
467
+
468
+ # Create working copy
469
+ df = data.copy()
470
+
471
+ # Ensure numeric types
472
+ df[time] = pd.to_numeric(df[time])
473
+ df[first_treat] = pd.to_numeric(df[first_treat])
474
+
475
+ # Identify groups and time periods
476
+ time_periods = sorted(df[time].unique())
477
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
478
+
479
+ # Never-treated indicator
480
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
481
+
482
+ # Get unique units
483
+ unit_info = (
484
+ df.groupby(unit)
485
+ .agg({first_treat: "first", "_never_treated": "first"})
486
+ .reset_index()
487
+ )
488
+
489
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
490
+ n_control_units = int((unit_info["_never_treated"]).sum())
491
+
492
+ if n_control_units == 0:
493
+ raise ValueError(
494
+ "No never-treated units found. Check 'first_treat' column."
495
+ )
496
+
497
+ if len(treatment_groups) == 0:
498
+ raise ValueError(
499
+ "No treated units found. Check 'first_treat' column."
500
+ )
501
+
502
+ # Compute relative time for each observation (vectorized)
503
+ df["_rel_time"] = np.where(
504
+ df[first_treat] > 0,
505
+ df[time] - df[first_treat],
506
+ np.nan
507
+ )
508
+
509
+ # Identify the range of relative time periods to estimate
510
+ rel_times_by_cohort = {}
511
+ for g in treatment_groups:
512
+ g_times = df[df[first_treat] == g][time].unique()
513
+ rel_times_by_cohort[g] = sorted([t - g for t in g_times])
514
+
515
+ # Find all relative time values
516
+ all_rel_times: set = set()
517
+ for g, rel_times in rel_times_by_cohort.items():
518
+ all_rel_times.update(rel_times)
519
+
520
+ all_rel_times_sorted = sorted(all_rel_times)
521
+
522
+ # Filter to reasonable range
523
+ min_rel = max(min(all_rel_times_sorted), -20) # cap at -20
524
+ max_rel = min(max(all_rel_times_sorted), 20) # cap at +20
525
+
526
+ # Reference period: last pre-treatment period (typically -1)
527
+ self._reference_period = -1 - self.anticipation
528
+
529
+ # Get relative periods to estimate (excluding reference)
530
+ rel_periods_to_estimate = [
531
+ e
532
+ for e in all_rel_times_sorted
533
+ if min_rel <= e <= max_rel and e != self._reference_period
534
+ ]
535
+
536
+ # Determine cluster variable
537
+ cluster_var = self.cluster if self.cluster is not None else unit
538
+
539
+ # Filter data based on control_group setting
540
+ if self.control_group == "never_treated":
541
+ # Only keep never-treated as controls
542
+ df_reg = df[df["_never_treated"] | (df[first_treat] > 0)].copy()
543
+ else:
544
+ # Keep all units (not_yet_treated will be handled by the regression)
545
+ df_reg = df.copy()
546
+
547
+ # Fit saturated regression
548
+ (
549
+ cohort_effects,
550
+ cohort_ses,
551
+ vcov_cohort,
552
+ coef_index_map,
553
+ ) = self._fit_saturated_regression(
554
+ df_reg,
555
+ outcome,
556
+ unit,
557
+ time,
558
+ first_treat,
559
+ treatment_groups,
560
+ rel_periods_to_estimate,
561
+ covariates,
562
+ cluster_var,
563
+ )
564
+
565
+ # Compute interaction-weighted event study effects
566
+ event_study_effects, cohort_weights = self._compute_iw_effects(
567
+ df,
568
+ unit,
569
+ first_treat,
570
+ treatment_groups,
571
+ rel_periods_to_estimate,
572
+ cohort_effects,
573
+ cohort_ses,
574
+ vcov_cohort,
575
+ coef_index_map,
576
+ )
577
+
578
+ # Compute overall ATT (average of post-treatment effects)
579
+ overall_att, overall_se = self._compute_overall_att(
580
+ df,
581
+ first_treat,
582
+ event_study_effects,
583
+ cohort_effects,
584
+ cohort_weights,
585
+ vcov_cohort,
586
+ coef_index_map,
587
+ )
588
+
589
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
590
+ overall_p = compute_p_value(overall_t)
591
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
592
+
593
+ # Run bootstrap if requested
594
+ bootstrap_results = None
595
+ if self.n_bootstrap > 0:
596
+ bootstrap_results = self._run_bootstrap(
597
+ df=df_reg,
598
+ outcome=outcome,
599
+ unit=unit,
600
+ time=time,
601
+ first_treat=first_treat,
602
+ treatment_groups=treatment_groups,
603
+ rel_periods_to_estimate=rel_periods_to_estimate,
604
+ covariates=covariates,
605
+ cluster_var=cluster_var,
606
+ original_event_study=event_study_effects,
607
+ original_overall_att=overall_att,
608
+ )
609
+
610
+ # Update results with bootstrap inference
611
+ overall_se = bootstrap_results.overall_att_se
612
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
613
+ overall_p = bootstrap_results.overall_att_p_value
614
+ overall_ci = bootstrap_results.overall_att_ci
615
+
616
+ # Update event study effects
617
+ for e in event_study_effects:
618
+ if e in bootstrap_results.event_study_ses:
619
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
620
+ event_study_effects[e]["conf_int"] = (
621
+ bootstrap_results.event_study_cis[e]
622
+ )
623
+ event_study_effects[e]["p_value"] = (
624
+ bootstrap_results.event_study_p_values[e]
625
+ )
626
+ eff_val = event_study_effects[e]["effect"]
627
+ se_val = event_study_effects[e]["se"]
628
+ event_study_effects[e]["t_stat"] = (
629
+ eff_val / se_val if se_val > 0 else 0.0
630
+ )
631
+
632
+ # Convert cohort effects to storage format
633
+ cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {}
634
+ for (g, e), effect in cohort_effects.items():
635
+ weight = cohort_weights.get(e, {}).get(g, 0.0)
636
+ se = cohort_ses.get((g, e), 0.0)
637
+ cohort_effects_storage[(g, e)] = {
638
+ "effect": effect,
639
+ "se": se,
640
+ "weight": weight,
641
+ }
642
+
643
+ # Store results
644
+ self.results_ = SunAbrahamResults(
645
+ event_study_effects=event_study_effects,
646
+ overall_att=overall_att,
647
+ overall_se=overall_se,
648
+ overall_t_stat=overall_t,
649
+ overall_p_value=overall_p,
650
+ overall_conf_int=overall_ci,
651
+ cohort_weights=cohort_weights,
652
+ groups=treatment_groups,
653
+ time_periods=time_periods,
654
+ n_obs=len(df),
655
+ n_treated_units=n_treated_units,
656
+ n_control_units=n_control_units,
657
+ alpha=self.alpha,
658
+ control_group=self.control_group,
659
+ bootstrap_results=bootstrap_results,
660
+ cohort_effects=cohort_effects_storage,
661
+ )
662
+
663
+ self.is_fitted_ = True
664
+ return self.results_
665
+
666
+ def _fit_saturated_regression(
667
+ self,
668
+ df: pd.DataFrame,
669
+ outcome: str,
670
+ unit: str,
671
+ time: str,
672
+ first_treat: str,
673
+ treatment_groups: List[Any],
674
+ rel_periods: List[int],
675
+ covariates: Optional[List[str]],
676
+ cluster_var: str,
677
+ ) -> Tuple[
678
+ Dict[Tuple[Any, int], float],
679
+ Dict[Tuple[Any, int], float],
680
+ np.ndarray,
681
+ Dict[Tuple[Any, int], int],
682
+ ]:
683
+ """
684
+ Fit saturated TWFE regression with cohort × relative-time interactions.
685
+
686
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε
687
+
688
+ Uses within-transformation for unit fixed effects and time dummies.
689
+
690
+ Returns
691
+ -------
692
+ cohort_effects : dict
693
+ Mapping (cohort, rel_period) -> effect estimate δ_{g,e}
694
+ cohort_ses : dict
695
+ Mapping (cohort, rel_period) -> standard error
696
+ vcov : np.ndarray
697
+ Variance-covariance matrix for cohort effects
698
+ coef_index_map : dict
699
+ Mapping (cohort, rel_period) -> index in coefficient vector
700
+ """
701
+ df = df.copy()
702
+
703
+ # Create cohort × relative-time interaction dummies
704
+ # Exclude reference period
705
+ # Build all columns at once to avoid fragmentation
706
+ interaction_data = {}
707
+ coef_index_map: Dict[Tuple[Any, int], int] = {}
708
+ idx = 0
709
+
710
+ for g in treatment_groups:
711
+ for e in rel_periods:
712
+ col_name = f"_D_{g}_{e}"
713
+ # Indicator: unit is in cohort g AND at relative time e
714
+ indicator = (
715
+ (df[first_treat] == g) &
716
+ (df["_rel_time"] == e)
717
+ ).astype(float)
718
+
719
+ # Only include if there are observations
720
+ if indicator.sum() > 0:
721
+ interaction_data[col_name] = indicator.values
722
+ coef_index_map[(g, e)] = idx
723
+ idx += 1
724
+
725
+ # Add all interaction columns at once
726
+ interaction_cols = list(interaction_data.keys())
727
+ if interaction_data:
728
+ interaction_df = pd.DataFrame(interaction_data, index=df.index)
729
+ df = pd.concat([df, interaction_df], axis=1)
730
+
731
+ if len(interaction_cols) == 0:
732
+ raise ValueError(
733
+ "No valid cohort × relative-time interactions found. "
734
+ "Check your data structure."
735
+ )
736
+
737
+ # Apply within-transformation for unit and time fixed effects
738
+ variables_to_demean = [outcome] + interaction_cols
739
+ if covariates:
740
+ variables_to_demean.extend(covariates)
741
+
742
+ df_demeaned = self._within_transform(df, variables_to_demean, unit, time)
743
+
744
+ # Build design matrix
745
+ X_cols = [f"{col}_dm" for col in interaction_cols]
746
+ if covariates:
747
+ X_cols.extend([f"{cov}_dm" for cov in covariates])
748
+
749
+ X = df_demeaned[X_cols].values
750
+ y = df_demeaned[f"{outcome}_dm"].values
751
+
752
+ # Fit OLS
753
+ try:
754
+ XtX_inv = np.linalg.inv(X.T @ X)
755
+ except np.linalg.LinAlgError:
756
+ # Use pseudo-inverse for singular matrices
757
+ XtX_inv = np.linalg.pinv(X.T @ X)
758
+
759
+ coefficients = XtX_inv @ (X.T @ y)
760
+ residuals = y - X @ coefficients
761
+
762
+ # Compute cluster-robust standard errors
763
+ cluster_ids = df_demeaned[cluster_var].values
764
+ vcov = compute_robust_se(X, residuals, cluster_ids)
765
+
766
+ # Extract cohort effects and standard errors
767
+ cohort_effects: Dict[Tuple[Any, int], float] = {}
768
+ cohort_ses: Dict[Tuple[Any, int], float] = {}
769
+
770
+ n_interactions = len(interaction_cols)
771
+ for (g, e), coef_idx in coef_index_map.items():
772
+ cohort_effects[(g, e)] = float(coefficients[coef_idx])
773
+ cohort_ses[(g, e)] = float(np.sqrt(vcov[coef_idx, coef_idx]))
774
+
775
+ # Extract just the vcov for cohort effects (excluding covariates)
776
+ vcov_cohort = vcov[:n_interactions, :n_interactions]
777
+
778
+ return cohort_effects, cohort_ses, vcov_cohort, coef_index_map
779
+
780
+ def _within_transform(
781
+ self,
782
+ df: pd.DataFrame,
783
+ variables: List[str],
784
+ unit: str,
785
+ time: str,
786
+ ) -> pd.DataFrame:
787
+ """
788
+ Apply two-way within transformation to remove unit and time fixed effects.
789
+
790
+ y_it - y_i. - y_.t + y_..
791
+ """
792
+ df = df.copy()
793
+
794
+ # Build all demeaned columns at once to avoid fragmentation
795
+ demeaned_data = {}
796
+ for var in variables:
797
+ # Unit means
798
+ unit_means = df.groupby(unit)[var].transform("mean")
799
+ # Time means
800
+ time_means = df.groupby(time)[var].transform("mean")
801
+ # Grand mean
802
+ grand_mean = df[var].mean()
803
+
804
+ # Within transformation
805
+ demeaned_data[f"{var}_dm"] = (
806
+ df[var] - unit_means - time_means + grand_mean
807
+ ).values
808
+
809
+ # Add all demeaned columns at once
810
+ demeaned_df = pd.DataFrame(demeaned_data, index=df.index)
811
+ df = pd.concat([df, demeaned_df], axis=1)
812
+
813
+ return df
814
+
815
+ def _compute_iw_effects(
816
+ self,
817
+ df: pd.DataFrame,
818
+ unit: str,
819
+ first_treat: str,
820
+ treatment_groups: List[Any],
821
+ rel_periods: List[int],
822
+ cohort_effects: Dict[Tuple[Any, int], float],
823
+ cohort_ses: Dict[Tuple[Any, int], float],
824
+ vcov_cohort: np.ndarray,
825
+ coef_index_map: Dict[Tuple[Any, int], int],
826
+ ) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]:
827
+ """
828
+ Compute interaction-weighted event study effects.
829
+
830
+ β_e = Σ_g w_{g,e} × δ_{g,e}
831
+
832
+ where w_{g,e} is the share of cohort g among treated units at relative time e.
833
+
834
+ Returns
835
+ -------
836
+ event_study_effects : dict
837
+ Dictionary mapping relative period to aggregated effect info.
838
+ cohort_weights : dict
839
+ Dictionary mapping relative period to cohort weight dictionary.
840
+ """
841
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
842
+ cohort_weights: Dict[int, Dict[Any, float]] = {}
843
+
844
+ # Get cohort sizes
845
+ unit_cohorts = df.groupby(unit)[first_treat].first()
846
+ cohort_sizes = unit_cohorts[unit_cohorts > 0].value_counts().to_dict()
847
+
848
+ for e in rel_periods:
849
+ # Get cohorts that have observations at this relative time
850
+ cohorts_at_e = [
851
+ g for g in treatment_groups
852
+ if (g, e) in cohort_effects
853
+ ]
854
+
855
+ if not cohorts_at_e:
856
+ continue
857
+
858
+ # Compute IW weights: share of each cohort among those observed at e
859
+ weights = {}
860
+ total_size = 0
861
+ for g in cohorts_at_e:
862
+ n_g = cohort_sizes.get(g, 0)
863
+ weights[g] = n_g
864
+ total_size += n_g
865
+
866
+ if total_size == 0:
867
+ continue
868
+
869
+ # Normalize weights
870
+ for g in weights:
871
+ weights[g] = weights[g] / total_size
872
+
873
+ cohort_weights[e] = weights
874
+
875
+ # Compute weighted average effect
876
+ agg_effect = 0.0
877
+ for g in cohorts_at_e:
878
+ w = weights[g]
879
+ agg_effect += w * cohort_effects[(g, e)]
880
+
881
+ # Compute SE using delta method with vcov
882
+ # Var(β_e) = w' Σ w where w is weight vector and Σ is vcov submatrix
883
+ indices = [coef_index_map[(g, e)] for g in cohorts_at_e]
884
+ weight_vec = np.array([weights[g] for g in cohorts_at_e])
885
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
886
+ agg_var = float(weight_vec @ vcov_subset @ weight_vec)
887
+ agg_se = np.sqrt(max(agg_var, 0))
888
+
889
+ t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
890
+ p_val = compute_p_value(t_stat)
891
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
892
+
893
+ event_study_effects[e] = {
894
+ "effect": agg_effect,
895
+ "se": agg_se,
896
+ "t_stat": t_stat,
897
+ "p_value": p_val,
898
+ "conf_int": ci,
899
+ "n_groups": len(cohorts_at_e),
900
+ }
901
+
902
+ return event_study_effects, cohort_weights
903
+
904
+ def _compute_overall_att(
905
+ self,
906
+ df: pd.DataFrame,
907
+ first_treat: str,
908
+ event_study_effects: Dict[int, Dict[str, Any]],
909
+ cohort_effects: Dict[Tuple[Any, int], float],
910
+ cohort_weights: Dict[int, Dict[Any, float]],
911
+ vcov_cohort: np.ndarray,
912
+ coef_index_map: Dict[Tuple[Any, int], int],
913
+ ) -> Tuple[float, float]:
914
+ """
915
+ Compute overall ATT as weighted average of post-treatment effects.
916
+
917
+ Returns (att, se) tuple.
918
+ """
919
+ post_effects = [
920
+ (e, eff)
921
+ for e, eff in event_study_effects.items()
922
+ if e >= 0
923
+ ]
924
+
925
+ if not post_effects:
926
+ return 0.0, 0.0
927
+
928
+ # Weight by number of treated observations at each relative time
929
+ post_weights = []
930
+ post_estimates = []
931
+
932
+ for e, eff in post_effects:
933
+ n_at_e = len(df[(df["_rel_time"] == e) & (df[first_treat] > 0)])
934
+ post_weights.append(max(n_at_e, 1))
935
+ post_estimates.append(eff["effect"])
936
+
937
+ post_weights = np.array(post_weights, dtype=float)
938
+ post_weights = post_weights / post_weights.sum()
939
+
940
+ overall_att = float(np.sum(post_weights * np.array(post_estimates)))
941
+
942
+ # Compute SE using delta method
943
+ # Need to trace back through the full weighting scheme
944
+ # ATT = Σ_e w_e × β_e = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e}
945
+ # Collect all (g, e) pairs and their overall weights
946
+ overall_weights_by_coef: Dict[Tuple[Any, int], float] = {}
947
+
948
+ for i, (e, _) in enumerate(post_effects):
949
+ period_weight = post_weights[i]
950
+ if e in cohort_weights:
951
+ for g, cw in cohort_weights[e].items():
952
+ key = (g, e)
953
+ if key in coef_index_map:
954
+ if key not in overall_weights_by_coef:
955
+ overall_weights_by_coef[key] = 0.0
956
+ overall_weights_by_coef[key] += period_weight * cw
957
+
958
+ if not overall_weights_by_coef:
959
+ # Fallback to simple variance calculation
960
+ overall_var = float(
961
+ np.sum((post_weights ** 2) * np.array([eff["se"] ** 2 for _, eff in post_effects]))
962
+ )
963
+ return overall_att, np.sqrt(overall_var)
964
+
965
+ # Build full weight vector and compute variance
966
+ indices = [coef_index_map[key] for key in overall_weights_by_coef.keys()]
967
+ weight_vec = np.array(list(overall_weights_by_coef.values()))
968
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
969
+ overall_var = float(weight_vec @ vcov_subset @ weight_vec)
970
+ overall_se = np.sqrt(max(overall_var, 0))
971
+
972
+ return overall_att, overall_se
973
+
974
+ def _run_bootstrap(
975
+ self,
976
+ df: pd.DataFrame,
977
+ outcome: str,
978
+ unit: str,
979
+ time: str,
980
+ first_treat: str,
981
+ treatment_groups: List[Any],
982
+ rel_periods_to_estimate: List[int],
983
+ covariates: Optional[List[str]],
984
+ cluster_var: str,
985
+ original_event_study: Dict[int, Dict[str, Any]],
986
+ original_overall_att: float,
987
+ ) -> SABootstrapResults:
988
+ """
989
+ Run pairs bootstrap for inference.
990
+
991
+ Resamples units with replacement and re-estimates the full model.
992
+ """
993
+ if self.n_bootstrap < 50:
994
+ warnings.warn(
995
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
996
+ "for reliable inference.",
997
+ UserWarning,
998
+ stacklevel=3,
999
+ )
1000
+
1001
+ rng = np.random.default_rng(self.seed)
1002
+
1003
+ # Get unique units
1004
+ all_units = df[unit].unique()
1005
+ n_units = len(all_units)
1006
+
1007
+ # Store bootstrap samples
1008
+ rel_periods = sorted(original_event_study.keys())
1009
+ bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods}
1010
+ bootstrap_overall = np.zeros(self.n_bootstrap)
1011
+
1012
+ for b in range(self.n_bootstrap):
1013
+ # Resample units with replacement (pairs bootstrap)
1014
+ boot_units = rng.choice(all_units, size=n_units, replace=True)
1015
+
1016
+ # Create bootstrap sample efficiently
1017
+ # Build index array for all selected units
1018
+ boot_indices = np.concatenate([
1019
+ df.index[df[unit] == u].values for u in boot_units
1020
+ ])
1021
+ df_b = df.iloc[boot_indices].copy()
1022
+
1023
+ # Reassign unique unit IDs for bootstrap sample
1024
+ # Each resampled unit gets a unique ID
1025
+ new_unit_ids = []
1026
+ current_id = 0
1027
+ for u in boot_units:
1028
+ unit_rows = df[df[unit] == u]
1029
+ for _ in range(len(unit_rows)):
1030
+ new_unit_ids.append(current_id)
1031
+ current_id += 1
1032
+ df_b[unit] = new_unit_ids[:len(df_b)]
1033
+
1034
+ # Recompute relative time (vectorized)
1035
+ df_b["_rel_time"] = np.where(
1036
+ df_b[first_treat] > 0,
1037
+ df_b[time] - df_b[first_treat],
1038
+ np.nan
1039
+ )
1040
+ df_b["_never_treated"] = (
1041
+ (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)
1042
+ )
1043
+
1044
+ try:
1045
+ # Re-estimate saturated regression
1046
+ (
1047
+ cohort_effects_b,
1048
+ cohort_ses_b,
1049
+ vcov_b,
1050
+ coef_map_b,
1051
+ ) = self._fit_saturated_regression(
1052
+ df_b,
1053
+ outcome,
1054
+ unit,
1055
+ time,
1056
+ first_treat,
1057
+ treatment_groups,
1058
+ rel_periods_to_estimate,
1059
+ covariates,
1060
+ cluster_var,
1061
+ )
1062
+
1063
+ # Compute IW effects for this bootstrap sample
1064
+ event_study_b, cohort_weights_b = self._compute_iw_effects(
1065
+ df_b,
1066
+ unit,
1067
+ first_treat,
1068
+ treatment_groups,
1069
+ rel_periods_to_estimate,
1070
+ cohort_effects_b,
1071
+ cohort_ses_b,
1072
+ vcov_b,
1073
+ coef_map_b,
1074
+ )
1075
+
1076
+ # Store bootstrap estimates
1077
+ for e in rel_periods:
1078
+ if e in event_study_b:
1079
+ bootstrap_effects[e][b] = event_study_b[e]["effect"]
1080
+ else:
1081
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1082
+
1083
+ # Compute overall ATT for this bootstrap sample
1084
+ overall_b, _ = self._compute_overall_att(
1085
+ df_b,
1086
+ first_treat,
1087
+ event_study_b,
1088
+ cohort_effects_b,
1089
+ cohort_weights_b,
1090
+ vcov_b,
1091
+ coef_map_b,
1092
+ )
1093
+ bootstrap_overall[b] = overall_b
1094
+
1095
+ except (ValueError, np.linalg.LinAlgError) as exc:
1096
+ # If bootstrap iteration fails, use original
1097
+ warnings.warn(
1098
+ f"Bootstrap iteration {b} failed: {exc}. Using original estimate.",
1099
+ UserWarning,
1100
+ stacklevel=2,
1101
+ )
1102
+ for e in rel_periods:
1103
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1104
+ bootstrap_overall[b] = original_overall_att
1105
+
1106
+ # Compute bootstrap statistics
1107
+ event_study_ses = {}
1108
+ event_study_cis = {}
1109
+ event_study_p_values = {}
1110
+
1111
+ for e in rel_periods:
1112
+ boot_dist = bootstrap_effects[e]
1113
+ original_effect = original_event_study[e]["effect"]
1114
+
1115
+ se = float(np.std(boot_dist, ddof=1))
1116
+ ci = self._compute_percentile_ci(boot_dist, self.alpha)
1117
+ p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
1118
+
1119
+ event_study_ses[e] = se
1120
+ event_study_cis[e] = ci
1121
+ event_study_p_values[e] = p_value
1122
+
1123
+ # Overall ATT statistics
1124
+ overall_se = float(np.std(bootstrap_overall, ddof=1))
1125
+ overall_ci = self._compute_percentile_ci(bootstrap_overall, self.alpha)
1126
+ overall_p = self._compute_bootstrap_pvalue(
1127
+ original_overall_att, bootstrap_overall
1128
+ )
1129
+
1130
+ return SABootstrapResults(
1131
+ n_bootstrap=self.n_bootstrap,
1132
+ weight_type="pairs",
1133
+ alpha=self.alpha,
1134
+ overall_att_se=overall_se,
1135
+ overall_att_ci=overall_ci,
1136
+ overall_att_p_value=overall_p,
1137
+ event_study_ses=event_study_ses,
1138
+ event_study_cis=event_study_cis,
1139
+ event_study_p_values=event_study_p_values,
1140
+ bootstrap_distribution=bootstrap_overall,
1141
+ )
1142
+
1143
+ def _compute_percentile_ci(
1144
+ self,
1145
+ boot_dist: np.ndarray,
1146
+ alpha: float,
1147
+ ) -> Tuple[float, float]:
1148
+ """Compute percentile confidence interval."""
1149
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
1150
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
1151
+ return (lower, upper)
1152
+
1153
+ def _compute_bootstrap_pvalue(
1154
+ self,
1155
+ original_effect: float,
1156
+ boot_dist: np.ndarray,
1157
+ ) -> float:
1158
+ """Compute two-sided bootstrap p-value."""
1159
+ if original_effect >= 0:
1160
+ p_one_sided = float(np.mean(boot_dist <= 0))
1161
+ else:
1162
+ p_one_sided = float(np.mean(boot_dist >= 0))
1163
+
1164
+ p_value = min(2 * p_one_sided, 1.0)
1165
+ p_value = max(p_value, 1 / (self.n_bootstrap + 1))
1166
+
1167
+ return p_value
1168
+
1169
+ def get_params(self) -> Dict[str, Any]:
1170
+ """Get estimator parameters (sklearn-compatible)."""
1171
+ return {
1172
+ "control_group": self.control_group,
1173
+ "anticipation": self.anticipation,
1174
+ "alpha": self.alpha,
1175
+ "cluster": self.cluster,
1176
+ "n_bootstrap": self.n_bootstrap,
1177
+ "seed": self.seed,
1178
+ }
1179
+
1180
+ def set_params(self, **params) -> "SunAbraham":
1181
+ """Set estimator parameters (sklearn-compatible)."""
1182
+ for key, value in params.items():
1183
+ if hasattr(self, key):
1184
+ setattr(self, key, value)
1185
+ else:
1186
+ raise ValueError(f"Unknown parameter: {key}")
1187
+ return self
1188
+
1189
+ def summary(self) -> str:
1190
+ """Get summary of estimation results."""
1191
+ if not self.is_fitted_:
1192
+ raise RuntimeError("Model must be fitted before calling summary()")
1193
+ assert self.results_ is not None
1194
+ return self.results_.summary()
1195
+
1196
+ def print_summary(self) -> None:
1197
+ """Print summary to stdout."""
1198
+ print(self.summary())