diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1227 @@
1
+ """
2
+ Sun-Abraham Interaction-Weighted Estimator for staggered DiD.
3
+
4
+ Implements the estimator from Sun & Abraham (2021), "Estimating dynamic
5
+ treatment effects in event studies with heterogeneous treatment effects",
6
+ Journal of Econometrics.
7
+
8
+ This provides an alternative to Callaway-Sant'Anna using a saturated
9
+ regression with cohort × relative-time interactions.
10
+ """
11
+
12
+ import warnings
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from diff_diff.linalg import LinearRegression, compute_robust_vcov
20
+ from diff_diff.results import _get_significance_stars
21
+ from diff_diff.utils import (
22
+ compute_confidence_interval,
23
+ compute_p_value,
24
+ within_transform as _within_transform_util,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class SunAbrahamResults:
30
+ """
31
+ Results from Sun-Abraham (2021) interaction-weighted estimation.
32
+
33
+ Attributes
34
+ ----------
35
+ event_study_effects : dict
36
+ Dictionary mapping relative time to effect dictionaries with keys:
37
+ 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_groups'.
38
+ overall_att : float
39
+ Overall average treatment effect (weighted average of post-treatment effects).
40
+ overall_se : float
41
+ Standard error of overall ATT.
42
+ overall_t_stat : float
43
+ T-statistic for overall ATT.
44
+ overall_p_value : float
45
+ P-value for overall ATT.
46
+ overall_conf_int : tuple
47
+ Confidence interval for overall ATT.
48
+ cohort_weights : dict
49
+ Dictionary mapping relative time to cohort weight dictionaries.
50
+ groups : list
51
+ List of treatment cohorts (first treatment periods).
52
+ time_periods : list
53
+ List of all time periods.
54
+ n_obs : int
55
+ Total number of observations.
56
+ n_treated_units : int
57
+ Number of ever-treated units.
58
+ n_control_units : int
59
+ Number of never-treated units.
60
+ alpha : float
61
+ Significance level used for confidence intervals.
62
+ control_group : str
63
+ Type of control group used.
64
+ """
65
+
66
+ event_study_effects: Dict[int, Dict[str, Any]]
67
+ overall_att: float
68
+ overall_se: float
69
+ overall_t_stat: float
70
+ overall_p_value: float
71
+ overall_conf_int: Tuple[float, float]
72
+ cohort_weights: Dict[int, Dict[Any, float]]
73
+ groups: List[Any]
74
+ time_periods: List[Any]
75
+ n_obs: int
76
+ n_treated_units: int
77
+ n_control_units: int
78
+ alpha: float = 0.05
79
+ control_group: str = "never_treated"
80
+ bootstrap_results: Optional["SABootstrapResults"] = field(default=None, repr=False)
81
+ cohort_effects: Optional[Dict[Tuple[Any, int], Dict[str, Any]]] = field(
82
+ default=None, repr=False
83
+ )
84
+
85
+ def __repr__(self) -> str:
86
+ """Concise string representation."""
87
+ sig = _get_significance_stars(self.overall_p_value)
88
+ n_rel_periods = len(self.event_study_effects)
89
+ return (
90
+ f"SunAbrahamResults(ATT={self.overall_att:.4f}{sig}, "
91
+ f"SE={self.overall_se:.4f}, "
92
+ f"n_groups={len(self.groups)}, "
93
+ f"n_rel_periods={n_rel_periods})"
94
+ )
95
+
96
+ def summary(self, alpha: Optional[float] = None) -> str:
97
+ """
98
+ Generate formatted summary of estimation results.
99
+
100
+ Parameters
101
+ ----------
102
+ alpha : float, optional
103
+ Significance level. Defaults to alpha used in estimation.
104
+
105
+ Returns
106
+ -------
107
+ str
108
+ Formatted summary.
109
+ """
110
+ alpha = alpha or self.alpha
111
+ conf_level = int((1 - alpha) * 100)
112
+
113
+ lines = [
114
+ "=" * 85,
115
+ "Sun-Abraham Interaction-Weighted Estimator Results".center(85),
116
+ "=" * 85,
117
+ "",
118
+ f"{'Total observations:':<30} {self.n_obs:>10}",
119
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
120
+ f"{'Control units:':<30} {self.n_control_units:>10}",
121
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
122
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
123
+ f"{'Control group:':<30} {self.control_group:>10}",
124
+ "",
125
+ ]
126
+
127
+ # Overall ATT
128
+ lines.extend(
129
+ [
130
+ "-" * 85,
131
+ "Overall Average Treatment Effect on the Treated".center(85),
132
+ "-" * 85,
133
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
134
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
135
+ "-" * 85,
136
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
137
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
138
+ f"{_get_significance_stars(self.overall_p_value):>6}",
139
+ "-" * 85,
140
+ "",
141
+ f"{conf_level}% Confidence Interval: "
142
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
143
+ "",
144
+ ]
145
+ )
146
+
147
+ # Event study effects
148
+ lines.extend(
149
+ [
150
+ "-" * 85,
151
+ "Event Study (Dynamic) Effects".center(85),
152
+ "-" * 85,
153
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
154
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
155
+ "-" * 85,
156
+ ]
157
+ )
158
+
159
+ for rel_t in sorted(self.event_study_effects.keys()):
160
+ eff = self.event_study_effects[rel_t]
161
+ sig = _get_significance_stars(eff["p_value"])
162
+ lines.append(
163
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
164
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
165
+ )
166
+
167
+ lines.extend(["-" * 85, ""])
168
+
169
+ lines.extend(
170
+ [
171
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
172
+ "=" * 85,
173
+ ]
174
+ )
175
+
176
+ return "\n".join(lines)
177
+
178
+ def print_summary(self, alpha: Optional[float] = None) -> None:
179
+ """Print summary to stdout."""
180
+ print(self.summary(alpha))
181
+
182
+ def to_dataframe(self, level: str = "event_study") -> pd.DataFrame:
183
+ """
184
+ Convert results to DataFrame.
185
+
186
+ Parameters
187
+ ----------
188
+ level : str, default="event_study"
189
+ Level of aggregation: "event_study" or "cohort".
190
+
191
+ Returns
192
+ -------
193
+ pd.DataFrame
194
+ Results as DataFrame.
195
+ """
196
+ if level == "event_study":
197
+ rows = []
198
+ for rel_t, data in sorted(self.event_study_effects.items()):
199
+ rows.append(
200
+ {
201
+ "relative_period": rel_t,
202
+ "effect": data["effect"],
203
+ "se": data["se"],
204
+ "t_stat": data["t_stat"],
205
+ "p_value": data["p_value"],
206
+ "conf_int_lower": data["conf_int"][0],
207
+ "conf_int_upper": data["conf_int"][1],
208
+ }
209
+ )
210
+ return pd.DataFrame(rows)
211
+
212
+ elif level == "cohort":
213
+ if self.cohort_effects is None:
214
+ raise ValueError(
215
+ "Cohort-level effects not available. "
216
+ "They are computed internally but not stored by default."
217
+ )
218
+ rows = []
219
+ for (cohort, rel_t), data in sorted(self.cohort_effects.items()):
220
+ rows.append(
221
+ {
222
+ "cohort": cohort,
223
+ "relative_period": rel_t,
224
+ "effect": data["effect"],
225
+ "se": data["se"],
226
+ "weight": data.get("weight", np.nan),
227
+ }
228
+ )
229
+ return pd.DataFrame(rows)
230
+
231
+ else:
232
+ raise ValueError(
233
+ f"Unknown level: {level}. Use 'event_study' or 'cohort'."
234
+ )
235
+
236
+ @property
237
+ def is_significant(self) -> bool:
238
+ """Check if overall ATT is significant."""
239
+ return bool(self.overall_p_value < self.alpha)
240
+
241
+ @property
242
+ def significance_stars(self) -> str:
243
+ """Significance stars for overall ATT."""
244
+ return _get_significance_stars(self.overall_p_value)
245
+
246
+
247
+ @dataclass
248
+ class SABootstrapResults:
249
+ """
250
+ Results from Sun-Abraham bootstrap inference.
251
+
252
+ Attributes
253
+ ----------
254
+ n_bootstrap : int
255
+ Number of bootstrap iterations.
256
+ weight_type : str
257
+ Type of bootstrap used (always "pairs" for pairs bootstrap).
258
+ alpha : float
259
+ Significance level used for confidence intervals.
260
+ overall_att_se : float
261
+ Bootstrap standard error for overall ATT.
262
+ overall_att_ci : Tuple[float, float]
263
+ Bootstrap confidence interval for overall ATT.
264
+ overall_att_p_value : float
265
+ Bootstrap p-value for overall ATT.
266
+ event_study_ses : Dict[int, float]
267
+ Bootstrap SEs for event study effects.
268
+ event_study_cis : Dict[int, Tuple[float, float]]
269
+ Bootstrap CIs for event study effects.
270
+ event_study_p_values : Dict[int, float]
271
+ Bootstrap p-values for event study effects.
272
+ bootstrap_distribution : Optional[np.ndarray]
273
+ Full bootstrap distribution of overall ATT.
274
+ """
275
+
276
+ n_bootstrap: int
277
+ weight_type: str
278
+ alpha: float
279
+ overall_att_se: float
280
+ overall_att_ci: Tuple[float, float]
281
+ overall_att_p_value: float
282
+ event_study_ses: Dict[int, float]
283
+ event_study_cis: Dict[int, Tuple[float, float]]
284
+ event_study_p_values: Dict[int, float]
285
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
286
+
287
+
288
+ class SunAbraham:
289
+ """
290
+ Sun-Abraham (2021) interaction-weighted estimator for staggered DiD.
291
+
292
+ This estimator provides event-study coefficients using a saturated
293
+ TWFE regression with cohort × relative-time interactions, following
294
+ the methodology in Sun & Abraham (2021).
295
+
296
+ The estimation procedure follows three steps:
297
+ 1. Run a saturated TWFE regression with cohort × relative-time dummies
298
+ 2. Compute cohort shares (weights) at each relative time
299
+ 3. Aggregate cohort-specific effects using interaction weights
300
+
301
+ This avoids the negative weighting problem of standard TWFE and provides
302
+ consistent event-study estimates under treatment effect heterogeneity.
303
+
304
+ Parameters
305
+ ----------
306
+ control_group : str, default="never_treated"
307
+ Which units to use as controls:
308
+ - "never_treated": Use only never-treated units (recommended)
309
+ - "not_yet_treated": Use never-treated and not-yet-treated units
310
+ anticipation : int, default=0
311
+ Number of periods before treatment where effects may occur.
312
+ alpha : float, default=0.05
313
+ Significance level for confidence intervals.
314
+ cluster : str, optional
315
+ Column name for cluster-robust standard errors.
316
+ If None, clusters at the unit level by default.
317
+ n_bootstrap : int, default=0
318
+ Number of bootstrap iterations for inference.
319
+ If 0, uses analytical cluster-robust standard errors.
320
+ seed : int, optional
321
+ Random seed for reproducibility.
322
+ rank_deficient_action : str, default="warn"
323
+ Action when design matrix is rank-deficient (linearly dependent columns):
324
+ - "warn": Issue warning and drop linearly dependent columns (default)
325
+ - "error": Raise ValueError
326
+ - "silent": Drop columns silently without warning
327
+
328
+ Attributes
329
+ ----------
330
+ results_ : SunAbrahamResults
331
+ Estimation results after calling fit().
332
+ is_fitted_ : bool
333
+ Whether the model has been fitted.
334
+
335
+ Examples
336
+ --------
337
+ Basic usage:
338
+
339
+ >>> import pandas as pd
340
+ >>> from diff_diff import SunAbraham
341
+ >>>
342
+ >>> # Panel data with staggered treatment
343
+ >>> data = pd.DataFrame({
344
+ ... 'unit': [...],
345
+ ... 'time': [...],
346
+ ... 'outcome': [...],
347
+ ... 'first_treat': [...] # 0 for never-treated
348
+ ... })
349
+ >>>
350
+ >>> sa = SunAbraham()
351
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
352
+ ... time='time', first_treat='first_treat')
353
+ >>> results.print_summary()
354
+
355
+ With covariates:
356
+
357
+ >>> sa = SunAbraham()
358
+ >>> results = sa.fit(data, outcome='outcome', unit='unit',
359
+ ... time='time', first_treat='first_treat',
360
+ ... covariates=['age', 'income'])
361
+
362
+ Notes
363
+ -----
364
+ The Sun-Abraham estimator uses a saturated regression approach:
365
+
366
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × 1(G_i=g) × D_{it}^e] + X'γ + ε_it
367
+
368
+ where:
369
+ - α_i = unit fixed effects
370
+ - λ_t = time fixed effects
371
+ - G_i = unit i's treatment cohort (first treatment period)
372
+ - D_{it}^e = indicator for being e periods from treatment
373
+ - δ_{g,e} = cohort-specific effect (CATT) at relative time e
374
+
375
+ The event-study coefficients are then computed as:
376
+
377
+ β_e = Σ_g w_{g,e} × δ_{g,e}
378
+
379
+ where w_{g,e} is the share of cohort g in the treated population at
380
+ relative time e (interaction weights).
381
+
382
+ Compared to Callaway-Sant'Anna:
383
+ - SA uses saturated regression; CS uses 2x2 DiD comparisons
384
+ - SA can be more efficient when model is correctly specified
385
+ - Both are consistent under heterogeneous treatment effects
386
+ - Running both provides a useful robustness check
387
+
388
+ References
389
+ ----------
390
+ Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in
391
+ event studies with heterogeneous treatment effects. Journal of
392
+ Econometrics, 225(2), 175-199.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ control_group: str = "never_treated",
398
+ anticipation: int = 0,
399
+ alpha: float = 0.05,
400
+ cluster: Optional[str] = None,
401
+ n_bootstrap: int = 0,
402
+ seed: Optional[int] = None,
403
+ rank_deficient_action: str = "warn",
404
+ ):
405
+ if control_group not in ["never_treated", "not_yet_treated"]:
406
+ raise ValueError(
407
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
408
+ f"got '{control_group}'"
409
+ )
410
+
411
+ if rank_deficient_action not in ["warn", "error", "silent"]:
412
+ raise ValueError(
413
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
414
+ f"got '{rank_deficient_action}'"
415
+ )
416
+
417
+ self.control_group = control_group
418
+ self.anticipation = anticipation
419
+ self.alpha = alpha
420
+ self.cluster = cluster
421
+ self.n_bootstrap = n_bootstrap
422
+ self.seed = seed
423
+ self.rank_deficient_action = rank_deficient_action
424
+
425
+ self.is_fitted_ = False
426
+ self.results_: Optional[SunAbrahamResults] = None
427
+ self._reference_period = -1 # Will be set during fit
428
+
429
+ def fit(
430
+ self,
431
+ data: pd.DataFrame,
432
+ outcome: str,
433
+ unit: str,
434
+ time: str,
435
+ first_treat: str,
436
+ covariates: Optional[List[str]] = None,
437
+ min_pre_periods: int = 1,
438
+ min_post_periods: int = 1,
439
+ ) -> SunAbrahamResults:
440
+ """
441
+ Fit the Sun-Abraham estimator using saturated regression.
442
+
443
+ Parameters
444
+ ----------
445
+ data : pd.DataFrame
446
+ Panel data with unit and time identifiers.
447
+ outcome : str
448
+ Name of outcome variable column.
449
+ unit : str
450
+ Name of unit identifier column.
451
+ time : str
452
+ Name of time period column.
453
+ first_treat : str
454
+ Name of column indicating when unit was first treated.
455
+ Use 0 (or np.inf) for never-treated units.
456
+ covariates : list, optional
457
+ List of covariate column names to include in regression.
458
+ min_pre_periods : int, default=1
459
+ **Deprecated**: Accepted but ignored. Will be removed in a future version.
460
+ min_post_periods : int, default=1
461
+ **Deprecated**: Accepted but ignored. Will be removed in a future version.
462
+
463
+ Returns
464
+ -------
465
+ SunAbrahamResults
466
+ Object containing all estimation results.
467
+
468
+ Raises
469
+ ------
470
+ ValueError
471
+ If required columns are missing or data validation fails.
472
+ """
473
+ # Deprecation warnings for unimplemented parameters
474
+ if min_pre_periods != 1:
475
+ warnings.warn(
476
+ "min_pre_periods is not yet implemented and will be ignored. "
477
+ "This parameter will be removed in a future version.",
478
+ FutureWarning,
479
+ stacklevel=2,
480
+ )
481
+ if min_post_periods != 1:
482
+ warnings.warn(
483
+ "min_post_periods is not yet implemented and will be ignored. "
484
+ "This parameter will be removed in a future version.",
485
+ FutureWarning,
486
+ stacklevel=2,
487
+ )
488
+
489
+ # Validate inputs
490
+ required_cols = [outcome, unit, time, first_treat]
491
+ if covariates:
492
+ required_cols.extend(covariates)
493
+
494
+ missing = [c for c in required_cols if c not in data.columns]
495
+ if missing:
496
+ raise ValueError(f"Missing columns: {missing}")
497
+
498
+ # Create working copy
499
+ df = data.copy()
500
+
501
+ # Ensure numeric types
502
+ df[time] = pd.to_numeric(df[time])
503
+ df[first_treat] = pd.to_numeric(df[first_treat])
504
+
505
+ # Never-treated indicator (must precede treatment_groups to exclude np.inf)
506
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
507
+ # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
508
+ df.loc[df[first_treat] == np.inf, first_treat] = 0
509
+
510
+ # Identify groups and time periods
511
+ time_periods = sorted(df[time].unique())
512
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
513
+
514
+ # Get unique units
515
+ unit_info = (
516
+ df.groupby(unit)
517
+ .agg({first_treat: "first", "_never_treated": "first"})
518
+ .reset_index()
519
+ )
520
+
521
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
522
+ n_control_units = int((unit_info["_never_treated"]).sum())
523
+
524
+ if n_control_units == 0:
525
+ raise ValueError(
526
+ "No never-treated units found. Check 'first_treat' column."
527
+ )
528
+
529
+ if len(treatment_groups) == 0:
530
+ raise ValueError(
531
+ "No treated units found. Check 'first_treat' column."
532
+ )
533
+
534
+ # Compute relative time for each observation (vectorized)
535
+ df["_rel_time"] = np.where(
536
+ df[first_treat] > 0,
537
+ df[time] - df[first_treat],
538
+ np.nan
539
+ )
540
+
541
+ # Identify the range of relative time periods to estimate
542
+ rel_times_by_cohort = {}
543
+ for g in treatment_groups:
544
+ g_times = df[df[first_treat] == g][time].unique()
545
+ rel_times_by_cohort[g] = sorted([t - g for t in g_times])
546
+
547
+ # Find all relative time values
548
+ all_rel_times: set = set()
549
+ for g, rel_times in rel_times_by_cohort.items():
550
+ all_rel_times.update(rel_times)
551
+
552
+ all_rel_times_sorted = sorted(all_rel_times)
553
+
554
+ # Use full range of relative times (no artificial truncation, matches R's fixest::sunab())
555
+ min_rel = min(all_rel_times_sorted)
556
+ max_rel = max(all_rel_times_sorted)
557
+
558
+ # Reference period: last pre-treatment period (typically -1)
559
+ self._reference_period = -1 - self.anticipation
560
+
561
+ # Get relative periods to estimate (excluding reference)
562
+ rel_periods_to_estimate = [
563
+ e
564
+ for e in all_rel_times_sorted
565
+ if min_rel <= e <= max_rel and e != self._reference_period
566
+ ]
567
+
568
+ # Determine cluster variable
569
+ cluster_var = self.cluster if self.cluster is not None else unit
570
+
571
+ # Filter data based on control_group setting
572
+ if self.control_group == "never_treated":
573
+ # Only keep never-treated as controls
574
+ df_reg = df[df["_never_treated"] | (df[first_treat] > 0)].copy()
575
+ else:
576
+ # Keep all units (not_yet_treated will be handled by the regression)
577
+ df_reg = df.copy()
578
+
579
+ # Fit saturated regression
580
+ (
581
+ cohort_effects,
582
+ cohort_ses,
583
+ vcov_cohort,
584
+ coef_index_map,
585
+ ) = self._fit_saturated_regression(
586
+ df_reg,
587
+ outcome,
588
+ unit,
589
+ time,
590
+ first_treat,
591
+ treatment_groups,
592
+ rel_periods_to_estimate,
593
+ covariates,
594
+ cluster_var,
595
+ )
596
+
597
+ # Compute interaction-weighted event study effects
598
+ event_study_effects, cohort_weights = self._compute_iw_effects(
599
+ df,
600
+ unit,
601
+ first_treat,
602
+ treatment_groups,
603
+ rel_periods_to_estimate,
604
+ cohort_effects,
605
+ cohort_ses,
606
+ vcov_cohort,
607
+ coef_index_map,
608
+ )
609
+
610
+ # Compute overall ATT (average of post-treatment effects)
611
+ overall_att, overall_se = self._compute_overall_att(
612
+ df,
613
+ first_treat,
614
+ event_study_effects,
615
+ cohort_effects,
616
+ cohort_weights,
617
+ vcov_cohort,
618
+ coef_index_map,
619
+ )
620
+
621
+ overall_t = overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
622
+ overall_p = compute_p_value(overall_t)
623
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha) if np.isfinite(overall_se) and overall_se > 0 else (np.nan, np.nan)
624
+
625
+ # Run bootstrap if requested
626
+ bootstrap_results = None
627
+ if self.n_bootstrap > 0:
628
+ bootstrap_results = self._run_bootstrap(
629
+ df=df_reg,
630
+ outcome=outcome,
631
+ unit=unit,
632
+ time=time,
633
+ first_treat=first_treat,
634
+ treatment_groups=treatment_groups,
635
+ rel_periods_to_estimate=rel_periods_to_estimate,
636
+ covariates=covariates,
637
+ cluster_var=cluster_var,
638
+ original_event_study=event_study_effects,
639
+ original_overall_att=overall_att,
640
+ )
641
+
642
+ # Update results with bootstrap inference
643
+ overall_se = bootstrap_results.overall_att_se
644
+ overall_t = overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
645
+ overall_p = bootstrap_results.overall_att_p_value
646
+ overall_ci = bootstrap_results.overall_att_ci
647
+
648
+ # Update event study effects
649
+ for e in event_study_effects:
650
+ if e in bootstrap_results.event_study_ses:
651
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
652
+ event_study_effects[e]["conf_int"] = (
653
+ bootstrap_results.event_study_cis[e]
654
+ )
655
+ event_study_effects[e]["p_value"] = (
656
+ bootstrap_results.event_study_p_values[e]
657
+ )
658
+ eff_val = event_study_effects[e]["effect"]
659
+ se_val = event_study_effects[e]["se"]
660
+ event_study_effects[e]["t_stat"] = (
661
+ eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan
662
+ )
663
+
664
+ # Convert cohort effects to storage format
665
+ cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {}
666
+ for (g, e), effect in cohort_effects.items():
667
+ weight = cohort_weights.get(e, {}).get(g, 0.0)
668
+ se = cohort_ses.get((g, e), 0.0)
669
+ cohort_effects_storage[(g, e)] = {
670
+ "effect": effect,
671
+ "se": se,
672
+ "weight": weight,
673
+ }
674
+
675
+ # Store results
676
+ self.results_ = SunAbrahamResults(
677
+ event_study_effects=event_study_effects,
678
+ overall_att=overall_att,
679
+ overall_se=overall_se,
680
+ overall_t_stat=overall_t,
681
+ overall_p_value=overall_p,
682
+ overall_conf_int=overall_ci,
683
+ cohort_weights=cohort_weights,
684
+ groups=treatment_groups,
685
+ time_periods=time_periods,
686
+ n_obs=len(df),
687
+ n_treated_units=n_treated_units,
688
+ n_control_units=n_control_units,
689
+ alpha=self.alpha,
690
+ control_group=self.control_group,
691
+ bootstrap_results=bootstrap_results,
692
+ cohort_effects=cohort_effects_storage,
693
+ )
694
+
695
+ self.is_fitted_ = True
696
+ return self.results_
697
+
698
+ def _fit_saturated_regression(
699
+ self,
700
+ df: pd.DataFrame,
701
+ outcome: str,
702
+ unit: str,
703
+ time: str,
704
+ first_treat: str,
705
+ treatment_groups: List[Any],
706
+ rel_periods: List[int],
707
+ covariates: Optional[List[str]],
708
+ cluster_var: str,
709
+ ) -> Tuple[
710
+ Dict[Tuple[Any, int], float],
711
+ Dict[Tuple[Any, int], float],
712
+ np.ndarray,
713
+ Dict[Tuple[Any, int], int],
714
+ ]:
715
+ """
716
+ Fit saturated TWFE regression with cohort × relative-time interactions.
717
+
718
+ Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε
719
+
720
+ Uses within-transformation for unit fixed effects and time dummies.
721
+
722
+ Returns
723
+ -------
724
+ cohort_effects : dict
725
+ Mapping (cohort, rel_period) -> effect estimate δ_{g,e}
726
+ cohort_ses : dict
727
+ Mapping (cohort, rel_period) -> standard error
728
+ vcov : np.ndarray
729
+ Variance-covariance matrix for cohort effects
730
+ coef_index_map : dict
731
+ Mapping (cohort, rel_period) -> index in coefficient vector
732
+ """
733
+ df = df.copy()
734
+
735
+ # Create cohort × relative-time interaction dummies
736
+ # Exclude reference period
737
+ # Build all columns at once to avoid fragmentation
738
+ interaction_data = {}
739
+ coef_index_map: Dict[Tuple[Any, int], int] = {}
740
+ idx = 0
741
+
742
+ for g in treatment_groups:
743
+ for e in rel_periods:
744
+ col_name = f"_D_{g}_{e}"
745
+ # Indicator: unit is in cohort g AND at relative time e
746
+ indicator = (
747
+ (df[first_treat] == g) &
748
+ (df["_rel_time"] == e)
749
+ ).astype(float)
750
+
751
+ # Only include if there are observations
752
+ if indicator.sum() > 0:
753
+ interaction_data[col_name] = indicator.values
754
+ coef_index_map[(g, e)] = idx
755
+ idx += 1
756
+
757
+ # Add all interaction columns at once
758
+ interaction_cols = list(interaction_data.keys())
759
+ if interaction_data:
760
+ interaction_df = pd.DataFrame(interaction_data, index=df.index)
761
+ df = pd.concat([df, interaction_df], axis=1)
762
+
763
+ if len(interaction_cols) == 0:
764
+ raise ValueError(
765
+ "No valid cohort × relative-time interactions found. "
766
+ "Check your data structure."
767
+ )
768
+
769
+ # Apply within-transformation for unit and time fixed effects
770
+ variables_to_demean = [outcome] + interaction_cols
771
+ if covariates:
772
+ variables_to_demean.extend(covariates)
773
+
774
+ df_demeaned = self._within_transform(df, variables_to_demean, unit, time)
775
+
776
+ # Build design matrix
777
+ X_cols = [f"{col}_dm" for col in interaction_cols]
778
+ if covariates:
779
+ X_cols.extend([f"{cov}_dm" for cov in covariates])
780
+
781
+ X = df_demeaned[X_cols].values
782
+ y = df_demeaned[f"{outcome}_dm"].values
783
+
784
+ # Fit OLS using LinearRegression helper (more stable than manual X'X inverse)
785
+ cluster_ids = df_demeaned[cluster_var].values
786
+
787
+ # Degrees of freedom adjustment for absorbed unit and time fixed effects
788
+ n_units_fe = df[unit].nunique()
789
+ n_times_fe = df[time].nunique()
790
+ df_adj = n_units_fe + n_times_fe - 1
791
+
792
+ reg = LinearRegression(
793
+ include_intercept=False, # Already demeaned, no intercept needed
794
+ robust=True,
795
+ cluster_ids=cluster_ids,
796
+ rank_deficient_action=self.rank_deficient_action,
797
+ ).fit(X, y, df_adjustment=df_adj)
798
+
799
+ coefficients = reg.coefficients_
800
+ vcov = reg.vcov_
801
+
802
+ # Extract cohort effects and standard errors using get_inference
803
+ cohort_effects: Dict[Tuple[Any, int], float] = {}
804
+ cohort_ses: Dict[Tuple[Any, int], float] = {}
805
+
806
+ n_interactions = len(interaction_cols)
807
+ for (g, e), coef_idx in coef_index_map.items():
808
+ inference = reg.get_inference(coef_idx)
809
+ cohort_effects[(g, e)] = inference.coefficient
810
+ cohort_ses[(g, e)] = inference.se
811
+
812
+ # Extract just the vcov for cohort effects (excluding covariates)
813
+ vcov_cohort = vcov[:n_interactions, :n_interactions]
814
+
815
+ return cohort_effects, cohort_ses, vcov_cohort, coef_index_map
816
+
817
+ def _within_transform(
818
+ self,
819
+ df: pd.DataFrame,
820
+ variables: List[str],
821
+ unit: str,
822
+ time: str,
823
+ ) -> pd.DataFrame:
824
+ """
825
+ Apply two-way within transformation to remove unit and time fixed effects.
826
+
827
+ y_it - y_i. - y_.t + y_..
828
+ """
829
+ return _within_transform_util(df, variables, unit, time, suffix="_dm")
830
+
831
+ def _compute_iw_effects(
832
+ self,
833
+ df: pd.DataFrame,
834
+ unit: str,
835
+ first_treat: str,
836
+ treatment_groups: List[Any],
837
+ rel_periods: List[int],
838
+ cohort_effects: Dict[Tuple[Any, int], float],
839
+ cohort_ses: Dict[Tuple[Any, int], float],
840
+ vcov_cohort: np.ndarray,
841
+ coef_index_map: Dict[Tuple[Any, int], int],
842
+ ) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]:
843
+ """
844
+ Compute interaction-weighted event study effects.
845
+
846
+ β_e = Σ_g w_{g,e} × δ_{g,e}
847
+
848
+ where w_{g,e} = n_{g,e} / Σ_g n_{g,e} is the share of observations from cohort g
849
+ at event-time e among all treated observations at that event-time.
850
+
851
+ Returns
852
+ -------
853
+ event_study_effects : dict
854
+ Dictionary mapping relative period to aggregated effect info.
855
+ cohort_weights : dict
856
+ Dictionary mapping relative period to cohort weight dictionary.
857
+ """
858
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
859
+ cohort_weights: Dict[int, Dict[Any, float]] = {}
860
+
861
+ # Pre-compute per-event-time observation counts: n_{g,e}
862
+ event_time_counts = df[df[first_treat] > 0].groupby([first_treat, "_rel_time"]).size()
863
+
864
+ for e in rel_periods:
865
+ # Get cohorts that have observations at this relative time
866
+ cohorts_at_e = [
867
+ g for g in treatment_groups
868
+ if (g, e) in cohort_effects
869
+ ]
870
+
871
+ if not cohorts_at_e:
872
+ continue
873
+
874
+ # Compute IW weights: n_{g,e} / Σ_g n_{g,e}
875
+ weights = {}
876
+ total_size = 0
877
+ for g in cohorts_at_e:
878
+ n_g_e = event_time_counts.get((g, e), 0)
879
+ weights[g] = n_g_e
880
+ total_size += n_g_e
881
+
882
+ if total_size == 0:
883
+ continue
884
+
885
+ # Normalize weights
886
+ for g in weights:
887
+ weights[g] = weights[g] / total_size
888
+
889
+ cohort_weights[e] = weights
890
+
891
+ # Compute weighted average effect
892
+ agg_effect = 0.0
893
+ for g in cohorts_at_e:
894
+ w = weights[g]
895
+ agg_effect += w * cohort_effects[(g, e)]
896
+
897
+ # Compute SE using delta method with vcov
898
+ # Var(β_e) = w' Σ w where w is weight vector and Σ is vcov submatrix
899
+ indices = [coef_index_map[(g, e)] for g in cohorts_at_e]
900
+ weight_vec = np.array([weights[g] for g in cohorts_at_e])
901
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
902
+ agg_var = float(weight_vec @ vcov_subset @ weight_vec)
903
+ agg_se = np.sqrt(max(agg_var, 0))
904
+
905
+ t_stat = agg_effect / agg_se if np.isfinite(agg_se) and agg_se > 0 else np.nan
906
+ p_val = compute_p_value(t_stat)
907
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha) if np.isfinite(agg_se) and agg_se > 0 else (np.nan, np.nan)
908
+
909
+ event_study_effects[e] = {
910
+ "effect": agg_effect,
911
+ "se": agg_se,
912
+ "t_stat": t_stat,
913
+ "p_value": p_val,
914
+ "conf_int": ci,
915
+ "n_groups": len(cohorts_at_e),
916
+ }
917
+
918
+ return event_study_effects, cohort_weights
919
+
920
+ def _compute_overall_att(
921
+ self,
922
+ df: pd.DataFrame,
923
+ first_treat: str,
924
+ event_study_effects: Dict[int, Dict[str, Any]],
925
+ cohort_effects: Dict[Tuple[Any, int], float],
926
+ cohort_weights: Dict[int, Dict[Any, float]],
927
+ vcov_cohort: np.ndarray,
928
+ coef_index_map: Dict[Tuple[Any, int], int],
929
+ ) -> Tuple[float, float]:
930
+ """
931
+ Compute overall ATT as weighted average of post-treatment effects.
932
+
933
+ Returns (att, se) tuple.
934
+ """
935
+ post_effects = [
936
+ (e, eff)
937
+ for e, eff in event_study_effects.items()
938
+ if e >= 0
939
+ ]
940
+
941
+ if not post_effects:
942
+ return np.nan, np.nan
943
+
944
+ # Weight by number of treated observations at each relative time
945
+ post_weights = []
946
+ post_estimates = []
947
+
948
+ for e, eff in post_effects:
949
+ n_at_e = len(df[(df["_rel_time"] == e) & (df[first_treat] > 0)])
950
+ post_weights.append(max(n_at_e, 1))
951
+ post_estimates.append(eff["effect"])
952
+
953
+ post_weights = np.array(post_weights, dtype=float)
954
+ post_weights = post_weights / post_weights.sum()
955
+
956
+ overall_att = float(np.sum(post_weights * np.array(post_estimates)))
957
+
958
+ # Compute SE using delta method
959
+ # Need to trace back through the full weighting scheme
960
+ # ATT = Σ_e w_e × β_e = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e}
961
+ # Collect all (g, e) pairs and their overall weights
962
+ overall_weights_by_coef: Dict[Tuple[Any, int], float] = {}
963
+
964
+ for i, (e, _) in enumerate(post_effects):
965
+ period_weight = post_weights[i]
966
+ if e in cohort_weights:
967
+ for g, cw in cohort_weights[e].items():
968
+ key = (g, e)
969
+ if key in coef_index_map:
970
+ if key not in overall_weights_by_coef:
971
+ overall_weights_by_coef[key] = 0.0
972
+ overall_weights_by_coef[key] += period_weight * cw
973
+
974
+ if not overall_weights_by_coef:
975
+ # Fallback to simplified variance that ignores covariances between periods
976
+ warnings.warn(
977
+ "Could not construct full weight vector for overall ATT SE. "
978
+ "Using simplified variance that ignores covariances between periods.",
979
+ UserWarning,
980
+ stacklevel=2,
981
+ )
982
+ overall_var = float(
983
+ np.sum((post_weights ** 2) * np.array([eff["se"] ** 2 for _, eff in post_effects]))
984
+ )
985
+ return overall_att, np.sqrt(overall_var)
986
+
987
+ # Build full weight vector and compute variance
988
+ indices = [coef_index_map[key] for key in overall_weights_by_coef.keys()]
989
+ weight_vec = np.array(list(overall_weights_by_coef.values()))
990
+ vcov_subset = vcov_cohort[np.ix_(indices, indices)]
991
+ overall_var = float(weight_vec @ vcov_subset @ weight_vec)
992
+ overall_se = np.sqrt(max(overall_var, 0))
993
+
994
+ return overall_att, overall_se
995
+
996
+ def _run_bootstrap(
997
+ self,
998
+ df: pd.DataFrame,
999
+ outcome: str,
1000
+ unit: str,
1001
+ time: str,
1002
+ first_treat: str,
1003
+ treatment_groups: List[Any],
1004
+ rel_periods_to_estimate: List[int],
1005
+ covariates: Optional[List[str]],
1006
+ cluster_var: str,
1007
+ original_event_study: Dict[int, Dict[str, Any]],
1008
+ original_overall_att: float,
1009
+ ) -> SABootstrapResults:
1010
+ """
1011
+ Run pairs bootstrap for inference.
1012
+
1013
+ Resamples units with replacement and re-estimates the full model.
1014
+ """
1015
+ if self.n_bootstrap < 50:
1016
+ warnings.warn(
1017
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
1018
+ "for reliable inference.",
1019
+ UserWarning,
1020
+ stacklevel=3,
1021
+ )
1022
+
1023
+ rng = np.random.default_rng(self.seed)
1024
+
1025
+ # Get unique units
1026
+ all_units = df[unit].unique()
1027
+ n_units = len(all_units)
1028
+
1029
+ # Store bootstrap samples
1030
+ rel_periods = sorted(original_event_study.keys())
1031
+ bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods}
1032
+ bootstrap_overall = np.zeros(self.n_bootstrap)
1033
+
1034
+ for b in range(self.n_bootstrap):
1035
+ # Resample units with replacement (pairs bootstrap)
1036
+ boot_units = rng.choice(all_units, size=n_units, replace=True)
1037
+
1038
+ # Create bootstrap sample efficiently
1039
+ # Build index array for all selected units
1040
+ boot_indices = np.concatenate([
1041
+ df.index[df[unit] == u].values for u in boot_units
1042
+ ])
1043
+ df_b = df.iloc[boot_indices].copy()
1044
+
1045
+ # Reassign unique unit IDs for bootstrap sample
1046
+ # Each resampled unit gets a unique ID
1047
+ new_unit_ids = []
1048
+ current_id = 0
1049
+ for u in boot_units:
1050
+ unit_rows = df[df[unit] == u]
1051
+ for _ in range(len(unit_rows)):
1052
+ new_unit_ids.append(current_id)
1053
+ current_id += 1
1054
+ df_b[unit] = new_unit_ids[:len(df_b)]
1055
+
1056
+ # Recompute relative time (vectorized)
1057
+ df_b["_rel_time"] = np.where(
1058
+ df_b[first_treat] > 0,
1059
+ df_b[time] - df_b[first_treat],
1060
+ np.nan
1061
+ )
1062
+ # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
1063
+ df_b["_never_treated"] = (
1064
+ (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)
1065
+ )
1066
+
1067
+ try:
1068
+ # Re-estimate saturated regression
1069
+ (
1070
+ cohort_effects_b,
1071
+ cohort_ses_b,
1072
+ vcov_b,
1073
+ coef_map_b,
1074
+ ) = self._fit_saturated_regression(
1075
+ df_b,
1076
+ outcome,
1077
+ unit,
1078
+ time,
1079
+ first_treat,
1080
+ treatment_groups,
1081
+ rel_periods_to_estimate,
1082
+ covariates,
1083
+ cluster_var,
1084
+ )
1085
+
1086
+ # Compute IW effects for this bootstrap sample
1087
+ event_study_b, cohort_weights_b = self._compute_iw_effects(
1088
+ df_b,
1089
+ unit,
1090
+ first_treat,
1091
+ treatment_groups,
1092
+ rel_periods_to_estimate,
1093
+ cohort_effects_b,
1094
+ cohort_ses_b,
1095
+ vcov_b,
1096
+ coef_map_b,
1097
+ )
1098
+
1099
+ # Store bootstrap estimates
1100
+ for e in rel_periods:
1101
+ if e in event_study_b:
1102
+ bootstrap_effects[e][b] = event_study_b[e]["effect"]
1103
+ else:
1104
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1105
+
1106
+ # Compute overall ATT for this bootstrap sample
1107
+ overall_b, _ = self._compute_overall_att(
1108
+ df_b,
1109
+ first_treat,
1110
+ event_study_b,
1111
+ cohort_effects_b,
1112
+ cohort_weights_b,
1113
+ vcov_b,
1114
+ coef_map_b,
1115
+ )
1116
+ bootstrap_overall[b] = overall_b
1117
+
1118
+ except (ValueError, np.linalg.LinAlgError) as exc:
1119
+ # If bootstrap iteration fails, use original
1120
+ warnings.warn(
1121
+ f"Bootstrap iteration {b} failed: {exc}. Using original estimate.",
1122
+ UserWarning,
1123
+ stacklevel=2,
1124
+ )
1125
+ for e in rel_periods:
1126
+ bootstrap_effects[e][b] = original_event_study[e]["effect"]
1127
+ bootstrap_overall[b] = original_overall_att
1128
+
1129
+ # Compute bootstrap statistics
1130
+ event_study_ses = {}
1131
+ event_study_cis = {}
1132
+ event_study_p_values = {}
1133
+
1134
+ for e in rel_periods:
1135
+ boot_dist = bootstrap_effects[e]
1136
+ original_effect = original_event_study[e]["effect"]
1137
+
1138
+ se = float(np.std(boot_dist, ddof=1))
1139
+ ci = self._compute_percentile_ci(boot_dist, self.alpha)
1140
+ p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
1141
+
1142
+ event_study_ses[e] = se
1143
+ event_study_cis[e] = ci
1144
+ event_study_p_values[e] = p_value
1145
+
1146
+ # Overall ATT statistics
1147
+ if not np.isfinite(original_overall_att):
1148
+ overall_se = np.nan
1149
+ overall_ci = (np.nan, np.nan)
1150
+ overall_p = np.nan
1151
+ else:
1152
+ overall_se = float(np.std(bootstrap_overall, ddof=1))
1153
+ overall_ci = self._compute_percentile_ci(bootstrap_overall, self.alpha)
1154
+ overall_p = self._compute_bootstrap_pvalue(
1155
+ original_overall_att, bootstrap_overall
1156
+ )
1157
+
1158
+ return SABootstrapResults(
1159
+ n_bootstrap=self.n_bootstrap,
1160
+ weight_type="pairs",
1161
+ alpha=self.alpha,
1162
+ overall_att_se=overall_se,
1163
+ overall_att_ci=overall_ci,
1164
+ overall_att_p_value=overall_p,
1165
+ event_study_ses=event_study_ses,
1166
+ event_study_cis=event_study_cis,
1167
+ event_study_p_values=event_study_p_values,
1168
+ bootstrap_distribution=bootstrap_overall,
1169
+ )
1170
+
1171
+ def _compute_percentile_ci(
1172
+ self,
1173
+ boot_dist: np.ndarray,
1174
+ alpha: float,
1175
+ ) -> Tuple[float, float]:
1176
+ """Compute percentile confidence interval."""
1177
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
1178
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
1179
+ return (lower, upper)
1180
+
1181
+ def _compute_bootstrap_pvalue(
1182
+ self,
1183
+ original_effect: float,
1184
+ boot_dist: np.ndarray,
1185
+ ) -> float:
1186
+ """Compute two-sided bootstrap p-value."""
1187
+ if original_effect >= 0:
1188
+ p_one_sided = float(np.mean(boot_dist <= 0))
1189
+ else:
1190
+ p_one_sided = float(np.mean(boot_dist >= 0))
1191
+
1192
+ p_value = min(2 * p_one_sided, 1.0)
1193
+ p_value = max(p_value, 1 / (self.n_bootstrap + 1))
1194
+
1195
+ return p_value
1196
+
1197
+ def get_params(self) -> Dict[str, Any]:
1198
+ """Get estimator parameters (sklearn-compatible)."""
1199
+ return {
1200
+ "control_group": self.control_group,
1201
+ "anticipation": self.anticipation,
1202
+ "alpha": self.alpha,
1203
+ "cluster": self.cluster,
1204
+ "n_bootstrap": self.n_bootstrap,
1205
+ "seed": self.seed,
1206
+ "rank_deficient_action": self.rank_deficient_action,
1207
+ }
1208
+
1209
+ def set_params(self, **params) -> "SunAbraham":
1210
+ """Set estimator parameters (sklearn-compatible)."""
1211
+ for key, value in params.items():
1212
+ if hasattr(self, key):
1213
+ setattr(self, key, value)
1214
+ else:
1215
+ raise ValueError(f"Unknown parameter: {key}")
1216
+ return self
1217
+
1218
+ def summary(self) -> str:
1219
+ """Get summary of estimation results."""
1220
+ if not self.is_fitted_:
1221
+ raise RuntimeError("Model must be fitted before calling summary()")
1222
+ assert self.results_ is not None
1223
+ return self.results_.summary()
1224
+
1225
+ def print_summary(self) -> None:
1226
+ """Print summary to stdout."""
1227
+ print(self.summary())