diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/bacon.py ADDED
@@ -0,0 +1,1140 @@
1
+ """
2
+ Goodman-Bacon Decomposition for Two-Way Fixed Effects.
3
+
4
+ Implements the decomposition from Goodman-Bacon (2021) that shows how
5
+ TWFE estimates with staggered treatment timing can be written as a
6
+ weighted average of all possible 2x2 DiD comparisons.
7
+
8
+ Reference:
9
+ Goodman-Bacon, A. (2021). Difference-in-differences with variation
10
+ in treatment timing. Journal of Econometrics, 225(2), 254-277.
11
+ """
12
+
13
+ import warnings
14
+ from dataclasses import dataclass, field
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ from diff_diff.results import _format_survey_block
21
+ from diff_diff.utils import within_transform as _within_transform_util
22
+
23
+
24
+ @dataclass
25
+ class Comparison2x2:
26
+ """
27
+ A single 2x2 DiD comparison in the Bacon decomposition.
28
+
29
+ Attributes
30
+ ----------
31
+ treated_group : Any
32
+ The timing group used as "treated" in this comparison.
33
+ control_group : Any
34
+ The timing group used as "control" in this comparison.
35
+ comparison_type : str
36
+ Type of comparison: "treated_vs_never", "earlier_vs_later",
37
+ or "later_vs_earlier".
38
+ estimate : float
39
+ The 2x2 DiD estimate for this comparison.
40
+ weight : float
41
+ The weight assigned to this comparison in the TWFE average.
42
+ n_treated : int
43
+ Number of treated observations in this comparison.
44
+ n_control : int
45
+ Number of control observations in this comparison.
46
+ time_window : Tuple[float, float]
47
+ The (start, end) time period for this comparison.
48
+ """
49
+
50
+ treated_group: Any
51
+ control_group: Any
52
+ comparison_type: str
53
+ estimate: float
54
+ weight: float
55
+ n_treated: int
56
+ n_control: int
57
+ time_window: Tuple[float, float]
58
+
59
+ def __repr__(self) -> str:
60
+ return (
61
+ f"Comparison2x2({self.treated_group} vs {self.control_group}, "
62
+ f"type={self.comparison_type}, β={self.estimate:.4f}, "
63
+ f"weight={self.weight:.4f})"
64
+ )
65
+
66
+
67
+ @dataclass
68
+ class BaconDecompositionResults:
69
+ """
70
+ Results from Goodman-Bacon decomposition of TWFE.
71
+
72
+ This decomposition shows that the TWFE estimate equals a weighted
73
+ average of all possible 2x2 DiD comparisons between timing groups.
74
+
75
+ Attributes
76
+ ----------
77
+ twfe_estimate : float
78
+ The overall TWFE coefficient (should equal weighted sum of 2x2 estimates).
79
+ comparisons : List[Comparison2x2]
80
+ List of all 2x2 comparisons with their estimates and weights.
81
+ total_weight_treated_vs_never : float
82
+ Total weight on treated vs never-treated comparisons.
83
+ total_weight_earlier_vs_later : float
84
+ Total weight on earlier vs later treated comparisons.
85
+ total_weight_later_vs_earlier : float
86
+ Total weight on later vs earlier treated comparisons (forbidden).
87
+ weighted_avg_treated_vs_never : float
88
+ Weighted average effect from treated vs never-treated comparisons.
89
+ weighted_avg_earlier_vs_later : float
90
+ Weighted average effect from earlier vs later comparisons.
91
+ weighted_avg_later_vs_earlier : float
92
+ Weighted average effect from later vs earlier comparisons.
93
+ n_timing_groups : int
94
+ Number of distinct treatment timing groups.
95
+ n_never_treated : int
96
+ Number of never-treated units.
97
+ timing_groups : List[Any]
98
+ List of treatment timing cohorts.
99
+ """
100
+
101
+ twfe_estimate: float
102
+ comparisons: List[Comparison2x2]
103
+ total_weight_treated_vs_never: float
104
+ total_weight_earlier_vs_later: float
105
+ total_weight_later_vs_earlier: float
106
+ weighted_avg_treated_vs_never: Optional[float]
107
+ weighted_avg_earlier_vs_later: Optional[float]
108
+ weighted_avg_later_vs_earlier: Optional[float]
109
+ n_timing_groups: int
110
+ n_never_treated: int
111
+ timing_groups: List[Any]
112
+ n_obs: int = 0
113
+ decomposition_error: float = field(default=0.0)
114
+ # Survey design metadata (SurveyMetadata instance from diff_diff.survey)
115
+ survey_metadata: Optional[Any] = field(default=None)
116
+
117
+ def __repr__(self) -> str:
118
+ return (
119
+ f"BaconDecompositionResults(TWFE={self.twfe_estimate:.4f}, "
120
+ f"n_comparisons={len(self.comparisons)}, "
121
+ f"n_groups={self.n_timing_groups})"
122
+ )
123
+
124
+ def summary(self) -> str:
125
+ """
126
+ Generate a formatted summary of the decomposition.
127
+
128
+ Returns
129
+ -------
130
+ str
131
+ Formatted summary table.
132
+ """
133
+ lines = [
134
+ "=" * 85,
135
+ "Goodman-Bacon Decomposition of Two-Way Fixed Effects".center(85),
136
+ "=" * 85,
137
+ "",
138
+ f"{'Total observations:':<35} {self.n_obs:>10}",
139
+ f"{'Treatment timing groups:':<35} {self.n_timing_groups:>10}",
140
+ f"{'Never-treated units:':<35} {self.n_never_treated:>10}",
141
+ f"{'Total 2x2 comparisons:':<35} {len(self.comparisons):>10}",
142
+ "",
143
+ ]
144
+
145
+ # Add survey design info
146
+ if self.survey_metadata is not None:
147
+ sm = self.survey_metadata
148
+ lines.extend(_format_survey_block(sm, 85))
149
+
150
+ lines.extend(
151
+ [
152
+ "-" * 85,
153
+ "TWFE Decomposition".center(85),
154
+ "-" * 85,
155
+ "",
156
+ f"{'TWFE Estimate:':<35} {self.twfe_estimate:>12.4f}",
157
+ f"{'Weighted Sum of 2x2 Estimates:':<35} {self._weighted_sum():>12.4f}",
158
+ f"{'Decomposition Error:':<35} {self.decomposition_error:>12.6f}",
159
+ "",
160
+ ]
161
+ )
162
+
163
+ # Weight breakdown by comparison type
164
+ lines.extend(
165
+ [
166
+ "-" * 85,
167
+ "Weight Breakdown by Comparison Type".center(85),
168
+ "-" * 85,
169
+ f"{'Comparison Type':<30} {'Weight':>12} {'Avg Effect':>12} {'Contribution':>12}",
170
+ "-" * 85,
171
+ ]
172
+ )
173
+
174
+ # Treated vs Never-treated
175
+ if self.total_weight_treated_vs_never > 0:
176
+ contrib = self.total_weight_treated_vs_never * (self.weighted_avg_treated_vs_never or 0)
177
+ lines.append(
178
+ f"{'Treated vs Never-treated':<30} "
179
+ f"{self.total_weight_treated_vs_never:>12.4f} "
180
+ f"{self.weighted_avg_treated_vs_never or 0:>12.4f} "
181
+ f"{contrib:>12.4f}"
182
+ )
183
+
184
+ # Earlier vs Later
185
+ if self.total_weight_earlier_vs_later > 0:
186
+ contrib = self.total_weight_earlier_vs_later * (self.weighted_avg_earlier_vs_later or 0)
187
+ lines.append(
188
+ f"{'Earlier vs Later treated':<30} "
189
+ f"{self.total_weight_earlier_vs_later:>12.4f} "
190
+ f"{self.weighted_avg_earlier_vs_later or 0:>12.4f} "
191
+ f"{contrib:>12.4f}"
192
+ )
193
+
194
+ # Later vs Earlier (forbidden)
195
+ if self.total_weight_later_vs_earlier > 0:
196
+ contrib = self.total_weight_later_vs_earlier * (self.weighted_avg_later_vs_earlier or 0)
197
+ lines.append(
198
+ f"{'Later vs Earlier (forbidden)':<30} "
199
+ f"{self.total_weight_later_vs_earlier:>12.4f} "
200
+ f"{self.weighted_avg_later_vs_earlier or 0:>12.4f} "
201
+ f"{contrib:>12.4f}"
202
+ )
203
+
204
+ lines.extend(
205
+ [
206
+ "-" * 85,
207
+ f"{'Total':<30} {self._total_weight():>12.4f} "
208
+ f"{'':>12} {self._weighted_sum():>12.4f}",
209
+ "-" * 85,
210
+ "",
211
+ ]
212
+ )
213
+
214
+ # Warning about forbidden comparisons
215
+ if self.total_weight_later_vs_earlier > 0.01:
216
+ pct = self.total_weight_later_vs_earlier * 100
217
+ lines.extend(
218
+ [
219
+ "WARNING: {:.1f}% of weight is on 'forbidden' comparisons where".format(pct),
220
+ "already-treated units serve as controls. This can bias TWFE",
221
+ "when treatment effects are heterogeneous over time.",
222
+ "",
223
+ "Consider using Callaway-Sant'Anna or other robust estimators.",
224
+ "",
225
+ ]
226
+ )
227
+
228
+ lines.append("=" * 85)
229
+
230
+ return "\n".join(lines)
231
+
232
+ def print_summary(self) -> None:
233
+ """Print the summary to stdout."""
234
+ print(self.summary())
235
+
236
+ def _weighted_sum(self) -> float:
237
+ """Calculate weighted sum of 2x2 estimates."""
238
+ return sum(c.weight * c.estimate for c in self.comparisons)
239
+
240
+ def _total_weight(self) -> float:
241
+ """Calculate total weight (should be 1.0)."""
242
+ return sum(c.weight for c in self.comparisons)
243
+
244
+ def to_dataframe(self) -> pd.DataFrame:
245
+ """
246
+ Convert comparisons to a DataFrame.
247
+
248
+ Returns
249
+ -------
250
+ pd.DataFrame
251
+ DataFrame with one row per 2x2 comparison.
252
+ """
253
+ rows = []
254
+ for c in self.comparisons:
255
+ rows.append(
256
+ {
257
+ "treated_group": c.treated_group,
258
+ "control_group": c.control_group,
259
+ "comparison_type": c.comparison_type,
260
+ "estimate": c.estimate,
261
+ "weight": c.weight,
262
+ "n_treated": c.n_treated,
263
+ "n_control": c.n_control,
264
+ "time_start": c.time_window[0],
265
+ "time_end": c.time_window[1],
266
+ }
267
+ )
268
+ return pd.DataFrame(rows)
269
+
270
+ def weight_by_type(self) -> Dict[str, float]:
271
+ """
272
+ Get total weight by comparison type.
273
+
274
+ Returns
275
+ -------
276
+ Dict[str, float]
277
+ Dictionary mapping comparison type to total weight.
278
+ """
279
+ return {
280
+ "treated_vs_never": self.total_weight_treated_vs_never,
281
+ "earlier_vs_later": self.total_weight_earlier_vs_later,
282
+ "later_vs_earlier": self.total_weight_later_vs_earlier,
283
+ }
284
+
285
+ def effect_by_type(self) -> Dict[str, Optional[float]]:
286
+ """
287
+ Get weighted average effect by comparison type.
288
+
289
+ Returns
290
+ -------
291
+ Dict[str, Optional[float]]
292
+ Dictionary mapping comparison type to weighted average effect.
293
+ """
294
+ return {
295
+ "treated_vs_never": self.weighted_avg_treated_vs_never,
296
+ "earlier_vs_later": self.weighted_avg_earlier_vs_later,
297
+ "later_vs_earlier": self.weighted_avg_later_vs_earlier,
298
+ }
299
+
300
+
301
+ class BaconDecomposition:
302
+ """
303
+ Goodman-Bacon (2021) decomposition of Two-Way Fixed Effects estimator.
304
+
305
+ This class decomposes a TWFE estimate into a weighted average of all
306
+ possible 2x2 DiD comparisons, revealing the implicit comparisons that
307
+ drive the TWFE estimate and their relative importance.
308
+
309
+ The decomposition identifies three types of comparisons:
310
+
311
+ 1. **Treated vs Never-treated**: Uses never-treated units as controls.
312
+ These are "clean" comparisons without bias concerns.
313
+
314
+ 2. **Earlier vs Later treated**: Units treated earlier are compared to
315
+ units treated later, using the later group as controls before they
316
+ are treated. These are valid comparisons.
317
+
318
+ 3. **Later vs Earlier treated**: Units treated later are compared to
319
+ units treated earlier, using the earlier group as controls AFTER
320
+ they are already treated. These are "forbidden comparisons" that
321
+ can introduce bias when treatment effects vary over time.
322
+
323
+ Parameters
324
+ ----------
325
+ weights : str, default="approximate"
326
+ Weight calculation method:
327
+ - "approximate": Fast simplified formula using group shares and
328
+ treatment variance. Good for diagnostic purposes where relative
329
+ weights are sufficient to identify problematic comparisons.
330
+ - "exact": Variance-based weights from Goodman-Bacon (2021) Theorem 1.
331
+ Use for publication-quality decompositions where the weighted sum
332
+ must closely match the TWFE estimate.
333
+
334
+ Attributes
335
+ ----------
336
+ weights : str
337
+ The weight calculation method.
338
+ results_ : BaconDecompositionResults
339
+ Decomposition results after calling fit().
340
+ is_fitted_ : bool
341
+ Whether the model has been fitted.
342
+
343
+ Examples
344
+ --------
345
+ Basic usage:
346
+
347
+ >>> import pandas as pd
348
+ >>> from diff_diff import BaconDecomposition
349
+ >>>
350
+ >>> # Panel data with staggered treatment
351
+ >>> data = pd.DataFrame({
352
+ ... 'unit': [...],
353
+ ... 'time': [...],
354
+ ... 'outcome': [...],
355
+ ... 'first_treat': [...] # 0 for never-treated
356
+ ... })
357
+ >>>
358
+ >>> bacon = BaconDecomposition()
359
+ >>> results = bacon.fit(data, outcome='outcome', unit='unit',
360
+ ... time='time', first_treat='first_treat')
361
+ >>> results.print_summary()
362
+
363
+ Visualizing the decomposition:
364
+
365
+ >>> from diff_diff import plot_bacon
366
+ >>> plot_bacon(results)
367
+
368
+ Notes
369
+ -----
370
+ The key insight from Goodman-Bacon (2021) is that TWFE with staggered
371
+ treatment timing implicitly makes comparisons using already-treated
372
+ units as controls. When treatment effects are dynamic (changing over
373
+ time since treatment), these "forbidden comparisons" can bias the
374
+ TWFE estimate, potentially even reversing its sign.
375
+
376
+ The decomposition helps diagnose this issue by showing:
377
+ - How much weight is on each type of comparison
378
+ - Whether forbidden comparisons contribute significantly to the estimate
379
+ - How the 2x2 estimates vary across comparison types
380
+
381
+ If forbidden comparisons have substantial weight and different estimates
382
+ than clean comparisons, consider using robust estimators like
383
+ Callaway-Sant'Anna that avoid these problematic comparisons.
384
+
385
+ References
386
+ ----------
387
+ Goodman-Bacon, A. (2021). Difference-in-differences with variation in
388
+ treatment timing. Journal of Econometrics, 225(2), 254-277.
389
+
390
+ See Also
391
+ --------
392
+ CallawaySantAnna : Robust estimator for staggered DiD
393
+ TwoWayFixedEffects : The TWFE estimator being decomposed
394
+ """
395
+
396
+ def __init__(self, weights: str = "approximate"):
397
+ """
398
+ Initialize BaconDecomposition.
399
+
400
+ Parameters
401
+ ----------
402
+ weights : str, default="approximate"
403
+ Weight calculation method:
404
+ - "approximate": Fast simplified formula (default)
405
+ - "exact": Variance-based weights from Goodman-Bacon (2021)
406
+ """
407
+ if weights not in ("approximate", "exact"):
408
+ raise ValueError(f"weights must be 'approximate' or 'exact', got '{weights}'")
409
+ self.weights = weights
410
+ self.results_: Optional[BaconDecompositionResults] = None
411
+ self.is_fitted_: bool = False
412
+
413
+ def fit(
414
+ self,
415
+ data: pd.DataFrame,
416
+ outcome: str,
417
+ unit: str,
418
+ time: str,
419
+ first_treat: str,
420
+ survey_design=None,
421
+ ) -> BaconDecompositionResults:
422
+ """
423
+ Perform the Goodman-Bacon decomposition.
424
+
425
+ Parameters
426
+ ----------
427
+ data : pd.DataFrame
428
+ Panel data with unit and time identifiers.
429
+ outcome : str
430
+ Name of outcome variable column.
431
+ unit : str
432
+ Name of unit identifier column.
433
+ time : str
434
+ Name of time period column.
435
+ first_treat : str
436
+ Name of column indicating when unit was first treated.
437
+ Use 0 (or np.inf) for never-treated units.
438
+ survey_design : SurveyDesign, optional
439
+ Survey design specification for weighted estimation.
440
+ When provided, all means and group shares use survey weights.
441
+ The decomposition remains diagnostic (no survey vcov needed).
442
+
443
+ Returns
444
+ -------
445
+ BaconDecompositionResults
446
+ Object containing decomposition results.
447
+
448
+ Raises
449
+ ------
450
+ ValueError
451
+ If required columns are missing or data validation fails.
452
+ """
453
+ # Validate inputs
454
+ required_cols = [outcome, unit, time, first_treat]
455
+ missing = [c for c in required_cols if c not in data.columns]
456
+ if missing:
457
+ raise ValueError(f"Missing columns: {missing}")
458
+
459
+ # Resolve survey design if provided
460
+ from diff_diff.survey import _resolve_survey_for_fit
461
+
462
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
463
+ _resolve_survey_for_fit(survey_design, data, "analytical")
464
+ )
465
+ # Reject replicate-weight designs — Bacon decomposition is a
466
+ # diagnostic that does not compute replicate-based variance
467
+ if resolved_survey is not None and resolved_survey.uses_replicate_variance:
468
+ raise NotImplementedError(
469
+ "BaconDecomposition does not support replicate-weight survey "
470
+ "designs. Use a TSL-based survey design (strata/psu/fpc)."
471
+ )
472
+
473
+ # Validate within-unit constancy for exact survey weights only.
474
+ # The exact-weight path collapses to per-unit weights via groupby().first(),
475
+ # which requires constant survey columns within units. The approximate path
476
+ # uses observation-level weighted means and does not need this constraint.
477
+ if resolved_survey is not None and self.weights == "exact":
478
+ from diff_diff.survey import _validate_unit_constant_survey
479
+
480
+ _validate_unit_constant_survey(data, unit, survey_design)
481
+
482
+ # Create working copy
483
+ df = data.copy()
484
+
485
+ # Ensure numeric types
486
+ df[time] = pd.to_numeric(df[time])
487
+ df[first_treat] = pd.to_numeric(df[first_treat])
488
+
489
+ # Check for balanced panel
490
+ periods_per_unit = df.groupby(unit)[time].count()
491
+ if periods_per_unit.nunique() > 1:
492
+ warnings.warn(
493
+ "Unbalanced panel detected. Bacon decomposition assumes "
494
+ "balanced panels. Results may be inaccurate.",
495
+ UserWarning,
496
+ stacklevel=2,
497
+ )
498
+
499
+ # Get unique time periods and timing groups
500
+ time_periods = sorted(df[time].unique())
501
+
502
+ # Identify never-treated and timing groups
503
+ # Never-treated: first_treat = 0 or inf
504
+ never_treated_mask = (df[first_treat] == 0) | (df[first_treat] == np.inf)
505
+ timing_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf])
506
+
507
+ # Get unit-level treatment timing
508
+ unit_info = df.groupby(unit).agg({first_treat: "first"}).reset_index()
509
+ n_never_treated = ((unit_info[first_treat] == 0) | (unit_info[first_treat] == np.inf)).sum()
510
+
511
+ # Create treatment indicator (D_it = 1 if treated at time t)
512
+ # Use unique internal name to avoid conflicts with user data
513
+ _TREAT_COL = "__bacon_treated_internal__"
514
+ df[_TREAT_COL] = (~never_treated_mask) & (df[time] >= df[first_treat])
515
+
516
+ # First, compute TWFE estimate for reference
517
+ twfe_estimate = self._compute_twfe(
518
+ df, outcome, unit, time, _TREAT_COL, weights=survey_weights
519
+ )
520
+
521
+ # Perform decomposition
522
+ comparisons = []
523
+
524
+ # 1. Treated vs Never-treated comparisons
525
+ if n_never_treated > 0:
526
+ for g in timing_groups:
527
+ comp = self._compute_treated_vs_never(
528
+ df,
529
+ outcome,
530
+ unit,
531
+ time,
532
+ first_treat,
533
+ g,
534
+ time_periods,
535
+ weights=survey_weights,
536
+ )
537
+ if comp is not None:
538
+ comparisons.append(comp)
539
+
540
+ # 2. Timing group comparisons (earlier vs later and later vs earlier)
541
+ for i, g_early in enumerate(timing_groups):
542
+ for g_late in timing_groups[i + 1 :]:
543
+ # Earlier vs Later: g_early treated, g_late as control
544
+ comp_early = self._compute_timing_comparison(
545
+ df,
546
+ outcome,
547
+ unit,
548
+ time,
549
+ first_treat,
550
+ g_early,
551
+ g_late,
552
+ time_periods,
553
+ "earlier_vs_later",
554
+ weights=survey_weights,
555
+ )
556
+ if comp_early is not None:
557
+ comparisons.append(comp_early)
558
+
559
+ # Later vs Earlier: g_late treated, g_early as control (forbidden)
560
+ comp_late = self._compute_timing_comparison(
561
+ df,
562
+ outcome,
563
+ unit,
564
+ time,
565
+ first_treat,
566
+ g_late,
567
+ g_early,
568
+ time_periods,
569
+ "later_vs_earlier",
570
+ weights=survey_weights,
571
+ )
572
+ if comp_late is not None:
573
+ comparisons.append(comp_late)
574
+
575
+ # Recompute exact weights if requested
576
+ if self.weights == "exact":
577
+ self._recompute_exact_weights(
578
+ comparisons,
579
+ df,
580
+ outcome,
581
+ unit,
582
+ time,
583
+ first_treat,
584
+ time_periods,
585
+ weights=survey_weights,
586
+ )
587
+
588
+ if not comparisons:
589
+ raise ValueError(
590
+ "No valid 2x2 comparisons remain after filtering. "
591
+ "All cells have zero effective weight or insufficient data. "
592
+ "Check subpopulation/domain definition."
593
+ )
594
+
595
+ # Normalize weights to sum to 1
596
+ total_weight = sum(c.weight for c in comparisons)
597
+ if total_weight > 0:
598
+ for c in comparisons:
599
+ c.weight = c.weight / total_weight
600
+
601
+ # Calculate weight totals and weighted averages by type
602
+ weight_by_type = {"treated_vs_never": 0.0, "earlier_vs_later": 0.0, "later_vs_earlier": 0.0}
603
+ weighted_sum_by_type = {
604
+ "treated_vs_never": 0.0,
605
+ "earlier_vs_later": 0.0,
606
+ "later_vs_earlier": 0.0,
607
+ }
608
+
609
+ for c in comparisons:
610
+ weight_by_type[c.comparison_type] += c.weight
611
+ weighted_sum_by_type[c.comparison_type] += c.weight * c.estimate
612
+
613
+ # Calculate weighted averages
614
+ avg_by_type = {}
615
+ for ctype in weight_by_type:
616
+ if weight_by_type[ctype] > 0:
617
+ avg_by_type[ctype] = weighted_sum_by_type[ctype] / weight_by_type[ctype]
618
+ else:
619
+ avg_by_type[ctype] = None
620
+
621
+ # Calculate decomposition error
622
+ weighted_sum = sum(c.weight * c.estimate for c in comparisons)
623
+ decomp_error = abs(twfe_estimate - weighted_sum)
624
+
625
+ self.results_ = BaconDecompositionResults(
626
+ twfe_estimate=twfe_estimate,
627
+ comparisons=comparisons,
628
+ total_weight_treated_vs_never=weight_by_type["treated_vs_never"],
629
+ total_weight_earlier_vs_later=weight_by_type["earlier_vs_later"],
630
+ total_weight_later_vs_earlier=weight_by_type["later_vs_earlier"],
631
+ weighted_avg_treated_vs_never=avg_by_type["treated_vs_never"],
632
+ weighted_avg_earlier_vs_later=avg_by_type["earlier_vs_later"],
633
+ weighted_avg_later_vs_earlier=avg_by_type["later_vs_earlier"],
634
+ n_timing_groups=len(timing_groups),
635
+ n_never_treated=n_never_treated,
636
+ timing_groups=timing_groups,
637
+ n_obs=len(df),
638
+ decomposition_error=decomp_error,
639
+ survey_metadata=survey_metadata,
640
+ )
641
+
642
+ self.is_fitted_ = True
643
+ return self.results_
644
+
645
+ def _compute_twfe(
646
+ self,
647
+ df: pd.DataFrame,
648
+ outcome: str,
649
+ unit: str,
650
+ time: str,
651
+ treat_col: str = "__bacon_treated_internal__",
652
+ weights: Optional[np.ndarray] = None,
653
+ ) -> float:
654
+ """Compute TWFE estimate using within-transformation."""
655
+ # Apply two-way within transformation (weighted if survey weights provided)
656
+ df_dm = _within_transform_util(
657
+ df,
658
+ [outcome, treat_col],
659
+ unit,
660
+ time,
661
+ suffix="_within",
662
+ weights=weights,
663
+ )
664
+
665
+ # Extract within-transformed values
666
+ y_within = df_dm[f"{outcome}_within"].values
667
+ d_within = df_dm[f"{treat_col}_within"].values
668
+
669
+ # OLS on demeaned data: beta = sum(w * d * y) / sum(w * d^2)
670
+ w = weights if weights is not None else np.ones(len(y_within))
671
+ d_var = np.sum(w * d_within**2)
672
+ if d_var > 0:
673
+ beta = np.sum(w * d_within * y_within) / d_var
674
+ else:
675
+ beta = 0.0
676
+
677
+ return beta
678
+
679
+ def _recompute_exact_weights(
680
+ self,
681
+ comparisons: List[Comparison2x2],
682
+ df: pd.DataFrame,
683
+ outcome: str,
684
+ unit: str,
685
+ time: str,
686
+ first_treat: str,
687
+ time_periods: List[Any],
688
+ weights: Optional[np.ndarray] = None,
689
+ ) -> None:
690
+ """
691
+ Recompute weights using exact variance-based formula from Theorem 1.
692
+
693
+ This modifies comparison weights in-place to use the exact formula
694
+ from Goodman-Bacon (2021) which accounts for within-group variance
695
+ of the treatment indicator in each 2x2 comparison window.
696
+
697
+ When survey weights are provided, uses weighted unit counts and
698
+ within-group variance of the treatment indicator.
699
+ """
700
+ n_total_obs = len(df)
701
+ w_arr = weights if weights is not None else np.ones(n_total_obs)
702
+ # Store weights as a column for safe label-based subsetting
703
+ df = df.copy()
704
+ df["_sw"] = w_arr
705
+ w_total = np.sum(w_arr)
706
+ n_total_units = df[unit].nunique()
707
+
708
+ for comp in comparisons:
709
+ # Get data for this specific comparison
710
+ if comp.comparison_type == "treated_vs_never":
711
+ pre_periods = [t for t in time_periods if t < comp.treated_group]
712
+ post_periods = [t for t in time_periods if t >= comp.treated_group]
713
+ # Get units in each group
714
+ units_treated = df[df[first_treat] == comp.treated_group][unit].unique()
715
+ units_control = df[(df[first_treat] == 0) | (df[first_treat] == np.inf)][
716
+ unit
717
+ ].unique()
718
+ elif comp.comparison_type == "earlier_vs_later":
719
+ g_early = comp.treated_group
720
+ g_late = comp.control_group
721
+ pre_periods = [t for t in time_periods if t < g_early]
722
+ post_periods = [t for t in time_periods if g_early <= t < g_late]
723
+ units_treated = df[df[first_treat] == g_early][unit].unique()
724
+ units_control = df[df[first_treat] == g_late][unit].unique()
725
+ else: # later_vs_earlier
726
+ g_late = comp.treated_group
727
+ g_early = comp.control_group
728
+ pre_periods = [t for t in time_periods if g_early <= t < g_late]
729
+ post_periods = [t for t in time_periods if t >= g_late]
730
+ units_treated = df[df[first_treat] == g_late][unit].unique()
731
+ units_control = df[df[first_treat] == g_early][unit].unique()
732
+
733
+ if not pre_periods or not post_periods:
734
+ comp.weight = 0.0
735
+ continue
736
+
737
+ # Subset to the 2x2 comparison sample
738
+ relevant_periods = set(pre_periods) | set(post_periods)
739
+ all_units = set(units_treated) | set(units_control)
740
+
741
+ df_22 = df[(df[unit].isin(all_units)) & (df[time].isin(relevant_periods))]
742
+
743
+ if len(df_22) == 0:
744
+ comp.weight = 0.0
745
+ continue
746
+
747
+ # Count units in this comparison
748
+ n_k = len(units_treated)
749
+ n_l = len(units_control)
750
+
751
+ if n_k == 0 or n_l == 0:
752
+ comp.weight = 0.0
753
+ continue
754
+
755
+ # Weighted observation counts for the 2x2 sample
756
+ w_22 = df_22["_sw"].values
757
+ w_22_sum = np.sum(w_22)
758
+
759
+ # Sample share of this comparison (weighted)
760
+ sample_share = w_22_sum / w_total
761
+
762
+ # Weighted group shares within the 2x2
763
+ treated_mask_22 = df_22[unit].isin(units_treated)
764
+ w_k = np.sum(w_22[treated_mask_22.values])
765
+ n_k_share = w_k / w_22_sum if w_22_sum > 0 else 0.0
766
+
767
+ # Create treatment indicator for the 2x2
768
+ T_pre = len(pre_periods)
769
+ T_post = len(post_periods)
770
+ T_window = T_pre + T_post
771
+
772
+ # Variance of D within the 2x2 for treated group
773
+ # D = 0 in pre, D = 1 in post for treated units
774
+ # D = 0 for all periods for control units in this window
775
+ D_k = T_post / T_window # proportion treated for treated group
776
+
777
+ # Within-comparison variance of treatment (weighted)
778
+ # Var(D) = n_k/(n_k+n_l) * D_k * (1-D_k) for the 2x2
779
+ var_D_22 = n_k_share * D_k * (1 - D_k)
780
+
781
+ # Exact weight: proportional to sample share * variance
782
+ # Scale by weighted unit share to account for subsample
783
+ # Use survey-weighted unit mass when weights present
784
+ if weights is not None:
785
+ # Sum of per-unit weights for treated + control units in this 2x2
786
+ unit_w_k = (
787
+ df_22.loc[treated_mask_22, "_sw"]
788
+ .groupby(df_22.loc[treated_mask_22, unit])
789
+ .first()
790
+ .sum()
791
+ )
792
+ unit_w_l = (
793
+ df_22.loc[~treated_mask_22, "_sw"]
794
+ .groupby(df_22.loc[~treated_mask_22, unit])
795
+ .first()
796
+ .sum()
797
+ )
798
+ unit_share = (unit_w_k + unit_w_l) / w_total
799
+ else:
800
+ unit_share = (n_k + n_l) / n_total_units
801
+ comp.weight = sample_share * var_D_22 * unit_share
802
+
803
+ def _compute_treated_vs_never(
804
+ self,
805
+ df: pd.DataFrame,
806
+ outcome: str,
807
+ unit: str,
808
+ time: str,
809
+ first_treat: str,
810
+ treated_group: Any,
811
+ time_periods: List[Any],
812
+ weights: Optional[np.ndarray] = None,
813
+ ) -> Optional[Comparison2x2]:
814
+ """
815
+ Compute 2x2 DiD comparing treated group to never-treated.
816
+
817
+ This is a "clean" comparison using the full sample of a treated
818
+ cohort versus never-treated units.
819
+ """
820
+ # Get treated and never-treated units
821
+ never_mask = (df[first_treat] == 0) | (df[first_treat] == np.inf)
822
+ treated_mask = df[first_treat] == treated_group
823
+
824
+ df_treated = df[treated_mask]
825
+ df_never = df[never_mask]
826
+
827
+ if len(df_treated) == 0 or len(df_never) == 0:
828
+ return None
829
+
830
+ # Time window: all periods
831
+ t_min = min(time_periods)
832
+ t_max = max(time_periods)
833
+
834
+ # Pre and post periods for this group
835
+ pre_periods = [t for t in time_periods if t < treated_group]
836
+ post_periods = [t for t in time_periods if t >= treated_group]
837
+
838
+ if not pre_periods or not post_periods:
839
+ return None
840
+
841
+ # Compute 2x2 DiD estimate using weighted means if survey weights provided
842
+ w = weights if weights is not None else np.ones(len(df))
843
+ y = df[outcome].values
844
+
845
+ treated_pre_mask = treated_mask & df[time].isin(pre_periods)
846
+ treated_post_mask = treated_mask & df[time].isin(post_periods)
847
+ never_pre_mask = never_mask & df[time].isin(pre_periods)
848
+ never_post_mask = never_mask & df[time].isin(post_periods)
849
+
850
+ # Guard against empty cells (unbalanced/filtered panels)
851
+ # Also check positive weight mass for survey/subpopulation designs
852
+ if not (
853
+ np.any(treated_pre_mask)
854
+ and np.any(treated_post_mask)
855
+ and np.any(never_pre_mask)
856
+ and np.any(never_post_mask)
857
+ ):
858
+ return None
859
+ if (
860
+ np.sum(w[treated_pre_mask]) <= 0
861
+ or np.sum(w[treated_post_mask]) <= 0
862
+ or np.sum(w[never_pre_mask]) <= 0
863
+ or np.sum(w[never_post_mask]) <= 0
864
+ ):
865
+ return None
866
+
867
+ treated_pre = np.average(y[treated_pre_mask], weights=w[treated_pre_mask])
868
+ treated_post = np.average(y[treated_post_mask], weights=w[treated_post_mask])
869
+ never_pre = np.average(y[never_pre_mask], weights=w[never_pre_mask])
870
+ never_post = np.average(y[never_post_mask], weights=w[never_post_mask])
871
+
872
+ estimate = (treated_post - treated_pre) - (never_post - never_pre)
873
+
874
+ # Calculate weight components using weighted group shares
875
+ n_treated = df_treated[unit].nunique()
876
+ n_never = df_never[unit].nunique()
877
+
878
+ w_treated_sum = np.sum(w[treated_mask])
879
+ w_never_sum = np.sum(w[never_mask])
880
+ w_total = w_treated_sum + w_never_sum
881
+
882
+ # Weighted group share
883
+ n_k = w_treated_sum / w_total if w_total > 0 else 0.0
884
+
885
+ # Variance of treatment: proportion of post-treatment periods
886
+ D_k = len(post_periods) / len(time_periods)
887
+
888
+ # Weight is proportional to n_k * (1 - n_k) * Var(D_k)
889
+ # Var(D) for treated group = D_k * (1 - D_k)
890
+ weight = n_k * (1 - n_k) * D_k * (1 - D_k)
891
+
892
+ return Comparison2x2(
893
+ treated_group=treated_group,
894
+ control_group="never_treated",
895
+ comparison_type="treated_vs_never",
896
+ estimate=estimate,
897
+ weight=weight,
898
+ n_treated=n_treated,
899
+ n_control=n_never,
900
+ time_window=(t_min, t_max),
901
+ )
902
+
903
+ def _compute_timing_comparison(
904
+ self,
905
+ df: pd.DataFrame,
906
+ outcome: str,
907
+ unit: str,
908
+ time: str,
909
+ first_treat: str,
910
+ treated_group: Any,
911
+ control_group: Any,
912
+ time_periods: List[Any],
913
+ comparison_type: str,
914
+ weights: Optional[np.ndarray] = None,
915
+ ) -> Optional[Comparison2x2]:
916
+ """
917
+ Compute 2x2 DiD comparing two timing groups.
918
+
919
+ For earlier_vs_later: uses later group as controls before they're treated.
920
+ For later_vs_earlier: uses earlier group as controls after treatment (forbidden).
921
+ """
922
+ treated_mask = df[first_treat] == treated_group
923
+ control_mask = df[first_treat] == control_group
924
+
925
+ df_treated = df[treated_mask]
926
+ df_control = df[control_mask]
927
+
928
+ if len(df_treated) == 0 or len(df_control) == 0:
929
+ return None
930
+
931
+ n_treated = df_treated[unit].nunique()
932
+ n_control = df_control[unit].nunique()
933
+
934
+ if comparison_type == "earlier_vs_later":
935
+ # Earlier treated vs Later treated
936
+ # Time window: from start to when later group gets treated
937
+ # Pre: before earlier group treated
938
+ # Post: after earlier treated but before later treated
939
+ g_early = treated_group
940
+ g_late = control_group
941
+
942
+ # Pre-period: before g_early
943
+ pre_periods = [t for t in time_periods if t < g_early]
944
+ # Post-period: g_early <= t < g_late (middle period)
945
+ post_periods = [t for t in time_periods if g_early <= t < g_late]
946
+
947
+ if not pre_periods or not post_periods:
948
+ return None
949
+
950
+ time_window = (min(time_periods), g_late - 1)
951
+
952
+ else: # later_vs_earlier (forbidden)
953
+ # Later treated vs Earlier treated (used as control after treatment)
954
+ g_late = treated_group
955
+ g_early = control_group
956
+
957
+ # Pre-period: after g_early treated but before g_late treated
958
+ pre_periods = [t for t in time_periods if g_early <= t < g_late]
959
+ # Post-period: after g_late treated
960
+ post_periods = [t for t in time_periods if t >= g_late]
961
+
962
+ if not pre_periods or not post_periods:
963
+ return None
964
+
965
+ time_window = (g_early, max(time_periods))
966
+
967
+ # Compute 2x2 DiD estimate using weighted means if survey weights provided
968
+ w = weights if weights is not None else np.ones(len(df))
969
+ y = df[outcome].values
970
+
971
+ treated_pre_mask = treated_mask & df[time].isin(pre_periods)
972
+ treated_post_mask = treated_mask & df[time].isin(post_periods)
973
+ control_pre_mask = control_mask & df[time].isin(pre_periods)
974
+ control_post_mask = control_mask & df[time].isin(post_periods)
975
+
976
+ # Skip if any cell is empty or has zero effective weight
977
+ if (
978
+ treated_pre_mask.sum() == 0
979
+ or treated_post_mask.sum() == 0
980
+ or control_pre_mask.sum() == 0
981
+ or control_post_mask.sum() == 0
982
+ ):
983
+ return None
984
+ if (
985
+ np.sum(w[treated_pre_mask]) <= 0
986
+ or np.sum(w[treated_post_mask]) <= 0
987
+ or np.sum(w[control_pre_mask]) <= 0
988
+ or np.sum(w[control_post_mask]) <= 0
989
+ ):
990
+ return None
991
+
992
+ treated_pre = np.average(y[treated_pre_mask], weights=w[treated_pre_mask])
993
+ treated_post = np.average(y[treated_post_mask], weights=w[treated_post_mask])
994
+ control_pre = np.average(y[control_pre_mask], weights=w[control_pre_mask])
995
+ control_post = np.average(y[control_post_mask], weights=w[control_post_mask])
996
+
997
+ if np.isnan(treated_pre) or np.isnan(treated_post):
998
+ return None
999
+ if np.isnan(control_pre) or np.isnan(control_post):
1000
+ return None
1001
+
1002
+ estimate = (treated_post - treated_pre) - (control_post - control_pre)
1003
+
1004
+ # Calculate weight using weighted group shares
1005
+ w_treated_sum = np.sum(w[treated_mask])
1006
+ w_control_sum = np.sum(w[control_mask])
1007
+ w_total = w_treated_sum + w_control_sum
1008
+ n_k = w_treated_sum / w_total if w_total > 0 else 0.0
1009
+
1010
+ # Variance of treatment within the comparison window
1011
+ total_periods_in_window = len(pre_periods) + len(post_periods)
1012
+ D_k = len(post_periods) / total_periods_in_window if total_periods_in_window > 0 else 0
1013
+
1014
+ # Weight proportional to group sizes and treatment variance
1015
+ # Scale by the fraction of total time this comparison covers
1016
+ time_share = total_periods_in_window / len(time_periods)
1017
+ weight = n_k * (1 - n_k) * D_k * (1 - D_k) * time_share
1018
+
1019
+ return Comparison2x2(
1020
+ treated_group=treated_group,
1021
+ control_group=control_group,
1022
+ comparison_type=comparison_type,
1023
+ estimate=estimate,
1024
+ weight=weight,
1025
+ n_treated=n_treated,
1026
+ n_control=n_control,
1027
+ time_window=time_window,
1028
+ )
1029
+
1030
+ def get_params(self) -> Dict[str, Any]:
1031
+ """Get estimator parameters (sklearn-compatible)."""
1032
+ return {"weights": self.weights}
1033
+
1034
+ def set_params(self, **params) -> "BaconDecomposition":
1035
+ """Set estimator parameters (sklearn-compatible)."""
1036
+ if "weights" in params:
1037
+ if params["weights"] not in ("approximate", "exact"):
1038
+ raise ValueError(
1039
+ f"weights must be 'approximate' or 'exact', " f"got '{params['weights']}'"
1040
+ )
1041
+ self.weights = params["weights"]
1042
+ return self
1043
+
1044
+ def summary(self) -> str:
1045
+ """Get summary of decomposition results."""
1046
+ if not self.is_fitted_:
1047
+ raise RuntimeError("Model must be fitted before calling summary()")
1048
+ assert self.results_ is not None
1049
+ return self.results_.summary()
1050
+
1051
+ def print_summary(self) -> None:
1052
+ """Print summary to stdout."""
1053
+ print(self.summary())
1054
+
1055
+
1056
+ def bacon_decompose(
1057
+ data: pd.DataFrame,
1058
+ outcome: str,
1059
+ unit: str,
1060
+ time: str,
1061
+ first_treat: str,
1062
+ weights: str = "approximate",
1063
+ survey_design: object = None,
1064
+ ) -> BaconDecompositionResults:
1065
+ """
1066
+ Convenience function for Goodman-Bacon decomposition.
1067
+
1068
+ Decomposes a TWFE estimate into weighted 2x2 DiD comparisons,
1069
+ showing which comparisons drive the estimate and whether
1070
+ problematic "forbidden comparisons" are involved.
1071
+
1072
+ Parameters
1073
+ ----------
1074
+ data : pd.DataFrame
1075
+ Panel data with unit and time identifiers.
1076
+ outcome : str
1077
+ Name of outcome variable column.
1078
+ unit : str
1079
+ Name of unit identifier column.
1080
+ time : str
1081
+ Name of time period column.
1082
+ first_treat : str
1083
+ Name of column indicating when unit was first treated.
1084
+ Use 0 (or np.inf) for never-treated units.
1085
+ weights : str, default="approximate"
1086
+ Weight calculation method:
1087
+ - "approximate": Fast simplified formula (default). Good for
1088
+ diagnostic purposes where relative weights are sufficient.
1089
+ - "exact": Variance-based weights from Goodman-Bacon (2021)
1090
+ Theorem 1. Use for publication-quality decompositions.
1091
+
1092
+ Returns
1093
+ -------
1094
+ BaconDecompositionResults
1095
+ Object containing decomposition results with:
1096
+ - twfe_estimate: The overall TWFE coefficient
1097
+ - comparisons: List of all 2x2 comparisons with estimates and weights
1098
+ - Weight totals by comparison type
1099
+ - Methods for visualization and export
1100
+
1101
+ Examples
1102
+ --------
1103
+ >>> from diff_diff import bacon_decompose
1104
+ >>>
1105
+ >>> # Quick diagnostic (default)
1106
+ >>> results = bacon_decompose(
1107
+ ... data=panel_df,
1108
+ ... outcome='earnings',
1109
+ ... unit='state',
1110
+ ... time='year',
1111
+ ... first_treat='treatment_year'
1112
+ ... )
1113
+ >>>
1114
+ >>> # Publication-quality exact decomposition
1115
+ >>> results = bacon_decompose(
1116
+ ... data=panel_df,
1117
+ ... outcome='earnings',
1118
+ ... unit='state',
1119
+ ... time='year',
1120
+ ... first_treat='treatment_year',
1121
+ ... weights='exact'
1122
+ ... )
1123
+ >>>
1124
+ >>> # View summary
1125
+ >>> results.print_summary()
1126
+ >>>
1127
+ >>> # Check weight on forbidden comparisons
1128
+ >>> print(f"Forbidden weight: {results.total_weight_later_vs_earlier:.1%}")
1129
+ >>>
1130
+ >>> # Convert to DataFrame for analysis
1131
+ >>> df = results.to_dataframe()
1132
+
1133
+ See Also
1134
+ --------
1135
+ BaconDecomposition : Class-based interface with more options
1136
+ plot_bacon : Visualize the decomposition
1137
+ CallawaySantAnna : Robust estimator that avoids forbidden comparisons
1138
+ """
1139
+ decomp = BaconDecomposition(weights=weights)
1140
+ return decomp.fit(data, outcome, unit, time, first_treat, survey_design=survey_design)