diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,359 @@
1
+ """
2
+ Multiplier bootstrap inference for the Efficient DiD estimator.
3
+
4
+ Pattern follows CallawaySantAnnaBootstrapMixin (staggered_bootstrap.py).
5
+ Perturbs EIF values with random weights to obtain bootstrap distributions
6
+ of ATT(g,t) and aggregated parameters.
7
+ """
8
+
9
+ import warnings
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ import numpy as np
14
+
15
+ from diff_diff.bootstrap_utils import (
16
+ compute_effect_bootstrap_stats as _compute_effect_bootstrap_stats_func,
17
+ )
18
+ from diff_diff.bootstrap_utils import (
19
+ generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch,
20
+ )
21
+
22
+
23
+ @dataclass
24
+ class EDiDBootstrapResults:
25
+ """Bootstrap inference results for EfficientDiD."""
26
+
27
+ n_bootstrap: int
28
+ weight_type: str
29
+ alpha: float
30
+ overall_att_se: float
31
+ overall_att_ci: Tuple[float, float]
32
+ overall_att_p_value: float
33
+ group_time_ses: Dict[Tuple[Any, Any], float]
34
+ group_time_cis: Dict[Tuple[Any, Any], Tuple[float, float]]
35
+ group_time_p_values: Dict[Tuple[Any, Any], float]
36
+ event_study_ses: Optional[Dict[int, float]] = None
37
+ event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
38
+ event_study_p_values: Optional[Dict[int, float]] = None
39
+ group_effect_ses: Optional[Dict[Any, float]] = None
40
+ group_effect_cis: Optional[Dict[Any, Tuple[float, float]]] = None
41
+ group_effect_p_values: Optional[Dict[Any, float]] = None
42
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
43
+
44
+
45
+ class EfficientDiDBootstrapMixin:
46
+ """Mixin providing multiplier bootstrap for EfficientDiD."""
47
+
48
+ n_bootstrap: int
49
+ bootstrap_weights: str
50
+ alpha: float
51
+ seed: Optional[int]
52
+ anticipation: int
53
+
54
+ def _run_multiplier_bootstrap(
55
+ self,
56
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
57
+ eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
58
+ n_units: int,
59
+ aggregate: Optional[str],
60
+ balance_e: Optional[int],
61
+ treatment_groups: List[Any],
62
+ cohort_fractions: Dict[float, float],
63
+ cluster_indices: Optional[np.ndarray] = None,
64
+ n_clusters: Optional[int] = None,
65
+ resolved_survey: object = None,
66
+ unit_level_weights: Optional[np.ndarray] = None,
67
+ ) -> EDiDBootstrapResults:
68
+ """Run multiplier bootstrap on stored EIF values.
69
+
70
+ For each bootstrap draw *b*, perturb ATT(g,t) as::
71
+
72
+ ATT_b(g,t) = ATT(g,t) + (1/n) * xi_b @ eif_gt
73
+
74
+ where ``xi_b`` is an i.i.d. weight vector of length ``n_units``.
75
+ When ``cluster_indices`` is provided, weights are generated at the
76
+ cluster level and expanded to units.
77
+
78
+ Aggregations (overall, event study, group) are recomputed from
79
+ the perturbed ATT(g,t) values.
80
+
81
+ Note: Bootstrap aggregation uses fixed cohort-size weights, consistent
82
+ with the Callaway-Sant'Anna bootstrap pattern (staggered_bootstrap.py).
83
+ The analytical path includes a WIF correction for aggregated SEs, but
84
+ the bootstrap captures weight uncertainty through EIF perturbation.
85
+ This matches the R ``did`` package approach.
86
+ """
87
+ if self.n_bootstrap < 50:
88
+ warnings.warn(
89
+ f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
90
+ "for reliable inference.",
91
+ UserWarning,
92
+ stacklevel=3,
93
+ )
94
+
95
+ rng = np.random.default_rng(self.seed)
96
+
97
+ gt_pairs = list(group_time_effects.keys())
98
+ n_gt = len(gt_pairs)
99
+
100
+ # Generate bootstrap weights — PSU-level when survey design is present,
101
+ # cluster-level if clustered, unit-level otherwise.
102
+ _use_survey_bootstrap = resolved_survey is not None and (
103
+ resolved_survey.strata is not None
104
+ or resolved_survey.psu is not None
105
+ or resolved_survey.fpc is not None
106
+ )
107
+
108
+ if _use_survey_bootstrap:
109
+ from diff_diff.bootstrap_utils import (
110
+ generate_survey_multiplier_weights_batch as _gen_survey_weights,
111
+ )
112
+
113
+ psu_weights, psu_ids = _gen_survey_weights(
114
+ self.n_bootstrap, resolved_survey, self.bootstrap_weights, rng
115
+ )
116
+ # Build unit -> PSU column map
117
+ if resolved_survey.psu is not None:
118
+ psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)}
119
+ unit_to_psu_col = np.array(
120
+ [psu_id_to_col[int(resolved_survey.psu[i])] for i in range(n_units)]
121
+ )
122
+ else:
123
+ unit_to_psu_col = np.arange(n_units)
124
+ all_weights = psu_weights[:, unit_to_psu_col]
125
+ elif cluster_indices is not None and n_clusters is not None:
126
+ cluster_weights = _generate_bootstrap_weights_batch(
127
+ self.n_bootstrap, n_clusters, self.bootstrap_weights, rng
128
+ )
129
+ # Expand cluster weights to unit level
130
+ all_weights = cluster_weights[:, cluster_indices]
131
+ else:
132
+ all_weights = _generate_bootstrap_weights_batch(
133
+ self.n_bootstrap, n_units, self.bootstrap_weights, rng
134
+ )
135
+
136
+ # Original ATTs
137
+ original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])
138
+
139
+ # Perturbed ATTs: (n_bootstrap, n_gt)
140
+ # Under survey design, perturb survey-score object w_i * eif_i / sum(w)
141
+ # to match the analytical variance convention (compute_survey_if_variance).
142
+ bootstrap_atts = np.zeros((self.n_bootstrap, n_gt))
143
+ for j, gt in enumerate(gt_pairs):
144
+ eif_gt = eif_by_gt[gt] # shape (n_units,)
145
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
146
+ if unit_level_weights is not None:
147
+ total_w = float(np.sum(unit_level_weights))
148
+ eif_scaled = unit_level_weights * eif_gt / total_w
149
+ perturbation = all_weights @ eif_scaled
150
+ else:
151
+ perturbation = (all_weights @ eif_gt) / n_units
152
+ bootstrap_atts[:, j] = original_atts[j] + perturbation
153
+
154
+ # Post-treatment mask — also exclude NaN effects
155
+ post_mask = np.array(
156
+ [
157
+ t >= g - self.anticipation and np.isfinite(original_atts[j])
158
+ for j, (g, t) in enumerate(gt_pairs)
159
+ ]
160
+ )
161
+ post_indices = np.where(post_mask)[0]
162
+
163
+ # Overall ATT: fixed-weight re-aggregation of perturbed cell ATTs.
164
+ # This matches CallawaySantAnna._run_multiplier_bootstrap
165
+ # (staggered_bootstrap.py:281). The analytical path includes a WIF
166
+ # correction; bootstrap captures sampling variability through per-cell
167
+ # EIF perturbation without re-estimating weights — this is standard
168
+ # in both this library's CS implementation and the R did package.
169
+ skip_overall = len(post_indices) == 0
170
+ if skip_overall:
171
+ bootstrap_overall = np.full(self.n_bootstrap, np.nan)
172
+ original_overall = np.nan
173
+ else:
174
+ post_groups = [gt_pairs[i][0] for i in post_indices]
175
+ pg = np.array([cohort_fractions.get(g, 0.0) for g in post_groups])
176
+ agg_w = pg / pg.sum() if pg.sum() > 0 else np.ones(len(pg)) / len(pg)
177
+ original_overall = float(np.sum(agg_w * original_atts[post_mask]))
178
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
179
+ bootstrap_overall = bootstrap_atts[:, post_indices] @ agg_w
180
+
181
+ # Event study: fixed-weight re-aggregation (same pattern as overall).
182
+ # See note above re: WIF — analytical WIF is not needed in bootstrap.
183
+ bootstrap_event_study = None
184
+ event_study_info = None
185
+ if aggregate in ("event_study", "all"):
186
+ event_study_info = self._prepare_es_agg_boot(
187
+ gt_pairs, original_atts, cohort_fractions, balance_e
188
+ )
189
+ bootstrap_event_study = {}
190
+ for e, info in event_study_info.items():
191
+ idx = info["gt_indices"]
192
+ w = info["weights"]
193
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
194
+ bootstrap_event_study[e] = bootstrap_atts[:, idx] @ w
195
+
196
+ # Group aggregation
197
+ bootstrap_group = None
198
+ group_agg_info = None
199
+ if aggregate in ("group", "all"):
200
+ group_agg_info = self._prepare_group_agg_boot(gt_pairs, original_atts, treatment_groups)
201
+ bootstrap_group = {}
202
+ for g, info in group_agg_info.items():
203
+ idx = info["gt_indices"]
204
+ w = info["weights"]
205
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
206
+ bootstrap_group[g] = bootstrap_atts[:, idx] @ w
207
+
208
+ # Compute statistics
209
+ gt_ses: Dict[Tuple[Any, Any], float] = {}
210
+ gt_cis: Dict[Tuple[Any, Any], Tuple[float, float]] = {}
211
+ gt_pvals: Dict[Tuple[Any, Any], float] = {}
212
+ for j, gt in enumerate(gt_pairs):
213
+ se, ci, pv = _compute_effect_bootstrap_stats_func(
214
+ original_atts[j],
215
+ bootstrap_atts[:, j],
216
+ alpha=self.alpha,
217
+ context=f"ATT(g={gt[0]}, t={gt[1]})",
218
+ )
219
+ gt_ses[gt] = se
220
+ gt_cis[gt] = ci
221
+ gt_pvals[gt] = pv
222
+
223
+ if skip_overall:
224
+ ov_se, ov_ci, ov_pv = np.nan, (np.nan, np.nan), np.nan
225
+ else:
226
+ ov_se, ov_ci, ov_pv = _compute_effect_bootstrap_stats_func(
227
+ original_overall,
228
+ bootstrap_overall,
229
+ alpha=self.alpha,
230
+ context="overall ATT",
231
+ )
232
+
233
+ es_ses = es_cis = es_pvs = None
234
+ if bootstrap_event_study is not None and event_study_info is not None:
235
+ es_ses, es_cis, es_pvs = {}, {}, {}
236
+ for e in sorted(event_study_info.keys()):
237
+ se, ci, pv = _compute_effect_bootstrap_stats_func(
238
+ event_study_info[e]["effect"],
239
+ bootstrap_event_study[e],
240
+ alpha=self.alpha,
241
+ context=f"event study (e={e})",
242
+ )
243
+ es_ses[e] = se
244
+ es_cis[e] = ci
245
+ es_pvs[e] = pv
246
+
247
+ g_ses = g_cis = g_pvs = None
248
+ if bootstrap_group is not None and group_agg_info is not None:
249
+ g_ses, g_cis, g_pvs = {}, {}, {}
250
+ for g in sorted(group_agg_info.keys()):
251
+ se, ci, pv = _compute_effect_bootstrap_stats_func(
252
+ group_agg_info[g]["effect"],
253
+ bootstrap_group[g],
254
+ alpha=self.alpha,
255
+ context=f"group effect (g={g})",
256
+ )
257
+ g_ses[g] = se
258
+ g_cis[g] = ci
259
+ g_pvs[g] = pv
260
+
261
+ return EDiDBootstrapResults(
262
+ n_bootstrap=self.n_bootstrap,
263
+ weight_type=self.bootstrap_weights,
264
+ alpha=self.alpha,
265
+ overall_att_se=ov_se,
266
+ overall_att_ci=ov_ci,
267
+ overall_att_p_value=ov_pv,
268
+ group_time_ses=gt_ses,
269
+ group_time_cis=gt_cis,
270
+ group_time_p_values=gt_pvals,
271
+ event_study_ses=es_ses,
272
+ event_study_cis=es_cis,
273
+ event_study_p_values=es_pvs,
274
+ group_effect_ses=g_ses,
275
+ group_effect_cis=g_cis,
276
+ group_effect_p_values=g_pvs,
277
+ bootstrap_distribution=bootstrap_overall,
278
+ )
279
+
280
+ def _prepare_es_agg_boot(
281
+ self,
282
+ gt_pairs: List[Tuple[Any, Any]],
283
+ original_atts: np.ndarray,
284
+ cohort_fractions: Dict[float, float],
285
+ balance_e: Optional[int],
286
+ ) -> Dict[int, Dict[str, Any]]:
287
+ """Prepare event-study aggregation info for bootstrap."""
288
+ effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}
289
+ for j, (g, t) in enumerate(gt_pairs):
290
+ if not np.isfinite(original_atts[j]):
291
+ continue # Skip NaN cells
292
+ e = t - g
293
+ if e not in effects_by_e:
294
+ effects_by_e[e] = []
295
+ effects_by_e[e].append((j, original_atts[j], cohort_fractions.get(g, 0.0)))
296
+
297
+ if balance_e is not None:
298
+ groups_at_e = {
299
+ gt_pairs[j][0]
300
+ for j, (g, t) in enumerate(gt_pairs)
301
+ if t - g == balance_e and np.isfinite(original_atts[j])
302
+ }
303
+ balanced: Dict[int, List[Tuple[int, float, float]]] = {}
304
+ for j, (g, t) in enumerate(gt_pairs):
305
+ if g in groups_at_e:
306
+ if not np.isfinite(original_atts[j]):
307
+ continue # Skip NaN cells even in balanced set
308
+ e = t - g
309
+ if e not in balanced:
310
+ balanced[e] = []
311
+ balanced[e].append((j, original_atts[j], cohort_fractions.get(g, 0.0)))
312
+ effects_by_e = balanced
313
+
314
+ if balance_e is not None and not effects_by_e:
315
+ warnings.warn(
316
+ f"balance_e={balance_e}: no cohort has a finite effect at the "
317
+ "anchor horizon. Event study will be empty.",
318
+ UserWarning,
319
+ stacklevel=2,
320
+ )
321
+
322
+ result = {}
323
+ for e, elist in effects_by_e.items():
324
+ indices = np.array([x[0] for x in elist])
325
+ effs = np.array([x[1] for x in elist])
326
+ pgs = np.array([x[2] for x in elist])
327
+ w = pgs / pgs.sum() if pgs.sum() > 0 else np.ones(len(pgs)) / len(pgs)
328
+ result[e] = {
329
+ "gt_indices": indices,
330
+ "weights": w,
331
+ "effect": float(np.sum(w * effs)),
332
+ }
333
+ return result
334
+
335
+ def _prepare_group_agg_boot(
336
+ self,
337
+ gt_pairs: List[Tuple[Any, Any]],
338
+ original_atts: np.ndarray,
339
+ treatment_groups: List[Any],
340
+ ) -> Dict[Any, Dict[str, Any]]:
341
+ """Prepare group-level aggregation info for bootstrap."""
342
+ result = {}
343
+ for g in treatment_groups:
344
+ group_data = [
345
+ (j, original_atts[j])
346
+ for j, (gg, t) in enumerate(gt_pairs)
347
+ if gg == g and t >= g - self.anticipation and np.isfinite(original_atts[j])
348
+ ]
349
+ if not group_data:
350
+ continue
351
+ indices = np.array([x[0] for x in group_data])
352
+ effs = np.array([x[1] for x in group_data])
353
+ w = np.ones(len(effs)) / len(effs)
354
+ result[g] = {
355
+ "gt_indices": indices,
356
+ "weights": w,
357
+ "effect": float(np.sum(w * effs)),
358
+ }
359
+ return result