diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,416 @@
1
+ """
2
+ Result container classes for Staggered Triple Difference estimator.
3
+
4
+ This module provides dataclass containers for storing and presenting
5
+ group-time DDD effects and their aggregations.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ from diff_diff.results import _format_survey_block, _get_significance_stars
15
+
16
+ if TYPE_CHECKING:
17
+ from diff_diff.staggered_bootstrap import CSBootstrapResults
18
+
19
+
20
+ @dataclass
21
+ class StaggeredTripleDiffResults:
22
+ """
23
+ Results from Staggered Triple Difference (DDD) estimation.
24
+
25
+ Implements the Ortiz-Villavicencio & Sant'Anna (2025) estimator for
26
+ staggered adoption settings with an eligibility dimension.
27
+
28
+ Attributes
29
+ ----------
30
+ group_time_effects : dict
31
+ Dictionary mapping (group, time) tuples to effect dictionaries.
32
+ overall_att : float
33
+ Overall average treatment effect (weighted average of ATT(g,t)).
34
+ overall_se : float
35
+ Standard error of overall ATT.
36
+ overall_t_stat : float
37
+ T-statistic for overall ATT.
38
+ overall_p_value : float
39
+ P-value for overall ATT.
40
+ overall_conf_int : tuple
41
+ Confidence interval for overall ATT.
42
+ groups : list
43
+ List of enabling cohorts (first treatment periods).
44
+ time_periods : list
45
+ List of all time periods.
46
+ n_obs : int
47
+ Total number of observations.
48
+ n_treated_units : int
49
+ Number of treated units (S < inf AND Q = 1).
50
+ n_control_units : int
51
+ Number of units not in treated group.
52
+ n_never_enabled : int
53
+ Number of never-enabled units (S = inf or 0).
54
+ n_eligible : int
55
+ Number of eligible units (Q = 1).
56
+ n_ineligible : int
57
+ Number of ineligible units (Q = 0).
58
+ """
59
+
60
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]]
61
+ overall_att: float
62
+ overall_se: float
63
+ overall_t_stat: float
64
+ overall_p_value: float
65
+ overall_conf_int: Tuple[float, float]
66
+ groups: List[Any]
67
+ time_periods: List[Any]
68
+ n_obs: int
69
+ n_treated_units: int
70
+ n_control_units: int
71
+ n_never_enabled: int
72
+ n_eligible: int
73
+ n_ineligible: int
74
+ alpha: float = 0.05
75
+ control_group: str = "notyettreated"
76
+ base_period: str = "varying"
77
+ estimation_method: str = "dr"
78
+ event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
79
+ group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
80
+ influence_functions: Optional["np.ndarray"] = field(default=None, repr=False)
81
+ bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False)
82
+ cband_crit_value: Optional[float] = None
83
+ pscore_trim: float = 0.01
84
+ survey_metadata: Optional[Any] = field(default=None, repr=False)
85
+ comparison_group_counts: Optional[Dict[Tuple, int]] = field(default=None, repr=False)
86
+ gmm_weights: Optional[Dict[Tuple, Dict]] = field(default=None, repr=False)
87
+ epv_diagnostics: Optional[Dict[Tuple[Any, Any], Dict[str, Any]]] = field(
88
+ default=None, repr=False
89
+ )
90
+ epv_threshold: float = 10
91
+ pscore_fallback: str = "error"
92
+
93
+ def __repr__(self) -> str:
94
+ """Concise string representation."""
95
+ sig = _get_significance_stars(self.overall_p_value)
96
+ return (
97
+ f"StaggeredTripleDiffResults(ATT={self.overall_att:.4f}{sig}, "
98
+ f"SE={self.overall_se:.4f}, "
99
+ f"n_groups={len(self.groups)}, "
100
+ f"n_periods={len(self.time_periods)})"
101
+ )
102
+
103
+ @property
104
+ def coef_var(self) -> float:
105
+ """Coefficient of variation: SE / |overall ATT|. NaN when ATT is 0 or SE non-finite."""
106
+ if not (np.isfinite(self.overall_se) and self.overall_se >= 0):
107
+ return np.nan
108
+ if not np.isfinite(self.overall_att) or self.overall_att == 0:
109
+ return np.nan
110
+ return self.overall_se / abs(self.overall_att)
111
+
112
+ def summary(self, alpha: Optional[float] = None) -> str:
113
+ """
114
+ Generate formatted summary of estimation results.
115
+
116
+ Parameters
117
+ ----------
118
+ alpha : float, optional
119
+ Significance level. Defaults to alpha used in estimation.
120
+
121
+ Returns
122
+ -------
123
+ str
124
+ Formatted summary.
125
+ """
126
+ alpha = alpha or self.alpha
127
+ conf_level = int((1 - alpha) * 100)
128
+
129
+ lines = [
130
+ "=" * 85,
131
+ "Staggered Triple Difference (DDD) Results".center(85),
132
+ "=" * 85,
133
+ "",
134
+ f"{'Total observations:':<30} {self.n_obs:>10}",
135
+ f"{'Treated units (S<inf, Q=1):':<30} {self.n_treated_units:>10}",
136
+ f"{'Control units:':<30} {self.n_control_units:>10}",
137
+ f"{'Never-enabled units:':<30} {self.n_never_enabled:>10}",
138
+ f"{'Eligible units (Q=1):':<30} {self.n_eligible:>10}",
139
+ f"{'Ineligible units (Q=0):':<30} {self.n_ineligible:>10}",
140
+ f"{'Enabling cohorts:':<30} {len(self.groups):>10}",
141
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
142
+ f"{'Estimation method:':<30} {self.estimation_method:>10}",
143
+ f"{'Control group:':<30} {self.control_group:>10}",
144
+ f"{'Base period:':<30} {self.base_period:>10}",
145
+ "",
146
+ ]
147
+
148
+ if self.survey_metadata is not None:
149
+ sm = self.survey_metadata
150
+ lines.extend(_format_survey_block(sm, 85))
151
+
152
+ # Overall ATT
153
+ lines.extend(
154
+ [
155
+ "-" * 85,
156
+ "Overall Average Treatment Effect on the Treated".center(85),
157
+ "-" * 85,
158
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
159
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
160
+ "-" * 85,
161
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
162
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
163
+ f"{_get_significance_stars(self.overall_p_value):>6}",
164
+ "-" * 85,
165
+ "",
166
+ f"{conf_level}% Confidence Interval: "
167
+ f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
168
+ ]
169
+ )
170
+
171
+ cv = self.coef_var
172
+ if np.isfinite(cv):
173
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
174
+
175
+ lines.append("")
176
+
177
+ # EPV diagnostics block (if any cohort has low EPV)
178
+ if self.epv_diagnostics:
179
+ low_epv = {k: v for k, v in self.epv_diagnostics.items() if v.get("is_low")}
180
+ if low_epv:
181
+ n_affected = len(low_epv)
182
+ n_total = len(self.epv_diagnostics)
183
+ min_entry = min(low_epv.values(), key=lambda v: v["epv"])
184
+ min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
185
+ lines.extend(
186
+ [
187
+ "-" * 85,
188
+ "Propensity Score Diagnostics".center(85),
189
+ "-" * 85,
190
+ f"WARNING: Low Events Per Variable (EPV) in "
191
+ f"{n_affected} of {n_total} cohort-time cell(s).",
192
+ f"Minimum EPV: {min_entry['epv']:.1f} "
193
+ f"(cohort g={min_g[0]}). "
194
+ f"Threshold: {self.epv_threshold:.0f}.",
195
+ "Consider: estimation_method='reg' or fewer covariates.",
196
+ "Call results.epv_summary() for per-cohort details.",
197
+ "-" * 85,
198
+ "",
199
+ ]
200
+ )
201
+
202
+ # Event study effects
203
+ if self.event_study_effects:
204
+ ci_label = "Simult. CI" if self.cband_crit_value is not None else "Pointwise CI"
205
+ lines.extend(
206
+ [
207
+ "-" * 85,
208
+ "Event Study (Dynamic) Effects".center(85),
209
+ "-" * 85,
210
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
211
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
212
+ "-" * 85,
213
+ ]
214
+ )
215
+
216
+ for rel_t in sorted(self.event_study_effects.keys()):
217
+ eff = self.event_study_effects[rel_t]
218
+ sig = _get_significance_stars(eff["p_value"])
219
+ lines.append(
220
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
221
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
222
+ )
223
+
224
+ lines.extend(["-" * 85])
225
+ if self.cband_crit_value is not None:
226
+ lines.append(
227
+ f"{ci_label}: critical value = {self.cband_crit_value:.4f} "
228
+ f"(sup-t bootstrap, {conf_level}% family-wise)"
229
+ )
230
+ lines.append("")
231
+
232
+ # Group effects
233
+ if self.group_effects:
234
+ lines.extend(
235
+ [
236
+ "-" * 85,
237
+ "Effects by Enabling Cohort".center(85),
238
+ "-" * 85,
239
+ f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} "
240
+ f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
241
+ "-" * 85,
242
+ ]
243
+ )
244
+
245
+ for group in sorted(self.group_effects.keys()):
246
+ eff = self.group_effects[group]
247
+ sig = _get_significance_stars(eff["p_value"])
248
+ lines.append(
249
+ f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
250
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
251
+ )
252
+
253
+ lines.extend(["-" * 85, ""])
254
+
255
+ lines.extend(
256
+ [
257
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
258
+ "=" * 85,
259
+ ]
260
+ )
261
+
262
+ return "\n".join(lines)
263
+
264
+ def print_summary(self, alpha: Optional[float] = None) -> None:
265
+ """Print summary to stdout."""
266
+ print(self.summary(alpha))
267
+
268
+ def epv_summary(self, show_all: bool = False) -> pd.DataFrame:
269
+ """
270
+ Return per-cohort EPV diagnostics as a DataFrame.
271
+
272
+ Parameters
273
+ ----------
274
+ show_all : bool, default False
275
+ If False, only show cells with low EPV. If True, show all cells.
276
+
277
+ Returns
278
+ -------
279
+ pd.DataFrame
280
+ Columns: group, time, epv, n_events, n_params, is_low.
281
+ """
282
+ if not self.epv_diagnostics:
283
+ return pd.DataFrame(columns=["group", "time", "epv", "n_events", "n_params", "is_low"])
284
+ rows = []
285
+ for (g, t), diag in sorted(self.epv_diagnostics.items()):
286
+ if show_all or diag.get("is_low", False):
287
+ rows.append(
288
+ {
289
+ "group": g,
290
+ "time": t,
291
+ "epv": diag.get("epv"),
292
+ "n_events": diag.get("n_events"),
293
+ "n_params": diag.get("k"),
294
+ "is_low": diag.get("is_low", False),
295
+ }
296
+ )
297
+ cols = ["group", "time", "epv", "n_events", "n_params", "is_low"]
298
+ return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
299
+
300
+ def to_dataframe(self, level: str = "group_time") -> pd.DataFrame:
301
+ """
302
+ Convert results to DataFrame.
303
+
304
+ Parameters
305
+ ----------
306
+ level : str, default="group_time"
307
+ Level of aggregation: "group_time", "event_study", or "group".
308
+
309
+ Returns
310
+ -------
311
+ pd.DataFrame
312
+ Results as DataFrame.
313
+ """
314
+ if level == "group_time":
315
+ rows = []
316
+ for (g, t), data in self.group_time_effects.items():
317
+ row = {
318
+ "group": g,
319
+ "time": t,
320
+ "effect": data["effect"],
321
+ "se": data["se"],
322
+ "t_stat": data["t_stat"],
323
+ "p_value": data["p_value"],
324
+ "conf_int_lower": data["conf_int"][0],
325
+ "conf_int_upper": data["conf_int"][1],
326
+ }
327
+ if self.epv_diagnostics and (g, t) in self.epv_diagnostics:
328
+ row["epv"] = self.epv_diagnostics[(g, t)].get("epv")
329
+ rows.append(row)
330
+ return pd.DataFrame(rows)
331
+
332
+ elif level == "event_study":
333
+ if self.event_study_effects is None:
334
+ raise ValueError("Event study effects not computed. Use aggregate='event_study'.")
335
+ rows = []
336
+ for rel_t, data in sorted(self.event_study_effects.items()):
337
+ cband_ci = data.get("cband_conf_int", (np.nan, np.nan))
338
+ rows.append(
339
+ {
340
+ "relative_period": rel_t,
341
+ "effect": data["effect"],
342
+ "se": data["se"],
343
+ "t_stat": data["t_stat"],
344
+ "p_value": data["p_value"],
345
+ "conf_int_lower": data["conf_int"][0],
346
+ "conf_int_upper": data["conf_int"][1],
347
+ "cband_lower": cband_ci[0],
348
+ "cband_upper": cband_ci[1],
349
+ }
350
+ )
351
+ return pd.DataFrame(rows)
352
+
353
+ elif level == "group":
354
+ if self.group_effects is None:
355
+ raise ValueError("Group effects not computed. Use aggregate='group'.")
356
+ rows = []
357
+ for group, data in sorted(self.group_effects.items()):
358
+ rows.append(
359
+ {
360
+ "group": group,
361
+ "effect": data["effect"],
362
+ "se": data["se"],
363
+ "t_stat": data["t_stat"],
364
+ "p_value": data["p_value"],
365
+ "conf_int_lower": data["conf_int"][0],
366
+ "conf_int_upper": data["conf_int"][1],
367
+ }
368
+ )
369
+ return pd.DataFrame(rows)
370
+
371
+ else:
372
+ raise ValueError(
373
+ f"Unknown level: {level}. " "Use 'group_time', 'event_study', or 'group'."
374
+ )
375
+
376
+ def to_dict(self) -> Dict[str, Any]:
377
+ """Convert results to dictionary."""
378
+ d = {
379
+ "overall_att": self.overall_att,
380
+ "overall_se": self.overall_se,
381
+ "overall_t_stat": self.overall_t_stat,
382
+ "overall_p_value": self.overall_p_value,
383
+ "overall_conf_int": self.overall_conf_int,
384
+ "n_obs": self.n_obs,
385
+ "n_treated_units": self.n_treated_units,
386
+ "n_control_units": self.n_control_units,
387
+ "n_never_enabled": self.n_never_enabled,
388
+ "n_eligible": self.n_eligible,
389
+ "n_ineligible": self.n_ineligible,
390
+ "n_groups": len(self.groups),
391
+ "n_periods": len(self.time_periods),
392
+ "groups": self.groups,
393
+ "time_periods": self.time_periods,
394
+ "estimation_method": self.estimation_method,
395
+ "control_group": self.control_group,
396
+ "base_period": self.base_period,
397
+ "alpha": self.alpha,
398
+ "pscore_trim": self.pscore_trim,
399
+ }
400
+ if self.event_study_effects is not None:
401
+ d["event_study_effects"] = self.event_study_effects
402
+ if self.group_effects is not None:
403
+ d["group_effects"] = self.group_effects
404
+ if self.comparison_group_counts is not None:
405
+ d["comparison_group_counts"] = self.comparison_group_counts
406
+ return d
407
+
408
+ @property
409
+ def is_significant(self) -> bool:
410
+ """Check if overall ATT is significant."""
411
+ return bool(self.overall_p_value < self.alpha)
412
+
413
+ @property
414
+ def significance_stars(self) -> str:
415
+ """Significance stars for overall ATT."""
416
+ return _get_significance_stars(self.overall_p_value)