diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,661 @@
1
+ """Power analysis visualization functions."""
2
+
3
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ if TYPE_CHECKING:
9
+ from diff_diff.power import PowerResults, SimulationPowerResults
10
+ from diff_diff.pretrends import PreTrendsPowerCurve, PreTrendsPowerResults
11
+
12
+
13
+ def plot_power_curve(
14
+ results: Optional[Union["PowerResults", "SimulationPowerResults", pd.DataFrame]] = None,
15
+ *,
16
+ effect_sizes: Optional[List[float]] = None,
17
+ powers: Optional[List[float]] = None,
18
+ mde: Optional[float] = None,
19
+ target_power: float = 0.80,
20
+ plot_type: str = "effect",
21
+ figsize: Tuple[float, float] = (10, 6),
22
+ title: Optional[str] = None,
23
+ xlabel: Optional[str] = None,
24
+ ylabel: str = "Power",
25
+ color: str = "#2563eb",
26
+ mde_color: str = "#dc2626",
27
+ target_color: str = "#22c55e",
28
+ linewidth: float = 2.0,
29
+ show_mde_line: bool = True,
30
+ show_target_line: bool = True,
31
+ show_grid: bool = True,
32
+ ax: Optional[Any] = None,
33
+ show: bool = True,
34
+ backend: str = "matplotlib",
35
+ ) -> Any:
36
+ """
37
+ Create a power curve visualization.
38
+
39
+ Shows how statistical power changes with effect size or sample size,
40
+ helping researchers understand the trade-offs in study design.
41
+
42
+ Parameters
43
+ ----------
44
+ results : PowerResults, SimulationPowerResults, or DataFrame, optional
45
+ Results object from PowerAnalysis or simulate_power(), or a DataFrame
46
+ with columns 'effect_size' and 'power' (or 'sample_size' and 'power').
47
+ If None, must provide effect_sizes and powers directly.
48
+ effect_sizes : list of float, optional
49
+ Effect sizes (x-axis values). Required if results is None.
50
+ powers : list of float, optional
51
+ Power values (y-axis values). Required if results is None.
52
+ mde : float, optional
53
+ Minimum detectable effect to mark on the plot.
54
+ target_power : float, default=0.80
55
+ Target power level to show as horizontal line.
56
+ plot_type : str, default="effect"
57
+ Type of power curve: "effect" (power vs effect size) or
58
+ "sample" (power vs sample size).
59
+ figsize : tuple, default=(10, 6)
60
+ Figure size (width, height) in inches.
61
+ title : str, optional
62
+ Plot title. If None, uses a sensible default.
63
+ xlabel : str, optional
64
+ X-axis label. If None, uses a sensible default.
65
+ ylabel : str, default="Power"
66
+ Y-axis label.
67
+ color : str, default="#2563eb"
68
+ Color for the power curve line.
69
+ mde_color : str, default="#dc2626"
70
+ Color for the MDE vertical line.
71
+ target_color : str, default="#22c55e"
72
+ Color for the target power horizontal line.
73
+ linewidth : float, default=2.0
74
+ Line width for the power curve.
75
+ show_mde_line : bool, default=True
76
+ Whether to show vertical line at MDE.
77
+ show_target_line : bool, default=True
78
+ Whether to show horizontal line at target power.
79
+ show_grid : bool, default=True
80
+ Whether to show grid lines.
81
+ ax : matplotlib.axes.Axes, optional
82
+ Axes to plot on. If None, creates new figure.
83
+ show : bool, default=True
84
+ Whether to call plt.show() at the end.
85
+ backend : str, default="matplotlib"
86
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
87
+
88
+ Returns
89
+ -------
90
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
91
+ The axes object (matplotlib) or figure (plotly).
92
+
93
+ Examples
94
+ --------
95
+ From PowerAnalysis results:
96
+
97
+ >>> from diff_diff import PowerAnalysis, plot_power_curve
98
+ >>> pa = PowerAnalysis(power=0.80)
99
+ >>> curve_df = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
100
+ >>> mde_result = pa.mde(n_treated=50, n_control=50, sigma=5.0)
101
+ >>> plot_power_curve(curve_df, mde=mde_result.mde)
102
+
103
+ From simulation results:
104
+
105
+ >>> from diff_diff import simulate_power, DifferenceInDifferences
106
+ >>> results = simulate_power(
107
+ ... DifferenceInDifferences(),
108
+ ... effect_sizes=[1, 2, 3, 5, 7, 10],
109
+ ... n_simulations=200
110
+ ... )
111
+ >>> plot_power_curve(results)
112
+
113
+ Manual data:
114
+
115
+ >>> plot_power_curve(
116
+ ... effect_sizes=[1, 2, 3, 4, 5],
117
+ ... powers=[0.2, 0.5, 0.75, 0.90, 0.97],
118
+ ... mde=2.5,
119
+ ... target_power=0.80
120
+ ... )
121
+ """
122
+ # Extract data from results if provided
123
+ if results is not None:
124
+ if isinstance(results, pd.DataFrame):
125
+ if "effect_size" in results.columns:
126
+ effect_sizes = results["effect_size"].tolist()
127
+ plot_type = "effect"
128
+ elif "sample_size" in results.columns:
129
+ effect_sizes = results["sample_size"].tolist()
130
+ plot_type = "sample"
131
+ else:
132
+ raise ValueError("DataFrame must have 'effect_size' or 'sample_size' column")
133
+ powers = results["power"].tolist()
134
+ elif hasattr(results, "effect_sizes") and hasattr(results, "powers"):
135
+ # SimulationPowerResults
136
+ effect_sizes = results.effect_sizes
137
+ powers = results.powers
138
+ if mde is None and hasattr(results, "true_effect"):
139
+ mde = results.true_effect
140
+ elif hasattr(results, "mde"):
141
+ raise ValueError(
142
+ "PowerResults should be used to get mde value, not as direct input. "
143
+ "Use PowerAnalysis.power_curve() to generate curve data."
144
+ )
145
+ else:
146
+ raise TypeError(f"Cannot extract power curve data from {type(results).__name__}")
147
+ elif effect_sizes is None or powers is None:
148
+ raise ValueError("Must provide either 'results' or both 'effect_sizes' and 'powers'")
149
+
150
+ # Default titles and labels
151
+ if title is None:
152
+ title = "Power Curve" if plot_type == "effect" else "Power vs Sample Size"
153
+ if xlabel is None:
154
+ xlabel = "Effect Size" if plot_type == "effect" else "Sample Size"
155
+
156
+ if backend == "plotly":
157
+ return _render_power_curve_plotly(
158
+ effect_sizes=effect_sizes,
159
+ powers=powers,
160
+ mde=mde,
161
+ target_power=target_power,
162
+ title=title,
163
+ xlabel=xlabel,
164
+ ylabel=ylabel,
165
+ color=color,
166
+ mde_color=mde_color,
167
+ target_color=target_color,
168
+ linewidth=linewidth,
169
+ show_mde_line=show_mde_line,
170
+ show_target_line=show_target_line,
171
+ show_grid=show_grid,
172
+ show=show,
173
+ )
174
+
175
+ return _render_power_curve_mpl(
176
+ effect_sizes=effect_sizes,
177
+ powers=powers,
178
+ mde=mde,
179
+ target_power=target_power,
180
+ figsize=figsize,
181
+ title=title,
182
+ xlabel=xlabel,
183
+ ylabel=ylabel,
184
+ color=color,
185
+ mde_color=mde_color,
186
+ target_color=target_color,
187
+ linewidth=linewidth,
188
+ show_mde_line=show_mde_line,
189
+ show_target_line=show_target_line,
190
+ show_grid=show_grid,
191
+ ax=ax,
192
+ show=show,
193
+ )
194
+
195
+
196
+ def _render_power_curve_mpl(
197
+ *,
198
+ effect_sizes,
199
+ powers,
200
+ mde,
201
+ target_power,
202
+ figsize,
203
+ title,
204
+ xlabel,
205
+ ylabel,
206
+ color,
207
+ mde_color,
208
+ target_color,
209
+ linewidth,
210
+ show_mde_line,
211
+ show_target_line,
212
+ show_grid,
213
+ ax,
214
+ show,
215
+ ):
216
+ """Render power curve with matplotlib."""
217
+ from diff_diff.visualization._common import _require_matplotlib
218
+
219
+ plt = _require_matplotlib()
220
+
221
+ if ax is None:
222
+ fig, ax = plt.subplots(figsize=figsize)
223
+ else:
224
+ fig = ax.get_figure()
225
+
226
+ # Plot power curve
227
+ ax.plot(effect_sizes, powers, color=color, linewidth=linewidth, label="Power")
228
+
229
+ # Add target power line
230
+ if show_target_line:
231
+ ax.axhline(
232
+ y=target_power,
233
+ color=target_color,
234
+ linestyle="--",
235
+ linewidth=1.5,
236
+ alpha=0.7,
237
+ label=f"Target power ({target_power:.0%})",
238
+ )
239
+
240
+ # Add MDE line
241
+ if show_mde_line and mde is not None:
242
+ ax.axvline(
243
+ x=mde,
244
+ color=mde_color,
245
+ linestyle=":",
246
+ linewidth=1.5,
247
+ alpha=0.7,
248
+ label=f"MDE = {mde:.3f}",
249
+ )
250
+
251
+ # Mark intersection point
252
+ if mde in effect_sizes:
253
+ idx = effect_sizes.index(mde)
254
+ power_at_mde = powers[idx]
255
+ else:
256
+ effect_arr = np.array(effect_sizes)
257
+ power_arr = np.array(powers)
258
+ if effect_arr.min() <= mde <= effect_arr.max():
259
+ power_at_mde = np.interp(mde, effect_arr, power_arr)
260
+ else:
261
+ power_at_mde = None
262
+
263
+ if power_at_mde is not None:
264
+ ax.scatter([mde], [power_at_mde], color=mde_color, s=50, zorder=5)
265
+
266
+ ax.set_xlabel(xlabel)
267
+ ax.set_ylabel(ylabel)
268
+ ax.set_title(title)
269
+ ax.set_ylim(0, 1.05)
270
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
271
+
272
+ if show_grid:
273
+ ax.grid(True, alpha=0.3)
274
+
275
+ ax.legend(loc="lower right")
276
+ fig.tight_layout()
277
+
278
+ if show:
279
+ plt.show()
280
+
281
+ return ax
282
+
283
+
284
+ def _render_power_curve_plotly(
285
+ *,
286
+ effect_sizes,
287
+ powers,
288
+ mde,
289
+ target_power,
290
+ title,
291
+ xlabel,
292
+ ylabel,
293
+ color,
294
+ mde_color,
295
+ target_color,
296
+ linewidth,
297
+ show_mde_line,
298
+ show_target_line,
299
+ show_grid,
300
+ show,
301
+ ):
302
+ """Render power curve with plotly."""
303
+ from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
304
+
305
+ go = _require_plotly()
306
+
307
+ fig = go.Figure()
308
+
309
+ fig.add_trace(
310
+ go.Scatter(
311
+ x=effect_sizes,
312
+ y=powers,
313
+ mode="lines",
314
+ line=dict(color=color, width=linewidth),
315
+ name="Power",
316
+ )
317
+ )
318
+
319
+ if show_target_line:
320
+ fig.add_hline(
321
+ y=target_power,
322
+ line_dash="dash",
323
+ line_color=target_color,
324
+ opacity=0.7,
325
+ annotation_text=f"Target ({target_power:.0%})",
326
+ )
327
+
328
+ if show_mde_line and mde is not None:
329
+ fig.add_vline(
330
+ x=mde,
331
+ line_dash="dot",
332
+ line_color=mde_color,
333
+ opacity=0.7,
334
+ annotation_text=f"MDE = {mde:.3f}",
335
+ )
336
+
337
+ _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
338
+ fig.update_xaxes(showgrid=show_grid)
339
+ fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
340
+
341
+ if show:
342
+ fig.show()
343
+
344
+ return fig
345
+
346
+
347
+ def plot_pretrends_power(
348
+ results: Optional[Union["PreTrendsPowerResults", "PreTrendsPowerCurve", pd.DataFrame]] = None,
349
+ *,
350
+ M_values: Optional[List[float]] = None,
351
+ powers: Optional[List[float]] = None,
352
+ mdv: Optional[float] = None,
353
+ target_power: float = 0.80,
354
+ figsize: Tuple[float, float] = (10, 6),
355
+ title: str = "Pre-Trends Test Power Curve",
356
+ xlabel: str = "Violation Magnitude (M)",
357
+ ylabel: str = "Power",
358
+ color: str = "#2563eb",
359
+ mdv_color: str = "#dc2626",
360
+ target_color: str = "#22c55e",
361
+ linewidth: float = 2.0,
362
+ show_mdv_line: bool = True,
363
+ show_target_line: bool = True,
364
+ show_grid: bool = True,
365
+ ax: Optional[Any] = None,
366
+ show: bool = True,
367
+ backend: str = "matplotlib",
368
+ ) -> Any:
369
+ """
370
+ Plot pre-trends test power curve.
371
+
372
+ Visualizes how the power to detect parallel trends violations changes
373
+ with the violation magnitude (M). This helps understand what violations
374
+ your pre-trends test is capable of detecting.
375
+
376
+ Parameters
377
+ ----------
378
+ results : PreTrendsPowerResults, PreTrendsPowerCurve, or DataFrame, optional
379
+ Results from PreTrendsPower.fit() or power_curve(), or a DataFrame
380
+ with columns 'M' and 'power'. If None, must provide M_values and powers.
381
+ M_values : list of float, optional
382
+ Violation magnitudes (x-axis). Required if results is None.
383
+ powers : list of float, optional
384
+ Power values (y-axis). Required if results is None.
385
+ mdv : float, optional
386
+ Minimum detectable violation to mark on the plot.
387
+ target_power : float, default=0.80
388
+ Target power level to show as horizontal line.
389
+ figsize : tuple, default=(10, 6)
390
+ Figure size (width, height) in inches.
391
+ title : str
392
+ Plot title.
393
+ xlabel : str
394
+ X-axis label.
395
+ ylabel : str
396
+ Y-axis label.
397
+ color : str, default="#2563eb"
398
+ Color for the power curve line.
399
+ mdv_color : str, default="#dc2626"
400
+ Color for the MDV vertical line.
401
+ target_color : str, default="#22c55e"
402
+ Color for the target power horizontal line.
403
+ linewidth : float, default=2.0
404
+ Line width for the power curve.
405
+ show_mdv_line : bool, default=True
406
+ Whether to show vertical line at MDV.
407
+ show_target_line : bool, default=True
408
+ Whether to show horizontal line at target power.
409
+ show_grid : bool, default=True
410
+ Whether to show grid lines.
411
+ ax : matplotlib.axes.Axes, optional
412
+ Axes to plot on. If None, creates new figure.
413
+ show : bool, default=True
414
+ Whether to call plt.show() at the end.
415
+ backend : str, default="matplotlib"
416
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
417
+
418
+ Returns
419
+ -------
420
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
421
+ The axes object (matplotlib) or figure (plotly).
422
+
423
+ Examples
424
+ --------
425
+ From PreTrendsPower results:
426
+
427
+ >>> from diff_diff import MultiPeriodDiD
428
+ >>> from diff_diff.pretrends import PreTrendsPower
429
+ >>> from diff_diff.visualization import plot_pretrends_power
430
+ >>>
431
+ >>> mp_did = MultiPeriodDiD()
432
+ >>> event_results = mp_did.fit(data, outcome='y', treatment='treated',
433
+ ... time='period', post_periods=[4, 5, 6, 7])
434
+ >>>
435
+ >>> pt = PreTrendsPower()
436
+ >>> curve = pt.power_curve(event_results)
437
+ >>> plot_pretrends_power(curve)
438
+
439
+ Notes
440
+ -----
441
+ The power curve shows how likely you are to reject the null hypothesis
442
+ of parallel trends given a true violation of magnitude M.
443
+
444
+ See Also
445
+ --------
446
+ PreTrendsPower : Main class for pre-trends power analysis
447
+ plot_sensitivity : Plot HonestDiD sensitivity analysis
448
+ """
449
+ # Extract data from results if provided
450
+ if results is not None:
451
+ if isinstance(results, pd.DataFrame):
452
+ if "M" not in results.columns or "power" not in results.columns:
453
+ raise ValueError("DataFrame must have 'M' and 'power' columns")
454
+ M_values = results["M"].tolist()
455
+ powers = results["power"].tolist()
456
+ elif hasattr(results, "M_values") and hasattr(results, "powers"):
457
+ # PreTrendsPowerCurve
458
+ M_values = results.M_values.tolist()
459
+ powers = results.powers.tolist()
460
+ if mdv is None:
461
+ mdv = results.mdv
462
+ if target_power is None:
463
+ target_power = results.target_power
464
+ elif hasattr(results, "mdv") and hasattr(results, "power"):
465
+ # Single PreTrendsPowerResults
466
+ if mdv is None:
467
+ mdv = results.mdv
468
+ if np.isfinite(mdv):
469
+ M_values = [0, mdv * 0.5, mdv, mdv * 1.5, mdv * 2]
470
+ else:
471
+ M_values = [0, 1, 2, 3, 4]
472
+ powers = None
473
+ else:
474
+ raise TypeError(f"Cannot extract power curve data from {type(results).__name__}")
475
+ elif M_values is None or powers is None:
476
+ raise ValueError("Must provide either 'results' or both 'M_values' and 'powers'")
477
+
478
+ if backend == "plotly":
479
+ return _render_pretrends_power_plotly(
480
+ M_values=M_values,
481
+ powers=powers,
482
+ mdv=mdv,
483
+ target_power=target_power,
484
+ title=title,
485
+ xlabel=xlabel,
486
+ ylabel=ylabel,
487
+ color=color,
488
+ mdv_color=mdv_color,
489
+ target_color=target_color,
490
+ linewidth=linewidth,
491
+ show_mdv_line=show_mdv_line,
492
+ show_target_line=show_target_line,
493
+ show_grid=show_grid,
494
+ show=show,
495
+ )
496
+
497
+ return _render_pretrends_power_mpl(
498
+ M_values=M_values,
499
+ powers=powers,
500
+ mdv=mdv,
501
+ target_power=target_power,
502
+ figsize=figsize,
503
+ title=title,
504
+ xlabel=xlabel,
505
+ ylabel=ylabel,
506
+ color=color,
507
+ mdv_color=mdv_color,
508
+ target_color=target_color,
509
+ linewidth=linewidth,
510
+ show_mdv_line=show_mdv_line,
511
+ show_target_line=show_target_line,
512
+ show_grid=show_grid,
513
+ ax=ax,
514
+ show=show,
515
+ )
516
+
517
+
518
+ def _render_pretrends_power_mpl(
519
+ *,
520
+ M_values,
521
+ powers,
522
+ mdv,
523
+ target_power,
524
+ figsize,
525
+ title,
526
+ xlabel,
527
+ ylabel,
528
+ color,
529
+ mdv_color,
530
+ target_color,
531
+ linewidth,
532
+ show_mdv_line,
533
+ show_target_line,
534
+ show_grid,
535
+ ax,
536
+ show,
537
+ ):
538
+ """Render pre-trends power curve with matplotlib."""
539
+ from diff_diff.visualization._common import _require_matplotlib
540
+
541
+ plt = _require_matplotlib()
542
+
543
+ if ax is None:
544
+ fig, ax = plt.subplots(figsize=figsize)
545
+ else:
546
+ fig = ax.get_figure()
547
+
548
+ # Plot power curve if we have powers
549
+ if powers is not None:
550
+ ax.plot(M_values, powers, color=color, linewidth=linewidth, label="Power")
551
+
552
+ # Add target power line
553
+ if show_target_line:
554
+ ax.axhline(
555
+ y=target_power,
556
+ color=target_color,
557
+ linestyle="--",
558
+ linewidth=1.5,
559
+ alpha=0.7,
560
+ label=f"Target power ({target_power:.0%})",
561
+ )
562
+
563
+ # Add MDV line
564
+ if show_mdv_line and mdv is not None and np.isfinite(mdv):
565
+ ax.axvline(
566
+ x=mdv,
567
+ color=mdv_color,
568
+ linestyle=":",
569
+ linewidth=1.5,
570
+ alpha=0.7,
571
+ label=f"MDV = {mdv:.3f}",
572
+ )
573
+
574
+ # Mark intersection point if we have powers
575
+ if powers is not None:
576
+ M_arr = np.array(M_values)
577
+ power_arr = np.array(powers)
578
+ if M_arr.min() <= mdv <= M_arr.max():
579
+ power_at_mdv = np.interp(mdv, M_arr, power_arr)
580
+ ax.scatter([mdv], [power_at_mdv], color=mdv_color, s=50, zorder=5)
581
+
582
+ ax.set_xlabel(xlabel)
583
+ ax.set_ylabel(ylabel)
584
+ ax.set_title(title)
585
+ ax.set_ylim(0, 1.05)
586
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
587
+
588
+ if show_grid:
589
+ ax.grid(True, alpha=0.3)
590
+
591
+ ax.legend(loc="lower right")
592
+ fig.tight_layout()
593
+
594
+ if show:
595
+ plt.show()
596
+
597
+ return ax
598
+
599
+
600
+ def _render_pretrends_power_plotly(
601
+ *,
602
+ M_values,
603
+ powers,
604
+ mdv,
605
+ target_power,
606
+ title,
607
+ xlabel,
608
+ ylabel,
609
+ color,
610
+ mdv_color,
611
+ target_color,
612
+ linewidth,
613
+ show_mdv_line,
614
+ show_target_line,
615
+ show_grid,
616
+ show,
617
+ ):
618
+ """Render pre-trends power curve with plotly."""
619
+ from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
620
+
621
+ go = _require_plotly()
622
+
623
+ fig = go.Figure()
624
+
625
+ if powers is not None:
626
+ fig.add_trace(
627
+ go.Scatter(
628
+ x=M_values,
629
+ y=powers,
630
+ mode="lines",
631
+ line=dict(color=color, width=linewidth),
632
+ name="Power",
633
+ )
634
+ )
635
+
636
+ if show_target_line:
637
+ fig.add_hline(
638
+ y=target_power,
639
+ line_dash="dash",
640
+ line_color=target_color,
641
+ opacity=0.7,
642
+ annotation_text=f"Target ({target_power:.0%})",
643
+ )
644
+
645
+ if show_mdv_line and mdv is not None and np.isfinite(mdv):
646
+ fig.add_vline(
647
+ x=mdv,
648
+ line_dash="dot",
649
+ line_color=mdv_color,
650
+ opacity=0.7,
651
+ annotation_text=f"MDV = {mdv:.3f}",
652
+ )
653
+
654
+ _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
655
+ fig.update_xaxes(showgrid=show_grid)
656
+ fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
657
+
658
+ if show:
659
+ fig.show()
660
+
661
+ return fig