diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/staggered.py ADDED
@@ -0,0 +1,2297 @@
1
+ """
2
+ Staggered Difference-in-Differences estimators.
3
+
4
+ Implements modern methods for DiD with variation in treatment timing,
5
+ including the Callaway-Sant'Anna (2021) estimator.
6
+ """
7
+
8
+ import warnings
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, Dict, List, Optional, Set, Tuple
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ from scipy import optimize
15
+
16
+ from diff_diff.linalg import solve_ols
17
+ from diff_diff.results import _get_significance_stars
18
+ from diff_diff.utils import (
19
+ compute_confidence_interval,
20
+ compute_p_value,
21
+ )
22
+
23
+ # Import Rust backend if available (from _backend to avoid circular imports)
24
+ from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights
25
+
26
+ # Type alias for pre-computed structures
27
+ PrecomputedData = Dict[str, Any]
28
+
29
+ # =============================================================================
30
+ # Bootstrap Weight Generators
31
+ # =============================================================================
32
+
33
+
34
+ def _generate_bootstrap_weights(
35
+ n_units: int,
36
+ weight_type: str,
37
+ rng: np.random.Generator,
38
+ ) -> np.ndarray:
39
+ """
40
+ Generate bootstrap weights for multiplier bootstrap.
41
+
42
+ Parameters
43
+ ----------
44
+ n_units : int
45
+ Number of units (clusters) to generate weights for.
46
+ weight_type : str
47
+ Type of weights: "rademacher", "mammen", or "webb".
48
+ rng : np.random.Generator
49
+ Random number generator.
50
+
51
+ Returns
52
+ -------
53
+ np.ndarray
54
+ Array of bootstrap weights with shape (n_units,).
55
+ """
56
+ if weight_type == "rademacher":
57
+ # Rademacher: +1 or -1 with equal probability
58
+ return rng.choice([-1.0, 1.0], size=n_units)
59
+
60
+ elif weight_type == "mammen":
61
+ # Mammen's two-point distribution
62
+ # E[v] = 0, E[v^2] = 1, E[v^3] = 1
63
+ sqrt5 = np.sqrt(5)
64
+ val1 = -(sqrt5 - 1) / 2 # ≈ -0.618
65
+ val2 = (sqrt5 + 1) / 2 # ≈ 1.618 (golden ratio)
66
+ p1 = (sqrt5 + 1) / (2 * sqrt5) # ≈ 0.724
67
+ return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1])
68
+
69
+ elif weight_type == "webb":
70
+ # Webb's 6-point distribution (recommended for few clusters)
71
+ values = np.array([
72
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
73
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
74
+ ])
75
+ probs = np.array([1, 2, 3, 3, 2, 1]) / 12
76
+ return rng.choice(values, size=n_units, p=probs)
77
+
78
+ else:
79
+ raise ValueError(
80
+ f"weight_type must be 'rademacher', 'mammen', or 'webb', "
81
+ f"got '{weight_type}'"
82
+ )
83
+
84
+
85
+ def _generate_bootstrap_weights_batch(
86
+ n_bootstrap: int,
87
+ n_units: int,
88
+ weight_type: str,
89
+ rng: np.random.Generator,
90
+ ) -> np.ndarray:
91
+ """
92
+ Generate all bootstrap weights at once (vectorized).
93
+
94
+ Parameters
95
+ ----------
96
+ n_bootstrap : int
97
+ Number of bootstrap iterations.
98
+ n_units : int
99
+ Number of units (clusters) to generate weights for.
100
+ weight_type : str
101
+ Type of weights: "rademacher", "mammen", or "webb".
102
+ rng : np.random.Generator
103
+ Random number generator.
104
+
105
+ Returns
106
+ -------
107
+ np.ndarray
108
+ Array of bootstrap weights with shape (n_bootstrap, n_units).
109
+ """
110
+ # Use Rust backend if available (parallel + fast RNG)
111
+ if HAS_RUST_BACKEND:
112
+ # Get seed from the NumPy RNG for reproducibility
113
+ seed = rng.integers(0, 2**63 - 1)
114
+ return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed)
115
+
116
+ # Fallback to NumPy implementation
117
+ return _generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng)
118
+
119
+
120
+ def _generate_bootstrap_weights_batch_numpy(
121
+ n_bootstrap: int,
122
+ n_units: int,
123
+ weight_type: str,
124
+ rng: np.random.Generator,
125
+ ) -> np.ndarray:
126
+ """
127
+ NumPy fallback implementation of _generate_bootstrap_weights_batch.
128
+
129
+ Generates multiplier bootstrap weights for wild cluster bootstrap.
130
+ All weight distributions satisfy E[w] = 0, E[w^2] = 1.
131
+
132
+ Parameters
133
+ ----------
134
+ n_bootstrap : int
135
+ Number of bootstrap iterations.
136
+ n_units : int
137
+ Number of units (clusters) to generate weights for.
138
+ weight_type : str
139
+ Type of weights: "rademacher" (+-1), "mammen" (2-point),
140
+ or "webb" (6-point).
141
+ rng : np.random.Generator
142
+ Random number generator for reproducibility.
143
+
144
+ Returns
145
+ -------
146
+ np.ndarray
147
+ Array of bootstrap weights with shape (n_bootstrap, n_units).
148
+ """
149
+ if weight_type == "rademacher":
150
+ # Rademacher: +1 or -1 with equal probability
151
+ return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units))
152
+
153
+ elif weight_type == "mammen":
154
+ # Mammen's two-point distribution
155
+ sqrt5 = np.sqrt(5)
156
+ val1 = -(sqrt5 - 1) / 2
157
+ val2 = (sqrt5 + 1) / 2
158
+ p1 = (sqrt5 + 1) / (2 * sqrt5)
159
+ return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1])
160
+
161
+ elif weight_type == "webb":
162
+ # Webb's 6-point distribution
163
+ values = np.array([
164
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
165
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
166
+ ])
167
+ probs = np.array([1, 2, 3, 3, 2, 1]) / 12
168
+ return rng.choice(values, size=(n_bootstrap, n_units), p=probs)
169
+
170
+ else:
171
+ raise ValueError(
172
+ f"weight_type must be 'rademacher', 'mammen', or 'webb', "
173
+ f"got '{weight_type}'"
174
+ )
175
+
176
+
177
+ # =============================================================================
178
+ # Bootstrap Results Container
179
+ # =============================================================================
180
+
181
+
182
+ @dataclass
183
+ class CSBootstrapResults:
184
+ """
185
+ Results from Callaway-Sant'Anna multiplier bootstrap inference.
186
+
187
+ Attributes
188
+ ----------
189
+ n_bootstrap : int
190
+ Number of bootstrap iterations.
191
+ weight_type : str
192
+ Type of bootstrap weights used.
193
+ alpha : float
194
+ Significance level used for confidence intervals.
195
+ overall_att_se : float
196
+ Bootstrap standard error for overall ATT.
197
+ overall_att_ci : Tuple[float, float]
198
+ Bootstrap confidence interval for overall ATT.
199
+ overall_att_p_value : float
200
+ Bootstrap p-value for overall ATT.
201
+ group_time_ses : Dict[Tuple[Any, Any], float]
202
+ Bootstrap SEs for each ATT(g,t).
203
+ group_time_cis : Dict[Tuple[Any, Any], Tuple[float, float]]
204
+ Bootstrap CIs for each ATT(g,t).
205
+ group_time_p_values : Dict[Tuple[Any, Any], float]
206
+ Bootstrap p-values for each ATT(g,t).
207
+ event_study_ses : Optional[Dict[int, float]]
208
+ Bootstrap SEs for event study effects.
209
+ event_study_cis : Optional[Dict[int, Tuple[float, float]]]
210
+ Bootstrap CIs for event study effects.
211
+ event_study_p_values : Optional[Dict[int, float]]
212
+ Bootstrap p-values for event study effects.
213
+ group_effect_ses : Optional[Dict[Any, float]]
214
+ Bootstrap SEs for group effects.
215
+ group_effect_cis : Optional[Dict[Any, Tuple[float, float]]]
216
+ Bootstrap CIs for group effects.
217
+ group_effect_p_values : Optional[Dict[Any, float]]
218
+ Bootstrap p-values for group effects.
219
+ bootstrap_distribution : Optional[np.ndarray]
220
+ Full bootstrap distribution of overall ATT (if requested).
221
+ """
222
+ n_bootstrap: int
223
+ weight_type: str
224
+ alpha: float
225
+ overall_att_se: float
226
+ overall_att_ci: Tuple[float, float]
227
+ overall_att_p_value: float
228
+ group_time_ses: Dict[Tuple[Any, Any], float]
229
+ group_time_cis: Dict[Tuple[Any, Any], Tuple[float, float]]
230
+ group_time_p_values: Dict[Tuple[Any, Any], float]
231
+ event_study_ses: Optional[Dict[int, float]] = None
232
+ event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
233
+ event_study_p_values: Optional[Dict[int, float]] = None
234
+ group_effect_ses: Optional[Dict[Any, float]] = None
235
+ group_effect_cis: Optional[Dict[Any, Tuple[float, float]]] = None
236
+ group_effect_p_values: Optional[Dict[Any, float]] = None
237
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
238
+
239
+
240
+ def _logistic_regression(
241
+ X: np.ndarray,
242
+ y: np.ndarray,
243
+ max_iter: int = 100,
244
+ tol: float = 1e-6,
245
+ ) -> Tuple[np.ndarray, np.ndarray]:
246
+ """
247
+ Fit logistic regression using scipy optimize.
248
+
249
+ Parameters
250
+ ----------
251
+ X : np.ndarray
252
+ Feature matrix (n_samples, n_features). Intercept added automatically.
253
+ y : np.ndarray
254
+ Binary outcome (0/1).
255
+ max_iter : int
256
+ Maximum iterations.
257
+ tol : float
258
+ Convergence tolerance.
259
+
260
+ Returns
261
+ -------
262
+ beta : np.ndarray
263
+ Fitted coefficients (including intercept).
264
+ probs : np.ndarray
265
+ Predicted probabilities.
266
+ """
267
+ n, p = X.shape
268
+ # Add intercept
269
+ X_with_intercept = np.column_stack([np.ones(n), X])
270
+
271
+ def neg_log_likelihood(beta: np.ndarray) -> float:
272
+ z = X_with_intercept @ beta
273
+ # Clip to prevent overflow
274
+ z = np.clip(z, -500, 500)
275
+ log_lik = np.sum(y * z - np.log(1 + np.exp(z)))
276
+ return -log_lik
277
+
278
+ def gradient(beta: np.ndarray) -> np.ndarray:
279
+ z = X_with_intercept @ beta
280
+ z = np.clip(z, -500, 500)
281
+ probs = 1 / (1 + np.exp(-z))
282
+ return -X_with_intercept.T @ (y - probs)
283
+
284
+ # Initialize with zeros
285
+ beta_init = np.zeros(p + 1)
286
+
287
+ result = optimize.minimize(
288
+ neg_log_likelihood,
289
+ beta_init,
290
+ method='BFGS',
291
+ jac=gradient,
292
+ options={'maxiter': max_iter, 'gtol': tol}
293
+ )
294
+
295
+ beta = result.x
296
+ z = X_with_intercept @ beta
297
+ z = np.clip(z, -500, 500)
298
+ probs = 1 / (1 + np.exp(-z))
299
+
300
+ return beta, probs
301
+
302
+
303
+ def _linear_regression(
304
+ X: np.ndarray,
305
+ y: np.ndarray,
306
+ ) -> Tuple[np.ndarray, np.ndarray]:
307
+ """
308
+ Fit OLS regression.
309
+
310
+ Parameters
311
+ ----------
312
+ X : np.ndarray
313
+ Feature matrix (n_samples, n_features). Intercept added automatically.
314
+ y : np.ndarray
315
+ Outcome variable.
316
+
317
+ Returns
318
+ -------
319
+ beta : np.ndarray
320
+ Fitted coefficients (including intercept).
321
+ residuals : np.ndarray
322
+ Residuals from the fit.
323
+ """
324
+ n = X.shape[0]
325
+ # Add intercept
326
+ X_with_intercept = np.column_stack([np.ones(n), X])
327
+
328
+ # Use unified OLS backend (no vcov needed)
329
+ beta, residuals, _ = solve_ols(X_with_intercept, y, return_vcov=False)
330
+
331
+ return beta, residuals
332
+
333
+
334
+ @dataclass
335
+ class GroupTimeEffect:
336
+ """
337
+ Treatment effect for a specific group-time combination.
338
+
339
+ Attributes
340
+ ----------
341
+ group : any
342
+ The treatment cohort (first treatment period).
343
+ time : any
344
+ The time period.
345
+ effect : float
346
+ The ATT(g,t) estimate.
347
+ se : float
348
+ Standard error.
349
+ n_treated : int
350
+ Number of treated observations.
351
+ n_control : int
352
+ Number of control observations.
353
+ """
354
+ group: Any
355
+ time: Any
356
+ effect: float
357
+ se: float
358
+ t_stat: float
359
+ p_value: float
360
+ conf_int: Tuple[float, float]
361
+ n_treated: int
362
+ n_control: int
363
+
364
+ @property
365
+ def is_significant(self) -> bool:
366
+ """Check if effect is significant at 0.05 level."""
367
+ return bool(self.p_value < 0.05)
368
+
369
+ @property
370
+ def significance_stars(self) -> str:
371
+ """Return significance stars based on p-value."""
372
+ return _get_significance_stars(self.p_value)
373
+
374
+
375
+ @dataclass
376
+ class CallawaySantAnnaResults:
377
+ """
378
+ Results from Callaway-Sant'Anna (2021) staggered DiD estimation.
379
+
380
+ This class stores group-time average treatment effects ATT(g,t) and
381
+ provides methods for aggregation into summary measures.
382
+
383
+ Attributes
384
+ ----------
385
+ group_time_effects : dict
386
+ Dictionary mapping (group, time) tuples to effect dictionaries.
387
+ overall_att : float
388
+ Overall average treatment effect (weighted average of ATT(g,t)).
389
+ overall_se : float
390
+ Standard error of overall ATT.
391
+ overall_p_value : float
392
+ P-value for overall ATT.
393
+ overall_conf_int : tuple
394
+ Confidence interval for overall ATT.
395
+ groups : list
396
+ List of treatment cohorts (first treatment periods).
397
+ time_periods : list
398
+ List of all time periods.
399
+ n_obs : int
400
+ Total number of observations.
401
+ n_treated_units : int
402
+ Number of ever-treated units.
403
+ n_control_units : int
404
+ Number of never-treated units.
405
+ event_study_effects : dict, optional
406
+ Effects aggregated by relative time (event study).
407
+ group_effects : dict, optional
408
+ Effects aggregated by treatment cohort.
409
+ """
410
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]]
411
+ overall_att: float
412
+ overall_se: float
413
+ overall_t_stat: float
414
+ overall_p_value: float
415
+ overall_conf_int: Tuple[float, float]
416
+ groups: List[Any]
417
+ time_periods: List[Any]
418
+ n_obs: int
419
+ n_treated_units: int
420
+ n_control_units: int
421
+ alpha: float = 0.05
422
+ control_group: str = "never_treated"
423
+ event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
424
+ group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
425
+ influence_functions: Optional[np.ndarray] = field(default=None, repr=False)
426
+ bootstrap_results: Optional[CSBootstrapResults] = field(default=None, repr=False)
427
+
428
+ def __repr__(self) -> str:
429
+ """Concise string representation."""
430
+ sig = _get_significance_stars(self.overall_p_value)
431
+ return (
432
+ f"CallawaySantAnnaResults(ATT={self.overall_att:.4f}{sig}, "
433
+ f"SE={self.overall_se:.4f}, "
434
+ f"n_groups={len(self.groups)}, "
435
+ f"n_periods={len(self.time_periods)})"
436
+ )
437
+
438
+ def summary(self, alpha: Optional[float] = None) -> str:
439
+ """
440
+ Generate formatted summary of estimation results.
441
+
442
+ Parameters
443
+ ----------
444
+ alpha : float, optional
445
+ Significance level. Defaults to alpha used in estimation.
446
+
447
+ Returns
448
+ -------
449
+ str
450
+ Formatted summary.
451
+ """
452
+ alpha = alpha or self.alpha
453
+ conf_level = int((1 - alpha) * 100)
454
+
455
+ lines = [
456
+ "=" * 85,
457
+ "Callaway-Sant'Anna Staggered Difference-in-Differences Results".center(85),
458
+ "=" * 85,
459
+ "",
460
+ f"{'Total observations:':<30} {self.n_obs:>10}",
461
+ f"{'Treated units:':<30} {self.n_treated_units:>10}",
462
+ f"{'Control units:':<30} {self.n_control_units:>10}",
463
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
464
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
465
+ f"{'Control group:':<30} {self.control_group:>10}",
466
+ "",
467
+ ]
468
+
469
+ # Overall ATT
470
+ lines.extend([
471
+ "-" * 85,
472
+ "Overall Average Treatment Effect on the Treated".center(85),
473
+ "-" * 85,
474
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
475
+ "-" * 85,
476
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
477
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
478
+ f"{_get_significance_stars(self.overall_p_value):>6}",
479
+ "-" * 85,
480
+ "",
481
+ f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
482
+ "",
483
+ ])
484
+
485
+ # Event study effects if available
486
+ if self.event_study_effects:
487
+ lines.extend([
488
+ "-" * 85,
489
+ "Event Study (Dynamic) Effects".center(85),
490
+ "-" * 85,
491
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
492
+ "-" * 85,
493
+ ])
494
+
495
+ for rel_t in sorted(self.event_study_effects.keys()):
496
+ eff = self.event_study_effects[rel_t]
497
+ sig = _get_significance_stars(eff['p_value'])
498
+ lines.append(
499
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
500
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
501
+ )
502
+
503
+ lines.extend(["-" * 85, ""])
504
+
505
+ # Group effects if available
506
+ if self.group_effects:
507
+ lines.extend([
508
+ "-" * 85,
509
+ "Effects by Treatment Cohort".center(85),
510
+ "-" * 85,
511
+ f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
512
+ "-" * 85,
513
+ ])
514
+
515
+ for group in sorted(self.group_effects.keys()):
516
+ eff = self.group_effects[group]
517
+ sig = _get_significance_stars(eff['p_value'])
518
+ lines.append(
519
+ f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
520
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
521
+ )
522
+
523
+ lines.extend(["-" * 85, ""])
524
+
525
+ lines.extend([
526
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
527
+ "=" * 85,
528
+ ])
529
+
530
+ return "\n".join(lines)
531
+
532
+ def print_summary(self, alpha: Optional[float] = None) -> None:
533
+ """Print summary to stdout."""
534
+ print(self.summary(alpha))
535
+
536
+ def to_dataframe(self, level: str = "group_time") -> pd.DataFrame:
537
+ """
538
+ Convert results to DataFrame.
539
+
540
+ Parameters
541
+ ----------
542
+ level : str, default="group_time"
543
+ Level of aggregation: "group_time", "event_study", or "group".
544
+
545
+ Returns
546
+ -------
547
+ pd.DataFrame
548
+ Results as DataFrame.
549
+ """
550
+ if level == "group_time":
551
+ rows = []
552
+ for (g, t), data in self.group_time_effects.items():
553
+ rows.append({
554
+ 'group': g,
555
+ 'time': t,
556
+ 'effect': data['effect'],
557
+ 'se': data['se'],
558
+ 't_stat': data['t_stat'],
559
+ 'p_value': data['p_value'],
560
+ 'conf_int_lower': data['conf_int'][0],
561
+ 'conf_int_upper': data['conf_int'][1],
562
+ })
563
+ return pd.DataFrame(rows)
564
+
565
+ elif level == "event_study":
566
+ if self.event_study_effects is None:
567
+ raise ValueError("Event study effects not computed. Use aggregate='event_study'.")
568
+ rows = []
569
+ for rel_t, data in sorted(self.event_study_effects.items()):
570
+ rows.append({
571
+ 'relative_period': rel_t,
572
+ 'effect': data['effect'],
573
+ 'se': data['se'],
574
+ 't_stat': data['t_stat'],
575
+ 'p_value': data['p_value'],
576
+ 'conf_int_lower': data['conf_int'][0],
577
+ 'conf_int_upper': data['conf_int'][1],
578
+ })
579
+ return pd.DataFrame(rows)
580
+
581
+ elif level == "group":
582
+ if self.group_effects is None:
583
+ raise ValueError("Group effects not computed. Use aggregate='group'.")
584
+ rows = []
585
+ for group, data in sorted(self.group_effects.items()):
586
+ rows.append({
587
+ 'group': group,
588
+ 'effect': data['effect'],
589
+ 'se': data['se'],
590
+ 't_stat': data['t_stat'],
591
+ 'p_value': data['p_value'],
592
+ 'conf_int_lower': data['conf_int'][0],
593
+ 'conf_int_upper': data['conf_int'][1],
594
+ })
595
+ return pd.DataFrame(rows)
596
+
597
+ else:
598
+ raise ValueError(f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'.")
599
+
600
+ @property
601
+ def is_significant(self) -> bool:
602
+ """Check if overall ATT is significant."""
603
+ return bool(self.overall_p_value < self.alpha)
604
+
605
+ @property
606
+ def significance_stars(self) -> str:
607
+ """Significance stars for overall ATT."""
608
+ return _get_significance_stars(self.overall_p_value)
609
+
610
+
611
+ class CallawaySantAnna:
612
+ """
613
+ Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
614
+
615
+ This estimator handles DiD designs with variation in treatment timing
616
+ (staggered adoption) and heterogeneous treatment effects. It avoids the
617
+ bias of traditional two-way fixed effects (TWFE) estimators by:
618
+
619
+ 1. Computing group-time average treatment effects ATT(g,t) for each
620
+ cohort g (units first treated in period g) and time t.
621
+ 2. Aggregating these to summary measures (overall ATT, event study, etc.)
622
+ using appropriate weights.
623
+
624
+ Parameters
625
+ ----------
626
+ control_group : str, default="never_treated"
627
+ Which units to use as controls:
628
+ - "never_treated": Use only never-treated units (recommended)
629
+ - "not_yet_treated": Use never-treated and not-yet-treated units
630
+ anticipation : int, default=0
631
+ Number of periods before treatment where effects may occur.
632
+ Set to > 0 if treatment effects can begin before the official
633
+ treatment date.
634
+ estimation_method : str, default="dr"
635
+ Estimation method:
636
+ - "dr": Doubly robust (recommended)
637
+ - "ipw": Inverse probability weighting
638
+ - "reg": Outcome regression
639
+ alpha : float, default=0.05
640
+ Significance level for confidence intervals.
641
+ cluster : str, optional
642
+ Column name for cluster-robust standard errors.
643
+ Defaults to unit-level clustering.
644
+ n_bootstrap : int, default=0
645
+ Number of bootstrap iterations for inference.
646
+ If 0, uses analytical standard errors.
647
+ Recommended: 999 or more for reliable inference.
648
+
649
+ .. note:: Memory Usage
650
+ The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
651
+ float64 array. For large datasets, this can be significant:
652
+ - 1K bootstrap × 10K units = ~80 MB
653
+ - 10K bootstrap × 100K units = ~8 GB
654
+ Consider reducing n_bootstrap if memory is constrained.
655
+
656
+ bootstrap_weights : str, default="rademacher"
657
+ Type of weights for multiplier bootstrap:
658
+ - "rademacher": +1/-1 with equal probability (standard choice)
659
+ - "mammen": Two-point distribution (asymptotically valid, matches skewness)
660
+ - "webb": Six-point distribution (recommended when n_clusters < 20)
661
+ bootstrap_weight_type : str, optional
662
+ .. deprecated:: 1.0.1
663
+ Use ``bootstrap_weights`` instead. Will be removed in v2.0.
664
+ seed : int, optional
665
+ Random seed for reproducibility.
666
+
667
+ Attributes
668
+ ----------
669
+ results_ : CallawaySantAnnaResults
670
+ Estimation results after calling fit().
671
+ is_fitted_ : bool
672
+ Whether the model has been fitted.
673
+
674
+ Examples
675
+ --------
676
+ Basic usage:
677
+
678
+ >>> import pandas as pd
679
+ >>> from diff_diff import CallawaySantAnna
680
+ >>>
681
+ >>> # Panel data with staggered treatment
682
+ >>> # 'first_treat' = period when unit was first treated (0 if never treated)
683
+ >>> data = pd.DataFrame({
684
+ ... 'unit': [...],
685
+ ... 'time': [...],
686
+ ... 'outcome': [...],
687
+ ... 'first_treat': [...] # 0 for never-treated, else first treatment period
688
+ ... })
689
+ >>>
690
+ >>> cs = CallawaySantAnna()
691
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
692
+ ... time='time', first_treat='first_treat')
693
+ >>>
694
+ >>> results.print_summary()
695
+
696
+ With event study aggregation:
697
+
698
+ >>> cs = CallawaySantAnna()
699
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
700
+ ... time='time', first_treat='first_treat',
701
+ ... aggregate='event_study')
702
+ >>>
703
+ >>> # Plot event study
704
+ >>> from diff_diff import plot_event_study
705
+ >>> plot_event_study(results)
706
+
707
+ With covariate adjustment (conditional parallel trends):
708
+
709
+ >>> # When parallel trends only holds conditional on covariates
710
+ >>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
711
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
712
+ ... time='time', first_treat='first_treat',
713
+ ... covariates=['age', 'income'])
714
+ >>>
715
+ >>> # DR is recommended: consistent if either outcome model
716
+ >>> # or propensity model is correctly specified
717
+
718
+ Notes
719
+ -----
720
+ The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
721
+ approach: instead of estimating a single treatment effect, they estimate
722
+ ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
723
+ problem where already-treated units act as controls.
724
+
725
+ The ATT(g,t) is identified under parallel trends conditional on covariates:
726
+
727
+ E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
728
+
729
+ where G=g indicates treatment cohort g and C=1 indicates control units.
730
+
731
+ References
732
+ ----------
733
+ Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
734
+ multiple time periods. Journal of Econometrics, 225(2), 200-230.
735
+ """
736
+
737
+ def __init__(
738
+ self,
739
+ control_group: str = "never_treated",
740
+ anticipation: int = 0,
741
+ estimation_method: str = "dr",
742
+ alpha: float = 0.05,
743
+ cluster: Optional[str] = None,
744
+ n_bootstrap: int = 0,
745
+ bootstrap_weights: Optional[str] = None,
746
+ bootstrap_weight_type: Optional[str] = None,
747
+ seed: Optional[int] = None,
748
+ ):
749
+ import warnings
750
+
751
+ if control_group not in ["never_treated", "not_yet_treated"]:
752
+ raise ValueError(
753
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
754
+ f"got '{control_group}'"
755
+ )
756
+ if estimation_method not in ["dr", "ipw", "reg"]:
757
+ raise ValueError(
758
+ f"estimation_method must be 'dr', 'ipw', or 'reg', "
759
+ f"got '{estimation_method}'"
760
+ )
761
+
762
+ # Handle bootstrap_weight_type deprecation
763
+ if bootstrap_weight_type is not None:
764
+ warnings.warn(
765
+ "bootstrap_weight_type is deprecated and will be removed in v2.0. "
766
+ "Use bootstrap_weights instead.",
767
+ DeprecationWarning,
768
+ stacklevel=2
769
+ )
770
+ if bootstrap_weights is None:
771
+ bootstrap_weights = bootstrap_weight_type
772
+
773
+ # Default to rademacher if neither specified
774
+ if bootstrap_weights is None:
775
+ bootstrap_weights = "rademacher"
776
+
777
+ if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
778
+ raise ValueError(
779
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
780
+ f"got '{bootstrap_weights}'"
781
+ )
782
+
783
+ self.control_group = control_group
784
+ self.anticipation = anticipation
785
+ self.estimation_method = estimation_method
786
+ self.alpha = alpha
787
+ self.cluster = cluster
788
+ self.n_bootstrap = n_bootstrap
789
+ self.bootstrap_weights = bootstrap_weights
790
+ # Keep bootstrap_weight_type for backward compatibility
791
+ self.bootstrap_weight_type = bootstrap_weights
792
+ self.seed = seed
793
+
794
+ self.is_fitted_ = False
795
+ self.results_ = None
796
+
797
+ def _precompute_structures(
798
+ self,
799
+ df: pd.DataFrame,
800
+ outcome: str,
801
+ unit: str,
802
+ time: str,
803
+ first_treat: str,
804
+ covariates: Optional[List[str]],
805
+ time_periods: List[Any],
806
+ treatment_groups: List[Any],
807
+ ) -> PrecomputedData:
808
+ """
809
+ Pre-compute data structures for efficient ATT(g,t) computation.
810
+
811
+ This pivots data to wide format and pre-computes:
812
+ - Outcome matrix (units x time periods)
813
+ - Covariate matrix (units x covariates) from base period
814
+ - Unit cohort membership masks
815
+ - Control unit masks
816
+
817
+ Returns
818
+ -------
819
+ PrecomputedData
820
+ Dictionary with pre-computed structures.
821
+ """
822
+ # Get unique units and their cohort assignments
823
+ unit_info = df.groupby(unit)[first_treat].first()
824
+ all_units = unit_info.index.values
825
+ unit_cohorts = unit_info.values
826
+ n_units = len(all_units)
827
+
828
+ # Create unit index mapping for fast lookups
829
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
830
+
831
+ # Pivot outcome to wide format: rows = units, columns = time periods
832
+ outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
833
+ # Reindex to ensure all units are present (handles unbalanced panels)
834
+ outcome_wide = outcome_wide.reindex(all_units)
835
+ outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
836
+ period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
837
+
838
+ # Pre-compute cohort masks (boolean arrays)
839
+ cohort_masks = {}
840
+ for g in treatment_groups:
841
+ cohort_masks[g] = (unit_cohorts == g)
842
+
843
+ # Never-treated mask
844
+ never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
845
+
846
+ # Pre-compute covariate matrices by time period if needed
847
+ # (covariates are retrieved from the base period of each comparison)
848
+ covariate_by_period = None
849
+ if covariates:
850
+ covariate_by_period = {}
851
+ for t in time_periods:
852
+ period_data = df[df[time] == t].set_index(unit)
853
+ period_cov = period_data.reindex(all_units)[covariates]
854
+ covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
855
+
856
+ return {
857
+ 'all_units': all_units,
858
+ 'unit_to_idx': unit_to_idx,
859
+ 'unit_cohorts': unit_cohorts,
860
+ 'outcome_matrix': outcome_matrix,
861
+ 'period_to_col': period_to_col,
862
+ 'cohort_masks': cohort_masks,
863
+ 'never_treated_mask': never_treated_mask,
864
+ 'covariate_by_period': covariate_by_period,
865
+ 'time_periods': time_periods,
866
+ }
867
+
868
+ def _compute_att_gt_fast(
869
+ self,
870
+ precomputed: PrecomputedData,
871
+ g: Any,
872
+ t: Any,
873
+ covariates: Optional[List[str]],
874
+ ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]:
875
+ """
876
+ Compute ATT(g,t) using pre-computed data structures (fast version).
877
+
878
+ Uses vectorized numpy operations on pre-pivoted outcome matrix
879
+ instead of repeated pandas filtering.
880
+ """
881
+ time_periods = precomputed['time_periods']
882
+ period_to_col = precomputed['period_to_col']
883
+ outcome_matrix = precomputed['outcome_matrix']
884
+ cohort_masks = precomputed['cohort_masks']
885
+ never_treated_mask = precomputed['never_treated_mask']
886
+ unit_cohorts = precomputed['unit_cohorts']
887
+ all_units = precomputed['all_units']
888
+ covariate_by_period = precomputed['covariate_by_period']
889
+
890
+ # Base period for comparison
891
+ base_period = g - 1 - self.anticipation
892
+ if base_period not in period_to_col:
893
+ # Find closest earlier period
894
+ earlier = [p for p in time_periods if p < g - self.anticipation]
895
+ if not earlier:
896
+ return None, 0.0, 0, 0, None
897
+ base_period = max(earlier)
898
+
899
+ # Check if periods exist in the data
900
+ if base_period not in period_to_col or t not in period_to_col:
901
+ return None, 0.0, 0, 0, None
902
+
903
+ base_col = period_to_col[base_period]
904
+ post_col = period_to_col[t]
905
+
906
+ # Get treated units mask (cohort g)
907
+ treated_mask = cohort_masks[g]
908
+
909
+ # Get control units mask
910
+ if self.control_group == "never_treated":
911
+ control_mask = never_treated_mask
912
+ else: # not_yet_treated
913
+ # Not yet treated at time t: never-treated OR first_treat > t
914
+ control_mask = never_treated_mask | (unit_cohorts > t)
915
+
916
+ # Extract outcomes for base and post periods
917
+ y_base = outcome_matrix[:, base_col]
918
+ y_post = outcome_matrix[:, post_col]
919
+
920
+ # Compute outcome changes (vectorized)
921
+ outcome_change = y_post - y_base
922
+
923
+ # Filter to units with valid data (no NaN in either period)
924
+ valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
925
+
926
+ # Get treated and control with valid data
927
+ treated_valid = treated_mask & valid_mask
928
+ control_valid = control_mask & valid_mask
929
+
930
+ n_treated = np.sum(treated_valid)
931
+ n_control = np.sum(control_valid)
932
+
933
+ if n_treated == 0 or n_control == 0:
934
+ return None, 0.0, 0, 0, None
935
+
936
+ # Extract outcome changes for treated and control
937
+ treated_change = outcome_change[treated_valid]
938
+ control_change = outcome_change[control_valid]
939
+
940
+ # Get unit IDs for influence function
941
+ treated_units = all_units[treated_valid]
942
+ control_units = all_units[control_valid]
943
+
944
+ # Get covariates if specified (from the base period)
945
+ X_treated = None
946
+ X_control = None
947
+ if covariates and covariate_by_period is not None:
948
+ cov_matrix = covariate_by_period[base_period]
949
+ X_treated = cov_matrix[treated_valid]
950
+ X_control = cov_matrix[control_valid]
951
+
952
+ # Check for missing values
953
+ if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
954
+ warnings.warn(
955
+ f"Missing values in covariates for group {g}, time {t}. "
956
+ "Falling back to unconditional estimation.",
957
+ UserWarning,
958
+ stacklevel=3,
959
+ )
960
+ X_treated = None
961
+ X_control = None
962
+
963
+ # Estimation method
964
+ if self.estimation_method == "reg":
965
+ att_gt, se_gt, inf_func = self._outcome_regression(
966
+ treated_change, control_change, X_treated, X_control
967
+ )
968
+ elif self.estimation_method == "ipw":
969
+ att_gt, se_gt, inf_func = self._ipw_estimation(
970
+ treated_change, control_change,
971
+ int(n_treated), int(n_control),
972
+ X_treated, X_control
973
+ )
974
+ else: # doubly robust
975
+ att_gt, se_gt, inf_func = self._doubly_robust(
976
+ treated_change, control_change, X_treated, X_control
977
+ )
978
+
979
+ # Package influence function info with unit IDs for bootstrap
980
+ n_t = int(n_treated)
981
+ inf_func_info = {
982
+ 'treated_units': list(treated_units),
983
+ 'control_units': list(control_units),
984
+ 'treated_inf': inf_func[:n_t],
985
+ 'control_inf': inf_func[n_t:],
986
+ }
987
+
988
+ return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info
989
+
990
+ def fit(
991
+ self,
992
+ data: pd.DataFrame,
993
+ outcome: str,
994
+ unit: str,
995
+ time: str,
996
+ first_treat: str,
997
+ covariates: Optional[List[str]] = None,
998
+ aggregate: Optional[str] = None,
999
+ balance_e: Optional[int] = None,
1000
+ ) -> CallawaySantAnnaResults:
1001
+ """
1002
+ Fit the Callaway-Sant'Anna estimator.
1003
+
1004
+ Parameters
1005
+ ----------
1006
+ data : pd.DataFrame
1007
+ Panel data with unit and time identifiers.
1008
+ outcome : str
1009
+ Name of outcome variable column.
1010
+ unit : str
1011
+ Name of unit identifier column.
1012
+ time : str
1013
+ Name of time period column.
1014
+ first_treat : str
1015
+ Name of column indicating when unit was first treated.
1016
+ Use 0 (or np.inf) for never-treated units.
1017
+ covariates : list, optional
1018
+ List of covariate column names for conditional parallel trends.
1019
+ aggregate : str, optional
1020
+ How to aggregate group-time effects:
1021
+ - None: Only compute ATT(g,t) (default)
1022
+ - "simple": Simple weighted average (overall ATT)
1023
+ - "event_study": Aggregate by relative time (event study)
1024
+ - "group": Aggregate by treatment cohort
1025
+ - "all": Compute all aggregations
1026
+ balance_e : int, optional
1027
+ For event study, balance the panel at relative time e.
1028
+ Ensures all groups contribute to each relative period.
1029
+
1030
+ Returns
1031
+ -------
1032
+ CallawaySantAnnaResults
1033
+ Object containing all estimation results.
1034
+
1035
+ Raises
1036
+ ------
1037
+ ValueError
1038
+ If required columns are missing or data validation fails.
1039
+ """
1040
+ # Validate inputs
1041
+ required_cols = [outcome, unit, time, first_treat]
1042
+ if covariates:
1043
+ required_cols.extend(covariates)
1044
+
1045
+ missing = [c for c in required_cols if c not in data.columns]
1046
+ if missing:
1047
+ raise ValueError(f"Missing columns: {missing}")
1048
+
1049
+ # Create working copy
1050
+ df = data.copy()
1051
+
1052
+ # Ensure numeric types
1053
+ df[time] = pd.to_numeric(df[time])
1054
+ df[first_treat] = pd.to_numeric(df[first_treat])
1055
+
1056
+ # Identify groups and time periods
1057
+ time_periods = sorted(df[time].unique())
1058
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
1059
+
1060
+ # Never-treated indicator (first_treat = 0 or inf)
1061
+ df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
1062
+
1063
+ # Get unique units
1064
+ unit_info = df.groupby(unit).agg({
1065
+ first_treat: 'first',
1066
+ '_never_treated': 'first'
1067
+ }).reset_index()
1068
+
1069
+ n_treated_units = (unit_info[first_treat] > 0).sum()
1070
+ n_control_units = (unit_info['_never_treated']).sum()
1071
+
1072
+ if n_control_units == 0:
1073
+ raise ValueError("No never-treated units found. Check 'first_treat' column.")
1074
+
1075
+ # Pre-compute data structures for efficient ATT(g,t) computation
1076
+ precomputed = self._precompute_structures(
1077
+ df, outcome, unit, time, first_treat,
1078
+ covariates, time_periods, treatment_groups
1079
+ )
1080
+
1081
+ # Compute ATT(g,t) for each group-time combination
1082
+ group_time_effects = {}
1083
+ influence_func_info = {} # Store influence functions for bootstrap
1084
+
1085
+ for g in treatment_groups:
1086
+ # Periods for which we compute effects (t >= g - anticipation)
1087
+ valid_periods = [t for t in time_periods if t >= g - self.anticipation]
1088
+
1089
+ for t in valid_periods:
1090
+ att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast(
1091
+ precomputed, g, t, covariates
1092
+ )
1093
+
1094
+ if att_gt is not None:
1095
+ t_stat = att_gt / se_gt if se_gt > 0 else 0.0
1096
+ p_val = compute_p_value(t_stat)
1097
+ ci = compute_confidence_interval(att_gt, se_gt, self.alpha)
1098
+
1099
+ group_time_effects[(g, t)] = {
1100
+ 'effect': att_gt,
1101
+ 'se': se_gt,
1102
+ 't_stat': t_stat,
1103
+ 'p_value': p_val,
1104
+ 'conf_int': ci,
1105
+ 'n_treated': n_treat,
1106
+ 'n_control': n_ctrl,
1107
+ }
1108
+
1109
+ if inf_info is not None:
1110
+ influence_func_info[(g, t)] = inf_info
1111
+
1112
+ if not group_time_effects:
1113
+ raise ValueError(
1114
+ "Could not estimate any group-time effects. "
1115
+ "Check that data has sufficient observations."
1116
+ )
1117
+
1118
+ # Compute overall ATT (simple aggregation)
1119
+ overall_att, overall_se = self._aggregate_simple(
1120
+ group_time_effects, influence_func_info, df, unit, precomputed
1121
+ )
1122
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
1123
+ overall_p = compute_p_value(overall_t)
1124
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
1125
+
1126
+ # Compute additional aggregations if requested
1127
+ event_study_effects = None
1128
+ group_effects = None
1129
+
1130
+ if aggregate in ["event_study", "all"]:
1131
+ event_study_effects = self._aggregate_event_study(
1132
+ group_time_effects, influence_func_info,
1133
+ treatment_groups, time_periods, balance_e
1134
+ )
1135
+
1136
+ if aggregate in ["group", "all"]:
1137
+ group_effects = self._aggregate_by_group(
1138
+ group_time_effects, influence_func_info, treatment_groups
1139
+ )
1140
+
1141
+ # Run bootstrap inference if requested
1142
+ bootstrap_results = None
1143
+ if self.n_bootstrap > 0 and influence_func_info:
1144
+ bootstrap_results = self._run_multiplier_bootstrap(
1145
+ group_time_effects=group_time_effects,
1146
+ influence_func_info=influence_func_info,
1147
+ aggregate=aggregate,
1148
+ balance_e=balance_e,
1149
+ treatment_groups=treatment_groups,
1150
+ time_periods=time_periods,
1151
+ )
1152
+
1153
+ # Update estimates with bootstrap inference
1154
+ overall_se = bootstrap_results.overall_att_se
1155
+ overall_t = overall_att / overall_se if overall_se > 0 else 0.0
1156
+ overall_p = bootstrap_results.overall_att_p_value
1157
+ overall_ci = bootstrap_results.overall_att_ci
1158
+
1159
+ # Update group-time effects with bootstrap SEs
1160
+ for gt in group_time_effects:
1161
+ if gt in bootstrap_results.group_time_ses:
1162
+ group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt]
1163
+ group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt]
1164
+ group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt]
1165
+ effect = float(group_time_effects[gt]['effect'])
1166
+ se = float(group_time_effects[gt]['se'])
1167
+ group_time_effects[gt]['t_stat'] = effect / se if se > 0 else 0.0
1168
+
1169
+ # Update event study effects with bootstrap SEs
1170
+ if (event_study_effects is not None
1171
+ and bootstrap_results.event_study_ses is not None
1172
+ and bootstrap_results.event_study_cis is not None
1173
+ and bootstrap_results.event_study_p_values is not None):
1174
+ for e in event_study_effects:
1175
+ if e in bootstrap_results.event_study_ses:
1176
+ event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e]
1177
+ event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e]
1178
+ p_val = bootstrap_results.event_study_p_values[e]
1179
+ event_study_effects[e]['p_value'] = p_val
1180
+ effect = float(event_study_effects[e]['effect'])
1181
+ se = float(event_study_effects[e]['se'])
1182
+ event_study_effects[e]['t_stat'] = effect / se if se > 0 else 0.0
1183
+
1184
+ # Update group effects with bootstrap SEs
1185
+ if (group_effects is not None
1186
+ and bootstrap_results.group_effect_ses is not None
1187
+ and bootstrap_results.group_effect_cis is not None
1188
+ and bootstrap_results.group_effect_p_values is not None):
1189
+ for g in group_effects:
1190
+ if g in bootstrap_results.group_effect_ses:
1191
+ group_effects[g]['se'] = bootstrap_results.group_effect_ses[g]
1192
+ group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g]
1193
+ group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g]
1194
+ effect = float(group_effects[g]['effect'])
1195
+ se = float(group_effects[g]['se'])
1196
+ group_effects[g]['t_stat'] = effect / se if se > 0 else 0.0
1197
+
1198
+ # Store results
1199
+ self.results_ = CallawaySantAnnaResults(
1200
+ group_time_effects=group_time_effects,
1201
+ overall_att=overall_att,
1202
+ overall_se=overall_se,
1203
+ overall_t_stat=overall_t,
1204
+ overall_p_value=overall_p,
1205
+ overall_conf_int=overall_ci,
1206
+ groups=treatment_groups,
1207
+ time_periods=time_periods,
1208
+ n_obs=len(df),
1209
+ n_treated_units=n_treated_units,
1210
+ n_control_units=n_control_units,
1211
+ alpha=self.alpha,
1212
+ control_group=self.control_group,
1213
+ event_study_effects=event_study_effects,
1214
+ group_effects=group_effects,
1215
+ bootstrap_results=bootstrap_results,
1216
+ )
1217
+
1218
+ self.is_fitted_ = True
1219
+ return self.results_
1220
+
1221
+ def _outcome_regression(
1222
+ self,
1223
+ treated_change: np.ndarray,
1224
+ control_change: np.ndarray,
1225
+ X_treated: Optional[np.ndarray] = None,
1226
+ X_control: Optional[np.ndarray] = None,
1227
+ ) -> Tuple[float, float, np.ndarray]:
1228
+ """
1229
+ Estimate ATT using outcome regression.
1230
+
1231
+ With covariates:
1232
+ 1. Regress outcome changes on covariates for control group
1233
+ 2. Predict counterfactual for treated using their covariates
1234
+ 3. ATT = mean(treated_change) - mean(predicted_counterfactual)
1235
+
1236
+ Without covariates:
1237
+ Simple difference in means.
1238
+ """
1239
+ n_t = len(treated_change)
1240
+ n_c = len(control_change)
1241
+
1242
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
1243
+ # Covariate-adjusted outcome regression
1244
+ # Fit regression on control units: E[Delta Y | X, D=0]
1245
+ beta, residuals = _linear_regression(X_control, control_change)
1246
+
1247
+ # Predict counterfactual for treated units
1248
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
1249
+ predicted_control = X_treated_with_intercept @ beta
1250
+
1251
+ # ATT = mean(observed treated change - predicted counterfactual)
1252
+ att = np.mean(treated_change - predicted_control)
1253
+
1254
+ # Standard error using sandwich estimator
1255
+ # Variance from treated: Var(Y_1 - m(X))
1256
+ treated_residuals = treated_change - predicted_control
1257
+ var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
1258
+
1259
+ # Variance from control regression (residual variance)
1260
+ var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
1261
+
1262
+ # Approximate SE (ignoring estimation error in beta for simplicity)
1263
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
1264
+
1265
+ # Influence function
1266
+ inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
1267
+ inf_control = -residuals / n_c
1268
+ inf_func = np.concatenate([inf_treated, inf_control])
1269
+ else:
1270
+ # Simple difference in means (no covariates)
1271
+ att = np.mean(treated_change) - np.mean(control_change)
1272
+
1273
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
1274
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
1275
+
1276
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
1277
+
1278
+ # Influence function (for aggregation)
1279
+ inf_treated = treated_change - np.mean(treated_change)
1280
+ inf_control = control_change - np.mean(control_change)
1281
+ inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
1282
+
1283
+ return att, se, inf_func
1284
+
1285
+ def _ipw_estimation(
1286
+ self,
1287
+ treated_change: np.ndarray,
1288
+ control_change: np.ndarray,
1289
+ n_treated: int,
1290
+ n_control: int,
1291
+ X_treated: Optional[np.ndarray] = None,
1292
+ X_control: Optional[np.ndarray] = None,
1293
+ ) -> Tuple[float, float, np.ndarray]:
1294
+ """
1295
+ Estimate ATT using inverse probability weighting.
1296
+
1297
+ With covariates:
1298
+ 1. Estimate propensity score P(D=1|X) using logistic regression
1299
+ 2. Reweight control units to match treated covariate distribution
1300
+ 3. ATT = mean(treated) - weighted_mean(control)
1301
+
1302
+ Without covariates:
1303
+ Simple difference in means with unconditional propensity weighting.
1304
+ """
1305
+ n_t = len(treated_change)
1306
+ n_c = len(control_change)
1307
+ n_total = n_treated + n_control
1308
+
1309
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
1310
+ # Covariate-adjusted IPW estimation
1311
+ # Stack covariates and create treatment indicator
1312
+ X_all = np.vstack([X_treated, X_control])
1313
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
1314
+
1315
+ # Estimate propensity scores using logistic regression
1316
+ try:
1317
+ _, pscore = _logistic_regression(X_all, D)
1318
+ except (np.linalg.LinAlgError, ValueError):
1319
+ # Fallback to unconditional if logistic regression fails
1320
+ warnings.warn(
1321
+ "Propensity score estimation failed. "
1322
+ "Falling back to unconditional estimation.",
1323
+ UserWarning,
1324
+ stacklevel=4,
1325
+ )
1326
+ pscore = np.full(len(D), n_t / (n_t + n_c))
1327
+
1328
+ # Propensity scores for treated and control
1329
+ pscore_treated = pscore[:n_t]
1330
+ pscore_control = pscore[n_t:]
1331
+
1332
+ # Clip propensity scores to avoid extreme weights
1333
+ pscore_control = np.clip(pscore_control, 0.01, 0.99)
1334
+ pscore_treated = np.clip(pscore_treated, 0.01, 0.99)
1335
+
1336
+ # IPW weights for control units: p(X) / (1 - p(X))
1337
+ # This reweights controls to have same covariate distribution as treated
1338
+ weights_control = pscore_control / (1 - pscore_control)
1339
+ weights_control = weights_control / np.sum(weights_control) # normalize
1340
+
1341
+ # ATT = mean(treated) - weighted_mean(control)
1342
+ att = np.mean(treated_change) - np.sum(weights_control * control_change)
1343
+
1344
+ # Compute standard error
1345
+ # Variance of treated mean
1346
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
1347
+
1348
+ # Variance of weighted control mean
1349
+ weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2)
1350
+
1351
+ se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0
1352
+
1353
+ # Influence function
1354
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
1355
+ inf_control = -weights_control * (control_change - np.sum(weights_control * control_change))
1356
+ inf_func = np.concatenate([inf_treated, inf_control])
1357
+ else:
1358
+ # Unconditional IPW (reduces to difference in means)
1359
+ p_treat = n_treated / n_total # unconditional propensity score
1360
+
1361
+ att = np.mean(treated_change) - np.mean(control_change)
1362
+
1363
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
1364
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
1365
+
1366
+ # Adjusted variance for IPW
1367
+ se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0
1368
+
1369
+ # Influence function (for aggregation)
1370
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
1371
+ inf_control = (control_change - np.mean(control_change)) / n_c
1372
+ inf_func = np.concatenate([inf_treated, -inf_control])
1373
+
1374
+ return att, se, inf_func
1375
+
1376
+ def _doubly_robust(
1377
+ self,
1378
+ treated_change: np.ndarray,
1379
+ control_change: np.ndarray,
1380
+ X_treated: Optional[np.ndarray] = None,
1381
+ X_control: Optional[np.ndarray] = None,
1382
+ ) -> Tuple[float, float, np.ndarray]:
1383
+ """
1384
+ Estimate ATT using doubly robust estimation.
1385
+
1386
+ With covariates:
1387
+ Combines outcome regression and IPW for double robustness.
1388
+ The estimator is consistent if either the outcome model OR
1389
+ the propensity model is correctly specified.
1390
+
1391
+ ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
1392
+ + (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
1393
+
1394
+ where m(X) is the outcome model and w_i are IPW weights.
1395
+
1396
+ Without covariates:
1397
+ Reduces to simple difference in means.
1398
+ """
1399
+ n_t = len(treated_change)
1400
+ n_c = len(control_change)
1401
+
1402
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
1403
+ # Doubly robust estimation with covariates
1404
+ # Step 1: Outcome regression - fit E[Delta Y | X] on control
1405
+ beta, _ = _linear_regression(X_control, control_change)
1406
+
1407
+ # Predict counterfactual for both treated and control
1408
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
1409
+ X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
1410
+ m_treated = X_treated_with_intercept @ beta
1411
+ m_control = X_control_with_intercept @ beta
1412
+
1413
+ # Step 2: Propensity score estimation
1414
+ X_all = np.vstack([X_treated, X_control])
1415
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
1416
+
1417
+ try:
1418
+ _, pscore = _logistic_regression(X_all, D)
1419
+ except (np.linalg.LinAlgError, ValueError):
1420
+ # Fallback to unconditional if logistic regression fails
1421
+ pscore = np.full(len(D), n_t / (n_t + n_c))
1422
+
1423
+ pscore_control = pscore[n_t:]
1424
+
1425
+ # Clip propensity scores
1426
+ pscore_control = np.clip(pscore_control, 0.01, 0.99)
1427
+
1428
+ # IPW weights for control: p(X) / (1 - p(X))
1429
+ weights_control = pscore_control / (1 - pscore_control)
1430
+
1431
+ # Step 3: Doubly robust ATT
1432
+ # ATT = mean(treated - m(X_treated))
1433
+ # + weighted_mean_control((m(X) - Y) * weight)
1434
+ att_treated_part = np.mean(treated_change - m_treated)
1435
+
1436
+ # Augmentation term from control
1437
+ augmentation = np.sum(weights_control * (m_control - control_change)) / n_t
1438
+
1439
+ att = att_treated_part + augmentation
1440
+
1441
+ # Step 4: Standard error using influence function
1442
+ # Influence function for DR estimator
1443
+ psi_treated = (treated_change - m_treated - att) / n_t
1444
+ psi_control = (weights_control * (m_control - control_change)) / n_t
1445
+
1446
+ # Variance is sum of squared influence functions
1447
+ var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2)
1448
+ se = np.sqrt(var_psi) if var_psi > 0 else 0.0
1449
+
1450
+ # Full influence function
1451
+ inf_func = np.concatenate([psi_treated, psi_control])
1452
+ else:
1453
+ # Without covariates, DR simplifies to difference in means
1454
+ att = np.mean(treated_change) - np.mean(control_change)
1455
+
1456
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
1457
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
1458
+
1459
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
1460
+
1461
+ # Influence function for DR estimator
1462
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
1463
+ inf_control = (control_change - np.mean(control_change)) / n_c
1464
+ inf_func = np.concatenate([inf_treated, -inf_control])
1465
+
1466
+ return att, se, inf_func
1467
+
1468
+ def _aggregate_simple(
1469
+ self,
1470
+ group_time_effects: Dict,
1471
+ influence_func_info: Dict,
1472
+ df: pd.DataFrame,
1473
+ unit: str,
1474
+ precomputed: Optional[PrecomputedData] = None,
1475
+ ) -> Tuple[float, float]:
1476
+ """
1477
+ Compute simple weighted average of ATT(g,t).
1478
+
1479
+ Weights by group size (number of treated units).
1480
+
1481
+ Standard errors are computed using influence function aggregation,
1482
+ which properly accounts for covariances across (g,t) pairs due to
1483
+ shared control units. This includes the wif (weight influence function)
1484
+ adjustment from R's `did` package that accounts for uncertainty in
1485
+ estimating the group-size weights.
1486
+ """
1487
+ effects = []
1488
+ weights_list = []
1489
+ gt_pairs = []
1490
+ groups_for_gt = []
1491
+
1492
+ for (g, t), data in group_time_effects.items():
1493
+ effects.append(data['effect'])
1494
+ weights_list.append(data['n_treated'])
1495
+ gt_pairs.append((g, t))
1496
+ groups_for_gt.append(g)
1497
+
1498
+ effects = np.array(effects)
1499
+ weights = np.array(weights_list, dtype=float)
1500
+ groups_for_gt = np.array(groups_for_gt)
1501
+
1502
+ # Normalize weights
1503
+ total_weight = np.sum(weights)
1504
+ weights_norm = weights / total_weight
1505
+
1506
+ # Weighted average
1507
+ overall_att = np.sum(weights_norm * effects)
1508
+
1509
+ # Compute SE using influence function aggregation with wif adjustment
1510
+ overall_se = self._compute_aggregated_se_with_wif(
1511
+ gt_pairs, weights_norm, effects, groups_for_gt,
1512
+ influence_func_info, df, unit, precomputed
1513
+ )
1514
+
1515
+ return overall_att, overall_se
1516
+
1517
+ def _compute_aggregated_se(
1518
+ self,
1519
+ gt_pairs: List[Tuple[Any, Any]],
1520
+ weights: np.ndarray,
1521
+ influence_func_info: Dict,
1522
+ ) -> float:
1523
+ """
1524
+ Compute standard error using influence function aggregation.
1525
+
1526
+ This properly accounts for covariances across (g,t) pairs by
1527
+ aggregating unit-level influence functions:
1528
+
1529
+ ψ_i(overall) = Σ_{(g,t)} w_(g,t) × ψ_i(g,t)
1530
+ Var(overall) = (1/n) Σ_i [ψ_i]²
1531
+
1532
+ This matches R's `did` package analytical SE formula.
1533
+ """
1534
+ if not influence_func_info:
1535
+ # Fallback if no influence functions available
1536
+ return 0.0
1537
+
1538
+ # Build unit index mapping from all (g,t) pairs
1539
+ all_units = set()
1540
+ for (g, t) in gt_pairs:
1541
+ if (g, t) in influence_func_info:
1542
+ info = influence_func_info[(g, t)]
1543
+ all_units.update(info['treated_units'])
1544
+ all_units.update(info['control_units'])
1545
+
1546
+ if not all_units:
1547
+ return 0.0
1548
+
1549
+ all_units = sorted(all_units)
1550
+ n_units = len(all_units)
1551
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1552
+
1553
+ # Aggregate influence functions across (g,t) pairs
1554
+ psi_overall = np.zeros(n_units)
1555
+
1556
+ for j, (g, t) in enumerate(gt_pairs):
1557
+ if (g, t) not in influence_func_info:
1558
+ continue
1559
+
1560
+ info = influence_func_info[(g, t)]
1561
+ w = weights[j]
1562
+
1563
+ # Treated unit contributions
1564
+ for i, unit_id in enumerate(info['treated_units']):
1565
+ idx = unit_to_idx[unit_id]
1566
+ psi_overall[idx] += w * info['treated_inf'][i]
1567
+
1568
+ # Control unit contributions
1569
+ for i, unit_id in enumerate(info['control_units']):
1570
+ idx = unit_to_idx[unit_id]
1571
+ psi_overall[idx] += w * info['control_inf'][i]
1572
+
1573
+ # Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ²
1574
+ variance = np.sum(psi_overall ** 2)
1575
+ return np.sqrt(variance)
1576
+
1577
+ def _compute_aggregated_se_with_wif(
1578
+ self,
1579
+ gt_pairs: List[Tuple[Any, Any]],
1580
+ weights: np.ndarray,
1581
+ effects: np.ndarray,
1582
+ groups_for_gt: np.ndarray,
1583
+ influence_func_info: Dict,
1584
+ df: pd.DataFrame,
1585
+ unit: str,
1586
+ precomputed: Optional[PrecomputedData] = None,
1587
+ ) -> float:
1588
+ """
1589
+ Compute SE with weight influence function (wif) adjustment.
1590
+
1591
+ This matches R's `did` package approach for "simple" aggregation,
1592
+ which accounts for uncertainty in estimating group-size weights.
1593
+
1594
+ The wif adjustment adds variance due to the fact that aggregation
1595
+ weights w_g = n_g / N depend on estimated group sizes.
1596
+
1597
+ Formula (matching R's did::aggte):
1598
+ agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
1599
+ se = sqrt(mean(agg_inf^2) / n)
1600
+
1601
+ where:
1602
+ - k indexes "keepers" (post-treatment (g,t) pairs)
1603
+ - w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
1604
+ - wif captures how unit i influences the weight estimation
1605
+ """
1606
+ if not influence_func_info:
1607
+ return 0.0
1608
+
1609
+ # Build unit index mapping
1610
+ all_units_set: Set[Any] = set()
1611
+ for (g, t) in gt_pairs:
1612
+ if (g, t) in influence_func_info:
1613
+ info = influence_func_info[(g, t)]
1614
+ all_units_set.update(info['treated_units'])
1615
+ all_units_set.update(info['control_units'])
1616
+
1617
+ if not all_units_set:
1618
+ return 0.0
1619
+
1620
+ all_units = sorted(all_units_set)
1621
+ n_units = len(all_units)
1622
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1623
+
1624
+ # Get unique groups and their information
1625
+ unique_groups = sorted(set(groups_for_gt))
1626
+ unique_groups_set = set(unique_groups)
1627
+ group_to_idx = {g: i for i, g in enumerate(unique_groups)}
1628
+
1629
+ # Compute group-level probabilities matching R's formula:
1630
+ # pg[g] = n_g / n_all (fraction of ALL units in group g)
1631
+ # This differs from our old formula which used n_g / total_treated
1632
+ group_sizes = {}
1633
+ for g in unique_groups:
1634
+ treated_in_g = df[df['first_treat'] == g][unit].nunique()
1635
+ group_sizes[g] = treated_in_g
1636
+
1637
+ # pg indexed by group
1638
+ pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups])
1639
+
1640
+ # pg indexed by keeper (each (g,t) pair gets its group's pg)
1641
+ # This matches R's: pg <- pgg[match(group, originalglist)]
1642
+ pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
1643
+ sum_pg_keepers = np.sum(pg_keepers)
1644
+
1645
+ # Guard against zero weights (no keepers = no variance)
1646
+ if sum_pg_keepers == 0:
1647
+ return 0.0
1648
+
1649
+ # Standard aggregated influence (without wif)
1650
+ psi_standard = np.zeros(n_units)
1651
+
1652
+ for j, (g, t) in enumerate(gt_pairs):
1653
+ if (g, t) not in influence_func_info:
1654
+ continue
1655
+
1656
+ info = influence_func_info[(g, t)]
1657
+ w = weights[j]
1658
+
1659
+ # Vectorized influence function aggregation for treated units
1660
+ treated_indices = np.array([unit_to_idx[uid] for uid in info['treated_units']])
1661
+ if len(treated_indices) > 0:
1662
+ np.add.at(psi_standard, treated_indices, w * info['treated_inf'])
1663
+
1664
+ # Vectorized influence function aggregation for control units
1665
+ control_indices = np.array([unit_to_idx[uid] for uid in info['control_units']])
1666
+ if len(control_indices) > 0:
1667
+ np.add.at(psi_standard, control_indices, w * info['control_inf'])
1668
+
1669
+ # Build unit-group array using precomputed data if available
1670
+ # This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
1671
+ if precomputed is not None:
1672
+ # Use precomputed cohort mapping
1673
+ precomputed_units = precomputed['all_units']
1674
+ precomputed_cohorts = precomputed['unit_cohorts']
1675
+ precomputed_unit_to_idx = precomputed['unit_to_idx']
1676
+
1677
+ # Build unit_groups_array for the units in this SE computation
1678
+ # A value of -1 indicates never-treated or other (not in unique_groups)
1679
+ unit_groups_array = np.full(n_units, -1, dtype=np.float64)
1680
+ for i, uid in enumerate(all_units):
1681
+ if uid in precomputed_unit_to_idx:
1682
+ cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
1683
+ if cohort in unique_groups_set:
1684
+ unit_groups_array[i] = cohort
1685
+ else:
1686
+ # Fallback: build from DataFrame (slow path for backward compatibility)
1687
+ unit_groups_array = np.full(n_units, -1, dtype=np.float64)
1688
+ for i, uid in enumerate(all_units):
1689
+ unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
1690
+ if unit_first_treat in unique_groups_set:
1691
+ unit_groups_array[i] = unit_first_treat
1692
+
1693
+ # Vectorized WIF computation
1694
+ # R's wif formula:
1695
+ # if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
1696
+ # if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
1697
+ # wif[i,k] = if1[i,k] - if2[i,k]
1698
+ # wif_contrib[i] = sum_k(wif[i,k] * att[k])
1699
+
1700
+ # Build indicator matrix: (n_units, n_keepers)
1701
+ # indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
1702
+ groups_for_gt_array = np.array(groups_for_gt)
1703
+ indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64)
1704
+
1705
+ # Vectorized indicator_sum: sum over keepers
1706
+ # indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
1707
+ indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
1708
+
1709
+ # Vectorized wif matrix computation
1710
+ # if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
1711
+ if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
1712
+ # if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
1713
+ if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
1714
+ wif_matrix = if1_matrix - if2_matrix
1715
+
1716
+ # Single matrix-vector multiply for all contributions
1717
+ # wif_contrib[i] = sum_k(wif[i,k] * att[k])
1718
+ wif_contrib = wif_matrix @ effects
1719
+
1720
+ # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
1721
+ psi_wif = wif_contrib / n_units
1722
+
1723
+ # Combine standard and wif terms
1724
+ psi_total = psi_standard + psi_wif
1725
+
1726
+ # Compute variance and SE
1727
+ # R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
1728
+ variance = np.sum(psi_total ** 2)
1729
+ return np.sqrt(variance)
1730
+
1731
+ def _aggregate_event_study(
1732
+ self,
1733
+ group_time_effects: Dict,
1734
+ influence_func_info: Dict,
1735
+ groups: List[Any],
1736
+ time_periods: List[Any],
1737
+ balance_e: Optional[int] = None,
1738
+ ) -> Dict[int, Dict[str, Any]]:
1739
+ """
1740
+ Aggregate effects by relative time (event study).
1741
+
1742
+ Computes average effect at each event time e = t - g.
1743
+
1744
+ Standard errors use influence function aggregation to account for
1745
+ covariances across (g,t) pairs.
1746
+ """
1747
+ # Organize effects by relative time, keeping track of (g,t) pairs
1748
+ effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
1749
+
1750
+ for (g, t), data in group_time_effects.items():
1751
+ e = t - g # Relative time
1752
+ if e not in effects_by_e:
1753
+ effects_by_e[e] = []
1754
+ effects_by_e[e].append((
1755
+ (g, t), # Keep track of the (g,t) pair
1756
+ data['effect'],
1757
+ data['n_treated']
1758
+ ))
1759
+
1760
+ # Balance the panel if requested
1761
+ if balance_e is not None:
1762
+ # Keep only groups that have effects at relative time balance_e
1763
+ groups_at_e = set()
1764
+ for (g, t), data in group_time_effects.items():
1765
+ if t - g == balance_e:
1766
+ groups_at_e.add(g)
1767
+
1768
+ # Filter effects to only include balanced groups
1769
+ balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
1770
+ for (g, t), data in group_time_effects.items():
1771
+ if g in groups_at_e:
1772
+ e = t - g
1773
+ if e not in balanced_effects:
1774
+ balanced_effects[e] = []
1775
+ balanced_effects[e].append((
1776
+ (g, t),
1777
+ data['effect'],
1778
+ data['n_treated']
1779
+ ))
1780
+ effects_by_e = balanced_effects
1781
+
1782
+ # Compute aggregated effects
1783
+ event_study_effects = {}
1784
+
1785
+ for e, effect_list in sorted(effects_by_e.items()):
1786
+ gt_pairs = [x[0] for x in effect_list]
1787
+ effs = np.array([x[1] for x in effect_list])
1788
+ ns = np.array([x[2] for x in effect_list], dtype=float)
1789
+
1790
+ # Weight by group size
1791
+ weights = ns / np.sum(ns)
1792
+
1793
+ agg_effect = np.sum(weights * effs)
1794
+
1795
+ # Compute SE using influence function aggregation
1796
+ agg_se = self._compute_aggregated_se(
1797
+ gt_pairs, weights, influence_func_info
1798
+ )
1799
+
1800
+ t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
1801
+ p_val = compute_p_value(t_stat)
1802
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
1803
+
1804
+ event_study_effects[e] = {
1805
+ 'effect': agg_effect,
1806
+ 'se': agg_se,
1807
+ 't_stat': t_stat,
1808
+ 'p_value': p_val,
1809
+ 'conf_int': ci,
1810
+ 'n_groups': len(effect_list),
1811
+ }
1812
+
1813
+ return event_study_effects
1814
+
1815
+ def _aggregate_by_group(
1816
+ self,
1817
+ group_time_effects: Dict,
1818
+ influence_func_info: Dict,
1819
+ groups: List[Any],
1820
+ ) -> Dict[Any, Dict[str, Any]]:
1821
+ """
1822
+ Aggregate effects by treatment cohort.
1823
+
1824
+ Computes average effect for each cohort across all post-treatment periods.
1825
+
1826
+ Standard errors use influence function aggregation to account for
1827
+ covariances across time periods within a cohort.
1828
+ """
1829
+ group_effects = {}
1830
+
1831
+ for g in groups:
1832
+ # Get all effects for this group (post-treatment only: t >= g)
1833
+ # Keep track of (g, t) pairs for influence function aggregation
1834
+ g_effects = [
1835
+ ((g, t), data['effect'])
1836
+ for (gg, t), data in group_time_effects.items()
1837
+ if gg == g and t >= g
1838
+ ]
1839
+
1840
+ if not g_effects:
1841
+ continue
1842
+
1843
+ gt_pairs = [x[0] for x in g_effects]
1844
+ effs = np.array([x[1] for x in g_effects])
1845
+
1846
+ # Equal weight across time periods for a group
1847
+ weights = np.ones(len(effs)) / len(effs)
1848
+
1849
+ agg_effect = np.sum(weights * effs)
1850
+
1851
+ # Compute SE using influence function aggregation
1852
+ agg_se = self._compute_aggregated_se(
1853
+ gt_pairs, weights, influence_func_info
1854
+ )
1855
+
1856
+ t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
1857
+ p_val = compute_p_value(t_stat)
1858
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
1859
+
1860
+ group_effects[g] = {
1861
+ 'effect': agg_effect,
1862
+ 'se': agg_se,
1863
+ 't_stat': t_stat,
1864
+ 'p_value': p_val,
1865
+ 'conf_int': ci,
1866
+ 'n_periods': len(g_effects),
1867
+ }
1868
+
1869
+ return group_effects
1870
+
1871
+ def _run_multiplier_bootstrap(
1872
+ self,
1873
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
1874
+ influence_func_info: Dict[Tuple[Any, Any], Dict[str, Any]],
1875
+ aggregate: Optional[str],
1876
+ balance_e: Optional[int],
1877
+ treatment_groups: List[Any],
1878
+ time_periods: List[Any],
1879
+ ) -> CSBootstrapResults:
1880
+ """
1881
+ Run multiplier bootstrap for inference on all parameters.
1882
+
1883
+ This implements the multiplier bootstrap procedure from Callaway & Sant'Anna (2021).
1884
+ The key idea is to perturb the influence function contributions with random
1885
+ weights at the cluster (unit) level, then recompute aggregations.
1886
+
1887
+ Parameters
1888
+ ----------
1889
+ group_time_effects : dict
1890
+ Dictionary of ATT(g,t) effects with analytical SEs.
1891
+ influence_func_info : dict
1892
+ Dictionary mapping (g,t) to influence function information.
1893
+ aggregate : str, optional
1894
+ Type of aggregation requested.
1895
+ balance_e : int, optional
1896
+ Balance parameter for event study.
1897
+ treatment_groups : list
1898
+ List of treatment cohorts.
1899
+ time_periods : list
1900
+ List of time periods.
1901
+
1902
+ Returns
1903
+ -------
1904
+ CSBootstrapResults
1905
+ Bootstrap inference results.
1906
+ """
1907
+ # Warn about low bootstrap iterations
1908
+ if self.n_bootstrap < 50:
1909
+ warnings.warn(
1910
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
1911
+ "for reliable inference. Percentile confidence intervals and p-values "
1912
+ "may be unreliable with few iterations.",
1913
+ UserWarning,
1914
+ stacklevel=3,
1915
+ )
1916
+
1917
+ rng = np.random.default_rng(self.seed)
1918
+
1919
+ # Collect all unique units across all (g,t) combinations
1920
+ all_units = set()
1921
+ for (g, t), info in influence_func_info.items():
1922
+ all_units.update(info['treated_units'])
1923
+ all_units.update(info['control_units'])
1924
+ all_units = sorted(all_units)
1925
+ n_units = len(all_units)
1926
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1927
+
1928
+ # Get list of (g,t) pairs
1929
+ gt_pairs = list(group_time_effects.keys())
1930
+ n_gt = len(gt_pairs)
1931
+
1932
+ # Compute aggregation weights for overall ATT
1933
+ overall_weights = np.array([
1934
+ group_time_effects[gt]['n_treated'] for gt in gt_pairs
1935
+ ], dtype=float)
1936
+ overall_weights = overall_weights / np.sum(overall_weights)
1937
+
1938
+ # Original point estimates
1939
+ original_atts = np.array([group_time_effects[gt]['effect'] for gt in gt_pairs])
1940
+ original_overall = np.sum(overall_weights * original_atts)
1941
+
1942
+ # Prepare event study and group aggregation info if needed
1943
+ event_study_info = None
1944
+ group_agg_info = None
1945
+
1946
+ if aggregate in ["event_study", "all"]:
1947
+ event_study_info = self._prepare_event_study_aggregation(
1948
+ gt_pairs, group_time_effects, balance_e
1949
+ )
1950
+
1951
+ if aggregate in ["group", "all"]:
1952
+ group_agg_info = self._prepare_group_aggregation(
1953
+ gt_pairs, group_time_effects, treatment_groups
1954
+ )
1955
+
1956
+ # Pre-compute unit index arrays for each (g,t) pair (done once, not per iteration)
1957
+ gt_treated_indices = []
1958
+ gt_control_indices = []
1959
+ gt_treated_inf = []
1960
+ gt_control_inf = []
1961
+
1962
+ for j, gt in enumerate(gt_pairs):
1963
+ info = influence_func_info[gt]
1964
+ treated_idx = np.array([unit_to_idx[u] for u in info['treated_units']])
1965
+ control_idx = np.array([unit_to_idx[u] for u in info['control_units']])
1966
+ gt_treated_indices.append(treated_idx)
1967
+ gt_control_indices.append(control_idx)
1968
+ gt_treated_inf.append(np.asarray(info['treated_inf']))
1969
+ gt_control_inf.append(np.asarray(info['control_inf']))
1970
+
1971
+ # Generate ALL bootstrap weights upfront: shape (n_bootstrap, n_units)
1972
+ # This is much faster than generating one at a time
1973
+ all_bootstrap_weights = _generate_bootstrap_weights_batch(
1974
+ self.n_bootstrap, n_units, self.bootstrap_weight_type, rng
1975
+ )
1976
+
1977
+ # Vectorized bootstrap ATT(g,t) computation
1978
+ # Compute all bootstrap ATTs for all (g,t) pairs using matrix operations
1979
+ bootstrap_atts_gt = np.zeros((self.n_bootstrap, n_gt))
1980
+
1981
+ for j in range(n_gt):
1982
+ treated_idx = gt_treated_indices[j]
1983
+ control_idx = gt_control_indices[j]
1984
+ treated_inf = gt_treated_inf[j]
1985
+ control_inf = gt_control_inf[j]
1986
+
1987
+ # Extract weights for this (g,t)'s units across all bootstrap iterations
1988
+ # Shape: (n_bootstrap, n_treated) and (n_bootstrap, n_control)
1989
+ treated_weights = all_bootstrap_weights[:, treated_idx]
1990
+ control_weights = all_bootstrap_weights[:, control_idx]
1991
+
1992
+ # Vectorized perturbation: matrix-vector multiply
1993
+ # Shape: (n_bootstrap,)
1994
+ perturbations = (
1995
+ treated_weights @ treated_inf +
1996
+ control_weights @ control_inf
1997
+ )
1998
+
1999
+ bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
2000
+
2001
+ # Vectorized overall ATT: matrix-vector multiply
2002
+ # Shape: (n_bootstrap,)
2003
+ bootstrap_overall = bootstrap_atts_gt @ overall_weights
2004
+
2005
+ # Vectorized event study aggregation
2006
+ if event_study_info is not None:
2007
+ rel_periods = sorted(event_study_info.keys())
2008
+ bootstrap_event_study = {}
2009
+ for e in rel_periods:
2010
+ agg_info = event_study_info[e]
2011
+ gt_indices = agg_info['gt_indices']
2012
+ weights = agg_info['weights']
2013
+ # Vectorized: select columns and multiply by weights
2014
+ bootstrap_event_study[e] = bootstrap_atts_gt[:, gt_indices] @ weights
2015
+ else:
2016
+ bootstrap_event_study = None
2017
+
2018
+ # Vectorized group aggregation
2019
+ if group_agg_info is not None:
2020
+ groups = sorted(group_agg_info.keys())
2021
+ bootstrap_group = {}
2022
+ for g in groups:
2023
+ agg_info = group_agg_info[g]
2024
+ gt_indices = agg_info['gt_indices']
2025
+ weights = agg_info['weights']
2026
+ bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights
2027
+ else:
2028
+ bootstrap_group = None
2029
+
2030
+ # Compute bootstrap statistics for ATT(g,t)
2031
+ gt_ses = {}
2032
+ gt_cis = {}
2033
+ gt_p_values = {}
2034
+
2035
+ for j, gt in enumerate(gt_pairs):
2036
+ se, ci, p_value = self._compute_effect_bootstrap_stats(
2037
+ original_atts[j], bootstrap_atts_gt[:, j]
2038
+ )
2039
+ gt_ses[gt] = se
2040
+ gt_cis[gt] = ci
2041
+ gt_p_values[gt] = p_value
2042
+
2043
+ # Compute bootstrap statistics for overall ATT
2044
+ overall_se, overall_ci, overall_p_value = self._compute_effect_bootstrap_stats(
2045
+ original_overall, bootstrap_overall
2046
+ )
2047
+
2048
+ # Compute bootstrap statistics for event study effects
2049
+ event_study_ses = None
2050
+ event_study_cis = None
2051
+ event_study_p_values = None
2052
+
2053
+ if bootstrap_event_study is not None and event_study_info is not None:
2054
+ event_study_ses = {}
2055
+ event_study_cis = {}
2056
+ event_study_p_values = {}
2057
+
2058
+ for e in rel_periods:
2059
+ se, ci, p_value = self._compute_effect_bootstrap_stats(
2060
+ event_study_info[e]['effect'], bootstrap_event_study[e]
2061
+ )
2062
+ event_study_ses[e] = se
2063
+ event_study_cis[e] = ci
2064
+ event_study_p_values[e] = p_value
2065
+
2066
+ # Compute bootstrap statistics for group effects
2067
+ group_effect_ses = None
2068
+ group_effect_cis = None
2069
+ group_effect_p_values = None
2070
+
2071
+ if bootstrap_group is not None and group_agg_info is not None:
2072
+ group_effect_ses = {}
2073
+ group_effect_cis = {}
2074
+ group_effect_p_values = {}
2075
+
2076
+ for g in groups:
2077
+ se, ci, p_value = self._compute_effect_bootstrap_stats(
2078
+ group_agg_info[g]['effect'], bootstrap_group[g]
2079
+ )
2080
+ group_effect_ses[g] = se
2081
+ group_effect_cis[g] = ci
2082
+ group_effect_p_values[g] = p_value
2083
+
2084
+ return CSBootstrapResults(
2085
+ n_bootstrap=self.n_bootstrap,
2086
+ weight_type=self.bootstrap_weight_type,
2087
+ alpha=self.alpha,
2088
+ overall_att_se=overall_se,
2089
+ overall_att_ci=overall_ci,
2090
+ overall_att_p_value=overall_p_value,
2091
+ group_time_ses=gt_ses,
2092
+ group_time_cis=gt_cis,
2093
+ group_time_p_values=gt_p_values,
2094
+ event_study_ses=event_study_ses,
2095
+ event_study_cis=event_study_cis,
2096
+ event_study_p_values=event_study_p_values,
2097
+ group_effect_ses=group_effect_ses,
2098
+ group_effect_cis=group_effect_cis,
2099
+ group_effect_p_values=group_effect_p_values,
2100
+ bootstrap_distribution=bootstrap_overall,
2101
+ )
2102
+
2103
+ def _prepare_event_study_aggregation(
2104
+ self,
2105
+ gt_pairs: List[Tuple[Any, Any]],
2106
+ group_time_effects: Dict,
2107
+ balance_e: Optional[int],
2108
+ ) -> Dict[int, Dict[str, Any]]:
2109
+ """Prepare aggregation info for event study bootstrap."""
2110
+ # Organize by relative time
2111
+ effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}
2112
+
2113
+ for j, (g, t) in enumerate(gt_pairs):
2114
+ e = t - g
2115
+ if e not in effects_by_e:
2116
+ effects_by_e[e] = []
2117
+ effects_by_e[e].append((
2118
+ j, # index in gt_pairs
2119
+ group_time_effects[(g, t)]['effect'],
2120
+ group_time_effects[(g, t)]['n_treated']
2121
+ ))
2122
+
2123
+ # Balance if requested
2124
+ if balance_e is not None:
2125
+ groups_at_e = set()
2126
+ for j, (g, t) in enumerate(gt_pairs):
2127
+ if t - g == balance_e:
2128
+ groups_at_e.add(g)
2129
+
2130
+ balanced_effects: Dict[int, List[Tuple[int, float, float]]] = {}
2131
+ for j, (g, t) in enumerate(gt_pairs):
2132
+ if g in groups_at_e:
2133
+ e = t - g
2134
+ if e not in balanced_effects:
2135
+ balanced_effects[e] = []
2136
+ balanced_effects[e].append((
2137
+ j,
2138
+ group_time_effects[(g, t)]['effect'],
2139
+ group_time_effects[(g, t)]['n_treated']
2140
+ ))
2141
+ effects_by_e = balanced_effects
2142
+
2143
+ # Compute aggregation weights
2144
+ result = {}
2145
+ for e, effect_list in effects_by_e.items():
2146
+ indices = np.array([x[0] for x in effect_list])
2147
+ effects = np.array([x[1] for x in effect_list])
2148
+ n_treated = np.array([x[2] for x in effect_list], dtype=float)
2149
+
2150
+ weights = n_treated / np.sum(n_treated)
2151
+ agg_effect = np.sum(weights * effects)
2152
+
2153
+ result[e] = {
2154
+ 'gt_indices': indices,
2155
+ 'weights': weights,
2156
+ 'effect': agg_effect,
2157
+ }
2158
+
2159
+ return result
2160
+
2161
+ def _prepare_group_aggregation(
2162
+ self,
2163
+ gt_pairs: List[Tuple[Any, Any]],
2164
+ group_time_effects: Dict,
2165
+ treatment_groups: List[Any],
2166
+ ) -> Dict[Any, Dict[str, Any]]:
2167
+ """Prepare aggregation info for group-level bootstrap."""
2168
+ result = {}
2169
+
2170
+ for g in treatment_groups:
2171
+ # Get all effects for this group (post-treatment only: t >= g)
2172
+ group_data = []
2173
+ for j, (gg, t) in enumerate(gt_pairs):
2174
+ if gg == g and t >= g:
2175
+ group_data.append((
2176
+ j,
2177
+ group_time_effects[(gg, t)]['effect'],
2178
+ ))
2179
+
2180
+ if not group_data:
2181
+ continue
2182
+
2183
+ indices = np.array([x[0] for x in group_data])
2184
+ effects = np.array([x[1] for x in group_data])
2185
+
2186
+ # Equal weights across time periods
2187
+ weights = np.ones(len(effects)) / len(effects)
2188
+ agg_effect = np.sum(weights * effects)
2189
+
2190
+ result[g] = {
2191
+ 'gt_indices': indices,
2192
+ 'weights': weights,
2193
+ 'effect': agg_effect,
2194
+ }
2195
+
2196
+ return result
2197
+
2198
+ def _compute_percentile_ci(
2199
+ self,
2200
+ boot_dist: np.ndarray,
2201
+ alpha: float,
2202
+ ) -> Tuple[float, float]:
2203
+ """Compute percentile confidence interval from bootstrap distribution."""
2204
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
2205
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
2206
+ return (lower, upper)
2207
+
2208
+ def _compute_bootstrap_pvalue(
2209
+ self,
2210
+ original_effect: float,
2211
+ boot_dist: np.ndarray,
2212
+ ) -> float:
2213
+ """
2214
+ Compute two-sided bootstrap p-value.
2215
+
2216
+ Uses the percentile method: p-value is the proportion of bootstrap
2217
+ estimates on the opposite side of zero from the original estimate,
2218
+ doubled for two-sided test.
2219
+ """
2220
+ if original_effect >= 0:
2221
+ # Proportion of bootstrap estimates <= 0
2222
+ p_one_sided = np.mean(boot_dist <= 0)
2223
+ else:
2224
+ # Proportion of bootstrap estimates >= 0
2225
+ p_one_sided = np.mean(boot_dist >= 0)
2226
+
2227
+ # Two-sided p-value
2228
+ p_value = min(2 * p_one_sided, 1.0)
2229
+
2230
+ # Ensure minimum p-value
2231
+ p_value = max(p_value, 1 / (self.n_bootstrap + 1))
2232
+
2233
+ return float(p_value)
2234
+
2235
+ def _compute_effect_bootstrap_stats(
2236
+ self,
2237
+ original_effect: float,
2238
+ boot_dist: np.ndarray,
2239
+ ) -> Tuple[float, Tuple[float, float], float]:
2240
+ """
2241
+ Compute bootstrap statistics for a single effect.
2242
+
2243
+ Parameters
2244
+ ----------
2245
+ original_effect : float
2246
+ Original point estimate.
2247
+ boot_dist : np.ndarray
2248
+ Bootstrap distribution of the effect.
2249
+
2250
+ Returns
2251
+ -------
2252
+ se : float
2253
+ Bootstrap standard error.
2254
+ ci : Tuple[float, float]
2255
+ Percentile confidence interval.
2256
+ p_value : float
2257
+ Bootstrap p-value.
2258
+ """
2259
+ se = float(np.std(boot_dist, ddof=1))
2260
+ ci = self._compute_percentile_ci(boot_dist, self.alpha)
2261
+ p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
2262
+ return se, ci, p_value
2263
+
2264
+ def get_params(self) -> Dict[str, Any]:
2265
+ """Get estimator parameters (sklearn-compatible)."""
2266
+ return {
2267
+ "control_group": self.control_group,
2268
+ "anticipation": self.anticipation,
2269
+ "estimation_method": self.estimation_method,
2270
+ "alpha": self.alpha,
2271
+ "cluster": self.cluster,
2272
+ "n_bootstrap": self.n_bootstrap,
2273
+ "bootstrap_weights": self.bootstrap_weights,
2274
+ # Deprecated but kept for backward compatibility
2275
+ "bootstrap_weight_type": self.bootstrap_weight_type,
2276
+ "seed": self.seed,
2277
+ }
2278
+
2279
+ def set_params(self, **params) -> "CallawaySantAnna":
2280
+ """Set estimator parameters (sklearn-compatible)."""
2281
+ for key, value in params.items():
2282
+ if hasattr(self, key):
2283
+ setattr(self, key, value)
2284
+ else:
2285
+ raise ValueError(f"Unknown parameter: {key}")
2286
+ return self
2287
+
2288
+ def summary(self) -> str:
2289
+ """Get summary of estimation results."""
2290
+ if not self.is_fitted_:
2291
+ raise RuntimeError("Model must be fitted before calling summary()")
2292
+ assert self.results_ is not None
2293
+ return self.results_.summary()
2294
+
2295
+ def print_summary(self) -> None:
2296
+ """Print summary to stdout."""
2297
+ print(self.summary())