diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1627 @@
1
+ """
2
+ Visualization functions for difference-in-differences analysis.
3
+
4
+ Provides event study plots and other diagnostic visualizations.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ if TYPE_CHECKING:
13
+ from diff_diff.bacon import BaconDecompositionResults
14
+ from diff_diff.honest_did import HonestDiDResults, SensitivityResults
15
+ from diff_diff.power import PowerResults, SimulationPowerResults
16
+ from diff_diff.pretrends import PreTrendsPowerCurve, PreTrendsPowerResults
17
+ from diff_diff.results import MultiPeriodDiDResults
18
+ from diff_diff.staggered import CallawaySantAnnaResults
19
+ from diff_diff.sun_abraham import SunAbrahamResults
20
+
21
+ # Type alias for results that can be plotted
22
+ PlottableResults = Union[
23
+ "MultiPeriodDiDResults",
24
+ "CallawaySantAnnaResults",
25
+ "SunAbrahamResults",
26
+ pd.DataFrame,
27
+ ]
28
+
29
+
30
+ def plot_event_study(
31
+ results: Optional[PlottableResults] = None,
32
+ *,
33
+ effects: Optional[Dict[Any, float]] = None,
34
+ se: Optional[Dict[Any, float]] = None,
35
+ periods: Optional[List[Any]] = None,
36
+ reference_period: Optional[Any] = None,
37
+ pre_periods: Optional[List[Any]] = None,
38
+ post_periods: Optional[List[Any]] = None,
39
+ alpha: float = 0.05,
40
+ figsize: Tuple[float, float] = (10, 6),
41
+ title: str = "Event Study",
42
+ xlabel: str = "Period Relative to Treatment",
43
+ ylabel: str = "Treatment Effect",
44
+ color: str = "#2563eb",
45
+ marker: str = "o",
46
+ markersize: int = 8,
47
+ linewidth: float = 1.5,
48
+ capsize: int = 4,
49
+ show_zero_line: bool = True,
50
+ show_reference_line: bool = True,
51
+ shade_pre: bool = True,
52
+ shade_color: str = "#f0f0f0",
53
+ ax: Optional[Any] = None,
54
+ show: bool = True,
55
+ ) -> Any:
56
+ """
57
+ Create an event study plot showing treatment effects over time.
58
+
59
+ This function creates a coefficient plot with point estimates and
60
+ confidence intervals for each time period, commonly used to visualize
61
+ dynamic treatment effects and assess pre-trends.
62
+
63
+ Parameters
64
+ ----------
65
+ results : MultiPeriodDiDResults, CallawaySantAnnaResults, or DataFrame, optional
66
+ Results object from MultiPeriodDiD, CallawaySantAnna, or a DataFrame
67
+ with columns 'period', 'effect', 'se' (and optionally 'conf_int_lower',
68
+ 'conf_int_upper'). If None, must provide effects and se directly.
69
+ effects : dict, optional
70
+ Dictionary mapping periods to effect estimates. Used if results is None.
71
+ se : dict, optional
72
+ Dictionary mapping periods to standard errors. Used if results is None.
73
+ periods : list, optional
74
+ List of periods to plot. If None, uses all periods from results.
75
+ reference_period : any, optional
76
+ The reference period (normalized to effect=0). Will be shown as a
77
+ hollow marker. If None, tries to infer from results.
78
+ pre_periods : list, optional
79
+ List of pre-treatment periods. Used for shading.
80
+ post_periods : list, optional
81
+ List of post-treatment periods. Used for shading.
82
+ alpha : float, default=0.05
83
+ Significance level for confidence intervals.
84
+ figsize : tuple, default=(10, 6)
85
+ Figure size (width, height) in inches.
86
+ title : str, default="Event Study"
87
+ Plot title.
88
+ xlabel : str, default="Period Relative to Treatment"
89
+ X-axis label.
90
+ ylabel : str, default="Treatment Effect"
91
+ Y-axis label.
92
+ color : str, default="#2563eb"
93
+ Color for points and error bars.
94
+ marker : str, default="o"
95
+ Marker style for point estimates.
96
+ markersize : int, default=8
97
+ Size of markers.
98
+ linewidth : float, default=1.5
99
+ Width of error bar lines.
100
+ capsize : int, default=4
101
+ Size of error bar caps.
102
+ show_zero_line : bool, default=True
103
+ Whether to show a horizontal line at y=0.
104
+ show_reference_line : bool, default=True
105
+ Whether to show a vertical line at the reference period.
106
+ shade_pre : bool, default=True
107
+ Whether to shade the pre-treatment region.
108
+ shade_color : str, default="#f0f0f0"
109
+ Color for pre-treatment shading.
110
+ ax : matplotlib.axes.Axes, optional
111
+ Axes to plot on. If None, creates new figure.
112
+ show : bool, default=True
113
+ Whether to call plt.show() at the end.
114
+
115
+ Returns
116
+ -------
117
+ matplotlib.axes.Axes
118
+ The axes object containing the plot.
119
+
120
+ Examples
121
+ --------
122
+ Using with MultiPeriodDiD results:
123
+
124
+ >>> from diff_diff import MultiPeriodDiD, plot_event_study
125
+ >>> did = MultiPeriodDiD()
126
+ >>> results = did.fit(data, outcome='y', treatment='treated',
127
+ ... time='period', post_periods=[3, 4, 5])
128
+ >>> plot_event_study(results)
129
+
130
+ Using with a DataFrame:
131
+
132
+ >>> df = pd.DataFrame({
133
+ ... 'period': [-2, -1, 0, 1, 2],
134
+ ... 'effect': [0.1, 0.05, 0.0, 0.5, 0.6],
135
+ ... 'se': [0.1, 0.1, 0.0, 0.15, 0.15]
136
+ ... })
137
+ >>> plot_event_study(df, reference_period=0)
138
+
139
+ Using with manual effects:
140
+
141
+ >>> effects = {-2: 0.1, -1: 0.05, 0: 0.0, 1: 0.5, 2: 0.6}
142
+ >>> se = {-2: 0.1, -1: 0.1, 0: 0.0, 1: 0.15, 2: 0.15}
143
+ >>> plot_event_study(effects=effects, se=se, reference_period=0)
144
+
145
+ Notes
146
+ -----
147
+ Event study plots are a standard visualization in difference-in-differences
148
+ analysis. They show:
149
+
150
+ 1. **Pre-treatment periods**: Effects should be close to zero if parallel
151
+ trends holds. Large pre-treatment effects suggest the assumption may
152
+ be violated.
153
+
154
+ 2. **Reference period**: Usually the last pre-treatment period (t=-1),
155
+ normalized to zero. This is the omitted category.
156
+
157
+ 3. **Post-treatment periods**: The treatment effects of interest. These
158
+ show how the outcome evolved after treatment.
159
+
160
+ The confidence intervals help assess statistical significance. Effects
161
+ whose CIs don't include zero are typically considered significant.
162
+ """
163
+ try:
164
+ import matplotlib.pyplot as plt
165
+ except ImportError:
166
+ raise ImportError(
167
+ "matplotlib is required for plotting. "
168
+ "Install it with: pip install matplotlib"
169
+ )
170
+
171
+ from scipy import stats as scipy_stats
172
+
173
+ # Extract data from results if provided
174
+ if results is not None:
175
+ effects, se, periods, pre_periods, post_periods, reference_period = \
176
+ _extract_plot_data(results, periods, pre_periods, post_periods, reference_period)
177
+ elif effects is None or se is None:
178
+ raise ValueError(
179
+ "Must provide either 'results' or both 'effects' and 'se'"
180
+ )
181
+
182
+ # Ensure effects and se are dicts
183
+ if not isinstance(effects, dict):
184
+ raise TypeError("effects must be a dictionary mapping periods to values")
185
+ if not isinstance(se, dict):
186
+ raise TypeError("se must be a dictionary mapping periods to values")
187
+
188
+ # Get periods to plot
189
+ if periods is None:
190
+ periods = sorted(effects.keys())
191
+
192
+ # Compute confidence intervals
193
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
194
+
195
+ plot_data = []
196
+ for period in periods:
197
+ effect = effects.get(period, np.nan)
198
+ std_err = se.get(period, np.nan)
199
+
200
+ if np.isnan(effect) or np.isnan(std_err):
201
+ continue
202
+
203
+ ci_lower = effect - critical_value * std_err
204
+ ci_upper = effect + critical_value * std_err
205
+
206
+ plot_data.append({
207
+ 'period': period,
208
+ 'effect': effect,
209
+ 'se': std_err,
210
+ 'ci_lower': ci_lower,
211
+ 'ci_upper': ci_upper,
212
+ 'is_reference': period == reference_period,
213
+ })
214
+
215
+ if not plot_data:
216
+ raise ValueError("No valid data to plot")
217
+
218
+ df = pd.DataFrame(plot_data)
219
+
220
+ # Create figure if needed
221
+ if ax is None:
222
+ fig, ax = plt.subplots(figsize=figsize)
223
+ else:
224
+ fig = ax.get_figure()
225
+
226
+ # Convert periods to numeric for plotting
227
+ period_to_x = {p: i for i, p in enumerate(df['period'])}
228
+ x_vals = [period_to_x[p] for p in df['period']]
229
+
230
+ # Shade pre-treatment region
231
+ if shade_pre and pre_periods is not None:
232
+ pre_x = [period_to_x[p] for p in pre_periods if p in period_to_x]
233
+ if pre_x:
234
+ ax.axvspan(min(pre_x) - 0.5, max(pre_x) + 0.5,
235
+ color=shade_color, alpha=0.5, zorder=0)
236
+
237
+ # Draw horizontal zero line
238
+ if show_zero_line:
239
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, zorder=1)
240
+
241
+ # Draw vertical reference line
242
+ if show_reference_line and reference_period is not None:
243
+ if reference_period in period_to_x:
244
+ ref_x = period_to_x[reference_period]
245
+ ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)
246
+
247
+ # Plot error bars
248
+ yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']]
249
+ ax.errorbar(
250
+ x_vals, df['effect'], yerr=yerr,
251
+ fmt='none', color=color, capsize=capsize, linewidth=linewidth,
252
+ capthick=linewidth, zorder=2
253
+ )
254
+
255
+ # Plot point estimates
256
+ for i, row in df.iterrows():
257
+ x = period_to_x[row['period']]
258
+ if row['is_reference']:
259
+ # Hollow marker for reference period
260
+ ax.plot(x, row['effect'], marker=marker, markersize=markersize,
261
+ markerfacecolor='white', markeredgecolor=color,
262
+ markeredgewidth=2, zorder=3)
263
+ else:
264
+ ax.plot(x, row['effect'], marker=marker, markersize=markersize,
265
+ color=color, zorder=3)
266
+
267
+ # Set labels and title
268
+ ax.set_xlabel(xlabel)
269
+ ax.set_ylabel(ylabel)
270
+ ax.set_title(title)
271
+
272
+ # Set x-axis ticks
273
+ ax.set_xticks(x_vals)
274
+ ax.set_xticklabels([str(p) for p in df['period']])
275
+
276
+ # Add grid
277
+ ax.grid(True, alpha=0.3, axis='y')
278
+
279
+ # Tight layout
280
+ fig.tight_layout()
281
+
282
+ if show:
283
+ plt.show()
284
+
285
+ return ax
286
+
287
+
288
+ def _extract_plot_data(
289
+ results: PlottableResults,
290
+ periods: Optional[List[Any]],
291
+ pre_periods: Optional[List[Any]],
292
+ post_periods: Optional[List[Any]],
293
+ reference_period: Optional[Any],
294
+ ) -> Tuple[Dict, Dict, List, List, List, Any]:
295
+ """
296
+ Extract plotting data from various result types.
297
+
298
+ Returns
299
+ -------
300
+ tuple
301
+ (effects, se, periods, pre_periods, post_periods, reference_period)
302
+ """
303
+ # Handle DataFrame input
304
+ if isinstance(results, pd.DataFrame):
305
+ if 'period' not in results.columns:
306
+ raise ValueError("DataFrame must have 'period' column")
307
+ if 'effect' not in results.columns:
308
+ raise ValueError("DataFrame must have 'effect' column")
309
+ if 'se' not in results.columns:
310
+ raise ValueError("DataFrame must have 'se' column")
311
+
312
+ effects = dict(zip(results['period'], results['effect']))
313
+ se = dict(zip(results['period'], results['se']))
314
+
315
+ if periods is None:
316
+ periods = list(results['period'])
317
+
318
+ return effects, se, periods, pre_periods, post_periods, reference_period
319
+
320
+ # Handle MultiPeriodDiDResults
321
+ if hasattr(results, 'period_effects'):
322
+ effects = {}
323
+ se = {}
324
+
325
+ for period, pe in results.period_effects.items():
326
+ effects[period] = pe.effect
327
+ se[period] = pe.se
328
+
329
+ if pre_periods is None and hasattr(results, 'pre_periods'):
330
+ pre_periods = results.pre_periods
331
+
332
+ if post_periods is None and hasattr(results, 'post_periods'):
333
+ post_periods = results.post_periods
334
+
335
+ if periods is None:
336
+ periods = post_periods
337
+
338
+ return effects, se, periods, pre_periods, post_periods, reference_period
339
+
340
+ # Handle CallawaySantAnnaResults (event study aggregation)
341
+ if hasattr(results, 'event_study_effects') and results.event_study_effects is not None:
342
+ effects = {}
343
+ se = {}
344
+
345
+ for rel_period, effect_data in results.event_study_effects.items():
346
+ effects[rel_period] = effect_data['effect']
347
+ se[rel_period] = effect_data['se']
348
+
349
+ if periods is None:
350
+ periods = sorted(effects.keys())
351
+
352
+ # Reference period is typically -1 for event study
353
+ if reference_period is None:
354
+ reference_period = -1
355
+
356
+ if pre_periods is None:
357
+ pre_periods = [p for p in periods if p < 0]
358
+
359
+ if post_periods is None:
360
+ post_periods = [p for p in periods if p >= 0]
361
+
362
+ return effects, se, periods, pre_periods, post_periods, reference_period
363
+
364
+ raise TypeError(
365
+ f"Cannot extract plot data from {type(results).__name__}. "
366
+ "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, "
367
+ "SunAbrahamResults, or DataFrame."
368
+ )
369
+
370
+
371
+ def plot_group_effects(
372
+ results: "CallawaySantAnnaResults",
373
+ *,
374
+ groups: Optional[List[Any]] = None,
375
+ figsize: Tuple[float, float] = (10, 6),
376
+ title: str = "Treatment Effects by Cohort",
377
+ xlabel: str = "Time Period",
378
+ ylabel: str = "Treatment Effect",
379
+ alpha: float = 0.05,
380
+ show: bool = True,
381
+ ax: Optional[Any] = None,
382
+ ) -> Any:
383
+ """
384
+ Plot treatment effects by treatment cohort (group).
385
+
386
+ Parameters
387
+ ----------
388
+ results : CallawaySantAnnaResults
389
+ Results from CallawaySantAnna estimator.
390
+ groups : list, optional
391
+ List of groups (cohorts) to plot. If None, plots all groups.
392
+ figsize : tuple, default=(10, 6)
393
+ Figure size.
394
+ title : str
395
+ Plot title.
396
+ xlabel : str
397
+ X-axis label.
398
+ ylabel : str
399
+ Y-axis label.
400
+ alpha : float, default=0.05
401
+ Significance level for confidence intervals.
402
+ show : bool, default=True
403
+ Whether to call plt.show().
404
+ ax : matplotlib.axes.Axes, optional
405
+ Axes to plot on.
406
+
407
+ Returns
408
+ -------
409
+ matplotlib.axes.Axes
410
+ The axes object.
411
+ """
412
+ try:
413
+ import matplotlib.pyplot as plt
414
+ except ImportError:
415
+ raise ImportError(
416
+ "matplotlib is required for plotting. "
417
+ "Install it with: pip install matplotlib"
418
+ )
419
+
420
+ from scipy import stats as scipy_stats
421
+
422
+ if not hasattr(results, 'group_time_effects'):
423
+ raise TypeError("results must be a CallawaySantAnnaResults object")
424
+
425
+ # Get groups to plot
426
+ if groups is None:
427
+ groups = sorted(set(g for g, t in results.group_time_effects.keys()))
428
+
429
+ # Create figure
430
+ if ax is None:
431
+ fig, ax = plt.subplots(figsize=figsize)
432
+ else:
433
+ fig = ax.get_figure()
434
+
435
+ # Color palette
436
+ colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))
437
+
438
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
439
+
440
+ for i, group in enumerate(groups):
441
+ # Get effects for this group
442
+ group_effects = [
443
+ (t, data) for (g, t), data in results.group_time_effects.items()
444
+ if g == group
445
+ ]
446
+ group_effects.sort(key=lambda x: x[0])
447
+
448
+ if not group_effects:
449
+ continue
450
+
451
+ times = [t for t, _ in group_effects]
452
+ effects = [data['effect'] for _, data in group_effects]
453
+ ses = [data['se'] for _, data in group_effects]
454
+
455
+ yerr = [
456
+ [e - (e - critical_value * s) for e, s in zip(effects, ses)],
457
+ [(e + critical_value * s) - e for e, s in zip(effects, ses)]
458
+ ]
459
+
460
+ ax.errorbar(
461
+ times, effects, yerr=yerr,
462
+ label=f'Cohort {group}', color=colors[i],
463
+ marker='o', capsize=3, linewidth=1.5
464
+ )
465
+
466
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1)
467
+ ax.set_xlabel(xlabel)
468
+ ax.set_ylabel(ylabel)
469
+ ax.set_title(title)
470
+ ax.legend(loc='best')
471
+ ax.grid(True, alpha=0.3, axis='y')
472
+
473
+ fig.tight_layout()
474
+
475
+ if show:
476
+ plt.show()
477
+
478
+ return ax
479
+
480
+
481
+ def plot_sensitivity(
482
+ sensitivity_results: "SensitivityResults",
483
+ *,
484
+ show_bounds: bool = True,
485
+ show_ci: bool = True,
486
+ breakdown_line: bool = True,
487
+ figsize: Tuple[float, float] = (10, 6),
488
+ title: str = "Honest DiD Sensitivity Analysis",
489
+ xlabel: str = "M (restriction parameter)",
490
+ ylabel: str = "Treatment Effect",
491
+ bounds_color: str = "#2563eb",
492
+ bounds_alpha: float = 0.3,
493
+ ci_color: str = "#2563eb",
494
+ ci_linewidth: float = 1.5,
495
+ breakdown_color: str = "#dc2626",
496
+ original_color: str = "#1f2937",
497
+ ax: Optional[Any] = None,
498
+ show: bool = True,
499
+ ) -> Any:
500
+ """
501
+ Plot sensitivity analysis results from Honest DiD.
502
+
503
+ Shows how treatment effect bounds and confidence intervals
504
+ change as the restriction parameter M varies.
505
+
506
+ Parameters
507
+ ----------
508
+ sensitivity_results : SensitivityResults
509
+ Results from HonestDiD.sensitivity_analysis().
510
+ show_bounds : bool, default=True
511
+ Whether to show the identified set bounds as shaded region.
512
+ show_ci : bool, default=True
513
+ Whether to show robust confidence interval lines.
514
+ breakdown_line : bool, default=True
515
+ Whether to show vertical line at breakdown value.
516
+ figsize : tuple, default=(10, 6)
517
+ Figure size (width, height) in inches.
518
+ title : str
519
+ Plot title.
520
+ xlabel : str
521
+ X-axis label.
522
+ ylabel : str
523
+ Y-axis label.
524
+ bounds_color : str
525
+ Color for identified set shading.
526
+ bounds_alpha : float
527
+ Transparency for identified set shading.
528
+ ci_color : str
529
+ Color for confidence interval lines.
530
+ ci_linewidth : float
531
+ Line width for CI lines.
532
+ breakdown_color : str
533
+ Color for breakdown value line.
534
+ original_color : str
535
+ Color for original estimate line.
536
+ ax : matplotlib.axes.Axes, optional
537
+ Axes to plot on. If None, creates new figure.
538
+ show : bool, default=True
539
+ Whether to call plt.show().
540
+
541
+ Returns
542
+ -------
543
+ matplotlib.axes.Axes
544
+ The axes object containing the plot.
545
+
546
+ Examples
547
+ --------
548
+ >>> from diff_diff import MultiPeriodDiD
549
+ >>> from diff_diff.honest_did import HonestDiD
550
+ >>> from diff_diff.visualization import plot_sensitivity
551
+ >>>
552
+ >>> # Fit event study and run sensitivity analysis
553
+ >>> results = MultiPeriodDiD().fit(data, ...)
554
+ >>> honest = HonestDiD(method='relative_magnitude')
555
+ >>> sensitivity = honest.sensitivity_analysis(results)
556
+ >>>
557
+ >>> # Create sensitivity plot
558
+ >>> plot_sensitivity(sensitivity)
559
+ """
560
+ try:
561
+ import matplotlib.pyplot as plt
562
+ except ImportError:
563
+ raise ImportError(
564
+ "matplotlib is required for plotting. "
565
+ "Install it with: pip install matplotlib"
566
+ )
567
+
568
+ # Create figure if needed
569
+ if ax is None:
570
+ fig, ax = plt.subplots(figsize=figsize)
571
+ else:
572
+ fig = ax.get_figure()
573
+
574
+ M = sensitivity_results.M_values
575
+ bounds_arr = np.array(sensitivity_results.bounds)
576
+ ci_arr = np.array(sensitivity_results.robust_cis)
577
+
578
+ # Plot original estimate
579
+ ax.axhline(
580
+ y=sensitivity_results.original_estimate,
581
+ color=original_color,
582
+ linestyle='-',
583
+ linewidth=1.5,
584
+ label='Original estimate',
585
+ alpha=0.7
586
+ )
587
+
588
+ # Plot zero line
589
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
590
+
591
+ # Plot identified set bounds
592
+ if show_bounds:
593
+ ax.fill_between(
594
+ M, bounds_arr[:, 0], bounds_arr[:, 1],
595
+ alpha=bounds_alpha,
596
+ color=bounds_color,
597
+ label='Identified set'
598
+ )
599
+
600
+ # Plot confidence intervals
601
+ if show_ci:
602
+ ax.plot(
603
+ M, ci_arr[:, 0],
604
+ color=ci_color,
605
+ linewidth=ci_linewidth,
606
+ label='Robust CI'
607
+ )
608
+ ax.plot(
609
+ M, ci_arr[:, 1],
610
+ color=ci_color,
611
+ linewidth=ci_linewidth
612
+ )
613
+
614
+ # Plot breakdown line
615
+ if breakdown_line and sensitivity_results.breakdown_M is not None:
616
+ ax.axvline(
617
+ x=sensitivity_results.breakdown_M,
618
+ color=breakdown_color,
619
+ linestyle=':',
620
+ linewidth=2,
621
+ label=f'Breakdown (M={sensitivity_results.breakdown_M:.2f})'
622
+ )
623
+
624
+ ax.set_xlabel(xlabel)
625
+ ax.set_ylabel(ylabel)
626
+ ax.set_title(title)
627
+ ax.legend(loc='best')
628
+ ax.grid(True, alpha=0.3)
629
+
630
+ fig.tight_layout()
631
+
632
+ if show:
633
+ plt.show()
634
+
635
+ return ax
636
+
637
+
638
+ def plot_honest_event_study(
639
+ honest_results: "HonestDiDResults",
640
+ *,
641
+ periods: Optional[List[Any]] = None,
642
+ reference_period: Optional[Any] = None,
643
+ figsize: Tuple[float, float] = (10, 6),
644
+ title: str = "Event Study with Honest Confidence Intervals",
645
+ xlabel: str = "Period Relative to Treatment",
646
+ ylabel: str = "Treatment Effect",
647
+ original_color: str = "#6b7280",
648
+ honest_color: str = "#2563eb",
649
+ marker: str = "o",
650
+ markersize: int = 8,
651
+ capsize: int = 4,
652
+ ax: Optional[Any] = None,
653
+ show: bool = True,
654
+ ) -> Any:
655
+ """
656
+ Create event study plot with Honest DiD confidence intervals.
657
+
658
+ Shows both the original confidence intervals (assuming parallel trends)
659
+ and the robust confidence intervals that allow for bounded violations.
660
+
661
+ Parameters
662
+ ----------
663
+ honest_results : HonestDiDResults
664
+ Results from HonestDiD.fit() that include event_study_bounds.
665
+ periods : list, optional
666
+ Periods to plot. If None, uses all available periods.
667
+ reference_period : any, optional
668
+ Reference period to show as hollow marker.
669
+ figsize : tuple, default=(10, 6)
670
+ Figure size.
671
+ title : str
672
+ Plot title.
673
+ xlabel : str
674
+ X-axis label.
675
+ ylabel : str
676
+ Y-axis label.
677
+ original_color : str
678
+ Color for original (standard) confidence intervals.
679
+ honest_color : str
680
+ Color for honest (robust) confidence intervals.
681
+ marker : str
682
+ Marker style.
683
+ markersize : int
684
+ Marker size.
685
+ capsize : int
686
+ Error bar cap size.
687
+ ax : matplotlib.axes.Axes, optional
688
+ Axes to plot on.
689
+ show : bool, default=True
690
+ Whether to call plt.show().
691
+
692
+ Returns
693
+ -------
694
+ matplotlib.axes.Axes
695
+ The axes object.
696
+
697
+ Notes
698
+ -----
699
+ This function requires the HonestDiDResults to have been computed
700
+ with event_study_bounds. If only a scalar bound was computed,
701
+ use plot_sensitivity() instead.
702
+ """
703
+ try:
704
+ import matplotlib.pyplot as plt
705
+ except ImportError:
706
+ raise ImportError(
707
+ "matplotlib is required for plotting. "
708
+ "Install it with: pip install matplotlib"
709
+ )
710
+
711
+ from scipy import stats as scipy_stats
712
+
713
+ # Get original results for standard CIs
714
+ original_results = honest_results.original_results
715
+ if original_results is None:
716
+ raise ValueError(
717
+ "HonestDiDResults must have original_results to plot event study"
718
+ )
719
+
720
+ # Extract data from original results
721
+ if hasattr(original_results, 'period_effects'):
722
+ # MultiPeriodDiDResults
723
+ effects_dict = {
724
+ p: pe.effect for p, pe in original_results.period_effects.items()
725
+ }
726
+ se_dict = {
727
+ p: pe.se for p, pe in original_results.period_effects.items()
728
+ }
729
+ if periods is None:
730
+ periods = list(original_results.period_effects.keys())
731
+ elif hasattr(original_results, 'event_study_effects'):
732
+ # CallawaySantAnnaResults
733
+ effects_dict = {
734
+ t: data['effect']
735
+ for t, data in original_results.event_study_effects.items()
736
+ }
737
+ se_dict = {
738
+ t: data['se']
739
+ for t, data in original_results.event_study_effects.items()
740
+ }
741
+ if periods is None:
742
+ periods = sorted(original_results.event_study_effects.keys())
743
+ else:
744
+ raise TypeError("Cannot extract event study data from original_results")
745
+
746
+ # Create figure
747
+ if ax is None:
748
+ fig, ax = plt.subplots(figsize=figsize)
749
+ else:
750
+ fig = ax.get_figure()
751
+
752
+ # Compute CIs
753
+ alpha = honest_results.alpha
754
+ z = scipy_stats.norm.ppf(1 - alpha / 2)
755
+
756
+ x_vals = list(range(len(periods)))
757
+
758
+ effects = [effects_dict[p] for p in periods]
759
+ original_ci_lower = [effects_dict[p] - z * se_dict[p] for p in periods]
760
+ original_ci_upper = [effects_dict[p] + z * se_dict[p] for p in periods]
761
+
762
+ # Get honest bounds if available for each period
763
+ if honest_results.event_study_bounds:
764
+ honest_ci_lower = [
765
+ honest_results.event_study_bounds[p]['ci_lb']
766
+ for p in periods
767
+ ]
768
+ honest_ci_upper = [
769
+ honest_results.event_study_bounds[p]['ci_ub']
770
+ for p in periods
771
+ ]
772
+ else:
773
+ # Use scalar bounds applied to all periods
774
+ honest_ci_lower = [honest_results.ci_lb] * len(periods)
775
+ honest_ci_upper = [honest_results.ci_ub] * len(periods)
776
+
777
+ # Zero line
778
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
779
+
780
+ # Plot original CIs (thinner, background)
781
+ yerr_orig = [
782
+ [e - lower for e, lower in zip(effects, original_ci_lower)],
783
+ [u - e for e, u in zip(effects, original_ci_upper)]
784
+ ]
785
+ ax.errorbar(
786
+ x_vals, effects, yerr=yerr_orig,
787
+ fmt='none', color=original_color, capsize=capsize - 1,
788
+ linewidth=1, alpha=0.6, label='Standard CI'
789
+ )
790
+
791
+ # Plot honest CIs (thicker, foreground)
792
+ yerr_honest = [
793
+ [e - lower for e, lower in zip(effects, honest_ci_lower)],
794
+ [u - e for e, u in zip(effects, honest_ci_upper)]
795
+ ]
796
+ ax.errorbar(
797
+ x_vals, effects, yerr=yerr_honest,
798
+ fmt='none', color=honest_color, capsize=capsize,
799
+ linewidth=2, label=f'Honest CI (M={honest_results.M:.2f})'
800
+ )
801
+
802
+ # Plot point estimates
803
+ for i, (x, effect, period) in enumerate(zip(x_vals, effects, periods)):
804
+ is_ref = period == reference_period
805
+ if is_ref:
806
+ ax.plot(x, effect, marker=marker, markersize=markersize,
807
+ markerfacecolor='white', markeredgecolor=honest_color,
808
+ markeredgewidth=2, zorder=3)
809
+ else:
810
+ ax.plot(x, effect, marker=marker, markersize=markersize,
811
+ color=honest_color, zorder=3)
812
+
813
+ ax.set_xlabel(xlabel)
814
+ ax.set_ylabel(ylabel)
815
+ ax.set_title(title)
816
+ ax.set_xticks(x_vals)
817
+ ax.set_xticklabels([str(p) for p in periods])
818
+ ax.legend(loc='best')
819
+ ax.grid(True, alpha=0.3, axis='y')
820
+
821
+ fig.tight_layout()
822
+
823
+ if show:
824
+ plt.show()
825
+
826
+ return ax
827
+
828
+
829
+ def plot_bacon(
830
+ results: "BaconDecompositionResults",
831
+ *,
832
+ plot_type: str = "scatter",
833
+ figsize: Tuple[float, float] = (10, 6),
834
+ title: Optional[str] = None,
835
+ xlabel: str = "2x2 DiD Estimate",
836
+ ylabel: str = "Weight",
837
+ colors: Optional[Dict[str, str]] = None,
838
+ marker: str = "o",
839
+ markersize: int = 80,
840
+ alpha: float = 0.7,
841
+ show_weighted_avg: bool = True,
842
+ show_twfe_line: bool = True,
843
+ ax: Optional[Any] = None,
844
+ show: bool = True,
845
+ ) -> Any:
846
+ """
847
+ Visualize Goodman-Bacon decomposition results.
848
+
849
+ Creates either a scatter plot showing the weight and estimate for each
850
+ 2x2 comparison, or a stacked bar chart showing total weight by comparison
851
+ type.
852
+
853
+ Parameters
854
+ ----------
855
+ results : BaconDecompositionResults
856
+ Results from BaconDecomposition.fit() or bacon_decompose().
857
+ plot_type : str, default="scatter"
858
+ Type of plot to create:
859
+ - "scatter": Scatter plot with estimates on x-axis, weights on y-axis
860
+ - "bar": Stacked bar chart of weights by comparison type
861
+ figsize : tuple, default=(10, 6)
862
+ Figure size (width, height) in inches.
863
+ title : str, optional
864
+ Plot title. If None, uses a default based on plot_type.
865
+ xlabel : str, default="2x2 DiD Estimate"
866
+ X-axis label (scatter plot only).
867
+ ylabel : str, default="Weight"
868
+ Y-axis label.
869
+ colors : dict, optional
870
+ Dictionary mapping comparison types to colors. Keys are:
871
+ "treated_vs_never", "earlier_vs_later", "later_vs_earlier".
872
+ If None, uses default colors.
873
+ marker : str, default="o"
874
+ Marker style for scatter plot.
875
+ markersize : int, default=80
876
+ Marker size for scatter plot.
877
+ alpha : float, default=0.7
878
+ Transparency for markers/bars.
879
+ show_weighted_avg : bool, default=True
880
+ Whether to show weighted average lines for each comparison type
881
+ (scatter plot only).
882
+ show_twfe_line : bool, default=True
883
+ Whether to show a vertical line at the TWFE estimate (scatter plot only).
884
+ ax : matplotlib.axes.Axes, optional
885
+ Axes to plot on. If None, creates new figure.
886
+ show : bool, default=True
887
+ Whether to call plt.show() at the end.
888
+
889
+ Returns
890
+ -------
891
+ matplotlib.axes.Axes
892
+ The axes object containing the plot.
893
+
894
+ Examples
895
+ --------
896
+ Scatter plot (default):
897
+
898
+ >>> from diff_diff import bacon_decompose, plot_bacon
899
+ >>> results = bacon_decompose(data, outcome='y', unit='id',
900
+ ... time='t', first_treat='first_treat')
901
+ >>> plot_bacon(results)
902
+
903
+ Bar chart of weights by type:
904
+
905
+ >>> plot_bacon(results, plot_type='bar')
906
+
907
+ Customized scatter plot:
908
+
909
+ >>> plot_bacon(results,
910
+ ... colors={'treated_vs_never': 'green',
911
+ ... 'earlier_vs_later': 'blue',
912
+ ... 'later_vs_earlier': 'red'},
913
+ ... title='My Bacon Decomposition')
914
+
915
+ Notes
916
+ -----
917
+ The scatter plot is particularly useful for understanding:
918
+
919
+ 1. **Distribution of estimates**: Are 2x2 estimates clustered or spread?
920
+ Wide spread suggests heterogeneous treatment effects.
921
+
922
+ 2. **Weight concentration**: Do a few comparisons dominate the TWFE?
923
+ Points with high weights have more influence.
924
+
925
+ 3. **Forbidden comparison problem**: Red points (later_vs_earlier) show
926
+ comparisons using already-treated units as controls. If these have
927
+ different estimates than clean comparisons, TWFE may be biased.
928
+
929
+ The bar chart provides a quick summary of how much weight falls on
930
+ each comparison type, which is useful for assessing the severity
931
+ of potential TWFE bias.
932
+
933
+ See Also
934
+ --------
935
+ bacon_decompose : Perform the decomposition
936
+ BaconDecomposition : Class-based interface
937
+ """
938
+ try:
939
+ import matplotlib.pyplot as plt
940
+ except ImportError:
941
+ raise ImportError(
942
+ "matplotlib is required for plotting. "
943
+ "Install it with: pip install matplotlib"
944
+ )
945
+
946
+ # Default colors
947
+ if colors is None:
948
+ colors = {
949
+ "treated_vs_never": "#22c55e", # Green - clean comparison
950
+ "earlier_vs_later": "#3b82f6", # Blue - valid comparison
951
+ "later_vs_earlier": "#ef4444", # Red - forbidden comparison
952
+ }
953
+
954
+ # Default titles
955
+ if title is None:
956
+ if plot_type == "scatter":
957
+ title = "Goodman-Bacon Decomposition"
958
+ else:
959
+ title = "TWFE Weight by Comparison Type"
960
+
961
+ # Create figure if needed
962
+ if ax is None:
963
+ fig, ax = plt.subplots(figsize=figsize)
964
+ else:
965
+ fig = ax.get_figure()
966
+
967
+ if plot_type == "scatter":
968
+ _plot_bacon_scatter(
969
+ ax, results, colors, marker, markersize, alpha,
970
+ show_weighted_avg, show_twfe_line, xlabel, ylabel, title
971
+ )
972
+ elif plot_type == "bar":
973
+ _plot_bacon_bar(ax, results, colors, alpha, ylabel, title)
974
+ else:
975
+ raise ValueError(f"Unknown plot_type: {plot_type}. Use 'scatter' or 'bar'.")
976
+
977
+ fig.tight_layout()
978
+
979
+ if show:
980
+ plt.show()
981
+
982
+ return ax
983
+
984
+
985
+ def _plot_bacon_scatter(
986
+ ax: Any,
987
+ results: "BaconDecompositionResults",
988
+ colors: Dict[str, str],
989
+ marker: str,
990
+ markersize: int,
991
+ alpha: float,
992
+ show_weighted_avg: bool,
993
+ show_twfe_line: bool,
994
+ xlabel: str,
995
+ ylabel: str,
996
+ title: str,
997
+ ) -> None:
998
+ """Create scatter plot of Bacon decomposition."""
999
+ # Separate comparisons by type
1000
+ by_type: Dict[str, List[Tuple[float, float]]] = {
1001
+ "treated_vs_never": [],
1002
+ "earlier_vs_later": [],
1003
+ "later_vs_earlier": [],
1004
+ }
1005
+
1006
+ for comp in results.comparisons:
1007
+ by_type[comp.comparison_type].append((comp.estimate, comp.weight))
1008
+
1009
+ # Plot each type
1010
+ labels = {
1011
+ "treated_vs_never": "Treated vs Never-treated",
1012
+ "earlier_vs_later": "Earlier vs Later treated",
1013
+ "later_vs_earlier": "Later vs Earlier (forbidden)",
1014
+ }
1015
+
1016
+ for ctype, points in by_type.items():
1017
+ if not points:
1018
+ continue
1019
+ estimates = [p[0] for p in points]
1020
+ weights = [p[1] for p in points]
1021
+ ax.scatter(
1022
+ estimates, weights,
1023
+ c=colors[ctype],
1024
+ label=labels[ctype],
1025
+ marker=marker,
1026
+ s=markersize,
1027
+ alpha=alpha,
1028
+ edgecolors='white',
1029
+ linewidths=0.5,
1030
+ )
1031
+
1032
+ # Show weighted average lines
1033
+ if show_weighted_avg:
1034
+ effect_by_type = results.effect_by_type()
1035
+ for ctype, avg_effect in effect_by_type.items():
1036
+ if avg_effect is not None and by_type[ctype]:
1037
+ ax.axvline(
1038
+ x=avg_effect,
1039
+ color=colors[ctype],
1040
+ linestyle='--',
1041
+ alpha=0.5,
1042
+ linewidth=1.5,
1043
+ )
1044
+
1045
+ # Show TWFE estimate line
1046
+ if show_twfe_line:
1047
+ ax.axvline(
1048
+ x=results.twfe_estimate,
1049
+ color='black',
1050
+ linestyle='-',
1051
+ linewidth=2,
1052
+ label=f'TWFE = {results.twfe_estimate:.4f}',
1053
+ )
1054
+
1055
+ ax.set_xlabel(xlabel)
1056
+ ax.set_ylabel(ylabel)
1057
+ ax.set_title(title)
1058
+ ax.legend(loc='best')
1059
+ ax.grid(True, alpha=0.3)
1060
+
1061
+ # Add zero line
1062
+ ax.axvline(x=0, color='gray', linestyle=':', alpha=0.5)
1063
+
1064
+
1065
+ def _plot_bacon_bar(
1066
+ ax: Any,
1067
+ results: "BaconDecompositionResults",
1068
+ colors: Dict[str, str],
1069
+ alpha: float,
1070
+ ylabel: str,
1071
+ title: str,
1072
+ ) -> None:
1073
+ """Create stacked bar chart of weights by comparison type."""
1074
+ # Get weights
1075
+ weights = results.weight_by_type()
1076
+
1077
+ # Labels and colors
1078
+ type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"]
1079
+ labels = {
1080
+ "treated_vs_never": "Treated vs Never-treated",
1081
+ "earlier_vs_later": "Earlier vs Later",
1082
+ "later_vs_earlier": "Later vs Earlier\n(forbidden)",
1083
+ }
1084
+
1085
+ # Create bar data
1086
+ bar_labels = [labels[t] for t in type_order]
1087
+ bar_weights = [weights[t] for t in type_order]
1088
+ bar_colors = [colors[t] for t in type_order]
1089
+
1090
+ # Create bars
1091
+ bars = ax.bar(
1092
+ bar_labels,
1093
+ bar_weights,
1094
+ color=bar_colors,
1095
+ alpha=alpha,
1096
+ edgecolor='white',
1097
+ linewidth=1,
1098
+ )
1099
+
1100
+ # Add percentage labels on bars
1101
+ for bar, weight in zip(bars, bar_weights):
1102
+ if weight > 0.01: # Only label if > 1%
1103
+ height = bar.get_height()
1104
+ ax.annotate(
1105
+ f'{weight:.1%}',
1106
+ xy=(bar.get_x() + bar.get_width() / 2, height),
1107
+ xytext=(0, 3),
1108
+ textcoords="offset points",
1109
+ ha='center',
1110
+ va='bottom',
1111
+ fontsize=10,
1112
+ fontweight='bold',
1113
+ )
1114
+
1115
+ # Add weighted average effect annotations
1116
+ effects = results.effect_by_type()
1117
+ for bar, ctype in zip(bars, type_order):
1118
+ effect = effects[ctype]
1119
+ if effect is not None and weights[ctype] > 0.01:
1120
+ ax.annotate(
1121
+ f'β = {effect:.3f}',
1122
+ xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2),
1123
+ ha='center',
1124
+ va='center',
1125
+ fontsize=9,
1126
+ color='white',
1127
+ fontweight='bold',
1128
+ )
1129
+
1130
+ ax.set_ylabel(ylabel)
1131
+ ax.set_title(title)
1132
+ ax.set_ylim(0, 1.1)
1133
+
1134
+ # Add horizontal line at total weight = 1
1135
+ ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
1136
+
1137
+ # Add TWFE estimate as text
1138
+ ax.text(
1139
+ 0.98, 0.98,
1140
+ f'TWFE = {results.twfe_estimate:.4f}',
1141
+ transform=ax.transAxes,
1142
+ ha='right',
1143
+ va='top',
1144
+ fontsize=10,
1145
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
1146
+ )
1147
+
1148
+
1149
+ def plot_power_curve(
1150
+ results: Optional[Union["PowerResults", "SimulationPowerResults", pd.DataFrame]] = None,
1151
+ *,
1152
+ effect_sizes: Optional[List[float]] = None,
1153
+ powers: Optional[List[float]] = None,
1154
+ mde: Optional[float] = None,
1155
+ target_power: float = 0.80,
1156
+ plot_type: str = "effect",
1157
+ figsize: Tuple[float, float] = (10, 6),
1158
+ title: Optional[str] = None,
1159
+ xlabel: Optional[str] = None,
1160
+ ylabel: str = "Power",
1161
+ color: str = "#2563eb",
1162
+ mde_color: str = "#dc2626",
1163
+ target_color: str = "#22c55e",
1164
+ linewidth: float = 2.0,
1165
+ show_mde_line: bool = True,
1166
+ show_target_line: bool = True,
1167
+ show_grid: bool = True,
1168
+ ax: Optional[Any] = None,
1169
+ show: bool = True,
1170
+ ) -> Any:
1171
+ """
1172
+ Create a power curve visualization.
1173
+
1174
+ Shows how statistical power changes with effect size or sample size,
1175
+ helping researchers understand the trade-offs in study design.
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ results : PowerResults, SimulationPowerResults, or DataFrame, optional
1180
+ Results object from PowerAnalysis or simulate_power(), or a DataFrame
1181
+ with columns 'effect_size' and 'power' (or 'sample_size' and 'power').
1182
+ If None, must provide effect_sizes and powers directly.
1183
+ effect_sizes : list of float, optional
1184
+ Effect sizes (x-axis values). Required if results is None.
1185
+ powers : list of float, optional
1186
+ Power values (y-axis values). Required if results is None.
1187
+ mde : float, optional
1188
+ Minimum detectable effect to mark on the plot.
1189
+ target_power : float, default=0.80
1190
+ Target power level to show as horizontal line.
1191
+ plot_type : str, default="effect"
1192
+ Type of power curve: "effect" (power vs effect size) or
1193
+ "sample" (power vs sample size).
1194
+ figsize : tuple, default=(10, 6)
1195
+ Figure size (width, height) in inches.
1196
+ title : str, optional
1197
+ Plot title. If None, uses a sensible default.
1198
+ xlabel : str, optional
1199
+ X-axis label. If None, uses a sensible default.
1200
+ ylabel : str, default="Power"
1201
+ Y-axis label.
1202
+ color : str, default="#2563eb"
1203
+ Color for the power curve line.
1204
+ mde_color : str, default="#dc2626"
1205
+ Color for the MDE vertical line.
1206
+ target_color : str, default="#22c55e"
1207
+ Color for the target power horizontal line.
1208
+ linewidth : float, default=2.0
1209
+ Line width for the power curve.
1210
+ show_mde_line : bool, default=True
1211
+ Whether to show vertical line at MDE.
1212
+ show_target_line : bool, default=True
1213
+ Whether to show horizontal line at target power.
1214
+ show_grid : bool, default=True
1215
+ Whether to show grid lines.
1216
+ ax : matplotlib.axes.Axes, optional
1217
+ Axes to plot on. If None, creates new figure.
1218
+ show : bool, default=True
1219
+ Whether to call plt.show() at the end.
1220
+
1221
+ Returns
1222
+ -------
1223
+ matplotlib.axes.Axes
1224
+ The axes object containing the plot.
1225
+
1226
+ Examples
1227
+ --------
1228
+ From PowerAnalysis results:
1229
+
1230
+ >>> from diff_diff import PowerAnalysis, plot_power_curve
1231
+ >>> pa = PowerAnalysis(power=0.80)
1232
+ >>> curve_df = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
1233
+ >>> mde_result = pa.mde(n_treated=50, n_control=50, sigma=5.0)
1234
+ >>> plot_power_curve(curve_df, mde=mde_result.mde)
1235
+
1236
+ From simulation results:
1237
+
1238
+ >>> from diff_diff import simulate_power, DifferenceInDifferences
1239
+ >>> results = simulate_power(
1240
+ ... DifferenceInDifferences(),
1241
+ ... effect_sizes=[1, 2, 3, 5, 7, 10],
1242
+ ... n_simulations=200
1243
+ ... )
1244
+ >>> plot_power_curve(results)
1245
+
1246
+ Manual data:
1247
+
1248
+ >>> plot_power_curve(
1249
+ ... effect_sizes=[1, 2, 3, 4, 5],
1250
+ ... powers=[0.2, 0.5, 0.75, 0.90, 0.97],
1251
+ ... mde=2.5,
1252
+ ... target_power=0.80
1253
+ ... )
1254
+ """
1255
+ try:
1256
+ import matplotlib.pyplot as plt
1257
+ except ImportError:
1258
+ raise ImportError(
1259
+ "matplotlib is required for plotting. "
1260
+ "Install it with: pip install matplotlib"
1261
+ )
1262
+
1263
+ # Extract data from results if provided
1264
+ if results is not None:
1265
+ if isinstance(results, pd.DataFrame):
1266
+ # DataFrame input
1267
+ if "effect_size" in results.columns:
1268
+ effect_sizes = results["effect_size"].tolist()
1269
+ plot_type = "effect"
1270
+ elif "sample_size" in results.columns:
1271
+ effect_sizes = results["sample_size"].tolist()
1272
+ plot_type = "sample"
1273
+ else:
1274
+ raise ValueError(
1275
+ "DataFrame must have 'effect_size' or 'sample_size' column"
1276
+ )
1277
+ powers = results["power"].tolist()
1278
+ elif hasattr(results, "effect_sizes") and hasattr(results, "powers"):
1279
+ # SimulationPowerResults
1280
+ effect_sizes = results.effect_sizes
1281
+ powers = results.powers
1282
+ if mde is None and hasattr(results, "true_effect"):
1283
+ # Mark true effect on plot
1284
+ mde = results.true_effect
1285
+ elif hasattr(results, "mde"):
1286
+ # PowerResults - create curve data
1287
+ raise ValueError(
1288
+ "PowerResults should be used to get mde value, not as direct input. "
1289
+ "Use PowerAnalysis.power_curve() to generate curve data."
1290
+ )
1291
+ else:
1292
+ raise TypeError(
1293
+ f"Cannot extract power curve data from {type(results).__name__}"
1294
+ )
1295
+ elif effect_sizes is None or powers is None:
1296
+ raise ValueError(
1297
+ "Must provide either 'results' or both 'effect_sizes' and 'powers'"
1298
+ )
1299
+
1300
+ # Default titles and labels
1301
+ if title is None:
1302
+ if plot_type == "effect":
1303
+ title = "Power Curve"
1304
+ else:
1305
+ title = "Power vs Sample Size"
1306
+
1307
+ if xlabel is None:
1308
+ if plot_type == "effect":
1309
+ xlabel = "Effect Size"
1310
+ else:
1311
+ xlabel = "Sample Size"
1312
+
1313
+ # Create figure if needed
1314
+ if ax is None:
1315
+ fig, ax = plt.subplots(figsize=figsize)
1316
+ else:
1317
+ fig = ax.get_figure()
1318
+
1319
+ # Plot power curve
1320
+ ax.plot(
1321
+ effect_sizes, powers,
1322
+ color=color,
1323
+ linewidth=linewidth,
1324
+ label="Power"
1325
+ )
1326
+
1327
+ # Add target power line
1328
+ if show_target_line:
1329
+ ax.axhline(
1330
+ y=target_power,
1331
+ color=target_color,
1332
+ linestyle="--",
1333
+ linewidth=1.5,
1334
+ alpha=0.7,
1335
+ label=f"Target power ({target_power:.0%})"
1336
+ )
1337
+
1338
+ # Add MDE line
1339
+ if show_mde_line and mde is not None:
1340
+ ax.axvline(
1341
+ x=mde,
1342
+ color=mde_color,
1343
+ linestyle=":",
1344
+ linewidth=1.5,
1345
+ alpha=0.7,
1346
+ label=f"MDE = {mde:.3f}"
1347
+ )
1348
+
1349
+ # Mark intersection point
1350
+ # Find power at MDE
1351
+ if mde in effect_sizes:
1352
+ idx = effect_sizes.index(mde)
1353
+ power_at_mde = powers[idx]
1354
+ else:
1355
+ # Interpolate
1356
+ effect_arr = np.array(effect_sizes)
1357
+ power_arr = np.array(powers)
1358
+ if effect_arr.min() <= mde <= effect_arr.max():
1359
+ power_at_mde = np.interp(mde, effect_arr, power_arr)
1360
+ else:
1361
+ power_at_mde = None
1362
+
1363
+ if power_at_mde is not None:
1364
+ ax.scatter(
1365
+ [mde], [power_at_mde],
1366
+ color=mde_color,
1367
+ s=50,
1368
+ zorder=5
1369
+ )
1370
+
1371
+ # Configure axes
1372
+ ax.set_xlabel(xlabel)
1373
+ ax.set_ylabel(ylabel)
1374
+ ax.set_title(title)
1375
+
1376
+ # Y-axis from 0 to 1
1377
+ ax.set_ylim(0, 1.05)
1378
+
1379
+ # Format y-axis as percentage
1380
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
1381
+
1382
+ if show_grid:
1383
+ ax.grid(True, alpha=0.3)
1384
+
1385
+ ax.legend(loc="lower right")
1386
+
1387
+ fig.tight_layout()
1388
+
1389
+ if show:
1390
+ plt.show()
1391
+
1392
+ return ax
1393
+
1394
+
1395
+ def plot_pretrends_power(
1396
+ results: Optional[Union["PreTrendsPowerResults", "PreTrendsPowerCurve", pd.DataFrame]] = None,
1397
+ *,
1398
+ M_values: Optional[List[float]] = None,
1399
+ powers: Optional[List[float]] = None,
1400
+ mdv: Optional[float] = None,
1401
+ target_power: float = 0.80,
1402
+ figsize: Tuple[float, float] = (10, 6),
1403
+ title: str = "Pre-Trends Test Power Curve",
1404
+ xlabel: str = "Violation Magnitude (M)",
1405
+ ylabel: str = "Power",
1406
+ color: str = "#2563eb",
1407
+ mdv_color: str = "#dc2626",
1408
+ target_color: str = "#22c55e",
1409
+ linewidth: float = 2.0,
1410
+ show_mdv_line: bool = True,
1411
+ show_target_line: bool = True,
1412
+ show_grid: bool = True,
1413
+ ax: Optional[Any] = None,
1414
+ show: bool = True,
1415
+ ) -> Any:
1416
+ """
1417
+ Plot pre-trends test power curve.
1418
+
1419
+ Visualizes how the power to detect parallel trends violations changes
1420
+ with the violation magnitude (M). This helps understand what violations
1421
+ your pre-trends test is capable of detecting.
1422
+
1423
+ Parameters
1424
+ ----------
1425
+ results : PreTrendsPowerResults, PreTrendsPowerCurve, or DataFrame, optional
1426
+ Results from PreTrendsPower.fit() or power_curve(), or a DataFrame
1427
+ with columns 'M' and 'power'. If None, must provide M_values and powers.
1428
+ M_values : list of float, optional
1429
+ Violation magnitudes (x-axis). Required if results is None.
1430
+ powers : list of float, optional
1431
+ Power values (y-axis). Required if results is None.
1432
+ mdv : float, optional
1433
+ Minimum detectable violation to mark on the plot.
1434
+ target_power : float, default=0.80
1435
+ Target power level to show as horizontal line.
1436
+ figsize : tuple, default=(10, 6)
1437
+ Figure size (width, height) in inches.
1438
+ title : str
1439
+ Plot title.
1440
+ xlabel : str
1441
+ X-axis label.
1442
+ ylabel : str
1443
+ Y-axis label.
1444
+ color : str, default="#2563eb"
1445
+ Color for the power curve line.
1446
+ mdv_color : str, default="#dc2626"
1447
+ Color for the MDV vertical line.
1448
+ target_color : str, default="#22c55e"
1449
+ Color for the target power horizontal line.
1450
+ linewidth : float, default=2.0
1451
+ Line width for the power curve.
1452
+ show_mdv_line : bool, default=True
1453
+ Whether to show vertical line at MDV.
1454
+ show_target_line : bool, default=True
1455
+ Whether to show horizontal line at target power.
1456
+ show_grid : bool, default=True
1457
+ Whether to show grid lines.
1458
+ ax : matplotlib.axes.Axes, optional
1459
+ Axes to plot on. If None, creates new figure.
1460
+ show : bool, default=True
1461
+ Whether to call plt.show() at the end.
1462
+
1463
+ Returns
1464
+ -------
1465
+ matplotlib.axes.Axes
1466
+ The axes object containing the plot.
1467
+
1468
+ Examples
1469
+ --------
1470
+ From PreTrendsPower results:
1471
+
1472
+ >>> from diff_diff import MultiPeriodDiD
1473
+ >>> from diff_diff.pretrends import PreTrendsPower
1474
+ >>> from diff_diff.visualization import plot_pretrends_power
1475
+ >>>
1476
+ >>> mp_did = MultiPeriodDiD()
1477
+ >>> event_results = mp_did.fit(data, outcome='y', treatment='treated',
1478
+ ... time='period', post_periods=[4, 5, 6, 7])
1479
+ >>>
1480
+ >>> pt = PreTrendsPower()
1481
+ >>> curve = pt.power_curve(event_results)
1482
+ >>> plot_pretrends_power(curve)
1483
+
1484
+ From manual data:
1485
+
1486
+ >>> plot_pretrends_power(
1487
+ ... M_values=[0, 0.5, 1, 1.5, 2],
1488
+ ... powers=[0.05, 0.3, 0.6, 0.85, 0.95],
1489
+ ... mdv=1.2,
1490
+ ... target_power=0.80
1491
+ ... )
1492
+
1493
+ Notes
1494
+ -----
1495
+ The power curve shows how likely you are to reject the null hypothesis
1496
+ of parallel trends given a true violation of magnitude M. Key points:
1497
+
1498
+ 1. **At M=0**: Power equals alpha (size of the test).
1499
+ 2. **At MDV**: Power equals target power (default 80%).
1500
+ 3. **Beyond MDV**: Power increases toward 100%.
1501
+
1502
+ A steep power curve indicates a sensitive pre-trends test. A flat curve
1503
+ indicates the test has limited ability to detect violations, suggesting
1504
+ you should use HonestDiD sensitivity analysis for robust inference.
1505
+
1506
+ See Also
1507
+ --------
1508
+ PreTrendsPower : Main class for pre-trends power analysis
1509
+ plot_sensitivity : Plot HonestDiD sensitivity analysis
1510
+ """
1511
+ try:
1512
+ import matplotlib.pyplot as plt
1513
+ except ImportError:
1514
+ raise ImportError(
1515
+ "matplotlib is required for plotting. "
1516
+ "Install it with: pip install matplotlib"
1517
+ )
1518
+
1519
+ # Extract data from results if provided
1520
+ if results is not None:
1521
+ if isinstance(results, pd.DataFrame):
1522
+ if "M" not in results.columns or "power" not in results.columns:
1523
+ raise ValueError("DataFrame must have 'M' and 'power' columns")
1524
+ M_values = results["M"].tolist()
1525
+ powers = results["power"].tolist()
1526
+ elif hasattr(results, "M_values") and hasattr(results, "powers"):
1527
+ # PreTrendsPowerCurve
1528
+ M_values = results.M_values.tolist()
1529
+ powers = results.powers.tolist()
1530
+ if mdv is None:
1531
+ mdv = results.mdv
1532
+ if target_power is None:
1533
+ target_power = results.target_power
1534
+ elif hasattr(results, "mdv") and hasattr(results, "power"):
1535
+ # Single PreTrendsPowerResults - create a simple plot
1536
+ if mdv is None:
1537
+ mdv = results.mdv
1538
+ # Create minimal curve around MDV
1539
+ if np.isfinite(mdv):
1540
+ M_values = [0, mdv * 0.5, mdv, mdv * 1.5, mdv * 2]
1541
+ else:
1542
+ M_values = [0, 1, 2, 3, 4]
1543
+ # We don't have the actual powers, so we need to create a placeholder
1544
+ # Just show MDV marker
1545
+ powers = None
1546
+ else:
1547
+ raise TypeError(
1548
+ f"Cannot extract power curve data from {type(results).__name__}"
1549
+ )
1550
+ elif M_values is None or powers is None:
1551
+ raise ValueError(
1552
+ "Must provide either 'results' or both 'M_values' and 'powers'"
1553
+ )
1554
+
1555
+ # Create figure if needed
1556
+ if ax is None:
1557
+ fig, ax = plt.subplots(figsize=figsize)
1558
+ else:
1559
+ fig = ax.get_figure()
1560
+
1561
+ # Plot power curve if we have powers
1562
+ if powers is not None:
1563
+ ax.plot(
1564
+ M_values, powers,
1565
+ color=color,
1566
+ linewidth=linewidth,
1567
+ label="Power"
1568
+ )
1569
+
1570
+ # Add target power line
1571
+ if show_target_line:
1572
+ ax.axhline(
1573
+ y=target_power,
1574
+ color=target_color,
1575
+ linestyle="--",
1576
+ linewidth=1.5,
1577
+ alpha=0.7,
1578
+ label=f"Target power ({target_power:.0%})"
1579
+ )
1580
+
1581
+ # Add MDV line
1582
+ if show_mdv_line and mdv is not None and np.isfinite(mdv):
1583
+ ax.axvline(
1584
+ x=mdv,
1585
+ color=mdv_color,
1586
+ linestyle=":",
1587
+ linewidth=1.5,
1588
+ alpha=0.7,
1589
+ label=f"MDV = {mdv:.3f}"
1590
+ )
1591
+
1592
+ # Mark intersection point if we have powers
1593
+ if powers is not None:
1594
+ # Find power at MDV (interpolate)
1595
+ M_arr = np.array(M_values)
1596
+ power_arr = np.array(powers)
1597
+ if M_arr.min() <= mdv <= M_arr.max():
1598
+ power_at_mdv = np.interp(mdv, M_arr, power_arr)
1599
+ ax.scatter(
1600
+ [mdv], [power_at_mdv],
1601
+ color=mdv_color,
1602
+ s=50,
1603
+ zorder=5
1604
+ )
1605
+
1606
+ # Configure axes
1607
+ ax.set_xlabel(xlabel)
1608
+ ax.set_ylabel(ylabel)
1609
+ ax.set_title(title)
1610
+
1611
+ # Y-axis from 0 to 1
1612
+ ax.set_ylim(0, 1.05)
1613
+
1614
+ # Format y-axis as percentage
1615
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
1616
+
1617
+ if show_grid:
1618
+ ax.grid(True, alpha=0.3)
1619
+
1620
+ ax.legend(loc="lower right")
1621
+
1622
+ fig.tight_layout()
1623
+
1624
+ if show:
1625
+ plt.show()
1626
+
1627
+ return ax