diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1176 @@
1
+ """
2
+ Sun-Abraham Interaction-Weighted Estimator for staggered DiD.
3
+
4
+ Implements the estimator from Sun & Abraham (2021), "Estimating dynamic
5
+ treatment effects in event studies with heterogeneous treatment effects",
6
+ Journal of Econometrics.
7
+
8
+ This provides an alternative to Callaway-Sant'Anna using a saturated
9
+ regression with cohort × relative-time interactions.
10
+ """
11
+
12
+ import warnings
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from diff_diff.linalg import LinearRegression, compute_robust_vcov
20
+ from diff_diff.results import _get_significance_stars
21
+ from diff_diff.utils import (
22
+ compute_confidence_interval,
23
+ compute_p_value,
24
+ within_transform as _within_transform_util,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class SunAbrahamResults:
30
+ """
31
+ Results from Sun-Abraham (2021) interaction-weighted estimation.
32
+
33
+ Attributes
34
+ ----------
35
+ event_study_effects : dict
36
+ Dictionary mapping relative time to effect dictionaries with keys:
37
+ 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_groups'.
38
+ overall_att : float
39
+ Overall average treatment effect (weighted average of post-treatment effects).
40
+ overall_se : float
41
+ Standard error of overall ATT.
42
+ overall_t_stat : float
43
+ T-statistic for overall ATT.
44
+ overall_p_value : float
45
+ P-value for overall ATT.
46
+ overall_conf_int : tuple
47
+ Confidence interval for overall ATT.
48
+ cohort_weights : dict
49
+ Dictionary mapping relative time to cohort weight dictionaries.
50
+ groups : list
51
+ List of treatment cohorts (first treatment periods).
52
+ time_periods : list
53
+ List of all time periods.
54
+ n_obs : int
55
+ Total number of observations.
56
+ n_treated_units : int
57
+ Number of ever-treated units.
58
+ n_control_units : int
59
+ Number of never-treated units.
60
+ alpha : float
61
+ Significance level used for confidence intervals.
62
+ control_group : str
63
+ Type of control group used.
64
+ """
65
+
66
+ event_study_effects: Dict[int, Dict[str, Any]]
67
+ overall_att: float
68
+ overall_se: float
69
+ overall_t_stat: float
70
+ overall_p_value: float
71
+ overall_conf_int: Tuple[float, float]
72
+ cohort_weights: Dict[int, Dict[Any, float]]
73
+ groups: List[Any]
74
+ time_periods: List[Any]
75
+ n_obs: int
76
+ n_treated_units: int
77
+ n_control_units: int
78
+ alpha: float = 0.05
79
+ control_group: str = "never_treated"
80
+ bootstrap_results: Optional["SABootstrapResults"] = field(default=None, repr=False)
81
+ cohort_effects: Optional[Dict[Tuple[Any, int], Dict[str, Any]]] = field(
82
+ default=None, repr=False
83
+ )
84
+
85
+ def __repr__(self) -> str:
86
+ """Concise string representation."""
87
+ sig = _get_significance_stars(self.overall_p_value)
88
+ n_rel_periods = len(self.event_study_effects)
89
+ return (
90
+ f"SunAbrahamResults(ATT={self.overall_att:.4f}{sig}, "
91
+ f"SE={self.overall_se:.4f}, "
92
+ f"n_groups={len(self.groups)}, "
93
+ f"n_rel_periods={n_rel_periods})"
94
+ )
95
+
96
+ def summary(self, alpha: Optional[float] = None) -> str:
97
+ """
98
+ Generate formatted summary of estimation results.
99
+
100
+ Parameters
101
+ ----------
102
+ alpha : float, optional
103
+ Significance level. Defaults to alpha used in estimation.
104
+
105
+ Returns
106
+ -------
107
+ str
108
+ Formatted summary.
109
+ """
110
+ alpha = alpha or self.alpha
111
+ conf_level = int((1 - alpha) * 100)
112
+
113
+ lines = [
114
+ "=" * 85,
115
+ "Sun-Abraham Interaction-Weighted Estimator Results".center(85),
116
+ "=" * 85,
117
+ "",
118
+ f"{'Total observations:':<30} {self.n_obs:>10}",
119
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
120
+ f"{'Control units:':<30} {self.n_control_units:>10}",
121
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
122
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
123
+ f"{'Control group:':<30} {self.control_group:>10}",
124
+ "",
125
+ ]
126
+
127
+ # Overall ATT
128
+ lines.extend(
129
+ [
130
+ "-" * 85,
131
+ "Overall Average Treatment Effect on the Treated".center(85),
132
+ "-" * 85,
133
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
134
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
135
+ "-" * 85,
136
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
137
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
138
+ f"{_get_significance_stars(self.overall_p_value):>6}",
139
+ "-" * 85,
140
+ "",
141
+ f"{conf_level}% Confidence Interval: "
142
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
143
+ "",
144
+ ]
145
+ )
146
+
147
+ # Event study effects
148
+ lines.extend(
149
+ [
150
+ "-" * 85,
151
+ "Event Study (Dynamic) Effects".center(85),
152
+ "-" * 85,
153
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
154
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
155
+ "-" * 85,
156
+ ]
157
+ )
158
+
159
+ for rel_t in sorted(self.event_study_effects.keys()):
160
+ eff = self.event_study_effects[rel_t]
161
+ sig = _get_significance_stars(eff["p_value"])
162
+ lines.append(
163
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
164
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
165
+ )
166
+
167
+ lines.extend(["-" * 85, ""])
168
+
169
+ lines.extend(
170
+ [
171
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
172
+ "=" * 85,
173
+ ]
174
+ )
175
+
176
+ return "\n".join(lines)
177
+
178
+ def print_summary(self, alpha: Optional[float] = None) -> None:
179
+ """Print summary to stdout."""
180
+ print(self.summary(alpha))
181
+
182
+ def to_dataframe(self, level: str = "event_study") -> pd.DataFrame:
183
+ """
184
+ Convert results to DataFrame.
185
+
186
+ Parameters
187
+ ----------
188
+ level : str, default="event_study"
189
+ Level of aggregation: "event_study" or "cohort".
190
+
191
+ Returns
192
+ -------
193
+ pd.DataFrame
194
+ Results as DataFrame.
195
+ """
196
+ if level == "event_study":
197
+ rows = []
198
+ for rel_t, data in sorted(self.event_study_effects.items()):
199
+ rows.append(
200
+ {
201
+ "relative_period": rel_t,
202
+ "effect": data["effect"],
203
+ "se": data["se"],
204
+ "t_stat": data["t_stat"],
205
+ "p_value": data["p_value"],
206
+ "conf_int_lower": data["conf_int"][0],
207
+ "conf_int_upper": data["conf_int"][1],
208
+ }
209
+ )
210
+ return pd.DataFrame(rows)
211
+
212
+ elif level == "cohort":
213
+ if self.cohort_effects is None:
214
+ raise ValueError(
215
+ "Cohort-level effects not available. "
216
+ "They are computed internally but not stored by default."
217
+ )
218
+ rows = []
219
+ for (cohort, rel_t), data in sorted(self.cohort_effects.items()):
220
+ rows.append(
221
+ {
222
+ "cohort": cohort,
223
+ "relative_period": rel_t,
224
+ "effect": data["effect"],
225
+ "se": data["se"],
226
+ "weight": data.get("weight", np.nan),
227
+ }
228
+ )
229
+ return pd.DataFrame(rows)
230
+
231
+ else:
232
+ raise ValueError(
233
+ f"Unknown level: {level}. Use 'event_study' or 'cohort'."
234
+ )
235
+
236
+ @property
237
+ def is_significant(self) -> bool:
238
+ """Check if overall ATT is significant."""
239
+ return bool(self.overall_p_value < self.alpha)
240
+
241
+ @property
242
+ def significance_stars(self) -> str:
243
+ """Significance stars for overall ATT."""
244
+ return _get_significance_stars(self.overall_p_value)
245
+
246
+
247
+ @dataclass
248
+ class SABootstrapResults:
249
+ """
250
+ Results from Sun-Abraham bootstrap inference.
251
+
252
+ Attributes
253
+ ----------
254
+ n_bootstrap : int
255
+ Number of bootstrap iterations.
256
+ weight_type : str
257
+ Type of bootstrap used (always "pairs" for pairs bootstrap).
258
+ alpha : float
259
+ Significance level used for confidence intervals.
260
+ overall_att_se : float
261
+ Bootstrap standard error for overall ATT.
262
+ overall_att_ci : Tuple[float, float]
263
+ Bootstrap confidence interval for overall ATT.
264
+ overall_att_p_value : float
265
+ Bootstrap p-value for overall ATT.
266
+ event_study_ses : Dict[int, float]
267
+ Bootstrap SEs for event study effects.
268
+ event_study_cis : Dict[int, Tuple[float, float]]
269
+ Bootstrap CIs for event study effects.
270
+ event_study_p_values : Dict[int, float]
271
+ Bootstrap p-values for event study effects.
272
+ bootstrap_distribution : Optional[np.ndarray]
273
+ Full bootstrap distribution of overall ATT.
274
+ """
275
+
276
+ n_bootstrap: int
277
+ weight_type: str
278
+ alpha: float
279
+ overall_att_se: float
280
+ overall_att_ci: Tuple[float, float]
281
+ overall_att_p_value: float
282
+ event_study_ses: Dict[int, float]
283
+ event_study_cis: Dict[int, Tuple[float, float]]
284
+ event_study_p_values: Dict[int, float]
285
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
286
+
287
+
288
+ class SunAbraham:
289
+ """
290
+ Sun-Abraham (2021) interaction-weighted estimator for staggered DiD.
291
+
292
+ This estimator provides event-study coefficients using a saturated
293
+ TWFE regression with cohort × relative-time interactions, following
294
+ the methodology in Sun & Abraham (2021).
295
+
296
+ The estimation procedure follows three steps:
297
+ 1. Run a saturated TWFE regression with cohort × relative-time dummies
298
+ 2. Compute cohort shares (weights) at each relative time
299
+ 3. Aggregate cohort-specific effects using interaction weights
300
+
301
+ This avoids the negative weighting problem of standard TWFE and provides
302
+ consistent event-study estimates under treatment effect heterogeneity.
303
+
304
+ Parameters
305
+ ----------
306
+ control_group : str, default="never_treated"
307
+ Which units to use as controls:
308
+ - "never_treated": Use only never-treated units (recommended)
309
+ - "not_yet_treated": Use never-treated and not-yet-treated units
310
+ anticipation : int, default=0
311
+ Number of periods before treatment where effects may occur.
312
+ alpha : float, default=0.05
313
+ Significance level for confidence intervals.
314
+ cluster : str, optional
315
+ Column name for cluster-robust standard errors.
316
+ If None, clusters at the unit level by default.
317
+ n_bootstrap : int, default=0
318
+ Number of bootstrap iterations for inference.
319
+ If 0, uses analytical cluster-robust standard errors.
320
+ seed : int, optional
321
+ Random seed for reproducibility.
322
+
323
+ Attributes
324
+ ----------
325
+ results_ : SunAbrahamResults
326
+ Estimation results after calling fit().
327
+ is_fitted_ : bool
328
+ Whether the model has been fitted.
329
+
330
+ Examples
331
+ --------
332
+ Basic usage:
333
+
334
+ >>> import pandas as pd
335
+ >>> from diff_diff import SunAbraham
336
+ >>>
337
+ >>> # Panel data with staggered treatment
338
+ >>> data = pd.DataFrame({
339
+ ... 'unit': [...],
340
+ ... 'time': [...],
341
+ ... 'outcome': [...],
342
+ ... 'first_treat': [...] # 0 for never-treated
343
+ ... })
344
+ >>>
345
+ >>> sa = SunAbraham()
346
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
347
+ ... time='time', first_treat='first_treat')
348
+ >>> results.print_summary()
349
+
350
+ With covariates:
351
+
352
+ >>> sa = SunAbraham()
353
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
354
+ ... time='time', first_treat='first_treat',
355
+ ... covariates=['age', 'income'])
356
+
357
+ Notes
358
+ -----
359
+ The Sun-Abraham estimator uses a saturated regression approach:
360
+
361
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × 1(G_i=g) × D_{it}^e] + X'γ + ε_it
362
+
363
+ where:
364
+ - α_i = unit fixed effects
365
+ - λ_t = time fixed effects
366
+ - G_i = unit i's treatment cohort (first treatment period)
367
+ - D_{it}^e = indicator for being e periods from treatment
368
+ - δ_{g,e} = cohort-specific effect (CATT) at relative time e
369
+
370
+ The event-study coefficients are then computed as:
371
+
372
+ β_e = Σ_g w_{g,e} × δ_{g,e}
373
+
374
+ where w_{g,e} is the share of cohort g in the treated population at
375
+ relative time e (interaction weights).
376
+
377
+ Compared to Callaway-Sant'Anna:
378
+ - SA uses saturated regression; CS uses 2x2 DiD comparisons
379
+ - SA can be more efficient when model is correctly specified
380
+ - Both are consistent under heterogeneous treatment effects
381
+ - Running both provides a useful robustness check
382
+
383
+ References
384
+ ----------
385
+ Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in
386
+ event studies with heterogeneous treatment effects. Journal of
387
+ Econometrics, 225(2), 175-199.
388
+ """
389
+
390
+ def __init__(
391
+ self,
392
+ control_group: str = "never_treated",
393
+ anticipation: int = 0,
394
+ alpha: float = 0.05,
395
+ cluster: Optional[str] = None,
396
+ n_bootstrap: int = 0,
397
+ seed: Optional[int] = None,
398
+ ):
399
+ if control_group not in ["never_treated", "not_yet_treated"]:
400
+ raise ValueError(
401
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
402
+ f"got '{control_group}'"
403
+ )
404
+
405
+ self.control_group = control_group
406
+ self.anticipation = anticipation
407
+ self.alpha = alpha
408
+ self.cluster = cluster
409
+ self.n_bootstrap = n_bootstrap
410
+ self.seed = seed
411
+
412
+ self.is_fitted_ = False
413
+ self.results_: Optional[SunAbrahamResults] = None
414
+ self._reference_period = -1 # Will be set during fit
415
+
416
+ def fit(
417
+ self,
418
+ data: pd.DataFrame,
419
+ outcome: str,
420
+ unit: str,
421
+ time: str,
422
+ first_treat: str,
423
+ covariates: Optional[List[str]] = None,
424
+ min_pre_periods: int = 1,
425
+ min_post_periods: int = 1,
426
+ ) -> SunAbrahamResults:
427
+ """
428
+ Fit the Sun-Abraham estimator using saturated regression.
429
+
430
+ Parameters
431
+ ----------
432
+ data : pd.DataFrame
433
+ Panel data with unit and time identifiers.
434
+ outcome : str
435
+ Name of outcome variable column.
436
+ unit : str
437
+ Name of unit identifier column.
438
+ time : str
439
+ Name of time period column.
440
+ first_treat : str
441
+ Name of column indicating when unit was first treated.
442
+ Use 0 (or np.inf) for never-treated units.
443
+ covariates : list, optional
444
+ List of covariate column names to include in regression.
445
+ min_pre_periods : int, default=1
446
+ Minimum number of pre-treatment periods to include in event study.
447
+ min_post_periods : int, default=1
448
+ Minimum number of post-treatment periods to include in event study.
449
+
450
+ Returns
451
+ -------
452
+ SunAbrahamResults
453
+ Object containing all estimation results.
454
+
455
+ Raises
456
+ ------
457
+ ValueError
458
+ If required columns are missing or data validation fails.
459
+ """
460
+ # Validate inputs
461
+ required_cols = [outcome, unit, time, first_treat]
462
+ if covariates:
463
+ required_cols.extend(covariates)
464
+
465
+ missing = [c for c in required_cols if c not in data.columns]
466
+ if missing:
467
+ raise ValueError(f"Missing columns: {missing}")
468
+
469
+ # Create working copy
470
+ df = data.copy()
471
+
472
+ # Ensure numeric types
473
+ df[time] = pd.to_numeric(df[time])
474
+ df[first_treat] = pd.to_numeric(df[first_treat])
475
+
476
+ # Identify groups and time periods
477
+ time_periods = sorted(df[time].unique())
478
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
479
+
480
+ # Never-treated indicator
481
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
482
+
483
+ # Get unique units
484
+ unit_info = (
485
+ df.groupby(unit)
486
+ .agg({first_treat: "first", "_never_treated": "first"})
487
+ .reset_index()
488
+ )
489
+
490
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
491
+ n_control_units = int((unit_info["_never_treated"]).sum())
492
+
493
+ if n_control_units == 0:
494
+ raise ValueError(
495
+ "No never-treated units found. Check 'first_treat' column."
496
+ )
497
+
498
+ if len(treatment_groups) == 0:
499
+ raise ValueError(
500
+ "No treated units found. Check 'first_treat' column."
501
+ )
502
+
503
+ # Compute relative time for each observation (vectorized)
504
+ df["_rel_time"] = np.where(
505
+ df[first_treat] > 0,
506
+ df[time] - df[first_treat],
507
+ np.nan
508
+ )
509
+
510
+ # Identify the range of relative time periods to estimate
511
+ rel_times_by_cohort = {}
512
+ for g in treatment_groups:
513
+ g_times = df[df[first_treat] == g][time].unique()
514
+ rel_times_by_cohort[g] = sorted([t - g for t in g_times])
515
+
516
+ # Find all relative time values
517
+ all_rel_times: set = set()
518
+ for g, rel_times in rel_times_by_cohort.items():
519
+ all_rel_times.update(rel_times)
520
+
521
+ all_rel_times_sorted = sorted(all_rel_times)
522
+
523
+ # Filter to reasonable range
524
+ min_rel = max(min(all_rel_times_sorted), -20) # cap at -20
525
+ max_rel = min(max(all_rel_times_sorted), 20) # cap at +20
526
+
527
+ # Reference period: last pre-treatment period (typically -1)
528
+ self._reference_period = -1 - self.anticipation
529
+
530
+ # Get relative periods to estimate (excluding reference)
531
+ rel_periods_to_estimate = [
532
+ e
533
+ for e in all_rel_times_sorted
534
+ if min_rel <= e <= max_rel and e != self._reference_period
535
+ ]
536
+
537
+ # Determine cluster variable
538
+ cluster_var = self.cluster if self.cluster is not None else unit
539
+
540
+ # Filter data based on control_group setting
541
+ if self.control_group == "never_treated":
542
+ # Only keep never-treated as controls
543
+ df_reg = df[df["_never_treated"] | (df[first_treat] > 0)].copy()
544
+ else:
545
+ # Keep all units (not_yet_treated will be handled by the regression)
546
+ df_reg = df.copy()
547
+
548
+ # Fit saturated regression
549
+ (
550
+ cohort_effects,
551
+ cohort_ses,
552
+ vcov_cohort,
553
+ coef_index_map,
554
+ ) = self._fit_saturated_regression(
555
+ df_reg,
556
+ outcome,
557
+ unit,
558
+ time,
559
+ first_treat,
560
+ treatment_groups,
561
+ rel_periods_to_estimate,
562
+ covariates,
563
+ cluster_var,
564
+ )
565
+
566
+ # Compute interaction-weighted event study effects
567
+ event_study_effects, cohort_weights = self._compute_iw_effects(
568
+ df,
569
+ unit,
570
+ first_treat,
571
+ treatment_groups,
572
+ rel_periods_to_estimate,
573
+ cohort_effects,
574
+ cohort_ses,
575
+ vcov_cohort,
576
+ coef_index_map,
577
+ )
578
+
579
+ # Compute overall ATT (average of post-treatment effects)
580
+ overall_att, overall_se = self._compute_overall_att(
581
+ df,
582
+ first_treat,
583
+ event_study_effects,
584
+ cohort_effects,
585
+ cohort_weights,
586
+ vcov_cohort,
587
+ coef_index_map,
588
+ )
589
+
590
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
591
+ overall_p = compute_p_value(overall_t)
592
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
593
+
594
+ # Run bootstrap if requested
595
+ bootstrap_results = None
596
+ if self.n_bootstrap > 0:
597
+ bootstrap_results = self._run_bootstrap(
598
+ df=df_reg,
599
+ outcome=outcome,
600
+ unit=unit,
601
+ time=time,
602
+ first_treat=first_treat,
603
+ treatment_groups=treatment_groups,
604
+ rel_periods_to_estimate=rel_periods_to_estimate,
605
+ covariates=covariates,
606
+ cluster_var=cluster_var,
607
+ original_event_study=event_study_effects,
608
+ original_overall_att=overall_att,
609
+ )
610
+
611
+ # Update results with bootstrap inference
612
+ overall_se = bootstrap_results.overall_att_se
613
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
614
+ overall_p = bootstrap_results.overall_att_p_value
615
+ overall_ci = bootstrap_results.overall_att_ci
616
+
617
+ # Update event study effects
618
+ for e in event_study_effects:
619
+ if e in bootstrap_results.event_study_ses:
620
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
621
+ event_study_effects[e]["conf_int"] = (
622
+ bootstrap_results.event_study_cis[e]
623
+ )
624
+ event_study_effects[e]["p_value"] = (
625
+ bootstrap_results.event_study_p_values[e]
626
+ )
627
+ eff_val = event_study_effects[e]["effect"]
628
+ se_val = event_study_effects[e]["se"]
629
+ event_study_effects[e]["t_stat"] = (
630
+ eff_val / se_val if se_val > 0 else 0.0
631
+ )
632
+
633
+ # Convert cohort effects to storage format
634
+ cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {}
635
+ for (g, e), effect in cohort_effects.items():
636
+ weight = cohort_weights.get(e, {}).get(g, 0.0)
637
+ se = cohort_ses.get((g, e), 0.0)
638
+ cohort_effects_storage[(g, e)] = {
639
+ "effect": effect,
640
+ "se": se,
641
+ "weight": weight,
642
+ }
643
+
644
+ # Store results
645
+ self.results_ = SunAbrahamResults(
646
+ event_study_effects=event_study_effects,
647
+ overall_att=overall_att,
648
+ overall_se=overall_se,
649
+ overall_t_stat=overall_t,
650
+ overall_p_value=overall_p,
651
+ overall_conf_int=overall_ci,
652
+ cohort_weights=cohort_weights,
653
+ groups=treatment_groups,
654
+ time_periods=time_periods,
655
+ n_obs=len(df),
656
+ n_treated_units=n_treated_units,
657
+ n_control_units=n_control_units,
658
+ alpha=self.alpha,
659
+ control_group=self.control_group,
660
+ bootstrap_results=bootstrap_results,
661
+ cohort_effects=cohort_effects_storage,
662
+ )
663
+
664
+ self.is_fitted_ = True
665
+ return self.results_
666
+
667
+ def _fit_saturated_regression(
668
+ self,
669
+ df: pd.DataFrame,
670
+ outcome: str,
671
+ unit: str,
672
+ time: str,
673
+ first_treat: str,
674
+ treatment_groups: List[Any],
675
+ rel_periods: List[int],
676
+ covariates: Optional[List[str]],
677
+ cluster_var: str,
678
+ ) -> Tuple[
679
+ Dict[Tuple[Any, int], float],
680
+ Dict[Tuple[Any, int], float],
681
+ np.ndarray,
682
+ Dict[Tuple[Any, int], int],
683
+ ]:
684
+ """
685
+ Fit saturated TWFE regression with cohort × relative-time interactions.
686
+
687
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε
688
+
689
+ Uses within-transformation for unit fixed effects and time dummies.
690
+
691
+ Returns
692
+ -------
693
+ cohort_effects : dict
694
+ Mapping (cohort, rel_period) -> effect estimate δ_{g,e}
695
+ cohort_ses : dict
696
+ Mapping (cohort, rel_period) -> standard error
697
+ vcov : np.ndarray
698
+ Variance-covariance matrix for cohort effects
699
+ coef_index_map : dict
700
+ Mapping (cohort, rel_period) -> index in coefficient vector
701
+ """
702
+ df = df.copy()
703
+
704
+ # Create cohort × relative-time interaction dummies
705
+ # Exclude reference period
706
+ # Build all columns at once to avoid fragmentation
707
+ interaction_data = {}
708
+ coef_index_map: Dict[Tuple[Any, int], int] = {}
709
+ idx = 0
710
+
711
+ for g in treatment_groups:
712
+ for e in rel_periods:
713
+ col_name = f"_D_{g}_{e}"
714
+ # Indicator: unit is in cohort g AND at relative time e
715
+ indicator = (
716
+ (df[first_treat] == g) &
717
+ (df["_rel_time"] == e)
718
+ ).astype(float)
719
+
720
+ # Only include if there are observations
721
+ if indicator.sum() > 0:
722
+ interaction_data[col_name] = indicator.values
723
+ coef_index_map[(g, e)] = idx
724
+ idx += 1
725
+
726
+ # Add all interaction columns at once
727
+ interaction_cols = list(interaction_data.keys())
728
+ if interaction_data:
729
+ interaction_df = pd.DataFrame(interaction_data, index=df.index)
730
+ df = pd.concat([df, interaction_df], axis=1)
731
+
732
+ if len(interaction_cols) == 0:
733
+ raise ValueError(
734
+ "No valid cohort × relative-time interactions found. "
735
+ "Check your data structure."
736
+ )
737
+
738
+ # Apply within-transformation for unit and time fixed effects
739
+ variables_to_demean = [outcome] + interaction_cols
740
+ if covariates:
741
+ variables_to_demean.extend(covariates)
742
+
743
+ df_demeaned = self._within_transform(df, variables_to_demean, unit, time)
744
+
745
+ # Build design matrix
746
+ X_cols = [f"{col}_dm" for col in interaction_cols]
747
+ if covariates:
748
+ X_cols.extend([f"{cov}_dm" for cov in covariates])
749
+
750
+ X = df_demeaned[X_cols].values
751
+ y = df_demeaned[f"{outcome}_dm"].values
752
+
753
+ # Fit OLS using LinearRegression helper (more stable than manual X'X inverse)
754
+ cluster_ids = df_demeaned[cluster_var].values
755
+ reg = LinearRegression(
756
+ include_intercept=False, # Already demeaned, no intercept needed
757
+ robust=True,
758
+ cluster_ids=cluster_ids,
759
+ ).fit(X, y)
760
+
761
+ coefficients = reg.coefficients_
762
+ vcov = reg.vcov_
763
+
764
+ # Extract cohort effects and standard errors using get_inference
765
+ cohort_effects: Dict[Tuple[Any, int], float] = {}
766
+ cohort_ses: Dict[Tuple[Any, int], float] = {}
767
+
768
+ n_interactions = len(interaction_cols)
769
+ for (g, e), coef_idx in coef_index_map.items():
770
+ inference = reg.get_inference(coef_idx)
771
+ cohort_effects[(g, e)] = inference.coefficient
772
+ cohort_ses[(g, e)] = inference.se
773
+
774
+ # Extract just the vcov for cohort effects (excluding covariates)
775
+ vcov_cohort = vcov[:n_interactions, :n_interactions]
776
+
777
+ return cohort_effects, cohort_ses, vcov_cohort, coef_index_map
778
+
779
+ def _within_transform(
780
+ self,
781
+ df: pd.DataFrame,
782
+ variables: List[str],
783
+ unit: str,
784
+ time: str,
785
+ ) -> pd.DataFrame:
786
+ """
787
+ Apply two-way within transformation to remove unit and time fixed effects.
788
+
789
+ y_it - y_i. - y_.t + y_..
790
+ """
791
+ return _within_transform_util(df, variables, unit, time, suffix="_dm")
792
+
793
+ def _compute_iw_effects(
794
+ self,
795
+ df: pd.DataFrame,
796
+ unit: str,
797
+ first_treat: str,
798
+ treatment_groups: List[Any],
799
+ rel_periods: List[int],
800
+ cohort_effects: Dict[Tuple[Any, int], float],
801
+ cohort_ses: Dict[Tuple[Any, int], float],
802
+ vcov_cohort: np.ndarray,
803
+ coef_index_map: Dict[Tuple[Any, int], int],
804
+ ) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]:
805
+ """
806
+ Compute interaction-weighted event study effects.
807
+
808
+ β_e = Σ_g w_{g,e} × δ_{g,e}
809
+
810
+ where w_{g,e} is the share of cohort g among treated units at relative time e.
811
+
812
+ Returns
813
+ -------
814
+ event_study_effects : dict
815
+ Dictionary mapping relative period to aggregated effect info.
816
+ cohort_weights : dict
817
+ Dictionary mapping relative period to cohort weight dictionary.
818
+ """
819
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
820
+ cohort_weights: Dict[int, Dict[Any, float]] = {}
821
+
822
+ # Get cohort sizes
823
+ unit_cohorts = df.groupby(unit)[first_treat].first()
824
+ cohort_sizes = unit_cohorts[unit_cohorts > 0].value_counts().to_dict()
825
+
826
+ for e in rel_periods:
827
+ # Get cohorts that have observations at this relative time
828
+ cohorts_at_e = [
829
+ g for g in treatment_groups
830
+ if (g, e) in cohort_effects
831
+ ]
832
+
833
+ if not cohorts_at_e:
834
+ continue
835
+
836
+ # Compute IW weights: share of each cohort among those observed at e
837
+ weights = {}
838
+ total_size = 0
839
+ for g in cohorts_at_e:
840
+ n_g = cohort_sizes.get(g, 0)
841
+ weights[g] = n_g
842
+ total_size += n_g
843
+
844
+ if total_size == 0:
845
+ continue
846
+
847
+ # Normalize weights
848
+ for g in weights:
849
+ weights[g] = weights[g] / total_size
850
+
851
+ cohort_weights[e] = weights
852
+
853
+ # Compute weighted average effect
854
+ agg_effect = 0.0
855
+ for g in cohorts_at_e:
856
+ w = weights[g]
857
+ agg_effect += w * cohort_effects[(g, e)]
858
+
859
+ # Compute SE using delta method with vcov
860
+ # Var(β_e) = w' Σ w where w is weight vector and Σ is vcov submatrix
861
+ indices = [coef_index_map[(g, e)] for g in cohorts_at_e]
862
+ weight_vec = np.array([weights[g] for g in cohorts_at_e])
863
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
864
+ agg_var = float(weight_vec @ vcov_subset @ weight_vec)
865
+ agg_se = np.sqrt(max(agg_var, 0))
866
+
867
+ t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
868
+ p_val = compute_p_value(t_stat)
869
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
870
+
871
+ event_study_effects[e] = {
872
+ "effect": agg_effect,
873
+ "se": agg_se,
874
+ "t_stat": t_stat,
875
+ "p_value": p_val,
876
+ "conf_int": ci,
877
+ "n_groups": len(cohorts_at_e),
878
+ }
879
+
880
+ return event_study_effects, cohort_weights
881
+
882
+ def _compute_overall_att(
883
+ self,
884
+ df: pd.DataFrame,
885
+ first_treat: str,
886
+ event_study_effects: Dict[int, Dict[str, Any]],
887
+ cohort_effects: Dict[Tuple[Any, int], float],
888
+ cohort_weights: Dict[int, Dict[Any, float]],
889
+ vcov_cohort: np.ndarray,
890
+ coef_index_map: Dict[Tuple[Any, int], int],
891
+ ) -> Tuple[float, float]:
892
+ """
893
+ Compute overall ATT as weighted average of post-treatment effects.
894
+
895
+ Returns (att, se) tuple.
896
+ """
897
+ post_effects = [
898
+ (e, eff)
899
+ for e, eff in event_study_effects.items()
900
+ if e >= 0
901
+ ]
902
+
903
+ if not post_effects:
904
+ return 0.0, 0.0
905
+
906
+ # Weight by number of treated observations at each relative time
907
+ post_weights = []
908
+ post_estimates = []
909
+
910
+ for e, eff in post_effects:
911
+ n_at_e = len(df[(df["_rel_time"] == e) & (df[first_treat] > 0)])
912
+ post_weights.append(max(n_at_e, 1))
913
+ post_estimates.append(eff["effect"])
914
+
915
+ post_weights = np.array(post_weights, dtype=float)
916
+ post_weights = post_weights / post_weights.sum()
917
+
918
+ overall_att = float(np.sum(post_weights * np.array(post_estimates)))
919
+
920
+ # Compute SE using delta method
921
+ # Need to trace back through the full weighting scheme
922
+ # ATT = Σ_e w_e × β_e = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e}
923
+ # Collect all (g, e) pairs and their overall weights
924
+ overall_weights_by_coef: Dict[Tuple[Any, int], float] = {}
925
+
926
+ for i, (e, _) in enumerate(post_effects):
927
+ period_weight = post_weights[i]
928
+ if e in cohort_weights:
929
+ for g, cw in cohort_weights[e].items():
930
+ key = (g, e)
931
+ if key in coef_index_map:
932
+ if key not in overall_weights_by_coef:
933
+ overall_weights_by_coef[key] = 0.0
934
+ overall_weights_by_coef[key] += period_weight * cw
935
+
936
+ if not overall_weights_by_coef:
937
+ # Fallback to simple variance calculation
938
+ overall_var = float(
939
+ np.sum((post_weights ** 2) * np.array([eff["se"] ** 2 for _, eff in post_effects]))
940
+ )
941
+ return overall_att, np.sqrt(overall_var)
942
+
943
+ # Build full weight vector and compute variance
944
+ indices = [coef_index_map[key] for key in overall_weights_by_coef.keys()]
945
+ weight_vec = np.array(list(overall_weights_by_coef.values()))
946
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
947
+ overall_var = float(weight_vec @ vcov_subset @ weight_vec)
948
+ overall_se = np.sqrt(max(overall_var, 0))
949
+
950
+ return overall_att, overall_se
951
+
952
+ def _run_bootstrap(
953
+ self,
954
+ df: pd.DataFrame,
955
+ outcome: str,
956
+ unit: str,
957
+ time: str,
958
+ first_treat: str,
959
+ treatment_groups: List[Any],
960
+ rel_periods_to_estimate: List[int],
961
+ covariates: Optional[List[str]],
962
+ cluster_var: str,
963
+ original_event_study: Dict[int, Dict[str, Any]],
964
+ original_overall_att: float,
965
+ ) -> SABootstrapResults:
966
+ """
967
+ Run pairs bootstrap for inference.
968
+
969
+ Resamples units with replacement and re-estimates the full model.
970
+ """
971
+ if self.n_bootstrap < 50:
972
+ warnings.warn(
973
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
974
+ "for reliable inference.",
975
+ UserWarning,
976
+ stacklevel=3,
977
+ )
978
+
979
+ rng = np.random.default_rng(self.seed)
980
+
981
+ # Get unique units
982
+ all_units = df[unit].unique()
983
+ n_units = len(all_units)
984
+
985
+ # Store bootstrap samples
986
+ rel_periods = sorted(original_event_study.keys())
987
+ bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods}
988
+ bootstrap_overall = np.zeros(self.n_bootstrap)
989
+
990
+ for b in range(self.n_bootstrap):
991
+ # Resample units with replacement (pairs bootstrap)
992
+ boot_units = rng.choice(all_units, size=n_units, replace=True)
993
+
994
+ # Create bootstrap sample efficiently
995
+ # Build index array for all selected units
996
+ boot_indices = np.concatenate([
997
+ df.index[df[unit] == u].values for u in boot_units
998
+ ])
999
+ df_b = df.iloc[boot_indices].copy()
1000
+
1001
+ # Reassign unique unit IDs for bootstrap sample
1002
+ # Each resampled unit gets a unique ID
1003
+ new_unit_ids = []
1004
+ current_id = 0
1005
+ for u in boot_units:
1006
+ unit_rows = df[df[unit] == u]
1007
+ for _ in range(len(unit_rows)):
1008
+ new_unit_ids.append(current_id)
1009
+ current_id += 1
1010
+ df_b[unit] = new_unit_ids[:len(df_b)]
1011
+
1012
+ # Recompute relative time (vectorized)
1013
+ df_b["_rel_time"] = np.where(
1014
+ df_b[first_treat] > 0,
1015
+ df_b[time] - df_b[first_treat],
1016
+ np.nan
1017
+ )
1018
+ df_b["_never_treated"] = (
1019
+ (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)
1020
+ )
1021
+
1022
+ try:
1023
+ # Re-estimate saturated regression
1024
+ (
1025
+ cohort_effects_b,
1026
+ cohort_ses_b,
1027
+ vcov_b,
1028
+ coef_map_b,
1029
+ ) = self._fit_saturated_regression(
1030
+ df_b,
1031
+ outcome,
1032
+ unit,
1033
+ time,
1034
+ first_treat,
1035
+ treatment_groups,
1036
+ rel_periods_to_estimate,
1037
+ covariates,
1038
+ cluster_var,
1039
+ )
1040
+
1041
+ # Compute IW effects for this bootstrap sample
1042
+ event_study_b, cohort_weights_b = self._compute_iw_effects(
1043
+ df_b,
1044
+ unit,
1045
+ first_treat,
1046
+ treatment_groups,
1047
+ rel_periods_to_estimate,
1048
+ cohort_effects_b,
1049
+ cohort_ses_b,
1050
+ vcov_b,
1051
+ coef_map_b,
1052
+ )
1053
+
1054
+ # Store bootstrap estimates
1055
+ for e in rel_periods:
1056
+ if e in event_study_b:
1057
+ bootstrap_effects[e][b] = event_study_b[e]["effect"]
1058
+ else:
1059
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1060
+
1061
+ # Compute overall ATT for this bootstrap sample
1062
+ overall_b, _ = self._compute_overall_att(
1063
+ df_b,
1064
+ first_treat,
1065
+ event_study_b,
1066
+ cohort_effects_b,
1067
+ cohort_weights_b,
1068
+ vcov_b,
1069
+ coef_map_b,
1070
+ )
1071
+ bootstrap_overall[b] = overall_b
1072
+
1073
+ except (ValueError, np.linalg.LinAlgError) as exc:
1074
+ # If bootstrap iteration fails, use original
1075
+ warnings.warn(
1076
+ f"Bootstrap iteration {b} failed: {exc}. Using original estimate.",
1077
+ UserWarning,
1078
+ stacklevel=2,
1079
+ )
1080
+ for e in rel_periods:
1081
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1082
+ bootstrap_overall[b] = original_overall_att
1083
+
1084
+ # Compute bootstrap statistics
1085
+ event_study_ses = {}
1086
+ event_study_cis = {}
1087
+ event_study_p_values = {}
1088
+
1089
+ for e in rel_periods:
1090
+ boot_dist = bootstrap_effects[e]
1091
+ original_effect = original_event_study[e]["effect"]
1092
+
1093
+ se = float(np.std(boot_dist, ddof=1))
1094
+ ci = self._compute_percentile_ci(boot_dist, self.alpha)
1095
+ p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
1096
+
1097
+ event_study_ses[e] = se
1098
+ event_study_cis[e] = ci
1099
+ event_study_p_values[e] = p_value
1100
+
1101
+ # Overall ATT statistics
1102
+ overall_se = float(np.std(bootstrap_overall, ddof=1))
1103
+ overall_ci = self._compute_percentile_ci(bootstrap_overall, self.alpha)
1104
+ overall_p = self._compute_bootstrap_pvalue(
1105
+ original_overall_att, bootstrap_overall
1106
+ )
1107
+
1108
+ return SABootstrapResults(
1109
+ n_bootstrap=self.n_bootstrap,
1110
+ weight_type="pairs",
1111
+ alpha=self.alpha,
1112
+ overall_att_se=overall_se,
1113
+ overall_att_ci=overall_ci,
1114
+ overall_att_p_value=overall_p,
1115
+ event_study_ses=event_study_ses,
1116
+ event_study_cis=event_study_cis,
1117
+ event_study_p_values=event_study_p_values,
1118
+ bootstrap_distribution=bootstrap_overall,
1119
+ )
1120
+
1121
+ def _compute_percentile_ci(
1122
+ self,
1123
+ boot_dist: np.ndarray,
1124
+ alpha: float,
1125
+ ) -> Tuple[float, float]:
1126
+ """Compute percentile confidence interval."""
1127
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
1128
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
1129
+ return (lower, upper)
1130
+
1131
+ def _compute_bootstrap_pvalue(
1132
+ self,
1133
+ original_effect: float,
1134
+ boot_dist: np.ndarray,
1135
+ ) -> float:
1136
+ """Compute two-sided bootstrap p-value."""
1137
+ if original_effect >= 0:
1138
+ p_one_sided = float(np.mean(boot_dist <= 0))
1139
+ else:
1140
+ p_one_sided = float(np.mean(boot_dist >= 0))
1141
+
1142
+ p_value = min(2 * p_one_sided, 1.0)
1143
+ p_value = max(p_value, 1 / (self.n_bootstrap + 1))
1144
+
1145
+ return p_value
1146
+
1147
+ def get_params(self) -> Dict[str, Any]:
1148
+ """Get estimator parameters (sklearn-compatible)."""
1149
+ return {
1150
+ "control_group": self.control_group,
1151
+ "anticipation": self.anticipation,
1152
+ "alpha": self.alpha,
1153
+ "cluster": self.cluster,
1154
+ "n_bootstrap": self.n_bootstrap,
1155
+ "seed": self.seed,
1156
+ }
1157
+
1158
+ def set_params(self, **params) -> "SunAbraham":
1159
+ """Set estimator parameters (sklearn-compatible)."""
1160
+ for key, value in params.items():
1161
+ if hasattr(self, key):
1162
+ setattr(self, key, value)
1163
+ else:
1164
+ raise ValueError(f"Unknown parameter: {key}")
1165
+ return self
1166
+
1167
+ def summary(self) -> str:
1168
+ """Get summary of estimation results."""
1169
+ if not self.is_fitted_:
1170
+ raise RuntimeError("Model must be fitted before calling summary()")
1171
+ assert self.results_ is not None
1172
+ return self.results_.summary()
1173
+
1174
+ def print_summary(self) -> None:
1175
+ """Print summary to stdout."""
1176
+ print(self.summary())