diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
"""Power analysis visualization functions."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from diff_diff.power import PowerResults, SimulationPowerResults
|
|
10
|
+
from diff_diff.pretrends import PreTrendsPowerCurve, PreTrendsPowerResults
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def plot_power_curve(
|
|
14
|
+
results: Optional[Union["PowerResults", "SimulationPowerResults", pd.DataFrame]] = None,
|
|
15
|
+
*,
|
|
16
|
+
effect_sizes: Optional[List[float]] = None,
|
|
17
|
+
powers: Optional[List[float]] = None,
|
|
18
|
+
mde: Optional[float] = None,
|
|
19
|
+
target_power: float = 0.80,
|
|
20
|
+
plot_type: str = "effect",
|
|
21
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
22
|
+
title: Optional[str] = None,
|
|
23
|
+
xlabel: Optional[str] = None,
|
|
24
|
+
ylabel: str = "Power",
|
|
25
|
+
color: str = "#2563eb",
|
|
26
|
+
mde_color: str = "#dc2626",
|
|
27
|
+
target_color: str = "#22c55e",
|
|
28
|
+
linewidth: float = 2.0,
|
|
29
|
+
show_mde_line: bool = True,
|
|
30
|
+
show_target_line: bool = True,
|
|
31
|
+
show_grid: bool = True,
|
|
32
|
+
ax: Optional[Any] = None,
|
|
33
|
+
show: bool = True,
|
|
34
|
+
backend: str = "matplotlib",
|
|
35
|
+
) -> Any:
|
|
36
|
+
"""
|
|
37
|
+
Create a power curve visualization.
|
|
38
|
+
|
|
39
|
+
Shows how statistical power changes with effect size or sample size,
|
|
40
|
+
helping researchers understand the trade-offs in study design.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
results : PowerResults, SimulationPowerResults, or DataFrame, optional
|
|
45
|
+
Results object from PowerAnalysis or simulate_power(), or a DataFrame
|
|
46
|
+
with columns 'effect_size' and 'power' (or 'sample_size' and 'power').
|
|
47
|
+
If None, must provide effect_sizes and powers directly.
|
|
48
|
+
effect_sizes : list of float, optional
|
|
49
|
+
Effect sizes (x-axis values). Required if results is None.
|
|
50
|
+
powers : list of float, optional
|
|
51
|
+
Power values (y-axis values). Required if results is None.
|
|
52
|
+
mde : float, optional
|
|
53
|
+
Minimum detectable effect to mark on the plot.
|
|
54
|
+
target_power : float, default=0.80
|
|
55
|
+
Target power level to show as horizontal line.
|
|
56
|
+
plot_type : str, default="effect"
|
|
57
|
+
Type of power curve: "effect" (power vs effect size) or
|
|
58
|
+
"sample" (power vs sample size).
|
|
59
|
+
figsize : tuple, default=(10, 6)
|
|
60
|
+
Figure size (width, height) in inches.
|
|
61
|
+
title : str, optional
|
|
62
|
+
Plot title. If None, uses a sensible default.
|
|
63
|
+
xlabel : str, optional
|
|
64
|
+
X-axis label. If None, uses a sensible default.
|
|
65
|
+
ylabel : str, default="Power"
|
|
66
|
+
Y-axis label.
|
|
67
|
+
color : str, default="#2563eb"
|
|
68
|
+
Color for the power curve line.
|
|
69
|
+
mde_color : str, default="#dc2626"
|
|
70
|
+
Color for the MDE vertical line.
|
|
71
|
+
target_color : str, default="#22c55e"
|
|
72
|
+
Color for the target power horizontal line.
|
|
73
|
+
linewidth : float, default=2.0
|
|
74
|
+
Line width for the power curve.
|
|
75
|
+
show_mde_line : bool, default=True
|
|
76
|
+
Whether to show vertical line at MDE.
|
|
77
|
+
show_target_line : bool, default=True
|
|
78
|
+
Whether to show horizontal line at target power.
|
|
79
|
+
show_grid : bool, default=True
|
|
80
|
+
Whether to show grid lines.
|
|
81
|
+
ax : matplotlib.axes.Axes, optional
|
|
82
|
+
Axes to plot on. If None, creates new figure.
|
|
83
|
+
show : bool, default=True
|
|
84
|
+
Whether to call plt.show() at the end.
|
|
85
|
+
backend : str, default="matplotlib"
|
|
86
|
+
Plotting backend: ``"matplotlib"`` or ``"plotly"``.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
matplotlib.axes.Axes or plotly.graph_objects.Figure
|
|
91
|
+
The axes object (matplotlib) or figure (plotly).
|
|
92
|
+
|
|
93
|
+
Examples
|
|
94
|
+
--------
|
|
95
|
+
From PowerAnalysis results:
|
|
96
|
+
|
|
97
|
+
>>> from diff_diff import PowerAnalysis, plot_power_curve
|
|
98
|
+
>>> pa = PowerAnalysis(power=0.80)
|
|
99
|
+
>>> curve_df = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
|
|
100
|
+
>>> mde_result = pa.mde(n_treated=50, n_control=50, sigma=5.0)
|
|
101
|
+
>>> plot_power_curve(curve_df, mde=mde_result.mde)
|
|
102
|
+
|
|
103
|
+
From simulation results:
|
|
104
|
+
|
|
105
|
+
>>> from diff_diff import simulate_power, DifferenceInDifferences
|
|
106
|
+
>>> results = simulate_power(
|
|
107
|
+
... DifferenceInDifferences(),
|
|
108
|
+
... effect_sizes=[1, 2, 3, 5, 7, 10],
|
|
109
|
+
... n_simulations=200
|
|
110
|
+
... )
|
|
111
|
+
>>> plot_power_curve(results)
|
|
112
|
+
|
|
113
|
+
Manual data:
|
|
114
|
+
|
|
115
|
+
>>> plot_power_curve(
|
|
116
|
+
... effect_sizes=[1, 2, 3, 4, 5],
|
|
117
|
+
... powers=[0.2, 0.5, 0.75, 0.90, 0.97],
|
|
118
|
+
... mde=2.5,
|
|
119
|
+
... target_power=0.80
|
|
120
|
+
... )
|
|
121
|
+
"""
|
|
122
|
+
# Extract data from results if provided
|
|
123
|
+
if results is not None:
|
|
124
|
+
if isinstance(results, pd.DataFrame):
|
|
125
|
+
if "effect_size" in results.columns:
|
|
126
|
+
effect_sizes = results["effect_size"].tolist()
|
|
127
|
+
plot_type = "effect"
|
|
128
|
+
elif "sample_size" in results.columns:
|
|
129
|
+
effect_sizes = results["sample_size"].tolist()
|
|
130
|
+
plot_type = "sample"
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError("DataFrame must have 'effect_size' or 'sample_size' column")
|
|
133
|
+
powers = results["power"].tolist()
|
|
134
|
+
elif hasattr(results, "effect_sizes") and hasattr(results, "powers"):
|
|
135
|
+
# SimulationPowerResults
|
|
136
|
+
effect_sizes = results.effect_sizes
|
|
137
|
+
powers = results.powers
|
|
138
|
+
if mde is None and hasattr(results, "true_effect"):
|
|
139
|
+
mde = results.true_effect
|
|
140
|
+
elif hasattr(results, "mde"):
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"PowerResults should be used to get mde value, not as direct input. "
|
|
143
|
+
"Use PowerAnalysis.power_curve() to generate curve data."
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
raise TypeError(f"Cannot extract power curve data from {type(results).__name__}")
|
|
147
|
+
elif effect_sizes is None or powers is None:
|
|
148
|
+
raise ValueError("Must provide either 'results' or both 'effect_sizes' and 'powers'")
|
|
149
|
+
|
|
150
|
+
# Default titles and labels
|
|
151
|
+
if title is None:
|
|
152
|
+
title = "Power Curve" if plot_type == "effect" else "Power vs Sample Size"
|
|
153
|
+
if xlabel is None:
|
|
154
|
+
xlabel = "Effect Size" if plot_type == "effect" else "Sample Size"
|
|
155
|
+
|
|
156
|
+
if backend == "plotly":
|
|
157
|
+
return _render_power_curve_plotly(
|
|
158
|
+
effect_sizes=effect_sizes,
|
|
159
|
+
powers=powers,
|
|
160
|
+
mde=mde,
|
|
161
|
+
target_power=target_power,
|
|
162
|
+
title=title,
|
|
163
|
+
xlabel=xlabel,
|
|
164
|
+
ylabel=ylabel,
|
|
165
|
+
color=color,
|
|
166
|
+
mde_color=mde_color,
|
|
167
|
+
target_color=target_color,
|
|
168
|
+
linewidth=linewidth,
|
|
169
|
+
show_mde_line=show_mde_line,
|
|
170
|
+
show_target_line=show_target_line,
|
|
171
|
+
show_grid=show_grid,
|
|
172
|
+
show=show,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return _render_power_curve_mpl(
|
|
176
|
+
effect_sizes=effect_sizes,
|
|
177
|
+
powers=powers,
|
|
178
|
+
mde=mde,
|
|
179
|
+
target_power=target_power,
|
|
180
|
+
figsize=figsize,
|
|
181
|
+
title=title,
|
|
182
|
+
xlabel=xlabel,
|
|
183
|
+
ylabel=ylabel,
|
|
184
|
+
color=color,
|
|
185
|
+
mde_color=mde_color,
|
|
186
|
+
target_color=target_color,
|
|
187
|
+
linewidth=linewidth,
|
|
188
|
+
show_mde_line=show_mde_line,
|
|
189
|
+
show_target_line=show_target_line,
|
|
190
|
+
show_grid=show_grid,
|
|
191
|
+
ax=ax,
|
|
192
|
+
show=show,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _render_power_curve_mpl(
|
|
197
|
+
*,
|
|
198
|
+
effect_sizes,
|
|
199
|
+
powers,
|
|
200
|
+
mde,
|
|
201
|
+
target_power,
|
|
202
|
+
figsize,
|
|
203
|
+
title,
|
|
204
|
+
xlabel,
|
|
205
|
+
ylabel,
|
|
206
|
+
color,
|
|
207
|
+
mde_color,
|
|
208
|
+
target_color,
|
|
209
|
+
linewidth,
|
|
210
|
+
show_mde_line,
|
|
211
|
+
show_target_line,
|
|
212
|
+
show_grid,
|
|
213
|
+
ax,
|
|
214
|
+
show,
|
|
215
|
+
):
|
|
216
|
+
"""Render power curve with matplotlib."""
|
|
217
|
+
from diff_diff.visualization._common import _require_matplotlib
|
|
218
|
+
|
|
219
|
+
plt = _require_matplotlib()
|
|
220
|
+
|
|
221
|
+
if ax is None:
|
|
222
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
223
|
+
else:
|
|
224
|
+
fig = ax.get_figure()
|
|
225
|
+
|
|
226
|
+
# Plot power curve
|
|
227
|
+
ax.plot(effect_sizes, powers, color=color, linewidth=linewidth, label="Power")
|
|
228
|
+
|
|
229
|
+
# Add target power line
|
|
230
|
+
if show_target_line:
|
|
231
|
+
ax.axhline(
|
|
232
|
+
y=target_power,
|
|
233
|
+
color=target_color,
|
|
234
|
+
linestyle="--",
|
|
235
|
+
linewidth=1.5,
|
|
236
|
+
alpha=0.7,
|
|
237
|
+
label=f"Target power ({target_power:.0%})",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Add MDE line
|
|
241
|
+
if show_mde_line and mde is not None:
|
|
242
|
+
ax.axvline(
|
|
243
|
+
x=mde,
|
|
244
|
+
color=mde_color,
|
|
245
|
+
linestyle=":",
|
|
246
|
+
linewidth=1.5,
|
|
247
|
+
alpha=0.7,
|
|
248
|
+
label=f"MDE = {mde:.3f}",
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Mark intersection point
|
|
252
|
+
if mde in effect_sizes:
|
|
253
|
+
idx = effect_sizes.index(mde)
|
|
254
|
+
power_at_mde = powers[idx]
|
|
255
|
+
else:
|
|
256
|
+
effect_arr = np.array(effect_sizes)
|
|
257
|
+
power_arr = np.array(powers)
|
|
258
|
+
if effect_arr.min() <= mde <= effect_arr.max():
|
|
259
|
+
power_at_mde = np.interp(mde, effect_arr, power_arr)
|
|
260
|
+
else:
|
|
261
|
+
power_at_mde = None
|
|
262
|
+
|
|
263
|
+
if power_at_mde is not None:
|
|
264
|
+
ax.scatter([mde], [power_at_mde], color=mde_color, s=50, zorder=5)
|
|
265
|
+
|
|
266
|
+
ax.set_xlabel(xlabel)
|
|
267
|
+
ax.set_ylabel(ylabel)
|
|
268
|
+
ax.set_title(title)
|
|
269
|
+
ax.set_ylim(0, 1.05)
|
|
270
|
+
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
|
|
271
|
+
|
|
272
|
+
if show_grid:
|
|
273
|
+
ax.grid(True, alpha=0.3)
|
|
274
|
+
|
|
275
|
+
ax.legend(loc="lower right")
|
|
276
|
+
fig.tight_layout()
|
|
277
|
+
|
|
278
|
+
if show:
|
|
279
|
+
plt.show()
|
|
280
|
+
|
|
281
|
+
return ax
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _render_power_curve_plotly(
|
|
285
|
+
*,
|
|
286
|
+
effect_sizes,
|
|
287
|
+
powers,
|
|
288
|
+
mde,
|
|
289
|
+
target_power,
|
|
290
|
+
title,
|
|
291
|
+
xlabel,
|
|
292
|
+
ylabel,
|
|
293
|
+
color,
|
|
294
|
+
mde_color,
|
|
295
|
+
target_color,
|
|
296
|
+
linewidth,
|
|
297
|
+
show_mde_line,
|
|
298
|
+
show_target_line,
|
|
299
|
+
show_grid,
|
|
300
|
+
show,
|
|
301
|
+
):
|
|
302
|
+
"""Render power curve with plotly."""
|
|
303
|
+
from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
|
|
304
|
+
|
|
305
|
+
go = _require_plotly()
|
|
306
|
+
|
|
307
|
+
fig = go.Figure()
|
|
308
|
+
|
|
309
|
+
fig.add_trace(
|
|
310
|
+
go.Scatter(
|
|
311
|
+
x=effect_sizes,
|
|
312
|
+
y=powers,
|
|
313
|
+
mode="lines",
|
|
314
|
+
line=dict(color=color, width=linewidth),
|
|
315
|
+
name="Power",
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if show_target_line:
|
|
320
|
+
fig.add_hline(
|
|
321
|
+
y=target_power,
|
|
322
|
+
line_dash="dash",
|
|
323
|
+
line_color=target_color,
|
|
324
|
+
opacity=0.7,
|
|
325
|
+
annotation_text=f"Target ({target_power:.0%})",
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
if show_mde_line and mde is not None:
|
|
329
|
+
fig.add_vline(
|
|
330
|
+
x=mde,
|
|
331
|
+
line_dash="dot",
|
|
332
|
+
line_color=mde_color,
|
|
333
|
+
opacity=0.7,
|
|
334
|
+
annotation_text=f"MDE = {mde:.3f}",
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
|
|
338
|
+
fig.update_xaxes(showgrid=show_grid)
|
|
339
|
+
fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
|
|
340
|
+
|
|
341
|
+
if show:
|
|
342
|
+
fig.show()
|
|
343
|
+
|
|
344
|
+
return fig
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def plot_pretrends_power(
|
|
348
|
+
results: Optional[Union["PreTrendsPowerResults", "PreTrendsPowerCurve", pd.DataFrame]] = None,
|
|
349
|
+
*,
|
|
350
|
+
M_values: Optional[List[float]] = None,
|
|
351
|
+
powers: Optional[List[float]] = None,
|
|
352
|
+
mdv: Optional[float] = None,
|
|
353
|
+
target_power: float = 0.80,
|
|
354
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
355
|
+
title: str = "Pre-Trends Test Power Curve",
|
|
356
|
+
xlabel: str = "Violation Magnitude (M)",
|
|
357
|
+
ylabel: str = "Power",
|
|
358
|
+
color: str = "#2563eb",
|
|
359
|
+
mdv_color: str = "#dc2626",
|
|
360
|
+
target_color: str = "#22c55e",
|
|
361
|
+
linewidth: float = 2.0,
|
|
362
|
+
show_mdv_line: bool = True,
|
|
363
|
+
show_target_line: bool = True,
|
|
364
|
+
show_grid: bool = True,
|
|
365
|
+
ax: Optional[Any] = None,
|
|
366
|
+
show: bool = True,
|
|
367
|
+
backend: str = "matplotlib",
|
|
368
|
+
) -> Any:
|
|
369
|
+
"""
|
|
370
|
+
Plot pre-trends test power curve.
|
|
371
|
+
|
|
372
|
+
Visualizes how the power to detect parallel trends violations changes
|
|
373
|
+
with the violation magnitude (M). This helps understand what violations
|
|
374
|
+
your pre-trends test is capable of detecting.
|
|
375
|
+
|
|
376
|
+
Parameters
|
|
377
|
+
----------
|
|
378
|
+
results : PreTrendsPowerResults, PreTrendsPowerCurve, or DataFrame, optional
|
|
379
|
+
Results from PreTrendsPower.fit() or power_curve(), or a DataFrame
|
|
380
|
+
with columns 'M' and 'power'. If None, must provide M_values and powers.
|
|
381
|
+
M_values : list of float, optional
|
|
382
|
+
Violation magnitudes (x-axis). Required if results is None.
|
|
383
|
+
powers : list of float, optional
|
|
384
|
+
Power values (y-axis). Required if results is None.
|
|
385
|
+
mdv : float, optional
|
|
386
|
+
Minimum detectable violation to mark on the plot.
|
|
387
|
+
target_power : float, default=0.80
|
|
388
|
+
Target power level to show as horizontal line.
|
|
389
|
+
figsize : tuple, default=(10, 6)
|
|
390
|
+
Figure size (width, height) in inches.
|
|
391
|
+
title : str
|
|
392
|
+
Plot title.
|
|
393
|
+
xlabel : str
|
|
394
|
+
X-axis label.
|
|
395
|
+
ylabel : str
|
|
396
|
+
Y-axis label.
|
|
397
|
+
color : str, default="#2563eb"
|
|
398
|
+
Color for the power curve line.
|
|
399
|
+
mdv_color : str, default="#dc2626"
|
|
400
|
+
Color for the MDV vertical line.
|
|
401
|
+
target_color : str, default="#22c55e"
|
|
402
|
+
Color for the target power horizontal line.
|
|
403
|
+
linewidth : float, default=2.0
|
|
404
|
+
Line width for the power curve.
|
|
405
|
+
show_mdv_line : bool, default=True
|
|
406
|
+
Whether to show vertical line at MDV.
|
|
407
|
+
show_target_line : bool, default=True
|
|
408
|
+
Whether to show horizontal line at target power.
|
|
409
|
+
show_grid : bool, default=True
|
|
410
|
+
Whether to show grid lines.
|
|
411
|
+
ax : matplotlib.axes.Axes, optional
|
|
412
|
+
Axes to plot on. If None, creates new figure.
|
|
413
|
+
show : bool, default=True
|
|
414
|
+
Whether to call plt.show() at the end.
|
|
415
|
+
backend : str, default="matplotlib"
|
|
416
|
+
Plotting backend: ``"matplotlib"`` or ``"plotly"``.
|
|
417
|
+
|
|
418
|
+
Returns
|
|
419
|
+
-------
|
|
420
|
+
matplotlib.axes.Axes or plotly.graph_objects.Figure
|
|
421
|
+
The axes object (matplotlib) or figure (plotly).
|
|
422
|
+
|
|
423
|
+
Examples
|
|
424
|
+
--------
|
|
425
|
+
From PreTrendsPower results:
|
|
426
|
+
|
|
427
|
+
>>> from diff_diff import MultiPeriodDiD
|
|
428
|
+
>>> from diff_diff.pretrends import PreTrendsPower
|
|
429
|
+
>>> from diff_diff.visualization import plot_pretrends_power
|
|
430
|
+
>>>
|
|
431
|
+
>>> mp_did = MultiPeriodDiD()
|
|
432
|
+
>>> event_results = mp_did.fit(data, outcome='y', treatment='treated',
|
|
433
|
+
... time='period', post_periods=[4, 5, 6, 7])
|
|
434
|
+
>>>
|
|
435
|
+
>>> pt = PreTrendsPower()
|
|
436
|
+
>>> curve = pt.power_curve(event_results)
|
|
437
|
+
>>> plot_pretrends_power(curve)
|
|
438
|
+
|
|
439
|
+
Notes
|
|
440
|
+
-----
|
|
441
|
+
The power curve shows how likely you are to reject the null hypothesis
|
|
442
|
+
of parallel trends given a true violation of magnitude M.
|
|
443
|
+
|
|
444
|
+
See Also
|
|
445
|
+
--------
|
|
446
|
+
PreTrendsPower : Main class for pre-trends power analysis
|
|
447
|
+
plot_sensitivity : Plot HonestDiD sensitivity analysis
|
|
448
|
+
"""
|
|
449
|
+
# Extract data from results if provided
|
|
450
|
+
if results is not None:
|
|
451
|
+
if isinstance(results, pd.DataFrame):
|
|
452
|
+
if "M" not in results.columns or "power" not in results.columns:
|
|
453
|
+
raise ValueError("DataFrame must have 'M' and 'power' columns")
|
|
454
|
+
M_values = results["M"].tolist()
|
|
455
|
+
powers = results["power"].tolist()
|
|
456
|
+
elif hasattr(results, "M_values") and hasattr(results, "powers"):
|
|
457
|
+
# PreTrendsPowerCurve
|
|
458
|
+
M_values = results.M_values.tolist()
|
|
459
|
+
powers = results.powers.tolist()
|
|
460
|
+
if mdv is None:
|
|
461
|
+
mdv = results.mdv
|
|
462
|
+
if target_power is None:
|
|
463
|
+
target_power = results.target_power
|
|
464
|
+
elif hasattr(results, "mdv") and hasattr(results, "power"):
|
|
465
|
+
# Single PreTrendsPowerResults
|
|
466
|
+
if mdv is None:
|
|
467
|
+
mdv = results.mdv
|
|
468
|
+
if np.isfinite(mdv):
|
|
469
|
+
M_values = [0, mdv * 0.5, mdv, mdv * 1.5, mdv * 2]
|
|
470
|
+
else:
|
|
471
|
+
M_values = [0, 1, 2, 3, 4]
|
|
472
|
+
powers = None
|
|
473
|
+
else:
|
|
474
|
+
raise TypeError(f"Cannot extract power curve data from {type(results).__name__}")
|
|
475
|
+
elif M_values is None or powers is None:
|
|
476
|
+
raise ValueError("Must provide either 'results' or both 'M_values' and 'powers'")
|
|
477
|
+
|
|
478
|
+
if backend == "plotly":
|
|
479
|
+
return _render_pretrends_power_plotly(
|
|
480
|
+
M_values=M_values,
|
|
481
|
+
powers=powers,
|
|
482
|
+
mdv=mdv,
|
|
483
|
+
target_power=target_power,
|
|
484
|
+
title=title,
|
|
485
|
+
xlabel=xlabel,
|
|
486
|
+
ylabel=ylabel,
|
|
487
|
+
color=color,
|
|
488
|
+
mdv_color=mdv_color,
|
|
489
|
+
target_color=target_color,
|
|
490
|
+
linewidth=linewidth,
|
|
491
|
+
show_mdv_line=show_mdv_line,
|
|
492
|
+
show_target_line=show_target_line,
|
|
493
|
+
show_grid=show_grid,
|
|
494
|
+
show=show,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
return _render_pretrends_power_mpl(
|
|
498
|
+
M_values=M_values,
|
|
499
|
+
powers=powers,
|
|
500
|
+
mdv=mdv,
|
|
501
|
+
target_power=target_power,
|
|
502
|
+
figsize=figsize,
|
|
503
|
+
title=title,
|
|
504
|
+
xlabel=xlabel,
|
|
505
|
+
ylabel=ylabel,
|
|
506
|
+
color=color,
|
|
507
|
+
mdv_color=mdv_color,
|
|
508
|
+
target_color=target_color,
|
|
509
|
+
linewidth=linewidth,
|
|
510
|
+
show_mdv_line=show_mdv_line,
|
|
511
|
+
show_target_line=show_target_line,
|
|
512
|
+
show_grid=show_grid,
|
|
513
|
+
ax=ax,
|
|
514
|
+
show=show,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _render_pretrends_power_mpl(
|
|
519
|
+
*,
|
|
520
|
+
M_values,
|
|
521
|
+
powers,
|
|
522
|
+
mdv,
|
|
523
|
+
target_power,
|
|
524
|
+
figsize,
|
|
525
|
+
title,
|
|
526
|
+
xlabel,
|
|
527
|
+
ylabel,
|
|
528
|
+
color,
|
|
529
|
+
mdv_color,
|
|
530
|
+
target_color,
|
|
531
|
+
linewidth,
|
|
532
|
+
show_mdv_line,
|
|
533
|
+
show_target_line,
|
|
534
|
+
show_grid,
|
|
535
|
+
ax,
|
|
536
|
+
show,
|
|
537
|
+
):
|
|
538
|
+
"""Render pre-trends power curve with matplotlib."""
|
|
539
|
+
from diff_diff.visualization._common import _require_matplotlib
|
|
540
|
+
|
|
541
|
+
plt = _require_matplotlib()
|
|
542
|
+
|
|
543
|
+
if ax is None:
|
|
544
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
545
|
+
else:
|
|
546
|
+
fig = ax.get_figure()
|
|
547
|
+
|
|
548
|
+
# Plot power curve if we have powers
|
|
549
|
+
if powers is not None:
|
|
550
|
+
ax.plot(M_values, powers, color=color, linewidth=linewidth, label="Power")
|
|
551
|
+
|
|
552
|
+
# Add target power line
|
|
553
|
+
if show_target_line:
|
|
554
|
+
ax.axhline(
|
|
555
|
+
y=target_power,
|
|
556
|
+
color=target_color,
|
|
557
|
+
linestyle="--",
|
|
558
|
+
linewidth=1.5,
|
|
559
|
+
alpha=0.7,
|
|
560
|
+
label=f"Target power ({target_power:.0%})",
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
# Add MDV line
|
|
564
|
+
if show_mdv_line and mdv is not None and np.isfinite(mdv):
|
|
565
|
+
ax.axvline(
|
|
566
|
+
x=mdv,
|
|
567
|
+
color=mdv_color,
|
|
568
|
+
linestyle=":",
|
|
569
|
+
linewidth=1.5,
|
|
570
|
+
alpha=0.7,
|
|
571
|
+
label=f"MDV = {mdv:.3f}",
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Mark intersection point if we have powers
|
|
575
|
+
if powers is not None:
|
|
576
|
+
M_arr = np.array(M_values)
|
|
577
|
+
power_arr = np.array(powers)
|
|
578
|
+
if M_arr.min() <= mdv <= M_arr.max():
|
|
579
|
+
power_at_mdv = np.interp(mdv, M_arr, power_arr)
|
|
580
|
+
ax.scatter([mdv], [power_at_mdv], color=mdv_color, s=50, zorder=5)
|
|
581
|
+
|
|
582
|
+
ax.set_xlabel(xlabel)
|
|
583
|
+
ax.set_ylabel(ylabel)
|
|
584
|
+
ax.set_title(title)
|
|
585
|
+
ax.set_ylim(0, 1.05)
|
|
586
|
+
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
|
|
587
|
+
|
|
588
|
+
if show_grid:
|
|
589
|
+
ax.grid(True, alpha=0.3)
|
|
590
|
+
|
|
591
|
+
ax.legend(loc="lower right")
|
|
592
|
+
fig.tight_layout()
|
|
593
|
+
|
|
594
|
+
if show:
|
|
595
|
+
plt.show()
|
|
596
|
+
|
|
597
|
+
return ax
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def _render_pretrends_power_plotly(
|
|
601
|
+
*,
|
|
602
|
+
M_values,
|
|
603
|
+
powers,
|
|
604
|
+
mdv,
|
|
605
|
+
target_power,
|
|
606
|
+
title,
|
|
607
|
+
xlabel,
|
|
608
|
+
ylabel,
|
|
609
|
+
color,
|
|
610
|
+
mdv_color,
|
|
611
|
+
target_color,
|
|
612
|
+
linewidth,
|
|
613
|
+
show_mdv_line,
|
|
614
|
+
show_target_line,
|
|
615
|
+
show_grid,
|
|
616
|
+
show,
|
|
617
|
+
):
|
|
618
|
+
"""Render pre-trends power curve with plotly."""
|
|
619
|
+
from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
|
|
620
|
+
|
|
621
|
+
go = _require_plotly()
|
|
622
|
+
|
|
623
|
+
fig = go.Figure()
|
|
624
|
+
|
|
625
|
+
if powers is not None:
|
|
626
|
+
fig.add_trace(
|
|
627
|
+
go.Scatter(
|
|
628
|
+
x=M_values,
|
|
629
|
+
y=powers,
|
|
630
|
+
mode="lines",
|
|
631
|
+
line=dict(color=color, width=linewidth),
|
|
632
|
+
name="Power",
|
|
633
|
+
)
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
if show_target_line:
|
|
637
|
+
fig.add_hline(
|
|
638
|
+
y=target_power,
|
|
639
|
+
line_dash="dash",
|
|
640
|
+
line_color=target_color,
|
|
641
|
+
opacity=0.7,
|
|
642
|
+
annotation_text=f"Target ({target_power:.0%})",
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
if show_mdv_line and mdv is not None and np.isfinite(mdv):
|
|
646
|
+
fig.add_vline(
|
|
647
|
+
x=mdv,
|
|
648
|
+
line_dash="dot",
|
|
649
|
+
line_color=mdv_color,
|
|
650
|
+
opacity=0.7,
|
|
651
|
+
annotation_text=f"MDV = {mdv:.3f}",
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
|
|
655
|
+
fig.update_xaxes(showgrid=show_grid)
|
|
656
|
+
fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
|
|
657
|
+
|
|
658
|
+
if show:
|
|
659
|
+
fig.show()
|
|
660
|
+
|
|
661
|
+
return fig
|