diff-diff 2.3.2__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +254 -0
- diff_diff/_backend.py +112 -0
- diff_diff/_rust_backend.cp313-win_amd64.pyd +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1161 -0
- diff_diff/honest_did.py +1511 -0
- diff_diff/imputation.py +2480 -0
- diff_diff/linalg.py +1537 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1241 -0
- diff_diff/prep_dgp.py +777 -0
- diff_diff/pretrends.py +1104 -0
- diff_diff/results.py +794 -0
- diff_diff/staggered.py +1120 -0
- diff_diff/staggered_aggregation.py +492 -0
- diff_diff/staggered_bootstrap.py +753 -0
- diff_diff/staggered_results.py +296 -0
- diff_diff/sun_abraham.py +1227 -0
- diff_diff/synthetic_did.py +858 -0
- diff_diff/triple_diff.py +1322 -0
- diff_diff/trop.py +2904 -0
- diff_diff/twfe.py +428 -0
- diff_diff/utils.py +1845 -0
- diff_diff/visualization.py +1676 -0
- diff_diff-2.3.2.dist-info/METADATA +2646 -0
- diff_diff-2.3.2.dist-info/RECORD +30 -0
- diff_diff-2.3.2.dist-info/WHEEL +4 -0
- diff_diff-2.3.2.dist-info/sboms/diff_diff_rust.cyclonedx.json +5952 -0
diff_diff/staggered.py
ADDED
|
@@ -0,0 +1,1120 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Staggered Difference-in-Differences estimators.
|
|
3
|
+
|
|
4
|
+
Implements modern methods for DiD with variation in treatment timing,
|
|
5
|
+
including the Callaway-Sant'Anna (2021) estimator.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from scipy import optimize
|
|
14
|
+
|
|
15
|
+
from diff_diff.linalg import solve_ols
|
|
16
|
+
from diff_diff.utils import (
|
|
17
|
+
compute_confidence_interval,
|
|
18
|
+
compute_p_value,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Import from split modules
|
|
22
|
+
from diff_diff.staggered_results import (
|
|
23
|
+
GroupTimeEffect,
|
|
24
|
+
CallawaySantAnnaResults,
|
|
25
|
+
)
|
|
26
|
+
from diff_diff.staggered_bootstrap import (
|
|
27
|
+
CSBootstrapResults,
|
|
28
|
+
CallawaySantAnnaBootstrapMixin,
|
|
29
|
+
)
|
|
30
|
+
from diff_diff.staggered_aggregation import (
|
|
31
|
+
CallawaySantAnnaAggregationMixin,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Re-export for backward compatibility
|
|
35
|
+
__all__ = [
|
|
36
|
+
"CallawaySantAnna",
|
|
37
|
+
"CallawaySantAnnaResults",
|
|
38
|
+
"CSBootstrapResults",
|
|
39
|
+
"GroupTimeEffect",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
# Type alias for pre-computed structures
|
|
43
|
+
PrecomputedData = Dict[str, Any]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _logistic_regression(
|
|
47
|
+
X: np.ndarray,
|
|
48
|
+
y: np.ndarray,
|
|
49
|
+
max_iter: int = 100,
|
|
50
|
+
tol: float = 1e-6,
|
|
51
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
52
|
+
"""
|
|
53
|
+
Fit logistic regression using scipy optimize.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
X : np.ndarray
|
|
58
|
+
Feature matrix (n_samples, n_features). Intercept added automatically.
|
|
59
|
+
y : np.ndarray
|
|
60
|
+
Binary outcome (0/1).
|
|
61
|
+
max_iter : int
|
|
62
|
+
Maximum iterations.
|
|
63
|
+
tol : float
|
|
64
|
+
Convergence tolerance.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
beta : np.ndarray
|
|
69
|
+
Fitted coefficients (including intercept).
|
|
70
|
+
probs : np.ndarray
|
|
71
|
+
Predicted probabilities.
|
|
72
|
+
"""
|
|
73
|
+
n, p = X.shape
|
|
74
|
+
# Add intercept
|
|
75
|
+
X_with_intercept = np.column_stack([np.ones(n), X])
|
|
76
|
+
|
|
77
|
+
def neg_log_likelihood(beta: np.ndarray) -> float:
|
|
78
|
+
z = X_with_intercept @ beta
|
|
79
|
+
# Clip to prevent overflow
|
|
80
|
+
z = np.clip(z, -500, 500)
|
|
81
|
+
log_lik = np.sum(y * z - np.log(1 + np.exp(z)))
|
|
82
|
+
return -log_lik
|
|
83
|
+
|
|
84
|
+
def gradient(beta: np.ndarray) -> np.ndarray:
|
|
85
|
+
z = X_with_intercept @ beta
|
|
86
|
+
z = np.clip(z, -500, 500)
|
|
87
|
+
probs = 1 / (1 + np.exp(-z))
|
|
88
|
+
return -X_with_intercept.T @ (y - probs)
|
|
89
|
+
|
|
90
|
+
# Initialize with zeros
|
|
91
|
+
beta_init = np.zeros(p + 1)
|
|
92
|
+
|
|
93
|
+
result = optimize.minimize(
|
|
94
|
+
neg_log_likelihood,
|
|
95
|
+
beta_init,
|
|
96
|
+
method='BFGS',
|
|
97
|
+
jac=gradient,
|
|
98
|
+
options={'maxiter': max_iter, 'gtol': tol}
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
beta = result.x
|
|
102
|
+
z = X_with_intercept @ beta
|
|
103
|
+
z = np.clip(z, -500, 500)
|
|
104
|
+
probs = 1 / (1 + np.exp(-z))
|
|
105
|
+
|
|
106
|
+
return beta, probs
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _linear_regression(
|
|
110
|
+
X: np.ndarray,
|
|
111
|
+
y: np.ndarray,
|
|
112
|
+
rank_deficient_action: str = "warn",
|
|
113
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
114
|
+
"""
|
|
115
|
+
Fit OLS regression.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
X : np.ndarray
|
|
120
|
+
Feature matrix (n_samples, n_features). Intercept added automatically.
|
|
121
|
+
y : np.ndarray
|
|
122
|
+
Outcome variable.
|
|
123
|
+
rank_deficient_action : str, default "warn"
|
|
124
|
+
Action when design matrix is rank-deficient:
|
|
125
|
+
- "warn": Issue warning and drop linearly dependent columns (default)
|
|
126
|
+
- "error": Raise ValueError
|
|
127
|
+
- "silent": Drop columns silently without warning
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
beta : np.ndarray
|
|
132
|
+
Fitted coefficients (including intercept).
|
|
133
|
+
residuals : np.ndarray
|
|
134
|
+
Residuals from the fit.
|
|
135
|
+
"""
|
|
136
|
+
n = X.shape[0]
|
|
137
|
+
# Add intercept
|
|
138
|
+
X_with_intercept = np.column_stack([np.ones(n), X])
|
|
139
|
+
|
|
140
|
+
# Use unified OLS backend (no vcov needed)
|
|
141
|
+
beta, residuals, _ = solve_ols(
|
|
142
|
+
X_with_intercept, y, return_vcov=False,
|
|
143
|
+
rank_deficient_action=rank_deficient_action,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return beta, residuals
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class CallawaySantAnna(
|
|
150
|
+
CallawaySantAnnaBootstrapMixin,
|
|
151
|
+
CallawaySantAnnaAggregationMixin,
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
|
|
155
|
+
|
|
156
|
+
This estimator handles DiD designs with variation in treatment timing
|
|
157
|
+
(staggered adoption) and heterogeneous treatment effects. It avoids the
|
|
158
|
+
bias of traditional two-way fixed effects (TWFE) estimators by:
|
|
159
|
+
|
|
160
|
+
1. Computing group-time average treatment effects ATT(g,t) for each
|
|
161
|
+
cohort g (units first treated in period g) and time t.
|
|
162
|
+
2. Aggregating these to summary measures (overall ATT, event study, etc.)
|
|
163
|
+
using appropriate weights.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
control_group : str, default="never_treated"
|
|
168
|
+
Which units to use as controls:
|
|
169
|
+
- "never_treated": Use only never-treated units (recommended)
|
|
170
|
+
- "not_yet_treated": Use never-treated and not-yet-treated units
|
|
171
|
+
anticipation : int, default=0
|
|
172
|
+
Number of periods before treatment where effects may occur.
|
|
173
|
+
Set to > 0 if treatment effects can begin before the official
|
|
174
|
+
treatment date.
|
|
175
|
+
estimation_method : str, default="dr"
|
|
176
|
+
Estimation method:
|
|
177
|
+
- "dr": Doubly robust (recommended)
|
|
178
|
+
- "ipw": Inverse probability weighting
|
|
179
|
+
- "reg": Outcome regression
|
|
180
|
+
alpha : float, default=0.05
|
|
181
|
+
Significance level for confidence intervals.
|
|
182
|
+
cluster : str, optional
|
|
183
|
+
Column name for cluster-robust standard errors.
|
|
184
|
+
Defaults to unit-level clustering.
|
|
185
|
+
n_bootstrap : int, default=0
|
|
186
|
+
Number of bootstrap iterations for inference.
|
|
187
|
+
If 0, uses analytical standard errors.
|
|
188
|
+
Recommended: 999 or more for reliable inference.
|
|
189
|
+
|
|
190
|
+
.. note:: Memory Usage
|
|
191
|
+
The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
|
|
192
|
+
float64 array. For large datasets, this can be significant:
|
|
193
|
+
- 1K bootstrap × 10K units = ~80 MB
|
|
194
|
+
- 10K bootstrap × 100K units = ~8 GB
|
|
195
|
+
Consider reducing n_bootstrap if memory is constrained.
|
|
196
|
+
|
|
197
|
+
bootstrap_weights : str, default="rademacher"
|
|
198
|
+
Type of weights for multiplier bootstrap:
|
|
199
|
+
- "rademacher": +1/-1 with equal probability (standard choice)
|
|
200
|
+
- "mammen": Two-point distribution (asymptotically valid, matches skewness)
|
|
201
|
+
- "webb": Six-point distribution (recommended when n_clusters < 20)
|
|
202
|
+
bootstrap_weight_type : str, optional
|
|
203
|
+
.. deprecated:: 1.0.1
|
|
204
|
+
Use ``bootstrap_weights`` instead. Will be removed in v3.0.
|
|
205
|
+
seed : int, optional
|
|
206
|
+
Random seed for reproducibility.
|
|
207
|
+
rank_deficient_action : str, default="warn"
|
|
208
|
+
Action when design matrix is rank-deficient (linearly dependent columns):
|
|
209
|
+
- "warn": Issue warning and drop linearly dependent columns (default)
|
|
210
|
+
- "error": Raise ValueError
|
|
211
|
+
- "silent": Drop columns silently without warning
|
|
212
|
+
base_period : str, default="varying"
|
|
213
|
+
Method for selecting the base (reference) period for computing
|
|
214
|
+
ATT(g,t). Options:
|
|
215
|
+
- "varying": For pre-treatment periods (t < g - anticipation), use
|
|
216
|
+
t-1 as base (consecutive comparisons). For post-treatment, use
|
|
217
|
+
g-1-anticipation. Requires t-1 to exist in data.
|
|
218
|
+
- "universal": Always use g-1-anticipation as base period.
|
|
219
|
+
Both produce identical post-treatment effects. Matches R's
|
|
220
|
+
did::att_gt() base_period parameter.
|
|
221
|
+
|
|
222
|
+
Attributes
|
|
223
|
+
----------
|
|
224
|
+
results_ : CallawaySantAnnaResults
|
|
225
|
+
Estimation results after calling fit().
|
|
226
|
+
is_fitted_ : bool
|
|
227
|
+
Whether the model has been fitted.
|
|
228
|
+
|
|
229
|
+
Examples
|
|
230
|
+
--------
|
|
231
|
+
Basic usage:
|
|
232
|
+
|
|
233
|
+
>>> import pandas as pd
|
|
234
|
+
>>> from diff_diff import CallawaySantAnna
|
|
235
|
+
>>>
|
|
236
|
+
>>> # Panel data with staggered treatment
|
|
237
|
+
>>> # 'first_treat' = period when unit was first treated (0 if never treated)
|
|
238
|
+
>>> data = pd.DataFrame({
|
|
239
|
+
... 'unit': [...],
|
|
240
|
+
... 'time': [...],
|
|
241
|
+
... 'outcome': [...],
|
|
242
|
+
... 'first_treat': [...] # 0 for never-treated, else first treatment period
|
|
243
|
+
... })
|
|
244
|
+
>>>
|
|
245
|
+
>>> cs = CallawaySantAnna()
|
|
246
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
247
|
+
... time='time', first_treat='first_treat')
|
|
248
|
+
>>>
|
|
249
|
+
>>> results.print_summary()
|
|
250
|
+
|
|
251
|
+
With event study aggregation:
|
|
252
|
+
|
|
253
|
+
>>> cs = CallawaySantAnna()
|
|
254
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
255
|
+
... time='time', first_treat='first_treat',
|
|
256
|
+
... aggregate='event_study')
|
|
257
|
+
>>>
|
|
258
|
+
>>> # Plot event study
|
|
259
|
+
>>> from diff_diff import plot_event_study
|
|
260
|
+
>>> plot_event_study(results)
|
|
261
|
+
|
|
262
|
+
With covariate adjustment (conditional parallel trends):
|
|
263
|
+
|
|
264
|
+
>>> # When parallel trends only holds conditional on covariates
|
|
265
|
+
>>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
|
|
266
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
267
|
+
... time='time', first_treat='first_treat',
|
|
268
|
+
... covariates=['age', 'income'])
|
|
269
|
+
>>>
|
|
270
|
+
>>> # DR is recommended: consistent if either outcome model
|
|
271
|
+
>>> # or propensity model is correctly specified
|
|
272
|
+
|
|
273
|
+
Notes
|
|
274
|
+
-----
|
|
275
|
+
The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
|
|
276
|
+
approach: instead of estimating a single treatment effect, they estimate
|
|
277
|
+
ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
|
|
278
|
+
problem where already-treated units act as controls.
|
|
279
|
+
|
|
280
|
+
The ATT(g,t) is identified under parallel trends conditional on covariates:
|
|
281
|
+
|
|
282
|
+
E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
|
|
283
|
+
|
|
284
|
+
where G=g indicates treatment cohort g and C=1 indicates control units.
|
|
285
|
+
This uses g-1 as the base period, which applies to post-treatment (t >= g).
|
|
286
|
+
With base_period="varying" (default), pre-treatment uses t-1 as base for
|
|
287
|
+
consecutive comparisons useful in parallel trends diagnostics.
|
|
288
|
+
|
|
289
|
+
References
|
|
290
|
+
----------
|
|
291
|
+
Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
|
|
292
|
+
multiple time periods. Journal of Econometrics, 225(2), 200-230.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def __init__(
|
|
296
|
+
self,
|
|
297
|
+
control_group: str = "never_treated",
|
|
298
|
+
anticipation: int = 0,
|
|
299
|
+
estimation_method: str = "dr",
|
|
300
|
+
alpha: float = 0.05,
|
|
301
|
+
cluster: Optional[str] = None,
|
|
302
|
+
n_bootstrap: int = 0,
|
|
303
|
+
bootstrap_weights: Optional[str] = None,
|
|
304
|
+
bootstrap_weight_type: Optional[str] = None,
|
|
305
|
+
seed: Optional[int] = None,
|
|
306
|
+
rank_deficient_action: str = "warn",
|
|
307
|
+
base_period: str = "varying",
|
|
308
|
+
):
|
|
309
|
+
import warnings
|
|
310
|
+
|
|
311
|
+
if control_group not in ["never_treated", "not_yet_treated"]:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"control_group must be 'never_treated' or 'not_yet_treated', "
|
|
314
|
+
f"got '{control_group}'"
|
|
315
|
+
)
|
|
316
|
+
if estimation_method not in ["dr", "ipw", "reg"]:
|
|
317
|
+
raise ValueError(
|
|
318
|
+
f"estimation_method must be 'dr', 'ipw', or 'reg', "
|
|
319
|
+
f"got '{estimation_method}'"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Handle bootstrap_weight_type deprecation
|
|
323
|
+
if bootstrap_weight_type is not None:
|
|
324
|
+
warnings.warn(
|
|
325
|
+
"bootstrap_weight_type is deprecated and will be removed in v3.0. "
|
|
326
|
+
"Use bootstrap_weights instead.",
|
|
327
|
+
DeprecationWarning,
|
|
328
|
+
stacklevel=2
|
|
329
|
+
)
|
|
330
|
+
if bootstrap_weights is None:
|
|
331
|
+
bootstrap_weights = bootstrap_weight_type
|
|
332
|
+
|
|
333
|
+
# Default to rademacher if neither specified
|
|
334
|
+
if bootstrap_weights is None:
|
|
335
|
+
bootstrap_weights = "rademacher"
|
|
336
|
+
|
|
337
|
+
if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
|
|
340
|
+
f"got '{bootstrap_weights}'"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if rank_deficient_action not in ["warn", "error", "silent"]:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
|
|
346
|
+
f"got '{rank_deficient_action}'"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if base_period not in ["varying", "universal"]:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"base_period must be 'varying' or 'universal', "
|
|
352
|
+
f"got '{base_period}'"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
self.control_group = control_group
|
|
356
|
+
self.anticipation = anticipation
|
|
357
|
+
self.estimation_method = estimation_method
|
|
358
|
+
self.alpha = alpha
|
|
359
|
+
self.cluster = cluster
|
|
360
|
+
self.n_bootstrap = n_bootstrap
|
|
361
|
+
self.bootstrap_weights = bootstrap_weights
|
|
362
|
+
# Keep bootstrap_weight_type for backward compatibility
|
|
363
|
+
self.bootstrap_weight_type = bootstrap_weights
|
|
364
|
+
self.seed = seed
|
|
365
|
+
self.rank_deficient_action = rank_deficient_action
|
|
366
|
+
self.base_period = base_period
|
|
367
|
+
|
|
368
|
+
self.is_fitted_ = False
|
|
369
|
+
self.results_: Optional[CallawaySantAnnaResults] = None
|
|
370
|
+
|
|
371
|
+
def _precompute_structures(
|
|
372
|
+
self,
|
|
373
|
+
df: pd.DataFrame,
|
|
374
|
+
outcome: str,
|
|
375
|
+
unit: str,
|
|
376
|
+
time: str,
|
|
377
|
+
first_treat: str,
|
|
378
|
+
covariates: Optional[List[str]],
|
|
379
|
+
time_periods: List[Any],
|
|
380
|
+
treatment_groups: List[Any],
|
|
381
|
+
) -> PrecomputedData:
|
|
382
|
+
"""
|
|
383
|
+
Pre-compute data structures for efficient ATT(g,t) computation.
|
|
384
|
+
|
|
385
|
+
This pivots data to wide format and pre-computes:
|
|
386
|
+
- Outcome matrix (units x time periods)
|
|
387
|
+
- Covariate matrix (units x covariates) from base period
|
|
388
|
+
- Unit cohort membership masks
|
|
389
|
+
- Control unit masks
|
|
390
|
+
|
|
391
|
+
Returns
|
|
392
|
+
-------
|
|
393
|
+
PrecomputedData
|
|
394
|
+
Dictionary with pre-computed structures.
|
|
395
|
+
"""
|
|
396
|
+
# Get unique units and their cohort assignments
|
|
397
|
+
unit_info = df.groupby(unit)[first_treat].first()
|
|
398
|
+
all_units = unit_info.index.values
|
|
399
|
+
unit_cohorts = unit_info.values
|
|
400
|
+
n_units = len(all_units)
|
|
401
|
+
|
|
402
|
+
# Create unit index mapping for fast lookups
|
|
403
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
404
|
+
|
|
405
|
+
# Pivot outcome to wide format: rows = units, columns = time periods
|
|
406
|
+
outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
|
|
407
|
+
# Reindex to ensure all units are present (handles unbalanced panels)
|
|
408
|
+
outcome_wide = outcome_wide.reindex(all_units)
|
|
409
|
+
outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
|
|
410
|
+
period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
|
|
411
|
+
|
|
412
|
+
# Pre-compute cohort masks (boolean arrays)
|
|
413
|
+
cohort_masks = {}
|
|
414
|
+
for g in treatment_groups:
|
|
415
|
+
cohort_masks[g] = (unit_cohorts == g)
|
|
416
|
+
|
|
417
|
+
# Never-treated mask
|
|
418
|
+
# np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
|
|
419
|
+
never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
|
|
420
|
+
|
|
421
|
+
# Pre-compute covariate matrices by time period if needed
|
|
422
|
+
# (covariates are retrieved from the base period of each comparison)
|
|
423
|
+
covariate_by_period = None
|
|
424
|
+
if covariates:
|
|
425
|
+
covariate_by_period = {}
|
|
426
|
+
for t in time_periods:
|
|
427
|
+
period_data = df[df[time] == t].set_index(unit)
|
|
428
|
+
period_cov = period_data.reindex(all_units)[covariates]
|
|
429
|
+
covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
|
|
430
|
+
|
|
431
|
+
return {
|
|
432
|
+
'all_units': all_units,
|
|
433
|
+
'unit_to_idx': unit_to_idx,
|
|
434
|
+
'unit_cohorts': unit_cohorts,
|
|
435
|
+
'outcome_matrix': outcome_matrix,
|
|
436
|
+
'period_to_col': period_to_col,
|
|
437
|
+
'cohort_masks': cohort_masks,
|
|
438
|
+
'never_treated_mask': never_treated_mask,
|
|
439
|
+
'covariate_by_period': covariate_by_period,
|
|
440
|
+
'time_periods': time_periods,
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
def _compute_att_gt_fast(
|
|
444
|
+
self,
|
|
445
|
+
precomputed: PrecomputedData,
|
|
446
|
+
g: Any,
|
|
447
|
+
t: Any,
|
|
448
|
+
covariates: Optional[List[str]],
|
|
449
|
+
) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]:
|
|
450
|
+
"""
|
|
451
|
+
Compute ATT(g,t) using pre-computed data structures (fast version).
|
|
452
|
+
|
|
453
|
+
Uses vectorized numpy operations on pre-pivoted outcome matrix
|
|
454
|
+
instead of repeated pandas filtering.
|
|
455
|
+
"""
|
|
456
|
+
time_periods = precomputed['time_periods']
|
|
457
|
+
period_to_col = precomputed['period_to_col']
|
|
458
|
+
outcome_matrix = precomputed['outcome_matrix']
|
|
459
|
+
cohort_masks = precomputed['cohort_masks']
|
|
460
|
+
never_treated_mask = precomputed['never_treated_mask']
|
|
461
|
+
unit_cohorts = precomputed['unit_cohorts']
|
|
462
|
+
all_units = precomputed['all_units']
|
|
463
|
+
covariate_by_period = precomputed['covariate_by_period']
|
|
464
|
+
|
|
465
|
+
# Base period selection based on mode
|
|
466
|
+
if self.base_period == "universal":
|
|
467
|
+
# Universal: always use g - 1 - anticipation
|
|
468
|
+
base_period_val = g - 1 - self.anticipation
|
|
469
|
+
else: # varying
|
|
470
|
+
if t < g - self.anticipation:
|
|
471
|
+
# Pre-treatment: use t - 1 (consecutive comparison)
|
|
472
|
+
base_period_val = t - 1
|
|
473
|
+
else:
|
|
474
|
+
# Post-treatment: use g - 1 - anticipation
|
|
475
|
+
base_period_val = g - 1 - self.anticipation
|
|
476
|
+
|
|
477
|
+
if base_period_val not in period_to_col:
|
|
478
|
+
# Base period must exist; no fallback to maintain methodological consistency
|
|
479
|
+
return None, 0.0, 0, 0, None
|
|
480
|
+
|
|
481
|
+
# Check if periods exist in the data
|
|
482
|
+
if base_period_val not in period_to_col or t not in period_to_col:
|
|
483
|
+
return None, 0.0, 0, 0, None
|
|
484
|
+
|
|
485
|
+
base_col = period_to_col[base_period_val]
|
|
486
|
+
post_col = period_to_col[t]
|
|
487
|
+
|
|
488
|
+
# Get treated units mask (cohort g)
|
|
489
|
+
treated_mask = cohort_masks[g]
|
|
490
|
+
|
|
491
|
+
# Get control units mask
|
|
492
|
+
if self.control_group == "never_treated":
|
|
493
|
+
control_mask = never_treated_mask
|
|
494
|
+
else: # not_yet_treated
|
|
495
|
+
# Not yet treated at time t: never-treated OR (first_treat > t AND not cohort g)
|
|
496
|
+
# Must exclude cohort g since they are the treated group for this ATT(g,t)
|
|
497
|
+
control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g))
|
|
498
|
+
|
|
499
|
+
# Extract outcomes for base and post periods
|
|
500
|
+
y_base = outcome_matrix[:, base_col]
|
|
501
|
+
y_post = outcome_matrix[:, post_col]
|
|
502
|
+
|
|
503
|
+
# Compute outcome changes (vectorized)
|
|
504
|
+
outcome_change = y_post - y_base
|
|
505
|
+
|
|
506
|
+
# Filter to units with valid data (no NaN in either period)
|
|
507
|
+
valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
|
|
508
|
+
|
|
509
|
+
# Get treated and control with valid data
|
|
510
|
+
treated_valid = treated_mask & valid_mask
|
|
511
|
+
control_valid = control_mask & valid_mask
|
|
512
|
+
|
|
513
|
+
n_treated = np.sum(treated_valid)
|
|
514
|
+
n_control = np.sum(control_valid)
|
|
515
|
+
|
|
516
|
+
if n_treated == 0 or n_control == 0:
|
|
517
|
+
return None, 0.0, 0, 0, None
|
|
518
|
+
|
|
519
|
+
# Extract outcome changes for treated and control
|
|
520
|
+
treated_change = outcome_change[treated_valid]
|
|
521
|
+
control_change = outcome_change[control_valid]
|
|
522
|
+
|
|
523
|
+
# Get unit IDs for influence function
|
|
524
|
+
treated_units = all_units[treated_valid]
|
|
525
|
+
control_units = all_units[control_valid]
|
|
526
|
+
|
|
527
|
+
# Get covariates if specified (from the base period)
|
|
528
|
+
X_treated = None
|
|
529
|
+
X_control = None
|
|
530
|
+
if covariates and covariate_by_period is not None:
|
|
531
|
+
cov_matrix = covariate_by_period[base_period_val]
|
|
532
|
+
X_treated = cov_matrix[treated_valid]
|
|
533
|
+
X_control = cov_matrix[control_valid]
|
|
534
|
+
|
|
535
|
+
# Check for missing values
|
|
536
|
+
if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
|
|
537
|
+
warnings.warn(
|
|
538
|
+
f"Missing values in covariates for group {g}, time {t}. "
|
|
539
|
+
"Falling back to unconditional estimation.",
|
|
540
|
+
UserWarning,
|
|
541
|
+
stacklevel=3,
|
|
542
|
+
)
|
|
543
|
+
X_treated = None
|
|
544
|
+
X_control = None
|
|
545
|
+
|
|
546
|
+
# Estimation method
|
|
547
|
+
if self.estimation_method == "reg":
|
|
548
|
+
att_gt, se_gt, inf_func = self._outcome_regression(
|
|
549
|
+
treated_change, control_change, X_treated, X_control
|
|
550
|
+
)
|
|
551
|
+
elif self.estimation_method == "ipw":
|
|
552
|
+
att_gt, se_gt, inf_func = self._ipw_estimation(
|
|
553
|
+
treated_change, control_change,
|
|
554
|
+
int(n_treated), int(n_control),
|
|
555
|
+
X_treated, X_control
|
|
556
|
+
)
|
|
557
|
+
else: # doubly robust
|
|
558
|
+
att_gt, se_gt, inf_func = self._doubly_robust(
|
|
559
|
+
treated_change, control_change, X_treated, X_control
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
# Package influence function info with unit IDs for bootstrap
|
|
563
|
+
n_t = int(n_treated)
|
|
564
|
+
inf_func_info = {
|
|
565
|
+
'treated_units': list(treated_units),
|
|
566
|
+
'control_units': list(control_units),
|
|
567
|
+
'treated_inf': inf_func[:n_t],
|
|
568
|
+
'control_inf': inf_func[n_t:],
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info
|
|
572
|
+
|
|
573
|
+
def fit(
|
|
574
|
+
self,
|
|
575
|
+
data: pd.DataFrame,
|
|
576
|
+
outcome: str,
|
|
577
|
+
unit: str,
|
|
578
|
+
time: str,
|
|
579
|
+
first_treat: str,
|
|
580
|
+
covariates: Optional[List[str]] = None,
|
|
581
|
+
aggregate: Optional[str] = None,
|
|
582
|
+
balance_e: Optional[int] = None,
|
|
583
|
+
) -> CallawaySantAnnaResults:
|
|
584
|
+
"""
|
|
585
|
+
Fit the Callaway-Sant'Anna estimator.
|
|
586
|
+
|
|
587
|
+
Parameters
|
|
588
|
+
----------
|
|
589
|
+
data : pd.DataFrame
|
|
590
|
+
Panel data with unit and time identifiers.
|
|
591
|
+
outcome : str
|
|
592
|
+
Name of outcome variable column.
|
|
593
|
+
unit : str
|
|
594
|
+
Name of unit identifier column.
|
|
595
|
+
time : str
|
|
596
|
+
Name of time period column.
|
|
597
|
+
first_treat : str
|
|
598
|
+
Name of column indicating when unit was first treated.
|
|
599
|
+
Use 0 (or np.inf) for never-treated units.
|
|
600
|
+
covariates : list, optional
|
|
601
|
+
List of covariate column names for conditional parallel trends.
|
|
602
|
+
aggregate : str, optional
|
|
603
|
+
How to aggregate group-time effects:
|
|
604
|
+
- None: Only compute ATT(g,t) (default)
|
|
605
|
+
- "simple": Simple weighted average (overall ATT)
|
|
606
|
+
- "event_study": Aggregate by relative time (event study)
|
|
607
|
+
- "group": Aggregate by treatment cohort
|
|
608
|
+
- "all": Compute all aggregations
|
|
609
|
+
balance_e : int, optional
|
|
610
|
+
For event study, balance the panel at relative time e.
|
|
611
|
+
Ensures all groups contribute to each relative period.
|
|
612
|
+
|
|
613
|
+
Returns
|
|
614
|
+
-------
|
|
615
|
+
CallawaySantAnnaResults
|
|
616
|
+
Object containing all estimation results.
|
|
617
|
+
|
|
618
|
+
Raises
|
|
619
|
+
------
|
|
620
|
+
ValueError
|
|
621
|
+
If required columns are missing or data validation fails.
|
|
622
|
+
"""
|
|
623
|
+
# Validate inputs
|
|
624
|
+
required_cols = [outcome, unit, time, first_treat]
|
|
625
|
+
if covariates:
|
|
626
|
+
required_cols.extend(covariates)
|
|
627
|
+
|
|
628
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
629
|
+
if missing:
|
|
630
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
631
|
+
|
|
632
|
+
# Create working copy
|
|
633
|
+
df = data.copy()
|
|
634
|
+
|
|
635
|
+
# Ensure numeric types
|
|
636
|
+
df[time] = pd.to_numeric(df[time])
|
|
637
|
+
df[first_treat] = pd.to_numeric(df[first_treat])
|
|
638
|
+
|
|
639
|
+
# Standardize the first_treat column name for internal use
|
|
640
|
+
# This avoids hardcoding column names in internal methods
|
|
641
|
+
df['first_treat'] = df[first_treat]
|
|
642
|
+
|
|
643
|
+
# Never-treated indicator (must precede treatment_groups to exclude np.inf)
|
|
644
|
+
df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
|
|
645
|
+
# Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
|
|
646
|
+
df.loc[df[first_treat] == np.inf, first_treat] = 0
|
|
647
|
+
|
|
648
|
+
# Identify groups and time periods
|
|
649
|
+
time_periods = sorted(df[time].unique())
|
|
650
|
+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
|
|
651
|
+
|
|
652
|
+
# Get unique units
|
|
653
|
+
unit_info = df.groupby(unit).agg({
|
|
654
|
+
first_treat: 'first',
|
|
655
|
+
'_never_treated': 'first'
|
|
656
|
+
}).reset_index()
|
|
657
|
+
|
|
658
|
+
n_treated_units = (unit_info[first_treat] > 0).sum()
|
|
659
|
+
n_control_units = (unit_info['_never_treated']).sum()
|
|
660
|
+
|
|
661
|
+
if n_control_units == 0:
|
|
662
|
+
raise ValueError("No never-treated units found. Check 'first_treat' column.")
|
|
663
|
+
|
|
664
|
+
# Pre-compute data structures for efficient ATT(g,t) computation
|
|
665
|
+
precomputed = self._precompute_structures(
|
|
666
|
+
df, outcome, unit, time, first_treat,
|
|
667
|
+
covariates, time_periods, treatment_groups
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
# Compute ATT(g,t) for each group-time combination
|
|
671
|
+
group_time_effects = {}
|
|
672
|
+
influence_func_info = {} # Store influence functions for bootstrap
|
|
673
|
+
|
|
674
|
+
# Get minimum period for determining valid pre-treatment periods
|
|
675
|
+
min_period = min(time_periods)
|
|
676
|
+
|
|
677
|
+
for g in treatment_groups:
|
|
678
|
+
# Compute valid periods including pre-treatment
|
|
679
|
+
if self.base_period == "universal":
|
|
680
|
+
# Universal: all periods except the base period (which is normalized to 0)
|
|
681
|
+
universal_base = g - 1 - self.anticipation
|
|
682
|
+
valid_periods = [t for t in time_periods if t != universal_base]
|
|
683
|
+
else:
|
|
684
|
+
# Varying: post-treatment + pre-treatment where t-1 exists
|
|
685
|
+
valid_periods = [
|
|
686
|
+
t for t in time_periods
|
|
687
|
+
if t >= g - self.anticipation or t > min_period
|
|
688
|
+
]
|
|
689
|
+
|
|
690
|
+
for t in valid_periods:
|
|
691
|
+
att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast(
|
|
692
|
+
precomputed, g, t, covariates
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
if att_gt is not None:
|
|
696
|
+
t_stat = att_gt / se_gt if np.isfinite(se_gt) and se_gt > 0 else np.nan
|
|
697
|
+
p_val = compute_p_value(t_stat)
|
|
698
|
+
ci = compute_confidence_interval(att_gt, se_gt, self.alpha)
|
|
699
|
+
|
|
700
|
+
group_time_effects[(g, t)] = {
|
|
701
|
+
'effect': att_gt,
|
|
702
|
+
'se': se_gt,
|
|
703
|
+
't_stat': t_stat,
|
|
704
|
+
'p_value': p_val,
|
|
705
|
+
'conf_int': ci,
|
|
706
|
+
'n_treated': n_treat,
|
|
707
|
+
'n_control': n_ctrl,
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
if inf_info is not None:
|
|
711
|
+
influence_func_info[(g, t)] = inf_info
|
|
712
|
+
|
|
713
|
+
if not group_time_effects:
|
|
714
|
+
raise ValueError(
|
|
715
|
+
"Could not estimate any group-time effects. "
|
|
716
|
+
"Check that data has sufficient observations."
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Compute overall ATT (simple aggregation)
|
|
720
|
+
overall_att, overall_se = self._aggregate_simple(
|
|
721
|
+
group_time_effects, influence_func_info, df, unit, precomputed
|
|
722
|
+
)
|
|
723
|
+
# Use NaN for t-stat and p-value when SE is undefined (NaN or non-positive)
|
|
724
|
+
if np.isfinite(overall_se) and overall_se > 0:
|
|
725
|
+
overall_t = overall_att / overall_se
|
|
726
|
+
overall_p = compute_p_value(overall_t)
|
|
727
|
+
else:
|
|
728
|
+
overall_t = np.nan
|
|
729
|
+
overall_p = np.nan
|
|
730
|
+
overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
|
|
731
|
+
|
|
732
|
+
# Compute additional aggregations if requested
|
|
733
|
+
event_study_effects = None
|
|
734
|
+
group_effects = None
|
|
735
|
+
|
|
736
|
+
if aggregate in ["event_study", "all"]:
|
|
737
|
+
event_study_effects = self._aggregate_event_study(
|
|
738
|
+
group_time_effects, influence_func_info,
|
|
739
|
+
treatment_groups, time_periods, balance_e
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
if aggregate in ["group", "all"]:
|
|
743
|
+
group_effects = self._aggregate_by_group(
|
|
744
|
+
group_time_effects, influence_func_info, treatment_groups
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Run bootstrap inference if requested
|
|
748
|
+
bootstrap_results = None
|
|
749
|
+
if self.n_bootstrap > 0 and influence_func_info:
|
|
750
|
+
bootstrap_results = self._run_multiplier_bootstrap(
|
|
751
|
+
group_time_effects=group_time_effects,
|
|
752
|
+
influence_func_info=influence_func_info,
|
|
753
|
+
aggregate=aggregate,
|
|
754
|
+
balance_e=balance_e,
|
|
755
|
+
treatment_groups=treatment_groups,
|
|
756
|
+
time_periods=time_periods,
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# Update estimates with bootstrap inference
|
|
760
|
+
overall_se = bootstrap_results.overall_att_se
|
|
761
|
+
# Use NaN for t-stat when SE is undefined; p-value comes from bootstrap
|
|
762
|
+
if np.isfinite(overall_se) and overall_se > 0:
|
|
763
|
+
overall_t = overall_att / overall_se
|
|
764
|
+
else:
|
|
765
|
+
overall_t = np.nan
|
|
766
|
+
overall_p = bootstrap_results.overall_att_p_value
|
|
767
|
+
overall_ci = bootstrap_results.overall_att_ci
|
|
768
|
+
|
|
769
|
+
# Update group-time effects with bootstrap SEs
|
|
770
|
+
for gt in group_time_effects:
|
|
771
|
+
if gt in bootstrap_results.group_time_ses:
|
|
772
|
+
group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt]
|
|
773
|
+
group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt]
|
|
774
|
+
group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt]
|
|
775
|
+
effect = float(group_time_effects[gt]['effect'])
|
|
776
|
+
se = float(group_time_effects[gt]['se'])
|
|
777
|
+
group_time_effects[gt]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
|
|
778
|
+
|
|
779
|
+
# Update event study effects with bootstrap SEs
|
|
780
|
+
if (event_study_effects is not None
|
|
781
|
+
and bootstrap_results.event_study_ses is not None
|
|
782
|
+
and bootstrap_results.event_study_cis is not None
|
|
783
|
+
and bootstrap_results.event_study_p_values is not None):
|
|
784
|
+
for e in event_study_effects:
|
|
785
|
+
if e in bootstrap_results.event_study_ses:
|
|
786
|
+
event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e]
|
|
787
|
+
event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e]
|
|
788
|
+
p_val = bootstrap_results.event_study_p_values[e]
|
|
789
|
+
event_study_effects[e]['p_value'] = p_val
|
|
790
|
+
effect = float(event_study_effects[e]['effect'])
|
|
791
|
+
se = float(event_study_effects[e]['se'])
|
|
792
|
+
event_study_effects[e]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
|
|
793
|
+
|
|
794
|
+
# Update group effects with bootstrap SEs
|
|
795
|
+
if (group_effects is not None
|
|
796
|
+
and bootstrap_results.group_effect_ses is not None
|
|
797
|
+
and bootstrap_results.group_effect_cis is not None
|
|
798
|
+
and bootstrap_results.group_effect_p_values is not None):
|
|
799
|
+
for g in group_effects:
|
|
800
|
+
if g in bootstrap_results.group_effect_ses:
|
|
801
|
+
group_effects[g]['se'] = bootstrap_results.group_effect_ses[g]
|
|
802
|
+
group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g]
|
|
803
|
+
group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g]
|
|
804
|
+
effect = float(group_effects[g]['effect'])
|
|
805
|
+
se = float(group_effects[g]['se'])
|
|
806
|
+
group_effects[g]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan
|
|
807
|
+
|
|
808
|
+
# Store results
|
|
809
|
+
self.results_ = CallawaySantAnnaResults(
|
|
810
|
+
group_time_effects=group_time_effects,
|
|
811
|
+
overall_att=overall_att,
|
|
812
|
+
overall_se=overall_se,
|
|
813
|
+
overall_t_stat=overall_t,
|
|
814
|
+
overall_p_value=overall_p,
|
|
815
|
+
overall_conf_int=overall_ci,
|
|
816
|
+
groups=treatment_groups,
|
|
817
|
+
time_periods=time_periods,
|
|
818
|
+
n_obs=len(df),
|
|
819
|
+
n_treated_units=n_treated_units,
|
|
820
|
+
n_control_units=n_control_units,
|
|
821
|
+
alpha=self.alpha,
|
|
822
|
+
control_group=self.control_group,
|
|
823
|
+
base_period=self.base_period,
|
|
824
|
+
event_study_effects=event_study_effects,
|
|
825
|
+
group_effects=group_effects,
|
|
826
|
+
bootstrap_results=bootstrap_results,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
self.is_fitted_ = True
|
|
830
|
+
return self.results_
|
|
831
|
+
|
|
832
|
+
def _outcome_regression(
|
|
833
|
+
self,
|
|
834
|
+
treated_change: np.ndarray,
|
|
835
|
+
control_change: np.ndarray,
|
|
836
|
+
X_treated: Optional[np.ndarray] = None,
|
|
837
|
+
X_control: Optional[np.ndarray] = None,
|
|
838
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
839
|
+
"""
|
|
840
|
+
Estimate ATT using outcome regression.
|
|
841
|
+
|
|
842
|
+
With covariates:
|
|
843
|
+
1. Regress outcome changes on covariates for control group
|
|
844
|
+
2. Predict counterfactual for treated using their covariates
|
|
845
|
+
3. ATT = mean(treated_change) - mean(predicted_counterfactual)
|
|
846
|
+
|
|
847
|
+
Without covariates:
|
|
848
|
+
Simple difference in means.
|
|
849
|
+
"""
|
|
850
|
+
n_t = len(treated_change)
|
|
851
|
+
n_c = len(control_change)
|
|
852
|
+
|
|
853
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
854
|
+
# Covariate-adjusted outcome regression
|
|
855
|
+
# Fit regression on control units: E[Delta Y | X, D=0]
|
|
856
|
+
beta, residuals = _linear_regression(
|
|
857
|
+
X_control, control_change,
|
|
858
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
# Predict counterfactual for treated units
|
|
862
|
+
X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
|
|
863
|
+
predicted_control = X_treated_with_intercept @ beta
|
|
864
|
+
|
|
865
|
+
# ATT = mean(observed treated change - predicted counterfactual)
|
|
866
|
+
att = np.mean(treated_change - predicted_control)
|
|
867
|
+
|
|
868
|
+
# Standard error using sandwich estimator
|
|
869
|
+
# Variance from treated: Var(Y_1 - m(X))
|
|
870
|
+
treated_residuals = treated_change - predicted_control
|
|
871
|
+
var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
|
|
872
|
+
|
|
873
|
+
# Variance from control regression (residual variance)
|
|
874
|
+
var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
|
|
875
|
+
|
|
876
|
+
# Approximate SE (ignoring estimation error in beta for simplicity)
|
|
877
|
+
se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
878
|
+
|
|
879
|
+
# Influence function
|
|
880
|
+
inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
|
|
881
|
+
inf_control = -residuals / n_c
|
|
882
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
883
|
+
else:
|
|
884
|
+
# Simple difference in means (no covariates)
|
|
885
|
+
att = np.mean(treated_change) - np.mean(control_change)
|
|
886
|
+
|
|
887
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
888
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
889
|
+
|
|
890
|
+
se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
891
|
+
|
|
892
|
+
# Influence function (for aggregation)
|
|
893
|
+
inf_treated = treated_change - np.mean(treated_change)
|
|
894
|
+
inf_control = control_change - np.mean(control_change)
|
|
895
|
+
inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
|
|
896
|
+
|
|
897
|
+
return att, se, inf_func
|
|
898
|
+
|
|
899
|
+
def _ipw_estimation(
|
|
900
|
+
self,
|
|
901
|
+
treated_change: np.ndarray,
|
|
902
|
+
control_change: np.ndarray,
|
|
903
|
+
n_treated: int,
|
|
904
|
+
n_control: int,
|
|
905
|
+
X_treated: Optional[np.ndarray] = None,
|
|
906
|
+
X_control: Optional[np.ndarray] = None,
|
|
907
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
908
|
+
"""
|
|
909
|
+
Estimate ATT using inverse probability weighting.
|
|
910
|
+
|
|
911
|
+
With covariates:
|
|
912
|
+
1. Estimate propensity score P(D=1|X) using logistic regression
|
|
913
|
+
2. Reweight control units to match treated covariate distribution
|
|
914
|
+
3. ATT = mean(treated) - weighted_mean(control)
|
|
915
|
+
|
|
916
|
+
Without covariates:
|
|
917
|
+
Simple difference in means with unconditional propensity weighting.
|
|
918
|
+
"""
|
|
919
|
+
n_t = len(treated_change)
|
|
920
|
+
n_c = len(control_change)
|
|
921
|
+
n_total = n_treated + n_control
|
|
922
|
+
|
|
923
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
924
|
+
# Covariate-adjusted IPW estimation
|
|
925
|
+
# Stack covariates and create treatment indicator
|
|
926
|
+
X_all = np.vstack([X_treated, X_control])
|
|
927
|
+
D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
928
|
+
|
|
929
|
+
# Estimate propensity scores using logistic regression
|
|
930
|
+
try:
|
|
931
|
+
_, pscore = _logistic_regression(X_all, D)
|
|
932
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
933
|
+
# Fallback to unconditional if logistic regression fails
|
|
934
|
+
warnings.warn(
|
|
935
|
+
"Propensity score estimation failed. "
|
|
936
|
+
"Falling back to unconditional estimation.",
|
|
937
|
+
UserWarning,
|
|
938
|
+
stacklevel=4,
|
|
939
|
+
)
|
|
940
|
+
pscore = np.full(len(D), n_t / (n_t + n_c))
|
|
941
|
+
|
|
942
|
+
# Propensity scores for treated and control
|
|
943
|
+
pscore_treated = pscore[:n_t]
|
|
944
|
+
pscore_control = pscore[n_t:]
|
|
945
|
+
|
|
946
|
+
# Clip propensity scores to avoid extreme weights
|
|
947
|
+
pscore_control = np.clip(pscore_control, 0.01, 0.99)
|
|
948
|
+
pscore_treated = np.clip(pscore_treated, 0.01, 0.99)
|
|
949
|
+
|
|
950
|
+
# IPW weights for control units: p(X) / (1 - p(X))
|
|
951
|
+
# This reweights controls to have same covariate distribution as treated
|
|
952
|
+
weights_control = pscore_control / (1 - pscore_control)
|
|
953
|
+
weights_control = weights_control / np.sum(weights_control) # normalize
|
|
954
|
+
|
|
955
|
+
# ATT = mean(treated) - weighted_mean(control)
|
|
956
|
+
att = np.mean(treated_change) - np.sum(weights_control * control_change)
|
|
957
|
+
|
|
958
|
+
# Compute standard error
|
|
959
|
+
# Variance of treated mean
|
|
960
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
961
|
+
|
|
962
|
+
# Variance of weighted control mean
|
|
963
|
+
weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2)
|
|
964
|
+
|
|
965
|
+
se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
966
|
+
|
|
967
|
+
# Influence function
|
|
968
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
969
|
+
inf_control = -weights_control * (control_change - np.sum(weights_control * control_change))
|
|
970
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
971
|
+
else:
|
|
972
|
+
# Unconditional IPW (reduces to difference in means)
|
|
973
|
+
p_treat = n_treated / n_total # unconditional propensity score
|
|
974
|
+
|
|
975
|
+
att = np.mean(treated_change) - np.mean(control_change)
|
|
976
|
+
|
|
977
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
978
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
979
|
+
|
|
980
|
+
# Adjusted variance for IPW
|
|
981
|
+
se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0
|
|
982
|
+
|
|
983
|
+
# Influence function (for aggregation)
|
|
984
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
985
|
+
inf_control = (control_change - np.mean(control_change)) / n_c
|
|
986
|
+
inf_func = np.concatenate([inf_treated, -inf_control])
|
|
987
|
+
|
|
988
|
+
return att, se, inf_func
|
|
989
|
+
|
|
990
|
+
def _doubly_robust(
|
|
991
|
+
self,
|
|
992
|
+
treated_change: np.ndarray,
|
|
993
|
+
control_change: np.ndarray,
|
|
994
|
+
X_treated: Optional[np.ndarray] = None,
|
|
995
|
+
X_control: Optional[np.ndarray] = None,
|
|
996
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
997
|
+
"""
|
|
998
|
+
Estimate ATT using doubly robust estimation.
|
|
999
|
+
|
|
1000
|
+
With covariates:
|
|
1001
|
+
Combines outcome regression and IPW for double robustness.
|
|
1002
|
+
The estimator is consistent if either the outcome model OR
|
|
1003
|
+
the propensity model is correctly specified.
|
|
1004
|
+
|
|
1005
|
+
ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
|
|
1006
|
+
+ (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
|
|
1007
|
+
|
|
1008
|
+
where m(X) is the outcome model and w_i are IPW weights.
|
|
1009
|
+
|
|
1010
|
+
Without covariates:
|
|
1011
|
+
Reduces to simple difference in means.
|
|
1012
|
+
"""
|
|
1013
|
+
n_t = len(treated_change)
|
|
1014
|
+
n_c = len(control_change)
|
|
1015
|
+
|
|
1016
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
1017
|
+
# Doubly robust estimation with covariates
|
|
1018
|
+
# Step 1: Outcome regression - fit E[Delta Y | X] on control
|
|
1019
|
+
beta, _ = _linear_regression(
|
|
1020
|
+
X_control, control_change,
|
|
1021
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
# Predict counterfactual for both treated and control
|
|
1025
|
+
X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
|
|
1026
|
+
X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
|
|
1027
|
+
m_treated = X_treated_with_intercept @ beta
|
|
1028
|
+
m_control = X_control_with_intercept @ beta
|
|
1029
|
+
|
|
1030
|
+
# Step 2: Propensity score estimation
|
|
1031
|
+
X_all = np.vstack([X_treated, X_control])
|
|
1032
|
+
D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
1033
|
+
|
|
1034
|
+
try:
|
|
1035
|
+
_, pscore = _logistic_regression(X_all, D)
|
|
1036
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1037
|
+
# Fallback to unconditional if logistic regression fails
|
|
1038
|
+
pscore = np.full(len(D), n_t / (n_t + n_c))
|
|
1039
|
+
|
|
1040
|
+
pscore_control = pscore[n_t:]
|
|
1041
|
+
|
|
1042
|
+
# Clip propensity scores
|
|
1043
|
+
pscore_control = np.clip(pscore_control, 0.01, 0.99)
|
|
1044
|
+
|
|
1045
|
+
# IPW weights for control: p(X) / (1 - p(X))
|
|
1046
|
+
weights_control = pscore_control / (1 - pscore_control)
|
|
1047
|
+
|
|
1048
|
+
# Step 3: Doubly robust ATT
|
|
1049
|
+
# ATT = mean(treated - m(X_treated))
|
|
1050
|
+
# + weighted_mean_control((m(X) - Y) * weight)
|
|
1051
|
+
att_treated_part = np.mean(treated_change - m_treated)
|
|
1052
|
+
|
|
1053
|
+
# Augmentation term from control
|
|
1054
|
+
augmentation = np.sum(weights_control * (m_control - control_change)) / n_t
|
|
1055
|
+
|
|
1056
|
+
att = att_treated_part + augmentation
|
|
1057
|
+
|
|
1058
|
+
# Step 4: Standard error using influence function
|
|
1059
|
+
# Influence function for DR estimator
|
|
1060
|
+
psi_treated = (treated_change - m_treated - att) / n_t
|
|
1061
|
+
psi_control = (weights_control * (m_control - control_change)) / n_t
|
|
1062
|
+
|
|
1063
|
+
# Variance is sum of squared influence functions
|
|
1064
|
+
var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2)
|
|
1065
|
+
se = np.sqrt(var_psi) if var_psi > 0 else 0.0
|
|
1066
|
+
|
|
1067
|
+
# Full influence function
|
|
1068
|
+
inf_func = np.concatenate([psi_treated, psi_control])
|
|
1069
|
+
else:
|
|
1070
|
+
# Without covariates, DR simplifies to difference in means
|
|
1071
|
+
att = np.mean(treated_change) - np.mean(control_change)
|
|
1072
|
+
|
|
1073
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
1074
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
1075
|
+
|
|
1076
|
+
se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
1077
|
+
|
|
1078
|
+
# Influence function for DR estimator
|
|
1079
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
1080
|
+
inf_control = (control_change - np.mean(control_change)) / n_c
|
|
1081
|
+
inf_func = np.concatenate([inf_treated, -inf_control])
|
|
1082
|
+
|
|
1083
|
+
return att, se, inf_func
|
|
1084
|
+
|
|
1085
|
+
def get_params(self) -> Dict[str, Any]:
|
|
1086
|
+
"""Get estimator parameters (sklearn-compatible)."""
|
|
1087
|
+
return {
|
|
1088
|
+
"control_group": self.control_group,
|
|
1089
|
+
"anticipation": self.anticipation,
|
|
1090
|
+
"estimation_method": self.estimation_method,
|
|
1091
|
+
"alpha": self.alpha,
|
|
1092
|
+
"cluster": self.cluster,
|
|
1093
|
+
"n_bootstrap": self.n_bootstrap,
|
|
1094
|
+
"bootstrap_weights": self.bootstrap_weights,
|
|
1095
|
+
# Deprecated but kept for backward compatibility
|
|
1096
|
+
"bootstrap_weight_type": self.bootstrap_weight_type,
|
|
1097
|
+
"seed": self.seed,
|
|
1098
|
+
"rank_deficient_action": self.rank_deficient_action,
|
|
1099
|
+
"base_period": self.base_period,
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
def set_params(self, **params) -> "CallawaySantAnna":
|
|
1103
|
+
"""Set estimator parameters (sklearn-compatible)."""
|
|
1104
|
+
for key, value in params.items():
|
|
1105
|
+
if hasattr(self, key):
|
|
1106
|
+
setattr(self, key, value)
|
|
1107
|
+
else:
|
|
1108
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
1109
|
+
return self
|
|
1110
|
+
|
|
1111
|
+
def summary(self) -> str:
|
|
1112
|
+
"""Get summary of estimation results."""
|
|
1113
|
+
if not self.is_fitted_:
|
|
1114
|
+
raise RuntimeError("Model must be fitted before calling summary()")
|
|
1115
|
+
assert self.results_ is not None
|
|
1116
|
+
return self.results_.summary()
|
|
1117
|
+
|
|
1118
|
+
def print_summary(self) -> None:
|
|
1119
|
+
"""Print summary to stdout."""
|
|
1120
|
+
print(self.summary())
|