diff-diff 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1391 @@
1
+ """
2
+ Visualization functions for difference-in-differences analysis.
3
+
4
+ Provides event study plots and other diagnostic visualizations.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ if TYPE_CHECKING:
13
+ from diff_diff.bacon import BaconDecompositionResults
14
+ from diff_diff.honest_did import HonestDiDResults, SensitivityResults
15
+ from diff_diff.power import PowerResults, SimulationPowerResults
16
+ from diff_diff.results import MultiPeriodDiDResults
17
+ from diff_diff.staggered import CallawaySantAnnaResults
18
+ from diff_diff.sun_abraham import SunAbrahamResults
19
+
20
+ # Type alias for results that can be plotted
21
+ PlottableResults = Union[
22
+ "MultiPeriodDiDResults",
23
+ "CallawaySantAnnaResults",
24
+ "SunAbrahamResults",
25
+ pd.DataFrame,
26
+ ]
27
+
28
+
29
+ def plot_event_study(
30
+ results: Optional[PlottableResults] = None,
31
+ *,
32
+ effects: Optional[Dict[Any, float]] = None,
33
+ se: Optional[Dict[Any, float]] = None,
34
+ periods: Optional[List[Any]] = None,
35
+ reference_period: Optional[Any] = None,
36
+ pre_periods: Optional[List[Any]] = None,
37
+ post_periods: Optional[List[Any]] = None,
38
+ alpha: float = 0.05,
39
+ figsize: Tuple[float, float] = (10, 6),
40
+ title: str = "Event Study",
41
+ xlabel: str = "Period Relative to Treatment",
42
+ ylabel: str = "Treatment Effect",
43
+ color: str = "#2563eb",
44
+ marker: str = "o",
45
+ markersize: int = 8,
46
+ linewidth: float = 1.5,
47
+ capsize: int = 4,
48
+ show_zero_line: bool = True,
49
+ show_reference_line: bool = True,
50
+ shade_pre: bool = True,
51
+ shade_color: str = "#f0f0f0",
52
+ ax: Optional[Any] = None,
53
+ show: bool = True,
54
+ ) -> Any:
55
+ """
56
+ Create an event study plot showing treatment effects over time.
57
+
58
+ This function creates a coefficient plot with point estimates and
59
+ confidence intervals for each time period, commonly used to visualize
60
+ dynamic treatment effects and assess pre-trends.
61
+
62
+ Parameters
63
+ ----------
64
+ results : MultiPeriodDiDResults, CallawaySantAnnaResults, or DataFrame, optional
65
+ Results object from MultiPeriodDiD, CallawaySantAnna, or a DataFrame
66
+ with columns 'period', 'effect', 'se' (and optionally 'conf_int_lower',
67
+ 'conf_int_upper'). If None, must provide effects and se directly.
68
+ effects : dict, optional
69
+ Dictionary mapping periods to effect estimates. Used if results is None.
70
+ se : dict, optional
71
+ Dictionary mapping periods to standard errors. Used if results is None.
72
+ periods : list, optional
73
+ List of periods to plot. If None, uses all periods from results.
74
+ reference_period : any, optional
75
+ The reference period (normalized to effect=0). Will be shown as a
76
+ hollow marker. If None, tries to infer from results.
77
+ pre_periods : list, optional
78
+ List of pre-treatment periods. Used for shading.
79
+ post_periods : list, optional
80
+ List of post-treatment periods. Used for shading.
81
+ alpha : float, default=0.05
82
+ Significance level for confidence intervals.
83
+ figsize : tuple, default=(10, 6)
84
+ Figure size (width, height) in inches.
85
+ title : str, default="Event Study"
86
+ Plot title.
87
+ xlabel : str, default="Period Relative to Treatment"
88
+ X-axis label.
89
+ ylabel : str, default="Treatment Effect"
90
+ Y-axis label.
91
+ color : str, default="#2563eb"
92
+ Color for points and error bars.
93
+ marker : str, default="o"
94
+ Marker style for point estimates.
95
+ markersize : int, default=8
96
+ Size of markers.
97
+ linewidth : float, default=1.5
98
+ Width of error bar lines.
99
+ capsize : int, default=4
100
+ Size of error bar caps.
101
+ show_zero_line : bool, default=True
102
+ Whether to show a horizontal line at y=0.
103
+ show_reference_line : bool, default=True
104
+ Whether to show a vertical line at the reference period.
105
+ shade_pre : bool, default=True
106
+ Whether to shade the pre-treatment region.
107
+ shade_color : str, default="#f0f0f0"
108
+ Color for pre-treatment shading.
109
+ ax : matplotlib.axes.Axes, optional
110
+ Axes to plot on. If None, creates new figure.
111
+ show : bool, default=True
112
+ Whether to call plt.show() at the end.
113
+
114
+ Returns
115
+ -------
116
+ matplotlib.axes.Axes
117
+ The axes object containing the plot.
118
+
119
+ Examples
120
+ --------
121
+ Using with MultiPeriodDiD results:
122
+
123
+ >>> from diff_diff import MultiPeriodDiD, plot_event_study
124
+ >>> did = MultiPeriodDiD()
125
+ >>> results = did.fit(data, outcome='y', treatment='treated',
126
+ ... time='period', post_periods=[3, 4, 5])
127
+ >>> plot_event_study(results)
128
+
129
+ Using with a DataFrame:
130
+
131
+ >>> df = pd.DataFrame({
132
+ ... 'period': [-2, -1, 0, 1, 2],
133
+ ... 'effect': [0.1, 0.05, 0.0, 0.5, 0.6],
134
+ ... 'se': [0.1, 0.1, 0.0, 0.15, 0.15]
135
+ ... })
136
+ >>> plot_event_study(df, reference_period=0)
137
+
138
+ Using with manual effects:
139
+
140
+ >>> effects = {-2: 0.1, -1: 0.05, 0: 0.0, 1: 0.5, 2: 0.6}
141
+ >>> se = {-2: 0.1, -1: 0.1, 0: 0.0, 1: 0.15, 2: 0.15}
142
+ >>> plot_event_study(effects=effects, se=se, reference_period=0)
143
+
144
+ Notes
145
+ -----
146
+ Event study plots are a standard visualization in difference-in-differences
147
+ analysis. They show:
148
+
149
+ 1. **Pre-treatment periods**: Effects should be close to zero if parallel
150
+ trends holds. Large pre-treatment effects suggest the assumption may
151
+ be violated.
152
+
153
+ 2. **Reference period**: Usually the last pre-treatment period (t=-1),
154
+ normalized to zero. This is the omitted category.
155
+
156
+ 3. **Post-treatment periods**: The treatment effects of interest. These
157
+ show how the outcome evolved after treatment.
158
+
159
+ The confidence intervals help assess statistical significance. Effects
160
+ whose CIs don't include zero are typically considered significant.
161
+ """
162
+ try:
163
+ import matplotlib.pyplot as plt
164
+ except ImportError:
165
+ raise ImportError(
166
+ "matplotlib is required for plotting. "
167
+ "Install it with: pip install matplotlib"
168
+ )
169
+
170
+ from scipy import stats as scipy_stats
171
+
172
+ # Extract data from results if provided
173
+ if results is not None:
174
+ effects, se, periods, pre_periods, post_periods, reference_period = \
175
+ _extract_plot_data(results, periods, pre_periods, post_periods, reference_period)
176
+ elif effects is None or se is None:
177
+ raise ValueError(
178
+ "Must provide either 'results' or both 'effects' and 'se'"
179
+ )
180
+
181
+ # Ensure effects and se are dicts
182
+ if not isinstance(effects, dict):
183
+ raise TypeError("effects must be a dictionary mapping periods to values")
184
+ if not isinstance(se, dict):
185
+ raise TypeError("se must be a dictionary mapping periods to values")
186
+
187
+ # Get periods to plot
188
+ if periods is None:
189
+ periods = sorted(effects.keys())
190
+
191
+ # Compute confidence intervals
192
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
193
+
194
+ plot_data = []
195
+ for period in periods:
196
+ effect = effects.get(period, np.nan)
197
+ std_err = se.get(period, np.nan)
198
+
199
+ if np.isnan(effect) or np.isnan(std_err):
200
+ continue
201
+
202
+ ci_lower = effect - critical_value * std_err
203
+ ci_upper = effect + critical_value * std_err
204
+
205
+ plot_data.append({
206
+ 'period': period,
207
+ 'effect': effect,
208
+ 'se': std_err,
209
+ 'ci_lower': ci_lower,
210
+ 'ci_upper': ci_upper,
211
+ 'is_reference': period == reference_period,
212
+ })
213
+
214
+ if not plot_data:
215
+ raise ValueError("No valid data to plot")
216
+
217
+ df = pd.DataFrame(plot_data)
218
+
219
+ # Create figure if needed
220
+ if ax is None:
221
+ fig, ax = plt.subplots(figsize=figsize)
222
+ else:
223
+ fig = ax.get_figure()
224
+
225
+ # Convert periods to numeric for plotting
226
+ period_to_x = {p: i for i, p in enumerate(df['period'])}
227
+ x_vals = [period_to_x[p] for p in df['period']]
228
+
229
+ # Shade pre-treatment region
230
+ if shade_pre and pre_periods is not None:
231
+ pre_x = [period_to_x[p] for p in pre_periods if p in period_to_x]
232
+ if pre_x:
233
+ ax.axvspan(min(pre_x) - 0.5, max(pre_x) + 0.5,
234
+ color=shade_color, alpha=0.5, zorder=0)
235
+
236
+ # Draw horizontal zero line
237
+ if show_zero_line:
238
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, zorder=1)
239
+
240
+ # Draw vertical reference line
241
+ if show_reference_line and reference_period is not None:
242
+ if reference_period in period_to_x:
243
+ ref_x = period_to_x[reference_period]
244
+ ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)
245
+
246
+ # Plot error bars
247
+ yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']]
248
+ ax.errorbar(
249
+ x_vals, df['effect'], yerr=yerr,
250
+ fmt='none', color=color, capsize=capsize, linewidth=linewidth,
251
+ capthick=linewidth, zorder=2
252
+ )
253
+
254
+ # Plot point estimates
255
+ for i, row in df.iterrows():
256
+ x = period_to_x[row['period']]
257
+ if row['is_reference']:
258
+ # Hollow marker for reference period
259
+ ax.plot(x, row['effect'], marker=marker, markersize=markersize,
260
+ markerfacecolor='white', markeredgecolor=color,
261
+ markeredgewidth=2, zorder=3)
262
+ else:
263
+ ax.plot(x, row['effect'], marker=marker, markersize=markersize,
264
+ color=color, zorder=3)
265
+
266
+ # Set labels and title
267
+ ax.set_xlabel(xlabel)
268
+ ax.set_ylabel(ylabel)
269
+ ax.set_title(title)
270
+
271
+ # Set x-axis ticks
272
+ ax.set_xticks(x_vals)
273
+ ax.set_xticklabels([str(p) for p in df['period']])
274
+
275
+ # Add grid
276
+ ax.grid(True, alpha=0.3, axis='y')
277
+
278
+ # Tight layout
279
+ fig.tight_layout()
280
+
281
+ if show:
282
+ plt.show()
283
+
284
+ return ax
285
+
286
+
287
+ def _extract_plot_data(
288
+ results: PlottableResults,
289
+ periods: Optional[List[Any]],
290
+ pre_periods: Optional[List[Any]],
291
+ post_periods: Optional[List[Any]],
292
+ reference_period: Optional[Any],
293
+ ) -> Tuple[Dict, Dict, List, List, List, Any]:
294
+ """
295
+ Extract plotting data from various result types.
296
+
297
+ Returns
298
+ -------
299
+ tuple
300
+ (effects, se, periods, pre_periods, post_periods, reference_period)
301
+ """
302
+ # Handle DataFrame input
303
+ if isinstance(results, pd.DataFrame):
304
+ if 'period' not in results.columns:
305
+ raise ValueError("DataFrame must have 'period' column")
306
+ if 'effect' not in results.columns:
307
+ raise ValueError("DataFrame must have 'effect' column")
308
+ if 'se' not in results.columns:
309
+ raise ValueError("DataFrame must have 'se' column")
310
+
311
+ effects = dict(zip(results['period'], results['effect']))
312
+ se = dict(zip(results['period'], results['se']))
313
+
314
+ if periods is None:
315
+ periods = list(results['period'])
316
+
317
+ return effects, se, periods, pre_periods, post_periods, reference_period
318
+
319
+ # Handle MultiPeriodDiDResults
320
+ if hasattr(results, 'period_effects'):
321
+ effects = {}
322
+ se = {}
323
+
324
+ for period, pe in results.period_effects.items():
325
+ effects[period] = pe.effect
326
+ se[period] = pe.se
327
+
328
+ if pre_periods is None and hasattr(results, 'pre_periods'):
329
+ pre_periods = results.pre_periods
330
+
331
+ if post_periods is None and hasattr(results, 'post_periods'):
332
+ post_periods = results.post_periods
333
+
334
+ if periods is None:
335
+ periods = post_periods
336
+
337
+ return effects, se, periods, pre_periods, post_periods, reference_period
338
+
339
+ # Handle CallawaySantAnnaResults (event study aggregation)
340
+ if hasattr(results, 'event_study_effects') and results.event_study_effects is not None:
341
+ effects = {}
342
+ se = {}
343
+
344
+ for rel_period, effect_data in results.event_study_effects.items():
345
+ effects[rel_period] = effect_data['effect']
346
+ se[rel_period] = effect_data['se']
347
+
348
+ if periods is None:
349
+ periods = sorted(effects.keys())
350
+
351
+ # Reference period is typically -1 for event study
352
+ if reference_period is None:
353
+ reference_period = -1
354
+
355
+ if pre_periods is None:
356
+ pre_periods = [p for p in periods if p < 0]
357
+
358
+ if post_periods is None:
359
+ post_periods = [p for p in periods if p >= 0]
360
+
361
+ return effects, se, periods, pre_periods, post_periods, reference_period
362
+
363
+ raise TypeError(
364
+ f"Cannot extract plot data from {type(results).__name__}. "
365
+ "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, "
366
+ "SunAbrahamResults, or DataFrame."
367
+ )
368
+
369
+
370
+ def plot_group_effects(
371
+ results: "CallawaySantAnnaResults",
372
+ *,
373
+ groups: Optional[List[Any]] = None,
374
+ figsize: Tuple[float, float] = (10, 6),
375
+ title: str = "Treatment Effects by Cohort",
376
+ xlabel: str = "Time Period",
377
+ ylabel: str = "Treatment Effect",
378
+ alpha: float = 0.05,
379
+ show: bool = True,
380
+ ax: Optional[Any] = None,
381
+ ) -> Any:
382
+ """
383
+ Plot treatment effects by treatment cohort (group).
384
+
385
+ Parameters
386
+ ----------
387
+ results : CallawaySantAnnaResults
388
+ Results from CallawaySantAnna estimator.
389
+ groups : list, optional
390
+ List of groups (cohorts) to plot. If None, plots all groups.
391
+ figsize : tuple, default=(10, 6)
392
+ Figure size.
393
+ title : str
394
+ Plot title.
395
+ xlabel : str
396
+ X-axis label.
397
+ ylabel : str
398
+ Y-axis label.
399
+ alpha : float, default=0.05
400
+ Significance level for confidence intervals.
401
+ show : bool, default=True
402
+ Whether to call plt.show().
403
+ ax : matplotlib.axes.Axes, optional
404
+ Axes to plot on.
405
+
406
+ Returns
407
+ -------
408
+ matplotlib.axes.Axes
409
+ The axes object.
410
+ """
411
+ try:
412
+ import matplotlib.pyplot as plt
413
+ except ImportError:
414
+ raise ImportError(
415
+ "matplotlib is required for plotting. "
416
+ "Install it with: pip install matplotlib"
417
+ )
418
+
419
+ from scipy import stats as scipy_stats
420
+
421
+ if not hasattr(results, 'group_time_effects'):
422
+ raise TypeError("results must be a CallawaySantAnnaResults object")
423
+
424
+ # Get groups to plot
425
+ if groups is None:
426
+ groups = sorted(set(g for g, t in results.group_time_effects.keys()))
427
+
428
+ # Create figure
429
+ if ax is None:
430
+ fig, ax = plt.subplots(figsize=figsize)
431
+ else:
432
+ fig = ax.get_figure()
433
+
434
+ # Color palette
435
+ colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))
436
+
437
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
438
+
439
+ for i, group in enumerate(groups):
440
+ # Get effects for this group
441
+ group_effects = [
442
+ (t, data) for (g, t), data in results.group_time_effects.items()
443
+ if g == group
444
+ ]
445
+ group_effects.sort(key=lambda x: x[0])
446
+
447
+ if not group_effects:
448
+ continue
449
+
450
+ times = [t for t, _ in group_effects]
451
+ effects = [data['effect'] for _, data in group_effects]
452
+ ses = [data['se'] for _, data in group_effects]
453
+
454
+ yerr = [
455
+ [e - (e - critical_value * s) for e, s in zip(effects, ses)],
456
+ [(e + critical_value * s) - e for e, s in zip(effects, ses)]
457
+ ]
458
+
459
+ ax.errorbar(
460
+ times, effects, yerr=yerr,
461
+ label=f'Cohort {group}', color=colors[i],
462
+ marker='o', capsize=3, linewidth=1.5
463
+ )
464
+
465
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1)
466
+ ax.set_xlabel(xlabel)
467
+ ax.set_ylabel(ylabel)
468
+ ax.set_title(title)
469
+ ax.legend(loc='best')
470
+ ax.grid(True, alpha=0.3, axis='y')
471
+
472
+ fig.tight_layout()
473
+
474
+ if show:
475
+ plt.show()
476
+
477
+ return ax
478
+
479
+
480
+ def plot_sensitivity(
481
+ sensitivity_results: "SensitivityResults",
482
+ *,
483
+ show_bounds: bool = True,
484
+ show_ci: bool = True,
485
+ breakdown_line: bool = True,
486
+ figsize: Tuple[float, float] = (10, 6),
487
+ title: str = "Honest DiD Sensitivity Analysis",
488
+ xlabel: str = "M (restriction parameter)",
489
+ ylabel: str = "Treatment Effect",
490
+ bounds_color: str = "#2563eb",
491
+ bounds_alpha: float = 0.3,
492
+ ci_color: str = "#2563eb",
493
+ ci_linewidth: float = 1.5,
494
+ breakdown_color: str = "#dc2626",
495
+ original_color: str = "#1f2937",
496
+ ax: Optional[Any] = None,
497
+ show: bool = True,
498
+ ) -> Any:
499
+ """
500
+ Plot sensitivity analysis results from Honest DiD.
501
+
502
+ Shows how treatment effect bounds and confidence intervals
503
+ change as the restriction parameter M varies.
504
+
505
+ Parameters
506
+ ----------
507
+ sensitivity_results : SensitivityResults
508
+ Results from HonestDiD.sensitivity_analysis().
509
+ show_bounds : bool, default=True
510
+ Whether to show the identified set bounds as shaded region.
511
+ show_ci : bool, default=True
512
+ Whether to show robust confidence interval lines.
513
+ breakdown_line : bool, default=True
514
+ Whether to show vertical line at breakdown value.
515
+ figsize : tuple, default=(10, 6)
516
+ Figure size (width, height) in inches.
517
+ title : str
518
+ Plot title.
519
+ xlabel : str
520
+ X-axis label.
521
+ ylabel : str
522
+ Y-axis label.
523
+ bounds_color : str
524
+ Color for identified set shading.
525
+ bounds_alpha : float
526
+ Transparency for identified set shading.
527
+ ci_color : str
528
+ Color for confidence interval lines.
529
+ ci_linewidth : float
530
+ Line width for CI lines.
531
+ breakdown_color : str
532
+ Color for breakdown value line.
533
+ original_color : str
534
+ Color for original estimate line.
535
+ ax : matplotlib.axes.Axes, optional
536
+ Axes to plot on. If None, creates new figure.
537
+ show : bool, default=True
538
+ Whether to call plt.show().
539
+
540
+ Returns
541
+ -------
542
+ matplotlib.axes.Axes
543
+ The axes object containing the plot.
544
+
545
+ Examples
546
+ --------
547
+ >>> from diff_diff import MultiPeriodDiD
548
+ >>> from diff_diff.honest_did import HonestDiD
549
+ >>> from diff_diff.visualization import plot_sensitivity
550
+ >>>
551
+ >>> # Fit event study and run sensitivity analysis
552
+ >>> results = MultiPeriodDiD().fit(data, ...)
553
+ >>> honest = HonestDiD(method='relative_magnitude')
554
+ >>> sensitivity = honest.sensitivity_analysis(results)
555
+ >>>
556
+ >>> # Create sensitivity plot
557
+ >>> plot_sensitivity(sensitivity)
558
+ """
559
+ try:
560
+ import matplotlib.pyplot as plt
561
+ except ImportError:
562
+ raise ImportError(
563
+ "matplotlib is required for plotting. "
564
+ "Install it with: pip install matplotlib"
565
+ )
566
+
567
+ # Create figure if needed
568
+ if ax is None:
569
+ fig, ax = plt.subplots(figsize=figsize)
570
+ else:
571
+ fig = ax.get_figure()
572
+
573
+ M = sensitivity_results.M_values
574
+ bounds_arr = np.array(sensitivity_results.bounds)
575
+ ci_arr = np.array(sensitivity_results.robust_cis)
576
+
577
+ # Plot original estimate
578
+ ax.axhline(
579
+ y=sensitivity_results.original_estimate,
580
+ color=original_color,
581
+ linestyle='-',
582
+ linewidth=1.5,
583
+ label='Original estimate',
584
+ alpha=0.7
585
+ )
586
+
587
+ # Plot zero line
588
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
589
+
590
+ # Plot identified set bounds
591
+ if show_bounds:
592
+ ax.fill_between(
593
+ M, bounds_arr[:, 0], bounds_arr[:, 1],
594
+ alpha=bounds_alpha,
595
+ color=bounds_color,
596
+ label='Identified set'
597
+ )
598
+
599
+ # Plot confidence intervals
600
+ if show_ci:
601
+ ax.plot(
602
+ M, ci_arr[:, 0],
603
+ color=ci_color,
604
+ linewidth=ci_linewidth,
605
+ label='Robust CI'
606
+ )
607
+ ax.plot(
608
+ M, ci_arr[:, 1],
609
+ color=ci_color,
610
+ linewidth=ci_linewidth
611
+ )
612
+
613
+ # Plot breakdown line
614
+ if breakdown_line and sensitivity_results.breakdown_M is not None:
615
+ ax.axvline(
616
+ x=sensitivity_results.breakdown_M,
617
+ color=breakdown_color,
618
+ linestyle=':',
619
+ linewidth=2,
620
+ label=f'Breakdown (M={sensitivity_results.breakdown_M:.2f})'
621
+ )
622
+
623
+ ax.set_xlabel(xlabel)
624
+ ax.set_ylabel(ylabel)
625
+ ax.set_title(title)
626
+ ax.legend(loc='best')
627
+ ax.grid(True, alpha=0.3)
628
+
629
+ fig.tight_layout()
630
+
631
+ if show:
632
+ plt.show()
633
+
634
+ return ax
635
+
636
+
637
+ def plot_honest_event_study(
638
+ honest_results: "HonestDiDResults",
639
+ *,
640
+ periods: Optional[List[Any]] = None,
641
+ reference_period: Optional[Any] = None,
642
+ figsize: Tuple[float, float] = (10, 6),
643
+ title: str = "Event Study with Honest Confidence Intervals",
644
+ xlabel: str = "Period Relative to Treatment",
645
+ ylabel: str = "Treatment Effect",
646
+ original_color: str = "#6b7280",
647
+ honest_color: str = "#2563eb",
648
+ marker: str = "o",
649
+ markersize: int = 8,
650
+ capsize: int = 4,
651
+ ax: Optional[Any] = None,
652
+ show: bool = True,
653
+ ) -> Any:
654
+ """
655
+ Create event study plot with Honest DiD confidence intervals.
656
+
657
+ Shows both the original confidence intervals (assuming parallel trends)
658
+ and the robust confidence intervals that allow for bounded violations.
659
+
660
+ Parameters
661
+ ----------
662
+ honest_results : HonestDiDResults
663
+ Results from HonestDiD.fit() that include event_study_bounds.
664
+ periods : list, optional
665
+ Periods to plot. If None, uses all available periods.
666
+ reference_period : any, optional
667
+ Reference period to show as hollow marker.
668
+ figsize : tuple, default=(10, 6)
669
+ Figure size.
670
+ title : str
671
+ Plot title.
672
+ xlabel : str
673
+ X-axis label.
674
+ ylabel : str
675
+ Y-axis label.
676
+ original_color : str
677
+ Color for original (standard) confidence intervals.
678
+ honest_color : str
679
+ Color for honest (robust) confidence intervals.
680
+ marker : str
681
+ Marker style.
682
+ markersize : int
683
+ Marker size.
684
+ capsize : int
685
+ Error bar cap size.
686
+ ax : matplotlib.axes.Axes, optional
687
+ Axes to plot on.
688
+ show : bool, default=True
689
+ Whether to call plt.show().
690
+
691
+ Returns
692
+ -------
693
+ matplotlib.axes.Axes
694
+ The axes object.
695
+
696
+ Notes
697
+ -----
698
+ This function requires the HonestDiDResults to have been computed
699
+ with event_study_bounds. If only a scalar bound was computed,
700
+ use plot_sensitivity() instead.
701
+ """
702
+ try:
703
+ import matplotlib.pyplot as plt
704
+ except ImportError:
705
+ raise ImportError(
706
+ "matplotlib is required for plotting. "
707
+ "Install it with: pip install matplotlib"
708
+ )
709
+
710
+ from scipy import stats as scipy_stats
711
+
712
+ # Get original results for standard CIs
713
+ original_results = honest_results.original_results
714
+ if original_results is None:
715
+ raise ValueError(
716
+ "HonestDiDResults must have original_results to plot event study"
717
+ )
718
+
719
+ # Extract data from original results
720
+ if hasattr(original_results, 'period_effects'):
721
+ # MultiPeriodDiDResults
722
+ effects_dict = {
723
+ p: pe.effect for p, pe in original_results.period_effects.items()
724
+ }
725
+ se_dict = {
726
+ p: pe.se for p, pe in original_results.period_effects.items()
727
+ }
728
+ if periods is None:
729
+ periods = list(original_results.period_effects.keys())
730
+ elif hasattr(original_results, 'event_study_effects'):
731
+ # CallawaySantAnnaResults
732
+ effects_dict = {
733
+ t: data['effect']
734
+ for t, data in original_results.event_study_effects.items()
735
+ }
736
+ se_dict = {
737
+ t: data['se']
738
+ for t, data in original_results.event_study_effects.items()
739
+ }
740
+ if periods is None:
741
+ periods = sorted(original_results.event_study_effects.keys())
742
+ else:
743
+ raise TypeError("Cannot extract event study data from original_results")
744
+
745
+ # Create figure
746
+ if ax is None:
747
+ fig, ax = plt.subplots(figsize=figsize)
748
+ else:
749
+ fig = ax.get_figure()
750
+
751
+ # Compute CIs
752
+ alpha = honest_results.alpha
753
+ z = scipy_stats.norm.ppf(1 - alpha / 2)
754
+
755
+ x_vals = list(range(len(periods)))
756
+
757
+ effects = [effects_dict[p] for p in periods]
758
+ original_ci_lower = [effects_dict[p] - z * se_dict[p] for p in periods]
759
+ original_ci_upper = [effects_dict[p] + z * se_dict[p] for p in periods]
760
+
761
+ # Get honest bounds if available for each period
762
+ if honest_results.event_study_bounds:
763
+ honest_ci_lower = [
764
+ honest_results.event_study_bounds[p]['ci_lb']
765
+ for p in periods
766
+ ]
767
+ honest_ci_upper = [
768
+ honest_results.event_study_bounds[p]['ci_ub']
769
+ for p in periods
770
+ ]
771
+ else:
772
+ # Use scalar bounds applied to all periods
773
+ honest_ci_lower = [honest_results.ci_lb] * len(periods)
774
+ honest_ci_upper = [honest_results.ci_ub] * len(periods)
775
+
776
+ # Zero line
777
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
778
+
779
+ # Plot original CIs (thinner, background)
780
+ yerr_orig = [
781
+ [e - lower for e, lower in zip(effects, original_ci_lower)],
782
+ [u - e for e, u in zip(effects, original_ci_upper)]
783
+ ]
784
+ ax.errorbar(
785
+ x_vals, effects, yerr=yerr_orig,
786
+ fmt='none', color=original_color, capsize=capsize - 1,
787
+ linewidth=1, alpha=0.6, label='Standard CI'
788
+ )
789
+
790
+ # Plot honest CIs (thicker, foreground)
791
+ yerr_honest = [
792
+ [e - lower for e, lower in zip(effects, honest_ci_lower)],
793
+ [u - e for e, u in zip(effects, honest_ci_upper)]
794
+ ]
795
+ ax.errorbar(
796
+ x_vals, effects, yerr=yerr_honest,
797
+ fmt='none', color=honest_color, capsize=capsize,
798
+ linewidth=2, label=f'Honest CI (M={honest_results.M:.2f})'
799
+ )
800
+
801
+ # Plot point estimates
802
+ for i, (x, effect, period) in enumerate(zip(x_vals, effects, periods)):
803
+ is_ref = period == reference_period
804
+ if is_ref:
805
+ ax.plot(x, effect, marker=marker, markersize=markersize,
806
+ markerfacecolor='white', markeredgecolor=honest_color,
807
+ markeredgewidth=2, zorder=3)
808
+ else:
809
+ ax.plot(x, effect, marker=marker, markersize=markersize,
810
+ color=honest_color, zorder=3)
811
+
812
+ ax.set_xlabel(xlabel)
813
+ ax.set_ylabel(ylabel)
814
+ ax.set_title(title)
815
+ ax.set_xticks(x_vals)
816
+ ax.set_xticklabels([str(p) for p in periods])
817
+ ax.legend(loc='best')
818
+ ax.grid(True, alpha=0.3, axis='y')
819
+
820
+ fig.tight_layout()
821
+
822
+ if show:
823
+ plt.show()
824
+
825
+ return ax
826
+
827
+
828
+ def plot_bacon(
829
+ results: "BaconDecompositionResults",
830
+ *,
831
+ plot_type: str = "scatter",
832
+ figsize: Tuple[float, float] = (10, 6),
833
+ title: Optional[str] = None,
834
+ xlabel: str = "2x2 DiD Estimate",
835
+ ylabel: str = "Weight",
836
+ colors: Optional[Dict[str, str]] = None,
837
+ marker: str = "o",
838
+ markersize: int = 80,
839
+ alpha: float = 0.7,
840
+ show_weighted_avg: bool = True,
841
+ show_twfe_line: bool = True,
842
+ ax: Optional[Any] = None,
843
+ show: bool = True,
844
+ ) -> Any:
845
+ """
846
+ Visualize Goodman-Bacon decomposition results.
847
+
848
+ Creates either a scatter plot showing the weight and estimate for each
849
+ 2x2 comparison, or a stacked bar chart showing total weight by comparison
850
+ type.
851
+
852
+ Parameters
853
+ ----------
854
+ results : BaconDecompositionResults
855
+ Results from BaconDecomposition.fit() or bacon_decompose().
856
+ plot_type : str, default="scatter"
857
+ Type of plot to create:
858
+ - "scatter": Scatter plot with estimates on x-axis, weights on y-axis
859
+ - "bar": Stacked bar chart of weights by comparison type
860
+ figsize : tuple, default=(10, 6)
861
+ Figure size (width, height) in inches.
862
+ title : str, optional
863
+ Plot title. If None, uses a default based on plot_type.
864
+ xlabel : str, default="2x2 DiD Estimate"
865
+ X-axis label (scatter plot only).
866
+ ylabel : str, default="Weight"
867
+ Y-axis label.
868
+ colors : dict, optional
869
+ Dictionary mapping comparison types to colors. Keys are:
870
+ "treated_vs_never", "earlier_vs_later", "later_vs_earlier".
871
+ If None, uses default colors.
872
+ marker : str, default="o"
873
+ Marker style for scatter plot.
874
+ markersize : int, default=80
875
+ Marker size for scatter plot.
876
+ alpha : float, default=0.7
877
+ Transparency for markers/bars.
878
+ show_weighted_avg : bool, default=True
879
+ Whether to show weighted average lines for each comparison type
880
+ (scatter plot only).
881
+ show_twfe_line : bool, default=True
882
+ Whether to show a vertical line at the TWFE estimate (scatter plot only).
883
+ ax : matplotlib.axes.Axes, optional
884
+ Axes to plot on. If None, creates new figure.
885
+ show : bool, default=True
886
+ Whether to call plt.show() at the end.
887
+
888
+ Returns
889
+ -------
890
+ matplotlib.axes.Axes
891
+ The axes object containing the plot.
892
+
893
+ Examples
894
+ --------
895
+ Scatter plot (default):
896
+
897
+ >>> from diff_diff import bacon_decompose, plot_bacon
898
+ >>> results = bacon_decompose(data, outcome='y', unit='id',
899
+ ... time='t', first_treat='first_treat')
900
+ >>> plot_bacon(results)
901
+
902
+ Bar chart of weights by type:
903
+
904
+ >>> plot_bacon(results, plot_type='bar')
905
+
906
+ Customized scatter plot:
907
+
908
+ >>> plot_bacon(results,
909
+ ... colors={'treated_vs_never': 'green',
910
+ ... 'earlier_vs_later': 'blue',
911
+ ... 'later_vs_earlier': 'red'},
912
+ ... title='My Bacon Decomposition')
913
+
914
+ Notes
915
+ -----
916
+ The scatter plot is particularly useful for understanding:
917
+
918
+ 1. **Distribution of estimates**: Are 2x2 estimates clustered or spread?
919
+ Wide spread suggests heterogeneous treatment effects.
920
+
921
+ 2. **Weight concentration**: Do a few comparisons dominate the TWFE?
922
+ Points with high weights have more influence.
923
+
924
+ 3. **Forbidden comparison problem**: Red points (later_vs_earlier) show
925
+ comparisons using already-treated units as controls. If these have
926
+ different estimates than clean comparisons, TWFE may be biased.
927
+
928
+ The bar chart provides a quick summary of how much weight falls on
929
+ each comparison type, which is useful for assessing the severity
930
+ of potential TWFE bias.
931
+
932
+ See Also
933
+ --------
934
+ bacon_decompose : Perform the decomposition
935
+ BaconDecomposition : Class-based interface
936
+ """
937
+ try:
938
+ import matplotlib.pyplot as plt
939
+ except ImportError:
940
+ raise ImportError(
941
+ "matplotlib is required for plotting. "
942
+ "Install it with: pip install matplotlib"
943
+ )
944
+
945
+ # Default colors
946
+ if colors is None:
947
+ colors = {
948
+ "treated_vs_never": "#22c55e", # Green - clean comparison
949
+ "earlier_vs_later": "#3b82f6", # Blue - valid comparison
950
+ "later_vs_earlier": "#ef4444", # Red - forbidden comparison
951
+ }
952
+
953
+ # Default titles
954
+ if title is None:
955
+ if plot_type == "scatter":
956
+ title = "Goodman-Bacon Decomposition"
957
+ else:
958
+ title = "TWFE Weight by Comparison Type"
959
+
960
+ # Create figure if needed
961
+ if ax is None:
962
+ fig, ax = plt.subplots(figsize=figsize)
963
+ else:
964
+ fig = ax.get_figure()
965
+
966
+ if plot_type == "scatter":
967
+ _plot_bacon_scatter(
968
+ ax, results, colors, marker, markersize, alpha,
969
+ show_weighted_avg, show_twfe_line, xlabel, ylabel, title
970
+ )
971
+ elif plot_type == "bar":
972
+ _plot_bacon_bar(ax, results, colors, alpha, ylabel, title)
973
+ else:
974
+ raise ValueError(f"Unknown plot_type: {plot_type}. Use 'scatter' or 'bar'.")
975
+
976
+ fig.tight_layout()
977
+
978
+ if show:
979
+ plt.show()
980
+
981
+ return ax
982
+
983
+
984
+ def _plot_bacon_scatter(
985
+ ax: Any,
986
+ results: "BaconDecompositionResults",
987
+ colors: Dict[str, str],
988
+ marker: str,
989
+ markersize: int,
990
+ alpha: float,
991
+ show_weighted_avg: bool,
992
+ show_twfe_line: bool,
993
+ xlabel: str,
994
+ ylabel: str,
995
+ title: str,
996
+ ) -> None:
997
+ """Create scatter plot of Bacon decomposition."""
998
+ # Separate comparisons by type
999
+ by_type: Dict[str, List[Tuple[float, float]]] = {
1000
+ "treated_vs_never": [],
1001
+ "earlier_vs_later": [],
1002
+ "later_vs_earlier": [],
1003
+ }
1004
+
1005
+ for comp in results.comparisons:
1006
+ by_type[comp.comparison_type].append((comp.estimate, comp.weight))
1007
+
1008
+ # Plot each type
1009
+ labels = {
1010
+ "treated_vs_never": "Treated vs Never-treated",
1011
+ "earlier_vs_later": "Earlier vs Later treated",
1012
+ "later_vs_earlier": "Later vs Earlier (forbidden)",
1013
+ }
1014
+
1015
+ for ctype, points in by_type.items():
1016
+ if not points:
1017
+ continue
1018
+ estimates = [p[0] for p in points]
1019
+ weights = [p[1] for p in points]
1020
+ ax.scatter(
1021
+ estimates, weights,
1022
+ c=colors[ctype],
1023
+ label=labels[ctype],
1024
+ marker=marker,
1025
+ s=markersize,
1026
+ alpha=alpha,
1027
+ edgecolors='white',
1028
+ linewidths=0.5,
1029
+ )
1030
+
1031
+ # Show weighted average lines
1032
+ if show_weighted_avg:
1033
+ effect_by_type = results.effect_by_type()
1034
+ for ctype, avg_effect in effect_by_type.items():
1035
+ if avg_effect is not None and by_type[ctype]:
1036
+ ax.axvline(
1037
+ x=avg_effect,
1038
+ color=colors[ctype],
1039
+ linestyle='--',
1040
+ alpha=0.5,
1041
+ linewidth=1.5,
1042
+ )
1043
+
1044
+ # Show TWFE estimate line
1045
+ if show_twfe_line:
1046
+ ax.axvline(
1047
+ x=results.twfe_estimate,
1048
+ color='black',
1049
+ linestyle='-',
1050
+ linewidth=2,
1051
+ label=f'TWFE = {results.twfe_estimate:.4f}',
1052
+ )
1053
+
1054
+ ax.set_xlabel(xlabel)
1055
+ ax.set_ylabel(ylabel)
1056
+ ax.set_title(title)
1057
+ ax.legend(loc='best')
1058
+ ax.grid(True, alpha=0.3)
1059
+
1060
+ # Add zero line
1061
+ ax.axvline(x=0, color='gray', linestyle=':', alpha=0.5)
1062
+
1063
+
1064
+ def _plot_bacon_bar(
1065
+ ax: Any,
1066
+ results: "BaconDecompositionResults",
1067
+ colors: Dict[str, str],
1068
+ alpha: float,
1069
+ ylabel: str,
1070
+ title: str,
1071
+ ) -> None:
1072
+ """Create stacked bar chart of weights by comparison type."""
1073
+ # Get weights
1074
+ weights = results.weight_by_type()
1075
+
1076
+ # Labels and colors
1077
+ type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"]
1078
+ labels = {
1079
+ "treated_vs_never": "Treated vs Never-treated",
1080
+ "earlier_vs_later": "Earlier vs Later",
1081
+ "later_vs_earlier": "Later vs Earlier\n(forbidden)",
1082
+ }
1083
+
1084
+ # Create bar data
1085
+ bar_labels = [labels[t] for t in type_order]
1086
+ bar_weights = [weights[t] for t in type_order]
1087
+ bar_colors = [colors[t] for t in type_order]
1088
+
1089
+ # Create bars
1090
+ bars = ax.bar(
1091
+ bar_labels,
1092
+ bar_weights,
1093
+ color=bar_colors,
1094
+ alpha=alpha,
1095
+ edgecolor='white',
1096
+ linewidth=1,
1097
+ )
1098
+
1099
+ # Add percentage labels on bars
1100
+ for bar, weight in zip(bars, bar_weights):
1101
+ if weight > 0.01: # Only label if > 1%
1102
+ height = bar.get_height()
1103
+ ax.annotate(
1104
+ f'{weight:.1%}',
1105
+ xy=(bar.get_x() + bar.get_width() / 2, height),
1106
+ xytext=(0, 3),
1107
+ textcoords="offset points",
1108
+ ha='center',
1109
+ va='bottom',
1110
+ fontsize=10,
1111
+ fontweight='bold',
1112
+ )
1113
+
1114
+ # Add weighted average effect annotations
1115
+ effects = results.effect_by_type()
1116
+ for bar, ctype in zip(bars, type_order):
1117
+ effect = effects[ctype]
1118
+ if effect is not None and weights[ctype] > 0.01:
1119
+ ax.annotate(
1120
+ f'β = {effect:.3f}',
1121
+ xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2),
1122
+ ha='center',
1123
+ va='center',
1124
+ fontsize=9,
1125
+ color='white',
1126
+ fontweight='bold',
1127
+ )
1128
+
1129
+ ax.set_ylabel(ylabel)
1130
+ ax.set_title(title)
1131
+ ax.set_ylim(0, 1.1)
1132
+
1133
+ # Add horizontal line at total weight = 1
1134
+ ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
1135
+
1136
+ # Add TWFE estimate as text
1137
+ ax.text(
1138
+ 0.98, 0.98,
1139
+ f'TWFE = {results.twfe_estimate:.4f}',
1140
+ transform=ax.transAxes,
1141
+ ha='right',
1142
+ va='top',
1143
+ fontsize=10,
1144
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
1145
+ )
1146
+
1147
+
1148
+ def plot_power_curve(
1149
+ results: Optional[Union["PowerResults", "SimulationPowerResults", pd.DataFrame]] = None,
1150
+ *,
1151
+ effect_sizes: Optional[List[float]] = None,
1152
+ powers: Optional[List[float]] = None,
1153
+ mde: Optional[float] = None,
1154
+ target_power: float = 0.80,
1155
+ plot_type: str = "effect",
1156
+ figsize: Tuple[float, float] = (10, 6),
1157
+ title: Optional[str] = None,
1158
+ xlabel: Optional[str] = None,
1159
+ ylabel: str = "Power",
1160
+ color: str = "#2563eb",
1161
+ mde_color: str = "#dc2626",
1162
+ target_color: str = "#22c55e",
1163
+ linewidth: float = 2.0,
1164
+ show_mde_line: bool = True,
1165
+ show_target_line: bool = True,
1166
+ show_grid: bool = True,
1167
+ ax: Optional[Any] = None,
1168
+ show: bool = True,
1169
+ ) -> Any:
1170
+ """
1171
+ Create a power curve visualization.
1172
+
1173
+ Shows how statistical power changes with effect size or sample size,
1174
+ helping researchers understand the trade-offs in study design.
1175
+
1176
+ Parameters
1177
+ ----------
1178
+ results : PowerResults, SimulationPowerResults, or DataFrame, optional
1179
+ Results object from PowerAnalysis or simulate_power(), or a DataFrame
1180
+ with columns 'effect_size' and 'power' (or 'sample_size' and 'power').
1181
+ If None, must provide effect_sizes and powers directly.
1182
+ effect_sizes : list of float, optional
1183
+ Effect sizes (x-axis values). Required if results is None.
1184
+ powers : list of float, optional
1185
+ Power values (y-axis values). Required if results is None.
1186
+ mde : float, optional
1187
+ Minimum detectable effect to mark on the plot.
1188
+ target_power : float, default=0.80
1189
+ Target power level to show as horizontal line.
1190
+ plot_type : str, default="effect"
1191
+ Type of power curve: "effect" (power vs effect size) or
1192
+ "sample" (power vs sample size).
1193
+ figsize : tuple, default=(10, 6)
1194
+ Figure size (width, height) in inches.
1195
+ title : str, optional
1196
+ Plot title. If None, uses a sensible default.
1197
+ xlabel : str, optional
1198
+ X-axis label. If None, uses a sensible default.
1199
+ ylabel : str, default="Power"
1200
+ Y-axis label.
1201
+ color : str, default="#2563eb"
1202
+ Color for the power curve line.
1203
+ mde_color : str, default="#dc2626"
1204
+ Color for the MDE vertical line.
1205
+ target_color : str, default="#22c55e"
1206
+ Color for the target power horizontal line.
1207
+ linewidth : float, default=2.0
1208
+ Line width for the power curve.
1209
+ show_mde_line : bool, default=True
1210
+ Whether to show vertical line at MDE.
1211
+ show_target_line : bool, default=True
1212
+ Whether to show horizontal line at target power.
1213
+ show_grid : bool, default=True
1214
+ Whether to show grid lines.
1215
+ ax : matplotlib.axes.Axes, optional
1216
+ Axes to plot on. If None, creates new figure.
1217
+ show : bool, default=True
1218
+ Whether to call plt.show() at the end.
1219
+
1220
+ Returns
1221
+ -------
1222
+ matplotlib.axes.Axes
1223
+ The axes object containing the plot.
1224
+
1225
+ Examples
1226
+ --------
1227
+ From PowerAnalysis results:
1228
+
1229
+ >>> from diff_diff import PowerAnalysis, plot_power_curve
1230
+ >>> pa = PowerAnalysis(power=0.80)
1231
+ >>> curve_df = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
1232
+ >>> mde_result = pa.mde(n_treated=50, n_control=50, sigma=5.0)
1233
+ >>> plot_power_curve(curve_df, mde=mde_result.mde)
1234
+
1235
+ From simulation results:
1236
+
1237
+ >>> from diff_diff import simulate_power, DifferenceInDifferences
1238
+ >>> results = simulate_power(
1239
+ ... DifferenceInDifferences(),
1240
+ ... effect_sizes=[1, 2, 3, 5, 7, 10],
1241
+ ... n_simulations=200
1242
+ ... )
1243
+ >>> plot_power_curve(results)
1244
+
1245
+ Manual data:
1246
+
1247
+ >>> plot_power_curve(
1248
+ ... effect_sizes=[1, 2, 3, 4, 5],
1249
+ ... powers=[0.2, 0.5, 0.75, 0.90, 0.97],
1250
+ ... mde=2.5,
1251
+ ... target_power=0.80
1252
+ ... )
1253
+ """
1254
+ try:
1255
+ import matplotlib.pyplot as plt
1256
+ except ImportError:
1257
+ raise ImportError(
1258
+ "matplotlib is required for plotting. "
1259
+ "Install it with: pip install matplotlib"
1260
+ )
1261
+
1262
+ # Extract data from results if provided
1263
+ if results is not None:
1264
+ if isinstance(results, pd.DataFrame):
1265
+ # DataFrame input
1266
+ if "effect_size" in results.columns:
1267
+ effect_sizes = results["effect_size"].tolist()
1268
+ plot_type = "effect"
1269
+ elif "sample_size" in results.columns:
1270
+ effect_sizes = results["sample_size"].tolist()
1271
+ plot_type = "sample"
1272
+ else:
1273
+ raise ValueError(
1274
+ "DataFrame must have 'effect_size' or 'sample_size' column"
1275
+ )
1276
+ powers = results["power"].tolist()
1277
+ elif hasattr(results, "effect_sizes") and hasattr(results, "powers"):
1278
+ # SimulationPowerResults
1279
+ effect_sizes = results.effect_sizes
1280
+ powers = results.powers
1281
+ if mde is None and hasattr(results, "true_effect"):
1282
+ # Mark true effect on plot
1283
+ mde = results.true_effect
1284
+ elif hasattr(results, "mde"):
1285
+ # PowerResults - create curve data
1286
+ raise ValueError(
1287
+ "PowerResults should be used to get mde value, not as direct input. "
1288
+ "Use PowerAnalysis.power_curve() to generate curve data."
1289
+ )
1290
+ else:
1291
+ raise TypeError(
1292
+ f"Cannot extract power curve data from {type(results).__name__}"
1293
+ )
1294
+ elif effect_sizes is None or powers is None:
1295
+ raise ValueError(
1296
+ "Must provide either 'results' or both 'effect_sizes' and 'powers'"
1297
+ )
1298
+
1299
+ # Default titles and labels
1300
+ if title is None:
1301
+ if plot_type == "effect":
1302
+ title = "Power Curve"
1303
+ else:
1304
+ title = "Power vs Sample Size"
1305
+
1306
+ if xlabel is None:
1307
+ if plot_type == "effect":
1308
+ xlabel = "Effect Size"
1309
+ else:
1310
+ xlabel = "Sample Size"
1311
+
1312
+ # Create figure if needed
1313
+ if ax is None:
1314
+ fig, ax = plt.subplots(figsize=figsize)
1315
+ else:
1316
+ fig = ax.get_figure()
1317
+
1318
+ # Plot power curve
1319
+ ax.plot(
1320
+ effect_sizes, powers,
1321
+ color=color,
1322
+ linewidth=linewidth,
1323
+ label="Power"
1324
+ )
1325
+
1326
+ # Add target power line
1327
+ if show_target_line:
1328
+ ax.axhline(
1329
+ y=target_power,
1330
+ color=target_color,
1331
+ linestyle="--",
1332
+ linewidth=1.5,
1333
+ alpha=0.7,
1334
+ label=f"Target power ({target_power:.0%})"
1335
+ )
1336
+
1337
+ # Add MDE line
1338
+ if show_mde_line and mde is not None:
1339
+ ax.axvline(
1340
+ x=mde,
1341
+ color=mde_color,
1342
+ linestyle=":",
1343
+ linewidth=1.5,
1344
+ alpha=0.7,
1345
+ label=f"MDE = {mde:.3f}"
1346
+ )
1347
+
1348
+ # Mark intersection point
1349
+ # Find power at MDE
1350
+ if mde in effect_sizes:
1351
+ idx = effect_sizes.index(mde)
1352
+ power_at_mde = powers[idx]
1353
+ else:
1354
+ # Interpolate
1355
+ effect_arr = np.array(effect_sizes)
1356
+ power_arr = np.array(powers)
1357
+ if effect_arr.min() <= mde <= effect_arr.max():
1358
+ power_at_mde = np.interp(mde, effect_arr, power_arr)
1359
+ else:
1360
+ power_at_mde = None
1361
+
1362
+ if power_at_mde is not None:
1363
+ ax.scatter(
1364
+ [mde], [power_at_mde],
1365
+ color=mde_color,
1366
+ s=50,
1367
+ zorder=5
1368
+ )
1369
+
1370
+ # Configure axes
1371
+ ax.set_xlabel(xlabel)
1372
+ ax.set_ylabel(ylabel)
1373
+ ax.set_title(title)
1374
+
1375
+ # Y-axis from 0 to 1
1376
+ ax.set_ylim(0, 1.05)
1377
+
1378
+ # Format y-axis as percentage
1379
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
1380
+
1381
+ if show_grid:
1382
+ ax.grid(True, alpha=0.3)
1383
+
1384
+ ax.legend(loc="lower right")
1385
+
1386
+ fig.tight_layout()
1387
+
1388
+ if show:
1389
+ plt.show()
1390
+
1391
+ return ax