diff-diff 2.2.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/staggered.py ADDED
@@ -0,0 +1,1117 @@
1
+ """
2
+ Staggered Difference-in-Differences estimators.
3
+
4
+ Implements modern methods for DiD with variation in treatment timing,
5
+ including the Callaway-Sant'Anna (2021) estimator.
6
+ """
7
+
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from scipy import optimize
14
+
15
+ from diff_diff.linalg import solve_ols
16
+ from diff_diff.utils import (
17
+ compute_confidence_interval,
18
+ compute_p_value,
19
+ )
20
+
21
+ # Import from split modules
22
+ from diff_diff.staggered_results import (
23
+ GroupTimeEffect,
24
+ CallawaySantAnnaResults,
25
+ )
26
+ from diff_diff.staggered_bootstrap import (
27
+ CSBootstrapResults,
28
+ CallawaySantAnnaBootstrapMixin,
29
+ )
30
+ from diff_diff.staggered_aggregation import (
31
+ CallawaySantAnnaAggregationMixin,
32
+ )
33
+
34
+ # Re-export for backward compatibility
35
+ __all__ = [
36
+ "CallawaySantAnna",
37
+ "CallawaySantAnnaResults",
38
+ "CSBootstrapResults",
39
+ "GroupTimeEffect",
40
+ ]
41
+
42
+ # Type alias for pre-computed structures
43
+ PrecomputedData = Dict[str, Any]
44
+
45
+
46
+ def _logistic_regression(
47
+ X: np.ndarray,
48
+ y: np.ndarray,
49
+ max_iter: int = 100,
50
+ tol: float = 1e-6,
51
+ ) -> Tuple[np.ndarray, np.ndarray]:
52
+ """
53
+ Fit logistic regression using scipy optimize.
54
+
55
+ Parameters
56
+ ----------
57
+ X : np.ndarray
58
+ Feature matrix (n_samples, n_features). Intercept added automatically.
59
+ y : np.ndarray
60
+ Binary outcome (0/1).
61
+ max_iter : int
62
+ Maximum iterations.
63
+ tol : float
64
+ Convergence tolerance.
65
+
66
+ Returns
67
+ -------
68
+ beta : np.ndarray
69
+ Fitted coefficients (including intercept).
70
+ probs : np.ndarray
71
+ Predicted probabilities.
72
+ """
73
+ n, p = X.shape
74
+ # Add intercept
75
+ X_with_intercept = np.column_stack([np.ones(n), X])
76
+
77
+ def neg_log_likelihood(beta: np.ndarray) -> float:
78
+ z = X_with_intercept @ beta
79
+ # Clip to prevent overflow
80
+ z = np.clip(z, -500, 500)
81
+ log_lik = np.sum(y * z - np.log(1 + np.exp(z)))
82
+ return -log_lik
83
+
84
+ def gradient(beta: np.ndarray) -> np.ndarray:
85
+ z = X_with_intercept @ beta
86
+ z = np.clip(z, -500, 500)
87
+ probs = 1 / (1 + np.exp(-z))
88
+ return -X_with_intercept.T @ (y - probs)
89
+
90
+ # Initialize with zeros
91
+ beta_init = np.zeros(p + 1)
92
+
93
+ result = optimize.minimize(
94
+ neg_log_likelihood,
95
+ beta_init,
96
+ method='BFGS',
97
+ jac=gradient,
98
+ options={'maxiter': max_iter, 'gtol': tol}
99
+ )
100
+
101
+ beta = result.x
102
+ z = X_with_intercept @ beta
103
+ z = np.clip(z, -500, 500)
104
+ probs = 1 / (1 + np.exp(-z))
105
+
106
+ return beta, probs
107
+
108
+
109
+ def _linear_regression(
110
+ X: np.ndarray,
111
+ y: np.ndarray,
112
+ rank_deficient_action: str = "warn",
113
+ ) -> Tuple[np.ndarray, np.ndarray]:
114
+ """
115
+ Fit OLS regression.
116
+
117
+ Parameters
118
+ ----------
119
+ X : np.ndarray
120
+ Feature matrix (n_samples, n_features). Intercept added automatically.
121
+ y : np.ndarray
122
+ Outcome variable.
123
+ rank_deficient_action : str, default "warn"
124
+ Action when design matrix is rank-deficient:
125
+ - "warn": Issue warning and drop linearly dependent columns (default)
126
+ - "error": Raise ValueError
127
+ - "silent": Drop columns silently without warning
128
+
129
+ Returns
130
+ -------
131
+ beta : np.ndarray
132
+ Fitted coefficients (including intercept).
133
+ residuals : np.ndarray
134
+ Residuals from the fit.
135
+ """
136
+ n = X.shape[0]
137
+ # Add intercept
138
+ X_with_intercept = np.column_stack([np.ones(n), X])
139
+
140
+ # Use unified OLS backend (no vcov needed)
141
+ beta, residuals, _ = solve_ols(
142
+ X_with_intercept, y, return_vcov=False,
143
+ rank_deficient_action=rank_deficient_action,
144
+ )
145
+
146
+ return beta, residuals
147
+
148
+
149
+ class CallawaySantAnna(
150
+ CallawaySantAnnaBootstrapMixin,
151
+ CallawaySantAnnaAggregationMixin,
152
+ ):
153
+ """
154
+ Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
155
+
156
+ This estimator handles DiD designs with variation in treatment timing
157
+ (staggered adoption) and heterogeneous treatment effects. It avoids the
158
+ bias of traditional two-way fixed effects (TWFE) estimators by:
159
+
160
+ 1. Computing group-time average treatment effects ATT(g,t) for each
161
+ cohort g (units first treated in period g) and time t.
162
+ 2. Aggregating these to summary measures (overall ATT, event study, etc.)
163
+ using appropriate weights.
164
+
165
+ Parameters
166
+ ----------
167
+ control_group : str, default="never_treated"
168
+ Which units to use as controls:
169
+ - "never_treated": Use only never-treated units (recommended)
170
+ - "not_yet_treated": Use never-treated and not-yet-treated units
171
+ anticipation : int, default=0
172
+ Number of periods before treatment where effects may occur.
173
+ Set to > 0 if treatment effects can begin before the official
174
+ treatment date.
175
+ estimation_method : str, default="dr"
176
+ Estimation method:
177
+ - "dr": Doubly robust (recommended)
178
+ - "ipw": Inverse probability weighting
179
+ - "reg": Outcome regression
180
+ alpha : float, default=0.05
181
+ Significance level for confidence intervals.
182
+ cluster : str, optional
183
+ Column name for cluster-robust standard errors.
184
+ Defaults to unit-level clustering.
185
+ n_bootstrap : int, default=0
186
+ Number of bootstrap iterations for inference.
187
+ If 0, uses analytical standard errors.
188
+ Recommended: 999 or more for reliable inference.
189
+
190
+ .. note:: Memory Usage
191
+ The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
192
+ float64 array. For large datasets, this can be significant:
193
+ - 1K bootstrap × 10K units = ~80 MB
194
+ - 10K bootstrap × 100K units = ~8 GB
195
+ Consider reducing n_bootstrap if memory is constrained.
196
+
197
+ bootstrap_weights : str, default="rademacher"
198
+ Type of weights for multiplier bootstrap:
199
+ - "rademacher": +1/-1 with equal probability (standard choice)
200
+ - "mammen": Two-point distribution (asymptotically valid, matches skewness)
201
+ - "webb": Six-point distribution (recommended when n_clusters < 20)
202
+ bootstrap_weight_type : str, optional
203
+ .. deprecated:: 1.0.1
204
+ Use ``bootstrap_weights`` instead. Will be removed in v3.0.
205
+ seed : int, optional
206
+ Random seed for reproducibility.
207
+ rank_deficient_action : str, default="warn"
208
+ Action when design matrix is rank-deficient (linearly dependent columns):
209
+ - "warn": Issue warning and drop linearly dependent columns (default)
210
+ - "error": Raise ValueError
211
+ - "silent": Drop columns silently without warning
212
+ base_period : str, default="varying"
213
+ Method for selecting the base (reference) period for computing
214
+ ATT(g,t). Options:
215
+ - "varying": For pre-treatment periods (t < g - anticipation), use
216
+ t-1 as base (consecutive comparisons). For post-treatment, use
217
+ g-1-anticipation. Requires t-1 to exist in data.
218
+ - "universal": Always use g-1-anticipation as base period.
219
+ Both produce identical post-treatment effects. Matches R's
220
+ did::att_gt() base_period parameter.
221
+
222
+ Attributes
223
+ ----------
224
+ results_ : CallawaySantAnnaResults
225
+ Estimation results after calling fit().
226
+ is_fitted_ : bool
227
+ Whether the model has been fitted.
228
+
229
+ Examples
230
+ --------
231
+ Basic usage:
232
+
233
+ >>> import pandas as pd
234
+ >>> from diff_diff import CallawaySantAnna
235
+ >>>
236
+ >>> # Panel data with staggered treatment
237
+ >>> # 'first_treat' = period when unit was first treated (0 if never treated)
238
+ >>> data = pd.DataFrame({
239
+ ... 'unit': [...],
240
+ ... 'time': [...],
241
+ ... 'outcome': [...],
242
+ ... 'first_treat': [...] # 0 for never-treated, else first treatment period
243
+ ... })
244
+ >>>
245
+ >>> cs = CallawaySantAnna()
246
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
247
+ ... time='time', first_treat='first_treat')
248
+ >>>
249
+ >>> results.print_summary()
250
+
251
+ With event study aggregation:
252
+
253
+ >>> cs = CallawaySantAnna()
254
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
255
+ ... time='time', first_treat='first_treat',
256
+ ... aggregate='event_study')
257
+ >>>
258
+ >>> # Plot event study
259
+ >>> from diff_diff import plot_event_study
260
+ >>> plot_event_study(results)
261
+
262
+ With covariate adjustment (conditional parallel trends):
263
+
264
+ >>> # When parallel trends only holds conditional on covariates
265
+ >>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
266
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
267
+ ... time='time', first_treat='first_treat',
268
+ ... covariates=['age', 'income'])
269
+ >>>
270
+ >>> # DR is recommended: consistent if either outcome model
271
+ >>> # or propensity model is correctly specified
272
+
273
+ Notes
274
+ -----
275
+ The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
276
+ approach: instead of estimating a single treatment effect, they estimate
277
+ ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
278
+ problem where already-treated units act as controls.
279
+
280
+ The ATT(g,t) is identified under parallel trends conditional on covariates:
281
+
282
+ E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
283
+
284
+ where G=g indicates treatment cohort g and C=1 indicates control units.
285
+ This uses g-1 as the base period, which applies to post-treatment (t >= g).
286
+ With base_period="varying" (default), pre-treatment uses t-1 as base for
287
+ consecutive comparisons useful in parallel trends diagnostics.
288
+
289
+ References
290
+ ----------
291
+ Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
292
+ multiple time periods. Journal of Econometrics, 225(2), 200-230.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ control_group: str = "never_treated",
298
+ anticipation: int = 0,
299
+ estimation_method: str = "dr",
300
+ alpha: float = 0.05,
301
+ cluster: Optional[str] = None,
302
+ n_bootstrap: int = 0,
303
+ bootstrap_weights: Optional[str] = None,
304
+ bootstrap_weight_type: Optional[str] = None,
305
+ seed: Optional[int] = None,
306
+ rank_deficient_action: str = "warn",
307
+ base_period: str = "varying",
308
+ ):
309
+ import warnings
310
+
311
+ if control_group not in ["never_treated", "not_yet_treated"]:
312
+ raise ValueError(
313
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
314
+ f"got '{control_group}'"
315
+ )
316
+ if estimation_method not in ["dr", "ipw", "reg"]:
317
+ raise ValueError(
318
+ f"estimation_method must be 'dr', 'ipw', or 'reg', "
319
+ f"got '{estimation_method}'"
320
+ )
321
+
322
+ # Handle bootstrap_weight_type deprecation
323
+ if bootstrap_weight_type is not None:
324
+ warnings.warn(
325
+ "bootstrap_weight_type is deprecated and will be removed in v3.0. "
326
+ "Use bootstrap_weights instead.",
327
+ DeprecationWarning,
328
+ stacklevel=2
329
+ )
330
+ if bootstrap_weights is None:
331
+ bootstrap_weights = bootstrap_weight_type
332
+
333
+ # Default to rademacher if neither specified
334
+ if bootstrap_weights is None:
335
+ bootstrap_weights = "rademacher"
336
+
337
+ if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
338
+ raise ValueError(
339
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
340
+ f"got '{bootstrap_weights}'"
341
+ )
342
+
343
+ if rank_deficient_action not in ["warn", "error", "silent"]:
344
+ raise ValueError(
345
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
346
+ f"got '{rank_deficient_action}'"
347
+ )
348
+
349
+ if base_period not in ["varying", "universal"]:
350
+ raise ValueError(
351
+ f"base_period must be 'varying' or 'universal', "
352
+ f"got '{base_period}'"
353
+ )
354
+
355
+ self.control_group = control_group
356
+ self.anticipation = anticipation
357
+ self.estimation_method = estimation_method
358
+ self.alpha = alpha
359
+ self.cluster = cluster
360
+ self.n_bootstrap = n_bootstrap
361
+ self.bootstrap_weights = bootstrap_weights
362
+ # Keep bootstrap_weight_type for backward compatibility
363
+ self.bootstrap_weight_type = bootstrap_weights
364
+ self.seed = seed
365
+ self.rank_deficient_action = rank_deficient_action
366
+ self.base_period = base_period
367
+
368
+ self.is_fitted_ = False
369
+ self.results_: Optional[CallawaySantAnnaResults] = None
370
+
371
+ def _precompute_structures(
372
+ self,
373
+ df: pd.DataFrame,
374
+ outcome: str,
375
+ unit: str,
376
+ time: str,
377
+ first_treat: str,
378
+ covariates: Optional[List[str]],
379
+ time_periods: List[Any],
380
+ treatment_groups: List[Any],
381
+ ) -> PrecomputedData:
382
+ """
383
+ Pre-compute data structures for efficient ATT(g,t) computation.
384
+
385
+ This pivots data to wide format and pre-computes:
386
+ - Outcome matrix (units x time periods)
387
+ - Covariate matrix (units x covariates) from base period
388
+ - Unit cohort membership masks
389
+ - Control unit masks
390
+
391
+ Returns
392
+ -------
393
+ PrecomputedData
394
+ Dictionary with pre-computed structures.
395
+ """
396
+ # Get unique units and their cohort assignments
397
+ unit_info = df.groupby(unit)[first_treat].first()
398
+ all_units = unit_info.index.values
399
+ unit_cohorts = unit_info.values
400
+ n_units = len(all_units)
401
+
402
+ # Create unit index mapping for fast lookups
403
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
404
+
405
+ # Pivot outcome to wide format: rows = units, columns = time periods
406
+ outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
407
+ # Reindex to ensure all units are present (handles unbalanced panels)
408
+ outcome_wide = outcome_wide.reindex(all_units)
409
+ outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
410
+ period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
411
+
412
+ # Pre-compute cohort masks (boolean arrays)
413
+ cohort_masks = {}
414
+ for g in treatment_groups:
415
+ cohort_masks[g] = (unit_cohorts == g)
416
+
417
+ # Never-treated mask
418
+ never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
419
+
420
+ # Pre-compute covariate matrices by time period if needed
421
+ # (covariates are retrieved from the base period of each comparison)
422
+ covariate_by_period = None
423
+ if covariates:
424
+ covariate_by_period = {}
425
+ for t in time_periods:
426
+ period_data = df[df[time] == t].set_index(unit)
427
+ period_cov = period_data.reindex(all_units)[covariates]
428
+ covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
429
+
430
+ return {
431
+ 'all_units': all_units,
432
+ 'unit_to_idx': unit_to_idx,
433
+ 'unit_cohorts': unit_cohorts,
434
+ 'outcome_matrix': outcome_matrix,
435
+ 'period_to_col': period_to_col,
436
+ 'cohort_masks': cohort_masks,
437
+ 'never_treated_mask': never_treated_mask,
438
+ 'covariate_by_period': covariate_by_period,
439
+ 'time_periods': time_periods,
440
+ }
441
+
442
+ def _compute_att_gt_fast(
443
+ self,
444
+ precomputed: PrecomputedData,
445
+ g: Any,
446
+ t: Any,
447
+ covariates: Optional[List[str]],
448
+ ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]:
449
+ """
450
+ Compute ATT(g,t) using pre-computed data structures (fast version).
451
+
452
+ Uses vectorized numpy operations on pre-pivoted outcome matrix
453
+ instead of repeated pandas filtering.
454
+ """
455
+ time_periods = precomputed['time_periods']
456
+ period_to_col = precomputed['period_to_col']
457
+ outcome_matrix = precomputed['outcome_matrix']
458
+ cohort_masks = precomputed['cohort_masks']
459
+ never_treated_mask = precomputed['never_treated_mask']
460
+ unit_cohorts = precomputed['unit_cohorts']
461
+ all_units = precomputed['all_units']
462
+ covariate_by_period = precomputed['covariate_by_period']
463
+
464
+ # Base period selection based on mode
465
+ if self.base_period == "universal":
466
+ # Universal: always use g - 1 - anticipation
467
+ base_period_val = g - 1 - self.anticipation
468
+ else: # varying
469
+ if t < g - self.anticipation:
470
+ # Pre-treatment: use t - 1 (consecutive comparison)
471
+ base_period_val = t - 1
472
+ else:
473
+ # Post-treatment: use g - 1 - anticipation
474
+ base_period_val = g - 1 - self.anticipation
475
+
476
+ if base_period_val not in period_to_col:
477
+ # Base period must exist; no fallback to maintain methodological consistency
478
+ return None, 0.0, 0, 0, None
479
+
480
+ # Check if periods exist in the data
481
+ if base_period_val not in period_to_col or t not in period_to_col:
482
+ return None, 0.0, 0, 0, None
483
+
484
+ base_col = period_to_col[base_period_val]
485
+ post_col = period_to_col[t]
486
+
487
+ # Get treated units mask (cohort g)
488
+ treated_mask = cohort_masks[g]
489
+
490
+ # Get control units mask
491
+ if self.control_group == "never_treated":
492
+ control_mask = never_treated_mask
493
+ else: # not_yet_treated
494
+ # Not yet treated at time t: never-treated OR (first_treat > t AND not cohort g)
495
+ # Must exclude cohort g since they are the treated group for this ATT(g,t)
496
+ control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g))
497
+
498
+ # Extract outcomes for base and post periods
499
+ y_base = outcome_matrix[:, base_col]
500
+ y_post = outcome_matrix[:, post_col]
501
+
502
+ # Compute outcome changes (vectorized)
503
+ outcome_change = y_post - y_base
504
+
505
+ # Filter to units with valid data (no NaN in either period)
506
+ valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
507
+
508
+ # Get treated and control with valid data
509
+ treated_valid = treated_mask & valid_mask
510
+ control_valid = control_mask & valid_mask
511
+
512
+ n_treated = np.sum(treated_valid)
513
+ n_control = np.sum(control_valid)
514
+
515
+ if n_treated == 0 or n_control == 0:
516
+ return None, 0.0, 0, 0, None
517
+
518
+ # Extract outcome changes for treated and control
519
+ treated_change = outcome_change[treated_valid]
520
+ control_change = outcome_change[control_valid]
521
+
522
+ # Get unit IDs for influence function
523
+ treated_units = all_units[treated_valid]
524
+ control_units = all_units[control_valid]
525
+
526
+ # Get covariates if specified (from the base period)
527
+ X_treated = None
528
+ X_control = None
529
+ if covariates and covariate_by_period is not None:
530
+ cov_matrix = covariate_by_period[base_period_val]
531
+ X_treated = cov_matrix[treated_valid]
532
+ X_control = cov_matrix[control_valid]
533
+
534
+ # Check for missing values
535
+ if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
536
+ warnings.warn(
537
+ f"Missing values in covariates for group {g}, time {t}. "
538
+ "Falling back to unconditional estimation.",
539
+ UserWarning,
540
+ stacklevel=3,
541
+ )
542
+ X_treated = None
543
+ X_control = None
544
+
545
+ # Estimation method
546
+ if self.estimation_method == "reg":
547
+ att_gt, se_gt, inf_func = self._outcome_regression(
548
+ treated_change, control_change, X_treated, X_control
549
+ )
550
+ elif self.estimation_method == "ipw":
551
+ att_gt, se_gt, inf_func = self._ipw_estimation(
552
+ treated_change, control_change,
553
+ int(n_treated), int(n_control),
554
+ X_treated, X_control
555
+ )
556
+ else: # doubly robust
557
+ att_gt, se_gt, inf_func = self._doubly_robust(
558
+ treated_change, control_change, X_treated, X_control
559
+ )
560
+
561
+ # Package influence function info with unit IDs for bootstrap
562
+ n_t = int(n_treated)
563
+ inf_func_info = {
564
+ 'treated_units': list(treated_units),
565
+ 'control_units': list(control_units),
566
+ 'treated_inf': inf_func[:n_t],
567
+ 'control_inf': inf_func[n_t:],
568
+ }
569
+
570
+ return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info
571
+
572
+ def fit(
573
+ self,
574
+ data: pd.DataFrame,
575
+ outcome: str,
576
+ unit: str,
577
+ time: str,
578
+ first_treat: str,
579
+ covariates: Optional[List[str]] = None,
580
+ aggregate: Optional[str] = None,
581
+ balance_e: Optional[int] = None,
582
+ ) -> CallawaySantAnnaResults:
583
+ """
584
+ Fit the Callaway-Sant'Anna estimator.
585
+
586
+ Parameters
587
+ ----------
588
+ data : pd.DataFrame
589
+ Panel data with unit and time identifiers.
590
+ outcome : str
591
+ Name of outcome variable column.
592
+ unit : str
593
+ Name of unit identifier column.
594
+ time : str
595
+ Name of time period column.
596
+ first_treat : str
597
+ Name of column indicating when unit was first treated.
598
+ Use 0 (or np.inf) for never-treated units.
599
+ covariates : list, optional
600
+ List of covariate column names for conditional parallel trends.
601
+ aggregate : str, optional
602
+ How to aggregate group-time effects:
603
+ - None: Only compute ATT(g,t) (default)
604
+ - "simple": Simple weighted average (overall ATT)
605
+ - "event_study": Aggregate by relative time (event study)
606
+ - "group": Aggregate by treatment cohort
607
+ - "all": Compute all aggregations
608
+ balance_e : int, optional
609
+ For event study, balance the panel at relative time e.
610
+ Ensures all groups contribute to each relative period.
611
+
612
+ Returns
613
+ -------
614
+ CallawaySantAnnaResults
615
+ Object containing all estimation results.
616
+
617
+ Raises
618
+ ------
619
+ ValueError
620
+ If required columns are missing or data validation fails.
621
+ """
622
+ # Validate inputs
623
+ required_cols = [outcome, unit, time, first_treat]
624
+ if covariates:
625
+ required_cols.extend(covariates)
626
+
627
+ missing = [c for c in required_cols if c not in data.columns]
628
+ if missing:
629
+ raise ValueError(f"Missing columns: {missing}")
630
+
631
+ # Create working copy
632
+ df = data.copy()
633
+
634
+ # Ensure numeric types
635
+ df[time] = pd.to_numeric(df[time])
636
+ df[first_treat] = pd.to_numeric(df[first_treat])
637
+
638
+ # Standardize the first_treat column name for internal use
639
+ # This avoids hardcoding column names in internal methods
640
+ df['first_treat'] = df[first_treat]
641
+
642
+ # Identify groups and time periods
643
+ time_periods = sorted(df[time].unique())
644
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
645
+
646
+ # Never-treated indicator (first_treat = 0 or inf)
647
+ df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
648
+
649
+ # Get unique units
650
+ unit_info = df.groupby(unit).agg({
651
+ first_treat: 'first',
652
+ '_never_treated': 'first'
653
+ }).reset_index()
654
+
655
+ n_treated_units = (unit_info[first_treat] > 0).sum()
656
+ n_control_units = (unit_info['_never_treated']).sum()
657
+
658
+ if n_control_units == 0:
659
+ raise ValueError("No never-treated units found. Check 'first_treat' column.")
660
+
661
+ # Pre-compute data structures for efficient ATT(g,t) computation
662
+ precomputed = self._precompute_structures(
663
+ df, outcome, unit, time, first_treat,
664
+ covariates, time_periods, treatment_groups
665
+ )
666
+
667
+ # Compute ATT(g,t) for each group-time combination
668
+ group_time_effects = {}
669
+ influence_func_info = {} # Store influence functions for bootstrap
670
+
671
+ # Get minimum period for determining valid pre-treatment periods
672
+ min_period = min(time_periods)
673
+
674
+ for g in treatment_groups:
675
+ # Compute valid periods including pre-treatment
676
+ if self.base_period == "universal":
677
+ # Universal: all periods except the base period (which is normalized to 0)
678
+ universal_base = g - 1 - self.anticipation
679
+ valid_periods = [t for t in time_periods if t != universal_base]
680
+ else:
681
+ # Varying: post-treatment + pre-treatment where t-1 exists
682
+ valid_periods = [
683
+ t for t in time_periods
684
+ if t >= g - self.anticipation or t > min_period
685
+ ]
686
+
687
+ for t in valid_periods:
688
+ att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast(
689
+ precomputed, g, t, covariates
690
+ )
691
+
692
+ if att_gt is not None:
693
+ t_stat = att_gt / se_gt if np.isfinite(se_gt) and se_gt > 0 else np.nan
694
+ p_val = compute_p_value(t_stat)
695
+ ci = compute_confidence_interval(att_gt, se_gt, self.alpha)
696
+
697
+ group_time_effects[(g, t)] = {
698
+ 'effect': att_gt,
699
+ 'se': se_gt,
700
+ 't_stat': t_stat,
701
+ 'p_value': p_val,
702
+ 'conf_int': ci,
703
+ 'n_treated': n_treat,
704
+ 'n_control': n_ctrl,
705
+ }
706
+
707
+ if inf_info is not None:
708
+ influence_func_info[(g, t)] = inf_info
709
+
710
+ if not group_time_effects:
711
+ raise ValueError(
712
+ "Could not estimate any group-time effects. "
713
+ "Check that data has sufficient observations."
714
+ )
715
+
716
+ # Compute overall ATT (simple aggregation)
717
+ overall_att, overall_se = self._aggregate_simple(
718
+ group_time_effects, influence_func_info, df, unit, precomputed
719
+ )
720
+ # Use NaN for t-stat and p-value when SE is undefined (NaN or non-positive)
721
+ if np.isfinite(overall_se) and overall_se > 0:
722
+ overall_t = overall_att / overall_se
723
+ overall_p = compute_p_value(overall_t)
724
+ else:
725
+ overall_t = np.nan
726
+ overall_p = np.nan
727
+ overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
728
+
729
+ # Compute additional aggregations if requested
730
+ event_study_effects = None
731
+ group_effects = None
732
+
733
+ if aggregate in ["event_study", "all"]:
734
+ event_study_effects = self._aggregate_event_study(
735
+ group_time_effects, influence_func_info,
736
+ treatment_groups, time_periods, balance_e
737
+ )
738
+
739
+ if aggregate in ["group", "all"]:
740
+ group_effects = self._aggregate_by_group(
741
+ group_time_effects, influence_func_info, treatment_groups
742
+ )
743
+
744
+ # Run bootstrap inference if requested
745
+ bootstrap_results = None
746
+ if self.n_bootstrap > 0 and influence_func_info:
747
+ bootstrap_results = self._run_multiplier_bootstrap(
748
+ group_time_effects=group_time_effects,
749
+ influence_func_info=influence_func_info,
750
+ aggregate=aggregate,
751
+ balance_e=balance_e,
752
+ treatment_groups=treatment_groups,
753
+ time_periods=time_periods,
754
+ )
755
+
756
+ # Update estimates with bootstrap inference
757
+ overall_se = bootstrap_results.overall_att_se
758
+ # Use NaN for t-stat when SE is undefined; p-value comes from bootstrap
759
+ if np.isfinite(overall_se) and overall_se > 0:
760
+ overall_t = overall_att / overall_se
761
+ else:
762
+ overall_t = np.nan
763
+ overall_p = bootstrap_results.overall_att_p_value
764
+ overall_ci = bootstrap_results.overall_att_ci
765
+
766
+ # Update group-time effects with bootstrap SEs
767
+ for gt in group_time_effects:
768
+ if gt in bootstrap_results.group_time_ses:
769
+ group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt]
770
+ group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt]
771
+ group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt]
772
+ effect = float(group_time_effects[gt]['effect'])
773
+ se = float(group_time_effects[gt]['se'])
774
+ group_time_effects[gt]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
775
+
776
+ # Update event study effects with bootstrap SEs
777
+ if (event_study_effects is not None
778
+ and bootstrap_results.event_study_ses is not None
779
+ and bootstrap_results.event_study_cis is not None
780
+ and bootstrap_results.event_study_p_values is not None):
781
+ for e in event_study_effects:
782
+ if e in bootstrap_results.event_study_ses:
783
+ event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e]
784
+ event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e]
785
+ p_val = bootstrap_results.event_study_p_values[e]
786
+ event_study_effects[e]['p_value'] = p_val
787
+ effect = float(event_study_effects[e]['effect'])
788
+ se = float(event_study_effects[e]['se'])
789
+ event_study_effects[e]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
790
+
791
+ # Update group effects with bootstrap SEs
792
+ if (group_effects is not None
793
+ and bootstrap_results.group_effect_ses is not None
794
+ and bootstrap_results.group_effect_cis is not None
795
+ and bootstrap_results.group_effect_p_values is not None):
796
+ for g in group_effects:
797
+ if g in bootstrap_results.group_effect_ses:
798
+ group_effects[g]['se'] = bootstrap_results.group_effect_ses[g]
799
+ group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g]
800
+ group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g]
801
+ effect = float(group_effects[g]['effect'])
802
+ se = float(group_effects[g]['se'])
803
+ group_effects[g]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
804
+
805
+ # Store results
806
+ self.results_ = CallawaySantAnnaResults(
807
+ group_time_effects=group_time_effects,
808
+ overall_att=overall_att,
809
+ overall_se=overall_se,
810
+ overall_t_stat=overall_t,
811
+ overall_p_value=overall_p,
812
+ overall_conf_int=overall_ci,
813
+ groups=treatment_groups,
814
+ time_periods=time_periods,
815
+ n_obs=len(df),
816
+ n_treated_units=n_treated_units,
817
+ n_control_units=n_control_units,
818
+ alpha=self.alpha,
819
+ control_group=self.control_group,
820
+ base_period=self.base_period,
821
+ event_study_effects=event_study_effects,
822
+ group_effects=group_effects,
823
+ bootstrap_results=bootstrap_results,
824
+ )
825
+
826
+ self.is_fitted_ = True
827
+ return self.results_
828
+
829
+ def _outcome_regression(
830
+ self,
831
+ treated_change: np.ndarray,
832
+ control_change: np.ndarray,
833
+ X_treated: Optional[np.ndarray] = None,
834
+ X_control: Optional[np.ndarray] = None,
835
+ ) -> Tuple[float, float, np.ndarray]:
836
+ """
837
+ Estimate ATT using outcome regression.
838
+
839
+ With covariates:
840
+ 1. Regress outcome changes on covariates for control group
841
+ 2. Predict counterfactual for treated using their covariates
842
+ 3. ATT = mean(treated_change) - mean(predicted_counterfactual)
843
+
844
+ Without covariates:
845
+ Simple difference in means.
846
+ """
847
+ n_t = len(treated_change)
848
+ n_c = len(control_change)
849
+
850
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
851
+ # Covariate-adjusted outcome regression
852
+ # Fit regression on control units: E[Delta Y | X, D=0]
853
+ beta, residuals = _linear_regression(
854
+ X_control, control_change,
855
+ rank_deficient_action=self.rank_deficient_action,
856
+ )
857
+
858
+ # Predict counterfactual for treated units
859
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
860
+ predicted_control = X_treated_with_intercept @ beta
861
+
862
+ # ATT = mean(observed treated change - predicted counterfactual)
863
+ att = np.mean(treated_change - predicted_control)
864
+
865
+ # Standard error using sandwich estimator
866
+ # Variance from treated: Var(Y_1 - m(X))
867
+ treated_residuals = treated_change - predicted_control
868
+ var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
869
+
870
+ # Variance from control regression (residual variance)
871
+ var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
872
+
873
+ # Approximate SE (ignoring estimation error in beta for simplicity)
874
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
875
+
876
+ # Influence function
877
+ inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
878
+ inf_control = -residuals / n_c
879
+ inf_func = np.concatenate([inf_treated, inf_control])
880
+ else:
881
+ # Simple difference in means (no covariates)
882
+ att = np.mean(treated_change) - np.mean(control_change)
883
+
884
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
885
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
886
+
887
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
888
+
889
+ # Influence function (for aggregation)
890
+ inf_treated = treated_change - np.mean(treated_change)
891
+ inf_control = control_change - np.mean(control_change)
892
+ inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
893
+
894
+ return att, se, inf_func
895
+
896
+ def _ipw_estimation(
897
+ self,
898
+ treated_change: np.ndarray,
899
+ control_change: np.ndarray,
900
+ n_treated: int,
901
+ n_control: int,
902
+ X_treated: Optional[np.ndarray] = None,
903
+ X_control: Optional[np.ndarray] = None,
904
+ ) -> Tuple[float, float, np.ndarray]:
905
+ """
906
+ Estimate ATT using inverse probability weighting.
907
+
908
+ With covariates:
909
+ 1. Estimate propensity score P(D=1|X) using logistic regression
910
+ 2. Reweight control units to match treated covariate distribution
911
+ 3. ATT = mean(treated) - weighted_mean(control)
912
+
913
+ Without covariates:
914
+ Simple difference in means with unconditional propensity weighting.
915
+ """
916
+ n_t = len(treated_change)
917
+ n_c = len(control_change)
918
+ n_total = n_treated + n_control
919
+
920
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
921
+ # Covariate-adjusted IPW estimation
922
+ # Stack covariates and create treatment indicator
923
+ X_all = np.vstack([X_treated, X_control])
924
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
925
+
926
+ # Estimate propensity scores using logistic regression
927
+ try:
928
+ _, pscore = _logistic_regression(X_all, D)
929
+ except (np.linalg.LinAlgError, ValueError):
930
+ # Fallback to unconditional if logistic regression fails
931
+ warnings.warn(
932
+ "Propensity score estimation failed. "
933
+ "Falling back to unconditional estimation.",
934
+ UserWarning,
935
+ stacklevel=4,
936
+ )
937
+ pscore = np.full(len(D), n_t / (n_t + n_c))
938
+
939
+ # Propensity scores for treated and control
940
+ pscore_treated = pscore[:n_t]
941
+ pscore_control = pscore[n_t:]
942
+
943
+ # Clip propensity scores to avoid extreme weights
944
+ pscore_control = np.clip(pscore_control, 0.01, 0.99)
945
+ pscore_treated = np.clip(pscore_treated, 0.01, 0.99)
946
+
947
+ # IPW weights for control units: p(X) / (1 - p(X))
948
+ # This reweights controls to have same covariate distribution as treated
949
+ weights_control = pscore_control / (1 - pscore_control)
950
+ weights_control = weights_control / np.sum(weights_control) # normalize
951
+
952
+ # ATT = mean(treated) - weighted_mean(control)
953
+ att = np.mean(treated_change) - np.sum(weights_control * control_change)
954
+
955
+ # Compute standard error
956
+ # Variance of treated mean
957
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
958
+
959
+ # Variance of weighted control mean
960
+ weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2)
961
+
962
+ se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0
963
+
964
+ # Influence function
965
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
966
+ inf_control = -weights_control * (control_change - np.sum(weights_control * control_change))
967
+ inf_func = np.concatenate([inf_treated, inf_control])
968
+ else:
969
+ # Unconditional IPW (reduces to difference in means)
970
+ p_treat = n_treated / n_total # unconditional propensity score
971
+
972
+ att = np.mean(treated_change) - np.mean(control_change)
973
+
974
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
975
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
976
+
977
+ # Adjusted variance for IPW
978
+ se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0
979
+
980
+ # Influence function (for aggregation)
981
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
982
+ inf_control = (control_change - np.mean(control_change)) / n_c
983
+ inf_func = np.concatenate([inf_treated, -inf_control])
984
+
985
+ return att, se, inf_func
986
+
987
+ def _doubly_robust(
988
+ self,
989
+ treated_change: np.ndarray,
990
+ control_change: np.ndarray,
991
+ X_treated: Optional[np.ndarray] = None,
992
+ X_control: Optional[np.ndarray] = None,
993
+ ) -> Tuple[float, float, np.ndarray]:
994
+ """
995
+ Estimate ATT using doubly robust estimation.
996
+
997
+ With covariates:
998
+ Combines outcome regression and IPW for double robustness.
999
+ The estimator is consistent if either the outcome model OR
1000
+ the propensity model is correctly specified.
1001
+
1002
+ ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
1003
+ + (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
1004
+
1005
+ where m(X) is the outcome model and w_i are IPW weights.
1006
+
1007
+ Without covariates:
1008
+ Reduces to simple difference in means.
1009
+ """
1010
+ n_t = len(treated_change)
1011
+ n_c = len(control_change)
1012
+
1013
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
1014
+ # Doubly robust estimation with covariates
1015
+ # Step 1: Outcome regression - fit E[Delta Y | X] on control
1016
+ beta, _ = _linear_regression(
1017
+ X_control, control_change,
1018
+ rank_deficient_action=self.rank_deficient_action,
1019
+ )
1020
+
1021
+ # Predict counterfactual for both treated and control
1022
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
1023
+ X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
1024
+ m_treated = X_treated_with_intercept @ beta
1025
+ m_control = X_control_with_intercept @ beta
1026
+
1027
+ # Step 2: Propensity score estimation
1028
+ X_all = np.vstack([X_treated, X_control])
1029
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
1030
+
1031
+ try:
1032
+ _, pscore = _logistic_regression(X_all, D)
1033
+ except (np.linalg.LinAlgError, ValueError):
1034
+ # Fallback to unconditional if logistic regression fails
1035
+ pscore = np.full(len(D), n_t / (n_t + n_c))
1036
+
1037
+ pscore_control = pscore[n_t:]
1038
+
1039
+ # Clip propensity scores
1040
+ pscore_control = np.clip(pscore_control, 0.01, 0.99)
1041
+
1042
+ # IPW weights for control: p(X) / (1 - p(X))
1043
+ weights_control = pscore_control / (1 - pscore_control)
1044
+
1045
+ # Step 3: Doubly robust ATT
1046
+ # ATT = mean(treated - m(X_treated))
1047
+ # + weighted_mean_control((m(X) - Y) * weight)
1048
+ att_treated_part = np.mean(treated_change - m_treated)
1049
+
1050
+ # Augmentation term from control
1051
+ augmentation = np.sum(weights_control * (m_control - control_change)) / n_t
1052
+
1053
+ att = att_treated_part + augmentation
1054
+
1055
+ # Step 4: Standard error using influence function
1056
+ # Influence function for DR estimator
1057
+ psi_treated = (treated_change - m_treated - att) / n_t
1058
+ psi_control = (weights_control * (m_control - control_change)) / n_t
1059
+
1060
+ # Variance is sum of squared influence functions
1061
+ var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2)
1062
+ se = np.sqrt(var_psi) if var_psi > 0 else 0.0
1063
+
1064
+ # Full influence function
1065
+ inf_func = np.concatenate([psi_treated, psi_control])
1066
+ else:
1067
+ # Without covariates, DR simplifies to difference in means
1068
+ att = np.mean(treated_change) - np.mean(control_change)
1069
+
1070
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
1071
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
1072
+
1073
+ se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
1074
+
1075
+ # Influence function for DR estimator
1076
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
1077
+ inf_control = (control_change - np.mean(control_change)) / n_c
1078
+ inf_func = np.concatenate([inf_treated, -inf_control])
1079
+
1080
+ return att, se, inf_func
1081
+
1082
+ def get_params(self) -> Dict[str, Any]:
1083
+ """Get estimator parameters (sklearn-compatible)."""
1084
+ return {
1085
+ "control_group": self.control_group,
1086
+ "anticipation": self.anticipation,
1087
+ "estimation_method": self.estimation_method,
1088
+ "alpha": self.alpha,
1089
+ "cluster": self.cluster,
1090
+ "n_bootstrap": self.n_bootstrap,
1091
+ "bootstrap_weights": self.bootstrap_weights,
1092
+ # Deprecated but kept for backward compatibility
1093
+ "bootstrap_weight_type": self.bootstrap_weight_type,
1094
+ "seed": self.seed,
1095
+ "rank_deficient_action": self.rank_deficient_action,
1096
+ "base_period": self.base_period,
1097
+ }
1098
+
1099
+ def set_params(self, **params) -> "CallawaySantAnna":
1100
+ """Set estimator parameters (sklearn-compatible)."""
1101
+ for key, value in params.items():
1102
+ if hasattr(self, key):
1103
+ setattr(self, key, value)
1104
+ else:
1105
+ raise ValueError(f"Unknown parameter: {key}")
1106
+ return self
1107
+
1108
+ def summary(self) -> str:
1109
+ """Get summary of estimation results."""
1110
+ if not self.is_fitted_:
1111
+ raise RuntimeError("Model must be fitted before calling summary()")
1112
+ assert self.results_ is not None
1113
+ return self.results_.summary()
1114
+
1115
+ def print_summary(self) -> None:
1116
+ """Print summary to stdout."""
1117
+ print(self.summary())