diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,753 @@
1
+ """
2
+ Bootstrap inference for Callaway-Sant'Anna estimator.
3
+
4
+ This module provides bootstrap weight generation functions, the bootstrap
5
+ results container, and the mixin class with bootstrap inference methods.
6
+ """
7
+
8
+ import warnings
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
11
+
12
+ import numpy as np
13
+
14
+ # Import Rust backend if available (from _backend to avoid circular imports)
15
+ from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights
16
+
17
+ if TYPE_CHECKING:
18
+ pass
19
+
20
+
21
+ # =============================================================================
22
+ # Bootstrap Weight Generators
23
+ # =============================================================================
24
+
25
+
26
+ def _generate_bootstrap_weights(
27
+ n_units: int,
28
+ weight_type: str,
29
+ rng: np.random.Generator,
30
+ ) -> np.ndarray:
31
+ """
32
+ Generate bootstrap weights for multiplier bootstrap.
33
+
34
+ Parameters
35
+ ----------
36
+ n_units : int
37
+ Number of units (clusters) to generate weights for.
38
+ weight_type : str
39
+ Type of weights: "rademacher", "mammen", or "webb".
40
+ rng : np.random.Generator
41
+ Random number generator.
42
+
43
+ Returns
44
+ -------
45
+ np.ndarray
46
+ Array of bootstrap weights with shape (n_units,).
47
+ """
48
+ if weight_type == "rademacher":
49
+ # Rademacher: +1 or -1 with equal probability
50
+ return rng.choice([-1.0, 1.0], size=n_units)
51
+
52
+ elif weight_type == "mammen":
53
+ # Mammen's two-point distribution
54
+ # E[v] = 0, E[v^2] = 1, E[v^3] = 1
55
+ sqrt5 = np.sqrt(5)
56
+ val1 = -(sqrt5 - 1) / 2 # ≈ -0.618
57
+ val2 = (sqrt5 + 1) / 2 # ≈ 1.618 (golden ratio)
58
+ p1 = (sqrt5 + 1) / (2 * sqrt5) # ≈ 0.724
59
+ return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1])
60
+
61
+ elif weight_type == "webb":
62
+ # Webb's 6-point distribution (recommended for few clusters)
63
+ # Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
64
+ # This matches R's did package: E[w]=0, Var(w)=1.0
65
+ values = np.array([
66
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
67
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
68
+ ])
69
+ return rng.choice(values, size=n_units) # Equal probs (1/6 each)
70
+
71
+ else:
72
+ raise ValueError(
73
+ f"weight_type must be 'rademacher', 'mammen', or 'webb', "
74
+ f"got '{weight_type}'"
75
+ )
76
+
77
+
78
+ def _generate_bootstrap_weights_batch(
79
+ n_bootstrap: int,
80
+ n_units: int,
81
+ weight_type: str,
82
+ rng: np.random.Generator,
83
+ ) -> np.ndarray:
84
+ """
85
+ Generate all bootstrap weights at once (vectorized).
86
+
87
+ Parameters
88
+ ----------
89
+ n_bootstrap : int
90
+ Number of bootstrap iterations.
91
+ n_units : int
92
+ Number of units (clusters) to generate weights for.
93
+ weight_type : str
94
+ Type of weights: "rademacher", "mammen", or "webb".
95
+ rng : np.random.Generator
96
+ Random number generator.
97
+
98
+ Returns
99
+ -------
100
+ np.ndarray
101
+ Array of bootstrap weights with shape (n_bootstrap, n_units).
102
+ """
103
+ # Use Rust backend if available (parallel + fast RNG)
104
+ if HAS_RUST_BACKEND and _rust_bootstrap_weights is not None:
105
+ # Get seed from the NumPy RNG for reproducibility
106
+ seed = rng.integers(0, 2**63 - 1)
107
+ return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed)
108
+
109
+ # Fallback to NumPy implementation
110
+ return _generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng)
111
+
112
+
113
+ def _generate_bootstrap_weights_batch_numpy(
114
+ n_bootstrap: int,
115
+ n_units: int,
116
+ weight_type: str,
117
+ rng: np.random.Generator,
118
+ ) -> np.ndarray:
119
+ """
120
+ NumPy fallback implementation of _generate_bootstrap_weights_batch.
121
+
122
+ Generates multiplier bootstrap weights for wild cluster bootstrap.
123
+ All weight distributions satisfy E[w] = 0, E[w^2] = 1.
124
+
125
+ Parameters
126
+ ----------
127
+ n_bootstrap : int
128
+ Number of bootstrap iterations.
129
+ n_units : int
130
+ Number of units (clusters) to generate weights for.
131
+ weight_type : str
132
+ Type of weights: "rademacher" (+-1), "mammen" (2-point),
133
+ or "webb" (6-point).
134
+ rng : np.random.Generator
135
+ Random number generator for reproducibility.
136
+
137
+ Returns
138
+ -------
139
+ np.ndarray
140
+ Array of bootstrap weights with shape (n_bootstrap, n_units).
141
+ """
142
+ if weight_type == "rademacher":
143
+ # Rademacher: +1 or -1 with equal probability
144
+ return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units))
145
+
146
+ elif weight_type == "mammen":
147
+ # Mammen's two-point distribution
148
+ sqrt5 = np.sqrt(5)
149
+ val1 = -(sqrt5 - 1) / 2
150
+ val2 = (sqrt5 + 1) / 2
151
+ p1 = (sqrt5 + 1) / (2 * sqrt5)
152
+ return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1])
153
+
154
+ elif weight_type == "webb":
155
+ # Webb's 6-point distribution
156
+ # Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
157
+ # This matches R's did package: E[w]=0, Var(w)=1.0
158
+ values = np.array([
159
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
160
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
161
+ ])
162
+ return rng.choice(values, size=(n_bootstrap, n_units)) # Equal probs (1/6 each)
163
+
164
+ else:
165
+ raise ValueError(
166
+ f"weight_type must be 'rademacher', 'mammen', or 'webb', "
167
+ f"got '{weight_type}'"
168
+ )
169
+
170
+
171
+ # =============================================================================
172
+ # Bootstrap Results Container
173
+ # =============================================================================
174
+
175
+
176
+ @dataclass
177
+ class CSBootstrapResults:
178
+ """
179
+ Results from Callaway-Sant'Anna multiplier bootstrap inference.
180
+
181
+ Attributes
182
+ ----------
183
+ n_bootstrap : int
184
+ Number of bootstrap iterations.
185
+ weight_type : str
186
+ Type of bootstrap weights used.
187
+ alpha : float
188
+ Significance level used for confidence intervals.
189
+ overall_att_se : float
190
+ Bootstrap standard error for overall ATT.
191
+ overall_att_ci : Tuple[float, float]
192
+ Bootstrap confidence interval for overall ATT.
193
+ overall_att_p_value : float
194
+ Bootstrap p-value for overall ATT.
195
+ group_time_ses : Dict[Tuple[Any, Any], float]
196
+ Bootstrap SEs for each ATT(g,t).
197
+ group_time_cis : Dict[Tuple[Any, Any], Tuple[float, float]]
198
+ Bootstrap CIs for each ATT(g,t).
199
+ group_time_p_values : Dict[Tuple[Any, Any], float]
200
+ Bootstrap p-values for each ATT(g,t).
201
+ event_study_ses : Optional[Dict[int, float]]
202
+ Bootstrap SEs for event study effects.
203
+ event_study_cis : Optional[Dict[int, Tuple[float, float]]]
204
+ Bootstrap CIs for event study effects.
205
+ event_study_p_values : Optional[Dict[int, float]]
206
+ Bootstrap p-values for event study effects.
207
+ group_effect_ses : Optional[Dict[Any, float]]
208
+ Bootstrap SEs for group effects.
209
+ group_effect_cis : Optional[Dict[Any, Tuple[float, float]]]
210
+ Bootstrap CIs for group effects.
211
+ group_effect_p_values : Optional[Dict[Any, float]]
212
+ Bootstrap p-values for group effects.
213
+ bootstrap_distribution : Optional[np.ndarray]
214
+ Full bootstrap distribution of overall ATT (if requested).
215
+ """
216
+ n_bootstrap: int
217
+ weight_type: str
218
+ alpha: float
219
+ overall_att_se: float
220
+ overall_att_ci: Tuple[float, float]
221
+ overall_att_p_value: float
222
+ group_time_ses: Dict[Tuple[Any, Any], float]
223
+ group_time_cis: Dict[Tuple[Any, Any], Tuple[float, float]]
224
+ group_time_p_values: Dict[Tuple[Any, Any], float]
225
+ event_study_ses: Optional[Dict[int, float]] = None
226
+ event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
227
+ event_study_p_values: Optional[Dict[int, float]] = None
228
+ group_effect_ses: Optional[Dict[Any, float]] = None
229
+ group_effect_cis: Optional[Dict[Any, Tuple[float, float]]] = None
230
+ group_effect_p_values: Optional[Dict[Any, float]] = None
231
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
232
+
233
+
234
+ # =============================================================================
235
+ # Bootstrap Mixin Class
236
+ # =============================================================================
237
+
238
+
239
+ class CallawaySantAnnaBootstrapMixin:
240
+ """
241
+ Mixin class providing bootstrap inference methods for CallawaySantAnna.
242
+
243
+ This class is not intended to be used standalone. It provides methods
244
+ that are used by the main CallawaySantAnna class for multiplier bootstrap
245
+ inference.
246
+ """
247
+
248
+ # Type hints for attributes accessed from the main class
249
+ n_bootstrap: int
250
+ bootstrap_weight_type: str
251
+ alpha: float
252
+ seed: Optional[int]
253
+ anticipation: int
254
+
255
+ def _run_multiplier_bootstrap(
256
+ self,
257
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
258
+ influence_func_info: Dict[Tuple[Any, Any], Dict[str, Any]],
259
+ aggregate: Optional[str],
260
+ balance_e: Optional[int],
261
+ treatment_groups: List[Any],
262
+ time_periods: List[Any],
263
+ ) -> CSBootstrapResults:
264
+ """
265
+ Run multiplier bootstrap for inference on all parameters.
266
+
267
+ This implements the multiplier bootstrap procedure from Callaway & Sant'Anna (2021).
268
+ The key idea is to perturb the influence function contributions with random
269
+ weights at the cluster (unit) level, then recompute aggregations.
270
+
271
+ Parameters
272
+ ----------
273
+ group_time_effects : dict
274
+ Dictionary of ATT(g,t) effects with analytical SEs.
275
+ influence_func_info : dict
276
+ Dictionary mapping (g,t) to influence function information.
277
+ aggregate : str, optional
278
+ Type of aggregation requested.
279
+ balance_e : int, optional
280
+ Balance parameter for event study.
281
+ treatment_groups : list
282
+ List of treatment cohorts.
283
+ time_periods : list
284
+ List of time periods.
285
+
286
+ Returns
287
+ -------
288
+ CSBootstrapResults
289
+ Bootstrap inference results.
290
+ """
291
+ # Warn about low bootstrap iterations
292
+ if self.n_bootstrap < 50:
293
+ warnings.warn(
294
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
295
+ "for reliable inference. Percentile confidence intervals and p-values "
296
+ "may be unreliable with few iterations.",
297
+ UserWarning,
298
+ stacklevel=3,
299
+ )
300
+
301
+ rng = np.random.default_rng(self.seed)
302
+
303
+ # Collect all unique units across all (g,t) combinations
304
+ all_units = set()
305
+ for (g, t), info in influence_func_info.items():
306
+ all_units.update(info['treated_units'])
307
+ all_units.update(info['control_units'])
308
+ all_units = sorted(all_units)
309
+ n_units = len(all_units)
310
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
311
+
312
+ # Get list of (g,t) pairs
313
+ gt_pairs = list(group_time_effects.keys())
314
+ n_gt = len(gt_pairs)
315
+
316
+ # Identify post-treatment (g,t) pairs for overall ATT
317
+ # Pre-treatment effects are for parallel trends assessment, not aggregated
318
+ post_treatment_mask = np.array([
319
+ t >= g - self.anticipation for (g, t) in gt_pairs
320
+ ])
321
+ post_treatment_indices = np.where(post_treatment_mask)[0]
322
+
323
+ # Compute aggregation weights for overall ATT (post-treatment only)
324
+ all_n_treated = np.array([
325
+ group_time_effects[gt]['n_treated'] for gt in gt_pairs
326
+ ], dtype=float)
327
+ post_n_treated = all_n_treated[post_treatment_mask]
328
+
329
+ # Flag to skip overall ATT aggregation when no post-treatment effects
330
+ # But continue bootstrap for per-effect SEs (pre-treatment effects need bootstrap SEs too)
331
+ skip_overall_aggregation = False
332
+ if len(post_treatment_indices) == 0:
333
+ warnings.warn(
334
+ "No post-treatment effects for bootstrap aggregation. "
335
+ "Overall ATT statistics will be NaN, but per-effect SEs will be computed.",
336
+ UserWarning,
337
+ stacklevel=2
338
+ )
339
+ skip_overall_aggregation = True
340
+ overall_weights_post = np.array([])
341
+ else:
342
+ overall_weights_post = post_n_treated / np.sum(post_n_treated)
343
+
344
+ # Original point estimates
345
+ original_atts = np.array([group_time_effects[gt]['effect'] for gt in gt_pairs])
346
+ if skip_overall_aggregation:
347
+ original_overall = np.nan
348
+ else:
349
+ original_overall = np.sum(overall_weights_post * original_atts[post_treatment_mask])
350
+
351
+ # Prepare event study and group aggregation info if needed
352
+ event_study_info = None
353
+ group_agg_info = None
354
+
355
+ if aggregate in ["event_study", "all"]:
356
+ event_study_info = self._prepare_event_study_aggregation(
357
+ gt_pairs, group_time_effects, balance_e
358
+ )
359
+
360
+ if aggregate in ["group", "all"]:
361
+ group_agg_info = self._prepare_group_aggregation(
362
+ gt_pairs, group_time_effects, treatment_groups
363
+ )
364
+
365
+ # Pre-compute unit index arrays for each (g,t) pair (done once, not per iteration)
366
+ gt_treated_indices = []
367
+ gt_control_indices = []
368
+ gt_treated_inf = []
369
+ gt_control_inf = []
370
+
371
+ for j, gt in enumerate(gt_pairs):
372
+ info = influence_func_info[gt]
373
+ treated_idx = np.array([unit_to_idx[u] for u in info['treated_units']])
374
+ control_idx = np.array([unit_to_idx[u] for u in info['control_units']])
375
+ gt_treated_indices.append(treated_idx)
376
+ gt_control_indices.append(control_idx)
377
+ gt_treated_inf.append(np.asarray(info['treated_inf']))
378
+ gt_control_inf.append(np.asarray(info['control_inf']))
379
+
380
+ # Generate ALL bootstrap weights upfront: shape (n_bootstrap, n_units)
381
+ # This is much faster than generating one at a time
382
+ all_bootstrap_weights = _generate_bootstrap_weights_batch(
383
+ self.n_bootstrap, n_units, self.bootstrap_weight_type, rng
384
+ )
385
+
386
+ # Vectorized bootstrap ATT(g,t) computation
387
+ # Compute all bootstrap ATTs for all (g,t) pairs using matrix operations
388
+ bootstrap_atts_gt = np.zeros((self.n_bootstrap, n_gt))
389
+
390
+ for j in range(n_gt):
391
+ treated_idx = gt_treated_indices[j]
392
+ control_idx = gt_control_indices[j]
393
+ treated_inf = gt_treated_inf[j]
394
+ control_inf = gt_control_inf[j]
395
+
396
+ # Extract weights for this (g,t)'s units across all bootstrap iterations
397
+ # Shape: (n_bootstrap, n_treated) and (n_bootstrap, n_control)
398
+ treated_weights = all_bootstrap_weights[:, treated_idx]
399
+ control_weights = all_bootstrap_weights[:, control_idx]
400
+
401
+ # Vectorized perturbation: matrix-vector multiply
402
+ # Shape: (n_bootstrap,)
403
+ # Suppress RuntimeWarnings for edge cases (small samples, extreme weights)
404
+ with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
405
+ perturbations = (
406
+ treated_weights @ treated_inf +
407
+ control_weights @ control_inf
408
+ )
409
+
410
+ # Let non-finite values propagate - they will be handled at statistics computation
411
+ bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
412
+
413
+ # Vectorized overall ATT: matrix-vector multiply (post-treatment only)
414
+ # Shape: (n_bootstrap,)
415
+ if skip_overall_aggregation:
416
+ bootstrap_overall = np.full(self.n_bootstrap, np.nan)
417
+ else:
418
+ # Suppress RuntimeWarnings for edge cases - non-finite values handled at statistics computation
419
+ with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
420
+ bootstrap_overall = bootstrap_atts_gt[:, post_treatment_indices] @ overall_weights_post
421
+
422
+ # Vectorized event study aggregation
423
+ # Non-finite values handled at statistics computation stage
424
+ rel_periods: List[int] = []
425
+ bootstrap_event_study: Optional[Dict[int, np.ndarray]] = None
426
+ if event_study_info is not None:
427
+ rel_periods = sorted(event_study_info.keys())
428
+ bootstrap_event_study = {}
429
+ for e in rel_periods:
430
+ agg_info = event_study_info[e]
431
+ gt_indices = agg_info['gt_indices']
432
+ weights = agg_info['weights']
433
+ # Vectorized: select columns and multiply by weights
434
+ # Suppress RuntimeWarnings for edge cases
435
+ with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
436
+ bootstrap_event_study[e] = bootstrap_atts_gt[:, gt_indices] @ weights
437
+
438
+ # Vectorized group aggregation
439
+ # Non-finite values handled at statistics computation stage
440
+ group_list: List[Any] = []
441
+ bootstrap_group: Optional[Dict[Any, np.ndarray]] = None
442
+ if group_agg_info is not None:
443
+ group_list = sorted(group_agg_info.keys())
444
+ bootstrap_group = {}
445
+ for g in group_list:
446
+ agg_info = group_agg_info[g]
447
+ gt_indices = agg_info['gt_indices']
448
+ weights = agg_info['weights']
449
+ # Suppress RuntimeWarnings for edge cases
450
+ with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
451
+ bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights
452
+
453
+ # Compute bootstrap statistics for ATT(g,t)
454
+ gt_ses = {}
455
+ gt_cis = {}
456
+ gt_p_values = {}
457
+
458
+ for j, gt in enumerate(gt_pairs):
459
+ se, ci, p_value = self._compute_effect_bootstrap_stats(
460
+ original_atts[j], bootstrap_atts_gt[:, j],
461
+ context=f"ATT(g={gt[0]}, t={gt[1]})"
462
+ )
463
+ gt_ses[gt] = se
464
+ gt_cis[gt] = ci
465
+ gt_p_values[gt] = p_value
466
+
467
+ # Compute bootstrap statistics for overall ATT
468
+ if skip_overall_aggregation:
469
+ overall_se = np.nan
470
+ overall_ci = (np.nan, np.nan)
471
+ overall_p_value = np.nan
472
+ else:
473
+ overall_se, overall_ci, overall_p_value = self._compute_effect_bootstrap_stats(
474
+ original_overall, bootstrap_overall,
475
+ context="overall ATT"
476
+ )
477
+
478
+ # Compute bootstrap statistics for event study effects
479
+ event_study_ses = None
480
+ event_study_cis = None
481
+ event_study_p_values = None
482
+
483
+ if bootstrap_event_study is not None and event_study_info is not None:
484
+ event_study_ses = {}
485
+ event_study_cis = {}
486
+ event_study_p_values = {}
487
+
488
+ for e in rel_periods:
489
+ se, ci, p_value = self._compute_effect_bootstrap_stats(
490
+ event_study_info[e]['effect'], bootstrap_event_study[e],
491
+ context=f"event study (e={e})"
492
+ )
493
+ event_study_ses[e] = se
494
+ event_study_cis[e] = ci
495
+ event_study_p_values[e] = p_value
496
+
497
+ # Compute bootstrap statistics for group effects
498
+ group_effect_ses = None
499
+ group_effect_cis = None
500
+ group_effect_p_values = None
501
+
502
+ if bootstrap_group is not None and group_agg_info is not None:
503
+ group_effect_ses = {}
504
+ group_effect_cis = {}
505
+ group_effect_p_values = {}
506
+
507
+ for g in group_list:
508
+ se, ci, p_value = self._compute_effect_bootstrap_stats(
509
+ group_agg_info[g]['effect'], bootstrap_group[g],
510
+ context=f"group effect (g={g})"
511
+ )
512
+ group_effect_ses[g] = se
513
+ group_effect_cis[g] = ci
514
+ group_effect_p_values[g] = p_value
515
+
516
+ return CSBootstrapResults(
517
+ n_bootstrap=self.n_bootstrap,
518
+ weight_type=self.bootstrap_weight_type,
519
+ alpha=self.alpha,
520
+ overall_att_se=overall_se,
521
+ overall_att_ci=overall_ci,
522
+ overall_att_p_value=overall_p_value,
523
+ group_time_ses=gt_ses,
524
+ group_time_cis=gt_cis,
525
+ group_time_p_values=gt_p_values,
526
+ event_study_ses=event_study_ses,
527
+ event_study_cis=event_study_cis,
528
+ event_study_p_values=event_study_p_values,
529
+ group_effect_ses=group_effect_ses,
530
+ group_effect_cis=group_effect_cis,
531
+ group_effect_p_values=group_effect_p_values,
532
+ bootstrap_distribution=bootstrap_overall,
533
+ )
534
+
535
+ def _prepare_event_study_aggregation(
536
+ self,
537
+ gt_pairs: List[Tuple[Any, Any]],
538
+ group_time_effects: Dict,
539
+ balance_e: Optional[int],
540
+ ) -> Dict[int, Dict[str, Any]]:
541
+ """Prepare aggregation info for event study bootstrap."""
542
+ # Organize by relative time
543
+ effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}
544
+
545
+ for j, (g, t) in enumerate(gt_pairs):
546
+ e = t - g
547
+ if e not in effects_by_e:
548
+ effects_by_e[e] = []
549
+ effects_by_e[e].append((
550
+ j, # index in gt_pairs
551
+ group_time_effects[(g, t)]['effect'],
552
+ group_time_effects[(g, t)]['n_treated']
553
+ ))
554
+
555
+ # Balance if requested
556
+ if balance_e is not None:
557
+ groups_at_e = set()
558
+ for j, (g, t) in enumerate(gt_pairs):
559
+ if t - g == balance_e:
560
+ groups_at_e.add(g)
561
+
562
+ balanced_effects: Dict[int, List[Tuple[int, float, float]]] = {}
563
+ for j, (g, t) in enumerate(gt_pairs):
564
+ if g in groups_at_e:
565
+ e = t - g
566
+ if e not in balanced_effects:
567
+ balanced_effects[e] = []
568
+ balanced_effects[e].append((
569
+ j,
570
+ group_time_effects[(g, t)]['effect'],
571
+ group_time_effects[(g, t)]['n_treated']
572
+ ))
573
+ effects_by_e = balanced_effects
574
+
575
+ # Compute aggregation weights
576
+ result = {}
577
+ for e, effect_list in effects_by_e.items():
578
+ indices = np.array([x[0] for x in effect_list])
579
+ effects = np.array([x[1] for x in effect_list])
580
+ n_treated = np.array([x[2] for x in effect_list], dtype=float)
581
+
582
+ weights = n_treated / np.sum(n_treated)
583
+ agg_effect = np.sum(weights * effects)
584
+
585
+ result[e] = {
586
+ 'gt_indices': indices,
587
+ 'weights': weights,
588
+ 'effect': agg_effect,
589
+ }
590
+
591
+ return result
592
+
593
+ def _prepare_group_aggregation(
594
+ self,
595
+ gt_pairs: List[Tuple[Any, Any]],
596
+ group_time_effects: Dict,
597
+ treatment_groups: List[Any],
598
+ ) -> Dict[Any, Dict[str, Any]]:
599
+ """Prepare aggregation info for group-level bootstrap."""
600
+ result = {}
601
+
602
+ for g in treatment_groups:
603
+ # Get all effects for this group (post-treatment only: t >= g - anticipation)
604
+ group_data = []
605
+ for j, (gg, t) in enumerate(gt_pairs):
606
+ if gg == g and t >= g - self.anticipation:
607
+ group_data.append((
608
+ j,
609
+ group_time_effects[(gg, t)]['effect'],
610
+ ))
611
+
612
+ if not group_data:
613
+ continue
614
+
615
+ indices = np.array([x[0] for x in group_data])
616
+ effects = np.array([x[1] for x in group_data])
617
+
618
+ # Equal weights across time periods
619
+ weights = np.ones(len(effects)) / len(effects)
620
+ agg_effect = np.sum(weights * effects)
621
+
622
+ result[g] = {
623
+ 'gt_indices': indices,
624
+ 'weights': weights,
625
+ 'effect': agg_effect,
626
+ }
627
+
628
+ return result
629
+
630
+ def _compute_percentile_ci(
631
+ self,
632
+ boot_dist: np.ndarray,
633
+ alpha: float,
634
+ ) -> Tuple[float, float]:
635
+ """Compute percentile confidence interval from bootstrap distribution."""
636
+ lower = float(np.percentile(boot_dist, alpha / 2 * 100))
637
+ upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
638
+ return (lower, upper)
639
+
640
+ def _compute_bootstrap_pvalue(
641
+ self,
642
+ original_effect: float,
643
+ boot_dist: np.ndarray,
644
+ n_valid: Optional[int] = None,
645
+ ) -> float:
646
+ """
647
+ Compute two-sided bootstrap p-value.
648
+
649
+ Uses the percentile method: p-value is the proportion of bootstrap
650
+ estimates on the opposite side of zero from the original estimate,
651
+ doubled for two-sided test.
652
+
653
+ Parameters
654
+ ----------
655
+ original_effect : float
656
+ Original point estimate.
657
+ boot_dist : np.ndarray
658
+ Bootstrap distribution of the effect.
659
+ n_valid : int, optional
660
+ Number of valid bootstrap samples. If None, uses self.n_bootstrap.
661
+ Use this when boot_dist has already been filtered for non-finite values
662
+ to ensure the p-value floor is based on the actual valid sample count.
663
+
664
+ Returns
665
+ -------
666
+ float
667
+ Two-sided bootstrap p-value.
668
+ """
669
+ if original_effect >= 0:
670
+ # Proportion of bootstrap estimates <= 0
671
+ p_one_sided = np.mean(boot_dist <= 0)
672
+ else:
673
+ # Proportion of bootstrap estimates >= 0
674
+ p_one_sided = np.mean(boot_dist >= 0)
675
+
676
+ # Two-sided p-value
677
+ p_value = min(2 * p_one_sided, 1.0)
678
+
679
+ # Ensure minimum p-value using n_valid if provided, otherwise n_bootstrap
680
+ n_for_floor = n_valid if n_valid is not None else self.n_bootstrap
681
+ p_value = max(p_value, 1 / (n_for_floor + 1))
682
+
683
+ return float(p_value)
684
+
685
+ def _compute_effect_bootstrap_stats(
686
+ self,
687
+ original_effect: float,
688
+ boot_dist: np.ndarray,
689
+ context: str = "bootstrap distribution",
690
+ ) -> Tuple[float, Tuple[float, float], float]:
691
+ """
692
+ Compute bootstrap statistics for a single effect.
693
+
694
+ Non-finite bootstrap samples are dropped and a warning is issued if any
695
+ are present. If too few valid samples remain (<50%), returns NaN for all
696
+ statistics to signal invalid inference.
697
+
698
+ Parameters
699
+ ----------
700
+ original_effect : float
701
+ Original point estimate.
702
+ boot_dist : np.ndarray
703
+ Bootstrap distribution of the effect.
704
+ context : str, optional
705
+ Description for warning messages, by default "bootstrap distribution".
706
+
707
+ Returns
708
+ -------
709
+ se : float
710
+ Bootstrap standard error.
711
+ ci : Tuple[float, float]
712
+ Percentile confidence interval.
713
+ p_value : float
714
+ Bootstrap p-value.
715
+ """
716
+ # Filter out non-finite values
717
+ finite_mask = np.isfinite(boot_dist)
718
+ n_valid = np.sum(finite_mask)
719
+ n_total = len(boot_dist)
720
+
721
+ if n_valid < n_total:
722
+ import warnings
723
+ n_nonfinite = n_total - n_valid
724
+ warnings.warn(
725
+ f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples in {context}. "
726
+ "This may occur with very small samples or extreme weights. "
727
+ "Bootstrap estimates based on remaining valid samples.",
728
+ RuntimeWarning,
729
+ stacklevel=3
730
+ )
731
+
732
+ # Check if we have enough valid samples
733
+ if n_valid < n_total * 0.5:
734
+ import warnings
735
+ warnings.warn(
736
+ f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. "
737
+ "Returning NaN for SE/CI/p-value to signal invalid inference.",
738
+ RuntimeWarning,
739
+ stacklevel=3
740
+ )
741
+ return np.nan, (np.nan, np.nan), np.nan
742
+
743
+ # Use only valid samples
744
+ valid_dist = boot_dist[finite_mask]
745
+ n_valid_bootstrap = len(valid_dist)
746
+
747
+ se = float(np.std(valid_dist, ddof=1))
748
+ ci = self._compute_percentile_ci(valid_dist, self.alpha)
749
+
750
+ # Compute p-value using shared method with correct floor based on valid sample count
751
+ p_value = self._compute_bootstrap_pvalue(original_effect, valid_dist, n_valid=n_valid_bootstrap)
752
+
753
+ return se, ci, p_value