diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,817 @@
1
+ """Diagnostic visualization functions (sensitivity, Bacon decomposition)."""
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ if TYPE_CHECKING:
8
+ from diff_diff.bacon import BaconDecompositionResults
9
+ from diff_diff.honest_did import SensitivityResults
10
+
11
+
12
+ def plot_sensitivity(
13
+ sensitivity_results: "SensitivityResults",
14
+ *,
15
+ show_bounds: bool = True,
16
+ show_ci: bool = True,
17
+ breakdown_line: bool = True,
18
+ figsize: Tuple[float, float] = (10, 6),
19
+ title: str = "Honest DiD Sensitivity Analysis",
20
+ xlabel: str = "M (restriction parameter)",
21
+ ylabel: str = "Treatment Effect",
22
+ bounds_color: str = "#2563eb",
23
+ bounds_alpha: float = 0.3,
24
+ ci_color: str = "#2563eb",
25
+ ci_linewidth: float = 1.5,
26
+ breakdown_color: str = "#dc2626",
27
+ original_color: str = "#1f2937",
28
+ ax: Optional[Any] = None,
29
+ show: bool = True,
30
+ backend: str = "matplotlib",
31
+ ) -> Any:
32
+ """
33
+ Plot sensitivity analysis results from Honest DiD.
34
+
35
+ Shows how treatment effect bounds and confidence intervals
36
+ change as the restriction parameter M varies.
37
+
38
+ Parameters
39
+ ----------
40
+ sensitivity_results : SensitivityResults
41
+ Results from HonestDiD.sensitivity_analysis().
42
+ show_bounds : bool, default=True
43
+ Whether to show the identified set bounds as shaded region.
44
+ show_ci : bool, default=True
45
+ Whether to show robust confidence interval lines.
46
+ breakdown_line : bool, default=True
47
+ Whether to show vertical line at breakdown value.
48
+ figsize : tuple, default=(10, 6)
49
+ Figure size (width, height) in inches.
50
+ title : str
51
+ Plot title.
52
+ xlabel : str
53
+ X-axis label.
54
+ ylabel : str
55
+ Y-axis label.
56
+ bounds_color : str
57
+ Color for identified set shading.
58
+ bounds_alpha : float
59
+ Transparency for identified set shading.
60
+ ci_color : str
61
+ Color for confidence interval lines.
62
+ ci_linewidth : float
63
+ Line width for CI lines.
64
+ breakdown_color : str
65
+ Color for breakdown value line.
66
+ original_color : str
67
+ Color for original estimate line.
68
+ ax : matplotlib.axes.Axes, optional
69
+ Axes to plot on. If None, creates new figure.
70
+ show : bool, default=True
71
+ Whether to call plt.show().
72
+ backend : str, default="matplotlib"
73
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
74
+
75
+ Returns
76
+ -------
77
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
78
+ The axes object (matplotlib) or figure (plotly).
79
+
80
+ Examples
81
+ --------
82
+ >>> from diff_diff import MultiPeriodDiD
83
+ >>> from diff_diff.honest_did import HonestDiD
84
+ >>> from diff_diff.visualization import plot_sensitivity
85
+ >>>
86
+ >>> # Fit event study and run sensitivity analysis
87
+ >>> results = MultiPeriodDiD().fit(data, ...)
88
+ >>> honest = HonestDiD(method='relative_magnitude')
89
+ >>> sensitivity = honest.sensitivity_analysis(results)
90
+ >>>
91
+ >>> # Create sensitivity plot
92
+ >>> plot_sensitivity(sensitivity)
93
+ """
94
+ M = sensitivity_results.M_values
95
+ bounds_arr = np.array(sensitivity_results.bounds)
96
+ ci_arr = np.array(sensitivity_results.robust_cis)
97
+
98
+ if backend == "plotly":
99
+ return _render_sensitivity_plotly(
100
+ M=M,
101
+ bounds_arr=bounds_arr,
102
+ ci_arr=ci_arr,
103
+ original_estimate=sensitivity_results.original_estimate,
104
+ breakdown_M=sensitivity_results.breakdown_M,
105
+ show_bounds=show_bounds,
106
+ show_ci=show_ci,
107
+ breakdown_line=breakdown_line,
108
+ title=title,
109
+ xlabel=xlabel,
110
+ ylabel=ylabel,
111
+ bounds_color=bounds_color,
112
+ bounds_alpha=bounds_alpha,
113
+ ci_color=ci_color,
114
+ ci_linewidth=ci_linewidth,
115
+ breakdown_color=breakdown_color,
116
+ original_color=original_color,
117
+ show=show,
118
+ )
119
+
120
+ return _render_sensitivity_mpl(
121
+ M=M,
122
+ bounds_arr=bounds_arr,
123
+ ci_arr=ci_arr,
124
+ original_estimate=sensitivity_results.original_estimate,
125
+ breakdown_M=sensitivity_results.breakdown_M,
126
+ show_bounds=show_bounds,
127
+ show_ci=show_ci,
128
+ breakdown_line=breakdown_line,
129
+ figsize=figsize,
130
+ title=title,
131
+ xlabel=xlabel,
132
+ ylabel=ylabel,
133
+ bounds_color=bounds_color,
134
+ bounds_alpha=bounds_alpha,
135
+ ci_color=ci_color,
136
+ ci_linewidth=ci_linewidth,
137
+ breakdown_color=breakdown_color,
138
+ original_color=original_color,
139
+ ax=ax,
140
+ show=show,
141
+ )
142
+
143
+
144
+ def _render_sensitivity_mpl(
145
+ *,
146
+ M,
147
+ bounds_arr,
148
+ ci_arr,
149
+ original_estimate,
150
+ breakdown_M,
151
+ show_bounds,
152
+ show_ci,
153
+ breakdown_line,
154
+ figsize,
155
+ title,
156
+ xlabel,
157
+ ylabel,
158
+ bounds_color,
159
+ bounds_alpha,
160
+ ci_color,
161
+ ci_linewidth,
162
+ breakdown_color,
163
+ original_color,
164
+ ax,
165
+ show,
166
+ ):
167
+ """Render sensitivity plot with matplotlib."""
168
+ from diff_diff.visualization._common import _require_matplotlib
169
+
170
+ plt = _require_matplotlib()
171
+
172
+ if ax is None:
173
+ fig, ax = plt.subplots(figsize=figsize)
174
+ else:
175
+ fig = ax.get_figure()
176
+
177
+ # Plot original estimate
178
+ ax.axhline(
179
+ y=original_estimate,
180
+ color=original_color,
181
+ linestyle="-",
182
+ linewidth=1.5,
183
+ label="Original estimate",
184
+ alpha=0.7,
185
+ )
186
+
187
+ # Plot zero line
188
+ ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5)
189
+
190
+ # Plot identified set bounds
191
+ if show_bounds:
192
+ ax.fill_between(
193
+ M,
194
+ bounds_arr[:, 0],
195
+ bounds_arr[:, 1],
196
+ alpha=bounds_alpha,
197
+ color=bounds_color,
198
+ label="Identified set",
199
+ )
200
+
201
+ # Plot confidence intervals
202
+ if show_ci:
203
+ ax.plot(M, ci_arr[:, 0], color=ci_color, linewidth=ci_linewidth, label="Robust CI")
204
+ ax.plot(M, ci_arr[:, 1], color=ci_color, linewidth=ci_linewidth)
205
+
206
+ # Plot breakdown line
207
+ if breakdown_line and breakdown_M is not None:
208
+ ax.axvline(
209
+ x=breakdown_M,
210
+ color=breakdown_color,
211
+ linestyle=":",
212
+ linewidth=2,
213
+ label=f"Breakdown (M={breakdown_M:.2f})",
214
+ )
215
+
216
+ ax.set_xlabel(xlabel)
217
+ ax.set_ylabel(ylabel)
218
+ ax.set_title(title)
219
+ ax.legend(loc="best")
220
+ ax.grid(True, alpha=0.3)
221
+
222
+ fig.tight_layout()
223
+
224
+ if show:
225
+ plt.show()
226
+
227
+ return ax
228
+
229
+
230
+ def _render_sensitivity_plotly(
231
+ *,
232
+ M,
233
+ bounds_arr,
234
+ ci_arr,
235
+ original_estimate,
236
+ breakdown_M,
237
+ show_bounds,
238
+ show_ci,
239
+ breakdown_line,
240
+ title,
241
+ xlabel,
242
+ ylabel,
243
+ bounds_color,
244
+ bounds_alpha,
245
+ ci_color,
246
+ ci_linewidth,
247
+ breakdown_color,
248
+ original_color,
249
+ show,
250
+ ):
251
+ """Render sensitivity plot with plotly."""
252
+ from diff_diff.visualization._common import (
253
+ _color_to_rgba,
254
+ _plotly_default_layout,
255
+ _require_plotly,
256
+ )
257
+
258
+ go = _require_plotly()
259
+
260
+ fig = go.Figure()
261
+
262
+ M_list = list(M) if not isinstance(M, list) else M
263
+
264
+ # Original estimate line
265
+ fig.add_hline(
266
+ y=original_estimate,
267
+ line_color=original_color,
268
+ line_width=1.5,
269
+ opacity=0.7,
270
+ annotation_text="Original estimate",
271
+ )
272
+
273
+ # Zero line
274
+ fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1, opacity=0.5)
275
+
276
+ # Identified set bounds
277
+ if show_bounds:
278
+ fig.add_trace(
279
+ go.Scatter(
280
+ x=M_list + M_list[::-1],
281
+ y=list(bounds_arr[:, 1]) + list(bounds_arr[:, 0])[::-1],
282
+ fill="toself",
283
+ fillcolor=_color_to_rgba(bounds_color, bounds_alpha),
284
+ line=dict(color="rgba(0,0,0,0)"),
285
+ name="Identified set",
286
+ )
287
+ )
288
+
289
+ # Confidence intervals
290
+ if show_ci:
291
+ fig.add_trace(
292
+ go.Scatter(
293
+ x=M_list,
294
+ y=list(ci_arr[:, 0]),
295
+ mode="lines",
296
+ line=dict(color=ci_color, width=ci_linewidth),
297
+ name="Robust CI",
298
+ )
299
+ )
300
+ fig.add_trace(
301
+ go.Scatter(
302
+ x=M_list,
303
+ y=list(ci_arr[:, 1]),
304
+ mode="lines",
305
+ line=dict(color=ci_color, width=ci_linewidth),
306
+ showlegend=False,
307
+ )
308
+ )
309
+
310
+ # Breakdown line
311
+ if breakdown_line and breakdown_M is not None:
312
+ fig.add_vline(
313
+ x=breakdown_M,
314
+ line_dash="dot",
315
+ line_color=breakdown_color,
316
+ line_width=2,
317
+ annotation_text=f"Breakdown (M={breakdown_M:.2f})",
318
+ )
319
+
320
+ _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
321
+
322
+ if show:
323
+ fig.show()
324
+
325
+ return fig
326
+
327
+
328
+ def plot_bacon(
329
+ results: "BaconDecompositionResults",
330
+ *,
331
+ plot_type: str = "scatter",
332
+ figsize: Tuple[float, float] = (10, 6),
333
+ title: Optional[str] = None,
334
+ xlabel: str = "2x2 DiD Estimate",
335
+ ylabel: str = "Weight",
336
+ colors: Optional[Dict[str, str]] = None,
337
+ marker: str = "o",
338
+ markersize: int = 80,
339
+ alpha: float = 0.7,
340
+ show_weighted_avg: bool = True,
341
+ show_twfe_line: bool = True,
342
+ ax: Optional[Any] = None,
343
+ show: bool = True,
344
+ backend: str = "matplotlib",
345
+ ) -> Any:
346
+ """
347
+ Visualize Goodman-Bacon decomposition results.
348
+
349
+ Creates either a scatter plot showing the weight and estimate for each
350
+ 2x2 comparison, or a stacked bar chart showing total weight by comparison
351
+ type.
352
+
353
+ Parameters
354
+ ----------
355
+ results : BaconDecompositionResults
356
+ Results from BaconDecomposition.fit() or bacon_decompose().
357
+ plot_type : str, default="scatter"
358
+ Type of plot to create:
359
+ - "scatter": Scatter plot with estimates on x-axis, weights on y-axis
360
+ - "bar": Stacked bar chart of weights by comparison type
361
+ figsize : tuple, default=(10, 6)
362
+ Figure size (width, height) in inches.
363
+ title : str, optional
364
+ Plot title. If None, uses a default based on plot_type.
365
+ xlabel : str, default="2x2 DiD Estimate"
366
+ X-axis label (scatter plot only).
367
+ ylabel : str, default="Weight"
368
+ Y-axis label.
369
+ colors : dict, optional
370
+ Dictionary mapping comparison types to colors. Keys are:
371
+ "treated_vs_never", "earlier_vs_later", "later_vs_earlier".
372
+ If None, uses default colors.
373
+ marker : str, default="o"
374
+ Marker style for scatter plot.
375
+ markersize : int, default=80
376
+ Marker size for scatter plot.
377
+ alpha : float, default=0.7
378
+ Transparency for markers/bars.
379
+ show_weighted_avg : bool, default=True
380
+ Whether to show weighted average lines for each comparison type
381
+ (scatter plot only).
382
+ show_twfe_line : bool, default=True
383
+ Whether to show a vertical line at the TWFE estimate (scatter plot only).
384
+ ax : matplotlib.axes.Axes, optional
385
+ Axes to plot on. If None, creates new figure.
386
+ show : bool, default=True
387
+ Whether to call plt.show() at the end.
388
+ backend : str, default="matplotlib"
389
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
390
+
391
+ Returns
392
+ -------
393
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
394
+ The axes object (matplotlib) or figure (plotly).
395
+
396
+ Examples
397
+ --------
398
+ Scatter plot (default):
399
+
400
+ >>> from diff_diff import bacon_decompose, plot_bacon
401
+ >>> results = bacon_decompose(data, outcome='y', unit='id',
402
+ ... time='t', first_treat='first_treat')
403
+ >>> plot_bacon(results)
404
+
405
+ Bar chart of weights by type:
406
+
407
+ >>> plot_bacon(results, plot_type='bar')
408
+
409
+ Notes
410
+ -----
411
+ The scatter plot is particularly useful for understanding:
412
+
413
+ 1. **Distribution of estimates**: Are 2x2 estimates clustered or spread?
414
+ Wide spread suggests heterogeneous treatment effects.
415
+
416
+ 2. **Weight concentration**: Do a few comparisons dominate the TWFE?
417
+ Points with high weights have more influence.
418
+
419
+ 3. **Forbidden comparison problem**: Red points (later_vs_earlier) show
420
+ comparisons using already-treated units as controls. If these have
421
+ different estimates than clean comparisons, TWFE may be biased.
422
+
423
+ See Also
424
+ --------
425
+ bacon_decompose : Perform the decomposition
426
+ BaconDecomposition : Class-based interface
427
+ """
428
+ # Default colors
429
+ if colors is None:
430
+ colors = {
431
+ "treated_vs_never": "#22c55e", # Green - clean comparison
432
+ "earlier_vs_later": "#3b82f6", # Blue - valid comparison
433
+ "later_vs_earlier": "#ef4444", # Red - forbidden comparison
434
+ }
435
+
436
+ # Default titles
437
+ if title is None:
438
+ if plot_type == "scatter":
439
+ title = "Goodman-Bacon Decomposition"
440
+ else:
441
+ title = "TWFE Weight by Comparison Type"
442
+
443
+ if plot_type not in ("scatter", "bar"):
444
+ raise ValueError(f"Unknown plot_type: {plot_type}. Use 'scatter' or 'bar'.")
445
+
446
+ if backend == "plotly":
447
+ return _render_bacon_plotly(
448
+ results=results,
449
+ plot_type=plot_type,
450
+ title=title,
451
+ xlabel=xlabel,
452
+ ylabel=ylabel,
453
+ colors=colors,
454
+ marker=marker,
455
+ markersize=markersize,
456
+ alpha=alpha,
457
+ show_weighted_avg=show_weighted_avg,
458
+ show_twfe_line=show_twfe_line,
459
+ show=show,
460
+ )
461
+
462
+ return _render_bacon_mpl(
463
+ results=results,
464
+ plot_type=plot_type,
465
+ figsize=figsize,
466
+ title=title,
467
+ xlabel=xlabel,
468
+ ylabel=ylabel,
469
+ colors=colors,
470
+ marker=marker,
471
+ markersize=markersize,
472
+ alpha=alpha,
473
+ show_weighted_avg=show_weighted_avg,
474
+ show_twfe_line=show_twfe_line,
475
+ ax=ax,
476
+ show=show,
477
+ )
478
+
479
+
480
+ def _render_bacon_mpl(
481
+ *,
482
+ results,
483
+ plot_type,
484
+ figsize,
485
+ title,
486
+ xlabel,
487
+ ylabel,
488
+ colors,
489
+ marker,
490
+ markersize,
491
+ alpha,
492
+ show_weighted_avg,
493
+ show_twfe_line,
494
+ ax,
495
+ show,
496
+ ):
497
+ """Render Bacon decomposition plot with matplotlib."""
498
+ from diff_diff.visualization._common import _require_matplotlib
499
+
500
+ plt = _require_matplotlib()
501
+
502
+ if ax is None:
503
+ fig, ax = plt.subplots(figsize=figsize)
504
+ else:
505
+ fig = ax.get_figure()
506
+
507
+ if plot_type == "scatter":
508
+ _plot_bacon_scatter(
509
+ ax,
510
+ results,
511
+ colors,
512
+ marker,
513
+ markersize,
514
+ alpha,
515
+ show_weighted_avg,
516
+ show_twfe_line,
517
+ xlabel,
518
+ ylabel,
519
+ title,
520
+ )
521
+ else:
522
+ _plot_bacon_bar(ax, results, colors, alpha, ylabel, title)
523
+
524
+ fig.tight_layout()
525
+
526
+ if show:
527
+ plt.show()
528
+
529
+ return ax
530
+
531
+
532
+ def _plot_bacon_scatter(
533
+ ax: Any,
534
+ results: "BaconDecompositionResults",
535
+ colors: Dict[str, str],
536
+ marker: str,
537
+ markersize: int,
538
+ alpha: float,
539
+ show_weighted_avg: bool,
540
+ show_twfe_line: bool,
541
+ xlabel: str,
542
+ ylabel: str,
543
+ title: str,
544
+ ) -> None:
545
+ """Create scatter plot of Bacon decomposition."""
546
+ # Separate comparisons by type
547
+ by_type: Dict[str, List[Tuple[float, float]]] = {
548
+ "treated_vs_never": [],
549
+ "earlier_vs_later": [],
550
+ "later_vs_earlier": [],
551
+ }
552
+
553
+ for comp in results.comparisons:
554
+ by_type[comp.comparison_type].append((comp.estimate, comp.weight))
555
+
556
+ # Plot each type
557
+ labels = {
558
+ "treated_vs_never": "Treated vs Never-treated",
559
+ "earlier_vs_later": "Earlier vs Later treated",
560
+ "later_vs_earlier": "Later vs Earlier (forbidden)",
561
+ }
562
+
563
+ for ctype, points in by_type.items():
564
+ if not points:
565
+ continue
566
+ estimates = [p[0] for p in points]
567
+ weights = [p[1] for p in points]
568
+ ax.scatter(
569
+ estimates,
570
+ weights,
571
+ c=colors[ctype],
572
+ label=labels[ctype],
573
+ marker=marker,
574
+ s=markersize,
575
+ alpha=alpha,
576
+ edgecolors="white",
577
+ linewidths=0.5,
578
+ )
579
+
580
+ # Show weighted average lines
581
+ if show_weighted_avg:
582
+ effect_by_type = results.effect_by_type()
583
+ for ctype, avg_effect in effect_by_type.items():
584
+ if avg_effect is not None and by_type[ctype]:
585
+ ax.axvline(
586
+ x=avg_effect,
587
+ color=colors[ctype],
588
+ linestyle="--",
589
+ alpha=0.5,
590
+ linewidth=1.5,
591
+ )
592
+
593
+ # Show TWFE estimate line
594
+ if show_twfe_line:
595
+ ax.axvline(
596
+ x=results.twfe_estimate,
597
+ color="black",
598
+ linestyle="-",
599
+ linewidth=2,
600
+ label=f"TWFE = {results.twfe_estimate:.4f}",
601
+ )
602
+
603
+ ax.set_xlabel(xlabel)
604
+ ax.set_ylabel(ylabel)
605
+ ax.set_title(title)
606
+ ax.legend(loc="best")
607
+ ax.grid(True, alpha=0.3)
608
+
609
+ # Add zero line
610
+ ax.axvline(x=0, color="gray", linestyle=":", alpha=0.5)
611
+
612
+
613
+ def _plot_bacon_bar(
614
+ ax: Any,
615
+ results: "BaconDecompositionResults",
616
+ colors: Dict[str, str],
617
+ alpha: float,
618
+ ylabel: str,
619
+ title: str,
620
+ ) -> None:
621
+ """Create stacked bar chart of weights by comparison type."""
622
+ # Get weights
623
+ weights = results.weight_by_type()
624
+
625
+ # Labels and colors
626
+ type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"]
627
+ labels = {
628
+ "treated_vs_never": "Treated vs Never-treated",
629
+ "earlier_vs_later": "Earlier vs Later",
630
+ "later_vs_earlier": "Later vs Earlier\n(forbidden)",
631
+ }
632
+
633
+ # Create bar data
634
+ bar_labels = [labels[t] for t in type_order]
635
+ bar_weights = [weights[t] for t in type_order]
636
+ bar_colors = [colors[t] for t in type_order]
637
+
638
+ # Create bars
639
+ bars = ax.bar(
640
+ bar_labels,
641
+ bar_weights,
642
+ color=bar_colors,
643
+ alpha=alpha,
644
+ edgecolor="white",
645
+ linewidth=1,
646
+ )
647
+
648
+ # Add percentage labels on bars
649
+ for bar, weight in zip(bars, bar_weights):
650
+ if weight > 0.01: # Only label if > 1%
651
+ height = bar.get_height()
652
+ ax.annotate(
653
+ f"{weight:.1%}",
654
+ xy=(bar.get_x() + bar.get_width() / 2, height),
655
+ xytext=(0, 3),
656
+ textcoords="offset points",
657
+ ha="center",
658
+ va="bottom",
659
+ fontsize=10,
660
+ fontweight="bold",
661
+ )
662
+
663
+ # Add weighted average effect annotations
664
+ effects = results.effect_by_type()
665
+ for bar, ctype in zip(bars, type_order):
666
+ effect = effects[ctype]
667
+ if effect is not None and weights[ctype] > 0.01:
668
+ ax.annotate(
669
+ f"β = {effect:.3f}",
670
+ xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2),
671
+ ha="center",
672
+ va="center",
673
+ fontsize=9,
674
+ color="white",
675
+ fontweight="bold",
676
+ )
677
+
678
+ ax.set_ylabel(ylabel)
679
+ ax.set_title(title)
680
+ ax.set_ylim(0, 1.1)
681
+
682
+ # Add horizontal line at total weight = 1
683
+ ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
684
+
685
+ # Add TWFE estimate as text
686
+ ax.text(
687
+ 0.98,
688
+ 0.98,
689
+ f"TWFE = {results.twfe_estimate:.4f}",
690
+ transform=ax.transAxes,
691
+ ha="right",
692
+ va="top",
693
+ fontsize=10,
694
+ bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
695
+ )
696
+
697
+
698
+ def _render_bacon_plotly(
699
+ *,
700
+ results,
701
+ plot_type,
702
+ title,
703
+ xlabel,
704
+ ylabel,
705
+ colors,
706
+ marker,
707
+ markersize,
708
+ alpha,
709
+ show_weighted_avg,
710
+ show_twfe_line,
711
+ show,
712
+ ):
713
+ """Render Bacon decomposition plot with plotly."""
714
+ from diff_diff.visualization._common import (
715
+ _mpl_marker_to_plotly_symbol,
716
+ _plotly_default_layout,
717
+ _require_plotly,
718
+ )
719
+
720
+ go = _require_plotly()
721
+
722
+ fig = go.Figure()
723
+
724
+ if plot_type == "scatter":
725
+ # Separate comparisons by type
726
+ by_type = {
727
+ "treated_vs_never": [],
728
+ "earlier_vs_later": [],
729
+ "later_vs_earlier": [],
730
+ }
731
+ for comp in results.comparisons:
732
+ by_type[comp.comparison_type].append((comp.estimate, comp.weight))
733
+
734
+ labels = {
735
+ "treated_vs_never": "Treated vs Never-treated",
736
+ "earlier_vs_later": "Earlier vs Later treated",
737
+ "later_vs_earlier": "Later vs Earlier (forbidden)",
738
+ }
739
+
740
+ # Convert matplotlib scatter area (points^2) to plotly diameter (px)
741
+ plotly_size = max(1, int(round(markersize**0.5)))
742
+ symbol = _mpl_marker_to_plotly_symbol(marker)
743
+
744
+ for ctype, points in by_type.items():
745
+ if not points:
746
+ continue
747
+ estimates = [p[0] for p in points]
748
+ weights = [p[1] for p in points]
749
+ fig.add_trace(
750
+ go.Scatter(
751
+ x=estimates,
752
+ y=weights,
753
+ mode="markers",
754
+ marker=dict(
755
+ color=colors[ctype],
756
+ size=plotly_size,
757
+ symbol=symbol,
758
+ opacity=alpha,
759
+ ),
760
+ name=labels[ctype],
761
+ )
762
+ )
763
+
764
+ # Weighted average lines
765
+ if show_weighted_avg:
766
+ effect_by_type = results.effect_by_type()
767
+ for ctype, avg_effect in effect_by_type.items():
768
+ if avg_effect is not None and by_type[ctype]:
769
+ fig.add_vline(
770
+ x=avg_effect,
771
+ line_dash="dash",
772
+ line_color=colors[ctype],
773
+ opacity=0.5,
774
+ line_width=1.5,
775
+ )
776
+
777
+ # TWFE line
778
+ if show_twfe_line:
779
+ fig.add_vline(
780
+ x=results.twfe_estimate,
781
+ line_color="black",
782
+ line_width=2,
783
+ annotation_text=f"TWFE = {results.twfe_estimate:.4f}",
784
+ )
785
+
786
+ # Zero line
787
+ fig.add_vline(x=0, line_dash="dot", line_color="gray", opacity=0.5)
788
+
789
+ _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
790
+
791
+ else: # bar
792
+ weights = results.weight_by_type()
793
+ type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"]
794
+ labels = {
795
+ "treated_vs_never": "Treated vs Never-treated",
796
+ "earlier_vs_later": "Earlier vs Later",
797
+ "later_vs_earlier": "Later vs Earlier (forbidden)",
798
+ }
799
+
800
+ fig.add_trace(
801
+ go.Bar(
802
+ x=[labels[t] for t in type_order],
803
+ y=[weights[t] for t in type_order],
804
+ marker_color=[colors[t] for t in type_order],
805
+ opacity=alpha,
806
+ text=[f"{weights[t]:.1%}" for t in type_order],
807
+ textposition="outside",
808
+ )
809
+ )
810
+
811
+ fig.update_layout(yaxis_range=[0, 1.1])
812
+ _plotly_default_layout(fig, title=title, xlabel=None, ylabel=ylabel, show_legend=False)
813
+
814
+ if show:
815
+ fig.show()
816
+
817
+ return fig