diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +234 -0
- diff_diff/_backend.py +64 -0
- diff_diff/_rust_backend.cpython-39-darwin.so +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1000 -0
- diff_diff/honest_did.py +1493 -0
- diff_diff/linalg.py +980 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1338 -0
- diff_diff/pretrends.py +1067 -0
- diff_diff/results.py +703 -0
- diff_diff/staggered.py +2297 -0
- diff_diff/sun_abraham.py +1176 -0
- diff_diff/synthetic_did.py +738 -0
- diff_diff/triple_diff.py +1291 -0
- diff_diff/trop.py +1348 -0
- diff_diff/twfe.py +344 -0
- diff_diff/utils.py +1481 -0
- diff_diff/visualization.py +1627 -0
- diff_diff-2.1.0.dist-info/METADATA +2511 -0
- diff_diff-2.1.0.dist-info/RECORD +24 -0
- diff_diff-2.1.0.dist-info/WHEEL +4 -0
diff_diff/staggered.py
ADDED
|
@@ -0,0 +1,2297 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Staggered Difference-in-Differences estimators.
|
|
3
|
+
|
|
4
|
+
Implements modern methods for DiD with variation in treatment timing,
|
|
5
|
+
including the Callaway-Sant'Anna (2021) estimator.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from scipy import optimize
|
|
15
|
+
|
|
16
|
+
from diff_diff.linalg import solve_ols
|
|
17
|
+
from diff_diff.results import _get_significance_stars
|
|
18
|
+
from diff_diff.utils import (
|
|
19
|
+
compute_confidence_interval,
|
|
20
|
+
compute_p_value,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Import Rust backend if available (from _backend to avoid circular imports)
|
|
24
|
+
from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights
|
|
25
|
+
|
|
26
|
+
# Type alias for pre-computed structures
|
|
27
|
+
PrecomputedData = Dict[str, Any]
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# Bootstrap Weight Generators
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _generate_bootstrap_weights(
|
|
35
|
+
n_units: int,
|
|
36
|
+
weight_type: str,
|
|
37
|
+
rng: np.random.Generator,
|
|
38
|
+
) -> np.ndarray:
|
|
39
|
+
"""
|
|
40
|
+
Generate bootstrap weights for multiplier bootstrap.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
n_units : int
|
|
45
|
+
Number of units (clusters) to generate weights for.
|
|
46
|
+
weight_type : str
|
|
47
|
+
Type of weights: "rademacher", "mammen", or "webb".
|
|
48
|
+
rng : np.random.Generator
|
|
49
|
+
Random number generator.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
np.ndarray
|
|
54
|
+
Array of bootstrap weights with shape (n_units,).
|
|
55
|
+
"""
|
|
56
|
+
if weight_type == "rademacher":
|
|
57
|
+
# Rademacher: +1 or -1 with equal probability
|
|
58
|
+
return rng.choice([-1.0, 1.0], size=n_units)
|
|
59
|
+
|
|
60
|
+
elif weight_type == "mammen":
|
|
61
|
+
# Mammen's two-point distribution
|
|
62
|
+
# E[v] = 0, E[v^2] = 1, E[v^3] = 1
|
|
63
|
+
sqrt5 = np.sqrt(5)
|
|
64
|
+
val1 = -(sqrt5 - 1) / 2 # ≈ -0.618
|
|
65
|
+
val2 = (sqrt5 + 1) / 2 # ≈ 1.618 (golden ratio)
|
|
66
|
+
p1 = (sqrt5 + 1) / (2 * sqrt5) # ≈ 0.724
|
|
67
|
+
return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1])
|
|
68
|
+
|
|
69
|
+
elif weight_type == "webb":
|
|
70
|
+
# Webb's 6-point distribution (recommended for few clusters)
|
|
71
|
+
values = np.array([
|
|
72
|
+
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
73
|
+
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
74
|
+
])
|
|
75
|
+
probs = np.array([1, 2, 3, 3, 2, 1]) / 12
|
|
76
|
+
return rng.choice(values, size=n_units, p=probs)
|
|
77
|
+
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
|
|
81
|
+
f"got '{weight_type}'"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _generate_bootstrap_weights_batch(
|
|
86
|
+
n_bootstrap: int,
|
|
87
|
+
n_units: int,
|
|
88
|
+
weight_type: str,
|
|
89
|
+
rng: np.random.Generator,
|
|
90
|
+
) -> np.ndarray:
|
|
91
|
+
"""
|
|
92
|
+
Generate all bootstrap weights at once (vectorized).
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
n_bootstrap : int
|
|
97
|
+
Number of bootstrap iterations.
|
|
98
|
+
n_units : int
|
|
99
|
+
Number of units (clusters) to generate weights for.
|
|
100
|
+
weight_type : str
|
|
101
|
+
Type of weights: "rademacher", "mammen", or "webb".
|
|
102
|
+
rng : np.random.Generator
|
|
103
|
+
Random number generator.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
np.ndarray
|
|
108
|
+
Array of bootstrap weights with shape (n_bootstrap, n_units).
|
|
109
|
+
"""
|
|
110
|
+
# Use Rust backend if available (parallel + fast RNG)
|
|
111
|
+
if HAS_RUST_BACKEND:
|
|
112
|
+
# Get seed from the NumPy RNG for reproducibility
|
|
113
|
+
seed = rng.integers(0, 2**63 - 1)
|
|
114
|
+
return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed)
|
|
115
|
+
|
|
116
|
+
# Fallback to NumPy implementation
|
|
117
|
+
return _generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _generate_bootstrap_weights_batch_numpy(
|
|
121
|
+
n_bootstrap: int,
|
|
122
|
+
n_units: int,
|
|
123
|
+
weight_type: str,
|
|
124
|
+
rng: np.random.Generator,
|
|
125
|
+
) -> np.ndarray:
|
|
126
|
+
"""
|
|
127
|
+
NumPy fallback implementation of _generate_bootstrap_weights_batch.
|
|
128
|
+
|
|
129
|
+
Generates multiplier bootstrap weights for wild cluster bootstrap.
|
|
130
|
+
All weight distributions satisfy E[w] = 0, E[w^2] = 1.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
n_bootstrap : int
|
|
135
|
+
Number of bootstrap iterations.
|
|
136
|
+
n_units : int
|
|
137
|
+
Number of units (clusters) to generate weights for.
|
|
138
|
+
weight_type : str
|
|
139
|
+
Type of weights: "rademacher" (+-1), "mammen" (2-point),
|
|
140
|
+
or "webb" (6-point).
|
|
141
|
+
rng : np.random.Generator
|
|
142
|
+
Random number generator for reproducibility.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
np.ndarray
|
|
147
|
+
Array of bootstrap weights with shape (n_bootstrap, n_units).
|
|
148
|
+
"""
|
|
149
|
+
if weight_type == "rademacher":
|
|
150
|
+
# Rademacher: +1 or -1 with equal probability
|
|
151
|
+
return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units))
|
|
152
|
+
|
|
153
|
+
elif weight_type == "mammen":
|
|
154
|
+
# Mammen's two-point distribution
|
|
155
|
+
sqrt5 = np.sqrt(5)
|
|
156
|
+
val1 = -(sqrt5 - 1) / 2
|
|
157
|
+
val2 = (sqrt5 + 1) / 2
|
|
158
|
+
p1 = (sqrt5 + 1) / (2 * sqrt5)
|
|
159
|
+
return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1])
|
|
160
|
+
|
|
161
|
+
elif weight_type == "webb":
|
|
162
|
+
# Webb's 6-point distribution
|
|
163
|
+
values = np.array([
|
|
164
|
+
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
165
|
+
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
166
|
+
])
|
|
167
|
+
probs = np.array([1, 2, 3, 3, 2, 1]) / 12
|
|
168
|
+
return rng.choice(values, size=(n_bootstrap, n_units), p=probs)
|
|
169
|
+
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
|
|
173
|
+
f"got '{weight_type}'"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# =============================================================================
|
|
178
|
+
# Bootstrap Results Container
|
|
179
|
+
# =============================================================================
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@dataclass
|
|
183
|
+
class CSBootstrapResults:
|
|
184
|
+
"""
|
|
185
|
+
Results from Callaway-Sant'Anna multiplier bootstrap inference.
|
|
186
|
+
|
|
187
|
+
Attributes
|
|
188
|
+
----------
|
|
189
|
+
n_bootstrap : int
|
|
190
|
+
Number of bootstrap iterations.
|
|
191
|
+
weight_type : str
|
|
192
|
+
Type of bootstrap weights used.
|
|
193
|
+
alpha : float
|
|
194
|
+
Significance level used for confidence intervals.
|
|
195
|
+
overall_att_se : float
|
|
196
|
+
Bootstrap standard error for overall ATT.
|
|
197
|
+
overall_att_ci : Tuple[float, float]
|
|
198
|
+
Bootstrap confidence interval for overall ATT.
|
|
199
|
+
overall_att_p_value : float
|
|
200
|
+
Bootstrap p-value for overall ATT.
|
|
201
|
+
group_time_ses : Dict[Tuple[Any, Any], float]
|
|
202
|
+
Bootstrap SEs for each ATT(g,t).
|
|
203
|
+
group_time_cis : Dict[Tuple[Any, Any], Tuple[float, float]]
|
|
204
|
+
Bootstrap CIs for each ATT(g,t).
|
|
205
|
+
group_time_p_values : Dict[Tuple[Any, Any], float]
|
|
206
|
+
Bootstrap p-values for each ATT(g,t).
|
|
207
|
+
event_study_ses : Optional[Dict[int, float]]
|
|
208
|
+
Bootstrap SEs for event study effects.
|
|
209
|
+
event_study_cis : Optional[Dict[int, Tuple[float, float]]]
|
|
210
|
+
Bootstrap CIs for event study effects.
|
|
211
|
+
event_study_p_values : Optional[Dict[int, float]]
|
|
212
|
+
Bootstrap p-values for event study effects.
|
|
213
|
+
group_effect_ses : Optional[Dict[Any, float]]
|
|
214
|
+
Bootstrap SEs for group effects.
|
|
215
|
+
group_effect_cis : Optional[Dict[Any, Tuple[float, float]]]
|
|
216
|
+
Bootstrap CIs for group effects.
|
|
217
|
+
group_effect_p_values : Optional[Dict[Any, float]]
|
|
218
|
+
Bootstrap p-values for group effects.
|
|
219
|
+
bootstrap_distribution : Optional[np.ndarray]
|
|
220
|
+
Full bootstrap distribution of overall ATT (if requested).
|
|
221
|
+
"""
|
|
222
|
+
n_bootstrap: int
|
|
223
|
+
weight_type: str
|
|
224
|
+
alpha: float
|
|
225
|
+
overall_att_se: float
|
|
226
|
+
overall_att_ci: Tuple[float, float]
|
|
227
|
+
overall_att_p_value: float
|
|
228
|
+
group_time_ses: Dict[Tuple[Any, Any], float]
|
|
229
|
+
group_time_cis: Dict[Tuple[Any, Any], Tuple[float, float]]
|
|
230
|
+
group_time_p_values: Dict[Tuple[Any, Any], float]
|
|
231
|
+
event_study_ses: Optional[Dict[int, float]] = None
|
|
232
|
+
event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
|
|
233
|
+
event_study_p_values: Optional[Dict[int, float]] = None
|
|
234
|
+
group_effect_ses: Optional[Dict[Any, float]] = None
|
|
235
|
+
group_effect_cis: Optional[Dict[Any, Tuple[float, float]]] = None
|
|
236
|
+
group_effect_p_values: Optional[Dict[Any, float]] = None
|
|
237
|
+
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _logistic_regression(
|
|
241
|
+
X: np.ndarray,
|
|
242
|
+
y: np.ndarray,
|
|
243
|
+
max_iter: int = 100,
|
|
244
|
+
tol: float = 1e-6,
|
|
245
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
246
|
+
"""
|
|
247
|
+
Fit logistic regression using scipy optimize.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
X : np.ndarray
|
|
252
|
+
Feature matrix (n_samples, n_features). Intercept added automatically.
|
|
253
|
+
y : np.ndarray
|
|
254
|
+
Binary outcome (0/1).
|
|
255
|
+
max_iter : int
|
|
256
|
+
Maximum iterations.
|
|
257
|
+
tol : float
|
|
258
|
+
Convergence tolerance.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
beta : np.ndarray
|
|
263
|
+
Fitted coefficients (including intercept).
|
|
264
|
+
probs : np.ndarray
|
|
265
|
+
Predicted probabilities.
|
|
266
|
+
"""
|
|
267
|
+
n, p = X.shape
|
|
268
|
+
# Add intercept
|
|
269
|
+
X_with_intercept = np.column_stack([np.ones(n), X])
|
|
270
|
+
|
|
271
|
+
def neg_log_likelihood(beta: np.ndarray) -> float:
|
|
272
|
+
z = X_with_intercept @ beta
|
|
273
|
+
# Clip to prevent overflow
|
|
274
|
+
z = np.clip(z, -500, 500)
|
|
275
|
+
log_lik = np.sum(y * z - np.log(1 + np.exp(z)))
|
|
276
|
+
return -log_lik
|
|
277
|
+
|
|
278
|
+
def gradient(beta: np.ndarray) -> np.ndarray:
|
|
279
|
+
z = X_with_intercept @ beta
|
|
280
|
+
z = np.clip(z, -500, 500)
|
|
281
|
+
probs = 1 / (1 + np.exp(-z))
|
|
282
|
+
return -X_with_intercept.T @ (y - probs)
|
|
283
|
+
|
|
284
|
+
# Initialize with zeros
|
|
285
|
+
beta_init = np.zeros(p + 1)
|
|
286
|
+
|
|
287
|
+
result = optimize.minimize(
|
|
288
|
+
neg_log_likelihood,
|
|
289
|
+
beta_init,
|
|
290
|
+
method='BFGS',
|
|
291
|
+
jac=gradient,
|
|
292
|
+
options={'maxiter': max_iter, 'gtol': tol}
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
beta = result.x
|
|
296
|
+
z = X_with_intercept @ beta
|
|
297
|
+
z = np.clip(z, -500, 500)
|
|
298
|
+
probs = 1 / (1 + np.exp(-z))
|
|
299
|
+
|
|
300
|
+
return beta, probs
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _linear_regression(
|
|
304
|
+
X: np.ndarray,
|
|
305
|
+
y: np.ndarray,
|
|
306
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
307
|
+
"""
|
|
308
|
+
Fit OLS regression.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
X : np.ndarray
|
|
313
|
+
Feature matrix (n_samples, n_features). Intercept added automatically.
|
|
314
|
+
y : np.ndarray
|
|
315
|
+
Outcome variable.
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
beta : np.ndarray
|
|
320
|
+
Fitted coefficients (including intercept).
|
|
321
|
+
residuals : np.ndarray
|
|
322
|
+
Residuals from the fit.
|
|
323
|
+
"""
|
|
324
|
+
n = X.shape[0]
|
|
325
|
+
# Add intercept
|
|
326
|
+
X_with_intercept = np.column_stack([np.ones(n), X])
|
|
327
|
+
|
|
328
|
+
# Use unified OLS backend (no vcov needed)
|
|
329
|
+
beta, residuals, _ = solve_ols(X_with_intercept, y, return_vcov=False)
|
|
330
|
+
|
|
331
|
+
return beta, residuals
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@dataclass
|
|
335
|
+
class GroupTimeEffect:
|
|
336
|
+
"""
|
|
337
|
+
Treatment effect for a specific group-time combination.
|
|
338
|
+
|
|
339
|
+
Attributes
|
|
340
|
+
----------
|
|
341
|
+
group : any
|
|
342
|
+
The treatment cohort (first treatment period).
|
|
343
|
+
time : any
|
|
344
|
+
The time period.
|
|
345
|
+
effect : float
|
|
346
|
+
The ATT(g,t) estimate.
|
|
347
|
+
se : float
|
|
348
|
+
Standard error.
|
|
349
|
+
n_treated : int
|
|
350
|
+
Number of treated observations.
|
|
351
|
+
n_control : int
|
|
352
|
+
Number of control observations.
|
|
353
|
+
"""
|
|
354
|
+
group: Any
|
|
355
|
+
time: Any
|
|
356
|
+
effect: float
|
|
357
|
+
se: float
|
|
358
|
+
t_stat: float
|
|
359
|
+
p_value: float
|
|
360
|
+
conf_int: Tuple[float, float]
|
|
361
|
+
n_treated: int
|
|
362
|
+
n_control: int
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def is_significant(self) -> bool:
|
|
366
|
+
"""Check if effect is significant at 0.05 level."""
|
|
367
|
+
return bool(self.p_value < 0.05)
|
|
368
|
+
|
|
369
|
+
@property
|
|
370
|
+
def significance_stars(self) -> str:
|
|
371
|
+
"""Return significance stars based on p-value."""
|
|
372
|
+
return _get_significance_stars(self.p_value)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@dataclass
|
|
376
|
+
class CallawaySantAnnaResults:
|
|
377
|
+
"""
|
|
378
|
+
Results from Callaway-Sant'Anna (2021) staggered DiD estimation.
|
|
379
|
+
|
|
380
|
+
This class stores group-time average treatment effects ATT(g,t) and
|
|
381
|
+
provides methods for aggregation into summary measures.
|
|
382
|
+
|
|
383
|
+
Attributes
|
|
384
|
+
----------
|
|
385
|
+
group_time_effects : dict
|
|
386
|
+
Dictionary mapping (group, time) tuples to effect dictionaries.
|
|
387
|
+
overall_att : float
|
|
388
|
+
Overall average treatment effect (weighted average of ATT(g,t)).
|
|
389
|
+
overall_se : float
|
|
390
|
+
Standard error of overall ATT.
|
|
391
|
+
overall_p_value : float
|
|
392
|
+
P-value for overall ATT.
|
|
393
|
+
overall_conf_int : tuple
|
|
394
|
+
Confidence interval for overall ATT.
|
|
395
|
+
groups : list
|
|
396
|
+
List of treatment cohorts (first treatment periods).
|
|
397
|
+
time_periods : list
|
|
398
|
+
List of all time periods.
|
|
399
|
+
n_obs : int
|
|
400
|
+
Total number of observations.
|
|
401
|
+
n_treated_units : int
|
|
402
|
+
Number of ever-treated units.
|
|
403
|
+
n_control_units : int
|
|
404
|
+
Number of never-treated units.
|
|
405
|
+
event_study_effects : dict, optional
|
|
406
|
+
Effects aggregated by relative time (event study).
|
|
407
|
+
group_effects : dict, optional
|
|
408
|
+
Effects aggregated by treatment cohort.
|
|
409
|
+
"""
|
|
410
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]]
|
|
411
|
+
overall_att: float
|
|
412
|
+
overall_se: float
|
|
413
|
+
overall_t_stat: float
|
|
414
|
+
overall_p_value: float
|
|
415
|
+
overall_conf_int: Tuple[float, float]
|
|
416
|
+
groups: List[Any]
|
|
417
|
+
time_periods: List[Any]
|
|
418
|
+
n_obs: int
|
|
419
|
+
n_treated_units: int
|
|
420
|
+
n_control_units: int
|
|
421
|
+
alpha: float = 0.05
|
|
422
|
+
control_group: str = "never_treated"
|
|
423
|
+
event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
|
|
424
|
+
group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
|
|
425
|
+
influence_functions: Optional[np.ndarray] = field(default=None, repr=False)
|
|
426
|
+
bootstrap_results: Optional[CSBootstrapResults] = field(default=None, repr=False)
|
|
427
|
+
|
|
428
|
+
def __repr__(self) -> str:
|
|
429
|
+
"""Concise string representation."""
|
|
430
|
+
sig = _get_significance_stars(self.overall_p_value)
|
|
431
|
+
return (
|
|
432
|
+
f"CallawaySantAnnaResults(ATT={self.overall_att:.4f}{sig}, "
|
|
433
|
+
f"SE={self.overall_se:.4f}, "
|
|
434
|
+
f"n_groups={len(self.groups)}, "
|
|
435
|
+
f"n_periods={len(self.time_periods)})"
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def summary(self, alpha: Optional[float] = None) -> str:
|
|
439
|
+
"""
|
|
440
|
+
Generate formatted summary of estimation results.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
alpha : float, optional
|
|
445
|
+
Significance level. Defaults to alpha used in estimation.
|
|
446
|
+
|
|
447
|
+
Returns
|
|
448
|
+
-------
|
|
449
|
+
str
|
|
450
|
+
Formatted summary.
|
|
451
|
+
"""
|
|
452
|
+
alpha = alpha or self.alpha
|
|
453
|
+
conf_level = int((1 - alpha) * 100)
|
|
454
|
+
|
|
455
|
+
lines = [
|
|
456
|
+
"=" * 85,
|
|
457
|
+
"Callaway-Sant'Anna Staggered Difference-in-Differences Results".center(85),
|
|
458
|
+
"=" * 85,
|
|
459
|
+
"",
|
|
460
|
+
f"{'Total observations:':<30} {self.n_obs:>10}",
|
|
461
|
+
f"{'Treated units:':<30} {self.n_treated_units:>10}",
|
|
462
|
+
f"{'Control units:':<30} {self.n_control_units:>10}",
|
|
463
|
+
f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
|
|
464
|
+
f"{'Time periods:':<30} {len(self.time_periods):>10}",
|
|
465
|
+
f"{'Control group:':<30} {self.control_group:>10}",
|
|
466
|
+
"",
|
|
467
|
+
]
|
|
468
|
+
|
|
469
|
+
# Overall ATT
|
|
470
|
+
lines.extend([
|
|
471
|
+
"-" * 85,
|
|
472
|
+
"Overall Average Treatment Effect on the Treated".center(85),
|
|
473
|
+
"-" * 85,
|
|
474
|
+
f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
|
|
475
|
+
"-" * 85,
|
|
476
|
+
f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
|
|
477
|
+
f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
|
|
478
|
+
f"{_get_significance_stars(self.overall_p_value):>6}",
|
|
479
|
+
"-" * 85,
|
|
480
|
+
"",
|
|
481
|
+
f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
|
|
482
|
+
"",
|
|
483
|
+
])
|
|
484
|
+
|
|
485
|
+
# Event study effects if available
|
|
486
|
+
if self.event_study_effects:
|
|
487
|
+
lines.extend([
|
|
488
|
+
"-" * 85,
|
|
489
|
+
"Event Study (Dynamic) Effects".center(85),
|
|
490
|
+
"-" * 85,
|
|
491
|
+
f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
|
|
492
|
+
"-" * 85,
|
|
493
|
+
])
|
|
494
|
+
|
|
495
|
+
for rel_t in sorted(self.event_study_effects.keys()):
|
|
496
|
+
eff = self.event_study_effects[rel_t]
|
|
497
|
+
sig = _get_significance_stars(eff['p_value'])
|
|
498
|
+
lines.append(
|
|
499
|
+
f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
|
|
500
|
+
f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
lines.extend(["-" * 85, ""])
|
|
504
|
+
|
|
505
|
+
# Group effects if available
|
|
506
|
+
if self.group_effects:
|
|
507
|
+
lines.extend([
|
|
508
|
+
"-" * 85,
|
|
509
|
+
"Effects by Treatment Cohort".center(85),
|
|
510
|
+
"-" * 85,
|
|
511
|
+
f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
|
|
512
|
+
"-" * 85,
|
|
513
|
+
])
|
|
514
|
+
|
|
515
|
+
for group in sorted(self.group_effects.keys()):
|
|
516
|
+
eff = self.group_effects[group]
|
|
517
|
+
sig = _get_significance_stars(eff['p_value'])
|
|
518
|
+
lines.append(
|
|
519
|
+
f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
|
|
520
|
+
f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
lines.extend(["-" * 85, ""])
|
|
524
|
+
|
|
525
|
+
lines.extend([
|
|
526
|
+
"Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
|
|
527
|
+
"=" * 85,
|
|
528
|
+
])
|
|
529
|
+
|
|
530
|
+
return "\n".join(lines)
|
|
531
|
+
|
|
532
|
+
def print_summary(self, alpha: Optional[float] = None) -> None:
|
|
533
|
+
"""Print summary to stdout."""
|
|
534
|
+
print(self.summary(alpha))
|
|
535
|
+
|
|
536
|
+
def to_dataframe(self, level: str = "group_time") -> pd.DataFrame:
|
|
537
|
+
"""
|
|
538
|
+
Convert results to DataFrame.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
level : str, default="group_time"
|
|
543
|
+
Level of aggregation: "group_time", "event_study", or "group".
|
|
544
|
+
|
|
545
|
+
Returns
|
|
546
|
+
-------
|
|
547
|
+
pd.DataFrame
|
|
548
|
+
Results as DataFrame.
|
|
549
|
+
"""
|
|
550
|
+
if level == "group_time":
|
|
551
|
+
rows = []
|
|
552
|
+
for (g, t), data in self.group_time_effects.items():
|
|
553
|
+
rows.append({
|
|
554
|
+
'group': g,
|
|
555
|
+
'time': t,
|
|
556
|
+
'effect': data['effect'],
|
|
557
|
+
'se': data['se'],
|
|
558
|
+
't_stat': data['t_stat'],
|
|
559
|
+
'p_value': data['p_value'],
|
|
560
|
+
'conf_int_lower': data['conf_int'][0],
|
|
561
|
+
'conf_int_upper': data['conf_int'][1],
|
|
562
|
+
})
|
|
563
|
+
return pd.DataFrame(rows)
|
|
564
|
+
|
|
565
|
+
elif level == "event_study":
|
|
566
|
+
if self.event_study_effects is None:
|
|
567
|
+
raise ValueError("Event study effects not computed. Use aggregate='event_study'.")
|
|
568
|
+
rows = []
|
|
569
|
+
for rel_t, data in sorted(self.event_study_effects.items()):
|
|
570
|
+
rows.append({
|
|
571
|
+
'relative_period': rel_t,
|
|
572
|
+
'effect': data['effect'],
|
|
573
|
+
'se': data['se'],
|
|
574
|
+
't_stat': data['t_stat'],
|
|
575
|
+
'p_value': data['p_value'],
|
|
576
|
+
'conf_int_lower': data['conf_int'][0],
|
|
577
|
+
'conf_int_upper': data['conf_int'][1],
|
|
578
|
+
})
|
|
579
|
+
return pd.DataFrame(rows)
|
|
580
|
+
|
|
581
|
+
elif level == "group":
|
|
582
|
+
if self.group_effects is None:
|
|
583
|
+
raise ValueError("Group effects not computed. Use aggregate='group'.")
|
|
584
|
+
rows = []
|
|
585
|
+
for group, data in sorted(self.group_effects.items()):
|
|
586
|
+
rows.append({
|
|
587
|
+
'group': group,
|
|
588
|
+
'effect': data['effect'],
|
|
589
|
+
'se': data['se'],
|
|
590
|
+
't_stat': data['t_stat'],
|
|
591
|
+
'p_value': data['p_value'],
|
|
592
|
+
'conf_int_lower': data['conf_int'][0],
|
|
593
|
+
'conf_int_upper': data['conf_int'][1],
|
|
594
|
+
})
|
|
595
|
+
return pd.DataFrame(rows)
|
|
596
|
+
|
|
597
|
+
else:
|
|
598
|
+
raise ValueError(f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'.")
|
|
599
|
+
|
|
600
|
+
@property
|
|
601
|
+
def is_significant(self) -> bool:
|
|
602
|
+
"""Check if overall ATT is significant."""
|
|
603
|
+
return bool(self.overall_p_value < self.alpha)
|
|
604
|
+
|
|
605
|
+
@property
|
|
606
|
+
def significance_stars(self) -> str:
|
|
607
|
+
"""Significance stars for overall ATT."""
|
|
608
|
+
return _get_significance_stars(self.overall_p_value)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
class CallawaySantAnna:
|
|
612
|
+
"""
|
|
613
|
+
Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
|
|
614
|
+
|
|
615
|
+
This estimator handles DiD designs with variation in treatment timing
|
|
616
|
+
(staggered adoption) and heterogeneous treatment effects. It avoids the
|
|
617
|
+
bias of traditional two-way fixed effects (TWFE) estimators by:
|
|
618
|
+
|
|
619
|
+
1. Computing group-time average treatment effects ATT(g,t) for each
|
|
620
|
+
cohort g (units first treated in period g) and time t.
|
|
621
|
+
2. Aggregating these to summary measures (overall ATT, event study, etc.)
|
|
622
|
+
using appropriate weights.
|
|
623
|
+
|
|
624
|
+
Parameters
|
|
625
|
+
----------
|
|
626
|
+
control_group : str, default="never_treated"
|
|
627
|
+
Which units to use as controls:
|
|
628
|
+
- "never_treated": Use only never-treated units (recommended)
|
|
629
|
+
- "not_yet_treated": Use never-treated and not-yet-treated units
|
|
630
|
+
anticipation : int, default=0
|
|
631
|
+
Number of periods before treatment where effects may occur.
|
|
632
|
+
Set to > 0 if treatment effects can begin before the official
|
|
633
|
+
treatment date.
|
|
634
|
+
estimation_method : str, default="dr"
|
|
635
|
+
Estimation method:
|
|
636
|
+
- "dr": Doubly robust (recommended)
|
|
637
|
+
- "ipw": Inverse probability weighting
|
|
638
|
+
- "reg": Outcome regression
|
|
639
|
+
alpha : float, default=0.05
|
|
640
|
+
Significance level for confidence intervals.
|
|
641
|
+
cluster : str, optional
|
|
642
|
+
Column name for cluster-robust standard errors.
|
|
643
|
+
Defaults to unit-level clustering.
|
|
644
|
+
n_bootstrap : int, default=0
|
|
645
|
+
Number of bootstrap iterations for inference.
|
|
646
|
+
If 0, uses analytical standard errors.
|
|
647
|
+
Recommended: 999 or more for reliable inference.
|
|
648
|
+
|
|
649
|
+
.. note:: Memory Usage
|
|
650
|
+
The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
|
|
651
|
+
float64 array. For large datasets, this can be significant:
|
|
652
|
+
- 1K bootstrap × 10K units = ~80 MB
|
|
653
|
+
- 10K bootstrap × 100K units = ~8 GB
|
|
654
|
+
Consider reducing n_bootstrap if memory is constrained.
|
|
655
|
+
|
|
656
|
+
bootstrap_weights : str, default="rademacher"
|
|
657
|
+
Type of weights for multiplier bootstrap:
|
|
658
|
+
- "rademacher": +1/-1 with equal probability (standard choice)
|
|
659
|
+
- "mammen": Two-point distribution (asymptotically valid, matches skewness)
|
|
660
|
+
- "webb": Six-point distribution (recommended when n_clusters < 20)
|
|
661
|
+
bootstrap_weight_type : str, optional
|
|
662
|
+
.. deprecated:: 1.0.1
|
|
663
|
+
Use ``bootstrap_weights`` instead. Will be removed in v2.0.
|
|
664
|
+
seed : int, optional
|
|
665
|
+
Random seed for reproducibility.
|
|
666
|
+
|
|
667
|
+
Attributes
|
|
668
|
+
----------
|
|
669
|
+
results_ : CallawaySantAnnaResults
|
|
670
|
+
Estimation results after calling fit().
|
|
671
|
+
is_fitted_ : bool
|
|
672
|
+
Whether the model has been fitted.
|
|
673
|
+
|
|
674
|
+
Examples
|
|
675
|
+
--------
|
|
676
|
+
Basic usage:
|
|
677
|
+
|
|
678
|
+
>>> import pandas as pd
|
|
679
|
+
>>> from diff_diff import CallawaySantAnna
|
|
680
|
+
>>>
|
|
681
|
+
>>> # Panel data with staggered treatment
|
|
682
|
+
>>> # 'first_treat' = period when unit was first treated (0 if never treated)
|
|
683
|
+
>>> data = pd.DataFrame({
|
|
684
|
+
... 'unit': [...],
|
|
685
|
+
... 'time': [...],
|
|
686
|
+
... 'outcome': [...],
|
|
687
|
+
... 'first_treat': [...] # 0 for never-treated, else first treatment period
|
|
688
|
+
... })
|
|
689
|
+
>>>
|
|
690
|
+
>>> cs = CallawaySantAnna()
|
|
691
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
692
|
+
... time='time', first_treat='first_treat')
|
|
693
|
+
>>>
|
|
694
|
+
>>> results.print_summary()
|
|
695
|
+
|
|
696
|
+
With event study aggregation:
|
|
697
|
+
|
|
698
|
+
>>> cs = CallawaySantAnna()
|
|
699
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
700
|
+
... time='time', first_treat='first_treat',
|
|
701
|
+
... aggregate='event_study')
|
|
702
|
+
>>>
|
|
703
|
+
>>> # Plot event study
|
|
704
|
+
>>> from diff_diff import plot_event_study
|
|
705
|
+
>>> plot_event_study(results)
|
|
706
|
+
|
|
707
|
+
With covariate adjustment (conditional parallel trends):
|
|
708
|
+
|
|
709
|
+
>>> # When parallel trends only holds conditional on covariates
|
|
710
|
+
>>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
|
|
711
|
+
>>> results = cs.fit(data, outcome='outcome', unit='unit',
|
|
712
|
+
... time='time', first_treat='first_treat',
|
|
713
|
+
... covariates=['age', 'income'])
|
|
714
|
+
>>>
|
|
715
|
+
>>> # DR is recommended: consistent if either outcome model
|
|
716
|
+
>>> # or propensity model is correctly specified
|
|
717
|
+
|
|
718
|
+
Notes
|
|
719
|
+
-----
|
|
720
|
+
The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
|
|
721
|
+
approach: instead of estimating a single treatment effect, they estimate
|
|
722
|
+
ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
|
|
723
|
+
problem where already-treated units act as controls.
|
|
724
|
+
|
|
725
|
+
The ATT(g,t) is identified under parallel trends conditional on covariates:
|
|
726
|
+
|
|
727
|
+
E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
|
|
728
|
+
|
|
729
|
+
where G=g indicates treatment cohort g and C=1 indicates control units.
|
|
730
|
+
|
|
731
|
+
References
|
|
732
|
+
----------
|
|
733
|
+
Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
|
|
734
|
+
multiple time periods. Journal of Econometrics, 225(2), 200-230.
|
|
735
|
+
"""
|
|
736
|
+
|
|
737
|
+
def __init__(
|
|
738
|
+
self,
|
|
739
|
+
control_group: str = "never_treated",
|
|
740
|
+
anticipation: int = 0,
|
|
741
|
+
estimation_method: str = "dr",
|
|
742
|
+
alpha: float = 0.05,
|
|
743
|
+
cluster: Optional[str] = None,
|
|
744
|
+
n_bootstrap: int = 0,
|
|
745
|
+
bootstrap_weights: Optional[str] = None,
|
|
746
|
+
bootstrap_weight_type: Optional[str] = None,
|
|
747
|
+
seed: Optional[int] = None,
|
|
748
|
+
):
|
|
749
|
+
import warnings
|
|
750
|
+
|
|
751
|
+
if control_group not in ["never_treated", "not_yet_treated"]:
|
|
752
|
+
raise ValueError(
|
|
753
|
+
f"control_group must be 'never_treated' or 'not_yet_treated', "
|
|
754
|
+
f"got '{control_group}'"
|
|
755
|
+
)
|
|
756
|
+
if estimation_method not in ["dr", "ipw", "reg"]:
|
|
757
|
+
raise ValueError(
|
|
758
|
+
f"estimation_method must be 'dr', 'ipw', or 'reg', "
|
|
759
|
+
f"got '{estimation_method}'"
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Handle bootstrap_weight_type deprecation
|
|
763
|
+
if bootstrap_weight_type is not None:
|
|
764
|
+
warnings.warn(
|
|
765
|
+
"bootstrap_weight_type is deprecated and will be removed in v2.0. "
|
|
766
|
+
"Use bootstrap_weights instead.",
|
|
767
|
+
DeprecationWarning,
|
|
768
|
+
stacklevel=2
|
|
769
|
+
)
|
|
770
|
+
if bootstrap_weights is None:
|
|
771
|
+
bootstrap_weights = bootstrap_weight_type
|
|
772
|
+
|
|
773
|
+
# Default to rademacher if neither specified
|
|
774
|
+
if bootstrap_weights is None:
|
|
775
|
+
bootstrap_weights = "rademacher"
|
|
776
|
+
|
|
777
|
+
if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
|
|
778
|
+
raise ValueError(
|
|
779
|
+
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
|
|
780
|
+
f"got '{bootstrap_weights}'"
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
self.control_group = control_group
|
|
784
|
+
self.anticipation = anticipation
|
|
785
|
+
self.estimation_method = estimation_method
|
|
786
|
+
self.alpha = alpha
|
|
787
|
+
self.cluster = cluster
|
|
788
|
+
self.n_bootstrap = n_bootstrap
|
|
789
|
+
self.bootstrap_weights = bootstrap_weights
|
|
790
|
+
# Keep bootstrap_weight_type for backward compatibility
|
|
791
|
+
self.bootstrap_weight_type = bootstrap_weights
|
|
792
|
+
self.seed = seed
|
|
793
|
+
|
|
794
|
+
self.is_fitted_ = False
|
|
795
|
+
self.results_ = None
|
|
796
|
+
|
|
797
|
+
def _precompute_structures(
|
|
798
|
+
self,
|
|
799
|
+
df: pd.DataFrame,
|
|
800
|
+
outcome: str,
|
|
801
|
+
unit: str,
|
|
802
|
+
time: str,
|
|
803
|
+
first_treat: str,
|
|
804
|
+
covariates: Optional[List[str]],
|
|
805
|
+
time_periods: List[Any],
|
|
806
|
+
treatment_groups: List[Any],
|
|
807
|
+
) -> PrecomputedData:
|
|
808
|
+
"""
|
|
809
|
+
Pre-compute data structures for efficient ATT(g,t) computation.
|
|
810
|
+
|
|
811
|
+
This pivots data to wide format and pre-computes:
|
|
812
|
+
- Outcome matrix (units x time periods)
|
|
813
|
+
- Covariate matrix (units x covariates) from base period
|
|
814
|
+
- Unit cohort membership masks
|
|
815
|
+
- Control unit masks
|
|
816
|
+
|
|
817
|
+
Returns
|
|
818
|
+
-------
|
|
819
|
+
PrecomputedData
|
|
820
|
+
Dictionary with pre-computed structures.
|
|
821
|
+
"""
|
|
822
|
+
# Get unique units and their cohort assignments
|
|
823
|
+
unit_info = df.groupby(unit)[first_treat].first()
|
|
824
|
+
all_units = unit_info.index.values
|
|
825
|
+
unit_cohorts = unit_info.values
|
|
826
|
+
n_units = len(all_units)
|
|
827
|
+
|
|
828
|
+
# Create unit index mapping for fast lookups
|
|
829
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
830
|
+
|
|
831
|
+
# Pivot outcome to wide format: rows = units, columns = time periods
|
|
832
|
+
outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
|
|
833
|
+
# Reindex to ensure all units are present (handles unbalanced panels)
|
|
834
|
+
outcome_wide = outcome_wide.reindex(all_units)
|
|
835
|
+
outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
|
|
836
|
+
period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
|
|
837
|
+
|
|
838
|
+
# Pre-compute cohort masks (boolean arrays)
|
|
839
|
+
cohort_masks = {}
|
|
840
|
+
for g in treatment_groups:
|
|
841
|
+
cohort_masks[g] = (unit_cohorts == g)
|
|
842
|
+
|
|
843
|
+
# Never-treated mask
|
|
844
|
+
never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
|
|
845
|
+
|
|
846
|
+
# Pre-compute covariate matrices by time period if needed
|
|
847
|
+
# (covariates are retrieved from the base period of each comparison)
|
|
848
|
+
covariate_by_period = None
|
|
849
|
+
if covariates:
|
|
850
|
+
covariate_by_period = {}
|
|
851
|
+
for t in time_periods:
|
|
852
|
+
period_data = df[df[time] == t].set_index(unit)
|
|
853
|
+
period_cov = period_data.reindex(all_units)[covariates]
|
|
854
|
+
covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
|
|
855
|
+
|
|
856
|
+
return {
|
|
857
|
+
'all_units': all_units,
|
|
858
|
+
'unit_to_idx': unit_to_idx,
|
|
859
|
+
'unit_cohorts': unit_cohorts,
|
|
860
|
+
'outcome_matrix': outcome_matrix,
|
|
861
|
+
'period_to_col': period_to_col,
|
|
862
|
+
'cohort_masks': cohort_masks,
|
|
863
|
+
'never_treated_mask': never_treated_mask,
|
|
864
|
+
'covariate_by_period': covariate_by_period,
|
|
865
|
+
'time_periods': time_periods,
|
|
866
|
+
}
|
|
867
|
+
|
|
868
|
+
def _compute_att_gt_fast(
|
|
869
|
+
self,
|
|
870
|
+
precomputed: PrecomputedData,
|
|
871
|
+
g: Any,
|
|
872
|
+
t: Any,
|
|
873
|
+
covariates: Optional[List[str]],
|
|
874
|
+
) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]:
|
|
875
|
+
"""
|
|
876
|
+
Compute ATT(g,t) using pre-computed data structures (fast version).
|
|
877
|
+
|
|
878
|
+
Uses vectorized numpy operations on pre-pivoted outcome matrix
|
|
879
|
+
instead of repeated pandas filtering.
|
|
880
|
+
"""
|
|
881
|
+
time_periods = precomputed['time_periods']
|
|
882
|
+
period_to_col = precomputed['period_to_col']
|
|
883
|
+
outcome_matrix = precomputed['outcome_matrix']
|
|
884
|
+
cohort_masks = precomputed['cohort_masks']
|
|
885
|
+
never_treated_mask = precomputed['never_treated_mask']
|
|
886
|
+
unit_cohorts = precomputed['unit_cohorts']
|
|
887
|
+
all_units = precomputed['all_units']
|
|
888
|
+
covariate_by_period = precomputed['covariate_by_period']
|
|
889
|
+
|
|
890
|
+
# Base period for comparison
|
|
891
|
+
base_period = g - 1 - self.anticipation
|
|
892
|
+
if base_period not in period_to_col:
|
|
893
|
+
# Find closest earlier period
|
|
894
|
+
earlier = [p for p in time_periods if p < g - self.anticipation]
|
|
895
|
+
if not earlier:
|
|
896
|
+
return None, 0.0, 0, 0, None
|
|
897
|
+
base_period = max(earlier)
|
|
898
|
+
|
|
899
|
+
# Check if periods exist in the data
|
|
900
|
+
if base_period not in period_to_col or t not in period_to_col:
|
|
901
|
+
return None, 0.0, 0, 0, None
|
|
902
|
+
|
|
903
|
+
base_col = period_to_col[base_period]
|
|
904
|
+
post_col = period_to_col[t]
|
|
905
|
+
|
|
906
|
+
# Get treated units mask (cohort g)
|
|
907
|
+
treated_mask = cohort_masks[g]
|
|
908
|
+
|
|
909
|
+
# Get control units mask
|
|
910
|
+
if self.control_group == "never_treated":
|
|
911
|
+
control_mask = never_treated_mask
|
|
912
|
+
else: # not_yet_treated
|
|
913
|
+
# Not yet treated at time t: never-treated OR first_treat > t
|
|
914
|
+
control_mask = never_treated_mask | (unit_cohorts > t)
|
|
915
|
+
|
|
916
|
+
# Extract outcomes for base and post periods
|
|
917
|
+
y_base = outcome_matrix[:, base_col]
|
|
918
|
+
y_post = outcome_matrix[:, post_col]
|
|
919
|
+
|
|
920
|
+
# Compute outcome changes (vectorized)
|
|
921
|
+
outcome_change = y_post - y_base
|
|
922
|
+
|
|
923
|
+
# Filter to units with valid data (no NaN in either period)
|
|
924
|
+
valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
|
|
925
|
+
|
|
926
|
+
# Get treated and control with valid data
|
|
927
|
+
treated_valid = treated_mask & valid_mask
|
|
928
|
+
control_valid = control_mask & valid_mask
|
|
929
|
+
|
|
930
|
+
n_treated = np.sum(treated_valid)
|
|
931
|
+
n_control = np.sum(control_valid)
|
|
932
|
+
|
|
933
|
+
if n_treated == 0 or n_control == 0:
|
|
934
|
+
return None, 0.0, 0, 0, None
|
|
935
|
+
|
|
936
|
+
# Extract outcome changes for treated and control
|
|
937
|
+
treated_change = outcome_change[treated_valid]
|
|
938
|
+
control_change = outcome_change[control_valid]
|
|
939
|
+
|
|
940
|
+
# Get unit IDs for influence function
|
|
941
|
+
treated_units = all_units[treated_valid]
|
|
942
|
+
control_units = all_units[control_valid]
|
|
943
|
+
|
|
944
|
+
# Get covariates if specified (from the base period)
|
|
945
|
+
X_treated = None
|
|
946
|
+
X_control = None
|
|
947
|
+
if covariates and covariate_by_period is not None:
|
|
948
|
+
cov_matrix = covariate_by_period[base_period]
|
|
949
|
+
X_treated = cov_matrix[treated_valid]
|
|
950
|
+
X_control = cov_matrix[control_valid]
|
|
951
|
+
|
|
952
|
+
# Check for missing values
|
|
953
|
+
if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
|
|
954
|
+
warnings.warn(
|
|
955
|
+
f"Missing values in covariates for group {g}, time {t}. "
|
|
956
|
+
"Falling back to unconditional estimation.",
|
|
957
|
+
UserWarning,
|
|
958
|
+
stacklevel=3,
|
|
959
|
+
)
|
|
960
|
+
X_treated = None
|
|
961
|
+
X_control = None
|
|
962
|
+
|
|
963
|
+
# Estimation method
|
|
964
|
+
if self.estimation_method == "reg":
|
|
965
|
+
att_gt, se_gt, inf_func = self._outcome_regression(
|
|
966
|
+
treated_change, control_change, X_treated, X_control
|
|
967
|
+
)
|
|
968
|
+
elif self.estimation_method == "ipw":
|
|
969
|
+
att_gt, se_gt, inf_func = self._ipw_estimation(
|
|
970
|
+
treated_change, control_change,
|
|
971
|
+
int(n_treated), int(n_control),
|
|
972
|
+
X_treated, X_control
|
|
973
|
+
)
|
|
974
|
+
else: # doubly robust
|
|
975
|
+
att_gt, se_gt, inf_func = self._doubly_robust(
|
|
976
|
+
treated_change, control_change, X_treated, X_control
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Package influence function info with unit IDs for bootstrap
|
|
980
|
+
n_t = int(n_treated)
|
|
981
|
+
inf_func_info = {
|
|
982
|
+
'treated_units': list(treated_units),
|
|
983
|
+
'control_units': list(control_units),
|
|
984
|
+
'treated_inf': inf_func[:n_t],
|
|
985
|
+
'control_inf': inf_func[n_t:],
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info
|
|
989
|
+
|
|
990
|
+
def fit(
|
|
991
|
+
self,
|
|
992
|
+
data: pd.DataFrame,
|
|
993
|
+
outcome: str,
|
|
994
|
+
unit: str,
|
|
995
|
+
time: str,
|
|
996
|
+
first_treat: str,
|
|
997
|
+
covariates: Optional[List[str]] = None,
|
|
998
|
+
aggregate: Optional[str] = None,
|
|
999
|
+
balance_e: Optional[int] = None,
|
|
1000
|
+
) -> CallawaySantAnnaResults:
|
|
1001
|
+
"""
|
|
1002
|
+
Fit the Callaway-Sant'Anna estimator.
|
|
1003
|
+
|
|
1004
|
+
Parameters
|
|
1005
|
+
----------
|
|
1006
|
+
data : pd.DataFrame
|
|
1007
|
+
Panel data with unit and time identifiers.
|
|
1008
|
+
outcome : str
|
|
1009
|
+
Name of outcome variable column.
|
|
1010
|
+
unit : str
|
|
1011
|
+
Name of unit identifier column.
|
|
1012
|
+
time : str
|
|
1013
|
+
Name of time period column.
|
|
1014
|
+
first_treat : str
|
|
1015
|
+
Name of column indicating when unit was first treated.
|
|
1016
|
+
Use 0 (or np.inf) for never-treated units.
|
|
1017
|
+
covariates : list, optional
|
|
1018
|
+
List of covariate column names for conditional parallel trends.
|
|
1019
|
+
aggregate : str, optional
|
|
1020
|
+
How to aggregate group-time effects:
|
|
1021
|
+
- None: Only compute ATT(g,t) (default)
|
|
1022
|
+
- "simple": Simple weighted average (overall ATT)
|
|
1023
|
+
- "event_study": Aggregate by relative time (event study)
|
|
1024
|
+
- "group": Aggregate by treatment cohort
|
|
1025
|
+
- "all": Compute all aggregations
|
|
1026
|
+
balance_e : int, optional
|
|
1027
|
+
For event study, balance the panel at relative time e.
|
|
1028
|
+
Ensures all groups contribute to each relative period.
|
|
1029
|
+
|
|
1030
|
+
Returns
|
|
1031
|
+
-------
|
|
1032
|
+
CallawaySantAnnaResults
|
|
1033
|
+
Object containing all estimation results.
|
|
1034
|
+
|
|
1035
|
+
Raises
|
|
1036
|
+
------
|
|
1037
|
+
ValueError
|
|
1038
|
+
If required columns are missing or data validation fails.
|
|
1039
|
+
"""
|
|
1040
|
+
# Validate inputs
|
|
1041
|
+
required_cols = [outcome, unit, time, first_treat]
|
|
1042
|
+
if covariates:
|
|
1043
|
+
required_cols.extend(covariates)
|
|
1044
|
+
|
|
1045
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
1046
|
+
if missing:
|
|
1047
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
1048
|
+
|
|
1049
|
+
# Create working copy
|
|
1050
|
+
df = data.copy()
|
|
1051
|
+
|
|
1052
|
+
# Ensure numeric types
|
|
1053
|
+
df[time] = pd.to_numeric(df[time])
|
|
1054
|
+
df[first_treat] = pd.to_numeric(df[first_treat])
|
|
1055
|
+
|
|
1056
|
+
# Identify groups and time periods
|
|
1057
|
+
time_periods = sorted(df[time].unique())
|
|
1058
|
+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
|
|
1059
|
+
|
|
1060
|
+
# Never-treated indicator (first_treat = 0 or inf)
|
|
1061
|
+
df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
|
|
1062
|
+
|
|
1063
|
+
# Get unique units
|
|
1064
|
+
unit_info = df.groupby(unit).agg({
|
|
1065
|
+
first_treat: 'first',
|
|
1066
|
+
'_never_treated': 'first'
|
|
1067
|
+
}).reset_index()
|
|
1068
|
+
|
|
1069
|
+
n_treated_units = (unit_info[first_treat] > 0).sum()
|
|
1070
|
+
n_control_units = (unit_info['_never_treated']).sum()
|
|
1071
|
+
|
|
1072
|
+
if n_control_units == 0:
|
|
1073
|
+
raise ValueError("No never-treated units found. Check 'first_treat' column.")
|
|
1074
|
+
|
|
1075
|
+
# Pre-compute data structures for efficient ATT(g,t) computation
|
|
1076
|
+
precomputed = self._precompute_structures(
|
|
1077
|
+
df, outcome, unit, time, first_treat,
|
|
1078
|
+
covariates, time_periods, treatment_groups
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
# Compute ATT(g,t) for each group-time combination
|
|
1082
|
+
group_time_effects = {}
|
|
1083
|
+
influence_func_info = {} # Store influence functions for bootstrap
|
|
1084
|
+
|
|
1085
|
+
for g in treatment_groups:
|
|
1086
|
+
# Periods for which we compute effects (t >= g - anticipation)
|
|
1087
|
+
valid_periods = [t for t in time_periods if t >= g - self.anticipation]
|
|
1088
|
+
|
|
1089
|
+
for t in valid_periods:
|
|
1090
|
+
att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast(
|
|
1091
|
+
precomputed, g, t, covariates
|
|
1092
|
+
)
|
|
1093
|
+
|
|
1094
|
+
if att_gt is not None:
|
|
1095
|
+
t_stat = att_gt / se_gt if se_gt > 0 else 0.0
|
|
1096
|
+
p_val = compute_p_value(t_stat)
|
|
1097
|
+
ci = compute_confidence_interval(att_gt, se_gt, self.alpha)
|
|
1098
|
+
|
|
1099
|
+
group_time_effects[(g, t)] = {
|
|
1100
|
+
'effect': att_gt,
|
|
1101
|
+
'se': se_gt,
|
|
1102
|
+
't_stat': t_stat,
|
|
1103
|
+
'p_value': p_val,
|
|
1104
|
+
'conf_int': ci,
|
|
1105
|
+
'n_treated': n_treat,
|
|
1106
|
+
'n_control': n_ctrl,
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
if inf_info is not None:
|
|
1110
|
+
influence_func_info[(g, t)] = inf_info
|
|
1111
|
+
|
|
1112
|
+
if not group_time_effects:
|
|
1113
|
+
raise ValueError(
|
|
1114
|
+
"Could not estimate any group-time effects. "
|
|
1115
|
+
"Check that data has sufficient observations."
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
# Compute overall ATT (simple aggregation)
|
|
1119
|
+
overall_att, overall_se = self._aggregate_simple(
|
|
1120
|
+
group_time_effects, influence_func_info, df, unit, precomputed
|
|
1121
|
+
)
|
|
1122
|
+
overall_t = overall_att / overall_se if overall_se > 0 else 0.0
|
|
1123
|
+
overall_p = compute_p_value(overall_t)
|
|
1124
|
+
overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha)
|
|
1125
|
+
|
|
1126
|
+
# Compute additional aggregations if requested
|
|
1127
|
+
event_study_effects = None
|
|
1128
|
+
group_effects = None
|
|
1129
|
+
|
|
1130
|
+
if aggregate in ["event_study", "all"]:
|
|
1131
|
+
event_study_effects = self._aggregate_event_study(
|
|
1132
|
+
group_time_effects, influence_func_info,
|
|
1133
|
+
treatment_groups, time_periods, balance_e
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
if aggregate in ["group", "all"]:
|
|
1137
|
+
group_effects = self._aggregate_by_group(
|
|
1138
|
+
group_time_effects, influence_func_info, treatment_groups
|
|
1139
|
+
)
|
|
1140
|
+
|
|
1141
|
+
# Run bootstrap inference if requested
|
|
1142
|
+
bootstrap_results = None
|
|
1143
|
+
if self.n_bootstrap > 0 and influence_func_info:
|
|
1144
|
+
bootstrap_results = self._run_multiplier_bootstrap(
|
|
1145
|
+
group_time_effects=group_time_effects,
|
|
1146
|
+
influence_func_info=influence_func_info,
|
|
1147
|
+
aggregate=aggregate,
|
|
1148
|
+
balance_e=balance_e,
|
|
1149
|
+
treatment_groups=treatment_groups,
|
|
1150
|
+
time_periods=time_periods,
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
# Update estimates with bootstrap inference
|
|
1154
|
+
overall_se = bootstrap_results.overall_att_se
|
|
1155
|
+
overall_t = overall_att / overall_se if overall_se > 0 else 0.0
|
|
1156
|
+
overall_p = bootstrap_results.overall_att_p_value
|
|
1157
|
+
overall_ci = bootstrap_results.overall_att_ci
|
|
1158
|
+
|
|
1159
|
+
# Update group-time effects with bootstrap SEs
|
|
1160
|
+
for gt in group_time_effects:
|
|
1161
|
+
if gt in bootstrap_results.group_time_ses:
|
|
1162
|
+
group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt]
|
|
1163
|
+
group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt]
|
|
1164
|
+
group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt]
|
|
1165
|
+
effect = float(group_time_effects[gt]['effect'])
|
|
1166
|
+
se = float(group_time_effects[gt]['se'])
|
|
1167
|
+
group_time_effects[gt]['t_stat'] = effect / se if se > 0 else 0.0
|
|
1168
|
+
|
|
1169
|
+
# Update event study effects with bootstrap SEs
|
|
1170
|
+
if (event_study_effects is not None
|
|
1171
|
+
and bootstrap_results.event_study_ses is not None
|
|
1172
|
+
and bootstrap_results.event_study_cis is not None
|
|
1173
|
+
and bootstrap_results.event_study_p_values is not None):
|
|
1174
|
+
for e in event_study_effects:
|
|
1175
|
+
if e in bootstrap_results.event_study_ses:
|
|
1176
|
+
event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e]
|
|
1177
|
+
event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e]
|
|
1178
|
+
p_val = bootstrap_results.event_study_p_values[e]
|
|
1179
|
+
event_study_effects[e]['p_value'] = p_val
|
|
1180
|
+
effect = float(event_study_effects[e]['effect'])
|
|
1181
|
+
se = float(event_study_effects[e]['se'])
|
|
1182
|
+
event_study_effects[e]['t_stat'] = effect / se if se > 0 else 0.0
|
|
1183
|
+
|
|
1184
|
+
# Update group effects with bootstrap SEs
|
|
1185
|
+
if (group_effects is not None
|
|
1186
|
+
and bootstrap_results.group_effect_ses is not None
|
|
1187
|
+
and bootstrap_results.group_effect_cis is not None
|
|
1188
|
+
and bootstrap_results.group_effect_p_values is not None):
|
|
1189
|
+
for g in group_effects:
|
|
1190
|
+
if g in bootstrap_results.group_effect_ses:
|
|
1191
|
+
group_effects[g]['se'] = bootstrap_results.group_effect_ses[g]
|
|
1192
|
+
group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g]
|
|
1193
|
+
group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g]
|
|
1194
|
+
effect = float(group_effects[g]['effect'])
|
|
1195
|
+
se = float(group_effects[g]['se'])
|
|
1196
|
+
group_effects[g]['t_stat'] = effect / se if se > 0 else 0.0
|
|
1197
|
+
|
|
1198
|
+
# Store results
|
|
1199
|
+
self.results_ = CallawaySantAnnaResults(
|
|
1200
|
+
group_time_effects=group_time_effects,
|
|
1201
|
+
overall_att=overall_att,
|
|
1202
|
+
overall_se=overall_se,
|
|
1203
|
+
overall_t_stat=overall_t,
|
|
1204
|
+
overall_p_value=overall_p,
|
|
1205
|
+
overall_conf_int=overall_ci,
|
|
1206
|
+
groups=treatment_groups,
|
|
1207
|
+
time_periods=time_periods,
|
|
1208
|
+
n_obs=len(df),
|
|
1209
|
+
n_treated_units=n_treated_units,
|
|
1210
|
+
n_control_units=n_control_units,
|
|
1211
|
+
alpha=self.alpha,
|
|
1212
|
+
control_group=self.control_group,
|
|
1213
|
+
event_study_effects=event_study_effects,
|
|
1214
|
+
group_effects=group_effects,
|
|
1215
|
+
bootstrap_results=bootstrap_results,
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
self.is_fitted_ = True
|
|
1219
|
+
return self.results_
|
|
1220
|
+
|
|
1221
|
+
def _outcome_regression(
|
|
1222
|
+
self,
|
|
1223
|
+
treated_change: np.ndarray,
|
|
1224
|
+
control_change: np.ndarray,
|
|
1225
|
+
X_treated: Optional[np.ndarray] = None,
|
|
1226
|
+
X_control: Optional[np.ndarray] = None,
|
|
1227
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
1228
|
+
"""
|
|
1229
|
+
Estimate ATT using outcome regression.
|
|
1230
|
+
|
|
1231
|
+
With covariates:
|
|
1232
|
+
1. Regress outcome changes on covariates for control group
|
|
1233
|
+
2. Predict counterfactual for treated using their covariates
|
|
1234
|
+
3. ATT = mean(treated_change) - mean(predicted_counterfactual)
|
|
1235
|
+
|
|
1236
|
+
Without covariates:
|
|
1237
|
+
Simple difference in means.
|
|
1238
|
+
"""
|
|
1239
|
+
n_t = len(treated_change)
|
|
1240
|
+
n_c = len(control_change)
|
|
1241
|
+
|
|
1242
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
1243
|
+
# Covariate-adjusted outcome regression
|
|
1244
|
+
# Fit regression on control units: E[Delta Y | X, D=0]
|
|
1245
|
+
beta, residuals = _linear_regression(X_control, control_change)
|
|
1246
|
+
|
|
1247
|
+
# Predict counterfactual for treated units
|
|
1248
|
+
X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
|
|
1249
|
+
predicted_control = X_treated_with_intercept @ beta
|
|
1250
|
+
|
|
1251
|
+
# ATT = mean(observed treated change - predicted counterfactual)
|
|
1252
|
+
att = np.mean(treated_change - predicted_control)
|
|
1253
|
+
|
|
1254
|
+
# Standard error using sandwich estimator
|
|
1255
|
+
# Variance from treated: Var(Y_1 - m(X))
|
|
1256
|
+
treated_residuals = treated_change - predicted_control
|
|
1257
|
+
var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
|
|
1258
|
+
|
|
1259
|
+
# Variance from control regression (residual variance)
|
|
1260
|
+
var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
|
|
1261
|
+
|
|
1262
|
+
# Approximate SE (ignoring estimation error in beta for simplicity)
|
|
1263
|
+
se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
1264
|
+
|
|
1265
|
+
# Influence function
|
|
1266
|
+
inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
|
|
1267
|
+
inf_control = -residuals / n_c
|
|
1268
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
1269
|
+
else:
|
|
1270
|
+
# Simple difference in means (no covariates)
|
|
1271
|
+
att = np.mean(treated_change) - np.mean(control_change)
|
|
1272
|
+
|
|
1273
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
1274
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
1275
|
+
|
|
1276
|
+
se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
1277
|
+
|
|
1278
|
+
# Influence function (for aggregation)
|
|
1279
|
+
inf_treated = treated_change - np.mean(treated_change)
|
|
1280
|
+
inf_control = control_change - np.mean(control_change)
|
|
1281
|
+
inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
|
|
1282
|
+
|
|
1283
|
+
return att, se, inf_func
|
|
1284
|
+
|
|
1285
|
+
def _ipw_estimation(
|
|
1286
|
+
self,
|
|
1287
|
+
treated_change: np.ndarray,
|
|
1288
|
+
control_change: np.ndarray,
|
|
1289
|
+
n_treated: int,
|
|
1290
|
+
n_control: int,
|
|
1291
|
+
X_treated: Optional[np.ndarray] = None,
|
|
1292
|
+
X_control: Optional[np.ndarray] = None,
|
|
1293
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
1294
|
+
"""
|
|
1295
|
+
Estimate ATT using inverse probability weighting.
|
|
1296
|
+
|
|
1297
|
+
With covariates:
|
|
1298
|
+
1. Estimate propensity score P(D=1|X) using logistic regression
|
|
1299
|
+
2. Reweight control units to match treated covariate distribution
|
|
1300
|
+
3. ATT = mean(treated) - weighted_mean(control)
|
|
1301
|
+
|
|
1302
|
+
Without covariates:
|
|
1303
|
+
Simple difference in means with unconditional propensity weighting.
|
|
1304
|
+
"""
|
|
1305
|
+
n_t = len(treated_change)
|
|
1306
|
+
n_c = len(control_change)
|
|
1307
|
+
n_total = n_treated + n_control
|
|
1308
|
+
|
|
1309
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
1310
|
+
# Covariate-adjusted IPW estimation
|
|
1311
|
+
# Stack covariates and create treatment indicator
|
|
1312
|
+
X_all = np.vstack([X_treated, X_control])
|
|
1313
|
+
D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
1314
|
+
|
|
1315
|
+
# Estimate propensity scores using logistic regression
|
|
1316
|
+
try:
|
|
1317
|
+
_, pscore = _logistic_regression(X_all, D)
|
|
1318
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1319
|
+
# Fallback to unconditional if logistic regression fails
|
|
1320
|
+
warnings.warn(
|
|
1321
|
+
"Propensity score estimation failed. "
|
|
1322
|
+
"Falling back to unconditional estimation.",
|
|
1323
|
+
UserWarning,
|
|
1324
|
+
stacklevel=4,
|
|
1325
|
+
)
|
|
1326
|
+
pscore = np.full(len(D), n_t / (n_t + n_c))
|
|
1327
|
+
|
|
1328
|
+
# Propensity scores for treated and control
|
|
1329
|
+
pscore_treated = pscore[:n_t]
|
|
1330
|
+
pscore_control = pscore[n_t:]
|
|
1331
|
+
|
|
1332
|
+
# Clip propensity scores to avoid extreme weights
|
|
1333
|
+
pscore_control = np.clip(pscore_control, 0.01, 0.99)
|
|
1334
|
+
pscore_treated = np.clip(pscore_treated, 0.01, 0.99)
|
|
1335
|
+
|
|
1336
|
+
# IPW weights for control units: p(X) / (1 - p(X))
|
|
1337
|
+
# This reweights controls to have same covariate distribution as treated
|
|
1338
|
+
weights_control = pscore_control / (1 - pscore_control)
|
|
1339
|
+
weights_control = weights_control / np.sum(weights_control) # normalize
|
|
1340
|
+
|
|
1341
|
+
# ATT = mean(treated) - weighted_mean(control)
|
|
1342
|
+
att = np.mean(treated_change) - np.sum(weights_control * control_change)
|
|
1343
|
+
|
|
1344
|
+
# Compute standard error
|
|
1345
|
+
# Variance of treated mean
|
|
1346
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
1347
|
+
|
|
1348
|
+
# Variance of weighted control mean
|
|
1349
|
+
weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2)
|
|
1350
|
+
|
|
1351
|
+
se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
1352
|
+
|
|
1353
|
+
# Influence function
|
|
1354
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
1355
|
+
inf_control = -weights_control * (control_change - np.sum(weights_control * control_change))
|
|
1356
|
+
inf_func = np.concatenate([inf_treated, inf_control])
|
|
1357
|
+
else:
|
|
1358
|
+
# Unconditional IPW (reduces to difference in means)
|
|
1359
|
+
p_treat = n_treated / n_total # unconditional propensity score
|
|
1360
|
+
|
|
1361
|
+
att = np.mean(treated_change) - np.mean(control_change)
|
|
1362
|
+
|
|
1363
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
1364
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
1365
|
+
|
|
1366
|
+
# Adjusted variance for IPW
|
|
1367
|
+
se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0
|
|
1368
|
+
|
|
1369
|
+
# Influence function (for aggregation)
|
|
1370
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
1371
|
+
inf_control = (control_change - np.mean(control_change)) / n_c
|
|
1372
|
+
inf_func = np.concatenate([inf_treated, -inf_control])
|
|
1373
|
+
|
|
1374
|
+
return att, se, inf_func
|
|
1375
|
+
|
|
1376
|
+
def _doubly_robust(
|
|
1377
|
+
self,
|
|
1378
|
+
treated_change: np.ndarray,
|
|
1379
|
+
control_change: np.ndarray,
|
|
1380
|
+
X_treated: Optional[np.ndarray] = None,
|
|
1381
|
+
X_control: Optional[np.ndarray] = None,
|
|
1382
|
+
) -> Tuple[float, float, np.ndarray]:
|
|
1383
|
+
"""
|
|
1384
|
+
Estimate ATT using doubly robust estimation.
|
|
1385
|
+
|
|
1386
|
+
With covariates:
|
|
1387
|
+
Combines outcome regression and IPW for double robustness.
|
|
1388
|
+
The estimator is consistent if either the outcome model OR
|
|
1389
|
+
the propensity model is correctly specified.
|
|
1390
|
+
|
|
1391
|
+
ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
|
|
1392
|
+
+ (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
|
|
1393
|
+
|
|
1394
|
+
where m(X) is the outcome model and w_i are IPW weights.
|
|
1395
|
+
|
|
1396
|
+
Without covariates:
|
|
1397
|
+
Reduces to simple difference in means.
|
|
1398
|
+
"""
|
|
1399
|
+
n_t = len(treated_change)
|
|
1400
|
+
n_c = len(control_change)
|
|
1401
|
+
|
|
1402
|
+
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
|
|
1403
|
+
# Doubly robust estimation with covariates
|
|
1404
|
+
# Step 1: Outcome regression - fit E[Delta Y | X] on control
|
|
1405
|
+
beta, _ = _linear_regression(X_control, control_change)
|
|
1406
|
+
|
|
1407
|
+
# Predict counterfactual for both treated and control
|
|
1408
|
+
X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
|
|
1409
|
+
X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
|
|
1410
|
+
m_treated = X_treated_with_intercept @ beta
|
|
1411
|
+
m_control = X_control_with_intercept @ beta
|
|
1412
|
+
|
|
1413
|
+
# Step 2: Propensity score estimation
|
|
1414
|
+
X_all = np.vstack([X_treated, X_control])
|
|
1415
|
+
D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
|
|
1416
|
+
|
|
1417
|
+
try:
|
|
1418
|
+
_, pscore = _logistic_regression(X_all, D)
|
|
1419
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1420
|
+
# Fallback to unconditional if logistic regression fails
|
|
1421
|
+
pscore = np.full(len(D), n_t / (n_t + n_c))
|
|
1422
|
+
|
|
1423
|
+
pscore_control = pscore[n_t:]
|
|
1424
|
+
|
|
1425
|
+
# Clip propensity scores
|
|
1426
|
+
pscore_control = np.clip(pscore_control, 0.01, 0.99)
|
|
1427
|
+
|
|
1428
|
+
# IPW weights for control: p(X) / (1 - p(X))
|
|
1429
|
+
weights_control = pscore_control / (1 - pscore_control)
|
|
1430
|
+
|
|
1431
|
+
# Step 3: Doubly robust ATT
|
|
1432
|
+
# ATT = mean(treated - m(X_treated))
|
|
1433
|
+
# + weighted_mean_control((m(X) - Y) * weight)
|
|
1434
|
+
att_treated_part = np.mean(treated_change - m_treated)
|
|
1435
|
+
|
|
1436
|
+
# Augmentation term from control
|
|
1437
|
+
augmentation = np.sum(weights_control * (m_control - control_change)) / n_t
|
|
1438
|
+
|
|
1439
|
+
att = att_treated_part + augmentation
|
|
1440
|
+
|
|
1441
|
+
# Step 4: Standard error using influence function
|
|
1442
|
+
# Influence function for DR estimator
|
|
1443
|
+
psi_treated = (treated_change - m_treated - att) / n_t
|
|
1444
|
+
psi_control = (weights_control * (m_control - control_change)) / n_t
|
|
1445
|
+
|
|
1446
|
+
# Variance is sum of squared influence functions
|
|
1447
|
+
var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2)
|
|
1448
|
+
se = np.sqrt(var_psi) if var_psi > 0 else 0.0
|
|
1449
|
+
|
|
1450
|
+
# Full influence function
|
|
1451
|
+
inf_func = np.concatenate([psi_treated, psi_control])
|
|
1452
|
+
else:
|
|
1453
|
+
# Without covariates, DR simplifies to difference in means
|
|
1454
|
+
att = np.mean(treated_change) - np.mean(control_change)
|
|
1455
|
+
|
|
1456
|
+
var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
|
|
1457
|
+
var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
|
|
1458
|
+
|
|
1459
|
+
se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0
|
|
1460
|
+
|
|
1461
|
+
# Influence function for DR estimator
|
|
1462
|
+
inf_treated = (treated_change - np.mean(treated_change)) / n_t
|
|
1463
|
+
inf_control = (control_change - np.mean(control_change)) / n_c
|
|
1464
|
+
inf_func = np.concatenate([inf_treated, -inf_control])
|
|
1465
|
+
|
|
1466
|
+
return att, se, inf_func
|
|
1467
|
+
|
|
1468
|
+
def _aggregate_simple(
|
|
1469
|
+
self,
|
|
1470
|
+
group_time_effects: Dict,
|
|
1471
|
+
influence_func_info: Dict,
|
|
1472
|
+
df: pd.DataFrame,
|
|
1473
|
+
unit: str,
|
|
1474
|
+
precomputed: Optional[PrecomputedData] = None,
|
|
1475
|
+
) -> Tuple[float, float]:
|
|
1476
|
+
"""
|
|
1477
|
+
Compute simple weighted average of ATT(g,t).
|
|
1478
|
+
|
|
1479
|
+
Weights by group size (number of treated units).
|
|
1480
|
+
|
|
1481
|
+
Standard errors are computed using influence function aggregation,
|
|
1482
|
+
which properly accounts for covariances across (g,t) pairs due to
|
|
1483
|
+
shared control units. This includes the wif (weight influence function)
|
|
1484
|
+
adjustment from R's `did` package that accounts for uncertainty in
|
|
1485
|
+
estimating the group-size weights.
|
|
1486
|
+
"""
|
|
1487
|
+
effects = []
|
|
1488
|
+
weights_list = []
|
|
1489
|
+
gt_pairs = []
|
|
1490
|
+
groups_for_gt = []
|
|
1491
|
+
|
|
1492
|
+
for (g, t), data in group_time_effects.items():
|
|
1493
|
+
effects.append(data['effect'])
|
|
1494
|
+
weights_list.append(data['n_treated'])
|
|
1495
|
+
gt_pairs.append((g, t))
|
|
1496
|
+
groups_for_gt.append(g)
|
|
1497
|
+
|
|
1498
|
+
effects = np.array(effects)
|
|
1499
|
+
weights = np.array(weights_list, dtype=float)
|
|
1500
|
+
groups_for_gt = np.array(groups_for_gt)
|
|
1501
|
+
|
|
1502
|
+
# Normalize weights
|
|
1503
|
+
total_weight = np.sum(weights)
|
|
1504
|
+
weights_norm = weights / total_weight
|
|
1505
|
+
|
|
1506
|
+
# Weighted average
|
|
1507
|
+
overall_att = np.sum(weights_norm * effects)
|
|
1508
|
+
|
|
1509
|
+
# Compute SE using influence function aggregation with wif adjustment
|
|
1510
|
+
overall_se = self._compute_aggregated_se_with_wif(
|
|
1511
|
+
gt_pairs, weights_norm, effects, groups_for_gt,
|
|
1512
|
+
influence_func_info, df, unit, precomputed
|
|
1513
|
+
)
|
|
1514
|
+
|
|
1515
|
+
return overall_att, overall_se
|
|
1516
|
+
|
|
1517
|
+
def _compute_aggregated_se(
|
|
1518
|
+
self,
|
|
1519
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
1520
|
+
weights: np.ndarray,
|
|
1521
|
+
influence_func_info: Dict,
|
|
1522
|
+
) -> float:
|
|
1523
|
+
"""
|
|
1524
|
+
Compute standard error using influence function aggregation.
|
|
1525
|
+
|
|
1526
|
+
This properly accounts for covariances across (g,t) pairs by
|
|
1527
|
+
aggregating unit-level influence functions:
|
|
1528
|
+
|
|
1529
|
+
ψ_i(overall) = Σ_{(g,t)} w_(g,t) × ψ_i(g,t)
|
|
1530
|
+
Var(overall) = (1/n) Σ_i [ψ_i]²
|
|
1531
|
+
|
|
1532
|
+
This matches R's `did` package analytical SE formula.
|
|
1533
|
+
"""
|
|
1534
|
+
if not influence_func_info:
|
|
1535
|
+
# Fallback if no influence functions available
|
|
1536
|
+
return 0.0
|
|
1537
|
+
|
|
1538
|
+
# Build unit index mapping from all (g,t) pairs
|
|
1539
|
+
all_units = set()
|
|
1540
|
+
for (g, t) in gt_pairs:
|
|
1541
|
+
if (g, t) in influence_func_info:
|
|
1542
|
+
info = influence_func_info[(g, t)]
|
|
1543
|
+
all_units.update(info['treated_units'])
|
|
1544
|
+
all_units.update(info['control_units'])
|
|
1545
|
+
|
|
1546
|
+
if not all_units:
|
|
1547
|
+
return 0.0
|
|
1548
|
+
|
|
1549
|
+
all_units = sorted(all_units)
|
|
1550
|
+
n_units = len(all_units)
|
|
1551
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
1552
|
+
|
|
1553
|
+
# Aggregate influence functions across (g,t) pairs
|
|
1554
|
+
psi_overall = np.zeros(n_units)
|
|
1555
|
+
|
|
1556
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
1557
|
+
if (g, t) not in influence_func_info:
|
|
1558
|
+
continue
|
|
1559
|
+
|
|
1560
|
+
info = influence_func_info[(g, t)]
|
|
1561
|
+
w = weights[j]
|
|
1562
|
+
|
|
1563
|
+
# Treated unit contributions
|
|
1564
|
+
for i, unit_id in enumerate(info['treated_units']):
|
|
1565
|
+
idx = unit_to_idx[unit_id]
|
|
1566
|
+
psi_overall[idx] += w * info['treated_inf'][i]
|
|
1567
|
+
|
|
1568
|
+
# Control unit contributions
|
|
1569
|
+
for i, unit_id in enumerate(info['control_units']):
|
|
1570
|
+
idx = unit_to_idx[unit_id]
|
|
1571
|
+
psi_overall[idx] += w * info['control_inf'][i]
|
|
1572
|
+
|
|
1573
|
+
# Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ²
|
|
1574
|
+
variance = np.sum(psi_overall ** 2)
|
|
1575
|
+
return np.sqrt(variance)
|
|
1576
|
+
|
|
1577
|
+
def _compute_aggregated_se_with_wif(
|
|
1578
|
+
self,
|
|
1579
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
1580
|
+
weights: np.ndarray,
|
|
1581
|
+
effects: np.ndarray,
|
|
1582
|
+
groups_for_gt: np.ndarray,
|
|
1583
|
+
influence_func_info: Dict,
|
|
1584
|
+
df: pd.DataFrame,
|
|
1585
|
+
unit: str,
|
|
1586
|
+
precomputed: Optional[PrecomputedData] = None,
|
|
1587
|
+
) -> float:
|
|
1588
|
+
"""
|
|
1589
|
+
Compute SE with weight influence function (wif) adjustment.
|
|
1590
|
+
|
|
1591
|
+
This matches R's `did` package approach for "simple" aggregation,
|
|
1592
|
+
which accounts for uncertainty in estimating group-size weights.
|
|
1593
|
+
|
|
1594
|
+
The wif adjustment adds variance due to the fact that aggregation
|
|
1595
|
+
weights w_g = n_g / N depend on estimated group sizes.
|
|
1596
|
+
|
|
1597
|
+
Formula (matching R's did::aggte):
|
|
1598
|
+
agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
|
|
1599
|
+
se = sqrt(mean(agg_inf^2) / n)
|
|
1600
|
+
|
|
1601
|
+
where:
|
|
1602
|
+
- k indexes "keepers" (post-treatment (g,t) pairs)
|
|
1603
|
+
- w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
|
|
1604
|
+
- wif captures how unit i influences the weight estimation
|
|
1605
|
+
"""
|
|
1606
|
+
if not influence_func_info:
|
|
1607
|
+
return 0.0
|
|
1608
|
+
|
|
1609
|
+
# Build unit index mapping
|
|
1610
|
+
all_units_set: Set[Any] = set()
|
|
1611
|
+
for (g, t) in gt_pairs:
|
|
1612
|
+
if (g, t) in influence_func_info:
|
|
1613
|
+
info = influence_func_info[(g, t)]
|
|
1614
|
+
all_units_set.update(info['treated_units'])
|
|
1615
|
+
all_units_set.update(info['control_units'])
|
|
1616
|
+
|
|
1617
|
+
if not all_units_set:
|
|
1618
|
+
return 0.0
|
|
1619
|
+
|
|
1620
|
+
all_units = sorted(all_units_set)
|
|
1621
|
+
n_units = len(all_units)
|
|
1622
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
1623
|
+
|
|
1624
|
+
# Get unique groups and their information
|
|
1625
|
+
unique_groups = sorted(set(groups_for_gt))
|
|
1626
|
+
unique_groups_set = set(unique_groups)
|
|
1627
|
+
group_to_idx = {g: i for i, g in enumerate(unique_groups)}
|
|
1628
|
+
|
|
1629
|
+
# Compute group-level probabilities matching R's formula:
|
|
1630
|
+
# pg[g] = n_g / n_all (fraction of ALL units in group g)
|
|
1631
|
+
# This differs from our old formula which used n_g / total_treated
|
|
1632
|
+
group_sizes = {}
|
|
1633
|
+
for g in unique_groups:
|
|
1634
|
+
treated_in_g = df[df['first_treat'] == g][unit].nunique()
|
|
1635
|
+
group_sizes[g] = treated_in_g
|
|
1636
|
+
|
|
1637
|
+
# pg indexed by group
|
|
1638
|
+
pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups])
|
|
1639
|
+
|
|
1640
|
+
# pg indexed by keeper (each (g,t) pair gets its group's pg)
|
|
1641
|
+
# This matches R's: pg <- pgg[match(group, originalglist)]
|
|
1642
|
+
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
|
|
1643
|
+
sum_pg_keepers = np.sum(pg_keepers)
|
|
1644
|
+
|
|
1645
|
+
# Guard against zero weights (no keepers = no variance)
|
|
1646
|
+
if sum_pg_keepers == 0:
|
|
1647
|
+
return 0.0
|
|
1648
|
+
|
|
1649
|
+
# Standard aggregated influence (without wif)
|
|
1650
|
+
psi_standard = np.zeros(n_units)
|
|
1651
|
+
|
|
1652
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
1653
|
+
if (g, t) not in influence_func_info:
|
|
1654
|
+
continue
|
|
1655
|
+
|
|
1656
|
+
info = influence_func_info[(g, t)]
|
|
1657
|
+
w = weights[j]
|
|
1658
|
+
|
|
1659
|
+
# Vectorized influence function aggregation for treated units
|
|
1660
|
+
treated_indices = np.array([unit_to_idx[uid] for uid in info['treated_units']])
|
|
1661
|
+
if len(treated_indices) > 0:
|
|
1662
|
+
np.add.at(psi_standard, treated_indices, w * info['treated_inf'])
|
|
1663
|
+
|
|
1664
|
+
# Vectorized influence function aggregation for control units
|
|
1665
|
+
control_indices = np.array([unit_to_idx[uid] for uid in info['control_units']])
|
|
1666
|
+
if len(control_indices) > 0:
|
|
1667
|
+
np.add.at(psi_standard, control_indices, w * info['control_inf'])
|
|
1668
|
+
|
|
1669
|
+
# Build unit-group array using precomputed data if available
|
|
1670
|
+
# This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
|
|
1671
|
+
if precomputed is not None:
|
|
1672
|
+
# Use precomputed cohort mapping
|
|
1673
|
+
precomputed_units = precomputed['all_units']
|
|
1674
|
+
precomputed_cohorts = precomputed['unit_cohorts']
|
|
1675
|
+
precomputed_unit_to_idx = precomputed['unit_to_idx']
|
|
1676
|
+
|
|
1677
|
+
# Build unit_groups_array for the units in this SE computation
|
|
1678
|
+
# A value of -1 indicates never-treated or other (not in unique_groups)
|
|
1679
|
+
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
|
|
1680
|
+
for i, uid in enumerate(all_units):
|
|
1681
|
+
if uid in precomputed_unit_to_idx:
|
|
1682
|
+
cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
|
|
1683
|
+
if cohort in unique_groups_set:
|
|
1684
|
+
unit_groups_array[i] = cohort
|
|
1685
|
+
else:
|
|
1686
|
+
# Fallback: build from DataFrame (slow path for backward compatibility)
|
|
1687
|
+
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
|
|
1688
|
+
for i, uid in enumerate(all_units):
|
|
1689
|
+
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
|
|
1690
|
+
if unit_first_treat in unique_groups_set:
|
|
1691
|
+
unit_groups_array[i] = unit_first_treat
|
|
1692
|
+
|
|
1693
|
+
# Vectorized WIF computation
|
|
1694
|
+
# R's wif formula:
|
|
1695
|
+
# if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
|
|
1696
|
+
# if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
|
|
1697
|
+
# wif[i,k] = if1[i,k] - if2[i,k]
|
|
1698
|
+
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
|
|
1699
|
+
|
|
1700
|
+
# Build indicator matrix: (n_units, n_keepers)
|
|
1701
|
+
# indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
|
|
1702
|
+
groups_for_gt_array = np.array(groups_for_gt)
|
|
1703
|
+
indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64)
|
|
1704
|
+
|
|
1705
|
+
# Vectorized indicator_sum: sum over keepers
|
|
1706
|
+
# indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
|
|
1707
|
+
indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
|
|
1708
|
+
|
|
1709
|
+
# Vectorized wif matrix computation
|
|
1710
|
+
# if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
|
|
1711
|
+
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
|
|
1712
|
+
# if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
|
|
1713
|
+
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
|
|
1714
|
+
wif_matrix = if1_matrix - if2_matrix
|
|
1715
|
+
|
|
1716
|
+
# Single matrix-vector multiply for all contributions
|
|
1717
|
+
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
|
|
1718
|
+
wif_contrib = wif_matrix @ effects
|
|
1719
|
+
|
|
1720
|
+
# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
|
|
1721
|
+
psi_wif = wif_contrib / n_units
|
|
1722
|
+
|
|
1723
|
+
# Combine standard and wif terms
|
|
1724
|
+
psi_total = psi_standard + psi_wif
|
|
1725
|
+
|
|
1726
|
+
# Compute variance and SE
|
|
1727
|
+
# R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
|
|
1728
|
+
variance = np.sum(psi_total ** 2)
|
|
1729
|
+
return np.sqrt(variance)
|
|
1730
|
+
|
|
1731
|
+
def _aggregate_event_study(
|
|
1732
|
+
self,
|
|
1733
|
+
group_time_effects: Dict,
|
|
1734
|
+
influence_func_info: Dict,
|
|
1735
|
+
groups: List[Any],
|
|
1736
|
+
time_periods: List[Any],
|
|
1737
|
+
balance_e: Optional[int] = None,
|
|
1738
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
1739
|
+
"""
|
|
1740
|
+
Aggregate effects by relative time (event study).
|
|
1741
|
+
|
|
1742
|
+
Computes average effect at each event time e = t - g.
|
|
1743
|
+
|
|
1744
|
+
Standard errors use influence function aggregation to account for
|
|
1745
|
+
covariances across (g,t) pairs.
|
|
1746
|
+
"""
|
|
1747
|
+
# Organize effects by relative time, keeping track of (g,t) pairs
|
|
1748
|
+
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
|
|
1749
|
+
|
|
1750
|
+
for (g, t), data in group_time_effects.items():
|
|
1751
|
+
e = t - g # Relative time
|
|
1752
|
+
if e not in effects_by_e:
|
|
1753
|
+
effects_by_e[e] = []
|
|
1754
|
+
effects_by_e[e].append((
|
|
1755
|
+
(g, t), # Keep track of the (g,t) pair
|
|
1756
|
+
data['effect'],
|
|
1757
|
+
data['n_treated']
|
|
1758
|
+
))
|
|
1759
|
+
|
|
1760
|
+
# Balance the panel if requested
|
|
1761
|
+
if balance_e is not None:
|
|
1762
|
+
# Keep only groups that have effects at relative time balance_e
|
|
1763
|
+
groups_at_e = set()
|
|
1764
|
+
for (g, t), data in group_time_effects.items():
|
|
1765
|
+
if t - g == balance_e:
|
|
1766
|
+
groups_at_e.add(g)
|
|
1767
|
+
|
|
1768
|
+
# Filter effects to only include balanced groups
|
|
1769
|
+
balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
|
|
1770
|
+
for (g, t), data in group_time_effects.items():
|
|
1771
|
+
if g in groups_at_e:
|
|
1772
|
+
e = t - g
|
|
1773
|
+
if e not in balanced_effects:
|
|
1774
|
+
balanced_effects[e] = []
|
|
1775
|
+
balanced_effects[e].append((
|
|
1776
|
+
(g, t),
|
|
1777
|
+
data['effect'],
|
|
1778
|
+
data['n_treated']
|
|
1779
|
+
))
|
|
1780
|
+
effects_by_e = balanced_effects
|
|
1781
|
+
|
|
1782
|
+
# Compute aggregated effects
|
|
1783
|
+
event_study_effects = {}
|
|
1784
|
+
|
|
1785
|
+
for e, effect_list in sorted(effects_by_e.items()):
|
|
1786
|
+
gt_pairs = [x[0] for x in effect_list]
|
|
1787
|
+
effs = np.array([x[1] for x in effect_list])
|
|
1788
|
+
ns = np.array([x[2] for x in effect_list], dtype=float)
|
|
1789
|
+
|
|
1790
|
+
# Weight by group size
|
|
1791
|
+
weights = ns / np.sum(ns)
|
|
1792
|
+
|
|
1793
|
+
agg_effect = np.sum(weights * effs)
|
|
1794
|
+
|
|
1795
|
+
# Compute SE using influence function aggregation
|
|
1796
|
+
agg_se = self._compute_aggregated_se(
|
|
1797
|
+
gt_pairs, weights, influence_func_info
|
|
1798
|
+
)
|
|
1799
|
+
|
|
1800
|
+
t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
|
|
1801
|
+
p_val = compute_p_value(t_stat)
|
|
1802
|
+
ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
|
|
1803
|
+
|
|
1804
|
+
event_study_effects[e] = {
|
|
1805
|
+
'effect': agg_effect,
|
|
1806
|
+
'se': agg_se,
|
|
1807
|
+
't_stat': t_stat,
|
|
1808
|
+
'p_value': p_val,
|
|
1809
|
+
'conf_int': ci,
|
|
1810
|
+
'n_groups': len(effect_list),
|
|
1811
|
+
}
|
|
1812
|
+
|
|
1813
|
+
return event_study_effects
|
|
1814
|
+
|
|
1815
|
+
def _aggregate_by_group(
|
|
1816
|
+
self,
|
|
1817
|
+
group_time_effects: Dict,
|
|
1818
|
+
influence_func_info: Dict,
|
|
1819
|
+
groups: List[Any],
|
|
1820
|
+
) -> Dict[Any, Dict[str, Any]]:
|
|
1821
|
+
"""
|
|
1822
|
+
Aggregate effects by treatment cohort.
|
|
1823
|
+
|
|
1824
|
+
Computes average effect for each cohort across all post-treatment periods.
|
|
1825
|
+
|
|
1826
|
+
Standard errors use influence function aggregation to account for
|
|
1827
|
+
covariances across time periods within a cohort.
|
|
1828
|
+
"""
|
|
1829
|
+
group_effects = {}
|
|
1830
|
+
|
|
1831
|
+
for g in groups:
|
|
1832
|
+
# Get all effects for this group (post-treatment only: t >= g)
|
|
1833
|
+
# Keep track of (g, t) pairs for influence function aggregation
|
|
1834
|
+
g_effects = [
|
|
1835
|
+
((g, t), data['effect'])
|
|
1836
|
+
for (gg, t), data in group_time_effects.items()
|
|
1837
|
+
if gg == g and t >= g
|
|
1838
|
+
]
|
|
1839
|
+
|
|
1840
|
+
if not g_effects:
|
|
1841
|
+
continue
|
|
1842
|
+
|
|
1843
|
+
gt_pairs = [x[0] for x in g_effects]
|
|
1844
|
+
effs = np.array([x[1] for x in g_effects])
|
|
1845
|
+
|
|
1846
|
+
# Equal weight across time periods for a group
|
|
1847
|
+
weights = np.ones(len(effs)) / len(effs)
|
|
1848
|
+
|
|
1849
|
+
agg_effect = np.sum(weights * effs)
|
|
1850
|
+
|
|
1851
|
+
# Compute SE using influence function aggregation
|
|
1852
|
+
agg_se = self._compute_aggregated_se(
|
|
1853
|
+
gt_pairs, weights, influence_func_info
|
|
1854
|
+
)
|
|
1855
|
+
|
|
1856
|
+
t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
|
|
1857
|
+
p_val = compute_p_value(t_stat)
|
|
1858
|
+
ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
|
|
1859
|
+
|
|
1860
|
+
group_effects[g] = {
|
|
1861
|
+
'effect': agg_effect,
|
|
1862
|
+
'se': agg_se,
|
|
1863
|
+
't_stat': t_stat,
|
|
1864
|
+
'p_value': p_val,
|
|
1865
|
+
'conf_int': ci,
|
|
1866
|
+
'n_periods': len(g_effects),
|
|
1867
|
+
}
|
|
1868
|
+
|
|
1869
|
+
return group_effects
|
|
1870
|
+
|
|
1871
|
+
def _run_multiplier_bootstrap(
|
|
1872
|
+
self,
|
|
1873
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
1874
|
+
influence_func_info: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
1875
|
+
aggregate: Optional[str],
|
|
1876
|
+
balance_e: Optional[int],
|
|
1877
|
+
treatment_groups: List[Any],
|
|
1878
|
+
time_periods: List[Any],
|
|
1879
|
+
) -> CSBootstrapResults:
|
|
1880
|
+
"""
|
|
1881
|
+
Run multiplier bootstrap for inference on all parameters.
|
|
1882
|
+
|
|
1883
|
+
This implements the multiplier bootstrap procedure from Callaway & Sant'Anna (2021).
|
|
1884
|
+
The key idea is to perturb the influence function contributions with random
|
|
1885
|
+
weights at the cluster (unit) level, then recompute aggregations.
|
|
1886
|
+
|
|
1887
|
+
Parameters
|
|
1888
|
+
----------
|
|
1889
|
+
group_time_effects : dict
|
|
1890
|
+
Dictionary of ATT(g,t) effects with analytical SEs.
|
|
1891
|
+
influence_func_info : dict
|
|
1892
|
+
Dictionary mapping (g,t) to influence function information.
|
|
1893
|
+
aggregate : str, optional
|
|
1894
|
+
Type of aggregation requested.
|
|
1895
|
+
balance_e : int, optional
|
|
1896
|
+
Balance parameter for event study.
|
|
1897
|
+
treatment_groups : list
|
|
1898
|
+
List of treatment cohorts.
|
|
1899
|
+
time_periods : list
|
|
1900
|
+
List of time periods.
|
|
1901
|
+
|
|
1902
|
+
Returns
|
|
1903
|
+
-------
|
|
1904
|
+
CSBootstrapResults
|
|
1905
|
+
Bootstrap inference results.
|
|
1906
|
+
"""
|
|
1907
|
+
# Warn about low bootstrap iterations
|
|
1908
|
+
if self.n_bootstrap < 50:
|
|
1909
|
+
warnings.warn(
|
|
1910
|
+
f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
|
|
1911
|
+
"for reliable inference. Percentile confidence intervals and p-values "
|
|
1912
|
+
"may be unreliable with few iterations.",
|
|
1913
|
+
UserWarning,
|
|
1914
|
+
stacklevel=3,
|
|
1915
|
+
)
|
|
1916
|
+
|
|
1917
|
+
rng = np.random.default_rng(self.seed)
|
|
1918
|
+
|
|
1919
|
+
# Collect all unique units across all (g,t) combinations
|
|
1920
|
+
all_units = set()
|
|
1921
|
+
for (g, t), info in influence_func_info.items():
|
|
1922
|
+
all_units.update(info['treated_units'])
|
|
1923
|
+
all_units.update(info['control_units'])
|
|
1924
|
+
all_units = sorted(all_units)
|
|
1925
|
+
n_units = len(all_units)
|
|
1926
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
1927
|
+
|
|
1928
|
+
# Get list of (g,t) pairs
|
|
1929
|
+
gt_pairs = list(group_time_effects.keys())
|
|
1930
|
+
n_gt = len(gt_pairs)
|
|
1931
|
+
|
|
1932
|
+
# Compute aggregation weights for overall ATT
|
|
1933
|
+
overall_weights = np.array([
|
|
1934
|
+
group_time_effects[gt]['n_treated'] for gt in gt_pairs
|
|
1935
|
+
], dtype=float)
|
|
1936
|
+
overall_weights = overall_weights / np.sum(overall_weights)
|
|
1937
|
+
|
|
1938
|
+
# Original point estimates
|
|
1939
|
+
original_atts = np.array([group_time_effects[gt]['effect'] for gt in gt_pairs])
|
|
1940
|
+
original_overall = np.sum(overall_weights * original_atts)
|
|
1941
|
+
|
|
1942
|
+
# Prepare event study and group aggregation info if needed
|
|
1943
|
+
event_study_info = None
|
|
1944
|
+
group_agg_info = None
|
|
1945
|
+
|
|
1946
|
+
if aggregate in ["event_study", "all"]:
|
|
1947
|
+
event_study_info = self._prepare_event_study_aggregation(
|
|
1948
|
+
gt_pairs, group_time_effects, balance_e
|
|
1949
|
+
)
|
|
1950
|
+
|
|
1951
|
+
if aggregate in ["group", "all"]:
|
|
1952
|
+
group_agg_info = self._prepare_group_aggregation(
|
|
1953
|
+
gt_pairs, group_time_effects, treatment_groups
|
|
1954
|
+
)
|
|
1955
|
+
|
|
1956
|
+
# Pre-compute unit index arrays for each (g,t) pair (done once, not per iteration)
|
|
1957
|
+
gt_treated_indices = []
|
|
1958
|
+
gt_control_indices = []
|
|
1959
|
+
gt_treated_inf = []
|
|
1960
|
+
gt_control_inf = []
|
|
1961
|
+
|
|
1962
|
+
for j, gt in enumerate(gt_pairs):
|
|
1963
|
+
info = influence_func_info[gt]
|
|
1964
|
+
treated_idx = np.array([unit_to_idx[u] for u in info['treated_units']])
|
|
1965
|
+
control_idx = np.array([unit_to_idx[u] for u in info['control_units']])
|
|
1966
|
+
gt_treated_indices.append(treated_idx)
|
|
1967
|
+
gt_control_indices.append(control_idx)
|
|
1968
|
+
gt_treated_inf.append(np.asarray(info['treated_inf']))
|
|
1969
|
+
gt_control_inf.append(np.asarray(info['control_inf']))
|
|
1970
|
+
|
|
1971
|
+
# Generate ALL bootstrap weights upfront: shape (n_bootstrap, n_units)
|
|
1972
|
+
# This is much faster than generating one at a time
|
|
1973
|
+
all_bootstrap_weights = _generate_bootstrap_weights_batch(
|
|
1974
|
+
self.n_bootstrap, n_units, self.bootstrap_weight_type, rng
|
|
1975
|
+
)
|
|
1976
|
+
|
|
1977
|
+
# Vectorized bootstrap ATT(g,t) computation
|
|
1978
|
+
# Compute all bootstrap ATTs for all (g,t) pairs using matrix operations
|
|
1979
|
+
bootstrap_atts_gt = np.zeros((self.n_bootstrap, n_gt))
|
|
1980
|
+
|
|
1981
|
+
for j in range(n_gt):
|
|
1982
|
+
treated_idx = gt_treated_indices[j]
|
|
1983
|
+
control_idx = gt_control_indices[j]
|
|
1984
|
+
treated_inf = gt_treated_inf[j]
|
|
1985
|
+
control_inf = gt_control_inf[j]
|
|
1986
|
+
|
|
1987
|
+
# Extract weights for this (g,t)'s units across all bootstrap iterations
|
|
1988
|
+
# Shape: (n_bootstrap, n_treated) and (n_bootstrap, n_control)
|
|
1989
|
+
treated_weights = all_bootstrap_weights[:, treated_idx]
|
|
1990
|
+
control_weights = all_bootstrap_weights[:, control_idx]
|
|
1991
|
+
|
|
1992
|
+
# Vectorized perturbation: matrix-vector multiply
|
|
1993
|
+
# Shape: (n_bootstrap,)
|
|
1994
|
+
perturbations = (
|
|
1995
|
+
treated_weights @ treated_inf +
|
|
1996
|
+
control_weights @ control_inf
|
|
1997
|
+
)
|
|
1998
|
+
|
|
1999
|
+
bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
|
|
2000
|
+
|
|
2001
|
+
# Vectorized overall ATT: matrix-vector multiply
|
|
2002
|
+
# Shape: (n_bootstrap,)
|
|
2003
|
+
bootstrap_overall = bootstrap_atts_gt @ overall_weights
|
|
2004
|
+
|
|
2005
|
+
# Vectorized event study aggregation
|
|
2006
|
+
if event_study_info is not None:
|
|
2007
|
+
rel_periods = sorted(event_study_info.keys())
|
|
2008
|
+
bootstrap_event_study = {}
|
|
2009
|
+
for e in rel_periods:
|
|
2010
|
+
agg_info = event_study_info[e]
|
|
2011
|
+
gt_indices = agg_info['gt_indices']
|
|
2012
|
+
weights = agg_info['weights']
|
|
2013
|
+
# Vectorized: select columns and multiply by weights
|
|
2014
|
+
bootstrap_event_study[e] = bootstrap_atts_gt[:, gt_indices] @ weights
|
|
2015
|
+
else:
|
|
2016
|
+
bootstrap_event_study = None
|
|
2017
|
+
|
|
2018
|
+
# Vectorized group aggregation
|
|
2019
|
+
if group_agg_info is not None:
|
|
2020
|
+
groups = sorted(group_agg_info.keys())
|
|
2021
|
+
bootstrap_group = {}
|
|
2022
|
+
for g in groups:
|
|
2023
|
+
agg_info = group_agg_info[g]
|
|
2024
|
+
gt_indices = agg_info['gt_indices']
|
|
2025
|
+
weights = agg_info['weights']
|
|
2026
|
+
bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights
|
|
2027
|
+
else:
|
|
2028
|
+
bootstrap_group = None
|
|
2029
|
+
|
|
2030
|
+
# Compute bootstrap statistics for ATT(g,t)
|
|
2031
|
+
gt_ses = {}
|
|
2032
|
+
gt_cis = {}
|
|
2033
|
+
gt_p_values = {}
|
|
2034
|
+
|
|
2035
|
+
for j, gt in enumerate(gt_pairs):
|
|
2036
|
+
se, ci, p_value = self._compute_effect_bootstrap_stats(
|
|
2037
|
+
original_atts[j], bootstrap_atts_gt[:, j]
|
|
2038
|
+
)
|
|
2039
|
+
gt_ses[gt] = se
|
|
2040
|
+
gt_cis[gt] = ci
|
|
2041
|
+
gt_p_values[gt] = p_value
|
|
2042
|
+
|
|
2043
|
+
# Compute bootstrap statistics for overall ATT
|
|
2044
|
+
overall_se, overall_ci, overall_p_value = self._compute_effect_bootstrap_stats(
|
|
2045
|
+
original_overall, bootstrap_overall
|
|
2046
|
+
)
|
|
2047
|
+
|
|
2048
|
+
# Compute bootstrap statistics for event study effects
|
|
2049
|
+
event_study_ses = None
|
|
2050
|
+
event_study_cis = None
|
|
2051
|
+
event_study_p_values = None
|
|
2052
|
+
|
|
2053
|
+
if bootstrap_event_study is not None and event_study_info is not None:
|
|
2054
|
+
event_study_ses = {}
|
|
2055
|
+
event_study_cis = {}
|
|
2056
|
+
event_study_p_values = {}
|
|
2057
|
+
|
|
2058
|
+
for e in rel_periods:
|
|
2059
|
+
se, ci, p_value = self._compute_effect_bootstrap_stats(
|
|
2060
|
+
event_study_info[e]['effect'], bootstrap_event_study[e]
|
|
2061
|
+
)
|
|
2062
|
+
event_study_ses[e] = se
|
|
2063
|
+
event_study_cis[e] = ci
|
|
2064
|
+
event_study_p_values[e] = p_value
|
|
2065
|
+
|
|
2066
|
+
# Compute bootstrap statistics for group effects
|
|
2067
|
+
group_effect_ses = None
|
|
2068
|
+
group_effect_cis = None
|
|
2069
|
+
group_effect_p_values = None
|
|
2070
|
+
|
|
2071
|
+
if bootstrap_group is not None and group_agg_info is not None:
|
|
2072
|
+
group_effect_ses = {}
|
|
2073
|
+
group_effect_cis = {}
|
|
2074
|
+
group_effect_p_values = {}
|
|
2075
|
+
|
|
2076
|
+
for g in groups:
|
|
2077
|
+
se, ci, p_value = self._compute_effect_bootstrap_stats(
|
|
2078
|
+
group_agg_info[g]['effect'], bootstrap_group[g]
|
|
2079
|
+
)
|
|
2080
|
+
group_effect_ses[g] = se
|
|
2081
|
+
group_effect_cis[g] = ci
|
|
2082
|
+
group_effect_p_values[g] = p_value
|
|
2083
|
+
|
|
2084
|
+
return CSBootstrapResults(
|
|
2085
|
+
n_bootstrap=self.n_bootstrap,
|
|
2086
|
+
weight_type=self.bootstrap_weight_type,
|
|
2087
|
+
alpha=self.alpha,
|
|
2088
|
+
overall_att_se=overall_se,
|
|
2089
|
+
overall_att_ci=overall_ci,
|
|
2090
|
+
overall_att_p_value=overall_p_value,
|
|
2091
|
+
group_time_ses=gt_ses,
|
|
2092
|
+
group_time_cis=gt_cis,
|
|
2093
|
+
group_time_p_values=gt_p_values,
|
|
2094
|
+
event_study_ses=event_study_ses,
|
|
2095
|
+
event_study_cis=event_study_cis,
|
|
2096
|
+
event_study_p_values=event_study_p_values,
|
|
2097
|
+
group_effect_ses=group_effect_ses,
|
|
2098
|
+
group_effect_cis=group_effect_cis,
|
|
2099
|
+
group_effect_p_values=group_effect_p_values,
|
|
2100
|
+
bootstrap_distribution=bootstrap_overall,
|
|
2101
|
+
)
|
|
2102
|
+
|
|
2103
|
+
def _prepare_event_study_aggregation(
|
|
2104
|
+
self,
|
|
2105
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
2106
|
+
group_time_effects: Dict,
|
|
2107
|
+
balance_e: Optional[int],
|
|
2108
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
2109
|
+
"""Prepare aggregation info for event study bootstrap."""
|
|
2110
|
+
# Organize by relative time
|
|
2111
|
+
effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}
|
|
2112
|
+
|
|
2113
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
2114
|
+
e = t - g
|
|
2115
|
+
if e not in effects_by_e:
|
|
2116
|
+
effects_by_e[e] = []
|
|
2117
|
+
effects_by_e[e].append((
|
|
2118
|
+
j, # index in gt_pairs
|
|
2119
|
+
group_time_effects[(g, t)]['effect'],
|
|
2120
|
+
group_time_effects[(g, t)]['n_treated']
|
|
2121
|
+
))
|
|
2122
|
+
|
|
2123
|
+
# Balance if requested
|
|
2124
|
+
if balance_e is not None:
|
|
2125
|
+
groups_at_e = set()
|
|
2126
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
2127
|
+
if t - g == balance_e:
|
|
2128
|
+
groups_at_e.add(g)
|
|
2129
|
+
|
|
2130
|
+
balanced_effects: Dict[int, List[Tuple[int, float, float]]] = {}
|
|
2131
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
2132
|
+
if g in groups_at_e:
|
|
2133
|
+
e = t - g
|
|
2134
|
+
if e not in balanced_effects:
|
|
2135
|
+
balanced_effects[e] = []
|
|
2136
|
+
balanced_effects[e].append((
|
|
2137
|
+
j,
|
|
2138
|
+
group_time_effects[(g, t)]['effect'],
|
|
2139
|
+
group_time_effects[(g, t)]['n_treated']
|
|
2140
|
+
))
|
|
2141
|
+
effects_by_e = balanced_effects
|
|
2142
|
+
|
|
2143
|
+
# Compute aggregation weights
|
|
2144
|
+
result = {}
|
|
2145
|
+
for e, effect_list in effects_by_e.items():
|
|
2146
|
+
indices = np.array([x[0] for x in effect_list])
|
|
2147
|
+
effects = np.array([x[1] for x in effect_list])
|
|
2148
|
+
n_treated = np.array([x[2] for x in effect_list], dtype=float)
|
|
2149
|
+
|
|
2150
|
+
weights = n_treated / np.sum(n_treated)
|
|
2151
|
+
agg_effect = np.sum(weights * effects)
|
|
2152
|
+
|
|
2153
|
+
result[e] = {
|
|
2154
|
+
'gt_indices': indices,
|
|
2155
|
+
'weights': weights,
|
|
2156
|
+
'effect': agg_effect,
|
|
2157
|
+
}
|
|
2158
|
+
|
|
2159
|
+
return result
|
|
2160
|
+
|
|
2161
|
+
def _prepare_group_aggregation(
|
|
2162
|
+
self,
|
|
2163
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
2164
|
+
group_time_effects: Dict,
|
|
2165
|
+
treatment_groups: List[Any],
|
|
2166
|
+
) -> Dict[Any, Dict[str, Any]]:
|
|
2167
|
+
"""Prepare aggregation info for group-level bootstrap."""
|
|
2168
|
+
result = {}
|
|
2169
|
+
|
|
2170
|
+
for g in treatment_groups:
|
|
2171
|
+
# Get all effects for this group (post-treatment only: t >= g)
|
|
2172
|
+
group_data = []
|
|
2173
|
+
for j, (gg, t) in enumerate(gt_pairs):
|
|
2174
|
+
if gg == g and t >= g:
|
|
2175
|
+
group_data.append((
|
|
2176
|
+
j,
|
|
2177
|
+
group_time_effects[(gg, t)]['effect'],
|
|
2178
|
+
))
|
|
2179
|
+
|
|
2180
|
+
if not group_data:
|
|
2181
|
+
continue
|
|
2182
|
+
|
|
2183
|
+
indices = np.array([x[0] for x in group_data])
|
|
2184
|
+
effects = np.array([x[1] for x in group_data])
|
|
2185
|
+
|
|
2186
|
+
# Equal weights across time periods
|
|
2187
|
+
weights = np.ones(len(effects)) / len(effects)
|
|
2188
|
+
agg_effect = np.sum(weights * effects)
|
|
2189
|
+
|
|
2190
|
+
result[g] = {
|
|
2191
|
+
'gt_indices': indices,
|
|
2192
|
+
'weights': weights,
|
|
2193
|
+
'effect': agg_effect,
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
return result
|
|
2197
|
+
|
|
2198
|
+
def _compute_percentile_ci(
|
|
2199
|
+
self,
|
|
2200
|
+
boot_dist: np.ndarray,
|
|
2201
|
+
alpha: float,
|
|
2202
|
+
) -> Tuple[float, float]:
|
|
2203
|
+
"""Compute percentile confidence interval from bootstrap distribution."""
|
|
2204
|
+
lower = float(np.percentile(boot_dist, alpha / 2 * 100))
|
|
2205
|
+
upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
|
|
2206
|
+
return (lower, upper)
|
|
2207
|
+
|
|
2208
|
+
def _compute_bootstrap_pvalue(
|
|
2209
|
+
self,
|
|
2210
|
+
original_effect: float,
|
|
2211
|
+
boot_dist: np.ndarray,
|
|
2212
|
+
) -> float:
|
|
2213
|
+
"""
|
|
2214
|
+
Compute two-sided bootstrap p-value.
|
|
2215
|
+
|
|
2216
|
+
Uses the percentile method: p-value is the proportion of bootstrap
|
|
2217
|
+
estimates on the opposite side of zero from the original estimate,
|
|
2218
|
+
doubled for two-sided test.
|
|
2219
|
+
"""
|
|
2220
|
+
if original_effect >= 0:
|
|
2221
|
+
# Proportion of bootstrap estimates <= 0
|
|
2222
|
+
p_one_sided = np.mean(boot_dist <= 0)
|
|
2223
|
+
else:
|
|
2224
|
+
# Proportion of bootstrap estimates >= 0
|
|
2225
|
+
p_one_sided = np.mean(boot_dist >= 0)
|
|
2226
|
+
|
|
2227
|
+
# Two-sided p-value
|
|
2228
|
+
p_value = min(2 * p_one_sided, 1.0)
|
|
2229
|
+
|
|
2230
|
+
# Ensure minimum p-value
|
|
2231
|
+
p_value = max(p_value, 1 / (self.n_bootstrap + 1))
|
|
2232
|
+
|
|
2233
|
+
return float(p_value)
|
|
2234
|
+
|
|
2235
|
+
def _compute_effect_bootstrap_stats(
|
|
2236
|
+
self,
|
|
2237
|
+
original_effect: float,
|
|
2238
|
+
boot_dist: np.ndarray,
|
|
2239
|
+
) -> Tuple[float, Tuple[float, float], float]:
|
|
2240
|
+
"""
|
|
2241
|
+
Compute bootstrap statistics for a single effect.
|
|
2242
|
+
|
|
2243
|
+
Parameters
|
|
2244
|
+
----------
|
|
2245
|
+
original_effect : float
|
|
2246
|
+
Original point estimate.
|
|
2247
|
+
boot_dist : np.ndarray
|
|
2248
|
+
Bootstrap distribution of the effect.
|
|
2249
|
+
|
|
2250
|
+
Returns
|
|
2251
|
+
-------
|
|
2252
|
+
se : float
|
|
2253
|
+
Bootstrap standard error.
|
|
2254
|
+
ci : Tuple[float, float]
|
|
2255
|
+
Percentile confidence interval.
|
|
2256
|
+
p_value : float
|
|
2257
|
+
Bootstrap p-value.
|
|
2258
|
+
"""
|
|
2259
|
+
se = float(np.std(boot_dist, ddof=1))
|
|
2260
|
+
ci = self._compute_percentile_ci(boot_dist, self.alpha)
|
|
2261
|
+
p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
|
|
2262
|
+
return se, ci, p_value
|
|
2263
|
+
|
|
2264
|
+
def get_params(self) -> Dict[str, Any]:
|
|
2265
|
+
"""Get estimator parameters (sklearn-compatible)."""
|
|
2266
|
+
return {
|
|
2267
|
+
"control_group": self.control_group,
|
|
2268
|
+
"anticipation": self.anticipation,
|
|
2269
|
+
"estimation_method": self.estimation_method,
|
|
2270
|
+
"alpha": self.alpha,
|
|
2271
|
+
"cluster": self.cluster,
|
|
2272
|
+
"n_bootstrap": self.n_bootstrap,
|
|
2273
|
+
"bootstrap_weights": self.bootstrap_weights,
|
|
2274
|
+
# Deprecated but kept for backward compatibility
|
|
2275
|
+
"bootstrap_weight_type": self.bootstrap_weight_type,
|
|
2276
|
+
"seed": self.seed,
|
|
2277
|
+
}
|
|
2278
|
+
|
|
2279
|
+
def set_params(self, **params) -> "CallawaySantAnna":
|
|
2280
|
+
"""Set estimator parameters (sklearn-compatible)."""
|
|
2281
|
+
for key, value in params.items():
|
|
2282
|
+
if hasattr(self, key):
|
|
2283
|
+
setattr(self, key, value)
|
|
2284
|
+
else:
|
|
2285
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
2286
|
+
return self
|
|
2287
|
+
|
|
2288
|
+
def summary(self) -> str:
|
|
2289
|
+
"""Get summary of estimation results."""
|
|
2290
|
+
if not self.is_fitted_:
|
|
2291
|
+
raise RuntimeError("Model must be fitted before calling summary()")
|
|
2292
|
+
assert self.results_ is not None
|
|
2293
|
+
return self.results_.summary()
|
|
2294
|
+
|
|
2295
|
+
def print_summary(self) -> None:
|
|
2296
|
+
"""Print summary to stdout."""
|
|
2297
|
+
print(self.summary())
|