diff-diff 2.2.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1681 @@
1
+ """
2
+ Visualization functions for difference-in-differences analysis.
3
+
4
+ Provides event study plots and other diagnostic visualizations.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ if TYPE_CHECKING:
13
+ from diff_diff.bacon import BaconDecompositionResults
14
+ from diff_diff.honest_did import HonestDiDResults, SensitivityResults
15
+ from diff_diff.power import PowerResults, SimulationPowerResults
16
+ from diff_diff.pretrends import PreTrendsPowerCurve, PreTrendsPowerResults
17
+ from diff_diff.results import MultiPeriodDiDResults
18
+ from diff_diff.staggered import CallawaySantAnnaResults
19
+ from diff_diff.sun_abraham import SunAbrahamResults
20
+
21
+ # Type alias for results that can be plotted
22
+ PlottableResults = Union[
23
+ "MultiPeriodDiDResults",
24
+ "CallawaySantAnnaResults",
25
+ "SunAbrahamResults",
26
+ pd.DataFrame,
27
+ ]
28
+
29
+
30
+ def plot_event_study(
31
+ results: Optional[PlottableResults] = None,
32
+ *,
33
+ effects: Optional[Dict[Any, float]] = None,
34
+ se: Optional[Dict[Any, float]] = None,
35
+ periods: Optional[List[Any]] = None,
36
+ reference_period: Optional[Any] = None,
37
+ pre_periods: Optional[List[Any]] = None,
38
+ post_periods: Optional[List[Any]] = None,
39
+ alpha: float = 0.05,
40
+ figsize: Tuple[float, float] = (10, 6),
41
+ title: str = "Event Study",
42
+ xlabel: str = "Period Relative to Treatment",
43
+ ylabel: str = "Treatment Effect",
44
+ color: str = "#2563eb",
45
+ marker: str = "o",
46
+ markersize: int = 8,
47
+ linewidth: float = 1.5,
48
+ capsize: int = 4,
49
+ show_zero_line: bool = True,
50
+ show_reference_line: bool = True,
51
+ shade_pre: bool = True,
52
+ shade_color: str = "#f0f0f0",
53
+ ax: Optional[Any] = None,
54
+ show: bool = True,
55
+ ) -> Any:
56
+ """
57
+ Create an event study plot showing treatment effects over time.
58
+
59
+ This function creates a coefficient plot with point estimates and
60
+ confidence intervals for each time period, commonly used to visualize
61
+ dynamic treatment effects and assess pre-trends.
62
+
63
+ Parameters
64
+ ----------
65
+ results : MultiPeriodDiDResults, CallawaySantAnnaResults, or DataFrame, optional
66
+ Results object from MultiPeriodDiD, CallawaySantAnna, or a DataFrame
67
+ with columns 'period', 'effect', 'se' (and optionally 'conf_int_lower',
68
+ 'conf_int_upper'). If None, must provide effects and se directly.
69
+ effects : dict, optional
70
+ Dictionary mapping periods to effect estimates. Used if results is None.
71
+ se : dict, optional
72
+ Dictionary mapping periods to standard errors. Used if results is None.
73
+ periods : list, optional
74
+ List of periods to plot. If None, uses all periods from results.
75
+ reference_period : any, optional
76
+ The reference period to highlight. When explicitly provided, effects
77
+ are normalized (ref effect subtracted) and ref SE is set to NaN.
78
+ When None and auto-inferred from results, only hollow marker styling
79
+ is applied (no normalization). If None, tries to infer from results.
80
+ pre_periods : list, optional
81
+ List of pre-treatment periods. Used for shading.
82
+ post_periods : list, optional
83
+ List of post-treatment periods. Used for shading.
84
+ alpha : float, default=0.05
85
+ Significance level for confidence intervals.
86
+ figsize : tuple, default=(10, 6)
87
+ Figure size (width, height) in inches.
88
+ title : str, default="Event Study"
89
+ Plot title.
90
+ xlabel : str, default="Period Relative to Treatment"
91
+ X-axis label.
92
+ ylabel : str, default="Treatment Effect"
93
+ Y-axis label.
94
+ color : str, default="#2563eb"
95
+ Color for points and error bars.
96
+ marker : str, default="o"
97
+ Marker style for point estimates.
98
+ markersize : int, default=8
99
+ Size of markers.
100
+ linewidth : float, default=1.5
101
+ Width of error bar lines.
102
+ capsize : int, default=4
103
+ Size of error bar caps.
104
+ show_zero_line : bool, default=True
105
+ Whether to show a horizontal line at y=0.
106
+ show_reference_line : bool, default=True
107
+ Whether to show a vertical line at the reference period.
108
+ shade_pre : bool, default=True
109
+ Whether to shade the pre-treatment region.
110
+ shade_color : str, default="#f0f0f0"
111
+ Color for pre-treatment shading.
112
+ ax : matplotlib.axes.Axes, optional
113
+ Axes to plot on. If None, creates new figure.
114
+ show : bool, default=True
115
+ Whether to call plt.show() at the end.
116
+
117
+ Returns
118
+ -------
119
+ matplotlib.axes.Axes
120
+ The axes object containing the plot.
121
+
122
+ Examples
123
+ --------
124
+ Using with MultiPeriodDiD results:
125
+
126
+ >>> from diff_diff import MultiPeriodDiD, plot_event_study
127
+ >>> did = MultiPeriodDiD()
128
+ >>> results = did.fit(data, outcome='y', treatment='treated',
129
+ ... time='period', post_periods=[3, 4, 5])
130
+ >>> plot_event_study(results)
131
+
132
+ Using with a DataFrame:
133
+
134
+ >>> df = pd.DataFrame({
135
+ ... 'period': [-2, -1, 0, 1, 2],
136
+ ... 'effect': [0.1, 0.05, 0.0, 0.5, 0.6],
137
+ ... 'se': [0.1, 0.1, 0.0, 0.15, 0.15]
138
+ ... })
139
+ >>> plot_event_study(df, reference_period=0)
140
+
141
+ Using with manual effects:
142
+
143
+ >>> effects = {-2: 0.1, -1: 0.05, 0: 0.0, 1: 0.5, 2: 0.6}
144
+ >>> se = {-2: 0.1, -1: 0.1, 0: 0.0, 1: 0.15, 2: 0.15}
145
+ >>> plot_event_study(effects=effects, se=se, reference_period=0)
146
+
147
+ Notes
148
+ -----
149
+ Event study plots are a standard visualization in difference-in-differences
150
+ analysis. They show:
151
+
152
+ 1. **Pre-treatment periods**: Effects should be close to zero if parallel
153
+ trends holds. Large pre-treatment effects suggest the assumption may
154
+ be violated.
155
+
156
+ 2. **Reference period**: Usually the last pre-treatment period (t=-1).
157
+ When explicitly specified via ``reference_period``, effects are normalized
158
+ to zero at this period. When auto-inferred, shown with hollow marker only.
159
+
160
+ 3. **Post-treatment periods**: The treatment effects of interest. These
161
+ show how the outcome evolved after treatment.
162
+
163
+ The confidence intervals help assess statistical significance. Effects
164
+ whose CIs don't include zero are typically considered significant.
165
+ """
166
+ try:
167
+ import matplotlib.pyplot as plt
168
+ except ImportError:
169
+ raise ImportError(
170
+ "matplotlib is required for plotting. "
171
+ "Install it with: pip install matplotlib"
172
+ )
173
+
174
+ from scipy import stats as scipy_stats
175
+
176
+ # Track if reference_period was explicitly provided by user
177
+ reference_period_explicit = reference_period is not None
178
+
179
+ # Extract data from results if provided
180
+ if results is not None:
181
+ extracted = _extract_plot_data(
182
+ results, periods, pre_periods, post_periods, reference_period
183
+ )
184
+ effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred = extracted
185
+ # If reference was inferred from results, it was NOT explicitly provided
186
+ if reference_inferred:
187
+ reference_period_explicit = False
188
+ elif effects is None or se is None:
189
+ raise ValueError(
190
+ "Must provide either 'results' or both 'effects' and 'se'"
191
+ )
192
+
193
+ # Ensure effects and se are dicts
194
+ if not isinstance(effects, dict):
195
+ raise TypeError("effects must be a dictionary mapping periods to values")
196
+ if not isinstance(se, dict):
197
+ raise TypeError("se must be a dictionary mapping periods to values")
198
+
199
+ # Get periods to plot
200
+ if periods is None:
201
+ periods = sorted(effects.keys())
202
+
203
+ # Compute confidence intervals
204
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
205
+
206
+ # Normalize effects to reference period ONLY if explicitly specified by user
207
+ # Auto-inferred reference periods (from CallawaySantAnna) just get hollow marker styling,
208
+ # NO normalization. This prevents unintended normalization when the reference period
209
+ # isn't a true identifying constraint (e.g., CallawaySantAnna with base_period="varying").
210
+ if (reference_period is not None and reference_period in effects and
211
+ reference_period_explicit):
212
+ ref_effect = effects[reference_period]
213
+ if np.isfinite(ref_effect):
214
+ effects = {p: e - ref_effect for p, e in effects.items()}
215
+ # Set reference SE to NaN (it's now a constraint, not an estimate)
216
+ # This follows fixest convention where the omitted category has no SE/CI
217
+ se = {p: (np.nan if p == reference_period else s) for p, s in se.items()}
218
+
219
+ plot_data = []
220
+ for period in periods:
221
+ effect = effects.get(period, np.nan)
222
+ std_err = se.get(period, np.nan)
223
+
224
+ # Skip entries with NaN effect, but allow NaN SE (will plot without error bars)
225
+ if np.isnan(effect):
226
+ continue
227
+
228
+ # Compute CI only if SE is finite
229
+ if np.isfinite(std_err):
230
+ ci_lower = effect - critical_value * std_err
231
+ ci_upper = effect + critical_value * std_err
232
+ else:
233
+ ci_lower = np.nan
234
+ ci_upper = np.nan
235
+
236
+ plot_data.append({
237
+ 'period': period,
238
+ 'effect': effect,
239
+ 'se': std_err,
240
+ 'ci_lower': ci_lower,
241
+ 'ci_upper': ci_upper,
242
+ 'is_reference': period == reference_period,
243
+ })
244
+
245
+ if not plot_data:
246
+ raise ValueError("No valid data to plot")
247
+
248
+ df = pd.DataFrame(plot_data)
249
+
250
+ # Create figure if needed
251
+ if ax is None:
252
+ fig, ax = plt.subplots(figsize=figsize)
253
+ else:
254
+ fig = ax.get_figure()
255
+
256
+ # Convert periods to numeric for plotting
257
+ period_to_x = {p: i for i, p in enumerate(df['period'])}
258
+ x_vals = [period_to_x[p] for p in df['period']]
259
+
260
+ # Shade pre-treatment region
261
+ if shade_pre and pre_periods is not None:
262
+ pre_x = [period_to_x[p] for p in pre_periods if p in period_to_x]
263
+ if pre_x:
264
+ ax.axvspan(min(pre_x) - 0.5, max(pre_x) + 0.5,
265
+ color=shade_color, alpha=0.5, zorder=0)
266
+
267
+ # Draw horizontal zero line
268
+ if show_zero_line:
269
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, zorder=1)
270
+
271
+ # Draw vertical reference line
272
+ if show_reference_line and reference_period is not None:
273
+ if reference_period in period_to_x:
274
+ ref_x = period_to_x[reference_period]
275
+ ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)
276
+
277
+ # Plot error bars (only for entries with finite CI)
278
+ has_ci = df['ci_lower'].notna() & df['ci_upper'].notna()
279
+ if has_ci.any():
280
+ df_with_ci = df[has_ci]
281
+ x_with_ci = [period_to_x[p] for p in df_with_ci['period']]
282
+ yerr = [
283
+ df_with_ci['effect'] - df_with_ci['ci_lower'],
284
+ df_with_ci['ci_upper'] - df_with_ci['effect']
285
+ ]
286
+ ax.errorbar(
287
+ x_with_ci, df_with_ci['effect'], yerr=yerr,
288
+ fmt='none', color=color, capsize=capsize, linewidth=linewidth,
289
+ capthick=linewidth, zorder=2
290
+ )
291
+
292
+ # Plot point estimates
293
+ for i, row in df.iterrows():
294
+ x = period_to_x[row['period']]
295
+ if row['is_reference']:
296
+ # Hollow marker for reference period
297
+ ax.plot(x, row['effect'], marker=marker, markersize=markersize,
298
+ markerfacecolor='white', markeredgecolor=color,
299
+ markeredgewidth=2, zorder=3)
300
+ else:
301
+ ax.plot(x, row['effect'], marker=marker, markersize=markersize,
302
+ color=color, zorder=3)
303
+
304
+ # Set labels and title
305
+ ax.set_xlabel(xlabel)
306
+ ax.set_ylabel(ylabel)
307
+ ax.set_title(title)
308
+
309
+ # Set x-axis ticks
310
+ ax.set_xticks(x_vals)
311
+ ax.set_xticklabels([str(p) for p in df['period']])
312
+
313
+ # Add grid
314
+ ax.grid(True, alpha=0.3, axis='y')
315
+
316
+ # Tight layout
317
+ fig.tight_layout()
318
+
319
+ if show:
320
+ plt.show()
321
+
322
+ return ax
323
+
324
+
325
+ def _extract_plot_data(
326
+ results: PlottableResults,
327
+ periods: Optional[List[Any]],
328
+ pre_periods: Optional[List[Any]],
329
+ post_periods: Optional[List[Any]],
330
+ reference_period: Optional[Any],
331
+ ) -> Tuple[Dict, Dict, List, List, List, Any, bool]:
332
+ """
333
+ Extract plotting data from various result types.
334
+
335
+ Returns
336
+ -------
337
+ tuple
338
+ (effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred)
339
+
340
+ reference_inferred is True if reference_period was auto-detected from results
341
+ rather than explicitly provided by the user.
342
+ """
343
+ # Handle DataFrame input
344
+ if isinstance(results, pd.DataFrame):
345
+ if 'period' not in results.columns:
346
+ raise ValueError("DataFrame must have 'period' column")
347
+ if 'effect' not in results.columns:
348
+ raise ValueError("DataFrame must have 'effect' column")
349
+ if 'se' not in results.columns:
350
+ raise ValueError("DataFrame must have 'se' column")
351
+
352
+ effects = dict(zip(results['period'], results['effect']))
353
+ se = dict(zip(results['period'], results['se']))
354
+
355
+ if periods is None:
356
+ periods = list(results['period'])
357
+
358
+ # DataFrame input: reference_period was already set by caller, never inferred here
359
+ return effects, se, periods, pre_periods, post_periods, reference_period, False
360
+
361
+ # Handle MultiPeriodDiDResults
362
+ if hasattr(results, 'period_effects'):
363
+ effects = {}
364
+ se = {}
365
+
366
+ for period, pe in results.period_effects.items():
367
+ effects[period] = pe.effect
368
+ se[period] = pe.se
369
+
370
+ if pre_periods is None and hasattr(results, 'pre_periods'):
371
+ pre_periods = results.pre_periods
372
+
373
+ if post_periods is None and hasattr(results, 'post_periods'):
374
+ post_periods = results.post_periods
375
+
376
+ if periods is None:
377
+ periods = post_periods
378
+
379
+ # MultiPeriodDiDResults: reference_period was already set by caller, never inferred here
380
+ return effects, se, periods, pre_periods, post_periods, reference_period, False
381
+
382
+ # Handle CallawaySantAnnaResults (event study aggregation)
383
+ if hasattr(results, 'event_study_effects') and results.event_study_effects is not None:
384
+ effects = {}
385
+ se = {}
386
+
387
+ for rel_period, effect_data in results.event_study_effects.items():
388
+ effects[rel_period] = effect_data['effect']
389
+ se[rel_period] = effect_data['se']
390
+
391
+ if periods is None:
392
+ periods = sorted(effects.keys())
393
+
394
+ # Track if reference_period was explicitly provided vs auto-inferred
395
+ reference_inferred = False
396
+
397
+ # Reference period is typically -1 for event study
398
+ if reference_period is None:
399
+ reference_inferred = True # We're about to infer it
400
+ # Detect reference period from n_groups=0 marker (normalization constraint)
401
+ # This handles anticipation > 0 where reference is at e = -1 - anticipation
402
+ for period, effect_data in results.event_study_effects.items():
403
+ if effect_data.get('n_groups', 1) == 0:
404
+ reference_period = period
405
+ break
406
+ # Fallback to -1 if no marker found (backward compatibility)
407
+ if reference_period is None:
408
+ reference_period = -1
409
+
410
+ if pre_periods is None:
411
+ pre_periods = [p for p in periods if p < 0]
412
+
413
+ if post_periods is None:
414
+ post_periods = [p for p in periods if p >= 0]
415
+
416
+ return effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred
417
+
418
+ raise TypeError(
419
+ f"Cannot extract plot data from {type(results).__name__}. "
420
+ "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, "
421
+ "SunAbrahamResults, or DataFrame."
422
+ )
423
+
424
+
425
+ def plot_group_effects(
426
+ results: "CallawaySantAnnaResults",
427
+ *,
428
+ groups: Optional[List[Any]] = None,
429
+ figsize: Tuple[float, float] = (10, 6),
430
+ title: str = "Treatment Effects by Cohort",
431
+ xlabel: str = "Time Period",
432
+ ylabel: str = "Treatment Effect",
433
+ alpha: float = 0.05,
434
+ show: bool = True,
435
+ ax: Optional[Any] = None,
436
+ ) -> Any:
437
+ """
438
+ Plot treatment effects by treatment cohort (group).
439
+
440
+ Parameters
441
+ ----------
442
+ results : CallawaySantAnnaResults
443
+ Results from CallawaySantAnna estimator.
444
+ groups : list, optional
445
+ List of groups (cohorts) to plot. If None, plots all groups.
446
+ figsize : tuple, default=(10, 6)
447
+ Figure size.
448
+ title : str
449
+ Plot title.
450
+ xlabel : str
451
+ X-axis label.
452
+ ylabel : str
453
+ Y-axis label.
454
+ alpha : float, default=0.05
455
+ Significance level for confidence intervals.
456
+ show : bool, default=True
457
+ Whether to call plt.show().
458
+ ax : matplotlib.axes.Axes, optional
459
+ Axes to plot on.
460
+
461
+ Returns
462
+ -------
463
+ matplotlib.axes.Axes
464
+ The axes object.
465
+ """
466
+ try:
467
+ import matplotlib.pyplot as plt
468
+ except ImportError:
469
+ raise ImportError(
470
+ "matplotlib is required for plotting. "
471
+ "Install it with: pip install matplotlib"
472
+ )
473
+
474
+ from scipy import stats as scipy_stats
475
+
476
+ if not hasattr(results, 'group_time_effects'):
477
+ raise TypeError("results must be a CallawaySantAnnaResults object")
478
+
479
+ # Get groups to plot
480
+ if groups is None:
481
+ groups = sorted(set(g for g, t in results.group_time_effects.keys()))
482
+
483
+ # Create figure
484
+ if ax is None:
485
+ fig, ax = plt.subplots(figsize=figsize)
486
+ else:
487
+ fig = ax.get_figure()
488
+
489
+ # Color palette
490
+ colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))
491
+
492
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
493
+
494
+ for i, group in enumerate(groups):
495
+ # Get effects for this group
496
+ group_effects = [
497
+ (t, data) for (g, t), data in results.group_time_effects.items()
498
+ if g == group
499
+ ]
500
+ group_effects.sort(key=lambda x: x[0])
501
+
502
+ if not group_effects:
503
+ continue
504
+
505
+ times = [t for t, _ in group_effects]
506
+ effects = [data['effect'] for _, data in group_effects]
507
+ ses = [data['se'] for _, data in group_effects]
508
+
509
+ yerr = [
510
+ [e - (e - critical_value * s) for e, s in zip(effects, ses)],
511
+ [(e + critical_value * s) - e for e, s in zip(effects, ses)]
512
+ ]
513
+
514
+ ax.errorbar(
515
+ times, effects, yerr=yerr,
516
+ label=f'Cohort {group}', color=colors[i],
517
+ marker='o', capsize=3, linewidth=1.5
518
+ )
519
+
520
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1)
521
+ ax.set_xlabel(xlabel)
522
+ ax.set_ylabel(ylabel)
523
+ ax.set_title(title)
524
+ ax.legend(loc='best')
525
+ ax.grid(True, alpha=0.3, axis='y')
526
+
527
+ fig.tight_layout()
528
+
529
+ if show:
530
+ plt.show()
531
+
532
+ return ax
533
+
534
+
535
+ def plot_sensitivity(
536
+ sensitivity_results: "SensitivityResults",
537
+ *,
538
+ show_bounds: bool = True,
539
+ show_ci: bool = True,
540
+ breakdown_line: bool = True,
541
+ figsize: Tuple[float, float] = (10, 6),
542
+ title: str = "Honest DiD Sensitivity Analysis",
543
+ xlabel: str = "M (restriction parameter)",
544
+ ylabel: str = "Treatment Effect",
545
+ bounds_color: str = "#2563eb",
546
+ bounds_alpha: float = 0.3,
547
+ ci_color: str = "#2563eb",
548
+ ci_linewidth: float = 1.5,
549
+ breakdown_color: str = "#dc2626",
550
+ original_color: str = "#1f2937",
551
+ ax: Optional[Any] = None,
552
+ show: bool = True,
553
+ ) -> Any:
554
+ """
555
+ Plot sensitivity analysis results from Honest DiD.
556
+
557
+ Shows how treatment effect bounds and confidence intervals
558
+ change as the restriction parameter M varies.
559
+
560
+ Parameters
561
+ ----------
562
+ sensitivity_results : SensitivityResults
563
+ Results from HonestDiD.sensitivity_analysis().
564
+ show_bounds : bool, default=True
565
+ Whether to show the identified set bounds as shaded region.
566
+ show_ci : bool, default=True
567
+ Whether to show robust confidence interval lines.
568
+ breakdown_line : bool, default=True
569
+ Whether to show vertical line at breakdown value.
570
+ figsize : tuple, default=(10, 6)
571
+ Figure size (width, height) in inches.
572
+ title : str
573
+ Plot title.
574
+ xlabel : str
575
+ X-axis label.
576
+ ylabel : str
577
+ Y-axis label.
578
+ bounds_color : str
579
+ Color for identified set shading.
580
+ bounds_alpha : float
581
+ Transparency for identified set shading.
582
+ ci_color : str
583
+ Color for confidence interval lines.
584
+ ci_linewidth : float
585
+ Line width for CI lines.
586
+ breakdown_color : str
587
+ Color for breakdown value line.
588
+ original_color : str
589
+ Color for original estimate line.
590
+ ax : matplotlib.axes.Axes, optional
591
+ Axes to plot on. If None, creates new figure.
592
+ show : bool, default=True
593
+ Whether to call plt.show().
594
+
595
+ Returns
596
+ -------
597
+ matplotlib.axes.Axes
598
+ The axes object containing the plot.
599
+
600
+ Examples
601
+ --------
602
+ >>> from diff_diff import MultiPeriodDiD
603
+ >>> from diff_diff.honest_did import HonestDiD
604
+ >>> from diff_diff.visualization import plot_sensitivity
605
+ >>>
606
+ >>> # Fit event study and run sensitivity analysis
607
+ >>> results = MultiPeriodDiD().fit(data, ...)
608
+ >>> honest = HonestDiD(method='relative_magnitude')
609
+ >>> sensitivity = honest.sensitivity_analysis(results)
610
+ >>>
611
+ >>> # Create sensitivity plot
612
+ >>> plot_sensitivity(sensitivity)
613
+ """
614
+ try:
615
+ import matplotlib.pyplot as plt
616
+ except ImportError:
617
+ raise ImportError(
618
+ "matplotlib is required for plotting. "
619
+ "Install it with: pip install matplotlib"
620
+ )
621
+
622
+ # Create figure if needed
623
+ if ax is None:
624
+ fig, ax = plt.subplots(figsize=figsize)
625
+ else:
626
+ fig = ax.get_figure()
627
+
628
+ M = sensitivity_results.M_values
629
+ bounds_arr = np.array(sensitivity_results.bounds)
630
+ ci_arr = np.array(sensitivity_results.robust_cis)
631
+
632
+ # Plot original estimate
633
+ ax.axhline(
634
+ y=sensitivity_results.original_estimate,
635
+ color=original_color,
636
+ linestyle='-',
637
+ linewidth=1.5,
638
+ label='Original estimate',
639
+ alpha=0.7
640
+ )
641
+
642
+ # Plot zero line
643
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
644
+
645
+ # Plot identified set bounds
646
+ if show_bounds:
647
+ ax.fill_between(
648
+ M, bounds_arr[:, 0], bounds_arr[:, 1],
649
+ alpha=bounds_alpha,
650
+ color=bounds_color,
651
+ label='Identified set'
652
+ )
653
+
654
+ # Plot confidence intervals
655
+ if show_ci:
656
+ ax.plot(
657
+ M, ci_arr[:, 0],
658
+ color=ci_color,
659
+ linewidth=ci_linewidth,
660
+ label='Robust CI'
661
+ )
662
+ ax.plot(
663
+ M, ci_arr[:, 1],
664
+ color=ci_color,
665
+ linewidth=ci_linewidth
666
+ )
667
+
668
+ # Plot breakdown line
669
+ if breakdown_line and sensitivity_results.breakdown_M is not None:
670
+ ax.axvline(
671
+ x=sensitivity_results.breakdown_M,
672
+ color=breakdown_color,
673
+ linestyle=':',
674
+ linewidth=2,
675
+ label=f'Breakdown (M={sensitivity_results.breakdown_M:.2f})'
676
+ )
677
+
678
+ ax.set_xlabel(xlabel)
679
+ ax.set_ylabel(ylabel)
680
+ ax.set_title(title)
681
+ ax.legend(loc='best')
682
+ ax.grid(True, alpha=0.3)
683
+
684
+ fig.tight_layout()
685
+
686
+ if show:
687
+ plt.show()
688
+
689
+ return ax
690
+
691
+
692
+ def plot_honest_event_study(
693
+ honest_results: "HonestDiDResults",
694
+ *,
695
+ periods: Optional[List[Any]] = None,
696
+ reference_period: Optional[Any] = None,
697
+ figsize: Tuple[float, float] = (10, 6),
698
+ title: str = "Event Study with Honest Confidence Intervals",
699
+ xlabel: str = "Period Relative to Treatment",
700
+ ylabel: str = "Treatment Effect",
701
+ original_color: str = "#6b7280",
702
+ honest_color: str = "#2563eb",
703
+ marker: str = "o",
704
+ markersize: int = 8,
705
+ capsize: int = 4,
706
+ ax: Optional[Any] = None,
707
+ show: bool = True,
708
+ ) -> Any:
709
+ """
710
+ Create event study plot with Honest DiD confidence intervals.
711
+
712
+ Shows both the original confidence intervals (assuming parallel trends)
713
+ and the robust confidence intervals that allow for bounded violations.
714
+
715
+ Parameters
716
+ ----------
717
+ honest_results : HonestDiDResults
718
+ Results from HonestDiD.fit() that include event_study_bounds.
719
+ periods : list, optional
720
+ Periods to plot. If None, uses all available periods.
721
+ reference_period : any, optional
722
+ Reference period to show as hollow marker.
723
+ figsize : tuple, default=(10, 6)
724
+ Figure size.
725
+ title : str
726
+ Plot title.
727
+ xlabel : str
728
+ X-axis label.
729
+ ylabel : str
730
+ Y-axis label.
731
+ original_color : str
732
+ Color for original (standard) confidence intervals.
733
+ honest_color : str
734
+ Color for honest (robust) confidence intervals.
735
+ marker : str
736
+ Marker style.
737
+ markersize : int
738
+ Marker size.
739
+ capsize : int
740
+ Error bar cap size.
741
+ ax : matplotlib.axes.Axes, optional
742
+ Axes to plot on.
743
+ show : bool, default=True
744
+ Whether to call plt.show().
745
+
746
+ Returns
747
+ -------
748
+ matplotlib.axes.Axes
749
+ The axes object.
750
+
751
+ Notes
752
+ -----
753
+ This function requires the HonestDiDResults to have been computed
754
+ with event_study_bounds. If only a scalar bound was computed,
755
+ use plot_sensitivity() instead.
756
+ """
757
+ try:
758
+ import matplotlib.pyplot as plt
759
+ except ImportError:
760
+ raise ImportError(
761
+ "matplotlib is required for plotting. "
762
+ "Install it with: pip install matplotlib"
763
+ )
764
+
765
+ from scipy import stats as scipy_stats
766
+
767
+ # Get original results for standard CIs
768
+ original_results = honest_results.original_results
769
+ if original_results is None:
770
+ raise ValueError(
771
+ "HonestDiDResults must have original_results to plot event study"
772
+ )
773
+
774
+ # Extract data from original results
775
+ if hasattr(original_results, 'period_effects'):
776
+ # MultiPeriodDiDResults
777
+ effects_dict = {
778
+ p: pe.effect for p, pe in original_results.period_effects.items()
779
+ }
780
+ se_dict = {
781
+ p: pe.se for p, pe in original_results.period_effects.items()
782
+ }
783
+ if periods is None:
784
+ periods = list(original_results.period_effects.keys())
785
+ elif hasattr(original_results, 'event_study_effects'):
786
+ # CallawaySantAnnaResults
787
+ effects_dict = {
788
+ t: data['effect']
789
+ for t, data in original_results.event_study_effects.items()
790
+ }
791
+ se_dict = {
792
+ t: data['se']
793
+ for t, data in original_results.event_study_effects.items()
794
+ }
795
+ if periods is None:
796
+ periods = sorted(original_results.event_study_effects.keys())
797
+ else:
798
+ raise TypeError("Cannot extract event study data from original_results")
799
+
800
+ # Create figure
801
+ if ax is None:
802
+ fig, ax = plt.subplots(figsize=figsize)
803
+ else:
804
+ fig = ax.get_figure()
805
+
806
+ # Compute CIs
807
+ alpha = honest_results.alpha
808
+ z = scipy_stats.norm.ppf(1 - alpha / 2)
809
+
810
+ x_vals = list(range(len(periods)))
811
+
812
+ effects = [effects_dict[p] for p in periods]
813
+ original_ci_lower = [effects_dict[p] - z * se_dict[p] for p in periods]
814
+ original_ci_upper = [effects_dict[p] + z * se_dict[p] for p in periods]
815
+
816
+ # Get honest bounds if available for each period
817
+ if honest_results.event_study_bounds:
818
+ honest_ci_lower = [
819
+ honest_results.event_study_bounds[p]['ci_lb']
820
+ for p in periods
821
+ ]
822
+ honest_ci_upper = [
823
+ honest_results.event_study_bounds[p]['ci_ub']
824
+ for p in periods
825
+ ]
826
+ else:
827
+ # Use scalar bounds applied to all periods
828
+ honest_ci_lower = [honest_results.ci_lb] * len(periods)
829
+ honest_ci_upper = [honest_results.ci_ub] * len(periods)
830
+
831
+ # Zero line
832
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
833
+
834
+ # Plot original CIs (thinner, background)
835
+ yerr_orig = [
836
+ [e - lower for e, lower in zip(effects, original_ci_lower)],
837
+ [u - e for e, u in zip(effects, original_ci_upper)]
838
+ ]
839
+ ax.errorbar(
840
+ x_vals, effects, yerr=yerr_orig,
841
+ fmt='none', color=original_color, capsize=capsize - 1,
842
+ linewidth=1, alpha=0.6, label='Standard CI'
843
+ )
844
+
845
+ # Plot honest CIs (thicker, foreground)
846
+ yerr_honest = [
847
+ [e - lower for e, lower in zip(effects, honest_ci_lower)],
848
+ [u - e for e, u in zip(effects, honest_ci_upper)]
849
+ ]
850
+ ax.errorbar(
851
+ x_vals, effects, yerr=yerr_honest,
852
+ fmt='none', color=honest_color, capsize=capsize,
853
+ linewidth=2, label=f'Honest CI (M={honest_results.M:.2f})'
854
+ )
855
+
856
+ # Plot point estimates
857
+ for i, (x, effect, period) in enumerate(zip(x_vals, effects, periods)):
858
+ is_ref = period == reference_period
859
+ if is_ref:
860
+ ax.plot(x, effect, marker=marker, markersize=markersize,
861
+ markerfacecolor='white', markeredgecolor=honest_color,
862
+ markeredgewidth=2, zorder=3)
863
+ else:
864
+ ax.plot(x, effect, marker=marker, markersize=markersize,
865
+ color=honest_color, zorder=3)
866
+
867
+ ax.set_xlabel(xlabel)
868
+ ax.set_ylabel(ylabel)
869
+ ax.set_title(title)
870
+ ax.set_xticks(x_vals)
871
+ ax.set_xticklabels([str(p) for p in periods])
872
+ ax.legend(loc='best')
873
+ ax.grid(True, alpha=0.3, axis='y')
874
+
875
+ fig.tight_layout()
876
+
877
+ if show:
878
+ plt.show()
879
+
880
+ return ax
881
+
882
+
883
+ def plot_bacon(
884
+ results: "BaconDecompositionResults",
885
+ *,
886
+ plot_type: str = "scatter",
887
+ figsize: Tuple[float, float] = (10, 6),
888
+ title: Optional[str] = None,
889
+ xlabel: str = "2x2 DiD Estimate",
890
+ ylabel: str = "Weight",
891
+ colors: Optional[Dict[str, str]] = None,
892
+ marker: str = "o",
893
+ markersize: int = 80,
894
+ alpha: float = 0.7,
895
+ show_weighted_avg: bool = True,
896
+ show_twfe_line: bool = True,
897
+ ax: Optional[Any] = None,
898
+ show: bool = True,
899
+ ) -> Any:
900
+ """
901
+ Visualize Goodman-Bacon decomposition results.
902
+
903
+ Creates either a scatter plot showing the weight and estimate for each
904
+ 2x2 comparison, or a stacked bar chart showing total weight by comparison
905
+ type.
906
+
907
+ Parameters
908
+ ----------
909
+ results : BaconDecompositionResults
910
+ Results from BaconDecomposition.fit() or bacon_decompose().
911
+ plot_type : str, default="scatter"
912
+ Type of plot to create:
913
+ - "scatter": Scatter plot with estimates on x-axis, weights on y-axis
914
+ - "bar": Stacked bar chart of weights by comparison type
915
+ figsize : tuple, default=(10, 6)
916
+ Figure size (width, height) in inches.
917
+ title : str, optional
918
+ Plot title. If None, uses a default based on plot_type.
919
+ xlabel : str, default="2x2 DiD Estimate"
920
+ X-axis label (scatter plot only).
921
+ ylabel : str, default="Weight"
922
+ Y-axis label.
923
+ colors : dict, optional
924
+ Dictionary mapping comparison types to colors. Keys are:
925
+ "treated_vs_never", "earlier_vs_later", "later_vs_earlier".
926
+ If None, uses default colors.
927
+ marker : str, default="o"
928
+ Marker style for scatter plot.
929
+ markersize : int, default=80
930
+ Marker size for scatter plot.
931
+ alpha : float, default=0.7
932
+ Transparency for markers/bars.
933
+ show_weighted_avg : bool, default=True
934
+ Whether to show weighted average lines for each comparison type
935
+ (scatter plot only).
936
+ show_twfe_line : bool, default=True
937
+ Whether to show a vertical line at the TWFE estimate (scatter plot only).
938
+ ax : matplotlib.axes.Axes, optional
939
+ Axes to plot on. If None, creates new figure.
940
+ show : bool, default=True
941
+ Whether to call plt.show() at the end.
942
+
943
+ Returns
944
+ -------
945
+ matplotlib.axes.Axes
946
+ The axes object containing the plot.
947
+
948
+ Examples
949
+ --------
950
+ Scatter plot (default):
951
+
952
+ >>> from diff_diff import bacon_decompose, plot_bacon
953
+ >>> results = bacon_decompose(data, outcome='y', unit='id',
954
+ ... time='t', first_treat='first_treat')
955
+ >>> plot_bacon(results)
956
+
957
+ Bar chart of weights by type:
958
+
959
+ >>> plot_bacon(results, plot_type='bar')
960
+
961
+ Customized scatter plot:
962
+
963
+ >>> plot_bacon(results,
964
+ ... colors={'treated_vs_never': 'green',
965
+ ... 'earlier_vs_later': 'blue',
966
+ ... 'later_vs_earlier': 'red'},
967
+ ... title='My Bacon Decomposition')
968
+
969
+ Notes
970
+ -----
971
+ The scatter plot is particularly useful for understanding:
972
+
973
+ 1. **Distribution of estimates**: Are 2x2 estimates clustered or spread?
974
+ Wide spread suggests heterogeneous treatment effects.
975
+
976
+ 2. **Weight concentration**: Do a few comparisons dominate the TWFE?
977
+ Points with high weights have more influence.
978
+
979
+ 3. **Forbidden comparison problem**: Red points (later_vs_earlier) show
980
+ comparisons using already-treated units as controls. If these have
981
+ different estimates than clean comparisons, TWFE may be biased.
982
+
983
+ The bar chart provides a quick summary of how much weight falls on
984
+ each comparison type, which is useful for assessing the severity
985
+ of potential TWFE bias.
986
+
987
+ See Also
988
+ --------
989
+ bacon_decompose : Perform the decomposition
990
+ BaconDecomposition : Class-based interface
991
+ """
992
+ try:
993
+ import matplotlib.pyplot as plt
994
+ except ImportError:
995
+ raise ImportError(
996
+ "matplotlib is required for plotting. "
997
+ "Install it with: pip install matplotlib"
998
+ )
999
+
1000
+ # Default colors
1001
+ if colors is None:
1002
+ colors = {
1003
+ "treated_vs_never": "#22c55e", # Green - clean comparison
1004
+ "earlier_vs_later": "#3b82f6", # Blue - valid comparison
1005
+ "later_vs_earlier": "#ef4444", # Red - forbidden comparison
1006
+ }
1007
+
1008
+ # Default titles
1009
+ if title is None:
1010
+ if plot_type == "scatter":
1011
+ title = "Goodman-Bacon Decomposition"
1012
+ else:
1013
+ title = "TWFE Weight by Comparison Type"
1014
+
1015
+ # Create figure if needed
1016
+ if ax is None:
1017
+ fig, ax = plt.subplots(figsize=figsize)
1018
+ else:
1019
+ fig = ax.get_figure()
1020
+
1021
+ if plot_type == "scatter":
1022
+ _plot_bacon_scatter(
1023
+ ax, results, colors, marker, markersize, alpha,
1024
+ show_weighted_avg, show_twfe_line, xlabel, ylabel, title
1025
+ )
1026
+ elif plot_type == "bar":
1027
+ _plot_bacon_bar(ax, results, colors, alpha, ylabel, title)
1028
+ else:
1029
+ raise ValueError(f"Unknown plot_type: {plot_type}. Use 'scatter' or 'bar'.")
1030
+
1031
+ fig.tight_layout()
1032
+
1033
+ if show:
1034
+ plt.show()
1035
+
1036
+ return ax
1037
+
1038
+
1039
+ def _plot_bacon_scatter(
1040
+ ax: Any,
1041
+ results: "BaconDecompositionResults",
1042
+ colors: Dict[str, str],
1043
+ marker: str,
1044
+ markersize: int,
1045
+ alpha: float,
1046
+ show_weighted_avg: bool,
1047
+ show_twfe_line: bool,
1048
+ xlabel: str,
1049
+ ylabel: str,
1050
+ title: str,
1051
+ ) -> None:
1052
+ """Create scatter plot of Bacon decomposition."""
1053
+ # Separate comparisons by type
1054
+ by_type: Dict[str, List[Tuple[float, float]]] = {
1055
+ "treated_vs_never": [],
1056
+ "earlier_vs_later": [],
1057
+ "later_vs_earlier": [],
1058
+ }
1059
+
1060
+ for comp in results.comparisons:
1061
+ by_type[comp.comparison_type].append((comp.estimate, comp.weight))
1062
+
1063
+ # Plot each type
1064
+ labels = {
1065
+ "treated_vs_never": "Treated vs Never-treated",
1066
+ "earlier_vs_later": "Earlier vs Later treated",
1067
+ "later_vs_earlier": "Later vs Earlier (forbidden)",
1068
+ }
1069
+
1070
+ for ctype, points in by_type.items():
1071
+ if not points:
1072
+ continue
1073
+ estimates = [p[0] for p in points]
1074
+ weights = [p[1] for p in points]
1075
+ ax.scatter(
1076
+ estimates, weights,
1077
+ c=colors[ctype],
1078
+ label=labels[ctype],
1079
+ marker=marker,
1080
+ s=markersize,
1081
+ alpha=alpha,
1082
+ edgecolors='white',
1083
+ linewidths=0.5,
1084
+ )
1085
+
1086
+ # Show weighted average lines
1087
+ if show_weighted_avg:
1088
+ effect_by_type = results.effect_by_type()
1089
+ for ctype, avg_effect in effect_by_type.items():
1090
+ if avg_effect is not None and by_type[ctype]:
1091
+ ax.axvline(
1092
+ x=avg_effect,
1093
+ color=colors[ctype],
1094
+ linestyle='--',
1095
+ alpha=0.5,
1096
+ linewidth=1.5,
1097
+ )
1098
+
1099
+ # Show TWFE estimate line
1100
+ if show_twfe_line:
1101
+ ax.axvline(
1102
+ x=results.twfe_estimate,
1103
+ color='black',
1104
+ linestyle='-',
1105
+ linewidth=2,
1106
+ label=f'TWFE = {results.twfe_estimate:.4f}',
1107
+ )
1108
+
1109
+ ax.set_xlabel(xlabel)
1110
+ ax.set_ylabel(ylabel)
1111
+ ax.set_title(title)
1112
+ ax.legend(loc='best')
1113
+ ax.grid(True, alpha=0.3)
1114
+
1115
+ # Add zero line
1116
+ ax.axvline(x=0, color='gray', linestyle=':', alpha=0.5)
1117
+
1118
+
1119
+ def _plot_bacon_bar(
1120
+ ax: Any,
1121
+ results: "BaconDecompositionResults",
1122
+ colors: Dict[str, str],
1123
+ alpha: float,
1124
+ ylabel: str,
1125
+ title: str,
1126
+ ) -> None:
1127
+ """Create stacked bar chart of weights by comparison type."""
1128
+ # Get weights
1129
+ weights = results.weight_by_type()
1130
+
1131
+ # Labels and colors
1132
+ type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"]
1133
+ labels = {
1134
+ "treated_vs_never": "Treated vs Never-treated",
1135
+ "earlier_vs_later": "Earlier vs Later",
1136
+ "later_vs_earlier": "Later vs Earlier\n(forbidden)",
1137
+ }
1138
+
1139
+ # Create bar data
1140
+ bar_labels = [labels[t] for t in type_order]
1141
+ bar_weights = [weights[t] for t in type_order]
1142
+ bar_colors = [colors[t] for t in type_order]
1143
+
1144
+ # Create bars
1145
+ bars = ax.bar(
1146
+ bar_labels,
1147
+ bar_weights,
1148
+ color=bar_colors,
1149
+ alpha=alpha,
1150
+ edgecolor='white',
1151
+ linewidth=1,
1152
+ )
1153
+
1154
+ # Add percentage labels on bars
1155
+ for bar, weight in zip(bars, bar_weights):
1156
+ if weight > 0.01: # Only label if > 1%
1157
+ height = bar.get_height()
1158
+ ax.annotate(
1159
+ f'{weight:.1%}',
1160
+ xy=(bar.get_x() + bar.get_width() / 2, height),
1161
+ xytext=(0, 3),
1162
+ textcoords="offset points",
1163
+ ha='center',
1164
+ va='bottom',
1165
+ fontsize=10,
1166
+ fontweight='bold',
1167
+ )
1168
+
1169
+ # Add weighted average effect annotations
1170
+ effects = results.effect_by_type()
1171
+ for bar, ctype in zip(bars, type_order):
1172
+ effect = effects[ctype]
1173
+ if effect is not None and weights[ctype] > 0.01:
1174
+ ax.annotate(
1175
+ f'β = {effect:.3f}',
1176
+ xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2),
1177
+ ha='center',
1178
+ va='center',
1179
+ fontsize=9,
1180
+ color='white',
1181
+ fontweight='bold',
1182
+ )
1183
+
1184
+ ax.set_ylabel(ylabel)
1185
+ ax.set_title(title)
1186
+ ax.set_ylim(0, 1.1)
1187
+
1188
+ # Add horizontal line at total weight = 1
1189
+ ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
1190
+
1191
+ # Add TWFE estimate as text
1192
+ ax.text(
1193
+ 0.98, 0.98,
1194
+ f'TWFE = {results.twfe_estimate:.4f}',
1195
+ transform=ax.transAxes,
1196
+ ha='right',
1197
+ va='top',
1198
+ fontsize=10,
1199
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
1200
+ )
1201
+
1202
+
1203
+ def plot_power_curve(
1204
+ results: Optional[Union["PowerResults", "SimulationPowerResults", pd.DataFrame]] = None,
1205
+ *,
1206
+ effect_sizes: Optional[List[float]] = None,
1207
+ powers: Optional[List[float]] = None,
1208
+ mde: Optional[float] = None,
1209
+ target_power: float = 0.80,
1210
+ plot_type: str = "effect",
1211
+ figsize: Tuple[float, float] = (10, 6),
1212
+ title: Optional[str] = None,
1213
+ xlabel: Optional[str] = None,
1214
+ ylabel: str = "Power",
1215
+ color: str = "#2563eb",
1216
+ mde_color: str = "#dc2626",
1217
+ target_color: str = "#22c55e",
1218
+ linewidth: float = 2.0,
1219
+ show_mde_line: bool = True,
1220
+ show_target_line: bool = True,
1221
+ show_grid: bool = True,
1222
+ ax: Optional[Any] = None,
1223
+ show: bool = True,
1224
+ ) -> Any:
1225
+ """
1226
+ Create a power curve visualization.
1227
+
1228
+ Shows how statistical power changes with effect size or sample size,
1229
+ helping researchers understand the trade-offs in study design.
1230
+
1231
+ Parameters
1232
+ ----------
1233
+ results : PowerResults, SimulationPowerResults, or DataFrame, optional
1234
+ Results object from PowerAnalysis or simulate_power(), or a DataFrame
1235
+ with columns 'effect_size' and 'power' (or 'sample_size' and 'power').
1236
+ If None, must provide effect_sizes and powers directly.
1237
+ effect_sizes : list of float, optional
1238
+ Effect sizes (x-axis values). Required if results is None.
1239
+ powers : list of float, optional
1240
+ Power values (y-axis values). Required if results is None.
1241
+ mde : float, optional
1242
+ Minimum detectable effect to mark on the plot.
1243
+ target_power : float, default=0.80
1244
+ Target power level to show as horizontal line.
1245
+ plot_type : str, default="effect"
1246
+ Type of power curve: "effect" (power vs effect size) or
1247
+ "sample" (power vs sample size).
1248
+ figsize : tuple, default=(10, 6)
1249
+ Figure size (width, height) in inches.
1250
+ title : str, optional
1251
+ Plot title. If None, uses a sensible default.
1252
+ xlabel : str, optional
1253
+ X-axis label. If None, uses a sensible default.
1254
+ ylabel : str, default="Power"
1255
+ Y-axis label.
1256
+ color : str, default="#2563eb"
1257
+ Color for the power curve line.
1258
+ mde_color : str, default="#dc2626"
1259
+ Color for the MDE vertical line.
1260
+ target_color : str, default="#22c55e"
1261
+ Color for the target power horizontal line.
1262
+ linewidth : float, default=2.0
1263
+ Line width for the power curve.
1264
+ show_mde_line : bool, default=True
1265
+ Whether to show vertical line at MDE.
1266
+ show_target_line : bool, default=True
1267
+ Whether to show horizontal line at target power.
1268
+ show_grid : bool, default=True
1269
+ Whether to show grid lines.
1270
+ ax : matplotlib.axes.Axes, optional
1271
+ Axes to plot on. If None, creates new figure.
1272
+ show : bool, default=True
1273
+ Whether to call plt.show() at the end.
1274
+
1275
+ Returns
1276
+ -------
1277
+ matplotlib.axes.Axes
1278
+ The axes object containing the plot.
1279
+
1280
+ Examples
1281
+ --------
1282
+ From PowerAnalysis results:
1283
+
1284
+ >>> from diff_diff import PowerAnalysis, plot_power_curve
1285
+ >>> pa = PowerAnalysis(power=0.80)
1286
+ >>> curve_df = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
1287
+ >>> mde_result = pa.mde(n_treated=50, n_control=50, sigma=5.0)
1288
+ >>> plot_power_curve(curve_df, mde=mde_result.mde)
1289
+
1290
+ From simulation results:
1291
+
1292
+ >>> from diff_diff import simulate_power, DifferenceInDifferences
1293
+ >>> results = simulate_power(
1294
+ ... DifferenceInDifferences(),
1295
+ ... effect_sizes=[1, 2, 3, 5, 7, 10],
1296
+ ... n_simulations=200
1297
+ ... )
1298
+ >>> plot_power_curve(results)
1299
+
1300
+ Manual data:
1301
+
1302
+ >>> plot_power_curve(
1303
+ ... effect_sizes=[1, 2, 3, 4, 5],
1304
+ ... powers=[0.2, 0.5, 0.75, 0.90, 0.97],
1305
+ ... mde=2.5,
1306
+ ... target_power=0.80
1307
+ ... )
1308
+ """
1309
+ try:
1310
+ import matplotlib.pyplot as plt
1311
+ except ImportError:
1312
+ raise ImportError(
1313
+ "matplotlib is required for plotting. "
1314
+ "Install it with: pip install matplotlib"
1315
+ )
1316
+
1317
+ # Extract data from results if provided
1318
+ if results is not None:
1319
+ if isinstance(results, pd.DataFrame):
1320
+ # DataFrame input
1321
+ if "effect_size" in results.columns:
1322
+ effect_sizes = results["effect_size"].tolist()
1323
+ plot_type = "effect"
1324
+ elif "sample_size" in results.columns:
1325
+ effect_sizes = results["sample_size"].tolist()
1326
+ plot_type = "sample"
1327
+ else:
1328
+ raise ValueError(
1329
+ "DataFrame must have 'effect_size' or 'sample_size' column"
1330
+ )
1331
+ powers = results["power"].tolist()
1332
+ elif hasattr(results, "effect_sizes") and hasattr(results, "powers"):
1333
+ # SimulationPowerResults
1334
+ effect_sizes = results.effect_sizes
1335
+ powers = results.powers
1336
+ if mde is None and hasattr(results, "true_effect"):
1337
+ # Mark true effect on plot
1338
+ mde = results.true_effect
1339
+ elif hasattr(results, "mde"):
1340
+ # PowerResults - create curve data
1341
+ raise ValueError(
1342
+ "PowerResults should be used to get mde value, not as direct input. "
1343
+ "Use PowerAnalysis.power_curve() to generate curve data."
1344
+ )
1345
+ else:
1346
+ raise TypeError(
1347
+ f"Cannot extract power curve data from {type(results).__name__}"
1348
+ )
1349
+ elif effect_sizes is None or powers is None:
1350
+ raise ValueError(
1351
+ "Must provide either 'results' or both 'effect_sizes' and 'powers'"
1352
+ )
1353
+
1354
+ # Default titles and labels
1355
+ if title is None:
1356
+ if plot_type == "effect":
1357
+ title = "Power Curve"
1358
+ else:
1359
+ title = "Power vs Sample Size"
1360
+
1361
+ if xlabel is None:
1362
+ if plot_type == "effect":
1363
+ xlabel = "Effect Size"
1364
+ else:
1365
+ xlabel = "Sample Size"
1366
+
1367
+ # Create figure if needed
1368
+ if ax is None:
1369
+ fig, ax = plt.subplots(figsize=figsize)
1370
+ else:
1371
+ fig = ax.get_figure()
1372
+
1373
+ # Plot power curve
1374
+ ax.plot(
1375
+ effect_sizes, powers,
1376
+ color=color,
1377
+ linewidth=linewidth,
1378
+ label="Power"
1379
+ )
1380
+
1381
+ # Add target power line
1382
+ if show_target_line:
1383
+ ax.axhline(
1384
+ y=target_power,
1385
+ color=target_color,
1386
+ linestyle="--",
1387
+ linewidth=1.5,
1388
+ alpha=0.7,
1389
+ label=f"Target power ({target_power:.0%})"
1390
+ )
1391
+
1392
+ # Add MDE line
1393
+ if show_mde_line and mde is not None:
1394
+ ax.axvline(
1395
+ x=mde,
1396
+ color=mde_color,
1397
+ linestyle=":",
1398
+ linewidth=1.5,
1399
+ alpha=0.7,
1400
+ label=f"MDE = {mde:.3f}"
1401
+ )
1402
+
1403
+ # Mark intersection point
1404
+ # Find power at MDE
1405
+ if mde in effect_sizes:
1406
+ idx = effect_sizes.index(mde)
1407
+ power_at_mde = powers[idx]
1408
+ else:
1409
+ # Interpolate
1410
+ effect_arr = np.array(effect_sizes)
1411
+ power_arr = np.array(powers)
1412
+ if effect_arr.min() <= mde <= effect_arr.max():
1413
+ power_at_mde = np.interp(mde, effect_arr, power_arr)
1414
+ else:
1415
+ power_at_mde = None
1416
+
1417
+ if power_at_mde is not None:
1418
+ ax.scatter(
1419
+ [mde], [power_at_mde],
1420
+ color=mde_color,
1421
+ s=50,
1422
+ zorder=5
1423
+ )
1424
+
1425
+ # Configure axes
1426
+ ax.set_xlabel(xlabel)
1427
+ ax.set_ylabel(ylabel)
1428
+ ax.set_title(title)
1429
+
1430
+ # Y-axis from 0 to 1
1431
+ ax.set_ylim(0, 1.05)
1432
+
1433
+ # Format y-axis as percentage
1434
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
1435
+
1436
+ if show_grid:
1437
+ ax.grid(True, alpha=0.3)
1438
+
1439
+ ax.legend(loc="lower right")
1440
+
1441
+ fig.tight_layout()
1442
+
1443
+ if show:
1444
+ plt.show()
1445
+
1446
+ return ax
1447
+
1448
+
1449
+ def plot_pretrends_power(
1450
+ results: Optional[Union["PreTrendsPowerResults", "PreTrendsPowerCurve", pd.DataFrame]] = None,
1451
+ *,
1452
+ M_values: Optional[List[float]] = None,
1453
+ powers: Optional[List[float]] = None,
1454
+ mdv: Optional[float] = None,
1455
+ target_power: float = 0.80,
1456
+ figsize: Tuple[float, float] = (10, 6),
1457
+ title: str = "Pre-Trends Test Power Curve",
1458
+ xlabel: str = "Violation Magnitude (M)",
1459
+ ylabel: str = "Power",
1460
+ color: str = "#2563eb",
1461
+ mdv_color: str = "#dc2626",
1462
+ target_color: str = "#22c55e",
1463
+ linewidth: float = 2.0,
1464
+ show_mdv_line: bool = True,
1465
+ show_target_line: bool = True,
1466
+ show_grid: bool = True,
1467
+ ax: Optional[Any] = None,
1468
+ show: bool = True,
1469
+ ) -> Any:
1470
+ """
1471
+ Plot pre-trends test power curve.
1472
+
1473
+ Visualizes how the power to detect parallel trends violations changes
1474
+ with the violation magnitude (M). This helps understand what violations
1475
+ your pre-trends test is capable of detecting.
1476
+
1477
+ Parameters
1478
+ ----------
1479
+ results : PreTrendsPowerResults, PreTrendsPowerCurve, or DataFrame, optional
1480
+ Results from PreTrendsPower.fit() or power_curve(), or a DataFrame
1481
+ with columns 'M' and 'power'. If None, must provide M_values and powers.
1482
+ M_values : list of float, optional
1483
+ Violation magnitudes (x-axis). Required if results is None.
1484
+ powers : list of float, optional
1485
+ Power values (y-axis). Required if results is None.
1486
+ mdv : float, optional
1487
+ Minimum detectable violation to mark on the plot.
1488
+ target_power : float, default=0.80
1489
+ Target power level to show as horizontal line.
1490
+ figsize : tuple, default=(10, 6)
1491
+ Figure size (width, height) in inches.
1492
+ title : str
1493
+ Plot title.
1494
+ xlabel : str
1495
+ X-axis label.
1496
+ ylabel : str
1497
+ Y-axis label.
1498
+ color : str, default="#2563eb"
1499
+ Color for the power curve line.
1500
+ mdv_color : str, default="#dc2626"
1501
+ Color for the MDV vertical line.
1502
+ target_color : str, default="#22c55e"
1503
+ Color for the target power horizontal line.
1504
+ linewidth : float, default=2.0
1505
+ Line width for the power curve.
1506
+ show_mdv_line : bool, default=True
1507
+ Whether to show vertical line at MDV.
1508
+ show_target_line : bool, default=True
1509
+ Whether to show horizontal line at target power.
1510
+ show_grid : bool, default=True
1511
+ Whether to show grid lines.
1512
+ ax : matplotlib.axes.Axes, optional
1513
+ Axes to plot on. If None, creates new figure.
1514
+ show : bool, default=True
1515
+ Whether to call plt.show() at the end.
1516
+
1517
+ Returns
1518
+ -------
1519
+ matplotlib.axes.Axes
1520
+ The axes object containing the plot.
1521
+
1522
+ Examples
1523
+ --------
1524
+ From PreTrendsPower results:
1525
+
1526
+ >>> from diff_diff import MultiPeriodDiD
1527
+ >>> from diff_diff.pretrends import PreTrendsPower
1528
+ >>> from diff_diff.visualization import plot_pretrends_power
1529
+ >>>
1530
+ >>> mp_did = MultiPeriodDiD()
1531
+ >>> event_results = mp_did.fit(data, outcome='y', treatment='treated',
1532
+ ... time='period', post_periods=[4, 5, 6, 7])
1533
+ >>>
1534
+ >>> pt = PreTrendsPower()
1535
+ >>> curve = pt.power_curve(event_results)
1536
+ >>> plot_pretrends_power(curve)
1537
+
1538
+ From manual data:
1539
+
1540
+ >>> plot_pretrends_power(
1541
+ ... M_values=[0, 0.5, 1, 1.5, 2],
1542
+ ... powers=[0.05, 0.3, 0.6, 0.85, 0.95],
1543
+ ... mdv=1.2,
1544
+ ... target_power=0.80
1545
+ ... )
1546
+
1547
+ Notes
1548
+ -----
1549
+ The power curve shows how likely you are to reject the null hypothesis
1550
+ of parallel trends given a true violation of magnitude M. Key points:
1551
+
1552
+ 1. **At M=0**: Power equals alpha (size of the test).
1553
+ 2. **At MDV**: Power equals target power (default 80%).
1554
+ 3. **Beyond MDV**: Power increases toward 100%.
1555
+
1556
+ A steep power curve indicates a sensitive pre-trends test. A flat curve
1557
+ indicates the test has limited ability to detect violations, suggesting
1558
+ you should use HonestDiD sensitivity analysis for robust inference.
1559
+
1560
+ See Also
1561
+ --------
1562
+ PreTrendsPower : Main class for pre-trends power analysis
1563
+ plot_sensitivity : Plot HonestDiD sensitivity analysis
1564
+ """
1565
+ try:
1566
+ import matplotlib.pyplot as plt
1567
+ except ImportError:
1568
+ raise ImportError(
1569
+ "matplotlib is required for plotting. "
1570
+ "Install it with: pip install matplotlib"
1571
+ )
1572
+
1573
+ # Extract data from results if provided
1574
+ if results is not None:
1575
+ if isinstance(results, pd.DataFrame):
1576
+ if "M" not in results.columns or "power" not in results.columns:
1577
+ raise ValueError("DataFrame must have 'M' and 'power' columns")
1578
+ M_values = results["M"].tolist()
1579
+ powers = results["power"].tolist()
1580
+ elif hasattr(results, "M_values") and hasattr(results, "powers"):
1581
+ # PreTrendsPowerCurve
1582
+ M_values = results.M_values.tolist()
1583
+ powers = results.powers.tolist()
1584
+ if mdv is None:
1585
+ mdv = results.mdv
1586
+ if target_power is None:
1587
+ target_power = results.target_power
1588
+ elif hasattr(results, "mdv") and hasattr(results, "power"):
1589
+ # Single PreTrendsPowerResults - create a simple plot
1590
+ if mdv is None:
1591
+ mdv = results.mdv
1592
+ # Create minimal curve around MDV
1593
+ if np.isfinite(mdv):
1594
+ M_values = [0, mdv * 0.5, mdv, mdv * 1.5, mdv * 2]
1595
+ else:
1596
+ M_values = [0, 1, 2, 3, 4]
1597
+ # We don't have the actual powers, so we need to create a placeholder
1598
+ # Just show MDV marker
1599
+ powers = None
1600
+ else:
1601
+ raise TypeError(
1602
+ f"Cannot extract power curve data from {type(results).__name__}"
1603
+ )
1604
+ elif M_values is None or powers is None:
1605
+ raise ValueError(
1606
+ "Must provide either 'results' or both 'M_values' and 'powers'"
1607
+ )
1608
+
1609
+ # Create figure if needed
1610
+ if ax is None:
1611
+ fig, ax = plt.subplots(figsize=figsize)
1612
+ else:
1613
+ fig = ax.get_figure()
1614
+
1615
+ # Plot power curve if we have powers
1616
+ if powers is not None:
1617
+ ax.plot(
1618
+ M_values, powers,
1619
+ color=color,
1620
+ linewidth=linewidth,
1621
+ label="Power"
1622
+ )
1623
+
1624
+ # Add target power line
1625
+ if show_target_line:
1626
+ ax.axhline(
1627
+ y=target_power,
1628
+ color=target_color,
1629
+ linestyle="--",
1630
+ linewidth=1.5,
1631
+ alpha=0.7,
1632
+ label=f"Target power ({target_power:.0%})"
1633
+ )
1634
+
1635
+ # Add MDV line
1636
+ if show_mdv_line and mdv is not None and np.isfinite(mdv):
1637
+ ax.axvline(
1638
+ x=mdv,
1639
+ color=mdv_color,
1640
+ linestyle=":",
1641
+ linewidth=1.5,
1642
+ alpha=0.7,
1643
+ label=f"MDV = {mdv:.3f}"
1644
+ )
1645
+
1646
+ # Mark intersection point if we have powers
1647
+ if powers is not None:
1648
+ # Find power at MDV (interpolate)
1649
+ M_arr = np.array(M_values)
1650
+ power_arr = np.array(powers)
1651
+ if M_arr.min() <= mdv <= M_arr.max():
1652
+ power_at_mdv = np.interp(mdv, M_arr, power_arr)
1653
+ ax.scatter(
1654
+ [mdv], [power_at_mdv],
1655
+ color=mdv_color,
1656
+ s=50,
1657
+ zorder=5
1658
+ )
1659
+
1660
+ # Configure axes
1661
+ ax.set_xlabel(xlabel)
1662
+ ax.set_ylabel(ylabel)
1663
+ ax.set_title(title)
1664
+
1665
+ # Y-axis from 0 to 1
1666
+ ax.set_ylim(0, 1.05)
1667
+
1668
+ # Format y-axis as percentage
1669
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
1670
+
1671
+ if show_grid:
1672
+ ax.grid(True, alpha=0.3)
1673
+
1674
+ ax.legend(loc="lower right")
1675
+
1676
+ fig.tight_layout()
1677
+
1678
+ if show:
1679
+ plt.show()
1680
+
1681
+ return ax