diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,520 @@
1
+ """
2
+ Bootstrap inference methods for the Two-Stage DiD estimator.
3
+
4
+ This module contains TwoStageDiDBootstrapMixin, which provides multiplier
5
+ bootstrap inference on the GMM influence function. Extracted from two_stage.py
6
+ for module size management.
7
+ """
8
+
9
+ import warnings
10
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ from scipy.sparse.linalg import factorized as sparse_factorized
15
+
16
+ from diff_diff.bootstrap_utils import (
17
+ compute_effect_bootstrap_stats as _compute_effect_bootstrap_stats,
18
+ )
19
+ from diff_diff.bootstrap_utils import (
20
+ generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch,
21
+ )
22
+ from diff_diff.bootstrap_utils import (
23
+ generate_survey_multiplier_weights_batch as _generate_survey_multiplier_weights_batch,
24
+ )
25
+ from diff_diff.linalg import solve_ols
26
+ from diff_diff.two_stage_results import TwoStageBootstrapResults
27
+
28
+ # Maximum number of elements before falling back to per-column sparse aggregation.
29
+ # Keep in sync with two_stage.py.
30
+ _SPARSE_DENSE_THRESHOLD = 10_000_000
31
+
32
+ __all__ = [
33
+ "TwoStageDiDBootstrapMixin",
34
+ ]
35
+
36
+
37
+ class TwoStageDiDBootstrapMixin:
38
+ """Mixin providing bootstrap inference methods for TwoStageDiD."""
39
+
40
+ # Type hints for attributes accessed from the main class
41
+ n_bootstrap: int
42
+ bootstrap_weights: str
43
+ alpha: float
44
+ seed: Optional[int]
45
+ horizon_max: Optional[int]
46
+
47
+ if TYPE_CHECKING:
48
+ from scipy import sparse
49
+
50
+ def _build_fe_design(
51
+ self,
52
+ df: pd.DataFrame,
53
+ unit: str,
54
+ time: str,
55
+ covariates: Optional[List[str]],
56
+ omega_0_mask: pd.Series,
57
+ ) -> Tuple["sparse.csr_matrix", "sparse.csr_matrix", Dict[Any, int], Dict[Any, int]]: ...
58
+
59
+ @staticmethod
60
+ def _compute_gmm_scores(
61
+ c_by_cluster: np.ndarray,
62
+ gamma_hat: np.ndarray,
63
+ s2_by_cluster: np.ndarray,
64
+ ) -> np.ndarray: ...
65
+
66
+ def _compute_cluster_S_scores(
67
+ self,
68
+ df: pd.DataFrame,
69
+ unit: str,
70
+ time: str,
71
+ covariates: Optional[List[str]],
72
+ omega_0_mask: pd.Series,
73
+ unit_fe: Dict[Any, float],
74
+ time_fe: Dict[Any, float],
75
+ delta_hat: Optional[np.ndarray],
76
+ kept_cov_mask: Optional[np.ndarray],
77
+ X_2: np.ndarray,
78
+ eps_2: np.ndarray,
79
+ cluster_ids: np.ndarray,
80
+ survey_weights: Optional[np.ndarray] = None,
81
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
82
+ """
83
+ Compute per-cluster S_g scores for bootstrap.
84
+
85
+ Returns
86
+ -------
87
+ S : np.ndarray, shape (G, k)
88
+ Per-cluster influence scores.
89
+ bread : np.ndarray, shape (k, k)
90
+ (X'_2 X_2)^{-1}.
91
+ unique_clusters : np.ndarray
92
+ Unique cluster identifiers.
93
+ """
94
+ n = len(df)
95
+ k = X_2.shape[1]
96
+
97
+ cov_list = covariates
98
+ if covariates and kept_cov_mask is not None and not np.all(kept_cov_mask):
99
+ cov_list = [c for c, k_ in zip(covariates, kept_cov_mask) if k_]
100
+
101
+ X_1_sparse, X_10_sparse, _, _ = self._build_fe_design(
102
+ df, unit, time, cov_list, omega_0_mask
103
+ )
104
+ p = X_1_sparse.shape[1]
105
+
106
+ # Reconstruct Y and compute eps_10
107
+ alpha_i = df[unit].map(unit_fe).values
108
+ beta_t = df[time].map(time_fe).values
109
+ alpha_i = np.where(pd.isna(alpha_i), 0.0, alpha_i).astype(float)
110
+ beta_t = np.where(pd.isna(beta_t), 0.0, beta_t).astype(float)
111
+ fitted_1 = alpha_i + beta_t
112
+ if delta_hat is not None and cov_list:
113
+ if kept_cov_mask is not None and not np.all(kept_cov_mask):
114
+ fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat[kept_cov_mask])
115
+ else:
116
+ fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat)
117
+
118
+ y_tilde = df["_y_tilde"].values
119
+ y_vals = y_tilde + fitted_1
120
+
121
+ eps_10 = np.empty(n)
122
+ omega_0 = omega_0_mask.values
123
+ eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0]
124
+ eps_10[~omega_0] = y_vals[~omega_0]
125
+
126
+ # gamma_hat — with survey weights, both cross-products need W
127
+ if survey_weights is not None:
128
+ XtX_10 = X_10_sparse.T @ X_10_sparse.multiply(survey_weights[:, None])
129
+ Xt1_X2 = X_1_sparse.T @ (X_2 * survey_weights[:, None])
130
+ else:
131
+ XtX_10 = X_10_sparse.T @ X_10_sparse
132
+ Xt1_X2 = X_1_sparse.T @ X_2
133
+
134
+ try:
135
+ solve_XtX = sparse_factorized(XtX_10.tocsc())
136
+ if Xt1_X2.ndim == 1:
137
+ gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1)
138
+ else:
139
+ gamma_hat = np.column_stack(
140
+ [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])]
141
+ )
142
+ except RuntimeError:
143
+ gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0]
144
+ if gamma_hat.ndim == 1:
145
+ gamma_hat = gamma_hat.reshape(-1, 1)
146
+
147
+ # Per-cluster aggregation — survey weights multiply eps_10 before sparse multiply
148
+ if survey_weights is not None:
149
+ weighted_eps_10 = survey_weights * eps_10
150
+ else:
151
+ weighted_eps_10 = eps_10
152
+ weighted_X10 = X_10_sparse.multiply(weighted_eps_10[:, None])
153
+ unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True)
154
+ G = len(unique_clusters)
155
+
156
+ n_elements = weighted_X10.shape[0] * weighted_X10.shape[1]
157
+ c_by_cluster = np.zeros((G, p))
158
+ if n_elements > _SPARSE_DENSE_THRESHOLD:
159
+ # Per-column path: limits peak memory for large FE matrices
160
+ weighted_X10_csc = weighted_X10.tocsc()
161
+ for j_col in range(p):
162
+ col_data = weighted_X10_csc.getcol(j_col).toarray().ravel()
163
+ np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data)
164
+ else:
165
+ # Dense path: faster for moderate-size matrices
166
+ weighted_X10_dense = weighted_X10.toarray()
167
+ for j_col in range(p):
168
+ np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col])
169
+
170
+ if survey_weights is not None:
171
+ weighted_eps_2 = survey_weights * eps_2
172
+ else:
173
+ weighted_eps_2 = eps_2
174
+ weighted_X2 = X_2 * weighted_eps_2[:, None]
175
+ s2_by_cluster = np.zeros((G, k))
176
+ for j_col in range(k):
177
+ np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col])
178
+
179
+ S = self._compute_gmm_scores(c_by_cluster, gamma_hat, s2_by_cluster)
180
+
181
+ # Bread — (X'_2 W X_2)^{-1} with survey weights
182
+ with np.errstate(invalid="ignore", over="ignore", divide="ignore"):
183
+ if survey_weights is not None:
184
+ XtX_2 = X_2.T @ (X_2 * survey_weights[:, None])
185
+ else:
186
+ XtX_2 = np.dot(X_2.T, X_2)
187
+ try:
188
+ bread = np.linalg.solve(XtX_2, np.eye(k))
189
+ except np.linalg.LinAlgError:
190
+ bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0]
191
+
192
+ return S, bread, unique_clusters
193
+
194
+ def _run_bootstrap(
195
+ self,
196
+ df: pd.DataFrame,
197
+ unit: str,
198
+ time: str,
199
+ first_treat: str,
200
+ covariates: Optional[List[str]],
201
+ omega_0_mask: pd.Series,
202
+ omega_1_mask: pd.Series,
203
+ unit_fe: Dict[Any, float],
204
+ time_fe: Dict[Any, float],
205
+ grand_mean: float,
206
+ delta_hat: Optional[np.ndarray],
207
+ cluster_var: str,
208
+ kept_cov_mask: Optional[np.ndarray],
209
+ treatment_groups: List[Any],
210
+ ref_period: int,
211
+ balance_e: Optional[int],
212
+ original_att: float,
213
+ original_event_study: Optional[Dict[int, Dict[str, Any]]],
214
+ original_group: Optional[Dict[Any, Dict[str, Any]]],
215
+ aggregate: Optional[str],
216
+ resolved_survey: Optional[Any] = None,
217
+ ) -> Optional[TwoStageBootstrapResults]:
218
+ """Run multiplier bootstrap on GMM influence function."""
219
+ if self.n_bootstrap < 50:
220
+ warnings.warn(
221
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
222
+ "for reliable inference.",
223
+ UserWarning,
224
+ stacklevel=3,
225
+ )
226
+
227
+ rng = np.random.default_rng(self.seed)
228
+
229
+ y_tilde = df["_y_tilde"].values.copy() # .copy() to avoid mutating df column
230
+ n = len(df)
231
+ cluster_ids = df[cluster_var].values
232
+
233
+ # Extract survey weights for S-score computation and Stage-2 WLS
234
+ survey_weights: Optional[np.ndarray] = None
235
+ survey_weight_type: str = "pweight"
236
+ if resolved_survey is not None:
237
+ survey_weights = resolved_survey.weights
238
+ survey_weight_type = resolved_survey.weight_type
239
+
240
+ # Handle NaN y_tilde (from unidentified FEs) — matches _stage2_static logic
241
+ nan_mask = ~np.isfinite(y_tilde)
242
+ if nan_mask.any():
243
+ y_tilde[nan_mask] = 0.0
244
+
245
+ # --- Static specification bootstrap ---
246
+ D = omega_1_mask.values.astype(float) # .astype() already creates a copy
247
+ D[nan_mask] = 0.0 # Exclude NaN y_tilde obs from bootstrap estimation
248
+
249
+ # Degenerate case: all treated obs have NaN y_tilde
250
+ if D.sum() == 0:
251
+ return None
252
+
253
+ X_2_static = D.reshape(-1, 1)
254
+ coef_static = solve_ols(
255
+ X_2_static, y_tilde, return_vcov=False,
256
+ weights=survey_weights, weight_type=survey_weight_type,
257
+ )[0]
258
+ eps_2_static = y_tilde - np.dot(X_2_static, coef_static)
259
+
260
+ S_static, bread_static, unique_clusters = self._compute_cluster_S_scores(
261
+ df=df,
262
+ unit=unit,
263
+ time=time,
264
+ covariates=covariates,
265
+ omega_0_mask=omega_0_mask,
266
+ unit_fe=unit_fe,
267
+ time_fe=time_fe,
268
+ delta_hat=delta_hat,
269
+ kept_cov_mask=kept_cov_mask,
270
+ X_2=X_2_static,
271
+ eps_2=eps_2_static,
272
+ cluster_ids=cluster_ids,
273
+ survey_weights=survey_weights,
274
+ )
275
+
276
+ n_clusters = len(unique_clusters)
277
+
278
+ # Generate bootstrap weights — PSU-level when survey design is present
279
+ _use_survey_bootstrap = resolved_survey is not None and (
280
+ resolved_survey.strata is not None
281
+ or resolved_survey.psu is not None
282
+ or resolved_survey.fpc is not None
283
+ )
284
+
285
+ if _use_survey_bootstrap:
286
+ psu_weights, psu_ids = _generate_survey_multiplier_weights_batch(
287
+ self.n_bootstrap, resolved_survey, self.bootstrap_weights, rng
288
+ )
289
+ # Map unique_clusters (PSU values) to PSU weight columns.
290
+ # When survey+PSU is active, cluster_var == "_survey_cluster" so
291
+ # unique_clusters are the PSU ids used in S-score aggregation.
292
+ psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)}
293
+ cluster_to_psu_col = np.array([psu_id_to_col[int(cl)] for cl in unique_clusters])
294
+ all_weights = psu_weights[:, cluster_to_psu_col]
295
+ else:
296
+ all_weights = _generate_bootstrap_weights_batch(
297
+ self.n_bootstrap, n_clusters, self.bootstrap_weights, rng
298
+ )
299
+
300
+ # T_b = bread @ (sum_g w_bg * S_g) = bread @ (W @ S)' per boot
301
+ # IF_b = bread @ S_g for each cluster, then perturb
302
+ # boot_coef = all_weights @ S_static @ bread_static.T -> (B, k)
303
+ # For static (k=1): boot_att = all_weights @ S_static @ bread_static.T
304
+ boot_att_vec = np.dot(all_weights, S_static) # (B, 1)
305
+ boot_att_vec = np.dot(boot_att_vec, bread_static.T) # (B, 1)
306
+ boot_overall = boot_att_vec[:, 0]
307
+
308
+ boot_overall_shifted = boot_overall + original_att
309
+ overall_se, overall_ci, overall_p = _compute_effect_bootstrap_stats(
310
+ original_att,
311
+ boot_overall_shifted,
312
+ alpha=self.alpha,
313
+ context="TwoStageDiD overall ATT",
314
+ )
315
+
316
+ # --- Event study bootstrap ---
317
+ event_study_ses = None
318
+ event_study_cis = None
319
+ event_study_p_values = None
320
+
321
+ if original_event_study and aggregate in ("event_study", "all"):
322
+ # Recompute S scores for event study specification
323
+ rel_times = df["_rel_time"].values
324
+ if self.pretrends:
325
+ evt_rel = rel_times[~df["_never_treated"].values]
326
+ else:
327
+ evt_rel = rel_times[omega_1_mask.values]
328
+ all_horizons = sorted(set(int(h) for h in evt_rel if np.isfinite(h)))
329
+ if self.horizon_max is not None:
330
+ all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
331
+
332
+ if balance_e is not None:
333
+ cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
334
+ balanced_cohorts = set()
335
+ if all_horizons:
336
+ max_h = max(all_horizons)
337
+ required_range = set(range(-balance_e, max_h + 1))
338
+ for g, horizons in cohort_rel_times.items():
339
+ if required_range.issubset(horizons):
340
+ balanced_cohorts.add(g)
341
+ if not balanced_cohorts:
342
+ all_horizons = [] # No qualifying cohorts -> skip event study bootstrap
343
+ else:
344
+ balance_mask = df[first_treat].isin(balanced_cohorts).values
345
+ else:
346
+ balance_mask = np.ones(n, dtype=bool)
347
+
348
+ est_horizons = [h for h in all_horizons if h != ref_period]
349
+
350
+ # Filter out Prop 5 horizons (same logic as _stage2_event_study)
351
+ has_never_treated = df["_never_treated"].any()
352
+ h_bar_boot = np.inf
353
+ if not has_never_treated and len(treatment_groups) > 1:
354
+ h_bar_boot = max(treatment_groups) - min(treatment_groups)
355
+ if h_bar_boot < np.inf:
356
+ est_horizons = [h for h in est_horizons if h < h_bar_boot]
357
+
358
+ if est_horizons:
359
+ horizon_to_col = {h: j for j, h in enumerate(est_horizons)}
360
+ k_es = len(est_horizons)
361
+ X_2_es = np.zeros((n, k_es))
362
+ for i in range(n):
363
+ if not balance_mask[i]:
364
+ continue
365
+ if nan_mask[i]:
366
+ continue # NaN y_tilde -> exclude from bootstrap event study
367
+ h = rel_times[i]
368
+ if np.isfinite(h):
369
+ h_int = int(h)
370
+ if h_int in horizon_to_col:
371
+ X_2_es[i, horizon_to_col[h_int]] = 1.0
372
+
373
+ coef_es = solve_ols(
374
+ X_2_es, y_tilde, return_vcov=False,
375
+ weights=survey_weights, weight_type=survey_weight_type,
376
+ )[0]
377
+ eps_2_es = y_tilde - np.dot(X_2_es, coef_es)
378
+
379
+ S_es, bread_es, _ = self._compute_cluster_S_scores(
380
+ df=df,
381
+ unit=unit,
382
+ time=time,
383
+ covariates=covariates,
384
+ omega_0_mask=omega_0_mask,
385
+ unit_fe=unit_fe,
386
+ time_fe=time_fe,
387
+ delta_hat=delta_hat,
388
+ kept_cov_mask=kept_cov_mask,
389
+ X_2=X_2_es,
390
+ eps_2=eps_2_es,
391
+ cluster_ids=cluster_ids,
392
+ survey_weights=survey_weights,
393
+ )
394
+
395
+ # boot_coef_es: (B, k_es)
396
+ boot_coef_es = np.dot(np.dot(all_weights, S_es), bread_es.T)
397
+
398
+ event_study_ses = {}
399
+ event_study_cis = {}
400
+ event_study_p_values = {}
401
+ for h in original_event_study:
402
+ if original_event_study[h].get("n_obs", 0) == 0:
403
+ continue
404
+ if np.isnan(original_event_study[h]["effect"]):
405
+ continue # Skip Prop 5 and other NaN-effect horizons
406
+ if h not in horizon_to_col:
407
+ continue
408
+ j = horizon_to_col[h]
409
+ orig_eff = original_event_study[h]["effect"]
410
+ boot_h = boot_coef_es[:, j]
411
+ shifted_h = boot_h + orig_eff
412
+ se_h, ci_h, p_h = _compute_effect_bootstrap_stats(
413
+ orig_eff,
414
+ shifted_h,
415
+ alpha=self.alpha,
416
+ context=f"TwoStageDiD event study (h={h})",
417
+ )
418
+ event_study_ses[h] = se_h
419
+ event_study_cis[h] = ci_h
420
+ event_study_p_values[h] = p_h
421
+
422
+ # --- Group bootstrap ---
423
+ group_ses = None
424
+ group_cis = None
425
+ group_p_values = None
426
+
427
+ if original_group and aggregate in ("group", "all"):
428
+ group_to_col = {g: j for j, g in enumerate(treatment_groups)}
429
+ k_grp = len(treatment_groups)
430
+ X_2_grp = np.zeros((n, k_grp))
431
+ ft_vals = df[first_treat].values
432
+ treated_mask = omega_1_mask.values
433
+ for i in range(n):
434
+ if treated_mask[i]:
435
+ if nan_mask[i]:
436
+ continue # NaN y_tilde -> exclude from group bootstrap
437
+ g = ft_vals[i]
438
+ if g in group_to_col:
439
+ X_2_grp[i, group_to_col[g]] = 1.0
440
+
441
+ coef_grp = solve_ols(
442
+ X_2_grp, y_tilde, return_vcov=False,
443
+ weights=survey_weights, weight_type=survey_weight_type,
444
+ )[0]
445
+ eps_2_grp = y_tilde - np.dot(X_2_grp, coef_grp)
446
+
447
+ S_grp, bread_grp, _ = self._compute_cluster_S_scores(
448
+ df=df,
449
+ unit=unit,
450
+ time=time,
451
+ covariates=covariates,
452
+ omega_0_mask=omega_0_mask,
453
+ unit_fe=unit_fe,
454
+ time_fe=time_fe,
455
+ delta_hat=delta_hat,
456
+ kept_cov_mask=kept_cov_mask,
457
+ X_2=X_2_grp,
458
+ eps_2=eps_2_grp,
459
+ cluster_ids=cluster_ids,
460
+ survey_weights=survey_weights,
461
+ )
462
+
463
+ boot_coef_grp = np.dot(np.dot(all_weights, S_grp), bread_grp.T)
464
+
465
+ group_ses = {}
466
+ group_cis = {}
467
+ group_p_values = {}
468
+ for g in original_group:
469
+ if g not in group_to_col:
470
+ continue
471
+ j = group_to_col[g]
472
+ orig_eff = original_group[g]["effect"]
473
+ boot_g = boot_coef_grp[:, j]
474
+ shifted_g = boot_g + orig_eff
475
+ se_g, ci_g, p_g = _compute_effect_bootstrap_stats(
476
+ orig_eff,
477
+ shifted_g,
478
+ alpha=self.alpha,
479
+ context=f"TwoStageDiD group effect (g={g})",
480
+ )
481
+ group_ses[g] = se_g
482
+ group_cis[g] = ci_g
483
+ group_p_values[g] = p_g
484
+
485
+ return TwoStageBootstrapResults(
486
+ n_bootstrap=self.n_bootstrap,
487
+ weight_type=self.bootstrap_weights,
488
+ alpha=self.alpha,
489
+ overall_att_se=overall_se,
490
+ overall_att_ci=overall_ci,
491
+ overall_att_p_value=overall_p,
492
+ event_study_ses=event_study_ses,
493
+ event_study_cis=event_study_cis,
494
+ event_study_p_values=event_study_p_values,
495
+ group_ses=group_ses,
496
+ group_cis=group_cis,
497
+ group_p_values=group_p_values,
498
+ bootstrap_distribution=boot_overall_shifted,
499
+ )
500
+
501
+ # =========================================================================
502
+ # Utility
503
+ # =========================================================================
504
+
505
+ @staticmethod
506
+ def _build_cohort_rel_times(
507
+ df: pd.DataFrame,
508
+ first_treat: str,
509
+ ) -> Dict[Any, Set[int]]:
510
+ """Build mapping of cohort -> set of observed relative times."""
511
+ treated_mask = ~df["_never_treated"]
512
+ treated_df = df.loc[treated_mask]
513
+ result: Dict[Any, Set[int]] = {}
514
+ ft_vals = treated_df[first_treat].values
515
+ rt_vals = treated_df["_rel_time"].values
516
+ for i in range(len(treated_df)):
517
+ h = rt_vals[i]
518
+ if np.isfinite(h):
519
+ result.setdefault(ft_vals[i], set()).add(int(h))
520
+ return result