diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,349 @@
1
+ """Results class for WooldridgeDiD (ETWFE) estimator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from diff_diff.utils import safe_inference
12
+
13
+
14
+ @dataclass
15
+ class WooldridgeDiDResults:
16
+ """Results from WooldridgeDiD.fit().
17
+
18
+ Core output is ``group_time_effects``: a dict keyed by (cohort_g, time_t)
19
+ with per-cell ATT estimates and inference. Call ``.aggregate(type)`` to
20
+ compute any of the four jwdid_estat aggregation types.
21
+ """
22
+
23
+ # ------------------------------------------------------------------ #
24
+ # Core cohort×time estimates #
25
+ # ------------------------------------------------------------------ #
26
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]]
27
+ """key=(g,t), value={att, se, t_stat, p_value, conf_int}"""
28
+
29
+ # ------------------------------------------------------------------ #
30
+ # Simple (overall) aggregation — always populated at fit time #
31
+ # ------------------------------------------------------------------ #
32
+ overall_att: float
33
+ overall_se: float
34
+ overall_t_stat: float
35
+ overall_p_value: float
36
+ overall_conf_int: Tuple[float, float]
37
+
38
+ # ------------------------------------------------------------------ #
39
+ # Other aggregations — populated by .aggregate() #
40
+ # ------------------------------------------------------------------ #
41
+ group_effects: Optional[Dict[Any, Dict]] = field(default=None, repr=False)
42
+ calendar_effects: Optional[Dict[Any, Dict]] = field(default=None, repr=False)
43
+ event_study_effects: Optional[Dict[int, Dict]] = field(default=None, repr=False)
44
+
45
+ # ------------------------------------------------------------------ #
46
+ # Metadata #
47
+ # ------------------------------------------------------------------ #
48
+ method: str = "ols"
49
+ control_group: str = "not_yet_treated"
50
+ groups: List[Any] = field(default_factory=list)
51
+ time_periods: List[Any] = field(default_factory=list)
52
+ n_obs: int = 0
53
+ n_treated_units: int = 0
54
+ n_control_units: int = 0
55
+ alpha: float = 0.05
56
+ anticipation: int = 0
57
+ survey_metadata: Optional[Any] = field(default=None, repr=False)
58
+
59
+ # ------------------------------------------------------------------ #
60
+ # Internal — used by aggregate() for delta-method SEs #
61
+ # ------------------------------------------------------------------ #
62
+ _gt_weights: Dict[Tuple[Any, Any], int] = field(default_factory=dict, repr=False)
63
+ _gt_vcov: Optional[np.ndarray] = field(default=None, repr=False)
64
+ """Full vcov of all β_{g,t} coefficients (ordered same as sorted group_time_effects keys)."""
65
+ _gt_keys: List[Tuple[Any, Any]] = field(default_factory=list, repr=False)
66
+ """Ordered list of (g,t) keys corresponding to _gt_vcov columns."""
67
+ _df_survey: Optional[int] = field(default=None, repr=False)
68
+ """Survey degrees of freedom for t-distribution inference."""
69
+
70
+ # ------------------------------------------------------------------ #
71
+ # Public methods #
72
+ # ------------------------------------------------------------------ #
73
+
74
+ def aggregate(self, type: str) -> "WooldridgeDiDResults": # noqa: A002
75
+ """Compute and store one of the four jwdid_estat aggregation types.
76
+
77
+ Parameters
78
+ ----------
79
+ type : "simple" | "group" | "calendar" | "event"
80
+
81
+ Returns self for chaining.
82
+ """
83
+ valid = ("simple", "group", "calendar", "event")
84
+ if type not in valid:
85
+ raise ValueError(f"type must be one of {valid}, got {type!r}")
86
+
87
+ gt = self.group_time_effects
88
+ weights = self._gt_weights
89
+ vcov = self._gt_vcov
90
+ keys_ordered = self._gt_keys if self._gt_keys else sorted(gt.keys())
91
+
92
+ def _agg_se(w_vec: np.ndarray) -> float:
93
+ """Delta-method SE for a linear combination w'β given full vcov."""
94
+ if vcov is None or len(w_vec) != vcov.shape[0]:
95
+ return float("nan")
96
+ return float(np.sqrt(max(w_vec @ vcov @ w_vec, 0.0)))
97
+
98
+ def _build_effect(att: float, se: float) -> Dict[str, Any]:
99
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=self._df_survey)
100
+ return {
101
+ "att": att,
102
+ "se": se,
103
+ "t_stat": t_stat,
104
+ "p_value": p_value,
105
+ "conf_int": conf_int,
106
+ }
107
+
108
+ if type == "simple":
109
+ # Re-compute overall using delta method (already stored in overall_* fields)
110
+ # This is a no-op but keeps the method callable.
111
+ pass
112
+
113
+ elif type == "group":
114
+ result: Dict[Any, Dict] = {}
115
+ for g in self.groups:
116
+ cells = [(g2, t) for (g2, t) in keys_ordered if g2 == g and t >= g]
117
+ if not cells:
118
+ continue
119
+ w_total = sum(weights.get(c, 0) for c in cells)
120
+ if w_total == 0:
121
+ continue
122
+ att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total
123
+ # delta-method weights vector over all keys_ordered
124
+ w_vec = np.array(
125
+ [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered]
126
+ )
127
+ se = _agg_se(w_vec)
128
+ result[g] = _build_effect(att, se)
129
+ self.group_effects = result
130
+
131
+ elif type == "calendar":
132
+ result = {}
133
+ for t in self.time_periods:
134
+ cells = [(g, t2) for (g, t2) in keys_ordered if t2 == t and t >= g]
135
+ if not cells:
136
+ continue
137
+ w_total = sum(weights.get(c, 0) for c in cells)
138
+ if w_total == 0:
139
+ continue
140
+ att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total
141
+ w_vec = np.array(
142
+ [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered]
143
+ )
144
+ se = _agg_se(w_vec)
145
+ result[t] = _build_effect(att, se)
146
+ self.calendar_effects = result
147
+
148
+ elif type == "event":
149
+ all_k = sorted({t - g for (g, t) in keys_ordered})
150
+ result = {}
151
+ for k in all_k:
152
+ cells = [(g, t) for (g, t) in keys_ordered if t - g == k]
153
+ if not cells:
154
+ continue
155
+ w_total = sum(weights.get(c, 0) for c in cells)
156
+ if w_total == 0:
157
+ continue
158
+ att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total
159
+ w_vec = np.array(
160
+ [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered]
161
+ )
162
+ se = _agg_se(w_vec)
163
+ result[k] = _build_effect(att, se)
164
+ self.event_study_effects = result
165
+
166
+ return self
167
+
168
+ def summary(self, aggregation: str = "simple") -> str:
169
+ """Print formatted summary table.
170
+
171
+ Parameters
172
+ ----------
173
+ aggregation : which aggregation to display ("simple", "group", "calendar", "event")
174
+ """
175
+ lines = [
176
+ "=" * 70,
177
+ " Wooldridge Extended Two-Way Fixed Effects (ETWFE) Results",
178
+ "=" * 70,
179
+ f"Method: {self.method}",
180
+ f"Control group: {self.control_group}",
181
+ f"Observations: {self.n_obs}",
182
+ f"Treated units: {self.n_treated_units}",
183
+ f"Control units: {self.n_control_units}",
184
+ "-" * 70,
185
+ ]
186
+
187
+ if self.survey_metadata is not None:
188
+ from diff_diff.results import _format_survey_block
189
+ lines.extend(_format_survey_block(self.survey_metadata, 70))
190
+ lines.append("-" * 70)
191
+
192
+ def _fmt_row(label: str, att: float, se: float, t: float, p: float, ci: Tuple) -> str:
193
+ from diff_diff.results import _get_significance_stars # type: ignore
194
+
195
+ stars = _get_significance_stars(p) if not np.isnan(p) else ""
196
+ ci_lo = f"{ci[0]:.4f}" if not np.isnan(ci[0]) else "NaN"
197
+ ci_hi = f"{ci[1]:.4f}" if not np.isnan(ci[1]) else "NaN"
198
+ return (
199
+ f"{label:<22} {att:>10.4f} {se:>10.4f} {t:>8.3f} "
200
+ f"{p:>8.4f}{stars} [{ci_lo}, {ci_hi}]"
201
+ )
202
+
203
+ ci_pct = f"{(1 - self.alpha) * 100:.0f}%"
204
+ header = (
205
+ f"{'Parameter':<22} {'Estimate':>10} {'Std. Err.':>10} "
206
+ f"{'t-stat':>8} {'P>|t|':>8} [{ci_pct} CI]"
207
+ )
208
+ lines.append(header)
209
+ lines.append("-" * 70)
210
+
211
+ if aggregation == "simple":
212
+ lines.append(
213
+ _fmt_row(
214
+ "ATT (simple)",
215
+ self.overall_att,
216
+ self.overall_se,
217
+ self.overall_t_stat,
218
+ self.overall_p_value,
219
+ self.overall_conf_int,
220
+ )
221
+ )
222
+ elif aggregation == "group" and self.group_effects:
223
+ for g, eff in sorted(self.group_effects.items()):
224
+ lines.append(
225
+ _fmt_row(
226
+ f"ATT(g={g})",
227
+ eff["att"],
228
+ eff["se"],
229
+ eff["t_stat"],
230
+ eff["p_value"],
231
+ eff["conf_int"],
232
+ )
233
+ )
234
+ elif aggregation == "calendar" and self.calendar_effects:
235
+ for t, eff in sorted(self.calendar_effects.items()):
236
+ lines.append(
237
+ _fmt_row(
238
+ f"ATT(t={t})",
239
+ eff["att"],
240
+ eff["se"],
241
+ eff["t_stat"],
242
+ eff["p_value"],
243
+ eff["conf_int"],
244
+ )
245
+ )
246
+ elif aggregation == "event" and self.event_study_effects:
247
+ for k, eff in sorted(self.event_study_effects.items()):
248
+ if k < -self.anticipation:
249
+ suffix = " [pre]"
250
+ elif k < 0:
251
+ suffix = " [antic]"
252
+ else:
253
+ suffix = ""
254
+ label = f"ATT(k={k})" + suffix
255
+ lines.append(
256
+ _fmt_row(
257
+ label,
258
+ eff["att"],
259
+ eff["se"],
260
+ eff["t_stat"],
261
+ eff["p_value"],
262
+ eff["conf_int"],
263
+ )
264
+ )
265
+ else:
266
+ lines.append(f" (call .aggregate({aggregation!r}) first)")
267
+
268
+ lines.append("=" * 70)
269
+ return "\n".join(lines)
270
+
271
+ def to_dataframe(self, aggregation: str = "event") -> pd.DataFrame:
272
+ """Export aggregated effects to a DataFrame.
273
+
274
+ Parameters
275
+ ----------
276
+ aggregation : "simple" | "group" | "calendar" | "event" | "gt"
277
+ Use "gt" to export raw group-time effects.
278
+ """
279
+ if aggregation == "gt":
280
+ rows = []
281
+ for (g, t), eff in sorted(self.group_time_effects.items()):
282
+ row = {"cohort": g, "time": t, "relative_period": t - g}
283
+ row.update(eff)
284
+ rows.append(row)
285
+ return pd.DataFrame(rows)
286
+
287
+ mapping = {
288
+ "simple": [
289
+ {
290
+ "label": "ATT",
291
+ "att": self.overall_att,
292
+ "se": self.overall_se,
293
+ "t_stat": self.overall_t_stat,
294
+ "p_value": self.overall_p_value,
295
+ "conf_int_lo": self.overall_conf_int[0],
296
+ "conf_int_hi": self.overall_conf_int[1],
297
+ }
298
+ ],
299
+ "group": [
300
+ {
301
+ "cohort": g,
302
+ **{k: v for k, v in eff.items() if k != "conf_int"},
303
+ "conf_int_lo": eff["conf_int"][0],
304
+ "conf_int_hi": eff["conf_int"][1],
305
+ }
306
+ for g, eff in sorted((self.group_effects or {}).items())
307
+ ],
308
+ "calendar": [
309
+ {
310
+ "time": t,
311
+ **{k: v for k, v in eff.items() if k != "conf_int"},
312
+ "conf_int_lo": eff["conf_int"][0],
313
+ "conf_int_hi": eff["conf_int"][1],
314
+ }
315
+ for t, eff in sorted((self.calendar_effects or {}).items())
316
+ ],
317
+ "event": [
318
+ {
319
+ "relative_period": k,
320
+ **{kk: vv for kk, vv in eff.items() if kk != "conf_int"},
321
+ "conf_int_lo": eff["conf_int"][0],
322
+ "conf_int_hi": eff["conf_int"][1],
323
+ }
324
+ for k, eff in sorted((self.event_study_effects or {}).items())
325
+ ],
326
+ }
327
+ rows = mapping.get(aggregation, [])
328
+ return pd.DataFrame(rows)
329
+
330
+ def plot_event_study(self, **kwargs) -> None:
331
+ """Event study plot. Calls aggregate('event') if needed."""
332
+ if self.event_study_effects is None:
333
+ self.aggregate("event")
334
+ from diff_diff.visualization import plot_event_study # type: ignore
335
+
336
+ effects = {k: v["att"] for k, v in (self.event_study_effects or {}).items()}
337
+ se = {k: v["se"] for k, v in (self.event_study_effects or {}).items()}
338
+ plot_event_study(effects=effects, se=se, alpha=self.alpha, **kwargs)
339
+
340
+ def __repr__(self) -> str:
341
+ n_gt = len(self.group_time_effects)
342
+ att_str = f"{self.overall_att:.4f}" if not np.isnan(self.overall_att) else "NaN"
343
+ se_str = f"{self.overall_se:.4f}" if not np.isnan(self.overall_se) else "NaN"
344
+ p_str = f"{self.overall_p_value:.4f}" if not np.isnan(self.overall_p_value) else "NaN"
345
+ return (
346
+ f"WooldridgeDiDResults("
347
+ f"ATT={att_str}, SE={se_str}, p={p_str}, "
348
+ f"n_gt={n_gt}, method={self.method!r})"
349
+ )