diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/staggered.py ADDED
@@ -0,0 +1,1120 @@
1
+ """
2
+ Staggered Difference-in-Differences estimators.
3
+
4
+ Implements modern methods for DiD with variation in treatment timing,
5
+ including the Callaway-Sant'Anna (2021) estimator.
6
+ """
7
+
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from scipy import optimize
14
+
15
+ from diff_diff.linalg import solve_ols
16
+ from diff_diff.utils import (
17
+ compute_confidence_interval,
18
+ compute_p_value,
19
+ )
20
+
21
+ # Import from split modules
22
+ from diff_diff.staggered_results import (
23
+ GroupTimeEffect,
24
+ CallawaySantAnnaResults,
25
+ )
26
+ from diff_diff.staggered_bootstrap import (
27
+ CSBootstrapResults,
28
+ CallawaySantAnnaBootstrapMixin,
29
+ )
30
+ from diff_diff.staggered_aggregation import (
31
+ CallawaySantAnnaAggregationMixin,
32
+ )
33
+
34
+ # Re-export for backward compatibility
35
+ __all__ = [
36
+ "CallawaySantAnna",
37
+ "CallawaySantAnnaResults",
38
+ "CSBootstrapResults",
39
+ "GroupTimeEffect",
40
+ ]
41
+
42
+ # Type alias for pre-computed structures
43
+ PrecomputedData = Dict[str, Any]
44
+
45
+
46
+ def _logistic_regression(
47
+ X: np.ndarray,
48
+ y: np.ndarray,
49
+ max_iter: int = 100,
50
+ tol: float = 1e-6,
51
+ ) -> Tuple[np.ndarray, np.ndarray]:
52
+ """
53
+ Fit logistic regression using scipy optimize.
54
+
55
+ Parameters
56
+ ----------
57
+ X : np.ndarray
58
+ Feature matrix (n_samples, n_features). Intercept added automatically.
59
+ y : np.ndarray
60
+ Binary outcome (0/1).
61
+ max_iter : int
62
+ Maximum iterations.
63
+ tol : float
64
+ Convergence tolerance.
65
+
66
+ Returns
67
+ -------
68
+ beta : np.ndarray
69
+ Fitted coefficients (including intercept).
70
+ probs : np.ndarray
71
+ Predicted probabilities.
72
+ """
73
+ n, p = X.shape
74
+ # Add intercept
75
+ X_with_intercept = np.column_stack([np.ones(n), X])
76
+
77
+ def neg_log_likelihood(beta: np.ndarray) -> float:
78
+ z = X_with_intercept @ beta
79
+ # Clip to prevent overflow
80
+ z = np.clip(z, -500, 500)
81
+ log_lik = np.sum(y * z - np.log(1 + np.exp(z)))
82
+ return -log_lik
83
+
84
+ def gradient(beta: np.ndarray) -> np.ndarray:
85
+ z = X_with_intercept @ beta
86
+ z = np.clip(z, -500, 500)
87
+ probs = 1 / (1 + np.exp(-z))
88
+ return -X_with_intercept.T @ (y - probs)
89
+
90
+ # Initialize with zeros
91
+ beta_init = np.zeros(p + 1)
92
+
93
+ result = optimize.minimize(
94
+ neg_log_likelihood,
95
+ beta_init,
96
+ method='BFGS',
97
+ jac=gradient,
98
+ options={'maxiter': max_iter, 'gtol': tol}
99
+ )
100
+
101
+ beta = result.x
102
+ z = X_with_intercept @ beta
103
+ z = np.clip(z, -500, 500)
104
+ probs = 1 / (1 + np.exp(-z))
105
+
106
+ return beta, probs
107
+
108
+
109
+ def _linear_regression(
110
+ X: np.ndarray,
111
+ y: np.ndarray,
112
+ rank_deficient_action: str = "warn",
113
+ ) -> Tuple[np.ndarray, np.ndarray]:
114
+ """
115
+ Fit OLS regression.
116
+
117
+ Parameters
118
+ ----------
119
+ X : np.ndarray
120
+ Feature matrix (n_samples, n_features). Intercept added automatically.
121
+ y : np.ndarray
122
+ Outcome variable.
123
+ rank_deficient_action : str, default "warn"
124
+ Action when design matrix is rank-deficient:
125
+ - "warn": Issue warning and drop linearly dependent columns (default)
126
+ - "error": Raise ValueError
127
+ - "silent": Drop columns silently without warning
128
+
129
+ Returns
130
+ -------
131
+ beta : np.ndarray
132
+ Fitted coefficients (including intercept).
133
+ residuals : np.ndarray
134
+ Residuals from the fit.
135
+ """
136
+ n = X.shape[0]
137
+ # Add intercept
138
+ X_with_intercept = np.column_stack([np.ones(n), X])
139
+
140
+ # Use unified OLS backend (no vcov needed)
141
+ beta, residuals, _ = solve_ols(
142
+ X_with_intercept, y, return_vcov=False,
143
+ rank_deficient_action=rank_deficient_action,
144
+ )
145
+
146
+ return beta, residuals
147
+
148
+
149
+ class CallawaySantAnna(
150
+ CallawaySantAnnaBootstrapMixin,
151
+ CallawaySantAnnaAggregationMixin,
152
+ ):
153
+ """
154
+ Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
155
+
156
+ This estimator handles DiD designs with variation in treatment timing
157
+ (staggered adoption) and heterogeneous treatment effects. It avoids the
158
+ bias of traditional two-way fixed effects (TWFE) estimators by:
159
+
160
+ 1. Computing group-time average treatment effects ATT(g,t) for each
161
+ cohort g (units first treated in period g) and time t.
162
+ 2. Aggregating these to summary measures (overall ATT, event study, etc.)
163
+ using appropriate weights.
164
+
165
+ Parameters
166
+ ----------
167
+ control_group : str, default="never_treated"
168
+ Which units to use as controls:
169
+ - "never_treated": Use only never-treated units (recommended)
170
+ - "not_yet_treated": Use never-treated and not-yet-treated units
171
+ anticipation : int, default=0
172
+ Number of periods before treatment where effects may occur.
173
+ Set to > 0 if treatment effects can begin before the official
174
+ treatment date.
175
+ estimation_method : str, default="dr"
176
+ Estimation method:
177
+ - "dr": Doubly robust (recommended)
178
+ - "ipw": Inverse probability weighting
179
+ - "reg": Outcome regression
180
+ alpha : float, default=0.05
181
+ Significance level for confidence intervals.
182
+ cluster : str, optional
183
+ Column name for cluster-robust standard errors.
184
+ Defaults to unit-level clustering.
185
+ n_bootstrap : int, default=0
186
+ Number of bootstrap iterations for inference.
187
+ If 0, uses analytical standard errors.
188
+ Recommended: 999 or more for reliable inference.
189
+
190
+ .. note:: Memory Usage
191
+ The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
192
+ float64 array. For large datasets, this can be significant:
193
+ - 1K bootstrap × 10K units = ~80 MB
194
+ - 10K bootstrap × 100K units = ~8 GB
195
+ Consider reducing n_bootstrap if memory is constrained.
196
+
197
+ bootstrap_weights : str, default="rademacher"
198
+ Type of weights for multiplier bootstrap:
199
+ - "rademacher": +1/-1 with equal probability (standard choice)
200
+ - "mammen": Two-point distribution (asymptotically valid, matches skewness)
201
+ - "webb": Six-point distribution (recommended when n_clusters < 20)
202
+ bootstrap_weight_type : str, optional
203
+ .. deprecated:: 1.0.1
204
+ Use ``bootstrap_weights`` instead. Will be removed in v3.0.
205
+ seed : int, optional
206
+ Random seed for reproducibility.
207
+ rank_deficient_action : str, default="warn"
208
+ Action when design matrix is rank-deficient (linearly dependent columns):
209
+ - "warn": Issue warning and drop linearly dependent columns (default)
210
+ - "error": Raise ValueError
211
+ - "silent": Drop columns silently without warning
212
+ base_period : str, default="varying"
213
+ Method for selecting the base (reference) period for computing
214
+ ATT(g,t). Options:
215
+ - "varying": For pre-treatment periods (t < g - anticipation), use
216
+ t-1 as base (consecutive comparisons). For post-treatment, use
217
+ g-1-anticipation. Requires t-1 to exist in data.
218
+ - "universal": Always use g-1-anticipation as base period.
219
+ Both produce identical post-treatment effects. Matches R's
220
+ did::att_gt() base_period parameter.
221
+
222
+ Attributes
223
+ ----------
224
+ results_ : CallawaySantAnnaResults
225
+ Estimation results after calling fit().
226
+ is_fitted_ : bool
227
+ Whether the model has been fitted.
228
+
229
+ Examples
230
+ --------
231
+ Basic usage:
232
+
233
+ >>> import pandas as pd
234
+ >>> from diff_diff import CallawaySantAnna
235
+ >>>
236
+ >>> # Panel data with staggered treatment
237
+ >>> # 'first_treat' = period when unit was first treated (0 if never treated)
238
+ >>> data = pd.DataFrame({
239
+ ... 'unit': [...],
240
+ ... 'time': [...],
241
+ ... 'outcome': [...],
242
+ ... 'first_treat': [...] # 0 for never-treated, else first treatment period
243
+ ... })
244
+ >>>
245
+ >>> cs = CallawaySantAnna()
246
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
247
+ ... time='time', first_treat='first_treat')
248
+ >>>
249
+ >>> results.print_summary()
250
+
251
+ With event study aggregation:
252
+
253
+ >>> cs = CallawaySantAnna()
254
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
255
+ ... time='time', first_treat='first_treat',
256
+ ... aggregate='event_study')
257
+ >>>
258
+ >>> # Plot event study
259
+ >>> from diff_diff import plot_event_study
260
+ >>> plot_event_study(results)
261
+
262
+ With covariate adjustment (conditional parallel trends):
263
+
264
+ >>> # When parallel trends only holds conditional on covariates
265
+ >>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
266
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
267
+ ... time='time', first_treat='first_treat',
268
+ ... covariates=['age', 'income'])
269
+ >>>
270
+ >>> # DR is recommended: consistent if either outcome model
271
+ >>> # or propensity model is correctly specified
272
+
273
+ Notes
274
+ -----
275
+ The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
276
+ approach: instead of estimating a single treatment effect, they estimate
277
+ ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
278
+ problem where already-treated units act as controls.
279
+
280
+ The ATT(g,t) is identified under parallel trends conditional on covariates:
281
+
282
+ E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
283
+
284
+ where G=g indicates treatment cohort g and C=1 indicates control units.
285
+ This uses g-1 as the base period, which applies to post-treatment (t >= g).
286
+ With base_period="varying" (default), pre-treatment uses t-1 as base for
287
+ consecutive comparisons useful in parallel trends diagnostics.
288
+
289
+ References
290
+ ----------
291
+ Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
292
+ multiple time periods. Journal of Econometrics, 225(2), 200-230.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ control_group: str = "never_treated",
298
+ anticipation: int = 0,
299
+ estimation_method: str = "dr",
300
+ alpha: float = 0.05,
301
+ cluster: Optional[str] = None,
302
+ n_bootstrap: int = 0,
303
+ bootstrap_weights: Optional[str] = None,
304
+ bootstrap_weight_type: Optional[str] = None,
305
+ seed: Optional[int] = None,
306
+ rank_deficient_action: str = "warn",
307
+ base_period: str = "varying",
308
+ ):
309
+ import warnings
310
+
311
+ if control_group not in ["never_treated", "not_yet_treated"]:
312
+ raise ValueError(
313
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
314
+ f"got '{control_group}'"
315
+ )
316
+ if estimation_method not in ["dr", "ipw", "reg"]:
317
+ raise ValueError(
318
+ f"estimation_method must be 'dr', 'ipw', or 'reg', "
319
+ f"got '{estimation_method}'"
320
+ )
321
+
322
+ # Handle bootstrap_weight_type deprecation
323
+ if bootstrap_weight_type is not None:
324
+ warnings.warn(
325
+ "bootstrap_weight_type is deprecated and will be removed in v3.0. "
326
+ "Use bootstrap_weights instead.",
327
+ DeprecationWarning,
328
+ stacklevel=2
329
+ )
330
+ if bootstrap_weights is None:
331
+ bootstrap_weights = bootstrap_weight_type
332
+
333
+ # Default to rademacher if neither specified
334
+ if bootstrap_weights is None:
335
+ bootstrap_weights = "rademacher"
336
+
337
+ if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
338
+ raise ValueError(
339
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
340
+ f"got '{bootstrap_weights}'"
341
+ )
342
+
343
+ if rank_deficient_action not in ["warn", "error", "silent"]:
344
+ raise ValueError(
345
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
346
+ f"got '{rank_deficient_action}'"
347
+ )
348
+
349
+ if base_period not in ["varying", "universal"]:
350
+ raise ValueError(
351
+ f"base_period must be 'varying' or 'universal', "
352
+ f"got '{base_period}'"
353
+ )
354
+
355
+ self.control_group = control_group
356
+ self.anticipation = anticipation
357
+ self.estimation_method = estimation_method
358
+ self.alpha = alpha
359
+ self.cluster = cluster
360
+ self.n_bootstrap = n_bootstrap
361
+ self.bootstrap_weights = bootstrap_weights
362
+ # Keep bootstrap_weight_type for backward compatibility
363
+ self.bootstrap_weight_type = bootstrap_weights
364
+ self.seed = seed
365
+ self.rank_deficient_action = rank_deficient_action
366
+ self.base_period = base_period
367
+
368
+ self.is_fitted_ = False
369
+ self.results_: Optional[CallawaySantAnnaResults] = None
370
+
371
+ def _precompute_structures(
372
+ self,
373
+ df: pd.DataFrame,
374
+ outcome: str,
375
+ unit: str,
376
+ time: str,
377
+ first_treat: str,
378
+ covariates: Optional[List[str]],
379
+ time_periods: List[Any],
380
+ treatment_groups: List[Any],
381
+ ) -> PrecomputedData:
382
+ """
383
+ Pre-compute data structures for efficient ATT(g,t) computation.
384
+
385
+ This pivots data to wide format and pre-computes:
386
+ - Outcome matrix (units x time periods)
387
+ - Covariate matrix (units x covariates) from base period
388
+ - Unit cohort membership masks
389
+ - Control unit masks
390
+
391
+ Returns
392
+ -------
393
+ PrecomputedData
394
+ Dictionary with pre-computed structures.
395
+ """
396
+ # Get unique units and their cohort assignments
397
+ unit_info = df.groupby(unit)[first_treat].first()
398
+ all_units = unit_info.index.values
399
+ unit_cohorts = unit_info.values
400
+ n_units = len(all_units)
401
+
402
+ # Create unit index mapping for fast lookups
403
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
404
+
405
+ # Pivot outcome to wide format: rows = units, columns = time periods
406
+ outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
407
+ # Reindex to ensure all units are present (handles unbalanced panels)
408
+ outcome_wide = outcome_wide.reindex(all_units)
409
+ outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
410
+ period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
411
+
412
+ # Pre-compute cohort masks (boolean arrays)
413
+ cohort_masks = {}
414
+ for g in treatment_groups:
415
+ cohort_masks[g] = (unit_cohorts == g)
416
+
417
+ # Never-treated mask
418
+ # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
419
+ never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
420
+
421
+ # Pre-compute covariate matrices by time period if needed
422
+ # (covariates are retrieved from the base period of each comparison)
423
+ covariate_by_period = None
424
+ if covariates:
425
+ covariate_by_period = {}
426
+ for t in time_periods:
427
+ period_data = df[df[time] == t].set_index(unit)
428
+ period_cov = period_data.reindex(all_units)[covariates]
429
+ covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
430
+
431
+ return {
432
+ 'all_units': all_units,
433
+ 'unit_to_idx': unit_to_idx,
434
+ 'unit_cohorts': unit_cohorts,
435
+ 'outcome_matrix': outcome_matrix,
436
+ 'period_to_col': period_to_col,
437
+ 'cohort_masks': cohort_masks,
438
+ 'never_treated_mask': never_treated_mask,
439
+ 'covariate_by_period': covariate_by_period,
440
+ 'time_periods': time_periods,
441
+ }
442
+
443
+ def _compute_att_gt_fast(
444
+ self,
445
+ precomputed: PrecomputedData,
446
+ g: Any,
447
+ t: Any,
448
+ covariates: Optional[List[str]],
449
+ ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]:
450
+ """
451
+ Compute ATT(g,t) using pre-computed data structures (fast version).
452
+
453
+ Uses vectorized numpy operations on pre-pivoted outcome matrix
454
+ instead of repeated pandas filtering.
455
+ """
456
+ time_periods = precomputed['time_periods']
457
+ period_to_col = precomputed['period_to_col']
458
+ outcome_matrix = precomputed['outcome_matrix']
459
+ cohort_masks = precomputed['cohort_masks']
460
+ never_treated_mask = precomputed['never_treated_mask']
461
+ unit_cohorts = precomputed['unit_cohorts']
462
+ all_units = precomputed['all_units']
463
+ covariate_by_period = precomputed['covariate_by_period']
464
+
465
+ # Base period selection based on mode
466
+ if self.base_period == "universal":
467
+ # Universal: always use g - 1 - anticipation
468
+ base_period_val = g - 1 - self.anticipation
469
+ else: # varying
470
+ if t < g - self.anticipation:
471
+ # Pre-treatment: use t - 1 (consecutive comparison)
472
+ base_period_val = t - 1
473
+ else:
474
+ # Post-treatment: use g - 1 - anticipation
475
+ base_period_val = g - 1 - self.anticipation
476
+
477
+ if base_period_val not in period_to_col:
478
+ # Base period must exist; no fallback to maintain methodological consistency
479
+ return None, 0.0, 0, 0, None
480
+
481
+ # Check if periods exist in the data
482
+ if base_period_val not in period_to_col or t not in period_to_col:
483
+ return None, 0.0, 0, 0, None
484
+
485
+ base_col = period_to_col[base_period_val]
486
+ post_col = period_to_col[t]
487
+
488
+ # Get treated units mask (cohort g)
489
+ treated_mask = cohort_masks[g]
490
+
491
+ # Get control units mask
492
+ if self.control_group == "never_treated":
493
+ control_mask = never_treated_mask
494
+ else: # not_yet_treated
495
+ # Not yet treated at time t: never-treated OR (first_treat > t AND not cohort g)
496
+ # Must exclude cohort g since they are the treated group for this ATT(g,t)
497
+ control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g))
498
+
499
+ # Extract outcomes for base and post periods
500
+ y_base = outcome_matrix[:, base_col]
501
+ y_post = outcome_matrix[:, post_col]
502
+
503
+ # Compute outcome changes (vectorized)
504
+ outcome_change = y_post - y_base
505
+
506
+ # Filter to units with valid data (no NaN in either period)
507
+ valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
508
+
509
+ # Get treated and control with valid data
510
+ treated_valid = treated_mask & valid_mask
511
+ control_valid = control_mask & valid_mask
512
+
513
+ n_treated = np.sum(treated_valid)
514
+ n_control = np.sum(control_valid)
515
+
516
+ if n_treated == 0 or n_control == 0:
517
+ return None, 0.0, 0, 0, None
518
+
519
+ # Extract outcome changes for treated and control
520
+ treated_change = outcome_change[treated_valid]
521
+ control_change = outcome_change[control_valid]
522
+
523
+ # Get unit IDs for influence function
524
+ treated_units = all_units[treated_valid]
525
+ control_units = all_units[control_valid]
526
+
527
+ # Get covariates if specified (from the base period)
528
+ X_treated = None
529
+ X_control = None
530
+ if covariates and covariate_by_period is not None:
531
+ cov_matrix = covariate_by_period[base_period_val]
532
+ X_treated = cov_matrix[treated_valid]
533
+ X_control = cov_matrix[control_valid]
534
+
535
+ # Check for missing values
536
+ if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
537
+ warnings.warn(
538
+ f"Missing values in covariates for group {g}, time {t}. "
539
+ "Falling back to unconditional estimation.",
540
+ UserWarning,
541
+ stacklevel=3,
542
+ )
543
+ X_treated = None
544
+ X_control = None
545
+
546
+ # Estimation method
547
+ if self.estimation_method == "reg":
548
+ att_gt, se_gt, inf_func = self._outcome_regression(
549
+ treated_change, control_change, X_treated, X_control
550
+ )
551
+ elif self.estimation_method == "ipw":
552
+ att_gt, se_gt, inf_func = self._ipw_estimation(
553
+ treated_change, control_change,
554
+ int(n_treated), int(n_control),
555
+ X_treated, X_control
556
+ )
557
+ else: # doubly robust
558
+ att_gt, se_gt, inf_func = self._doubly_robust(
559
+ treated_change, control_change, X_treated, X_control
560
+ )
561
+
562
+ # Package influence function info with unit IDs for bootstrap
563
+ n_t = int(n_treated)
564
+ inf_func_info = {
565
+ 'treated_units': list(treated_units),
566
+ 'control_units': list(control_units),
567
+ 'treated_inf': inf_func[:n_t],
568
+ 'control_inf': inf_func[n_t:],
569
+ }
570
+
571
+ return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info
572
+
573
+ def fit(
574
+ self,
575
+ data: pd.DataFrame,
576
+ outcome: str,
577
+ unit: str,
578
+ time: str,
579
+ first_treat: str,
580
+ covariates: Optional[List[str]] = None,
581
+ aggregate: Optional[str] = None,
582
+ balance_e: Optional[int] = None,
583
+ ) -> CallawaySantAnnaResults:
584
+ """
585
+ Fit the Callaway-Sant'Anna estimator.
586
+
587
+ Parameters
588
+ ----------
589
+ data : pd.DataFrame
590
+ Panel data with unit and time identifiers.
591
+ outcome : str
592
+ Name of outcome variable column.
593
+ unit : str
594
+ Name of unit identifier column.
595
+ time : str
596
+ Name of time period column.
597
+ first_treat : str
598
+ Name of column indicating when unit was first treated.
599
+ Use 0 (or np.inf) for never-treated units.
600
+ covariates : list, optional
601
+ List of covariate column names for conditional parallel trends.
602
+ aggregate : str, optional
603
+ How to aggregate group-time effects:
604
+ - None: Only compute ATT(g,t) (default)
605
+ - "simple": Simple weighted average (overall ATT)
606
+ - "event_study": Aggregate by relative time (event study)
607
+ - "group": Aggregate by treatment cohort
608
+ - "all": Compute all aggregations
609
+ balance_e : int, optional
610
+ For event study, balance the panel at relative time e.
611
+ Ensures all groups contribute to each relative period.
612
+
613
+ Returns
614
+ -------
615
+ CallawaySantAnnaResults
616
+ Object containing all estimation results.
617
+
618
+ Raises
619
+ ------
620
+ ValueError
621
+ If required columns are missing or data validation fails.
622
+ """
623
+ # Validate inputs
624
+ required_cols = [outcome, unit, time, first_treat]
625
+ if covariates:
626
+ required_cols.extend(covariates)
627
+
628
+ missing = [c for c in required_cols if c not in data.columns]
629
+ if missing:
630
+ raise ValueError(f"Missing columns: {missing}")
631
+
632
+ # Create working copy
633
+ df = data.copy()
634
+
635
+ # Ensure numeric types
636
+ df[time] = pd.to_numeric(df[time])
637
+ df[first_treat] = pd.to_numeric(df[first_treat])
638
+
639
+ # Standardize the first_treat column name for internal use
640
+ # This avoids hardcoding column names in internal methods
641
+ df['first_treat'] = df[first_treat]
642
+
643
+ # Never-treated indicator (must precede treatment_groups to exclude np.inf)
644
+ df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
645
+ # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
646
+ df.loc[df[first_treat] == np.inf, first_treat] = 0
647
+
648
+ # Identify groups and time periods
649
+ time_periods = sorted(df[time].unique())
650
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
651
+
652
+ # Get unique units
653
+ unit_info = df.groupby(unit).agg({
654
+ first_treat: 'first',
655
+ '_never_treated': 'first'
656
+ }).reset_index()
657
+
658
+ n_treated_units = (unit_info[first_treat] > 0).sum()
659
+ n_control_units = (unit_info['_never_treated']).sum()
660
+
661
+ if n_control_units == 0:
662
+ raise ValueError("No never-treated units found. Check 'first_treat' column.")
663
+
664
+ # Pre-compute data structures for efficient ATT(g,t) computation
665
+ precomputed = self._precompute_structures(
666
+ df, outcome, unit, time, first_treat,
667
+ covariates, time_periods, treatment_groups
668
+ )
669
+
670
+ # Compute ATT(g,t) for each group-time combination
671
+ group_time_effects = {}
672
+ influence_func_info = {} # Store influence functions for bootstrap
673
+
674
+ # Get minimum period for determining valid pre-treatment periods
675
+ min_period = min(time_periods)
676
+
677
+ for g in treatment_groups:
678
+ # Compute valid periods including pre-treatment
679
+ if self.base_period == "universal":
680
+ # Universal: all periods except the base period (which is normalized to 0)
681
+ universal_base = g - 1 - self.anticipation
682
+ valid_periods = [t for t in time_periods if t != universal_base]
683
+ else:
684
+ # Varying: post-treatment + pre-treatment where t-1 exists
685
+ valid_periods = [
686
+ t for t in time_periods
687
+ if t >= g - self.anticipation or t > min_period
688
+ ]
689
+
690
+ for t in valid_periods:
691
+ att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast(
692
+ precomputed, g, t, covariates
693
+ )
694
+
695
+ if att_gt is not None:
696
+ t_stat = att_gt / se_gt if np.isfinite(se_gt) and se_gt > 0 else np.nan
697
+ p_val = compute_p_value(t_stat)
698
+ ci = compute_confidence_interval(att_gt, se_gt, self.alpha)
699
+
700
+ group_time_effects[(g, t)] = {
701
+ 'effect': att_gt,
702
+ 'se': se_gt,
703
+ 't_stat': t_stat,
704
+ 'p_value': p_val,
705
+ 'conf_int': ci,
706
+ 'n_treated': n_treat,
707
+ 'n_control': n_ctrl,
708
+ }
709
+
710
+ if inf_info is not None:
711
+ influence_func_info[(g, t)] = inf_info
712
+
713
+ if not group_time_effects:
714
+ raise ValueError(
715
+ "Could not estimate any group-time effects. "
716
+ "Check that data has sufficient observations."
717
+ )
718
+
719
+ # Compute overall ATT (simple aggregation)
720
+ overall_att, overall_se = self._aggregate_simple(
721
+ group_time_effects, influence_func_info, df, unit, precomputed
722
+ )
723
+ # Use NaN for t-stat and p-value when SE is undefined (NaN or non-positive)
724
+ if np.isfinite(overall_se) and overall_se > 0:
725
+ overall_t = overall_att / overall_se
726
+ overall_p = compute_p_value(overall_t)
727
+ else:
728
+ overall_t = np.nan
729
+ overall_p = np.nan
730
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
731
+
732
+ # Compute additional aggregations if requested
733
+ event_study_effects = None
734
+ group_effects = None
735
+
736
+ if aggregate in ["event_study", "all"]:
737
+ event_study_effects = self._aggregate_event_study(
738
+ group_time_effects, influence_func_info,
739
+ treatment_groups, time_periods, balance_e
740
+ )
741
+
742
+ if aggregate in ["group", "all"]:
743
+ group_effects = self._aggregate_by_group(
744
+ group_time_effects, influence_func_info, treatment_groups
745
+ )
746
+
747
+ # Run bootstrap inference if requested
748
+ bootstrap_results = None
749
+ if self.n_bootstrap > 0 and influence_func_info:
750
+ bootstrap_results = self._run_multiplier_bootstrap(
751
+ group_time_effects=group_time_effects,
752
+ influence_func_info=influence_func_info,
753
+ aggregate=aggregate,
754
+ balance_e=balance_e,
755
+ treatment_groups=treatment_groups,
756
+ time_periods=time_periods,
757
+ )
758
+
759
+ # Update estimates with bootstrap inference
760
+ overall_se = bootstrap_results.overall_att_se
761
+ # Use NaN for t-stat when SE is undefined; p-value comes from bootstrap
762
+ if np.isfinite(overall_se) and overall_se > 0:
763
+ overall_t = overall_att / overall_se
764
+ else:
765
+ overall_t = np.nan
766
+ overall_p = bootstrap_results.overall_att_p_value
767
+ overall_ci = bootstrap_results.overall_att_ci
768
+
769
+ # Update group-time effects with bootstrap SEs
770
+ for gt in group_time_effects:
771
+ if gt in bootstrap_results.group_time_ses:
772
+ group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt]
773
+ group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt]
774
+ group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt]
775
+ effect = float(group_time_effects[gt]['effect'])
776
+ se = float(group_time_effects[gt]['se'])
777
+ group_time_effects[gt]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
778
+
779
+ # Update event study effects with bootstrap SEs
780
+ if (event_study_effects is not None
781
+ and bootstrap_results.event_study_ses is not None
782
+ and bootstrap_results.event_study_cis is not None
783
+ and bootstrap_results.event_study_p_values is not None):
784
+ for e in event_study_effects:
785
+ if e in bootstrap_results.event_study_ses:
786
+ event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e]
787
+ event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e]
788
+ p_val = bootstrap_results.event_study_p_values[e]
789
+ event_study_effects[e]['p_value'] = p_val
790
+ effect = float(event_study_effects[e]['effect'])
791
+ se = float(event_study_effects[e]['se'])
792
+ event_study_effects[e]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
793
+
794
+ # Update group effects with bootstrap SEs
795
+ if (group_effects is not None
796
+ and bootstrap_results.group_effect_ses is not None
797
+ and bootstrap_results.group_effect_cis is not None
798
+ and bootstrap_results.group_effect_p_values is not None):
799
+ for g in group_effects:
800
+ if g in bootstrap_results.group_effect_ses:
801
+ group_effects[g]['se'] = bootstrap_results.group_effect_ses[g]
802
+ group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g]
803
+ group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g]
804
+ effect = float(group_effects[g]['effect'])
805
+ se = float(group_effects[g]['se'])
806
+ group_effects[g]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
807
+
808
+ # Store results
809
+ self.results_ = CallawaySantAnnaResults(
810
+ group_time_effects=group_time_effects,
811
+ overall_att=overall_att,
812
+ overall_se=overall_se,
813
+ overall_t_stat=overall_t,
814
+ overall_p_value=overall_p,
815
+ overall_conf_int=overall_ci,
816
+ groups=treatment_groups,
817
+ time_periods=time_periods,
818
+ n_obs=len(df),
819
+ n_treated_units=n_treated_units,
820
+ n_control_units=n_control_units,
821
+ alpha=self.alpha,
822
+ control_group=self.control_group,
823
+ base_period=self.base_period,
824
+ event_study_effects=event_study_effects,
825
+ group_effects=group_effects,
826
+ bootstrap_results=bootstrap_results,
827
+ )
828
+
829
+ self.is_fitted_ = True
830
+ return self.results_
831
+
832
+ def _outcome_regression(
833
+ self,
834
+ treated_change: np.ndarray,
835
+ control_change: np.ndarray,
836
+ X_treated: Optional[np.ndarray] = None,
837
+ X_control: Optional[np.ndarray] = None,
838
+ ) -> Tuple[float, float, np.ndarray]:
839
+ """
840
+ Estimate ATT using outcome regression.
841
+
842
+ With covariates:
843
+ 1. Regress outcome changes on covariates for control group
844
+ 2. Predict counterfactual for treated using their covariates
845
+ 3. ATT = mean(treated_change) - mean(predicted_counterfactual)
846
+
847
+ Without covariates:
848
+ Simple difference in means.
849
+ """
850
+ n_t = len(treated_change)
851
+ n_c = len(control_change)
852
+
853
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
854
+ # Covariate-adjusted outcome regression
855
+ # Fit regression on control units: E[Delta Y | X, D=0]
856
+ beta, residuals = _linear_regression(
857
+ X_control, control_change,
858
+ rank_deficient_action=self.rank_deficient_action,
859
+ )
860
+
861
+ # Predict counterfactual for treated units
862
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
863
+ predicted_control = X_treated_with_intercept @ beta
864
+
865
+ # ATT = mean(observed treated change - predicted counterfactual)
866
+ att = np.mean(treated_change - predicted_control)
867
+
868
+ # Standard error using sandwich estimator
869
+ # Variance from treated: Var(Y_1 - m(X))
870
+ treated_residuals = treated_change - predicted_control
871
+ var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
872
+
873
+ # Variance from control regression (residual variance)
874
+ var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
875
+
876
+ # Approximate SE (ignoring estimation error in beta for simplicity)
877
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
878
+
879
+ # Influence function
880
+ inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
881
+ inf_control = -residuals / n_c
882
+ inf_func = np.concatenate([inf_treated, inf_control])
883
+ else:
884
+ # Simple difference in means (no covariates)
885
+ att = np.mean(treated_change) - np.mean(control_change)
886
+
887
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
888
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
889
+
890
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
891
+
892
+ # Influence function (for aggregation)
893
+ inf_treated = treated_change - np.mean(treated_change)
894
+ inf_control = control_change - np.mean(control_change)
895
+ inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
896
+
897
+ return att, se, inf_func
898
+
899
+ def _ipw_estimation(
900
+ self,
901
+ treated_change: np.ndarray,
902
+ control_change: np.ndarray,
903
+ n_treated: int,
904
+ n_control: int,
905
+ X_treated: Optional[np.ndarray] = None,
906
+ X_control: Optional[np.ndarray] = None,
907
+ ) -> Tuple[float, float, np.ndarray]:
908
+ """
909
+ Estimate ATT using inverse probability weighting.
910
+
911
+ With covariates:
912
+ 1. Estimate propensity score P(D=1|X) using logistic regression
913
+ 2. Reweight control units to match treated covariate distribution
914
+ 3. ATT = mean(treated) - weighted_mean(control)
915
+
916
+ Without covariates:
917
+ Simple difference in means with unconditional propensity weighting.
918
+ """
919
+ n_t = len(treated_change)
920
+ n_c = len(control_change)
921
+ n_total = n_treated + n_control
922
+
923
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
924
+ # Covariate-adjusted IPW estimation
925
+ # Stack covariates and create treatment indicator
926
+ X_all = np.vstack([X_treated, X_control])
927
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
928
+
929
+ # Estimate propensity scores using logistic regression
930
+ try:
931
+ _, pscore = _logistic_regression(X_all, D)
932
+ except (np.linalg.LinAlgError, ValueError):
933
+ # Fallback to unconditional if logistic regression fails
934
+ warnings.warn(
935
+ "Propensity score estimation failed. "
936
+ "Falling back to unconditional estimation.",
937
+ UserWarning,
938
+ stacklevel=4,
939
+ )
940
+ pscore = np.full(len(D), n_t / (n_t + n_c))
941
+
942
+ # Propensity scores for treated and control
943
+ pscore_treated = pscore[:n_t]
944
+ pscore_control = pscore[n_t:]
945
+
946
+ # Clip propensity scores to avoid extreme weights
947
+ pscore_control = np.clip(pscore_control, 0.01, 0.99)
948
+ pscore_treated = np.clip(pscore_treated, 0.01, 0.99)
949
+
950
+ # IPW weights for control units: p(X) / (1 - p(X))
951
+ # This reweights controls to have same covariate distribution as treated
952
+ weights_control = pscore_control / (1 - pscore_control)
953
+ weights_control = weights_control / np.sum(weights_control) # normalize
954
+
955
+ # ATT = mean(treated) - weighted_mean(control)
956
+ att = np.mean(treated_change) - np.sum(weights_control * control_change)
957
+
958
+ # Compute standard error
959
+ # Variance of treated mean
960
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
961
+
962
+ # Variance of weighted control mean
963
+ weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2)
964
+
965
+ se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0
966
+
967
+ # Influence function
968
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
969
+ inf_control = -weights_control * (control_change - np.sum(weights_control * control_change))
970
+ inf_func = np.concatenate([inf_treated, inf_control])
971
+ else:
972
+ # Unconditional IPW (reduces to difference in means)
973
+ p_treat = n_treated / n_total # unconditional propensity score
974
+
975
+ att = np.mean(treated_change) - np.mean(control_change)
976
+
977
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
978
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
979
+
980
+ # Adjusted variance for IPW
981
+ se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0
982
+
983
+ # Influence function (for aggregation)
984
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
985
+ inf_control = (control_change - np.mean(control_change)) / n_c
986
+ inf_func = np.concatenate([inf_treated, -inf_control])
987
+
988
+ return att, se, inf_func
989
+
990
+ def _doubly_robust(
991
+ self,
992
+ treated_change: np.ndarray,
993
+ control_change: np.ndarray,
994
+ X_treated: Optional[np.ndarray] = None,
995
+ X_control: Optional[np.ndarray] = None,
996
+ ) -> Tuple[float, float, np.ndarray]:
997
+ """
998
+ Estimate ATT using doubly robust estimation.
999
+
1000
+ With covariates:
1001
+ Combines outcome regression and IPW for double robustness.
1002
+ The estimator is consistent if either the outcome model OR
1003
+ the propensity model is correctly specified.
1004
+
1005
+ ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
1006
+ + (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
1007
+
1008
+ where m(X) is the outcome model and w_i are IPW weights.
1009
+
1010
+ Without covariates:
1011
+ Reduces to simple difference in means.
1012
+ """
1013
+ n_t = len(treated_change)
1014
+ n_c = len(control_change)
1015
+
1016
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
1017
+ # Doubly robust estimation with covariates
1018
+ # Step 1: Outcome regression - fit E[Delta Y | X] on control
1019
+ beta, _ = _linear_regression(
1020
+ X_control, control_change,
1021
+ rank_deficient_action=self.rank_deficient_action,
1022
+ )
1023
+
1024
+ # Predict counterfactual for both treated and control
1025
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
1026
+ X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
1027
+ m_treated = X_treated_with_intercept @ beta
1028
+ m_control = X_control_with_intercept @ beta
1029
+
1030
+ # Step 2: Propensity score estimation
1031
+ X_all = np.vstack([X_treated, X_control])
1032
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
1033
+
1034
+ try:
1035
+ _, pscore = _logistic_regression(X_all, D)
1036
+ except (np.linalg.LinAlgError, ValueError):
1037
+ # Fallback to unconditional if logistic regression fails
1038
+ pscore = np.full(len(D), n_t / (n_t + n_c))
1039
+
1040
+ pscore_control = pscore[n_t:]
1041
+
1042
+ # Clip propensity scores
1043
+ pscore_control = np.clip(pscore_control, 0.01, 0.99)
1044
+
1045
+ # IPW weights for control: p(X) / (1 - p(X))
1046
+ weights_control = pscore_control / (1 - pscore_control)
1047
+
1048
+ # Step 3: Doubly robust ATT
1049
+ # ATT = mean(treated - m(X_treated))
1050
+ # + weighted_mean_control((m(X) - Y) * weight)
1051
+ att_treated_part = np.mean(treated_change - m_treated)
1052
+
1053
+ # Augmentation term from control
1054
+ augmentation = np.sum(weights_control * (m_control - control_change)) / n_t
1055
+
1056
+ att = att_treated_part + augmentation
1057
+
1058
+ # Step 4: Standard error using influence function
1059
+ # Influence function for DR estimator
1060
+ psi_treated = (treated_change - m_treated - att) / n_t
1061
+ psi_control = (weights_control * (m_control - control_change)) / n_t
1062
+
1063
+ # Variance is sum of squared influence functions
1064
+ var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2)
1065
+ se = np.sqrt(var_psi) if var_psi > 0 else 0.0
1066
+
1067
+ # Full influence function
1068
+ inf_func = np.concatenate([psi_treated, psi_control])
1069
+ else:
1070
+ # Without covariates, DR simplifies to difference in means
1071
+ att = np.mean(treated_change) - np.mean(control_change)
1072
+
1073
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
1074
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
1075
+
1076
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
1077
+
1078
+ # Influence function for DR estimator
1079
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
1080
+ inf_control = (control_change - np.mean(control_change)) / n_c
1081
+ inf_func = np.concatenate([inf_treated, -inf_control])
1082
+
1083
+ return att, se, inf_func
1084
+
1085
+ def get_params(self) -> Dict[str, Any]:
1086
+ """Get estimator parameters (sklearn-compatible)."""
1087
+ return {
1088
+ "control_group": self.control_group,
1089
+ "anticipation": self.anticipation,
1090
+ "estimation_method": self.estimation_method,
1091
+ "alpha": self.alpha,
1092
+ "cluster": self.cluster,
1093
+ "n_bootstrap": self.n_bootstrap,
1094
+ "bootstrap_weights": self.bootstrap_weights,
1095
+ # Deprecated but kept for backward compatibility
1096
+ "bootstrap_weight_type": self.bootstrap_weight_type,
1097
+ "seed": self.seed,
1098
+ "rank_deficient_action": self.rank_deficient_action,
1099
+ "base_period": self.base_period,
1100
+ }
1101
+
1102
+ def set_params(self, **params) -> "CallawaySantAnna":
1103
+ """Set estimator parameters (sklearn-compatible)."""
1104
+ for key, value in params.items():
1105
+ if hasattr(self, key):
1106
+ setattr(self, key, value)
1107
+ else:
1108
+ raise ValueError(f"Unknown parameter: {key}")
1109
+ return self
1110
+
1111
+ def summary(self) -> str:
1112
+ """Get summary of estimation results."""
1113
+ if not self.is_fitted_:
1114
+ raise RuntimeError("Model must be fitted before calling summary()")
1115
+ assert self.results_ is not None
1116
+ return self.results_.summary()
1117
+
1118
+ def print_summary(self) -> None:
1119
+ """Print summary to stdout."""
1120
+ print(self.summary())