diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2480 @@
1
+ """
2
+ Borusyak-Jaravel-Spiess (2024) Imputation DiD Estimator.
3
+
4
+ Implements the efficient imputation estimator for staggered
5
+ Difference-in-Differences from Borusyak, Jaravel & Spiess (2024),
6
+ "Revisiting Event-Study Designs: Robust and Efficient Estimation",
7
+ Review of Economic Studies.
8
+
9
+ The estimator:
10
+ 1. Runs OLS on untreated observations to estimate unit + time fixed effects
11
+ 2. Imputes counterfactual Y(0) for treated observations
12
+ 3. Aggregates imputed treatment effects with researcher-chosen weights
13
+
14
+ Inference uses the conservative clustered variance estimator (Theorem 3).
15
+ """
16
+
17
+ import warnings
18
+ from dataclasses import dataclass, field
19
+ from typing import Any, Dict, List, Optional, Set, Tuple
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ from scipy import sparse, stats
24
+ from scipy.sparse.linalg import spsolve
25
+
26
+ from diff_diff.linalg import solve_ols
27
+ from diff_diff.results import _get_significance_stars
28
+ from diff_diff.utils import compute_confidence_interval, compute_p_value
29
+
30
+ # =============================================================================
31
+ # Results Dataclasses
32
+ # =============================================================================
33
+
34
+
35
+ @dataclass
36
+ class ImputationBootstrapResults:
37
+ """
38
+ Results from ImputationDiD bootstrap inference.
39
+
40
+ Bootstrap is a library extension beyond Borusyak et al. (2024), which
41
+ proposes only analytical inference via the conservative variance estimator.
42
+ Provided for consistency with CallawaySantAnna and SunAbraham.
43
+
44
+ Attributes
45
+ ----------
46
+ n_bootstrap : int
47
+ Number of bootstrap iterations.
48
+ weight_type : str
49
+ Type of bootstrap weights (currently "rademacher" only).
50
+ alpha : float
51
+ Significance level used for confidence intervals.
52
+ overall_att_se : float
53
+ Bootstrap standard error for overall ATT.
54
+ overall_att_ci : tuple
55
+ Bootstrap confidence interval for overall ATT.
56
+ overall_att_p_value : float
57
+ Bootstrap p-value for overall ATT.
58
+ event_study_ses : dict, optional
59
+ Bootstrap SEs for event study effects.
60
+ event_study_cis : dict, optional
61
+ Bootstrap CIs for event study effects.
62
+ event_study_p_values : dict, optional
63
+ Bootstrap p-values for event study effects.
64
+ group_ses : dict, optional
65
+ Bootstrap SEs for group effects.
66
+ group_cis : dict, optional
67
+ Bootstrap CIs for group effects.
68
+ group_p_values : dict, optional
69
+ Bootstrap p-values for group effects.
70
+ bootstrap_distribution : np.ndarray, optional
71
+ Full bootstrap distribution of overall ATT.
72
+ """
73
+
74
+ n_bootstrap: int
75
+ weight_type: str
76
+ alpha: float
77
+ overall_att_se: float
78
+ overall_att_ci: Tuple[float, float]
79
+ overall_att_p_value: float
80
+ event_study_ses: Optional[Dict[int, float]] = None
81
+ event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
82
+ event_study_p_values: Optional[Dict[int, float]] = None
83
+ group_ses: Optional[Dict[Any, float]] = None
84
+ group_cis: Optional[Dict[Any, Tuple[float, float]]] = None
85
+ group_p_values: Optional[Dict[Any, float]] = None
86
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
87
+
88
+
89
+ @dataclass
90
+ class ImputationDiDResults:
91
+ """
92
+ Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation.
93
+
94
+ Attributes
95
+ ----------
96
+ treatment_effects : pd.DataFrame
97
+ Unit-level treatment effects with columns: unit, time, tau_hat, weight.
98
+ overall_att : float
99
+ Overall average treatment effect on the treated.
100
+ overall_se : float
101
+ Standard error of overall ATT.
102
+ overall_t_stat : float
103
+ T-statistic for overall ATT.
104
+ overall_p_value : float
105
+ P-value for overall ATT.
106
+ overall_conf_int : tuple
107
+ Confidence interval for overall ATT.
108
+ event_study_effects : dict, optional
109
+ Dictionary mapping relative time h to effect dict with keys:
110
+ 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'.
111
+ group_effects : dict, optional
112
+ Dictionary mapping cohort g to effect dict.
113
+ groups : list
114
+ List of treatment cohorts.
115
+ time_periods : list
116
+ List of all time periods.
117
+ n_obs : int
118
+ Total number of observations.
119
+ n_treated_obs : int
120
+ Number of treated observations (|Omega_1|).
121
+ n_untreated_obs : int
122
+ Number of untreated observations (|Omega_0|).
123
+ n_treated_units : int
124
+ Number of ever-treated units.
125
+ n_control_units : int
126
+ Number of units contributing to Omega_0.
127
+ alpha : float
128
+ Significance level used.
129
+ pretrend_results : dict, optional
130
+ Populated by pretrend_test().
131
+ bootstrap_results : ImputationBootstrapResults, optional
132
+ Bootstrap inference results.
133
+ """
134
+
135
+ treatment_effects: pd.DataFrame
136
+ overall_att: float
137
+ overall_se: float
138
+ overall_t_stat: float
139
+ overall_p_value: float
140
+ overall_conf_int: Tuple[float, float]
141
+ event_study_effects: Optional[Dict[int, Dict[str, Any]]]
142
+ group_effects: Optional[Dict[Any, Dict[str, Any]]]
143
+ groups: List[Any]
144
+ time_periods: List[Any]
145
+ n_obs: int
146
+ n_treated_obs: int
147
+ n_untreated_obs: int
148
+ n_treated_units: int
149
+ n_control_units: int
150
+ alpha: float = 0.05
151
+ pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False)
152
+ bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False)
153
+ # Internal: stores data needed for pretrend_test()
154
+ _estimator_ref: Optional[Any] = field(default=None, repr=False)
155
+
156
+ def __repr__(self) -> str:
157
+ """Concise string representation."""
158
+ sig = _get_significance_stars(self.overall_p_value)
159
+ return (
160
+ f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, "
161
+ f"SE={self.overall_se:.4f}, "
162
+ f"n_groups={len(self.groups)}, "
163
+ f"n_treated_obs={self.n_treated_obs})"
164
+ )
165
+
166
+ def summary(self, alpha: Optional[float] = None) -> str:
167
+ """
168
+ Generate formatted summary of estimation results.
169
+
170
+ Parameters
171
+ ----------
172
+ alpha : float, optional
173
+ Significance level. Defaults to alpha used in estimation.
174
+
175
+ Returns
176
+ -------
177
+ str
178
+ Formatted summary.
179
+ """
180
+ alpha = alpha or self.alpha
181
+ conf_level = int((1 - alpha) * 100)
182
+
183
+ lines = [
184
+ "=" * 85,
185
+ "Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85),
186
+ "=" * 85,
187
+ "",
188
+ f"{'Total observations:':<30} {self.n_obs:>10}",
189
+ f"{'Treated observations:':<30} {self.n_treated_obs:>10}",
190
+ f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}",
191
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
192
+ f"{'Control units:':<30} {self.n_control_units:>10}",
193
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
194
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
195
+ "",
196
+ ]
197
+
198
+ # Overall ATT
199
+ lines.extend(
200
+ [
201
+ "-" * 85,
202
+ "Overall Average Treatment Effect on the Treated".center(85),
203
+ "-" * 85,
204
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
205
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
206
+ "-" * 85,
207
+ ]
208
+ )
209
+
210
+ t_str = (
211
+ f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}"
212
+ )
213
+ p_str = (
214
+ f"{self.overall_p_value:>10.4f}"
215
+ if np.isfinite(self.overall_p_value)
216
+ else f"{'NaN':>10}"
217
+ )
218
+ sig = _get_significance_stars(self.overall_p_value)
219
+
220
+ lines.extend(
221
+ [
222
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
223
+ f"{t_str} {p_str} {sig:>6}",
224
+ "-" * 85,
225
+ "",
226
+ f"{conf_level}% Confidence Interval: "
227
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
228
+ "",
229
+ ]
230
+ )
231
+
232
+ # Event study effects
233
+ if self.event_study_effects:
234
+ lines.extend(
235
+ [
236
+ "-" * 85,
237
+ "Event Study (Dynamic) Effects".center(85),
238
+ "-" * 85,
239
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
240
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
241
+ "-" * 85,
242
+ ]
243
+ )
244
+
245
+ for h in sorted(self.event_study_effects.keys()):
246
+ eff = self.event_study_effects[h]
247
+ if eff.get("n_obs", 1) == 0:
248
+ # Reference period marker
249
+ lines.append(
250
+ f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}"
251
+ )
252
+ elif np.isnan(eff["effect"]):
253
+ lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}")
254
+ else:
255
+ e_sig = _get_significance_stars(eff["p_value"])
256
+ e_t = (
257
+ f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}"
258
+ )
259
+ e_p = (
260
+ f"{eff['p_value']:>10.4f}"
261
+ if np.isfinite(eff["p_value"])
262
+ else f"{'NaN':>10}"
263
+ )
264
+ lines.append(
265
+ f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
266
+ f"{e_t} {e_p} {e_sig:>6}"
267
+ )
268
+
269
+ lines.extend(["-" * 85, ""])
270
+
271
+ # Group effects
272
+ if self.group_effects:
273
+ lines.extend(
274
+ [
275
+ "-" * 85,
276
+ "Group (Cohort) Effects".center(85),
277
+ "-" * 85,
278
+ f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} "
279
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
280
+ "-" * 85,
281
+ ]
282
+ )
283
+
284
+ for g in sorted(self.group_effects.keys()):
285
+ eff = self.group_effects[g]
286
+ if np.isnan(eff["effect"]):
287
+ lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}")
288
+ else:
289
+ g_sig = _get_significance_stars(eff["p_value"])
290
+ g_t = (
291
+ f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}"
292
+ )
293
+ g_p = (
294
+ f"{eff['p_value']:>10.4f}"
295
+ if np.isfinite(eff["p_value"])
296
+ else f"{'NaN':>10}"
297
+ )
298
+ lines.append(
299
+ f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
300
+ f"{g_t} {g_p} {g_sig:>6}"
301
+ )
302
+
303
+ lines.extend(["-" * 85, ""])
304
+
305
+ # Pre-trend test
306
+ if self.pretrend_results is not None:
307
+ pt = self.pretrend_results
308
+ lines.extend(
309
+ [
310
+ "-" * 85,
311
+ "Pre-Trend Test (Equation 9)".center(85),
312
+ "-" * 85,
313
+ f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}",
314
+ f"{'P-value:':<30} {pt['p_value']:>10.4f}",
315
+ f"{'Degrees of freedom:':<30} {pt['df']:>10}",
316
+ f"{'Number of leads:':<30} {pt['n_leads']:>10}",
317
+ "-" * 85,
318
+ "",
319
+ ]
320
+ )
321
+
322
+ lines.extend(
323
+ [
324
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
325
+ "=" * 85,
326
+ ]
327
+ )
328
+
329
+ return "\n".join(lines)
330
+
331
+ def print_summary(self, alpha: Optional[float] = None) -> None:
332
+ """Print summary to stdout."""
333
+ print(self.summary(alpha))
334
+
335
+ def to_dataframe(self, level: str = "observation") -> pd.DataFrame:
336
+ """
337
+ Convert results to DataFrame.
338
+
339
+ Parameters
340
+ ----------
341
+ level : str, default="observation"
342
+ Level of aggregation:
343
+ - "observation": Unit-level treatment effects
344
+ - "event_study": Event study effects by relative time
345
+ - "group": Group (cohort) effects
346
+
347
+ Returns
348
+ -------
349
+ pd.DataFrame
350
+ Results as DataFrame.
351
+ """
352
+ if level == "observation":
353
+ return self.treatment_effects.copy()
354
+
355
+ elif level == "event_study":
356
+ if self.event_study_effects is None:
357
+ raise ValueError(
358
+ "Event study effects not computed. "
359
+ "Use aggregate='event_study' or aggregate='all'."
360
+ )
361
+ rows = []
362
+ for h, data in sorted(self.event_study_effects.items()):
363
+ rows.append(
364
+ {
365
+ "relative_period": h,
366
+ "effect": data["effect"],
367
+ "se": data["se"],
368
+ "t_stat": data["t_stat"],
369
+ "p_value": data["p_value"],
370
+ "conf_int_lower": data["conf_int"][0],
371
+ "conf_int_upper": data["conf_int"][1],
372
+ "n_obs": data.get("n_obs", np.nan),
373
+ }
374
+ )
375
+ return pd.DataFrame(rows)
376
+
377
+ elif level == "group":
378
+ if self.group_effects is None:
379
+ raise ValueError(
380
+ "Group effects not computed. " "Use aggregate='group' or aggregate='all'."
381
+ )
382
+ rows = []
383
+ for g, data in sorted(self.group_effects.items()):
384
+ rows.append(
385
+ {
386
+ "group": g,
387
+ "effect": data["effect"],
388
+ "se": data["se"],
389
+ "t_stat": data["t_stat"],
390
+ "p_value": data["p_value"],
391
+ "conf_int_lower": data["conf_int"][0],
392
+ "conf_int_upper": data["conf_int"][1],
393
+ "n_obs": data.get("n_obs", np.nan),
394
+ }
395
+ )
396
+ return pd.DataFrame(rows)
397
+
398
+ else:
399
+ raise ValueError(
400
+ f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'."
401
+ )
402
+
403
+ def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
404
+ """
405
+ Run a pre-trend test (Equation 9 of Borusyak et al. 2024).
406
+
407
+ Adds pre-treatment lead indicators to the Step 1 OLS and tests
408
+ their joint significance via a cluster-robust Wald F-test.
409
+
410
+ Parameters
411
+ ----------
412
+ n_leads : int, optional
413
+ Number of pre-treatment leads to include. If None, uses all
414
+ available pre-treatment periods minus one (for the reference period).
415
+
416
+ Returns
417
+ -------
418
+ dict
419
+ Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads',
420
+ 'lead_coefficients'.
421
+ """
422
+ if self._estimator_ref is None:
423
+ raise RuntimeError(
424
+ "Pre-trend test requires internal estimator reference. "
425
+ "Re-fit the model to use this method."
426
+ )
427
+ result = self._estimator_ref._pretrend_test(n_leads=n_leads)
428
+ self.pretrend_results = result
429
+ return result
430
+
431
+ @property
432
+ def is_significant(self) -> bool:
433
+ """Check if overall ATT is significant."""
434
+ return bool(self.overall_p_value < self.alpha)
435
+
436
+ @property
437
+ def significance_stars(self) -> str:
438
+ """Significance stars for overall ATT."""
439
+ return _get_significance_stars(self.overall_p_value)
440
+
441
+
442
+ # =============================================================================
443
+ # Main Estimator
444
+ # =============================================================================
445
+
446
+
447
+ class ImputationDiD:
448
+ """
449
+ Borusyak-Jaravel-Spiess (2024) imputation DiD estimator.
450
+
451
+ This is the efficient estimator for staggered Difference-in-Differences
452
+ under parallel trends. It produces shorter confidence intervals than
453
+ Callaway-Sant'Anna (~50% shorter) and Sun-Abraham (2-3.5x shorter)
454
+ under homogeneous treatment effects.
455
+
456
+ The estimation procedure:
457
+ 1. Run OLS on untreated observations to estimate unit + time fixed effects
458
+ 2. Impute counterfactual Y(0) for treated observations
459
+ 3. Aggregate imputed treatment effects with researcher-chosen weights
460
+
461
+ Inference uses the conservative clustered variance estimator from Theorem 3
462
+ of the paper.
463
+
464
+ Parameters
465
+ ----------
466
+ anticipation : int, default=0
467
+ Number of periods before treatment where effects may occur.
468
+ alpha : float, default=0.05
469
+ Significance level for confidence intervals.
470
+ cluster : str, optional
471
+ Column name for cluster-robust standard errors.
472
+ If None, clusters at the unit level by default.
473
+ n_bootstrap : int, default=0
474
+ Number of bootstrap iterations. If 0, uses analytical inference
475
+ (conservative variance from Theorem 3).
476
+ seed : int, optional
477
+ Random seed for reproducibility.
478
+ rank_deficient_action : str, default="warn"
479
+ Action when design matrix is rank-deficient:
480
+ - "warn": Issue warning and drop linearly dependent columns
481
+ - "error": Raise ValueError
482
+ - "silent": Drop columns silently
483
+ horizon_max : int, optional
484
+ Maximum event-study horizon. If set, event study effects are only
485
+ computed for |h| <= horizon_max.
486
+ aux_partition : str, default="cohort_horizon"
487
+ Controls the auxiliary model partition for Theorem 3 variance:
488
+ - "cohort_horizon": Groups by cohort x relative time (tightest SEs)
489
+ - "cohort": Groups by cohort only (more conservative)
490
+ - "horizon": Groups by relative time only (more conservative)
491
+
492
+ Attributes
493
+ ----------
494
+ results_ : ImputationDiDResults
495
+ Estimation results after calling fit().
496
+ is_fitted_ : bool
497
+ Whether the model has been fitted.
498
+
499
+ Examples
500
+ --------
501
+ Basic usage:
502
+
503
+ >>> from diff_diff import ImputationDiD, generate_staggered_data
504
+ >>> data = generate_staggered_data(n_units=200, seed=42)
505
+ >>> est = ImputationDiD()
506
+ >>> results = est.fit(data, outcome='outcome', unit='unit',
507
+ ... time='time', first_treat='first_treat')
508
+ >>> results.print_summary()
509
+
510
+ With event study:
511
+
512
+ >>> est = ImputationDiD()
513
+ >>> results = est.fit(data, outcome='outcome', unit='unit',
514
+ ... time='time', first_treat='first_treat',
515
+ ... aggregate='event_study')
516
+ >>> from diff_diff import plot_event_study
517
+ >>> plot_event_study(results)
518
+
519
+ Notes
520
+ -----
521
+ The imputation estimator uses ALL untreated observations (never-treated +
522
+ not-yet-treated periods of eventually-treated units) to estimate the
523
+ counterfactual model. There is no ``control_group`` parameter because this
524
+ is fundamental to the method's efficiency.
525
+
526
+ References
527
+ ----------
528
+ Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event-Study
529
+ Designs: Robust and Efficient Estimation. Review of Economic Studies,
530
+ 91(6), 3253-3285.
531
+ """
532
+
533
+ def __init__(
534
+ self,
535
+ anticipation: int = 0,
536
+ alpha: float = 0.05,
537
+ cluster: Optional[str] = None,
538
+ n_bootstrap: int = 0,
539
+ seed: Optional[int] = None,
540
+ rank_deficient_action: str = "warn",
541
+ horizon_max: Optional[int] = None,
542
+ aux_partition: str = "cohort_horizon",
543
+ ):
544
+ if rank_deficient_action not in ("warn", "error", "silent"):
545
+ raise ValueError(
546
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
547
+ f"got '{rank_deficient_action}'"
548
+ )
549
+ if aux_partition not in ("cohort_horizon", "cohort", "horizon"):
550
+ raise ValueError(
551
+ f"aux_partition must be 'cohort_horizon', 'cohort', or 'horizon', "
552
+ f"got '{aux_partition}'"
553
+ )
554
+
555
+ self.anticipation = anticipation
556
+ self.alpha = alpha
557
+ self.cluster = cluster
558
+ self.n_bootstrap = n_bootstrap
559
+ self.seed = seed
560
+ self.rank_deficient_action = rank_deficient_action
561
+ self.horizon_max = horizon_max
562
+ self.aux_partition = aux_partition
563
+
564
+ self.is_fitted_ = False
565
+ self.results_: Optional[ImputationDiDResults] = None
566
+
567
+ # Internal state preserved for pretrend_test()
568
+ self._fit_data: Optional[Dict[str, Any]] = None
569
+
570
+ def fit(
571
+ self,
572
+ data: pd.DataFrame,
573
+ outcome: str,
574
+ unit: str,
575
+ time: str,
576
+ first_treat: str,
577
+ covariates: Optional[List[str]] = None,
578
+ aggregate: Optional[str] = None,
579
+ balance_e: Optional[int] = None,
580
+ ) -> ImputationDiDResults:
581
+ """
582
+ Fit the imputation DiD estimator.
583
+
584
+ Parameters
585
+ ----------
586
+ data : pd.DataFrame
587
+ Panel data with unit and time identifiers.
588
+ outcome : str
589
+ Name of outcome variable column.
590
+ unit : str
591
+ Name of unit identifier column.
592
+ time : str
593
+ Name of time period column.
594
+ first_treat : str
595
+ Name of column indicating when unit was first treated.
596
+ Use 0 (or np.inf) for never-treated units.
597
+ covariates : list of str, optional
598
+ List of covariate column names.
599
+ aggregate : str, optional
600
+ Aggregation mode: None/"simple" (overall ATT only),
601
+ "event_study", "group", or "all".
602
+ balance_e : int, optional
603
+ When computing event study, restrict to cohorts observed at all
604
+ relative times in [-balance_e, max_h].
605
+
606
+ Returns
607
+ -------
608
+ ImputationDiDResults
609
+ Object containing all estimation results.
610
+
611
+ Raises
612
+ ------
613
+ ValueError
614
+ If required columns are missing or data validation fails.
615
+ """
616
+ # Validate inputs
617
+ required_cols = [outcome, unit, time, first_treat]
618
+ if covariates:
619
+ required_cols.extend(covariates)
620
+
621
+ missing = [c for c in required_cols if c not in data.columns]
622
+ if missing:
623
+ raise ValueError(f"Missing columns: {missing}")
624
+
625
+ # Create working copy
626
+ df = data.copy()
627
+
628
+ # Ensure numeric types
629
+ df[time] = pd.to_numeric(df[time])
630
+ df[first_treat] = pd.to_numeric(df[first_treat])
631
+
632
+ # Validate absorbing treatment: first_treat must be constant within each unit
633
+ ft_nunique = df.groupby(unit)[first_treat].nunique()
634
+ non_constant = ft_nunique[ft_nunique > 1]
635
+ if len(non_constant) > 0:
636
+ example_unit = non_constant.index[0]
637
+ example_vals = sorted(df.loc[df[unit] == example_unit, first_treat].unique())
638
+ warnings.warn(
639
+ f"{len(non_constant)} unit(s) have non-constant '{first_treat}' "
640
+ f"values (e.g., unit '{example_unit}' has values {example_vals}). "
641
+ f"ImputationDiD assumes treatment is an absorbing state "
642
+ f"(once treated, always treated) with a single treatment onset "
643
+ f"time per unit. Non-constant first_treat violates this assumption "
644
+ f"and may produce unreliable estimates.",
645
+ UserWarning,
646
+ stacklevel=2,
647
+ )
648
+
649
+ # Coerce to per-unit value so downstream code
650
+ # (_never_treated, _treated, _rel_time) uses a single
651
+ # consistent first_treat per unit.
652
+ df[first_treat] = df.groupby(unit)[first_treat].transform("first")
653
+
654
+ # Identify treatment status
655
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
656
+
657
+ # Check for always-treated units (treated in all observed periods)
658
+ min_time = df[time].min()
659
+ always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time)
660
+ n_always_treated = df.loc[always_treated_mask, unit].nunique()
661
+ if n_always_treated > 0:
662
+ warnings.warn(
663
+ f"{n_always_treated} unit(s) are treated in all observed periods "
664
+ f"(first_treat <= {min_time}). These units have no untreated "
665
+ "observations and cannot contribute to the counterfactual model. "
666
+ "Their treatment effects will be imputed but may be unreliable.",
667
+ UserWarning,
668
+ stacklevel=2,
669
+ )
670
+
671
+ # Create treatment indicator D_it
672
+ # D_it = 1 if t >= first_treat and first_treat > 0
673
+ # With anticipation: D_it = 1 if t >= first_treat - anticipation
674
+ effective_treat = df[first_treat] - self.anticipation
675
+ df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat)
676
+
677
+ # Identify Omega_0 (untreated) and Omega_1 (treated)
678
+ omega_0_mask = ~df["_treated"]
679
+ omega_1_mask = df["_treated"]
680
+
681
+ n_omega_0 = int(omega_0_mask.sum())
682
+ n_omega_1 = int(omega_1_mask.sum())
683
+
684
+ if n_omega_0 == 0:
685
+ raise ValueError(
686
+ "No untreated observations found. Cannot estimate counterfactual model."
687
+ )
688
+ if n_omega_1 == 0:
689
+ raise ValueError("No treated observations found. Nothing to estimate.")
690
+
691
+ # Identify groups and time periods
692
+ time_periods = sorted(df[time].unique())
693
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf])
694
+
695
+ if len(treatment_groups) == 0:
696
+ raise ValueError("No treated units found. Check 'first_treat' column.")
697
+
698
+ # Unit info
699
+ unit_info = (
700
+ df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index()
701
+ )
702
+ n_treated_units = int((~unit_info["_never_treated"]).sum())
703
+ # Control units = units with at least one untreated observation
704
+ units_in_omega_0 = df.loc[omega_0_mask, unit].unique()
705
+ n_control_units = len(units_in_omega_0)
706
+
707
+ # Cluster variable
708
+ cluster_var = self.cluster if self.cluster is not None else unit
709
+ if self.cluster is not None and self.cluster not in df.columns:
710
+ raise ValueError(
711
+ f"Cluster column '{self.cluster}' not found in data. "
712
+ f"Available columns: {list(df.columns)}"
713
+ )
714
+
715
+ # Compute relative time
716
+ df["_rel_time"] = np.where(
717
+ ~df["_never_treated"],
718
+ df[time] - df[first_treat],
719
+ np.nan,
720
+ )
721
+
722
+ # ---- Step 1: OLS on untreated observations ----
723
+ unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model(
724
+ df, outcome, unit, time, covariates, omega_0_mask
725
+ )
726
+
727
+ # ---- Rank condition checks ----
728
+ # Check: every treated unit should have >= 1 untreated period (for unit FE)
729
+ treated_unit_ids = df.loc[omega_1_mask, unit].unique()
730
+ units_with_fe = set(unit_fe.keys())
731
+ units_missing_fe = set(treated_unit_ids) - units_with_fe
732
+
733
+ # Check: every post-treatment period should have >= 1 untreated unit (for time FE)
734
+ post_period_ids = df.loc[omega_1_mask, time].unique()
735
+ periods_with_fe = set(time_fe.keys())
736
+ periods_missing_fe = set(post_period_ids) - periods_with_fe
737
+
738
+ if units_missing_fe or periods_missing_fe:
739
+ parts = []
740
+ if units_missing_fe:
741
+ sorted_missing = sorted(units_missing_fe)
742
+ parts.append(
743
+ f"{len(units_missing_fe)} treated unit(s) have no untreated "
744
+ f"periods (units: {sorted_missing[:5]}"
745
+ f"{'...' if len(units_missing_fe) > 5 else ''})"
746
+ )
747
+ if periods_missing_fe:
748
+ sorted_missing = sorted(periods_missing_fe)
749
+ parts.append(
750
+ f"{len(periods_missing_fe)} post-treatment period(s) have no "
751
+ f"untreated units (periods: {sorted_missing[:5]}"
752
+ f"{'...' if len(periods_missing_fe) > 5 else ''})"
753
+ )
754
+ msg = (
755
+ "Rank condition violated: "
756
+ + "; ".join(parts)
757
+ + ". Affected treatment effects will be NaN."
758
+ )
759
+ if self.rank_deficient_action == "error":
760
+ raise ValueError(msg)
761
+ elif self.rank_deficient_action == "warn":
762
+ warnings.warn(msg, UserWarning, stacklevel=2)
763
+ # "silent": continue without warning
764
+
765
+ # ---- Step 2: Impute treatment effects ----
766
+ tau_hat, y_hat_0 = self._impute_treatment_effects(
767
+ df,
768
+ outcome,
769
+ unit,
770
+ time,
771
+ covariates,
772
+ omega_1_mask,
773
+ unit_fe,
774
+ time_fe,
775
+ grand_mean,
776
+ delta_hat,
777
+ )
778
+
779
+ # Store tau_hat in dataframe
780
+ df["_tau_hat"] = np.nan
781
+ df.loc[omega_1_mask, "_tau_hat"] = tau_hat
782
+
783
+ # ---- Step 3: Aggregate ----
784
+ # Always compute overall ATT (simple aggregation)
785
+ valid_tau = tau_hat[np.isfinite(tau_hat)]
786
+
787
+ if len(valid_tau) == 0:
788
+ overall_att = np.nan
789
+ else:
790
+ overall_att = float(np.mean(valid_tau))
791
+
792
+ # ---- Conservative variance (Theorem 3) ----
793
+ # Build weights matching the ATT: uniform over finite tau_hat, zero for NaN
794
+ overall_weights = np.zeros(n_omega_1)
795
+ finite_mask = np.isfinite(tau_hat)
796
+ n_valid = int(finite_mask.sum())
797
+ if n_valid > 0:
798
+ overall_weights[finite_mask] = 1.0 / n_valid
799
+
800
+ if n_valid == 0:
801
+ overall_se = np.nan
802
+ else:
803
+ overall_se = self._compute_conservative_variance(
804
+ df=df,
805
+ outcome=outcome,
806
+ unit=unit,
807
+ time=time,
808
+ first_treat=first_treat,
809
+ covariates=covariates,
810
+ omega_0_mask=omega_0_mask,
811
+ omega_1_mask=omega_1_mask,
812
+ unit_fe=unit_fe,
813
+ time_fe=time_fe,
814
+ grand_mean=grand_mean,
815
+ delta_hat=delta_hat,
816
+ weights=overall_weights,
817
+ cluster_var=cluster_var,
818
+ kept_cov_mask=kept_cov_mask,
819
+ )
820
+
821
+ overall_t = (
822
+ overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
823
+ )
824
+ overall_p = compute_p_value(overall_t)
825
+ overall_ci = (
826
+ compute_confidence_interval(overall_att, overall_se, self.alpha)
827
+ if np.isfinite(overall_se) and overall_se > 0
828
+ else (np.nan, np.nan)
829
+ )
830
+
831
+ # Event study and group aggregation
832
+ event_study_effects = None
833
+ group_effects = None
834
+
835
+ if aggregate in ("event_study", "all"):
836
+ event_study_effects = self._aggregate_event_study(
837
+ df=df,
838
+ outcome=outcome,
839
+ unit=unit,
840
+ time=time,
841
+ first_treat=first_treat,
842
+ covariates=covariates,
843
+ omega_0_mask=omega_0_mask,
844
+ omega_1_mask=omega_1_mask,
845
+ unit_fe=unit_fe,
846
+ time_fe=time_fe,
847
+ grand_mean=grand_mean,
848
+ delta_hat=delta_hat,
849
+ cluster_var=cluster_var,
850
+ treatment_groups=treatment_groups,
851
+ balance_e=balance_e,
852
+ kept_cov_mask=kept_cov_mask,
853
+ )
854
+
855
+ if aggregate in ("group", "all"):
856
+ group_effects = self._aggregate_group(
857
+ df=df,
858
+ outcome=outcome,
859
+ unit=unit,
860
+ time=time,
861
+ first_treat=first_treat,
862
+ covariates=covariates,
863
+ omega_0_mask=omega_0_mask,
864
+ omega_1_mask=omega_1_mask,
865
+ unit_fe=unit_fe,
866
+ time_fe=time_fe,
867
+ grand_mean=grand_mean,
868
+ delta_hat=delta_hat,
869
+ cluster_var=cluster_var,
870
+ treatment_groups=treatment_groups,
871
+ kept_cov_mask=kept_cov_mask,
872
+ )
873
+
874
+ # Build treatment effects dataframe
875
+ treated_df = df.loc[omega_1_mask, [unit, time, "_tau_hat", "_rel_time"]].copy()
876
+ treated_df = treated_df.rename(columns={"_tau_hat": "tau_hat", "_rel_time": "rel_time"})
877
+ # Weights consistent with actual ATT: zero for NaN tau_hat, 1/n_valid for finite
878
+ tau_finite = treated_df["tau_hat"].notna()
879
+ n_valid_te = int(tau_finite.sum())
880
+ if n_valid_te > 0:
881
+ treated_df["weight"] = np.where(tau_finite, 1.0 / n_valid_te, 0.0)
882
+ else:
883
+ treated_df["weight"] = 0.0
884
+
885
+ # Store fit data for pretrend_test
886
+ self._fit_data = {
887
+ "df": df,
888
+ "outcome": outcome,
889
+ "unit": unit,
890
+ "time": time,
891
+ "first_treat": first_treat,
892
+ "covariates": covariates,
893
+ "omega_0_mask": omega_0_mask,
894
+ "omega_1_mask": omega_1_mask,
895
+ "cluster_var": cluster_var,
896
+ "unit_fe": unit_fe,
897
+ "time_fe": time_fe,
898
+ "grand_mean": grand_mean,
899
+ "delta_hat": delta_hat,
900
+ "kept_cov_mask": kept_cov_mask,
901
+ }
902
+
903
+ # Pre-compute cluster psi sums for bootstrap
904
+ psi_data = None
905
+ if self.n_bootstrap > 0 and n_valid > 0:
906
+ try:
907
+ psi_data = self._precompute_bootstrap_psi(
908
+ df=df,
909
+ outcome=outcome,
910
+ unit=unit,
911
+ time=time,
912
+ first_treat=first_treat,
913
+ covariates=covariates,
914
+ omega_0_mask=omega_0_mask,
915
+ omega_1_mask=omega_1_mask,
916
+ unit_fe=unit_fe,
917
+ time_fe=time_fe,
918
+ grand_mean=grand_mean,
919
+ delta_hat=delta_hat,
920
+ cluster_var=cluster_var,
921
+ kept_cov_mask=kept_cov_mask,
922
+ overall_weights=overall_weights,
923
+ event_study_effects=event_study_effects,
924
+ group_effects=group_effects,
925
+ treatment_groups=treatment_groups,
926
+ tau_hat=tau_hat,
927
+ balance_e=balance_e,
928
+ )
929
+ except Exception as e:
930
+ warnings.warn(
931
+ f"Bootstrap pre-computation failed: {e}. " "Skipping bootstrap inference.",
932
+ UserWarning,
933
+ stacklevel=2,
934
+ )
935
+ psi_data = None
936
+
937
+ # Bootstrap
938
+ bootstrap_results = None
939
+ if self.n_bootstrap > 0 and psi_data is not None:
940
+ bootstrap_results = self._run_bootstrap(
941
+ original_att=overall_att,
942
+ original_event_study=event_study_effects,
943
+ original_group=group_effects,
944
+ psi_data=psi_data,
945
+ )
946
+
947
+ # Update inference with bootstrap results
948
+ overall_se = bootstrap_results.overall_att_se
949
+ overall_t = (
950
+ overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
951
+ )
952
+ overall_p = bootstrap_results.overall_att_p_value
953
+ overall_ci = bootstrap_results.overall_att_ci
954
+
955
+ # Update event study
956
+ if event_study_effects and bootstrap_results.event_study_ses:
957
+ for h in event_study_effects:
958
+ if (
959
+ h in bootstrap_results.event_study_ses
960
+ and event_study_effects[h].get("n_obs", 1) > 0
961
+ ):
962
+ event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h]
963
+ event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[h]
964
+ event_study_effects[h]["p_value"] = bootstrap_results.event_study_p_values[
965
+ h
966
+ ]
967
+ eff_val = event_study_effects[h]["effect"]
968
+ se_val = event_study_effects[h]["se"]
969
+ event_study_effects[h]["t_stat"] = (
970
+ eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan
971
+ )
972
+
973
+ # Update group effects
974
+ if group_effects and bootstrap_results.group_ses:
975
+ for g in group_effects:
976
+ if g in bootstrap_results.group_ses:
977
+ group_effects[g]["se"] = bootstrap_results.group_ses[g]
978
+ group_effects[g]["conf_int"] = bootstrap_results.group_cis[g]
979
+ group_effects[g]["p_value"] = bootstrap_results.group_p_values[g]
980
+ eff_val = group_effects[g]["effect"]
981
+ se_val = group_effects[g]["se"]
982
+ group_effects[g]["t_stat"] = (
983
+ eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan
984
+ )
985
+
986
+ # Construct results
987
+ self.results_ = ImputationDiDResults(
988
+ treatment_effects=treated_df,
989
+ overall_att=overall_att,
990
+ overall_se=overall_se,
991
+ overall_t_stat=overall_t,
992
+ overall_p_value=overall_p,
993
+ overall_conf_int=overall_ci,
994
+ event_study_effects=event_study_effects,
995
+ group_effects=group_effects,
996
+ groups=treatment_groups,
997
+ time_periods=time_periods,
998
+ n_obs=len(df),
999
+ n_treated_obs=n_omega_1,
1000
+ n_untreated_obs=n_omega_0,
1001
+ n_treated_units=n_treated_units,
1002
+ n_control_units=n_control_units,
1003
+ alpha=self.alpha,
1004
+ bootstrap_results=bootstrap_results,
1005
+ _estimator_ref=self,
1006
+ )
1007
+
1008
+ self.is_fitted_ = True
1009
+ return self.results_
1010
+
1011
+ # =========================================================================
1012
+ # Step 1: OLS on untreated observations
1013
+ # =========================================================================
1014
+
1015
+ def _iterative_fe(
1016
+ self,
1017
+ y: np.ndarray,
1018
+ unit_vals: np.ndarray,
1019
+ time_vals: np.ndarray,
1020
+ idx: pd.Index,
1021
+ max_iter: int = 100,
1022
+ tol: float = 1e-10,
1023
+ ) -> Tuple[Dict[Any, float], Dict[Any, float]]:
1024
+ """
1025
+ Estimate unit and time FE via iterative alternating projection (Gauss-Seidel).
1026
+
1027
+ Converges to the exact OLS solution for both balanced and unbalanced panels.
1028
+ For balanced panels, converges in 1-2 iterations (identical to one-pass).
1029
+ For unbalanced panels, typically 5-20 iterations.
1030
+
1031
+ Returns
1032
+ -------
1033
+ unit_fe : dict
1034
+ Mapping from unit -> unit fixed effect.
1035
+ time_fe : dict
1036
+ Mapping from time -> time fixed effect.
1037
+ """
1038
+ n = len(y)
1039
+ alpha = np.zeros(n) # unit FE broadcast to obs level
1040
+ beta = np.zeros(n) # time FE broadcast to obs level
1041
+
1042
+ with np.errstate(invalid="ignore", divide="ignore"):
1043
+ for iteration in range(max_iter):
1044
+ # Update time FE: beta_t = mean_i(y_it - alpha_i)
1045
+ resid_after_alpha = y - alpha
1046
+ beta_new = (
1047
+ pd.Series(resid_after_alpha, index=idx)
1048
+ .groupby(time_vals)
1049
+ .transform("mean")
1050
+ .values
1051
+ )
1052
+
1053
+ # Update unit FE: alpha_i = mean_t(y_it - beta_t)
1054
+ resid_after_beta = y - beta_new
1055
+ alpha_new = (
1056
+ pd.Series(resid_after_beta, index=idx)
1057
+ .groupby(unit_vals)
1058
+ .transform("mean")
1059
+ .values
1060
+ )
1061
+
1062
+ # Check convergence on FE changes
1063
+ max_change = max(
1064
+ np.max(np.abs(alpha_new - alpha)),
1065
+ np.max(np.abs(beta_new - beta)),
1066
+ )
1067
+ alpha = alpha_new
1068
+ beta = beta_new
1069
+ if max_change < tol:
1070
+ break
1071
+
1072
+ unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
1073
+ time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
1074
+ return unit_fe, time_fe
1075
+
1076
+ @staticmethod
1077
+ def _iterative_demean(
1078
+ vals: np.ndarray,
1079
+ unit_vals: np.ndarray,
1080
+ time_vals: np.ndarray,
1081
+ idx: pd.Index,
1082
+ max_iter: int = 100,
1083
+ tol: float = 1e-10,
1084
+ ) -> np.ndarray:
1085
+ """Demean a vector by iterative alternating projection (unit + time FE removal).
1086
+
1087
+ Converges to the exact within-transformation for both balanced and
1088
+ unbalanced panels. For balanced panels, converges in 1-2 iterations.
1089
+ """
1090
+ result = vals.copy()
1091
+ with np.errstate(invalid="ignore", divide="ignore"):
1092
+ for _ in range(max_iter):
1093
+ time_means = (
1094
+ pd.Series(result, index=idx).groupby(time_vals).transform("mean").values
1095
+ )
1096
+ result_after_time = result - time_means
1097
+ unit_means = (
1098
+ pd.Series(result_after_time, index=idx)
1099
+ .groupby(unit_vals)
1100
+ .transform("mean")
1101
+ .values
1102
+ )
1103
+ result_new = result_after_time - unit_means
1104
+ if np.max(np.abs(result_new - result)) < tol:
1105
+ result = result_new
1106
+ break
1107
+ result = result_new
1108
+ return result
1109
+
1110
+ @staticmethod
1111
+ def _compute_balanced_cohort_mask(
1112
+ df_treated: pd.DataFrame,
1113
+ first_treat: str,
1114
+ all_horizons: List[int],
1115
+ balance_e: int,
1116
+ cohort_rel_times: Dict[Any, Set[int]],
1117
+ ) -> np.ndarray:
1118
+ """Compute boolean mask selecting treated obs from balanced cohorts.
1119
+
1120
+ A cohort is 'balanced' if it has observations at every relative time
1121
+ in [-balance_e, max(all_horizons)].
1122
+
1123
+ Parameters
1124
+ ----------
1125
+ df_treated : pd.DataFrame
1126
+ Post-treatment observations (Omega_1).
1127
+ first_treat : str
1128
+ Column name for cohort identifier.
1129
+ all_horizons : list of int
1130
+ Post-treatment horizons in the event study.
1131
+ balance_e : int
1132
+ Number of pre-treatment periods to require.
1133
+ cohort_rel_times : dict
1134
+ Maps each cohort value to the set of all observed relative times
1135
+ (including pre-treatment) from the full panel. Built by
1136
+ _build_cohort_rel_times().
1137
+ """
1138
+ if not all_horizons:
1139
+ return np.ones(len(df_treated), dtype=bool)
1140
+
1141
+ max_h = max(all_horizons)
1142
+ required_range = set(range(-balance_e, max_h + 1))
1143
+
1144
+ balanced_cohorts = set()
1145
+ for g, horizons in cohort_rel_times.items():
1146
+ if required_range.issubset(horizons):
1147
+ balanced_cohorts.add(g)
1148
+
1149
+ return df_treated[first_treat].isin(balanced_cohorts).values
1150
+
1151
+ @staticmethod
1152
+ def _build_cohort_rel_times(
1153
+ df: pd.DataFrame,
1154
+ first_treat: str,
1155
+ ) -> Dict[Any, Set[int]]:
1156
+ """Build mapping of cohort -> set of observed relative times from full panel.
1157
+
1158
+ Precondition: df must have '_never_treated' and '_rel_time' columns
1159
+ (set by fit() before any aggregation calls).
1160
+ """
1161
+ treated_mask = ~df["_never_treated"]
1162
+ treated_df = df.loc[treated_mask]
1163
+ result: Dict[Any, Set[int]] = {}
1164
+ ft_vals = treated_df[first_treat].values
1165
+ rt_vals = treated_df["_rel_time"].values
1166
+ for i in range(len(treated_df)):
1167
+ h = rt_vals[i]
1168
+ if np.isfinite(h):
1169
+ result.setdefault(ft_vals[i], set()).add(int(h))
1170
+ return result
1171
+
1172
+ def _fit_untreated_model(
1173
+ self,
1174
+ df: pd.DataFrame,
1175
+ outcome: str,
1176
+ unit: str,
1177
+ time: str,
1178
+ covariates: Optional[List[str]],
1179
+ omega_0_mask: pd.Series,
1180
+ ) -> Tuple[
1181
+ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray]
1182
+ ]:
1183
+ """
1184
+ Step 1: Estimate unit + time FE on untreated observations.
1185
+
1186
+ Uses iterative alternating projection (Gauss-Seidel) to compute exact
1187
+ OLS fixed effects for both balanced and unbalanced panels. For balanced
1188
+ panels, converges in 1-2 iterations (identical to one-pass demeaning).
1189
+
1190
+ Returns
1191
+ -------
1192
+ unit_fe : dict
1193
+ Unit fixed effects {unit_id: alpha_i}.
1194
+ time_fe : dict
1195
+ Time fixed effects {time_period: beta_t}.
1196
+ grand_mean : float
1197
+ Grand mean (0.0 — absorbed into iterative FE).
1198
+ delta_hat : np.ndarray or None
1199
+ Covariate coefficients (if covariates provided).
1200
+ kept_cov_mask : np.ndarray or None
1201
+ Boolean mask of shape (n_covariates,) indicating which covariates
1202
+ have finite coefficients. None if no covariates.
1203
+ """
1204
+ df_0 = df.loc[omega_0_mask]
1205
+
1206
+ if covariates is None or len(covariates) == 0:
1207
+ # No covariates: estimate FE via iterative alternating projection
1208
+ # (exact OLS for both balanced and unbalanced panels)
1209
+ y = df_0[outcome].values.copy()
1210
+ unit_fe, time_fe = self._iterative_fe(
1211
+ y, df_0[unit].values, df_0[time].values, df_0.index
1212
+ )
1213
+ # grand_mean = 0: iterative FE absorb the intercept
1214
+ return unit_fe, time_fe, 0.0, None, None
1215
+
1216
+ else:
1217
+ # With covariates: iteratively demean Y and X, OLS for delta,
1218
+ # then recover FE from covariate-adjusted outcome
1219
+ y = df_0[outcome].values.copy()
1220
+ X_raw = df_0[covariates].values.copy()
1221
+ units = df_0[unit].values
1222
+ times = df_0[time].values
1223
+ n_cov = len(covariates)
1224
+
1225
+ # Step A: Iteratively demean Y and all X columns to remove unit+time FE
1226
+ y_dm = self._iterative_demean(y, units, times, df_0.index)
1227
+ X_dm = np.column_stack(
1228
+ [
1229
+ self._iterative_demean(X_raw[:, j], units, times, df_0.index)
1230
+ for j in range(n_cov)
1231
+ ]
1232
+ )
1233
+
1234
+ # Step B: OLS for covariate coefficients on demeaned data
1235
+ result = solve_ols(
1236
+ X_dm,
1237
+ y_dm,
1238
+ return_vcov=False,
1239
+ rank_deficient_action=self.rank_deficient_action,
1240
+ column_names=covariates,
1241
+ )
1242
+ delta_hat = result[0]
1243
+
1244
+ # Mask of covariates with finite coefficients (before cleaning)
1245
+ # Used to exclude rank-deficient covariates from variance design matrices
1246
+ kept_cov_mask = np.isfinite(delta_hat)
1247
+
1248
+ # Replace NaN coefficients with 0 for adjustment
1249
+ # (rank-deficient covariates are dropped)
1250
+ delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0)
1251
+
1252
+ # Step C: Recover FE from covariate-adjusted outcome using iterative FE
1253
+ y_adj = y - X_raw @ delta_hat_clean
1254
+ unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index)
1255
+
1256
+ # grand_mean = 0: iterative FE absorb the intercept
1257
+ return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask
1258
+
1259
+ # =========================================================================
1260
+ # Step 2: Impute counterfactuals
1261
+ # =========================================================================
1262
+
1263
+ def _impute_treatment_effects(
1264
+ self,
1265
+ df: pd.DataFrame,
1266
+ outcome: str,
1267
+ unit: str,
1268
+ time: str,
1269
+ covariates: Optional[List[str]],
1270
+ omega_1_mask: pd.Series,
1271
+ unit_fe: Dict[Any, float],
1272
+ time_fe: Dict[Any, float],
1273
+ grand_mean: float,
1274
+ delta_hat: Optional[np.ndarray],
1275
+ ) -> Tuple[np.ndarray, np.ndarray]:
1276
+ """
1277
+ Step 2: Impute Y(0) for treated observations and compute tau_hat.
1278
+
1279
+ Returns
1280
+ -------
1281
+ tau_hat : np.ndarray
1282
+ Imputed treatment effects for each treated observation.
1283
+ y_hat_0 : np.ndarray
1284
+ Imputed counterfactual Y(0).
1285
+ """
1286
+ df_1 = df.loc[omega_1_mask]
1287
+ n_1 = len(df_1)
1288
+
1289
+ # Look up unit and time FE
1290
+ alpha_i = df_1[unit].map(unit_fe).values
1291
+ beta_t = df_1[time].map(time_fe).values
1292
+
1293
+ # Handle missing FE (set to NaN)
1294
+ alpha_i = np.where(pd.isna(alpha_i), np.nan, alpha_i).astype(float)
1295
+ beta_t = np.where(pd.isna(beta_t), np.nan, beta_t).astype(float)
1296
+
1297
+ y_hat_0 = grand_mean + alpha_i + beta_t
1298
+
1299
+ if delta_hat is not None and covariates:
1300
+ X_1 = df_1[covariates].values
1301
+ y_hat_0 = y_hat_0 + X_1 @ delta_hat
1302
+
1303
+ tau_hat = df_1[outcome].values - y_hat_0
1304
+
1305
+ return tau_hat, y_hat_0
1306
+
1307
+ # =========================================================================
1308
+ # Conservative Variance (Theorem 3)
1309
+ # =========================================================================
1310
+
1311
+ def _compute_cluster_psi_sums(
1312
+ self,
1313
+ df: pd.DataFrame,
1314
+ outcome: str,
1315
+ unit: str,
1316
+ time: str,
1317
+ first_treat: str,
1318
+ covariates: Optional[List[str]],
1319
+ omega_0_mask: pd.Series,
1320
+ omega_1_mask: pd.Series,
1321
+ unit_fe: Dict[Any, float],
1322
+ time_fe: Dict[Any, float],
1323
+ grand_mean: float,
1324
+ delta_hat: Optional[np.ndarray],
1325
+ weights: np.ndarray,
1326
+ cluster_var: str,
1327
+ kept_cov_mask: Optional[np.ndarray] = None,
1328
+ ) -> Tuple[np.ndarray, np.ndarray]:
1329
+ """
1330
+ Compute cluster-level influence function sums (Theorem 3).
1331
+
1332
+ psi_i = sum_t v_it * epsilon_tilde_it, summed within each cluster.
1333
+
1334
+ Returns
1335
+ -------
1336
+ cluster_psi_sums : np.ndarray
1337
+ Array of cluster-level psi sums.
1338
+ cluster_ids_unique : np.ndarray
1339
+ Unique cluster identifiers (matching order of psi sums).
1340
+ """
1341
+ df_0 = df.loc[omega_0_mask]
1342
+ df_1 = df.loc[omega_1_mask]
1343
+ n_0 = len(df_0)
1344
+ n_1 = len(df_1)
1345
+
1346
+ # ---- Compute v_it for treated observations ----
1347
+ v_treated = weights.copy()
1348
+
1349
+ # ---- Compute v_it for untreated observations ----
1350
+ if covariates is None or len(covariates) == 0:
1351
+ # FE-only case: closed-form
1352
+ treated_units = df_1[unit].values
1353
+ treated_times = df_1[time].values
1354
+
1355
+ w_by_unit: Dict[Any, float] = {}
1356
+ for i_idx in range(n_1):
1357
+ u = treated_units[i_idx]
1358
+ w_by_unit[u] = w_by_unit.get(u, 0.0) + weights[i_idx]
1359
+
1360
+ w_by_time: Dict[Any, float] = {}
1361
+ for i_idx in range(n_1):
1362
+ t = treated_times[i_idx]
1363
+ w_by_time[t] = w_by_time.get(t, 0.0) + weights[i_idx]
1364
+
1365
+ w_total = float(np.sum(weights))
1366
+
1367
+ n0_by_unit = df_0.groupby(unit).size().to_dict()
1368
+ n0_by_time = df_0.groupby(time).size().to_dict()
1369
+
1370
+ untreated_units = df_0[unit].values
1371
+ untreated_times = df_0[time].values
1372
+ v_untreated = np.zeros(n_0)
1373
+
1374
+ for j in range(n_0):
1375
+ u = untreated_units[j]
1376
+ t = untreated_times[j]
1377
+ w_i = w_by_unit.get(u, 0.0)
1378
+ w_t = w_by_time.get(t, 0.0)
1379
+ n0_i = n0_by_unit.get(u, 1)
1380
+ n0_t = n0_by_time.get(t, 1)
1381
+ v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n_0)
1382
+ else:
1383
+ v_untreated = self._compute_v_untreated_with_covariates(
1384
+ df_0,
1385
+ df_1,
1386
+ unit,
1387
+ time,
1388
+ covariates,
1389
+ weights,
1390
+ delta_hat,
1391
+ kept_cov_mask=kept_cov_mask,
1392
+ )
1393
+
1394
+ # ---- Compute auxiliary model residuals (Equation 8) ----
1395
+ epsilon_treated = self._compute_auxiliary_residuals_treated(
1396
+ df_1,
1397
+ outcome,
1398
+ unit,
1399
+ time,
1400
+ first_treat,
1401
+ covariates,
1402
+ unit_fe,
1403
+ time_fe,
1404
+ grand_mean,
1405
+ delta_hat,
1406
+ v_treated,
1407
+ )
1408
+ epsilon_untreated = self._compute_residuals_untreated(
1409
+ df_0, outcome, unit, time, covariates, unit_fe, time_fe, grand_mean, delta_hat
1410
+ )
1411
+
1412
+ # ---- psi_it = v_it * epsilon_tilde_it ----
1413
+ v_all = np.empty(len(df))
1414
+ v_all[omega_1_mask.values] = v_treated
1415
+ v_all[omega_0_mask.values] = v_untreated
1416
+
1417
+ eps_all = np.empty(len(df))
1418
+ eps_all[omega_1_mask.values] = epsilon_treated
1419
+ eps_all[omega_0_mask.values] = epsilon_untreated
1420
+
1421
+ ve_product = v_all * eps_all
1422
+ # NaN eps from missing FE (rank condition violation). Zero their variance
1423
+ # contribution — matches R's did_imputation which drops unimputable obs.
1424
+ np.nan_to_num(ve_product, copy=False, nan=0.0)
1425
+
1426
+ # Sum within clusters
1427
+ cluster_ids = df[cluster_var].values
1428
+ ve_series = pd.Series(ve_product, index=df.index)
1429
+ cluster_sums = ve_series.groupby(cluster_ids).sum()
1430
+
1431
+ return cluster_sums.values, cluster_sums.index.values
1432
+
1433
+ def _compute_conservative_variance(
1434
+ self,
1435
+ df: pd.DataFrame,
1436
+ outcome: str,
1437
+ unit: str,
1438
+ time: str,
1439
+ first_treat: str,
1440
+ covariates: Optional[List[str]],
1441
+ omega_0_mask: pd.Series,
1442
+ omega_1_mask: pd.Series,
1443
+ unit_fe: Dict[Any, float],
1444
+ time_fe: Dict[Any, float],
1445
+ grand_mean: float,
1446
+ delta_hat: Optional[np.ndarray],
1447
+ weights: np.ndarray,
1448
+ cluster_var: str,
1449
+ kept_cov_mask: Optional[np.ndarray] = None,
1450
+ ) -> float:
1451
+ """
1452
+ Compute conservative clustered variance (Theorem 3, Equation 7).
1453
+
1454
+ Parameters
1455
+ ----------
1456
+ weights : np.ndarray
1457
+ Aggregation weights w_it for treated observations.
1458
+ Shape: (n_treated,), must sum to 1.
1459
+
1460
+ Returns
1461
+ -------
1462
+ float
1463
+ Standard error.
1464
+ """
1465
+ cluster_psi_sums, _ = self._compute_cluster_psi_sums(
1466
+ df=df,
1467
+ outcome=outcome,
1468
+ unit=unit,
1469
+ time=time,
1470
+ first_treat=first_treat,
1471
+ covariates=covariates,
1472
+ omega_0_mask=omega_0_mask,
1473
+ omega_1_mask=omega_1_mask,
1474
+ unit_fe=unit_fe,
1475
+ time_fe=time_fe,
1476
+ grand_mean=grand_mean,
1477
+ delta_hat=delta_hat,
1478
+ weights=weights,
1479
+ cluster_var=cluster_var,
1480
+ kept_cov_mask=kept_cov_mask,
1481
+ )
1482
+ sigma_sq = float((cluster_psi_sums**2).sum())
1483
+ return np.sqrt(max(sigma_sq, 0.0))
1484
+
1485
+ def _compute_v_untreated_with_covariates(
1486
+ self,
1487
+ df_0: pd.DataFrame,
1488
+ df_1: pd.DataFrame,
1489
+ unit: str,
1490
+ time: str,
1491
+ covariates: List[str],
1492
+ weights: np.ndarray,
1493
+ delta_hat: Optional[np.ndarray],
1494
+ kept_cov_mask: Optional[np.ndarray] = None,
1495
+ ) -> np.ndarray:
1496
+ """
1497
+ Compute v_it for untreated observations with covariates.
1498
+
1499
+ Uses the projection: v_untreated = -A_0 (A_0'A_0)^{-1} A_1' w_treated
1500
+
1501
+ Uses scipy.sparse for FE dummy columns to reduce memory from O(N*(U+T))
1502
+ to O(N) for the FE portion.
1503
+ """
1504
+ # Exclude rank-deficient covariates from design matrices
1505
+ if kept_cov_mask is not None and not np.all(kept_cov_mask):
1506
+ covariates = [c for c, k in zip(covariates, kept_cov_mask) if k]
1507
+
1508
+ units_0 = df_0[unit].values
1509
+ times_0 = df_0[time].values
1510
+ units_1 = df_1[unit].values
1511
+ times_1 = df_1[time].values
1512
+
1513
+ all_units = np.unique(np.concatenate([units_0, units_1]))
1514
+ all_times = np.unique(np.concatenate([times_0, times_1]))
1515
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1516
+ time_to_idx = {t: i for i, t in enumerate(all_times)}
1517
+ n_units = len(all_units)
1518
+ n_times = len(all_times)
1519
+ n_cov = len(covariates)
1520
+ n_fe_cols = (n_units - 1) + (n_times - 1)
1521
+
1522
+ def _build_A_sparse(df_sub, unit_vals, time_vals):
1523
+ n = len(df_sub)
1524
+
1525
+ # Unit dummies (drop first) — vectorized
1526
+ u_indices = np.array([unit_to_idx[u] for u in unit_vals])
1527
+ u_mask = u_indices > 0 # skip first unit (dropped)
1528
+ u_rows = np.arange(n)[u_mask]
1529
+ u_cols = u_indices[u_mask] - 1
1530
+
1531
+ # Time dummies (drop first) — vectorized
1532
+ t_indices = np.array([time_to_idx[t] for t in time_vals])
1533
+ t_mask = t_indices > 0
1534
+ t_rows = np.arange(n)[t_mask]
1535
+ t_cols = (n_units - 1) + t_indices[t_mask] - 1
1536
+
1537
+ rows = np.concatenate([u_rows, t_rows])
1538
+ cols = np.concatenate([u_cols, t_cols])
1539
+ data = np.ones(len(rows))
1540
+
1541
+ A_fe = sparse.csr_matrix((data, (rows, cols)), shape=(n, n_fe_cols))
1542
+
1543
+ # Covariates (dense, typically few columns)
1544
+ if n_cov > 0:
1545
+ A_cov = sparse.csr_matrix(df_sub[covariates].values)
1546
+ A = sparse.hstack([A_fe, A_cov], format="csr")
1547
+ else:
1548
+ A = A_fe
1549
+
1550
+ return A
1551
+
1552
+ A_0 = _build_A_sparse(df_0, units_0, times_0)
1553
+ A_1 = _build_A_sparse(df_1, units_1, times_1)
1554
+
1555
+ # Compute A_1' w (sparse.T @ dense -> dense)
1556
+ A1_w = A_1.T @ weights # shape (p,)
1557
+
1558
+ # Solve (A_0'A_0) z = A_1' w using sparse direct solver
1559
+ A0tA0_sparse = A_0.T @ A_0 # stays sparse
1560
+ try:
1561
+ z = spsolve(A0tA0_sparse.tocsc(), A1_w)
1562
+ except Exception:
1563
+ # Fallback to dense lstsq if sparse solver fails (e.g., singular matrix)
1564
+ A0tA0_dense = A0tA0_sparse.toarray()
1565
+ z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None)
1566
+
1567
+ # v_untreated = -A_0 z (sparse @ dense -> dense)
1568
+ v_untreated = -(A_0 @ z)
1569
+ return v_untreated
1570
+
1571
+ def _compute_auxiliary_residuals_treated(
1572
+ self,
1573
+ df_1: pd.DataFrame,
1574
+ outcome: str,
1575
+ unit: str,
1576
+ time: str,
1577
+ first_treat: str,
1578
+ covariates: Optional[List[str]],
1579
+ unit_fe: Dict[Any, float],
1580
+ time_fe: Dict[Any, float],
1581
+ grand_mean: float,
1582
+ delta_hat: Optional[np.ndarray],
1583
+ v_treated: np.ndarray,
1584
+ ) -> np.ndarray:
1585
+ """
1586
+ Compute v_it-weighted auxiliary residuals for treated obs (Equation 8).
1587
+
1588
+ Computes v_it-weighted tau_tilde_g per Equation 8 of Borusyak et al. (2024):
1589
+ tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g.
1590
+
1591
+ epsilon_tilde_it = Y_it - alpha_i - beta_t [- X'delta] - tau_tilde_g
1592
+ """
1593
+ n_1 = len(df_1)
1594
+
1595
+ # Compute base residuals (Y - Y_hat(0) = tau_hat)
1596
+ # NaN for missing FE (consistent with _impute_treatment_effects)
1597
+ alpha_i = df_1[unit].map(unit_fe).values.astype(float) # NaN for missing
1598
+ beta_t = df_1[time].map(time_fe).values.astype(float) # NaN for missing
1599
+ y_hat_0 = grand_mean + alpha_i + beta_t
1600
+
1601
+ if delta_hat is not None and covariates:
1602
+ y_hat_0 = y_hat_0 + df_1[covariates].values @ delta_hat
1603
+
1604
+ tau_hat = df_1[outcome].values - y_hat_0
1605
+
1606
+ # Partition Omega_1 and compute tau_tilde for each group
1607
+ if self.aux_partition == "cohort_horizon":
1608
+ group_keys = list(zip(df_1[first_treat].values, df_1["_rel_time"].values))
1609
+ elif self.aux_partition == "cohort":
1610
+ group_keys = list(df_1[first_treat].values)
1611
+ elif self.aux_partition == "horizon":
1612
+ group_keys = list(df_1["_rel_time"].values)
1613
+ else:
1614
+ group_keys = list(range(n_1)) # each obs is its own group
1615
+
1616
+ # Compute v_it-weighted average tau within each partition group (Equation 8)
1617
+ # tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g
1618
+ group_series = pd.Series(group_keys, index=df_1.index)
1619
+ tau_series = pd.Series(tau_hat, index=df_1.index)
1620
+ v_series = pd.Series(v_treated, index=df_1.index)
1621
+
1622
+ weighted_tau_sum = (v_series * tau_series).groupby(group_series).sum()
1623
+ weight_sum = v_series.groupby(group_series).sum()
1624
+
1625
+ # Guard: zero-weight groups -> their tau_tilde doesn't affect variance
1626
+ # (v_it ~ 0 means these obs contribute nothing to the estimand)
1627
+ # Use simple mean as fallback. This is common for event-study SE computation
1628
+ # where weights target a specific horizon, making other partition groups zero.
1629
+ zero_weight_groups = weight_sum.abs() < 1e-15
1630
+ if zero_weight_groups.any():
1631
+ simple_means = tau_series.groupby(group_series).mean()
1632
+ tau_tilde_map = weighted_tau_sum / weight_sum
1633
+ tau_tilde_map = tau_tilde_map.where(~zero_weight_groups, simple_means)
1634
+ else:
1635
+ tau_tilde_map = weighted_tau_sum / weight_sum
1636
+
1637
+ tau_tilde = group_series.map(tau_tilde_map).values
1638
+
1639
+ # Auxiliary residuals
1640
+ epsilon_treated = tau_hat - tau_tilde
1641
+
1642
+ return epsilon_treated
1643
+
1644
+ def _compute_residuals_untreated(
1645
+ self,
1646
+ df_0: pd.DataFrame,
1647
+ outcome: str,
1648
+ unit: str,
1649
+ time: str,
1650
+ covariates: Optional[List[str]],
1651
+ unit_fe: Dict[Any, float],
1652
+ time_fe: Dict[Any, float],
1653
+ grand_mean: float,
1654
+ delta_hat: Optional[np.ndarray],
1655
+ ) -> np.ndarray:
1656
+ """Compute Step 1 residuals for untreated observations."""
1657
+ alpha_i = df_0[unit].map(unit_fe).fillna(0.0).values
1658
+ beta_t = df_0[time].map(time_fe).fillna(0.0).values
1659
+ y_hat = grand_mean + alpha_i + beta_t
1660
+
1661
+ if delta_hat is not None and covariates:
1662
+ y_hat = y_hat + df_0[covariates].values @ delta_hat
1663
+
1664
+ return df_0[outcome].values - y_hat
1665
+
1666
+ # =========================================================================
1667
+ # Aggregation
1668
+ # =========================================================================
1669
+
1670
+ def _aggregate_event_study(
1671
+ self,
1672
+ df: pd.DataFrame,
1673
+ outcome: str,
1674
+ unit: str,
1675
+ time: str,
1676
+ first_treat: str,
1677
+ covariates: Optional[List[str]],
1678
+ omega_0_mask: pd.Series,
1679
+ omega_1_mask: pd.Series,
1680
+ unit_fe: Dict[Any, float],
1681
+ time_fe: Dict[Any, float],
1682
+ grand_mean: float,
1683
+ delta_hat: Optional[np.ndarray],
1684
+ cluster_var: str,
1685
+ treatment_groups: List[Any],
1686
+ balance_e: Optional[int] = None,
1687
+ kept_cov_mask: Optional[np.ndarray] = None,
1688
+ ) -> Dict[int, Dict[str, Any]]:
1689
+ """Aggregate treatment effects by event-study horizon."""
1690
+ df_1 = df.loc[omega_1_mask]
1691
+ tau_hat = df["_tau_hat"].loc[omega_1_mask].values
1692
+ rel_times = df_1["_rel_time"].values
1693
+
1694
+ # Get all horizons
1695
+ all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h)))
1696
+
1697
+ # Apply horizon_max filter
1698
+ if self.horizon_max is not None:
1699
+ all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
1700
+
1701
+ # Apply balance_e filter
1702
+ if balance_e is not None:
1703
+ cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
1704
+ balanced_mask = pd.Series(
1705
+ self._compute_balanced_cohort_mask(
1706
+ df_1, first_treat, all_horizons, balance_e, cohort_rel_times
1707
+ ),
1708
+ index=df_1.index,
1709
+ )
1710
+ else:
1711
+ balanced_mask = pd.Series(True, index=df_1.index)
1712
+
1713
+ # Check Proposition 5: no never-treated units
1714
+ has_never_treated = df["_never_treated"].any()
1715
+ h_bar = np.inf
1716
+ if not has_never_treated and len(treatment_groups) > 1:
1717
+ h_bar = max(treatment_groups) - min(treatment_groups)
1718
+
1719
+ # Reference period
1720
+ ref_period = -1 - self.anticipation
1721
+
1722
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
1723
+
1724
+ # Add reference period marker
1725
+ event_study_effects[ref_period] = {
1726
+ "effect": 0.0,
1727
+ "se": 0.0,
1728
+ "t_stat": np.nan,
1729
+ "p_value": np.nan,
1730
+ "conf_int": (0.0, 0.0),
1731
+ "n_obs": 0,
1732
+ }
1733
+
1734
+ # Collect horizons with Proposition 5 violations
1735
+ prop5_horizons = []
1736
+
1737
+ for h in all_horizons:
1738
+ if h == ref_period:
1739
+ continue
1740
+
1741
+ # Select treated obs at this horizon from balanced cohorts
1742
+ h_mask = (rel_times == h) & balanced_mask.values
1743
+ n_h = int(h_mask.sum())
1744
+
1745
+ if n_h == 0:
1746
+ continue
1747
+
1748
+ # Proposition 5 check
1749
+ if not has_never_treated and h >= h_bar:
1750
+ prop5_horizons.append(h)
1751
+ event_study_effects[h] = {
1752
+ "effect": np.nan,
1753
+ "se": np.nan,
1754
+ "t_stat": np.nan,
1755
+ "p_value": np.nan,
1756
+ "conf_int": (np.nan, np.nan),
1757
+ "n_obs": n_h,
1758
+ }
1759
+ continue
1760
+
1761
+ tau_h = tau_hat[h_mask]
1762
+ valid_tau = tau_h[np.isfinite(tau_h)]
1763
+
1764
+ if len(valid_tau) == 0:
1765
+ event_study_effects[h] = {
1766
+ "effect": np.nan,
1767
+ "se": np.nan,
1768
+ "t_stat": np.nan,
1769
+ "p_value": np.nan,
1770
+ "conf_int": (np.nan, np.nan),
1771
+ "n_obs": n_h,
1772
+ }
1773
+ continue
1774
+
1775
+ effect = float(np.mean(valid_tau))
1776
+
1777
+ # Compute SE via conservative variance with horizon-specific weights
1778
+ weights_h = np.zeros(int(omega_1_mask.sum()))
1779
+ # Map h_mask (relative to df_1) to weights array
1780
+ h_indices_in_omega1 = np.where(h_mask)[0]
1781
+ n_valid = len(valid_tau)
1782
+ # Only weight valid (finite) observations
1783
+ finite_mask = np.isfinite(tau_hat[h_mask])
1784
+ valid_h_indices = h_indices_in_omega1[finite_mask]
1785
+ for idx in valid_h_indices:
1786
+ weights_h[idx] = 1.0 / n_valid
1787
+
1788
+ se = self._compute_conservative_variance(
1789
+ df=df,
1790
+ outcome=outcome,
1791
+ unit=unit,
1792
+ time=time,
1793
+ first_treat=first_treat,
1794
+ covariates=covariates,
1795
+ omega_0_mask=omega_0_mask,
1796
+ omega_1_mask=omega_1_mask,
1797
+ unit_fe=unit_fe,
1798
+ time_fe=time_fe,
1799
+ grand_mean=grand_mean,
1800
+ delta_hat=delta_hat,
1801
+ weights=weights_h,
1802
+ cluster_var=cluster_var,
1803
+ kept_cov_mask=kept_cov_mask,
1804
+ )
1805
+
1806
+ t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan
1807
+ p_value = compute_p_value(t_stat)
1808
+ conf_int = (
1809
+ compute_confidence_interval(effect, se, self.alpha)
1810
+ if np.isfinite(se) and se > 0
1811
+ else (np.nan, np.nan)
1812
+ )
1813
+
1814
+ event_study_effects[h] = {
1815
+ "effect": effect,
1816
+ "se": se,
1817
+ "t_stat": t_stat,
1818
+ "p_value": p_value,
1819
+ "conf_int": conf_int,
1820
+ "n_obs": n_h,
1821
+ }
1822
+
1823
+ # Proposition 5 warning
1824
+ if prop5_horizons:
1825
+ warnings.warn(
1826
+ f"Horizons {prop5_horizons} are not identified without "
1827
+ f"never-treated units (Proposition 5). Set to NaN.",
1828
+ UserWarning,
1829
+ stacklevel=3,
1830
+ )
1831
+
1832
+ # Check for empty result set after filtering
1833
+ real_effects = [
1834
+ h for h, v in event_study_effects.items() if h != ref_period and v.get("n_obs", 0) > 0
1835
+ ]
1836
+ if len(real_effects) == 0:
1837
+ filter_info = []
1838
+ if balance_e is not None:
1839
+ filter_info.append(f"balance_e={balance_e}")
1840
+ if self.horizon_max is not None:
1841
+ filter_info.append(f"horizon_max={self.horizon_max}")
1842
+ filter_str = " and ".join(filter_info) if filter_info else "filters"
1843
+ warnings.warn(
1844
+ f"Event study aggregation produced no horizons with observations "
1845
+ f"after applying {filter_str}. The result contains only the "
1846
+ f"reference period marker. Consider relaxing filter parameters.",
1847
+ UserWarning,
1848
+ stacklevel=3,
1849
+ )
1850
+
1851
+ return event_study_effects
1852
+
1853
+ def _aggregate_group(
1854
+ self,
1855
+ df: pd.DataFrame,
1856
+ outcome: str,
1857
+ unit: str,
1858
+ time: str,
1859
+ first_treat: str,
1860
+ covariates: Optional[List[str]],
1861
+ omega_0_mask: pd.Series,
1862
+ omega_1_mask: pd.Series,
1863
+ unit_fe: Dict[Any, float],
1864
+ time_fe: Dict[Any, float],
1865
+ grand_mean: float,
1866
+ delta_hat: Optional[np.ndarray],
1867
+ cluster_var: str,
1868
+ treatment_groups: List[Any],
1869
+ kept_cov_mask: Optional[np.ndarray] = None,
1870
+ ) -> Dict[Any, Dict[str, Any]]:
1871
+ """Aggregate treatment effects by cohort."""
1872
+ df_1 = df.loc[omega_1_mask]
1873
+ tau_hat = df["_tau_hat"].loc[omega_1_mask].values
1874
+ cohorts = df_1[first_treat].values
1875
+
1876
+ group_effects: Dict[Any, Dict[str, Any]] = {}
1877
+
1878
+ for g in treatment_groups:
1879
+ g_mask = cohorts == g
1880
+ n_g = int(g_mask.sum())
1881
+
1882
+ if n_g == 0:
1883
+ continue
1884
+
1885
+ tau_g = tau_hat[g_mask]
1886
+ valid_tau = tau_g[np.isfinite(tau_g)]
1887
+
1888
+ if len(valid_tau) == 0:
1889
+ group_effects[g] = {
1890
+ "effect": np.nan,
1891
+ "se": np.nan,
1892
+ "t_stat": np.nan,
1893
+ "p_value": np.nan,
1894
+ "conf_int": (np.nan, np.nan),
1895
+ "n_obs": n_g,
1896
+ }
1897
+ continue
1898
+
1899
+ effect = float(np.mean(valid_tau))
1900
+
1901
+ # Compute SE with group-specific weights
1902
+ weights_g = np.zeros(int(omega_1_mask.sum()))
1903
+ finite_mask = np.isfinite(tau_hat) & g_mask
1904
+ g_indices = np.where(finite_mask)[0]
1905
+ n_valid = len(valid_tau)
1906
+ for idx in g_indices:
1907
+ weights_g[idx] = 1.0 / n_valid
1908
+
1909
+ se = self._compute_conservative_variance(
1910
+ df=df,
1911
+ outcome=outcome,
1912
+ unit=unit,
1913
+ time=time,
1914
+ first_treat=first_treat,
1915
+ covariates=covariates,
1916
+ omega_0_mask=omega_0_mask,
1917
+ omega_1_mask=omega_1_mask,
1918
+ unit_fe=unit_fe,
1919
+ time_fe=time_fe,
1920
+ grand_mean=grand_mean,
1921
+ delta_hat=delta_hat,
1922
+ weights=weights_g,
1923
+ cluster_var=cluster_var,
1924
+ kept_cov_mask=kept_cov_mask,
1925
+ )
1926
+
1927
+ t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan
1928
+ p_value = compute_p_value(t_stat)
1929
+ conf_int = (
1930
+ compute_confidence_interval(effect, se, self.alpha)
1931
+ if np.isfinite(se) and se > 0
1932
+ else (np.nan, np.nan)
1933
+ )
1934
+
1935
+ group_effects[g] = {
1936
+ "effect": effect,
1937
+ "se": se,
1938
+ "t_stat": t_stat,
1939
+ "p_value": p_value,
1940
+ "conf_int": conf_int,
1941
+ "n_obs": n_g,
1942
+ }
1943
+
1944
+ return group_effects
1945
+
1946
+ # =========================================================================
1947
+ # Pre-trend test (Equation 9)
1948
+ # =========================================================================
1949
+
1950
+ def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
1951
+ """
1952
+ Run pre-trend test (Equation 9).
1953
+
1954
+ Adds pre-treatment lead indicators to the Step 1 OLS on Omega_0
1955
+ and tests their joint significance via cluster-robust Wald F-test.
1956
+ """
1957
+ if self._fit_data is None:
1958
+ raise RuntimeError("Must call fit() before pretrend_test().")
1959
+
1960
+ fd = self._fit_data
1961
+ df = fd["df"]
1962
+ outcome = fd["outcome"]
1963
+ unit = fd["unit"]
1964
+ time = fd["time"]
1965
+ first_treat = fd["first_treat"]
1966
+ covariates = fd["covariates"]
1967
+ omega_0_mask = fd["omega_0_mask"]
1968
+ cluster_var = fd["cluster_var"]
1969
+
1970
+ df_0 = df.loc[omega_0_mask].copy()
1971
+
1972
+ # Compute relative time for untreated obs
1973
+ # For not-yet-treated units in their pre-treatment periods
1974
+ rel_time_0 = np.where(
1975
+ ~df_0["_never_treated"],
1976
+ df_0[time] - df_0[first_treat],
1977
+ np.nan,
1978
+ )
1979
+
1980
+ # Get available pre-treatment relative times (negative values)
1981
+ pre_rel_times = sorted(
1982
+ set(int(h) for h in rel_time_0 if np.isfinite(h) and h < -self.anticipation)
1983
+ )
1984
+
1985
+ if len(pre_rel_times) == 0:
1986
+ return {
1987
+ "f_stat": np.nan,
1988
+ "p_value": np.nan,
1989
+ "df": 0,
1990
+ "n_leads": 0,
1991
+ "lead_coefficients": {},
1992
+ }
1993
+
1994
+ # Exclude the reference period (last pre-treatment period)
1995
+ ref = -1 - self.anticipation
1996
+ pre_rel_times = [h for h in pre_rel_times if h != ref]
1997
+
1998
+ if n_leads is not None:
1999
+ # Take the n_leads periods closest to treatment
2000
+ pre_rel_times = sorted(pre_rel_times, reverse=True)[:n_leads]
2001
+ pre_rel_times = sorted(pre_rel_times)
2002
+
2003
+ if len(pre_rel_times) == 0:
2004
+ return {
2005
+ "f_stat": np.nan,
2006
+ "p_value": np.nan,
2007
+ "df": 0,
2008
+ "n_leads": 0,
2009
+ "lead_coefficients": {},
2010
+ }
2011
+
2012
+ # Build lead indicators
2013
+ lead_cols = []
2014
+ for h in pre_rel_times:
2015
+ col_name = f"_lead_{h}"
2016
+ df_0[col_name] = ((rel_time_0 == h)).astype(float)
2017
+ lead_cols.append(col_name)
2018
+
2019
+ # Within-transform via iterative demeaning (exact for unbalanced panels)
2020
+ y_dm = self._iterative_demean(
2021
+ df_0[outcome].values, df_0[unit].values, df_0[time].values, df_0.index
2022
+ )
2023
+
2024
+ all_x_cols = lead_cols[:]
2025
+ if covariates:
2026
+ all_x_cols.extend(covariates)
2027
+
2028
+ X_dm = np.column_stack(
2029
+ [
2030
+ self._iterative_demean(
2031
+ df_0[col].values, df_0[unit].values, df_0[time].values, df_0.index
2032
+ )
2033
+ for col in all_x_cols
2034
+ ]
2035
+ )
2036
+
2037
+ # OLS with cluster-robust SEs
2038
+ cluster_ids = df_0[cluster_var].values
2039
+ result = solve_ols(
2040
+ X_dm,
2041
+ y_dm,
2042
+ cluster_ids=cluster_ids,
2043
+ return_vcov=True,
2044
+ rank_deficient_action=self.rank_deficient_action,
2045
+ column_names=all_x_cols,
2046
+ )
2047
+ coefficients = result[0]
2048
+ vcov = result[2]
2049
+
2050
+ # Extract lead coefficients and their sub-VCV
2051
+ n_leads_actual = len(lead_cols)
2052
+ gamma = coefficients[:n_leads_actual]
2053
+ V_gamma = vcov[:n_leads_actual, :n_leads_actual]
2054
+
2055
+ # Wald F-test: F = (gamma' V^{-1} gamma) / n_leads
2056
+ try:
2057
+ V_inv_gamma = np.linalg.solve(V_gamma, gamma)
2058
+ wald_stat = float(gamma @ V_inv_gamma)
2059
+ f_stat = wald_stat / n_leads_actual
2060
+ except np.linalg.LinAlgError:
2061
+ f_stat = np.nan
2062
+
2063
+ # P-value from F distribution
2064
+ if np.isfinite(f_stat) and f_stat >= 0:
2065
+ n_clusters = len(np.unique(cluster_ids))
2066
+ df_denom = max(n_clusters - 1, 1)
2067
+ p_value = float(stats.f.sf(f_stat, n_leads_actual, df_denom))
2068
+ else:
2069
+ p_value = np.nan
2070
+
2071
+ # Store lead coefficients
2072
+ lead_coefficients = {}
2073
+ for j, h in enumerate(pre_rel_times):
2074
+ lead_coefficients[h] = float(gamma[j])
2075
+
2076
+ return {
2077
+ "f_stat": f_stat,
2078
+ "p_value": p_value,
2079
+ "df": n_leads_actual,
2080
+ "n_leads": n_leads_actual,
2081
+ "lead_coefficients": lead_coefficients,
2082
+ }
2083
+
2084
+ # =========================================================================
2085
+ # Bootstrap
2086
+ # =========================================================================
2087
+
2088
+ def _compute_percentile_ci(
2089
+ self,
2090
+ boot_dist: np.ndarray,
2091
+ alpha: float,
2092
+ ) -> Tuple[float, float]:
2093
+ """Compute percentile confidence interval from bootstrap distribution."""
2094
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
2095
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
2096
+ return (lower, upper)
2097
+
2098
+ def _compute_bootstrap_pvalue(
2099
+ self,
2100
+ original_effect: float,
2101
+ boot_dist: np.ndarray,
2102
+ n_valid: Optional[int] = None,
2103
+ ) -> float:
2104
+ """
2105
+ Compute two-sided bootstrap p-value.
2106
+
2107
+ Uses the percentile method: p-value is the proportion of bootstrap
2108
+ estimates on the opposite side of zero from the original estimate,
2109
+ doubled for two-sided test.
2110
+
2111
+ Parameters
2112
+ ----------
2113
+ original_effect : float
2114
+ Original point estimate.
2115
+ boot_dist : np.ndarray
2116
+ Bootstrap distribution of the effect.
2117
+ n_valid : int, optional
2118
+ Number of valid bootstrap samples. If None, uses self.n_bootstrap.
2119
+ """
2120
+ if original_effect >= 0:
2121
+ p_one_sided = float(np.mean(boot_dist <= 0))
2122
+ else:
2123
+ p_one_sided = float(np.mean(boot_dist >= 0))
2124
+ p_value = min(2 * p_one_sided, 1.0)
2125
+ n_for_floor = n_valid if n_valid is not None else self.n_bootstrap
2126
+ p_value = max(p_value, 1 / (n_for_floor + 1))
2127
+ return p_value
2128
+
2129
+ def _precompute_bootstrap_psi(
2130
+ self,
2131
+ df: pd.DataFrame,
2132
+ outcome: str,
2133
+ unit: str,
2134
+ time: str,
2135
+ first_treat: str,
2136
+ covariates: Optional[List[str]],
2137
+ omega_0_mask: pd.Series,
2138
+ omega_1_mask: pd.Series,
2139
+ unit_fe: Dict[Any, float],
2140
+ time_fe: Dict[Any, float],
2141
+ grand_mean: float,
2142
+ delta_hat: Optional[np.ndarray],
2143
+ cluster_var: str,
2144
+ kept_cov_mask: Optional[np.ndarray],
2145
+ overall_weights: np.ndarray,
2146
+ event_study_effects: Optional[Dict[int, Dict[str, Any]]],
2147
+ group_effects: Optional[Dict[Any, Dict[str, Any]]],
2148
+ treatment_groups: List[Any],
2149
+ tau_hat: np.ndarray,
2150
+ balance_e: Optional[int],
2151
+ ) -> Dict[str, Any]:
2152
+ """
2153
+ Pre-compute cluster-level influence function sums for each bootstrap target.
2154
+
2155
+ For each aggregation target (overall, per-horizon, per-group), computes
2156
+ psi_i = sum_t v_it * epsilon_tilde_it for each cluster. The multiplier
2157
+ bootstrap then perturbs these psi sums with Rademacher weights.
2158
+
2159
+ Computational cost scales with the number of aggregation targets, since
2160
+ each target requires its own v_untreated computation (weight-dependent).
2161
+ """
2162
+ result: Dict[str, Any] = {}
2163
+
2164
+ common = dict(
2165
+ df=df,
2166
+ outcome=outcome,
2167
+ unit=unit,
2168
+ time=time,
2169
+ first_treat=first_treat,
2170
+ covariates=covariates,
2171
+ omega_0_mask=omega_0_mask,
2172
+ omega_1_mask=omega_1_mask,
2173
+ unit_fe=unit_fe,
2174
+ time_fe=time_fe,
2175
+ grand_mean=grand_mean,
2176
+ delta_hat=delta_hat,
2177
+ cluster_var=cluster_var,
2178
+ kept_cov_mask=kept_cov_mask,
2179
+ )
2180
+
2181
+ # Overall ATT
2182
+ overall_psi, cluster_ids = self._compute_cluster_psi_sums(**common, weights=overall_weights)
2183
+ result["overall"] = (overall_psi, cluster_ids)
2184
+
2185
+ # Event study: per-horizon weights
2186
+ # NOTE: weight logic duplicated from _aggregate_event_study.
2187
+ # If weight scheme changes there, update here too.
2188
+ if event_study_effects:
2189
+ result["event_study"] = {}
2190
+ df_1 = df.loc[omega_1_mask]
2191
+ rel_times = df_1["_rel_time"].values
2192
+ n_omega_1 = int(omega_1_mask.sum())
2193
+
2194
+ # Balanced cohort mask (same logic as _aggregate_event_study)
2195
+ balanced_mask = None
2196
+ if balance_e is not None:
2197
+ all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h)))
2198
+ if self.horizon_max is not None:
2199
+ all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
2200
+ cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
2201
+ balanced_mask = self._compute_balanced_cohort_mask(
2202
+ df_1, first_treat, all_horizons, balance_e, cohort_rel_times
2203
+ )
2204
+
2205
+ ref_period = -1 - self.anticipation
2206
+ for h in event_study_effects:
2207
+ if event_study_effects[h].get("n_obs", 0) == 0:
2208
+ continue
2209
+ if h == ref_period:
2210
+ continue
2211
+ if not np.isfinite(event_study_effects[h].get("effect", np.nan)):
2212
+ continue
2213
+ h_mask = rel_times == h
2214
+ if balanced_mask is not None:
2215
+ h_mask = h_mask & balanced_mask
2216
+ weights_h = np.zeros(n_omega_1)
2217
+ finite_h = np.isfinite(tau_hat) & h_mask
2218
+ n_valid_h = int(finite_h.sum())
2219
+ if n_valid_h == 0:
2220
+ continue
2221
+ weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h
2222
+
2223
+ psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h)
2224
+ result["event_study"][h] = psi_h
2225
+
2226
+ # Group effects: per-group weights
2227
+ # NOTE: weight logic duplicated from _aggregate_group.
2228
+ # If weight scheme changes there, update here too.
2229
+ if group_effects:
2230
+ result["group"] = {}
2231
+ df_1 = df.loc[omega_1_mask]
2232
+ cohorts = df_1[first_treat].values
2233
+ n_omega_1 = int(omega_1_mask.sum())
2234
+
2235
+ for g in group_effects:
2236
+ if group_effects[g].get("n_obs", 0) == 0:
2237
+ continue
2238
+ if not np.isfinite(group_effects[g].get("effect", np.nan)):
2239
+ continue
2240
+ g_mask = cohorts == g
2241
+ weights_g = np.zeros(n_omega_1)
2242
+ finite_g = np.isfinite(tau_hat) & g_mask
2243
+ n_valid_g = int(finite_g.sum())
2244
+ if n_valid_g == 0:
2245
+ continue
2246
+ weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g
2247
+
2248
+ psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g)
2249
+ result["group"][g] = psi_g
2250
+
2251
+ return result
2252
+
2253
+ def _run_bootstrap(
2254
+ self,
2255
+ original_att: float,
2256
+ original_event_study: Optional[Dict[int, Dict[str, Any]]],
2257
+ original_group: Optional[Dict[Any, Dict[str, Any]]],
2258
+ psi_data: Dict[str, Any],
2259
+ ) -> ImputationBootstrapResults:
2260
+ """
2261
+ Run multiplier bootstrap on pre-computed influence function sums.
2262
+
2263
+ Uses T_b = sum_i w_b_i * psi_i where w_b_i are Rademacher weights
2264
+ and psi_i are cluster-level influence function sums from Theorem 3.
2265
+ SE = std(T_b, ddof=1).
2266
+ """
2267
+ if self.n_bootstrap < 50:
2268
+ warnings.warn(
2269
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
2270
+ "for reliable inference.",
2271
+ UserWarning,
2272
+ stacklevel=3,
2273
+ )
2274
+
2275
+ rng = np.random.default_rng(self.seed)
2276
+
2277
+ from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch
2278
+
2279
+ overall_psi, cluster_ids = psi_data["overall"]
2280
+ n_clusters = len(cluster_ids)
2281
+
2282
+ # Generate ALL weights upfront: shape (n_bootstrap, n_clusters)
2283
+ all_weights = _generate_bootstrap_weights_batch(
2284
+ self.n_bootstrap, n_clusters, "rademacher", rng
2285
+ )
2286
+
2287
+ # Overall ATT bootstrap draws
2288
+ boot_overall = all_weights @ overall_psi # (n_bootstrap,)
2289
+
2290
+ # Event study: loop over horizons
2291
+ boot_event_study: Optional[Dict[int, np.ndarray]] = None
2292
+ if original_event_study and "event_study" in psi_data:
2293
+ boot_event_study = {}
2294
+ for h, psi_h in psi_data["event_study"].items():
2295
+ boot_event_study[h] = all_weights @ psi_h
2296
+
2297
+ # Group effects: loop over groups
2298
+ boot_group: Optional[Dict[Any, np.ndarray]] = None
2299
+ if original_group and "group" in psi_data:
2300
+ boot_group = {}
2301
+ for g, psi_g in psi_data["group"].items():
2302
+ boot_group[g] = all_weights @ psi_g
2303
+
2304
+ # --- Inference (percentile bootstrap, matching CS/SA convention) ---
2305
+ # Shift perturbation-centered draws to effect-centered draws.
2306
+ # The multiplier bootstrap produces T_b = sum w_b_i * psi_i centered at 0.
2307
+ # CS adds the original effect back (L411 of staggered_bootstrap.py).
2308
+ # We do the same here so percentile CIs and empirical p-values work correctly.
2309
+ boot_overall_shifted = boot_overall + original_att
2310
+
2311
+ overall_se = float(np.std(boot_overall, ddof=1))
2312
+ overall_ci = (
2313
+ self._compute_percentile_ci(boot_overall_shifted, self.alpha)
2314
+ if overall_se > 0
2315
+ else (np.nan, np.nan)
2316
+ )
2317
+ overall_p = (
2318
+ self._compute_bootstrap_pvalue(original_att, boot_overall_shifted)
2319
+ if overall_se > 0
2320
+ else np.nan
2321
+ )
2322
+
2323
+ event_study_ses = None
2324
+ event_study_cis = None
2325
+ event_study_p_values = None
2326
+ if boot_event_study and original_event_study:
2327
+ event_study_ses = {}
2328
+ event_study_cis = {}
2329
+ event_study_p_values = {}
2330
+ for h in boot_event_study:
2331
+ se_h = float(np.std(boot_event_study[h], ddof=1))
2332
+ event_study_ses[h] = se_h
2333
+ orig_eff = original_event_study[h]["effect"]
2334
+ if se_h > 0 and np.isfinite(orig_eff):
2335
+ shifted_h = boot_event_study[h] + orig_eff
2336
+ event_study_p_values[h] = self._compute_bootstrap_pvalue(orig_eff, shifted_h)
2337
+ event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha)
2338
+ else:
2339
+ event_study_p_values[h] = np.nan
2340
+ event_study_cis[h] = (np.nan, np.nan)
2341
+
2342
+ group_ses = None
2343
+ group_cis = None
2344
+ group_p_values = None
2345
+ if boot_group and original_group:
2346
+ group_ses = {}
2347
+ group_cis = {}
2348
+ group_p_values = {}
2349
+ for g in boot_group:
2350
+ se_g = float(np.std(boot_group[g], ddof=1))
2351
+ group_ses[g] = se_g
2352
+ orig_eff = original_group[g]["effect"]
2353
+ if se_g > 0 and np.isfinite(orig_eff):
2354
+ shifted_g = boot_group[g] + orig_eff
2355
+ group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g)
2356
+ group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha)
2357
+ else:
2358
+ group_p_values[g] = np.nan
2359
+ group_cis[g] = (np.nan, np.nan)
2360
+
2361
+ return ImputationBootstrapResults(
2362
+ n_bootstrap=self.n_bootstrap,
2363
+ weight_type="rademacher",
2364
+ alpha=self.alpha,
2365
+ overall_att_se=overall_se,
2366
+ overall_att_ci=overall_ci,
2367
+ overall_att_p_value=overall_p,
2368
+ event_study_ses=event_study_ses,
2369
+ event_study_cis=event_study_cis,
2370
+ event_study_p_values=event_study_p_values,
2371
+ group_ses=group_ses,
2372
+ group_cis=group_cis,
2373
+ group_p_values=group_p_values,
2374
+ bootstrap_distribution=boot_overall_shifted,
2375
+ )
2376
+
2377
+ # =========================================================================
2378
+ # sklearn-compatible interface
2379
+ # =========================================================================
2380
+
2381
+ def get_params(self) -> Dict[str, Any]:
2382
+ """Get estimator parameters (sklearn-compatible)."""
2383
+ return {
2384
+ "anticipation": self.anticipation,
2385
+ "alpha": self.alpha,
2386
+ "cluster": self.cluster,
2387
+ "n_bootstrap": self.n_bootstrap,
2388
+ "seed": self.seed,
2389
+ "rank_deficient_action": self.rank_deficient_action,
2390
+ "horizon_max": self.horizon_max,
2391
+ "aux_partition": self.aux_partition,
2392
+ }
2393
+
2394
+ def set_params(self, **params) -> "ImputationDiD":
2395
+ """Set estimator parameters (sklearn-compatible)."""
2396
+ for key, value in params.items():
2397
+ if hasattr(self, key):
2398
+ setattr(self, key, value)
2399
+ else:
2400
+ raise ValueError(f"Unknown parameter: {key}")
2401
+ return self
2402
+
2403
+ def summary(self) -> str:
2404
+ """Get summary of estimation results."""
2405
+ if not self.is_fitted_:
2406
+ raise RuntimeError("Model must be fitted before calling summary()")
2407
+ assert self.results_ is not None
2408
+ return self.results_.summary()
2409
+
2410
+ def print_summary(self) -> None:
2411
+ """Print summary to stdout."""
2412
+ print(self.summary())
2413
+
2414
+
2415
+ # =============================================================================
2416
+ # Convenience function
2417
+ # =============================================================================
2418
+
2419
+
2420
+ def imputation_did(
2421
+ data: pd.DataFrame,
2422
+ outcome: str,
2423
+ unit: str,
2424
+ time: str,
2425
+ first_treat: str,
2426
+ covariates: Optional[List[str]] = None,
2427
+ aggregate: Optional[str] = None,
2428
+ balance_e: Optional[int] = None,
2429
+ **kwargs,
2430
+ ) -> ImputationDiDResults:
2431
+ """
2432
+ Convenience function for imputation DiD estimation.
2433
+
2434
+ This is a shortcut for creating an ImputationDiD estimator and calling fit().
2435
+
2436
+ Parameters
2437
+ ----------
2438
+ data : pd.DataFrame
2439
+ Panel data.
2440
+ outcome : str
2441
+ Outcome variable column name.
2442
+ unit : str
2443
+ Unit identifier column name.
2444
+ time : str
2445
+ Time period column name.
2446
+ first_treat : str
2447
+ Column indicating first treatment period (0 for never-treated).
2448
+ covariates : list of str, optional
2449
+ Covariate column names.
2450
+ aggregate : str, optional
2451
+ Aggregation mode: None, "simple", "event_study", "group", "all".
2452
+ balance_e : int, optional
2453
+ Balance event study to cohorts observed at all relative times.
2454
+ **kwargs
2455
+ Additional keyword arguments passed to ImputationDiD constructor.
2456
+
2457
+ Returns
2458
+ -------
2459
+ ImputationDiDResults
2460
+ Estimation results.
2461
+
2462
+ Examples
2463
+ --------
2464
+ >>> from diff_diff import imputation_did, generate_staggered_data
2465
+ >>> data = generate_staggered_data(seed=42)
2466
+ >>> results = imputation_did(data, 'outcome', 'unit', 'time', 'first_treat',
2467
+ ... aggregate='event_study')
2468
+ >>> results.print_summary()
2469
+ """
2470
+ est = ImputationDiD(**kwargs)
2471
+ return est.fit(
2472
+ data,
2473
+ outcome=outcome,
2474
+ unit=unit,
2475
+ time=time,
2476
+ first_treat=first_treat,
2477
+ covariates=covariates,
2478
+ aggregate=aggregate,
2479
+ balance_e=balance_e,
2480
+ )