diff-diff 2.3.2__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +254 -0
- diff_diff/_backend.py +112 -0
- diff_diff/_rust_backend.cp313-win_amd64.pyd +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1161 -0
- diff_diff/honest_did.py +1511 -0
- diff_diff/imputation.py +2480 -0
- diff_diff/linalg.py +1537 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1241 -0
- diff_diff/prep_dgp.py +777 -0
- diff_diff/pretrends.py +1104 -0
- diff_diff/results.py +794 -0
- diff_diff/staggered.py +1120 -0
- diff_diff/staggered_aggregation.py +492 -0
- diff_diff/staggered_bootstrap.py +753 -0
- diff_diff/staggered_results.py +296 -0
- diff_diff/sun_abraham.py +1227 -0
- diff_diff/synthetic_did.py +858 -0
- diff_diff/triple_diff.py +1322 -0
- diff_diff/trop.py +2904 -0
- diff_diff/twfe.py +428 -0
- diff_diff/utils.py +1845 -0
- diff_diff/visualization.py +1676 -0
- diff_diff-2.3.2.dist-info/METADATA +2646 -0
- diff_diff-2.3.2.dist-info/RECORD +30 -0
- diff_diff-2.3.2.dist-info/WHEEL +4 -0
- diff_diff-2.3.2.dist-info/sboms/diff_diff_rust.cyclonedx.json +5952 -0
|
@@ -0,0 +1,753 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Bootstrap inference for Callaway-Sant'Anna estimator.
|
|
3
|
+
|
|
4
|
+
This module provides bootstrap weight generation functions, the bootstrap
|
|
5
|
+
results container, and the mixin class with bootstrap inference methods.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
# Import Rust backend if available (from _backend to avoid circular imports)
|
|
15
|
+
from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# =============================================================================
|
|
22
|
+
# Bootstrap Weight Generators
|
|
23
|
+
# =============================================================================
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _generate_bootstrap_weights(
|
|
27
|
+
n_units: int,
|
|
28
|
+
weight_type: str,
|
|
29
|
+
rng: np.random.Generator,
|
|
30
|
+
) -> np.ndarray:
|
|
31
|
+
"""
|
|
32
|
+
Generate bootstrap weights for multiplier bootstrap.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
n_units : int
|
|
37
|
+
Number of units (clusters) to generate weights for.
|
|
38
|
+
weight_type : str
|
|
39
|
+
Type of weights: "rademacher", "mammen", or "webb".
|
|
40
|
+
rng : np.random.Generator
|
|
41
|
+
Random number generator.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
np.ndarray
|
|
46
|
+
Array of bootstrap weights with shape (n_units,).
|
|
47
|
+
"""
|
|
48
|
+
if weight_type == "rademacher":
|
|
49
|
+
# Rademacher: +1 or -1 with equal probability
|
|
50
|
+
return rng.choice([-1.0, 1.0], size=n_units)
|
|
51
|
+
|
|
52
|
+
elif weight_type == "mammen":
|
|
53
|
+
# Mammen's two-point distribution
|
|
54
|
+
# E[v] = 0, E[v^2] = 1, E[v^3] = 1
|
|
55
|
+
sqrt5 = np.sqrt(5)
|
|
56
|
+
val1 = -(sqrt5 - 1) / 2 # ≈ -0.618
|
|
57
|
+
val2 = (sqrt5 + 1) / 2 # ≈ 1.618 (golden ratio)
|
|
58
|
+
p1 = (sqrt5 + 1) / (2 * sqrt5) # ≈ 0.724
|
|
59
|
+
return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1])
|
|
60
|
+
|
|
61
|
+
elif weight_type == "webb":
|
|
62
|
+
# Webb's 6-point distribution (recommended for few clusters)
|
|
63
|
+
# Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
|
|
64
|
+
# This matches R's did package: E[w]=0, Var(w)=1.0
|
|
65
|
+
values = np.array([
|
|
66
|
+
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
67
|
+
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
68
|
+
])
|
|
69
|
+
return rng.choice(values, size=n_units) # Equal probs (1/6 each)
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
|
|
74
|
+
f"got '{weight_type}'"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _generate_bootstrap_weights_batch(
|
|
79
|
+
n_bootstrap: int,
|
|
80
|
+
n_units: int,
|
|
81
|
+
weight_type: str,
|
|
82
|
+
rng: np.random.Generator,
|
|
83
|
+
) -> np.ndarray:
|
|
84
|
+
"""
|
|
85
|
+
Generate all bootstrap weights at once (vectorized).
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
n_bootstrap : int
|
|
90
|
+
Number of bootstrap iterations.
|
|
91
|
+
n_units : int
|
|
92
|
+
Number of units (clusters) to generate weights for.
|
|
93
|
+
weight_type : str
|
|
94
|
+
Type of weights: "rademacher", "mammen", or "webb".
|
|
95
|
+
rng : np.random.Generator
|
|
96
|
+
Random number generator.
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
np.ndarray
|
|
101
|
+
Array of bootstrap weights with shape (n_bootstrap, n_units).
|
|
102
|
+
"""
|
|
103
|
+
# Use Rust backend if available (parallel + fast RNG)
|
|
104
|
+
if HAS_RUST_BACKEND and _rust_bootstrap_weights is not None:
|
|
105
|
+
# Get seed from the NumPy RNG for reproducibility
|
|
106
|
+
seed = rng.integers(0, 2**63 - 1)
|
|
107
|
+
return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed)
|
|
108
|
+
|
|
109
|
+
# Fallback to NumPy implementation
|
|
110
|
+
return _generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _generate_bootstrap_weights_batch_numpy(
|
|
114
|
+
n_bootstrap: int,
|
|
115
|
+
n_units: int,
|
|
116
|
+
weight_type: str,
|
|
117
|
+
rng: np.random.Generator,
|
|
118
|
+
) -> np.ndarray:
|
|
119
|
+
"""
|
|
120
|
+
NumPy fallback implementation of _generate_bootstrap_weights_batch.
|
|
121
|
+
|
|
122
|
+
Generates multiplier bootstrap weights for wild cluster bootstrap.
|
|
123
|
+
All weight distributions satisfy E[w] = 0, E[w^2] = 1.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
n_bootstrap : int
|
|
128
|
+
Number of bootstrap iterations.
|
|
129
|
+
n_units : int
|
|
130
|
+
Number of units (clusters) to generate weights for.
|
|
131
|
+
weight_type : str
|
|
132
|
+
Type of weights: "rademacher" (+-1), "mammen" (2-point),
|
|
133
|
+
or "webb" (6-point).
|
|
134
|
+
rng : np.random.Generator
|
|
135
|
+
Random number generator for reproducibility.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
np.ndarray
|
|
140
|
+
Array of bootstrap weights with shape (n_bootstrap, n_units).
|
|
141
|
+
"""
|
|
142
|
+
if weight_type == "rademacher":
|
|
143
|
+
# Rademacher: +1 or -1 with equal probability
|
|
144
|
+
return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units))
|
|
145
|
+
|
|
146
|
+
elif weight_type == "mammen":
|
|
147
|
+
# Mammen's two-point distribution
|
|
148
|
+
sqrt5 = np.sqrt(5)
|
|
149
|
+
val1 = -(sqrt5 - 1) / 2
|
|
150
|
+
val2 = (sqrt5 + 1) / 2
|
|
151
|
+
p1 = (sqrt5 + 1) / (2 * sqrt5)
|
|
152
|
+
return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1])
|
|
153
|
+
|
|
154
|
+
elif weight_type == "webb":
|
|
155
|
+
# Webb's 6-point distribution
|
|
156
|
+
# Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
|
|
157
|
+
# This matches R's did package: E[w]=0, Var(w)=1.0
|
|
158
|
+
values = np.array([
|
|
159
|
+
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
160
|
+
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
161
|
+
])
|
|
162
|
+
return rng.choice(values, size=(n_bootstrap, n_units)) # Equal probs (1/6 each)
|
|
163
|
+
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
|
|
167
|
+
f"got '{weight_type}'"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# =============================================================================
|
|
172
|
+
# Bootstrap Results Container
|
|
173
|
+
# =============================================================================
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@dataclass
|
|
177
|
+
class CSBootstrapResults:
|
|
178
|
+
"""
|
|
179
|
+
Results from Callaway-Sant'Anna multiplier bootstrap inference.
|
|
180
|
+
|
|
181
|
+
Attributes
|
|
182
|
+
----------
|
|
183
|
+
n_bootstrap : int
|
|
184
|
+
Number of bootstrap iterations.
|
|
185
|
+
weight_type : str
|
|
186
|
+
Type of bootstrap weights used.
|
|
187
|
+
alpha : float
|
|
188
|
+
Significance level used for confidence intervals.
|
|
189
|
+
overall_att_se : float
|
|
190
|
+
Bootstrap standard error for overall ATT.
|
|
191
|
+
overall_att_ci : Tuple[float, float]
|
|
192
|
+
Bootstrap confidence interval for overall ATT.
|
|
193
|
+
overall_att_p_value : float
|
|
194
|
+
Bootstrap p-value for overall ATT.
|
|
195
|
+
group_time_ses : Dict[Tuple[Any, Any], float]
|
|
196
|
+
Bootstrap SEs for each ATT(g,t).
|
|
197
|
+
group_time_cis : Dict[Tuple[Any, Any], Tuple[float, float]]
|
|
198
|
+
Bootstrap CIs for each ATT(g,t).
|
|
199
|
+
group_time_p_values : Dict[Tuple[Any, Any], float]
|
|
200
|
+
Bootstrap p-values for each ATT(g,t).
|
|
201
|
+
event_study_ses : Optional[Dict[int, float]]
|
|
202
|
+
Bootstrap SEs for event study effects.
|
|
203
|
+
event_study_cis : Optional[Dict[int, Tuple[float, float]]]
|
|
204
|
+
Bootstrap CIs for event study effects.
|
|
205
|
+
event_study_p_values : Optional[Dict[int, float]]
|
|
206
|
+
Bootstrap p-values for event study effects.
|
|
207
|
+
group_effect_ses : Optional[Dict[Any, float]]
|
|
208
|
+
Bootstrap SEs for group effects.
|
|
209
|
+
group_effect_cis : Optional[Dict[Any, Tuple[float, float]]]
|
|
210
|
+
Bootstrap CIs for group effects.
|
|
211
|
+
group_effect_p_values : Optional[Dict[Any, float]]
|
|
212
|
+
Bootstrap p-values for group effects.
|
|
213
|
+
bootstrap_distribution : Optional[np.ndarray]
|
|
214
|
+
Full bootstrap distribution of overall ATT (if requested).
|
|
215
|
+
"""
|
|
216
|
+
n_bootstrap: int
|
|
217
|
+
weight_type: str
|
|
218
|
+
alpha: float
|
|
219
|
+
overall_att_se: float
|
|
220
|
+
overall_att_ci: Tuple[float, float]
|
|
221
|
+
overall_att_p_value: float
|
|
222
|
+
group_time_ses: Dict[Tuple[Any, Any], float]
|
|
223
|
+
group_time_cis: Dict[Tuple[Any, Any], Tuple[float, float]]
|
|
224
|
+
group_time_p_values: Dict[Tuple[Any, Any], float]
|
|
225
|
+
event_study_ses: Optional[Dict[int, float]] = None
|
|
226
|
+
event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
|
|
227
|
+
event_study_p_values: Optional[Dict[int, float]] = None
|
|
228
|
+
group_effect_ses: Optional[Dict[Any, float]] = None
|
|
229
|
+
group_effect_cis: Optional[Dict[Any, Tuple[float, float]]] = None
|
|
230
|
+
group_effect_p_values: Optional[Dict[Any, float]] = None
|
|
231
|
+
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# =============================================================================
|
|
235
|
+
# Bootstrap Mixin Class
|
|
236
|
+
# =============================================================================
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class CallawaySantAnnaBootstrapMixin:
|
|
240
|
+
"""
|
|
241
|
+
Mixin class providing bootstrap inference methods for CallawaySantAnna.
|
|
242
|
+
|
|
243
|
+
This class is not intended to be used standalone. It provides methods
|
|
244
|
+
that are used by the main CallawaySantAnna class for multiplier bootstrap
|
|
245
|
+
inference.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
# Type hints for attributes accessed from the main class
|
|
249
|
+
n_bootstrap: int
|
|
250
|
+
bootstrap_weight_type: str
|
|
251
|
+
alpha: float
|
|
252
|
+
seed: Optional[int]
|
|
253
|
+
anticipation: int
|
|
254
|
+
|
|
255
|
+
def _run_multiplier_bootstrap(
|
|
256
|
+
self,
|
|
257
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
258
|
+
influence_func_info: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
259
|
+
aggregate: Optional[str],
|
|
260
|
+
balance_e: Optional[int],
|
|
261
|
+
treatment_groups: List[Any],
|
|
262
|
+
time_periods: List[Any],
|
|
263
|
+
) -> CSBootstrapResults:
|
|
264
|
+
"""
|
|
265
|
+
Run multiplier bootstrap for inference on all parameters.
|
|
266
|
+
|
|
267
|
+
This implements the multiplier bootstrap procedure from Callaway & Sant'Anna (2021).
|
|
268
|
+
The key idea is to perturb the influence function contributions with random
|
|
269
|
+
weights at the cluster (unit) level, then recompute aggregations.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
group_time_effects : dict
|
|
274
|
+
Dictionary of ATT(g,t) effects with analytical SEs.
|
|
275
|
+
influence_func_info : dict
|
|
276
|
+
Dictionary mapping (g,t) to influence function information.
|
|
277
|
+
aggregate : str, optional
|
|
278
|
+
Type of aggregation requested.
|
|
279
|
+
balance_e : int, optional
|
|
280
|
+
Balance parameter for event study.
|
|
281
|
+
treatment_groups : list
|
|
282
|
+
List of treatment cohorts.
|
|
283
|
+
time_periods : list
|
|
284
|
+
List of time periods.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
CSBootstrapResults
|
|
289
|
+
Bootstrap inference results.
|
|
290
|
+
"""
|
|
291
|
+
# Warn about low bootstrap iterations
|
|
292
|
+
if self.n_bootstrap < 50:
|
|
293
|
+
warnings.warn(
|
|
294
|
+
f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
|
|
295
|
+
"for reliable inference. Percentile confidence intervals and p-values "
|
|
296
|
+
"may be unreliable with few iterations.",
|
|
297
|
+
UserWarning,
|
|
298
|
+
stacklevel=3,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
rng = np.random.default_rng(self.seed)
|
|
302
|
+
|
|
303
|
+
# Collect all unique units across all (g,t) combinations
|
|
304
|
+
all_units = set()
|
|
305
|
+
for (g, t), info in influence_func_info.items():
|
|
306
|
+
all_units.update(info['treated_units'])
|
|
307
|
+
all_units.update(info['control_units'])
|
|
308
|
+
all_units = sorted(all_units)
|
|
309
|
+
n_units = len(all_units)
|
|
310
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
311
|
+
|
|
312
|
+
# Get list of (g,t) pairs
|
|
313
|
+
gt_pairs = list(group_time_effects.keys())
|
|
314
|
+
n_gt = len(gt_pairs)
|
|
315
|
+
|
|
316
|
+
# Identify post-treatment (g,t) pairs for overall ATT
|
|
317
|
+
# Pre-treatment effects are for parallel trends assessment, not aggregated
|
|
318
|
+
post_treatment_mask = np.array([
|
|
319
|
+
t >= g - self.anticipation for (g, t) in gt_pairs
|
|
320
|
+
])
|
|
321
|
+
post_treatment_indices = np.where(post_treatment_mask)[0]
|
|
322
|
+
|
|
323
|
+
# Compute aggregation weights for overall ATT (post-treatment only)
|
|
324
|
+
all_n_treated = np.array([
|
|
325
|
+
group_time_effects[gt]['n_treated'] for gt in gt_pairs
|
|
326
|
+
], dtype=float)
|
|
327
|
+
post_n_treated = all_n_treated[post_treatment_mask]
|
|
328
|
+
|
|
329
|
+
# Flag to skip overall ATT aggregation when no post-treatment effects
|
|
330
|
+
# But continue bootstrap for per-effect SEs (pre-treatment effects need bootstrap SEs too)
|
|
331
|
+
skip_overall_aggregation = False
|
|
332
|
+
if len(post_treatment_indices) == 0:
|
|
333
|
+
warnings.warn(
|
|
334
|
+
"No post-treatment effects for bootstrap aggregation. "
|
|
335
|
+
"Overall ATT statistics will be NaN, but per-effect SEs will be computed.",
|
|
336
|
+
UserWarning,
|
|
337
|
+
stacklevel=2
|
|
338
|
+
)
|
|
339
|
+
skip_overall_aggregation = True
|
|
340
|
+
overall_weights_post = np.array([])
|
|
341
|
+
else:
|
|
342
|
+
overall_weights_post = post_n_treated / np.sum(post_n_treated)
|
|
343
|
+
|
|
344
|
+
# Original point estimates
|
|
345
|
+
original_atts = np.array([group_time_effects[gt]['effect'] for gt in gt_pairs])
|
|
346
|
+
if skip_overall_aggregation:
|
|
347
|
+
original_overall = np.nan
|
|
348
|
+
else:
|
|
349
|
+
original_overall = np.sum(overall_weights_post * original_atts[post_treatment_mask])
|
|
350
|
+
|
|
351
|
+
# Prepare event study and group aggregation info if needed
|
|
352
|
+
event_study_info = None
|
|
353
|
+
group_agg_info = None
|
|
354
|
+
|
|
355
|
+
if aggregate in ["event_study", "all"]:
|
|
356
|
+
event_study_info = self._prepare_event_study_aggregation(
|
|
357
|
+
gt_pairs, group_time_effects, balance_e
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
if aggregate in ["group", "all"]:
|
|
361
|
+
group_agg_info = self._prepare_group_aggregation(
|
|
362
|
+
gt_pairs, group_time_effects, treatment_groups
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Pre-compute unit index arrays for each (g,t) pair (done once, not per iteration)
|
|
366
|
+
gt_treated_indices = []
|
|
367
|
+
gt_control_indices = []
|
|
368
|
+
gt_treated_inf = []
|
|
369
|
+
gt_control_inf = []
|
|
370
|
+
|
|
371
|
+
for j, gt in enumerate(gt_pairs):
|
|
372
|
+
info = influence_func_info[gt]
|
|
373
|
+
treated_idx = np.array([unit_to_idx[u] for u in info['treated_units']])
|
|
374
|
+
control_idx = np.array([unit_to_idx[u] for u in info['control_units']])
|
|
375
|
+
gt_treated_indices.append(treated_idx)
|
|
376
|
+
gt_control_indices.append(control_idx)
|
|
377
|
+
gt_treated_inf.append(np.asarray(info['treated_inf']))
|
|
378
|
+
gt_control_inf.append(np.asarray(info['control_inf']))
|
|
379
|
+
|
|
380
|
+
# Generate ALL bootstrap weights upfront: shape (n_bootstrap, n_units)
|
|
381
|
+
# This is much faster than generating one at a time
|
|
382
|
+
all_bootstrap_weights = _generate_bootstrap_weights_batch(
|
|
383
|
+
self.n_bootstrap, n_units, self.bootstrap_weight_type, rng
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Vectorized bootstrap ATT(g,t) computation
|
|
387
|
+
# Compute all bootstrap ATTs for all (g,t) pairs using matrix operations
|
|
388
|
+
bootstrap_atts_gt = np.zeros((self.n_bootstrap, n_gt))
|
|
389
|
+
|
|
390
|
+
for j in range(n_gt):
|
|
391
|
+
treated_idx = gt_treated_indices[j]
|
|
392
|
+
control_idx = gt_control_indices[j]
|
|
393
|
+
treated_inf = gt_treated_inf[j]
|
|
394
|
+
control_inf = gt_control_inf[j]
|
|
395
|
+
|
|
396
|
+
# Extract weights for this (g,t)'s units across all bootstrap iterations
|
|
397
|
+
# Shape: (n_bootstrap, n_treated) and (n_bootstrap, n_control)
|
|
398
|
+
treated_weights = all_bootstrap_weights[:, treated_idx]
|
|
399
|
+
control_weights = all_bootstrap_weights[:, control_idx]
|
|
400
|
+
|
|
401
|
+
# Vectorized perturbation: matrix-vector multiply
|
|
402
|
+
# Shape: (n_bootstrap,)
|
|
403
|
+
# Suppress RuntimeWarnings for edge cases (small samples, extreme weights)
|
|
404
|
+
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
|
|
405
|
+
perturbations = (
|
|
406
|
+
treated_weights @ treated_inf +
|
|
407
|
+
control_weights @ control_inf
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Let non-finite values propagate - they will be handled at statistics computation
|
|
411
|
+
bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
|
|
412
|
+
|
|
413
|
+
# Vectorized overall ATT: matrix-vector multiply (post-treatment only)
|
|
414
|
+
# Shape: (n_bootstrap,)
|
|
415
|
+
if skip_overall_aggregation:
|
|
416
|
+
bootstrap_overall = np.full(self.n_bootstrap, np.nan)
|
|
417
|
+
else:
|
|
418
|
+
# Suppress RuntimeWarnings for edge cases - non-finite values handled at statistics computation
|
|
419
|
+
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
|
|
420
|
+
bootstrap_overall = bootstrap_atts_gt[:, post_treatment_indices] @ overall_weights_post
|
|
421
|
+
|
|
422
|
+
# Vectorized event study aggregation
|
|
423
|
+
# Non-finite values handled at statistics computation stage
|
|
424
|
+
rel_periods: List[int] = []
|
|
425
|
+
bootstrap_event_study: Optional[Dict[int, np.ndarray]] = None
|
|
426
|
+
if event_study_info is not None:
|
|
427
|
+
rel_periods = sorted(event_study_info.keys())
|
|
428
|
+
bootstrap_event_study = {}
|
|
429
|
+
for e in rel_periods:
|
|
430
|
+
agg_info = event_study_info[e]
|
|
431
|
+
gt_indices = agg_info['gt_indices']
|
|
432
|
+
weights = agg_info['weights']
|
|
433
|
+
# Vectorized: select columns and multiply by weights
|
|
434
|
+
# Suppress RuntimeWarnings for edge cases
|
|
435
|
+
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
|
|
436
|
+
bootstrap_event_study[e] = bootstrap_atts_gt[:, gt_indices] @ weights
|
|
437
|
+
|
|
438
|
+
# Vectorized group aggregation
|
|
439
|
+
# Non-finite values handled at statistics computation stage
|
|
440
|
+
group_list: List[Any] = []
|
|
441
|
+
bootstrap_group: Optional[Dict[Any, np.ndarray]] = None
|
|
442
|
+
if group_agg_info is not None:
|
|
443
|
+
group_list = sorted(group_agg_info.keys())
|
|
444
|
+
bootstrap_group = {}
|
|
445
|
+
for g in group_list:
|
|
446
|
+
agg_info = group_agg_info[g]
|
|
447
|
+
gt_indices = agg_info['gt_indices']
|
|
448
|
+
weights = agg_info['weights']
|
|
449
|
+
# Suppress RuntimeWarnings for edge cases
|
|
450
|
+
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
|
|
451
|
+
bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights
|
|
452
|
+
|
|
453
|
+
# Compute bootstrap statistics for ATT(g,t)
|
|
454
|
+
gt_ses = {}
|
|
455
|
+
gt_cis = {}
|
|
456
|
+
gt_p_values = {}
|
|
457
|
+
|
|
458
|
+
for j, gt in enumerate(gt_pairs):
|
|
459
|
+
se, ci, p_value = self._compute_effect_bootstrap_stats(
|
|
460
|
+
original_atts[j], bootstrap_atts_gt[:, j],
|
|
461
|
+
context=f"ATT(g={gt[0]}, t={gt[1]})"
|
|
462
|
+
)
|
|
463
|
+
gt_ses[gt] = se
|
|
464
|
+
gt_cis[gt] = ci
|
|
465
|
+
gt_p_values[gt] = p_value
|
|
466
|
+
|
|
467
|
+
# Compute bootstrap statistics for overall ATT
|
|
468
|
+
if skip_overall_aggregation:
|
|
469
|
+
overall_se = np.nan
|
|
470
|
+
overall_ci = (np.nan, np.nan)
|
|
471
|
+
overall_p_value = np.nan
|
|
472
|
+
else:
|
|
473
|
+
overall_se, overall_ci, overall_p_value = self._compute_effect_bootstrap_stats(
|
|
474
|
+
original_overall, bootstrap_overall,
|
|
475
|
+
context="overall ATT"
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Compute bootstrap statistics for event study effects
|
|
479
|
+
event_study_ses = None
|
|
480
|
+
event_study_cis = None
|
|
481
|
+
event_study_p_values = None
|
|
482
|
+
|
|
483
|
+
if bootstrap_event_study is not None and event_study_info is not None:
|
|
484
|
+
event_study_ses = {}
|
|
485
|
+
event_study_cis = {}
|
|
486
|
+
event_study_p_values = {}
|
|
487
|
+
|
|
488
|
+
for e in rel_periods:
|
|
489
|
+
se, ci, p_value = self._compute_effect_bootstrap_stats(
|
|
490
|
+
event_study_info[e]['effect'], bootstrap_event_study[e],
|
|
491
|
+
context=f"event study (e={e})"
|
|
492
|
+
)
|
|
493
|
+
event_study_ses[e] = se
|
|
494
|
+
event_study_cis[e] = ci
|
|
495
|
+
event_study_p_values[e] = p_value
|
|
496
|
+
|
|
497
|
+
# Compute bootstrap statistics for group effects
|
|
498
|
+
group_effect_ses = None
|
|
499
|
+
group_effect_cis = None
|
|
500
|
+
group_effect_p_values = None
|
|
501
|
+
|
|
502
|
+
if bootstrap_group is not None and group_agg_info is not None:
|
|
503
|
+
group_effect_ses = {}
|
|
504
|
+
group_effect_cis = {}
|
|
505
|
+
group_effect_p_values = {}
|
|
506
|
+
|
|
507
|
+
for g in group_list:
|
|
508
|
+
se, ci, p_value = self._compute_effect_bootstrap_stats(
|
|
509
|
+
group_agg_info[g]['effect'], bootstrap_group[g],
|
|
510
|
+
context=f"group effect (g={g})"
|
|
511
|
+
)
|
|
512
|
+
group_effect_ses[g] = se
|
|
513
|
+
group_effect_cis[g] = ci
|
|
514
|
+
group_effect_p_values[g] = p_value
|
|
515
|
+
|
|
516
|
+
return CSBootstrapResults(
|
|
517
|
+
n_bootstrap=self.n_bootstrap,
|
|
518
|
+
weight_type=self.bootstrap_weight_type,
|
|
519
|
+
alpha=self.alpha,
|
|
520
|
+
overall_att_se=overall_se,
|
|
521
|
+
overall_att_ci=overall_ci,
|
|
522
|
+
overall_att_p_value=overall_p_value,
|
|
523
|
+
group_time_ses=gt_ses,
|
|
524
|
+
group_time_cis=gt_cis,
|
|
525
|
+
group_time_p_values=gt_p_values,
|
|
526
|
+
event_study_ses=event_study_ses,
|
|
527
|
+
event_study_cis=event_study_cis,
|
|
528
|
+
event_study_p_values=event_study_p_values,
|
|
529
|
+
group_effect_ses=group_effect_ses,
|
|
530
|
+
group_effect_cis=group_effect_cis,
|
|
531
|
+
group_effect_p_values=group_effect_p_values,
|
|
532
|
+
bootstrap_distribution=bootstrap_overall,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
def _prepare_event_study_aggregation(
|
|
536
|
+
self,
|
|
537
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
538
|
+
group_time_effects: Dict,
|
|
539
|
+
balance_e: Optional[int],
|
|
540
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
541
|
+
"""Prepare aggregation info for event study bootstrap."""
|
|
542
|
+
# Organize by relative time
|
|
543
|
+
effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}
|
|
544
|
+
|
|
545
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
546
|
+
e = t - g
|
|
547
|
+
if e not in effects_by_e:
|
|
548
|
+
effects_by_e[e] = []
|
|
549
|
+
effects_by_e[e].append((
|
|
550
|
+
j, # index in gt_pairs
|
|
551
|
+
group_time_effects[(g, t)]['effect'],
|
|
552
|
+
group_time_effects[(g, t)]['n_treated']
|
|
553
|
+
))
|
|
554
|
+
|
|
555
|
+
# Balance if requested
|
|
556
|
+
if balance_e is not None:
|
|
557
|
+
groups_at_e = set()
|
|
558
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
559
|
+
if t - g == balance_e:
|
|
560
|
+
groups_at_e.add(g)
|
|
561
|
+
|
|
562
|
+
balanced_effects: Dict[int, List[Tuple[int, float, float]]] = {}
|
|
563
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
564
|
+
if g in groups_at_e:
|
|
565
|
+
e = t - g
|
|
566
|
+
if e not in balanced_effects:
|
|
567
|
+
balanced_effects[e] = []
|
|
568
|
+
balanced_effects[e].append((
|
|
569
|
+
j,
|
|
570
|
+
group_time_effects[(g, t)]['effect'],
|
|
571
|
+
group_time_effects[(g, t)]['n_treated']
|
|
572
|
+
))
|
|
573
|
+
effects_by_e = balanced_effects
|
|
574
|
+
|
|
575
|
+
# Compute aggregation weights
|
|
576
|
+
result = {}
|
|
577
|
+
for e, effect_list in effects_by_e.items():
|
|
578
|
+
indices = np.array([x[0] for x in effect_list])
|
|
579
|
+
effects = np.array([x[1] for x in effect_list])
|
|
580
|
+
n_treated = np.array([x[2] for x in effect_list], dtype=float)
|
|
581
|
+
|
|
582
|
+
weights = n_treated / np.sum(n_treated)
|
|
583
|
+
agg_effect = np.sum(weights * effects)
|
|
584
|
+
|
|
585
|
+
result[e] = {
|
|
586
|
+
'gt_indices': indices,
|
|
587
|
+
'weights': weights,
|
|
588
|
+
'effect': agg_effect,
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
return result
|
|
592
|
+
|
|
593
|
+
def _prepare_group_aggregation(
|
|
594
|
+
self,
|
|
595
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
596
|
+
group_time_effects: Dict,
|
|
597
|
+
treatment_groups: List[Any],
|
|
598
|
+
) -> Dict[Any, Dict[str, Any]]:
|
|
599
|
+
"""Prepare aggregation info for group-level bootstrap."""
|
|
600
|
+
result = {}
|
|
601
|
+
|
|
602
|
+
for g in treatment_groups:
|
|
603
|
+
# Get all effects for this group (post-treatment only: t >= g - anticipation)
|
|
604
|
+
group_data = []
|
|
605
|
+
for j, (gg, t) in enumerate(gt_pairs):
|
|
606
|
+
if gg == g and t >= g - self.anticipation:
|
|
607
|
+
group_data.append((
|
|
608
|
+
j,
|
|
609
|
+
group_time_effects[(gg, t)]['effect'],
|
|
610
|
+
))
|
|
611
|
+
|
|
612
|
+
if not group_data:
|
|
613
|
+
continue
|
|
614
|
+
|
|
615
|
+
indices = np.array([x[0] for x in group_data])
|
|
616
|
+
effects = np.array([x[1] for x in group_data])
|
|
617
|
+
|
|
618
|
+
# Equal weights across time periods
|
|
619
|
+
weights = np.ones(len(effects)) / len(effects)
|
|
620
|
+
agg_effect = np.sum(weights * effects)
|
|
621
|
+
|
|
622
|
+
result[g] = {
|
|
623
|
+
'gt_indices': indices,
|
|
624
|
+
'weights': weights,
|
|
625
|
+
'effect': agg_effect,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
return result
|
|
629
|
+
|
|
630
|
+
def _compute_percentile_ci(
|
|
631
|
+
self,
|
|
632
|
+
boot_dist: np.ndarray,
|
|
633
|
+
alpha: float,
|
|
634
|
+
) -> Tuple[float, float]:
|
|
635
|
+
"""Compute percentile confidence interval from bootstrap distribution."""
|
|
636
|
+
lower = float(np.percentile(boot_dist, alpha / 2 * 100))
|
|
637
|
+
upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
|
|
638
|
+
return (lower, upper)
|
|
639
|
+
|
|
640
|
+
def _compute_bootstrap_pvalue(
|
|
641
|
+
self,
|
|
642
|
+
original_effect: float,
|
|
643
|
+
boot_dist: np.ndarray,
|
|
644
|
+
n_valid: Optional[int] = None,
|
|
645
|
+
) -> float:
|
|
646
|
+
"""
|
|
647
|
+
Compute two-sided bootstrap p-value.
|
|
648
|
+
|
|
649
|
+
Uses the percentile method: p-value is the proportion of bootstrap
|
|
650
|
+
estimates on the opposite side of zero from the original estimate,
|
|
651
|
+
doubled for two-sided test.
|
|
652
|
+
|
|
653
|
+
Parameters
|
|
654
|
+
----------
|
|
655
|
+
original_effect : float
|
|
656
|
+
Original point estimate.
|
|
657
|
+
boot_dist : np.ndarray
|
|
658
|
+
Bootstrap distribution of the effect.
|
|
659
|
+
n_valid : int, optional
|
|
660
|
+
Number of valid bootstrap samples. If None, uses self.n_bootstrap.
|
|
661
|
+
Use this when boot_dist has already been filtered for non-finite values
|
|
662
|
+
to ensure the p-value floor is based on the actual valid sample count.
|
|
663
|
+
|
|
664
|
+
Returns
|
|
665
|
+
-------
|
|
666
|
+
float
|
|
667
|
+
Two-sided bootstrap p-value.
|
|
668
|
+
"""
|
|
669
|
+
if original_effect >= 0:
|
|
670
|
+
# Proportion of bootstrap estimates <= 0
|
|
671
|
+
p_one_sided = np.mean(boot_dist <= 0)
|
|
672
|
+
else:
|
|
673
|
+
# Proportion of bootstrap estimates >= 0
|
|
674
|
+
p_one_sided = np.mean(boot_dist >= 0)
|
|
675
|
+
|
|
676
|
+
# Two-sided p-value
|
|
677
|
+
p_value = min(2 * p_one_sided, 1.0)
|
|
678
|
+
|
|
679
|
+
# Ensure minimum p-value using n_valid if provided, otherwise n_bootstrap
|
|
680
|
+
n_for_floor = n_valid if n_valid is not None else self.n_bootstrap
|
|
681
|
+
p_value = max(p_value, 1 / (n_for_floor + 1))
|
|
682
|
+
|
|
683
|
+
return float(p_value)
|
|
684
|
+
|
|
685
|
+
def _compute_effect_bootstrap_stats(
|
|
686
|
+
self,
|
|
687
|
+
original_effect: float,
|
|
688
|
+
boot_dist: np.ndarray,
|
|
689
|
+
context: str = "bootstrap distribution",
|
|
690
|
+
) -> Tuple[float, Tuple[float, float], float]:
|
|
691
|
+
"""
|
|
692
|
+
Compute bootstrap statistics for a single effect.
|
|
693
|
+
|
|
694
|
+
Non-finite bootstrap samples are dropped and a warning is issued if any
|
|
695
|
+
are present. If too few valid samples remain (<50%), returns NaN for all
|
|
696
|
+
statistics to signal invalid inference.
|
|
697
|
+
|
|
698
|
+
Parameters
|
|
699
|
+
----------
|
|
700
|
+
original_effect : float
|
|
701
|
+
Original point estimate.
|
|
702
|
+
boot_dist : np.ndarray
|
|
703
|
+
Bootstrap distribution of the effect.
|
|
704
|
+
context : str, optional
|
|
705
|
+
Description for warning messages, by default "bootstrap distribution".
|
|
706
|
+
|
|
707
|
+
Returns
|
|
708
|
+
-------
|
|
709
|
+
se : float
|
|
710
|
+
Bootstrap standard error.
|
|
711
|
+
ci : Tuple[float, float]
|
|
712
|
+
Percentile confidence interval.
|
|
713
|
+
p_value : float
|
|
714
|
+
Bootstrap p-value.
|
|
715
|
+
"""
|
|
716
|
+
# Filter out non-finite values
|
|
717
|
+
finite_mask = np.isfinite(boot_dist)
|
|
718
|
+
n_valid = np.sum(finite_mask)
|
|
719
|
+
n_total = len(boot_dist)
|
|
720
|
+
|
|
721
|
+
if n_valid < n_total:
|
|
722
|
+
import warnings
|
|
723
|
+
n_nonfinite = n_total - n_valid
|
|
724
|
+
warnings.warn(
|
|
725
|
+
f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples in {context}. "
|
|
726
|
+
"This may occur with very small samples or extreme weights. "
|
|
727
|
+
"Bootstrap estimates based on remaining valid samples.",
|
|
728
|
+
RuntimeWarning,
|
|
729
|
+
stacklevel=3
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Check if we have enough valid samples
|
|
733
|
+
if n_valid < n_total * 0.5:
|
|
734
|
+
import warnings
|
|
735
|
+
warnings.warn(
|
|
736
|
+
f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. "
|
|
737
|
+
"Returning NaN for SE/CI/p-value to signal invalid inference.",
|
|
738
|
+
RuntimeWarning,
|
|
739
|
+
stacklevel=3
|
|
740
|
+
)
|
|
741
|
+
return np.nan, (np.nan, np.nan), np.nan
|
|
742
|
+
|
|
743
|
+
# Use only valid samples
|
|
744
|
+
valid_dist = boot_dist[finite_mask]
|
|
745
|
+
n_valid_bootstrap = len(valid_dist)
|
|
746
|
+
|
|
747
|
+
se = float(np.std(valid_dist, ddof=1))
|
|
748
|
+
ci = self._compute_percentile_ci(valid_dist, self.alpha)
|
|
749
|
+
|
|
750
|
+
# Compute p-value using shared method with correct floor based on valid sample count
|
|
751
|
+
p_value = self._compute_bootstrap_pvalue(original_effect, valid_dist, n_valid=n_valid_bootstrap)
|
|
752
|
+
|
|
753
|
+
return se, ci, p_value
|