diff-diff 2.2.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1191 @@
1
+ """
2
+ Sun-Abraham Interaction-Weighted Estimator for staggered DiD.
3
+
4
+ Implements the estimator from Sun & Abraham (2021), "Estimating dynamic
5
+ treatment effects in event studies with heterogeneous treatment effects",
6
+ Journal of Econometrics.
7
+
8
+ This provides an alternative to Callaway-Sant'Anna using a saturated
9
+ regression with cohort × relative-time interactions.
10
+ """
11
+
12
+ import warnings
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from diff_diff.linalg import LinearRegression, compute_robust_vcov
20
+ from diff_diff.results import _get_significance_stars
21
+ from diff_diff.utils import (
22
+ compute_confidence_interval,
23
+ compute_p_value,
24
+ within_transform as _within_transform_util,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class SunAbrahamResults:
30
+ """
31
+ Results from Sun-Abraham (2021) interaction-weighted estimation.
32
+
33
+ Attributes
34
+ ----------
35
+ event_study_effects : dict
36
+ Dictionary mapping relative time to effect dictionaries with keys:
37
+ 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_groups'.
38
+ overall_att : float
39
+ Overall average treatment effect (weighted average of post-treatment effects).
40
+ overall_se : float
41
+ Standard error of overall ATT.
42
+ overall_t_stat : float
43
+ T-statistic for overall ATT.
44
+ overall_p_value : float
45
+ P-value for overall ATT.
46
+ overall_conf_int : tuple
47
+ Confidence interval for overall ATT.
48
+ cohort_weights : dict
49
+ Dictionary mapping relative time to cohort weight dictionaries.
50
+ groups : list
51
+ List of treatment cohorts (first treatment periods).
52
+ time_periods : list
53
+ List of all time periods.
54
+ n_obs : int
55
+ Total number of observations.
56
+ n_treated_units : int
57
+ Number of ever-treated units.
58
+ n_control_units : int
59
+ Number of never-treated units.
60
+ alpha : float
61
+ Significance level used for confidence intervals.
62
+ control_group : str
63
+ Type of control group used.
64
+ """
65
+
66
+ event_study_effects: Dict[int, Dict[str, Any]]
67
+ overall_att: float
68
+ overall_se: float
69
+ overall_t_stat: float
70
+ overall_p_value: float
71
+ overall_conf_int: Tuple[float, float]
72
+ cohort_weights: Dict[int, Dict[Any, float]]
73
+ groups: List[Any]
74
+ time_periods: List[Any]
75
+ n_obs: int
76
+ n_treated_units: int
77
+ n_control_units: int
78
+ alpha: float = 0.05
79
+ control_group: str = "never_treated"
80
+ bootstrap_results: Optional["SABootstrapResults"] = field(default=None, repr=False)
81
+ cohort_effects: Optional[Dict[Tuple[Any, int], Dict[str, Any]]] = field(
82
+ default=None, repr=False
83
+ )
84
+
85
+ def __repr__(self) -> str:
86
+ """Concise string representation."""
87
+ sig = _get_significance_stars(self.overall_p_value)
88
+ n_rel_periods = len(self.event_study_effects)
89
+ return (
90
+ f"SunAbrahamResults(ATT={self.overall_att:.4f}{sig}, "
91
+ f"SE={self.overall_se:.4f}, "
92
+ f"n_groups={len(self.groups)}, "
93
+ f"n_rel_periods={n_rel_periods})"
94
+ )
95
+
96
+ def summary(self, alpha: Optional[float] = None) -> str:
97
+ """
98
+ Generate formatted summary of estimation results.
99
+
100
+ Parameters
101
+ ----------
102
+ alpha : float, optional
103
+ Significance level. Defaults to alpha used in estimation.
104
+
105
+ Returns
106
+ -------
107
+ str
108
+ Formatted summary.
109
+ """
110
+ alpha = alpha or self.alpha
111
+ conf_level = int((1 - alpha) * 100)
112
+
113
+ lines = [
114
+ "=" * 85,
115
+ "Sun-Abraham Interaction-Weighted Estimator Results".center(85),
116
+ "=" * 85,
117
+ "",
118
+ f"{'Total observations:':<30} {self.n_obs:>10}",
119
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
120
+ f"{'Control units:':<30} {self.n_control_units:>10}",
121
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
122
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
123
+ f"{'Control group:':<30} {self.control_group:>10}",
124
+ "",
125
+ ]
126
+
127
+ # Overall ATT
128
+ lines.extend(
129
+ [
130
+ "-" * 85,
131
+ "Overall Average Treatment Effect on the Treated".center(85),
132
+ "-" * 85,
133
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
134
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
135
+ "-" * 85,
136
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
137
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
138
+ f"{_get_significance_stars(self.overall_p_value):>6}",
139
+ "-" * 85,
140
+ "",
141
+ f"{conf_level}% Confidence Interval: "
142
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
143
+ "",
144
+ ]
145
+ )
146
+
147
+ # Event study effects
148
+ lines.extend(
149
+ [
150
+ "-" * 85,
151
+ "Event Study (Dynamic) Effects".center(85),
152
+ "-" * 85,
153
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
154
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
155
+ "-" * 85,
156
+ ]
157
+ )
158
+
159
+ for rel_t in sorted(self.event_study_effects.keys()):
160
+ eff = self.event_study_effects[rel_t]
161
+ sig = _get_significance_stars(eff["p_value"])
162
+ lines.append(
163
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
164
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
165
+ )
166
+
167
+ lines.extend(["-" * 85, ""])
168
+
169
+ lines.extend(
170
+ [
171
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
172
+ "=" * 85,
173
+ ]
174
+ )
175
+
176
+ return "\n".join(lines)
177
+
178
+ def print_summary(self, alpha: Optional[float] = None) -> None:
179
+ """Print summary to stdout."""
180
+ print(self.summary(alpha))
181
+
182
+ def to_dataframe(self, level: str = "event_study") -> pd.DataFrame:
183
+ """
184
+ Convert results to DataFrame.
185
+
186
+ Parameters
187
+ ----------
188
+ level : str, default="event_study"
189
+ Level of aggregation: "event_study" or "cohort".
190
+
191
+ Returns
192
+ -------
193
+ pd.DataFrame
194
+ Results as DataFrame.
195
+ """
196
+ if level == "event_study":
197
+ rows = []
198
+ for rel_t, data in sorted(self.event_study_effects.items()):
199
+ rows.append(
200
+ {
201
+ "relative_period": rel_t,
202
+ "effect": data["effect"],
203
+ "se": data["se"],
204
+ "t_stat": data["t_stat"],
205
+ "p_value": data["p_value"],
206
+ "conf_int_lower": data["conf_int"][0],
207
+ "conf_int_upper": data["conf_int"][1],
208
+ }
209
+ )
210
+ return pd.DataFrame(rows)
211
+
212
+ elif level == "cohort":
213
+ if self.cohort_effects is None:
214
+ raise ValueError(
215
+ "Cohort-level effects not available. "
216
+ "They are computed internally but not stored by default."
217
+ )
218
+ rows = []
219
+ for (cohort, rel_t), data in sorted(self.cohort_effects.items()):
220
+ rows.append(
221
+ {
222
+ "cohort": cohort,
223
+ "relative_period": rel_t,
224
+ "effect": data["effect"],
225
+ "se": data["se"],
226
+ "weight": data.get("weight", np.nan),
227
+ }
228
+ )
229
+ return pd.DataFrame(rows)
230
+
231
+ else:
232
+ raise ValueError(
233
+ f"Unknown level: {level}. Use 'event_study' or 'cohort'."
234
+ )
235
+
236
+ @property
237
+ def is_significant(self) -> bool:
238
+ """Check if overall ATT is significant."""
239
+ return bool(self.overall_p_value < self.alpha)
240
+
241
+ @property
242
+ def significance_stars(self) -> str:
243
+ """Significance stars for overall ATT."""
244
+ return _get_significance_stars(self.overall_p_value)
245
+
246
+
247
+ @dataclass
248
+ class SABootstrapResults:
249
+ """
250
+ Results from Sun-Abraham bootstrap inference.
251
+
252
+ Attributes
253
+ ----------
254
+ n_bootstrap : int
255
+ Number of bootstrap iterations.
256
+ weight_type : str
257
+ Type of bootstrap used (always "pairs" for pairs bootstrap).
258
+ alpha : float
259
+ Significance level used for confidence intervals.
260
+ overall_att_se : float
261
+ Bootstrap standard error for overall ATT.
262
+ overall_att_ci : Tuple[float, float]
263
+ Bootstrap confidence interval for overall ATT.
264
+ overall_att_p_value : float
265
+ Bootstrap p-value for overall ATT.
266
+ event_study_ses : Dict[int, float]
267
+ Bootstrap SEs for event study effects.
268
+ event_study_cis : Dict[int, Tuple[float, float]]
269
+ Bootstrap CIs for event study effects.
270
+ event_study_p_values : Dict[int, float]
271
+ Bootstrap p-values for event study effects.
272
+ bootstrap_distribution : Optional[np.ndarray]
273
+ Full bootstrap distribution of overall ATT.
274
+ """
275
+
276
+ n_bootstrap: int
277
+ weight_type: str
278
+ alpha: float
279
+ overall_att_se: float
280
+ overall_att_ci: Tuple[float, float]
281
+ overall_att_p_value: float
282
+ event_study_ses: Dict[int, float]
283
+ event_study_cis: Dict[int, Tuple[float, float]]
284
+ event_study_p_values: Dict[int, float]
285
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
286
+
287
+
288
+ class SunAbraham:
289
+ """
290
+ Sun-Abraham (2021) interaction-weighted estimator for staggered DiD.
291
+
292
+ This estimator provides event-study coefficients using a saturated
293
+ TWFE regression with cohort × relative-time interactions, following
294
+ the methodology in Sun & Abraham (2021).
295
+
296
+ The estimation procedure follows three steps:
297
+ 1. Run a saturated TWFE regression with cohort × relative-time dummies
298
+ 2. Compute cohort shares (weights) at each relative time
299
+ 3. Aggregate cohort-specific effects using interaction weights
300
+
301
+ This avoids the negative weighting problem of standard TWFE and provides
302
+ consistent event-study estimates under treatment effect heterogeneity.
303
+
304
+ Parameters
305
+ ----------
306
+ control_group : str, default="never_treated"
307
+ Which units to use as controls:
308
+ - "never_treated": Use only never-treated units (recommended)
309
+ - "not_yet_treated": Use never-treated and not-yet-treated units
310
+ anticipation : int, default=0
311
+ Number of periods before treatment where effects may occur.
312
+ alpha : float, default=0.05
313
+ Significance level for confidence intervals.
314
+ cluster : str, optional
315
+ Column name for cluster-robust standard errors.
316
+ If None, clusters at the unit level by default.
317
+ n_bootstrap : int, default=0
318
+ Number of bootstrap iterations for inference.
319
+ If 0, uses analytical cluster-robust standard errors.
320
+ seed : int, optional
321
+ Random seed for reproducibility.
322
+ rank_deficient_action : str, default="warn"
323
+ Action when design matrix is rank-deficient (linearly dependent columns):
324
+ - "warn": Issue warning and drop linearly dependent columns (default)
325
+ - "error": Raise ValueError
326
+ - "silent": Drop columns silently without warning
327
+
328
+ Attributes
329
+ ----------
330
+ results_ : SunAbrahamResults
331
+ Estimation results after calling fit().
332
+ is_fitted_ : bool
333
+ Whether the model has been fitted.
334
+
335
+ Examples
336
+ --------
337
+ Basic usage:
338
+
339
+ >>> import pandas as pd
340
+ >>> from diff_diff import SunAbraham
341
+ >>>
342
+ >>> # Panel data with staggered treatment
343
+ >>> data = pd.DataFrame({
344
+ ... 'unit': [...],
345
+ ... 'time': [...],
346
+ ... 'outcome': [...],
347
+ ... 'first_treat': [...] # 0 for never-treated
348
+ ... })
349
+ >>>
350
+ >>> sa = SunAbraham()
351
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
352
+ ... time='time', first_treat='first_treat')
353
+ >>> results.print_summary()
354
+
355
+ With covariates:
356
+
357
+ >>> sa = SunAbraham()
358
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
359
+ ... time='time', first_treat='first_treat',
360
+ ... covariates=['age', 'income'])
361
+
362
+ Notes
363
+ -----
364
+ The Sun-Abraham estimator uses a saturated regression approach:
365
+
366
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × 1(G_i=g) × D_{it}^e] + X'γ + ε_it
367
+
368
+ where:
369
+ - α_i = unit fixed effects
370
+ - λ_t = time fixed effects
371
+ - G_i = unit i's treatment cohort (first treatment period)
372
+ - D_{it}^e = indicator for being e periods from treatment
373
+ - δ_{g,e} = cohort-specific effect (CATT) at relative time e
374
+
375
+ The event-study coefficients are then computed as:
376
+
377
+ β_e = Σ_g w_{g,e} × δ_{g,e}
378
+
379
+ where w_{g,e} is the share of cohort g in the treated population at
380
+ relative time e (interaction weights).
381
+
382
+ Compared to Callaway-Sant'Anna:
383
+ - SA uses saturated regression; CS uses 2x2 DiD comparisons
384
+ - SA can be more efficient when model is correctly specified
385
+ - Both are consistent under heterogeneous treatment effects
386
+ - Running both provides a useful robustness check
387
+
388
+ References
389
+ ----------
390
+ Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in
391
+ event studies with heterogeneous treatment effects. Journal of
392
+ Econometrics, 225(2), 175-199.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ control_group: str = "never_treated",
398
+ anticipation: int = 0,
399
+ alpha: float = 0.05,
400
+ cluster: Optional[str] = None,
401
+ n_bootstrap: int = 0,
402
+ seed: Optional[int] = None,
403
+ rank_deficient_action: str = "warn",
404
+ ):
405
+ if control_group not in ["never_treated", "not_yet_treated"]:
406
+ raise ValueError(
407
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
408
+ f"got '{control_group}'"
409
+ )
410
+
411
+ if rank_deficient_action not in ["warn", "error", "silent"]:
412
+ raise ValueError(
413
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
414
+ f"got '{rank_deficient_action}'"
415
+ )
416
+
417
+ self.control_group = control_group
418
+ self.anticipation = anticipation
419
+ self.alpha = alpha
420
+ self.cluster = cluster
421
+ self.n_bootstrap = n_bootstrap
422
+ self.seed = seed
423
+ self.rank_deficient_action = rank_deficient_action
424
+
425
+ self.is_fitted_ = False
426
+ self.results_: Optional[SunAbrahamResults] = None
427
+ self._reference_period = -1 # Will be set during fit
428
+
429
+ def fit(
430
+ self,
431
+ data: pd.DataFrame,
432
+ outcome: str,
433
+ unit: str,
434
+ time: str,
435
+ first_treat: str,
436
+ covariates: Optional[List[str]] = None,
437
+ min_pre_periods: int = 1,
438
+ min_post_periods: int = 1,
439
+ ) -> SunAbrahamResults:
440
+ """
441
+ Fit the Sun-Abraham estimator using saturated regression.
442
+
443
+ Parameters
444
+ ----------
445
+ data : pd.DataFrame
446
+ Panel data with unit and time identifiers.
447
+ outcome : str
448
+ Name of outcome variable column.
449
+ unit : str
450
+ Name of unit identifier column.
451
+ time : str
452
+ Name of time period column.
453
+ first_treat : str
454
+ Name of column indicating when unit was first treated.
455
+ Use 0 (or np.inf) for never-treated units.
456
+ covariates : list, optional
457
+ List of covariate column names to include in regression.
458
+ min_pre_periods : int, default=1
459
+ Minimum number of pre-treatment periods to include in event study.
460
+ min_post_periods : int, default=1
461
+ Minimum number of post-treatment periods to include in event study.
462
+
463
+ Returns
464
+ -------
465
+ SunAbrahamResults
466
+ Object containing all estimation results.
467
+
468
+ Raises
469
+ ------
470
+ ValueError
471
+ If required columns are missing or data validation fails.
472
+ """
473
+ # Validate inputs
474
+ required_cols = [outcome, unit, time, first_treat]
475
+ if covariates:
476
+ required_cols.extend(covariates)
477
+
478
+ missing = [c for c in required_cols if c not in data.columns]
479
+ if missing:
480
+ raise ValueError(f"Missing columns: {missing}")
481
+
482
+ # Create working copy
483
+ df = data.copy()
484
+
485
+ # Ensure numeric types
486
+ df[time] = pd.to_numeric(df[time])
487
+ df[first_treat] = pd.to_numeric(df[first_treat])
488
+
489
+ # Identify groups and time periods
490
+ time_periods = sorted(df[time].unique())
491
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
492
+
493
+ # Never-treated indicator
494
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
495
+
496
+ # Get unique units
497
+ unit_info = (
498
+ df.groupby(unit)
499
+ .agg({first_treat: "first", "_never_treated": "first"})
500
+ .reset_index()
501
+ )
502
+
503
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
504
+ n_control_units = int((unit_info["_never_treated"]).sum())
505
+
506
+ if n_control_units == 0:
507
+ raise ValueError(
508
+ "No never-treated units found. Check 'first_treat' column."
509
+ )
510
+
511
+ if len(treatment_groups) == 0:
512
+ raise ValueError(
513
+ "No treated units found. Check 'first_treat' column."
514
+ )
515
+
516
+ # Compute relative time for each observation (vectorized)
517
+ df["_rel_time"] = np.where(
518
+ df[first_treat] > 0,
519
+ df[time] - df[first_treat],
520
+ np.nan
521
+ )
522
+
523
+ # Identify the range of relative time periods to estimate
524
+ rel_times_by_cohort = {}
525
+ for g in treatment_groups:
526
+ g_times = df[df[first_treat] == g][time].unique()
527
+ rel_times_by_cohort[g] = sorted([t - g for t in g_times])
528
+
529
+ # Find all relative time values
530
+ all_rel_times: set = set()
531
+ for g, rel_times in rel_times_by_cohort.items():
532
+ all_rel_times.update(rel_times)
533
+
534
+ all_rel_times_sorted = sorted(all_rel_times)
535
+
536
+ # Filter to reasonable range
537
+ min_rel = max(min(all_rel_times_sorted), -20) # cap at -20
538
+ max_rel = min(max(all_rel_times_sorted), 20) # cap at +20
539
+
540
+ # Reference period: last pre-treatment period (typically -1)
541
+ self._reference_period = -1 - self.anticipation
542
+
543
+ # Get relative periods to estimate (excluding reference)
544
+ rel_periods_to_estimate = [
545
+ e
546
+ for e in all_rel_times_sorted
547
+ if min_rel <= e <= max_rel and e != self._reference_period
548
+ ]
549
+
550
+ # Determine cluster variable
551
+ cluster_var = self.cluster if self.cluster is not None else unit
552
+
553
+ # Filter data based on control_group setting
554
+ if self.control_group == "never_treated":
555
+ # Only keep never-treated as controls
556
+ df_reg = df[df["_never_treated"] | (df[first_treat] > 0)].copy()
557
+ else:
558
+ # Keep all units (not_yet_treated will be handled by the regression)
559
+ df_reg = df.copy()
560
+
561
+ # Fit saturated regression
562
+ (
563
+ cohort_effects,
564
+ cohort_ses,
565
+ vcov_cohort,
566
+ coef_index_map,
567
+ ) = self._fit_saturated_regression(
568
+ df_reg,
569
+ outcome,
570
+ unit,
571
+ time,
572
+ first_treat,
573
+ treatment_groups,
574
+ rel_periods_to_estimate,
575
+ covariates,
576
+ cluster_var,
577
+ )
578
+
579
+ # Compute interaction-weighted event study effects
580
+ event_study_effects, cohort_weights = self._compute_iw_effects(
581
+ df,
582
+ unit,
583
+ first_treat,
584
+ treatment_groups,
585
+ rel_periods_to_estimate,
586
+ cohort_effects,
587
+ cohort_ses,
588
+ vcov_cohort,
589
+ coef_index_map,
590
+ )
591
+
592
+ # Compute overall ATT (average of post-treatment effects)
593
+ overall_att, overall_se = self._compute_overall_att(
594
+ df,
595
+ first_treat,
596
+ event_study_effects,
597
+ cohort_effects,
598
+ cohort_weights,
599
+ vcov_cohort,
600
+ coef_index_map,
601
+ )
602
+
603
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
604
+ overall_p = compute_p_value(overall_t)
605
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
606
+
607
+ # Run bootstrap if requested
608
+ bootstrap_results = None
609
+ if self.n_bootstrap > 0:
610
+ bootstrap_results = self._run_bootstrap(
611
+ df=df_reg,
612
+ outcome=outcome,
613
+ unit=unit,
614
+ time=time,
615
+ first_treat=first_treat,
616
+ treatment_groups=treatment_groups,
617
+ rel_periods_to_estimate=rel_periods_to_estimate,
618
+ covariates=covariates,
619
+ cluster_var=cluster_var,
620
+ original_event_study=event_study_effects,
621
+ original_overall_att=overall_att,
622
+ )
623
+
624
+ # Update results with bootstrap inference
625
+ overall_se = bootstrap_results.overall_att_se
626
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
627
+ overall_p = bootstrap_results.overall_att_p_value
628
+ overall_ci = bootstrap_results.overall_att_ci
629
+
630
+ # Update event study effects
631
+ for e in event_study_effects:
632
+ if e in bootstrap_results.event_study_ses:
633
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
634
+ event_study_effects[e]["conf_int"] = (
635
+ bootstrap_results.event_study_cis[e]
636
+ )
637
+ event_study_effects[e]["p_value"] = (
638
+ bootstrap_results.event_study_p_values[e]
639
+ )
640
+ eff_val = event_study_effects[e]["effect"]
641
+ se_val = event_study_effects[e]["se"]
642
+ event_study_effects[e]["t_stat"] = (
643
+ eff_val / se_val if se_val > 0 else 0.0
644
+ )
645
+
646
+ # Convert cohort effects to storage format
647
+ cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {}
648
+ for (g, e), effect in cohort_effects.items():
649
+ weight = cohort_weights.get(e, {}).get(g, 0.0)
650
+ se = cohort_ses.get((g, e), 0.0)
651
+ cohort_effects_storage[(g, e)] = {
652
+ "effect": effect,
653
+ "se": se,
654
+ "weight": weight,
655
+ }
656
+
657
+ # Store results
658
+ self.results_ = SunAbrahamResults(
659
+ event_study_effects=event_study_effects,
660
+ overall_att=overall_att,
661
+ overall_se=overall_se,
662
+ overall_t_stat=overall_t,
663
+ overall_p_value=overall_p,
664
+ overall_conf_int=overall_ci,
665
+ cohort_weights=cohort_weights,
666
+ groups=treatment_groups,
667
+ time_periods=time_periods,
668
+ n_obs=len(df),
669
+ n_treated_units=n_treated_units,
670
+ n_control_units=n_control_units,
671
+ alpha=self.alpha,
672
+ control_group=self.control_group,
673
+ bootstrap_results=bootstrap_results,
674
+ cohort_effects=cohort_effects_storage,
675
+ )
676
+
677
+ self.is_fitted_ = True
678
+ return self.results_
679
+
680
+ def _fit_saturated_regression(
681
+ self,
682
+ df: pd.DataFrame,
683
+ outcome: str,
684
+ unit: str,
685
+ time: str,
686
+ first_treat: str,
687
+ treatment_groups: List[Any],
688
+ rel_periods: List[int],
689
+ covariates: Optional[List[str]],
690
+ cluster_var: str,
691
+ ) -> Tuple[
692
+ Dict[Tuple[Any, int], float],
693
+ Dict[Tuple[Any, int], float],
694
+ np.ndarray,
695
+ Dict[Tuple[Any, int], int],
696
+ ]:
697
+ """
698
+ Fit saturated TWFE regression with cohort × relative-time interactions.
699
+
700
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε
701
+
702
+ Uses within-transformation for unit fixed effects and time dummies.
703
+
704
+ Returns
705
+ -------
706
+ cohort_effects : dict
707
+ Mapping (cohort, rel_period) -> effect estimate δ_{g,e}
708
+ cohort_ses : dict
709
+ Mapping (cohort, rel_period) -> standard error
710
+ vcov : np.ndarray
711
+ Variance-covariance matrix for cohort effects
712
+ coef_index_map : dict
713
+ Mapping (cohort, rel_period) -> index in coefficient vector
714
+ """
715
+ df = df.copy()
716
+
717
+ # Create cohort × relative-time interaction dummies
718
+ # Exclude reference period
719
+ # Build all columns at once to avoid fragmentation
720
+ interaction_data = {}
721
+ coef_index_map: Dict[Tuple[Any, int], int] = {}
722
+ idx = 0
723
+
724
+ for g in treatment_groups:
725
+ for e in rel_periods:
726
+ col_name = f"_D_{g}_{e}"
727
+ # Indicator: unit is in cohort g AND at relative time e
728
+ indicator = (
729
+ (df[first_treat] == g) &
730
+ (df["_rel_time"] == e)
731
+ ).astype(float)
732
+
733
+ # Only include if there are observations
734
+ if indicator.sum() > 0:
735
+ interaction_data[col_name] = indicator.values
736
+ coef_index_map[(g, e)] = idx
737
+ idx += 1
738
+
739
+ # Add all interaction columns at once
740
+ interaction_cols = list(interaction_data.keys())
741
+ if interaction_data:
742
+ interaction_df = pd.DataFrame(interaction_data, index=df.index)
743
+ df = pd.concat([df, interaction_df], axis=1)
744
+
745
+ if len(interaction_cols) == 0:
746
+ raise ValueError(
747
+ "No valid cohort × relative-time interactions found. "
748
+ "Check your data structure."
749
+ )
750
+
751
+ # Apply within-transformation for unit and time fixed effects
752
+ variables_to_demean = [outcome] + interaction_cols
753
+ if covariates:
754
+ variables_to_demean.extend(covariates)
755
+
756
+ df_demeaned = self._within_transform(df, variables_to_demean, unit, time)
757
+
758
+ # Build design matrix
759
+ X_cols = [f"{col}_dm" for col in interaction_cols]
760
+ if covariates:
761
+ X_cols.extend([f"{cov}_dm" for cov in covariates])
762
+
763
+ X = df_demeaned[X_cols].values
764
+ y = df_demeaned[f"{outcome}_dm"].values
765
+
766
+ # Fit OLS using LinearRegression helper (more stable than manual X'X inverse)
767
+ cluster_ids = df_demeaned[cluster_var].values
768
+ reg = LinearRegression(
769
+ include_intercept=False, # Already demeaned, no intercept needed
770
+ robust=True,
771
+ cluster_ids=cluster_ids,
772
+ rank_deficient_action=self.rank_deficient_action,
773
+ ).fit(X, y)
774
+
775
+ coefficients = reg.coefficients_
776
+ vcov = reg.vcov_
777
+
778
+ # Extract cohort effects and standard errors using get_inference
779
+ cohort_effects: Dict[Tuple[Any, int], float] = {}
780
+ cohort_ses: Dict[Tuple[Any, int], float] = {}
781
+
782
+ n_interactions = len(interaction_cols)
783
+ for (g, e), coef_idx in coef_index_map.items():
784
+ inference = reg.get_inference(coef_idx)
785
+ cohort_effects[(g, e)] = inference.coefficient
786
+ cohort_ses[(g, e)] = inference.se
787
+
788
+ # Extract just the vcov for cohort effects (excluding covariates)
789
+ vcov_cohort = vcov[:n_interactions, :n_interactions]
790
+
791
+ return cohort_effects, cohort_ses, vcov_cohort, coef_index_map
792
+
793
+ def _within_transform(
794
+ self,
795
+ df: pd.DataFrame,
796
+ variables: List[str],
797
+ unit: str,
798
+ time: str,
799
+ ) -> pd.DataFrame:
800
+ """
801
+ Apply two-way within transformation to remove unit and time fixed effects.
802
+
803
+ y_it - y_i. - y_.t + y_..
804
+ """
805
+ return _within_transform_util(df, variables, unit, time, suffix="_dm")
806
+
807
+ def _compute_iw_effects(
808
+ self,
809
+ df: pd.DataFrame,
810
+ unit: str,
811
+ first_treat: str,
812
+ treatment_groups: List[Any],
813
+ rel_periods: List[int],
814
+ cohort_effects: Dict[Tuple[Any, int], float],
815
+ cohort_ses: Dict[Tuple[Any, int], float],
816
+ vcov_cohort: np.ndarray,
817
+ coef_index_map: Dict[Tuple[Any, int], int],
818
+ ) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]:
819
+ """
820
+ Compute interaction-weighted event study effects.
821
+
822
+ β_e = Σ_g w_{g,e} × δ_{g,e}
823
+
824
+ where w_{g,e} is the share of cohort g among treated units at relative time e.
825
+
826
+ Returns
827
+ -------
828
+ event_study_effects : dict
829
+ Dictionary mapping relative period to aggregated effect info.
830
+ cohort_weights : dict
831
+ Dictionary mapping relative period to cohort weight dictionary.
832
+ """
833
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
834
+ cohort_weights: Dict[int, Dict[Any, float]] = {}
835
+
836
+ # Get cohort sizes
837
+ unit_cohorts = df.groupby(unit)[first_treat].first()
838
+ cohort_sizes = unit_cohorts[unit_cohorts > 0].value_counts().to_dict()
839
+
840
+ for e in rel_periods:
841
+ # Get cohorts that have observations at this relative time
842
+ cohorts_at_e = [
843
+ g for g in treatment_groups
844
+ if (g, e) in cohort_effects
845
+ ]
846
+
847
+ if not cohorts_at_e:
848
+ continue
849
+
850
+ # Compute IW weights: share of each cohort among those observed at e
851
+ weights = {}
852
+ total_size = 0
853
+ for g in cohorts_at_e:
854
+ n_g = cohort_sizes.get(g, 0)
855
+ weights[g] = n_g
856
+ total_size += n_g
857
+
858
+ if total_size == 0:
859
+ continue
860
+
861
+ # Normalize weights
862
+ for g in weights:
863
+ weights[g] = weights[g] / total_size
864
+
865
+ cohort_weights[e] = weights
866
+
867
+ # Compute weighted average effect
868
+ agg_effect = 0.0
869
+ for g in cohorts_at_e:
870
+ w = weights[g]
871
+ agg_effect += w * cohort_effects[(g, e)]
872
+
873
+ # Compute SE using delta method with vcov
874
+ # Var(β_e) = w' Σ w where w is weight vector and Σ is vcov submatrix
875
+ indices = [coef_index_map[(g, e)] for g in cohorts_at_e]
876
+ weight_vec = np.array([weights[g] for g in cohorts_at_e])
877
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
878
+ agg_var = float(weight_vec @ vcov_subset @ weight_vec)
879
+ agg_se = np.sqrt(max(agg_var, 0))
880
+
881
+ t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
882
+ p_val = compute_p_value(t_stat)
883
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
884
+
885
+ event_study_effects[e] = {
886
+ "effect": agg_effect,
887
+ "se": agg_se,
888
+ "t_stat": t_stat,
889
+ "p_value": p_val,
890
+ "conf_int": ci,
891
+ "n_groups": len(cohorts_at_e),
892
+ }
893
+
894
+ return event_study_effects, cohort_weights
895
+
896
+ def _compute_overall_att(
897
+ self,
898
+ df: pd.DataFrame,
899
+ first_treat: str,
900
+ event_study_effects: Dict[int, Dict[str, Any]],
901
+ cohort_effects: Dict[Tuple[Any, int], float],
902
+ cohort_weights: Dict[int, Dict[Any, float]],
903
+ vcov_cohort: np.ndarray,
904
+ coef_index_map: Dict[Tuple[Any, int], int],
905
+ ) -> Tuple[float, float]:
906
+ """
907
+ Compute overall ATT as weighted average of post-treatment effects.
908
+
909
+ Returns (att, se) tuple.
910
+ """
911
+ post_effects = [
912
+ (e, eff)
913
+ for e, eff in event_study_effects.items()
914
+ if e >= 0
915
+ ]
916
+
917
+ if not post_effects:
918
+ return 0.0, 0.0
919
+
920
+ # Weight by number of treated observations at each relative time
921
+ post_weights = []
922
+ post_estimates = []
923
+
924
+ for e, eff in post_effects:
925
+ n_at_e = len(df[(df["_rel_time"] == e) & (df[first_treat] > 0)])
926
+ post_weights.append(max(n_at_e, 1))
927
+ post_estimates.append(eff["effect"])
928
+
929
+ post_weights = np.array(post_weights, dtype=float)
930
+ post_weights = post_weights / post_weights.sum()
931
+
932
+ overall_att = float(np.sum(post_weights * np.array(post_estimates)))
933
+
934
+ # Compute SE using delta method
935
+ # Need to trace back through the full weighting scheme
936
+ # ATT = Σ_e w_e × β_e = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e}
937
+ # Collect all (g, e) pairs and their overall weights
938
+ overall_weights_by_coef: Dict[Tuple[Any, int], float] = {}
939
+
940
+ for i, (e, _) in enumerate(post_effects):
941
+ period_weight = post_weights[i]
942
+ if e in cohort_weights:
943
+ for g, cw in cohort_weights[e].items():
944
+ key = (g, e)
945
+ if key in coef_index_map:
946
+ if key not in overall_weights_by_coef:
947
+ overall_weights_by_coef[key] = 0.0
948
+ overall_weights_by_coef[key] += period_weight * cw
949
+
950
+ if not overall_weights_by_coef:
951
+ # Fallback to simple variance calculation
952
+ overall_var = float(
953
+ np.sum((post_weights ** 2) * np.array([eff["se"] ** 2 for _, eff in post_effects]))
954
+ )
955
+ return overall_att, np.sqrt(overall_var)
956
+
957
+ # Build full weight vector and compute variance
958
+ indices = [coef_index_map[key] for key in overall_weights_by_coef.keys()]
959
+ weight_vec = np.array(list(overall_weights_by_coef.values()))
960
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
961
+ overall_var = float(weight_vec @ vcov_subset @ weight_vec)
962
+ overall_se = np.sqrt(max(overall_var, 0))
963
+
964
+ return overall_att, overall_se
965
+
966
+ def _run_bootstrap(
967
+ self,
968
+ df: pd.DataFrame,
969
+ outcome: str,
970
+ unit: str,
971
+ time: str,
972
+ first_treat: str,
973
+ treatment_groups: List[Any],
974
+ rel_periods_to_estimate: List[int],
975
+ covariates: Optional[List[str]],
976
+ cluster_var: str,
977
+ original_event_study: Dict[int, Dict[str, Any]],
978
+ original_overall_att: float,
979
+ ) -> SABootstrapResults:
980
+ """
981
+ Run pairs bootstrap for inference.
982
+
983
+ Resamples units with replacement and re-estimates the full model.
984
+ """
985
+ if self.n_bootstrap < 50:
986
+ warnings.warn(
987
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
988
+ "for reliable inference.",
989
+ UserWarning,
990
+ stacklevel=3,
991
+ )
992
+
993
+ rng = np.random.default_rng(self.seed)
994
+
995
+ # Get unique units
996
+ all_units = df[unit].unique()
997
+ n_units = len(all_units)
998
+
999
+ # Store bootstrap samples
1000
+ rel_periods = sorted(original_event_study.keys())
1001
+ bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods}
1002
+ bootstrap_overall = np.zeros(self.n_bootstrap)
1003
+
1004
+ for b in range(self.n_bootstrap):
1005
+ # Resample units with replacement (pairs bootstrap)
1006
+ boot_units = rng.choice(all_units, size=n_units, replace=True)
1007
+
1008
+ # Create bootstrap sample efficiently
1009
+ # Build index array for all selected units
1010
+ boot_indices = np.concatenate([
1011
+ df.index[df[unit] == u].values for u in boot_units
1012
+ ])
1013
+ df_b = df.iloc[boot_indices].copy()
1014
+
1015
+ # Reassign unique unit IDs for bootstrap sample
1016
+ # Each resampled unit gets a unique ID
1017
+ new_unit_ids = []
1018
+ current_id = 0
1019
+ for u in boot_units:
1020
+ unit_rows = df[df[unit] == u]
1021
+ for _ in range(len(unit_rows)):
1022
+ new_unit_ids.append(current_id)
1023
+ current_id += 1
1024
+ df_b[unit] = new_unit_ids[:len(df_b)]
1025
+
1026
+ # Recompute relative time (vectorized)
1027
+ df_b["_rel_time"] = np.where(
1028
+ df_b[first_treat] > 0,
1029
+ df_b[time] - df_b[first_treat],
1030
+ np.nan
1031
+ )
1032
+ df_b["_never_treated"] = (
1033
+ (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)
1034
+ )
1035
+
1036
+ try:
1037
+ # Re-estimate saturated regression
1038
+ (
1039
+ cohort_effects_b,
1040
+ cohort_ses_b,
1041
+ vcov_b,
1042
+ coef_map_b,
1043
+ ) = self._fit_saturated_regression(
1044
+ df_b,
1045
+ outcome,
1046
+ unit,
1047
+ time,
1048
+ first_treat,
1049
+ treatment_groups,
1050
+ rel_periods_to_estimate,
1051
+ covariates,
1052
+ cluster_var,
1053
+ )
1054
+
1055
+ # Compute IW effects for this bootstrap sample
1056
+ event_study_b, cohort_weights_b = self._compute_iw_effects(
1057
+ df_b,
1058
+ unit,
1059
+ first_treat,
1060
+ treatment_groups,
1061
+ rel_periods_to_estimate,
1062
+ cohort_effects_b,
1063
+ cohort_ses_b,
1064
+ vcov_b,
1065
+ coef_map_b,
1066
+ )
1067
+
1068
+ # Store bootstrap estimates
1069
+ for e in rel_periods:
1070
+ if e in event_study_b:
1071
+ bootstrap_effects[e][b] = event_study_b[e]["effect"]
1072
+ else:
1073
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1074
+
1075
+ # Compute overall ATT for this bootstrap sample
1076
+ overall_b, _ = self._compute_overall_att(
1077
+ df_b,
1078
+ first_treat,
1079
+ event_study_b,
1080
+ cohort_effects_b,
1081
+ cohort_weights_b,
1082
+ vcov_b,
1083
+ coef_map_b,
1084
+ )
1085
+ bootstrap_overall[b] = overall_b
1086
+
1087
+ except (ValueError, np.linalg.LinAlgError) as exc:
1088
+ # If bootstrap iteration fails, use original
1089
+ warnings.warn(
1090
+ f"Bootstrap iteration {b} failed: {exc}. Using original estimate.",
1091
+ UserWarning,
1092
+ stacklevel=2,
1093
+ )
1094
+ for e in rel_periods:
1095
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1096
+ bootstrap_overall[b] = original_overall_att
1097
+
1098
+ # Compute bootstrap statistics
1099
+ event_study_ses = {}
1100
+ event_study_cis = {}
1101
+ event_study_p_values = {}
1102
+
1103
+ for e in rel_periods:
1104
+ boot_dist = bootstrap_effects[e]
1105
+ original_effect = original_event_study[e]["effect"]
1106
+
1107
+ se = float(np.std(boot_dist, ddof=1))
1108
+ ci = self._compute_percentile_ci(boot_dist, self.alpha)
1109
+ p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
1110
+
1111
+ event_study_ses[e] = se
1112
+ event_study_cis[e] = ci
1113
+ event_study_p_values[e] = p_value
1114
+
1115
+ # Overall ATT statistics
1116
+ overall_se = float(np.std(bootstrap_overall, ddof=1))
1117
+ overall_ci = self._compute_percentile_ci(bootstrap_overall, self.alpha)
1118
+ overall_p = self._compute_bootstrap_pvalue(
1119
+ original_overall_att, bootstrap_overall
1120
+ )
1121
+
1122
+ return SABootstrapResults(
1123
+ n_bootstrap=self.n_bootstrap,
1124
+ weight_type="pairs",
1125
+ alpha=self.alpha,
1126
+ overall_att_se=overall_se,
1127
+ overall_att_ci=overall_ci,
1128
+ overall_att_p_value=overall_p,
1129
+ event_study_ses=event_study_ses,
1130
+ event_study_cis=event_study_cis,
1131
+ event_study_p_values=event_study_p_values,
1132
+ bootstrap_distribution=bootstrap_overall,
1133
+ )
1134
+
1135
+ def _compute_percentile_ci(
1136
+ self,
1137
+ boot_dist: np.ndarray,
1138
+ alpha: float,
1139
+ ) -> Tuple[float, float]:
1140
+ """Compute percentile confidence interval."""
1141
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
1142
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
1143
+ return (lower, upper)
1144
+
1145
+ def _compute_bootstrap_pvalue(
1146
+ self,
1147
+ original_effect: float,
1148
+ boot_dist: np.ndarray,
1149
+ ) -> float:
1150
+ """Compute two-sided bootstrap p-value."""
1151
+ if original_effect >= 0:
1152
+ p_one_sided = float(np.mean(boot_dist <= 0))
1153
+ else:
1154
+ p_one_sided = float(np.mean(boot_dist >= 0))
1155
+
1156
+ p_value = min(2 * p_one_sided, 1.0)
1157
+ p_value = max(p_value, 1 / (self.n_bootstrap + 1))
1158
+
1159
+ return p_value
1160
+
1161
+ def get_params(self) -> Dict[str, Any]:
1162
+ """Get estimator parameters (sklearn-compatible)."""
1163
+ return {
1164
+ "control_group": self.control_group,
1165
+ "anticipation": self.anticipation,
1166
+ "alpha": self.alpha,
1167
+ "cluster": self.cluster,
1168
+ "n_bootstrap": self.n_bootstrap,
1169
+ "seed": self.seed,
1170
+ "rank_deficient_action": self.rank_deficient_action,
1171
+ }
1172
+
1173
+ def set_params(self, **params) -> "SunAbraham":
1174
+ """Set estimator parameters (sklearn-compatible)."""
1175
+ for key, value in params.items():
1176
+ if hasattr(self, key):
1177
+ setattr(self, key, value)
1178
+ else:
1179
+ raise ValueError(f"Unknown parameter: {key}")
1180
+ return self
1181
+
1182
+ def summary(self) -> str:
1183
+ """Get summary of estimation results."""
1184
+ if not self.is_fitted_:
1185
+ raise RuntimeError("Model must be fitted before calling summary()")
1186
+ assert self.results_ is not None
1187
+ return self.results_.summary()
1188
+
1189
+ def print_summary(self) -> None:
1190
+ """Print summary to stdout."""
1191
+ print(self.summary())