diff-diff 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/bacon.py ADDED
@@ -0,0 +1,1027 @@
1
+ """
2
+ Goodman-Bacon Decomposition for Two-Way Fixed Effects.
3
+
4
+ Implements the decomposition from Goodman-Bacon (2021) that shows how
5
+ TWFE estimates with staggered treatment timing can be written as a
6
+ weighted average of all possible 2x2 DiD comparisons.
7
+
8
+ Reference:
9
+ Goodman-Bacon, A. (2021). Difference-in-differences with variation
10
+ in treatment timing. Journal of Econometrics, 225(2), 254-277.
11
+ """
12
+
13
+ import warnings
14
+ from dataclasses import dataclass, field
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+
21
+ @dataclass
22
+ class Comparison2x2:
23
+ """
24
+ A single 2x2 DiD comparison in the Bacon decomposition.
25
+
26
+ Attributes
27
+ ----------
28
+ treated_group : Any
29
+ The timing group used as "treated" in this comparison.
30
+ control_group : Any
31
+ The timing group used as "control" in this comparison.
32
+ comparison_type : str
33
+ Type of comparison: "treated_vs_never", "earlier_vs_later",
34
+ or "later_vs_earlier".
35
+ estimate : float
36
+ The 2x2 DiD estimate for this comparison.
37
+ weight : float
38
+ The weight assigned to this comparison in the TWFE average.
39
+ n_treated : int
40
+ Number of treated observations in this comparison.
41
+ n_control : int
42
+ Number of control observations in this comparison.
43
+ time_window : Tuple[float, float]
44
+ The (start, end) time period for this comparison.
45
+ """
46
+
47
+ treated_group: Any
48
+ control_group: Any
49
+ comparison_type: str
50
+ estimate: float
51
+ weight: float
52
+ n_treated: int
53
+ n_control: int
54
+ time_window: Tuple[float, float]
55
+
56
+ def __repr__(self) -> str:
57
+ return (
58
+ f"Comparison2x2({self.treated_group} vs {self.control_group}, "
59
+ f"type={self.comparison_type}, β={self.estimate:.4f}, "
60
+ f"weight={self.weight:.4f})"
61
+ )
62
+
63
+
64
+ @dataclass
65
+ class BaconDecompositionResults:
66
+ """
67
+ Results from Goodman-Bacon decomposition of TWFE.
68
+
69
+ This decomposition shows that the TWFE estimate equals a weighted
70
+ average of all possible 2x2 DiD comparisons between timing groups.
71
+
72
+ Attributes
73
+ ----------
74
+ twfe_estimate : float
75
+ The overall TWFE coefficient (should equal weighted sum of 2x2 estimates).
76
+ comparisons : List[Comparison2x2]
77
+ List of all 2x2 comparisons with their estimates and weights.
78
+ total_weight_treated_vs_never : float
79
+ Total weight on treated vs never-treated comparisons.
80
+ total_weight_earlier_vs_later : float
81
+ Total weight on earlier vs later treated comparisons.
82
+ total_weight_later_vs_earlier : float
83
+ Total weight on later vs earlier treated comparisons (forbidden).
84
+ weighted_avg_treated_vs_never : float
85
+ Weighted average effect from treated vs never-treated comparisons.
86
+ weighted_avg_earlier_vs_later : float
87
+ Weighted average effect from earlier vs later comparisons.
88
+ weighted_avg_later_vs_earlier : float
89
+ Weighted average effect from later vs earlier comparisons.
90
+ n_timing_groups : int
91
+ Number of distinct treatment timing groups.
92
+ n_never_treated : int
93
+ Number of never-treated units.
94
+ timing_groups : List[Any]
95
+ List of treatment timing cohorts.
96
+ """
97
+
98
+ twfe_estimate: float
99
+ comparisons: List[Comparison2x2]
100
+ total_weight_treated_vs_never: float
101
+ total_weight_earlier_vs_later: float
102
+ total_weight_later_vs_earlier: float
103
+ weighted_avg_treated_vs_never: Optional[float]
104
+ weighted_avg_earlier_vs_later: Optional[float]
105
+ weighted_avg_later_vs_earlier: Optional[float]
106
+ n_timing_groups: int
107
+ n_never_treated: int
108
+ timing_groups: List[Any]
109
+ n_obs: int = 0
110
+ decomposition_error: float = field(default=0.0)
111
+
112
+ def __repr__(self) -> str:
113
+ return (
114
+ f"BaconDecompositionResults(TWFE={self.twfe_estimate:.4f}, "
115
+ f"n_comparisons={len(self.comparisons)}, "
116
+ f"n_groups={self.n_timing_groups})"
117
+ )
118
+
119
+ def summary(self) -> str:
120
+ """
121
+ Generate a formatted summary of the decomposition.
122
+
123
+ Returns
124
+ -------
125
+ str
126
+ Formatted summary table.
127
+ """
128
+ lines = [
129
+ "=" * 85,
130
+ "Goodman-Bacon Decomposition of Two-Way Fixed Effects".center(85),
131
+ "=" * 85,
132
+ "",
133
+ f"{'Total observations:':<35} {self.n_obs:>10}",
134
+ f"{'Treatment timing groups:':<35} {self.n_timing_groups:>10}",
135
+ f"{'Never-treated units:':<35} {self.n_never_treated:>10}",
136
+ f"{'Total 2x2 comparisons:':<35} {len(self.comparisons):>10}",
137
+ "",
138
+ "-" * 85,
139
+ "TWFE Decomposition".center(85),
140
+ "-" * 85,
141
+ "",
142
+ f"{'TWFE Estimate:':<35} {self.twfe_estimate:>12.4f}",
143
+ f"{'Weighted Sum of 2x2 Estimates:':<35} {self._weighted_sum():>12.4f}",
144
+ f"{'Decomposition Error:':<35} {self.decomposition_error:>12.6f}",
145
+ "",
146
+ ]
147
+
148
+ # Weight breakdown by comparison type
149
+ lines.extend([
150
+ "-" * 85,
151
+ "Weight Breakdown by Comparison Type".center(85),
152
+ "-" * 85,
153
+ f"{'Comparison Type':<30} {'Weight':>12} {'Avg Effect':>12} {'Contribution':>12}",
154
+ "-" * 85,
155
+ ])
156
+
157
+ # Treated vs Never-treated
158
+ if self.total_weight_treated_vs_never > 0:
159
+ contrib = self.total_weight_treated_vs_never * (
160
+ self.weighted_avg_treated_vs_never or 0
161
+ )
162
+ lines.append(
163
+ f"{'Treated vs Never-treated':<30} "
164
+ f"{self.total_weight_treated_vs_never:>12.4f} "
165
+ f"{self.weighted_avg_treated_vs_never or 0:>12.4f} "
166
+ f"{contrib:>12.4f}"
167
+ )
168
+
169
+ # Earlier vs Later
170
+ if self.total_weight_earlier_vs_later > 0:
171
+ contrib = self.total_weight_earlier_vs_later * (
172
+ self.weighted_avg_earlier_vs_later or 0
173
+ )
174
+ lines.append(
175
+ f"{'Earlier vs Later treated':<30} "
176
+ f"{self.total_weight_earlier_vs_later:>12.4f} "
177
+ f"{self.weighted_avg_earlier_vs_later or 0:>12.4f} "
178
+ f"{contrib:>12.4f}"
179
+ )
180
+
181
+ # Later vs Earlier (forbidden)
182
+ if self.total_weight_later_vs_earlier > 0:
183
+ contrib = self.total_weight_later_vs_earlier * (
184
+ self.weighted_avg_later_vs_earlier or 0
185
+ )
186
+ lines.append(
187
+ f"{'Later vs Earlier (forbidden)':<30} "
188
+ f"{self.total_weight_later_vs_earlier:>12.4f} "
189
+ f"{self.weighted_avg_later_vs_earlier or 0:>12.4f} "
190
+ f"{contrib:>12.4f}"
191
+ )
192
+
193
+ lines.extend([
194
+ "-" * 85,
195
+ f"{'Total':<30} {self._total_weight():>12.4f} "
196
+ f"{'':>12} {self._weighted_sum():>12.4f}",
197
+ "-" * 85,
198
+ "",
199
+ ])
200
+
201
+ # Warning about forbidden comparisons
202
+ if self.total_weight_later_vs_earlier > 0.01:
203
+ pct = self.total_weight_later_vs_earlier * 100
204
+ lines.extend([
205
+ "WARNING: {:.1f}% of weight is on 'forbidden' comparisons where".format(
206
+ pct
207
+ ),
208
+ "already-treated units serve as controls. This can bias TWFE",
209
+ "when treatment effects are heterogeneous over time.",
210
+ "",
211
+ "Consider using Callaway-Sant'Anna or other robust estimators.",
212
+ "",
213
+ ])
214
+
215
+ lines.append("=" * 85)
216
+
217
+ return "\n".join(lines)
218
+
219
+ def print_summary(self) -> None:
220
+ """Print the summary to stdout."""
221
+ print(self.summary())
222
+
223
+ def _weighted_sum(self) -> float:
224
+ """Calculate weighted sum of 2x2 estimates."""
225
+ return sum(c.weight * c.estimate for c in self.comparisons)
226
+
227
+ def _total_weight(self) -> float:
228
+ """Calculate total weight (should be 1.0)."""
229
+ return sum(c.weight for c in self.comparisons)
230
+
231
+ def to_dataframe(self) -> pd.DataFrame:
232
+ """
233
+ Convert comparisons to a DataFrame.
234
+
235
+ Returns
236
+ -------
237
+ pd.DataFrame
238
+ DataFrame with one row per 2x2 comparison.
239
+ """
240
+ rows = []
241
+ for c in self.comparisons:
242
+ rows.append({
243
+ "treated_group": c.treated_group,
244
+ "control_group": c.control_group,
245
+ "comparison_type": c.comparison_type,
246
+ "estimate": c.estimate,
247
+ "weight": c.weight,
248
+ "n_treated": c.n_treated,
249
+ "n_control": c.n_control,
250
+ "time_start": c.time_window[0],
251
+ "time_end": c.time_window[1],
252
+ })
253
+ return pd.DataFrame(rows)
254
+
255
+ def weight_by_type(self) -> Dict[str, float]:
256
+ """
257
+ Get total weight by comparison type.
258
+
259
+ Returns
260
+ -------
261
+ Dict[str, float]
262
+ Dictionary mapping comparison type to total weight.
263
+ """
264
+ return {
265
+ "treated_vs_never": self.total_weight_treated_vs_never,
266
+ "earlier_vs_later": self.total_weight_earlier_vs_later,
267
+ "later_vs_earlier": self.total_weight_later_vs_earlier,
268
+ }
269
+
270
+ def effect_by_type(self) -> Dict[str, Optional[float]]:
271
+ """
272
+ Get weighted average effect by comparison type.
273
+
274
+ Returns
275
+ -------
276
+ Dict[str, Optional[float]]
277
+ Dictionary mapping comparison type to weighted average effect.
278
+ """
279
+ return {
280
+ "treated_vs_never": self.weighted_avg_treated_vs_never,
281
+ "earlier_vs_later": self.weighted_avg_earlier_vs_later,
282
+ "later_vs_earlier": self.weighted_avg_later_vs_earlier,
283
+ }
284
+
285
+
286
+ class BaconDecomposition:
287
+ """
288
+ Goodman-Bacon (2021) decomposition of Two-Way Fixed Effects estimator.
289
+
290
+ This class decomposes a TWFE estimate into a weighted average of all
291
+ possible 2x2 DiD comparisons, revealing the implicit comparisons that
292
+ drive the TWFE estimate and their relative importance.
293
+
294
+ The decomposition identifies three types of comparisons:
295
+
296
+ 1. **Treated vs Never-treated**: Uses never-treated units as controls.
297
+ These are "clean" comparisons without bias concerns.
298
+
299
+ 2. **Earlier vs Later treated**: Units treated earlier are compared to
300
+ units treated later, using the later group as controls before they
301
+ are treated. These are valid comparisons.
302
+
303
+ 3. **Later vs Earlier treated**: Units treated later are compared to
304
+ units treated earlier, using the earlier group as controls AFTER
305
+ they are already treated. These are "forbidden comparisons" that
306
+ can introduce bias when treatment effects vary over time.
307
+
308
+ Parameters
309
+ ----------
310
+ weights : str, default="approximate"
311
+ Weight calculation method:
312
+ - "approximate": Fast simplified formula using group shares and
313
+ treatment variance. Good for diagnostic purposes where relative
314
+ weights are sufficient to identify problematic comparisons.
315
+ - "exact": Variance-based weights from Goodman-Bacon (2021) Theorem 1.
316
+ Use for publication-quality decompositions where the weighted sum
317
+ must closely match the TWFE estimate.
318
+
319
+ Attributes
320
+ ----------
321
+ weights : str
322
+ The weight calculation method.
323
+ results_ : BaconDecompositionResults
324
+ Decomposition results after calling fit().
325
+ is_fitted_ : bool
326
+ Whether the model has been fitted.
327
+
328
+ Examples
329
+ --------
330
+ Basic usage:
331
+
332
+ >>> import pandas as pd
333
+ >>> from diff_diff import BaconDecomposition
334
+ >>>
335
+ >>> # Panel data with staggered treatment
336
+ >>> data = pd.DataFrame({
337
+ ... 'unit': [...],
338
+ ... 'time': [...],
339
+ ... 'outcome': [...],
340
+ ... 'first_treat': [...] # 0 for never-treated
341
+ ... })
342
+ >>>
343
+ >>> bacon = BaconDecomposition()
344
+ >>> results = bacon.fit(data, outcome='outcome', unit='unit',
345
+ ... time='time', first_treat='first_treat')
346
+ >>> results.print_summary()
347
+
348
+ Visualizing the decomposition:
349
+
350
+ >>> from diff_diff import plot_bacon
351
+ >>> plot_bacon(results)
352
+
353
+ Notes
354
+ -----
355
+ The key insight from Goodman-Bacon (2021) is that TWFE with staggered
356
+ treatment timing implicitly makes comparisons using already-treated
357
+ units as controls. When treatment effects are dynamic (changing over
358
+ time since treatment), these "forbidden comparisons" can bias the
359
+ TWFE estimate, potentially even reversing its sign.
360
+
361
+ The decomposition helps diagnose this issue by showing:
362
+ - How much weight is on each type of comparison
363
+ - Whether forbidden comparisons contribute significantly to the estimate
364
+ - How the 2x2 estimates vary across comparison types
365
+
366
+ If forbidden comparisons have substantial weight and different estimates
367
+ than clean comparisons, consider using robust estimators like
368
+ Callaway-Sant'Anna that avoid these problematic comparisons.
369
+
370
+ References
371
+ ----------
372
+ Goodman-Bacon, A. (2021). Difference-in-differences with variation in
373
+ treatment timing. Journal of Econometrics, 225(2), 254-277.
374
+
375
+ See Also
376
+ --------
377
+ CallawaySantAnna : Robust estimator for staggered DiD
378
+ TwoWayFixedEffects : The TWFE estimator being decomposed
379
+ """
380
+
381
+ def __init__(self, weights: str = "approximate"):
382
+ """
383
+ Initialize BaconDecomposition.
384
+
385
+ Parameters
386
+ ----------
387
+ weights : str, default="approximate"
388
+ Weight calculation method:
389
+ - "approximate": Fast simplified formula (default)
390
+ - "exact": Variance-based weights from Goodman-Bacon (2021)
391
+ """
392
+ if weights not in ("approximate", "exact"):
393
+ raise ValueError(
394
+ f"weights must be 'approximate' or 'exact', got '{weights}'"
395
+ )
396
+ self.weights = weights
397
+ self.results_: Optional[BaconDecompositionResults] = None
398
+ self.is_fitted_: bool = False
399
+
400
+ def fit(
401
+ self,
402
+ data: pd.DataFrame,
403
+ outcome: str,
404
+ unit: str,
405
+ time: str,
406
+ first_treat: str,
407
+ ) -> BaconDecompositionResults:
408
+ """
409
+ Perform the Goodman-Bacon decomposition.
410
+
411
+ Parameters
412
+ ----------
413
+ data : pd.DataFrame
414
+ Panel data with unit and time identifiers.
415
+ outcome : str
416
+ Name of outcome variable column.
417
+ unit : str
418
+ Name of unit identifier column.
419
+ time : str
420
+ Name of time period column.
421
+ first_treat : str
422
+ Name of column indicating when unit was first treated.
423
+ Use 0 (or np.inf) for never-treated units.
424
+
425
+ Returns
426
+ -------
427
+ BaconDecompositionResults
428
+ Object containing decomposition results.
429
+
430
+ Raises
431
+ ------
432
+ ValueError
433
+ If required columns are missing or data validation fails.
434
+ """
435
+ # Validate inputs
436
+ required_cols = [outcome, unit, time, first_treat]
437
+ missing = [c for c in required_cols if c not in data.columns]
438
+ if missing:
439
+ raise ValueError(f"Missing columns: {missing}")
440
+
441
+ # Create working copy
442
+ df = data.copy()
443
+
444
+ # Ensure numeric types
445
+ df[time] = pd.to_numeric(df[time])
446
+ df[first_treat] = pd.to_numeric(df[first_treat])
447
+
448
+ # Check for balanced panel
449
+ periods_per_unit = df.groupby(unit)[time].count()
450
+ if periods_per_unit.nunique() > 1:
451
+ warnings.warn(
452
+ "Unbalanced panel detected. Bacon decomposition assumes "
453
+ "balanced panels. Results may be inaccurate.",
454
+ UserWarning,
455
+ stacklevel=2,
456
+ )
457
+
458
+ # Get unique time periods and timing groups
459
+ time_periods = sorted(df[time].unique())
460
+
461
+ # Identify never-treated and timing groups
462
+ # Never-treated: first_treat = 0 or inf
463
+ never_treated_mask = (df[first_treat] == 0) | (df[first_treat] == np.inf)
464
+ timing_groups = sorted([g for g in df[first_treat].unique()
465
+ if g > 0 and g != np.inf])
466
+
467
+ # Get unit-level treatment timing
468
+ unit_info = df.groupby(unit).agg({first_treat: 'first'}).reset_index()
469
+ n_never_treated = (
470
+ (unit_info[first_treat] == 0) | (unit_info[first_treat] == np.inf)
471
+ ).sum()
472
+
473
+ # Create treatment indicator (D_it = 1 if treated at time t)
474
+ # Use unique internal name to avoid conflicts with user data
475
+ _TREAT_COL = '__bacon_treated_internal__'
476
+ df[_TREAT_COL] = (~never_treated_mask) & (df[time] >= df[first_treat])
477
+
478
+ # First, compute TWFE estimate for reference
479
+ twfe_estimate = self._compute_twfe(df, outcome, unit, time, _TREAT_COL)
480
+
481
+ # Perform decomposition
482
+ comparisons = []
483
+
484
+ # 1. Treated vs Never-treated comparisons
485
+ if n_never_treated > 0:
486
+ for g in timing_groups:
487
+ comp = self._compute_treated_vs_never(
488
+ df, outcome, unit, time, first_treat, g, time_periods
489
+ )
490
+ if comp is not None:
491
+ comparisons.append(comp)
492
+
493
+ # 2. Timing group comparisons (earlier vs later and later vs earlier)
494
+ for i, g_early in enumerate(timing_groups):
495
+ for g_late in timing_groups[i + 1:]:
496
+ # Earlier vs Later: g_early treated, g_late as control
497
+ comp_early = self._compute_timing_comparison(
498
+ df, outcome, unit, time, first_treat,
499
+ g_early, g_late, time_periods, "earlier_vs_later"
500
+ )
501
+ if comp_early is not None:
502
+ comparisons.append(comp_early)
503
+
504
+ # Later vs Earlier: g_late treated, g_early as control (forbidden)
505
+ comp_late = self._compute_timing_comparison(
506
+ df, outcome, unit, time, first_treat,
507
+ g_late, g_early, time_periods, "later_vs_earlier"
508
+ )
509
+ if comp_late is not None:
510
+ comparisons.append(comp_late)
511
+
512
+ # Recompute exact weights if requested
513
+ if self.weights == "exact":
514
+ self._recompute_exact_weights(
515
+ comparisons, df, outcome, unit, time, first_treat, time_periods
516
+ )
517
+
518
+ # Normalize weights to sum to 1
519
+ total_weight = sum(c.weight for c in comparisons)
520
+ if total_weight > 0:
521
+ for c in comparisons:
522
+ c.weight = c.weight / total_weight
523
+
524
+ # Calculate weight totals and weighted averages by type
525
+ weight_by_type = {"treated_vs_never": 0.0, "earlier_vs_later": 0.0,
526
+ "later_vs_earlier": 0.0}
527
+ weighted_sum_by_type = {"treated_vs_never": 0.0, "earlier_vs_later": 0.0,
528
+ "later_vs_earlier": 0.0}
529
+
530
+ for c in comparisons:
531
+ weight_by_type[c.comparison_type] += c.weight
532
+ weighted_sum_by_type[c.comparison_type] += c.weight * c.estimate
533
+
534
+ # Calculate weighted averages
535
+ avg_by_type = {}
536
+ for ctype in weight_by_type:
537
+ if weight_by_type[ctype] > 0:
538
+ avg_by_type[ctype] = (
539
+ weighted_sum_by_type[ctype] / weight_by_type[ctype]
540
+ )
541
+ else:
542
+ avg_by_type[ctype] = None
543
+
544
+ # Calculate decomposition error
545
+ weighted_sum = sum(c.weight * c.estimate for c in comparisons)
546
+ decomp_error = abs(twfe_estimate - weighted_sum)
547
+
548
+ self.results_ = BaconDecompositionResults(
549
+ twfe_estimate=twfe_estimate,
550
+ comparisons=comparisons,
551
+ total_weight_treated_vs_never=weight_by_type["treated_vs_never"],
552
+ total_weight_earlier_vs_later=weight_by_type["earlier_vs_later"],
553
+ total_weight_later_vs_earlier=weight_by_type["later_vs_earlier"],
554
+ weighted_avg_treated_vs_never=avg_by_type["treated_vs_never"],
555
+ weighted_avg_earlier_vs_later=avg_by_type["earlier_vs_later"],
556
+ weighted_avg_later_vs_earlier=avg_by_type["later_vs_earlier"],
557
+ n_timing_groups=len(timing_groups),
558
+ n_never_treated=n_never_treated,
559
+ timing_groups=timing_groups,
560
+ n_obs=len(df),
561
+ decomposition_error=decomp_error,
562
+ )
563
+
564
+ self.is_fitted_ = True
565
+ return self.results_
566
+
567
+ def _compute_twfe(
568
+ self,
569
+ df: pd.DataFrame,
570
+ outcome: str,
571
+ unit: str,
572
+ time: str,
573
+ treat_col: str = '__bacon_treated_internal__',
574
+ ) -> float:
575
+ """Compute TWFE estimate using within-transformation."""
576
+ # Demean by unit and time
577
+ y = df[outcome].values
578
+ d = df[treat_col].astype(float).values
579
+
580
+ # Create unit and time dummies for demeaning
581
+ units = df[unit].values
582
+ times = df[time].values
583
+
584
+ # Unit means
585
+ unit_map = {u: i for i, u in enumerate(df[unit].unique())}
586
+ unit_idx = np.array([unit_map[u] for u in units])
587
+ n_units = len(unit_map)
588
+
589
+ # Time means
590
+ time_map = {t: i for i, t in enumerate(df[time].unique())}
591
+ time_idx = np.array([time_map[t] for t in times])
592
+ n_times = len(time_map)
593
+
594
+ # Compute means
595
+ y_unit_mean = np.zeros(n_units)
596
+ d_unit_mean = np.zeros(n_units)
597
+ unit_counts = np.zeros(n_units)
598
+
599
+ for i in range(len(y)):
600
+ u = unit_idx[i]
601
+ y_unit_mean[u] += y[i]
602
+ d_unit_mean[u] += d[i]
603
+ unit_counts[u] += 1
604
+
605
+ y_unit_mean /= np.maximum(unit_counts, 1)
606
+ d_unit_mean /= np.maximum(unit_counts, 1)
607
+
608
+ y_time_mean = np.zeros(n_times)
609
+ d_time_mean = np.zeros(n_times)
610
+ time_counts = np.zeros(n_times)
611
+
612
+ for i in range(len(y)):
613
+ t = time_idx[i]
614
+ y_time_mean[t] += y[i]
615
+ d_time_mean[t] += d[i]
616
+ time_counts[t] += 1
617
+
618
+ y_time_mean /= np.maximum(time_counts, 1)
619
+ d_time_mean /= np.maximum(time_counts, 1)
620
+
621
+ # Overall mean
622
+ y_mean = np.mean(y)
623
+ d_mean = np.mean(d)
624
+
625
+ # Within transformation: y_it - y_i - y_t + y
626
+ y_within = np.zeros(len(y))
627
+ d_within = np.zeros(len(d))
628
+
629
+ for i in range(len(y)):
630
+ u = unit_idx[i]
631
+ t = time_idx[i]
632
+ y_within[i] = y[i] - y_unit_mean[u] - y_time_mean[t] + y_mean
633
+ d_within[i] = d[i] - d_unit_mean[u] - d_time_mean[t] + d_mean
634
+
635
+ # OLS on demeaned data
636
+ d_var = np.sum(d_within ** 2)
637
+ if d_var > 0:
638
+ beta = np.sum(d_within * y_within) / d_var
639
+ else:
640
+ beta = 0.0
641
+
642
+ return beta
643
+
644
+ def _recompute_exact_weights(
645
+ self,
646
+ comparisons: List[Comparison2x2],
647
+ df: pd.DataFrame,
648
+ outcome: str,
649
+ unit: str,
650
+ time: str,
651
+ first_treat: str,
652
+ time_periods: List[Any],
653
+ ) -> None:
654
+ """
655
+ Recompute weights using exact variance-based formula from Theorem 1.
656
+
657
+ This modifies comparison weights in-place to use the exact formula
658
+ from Goodman-Bacon (2021) which accounts for within-group variance
659
+ of the treatment indicator in each 2x2 comparison window.
660
+ """
661
+ n_total_obs = len(df)
662
+ n_total_units = df[unit].nunique()
663
+
664
+ for comp in comparisons:
665
+ # Get data for this specific comparison
666
+ if comp.comparison_type == "treated_vs_never":
667
+ pre_periods = [t for t in time_periods if t < comp.treated_group]
668
+ post_periods = [t for t in time_periods if t >= comp.treated_group]
669
+ # Get units in each group
670
+ units_treated = df[df[first_treat] == comp.treated_group][unit].unique()
671
+ units_control = df[
672
+ (df[first_treat] == 0) | (df[first_treat] == np.inf)
673
+ ][unit].unique()
674
+ elif comp.comparison_type == "earlier_vs_later":
675
+ g_early = comp.treated_group
676
+ g_late = comp.control_group
677
+ pre_periods = [t for t in time_periods if t < g_early]
678
+ post_periods = [t for t in time_periods if g_early <= t < g_late]
679
+ units_treated = df[df[first_treat] == g_early][unit].unique()
680
+ units_control = df[df[first_treat] == g_late][unit].unique()
681
+ else: # later_vs_earlier
682
+ g_late = comp.treated_group
683
+ g_early = comp.control_group
684
+ pre_periods = [t for t in time_periods if g_early <= t < g_late]
685
+ post_periods = [t for t in time_periods if t >= g_late]
686
+ units_treated = df[df[first_treat] == g_late][unit].unique()
687
+ units_control = df[df[first_treat] == g_early][unit].unique()
688
+
689
+ if not pre_periods or not post_periods:
690
+ comp.weight = 0.0
691
+ continue
692
+
693
+ # Subset to the 2x2 comparison sample
694
+ relevant_periods = set(pre_periods) | set(post_periods)
695
+ all_units = set(units_treated) | set(units_control)
696
+
697
+ df_22 = df[
698
+ (df[unit].isin(all_units)) &
699
+ (df[time].isin(relevant_periods))
700
+ ]
701
+
702
+ if len(df_22) == 0:
703
+ comp.weight = 0.0
704
+ continue
705
+
706
+ # Count units in this comparison
707
+ n_k = len(units_treated)
708
+ n_l = len(units_control)
709
+
710
+ if n_k == 0 or n_l == 0:
711
+ comp.weight = 0.0
712
+ continue
713
+
714
+ # Number of observations in this 2x2 sample
715
+ n_22 = len(df_22)
716
+
717
+ # Sample share of this comparison
718
+ sample_share = n_22 / n_total_obs
719
+
720
+ # Group shares within the 2x2
721
+ n_k_share = n_k / (n_k + n_l)
722
+
723
+ # Create treatment indicator for the 2x2
724
+ T_pre = len(pre_periods)
725
+ T_post = len(post_periods)
726
+ T_window = T_pre + T_post
727
+
728
+ # Variance of D within the 2x2 for treated group
729
+ # D = 0 in pre, D = 1 in post for treated units
730
+ # D = 0 for all periods for control units in this window
731
+ D_k = T_post / T_window # proportion treated for treated group
732
+
733
+ # Within-comparison variance of treatment
734
+ # Var(D) = n_k/(n_k+n_l) * D_k * (1-D_k) for the 2x2
735
+ var_D_22 = n_k_share * D_k * (1 - D_k)
736
+
737
+ # Exact weight: proportional to sample share * variance
738
+ # Scale by (n_k + n_l) / n_total_units to account for subsample
739
+ unit_share = (n_k + n_l) / n_total_units
740
+ comp.weight = sample_share * var_D_22 * unit_share
741
+
742
+ def _compute_treated_vs_never(
743
+ self,
744
+ df: pd.DataFrame,
745
+ outcome: str,
746
+ unit: str,
747
+ time: str,
748
+ first_treat: str,
749
+ treated_group: Any,
750
+ time_periods: List[Any],
751
+ ) -> Optional[Comparison2x2]:
752
+ """
753
+ Compute 2x2 DiD comparing treated group to never-treated.
754
+
755
+ This is a "clean" comparison using the full sample of a treated
756
+ cohort versus never-treated units.
757
+ """
758
+ # Get treated and never-treated units
759
+ never_mask = (df[first_treat] == 0) | (df[first_treat] == np.inf)
760
+ treated_mask = df[first_treat] == treated_group
761
+
762
+ df_treated = df[treated_mask]
763
+ df_never = df[never_mask]
764
+
765
+ if len(df_treated) == 0 or len(df_never) == 0:
766
+ return None
767
+
768
+ # Time window: all periods
769
+ t_min = min(time_periods)
770
+ t_max = max(time_periods)
771
+
772
+ # Pre and post periods for this group
773
+ pre_periods = [t for t in time_periods if t < treated_group]
774
+ post_periods = [t for t in time_periods if t >= treated_group]
775
+
776
+ if not pre_periods or not post_periods:
777
+ return None
778
+
779
+ # Compute 2x2 DiD estimate
780
+ # Mean change for treated
781
+ treated_pre = df_treated[df_treated[time].isin(pre_periods)][outcome].mean()
782
+ treated_post = df_treated[df_treated[time].isin(post_periods)][outcome].mean()
783
+
784
+ # Mean change for never-treated
785
+ never_pre = df_never[df_never[time].isin(pre_periods)][outcome].mean()
786
+ never_post = df_never[df_never[time].isin(post_periods)][outcome].mean()
787
+
788
+ estimate = (treated_post - treated_pre) - (never_post - never_pre)
789
+
790
+ # Calculate weight components
791
+ n_treated = df_treated[unit].nunique()
792
+ n_never = df_never[unit].nunique()
793
+ n_total = n_treated + n_never
794
+
795
+ # Group share
796
+ n_k = n_treated / n_total
797
+
798
+ # Variance of treatment: proportion of post-treatment periods
799
+ D_k = len(post_periods) / len(time_periods)
800
+
801
+ # Weight is proportional to n_k * (1 - n_k) * Var(D_k)
802
+ # Var(D) for treated group = D_k * (1 - D_k)
803
+ weight = n_k * (1 - n_k) * D_k * (1 - D_k)
804
+
805
+ return Comparison2x2(
806
+ treated_group=treated_group,
807
+ control_group="never_treated",
808
+ comparison_type="treated_vs_never",
809
+ estimate=estimate,
810
+ weight=weight,
811
+ n_treated=n_treated,
812
+ n_control=n_never,
813
+ time_window=(t_min, t_max),
814
+ )
815
+
816
+ def _compute_timing_comparison(
817
+ self,
818
+ df: pd.DataFrame,
819
+ outcome: str,
820
+ unit: str,
821
+ time: str,
822
+ first_treat: str,
823
+ treated_group: Any,
824
+ control_group: Any,
825
+ time_periods: List[Any],
826
+ comparison_type: str,
827
+ ) -> Optional[Comparison2x2]:
828
+ """
829
+ Compute 2x2 DiD comparing two timing groups.
830
+
831
+ For earlier_vs_later: uses later group as controls before they're treated.
832
+ For later_vs_earlier: uses earlier group as controls after treatment (forbidden).
833
+ """
834
+ treated_mask = df[first_treat] == treated_group
835
+ control_mask = df[first_treat] == control_group
836
+
837
+ df_treated = df[treated_mask]
838
+ df_control = df[control_mask]
839
+
840
+ if len(df_treated) == 0 or len(df_control) == 0:
841
+ return None
842
+
843
+ n_treated = df_treated[unit].nunique()
844
+ n_control = df_control[unit].nunique()
845
+ n_total = n_treated + n_control
846
+
847
+ if comparison_type == "earlier_vs_later":
848
+ # Earlier treated vs Later treated
849
+ # Time window: from start to when later group gets treated
850
+ # Pre: before earlier group treated
851
+ # Post: after earlier treated but before later treated
852
+ g_early = treated_group
853
+ g_late = control_group
854
+
855
+ # Pre-period: before g_early
856
+ pre_periods = [t for t in time_periods if t < g_early]
857
+ # Post-period: g_early <= t < g_late (middle period)
858
+ post_periods = [t for t in time_periods if g_early <= t < g_late]
859
+
860
+ if not pre_periods or not post_periods:
861
+ return None
862
+
863
+ time_window = (min(time_periods), g_late - 1)
864
+
865
+ else: # later_vs_earlier (forbidden)
866
+ # Later treated vs Earlier treated (used as control after treatment)
867
+ g_late = treated_group
868
+ g_early = control_group
869
+
870
+ # Pre-period: after g_early treated but before g_late treated
871
+ pre_periods = [t for t in time_periods if g_early <= t < g_late]
872
+ # Post-period: after g_late treated
873
+ post_periods = [t for t in time_periods if t >= g_late]
874
+
875
+ if not pre_periods or not post_periods:
876
+ return None
877
+
878
+ time_window = (g_early, max(time_periods))
879
+
880
+ # Compute 2x2 DiD estimate
881
+ treated_pre = df_treated[df_treated[time].isin(pre_periods)][outcome].mean()
882
+ treated_post = df_treated[df_treated[time].isin(post_periods)][outcome].mean()
883
+
884
+ control_pre = df_control[df_control[time].isin(pre_periods)][outcome].mean()
885
+ control_post = df_control[df_control[time].isin(post_periods)][outcome].mean()
886
+
887
+ if np.isnan(treated_pre) or np.isnan(treated_post):
888
+ return None
889
+ if np.isnan(control_pre) or np.isnan(control_post):
890
+ return None
891
+
892
+ estimate = (treated_post - treated_pre) - (control_post - control_pre)
893
+
894
+ # Calculate weight
895
+ n_k = n_treated / n_total
896
+
897
+ # Variance of treatment within the comparison window
898
+ total_periods_in_window = len(pre_periods) + len(post_periods)
899
+ D_k = len(post_periods) / total_periods_in_window if total_periods_in_window > 0 else 0
900
+
901
+ # Weight proportional to group sizes and treatment variance
902
+ # Scale by the fraction of total time this comparison covers
903
+ time_share = total_periods_in_window / len(time_periods)
904
+ weight = n_k * (1 - n_k) * D_k * (1 - D_k) * time_share
905
+
906
+ return Comparison2x2(
907
+ treated_group=treated_group,
908
+ control_group=control_group,
909
+ comparison_type=comparison_type,
910
+ estimate=estimate,
911
+ weight=weight,
912
+ n_treated=n_treated,
913
+ n_control=n_control,
914
+ time_window=time_window,
915
+ )
916
+
917
+ def get_params(self) -> Dict[str, Any]:
918
+ """Get estimator parameters (sklearn-compatible)."""
919
+ return {"weights": self.weights}
920
+
921
+ def set_params(self, **params) -> "BaconDecomposition":
922
+ """Set estimator parameters (sklearn-compatible)."""
923
+ if "weights" in params:
924
+ if params["weights"] not in ("approximate", "exact"):
925
+ raise ValueError(
926
+ f"weights must be 'approximate' or 'exact', "
927
+ f"got '{params['weights']}'"
928
+ )
929
+ self.weights = params["weights"]
930
+ return self
931
+
932
+ def summary(self) -> str:
933
+ """Get summary of decomposition results."""
934
+ if not self.is_fitted_:
935
+ raise RuntimeError("Model must be fitted before calling summary()")
936
+ assert self.results_ is not None
937
+ return self.results_.summary()
938
+
939
+ def print_summary(self) -> None:
940
+ """Print summary to stdout."""
941
+ print(self.summary())
942
+
943
+
944
+ def bacon_decompose(
945
+ data: pd.DataFrame,
946
+ outcome: str,
947
+ unit: str,
948
+ time: str,
949
+ first_treat: str,
950
+ weights: str = "approximate",
951
+ ) -> BaconDecompositionResults:
952
+ """
953
+ Convenience function for Goodman-Bacon decomposition.
954
+
955
+ Decomposes a TWFE estimate into weighted 2x2 DiD comparisons,
956
+ showing which comparisons drive the estimate and whether
957
+ problematic "forbidden comparisons" are involved.
958
+
959
+ Parameters
960
+ ----------
961
+ data : pd.DataFrame
962
+ Panel data with unit and time identifiers.
963
+ outcome : str
964
+ Name of outcome variable column.
965
+ unit : str
966
+ Name of unit identifier column.
967
+ time : str
968
+ Name of time period column.
969
+ first_treat : str
970
+ Name of column indicating when unit was first treated.
971
+ Use 0 (or np.inf) for never-treated units.
972
+ weights : str, default="approximate"
973
+ Weight calculation method:
974
+ - "approximate": Fast simplified formula (default). Good for
975
+ diagnostic purposes where relative weights are sufficient.
976
+ - "exact": Variance-based weights from Goodman-Bacon (2021)
977
+ Theorem 1. Use for publication-quality decompositions.
978
+
979
+ Returns
980
+ -------
981
+ BaconDecompositionResults
982
+ Object containing decomposition results with:
983
+ - twfe_estimate: The overall TWFE coefficient
984
+ - comparisons: List of all 2x2 comparisons with estimates and weights
985
+ - Weight totals by comparison type
986
+ - Methods for visualization and export
987
+
988
+ Examples
989
+ --------
990
+ >>> from diff_diff import bacon_decompose
991
+ >>>
992
+ >>> # Quick diagnostic (default)
993
+ >>> results = bacon_decompose(
994
+ ... data=panel_df,
995
+ ... outcome='earnings',
996
+ ... unit='state',
997
+ ... time='year',
998
+ ... first_treat='treatment_year'
999
+ ... )
1000
+ >>>
1001
+ >>> # Publication-quality exact decomposition
1002
+ >>> results = bacon_decompose(
1003
+ ... data=panel_df,
1004
+ ... outcome='earnings',
1005
+ ... unit='state',
1006
+ ... time='year',
1007
+ ... first_treat='treatment_year',
1008
+ ... weights='exact'
1009
+ ... )
1010
+ >>>
1011
+ >>> # View summary
1012
+ >>> results.print_summary()
1013
+ >>>
1014
+ >>> # Check weight on forbidden comparisons
1015
+ >>> print(f"Forbidden weight: {results.total_weight_later_vs_earlier:.1%}")
1016
+ >>>
1017
+ >>> # Convert to DataFrame for analysis
1018
+ >>> df = results.to_dataframe()
1019
+
1020
+ See Also
1021
+ --------
1022
+ BaconDecomposition : Class-based interface with more options
1023
+ plot_bacon : Visualize the decomposition
1024
+ CallawaySantAnna : Robust estimator that avoids forbidden comparisons
1025
+ """
1026
+ decomp = BaconDecomposition(weights=weights)
1027
+ return decomp.fit(data, outcome, unit, time, first_treat)