diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,416 @@
1
+ """
2
+ Result container classes for Callaway-Sant'Anna estimator.
3
+
4
+ This module provides dataclass containers for storing and presenting
5
+ group-time average treatment effects and their aggregations.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ from diff_diff.results import _format_survey_block, _get_significance_stars
15
+
16
+ if TYPE_CHECKING:
17
+ from diff_diff.staggered_bootstrap import CSBootstrapResults
18
+
19
+
20
+ @dataclass
21
+ class GroupTimeEffect:
22
+ """
23
+ Treatment effect for a specific group-time combination.
24
+
25
+ Attributes
26
+ ----------
27
+ group : any
28
+ The treatment cohort (first treatment period).
29
+ time : any
30
+ The time period.
31
+ effect : float
32
+ The ATT(g,t) estimate.
33
+ se : float
34
+ Standard error.
35
+ n_treated : int
36
+ Number of treated observations.
37
+ n_control : int
38
+ Number of control observations.
39
+ """
40
+
41
+ group: Any
42
+ time: Any
43
+ effect: float
44
+ se: float
45
+ t_stat: float
46
+ p_value: float
47
+ conf_int: Tuple[float, float]
48
+ n_treated: int
49
+ n_control: int
50
+
51
+ @property
52
+ def is_significant(self) -> bool:
53
+ """Check if effect is significant at 0.05 level."""
54
+ return bool(self.p_value < 0.05)
55
+
56
+ @property
57
+ def significance_stars(self) -> str:
58
+ """Return significance stars based on p-value."""
59
+ return _get_significance_stars(self.p_value)
60
+
61
+
62
+ @dataclass
63
+ class CallawaySantAnnaResults:
64
+ """
65
+ Results from Callaway-Sant'Anna (2021) staggered DiD estimation.
66
+
67
+ This class stores group-time average treatment effects ATT(g,t) and
68
+ provides methods for aggregation into summary measures.
69
+
70
+ Attributes
71
+ ----------
72
+ group_time_effects : dict
73
+ Dictionary mapping (group, time) tuples to effect dictionaries.
74
+ overall_att : float
75
+ Overall average treatment effect (weighted average of ATT(g,t)).
76
+ overall_se : float
77
+ Standard error of overall ATT.
78
+ overall_p_value : float
79
+ P-value for overall ATT.
80
+ overall_conf_int : tuple
81
+ Confidence interval for overall ATT.
82
+ groups : list
83
+ List of treatment cohorts (first treatment periods).
84
+ time_periods : list
85
+ List of all time periods.
86
+ n_obs : int
87
+ Total number of observations.
88
+ n_treated_units : int
89
+ Number of ever-treated units.
90
+ n_control_units : int
91
+ Number of never-treated units (excludes not-yet-treated dynamic controls).
92
+ event_study_effects : dict, optional
93
+ Effects aggregated by relative time (event study).
94
+ group_effects : dict, optional
95
+ Effects aggregated by treatment cohort.
96
+ pscore_trim : float
97
+ Propensity score trimming bound used during estimation.
98
+ """
99
+
100
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]]
101
+ overall_att: float
102
+ overall_se: float
103
+ overall_t_stat: float
104
+ overall_p_value: float
105
+ overall_conf_int: Tuple[float, float]
106
+ groups: List[Any]
107
+ time_periods: List[Any]
108
+ n_obs: int
109
+ n_treated_units: int
110
+ n_control_units: int
111
+ alpha: float = 0.05
112
+ control_group: str = "never_treated"
113
+ base_period: str = "varying"
114
+ panel: bool = True
115
+ event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
116
+ group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
117
+ influence_functions: Optional["np.ndarray"] = field(default=None, repr=False)
118
+ # Full event-study VCV matrix (Phase 7d): indexed by event_study_vcov_index
119
+ event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False)
120
+ event_study_vcov_index: Optional[list] = field(default=None, repr=False)
121
+ bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False)
122
+ cband_crit_value: Optional[float] = None
123
+ pscore_trim: float = 0.01
124
+ # Survey design metadata (SurveyMetadata instance from diff_diff.survey)
125
+ survey_metadata: Optional[Any] = field(default=None, repr=False)
126
+ # EPV diagnostics per (group, time) cell
127
+ epv_diagnostics: Optional[Dict[Tuple[Any, Any], Dict[str, Any]]] = field(
128
+ default=None, repr=False
129
+ )
130
+ epv_threshold: float = 10
131
+ pscore_fallback: str = "error"
132
+
133
+ def __repr__(self) -> str:
134
+ """Concise string representation."""
135
+ sig = _get_significance_stars(self.overall_p_value)
136
+ return (
137
+ f"CallawaySantAnnaResults(ATT={self.overall_att:.4f}{sig}, "
138
+ f"SE={self.overall_se:.4f}, "
139
+ f"n_groups={len(self.groups)}, "
140
+ f"n_periods={len(self.time_periods)})"
141
+ )
142
+
143
+ @property
144
+ def coef_var(self) -> float:
145
+ """Coefficient of variation: SE / |overall ATT|. NaN when ATT is 0 or SE non-finite."""
146
+ if not (np.isfinite(self.overall_se) and self.overall_se >= 0):
147
+ return np.nan
148
+ if not np.isfinite(self.overall_att) or self.overall_att == 0:
149
+ return np.nan
150
+ return self.overall_se / abs(self.overall_att)
151
+
152
+ def summary(self, alpha: Optional[float] = None) -> str:
153
+ """
154
+ Generate formatted summary of estimation results.
155
+
156
+ Parameters
157
+ ----------
158
+ alpha : float, optional
159
+ Significance level. Defaults to alpha used in estimation.
160
+
161
+ Returns
162
+ -------
163
+ str
164
+ Formatted summary.
165
+ """
166
+ alpha = alpha or self.alpha
167
+ conf_level = int((1 - alpha) * 100)
168
+
169
+ lines = [
170
+ "=" * 85,
171
+ "Callaway-Sant'Anna Staggered Difference-in-Differences Results".center(85),
172
+ "=" * 85,
173
+ "",
174
+ f"{'Total observations:':<30} {self.n_obs:>10}",
175
+ f"{'Treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_treated_units:>10}",
176
+ f"{'Never-treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_control_units:>10}",
177
+ f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
178
+ f"{'Time periods:':<30} {len(self.time_periods):>10}",
179
+ f"{'Control group:':<30} {self.control_group:>10}",
180
+ f"{'Base period:':<30} {self.base_period:>10}",
181
+ "",
182
+ ]
183
+
184
+ # Survey design info
185
+ if self.survey_metadata is not None:
186
+ sm = self.survey_metadata
187
+ lines.extend(_format_survey_block(sm, 85))
188
+
189
+ # Overall ATT
190
+ lines.extend(
191
+ [
192
+ "-" * 85,
193
+ "Overall Average Treatment Effect on the Treated".center(85),
194
+ "-" * 85,
195
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
196
+ "-" * 85,
197
+ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
198
+ f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
199
+ f"{_get_significance_stars(self.overall_p_value):>6}",
200
+ "-" * 85,
201
+ "",
202
+ f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
203
+ ]
204
+ )
205
+
206
+ cv = self.coef_var
207
+ if np.isfinite(cv):
208
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
209
+
210
+ lines.append("")
211
+
212
+ # EPV diagnostics block (if any cohort has low EPV)
213
+ if self.epv_diagnostics:
214
+ low_epv = {k: v for k, v in self.epv_diagnostics.items() if v.get("is_low")}
215
+ if low_epv:
216
+ n_affected = len(low_epv)
217
+ n_total = len(self.epv_diagnostics)
218
+ min_entry = min(low_epv.values(), key=lambda v: v["epv"])
219
+ min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
220
+ lines.extend(
221
+ [
222
+ "-" * 85,
223
+ "Propensity Score Diagnostics".center(85),
224
+ "-" * 85,
225
+ f"WARNING: Low Events Per Variable (EPV) in "
226
+ f"{n_affected} of {n_total} cohort-time cell(s).",
227
+ f"Minimum EPV: {min_entry['epv']:.1f} "
228
+ f"(cohort g={min_g[0]}). Threshold: {self.epv_threshold:.0f}.",
229
+ "Consider: estimation_method='reg' or fewer covariates.",
230
+ "Call results.epv_summary() for per-cohort details.",
231
+ "-" * 85,
232
+ "",
233
+ ]
234
+ )
235
+
236
+ # Event study effects if available
237
+ if self.event_study_effects:
238
+ ci_label = "Simult. CI" if self.cband_crit_value is not None else "Pointwise CI"
239
+ lines.extend(
240
+ [
241
+ "-" * 85,
242
+ "Event Study (Dynamic) Effects".center(85),
243
+ "-" * 85,
244
+ f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
245
+ "-" * 85,
246
+ ]
247
+ )
248
+
249
+ for rel_t in sorted(self.event_study_effects.keys()):
250
+ eff = self.event_study_effects[rel_t]
251
+ sig = _get_significance_stars(eff["p_value"])
252
+ lines.append(
253
+ f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
254
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
255
+ )
256
+
257
+ lines.extend(["-" * 85])
258
+ if self.cband_crit_value is not None:
259
+ lines.append(
260
+ f"{ci_label}: critical value = {self.cband_crit_value:.4f} "
261
+ f"(sup-t bootstrap, {conf_level}% family-wise)"
262
+ )
263
+ lines.append("")
264
+
265
+ # Group effects if available
266
+ if self.group_effects:
267
+ lines.extend(
268
+ [
269
+ "-" * 85,
270
+ "Effects by Treatment Cohort".center(85),
271
+ "-" * 85,
272
+ f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
273
+ "-" * 85,
274
+ ]
275
+ )
276
+
277
+ for group in sorted(self.group_effects.keys()):
278
+ eff = self.group_effects[group]
279
+ sig = _get_significance_stars(eff["p_value"])
280
+ lines.append(
281
+ f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
282
+ f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
283
+ )
284
+
285
+ lines.extend(["-" * 85, ""])
286
+
287
+ lines.extend(
288
+ [
289
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
290
+ "=" * 85,
291
+ ]
292
+ )
293
+
294
+ return "\n".join(lines)
295
+
296
+ def epv_summary(self, show_all: bool = False) -> pd.DataFrame:
297
+ """
298
+ Return per-cohort EPV diagnostics as a DataFrame.
299
+
300
+ Parameters
301
+ ----------
302
+ show_all : bool, default False
303
+ If False, only show cells with low EPV. If True, show all cells.
304
+
305
+ Returns
306
+ -------
307
+ pd.DataFrame
308
+ Columns: group, time, epv, n_events, n_params, is_low.
309
+ """
310
+ if not self.epv_diagnostics:
311
+ return pd.DataFrame(columns=["group", "time", "epv", "n_events", "n_params", "is_low"])
312
+ rows = []
313
+ for (g, t), diag in sorted(self.epv_diagnostics.items()):
314
+ if show_all or diag.get("is_low", False):
315
+ rows.append(
316
+ {
317
+ "group": g,
318
+ "time": t,
319
+ "epv": diag.get("epv"),
320
+ "n_events": diag.get("n_events"),
321
+ "n_params": diag.get("k"),
322
+ "is_low": diag.get("is_low", False),
323
+ }
324
+ )
325
+ cols = ["group", "time", "epv", "n_events", "n_params", "is_low"]
326
+ return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
327
+
328
+ def print_summary(self, alpha: Optional[float] = None) -> None:
329
+ """Print summary to stdout."""
330
+ print(self.summary(alpha))
331
+
332
+ def to_dataframe(self, level: str = "group_time") -> pd.DataFrame:
333
+ """
334
+ Convert results to DataFrame.
335
+
336
+ Parameters
337
+ ----------
338
+ level : str, default="group_time"
339
+ Level of aggregation: "group_time", "event_study", or "group".
340
+
341
+ Returns
342
+ -------
343
+ pd.DataFrame
344
+ Results as DataFrame.
345
+ """
346
+ if level == "group_time":
347
+ rows = []
348
+ for (g, t), data in self.group_time_effects.items():
349
+ row = {
350
+ "group": g,
351
+ "time": t,
352
+ "effect": data["effect"],
353
+ "se": data["se"],
354
+ "t_stat": data["t_stat"],
355
+ "p_value": data["p_value"],
356
+ "conf_int_lower": data["conf_int"][0],
357
+ "conf_int_upper": data["conf_int"][1],
358
+ }
359
+ if self.epv_diagnostics and (g, t) in self.epv_diagnostics:
360
+ row["epv"] = self.epv_diagnostics[(g, t)].get("epv")
361
+ rows.append(row)
362
+ return pd.DataFrame(rows)
363
+
364
+ elif level == "event_study":
365
+ if self.event_study_effects is None:
366
+ raise ValueError("Event study effects not computed. Use aggregate='event_study'.")
367
+ rows = []
368
+ for rel_t, data in sorted(self.event_study_effects.items()):
369
+ cband_ci = data.get("cband_conf_int", (np.nan, np.nan))
370
+ rows.append(
371
+ {
372
+ "relative_period": rel_t,
373
+ "effect": data["effect"],
374
+ "se": data["se"],
375
+ "t_stat": data["t_stat"],
376
+ "p_value": data["p_value"],
377
+ "conf_int_lower": data["conf_int"][0],
378
+ "conf_int_upper": data["conf_int"][1],
379
+ "cband_lower": cband_ci[0],
380
+ "cband_upper": cband_ci[1],
381
+ }
382
+ )
383
+ return pd.DataFrame(rows)
384
+
385
+ elif level == "group":
386
+ if self.group_effects is None:
387
+ raise ValueError("Group effects not computed. Use aggregate='group'.")
388
+ rows = []
389
+ for group, data in sorted(self.group_effects.items()):
390
+ rows.append(
391
+ {
392
+ "group": group,
393
+ "effect": data["effect"],
394
+ "se": data["se"],
395
+ "t_stat": data["t_stat"],
396
+ "p_value": data["p_value"],
397
+ "conf_int_lower": data["conf_int"][0],
398
+ "conf_int_upper": data["conf_int"][1],
399
+ }
400
+ )
401
+ return pd.DataFrame(rows)
402
+
403
+ else:
404
+ raise ValueError(
405
+ f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'."
406
+ )
407
+
408
+ @property
409
+ def is_significant(self) -> bool:
410
+ """Check if overall ATT is significant."""
411
+ return bool(self.overall_p_value < self.alpha)
412
+
413
+ @property
414
+ def significance_stars(self) -> str:
415
+ """Significance stars for overall ATT."""
416
+ return _get_significance_stars(self.overall_p_value)