diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1626 @@
1
+ """
2
+ Continuous Difference-in-Differences estimator.
3
+
4
+ Implements Callaway, Goodman-Bacon & Sant'Anna (2024),
5
+ "Difference-in-Differences with a Continuous Treatment" (NBER WP 32117).
6
+
7
+ Estimates dose-response curves ATT(d) and ACRT(d), as well as summary
8
+ parameters ATT^{glob} and ACRT^{glob}, with optional multiplier bootstrap
9
+ inference.
10
+ """
11
+
12
+ import warnings
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from diff_diff.bootstrap_utils import (
19
+ compute_effect_bootstrap_stats,
20
+ generate_bootstrap_weights_batch,
21
+ )
22
+ from diff_diff.continuous_did_bspline import (
23
+ bspline_derivative_design_matrix,
24
+ bspline_design_matrix,
25
+ build_bspline_basis,
26
+ default_dose_grid,
27
+ )
28
+ from diff_diff.continuous_did_results import (
29
+ ContinuousDiDResults,
30
+ DoseResponseCurve,
31
+ )
32
+ from diff_diff.linalg import solve_ols
33
+ from diff_diff.survey import (
34
+ ResolvedSurveyDesign,
35
+ _resolve_survey_for_fit,
36
+ _validate_unit_constant_survey,
37
+ compute_survey_vcov,
38
+ )
39
+ from diff_diff.utils import safe_inference
40
+
41
+ __all__ = ["ContinuousDiD", "ContinuousDiDResults", "DoseResponseCurve"]
42
+
43
+
44
+ class ContinuousDiD:
45
+ """
46
+ Continuous Difference-in-Differences estimator.
47
+
48
+ Implements the methodology from Callaway, Goodman-Bacon & Sant'Anna (2024)
49
+ for estimating dose-response curves when treatment has a continuous intensity.
50
+
51
+ Parameters
52
+ ----------
53
+ degree : int, default=3
54
+ B-spline degree (3 = cubic).
55
+ num_knots : int, default=0
56
+ Number of interior knots for the B-spline basis.
57
+ dvals : array-like, optional
58
+ Custom dose evaluation grid. If None, uses quantile-based default.
59
+ control_group : str, default="never_treated"
60
+ ``"never_treated"`` or ``"not_yet_treated"``.
61
+ anticipation : int, default=0
62
+ Number of periods of treatment anticipation.
63
+ base_period : str, default="varying"
64
+ ``"varying"`` or ``"universal"``.
65
+ alpha : float, default=0.05
66
+ Significance level for confidence intervals.
67
+ n_bootstrap : int, default=0
68
+ Number of multiplier bootstrap iterations. 0 for analytical SEs only.
69
+ bootstrap_weights : str, default="rademacher"
70
+ Bootstrap weight type: ``"rademacher"``, ``"mammen"``, or ``"webb"``.
71
+ seed : int, optional
72
+ Random seed for reproducibility.
73
+ rank_deficient_action : str, default="warn"
74
+ Action for rank-deficient B-spline OLS: ``"warn"``, ``"error"``, or ``"silent"``.
75
+
76
+ Examples
77
+ --------
78
+ >>> from diff_diff import ContinuousDiD, generate_continuous_did_data
79
+ >>> data = generate_continuous_did_data(n_units=200, seed=42)
80
+ >>> est = ContinuousDiD(n_bootstrap=199, seed=42)
81
+ >>> results = est.fit(data, outcome="outcome", unit="unit",
82
+ ... time="period", first_treat="first_treat",
83
+ ... dose="dose", aggregate="dose")
84
+ >>> results.overall_att # doctest: +SKIP
85
+ """
86
+
87
+ _VALID_CONTROL_GROUPS = {"never_treated", "not_yet_treated"}
88
+ _VALID_BASE_PERIODS = {"varying", "universal"}
89
+
90
+ def __init__(
91
+ self,
92
+ degree: int = 3,
93
+ num_knots: int = 0,
94
+ dvals: Optional[np.ndarray] = None,
95
+ control_group: str = "never_treated",
96
+ anticipation: int = 0,
97
+ base_period: str = "varying",
98
+ alpha: float = 0.05,
99
+ n_bootstrap: int = 0,
100
+ bootstrap_weights: str = "rademacher",
101
+ seed: Optional[int] = None,
102
+ rank_deficient_action: str = "warn",
103
+ ):
104
+ self.degree = degree
105
+ self.num_knots = num_knots
106
+ self.dvals = np.asarray(dvals, dtype=float) if dvals is not None else None
107
+ self.control_group = control_group
108
+ self.anticipation = anticipation
109
+ self.base_period = base_period
110
+ self.alpha = alpha
111
+ self.n_bootstrap = n_bootstrap
112
+ self.bootstrap_weights = bootstrap_weights
113
+ self.seed = seed
114
+ self.rank_deficient_action = rank_deficient_action
115
+ self._validate_constrained_params()
116
+
117
+ def _validate_constrained_params(self) -> None:
118
+ """Validate control_group and base_period values."""
119
+ if self.control_group not in self._VALID_CONTROL_GROUPS:
120
+ raise ValueError(
121
+ f"Invalid control_group: '{self.control_group}'. "
122
+ f"Must be one of {self._VALID_CONTROL_GROUPS}."
123
+ )
124
+ if self.base_period not in self._VALID_BASE_PERIODS:
125
+ raise ValueError(
126
+ f"Invalid base_period: '{self.base_period}'. "
127
+ f"Must be one of {self._VALID_BASE_PERIODS}."
128
+ )
129
+
130
+ def get_params(self) -> Dict[str, Any]:
131
+ """Return estimator parameters as a dictionary."""
132
+ return {
133
+ "degree": self.degree,
134
+ "num_knots": self.num_knots,
135
+ "dvals": self.dvals,
136
+ "control_group": self.control_group,
137
+ "anticipation": self.anticipation,
138
+ "base_period": self.base_period,
139
+ "alpha": self.alpha,
140
+ "n_bootstrap": self.n_bootstrap,
141
+ "bootstrap_weights": self.bootstrap_weights,
142
+ "seed": self.seed,
143
+ "rank_deficient_action": self.rank_deficient_action,
144
+ }
145
+
146
+ def set_params(self, **params) -> "ContinuousDiD":
147
+ """Set estimator parameters and return self."""
148
+ for key, value in params.items():
149
+ if not hasattr(self, key):
150
+ raise ValueError(f"Invalid parameter: {key}")
151
+ setattr(self, key, value)
152
+ self._validate_constrained_params()
153
+ return self
154
+
155
+ # ------------------------------------------------------------------
156
+ # Main fit
157
+ # ------------------------------------------------------------------
158
+
159
+ def fit(
160
+ self,
161
+ data: pd.DataFrame,
162
+ outcome: str,
163
+ unit: str,
164
+ time: str,
165
+ first_treat: str,
166
+ dose: str,
167
+ aggregate: Optional[str] = None,
168
+ survey_design: object = None,
169
+ ) -> ContinuousDiDResults:
170
+ """
171
+ Fit the continuous DiD estimator.
172
+
173
+ Parameters
174
+ ----------
175
+ data : pd.DataFrame
176
+ Panel data.
177
+ outcome : str
178
+ Outcome column name.
179
+ unit : str
180
+ Unit identifier column.
181
+ time : str
182
+ Time period column.
183
+ first_treat : str
184
+ First treatment period column (0 or inf for never-treated).
185
+ dose : str
186
+ Continuous dose column.
187
+ aggregate : str, optional
188
+ ``"dose"`` for dose-response aggregation, ``"eventstudy"`` for
189
+ binarized event study.
190
+ survey_design : SurveyDesign, optional
191
+ Survey design specification for design-based inference.
192
+ Supports weighted estimation and Taylor series linearization
193
+ variance with strata, PSU, and FPC.
194
+
195
+ Returns
196
+ -------
197
+ ContinuousDiDResults
198
+ """
199
+ # 1. Validate & prepare
200
+ _VALID_AGGREGATES = (None, "dose", "eventstudy")
201
+ if aggregate not in _VALID_AGGREGATES:
202
+ raise ValueError(
203
+ f"Invalid aggregate: '{aggregate}'. " f"Must be one of {_VALID_AGGREGATES}."
204
+ )
205
+
206
+ # Resolve survey design if provided
207
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
208
+ _resolve_survey_for_fit(survey_design, data, "analytical")
209
+ )
210
+
211
+ # Validate within-unit constancy for panel survey designs
212
+ if resolved_survey is not None:
213
+ _validate_unit_constant_survey(data, unit, survey_design)
214
+
215
+ # Bootstrap + survey supported via PSU-level multiplier bootstrap.
216
+
217
+ df = data.copy()
218
+ for col in [outcome, unit, time, first_treat, dose]:
219
+ if col not in df.columns:
220
+ raise ValueError(f"Column '{col}' not found in data.")
221
+
222
+ # Verify dose is time-invariant
223
+ dose_nunique = df.groupby(unit)[dose].nunique()
224
+ if dose_nunique.max() > 1:
225
+ bad_units = dose_nunique[dose_nunique > 1].index.tolist()
226
+ raise ValueError(
227
+ f"Dose must be time-invariant. Units with varying dose: {bad_units[:5]}"
228
+ )
229
+
230
+ # Normalize first_treat: inf → 0
231
+ df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
232
+
233
+ # Drop units with positive first_treat but zero dose (R convention)
234
+ unit_info = df.groupby(unit).first()[[first_treat, dose]]
235
+ drop_units = unit_info[(unit_info[first_treat] > 0) & (unit_info[dose] == 0)].index
236
+ if len(drop_units) > 0:
237
+ warnings.warn(
238
+ f"Dropping {len(drop_units)} units with positive first_treat but zero dose.",
239
+ UserWarning,
240
+ stacklevel=2,
241
+ )
242
+ df = df[~df[unit].isin(drop_units)]
243
+
244
+ # Validate no negative doses among treated units
245
+ treated_doses = df.loc[df[first_treat] > 0, dose]
246
+ if (treated_doses < 0).any():
247
+ n_neg = int((treated_doses < 0).sum())
248
+ raise ValueError(
249
+ f"Found {n_neg} treated unit(s) with negative dose. "
250
+ f"Dose must be strictly positive for treated units (D > 0)."
251
+ )
252
+
253
+ # Detect discrete (integer-valued) dose among treated units
254
+ unit_doses = df.loc[df[first_treat] > 0].groupby(unit)[dose].first()
255
+ unique_pos_doses = unit_doses[unit_doses > 0].unique()
256
+ is_integer = len(unique_pos_doses) > 0 and np.allclose(
257
+ unique_pos_doses, np.round(unique_pos_doses)
258
+ )
259
+ if is_integer:
260
+ warnings.warn(
261
+ f"Dose appears discrete ({len(unique_pos_doses)} unique integer values). "
262
+ "B-spline smoothing may be inappropriate for discrete treatments. "
263
+ "Consider a saturated regression approach (not yet implemented).",
264
+ UserWarning,
265
+ stacklevel=2,
266
+ )
267
+
268
+ # Force dose=0 for never-treated units with nonzero dose
269
+ never_treated_mask = df[first_treat] == 0
270
+ if (df.loc[never_treated_mask, dose] != 0).any():
271
+ df.loc[never_treated_mask, dose] = 0.0
272
+
273
+ # Verify balanced panel
274
+ all_periods = set(df[time].unique())
275
+ unit_periods = df.groupby(unit)[time].apply(set)
276
+ is_unbalanced = unit_periods.apply(lambda s: s != all_periods)
277
+ if is_unbalanced.any():
278
+ n_bad = int(is_unbalanced.sum())
279
+ raise ValueError(
280
+ "Unbalanced panel detected. ContinuousDiD requires a balanced panel. "
281
+ f"{n_bad} unit(s) have missing periods."
282
+ )
283
+
284
+ # Identify groups and time periods
285
+ unit_cohort = df.groupby(unit)[first_treat].first()
286
+ treatment_groups = sorted([g for g in unit_cohort.unique() if g > 0])
287
+ time_periods = sorted(df[time].unique())
288
+
289
+ if len(treatment_groups) == 0:
290
+ raise ValueError("No treated units found (all first_treat == 0).")
291
+
292
+ n_control = int((unit_cohort == 0).sum())
293
+ if self.control_group == "never_treated" and n_control == 0:
294
+ raise ValueError(
295
+ "No never-treated units found. Use control_group='not_yet_treated' "
296
+ "or add never-treated units."
297
+ )
298
+
299
+ if self.control_group == "not_yet_treated" and n_control == 0:
300
+ raise ValueError(
301
+ "No never-treated (D=0) units found. With control_group='not_yet_treated', "
302
+ "dose-response curve identification requires P(D=0) > 0 "
303
+ "(Remark 3.1 in Callaway et al. is not yet implemented). "
304
+ "Add never-treated units or use a dataset with D=0 observations."
305
+ )
306
+
307
+ # Re-resolve survey design on filtered df if rows were dropped
308
+ # (survey arrays must align with df, not the original data)
309
+ if resolved_survey is not None and len(df) < len(data):
310
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
311
+ _resolve_survey_for_fit(survey_design, df, "analytical")
312
+ )
313
+
314
+ # 2. Precompute structures
315
+ precomp = self._precompute_structures(
316
+ df,
317
+ outcome,
318
+ unit,
319
+ time,
320
+ first_treat,
321
+ dose,
322
+ time_periods,
323
+ survey_weights=survey_weights,
324
+ )
325
+
326
+ # Compute dvals (evaluation grid)
327
+ all_treated_doses = precomp["dose_vector"][precomp["dose_vector"] > 0]
328
+ if self.dvals is not None:
329
+ dvals = self.dvals
330
+ else:
331
+ dvals = default_dose_grid(all_treated_doses)
332
+
333
+ # Build B-spline knots from all treated doses
334
+ knots, degree = build_bspline_basis(
335
+ all_treated_doses, degree=self.degree, num_knots=self.num_knots
336
+ )
337
+
338
+ # 3. Iterate over (g,t) cells
339
+ gt_results = {}
340
+ gt_bootstrap_info = {}
341
+
342
+ for g in treatment_groups:
343
+ for t in time_periods:
344
+ result = self._compute_dose_response_gt(
345
+ precomp,
346
+ g,
347
+ t,
348
+ knots,
349
+ degree,
350
+ dvals,
351
+ survey_weights=precomp.get("unit_survey_weights"),
352
+ resolved_survey=resolved_survey,
353
+ )
354
+ if result is not None:
355
+ gt_results[(g, t)] = result
356
+ gt_bootstrap_info[(g, t)] = result.get("_bootstrap_info", {})
357
+
358
+ # Filter out NaN cells (e.g., from zero effective survey mass)
359
+ gt_results = {
360
+ gt: r for gt, r in gt_results.items()
361
+ if np.isfinite(r.get("att_glob", np.nan))
362
+ }
363
+
364
+ if len(gt_results) == 0:
365
+ raise ValueError("No valid (g,t) cells computed.")
366
+
367
+ # 4. Aggregate
368
+ post_gt = {(g, t): r for (g, t), r in gt_results.items() if t >= g - self.anticipation}
369
+
370
+ # Dose-response aggregation
371
+ n_grid = len(dvals)
372
+
373
+ # NaN-initialized SE/CI fields (used when post_gt is empty or as defaults)
374
+ att_d_se = np.full(n_grid, np.nan)
375
+ att_d_ci_lower = np.full(n_grid, np.nan)
376
+ att_d_ci_upper = np.full(n_grid, np.nan)
377
+ acrt_d_se = np.full(n_grid, np.nan)
378
+ acrt_d_ci_lower = np.full(n_grid, np.nan)
379
+ acrt_d_ci_upper = np.full(n_grid, np.nan)
380
+ overall_att_se = np.nan
381
+ overall_att_t = np.nan
382
+ overall_att_p = np.nan
383
+ overall_att_ci = (np.nan, np.nan)
384
+ overall_acrt_se = np.nan
385
+ overall_acrt_t = np.nan
386
+ overall_acrt_p = np.nan
387
+ overall_acrt_ci = (np.nan, np.nan)
388
+ att_d_p = None
389
+ acrt_d_p = None
390
+
391
+ # Event study aggregation (binarized) — runs on ALL (g,t) cells
392
+ event_study_effects = None
393
+ if aggregate == "eventstudy":
394
+ event_study_effects = self._aggregate_event_study(
395
+ gt_results,
396
+ gt_bootstrap_info=gt_bootstrap_info,
397
+ unit_survey_weights=precomp.get("unit_survey_weights"),
398
+ unit_cohorts=precomp["unit_cohorts"],
399
+ anticipation=self.anticipation,
400
+ )
401
+
402
+ _survey_df = None # Set by analytical branch when survey is active
403
+
404
+ if len(post_gt) == 0:
405
+ warnings.warn(
406
+ "No post-treatment (g,t) cells available for aggregation. "
407
+ "This can occur when all treatments start after the last observed "
408
+ "period or all cells were skipped due to insufficient data.",
409
+ UserWarning,
410
+ stacklevel=2,
411
+ )
412
+ overall_att = np.nan
413
+ overall_acrt = np.nan
414
+ agg_att_d = np.full(n_grid, np.nan)
415
+ agg_acrt_d = np.full(n_grid, np.nan)
416
+ else:
417
+ # Compute cell weights: group-proportional (matching R's contdid convention).
418
+ # Each group g gets weight proportional to its number of treated units.
419
+ # When survey weights present, use sum(w_g) / sum(w) instead of n_g / N.
420
+ # Within each group, weight is divided equally among post-treatment cells.
421
+ group_n_treated = {}
422
+ group_n_post_cells = {}
423
+ unit_sw = precomp.get("unit_survey_weights")
424
+ for (g, t), r in post_gt.items():
425
+ if g not in group_n_treated:
426
+ if unit_sw is not None:
427
+ # Survey-weighted group size: sum of weights for treated units in g
428
+ g_mask = precomp["unit_cohorts"] == g
429
+ group_n_treated[g] = float(np.sum(unit_sw[g_mask]))
430
+ else:
431
+ group_n_treated[g] = float(r["n_treated"])
432
+ group_n_post_cells[g] = 0
433
+ group_n_post_cells[g] += 1
434
+
435
+ total_treated = sum(group_n_treated.values())
436
+ cell_weights = {}
437
+ if total_treated > 0:
438
+ for (g, t), r in post_gt.items():
439
+ pg = group_n_treated[g] / total_treated
440
+ cell_weights[(g, t)] = pg / group_n_post_cells[g]
441
+
442
+ agg_att_d = np.zeros(n_grid)
443
+ agg_acrt_d = np.zeros(n_grid)
444
+ overall_att = 0.0
445
+ overall_acrt = 0.0
446
+
447
+ for gt, w in cell_weights.items():
448
+ r = post_gt[gt]
449
+ agg_att_d += w * r["att_d"]
450
+ agg_acrt_d += w * r["acrt_d"]
451
+ overall_att += w * r["att_glob"]
452
+ overall_acrt += w * r["acrt_glob"]
453
+
454
+ # 5. Bootstrap / Analytical SE
455
+ if self.n_bootstrap > 0:
456
+ boot_result = self._run_bootstrap(
457
+ precomp,
458
+ gt_results,
459
+ gt_bootstrap_info,
460
+ post_gt,
461
+ cell_weights,
462
+ knots,
463
+ degree,
464
+ dvals,
465
+ overall_att,
466
+ overall_acrt,
467
+ agg_att_d,
468
+ agg_acrt_d,
469
+ event_study_effects,
470
+ resolved_survey=resolved_survey,
471
+ )
472
+ att_d_se = boot_result["att_d_se"]
473
+ att_d_ci_lower = boot_result["att_d_ci_lower"]
474
+ att_d_ci_upper = boot_result["att_d_ci_upper"]
475
+ acrt_d_se = boot_result["acrt_d_se"]
476
+ acrt_d_ci_lower = boot_result["acrt_d_ci_lower"]
477
+ acrt_d_ci_upper = boot_result["acrt_d_ci_upper"]
478
+ att_d_p = boot_result["att_d_p"]
479
+ acrt_d_p = boot_result["acrt_d_p"]
480
+ overall_att_se = boot_result["overall_att_se"]
481
+ overall_att_t = safe_inference(overall_att, overall_att_se, self.alpha)[0]
482
+ overall_att_p = boot_result["overall_att_p"]
483
+ overall_att_ci = boot_result["overall_att_ci"]
484
+ overall_acrt_se = boot_result["overall_acrt_se"]
485
+ overall_acrt_t = safe_inference(overall_acrt, overall_acrt_se, self.alpha)[0]
486
+ overall_acrt_p = boot_result["overall_acrt_p"]
487
+ overall_acrt_ci = boot_result["overall_acrt_ci"]
488
+ if event_study_effects is not None:
489
+ for e, info in event_study_effects.items():
490
+ if e in boot_result.get("es_se", {}):
491
+ info["se"] = boot_result["es_se"][e]
492
+ info["t_stat"] = safe_inference(info["effect"], info["se"], self.alpha)[
493
+ 0
494
+ ]
495
+ info["p_value"] = boot_result["es_p"][e]
496
+ info["conf_int"] = boot_result["es_ci"][e]
497
+ else:
498
+ # Analytical SEs via influence functions
499
+ analytic = self._compute_analytical_se(
500
+ precomp,
501
+ gt_results,
502
+ gt_bootstrap_info,
503
+ post_gt,
504
+ cell_weights,
505
+ knots,
506
+ degree,
507
+ dvals,
508
+ agg_att_d,
509
+ agg_acrt_d,
510
+ resolved_survey=resolved_survey,
511
+ )
512
+ att_d_se = analytic["att_d_se"]
513
+ acrt_d_se = analytic["acrt_d_se"]
514
+ overall_att_se = analytic["overall_att_se"]
515
+ overall_acrt_se = analytic["overall_acrt_se"]
516
+
517
+ # Survey df for t-distribution inference (unit-level, not panel-level)
518
+ _survey_df = analytic.get("df_survey")
519
+ # Guard: replicate design with undefined df → NaN inference
520
+ if (_survey_df is None and resolved_survey is not None
521
+ and hasattr(resolved_survey, 'uses_replicate_variance')
522
+ and resolved_survey.uses_replicate_variance):
523
+ _survey_df = 0
524
+
525
+ # Recompute survey_metadata from unit-level design so reported
526
+ # effective_n/n_psu/df_survey match the inference actually run
527
+ _unit_resolved = analytic.get("unit_resolved")
528
+ if _unit_resolved is not None:
529
+ from diff_diff.survey import compute_survey_metadata
530
+
531
+ raw_w_unit = _unit_resolved.weights
532
+ survey_metadata = compute_survey_metadata(_unit_resolved, raw_w_unit)
533
+
534
+ # Propagate replicate df override to survey_metadata for display
535
+ # (but not the df=0 sentinel — keep metadata as None for undefined df)
536
+ if (_survey_df is not None and _survey_df != 0
537
+ and survey_metadata is not None):
538
+ if survey_metadata.df_survey != _survey_df:
539
+ survey_metadata.df_survey = _survey_df
540
+
541
+ overall_att_t, overall_att_p, overall_att_ci = safe_inference(
542
+ overall_att, overall_att_se, self.alpha, df=_survey_df
543
+ )
544
+ overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference(
545
+ overall_acrt, overall_acrt_se, self.alpha, df=_survey_df
546
+ )
547
+
548
+ # Per-grid-point inference for dose-response
549
+ for idx in range(n_grid):
550
+ _, _, ci = safe_inference(
551
+ agg_att_d[idx], att_d_se[idx], self.alpha, df=_survey_df
552
+ )
553
+ att_d_ci_lower[idx] = ci[0]
554
+ att_d_ci_upper[idx] = ci[1]
555
+
556
+ _, _, ci = safe_inference(
557
+ agg_acrt_d[idx], acrt_d_se[idx], self.alpha, df=_survey_df
558
+ )
559
+ acrt_d_ci_lower[idx] = ci[0]
560
+ acrt_d_ci_upper[idx] = ci[1]
561
+
562
+ # Event study analytical SEs
563
+ if event_study_effects is not None:
564
+ n_units = precomp["n_units"]
565
+ unit_sw = precomp.get("unit_survey_weights")
566
+
567
+ # Build unit-level ResolvedSurveyDesign once (reused per bin)
568
+ unit_resolved_es = None
569
+ if resolved_survey is not None:
570
+ row_idx = precomp["unit_first_panel_row"]
571
+ uw = (
572
+ precomp.get("unit_survey_weights")
573
+ if precomp.get("unit_survey_weights") is not None
574
+ else np.ones(n_units)
575
+ )
576
+ us = (
577
+ resolved_survey.strata[row_idx]
578
+ if resolved_survey.strata is not None
579
+ else None
580
+ )
581
+ up = (
582
+ resolved_survey.psu[row_idx]
583
+ if resolved_survey.psu is not None
584
+ else None
585
+ )
586
+ uf = (
587
+ resolved_survey.fpc[row_idx]
588
+ if resolved_survey.fpc is not None
589
+ else None
590
+ )
591
+ n_strata_u = len(np.unique(us)) if us is not None else 0
592
+ n_psu_u = len(np.unique(up)) if up is not None else 0
593
+ unit_resolved_es = resolved_survey.subset_to_units(
594
+ row_idx, uw, us, up, uf, n_strata_u, n_psu_u,
595
+ )
596
+
597
+ for e_val, info_e in event_study_effects.items():
598
+ # Collect (g,t) cells for this event-time bin
599
+ e_gts = [gt for gt in gt_results if gt[1] - gt[0] == e_val]
600
+ if not e_gts:
601
+ continue
602
+ # Weights within this bin: survey-weighted mass or n_treated
603
+ if unit_sw is not None:
604
+ unit_cohorts = precomp["unit_cohorts"]
605
+ ns = np.array(
606
+ [float(np.sum(unit_sw[unit_cohorts == gt[0]])) for gt in e_gts],
607
+ dtype=float,
608
+ )
609
+ else:
610
+ ns = np.array(
611
+ [gt_results[gt]["n_treated"] for gt in e_gts],
612
+ dtype=float,
613
+ )
614
+ total_n = ns.sum()
615
+ if total_n == 0:
616
+ continue
617
+ ws = ns / total_n
618
+
619
+ # Build per-unit IF for this event-time bin
620
+ if_es = np.zeros(n_units)
621
+ for idx_cell, gt in enumerate(e_gts):
622
+ b_info = gt_bootstrap_info.get(gt, {})
623
+ if not b_info:
624
+ continue
625
+ w = ws[idx_cell]
626
+ treated_idx = b_info["treated_indices"]
627
+ control_idx = b_info["control_indices"]
628
+ n_t = b_info["n_treated"]
629
+ n_c = b_info["n_control"]
630
+ # Use survey-weighted masses when available
631
+ if "w_treated" in b_info:
632
+ n_t = b_info["w_treated"]
633
+ n_c = b_info["w_control"]
634
+ n_total_gt = n_t + n_c
635
+ p_1 = n_t / n_total_gt
636
+ p_0 = n_c / n_total_gt
637
+ att_glob_gt = b_info["att_glob"]
638
+ mu_0 = b_info["mu_0"]
639
+ delta_y_treated = b_info["delta_y_treated"]
640
+ ee_control = b_info["ee_control"]
641
+ sw_treated = b_info.get("w_treated_arr")
642
+
643
+ for k, uid in enumerate(treated_idx):
644
+ score_k = delta_y_treated[k] - att_glob_gt - mu_0
645
+ if sw_treated is not None:
646
+ score_k = sw_treated[k] * score_k
647
+ if_es[uid] += w * score_k / p_1 / n_total_gt
648
+ for k, uid in enumerate(control_idx):
649
+ if_es[uid] -= w * ee_control[k] / p_0 / n_total_gt
650
+
651
+ # Compute SE: survey-aware TSL or standard sqrt(sum(IF^2))
652
+ if unit_resolved_es is not None:
653
+ if unit_resolved_es.uses_replicate_variance:
654
+ from diff_diff.survey import compute_replicate_if_variance
655
+
656
+ # Score-scale: psi = w * if_es (matches TSL bread)
657
+ psi_es = unit_resolved_es.weights * if_es
658
+ variance, _nv = compute_replicate_if_variance(psi_es, unit_resolved_es)
659
+ es_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
660
+ else:
661
+ X_ones_es = np.ones((n_units, 1))
662
+ tsl_scale_es = float(unit_resolved_es.weights.sum())
663
+ if_es_tsl = if_es * tsl_scale_es
664
+ vcov_es = compute_survey_vcov(X_ones_es, if_es_tsl, unit_resolved_es)
665
+ es_se = float(np.sqrt(np.abs(vcov_es[0, 0])))
666
+ else:
667
+ es_se = float(np.sqrt(np.sum(if_es**2)))
668
+
669
+ t_stat, p_val, ci_es = safe_inference(
670
+ info_e["effect"], es_se, self.alpha, df=_survey_df
671
+ )
672
+ info_e["se"] = es_se
673
+ info_e["t_stat"] = t_stat
674
+ info_e["p_value"] = p_val
675
+ info_e["conf_int"] = ci_es
676
+
677
+ # 6. Assemble results
678
+ dose_response_att = DoseResponseCurve(
679
+ dose_grid=dvals,
680
+ effects=agg_att_d,
681
+ se=att_d_se,
682
+ conf_int_lower=att_d_ci_lower,
683
+ conf_int_upper=att_d_ci_upper,
684
+ target="att",
685
+ p_value=att_d_p,
686
+ n_bootstrap=self.n_bootstrap,
687
+ df_survey=_survey_df,
688
+ )
689
+ dose_response_acrt = DoseResponseCurve(
690
+ dose_grid=dvals,
691
+ effects=agg_acrt_d,
692
+ se=acrt_d_se,
693
+ conf_int_lower=acrt_d_ci_lower,
694
+ conf_int_upper=acrt_d_ci_upper,
695
+ target="acrt",
696
+ p_value=acrt_d_p,
697
+ n_bootstrap=self.n_bootstrap,
698
+ df_survey=_survey_df,
699
+ )
700
+
701
+ # Strip bootstrap internals from gt_results
702
+ clean_gt = {}
703
+ for gt, r in gt_results.items():
704
+ clean_gt[gt] = {k: v for k, v in r.items() if not k.startswith("_")}
705
+
706
+ return ContinuousDiDResults(
707
+ dose_response_att=dose_response_att,
708
+ dose_response_acrt=dose_response_acrt,
709
+ overall_att=overall_att,
710
+ overall_att_se=overall_att_se,
711
+ overall_att_t_stat=overall_att_t,
712
+ overall_att_p_value=overall_att_p,
713
+ overall_att_conf_int=overall_att_ci,
714
+ overall_acrt=overall_acrt,
715
+ overall_acrt_se=overall_acrt_se,
716
+ overall_acrt_t_stat=overall_acrt_t,
717
+ overall_acrt_p_value=overall_acrt_p,
718
+ overall_acrt_conf_int=overall_acrt_ci,
719
+ group_time_effects=clean_gt,
720
+ dose_grid=dvals,
721
+ groups=treatment_groups,
722
+ time_periods=time_periods,
723
+ n_obs=len(df),
724
+ n_treated_units=int((unit_cohort > 0).sum()),
725
+ n_control_units=n_control,
726
+ alpha=self.alpha,
727
+ control_group=self.control_group,
728
+ degree=self.degree,
729
+ num_knots=self.num_knots,
730
+ base_period=self.base_period,
731
+ anticipation=self.anticipation,
732
+ n_bootstrap=self.n_bootstrap,
733
+ bootstrap_weights=self.bootstrap_weights,
734
+ seed=self.seed,
735
+ rank_deficient_action=self.rank_deficient_action,
736
+ event_study_effects=event_study_effects,
737
+ survey_metadata=survey_metadata,
738
+ )
739
+
740
+ # ------------------------------------------------------------------
741
+ # Internal helpers
742
+ # ------------------------------------------------------------------
743
+
744
+ def _precompute_structures(
745
+ self,
746
+ df: pd.DataFrame,
747
+ outcome: str,
748
+ unit: str,
749
+ time: str,
750
+ first_treat: str,
751
+ dose: str,
752
+ time_periods: List[Any],
753
+ survey_weights: Optional[np.ndarray] = None,
754
+ ) -> Dict[str, Any]:
755
+ """Pivot to wide format and build lookup structures."""
756
+ all_units = sorted(df[unit].unique())
757
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
758
+ n_units = len(all_units)
759
+ n_periods = len(time_periods)
760
+ period_to_col = {t: j for j, t in enumerate(time_periods)}
761
+
762
+ # Outcome matrix: (n_units, n_periods)
763
+ outcome_matrix = np.full((n_units, n_periods), np.nan)
764
+ for _, row in df.iterrows():
765
+ i = unit_to_idx[row[unit]]
766
+ j = period_to_col[row[time]]
767
+ outcome_matrix[i, j] = row[outcome]
768
+
769
+ # Per-unit cohort and dose
770
+ unit_cohorts = np.zeros(n_units, dtype=float)
771
+ dose_vector = np.zeros(n_units, dtype=float)
772
+ unit_first = df.groupby(unit).first()
773
+ for u in all_units:
774
+ i = unit_to_idx[u]
775
+ unit_cohorts[i] = unit_first.loc[u, first_treat]
776
+ dose_vector[i] = unit_first.loc[u, dose]
777
+
778
+ # Build unit-to-first-panel-row mapping (for subsetting panel-level arrays)
779
+ # This maps each unit index to the positional index of its first row in df.
780
+ unit_first_panel_row = np.zeros(n_units, dtype=int)
781
+ seen_units: set = set()
782
+ for pos_idx, (_, row) in enumerate(df.iterrows()):
783
+ u = row[unit]
784
+ if u not in seen_units:
785
+ seen_units.add(u)
786
+ unit_first_panel_row[unit_to_idx[u]] = pos_idx
787
+
788
+ # Per-unit survey weights (take first obs per unit from panel data)
789
+ unit_survey_weights = None
790
+ if survey_weights is not None:
791
+ unit_survey_weights = survey_weights[unit_first_panel_row]
792
+
793
+ # Cohort masks
794
+ cohort_masks = {}
795
+ unique_cohorts = np.unique(unit_cohorts)
796
+ for c in unique_cohorts:
797
+ cohort_masks[c] = unit_cohorts == c
798
+
799
+ never_treated_mask = unit_cohorts == 0
800
+
801
+ return {
802
+ "all_units": all_units,
803
+ "unit_to_idx": unit_to_idx,
804
+ "outcome_matrix": outcome_matrix,
805
+ "period_to_col": period_to_col,
806
+ "unit_cohorts": unit_cohorts,
807
+ "dose_vector": dose_vector,
808
+ "cohort_masks": cohort_masks,
809
+ "never_treated_mask": never_treated_mask,
810
+ "time_periods": time_periods,
811
+ "n_units": n_units,
812
+ "unit_survey_weights": unit_survey_weights,
813
+ "unit_first_panel_row": unit_first_panel_row,
814
+ }
815
+
816
+ def _compute_dose_response_gt(
817
+ self,
818
+ precomp: Dict[str, Any],
819
+ g: Any,
820
+ t: Any,
821
+ knots: np.ndarray,
822
+ degree: int,
823
+ dvals: np.ndarray,
824
+ survey_weights: Optional[np.ndarray] = None,
825
+ resolved_survey: object = None,
826
+ ) -> Optional[Dict[str, Any]]:
827
+ """Compute dose-response for a single (g,t) cell."""
828
+ period_to_col = precomp["period_to_col"]
829
+ outcome_matrix = precomp["outcome_matrix"]
830
+ unit_cohorts = precomp["unit_cohorts"]
831
+ dose_vector = precomp["dose_vector"]
832
+ never_treated_mask = precomp["never_treated_mask"]
833
+ time_periods = precomp["time_periods"]
834
+
835
+ # Base period selection
836
+ is_post = t >= g - self.anticipation
837
+ if self.base_period == "varying":
838
+ if is_post:
839
+ base_t = g - 1 - self.anticipation
840
+ else:
841
+ # Pre-treatment: use t-1
842
+ t_idx = time_periods.index(t)
843
+ if t_idx == 0:
844
+ return None # No prior period
845
+ base_t = time_periods[t_idx - 1]
846
+ else:
847
+ # Universal base period
848
+ base_t = g - 1 - self.anticipation
849
+
850
+ if base_t not in period_to_col or t not in period_to_col:
851
+ return None
852
+
853
+ col_t = period_to_col[t]
854
+ col_base = period_to_col[base_t]
855
+
856
+ # Treated units: first_treat == g and dose > 0
857
+ treated_mask = (unit_cohorts == g) & (dose_vector > 0)
858
+ n_treated = int(np.sum(treated_mask))
859
+ if n_treated == 0:
860
+ return None
861
+
862
+ # Control units
863
+ if self.control_group == "never_treated":
864
+ control_mask = never_treated_mask
865
+ else:
866
+ # Not-yet-treated: never-treated + first_treat > t
867
+ control_mask = never_treated_mask | (
868
+ (unit_cohorts > t + self.anticipation) & (unit_cohorts != g)
869
+ )
870
+ n_control = int(np.sum(control_mask))
871
+ if n_control == 0:
872
+ warnings.warn(
873
+ f"No control units for (g={g}, t={t}). Skipping.",
874
+ UserWarning,
875
+ stacklevel=3,
876
+ )
877
+ return None
878
+
879
+ # Outcome changes
880
+ delta_y_treated = (
881
+ outcome_matrix[treated_mask, col_t] - outcome_matrix[treated_mask, col_base]
882
+ )
883
+ delta_y_control = (
884
+ outcome_matrix[control_mask, col_t] - outcome_matrix[control_mask, col_base]
885
+ )
886
+
887
+ # Subset survey weights to the cell
888
+ w_treated = None
889
+ w_control = None
890
+ if survey_weights is not None:
891
+ w_treated = survey_weights[treated_mask]
892
+ w_control = survey_weights[control_mask]
893
+ # Guard against zero effective mass (e.g., after subpopulation)
894
+ if np.sum(w_treated) <= 0 or np.sum(w_control) <= 0:
895
+ return {
896
+ "att_glob": np.nan, "acrt_glob": np.nan,
897
+ "n_treated": 0, "n_control": 0,
898
+ "att_d": np.full(len(dvals), np.nan),
899
+ "acrt_d": np.full(len(dvals), np.nan),
900
+ }
901
+
902
+ # Control counterfactual (weighted mean when survey weights present)
903
+ if w_control is not None:
904
+ mu_0 = float(np.average(delta_y_control, weights=w_control))
905
+ else:
906
+ mu_0 = float(np.mean(delta_y_control))
907
+
908
+ # Demean
909
+ delta_tilde_y = delta_y_treated - mu_0
910
+
911
+ # Treated doses
912
+ treated_doses = dose_vector[treated_mask]
913
+
914
+ # B-spline OLS
915
+ Psi = bspline_design_matrix(treated_doses, knots, degree, include_intercept=True)
916
+ n_basis = Psi.shape[1]
917
+
918
+ # Check for all-same dose
919
+ if np.all(treated_doses == treated_doses[0]):
920
+ warnings.warn(
921
+ f"All treated doses identical in (g={g}, t={t}). " "ACRT(d) will be 0 everywhere.",
922
+ UserWarning,
923
+ stacklevel=3,
924
+ )
925
+
926
+ # Skip if not enough treated units for OLS (need n > K for residual df)
927
+ # When survey weights are present, use positive-weight count as
928
+ # the effective sample size — subpopulation() can zero weights
929
+ # leaving rows present but the weighted regression underidentified.
930
+ n_eff = int(np.count_nonzero(w_treated > 0)) if w_treated is not None else n_treated
931
+ if n_eff <= n_basis:
932
+ label = "positive-weight treated units" if w_treated is not None else "treated units"
933
+ warnings.warn(
934
+ f"Not enough {label} ({n_eff}) for {n_basis} basis functions "
935
+ f"in (g={g}, t={t}). Skipping cell.",
936
+ UserWarning,
937
+ stacklevel=3,
938
+ )
939
+ return None
940
+
941
+ # OLS or WLS regression
942
+ if w_treated is not None:
943
+ # WLS: apply sqrt(w) transformation
944
+ sqrt_w = np.sqrt(w_treated)
945
+ Psi_w = Psi * sqrt_w[:, np.newaxis]
946
+ delta_tilde_y_w = delta_tilde_y * sqrt_w
947
+ beta_hat, _, _ = solve_ols(
948
+ Psi_w,
949
+ delta_tilde_y_w,
950
+ return_vcov=False,
951
+ rank_deficient_action=self.rank_deficient_action,
952
+ )
953
+ # Residuals on original scale (for influence functions)
954
+ beta_pred_tmp = np.where(np.isnan(beta_hat), 0.0, beta_hat)
955
+ residuals = delta_tilde_y - Psi @ beta_pred_tmp
956
+ else:
957
+ beta_hat, residuals, _ = solve_ols(
958
+ Psi,
959
+ delta_tilde_y,
960
+ return_vcov=False,
961
+ rank_deficient_action=self.rank_deficient_action,
962
+ )
963
+
964
+ # For prediction: zero out NaN (dropped rank-deficient columns).
965
+ # solve_ols sets dropped-column coefficients to NaN (R convention);
966
+ # zeroing them produces correct predictions: ATT(d) = intercept
967
+ # (constant), ACRT(d) = 0 (derivative of intercept is 0).
968
+ beta_pred = np.where(np.isnan(beta_hat), 0.0, beta_hat)
969
+
970
+ # Evaluate ATT(d) and ACRT(d) at dvals
971
+ Psi_eval = bspline_design_matrix(dvals, knots, degree, include_intercept=True)
972
+ dPsi_eval = bspline_derivative_design_matrix(dvals, knots, degree, include_intercept=True)
973
+
974
+ att_d = Psi_eval @ beta_pred
975
+ acrt_d = dPsi_eval @ beta_pred
976
+
977
+ # Summary parameters
978
+ if w_treated is not None:
979
+ att_glob = float(np.average(delta_y_treated, weights=w_treated) - mu_0)
980
+ else:
981
+ att_glob = float(np.mean(delta_y_treated) - mu_0)
982
+
983
+ # ACRT^{glob}: plug-in average of ACRT(D_i) for treated
984
+ dPsi_treated = bspline_derivative_design_matrix(
985
+ treated_doses, knots, degree, include_intercept=True
986
+ )
987
+ if w_treated is not None:
988
+ acrt_glob = float(np.average(dPsi_treated @ beta_pred, weights=w_treated))
989
+ else:
990
+ acrt_glob = float(np.mean(dPsi_treated @ beta_pred))
991
+
992
+ # Store bootstrap info for influence function computation
993
+ # bread = (Psi'WPsi / n_treated)^{-1} when survey, (Psi'Psi / n_treated)^{-1} otherwise
994
+ if w_treated is not None:
995
+ w_treated_sum = float(np.sum(w_treated))
996
+ PtWP = Psi.T @ (Psi * w_treated[:, np.newaxis])
997
+ # Normalize bread by weighted mass (not raw count) for consistency
998
+ # with downstream IF score denominators that also use weighted mass
999
+ try:
1000
+ bread = np.linalg.inv(PtWP / w_treated_sum)
1001
+ except np.linalg.LinAlgError:
1002
+ bread = np.linalg.pinv(PtWP / w_treated_sum)
1003
+ else:
1004
+ PtP = Psi.T @ Psi
1005
+ try:
1006
+ bread = np.linalg.inv(PtP / n_treated)
1007
+ except np.linalg.LinAlgError:
1008
+ bread = np.linalg.pinv(PtP / n_treated)
1009
+
1010
+ # ee_treated: per-unit estimating equation vectors (K-vector per unit)
1011
+ # For WLS (survey weights), the score is w_i * X_i * u_i to match the
1012
+ # weighted bread inv(X'WX / sum(w)). Without this factor the sandwich
1013
+ # is inconsistent. For OLS (no survey weights), the score is X_i * u_i.
1014
+ if w_treated is not None:
1015
+ ee_treated = Psi * (w_treated * residuals)[:, np.newaxis] # (n_treated, K)
1016
+ else:
1017
+ ee_treated = Psi * residuals[:, np.newaxis] # (n_treated, K)
1018
+
1019
+ # ee_control: per-unit deviation from control mean (weighted for WLS)
1020
+ if w_control is not None:
1021
+ ee_control = w_control * (delta_y_control - mu_0) # (n_control,)
1022
+ else:
1023
+ ee_control = delta_y_control - mu_0 # (n_control,)
1024
+
1025
+ # psi_bar: mean basis vector for treated (weighted when survey)
1026
+ if w_treated is not None:
1027
+ psi_bar = np.average(Psi, axis=0, weights=w_treated)
1028
+ else:
1029
+ psi_bar = np.mean(Psi, axis=0) # (K,)
1030
+
1031
+ # Unit indices for bootstrap
1032
+ treated_indices = np.where(treated_mask)[0]
1033
+ control_indices = np.where(control_mask)[0]
1034
+
1035
+ # dpsi_bar: mean derivative basis vector (weighted when survey)
1036
+ if w_treated is not None:
1037
+ dpsi_bar = np.average(dPsi_treated, axis=0, weights=w_treated)
1038
+ else:
1039
+ dpsi_bar = np.mean(dPsi_treated, axis=0)
1040
+
1041
+ bootstrap_info = {
1042
+ "bread": bread,
1043
+ "ee_treated": ee_treated,
1044
+ "ee_control": ee_control,
1045
+ "psi_bar": psi_bar,
1046
+ "dpsi_bar": dpsi_bar,
1047
+ "beta_hat": beta_hat,
1048
+ "beta_pred": beta_pred,
1049
+ "treated_indices": treated_indices,
1050
+ "control_indices": control_indices,
1051
+ "n_treated": n_treated,
1052
+ "n_control": n_control,
1053
+ "Psi_eval": Psi_eval,
1054
+ "dPsi_eval": dPsi_eval,
1055
+ "dPsi_treated": dPsi_treated,
1056
+ "delta_y_treated": delta_y_treated,
1057
+ "delta_y_control": delta_y_control,
1058
+ "mu_0": mu_0,
1059
+ "att_glob": att_glob,
1060
+ "acrt_glob": acrt_glob,
1061
+ }
1062
+
1063
+ # Store survey-weighted masses and per-unit arrays for IF linearization
1064
+ if w_treated is not None:
1065
+ bootstrap_info["w_treated"] = float(np.sum(w_treated))
1066
+ bootstrap_info["w_control"] = float(np.sum(w_control))
1067
+ bootstrap_info["w_treated_arr"] = w_treated
1068
+ bootstrap_info["w_control_arr"] = w_control
1069
+
1070
+ return {
1071
+ "att_d": att_d,
1072
+ "acrt_d": acrt_d,
1073
+ "att_glob": att_glob,
1074
+ "acrt_glob": acrt_glob,
1075
+ "beta_hat": beta_hat,
1076
+ "n_treated": n_treated,
1077
+ "n_control": n_control,
1078
+ "_bootstrap_info": bootstrap_info,
1079
+ }
1080
+
1081
+ def _aggregate_event_study(
1082
+ self,
1083
+ gt_results: Dict[Tuple, Dict],
1084
+ gt_bootstrap_info: Dict[Tuple, Dict] = None,
1085
+ unit_survey_weights: Optional[np.ndarray] = None,
1086
+ unit_cohorts: Optional[np.ndarray] = None,
1087
+ anticipation: int = 0,
1088
+ ) -> Dict[int, Dict[str, Any]]:
1089
+ """Aggregate binarized ATT_glob by relative period."""
1090
+ effects_by_e: Dict[int, List[Tuple[float, float, Tuple]]] = {}
1091
+
1092
+ for (g, t), r in gt_results.items():
1093
+ e = t - g
1094
+ if anticipation > 0 and e < -anticipation:
1095
+ continue
1096
+ if e not in effects_by_e:
1097
+ effects_by_e[e] = []
1098
+ # Compute weight for this (g,t) cell
1099
+ if unit_survey_weights is not None and unit_cohorts is not None:
1100
+ # Survey-weighted: sum of survey weights for treated units in group g
1101
+ g_mask = unit_cohorts == g
1102
+ cell_weight = float(np.sum(unit_survey_weights[g_mask]))
1103
+ else:
1104
+ cell_weight = float(r["n_treated"])
1105
+ effects_by_e[e].append((r["att_glob"], cell_weight, (g, t)))
1106
+
1107
+ result = {}
1108
+ for e, entries in sorted(effects_by_e.items()):
1109
+ effects = np.array([x[0] for x in entries])
1110
+ weights = np.array([x[1] for x in entries])
1111
+ if np.sum(weights) > 0:
1112
+ w = weights / np.sum(weights)
1113
+ agg = float(np.sum(w * effects))
1114
+ else:
1115
+ agg = np.nan
1116
+ result[e] = {
1117
+ "effect": agg,
1118
+ "se": np.nan,
1119
+ "t_stat": np.nan,
1120
+ "p_value": np.nan,
1121
+ "conf_int": (np.nan, np.nan),
1122
+ }
1123
+ return result
1124
+
1125
+ def _compute_analytical_se(
1126
+ self,
1127
+ precomp: Dict[str, Any],
1128
+ gt_results: Dict[Tuple, Dict],
1129
+ gt_bootstrap_info: Dict[Tuple, Dict],
1130
+ post_gt: Dict[Tuple, Dict],
1131
+ cell_weights: Dict[Tuple, float],
1132
+ knots: np.ndarray,
1133
+ degree: int,
1134
+ dvals: np.ndarray,
1135
+ agg_att_d: np.ndarray,
1136
+ agg_acrt_d: np.ndarray,
1137
+ resolved_survey: object = None,
1138
+ ) -> Dict[str, Any]:
1139
+ """Compute analytical SEs using influence functions."""
1140
+ n_units = precomp["n_units"]
1141
+ n_grid = len(dvals)
1142
+
1143
+ # Build per-unit influence functions for aggregated parameters
1144
+ # IF_i for overall ATT_glob (binarized)
1145
+ if_att_glob = np.zeros(n_units)
1146
+ if_acrt_glob = np.zeros(n_units)
1147
+ if_att_d = np.zeros((n_units, n_grid))
1148
+ if_acrt_d = np.zeros((n_units, n_grid))
1149
+
1150
+ for gt, w in cell_weights.items():
1151
+ if w == 0:
1152
+ continue
1153
+ info = gt_bootstrap_info[gt]
1154
+ if not info:
1155
+ continue
1156
+ treated_idx = info["treated_indices"]
1157
+ control_idx = info["control_indices"]
1158
+ n_t = info["n_treated"]
1159
+ n_c = info["n_control"]
1160
+ # Use survey-weighted masses when available
1161
+ if "w_treated" in info:
1162
+ n_t = info["w_treated"]
1163
+ n_c = info["w_control"]
1164
+ bread = info["bread"]
1165
+ ee_treated = info["ee_treated"]
1166
+ ee_control = info["ee_control"]
1167
+ psi_bar = info["psi_bar"]
1168
+ dpsi_bar = info["dpsi_bar"]
1169
+ Psi_eval = info["Psi_eval"]
1170
+ dPsi_eval = info["dPsi_eval"]
1171
+ att_glob_gt = info["att_glob"]
1172
+ mu_0 = info["mu_0"]
1173
+ delta_y_treated = info["delta_y_treated"]
1174
+ # Per-unit survey weight array (None when no survey)
1175
+ sw_treated = info.get("w_treated_arr")
1176
+
1177
+ n_total = n_t + n_c
1178
+ p_1 = n_t / n_total
1179
+ p_0 = n_c / n_total
1180
+
1181
+ # IF for ATT_glob (binarized DiD)
1182
+ # When survey weights are present, each unit's score includes its
1183
+ # survey weight w_k so the sandwich is consistent with the weighted
1184
+ # estimand. ee_control already contains the w_k factor (set in
1185
+ # _compute_dose_response_gt); delta_y_treated needs it here.
1186
+ for k, idx in enumerate(treated_idx):
1187
+ score_k = delta_y_treated[k] - att_glob_gt - mu_0
1188
+ if sw_treated is not None:
1189
+ score_k = sw_treated[k] * score_k
1190
+ if_att_glob[idx] += w * score_k / p_1 / n_total
1191
+ for k, idx in enumerate(control_idx):
1192
+ if_att_glob[idx] -= w * ee_control[k] / p_0 / n_total
1193
+
1194
+ # IF for beta perturbation → ATT(d) and ACRT(d)
1195
+ # beta perturbation from treated: bread @ (1/n_t) * sum w_i * ee_treated_i
1196
+ # beta perturbation from control: -bread @ psi_bar * (1/n_c) * sum w_i * ee_control_i
1197
+ # ATT_b(d) = Psi_eval @ beta_b => IF_i(d) contribution
1198
+
1199
+ # Treated unit contributions to beta
1200
+ for k, idx in enumerate(treated_idx):
1201
+ beta_pert = bread @ ee_treated[k] / n_t
1202
+ if_att_d[idx] += w * (Psi_eval @ beta_pert)
1203
+ if_acrt_d[idx] += w * (dPsi_eval @ beta_pert)
1204
+
1205
+ # Control unit contributions to beta (through mu_0)
1206
+ for k, idx in enumerate(control_idx):
1207
+ beta_pert = -bread @ psi_bar * ee_control[k] / n_c
1208
+ if_att_d[idx] += w * (Psi_eval @ beta_pert)
1209
+ if_acrt_d[idx] += w * (dPsi_eval @ beta_pert)
1210
+
1211
+ # ACRT_glob IF: (1/n_t) sum_j dpsi(D_j)' @ beta_pert
1212
+ for k, idx in enumerate(treated_idx):
1213
+ beta_pert = bread @ ee_treated[k] / n_t
1214
+ if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert)
1215
+ for k, idx in enumerate(control_idx):
1216
+ beta_pert = -bread @ psi_bar * ee_control[k] / n_c
1217
+ if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert)
1218
+
1219
+ # Compute SEs from influence functions
1220
+ if resolved_survey is not None:
1221
+ # Survey design: use TSL variance on the aggregate influence functions.
1222
+ # The IFs serve as "residuals" in the TSL sandwich; X is a column of ones
1223
+ # (the estimand is a scalar/vector mean of the IFs).
1224
+ #
1225
+ # The resolved_survey has panel-level arrays (n_obs = n_units * n_periods),
1226
+ # but influence functions are unit-level (n_units). Build a unit-level
1227
+ # ResolvedSurveyDesign by subsetting to one obs per unit.
1228
+ row_idx = precomp["unit_first_panel_row"]
1229
+ unit_weights = precomp.get("unit_survey_weights")
1230
+ if unit_weights is None:
1231
+ unit_weights = np.ones(n_units)
1232
+
1233
+ unit_strata = (
1234
+ resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
1235
+ )
1236
+ unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
1237
+ unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
1238
+
1239
+ # Count unique strata/PSU in the unit-level subset
1240
+ n_strata_unit = len(np.unique(unit_strata)) if unit_strata is not None else 0
1241
+ n_psu_unit = len(np.unique(unit_psu)) if unit_psu is not None else 0
1242
+
1243
+ unit_resolved = resolved_survey.subset_to_units(
1244
+ row_idx, unit_weights, unit_strata, unit_psu, unit_fpc,
1245
+ n_strata_unit, n_psu_unit,
1246
+ )
1247
+
1248
+ X_ones = np.ones((n_units, 1))
1249
+
1250
+ if unit_resolved.uses_replicate_variance:
1251
+ # Replicate-weight variance: score-scale IFs to match TSL bread.
1252
+ # TSL path does: scores = w * (if * tsl_scale), bread = 1/sum(w)^2
1253
+ # Equivalent psi for replicate: w * if_vals * tsl_scale / sum(w) = w * if_vals
1254
+ from diff_diff.survey import compute_replicate_if_variance
1255
+
1256
+ _w_rep = unit_resolved.weights
1257
+ _rep_n_valid = unit_resolved.n_replicates # track effective count
1258
+
1259
+ def _rep_se(if_vals):
1260
+ nonlocal _rep_n_valid
1261
+ psi_scaled = _w_rep * if_vals
1262
+ v, nv = compute_replicate_if_variance(psi_scaled, unit_resolved)
1263
+ _rep_n_valid = min(_rep_n_valid, nv) # worst-case valid count
1264
+ return float(np.sqrt(max(v, 0.0))) if np.isfinite(v) else np.nan
1265
+
1266
+ overall_att_se = _rep_se(if_att_glob)
1267
+ overall_acrt_se = _rep_se(if_acrt_glob)
1268
+ att_d_se = np.zeros(n_grid)
1269
+ acrt_d_se = np.zeros(n_grid)
1270
+ for d_idx in range(n_grid):
1271
+ att_d_se[d_idx] = _rep_se(if_att_d[:, d_idx])
1272
+ acrt_d_se[d_idx] = _rep_se(if_acrt_d[:, d_idx])
1273
+ else:
1274
+ # TSL: rescale IFs from 1/n convention to score scale for sandwich.
1275
+ tsl_scale = float(unit_resolved.weights.sum())
1276
+ if_att_glob_tsl = if_att_glob * tsl_scale
1277
+ if_acrt_glob_tsl = if_acrt_glob * tsl_scale
1278
+ if_att_d_tsl = if_att_d * tsl_scale
1279
+ if_acrt_d_tsl = if_acrt_d * tsl_scale
1280
+
1281
+ vcov_att = compute_survey_vcov(X_ones, if_att_glob_tsl, unit_resolved)
1282
+ overall_att_se = float(np.sqrt(np.abs(vcov_att[0, 0])))
1283
+
1284
+ vcov_acrt = compute_survey_vcov(X_ones, if_acrt_glob_tsl, unit_resolved)
1285
+ overall_acrt_se = float(np.sqrt(np.abs(vcov_acrt[0, 0])))
1286
+
1287
+ att_d_se = np.zeros(n_grid)
1288
+ acrt_d_se = np.zeros(n_grid)
1289
+ for d_idx in range(n_grid):
1290
+ vcov_d = compute_survey_vcov(X_ones, if_att_d_tsl[:, d_idx], unit_resolved)
1291
+ att_d_se[d_idx] = float(np.sqrt(np.abs(vcov_d[0, 0])))
1292
+
1293
+ vcov_d = compute_survey_vcov(X_ones, if_acrt_d_tsl[:, d_idx], unit_resolved)
1294
+ acrt_d_se[d_idx] = float(np.sqrt(np.abs(vcov_d[0, 0])))
1295
+ else:
1296
+ # SE = sqrt(sum(IF_i^2)), matching CallawaySantAnna's convention
1297
+ # (per-unit IFs already contain 1/n_t, 1/n_c scaling)
1298
+ overall_att_se = float(np.sqrt(np.sum(if_att_glob**2)))
1299
+ overall_acrt_se = float(np.sqrt(np.sum(if_acrt_glob**2)))
1300
+
1301
+ att_d_se = np.sqrt(np.sum(if_att_d**2, axis=0))
1302
+ acrt_d_se = np.sqrt(np.sum(if_acrt_d**2, axis=0))
1303
+
1304
+ # Return unit-level survey df and resolved design for metadata recomputation
1305
+ # Only override with n_valid-based df when replicates were actually dropped
1306
+ if resolved_survey is not None and hasattr(resolved_survey, 'uses_replicate_variance') and resolved_survey.uses_replicate_variance:
1307
+ if _rep_n_valid < unit_resolved.n_replicates:
1308
+ unit_df_survey = _rep_n_valid - 1 if _rep_n_valid > 1 else None
1309
+ else:
1310
+ unit_df_survey = unit_resolved.df_survey
1311
+ else:
1312
+ unit_df_survey = unit_resolved.df_survey if resolved_survey is not None else None
1313
+
1314
+ return {
1315
+ "overall_att_se": overall_att_se,
1316
+ "overall_acrt_se": overall_acrt_se,
1317
+ "att_d_se": att_d_se,
1318
+ "acrt_d_se": acrt_d_se,
1319
+ "df_survey": unit_df_survey,
1320
+ "unit_resolved": unit_resolved if resolved_survey is not None else None,
1321
+ }
1322
+
1323
+ def _run_bootstrap(
1324
+ self,
1325
+ precomp: Dict[str, Any],
1326
+ gt_results: Dict[Tuple, Dict],
1327
+ gt_bootstrap_info: Dict[Tuple, Dict],
1328
+ post_gt: Dict[Tuple, Dict],
1329
+ cell_weights: Dict[Tuple, float],
1330
+ knots: np.ndarray,
1331
+ degree: int,
1332
+ dvals: np.ndarray,
1333
+ original_att: float,
1334
+ original_acrt: float,
1335
+ original_att_d: np.ndarray,
1336
+ original_acrt_d: np.ndarray,
1337
+ event_study_effects: Optional[Dict[int, Dict]],
1338
+ resolved_survey: object = None,
1339
+ ) -> Dict[str, Any]:
1340
+ """Run multiplier bootstrap inference."""
1341
+ if self.n_bootstrap < 50:
1342
+ warnings.warn(
1343
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
1344
+ "for reliable inference.",
1345
+ UserWarning,
1346
+ stacklevel=3,
1347
+ )
1348
+
1349
+ # Reject replicate-weight designs for bootstrap — replicate variance
1350
+ # is an analytical alternative to bootstrap, not compatible with it
1351
+ if resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance:
1352
+ raise NotImplementedError(
1353
+ "ContinuousDiD bootstrap (n_bootstrap > 0) is not supported "
1354
+ "with replicate-weight survey designs. Replicate weights provide "
1355
+ "analytical variance; use n_bootstrap=0 instead."
1356
+ )
1357
+
1358
+ rng = np.random.default_rng(self.seed)
1359
+ n_units = precomp["n_units"]
1360
+ n_grid = len(dvals)
1361
+
1362
+ # Build unit-level ResolvedSurveyDesign for survey-aware bootstrap
1363
+ unit_resolved = None
1364
+ if resolved_survey is not None:
1365
+ from diff_diff.survey import ResolvedSurveyDesign
1366
+
1367
+ row_idx = precomp["unit_first_panel_row"]
1368
+ unit_weights = precomp.get("unit_survey_weights")
1369
+ if unit_weights is None:
1370
+ unit_weights = np.ones(n_units)
1371
+ unit_strata = (
1372
+ resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
1373
+ )
1374
+ unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
1375
+ unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
1376
+ n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
1377
+ n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
1378
+ unit_resolved = resolved_survey.subset_to_units(
1379
+ row_idx, unit_weights, unit_strata, unit_psu, unit_fpc,
1380
+ n_strata_u, n_psu_u,
1381
+ )
1382
+
1383
+ # Generate bootstrap weights — PSU-level when survey design is present
1384
+ _use_survey_bootstrap = unit_resolved is not None and (
1385
+ unit_resolved.strata is not None
1386
+ or unit_resolved.psu is not None
1387
+ or unit_resolved.fpc is not None
1388
+ )
1389
+
1390
+ if _use_survey_bootstrap:
1391
+ from diff_diff.bootstrap_utils import (
1392
+ generate_survey_multiplier_weights_batch,
1393
+ )
1394
+
1395
+ psu_weights, psu_ids = generate_survey_multiplier_weights_batch(
1396
+ self.n_bootstrap, unit_resolved, self.bootstrap_weights, rng
1397
+ )
1398
+ # Build unit -> PSU column map
1399
+ if unit_resolved.psu is not None:
1400
+ psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)}
1401
+ unit_to_psu_col = np.array(
1402
+ [psu_id_to_col[int(unit_resolved.psu[i])] for i in range(n_units)]
1403
+ )
1404
+ else:
1405
+ unit_to_psu_col = np.arange(n_units)
1406
+ all_weights = psu_weights[:, unit_to_psu_col]
1407
+ else:
1408
+ all_weights = generate_bootstrap_weights_batch(
1409
+ self.n_bootstrap, n_units, self.bootstrap_weights, rng
1410
+ )
1411
+
1412
+ boot_att_glob = np.zeros(self.n_bootstrap)
1413
+ boot_acrt_glob = np.zeros(self.n_bootstrap)
1414
+ boot_att_d = np.zeros((self.n_bootstrap, n_grid))
1415
+ boot_acrt_d = np.zeros((self.n_bootstrap, n_grid))
1416
+
1417
+ # Event study bootstrap — compute weights per event-time bin
1418
+ es_keys = sorted(event_study_effects.keys()) if event_study_effects else []
1419
+ boot_es = {e: np.zeros(self.n_bootstrap) for e in es_keys}
1420
+ # Per-(g,t) weight within event-time bin — use survey-weighted cohort
1421
+ # masses when available, matching _aggregate_event_study.
1422
+ unit_sw = precomp.get("unit_survey_weights")
1423
+ unit_cohorts = precomp["unit_cohorts"]
1424
+ es_cell_weights: Dict[Tuple, float] = {}
1425
+ if event_study_effects is not None:
1426
+ from collections import defaultdict
1427
+
1428
+ es_bin_total: Dict[int, float] = defaultdict(float)
1429
+ for gt, r in gt_results.items():
1430
+ g_val, t_val = gt
1431
+ e = t_val - g_val
1432
+ if self.anticipation > 0 and e < -self.anticipation:
1433
+ continue
1434
+ if unit_sw is not None:
1435
+ g_mask = unit_cohorts == g_val
1436
+ cell_mass = float(np.sum(unit_sw[g_mask]))
1437
+ else:
1438
+ cell_mass = float(r["n_treated"])
1439
+ es_bin_total[e] += cell_mass
1440
+ for gt, r in gt_results.items():
1441
+ g_val, t_val = gt
1442
+ e = t_val - g_val
1443
+ if self.anticipation > 0 and e < -self.anticipation:
1444
+ continue
1445
+ if unit_sw is not None:
1446
+ g_mask = unit_cohorts == g_val
1447
+ cell_mass = float(np.sum(unit_sw[g_mask]))
1448
+ else:
1449
+ cell_mass = float(r["n_treated"])
1450
+ if es_bin_total[e] > 0:
1451
+ es_cell_weights[gt] = cell_mass / es_bin_total[e]
1452
+
1453
+ # Helper to bootstrap a single (g,t) cell
1454
+ def _bootstrap_gt_cell(gt, info):
1455
+ """Returns att_glob_b array (B,) for this cell."""
1456
+ treated_idx = info["treated_indices"]
1457
+ control_idx = info["control_indices"]
1458
+ n_t = info["n_treated"]
1459
+ n_c = info["n_control"]
1460
+ # Use survey-weighted masses when available (matching analytical SE)
1461
+ if "w_treated" in info:
1462
+ n_t = info["w_treated"]
1463
+ n_c = info["w_control"]
1464
+ bread = info["bread"]
1465
+ ee_treated = info["ee_treated"]
1466
+ ee_control = info["ee_control"]
1467
+ psi_bar = info["psi_bar"]
1468
+ beta_pred = info["beta_pred"]
1469
+ Psi_eval = info["Psi_eval"]
1470
+ dPsi_eval = info["dPsi_eval"]
1471
+ dPsi_treated = info["dPsi_treated"]
1472
+ delta_y_treated = info["delta_y_treated"]
1473
+ mu_0 = info["mu_0"]
1474
+ att_glob_gt = info["att_glob"]
1475
+ sw_treated = info.get("w_treated_arr")
1476
+
1477
+ w_treated = all_weights[:, treated_idx]
1478
+ w_control = all_weights[:, control_idx]
1479
+
1480
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1481
+ treated_sum = w_treated @ ee_treated / n_t
1482
+ control_sum = (w_control @ ee_control) / n_c
1483
+ psi_bar_outer = psi_bar[np.newaxis, :]
1484
+
1485
+ delta_beta = (treated_sum - control_sum[:, np.newaxis] * psi_bar_outer) @ bread.T
1486
+ beta_b = beta_pred[np.newaxis, :] + delta_beta
1487
+
1488
+ att_d_b = beta_b @ Psi_eval.T
1489
+ acrt_d_b = beta_b @ dPsi_eval.T
1490
+
1491
+ mu_0_pert = (w_control @ ee_control) / n_c
1492
+ # ATT_glob perturbation: weight scores by survey weight w_k
1493
+ # when present, matching the analytical IF path.
1494
+ att_glob_score = delta_y_treated - att_glob_gt - mu_0
1495
+ if sw_treated is not None:
1496
+ att_glob_score = sw_treated * att_glob_score
1497
+ mean_dy_treated_pert = (w_treated @ att_glob_score) / n_t
1498
+ att_glob_b = att_glob_gt + mean_dy_treated_pert - mu_0_pert
1499
+
1500
+ if sw_treated is not None:
1501
+ sw_norm = sw_treated / sw_treated.sum()
1502
+ dpsi_mean = sw_norm @ dPsi_treated
1503
+ else:
1504
+ dpsi_mean = np.mean(dPsi_treated, axis=0)
1505
+ acrt_glob_b = delta_beta @ dpsi_mean
1506
+
1507
+ return att_d_b, acrt_d_b, att_glob_b, acrt_glob_b, info.get("acrt_glob", 0.0)
1508
+
1509
+ # Iterate over post-treatment cells for dose-response/overall aggregation
1510
+ for gt, w in cell_weights.items():
1511
+ if w == 0:
1512
+ continue
1513
+ info = gt_bootstrap_info[gt]
1514
+ if not info:
1515
+ continue
1516
+
1517
+ att_d_b, acrt_d_b, att_glob_b, acrt_glob_b, acrt_glob_pt = _bootstrap_gt_cell(gt, info)
1518
+
1519
+ boot_att_d += w * att_d_b
1520
+ boot_acrt_d += w * acrt_d_b
1521
+ boot_att_glob += w * att_glob_b
1522
+ boot_acrt_glob += w * (acrt_glob_pt + acrt_glob_b)
1523
+
1524
+ # Event study bootstrap — iterate over ALL (g,t) cells
1525
+ if event_study_effects is not None:
1526
+ for gt, r in gt_results.items():
1527
+ info = gt_bootstrap_info[gt]
1528
+ if not info:
1529
+ continue
1530
+ g_val, t_val = gt
1531
+ e = t_val - g_val
1532
+ if e not in boot_es:
1533
+ continue
1534
+ es_w = es_cell_weights.get(gt, 0.0)
1535
+ if es_w == 0:
1536
+ continue
1537
+ _, _, att_glob_b, _, _ = _bootstrap_gt_cell(gt, info)
1538
+ boot_es[e] += es_w * att_glob_b
1539
+
1540
+ # Compute statistics
1541
+ result: Dict[str, Any] = {}
1542
+
1543
+ # Per-grid-point
1544
+ att_d_se = np.full(n_grid, np.nan)
1545
+ att_d_ci_lower = np.full(n_grid, np.nan)
1546
+ att_d_ci_upper = np.full(n_grid, np.nan)
1547
+ acrt_d_se = np.full(n_grid, np.nan)
1548
+ acrt_d_ci_lower = np.full(n_grid, np.nan)
1549
+ acrt_d_ci_upper = np.full(n_grid, np.nan)
1550
+
1551
+ att_d_p = np.full(n_grid, np.nan)
1552
+ acrt_d_p = np.full(n_grid, np.nan)
1553
+
1554
+ for idx in range(n_grid):
1555
+ se, ci, p = compute_effect_bootstrap_stats(
1556
+ original_att_d[idx],
1557
+ boot_att_d[:, idx],
1558
+ alpha=self.alpha,
1559
+ context=f"ATT(d) at grid point {idx}",
1560
+ )
1561
+ att_d_se[idx] = se
1562
+ att_d_ci_lower[idx] = ci[0]
1563
+ att_d_ci_upper[idx] = ci[1]
1564
+ att_d_p[idx] = p
1565
+
1566
+ se, ci, p = compute_effect_bootstrap_stats(
1567
+ original_acrt_d[idx],
1568
+ boot_acrt_d[:, idx],
1569
+ alpha=self.alpha,
1570
+ context=f"ACRT(d) at grid point {idx}",
1571
+ )
1572
+ acrt_d_se[idx] = se
1573
+ acrt_d_ci_lower[idx] = ci[0]
1574
+ acrt_d_ci_upper[idx] = ci[1]
1575
+ acrt_d_p[idx] = p
1576
+
1577
+ result["att_d_se"] = att_d_se
1578
+ result["att_d_ci_lower"] = att_d_ci_lower
1579
+ result["att_d_ci_upper"] = att_d_ci_upper
1580
+ result["acrt_d_se"] = acrt_d_se
1581
+ result["acrt_d_ci_lower"] = acrt_d_ci_lower
1582
+ result["acrt_d_ci_upper"] = acrt_d_ci_upper
1583
+ result["att_d_p"] = att_d_p
1584
+ result["acrt_d_p"] = acrt_d_p
1585
+
1586
+ # Overall
1587
+ se, ci, p = compute_effect_bootstrap_stats(
1588
+ original_att,
1589
+ boot_att_glob,
1590
+ alpha=self.alpha,
1591
+ context="overall ATT_glob",
1592
+ )
1593
+ result["overall_att_se"] = se
1594
+ result["overall_att_ci"] = ci
1595
+ result["overall_att_p"] = p
1596
+
1597
+ se, ci, p = compute_effect_bootstrap_stats(
1598
+ original_acrt,
1599
+ boot_acrt_glob,
1600
+ alpha=self.alpha,
1601
+ context="overall ACRT_glob",
1602
+ )
1603
+ result["overall_acrt_se"] = se
1604
+ result["overall_acrt_ci"] = ci
1605
+ result["overall_acrt_p"] = p
1606
+
1607
+ # Event study SEs
1608
+ if event_study_effects is not None:
1609
+ es_se = {}
1610
+ es_ci = {}
1611
+ es_p = {}
1612
+ for e in es_keys:
1613
+ se_e, ci_e, p_e = compute_effect_bootstrap_stats(
1614
+ event_study_effects[e]["effect"],
1615
+ boot_es[e],
1616
+ alpha=self.alpha,
1617
+ context=f"event study e={e}",
1618
+ )
1619
+ es_se[e] = se_e
1620
+ es_ci[e] = ci_e
1621
+ es_p[e] = p_e
1622
+ result["es_se"] = es_se
1623
+ result["es_ci"] = es_ci
1624
+ result["es_p"] = es_p
1625
+
1626
+ return result