diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,752 @@
1
+ """
2
+ Bootstrap inference for Callaway-Sant'Anna estimator.
3
+
4
+ This module provides the bootstrap results container and the mixin class
5
+ with bootstrap inference methods. Weight generation and statistical helpers
6
+ are in :mod:`diff_diff.bootstrap_utils`.
7
+ """
8
+
9
+ import warnings
10
+ from dataclasses import dataclass, field
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
12
+
13
+ import numpy as np
14
+
15
+ from diff_diff.bootstrap_utils import (
16
+ compute_bootstrap_pvalue as _compute_bootstrap_pvalue_func,
17
+ )
18
+ from diff_diff.bootstrap_utils import (
19
+ compute_effect_bootstrap_stats as _compute_effect_bootstrap_stats_func,
20
+ )
21
+ from diff_diff.bootstrap_utils import (
22
+ compute_effect_bootstrap_stats_batch as _compute_effect_bootstrap_stats_batch_func,
23
+ )
24
+ from diff_diff.bootstrap_utils import (
25
+ compute_percentile_ci as _compute_percentile_ci_func,
26
+ )
27
+ from diff_diff.bootstrap_utils import (
28
+ generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch,
29
+ )
30
+ from diff_diff.bootstrap_utils import (
31
+ generate_survey_multiplier_weights_batch as _generate_survey_multiplier_weights_batch,
32
+ )
33
+
34
+ if TYPE_CHECKING:
35
+ import pandas as pd
36
+
37
+ from diff_diff.staggered_aggregation import PrecomputedData
38
+
39
+
40
+ # =============================================================================
41
+ # Bootstrap Results Container
42
+ # =============================================================================
43
+
44
+
45
+ @dataclass
46
+ class CSBootstrapResults:
47
+ """
48
+ Results from Callaway-Sant'Anna multiplier bootstrap inference.
49
+
50
+ Attributes
51
+ ----------
52
+ n_bootstrap : int
53
+ Number of bootstrap iterations.
54
+ weight_type : str
55
+ Type of bootstrap weights used.
56
+ alpha : float
57
+ Significance level used for confidence intervals.
58
+ overall_att_se : float
59
+ Bootstrap standard error for overall ATT.
60
+ overall_att_ci : Tuple[float, float]
61
+ Bootstrap confidence interval for overall ATT.
62
+ overall_att_p_value : float
63
+ Bootstrap p-value for overall ATT.
64
+ group_time_ses : Dict[Tuple[Any, Any], float]
65
+ Bootstrap SEs for each ATT(g,t).
66
+ group_time_cis : Dict[Tuple[Any, Any], Tuple[float, float]]
67
+ Bootstrap CIs for each ATT(g,t).
68
+ group_time_p_values : Dict[Tuple[Any, Any], float]
69
+ Bootstrap p-values for each ATT(g,t).
70
+ event_study_ses : Optional[Dict[int, float]]
71
+ Bootstrap SEs for event study effects.
72
+ event_study_cis : Optional[Dict[int, Tuple[float, float]]]
73
+ Bootstrap CIs for event study effects.
74
+ event_study_p_values : Optional[Dict[int, float]]
75
+ Bootstrap p-values for event study effects.
76
+ group_effect_ses : Optional[Dict[Any, float]]
77
+ Bootstrap SEs for group effects.
78
+ group_effect_cis : Optional[Dict[Any, Tuple[float, float]]]
79
+ Bootstrap CIs for group effects.
80
+ group_effect_p_values : Optional[Dict[Any, float]]
81
+ Bootstrap p-values for group effects.
82
+ bootstrap_distribution : Optional[np.ndarray]
83
+ Full bootstrap distribution of overall ATT (if requested).
84
+ """
85
+
86
+ n_bootstrap: int
87
+ weight_type: str
88
+ alpha: float
89
+ overall_att_se: float
90
+ overall_att_ci: Tuple[float, float]
91
+ overall_att_p_value: float
92
+ group_time_ses: Dict[Tuple[Any, Any], float]
93
+ group_time_cis: Dict[Tuple[Any, Any], Tuple[float, float]]
94
+ group_time_p_values: Dict[Tuple[Any, Any], float]
95
+ event_study_ses: Optional[Dict[int, float]] = None
96
+ event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
97
+ event_study_p_values: Optional[Dict[int, float]] = None
98
+ group_effect_ses: Optional[Dict[Any, float]] = None
99
+ group_effect_cis: Optional[Dict[Any, Tuple[float, float]]] = None
100
+ group_effect_p_values: Optional[Dict[Any, float]] = None
101
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
102
+ cband_crit_value: Optional[float] = None
103
+
104
+
105
+ # =============================================================================
106
+ # Bootstrap Mixin Class
107
+ # =============================================================================
108
+
109
+
110
+ class CallawaySantAnnaBootstrapMixin:
111
+ """
112
+ Mixin class providing bootstrap inference methods for CallawaySantAnna.
113
+
114
+ This class is not intended to be used standalone. It provides methods
115
+ that are used by the main CallawaySantAnna class for multiplier bootstrap
116
+ inference.
117
+ """
118
+
119
+ # Type hints for attributes accessed from the main class
120
+ n_bootstrap: int
121
+ bootstrap_weights: str
122
+ alpha: float
123
+ seed: Optional[int]
124
+ anticipation: int
125
+
126
+ if TYPE_CHECKING:
127
+
128
+ def _compute_combined_influence_function(
129
+ self,
130
+ gt_pairs: List[Tuple[Any, Any]],
131
+ weights: np.ndarray,
132
+ effects: np.ndarray,
133
+ groups_for_gt: np.ndarray,
134
+ influence_func_info: Dict,
135
+ df: "pd.DataFrame",
136
+ unit: str,
137
+ precomputed: Optional["PrecomputedData"] = None,
138
+ global_unit_to_idx: Optional[Dict[Any, int]] = None,
139
+ n_global_units: Optional[int] = None,
140
+ ) -> Tuple[np.ndarray, Optional[List]]: ...
141
+
142
+ def _run_multiplier_bootstrap(
143
+ self,
144
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
145
+ influence_func_info: Dict[Tuple[Any, Any], Dict[str, Any]],
146
+ aggregate: Optional[str],
147
+ balance_e: Optional[int],
148
+ treatment_groups: List[Any],
149
+ time_periods: List[Any],
150
+ df: Any = None,
151
+ unit: Optional[str] = None,
152
+ precomputed: Any = None,
153
+ cband: bool = True,
154
+ ) -> CSBootstrapResults:
155
+ """
156
+ Run multiplier bootstrap for inference on all parameters.
157
+
158
+ This implements the multiplier bootstrap procedure from Callaway & Sant'Anna (2021).
159
+ The key idea is to perturb the influence function contributions with random
160
+ weights at the cluster (unit) level, then recompute aggregations.
161
+
162
+ Parameters
163
+ ----------
164
+ group_time_effects : dict
165
+ Dictionary of ATT(g,t) effects with analytical SEs.
166
+ influence_func_info : dict
167
+ Dictionary mapping (g,t) to influence function information.
168
+ aggregate : str, optional
169
+ Type of aggregation requested.
170
+ balance_e : int, optional
171
+ Balance parameter for event study.
172
+ treatment_groups : list
173
+ List of treatment cohorts.
174
+ time_periods : list
175
+ List of time periods.
176
+
177
+ Returns
178
+ -------
179
+ CSBootstrapResults
180
+ Bootstrap inference results.
181
+ """
182
+ # Warn about low bootstrap iterations
183
+ if self.n_bootstrap < 50:
184
+ warnings.warn(
185
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
186
+ "for reliable inference. Percentile confidence intervals and p-values "
187
+ "may be unreliable with few iterations.",
188
+ UserWarning,
189
+ stacklevel=3,
190
+ )
191
+
192
+ rng = np.random.default_rng(self.seed)
193
+
194
+ # Use global unit set for correct pg = n_g / N_total scaling.
195
+ # Without this, pg is overestimated in unbalanced panels where some
196
+ # units don't appear in any influence function.
197
+ if precomputed is not None:
198
+ all_units = precomputed["all_units"]
199
+ n_units = precomputed.get("canonical_size", len(all_units))
200
+ unit_to_idx = precomputed["unit_to_idx"] # None for RCS
201
+ else:
202
+ # Fallback: collect units from influence functions
203
+ all_units_set = set()
204
+ for (g, t), info in influence_func_info.items():
205
+ all_units_set.update(info["treated_units"])
206
+ all_units_set.update(info["control_units"])
207
+ all_units = sorted(all_units_set)
208
+ # Use global N from dataframe when available
209
+ n_units = (
210
+ df[unit].nunique() if (df is not None and unit is not None) else len(all_units)
211
+ )
212
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
213
+
214
+ # Get list of (g,t) pairs that have influence function info
215
+ # (skip zero-mass cells that recorded NaN ATT without IF)
216
+ gt_pairs = [gt for gt in group_time_effects.keys() if gt in influence_func_info]
217
+ n_gt = len(gt_pairs)
218
+
219
+ # Identify post-treatment (g,t) pairs for overall ATT
220
+ # Pre-treatment effects are for parallel trends assessment, not aggregated
221
+ post_treatment_mask = np.array([t >= g - self.anticipation for (g, t) in gt_pairs])
222
+ post_treatment_indices = np.where(post_treatment_mask)[0]
223
+
224
+ # Compute aggregation weights for overall ATT (post-treatment only)
225
+ # When survey weights are present, use fixed cohort survey masses
226
+ # (from precomputed survey_weights × unit_cohorts), matching the
227
+ # analytical _aggregate_simple() path in staggered_aggregation.py.
228
+ # Do NOT use per-cell survey_weight_sum (which varies by cell on
229
+ # unbalanced panels).
230
+ survey_w = precomputed.get("survey_weights") if precomputed is not None else None
231
+ if survey_w is not None:
232
+ unit_cohorts = precomputed["unit_cohorts"]
233
+ # Precompute fixed cohort masses (same formula as _aggregate_simple)
234
+ _cohort_mass_cache: dict = {}
235
+ for gt in gt_pairs:
236
+ g = gt[0]
237
+ if g not in _cohort_mass_cache:
238
+ _cohort_mass_cache[g] = float(np.sum(survey_w[unit_cohorts == g]))
239
+ all_n_treated = np.array([_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float)
240
+ else:
241
+ # Use agg_weight if available (RCS: fixed cohort mass);
242
+ # fall back to n_treated for panel data
243
+ all_n_treated = np.array(
244
+ [
245
+ group_time_effects[gt].get("agg_weight", group_time_effects[gt]["n_treated"])
246
+ for gt in gt_pairs
247
+ ],
248
+ dtype=float,
249
+ )
250
+ post_n_treated = all_n_treated[post_treatment_mask]
251
+
252
+ # Filter out NaN ATT(g,t) cells from overall aggregation (matches analytical path)
253
+ post_effects_raw = np.array(
254
+ [group_time_effects[gt_pairs[i]]["effect"] for i in post_treatment_indices]
255
+ )
256
+ finite_post = np.isfinite(post_effects_raw)
257
+ if not np.all(finite_post):
258
+ post_treatment_indices = post_treatment_indices[finite_post]
259
+ post_n_treated = post_n_treated[finite_post]
260
+
261
+ # Flag to skip overall ATT aggregation when no post-treatment effects
262
+ # But continue bootstrap for per-effect SEs (pre-treatment effects need bootstrap SEs too)
263
+ skip_overall_aggregation = False
264
+ if len(post_treatment_indices) == 0:
265
+ warnings.warn(
266
+ "No post-treatment effects for bootstrap aggregation. "
267
+ "Overall ATT statistics will be NaN, but per-effect SEs will be computed.",
268
+ UserWarning,
269
+ stacklevel=2,
270
+ )
271
+ skip_overall_aggregation = True
272
+ overall_weights_post = np.array([])
273
+ else:
274
+ overall_weights_post = post_n_treated / np.sum(post_n_treated)
275
+
276
+ # Original point estimates
277
+ original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])
278
+ if skip_overall_aggregation:
279
+ original_overall = np.nan
280
+ else:
281
+ original_overall = np.sum(overall_weights_post * original_atts[post_treatment_indices])
282
+
283
+ # Prepare event study and group aggregation info if needed
284
+ event_study_info = None
285
+ group_agg_info = None
286
+
287
+ if aggregate in ["event_study", "all"]:
288
+ event_study_info = self._prepare_event_study_aggregation(
289
+ gt_pairs,
290
+ group_time_effects,
291
+ balance_e,
292
+ influence_func_info=influence_func_info,
293
+ df=df,
294
+ unit=unit,
295
+ precomputed=precomputed,
296
+ global_unit_to_idx=unit_to_idx,
297
+ n_global_units=n_units,
298
+ )
299
+
300
+ if aggregate in ["group", "all"]:
301
+ group_agg_info = self._prepare_group_aggregation(
302
+ gt_pairs, group_time_effects, treatment_groups
303
+ )
304
+
305
+ # Pre-compute unit index arrays for each (g,t) pair (done once, not per iteration)
306
+ gt_treated_indices = []
307
+ gt_control_indices = []
308
+ gt_treated_inf = []
309
+ gt_control_inf = []
310
+
311
+ for j, gt in enumerate(gt_pairs):
312
+ info = influence_func_info[gt]
313
+ gt_treated_indices.append(info["treated_idx"])
314
+ gt_control_indices.append(info["control_idx"])
315
+ gt_treated_inf.append(np.asarray(info["treated_inf"]))
316
+ gt_control_inf.append(np.asarray(info["control_inf"]))
317
+
318
+ # Generate bootstrap weights — PSU-level when survey design is present,
319
+ # unit-level otherwise.
320
+ resolved_survey_unit = (
321
+ precomputed.get("resolved_survey_unit") if precomputed is not None else None
322
+ )
323
+ _use_survey_bootstrap = resolved_survey_unit is not None and (
324
+ resolved_survey_unit.strata is not None
325
+ or resolved_survey_unit.psu is not None
326
+ or resolved_survey_unit.fpc is not None
327
+ )
328
+
329
+ if _use_survey_bootstrap:
330
+ # PSU-level multiplier weights
331
+ psu_weights, psu_ids = _generate_survey_multiplier_weights_batch(
332
+ self.n_bootstrap, resolved_survey_unit, self.bootstrap_weights, rng
333
+ )
334
+ # Build unit → PSU column map
335
+ if resolved_survey_unit.psu is not None:
336
+ unit_psu = resolved_survey_unit.psu
337
+ psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)}
338
+ unit_to_psu_col = np.array(
339
+ [psu_id_to_col[int(unit_psu[i])] for i in range(n_units)]
340
+ )
341
+ else:
342
+ # Each unit is its own PSU — identity mapping
343
+ unit_to_psu_col = np.arange(n_units)
344
+
345
+ # Expand PSU weights to unit level for per-(g,t) perturbation
346
+ # Shape: (n_bootstrap, n_units)
347
+ all_bootstrap_weights = psu_weights[:, unit_to_psu_col]
348
+ else:
349
+ # Standard unit-level weights (no survey or weights-only)
350
+ all_bootstrap_weights = _generate_bootstrap_weights_batch(
351
+ self.n_bootstrap, n_units, self.bootstrap_weights, rng
352
+ )
353
+
354
+ # Vectorized bootstrap ATT(g,t) computation
355
+ # Compute all bootstrap ATTs for all (g,t) pairs using matrix operations
356
+ bootstrap_atts_gt = np.zeros((self.n_bootstrap, n_gt))
357
+
358
+ for j in range(n_gt):
359
+ treated_idx = gt_treated_indices[j]
360
+ control_idx = gt_control_indices[j]
361
+ treated_inf = gt_treated_inf[j]
362
+ control_inf = gt_control_inf[j]
363
+
364
+ # Extract weights for this (g,t)'s units across all bootstrap iterations
365
+ # Shape: (n_bootstrap, n_treated) and (n_bootstrap, n_control)
366
+ treated_weights = all_bootstrap_weights[:, treated_idx]
367
+ control_weights = all_bootstrap_weights[:, control_idx]
368
+
369
+ # Vectorized perturbation: matrix-vector multiply
370
+ # Shape: (n_bootstrap,)
371
+ # Suppress RuntimeWarnings for edge cases (small samples, extreme weights)
372
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
373
+ perturbations = treated_weights @ treated_inf + control_weights @ control_inf
374
+
375
+ # Let non-finite values propagate - they will be handled at statistics computation
376
+ bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
377
+
378
+ # Vectorized overall ATT using combined IF (includes WIF)
379
+ # Shape: (n_bootstrap,)
380
+ if skip_overall_aggregation:
381
+ bootstrap_overall = np.full(self.n_bootstrap, np.nan)
382
+ else:
383
+ # Use combined IF (standard IF + WIF) for proper bootstrap
384
+ post_gt_pairs = [gt_pairs[i] for i in post_treatment_indices]
385
+ post_groups = np.array([gt_pairs[i][0] for i in post_treatment_indices])
386
+ post_effects = original_atts[post_treatment_indices]
387
+ overall_combined_if, _ = self._compute_combined_influence_function(
388
+ post_gt_pairs,
389
+ overall_weights_post,
390
+ post_effects,
391
+ post_groups,
392
+ influence_func_info,
393
+ df,
394
+ unit,
395
+ precomputed,
396
+ global_unit_to_idx=unit_to_idx,
397
+ n_global_units=n_units,
398
+ )
399
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
400
+ bootstrap_overall = original_overall + all_bootstrap_weights @ overall_combined_if
401
+
402
+ # Vectorized event study aggregation using combined IFs
403
+ # Non-finite values handled at statistics computation stage
404
+ rel_periods: List[int] = []
405
+ bootstrap_event_study: Optional[Dict[int, np.ndarray]] = None
406
+ if event_study_info is not None:
407
+ rel_periods = sorted(event_study_info.keys())
408
+ bootstrap_event_study = {}
409
+ for e in rel_periods:
410
+ agg_info = event_study_info[e]
411
+ # Use combined IF (standard IF + WIF) for proper bootstrap
412
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
413
+ bootstrap_event_study[e] = (
414
+ agg_info["effect"] + all_bootstrap_weights @ agg_info["combined_if"]
415
+ )
416
+
417
+ # Vectorized group aggregation
418
+ # Non-finite values handled at statistics computation stage
419
+ group_list: List[Any] = []
420
+ bootstrap_group: Optional[Dict[Any, np.ndarray]] = None
421
+ if group_agg_info is not None:
422
+ group_list = sorted(group_agg_info.keys())
423
+ bootstrap_group = {}
424
+ for g in group_list:
425
+ agg_info = group_agg_info[g]
426
+ gt_indices = agg_info["gt_indices"]
427
+ weights = agg_info["weights"]
428
+ # Suppress RuntimeWarnings for edge cases
429
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
430
+ bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights
431
+
432
+ # Batch compute bootstrap statistics for ATT(g,t)
433
+ batch_ses, batch_ci_lo, batch_ci_hi, batch_pv = _compute_effect_bootstrap_stats_batch_func(
434
+ original_atts,
435
+ bootstrap_atts_gt,
436
+ alpha=self.alpha,
437
+ )
438
+ gt_ses = {}
439
+ gt_cis = {}
440
+ gt_p_values = {}
441
+ for j, gt in enumerate(gt_pairs):
442
+ gt_ses[gt] = float(batch_ses[j])
443
+ gt_cis[gt] = (float(batch_ci_lo[j]), float(batch_ci_hi[j]))
444
+ gt_p_values[gt] = float(batch_pv[j])
445
+
446
+ # Compute bootstrap statistics for overall ATT
447
+ if skip_overall_aggregation:
448
+ overall_se = np.nan
449
+ overall_ci = (np.nan, np.nan)
450
+ overall_p_value = np.nan
451
+ else:
452
+ overall_se, overall_ci, overall_p_value = _compute_effect_bootstrap_stats_func(
453
+ original_overall,
454
+ bootstrap_overall,
455
+ alpha=self.alpha,
456
+ context="overall ATT",
457
+ )
458
+
459
+ # Batch compute bootstrap statistics for event study effects
460
+ event_study_ses = None
461
+ event_study_cis = None
462
+ event_study_p_values = None
463
+
464
+ if bootstrap_event_study is not None and event_study_info is not None:
465
+ es_effects = np.array([event_study_info[e]["effect"] for e in rel_periods])
466
+ es_boot_matrix = np.column_stack([bootstrap_event_study[e] for e in rel_periods])
467
+ es_ses, es_ci_lo, es_ci_hi, es_pv = _compute_effect_bootstrap_stats_batch_func(
468
+ es_effects,
469
+ es_boot_matrix,
470
+ alpha=self.alpha,
471
+ )
472
+ event_study_ses = {e: float(es_ses[i]) for i, e in enumerate(rel_periods)}
473
+ event_study_cis = {
474
+ e: (float(es_ci_lo[i]), float(es_ci_hi[i])) for i, e in enumerate(rel_periods)
475
+ }
476
+ event_study_p_values = {e: float(es_pv[i]) for i, e in enumerate(rel_periods)}
477
+
478
+ # Batch compute bootstrap statistics for group effects
479
+ group_effect_ses = None
480
+ group_effect_cis = None
481
+ group_effect_p_values = None
482
+
483
+ if bootstrap_group is not None and group_agg_info is not None:
484
+ grp_effects = np.array([group_agg_info[g]["effect"] for g in group_list])
485
+ grp_boot_matrix = np.column_stack([bootstrap_group[g] for g in group_list])
486
+ grp_ses, grp_ci_lo, grp_ci_hi, grp_pv = _compute_effect_bootstrap_stats_batch_func(
487
+ grp_effects,
488
+ grp_boot_matrix,
489
+ alpha=self.alpha,
490
+ )
491
+ group_effect_ses = {g: float(grp_ses[i]) for i, g in enumerate(group_list)}
492
+ group_effect_cis = {
493
+ g: (float(grp_ci_lo[i]), float(grp_ci_hi[i])) for i, g in enumerate(group_list)
494
+ }
495
+ group_effect_p_values = {g: float(grp_pv[i]) for i, g in enumerate(group_list)}
496
+
497
+ # Compute simultaneous confidence band critical value (sup-t)
498
+ cband_crit_value = None
499
+ if (
500
+ cband
501
+ and bootstrap_event_study is not None
502
+ and event_study_ses is not None
503
+ and event_study_info is not None
504
+ ):
505
+ valid_es = [
506
+ e
507
+ for e in rel_periods
508
+ if e in event_study_ses
509
+ and np.isfinite(event_study_ses[e])
510
+ and event_study_ses[e] > 0
511
+ ]
512
+ if valid_es:
513
+ # Vectorized sup_t: max_e |(boot_att_e[b] - att_e) / se_e|
514
+ boot_matrix = np.array([bootstrap_event_study[e] for e in valid_es])
515
+ effects_vec = np.array([event_study_info[e]["effect"] for e in valid_es])
516
+ ses_vec = np.array([event_study_ses[e] for e in valid_es])
517
+ with np.errstate(divide="ignore", invalid="ignore"):
518
+ sup_t_dist = np.max(
519
+ np.abs((boot_matrix - effects_vec[:, None]) / ses_vec[:, None]),
520
+ axis=0,
521
+ )
522
+ finite_mask = np.isfinite(sup_t_dist)
523
+ n_valid = int(np.sum(finite_mask))
524
+ n_total = len(sup_t_dist)
525
+ if n_valid < n_total * 0.5:
526
+ warnings.warn(
527
+ f"Too few valid sup-t bootstrap samples ({n_valid}/{n_total}). "
528
+ "Returning None for cband critical value.",
529
+ RuntimeWarning,
530
+ stacklevel=2,
531
+ )
532
+ elif n_valid > 0:
533
+ cband_crit_value = float(np.quantile(sup_t_dist[finite_mask], 1 - self.alpha))
534
+
535
+ return CSBootstrapResults(
536
+ n_bootstrap=self.n_bootstrap,
537
+ weight_type=self.bootstrap_weights,
538
+ alpha=self.alpha,
539
+ overall_att_se=overall_se,
540
+ overall_att_ci=overall_ci,
541
+ overall_att_p_value=overall_p_value,
542
+ group_time_ses=gt_ses,
543
+ group_time_cis=gt_cis,
544
+ group_time_p_values=gt_p_values,
545
+ event_study_ses=event_study_ses,
546
+ event_study_cis=event_study_cis,
547
+ event_study_p_values=event_study_p_values,
548
+ group_effect_ses=group_effect_ses,
549
+ group_effect_cis=group_effect_cis,
550
+ group_effect_p_values=group_effect_p_values,
551
+ bootstrap_distribution=bootstrap_overall,
552
+ cband_crit_value=cband_crit_value,
553
+ )
554
+
555
+ def _prepare_event_study_aggregation(
556
+ self,
557
+ gt_pairs: List[Tuple[Any, Any]],
558
+ group_time_effects: Dict,
559
+ balance_e: Optional[int],
560
+ influence_func_info: Any = None,
561
+ df: Any = None,
562
+ unit: Optional[str] = None,
563
+ precomputed: Any = None,
564
+ global_unit_to_idx: Optional[Dict[Any, int]] = None,
565
+ n_global_units: Optional[int] = None,
566
+ ) -> Dict[int, Dict[str, Any]]:
567
+ """Prepare aggregation info for event study bootstrap."""
568
+ # Use fixed cohort survey masses (not per-cell survey_weight_sum) when
569
+ # survey weights are present, matching the analytical
570
+ # _aggregate_event_study() path.
571
+ survey_w = precomputed.get("survey_weights") if precomputed is not None else None
572
+ _cohort_mass: Optional[dict] = None
573
+ if survey_w is not None:
574
+ unit_cohorts = precomputed["unit_cohorts"]
575
+ _cohort_mass = {}
576
+
577
+ def _agg_weight(g: Any, t: Any) -> float:
578
+ if _cohort_mass is not None:
579
+ if g not in _cohort_mass:
580
+ _cohort_mass[g] = float(np.sum(survey_w[unit_cohorts == g]))
581
+ return _cohort_mass[g]
582
+ # Use agg_weight if available (RCS: fixed cohort mass)
583
+ return group_time_effects[(g, t)].get(
584
+ "agg_weight", group_time_effects[(g, t)]["n_treated"]
585
+ )
586
+
587
+ # Organize by relative time
588
+ effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}
589
+
590
+ for j, (g, t) in enumerate(gt_pairs):
591
+ e = t - g
592
+ if e not in effects_by_e:
593
+ effects_by_e[e] = []
594
+ effects_by_e[e].append(
595
+ (
596
+ j, # index in gt_pairs
597
+ group_time_effects[(g, t)]["effect"],
598
+ _agg_weight(g, t),
599
+ )
600
+ )
601
+
602
+ # Balance if requested
603
+ if balance_e is not None:
604
+ groups_at_e = set()
605
+ for j, (g, t) in enumerate(gt_pairs):
606
+ if t - g == balance_e and np.isfinite(group_time_effects[(g, t)]["effect"]):
607
+ groups_at_e.add(g)
608
+
609
+ balanced_effects: Dict[int, List[Tuple[int, float, float]]] = {}
610
+ for j, (g, t) in enumerate(gt_pairs):
611
+ if g in groups_at_e:
612
+ e = t - g
613
+ if e not in balanced_effects:
614
+ balanced_effects[e] = []
615
+ balanced_effects[e].append(
616
+ (
617
+ j,
618
+ group_time_effects[(g, t)]["effect"],
619
+ _agg_weight(g, t),
620
+ )
621
+ )
622
+ effects_by_e = balanced_effects
623
+
624
+ # Compute aggregation weights
625
+ result = {}
626
+ for e, effect_list in effects_by_e.items():
627
+ indices = np.array([x[0] for x in effect_list])
628
+ effects = np.array([x[1] for x in effect_list])
629
+ n_treated = np.array([x[2] for x in effect_list], dtype=float)
630
+
631
+ # Exclude NaN effects (matches analytical aggregation path)
632
+ finite_mask = np.isfinite(effects)
633
+ if not np.all(finite_mask):
634
+ indices = indices[finite_mask]
635
+ effects = effects[finite_mask]
636
+ n_treated = n_treated[finite_mask]
637
+ if len(effects) == 0:
638
+ continue
639
+
640
+ weights = n_treated / np.sum(n_treated)
641
+ agg_effect = np.sum(weights * effects)
642
+
643
+ entry: Dict[str, Any] = {
644
+ "gt_indices": indices,
645
+ "weights": weights,
646
+ "effect": agg_effect,
647
+ }
648
+
649
+ # Compute combined IF for this event time if args available
650
+ if influence_func_info is not None and df is not None and unit is not None:
651
+ gt_pairs_for_e = [gt_pairs[i] for i in indices]
652
+ groups_for_gt = np.array([gt_pairs[i][0] for i in indices])
653
+ combined_if, _ = self._compute_combined_influence_function(
654
+ gt_pairs_for_e,
655
+ weights,
656
+ effects,
657
+ groups_for_gt,
658
+ influence_func_info,
659
+ df,
660
+ unit,
661
+ precomputed,
662
+ global_unit_to_idx=global_unit_to_idx,
663
+ n_global_units=n_global_units,
664
+ )
665
+ entry["combined_if"] = combined_if
666
+
667
+ result[e] = entry
668
+
669
+ return result
670
+
671
+ def _prepare_group_aggregation(
672
+ self,
673
+ gt_pairs: List[Tuple[Any, Any]],
674
+ group_time_effects: Dict,
675
+ treatment_groups: List[Any],
676
+ ) -> Dict[Any, Dict[str, Any]]:
677
+ """Prepare aggregation info for group-level bootstrap."""
678
+ result = {}
679
+
680
+ for g in treatment_groups:
681
+ # Get all effects for this group (post-treatment only: t >= g - anticipation)
682
+ group_data = []
683
+ for j, (gg, t) in enumerate(gt_pairs):
684
+ if gg == g and t >= g - self.anticipation:
685
+ group_data.append(
686
+ (
687
+ j,
688
+ group_time_effects[(gg, t)]["effect"],
689
+ )
690
+ )
691
+
692
+ if not group_data:
693
+ continue
694
+
695
+ indices = np.array([x[0] for x in group_data])
696
+ effects = np.array([x[1] for x in group_data])
697
+
698
+ # Exclude NaN effects (matches analytical aggregation path)
699
+ finite_mask = np.isfinite(effects)
700
+ if not np.all(finite_mask):
701
+ indices = indices[finite_mask]
702
+ effects = effects[finite_mask]
703
+ if len(effects) == 0:
704
+ continue
705
+
706
+ # Equal weights across time periods
707
+ weights = np.ones(len(effects)) / len(effects)
708
+ agg_effect = np.sum(weights * effects)
709
+
710
+ result[g] = {
711
+ "gt_indices": indices,
712
+ "weights": weights,
713
+ "effect": agg_effect,
714
+ }
715
+
716
+ return result
717
+
718
+ def _compute_percentile_ci(
719
+ self,
720
+ boot_dist: np.ndarray,
721
+ alpha: float,
722
+ ) -> Tuple[float, float]:
723
+ """Compute percentile confidence interval from bootstrap distribution."""
724
+ return _compute_percentile_ci_func(boot_dist, alpha)
725
+
726
+ def _compute_bootstrap_pvalue(
727
+ self,
728
+ original_effect: float,
729
+ boot_dist: np.ndarray,
730
+ n_valid: Optional[int] = None,
731
+ ) -> float:
732
+ """
733
+ Compute two-sided bootstrap p-value.
734
+
735
+ Delegates to :func:`bootstrap_utils.compute_bootstrap_pvalue`.
736
+ """
737
+ return _compute_bootstrap_pvalue_func(original_effect, boot_dist, n_valid=n_valid)
738
+
739
+ def _compute_effect_bootstrap_stats(
740
+ self,
741
+ original_effect: float,
742
+ boot_dist: np.ndarray,
743
+ context: str = "bootstrap distribution",
744
+ ) -> Tuple[float, Tuple[float, float], float]:
745
+ """
746
+ Compute bootstrap statistics for a single effect.
747
+
748
+ Delegates to :func:`bootstrap_utils.compute_effect_bootstrap_stats`.
749
+ """
750
+ return _compute_effect_bootstrap_stats_func(
751
+ original_effect, boot_dist, alpha=self.alpha, context=context
752
+ )