diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/bacon.py ADDED
@@ -0,0 +1,979 @@
1
+ """
2
+ Goodman-Bacon Decomposition for Two-Way Fixed Effects.
3
+
4
+ Implements the decomposition from Goodman-Bacon (2021) that shows how
5
+ TWFE estimates with staggered treatment timing can be written as a
6
+ weighted average of all possible 2x2 DiD comparisons.
7
+
8
+ Reference:
9
+ Goodman-Bacon, A. (2021). Difference-in-differences with variation
10
+ in treatment timing. Journal of Econometrics, 225(2), 254-277.
11
+ """
12
+
13
+ import warnings
14
+ from dataclasses import dataclass, field
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ from diff_diff.utils import within_transform as _within_transform_util
21
+
22
+
23
+ @dataclass
24
+ class Comparison2x2:
25
+ """
26
+ A single 2x2 DiD comparison in the Bacon decomposition.
27
+
28
+ Attributes
29
+ ----------
30
+ treated_group : Any
31
+ The timing group used as "treated" in this comparison.
32
+ control_group : Any
33
+ The timing group used as "control" in this comparison.
34
+ comparison_type : str
35
+ Type of comparison: "treated_vs_never", "earlier_vs_later",
36
+ or "later_vs_earlier".
37
+ estimate : float
38
+ The 2x2 DiD estimate for this comparison.
39
+ weight : float
40
+ The weight assigned to this comparison in the TWFE average.
41
+ n_treated : int
42
+ Number of treated observations in this comparison.
43
+ n_control : int
44
+ Number of control observations in this comparison.
45
+ time_window : Tuple[float, float]
46
+ The (start, end) time period for this comparison.
47
+ """
48
+
49
+ treated_group: Any
50
+ control_group: Any
51
+ comparison_type: str
52
+ estimate: float
53
+ weight: float
54
+ n_treated: int
55
+ n_control: int
56
+ time_window: Tuple[float, float]
57
+
58
+ def __repr__(self) -> str:
59
+ return (
60
+ f"Comparison2x2({self.treated_group} vs {self.control_group}, "
61
+ f"type={self.comparison_type}, β={self.estimate:.4f}, "
62
+ f"weight={self.weight:.4f})"
63
+ )
64
+
65
+
66
+ @dataclass
67
+ class BaconDecompositionResults:
68
+ """
69
+ Results from Goodman-Bacon decomposition of TWFE.
70
+
71
+ This decomposition shows that the TWFE estimate equals a weighted
72
+ average of all possible 2x2 DiD comparisons between timing groups.
73
+
74
+ Attributes
75
+ ----------
76
+ twfe_estimate : float
77
+ The overall TWFE coefficient (should equal weighted sum of 2x2 estimates).
78
+ comparisons : List[Comparison2x2]
79
+ List of all 2x2 comparisons with their estimates and weights.
80
+ total_weight_treated_vs_never : float
81
+ Total weight on treated vs never-treated comparisons.
82
+ total_weight_earlier_vs_later : float
83
+ Total weight on earlier vs later treated comparisons.
84
+ total_weight_later_vs_earlier : float
85
+ Total weight on later vs earlier treated comparisons (forbidden).
86
+ weighted_avg_treated_vs_never : float
87
+ Weighted average effect from treated vs never-treated comparisons.
88
+ weighted_avg_earlier_vs_later : float
89
+ Weighted average effect from earlier vs later comparisons.
90
+ weighted_avg_later_vs_earlier : float
91
+ Weighted average effect from later vs earlier comparisons.
92
+ n_timing_groups : int
93
+ Number of distinct treatment timing groups.
94
+ n_never_treated : int
95
+ Number of never-treated units.
96
+ timing_groups : List[Any]
97
+ List of treatment timing cohorts.
98
+ """
99
+
100
+ twfe_estimate: float
101
+ comparisons: List[Comparison2x2]
102
+ total_weight_treated_vs_never: float
103
+ total_weight_earlier_vs_later: float
104
+ total_weight_later_vs_earlier: float
105
+ weighted_avg_treated_vs_never: Optional[float]
106
+ weighted_avg_earlier_vs_later: Optional[float]
107
+ weighted_avg_later_vs_earlier: Optional[float]
108
+ n_timing_groups: int
109
+ n_never_treated: int
110
+ timing_groups: List[Any]
111
+ n_obs: int = 0
112
+ decomposition_error: float = field(default=0.0)
113
+
114
+ def __repr__(self) -> str:
115
+ return (
116
+ f"BaconDecompositionResults(TWFE={self.twfe_estimate:.4f}, "
117
+ f"n_comparisons={len(self.comparisons)}, "
118
+ f"n_groups={self.n_timing_groups})"
119
+ )
120
+
121
+ def summary(self) -> str:
122
+ """
123
+ Generate a formatted summary of the decomposition.
124
+
125
+ Returns
126
+ -------
127
+ str
128
+ Formatted summary table.
129
+ """
130
+ lines = [
131
+ "=" * 85,
132
+ "Goodman-Bacon Decomposition of Two-Way Fixed Effects".center(85),
133
+ "=" * 85,
134
+ "",
135
+ f"{'Total observations:':<35} {self.n_obs:>10}",
136
+ f"{'Treatment timing groups:':<35} {self.n_timing_groups:>10}",
137
+ f"{'Never-treated units:':<35} {self.n_never_treated:>10}",
138
+ f"{'Total 2x2 comparisons:':<35} {len(self.comparisons):>10}",
139
+ "",
140
+ "-" * 85,
141
+ "TWFE Decomposition".center(85),
142
+ "-" * 85,
143
+ "",
144
+ f"{'TWFE Estimate:':<35} {self.twfe_estimate:>12.4f}",
145
+ f"{'Weighted Sum of 2x2 Estimates:':<35} {self._weighted_sum():>12.4f}",
146
+ f"{'Decomposition Error:':<35} {self.decomposition_error:>12.6f}",
147
+ "",
148
+ ]
149
+
150
+ # Weight breakdown by comparison type
151
+ lines.extend([
152
+ "-" * 85,
153
+ "Weight Breakdown by Comparison Type".center(85),
154
+ "-" * 85,
155
+ f"{'Comparison Type':<30} {'Weight':>12} {'Avg Effect':>12} {'Contribution':>12}",
156
+ "-" * 85,
157
+ ])
158
+
159
+ # Treated vs Never-treated
160
+ if self.total_weight_treated_vs_never > 0:
161
+ contrib = self.total_weight_treated_vs_never * (
162
+ self.weighted_avg_treated_vs_never or 0
163
+ )
164
+ lines.append(
165
+ f"{'Treated vs Never-treated':<30} "
166
+ f"{self.total_weight_treated_vs_never:>12.4f} "
167
+ f"{self.weighted_avg_treated_vs_never or 0:>12.4f} "
168
+ f"{contrib:>12.4f}"
169
+ )
170
+
171
+ # Earlier vs Later
172
+ if self.total_weight_earlier_vs_later > 0:
173
+ contrib = self.total_weight_earlier_vs_later * (
174
+ self.weighted_avg_earlier_vs_later or 0
175
+ )
176
+ lines.append(
177
+ f"{'Earlier vs Later treated':<30} "
178
+ f"{self.total_weight_earlier_vs_later:>12.4f} "
179
+ f"{self.weighted_avg_earlier_vs_later or 0:>12.4f} "
180
+ f"{contrib:>12.4f}"
181
+ )
182
+
183
+ # Later vs Earlier (forbidden)
184
+ if self.total_weight_later_vs_earlier > 0:
185
+ contrib = self.total_weight_later_vs_earlier * (
186
+ self.weighted_avg_later_vs_earlier or 0
187
+ )
188
+ lines.append(
189
+ f"{'Later vs Earlier (forbidden)':<30} "
190
+ f"{self.total_weight_later_vs_earlier:>12.4f} "
191
+ f"{self.weighted_avg_later_vs_earlier or 0:>12.4f} "
192
+ f"{contrib:>12.4f}"
193
+ )
194
+
195
+ lines.extend([
196
+ "-" * 85,
197
+ f"{'Total':<30} {self._total_weight():>12.4f} "
198
+ f"{'':>12} {self._weighted_sum():>12.4f}",
199
+ "-" * 85,
200
+ "",
201
+ ])
202
+
203
+ # Warning about forbidden comparisons
204
+ if self.total_weight_later_vs_earlier > 0.01:
205
+ pct = self.total_weight_later_vs_earlier * 100
206
+ lines.extend([
207
+ "WARNING: {:.1f}% of weight is on 'forbidden' comparisons where".format(
208
+ pct
209
+ ),
210
+ "already-treated units serve as controls. This can bias TWFE",
211
+ "when treatment effects are heterogeneous over time.",
212
+ "",
213
+ "Consider using Callaway-Sant'Anna or other robust estimators.",
214
+ "",
215
+ ])
216
+
217
+ lines.append("=" * 85)
218
+
219
+ return "\n".join(lines)
220
+
221
+ def print_summary(self) -> None:
222
+ """Print the summary to stdout."""
223
+ print(self.summary())
224
+
225
+ def _weighted_sum(self) -> float:
226
+ """Calculate weighted sum of 2x2 estimates."""
227
+ return sum(c.weight * c.estimate for c in self.comparisons)
228
+
229
+ def _total_weight(self) -> float:
230
+ """Calculate total weight (should be 1.0)."""
231
+ return sum(c.weight for c in self.comparisons)
232
+
233
+ def to_dataframe(self) -> pd.DataFrame:
234
+ """
235
+ Convert comparisons to a DataFrame.
236
+
237
+ Returns
238
+ -------
239
+ pd.DataFrame
240
+ DataFrame with one row per 2x2 comparison.
241
+ """
242
+ rows = []
243
+ for c in self.comparisons:
244
+ rows.append({
245
+ "treated_group": c.treated_group,
246
+ "control_group": c.control_group,
247
+ "comparison_type": c.comparison_type,
248
+ "estimate": c.estimate,
249
+ "weight": c.weight,
250
+ "n_treated": c.n_treated,
251
+ "n_control": c.n_control,
252
+ "time_start": c.time_window[0],
253
+ "time_end": c.time_window[1],
254
+ })
255
+ return pd.DataFrame(rows)
256
+
257
+ def weight_by_type(self) -> Dict[str, float]:
258
+ """
259
+ Get total weight by comparison type.
260
+
261
+ Returns
262
+ -------
263
+ Dict[str, float]
264
+ Dictionary mapping comparison type to total weight.
265
+ """
266
+ return {
267
+ "treated_vs_never": self.total_weight_treated_vs_never,
268
+ "earlier_vs_later": self.total_weight_earlier_vs_later,
269
+ "later_vs_earlier": self.total_weight_later_vs_earlier,
270
+ }
271
+
272
+ def effect_by_type(self) -> Dict[str, Optional[float]]:
273
+ """
274
+ Get weighted average effect by comparison type.
275
+
276
+ Returns
277
+ -------
278
+ Dict[str, Optional[float]]
279
+ Dictionary mapping comparison type to weighted average effect.
280
+ """
281
+ return {
282
+ "treated_vs_never": self.weighted_avg_treated_vs_never,
283
+ "earlier_vs_later": self.weighted_avg_earlier_vs_later,
284
+ "later_vs_earlier": self.weighted_avg_later_vs_earlier,
285
+ }
286
+
287
+
288
+ class BaconDecomposition:
289
+ """
290
+ Goodman-Bacon (2021) decomposition of Two-Way Fixed Effects estimator.
291
+
292
+ This class decomposes a TWFE estimate into a weighted average of all
293
+ possible 2x2 DiD comparisons, revealing the implicit comparisons that
294
+ drive the TWFE estimate and their relative importance.
295
+
296
+ The decomposition identifies three types of comparisons:
297
+
298
+ 1. **Treated vs Never-treated**: Uses never-treated units as controls.
299
+ These are "clean" comparisons without bias concerns.
300
+
301
+ 2. **Earlier vs Later treated**: Units treated earlier are compared to
302
+ units treated later, using the later group as controls before they
303
+ are treated. These are valid comparisons.
304
+
305
+ 3. **Later vs Earlier treated**: Units treated later are compared to
306
+ units treated earlier, using the earlier group as controls AFTER
307
+ they are already treated. These are "forbidden comparisons" that
308
+ can introduce bias when treatment effects vary over time.
309
+
310
+ Parameters
311
+ ----------
312
+ weights : str, default="approximate"
313
+ Weight calculation method:
314
+ - "approximate": Fast simplified formula using group shares and
315
+ treatment variance. Good for diagnostic purposes where relative
316
+ weights are sufficient to identify problematic comparisons.
317
+ - "exact": Variance-based weights from Goodman-Bacon (2021) Theorem 1.
318
+ Use for publication-quality decompositions where the weighted sum
319
+ must closely match the TWFE estimate.
320
+
321
+ Attributes
322
+ ----------
323
+ weights : str
324
+ The weight calculation method.
325
+ results_ : BaconDecompositionResults
326
+ Decomposition results after calling fit().
327
+ is_fitted_ : bool
328
+ Whether the model has been fitted.
329
+
330
+ Examples
331
+ --------
332
+ Basic usage:
333
+
334
+ >>> import pandas as pd
335
+ >>> from diff_diff import BaconDecomposition
336
+ >>>
337
+ >>> # Panel data with staggered treatment
338
+ >>> data = pd.DataFrame({
339
+ ... 'unit': [...],
340
+ ... 'time': [...],
341
+ ... 'outcome': [...],
342
+ ... 'first_treat': [...] # 0 for never-treated
343
+ ... })
344
+ >>>
345
+ >>> bacon = BaconDecomposition()
346
+ >>> results = bacon.fit(data, outcome='outcome', unit='unit',
347
+ ... time='time', first_treat='first_treat')
348
+ >>> results.print_summary()
349
+
350
+ Visualizing the decomposition:
351
+
352
+ >>> from diff_diff import plot_bacon
353
+ >>> plot_bacon(results)
354
+
355
+ Notes
356
+ -----
357
+ The key insight from Goodman-Bacon (2021) is that TWFE with staggered
358
+ treatment timing implicitly makes comparisons using already-treated
359
+ units as controls. When treatment effects are dynamic (changing over
360
+ time since treatment), these "forbidden comparisons" can bias the
361
+ TWFE estimate, potentially even reversing its sign.
362
+
363
+ The decomposition helps diagnose this issue by showing:
364
+ - How much weight is on each type of comparison
365
+ - Whether forbidden comparisons contribute significantly to the estimate
366
+ - How the 2x2 estimates vary across comparison types
367
+
368
+ If forbidden comparisons have substantial weight and different estimates
369
+ than clean comparisons, consider using robust estimators like
370
+ Callaway-Sant'Anna that avoid these problematic comparisons.
371
+
372
+ References
373
+ ----------
374
+ Goodman-Bacon, A. (2021). Difference-in-differences with variation in
375
+ treatment timing. Journal of Econometrics, 225(2), 254-277.
376
+
377
+ See Also
378
+ --------
379
+ CallawaySantAnna : Robust estimator for staggered DiD
380
+ TwoWayFixedEffects : The TWFE estimator being decomposed
381
+ """
382
+
383
+ def __init__(self, weights: str = "approximate"):
384
+ """
385
+ Initialize BaconDecomposition.
386
+
387
+ Parameters
388
+ ----------
389
+ weights : str, default="approximate"
390
+ Weight calculation method:
391
+ - "approximate": Fast simplified formula (default)
392
+ - "exact": Variance-based weights from Goodman-Bacon (2021)
393
+ """
394
+ if weights not in ("approximate", "exact"):
395
+ raise ValueError(
396
+ f"weights must be 'approximate' or 'exact', got '{weights}'"
397
+ )
398
+ self.weights = weights
399
+ self.results_: Optional[BaconDecompositionResults] = None
400
+ self.is_fitted_: bool = False
401
+
402
+ def fit(
403
+ self,
404
+ data: pd.DataFrame,
405
+ outcome: str,
406
+ unit: str,
407
+ time: str,
408
+ first_treat: str,
409
+ ) -> BaconDecompositionResults:
410
+ """
411
+ Perform the Goodman-Bacon decomposition.
412
+
413
+ Parameters
414
+ ----------
415
+ data : pd.DataFrame
416
+ Panel data with unit and time identifiers.
417
+ outcome : str
418
+ Name of outcome variable column.
419
+ unit : str
420
+ Name of unit identifier column.
421
+ time : str
422
+ Name of time period column.
423
+ first_treat : str
424
+ Name of column indicating when unit was first treated.
425
+ Use 0 (or np.inf) for never-treated units.
426
+
427
+ Returns
428
+ -------
429
+ BaconDecompositionResults
430
+ Object containing decomposition results.
431
+
432
+ Raises
433
+ ------
434
+ ValueError
435
+ If required columns are missing or data validation fails.
436
+ """
437
+ # Validate inputs
438
+ required_cols = [outcome, unit, time, first_treat]
439
+ missing = [c for c in required_cols if c not in data.columns]
440
+ if missing:
441
+ raise ValueError(f"Missing columns: {missing}")
442
+
443
+ # Create working copy
444
+ df = data.copy()
445
+
446
+ # Ensure numeric types
447
+ df[time] = pd.to_numeric(df[time])
448
+ df[first_treat] = pd.to_numeric(df[first_treat])
449
+
450
+ # Check for balanced panel
451
+ periods_per_unit = df.groupby(unit)[time].count()
452
+ if periods_per_unit.nunique() > 1:
453
+ warnings.warn(
454
+ "Unbalanced panel detected. Bacon decomposition assumes "
455
+ "balanced panels. Results may be inaccurate.",
456
+ UserWarning,
457
+ stacklevel=2,
458
+ )
459
+
460
+ # Get unique time periods and timing groups
461
+ time_periods = sorted(df[time].unique())
462
+
463
+ # Identify never-treated and timing groups
464
+ # Never-treated: first_treat = 0 or inf
465
+ never_treated_mask = (df[first_treat] == 0) | (df[first_treat] == np.inf)
466
+ timing_groups = sorted([g for g in df[first_treat].unique()
467
+ if g > 0 and g != np.inf])
468
+
469
+ # Get unit-level treatment timing
470
+ unit_info = df.groupby(unit).agg({first_treat: 'first'}).reset_index()
471
+ n_never_treated = (
472
+ (unit_info[first_treat] == 0) | (unit_info[first_treat] == np.inf)
473
+ ).sum()
474
+
475
+ # Create treatment indicator (D_it = 1 if treated at time t)
476
+ # Use unique internal name to avoid conflicts with user data
477
+ _TREAT_COL = '__bacon_treated_internal__'
478
+ df[_TREAT_COL] = (~never_treated_mask) & (df[time] >= df[first_treat])
479
+
480
+ # First, compute TWFE estimate for reference
481
+ twfe_estimate = self._compute_twfe(df, outcome, unit, time, _TREAT_COL)
482
+
483
+ # Perform decomposition
484
+ comparisons = []
485
+
486
+ # 1. Treated vs Never-treated comparisons
487
+ if n_never_treated > 0:
488
+ for g in timing_groups:
489
+ comp = self._compute_treated_vs_never(
490
+ df, outcome, unit, time, first_treat, g, time_periods
491
+ )
492
+ if comp is not None:
493
+ comparisons.append(comp)
494
+
495
+ # 2. Timing group comparisons (earlier vs later and later vs earlier)
496
+ for i, g_early in enumerate(timing_groups):
497
+ for g_late in timing_groups[i + 1:]:
498
+ # Earlier vs Later: g_early treated, g_late as control
499
+ comp_early = self._compute_timing_comparison(
500
+ df, outcome, unit, time, first_treat,
501
+ g_early, g_late, time_periods, "earlier_vs_later"
502
+ )
503
+ if comp_early is not None:
504
+ comparisons.append(comp_early)
505
+
506
+ # Later vs Earlier: g_late treated, g_early as control (forbidden)
507
+ comp_late = self._compute_timing_comparison(
508
+ df, outcome, unit, time, first_treat,
509
+ g_late, g_early, time_periods, "later_vs_earlier"
510
+ )
511
+ if comp_late is not None:
512
+ comparisons.append(comp_late)
513
+
514
+ # Recompute exact weights if requested
515
+ if self.weights == "exact":
516
+ self._recompute_exact_weights(
517
+ comparisons, df, outcome, unit, time, first_treat, time_periods
518
+ )
519
+
520
+ # Normalize weights to sum to 1
521
+ total_weight = sum(c.weight for c in comparisons)
522
+ if total_weight > 0:
523
+ for c in comparisons:
524
+ c.weight = c.weight / total_weight
525
+
526
+ # Calculate weight totals and weighted averages by type
527
+ weight_by_type = {"treated_vs_never": 0.0, "earlier_vs_later": 0.0,
528
+ "later_vs_earlier": 0.0}
529
+ weighted_sum_by_type = {"treated_vs_never": 0.0, "earlier_vs_later": 0.0,
530
+ "later_vs_earlier": 0.0}
531
+
532
+ for c in comparisons:
533
+ weight_by_type[c.comparison_type] += c.weight
534
+ weighted_sum_by_type[c.comparison_type] += c.weight * c.estimate
535
+
536
+ # Calculate weighted averages
537
+ avg_by_type = {}
538
+ for ctype in weight_by_type:
539
+ if weight_by_type[ctype] > 0:
540
+ avg_by_type[ctype] = (
541
+ weighted_sum_by_type[ctype] / weight_by_type[ctype]
542
+ )
543
+ else:
544
+ avg_by_type[ctype] = None
545
+
546
+ # Calculate decomposition error
547
+ weighted_sum = sum(c.weight * c.estimate for c in comparisons)
548
+ decomp_error = abs(twfe_estimate - weighted_sum)
549
+
550
+ self.results_ = BaconDecompositionResults(
551
+ twfe_estimate=twfe_estimate,
552
+ comparisons=comparisons,
553
+ total_weight_treated_vs_never=weight_by_type["treated_vs_never"],
554
+ total_weight_earlier_vs_later=weight_by_type["earlier_vs_later"],
555
+ total_weight_later_vs_earlier=weight_by_type["later_vs_earlier"],
556
+ weighted_avg_treated_vs_never=avg_by_type["treated_vs_never"],
557
+ weighted_avg_earlier_vs_later=avg_by_type["earlier_vs_later"],
558
+ weighted_avg_later_vs_earlier=avg_by_type["later_vs_earlier"],
559
+ n_timing_groups=len(timing_groups),
560
+ n_never_treated=n_never_treated,
561
+ timing_groups=timing_groups,
562
+ n_obs=len(df),
563
+ decomposition_error=decomp_error,
564
+ )
565
+
566
+ self.is_fitted_ = True
567
+ return self.results_
568
+
569
+ def _compute_twfe(
570
+ self,
571
+ df: pd.DataFrame,
572
+ outcome: str,
573
+ unit: str,
574
+ time: str,
575
+ treat_col: str = '__bacon_treated_internal__',
576
+ ) -> float:
577
+ """Compute TWFE estimate using within-transformation."""
578
+ # Apply two-way within transformation
579
+ df_dm = _within_transform_util(
580
+ df, [outcome, treat_col], unit, time, suffix="_within"
581
+ )
582
+
583
+ # Extract within-transformed values
584
+ y_within = df_dm[f"{outcome}_within"].values
585
+ d_within = df_dm[f"{treat_col}_within"].values
586
+
587
+ # OLS on demeaned data: beta = sum(d * y) / sum(d^2)
588
+ d_var = np.sum(d_within ** 2)
589
+ if d_var > 0:
590
+ beta = np.sum(d_within * y_within) / d_var
591
+ else:
592
+ beta = 0.0
593
+
594
+ return beta
595
+
596
+ def _recompute_exact_weights(
597
+ self,
598
+ comparisons: List[Comparison2x2],
599
+ df: pd.DataFrame,
600
+ outcome: str,
601
+ unit: str,
602
+ time: str,
603
+ first_treat: str,
604
+ time_periods: List[Any],
605
+ ) -> None:
606
+ """
607
+ Recompute weights using exact variance-based formula from Theorem 1.
608
+
609
+ This modifies comparison weights in-place to use the exact formula
610
+ from Goodman-Bacon (2021) which accounts for within-group variance
611
+ of the treatment indicator in each 2x2 comparison window.
612
+ """
613
+ n_total_obs = len(df)
614
+ n_total_units = df[unit].nunique()
615
+
616
+ for comp in comparisons:
617
+ # Get data for this specific comparison
618
+ if comp.comparison_type == "treated_vs_never":
619
+ pre_periods = [t for t in time_periods if t < comp.treated_group]
620
+ post_periods = [t for t in time_periods if t >= comp.treated_group]
621
+ # Get units in each group
622
+ units_treated = df[df[first_treat] == comp.treated_group][unit].unique()
623
+ units_control = df[
624
+ (df[first_treat] == 0) | (df[first_treat] == np.inf)
625
+ ][unit].unique()
626
+ elif comp.comparison_type == "earlier_vs_later":
627
+ g_early = comp.treated_group
628
+ g_late = comp.control_group
629
+ pre_periods = [t for t in time_periods if t < g_early]
630
+ post_periods = [t for t in time_periods if g_early <= t < g_late]
631
+ units_treated = df[df[first_treat] == g_early][unit].unique()
632
+ units_control = df[df[first_treat] == g_late][unit].unique()
633
+ else: # later_vs_earlier
634
+ g_late = comp.treated_group
635
+ g_early = comp.control_group
636
+ pre_periods = [t for t in time_periods if g_early <= t < g_late]
637
+ post_periods = [t for t in time_periods if t >= g_late]
638
+ units_treated = df[df[first_treat] == g_late][unit].unique()
639
+ units_control = df[df[first_treat] == g_early][unit].unique()
640
+
641
+ if not pre_periods or not post_periods:
642
+ comp.weight = 0.0
643
+ continue
644
+
645
+ # Subset to the 2x2 comparison sample
646
+ relevant_periods = set(pre_periods) | set(post_periods)
647
+ all_units = set(units_treated) | set(units_control)
648
+
649
+ df_22 = df[
650
+ (df[unit].isin(all_units)) &
651
+ (df[time].isin(relevant_periods))
652
+ ]
653
+
654
+ if len(df_22) == 0:
655
+ comp.weight = 0.0
656
+ continue
657
+
658
+ # Count units in this comparison
659
+ n_k = len(units_treated)
660
+ n_l = len(units_control)
661
+
662
+ if n_k == 0 or n_l == 0:
663
+ comp.weight = 0.0
664
+ continue
665
+
666
+ # Number of observations in this 2x2 sample
667
+ n_22 = len(df_22)
668
+
669
+ # Sample share of this comparison
670
+ sample_share = n_22 / n_total_obs
671
+
672
+ # Group shares within the 2x2
673
+ n_k_share = n_k / (n_k + n_l)
674
+
675
+ # Create treatment indicator for the 2x2
676
+ T_pre = len(pre_periods)
677
+ T_post = len(post_periods)
678
+ T_window = T_pre + T_post
679
+
680
+ # Variance of D within the 2x2 for treated group
681
+ # D = 0 in pre, D = 1 in post for treated units
682
+ # D = 0 for all periods for control units in this window
683
+ D_k = T_post / T_window # proportion treated for treated group
684
+
685
+ # Within-comparison variance of treatment
686
+ # Var(D) = n_k/(n_k+n_l) * D_k * (1-D_k) for the 2x2
687
+ var_D_22 = n_k_share * D_k * (1 - D_k)
688
+
689
+ # Exact weight: proportional to sample share * variance
690
+ # Scale by (n_k + n_l) / n_total_units to account for subsample
691
+ unit_share = (n_k + n_l) / n_total_units
692
+ comp.weight = sample_share * var_D_22 * unit_share
693
+
694
+ def _compute_treated_vs_never(
695
+ self,
696
+ df: pd.DataFrame,
697
+ outcome: str,
698
+ unit: str,
699
+ time: str,
700
+ first_treat: str,
701
+ treated_group: Any,
702
+ time_periods: List[Any],
703
+ ) -> Optional[Comparison2x2]:
704
+ """
705
+ Compute 2x2 DiD comparing treated group to never-treated.
706
+
707
+ This is a "clean" comparison using the full sample of a treated
708
+ cohort versus never-treated units.
709
+ """
710
+ # Get treated and never-treated units
711
+ never_mask = (df[first_treat] == 0) | (df[first_treat] == np.inf)
712
+ treated_mask = df[first_treat] == treated_group
713
+
714
+ df_treated = df[treated_mask]
715
+ df_never = df[never_mask]
716
+
717
+ if len(df_treated) == 0 or len(df_never) == 0:
718
+ return None
719
+
720
+ # Time window: all periods
721
+ t_min = min(time_periods)
722
+ t_max = max(time_periods)
723
+
724
+ # Pre and post periods for this group
725
+ pre_periods = [t for t in time_periods if t < treated_group]
726
+ post_periods = [t for t in time_periods if t >= treated_group]
727
+
728
+ if not pre_periods or not post_periods:
729
+ return None
730
+
731
+ # Compute 2x2 DiD estimate
732
+ # Mean change for treated
733
+ treated_pre = df_treated[df_treated[time].isin(pre_periods)][outcome].mean()
734
+ treated_post = df_treated[df_treated[time].isin(post_periods)][outcome].mean()
735
+
736
+ # Mean change for never-treated
737
+ never_pre = df_never[df_never[time].isin(pre_periods)][outcome].mean()
738
+ never_post = df_never[df_never[time].isin(post_periods)][outcome].mean()
739
+
740
+ estimate = (treated_post - treated_pre) - (never_post - never_pre)
741
+
742
+ # Calculate weight components
743
+ n_treated = df_treated[unit].nunique()
744
+ n_never = df_never[unit].nunique()
745
+ n_total = n_treated + n_never
746
+
747
+ # Group share
748
+ n_k = n_treated / n_total
749
+
750
+ # Variance of treatment: proportion of post-treatment periods
751
+ D_k = len(post_periods) / len(time_periods)
752
+
753
+ # Weight is proportional to n_k * (1 - n_k) * Var(D_k)
754
+ # Var(D) for treated group = D_k * (1 - D_k)
755
+ weight = n_k * (1 - n_k) * D_k * (1 - D_k)
756
+
757
+ return Comparison2x2(
758
+ treated_group=treated_group,
759
+ control_group="never_treated",
760
+ comparison_type="treated_vs_never",
761
+ estimate=estimate,
762
+ weight=weight,
763
+ n_treated=n_treated,
764
+ n_control=n_never,
765
+ time_window=(t_min, t_max),
766
+ )
767
+
768
+ def _compute_timing_comparison(
769
+ self,
770
+ df: pd.DataFrame,
771
+ outcome: str,
772
+ unit: str,
773
+ time: str,
774
+ first_treat: str,
775
+ treated_group: Any,
776
+ control_group: Any,
777
+ time_periods: List[Any],
778
+ comparison_type: str,
779
+ ) -> Optional[Comparison2x2]:
780
+ """
781
+ Compute 2x2 DiD comparing two timing groups.
782
+
783
+ For earlier_vs_later: uses later group as controls before they're treated.
784
+ For later_vs_earlier: uses earlier group as controls after treatment (forbidden).
785
+ """
786
+ treated_mask = df[first_treat] == treated_group
787
+ control_mask = df[first_treat] == control_group
788
+
789
+ df_treated = df[treated_mask]
790
+ df_control = df[control_mask]
791
+
792
+ if len(df_treated) == 0 or len(df_control) == 0:
793
+ return None
794
+
795
+ n_treated = df_treated[unit].nunique()
796
+ n_control = df_control[unit].nunique()
797
+ n_total = n_treated + n_control
798
+
799
+ if comparison_type == "earlier_vs_later":
800
+ # Earlier treated vs Later treated
801
+ # Time window: from start to when later group gets treated
802
+ # Pre: before earlier group treated
803
+ # Post: after earlier treated but before later treated
804
+ g_early = treated_group
805
+ g_late = control_group
806
+
807
+ # Pre-period: before g_early
808
+ pre_periods = [t for t in time_periods if t < g_early]
809
+ # Post-period: g_early <= t < g_late (middle period)
810
+ post_periods = [t for t in time_periods if g_early <= t < g_late]
811
+
812
+ if not pre_periods or not post_periods:
813
+ return None
814
+
815
+ time_window = (min(time_periods), g_late - 1)
816
+
817
+ else: # later_vs_earlier (forbidden)
818
+ # Later treated vs Earlier treated (used as control after treatment)
819
+ g_late = treated_group
820
+ g_early = control_group
821
+
822
+ # Pre-period: after g_early treated but before g_late treated
823
+ pre_periods = [t for t in time_periods if g_early <= t < g_late]
824
+ # Post-period: after g_late treated
825
+ post_periods = [t for t in time_periods if t >= g_late]
826
+
827
+ if not pre_periods or not post_periods:
828
+ return None
829
+
830
+ time_window = (g_early, max(time_periods))
831
+
832
+ # Compute 2x2 DiD estimate
833
+ treated_pre = df_treated[df_treated[time].isin(pre_periods)][outcome].mean()
834
+ treated_post = df_treated[df_treated[time].isin(post_periods)][outcome].mean()
835
+
836
+ control_pre = df_control[df_control[time].isin(pre_periods)][outcome].mean()
837
+ control_post = df_control[df_control[time].isin(post_periods)][outcome].mean()
838
+
839
+ if np.isnan(treated_pre) or np.isnan(treated_post):
840
+ return None
841
+ if np.isnan(control_pre) or np.isnan(control_post):
842
+ return None
843
+
844
+ estimate = (treated_post - treated_pre) - (control_post - control_pre)
845
+
846
+ # Calculate weight
847
+ n_k = n_treated / n_total
848
+
849
+ # Variance of treatment within the comparison window
850
+ total_periods_in_window = len(pre_periods) + len(post_periods)
851
+ D_k = len(post_periods) / total_periods_in_window if total_periods_in_window > 0 else 0
852
+
853
+ # Weight proportional to group sizes and treatment variance
854
+ # Scale by the fraction of total time this comparison covers
855
+ time_share = total_periods_in_window / len(time_periods)
856
+ weight = n_k * (1 - n_k) * D_k * (1 - D_k) * time_share
857
+
858
+ return Comparison2x2(
859
+ treated_group=treated_group,
860
+ control_group=control_group,
861
+ comparison_type=comparison_type,
862
+ estimate=estimate,
863
+ weight=weight,
864
+ n_treated=n_treated,
865
+ n_control=n_control,
866
+ time_window=time_window,
867
+ )
868
+
869
+ def get_params(self) -> Dict[str, Any]:
870
+ """Get estimator parameters (sklearn-compatible)."""
871
+ return {"weights": self.weights}
872
+
873
+ def set_params(self, **params) -> "BaconDecomposition":
874
+ """Set estimator parameters (sklearn-compatible)."""
875
+ if "weights" in params:
876
+ if params["weights"] not in ("approximate", "exact"):
877
+ raise ValueError(
878
+ f"weights must be 'approximate' or 'exact', "
879
+ f"got '{params['weights']}'"
880
+ )
881
+ self.weights = params["weights"]
882
+ return self
883
+
884
+ def summary(self) -> str:
885
+ """Get summary of decomposition results."""
886
+ if not self.is_fitted_:
887
+ raise RuntimeError("Model must be fitted before calling summary()")
888
+ assert self.results_ is not None
889
+ return self.results_.summary()
890
+
891
+ def print_summary(self) -> None:
892
+ """Print summary to stdout."""
893
+ print(self.summary())
894
+
895
+
896
+ def bacon_decompose(
897
+ data: pd.DataFrame,
898
+ outcome: str,
899
+ unit: str,
900
+ time: str,
901
+ first_treat: str,
902
+ weights: str = "approximate",
903
+ ) -> BaconDecompositionResults:
904
+ """
905
+ Convenience function for Goodman-Bacon decomposition.
906
+
907
+ Decomposes a TWFE estimate into weighted 2x2 DiD comparisons,
908
+ showing which comparisons drive the estimate and whether
909
+ problematic "forbidden comparisons" are involved.
910
+
911
+ Parameters
912
+ ----------
913
+ data : pd.DataFrame
914
+ Panel data with unit and time identifiers.
915
+ outcome : str
916
+ Name of outcome variable column.
917
+ unit : str
918
+ Name of unit identifier column.
919
+ time : str
920
+ Name of time period column.
921
+ first_treat : str
922
+ Name of column indicating when unit was first treated.
923
+ Use 0 (or np.inf) for never-treated units.
924
+ weights : str, default="approximate"
925
+ Weight calculation method:
926
+ - "approximate": Fast simplified formula (default). Good for
927
+ diagnostic purposes where relative weights are sufficient.
928
+ - "exact": Variance-based weights from Goodman-Bacon (2021)
929
+ Theorem 1. Use for publication-quality decompositions.
930
+
931
+ Returns
932
+ -------
933
+ BaconDecompositionResults
934
+ Object containing decomposition results with:
935
+ - twfe_estimate: The overall TWFE coefficient
936
+ - comparisons: List of all 2x2 comparisons with estimates and weights
937
+ - Weight totals by comparison type
938
+ - Methods for visualization and export
939
+
940
+ Examples
941
+ --------
942
+ >>> from diff_diff import bacon_decompose
943
+ >>>
944
+ >>> # Quick diagnostic (default)
945
+ >>> results = bacon_decompose(
946
+ ... data=panel_df,
947
+ ... outcome='earnings',
948
+ ... unit='state',
949
+ ... time='year',
950
+ ... first_treat='treatment_year'
951
+ ... )
952
+ >>>
953
+ >>> # Publication-quality exact decomposition
954
+ >>> results = bacon_decompose(
955
+ ... data=panel_df,
956
+ ... outcome='earnings',
957
+ ... unit='state',
958
+ ... time='year',
959
+ ... first_treat='treatment_year',
960
+ ... weights='exact'
961
+ ... )
962
+ >>>
963
+ >>> # View summary
964
+ >>> results.print_summary()
965
+ >>>
966
+ >>> # Check weight on forbidden comparisons
967
+ >>> print(f"Forbidden weight: {results.total_weight_later_vs_earlier:.1%}")
968
+ >>>
969
+ >>> # Convert to DataFrame for analysis
970
+ >>> df = results.to_dataframe()
971
+
972
+ See Also
973
+ --------
974
+ BaconDecomposition : Class-based interface with more options
975
+ plot_bacon : Visualize the decomposition
976
+ CallawaySantAnna : Robust estimator that avoids forbidden comparisons
977
+ """
978
+ decomp = BaconDecomposition(weights=weights)
979
+ return decomp.fit(data, outcome, unit, time, first_treat)