diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization functions for difference-in-differences analysis.
|
|
3
|
+
|
|
4
|
+
Provides event study plots, diagnostic visualizations, and other plotting
|
|
5
|
+
utilities with support for matplotlib (default) and plotly backends.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Event study plots
|
|
9
|
+
# Continuous DiD plots
|
|
10
|
+
from diff_diff.visualization._continuous import (
|
|
11
|
+
plot_dose_response,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# Diagnostic plots
|
|
15
|
+
from diff_diff.visualization._diagnostic import (
|
|
16
|
+
plot_bacon,
|
|
17
|
+
plot_sensitivity,
|
|
18
|
+
)
|
|
19
|
+
from diff_diff.visualization._event_study import (
|
|
20
|
+
PlottableResults,
|
|
21
|
+
_extract_plot_data,
|
|
22
|
+
plot_event_study,
|
|
23
|
+
plot_honest_event_study,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Power analysis plots
|
|
27
|
+
from diff_diff.visualization._power import (
|
|
28
|
+
plot_power_curve,
|
|
29
|
+
plot_pretrends_power,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Staggered DiD plots
|
|
33
|
+
from diff_diff.visualization._staggered import (
|
|
34
|
+
plot_group_effects,
|
|
35
|
+
plot_group_time_heatmap,
|
|
36
|
+
plot_staircase,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Synthetic control plots
|
|
40
|
+
from diff_diff.visualization._synthetic import (
|
|
41
|
+
plot_synth_weights,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
# Existing public functions
|
|
46
|
+
"plot_event_study",
|
|
47
|
+
"plot_honest_event_study",
|
|
48
|
+
"plot_group_effects",
|
|
49
|
+
"plot_sensitivity",
|
|
50
|
+
"plot_bacon",
|
|
51
|
+
"plot_power_curve",
|
|
52
|
+
"plot_pretrends_power",
|
|
53
|
+
# New public functions
|
|
54
|
+
"plot_synth_weights",
|
|
55
|
+
"plot_staircase",
|
|
56
|
+
"plot_dose_response",
|
|
57
|
+
"plot_group_time_heatmap",
|
|
58
|
+
# Re-exported for backward compatibility (used in tests)
|
|
59
|
+
"_extract_plot_data",
|
|
60
|
+
"PlottableResults",
|
|
61
|
+
]
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
"""Shared utilities for the visualization subpackage."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _require_matplotlib():
|
|
7
|
+
"""Lazy import matplotlib with clear error message.
|
|
8
|
+
|
|
9
|
+
Returns
|
|
10
|
+
-------
|
|
11
|
+
module
|
|
12
|
+
The ``matplotlib.pyplot`` module.
|
|
13
|
+
"""
|
|
14
|
+
try:
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
|
|
17
|
+
return plt
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"matplotlib is required for plotting. " "Install it with: pip install matplotlib"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _require_plotly():
|
|
25
|
+
"""Lazy import plotly with clear error message.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
module
|
|
30
|
+
The ``plotly.graph_objects`` module.
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
import plotly.graph_objects as go
|
|
34
|
+
|
|
35
|
+
return go
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"plotly is required for interactive plots. "
|
|
39
|
+
"Install with: pip install diff-diff[plotly]"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _plotly_default_layout(fig, *, title=None, xlabel=None, ylabel=None, show_legend=True):
|
|
44
|
+
"""Apply standard plotly layout settings.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
fig : plotly.graph_objects.Figure
|
|
49
|
+
The figure to configure.
|
|
50
|
+
title : str, optional
|
|
51
|
+
Plot title.
|
|
52
|
+
xlabel : str, optional
|
|
53
|
+
X-axis label.
|
|
54
|
+
ylabel : str, optional
|
|
55
|
+
Y-axis label.
|
|
56
|
+
show_legend : bool, default=True
|
|
57
|
+
Whether to show the legend.
|
|
58
|
+
"""
|
|
59
|
+
fig.update_layout(
|
|
60
|
+
title=title,
|
|
61
|
+
xaxis_title=xlabel,
|
|
62
|
+
yaxis_title=ylabel,
|
|
63
|
+
showlegend=show_legend,
|
|
64
|
+
template="plotly_white",
|
|
65
|
+
font=dict(size=12),
|
|
66
|
+
margin=dict(l=60, r=30, t=50, b=50),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Complete CSS named color table (all 148 standard CSS colors).
|
|
71
|
+
# No matplotlib dependency required for any color used by plotly/CSS.
|
|
72
|
+
_CSS_COLORS = {
|
|
73
|
+
"aliceblue": (240, 248, 255),
|
|
74
|
+
"antiquewhite": (250, 235, 215),
|
|
75
|
+
"aqua": (0, 255, 255),
|
|
76
|
+
"aquamarine": (127, 255, 212),
|
|
77
|
+
"azure": (240, 255, 255),
|
|
78
|
+
"beige": (245, 245, 220),
|
|
79
|
+
"bisque": (255, 228, 196),
|
|
80
|
+
"black": (0, 0, 0),
|
|
81
|
+
"blanchedalmond": (255, 235, 205),
|
|
82
|
+
"blue": (0, 0, 255),
|
|
83
|
+
"blueviolet": (138, 43, 226),
|
|
84
|
+
"brown": (165, 42, 42),
|
|
85
|
+
"burlywood": (222, 184, 135),
|
|
86
|
+
"cadetblue": (95, 158, 160),
|
|
87
|
+
"chartreuse": (127, 255, 0),
|
|
88
|
+
"chocolate": (210, 105, 30),
|
|
89
|
+
"coral": (255, 127, 80),
|
|
90
|
+
"cornflowerblue": (100, 149, 237),
|
|
91
|
+
"cornsilk": (255, 248, 220),
|
|
92
|
+
"crimson": (220, 20, 60),
|
|
93
|
+
"cyan": (0, 255, 255),
|
|
94
|
+
"darkblue": (0, 0, 139),
|
|
95
|
+
"darkcyan": (0, 139, 139),
|
|
96
|
+
"darkgoldenrod": (184, 134, 11),
|
|
97
|
+
"darkgray": (169, 169, 169),
|
|
98
|
+
"darkgreen": (0, 100, 0),
|
|
99
|
+
"darkgrey": (169, 169, 169),
|
|
100
|
+
"darkkhaki": (189, 183, 107),
|
|
101
|
+
"darkmagenta": (139, 0, 139),
|
|
102
|
+
"darkolivegreen": (85, 107, 47),
|
|
103
|
+
"darkorange": (255, 140, 0),
|
|
104
|
+
"darkorchid": (153, 50, 204),
|
|
105
|
+
"darkred": (139, 0, 0),
|
|
106
|
+
"darksalmon": (233, 150, 122),
|
|
107
|
+
"darkseagreen": (143, 188, 143),
|
|
108
|
+
"darkslateblue": (72, 61, 139),
|
|
109
|
+
"darkslategray": (47, 79, 79),
|
|
110
|
+
"darkslategrey": (47, 79, 79),
|
|
111
|
+
"darkturquoise": (0, 206, 209),
|
|
112
|
+
"darkviolet": (148, 0, 211),
|
|
113
|
+
"deeppink": (255, 20, 147),
|
|
114
|
+
"deepskyblue": (0, 191, 255),
|
|
115
|
+
"dimgray": (105, 105, 105),
|
|
116
|
+
"dimgrey": (105, 105, 105),
|
|
117
|
+
"dodgerblue": (30, 144, 255),
|
|
118
|
+
"firebrick": (178, 34, 34),
|
|
119
|
+
"floralwhite": (255, 250, 240),
|
|
120
|
+
"forestgreen": (34, 139, 34),
|
|
121
|
+
"fuchsia": (255, 0, 255),
|
|
122
|
+
"gainsboro": (220, 220, 220),
|
|
123
|
+
"ghostwhite": (248, 248, 255),
|
|
124
|
+
"gold": (255, 215, 0),
|
|
125
|
+
"goldenrod": (218, 165, 32),
|
|
126
|
+
"gray": (128, 128, 128),
|
|
127
|
+
"green": (0, 128, 0),
|
|
128
|
+
"greenyellow": (173, 255, 47),
|
|
129
|
+
"grey": (128, 128, 128),
|
|
130
|
+
"honeydew": (240, 255, 240),
|
|
131
|
+
"hotpink": (255, 105, 180),
|
|
132
|
+
"indianred": (205, 92, 92),
|
|
133
|
+
"indigo": (75, 0, 130),
|
|
134
|
+
"ivory": (255, 255, 240),
|
|
135
|
+
"khaki": (240, 230, 140),
|
|
136
|
+
"lavender": (230, 230, 250),
|
|
137
|
+
"lavenderblush": (255, 240, 245),
|
|
138
|
+
"lawngreen": (124, 252, 0),
|
|
139
|
+
"lemonchiffon": (255, 250, 205),
|
|
140
|
+
"lightblue": (173, 216, 230),
|
|
141
|
+
"lightcoral": (240, 128, 128),
|
|
142
|
+
"lightcyan": (224, 255, 255),
|
|
143
|
+
"lightgoldenrodyellow": (250, 250, 210),
|
|
144
|
+
"lightgray": (211, 211, 211),
|
|
145
|
+
"lightgreen": (144, 238, 144),
|
|
146
|
+
"lightgrey": (211, 211, 211),
|
|
147
|
+
"lightpink": (255, 182, 193),
|
|
148
|
+
"lightsalmon": (255, 160, 122),
|
|
149
|
+
"lightseagreen": (32, 178, 170),
|
|
150
|
+
"lightskyblue": (135, 206, 250),
|
|
151
|
+
"lightslategray": (119, 136, 153),
|
|
152
|
+
"lightslategrey": (119, 136, 153),
|
|
153
|
+
"lightsteelblue": (176, 196, 222),
|
|
154
|
+
"lightyellow": (255, 255, 224),
|
|
155
|
+
"lime": (0, 255, 0),
|
|
156
|
+
"limegreen": (50, 205, 50),
|
|
157
|
+
"linen": (250, 240, 230),
|
|
158
|
+
"magenta": (255, 0, 255),
|
|
159
|
+
"maroon": (128, 0, 0),
|
|
160
|
+
"mediumaquamarine": (102, 205, 170),
|
|
161
|
+
"mediumblue": (0, 0, 205),
|
|
162
|
+
"mediumorchid": (186, 85, 211),
|
|
163
|
+
"mediumpurple": (147, 112, 219),
|
|
164
|
+
"mediumseagreen": (60, 179, 113),
|
|
165
|
+
"mediumslateblue": (123, 104, 238),
|
|
166
|
+
"mediumspringgreen": (0, 250, 154),
|
|
167
|
+
"mediumturquoise": (72, 209, 204),
|
|
168
|
+
"mediumvioletred": (199, 21, 133),
|
|
169
|
+
"midnightblue": (25, 25, 112),
|
|
170
|
+
"mintcream": (245, 255, 250),
|
|
171
|
+
"mistyrose": (255, 228, 225),
|
|
172
|
+
"moccasin": (255, 228, 181),
|
|
173
|
+
"navajowhite": (255, 222, 173),
|
|
174
|
+
"navy": (0, 0, 128),
|
|
175
|
+
"oldlace": (253, 245, 230),
|
|
176
|
+
"olive": (128, 128, 0),
|
|
177
|
+
"olivedrab": (107, 142, 35),
|
|
178
|
+
"orange": (255, 165, 0),
|
|
179
|
+
"orangered": (255, 69, 0),
|
|
180
|
+
"orchid": (218, 112, 214),
|
|
181
|
+
"palegoldenrod": (238, 232, 170),
|
|
182
|
+
"palegreen": (152, 251, 152),
|
|
183
|
+
"paleturquoise": (175, 238, 238),
|
|
184
|
+
"palevioletred": (219, 112, 147),
|
|
185
|
+
"papayawhip": (255, 239, 213),
|
|
186
|
+
"peachpuff": (255, 218, 185),
|
|
187
|
+
"peru": (205, 133, 63),
|
|
188
|
+
"pink": (255, 192, 203),
|
|
189
|
+
"plum": (221, 160, 221),
|
|
190
|
+
"powderblue": (176, 224, 230),
|
|
191
|
+
"purple": (128, 0, 128),
|
|
192
|
+
"rebeccapurple": (102, 51, 153),
|
|
193
|
+
"red": (255, 0, 0),
|
|
194
|
+
"rosybrown": (188, 143, 143),
|
|
195
|
+
"royalblue": (65, 105, 225),
|
|
196
|
+
"saddlebrown": (139, 69, 19),
|
|
197
|
+
"salmon": (250, 128, 114),
|
|
198
|
+
"sandybrown": (244, 164, 96),
|
|
199
|
+
"seagreen": (46, 139, 87),
|
|
200
|
+
"seashell": (255, 245, 238),
|
|
201
|
+
"sienna": (160, 82, 45),
|
|
202
|
+
"silver": (192, 192, 192),
|
|
203
|
+
"skyblue": (135, 206, 235),
|
|
204
|
+
"slateblue": (106, 90, 205),
|
|
205
|
+
"slategray": (112, 128, 144),
|
|
206
|
+
"slategrey": (112, 128, 144),
|
|
207
|
+
"snow": (255, 250, 250),
|
|
208
|
+
"springgreen": (0, 255, 127),
|
|
209
|
+
"steelblue": (70, 130, 180),
|
|
210
|
+
"tan": (210, 180, 140),
|
|
211
|
+
"teal": (0, 128, 128),
|
|
212
|
+
"thistle": (216, 191, 216),
|
|
213
|
+
"tomato": (255, 99, 71),
|
|
214
|
+
"turquoise": (64, 224, 208),
|
|
215
|
+
"violet": (238, 130, 238),
|
|
216
|
+
"wheat": (245, 222, 179),
|
|
217
|
+
"white": (255, 255, 255),
|
|
218
|
+
"whitesmoke": (245, 245, 245),
|
|
219
|
+
"yellow": (255, 255, 0),
|
|
220
|
+
"yellowgreen": (154, 205, 50),
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
_RGB_RE = re.compile(r"^rgb\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$")
|
|
224
|
+
_RGBA_RE = re.compile(r"^rgba\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*([0-9.]+)\s*\)$")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _color_to_rgba(color, alpha=1.0):
|
|
228
|
+
"""Convert any color to an ``rgba(r, g, b, a)`` string for plotly.
|
|
229
|
+
|
|
230
|
+
Accepts hex colors (``#rrggbb``, ``#rgb``), all 148 CSS named colors,
|
|
231
|
+
and ``rgb(r,g,b)`` / ``rgba(r,g,b,a)`` strings. Does **not** require
|
|
232
|
+
matplotlib.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
color : str
|
|
237
|
+
Color specification.
|
|
238
|
+
alpha : float, default=1.0
|
|
239
|
+
Opacity value between 0 and 1.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
str
|
|
244
|
+
An ``rgba(r, g, b, a)`` string.
|
|
245
|
+
"""
|
|
246
|
+
if not isinstance(color, str):
|
|
247
|
+
raise ValueError(f"Expected a color string, got {type(color).__name__}")
|
|
248
|
+
|
|
249
|
+
# 1. Hex colors: #rrggbb or #rgb
|
|
250
|
+
stripped = color.lstrip("#")
|
|
251
|
+
if color.startswith("#") and all(c in "0123456789abcdefABCDEF" for c in stripped):
|
|
252
|
+
if len(stripped) == 6:
|
|
253
|
+
r = int(stripped[0:2], 16)
|
|
254
|
+
g = int(stripped[2:4], 16)
|
|
255
|
+
b = int(stripped[4:6], 16)
|
|
256
|
+
return f"rgba({r}, {g}, {b}, {alpha})"
|
|
257
|
+
if len(stripped) == 3:
|
|
258
|
+
r = int(stripped[0] * 2, 16)
|
|
259
|
+
g = int(stripped[1] * 2, 16)
|
|
260
|
+
b = int(stripped[2] * 2, 16)
|
|
261
|
+
return f"rgba({r}, {g}, {b}, {alpha})"
|
|
262
|
+
|
|
263
|
+
# 2. Named CSS colors (complete table — no matplotlib needed)
|
|
264
|
+
if color.lower() in _CSS_COLORS:
|
|
265
|
+
r, g, b = _CSS_COLORS[color.lower()]
|
|
266
|
+
return f"rgba({r}, {g}, {b}, {alpha})"
|
|
267
|
+
|
|
268
|
+
# 3. rgb(r, g, b) — parse and apply alpha
|
|
269
|
+
m = _RGB_RE.match(color.strip())
|
|
270
|
+
if m:
|
|
271
|
+
r, g, b = int(m.group(1)), int(m.group(2)), int(m.group(3))
|
|
272
|
+
return f"rgba({r}, {g}, {b}, {alpha})"
|
|
273
|
+
|
|
274
|
+
# 4. rgba(r, g, b, a) — parse and override alpha
|
|
275
|
+
m = _RGBA_RE.match(color.strip())
|
|
276
|
+
if m:
|
|
277
|
+
r, g, b = int(m.group(1)), int(m.group(2)), int(m.group(3))
|
|
278
|
+
return f"rgba({r}, {g}, {b}, {alpha})"
|
|
279
|
+
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"Cannot parse color '{color}'. Use hex (#rrggbb), a CSS color name, "
|
|
282
|
+
"or rgb(r,g,b) / rgba(r,g,b,a) format."
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# Matplotlib marker code -> plotly symbol name mapping
|
|
287
|
+
_MPL_TO_PLOTLY_SYMBOL = {
|
|
288
|
+
"o": "circle",
|
|
289
|
+
"s": "square",
|
|
290
|
+
"D": "diamond",
|
|
291
|
+
"d": "diamond",
|
|
292
|
+
"^": "triangle-up",
|
|
293
|
+
"v": "triangle-down",
|
|
294
|
+
"<": "triangle-left",
|
|
295
|
+
">": "triangle-right",
|
|
296
|
+
"p": "pentagon",
|
|
297
|
+
"h": "hexagon",
|
|
298
|
+
"+": "cross",
|
|
299
|
+
"x": "x",
|
|
300
|
+
"*": "star",
|
|
301
|
+
".": "circle",
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _mpl_marker_to_plotly_symbol(marker):
|
|
306
|
+
"""Convert a matplotlib marker code to a plotly symbol name.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
marker : str
|
|
311
|
+
Matplotlib marker shorthand (e.g., ``"o"``, ``"s"``, ``"D"``).
|
|
312
|
+
|
|
313
|
+
Returns
|
|
314
|
+
-------
|
|
315
|
+
str
|
|
316
|
+
Plotly symbol name (e.g., ``"circle"``, ``"square"``, ``"diamond"``).
|
|
317
|
+
Returns ``"circle"`` for unrecognized markers.
|
|
318
|
+
"""
|
|
319
|
+
return _MPL_TO_PLOTLY_SYMBOL.get(marker, "circle")
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# Default color constants
|
|
323
|
+
DEFAULT_BLUE = "#2563eb"
|
|
324
|
+
DEFAULT_RED = "#dc2626"
|
|
325
|
+
DEFAULT_GREEN = "#22c55e"
|
|
326
|
+
DEFAULT_GRAY = "#6b7280"
|
|
327
|
+
DEFAULT_DARK = "#1f2937"
|
|
328
|
+
DEFAULT_SHADE = "#f0f0f0"
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Continuous DiD visualization functions (dose-response curves)."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from diff_diff.continuous_did_results import ContinuousDiDResults, DoseResponseCurve
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def plot_dose_response(
|
|
12
|
+
results: Optional["ContinuousDiDResults"] = None,
|
|
13
|
+
*,
|
|
14
|
+
curve: Optional["DoseResponseCurve"] = None,
|
|
15
|
+
data: Optional[pd.DataFrame] = None,
|
|
16
|
+
target: str = "att",
|
|
17
|
+
alpha: float = 0.05,
|
|
18
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
19
|
+
title: Optional[str] = None,
|
|
20
|
+
xlabel: str = "Dose",
|
|
21
|
+
ylabel: str = "Treatment Effect",
|
|
22
|
+
color: str = "#2563eb",
|
|
23
|
+
ci_color: Optional[str] = None,
|
|
24
|
+
show_zero_line: bool = True,
|
|
25
|
+
ax: Optional[Any] = None,
|
|
26
|
+
show: bool = True,
|
|
27
|
+
backend: str = "matplotlib",
|
|
28
|
+
) -> Any:
|
|
29
|
+
"""
|
|
30
|
+
Plot dose-response curve from Continuous DiD estimation.
|
|
31
|
+
|
|
32
|
+
Visualizes how the treatment effect varies with the treatment dose
|
|
33
|
+
(intensity), with confidence bands.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
results : ContinuousDiDResults, optional
|
|
38
|
+
Results from ContinuousDiD estimator. Extracts the dose-response
|
|
39
|
+
curve based on ``target``.
|
|
40
|
+
curve : DoseResponseCurve, optional
|
|
41
|
+
A DoseResponseCurve object directly.
|
|
42
|
+
data : pd.DataFrame, optional
|
|
43
|
+
DataFrame with columns ``dose``, ``effect``, ``se`` (and optionally
|
|
44
|
+
``conf_int_lower``, ``conf_int_upper``).
|
|
45
|
+
target : str, default="att"
|
|
46
|
+
Which dose-response curve: ``"att"`` or ``"acrt"``.
|
|
47
|
+
alpha : float, default=0.05
|
|
48
|
+
Significance level for confidence intervals (used with DataFrame input).
|
|
49
|
+
figsize : tuple, default=(10, 6)
|
|
50
|
+
Figure size (width, height) in inches.
|
|
51
|
+
title : str, optional
|
|
52
|
+
Plot title. Auto-generated if None.
|
|
53
|
+
xlabel : str, default="Dose"
|
|
54
|
+
X-axis label.
|
|
55
|
+
ylabel : str, default="Treatment Effect"
|
|
56
|
+
Y-axis label.
|
|
57
|
+
color : str, default="#2563eb"
|
|
58
|
+
Color for the line.
|
|
59
|
+
ci_color : str, optional
|
|
60
|
+
Color for confidence band. Defaults to ``color`` with transparency.
|
|
61
|
+
show_zero_line : bool, default=True
|
|
62
|
+
Whether to show a horizontal line at y=0.
|
|
63
|
+
ax : matplotlib.axes.Axes, optional
|
|
64
|
+
Axes to plot on. If None, creates new figure.
|
|
65
|
+
show : bool, default=True
|
|
66
|
+
Whether to call plt.show() at the end.
|
|
67
|
+
backend : str, default="matplotlib"
|
|
68
|
+
Plotting backend: ``"matplotlib"`` or ``"plotly"``.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
matplotlib.axes.Axes or plotly.graph_objects.Figure
|
|
73
|
+
The axes object (matplotlib) or figure (plotly).
|
|
74
|
+
"""
|
|
75
|
+
from scipy import stats as scipy_stats
|
|
76
|
+
|
|
77
|
+
# Extract dose-response data
|
|
78
|
+
if sum(x is not None for x in (results, curve, data)) != 1:
|
|
79
|
+
raise ValueError("Provide exactly one of 'results', 'curve', or 'data'.")
|
|
80
|
+
|
|
81
|
+
if results is not None:
|
|
82
|
+
if target == "att":
|
|
83
|
+
curve = results.dose_response_att
|
|
84
|
+
elif target == "acrt":
|
|
85
|
+
curve = results.dose_response_acrt
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"target must be 'att' or 'acrt', got '{target}'")
|
|
88
|
+
|
|
89
|
+
if curve is not None:
|
|
90
|
+
# Infer target from curve when passed directly (not via results)
|
|
91
|
+
if results is None and hasattr(curve, "target") and curve.target:
|
|
92
|
+
target = curve.target
|
|
93
|
+
dose_grid = curve.dose_grid
|
|
94
|
+
effects = curve.effects
|
|
95
|
+
ci_lower = curve.conf_int_lower
|
|
96
|
+
ci_upper = curve.conf_int_upper
|
|
97
|
+
elif data is not None:
|
|
98
|
+
if "dose" not in data.columns or "effect" not in data.columns:
|
|
99
|
+
raise ValueError("DataFrame must have 'dose' and 'effect' columns")
|
|
100
|
+
dose_grid = data["dose"].values
|
|
101
|
+
effects = data["effect"].values
|
|
102
|
+
if "conf_int_lower" in data.columns and "conf_int_upper" in data.columns:
|
|
103
|
+
ci_lower = data["conf_int_lower"].values
|
|
104
|
+
ci_upper = data["conf_int_upper"].values
|
|
105
|
+
elif "se" in data.columns:
|
|
106
|
+
z = scipy_stats.norm.ppf(1 - alpha / 2)
|
|
107
|
+
ci_lower = effects - z * data["se"].values
|
|
108
|
+
ci_upper = effects + z * data["se"].values
|
|
109
|
+
else:
|
|
110
|
+
ci_lower = None
|
|
111
|
+
ci_upper = None
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError("Must provide 'results', 'curve', or 'data'.")
|
|
114
|
+
|
|
115
|
+
# Auto-generate title
|
|
116
|
+
if title is None:
|
|
117
|
+
if target == "att":
|
|
118
|
+
title = "ATT Dose-Response Curve"
|
|
119
|
+
else:
|
|
120
|
+
title = "ACRT Dose-Response Curve"
|
|
121
|
+
|
|
122
|
+
if backend == "plotly":
|
|
123
|
+
return _render_dose_response_plotly(
|
|
124
|
+
dose_grid=dose_grid,
|
|
125
|
+
effects=effects,
|
|
126
|
+
ci_lower=ci_lower,
|
|
127
|
+
ci_upper=ci_upper,
|
|
128
|
+
title=title,
|
|
129
|
+
xlabel=xlabel,
|
|
130
|
+
ylabel=ylabel,
|
|
131
|
+
color=color,
|
|
132
|
+
ci_color=ci_color,
|
|
133
|
+
show_zero_line=show_zero_line,
|
|
134
|
+
show=show,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return _render_dose_response_mpl(
|
|
138
|
+
dose_grid=dose_grid,
|
|
139
|
+
effects=effects,
|
|
140
|
+
ci_lower=ci_lower,
|
|
141
|
+
ci_upper=ci_upper,
|
|
142
|
+
figsize=figsize,
|
|
143
|
+
title=title,
|
|
144
|
+
xlabel=xlabel,
|
|
145
|
+
ylabel=ylabel,
|
|
146
|
+
color=color,
|
|
147
|
+
ci_color=ci_color,
|
|
148
|
+
show_zero_line=show_zero_line,
|
|
149
|
+
ax=ax,
|
|
150
|
+
show=show,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _render_dose_response_mpl(
|
|
155
|
+
*,
|
|
156
|
+
dose_grid,
|
|
157
|
+
effects,
|
|
158
|
+
ci_lower,
|
|
159
|
+
ci_upper,
|
|
160
|
+
figsize,
|
|
161
|
+
title,
|
|
162
|
+
xlabel,
|
|
163
|
+
ylabel,
|
|
164
|
+
color,
|
|
165
|
+
ci_color,
|
|
166
|
+
show_zero_line,
|
|
167
|
+
ax,
|
|
168
|
+
show,
|
|
169
|
+
):
|
|
170
|
+
"""Render dose-response curve with matplotlib."""
|
|
171
|
+
from diff_diff.visualization._common import _require_matplotlib
|
|
172
|
+
|
|
173
|
+
plt = _require_matplotlib()
|
|
174
|
+
|
|
175
|
+
if ax is None:
|
|
176
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
177
|
+
else:
|
|
178
|
+
fig = ax.get_figure()
|
|
179
|
+
|
|
180
|
+
# Zero line
|
|
181
|
+
if show_zero_line:
|
|
182
|
+
ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5)
|
|
183
|
+
|
|
184
|
+
# Confidence band
|
|
185
|
+
if ci_lower is not None and ci_upper is not None:
|
|
186
|
+
band_color = ci_color or color
|
|
187
|
+
ax.fill_between(
|
|
188
|
+
dose_grid,
|
|
189
|
+
ci_lower,
|
|
190
|
+
ci_upper,
|
|
191
|
+
alpha=0.15,
|
|
192
|
+
color=band_color,
|
|
193
|
+
label="95% CI",
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Effect line
|
|
197
|
+
ax.plot(dose_grid, effects, color=color, linewidth=2, label="Effect")
|
|
198
|
+
|
|
199
|
+
ax.set_xlabel(xlabel)
|
|
200
|
+
ax.set_ylabel(ylabel)
|
|
201
|
+
ax.set_title(title)
|
|
202
|
+
ax.legend(loc="best")
|
|
203
|
+
ax.grid(True, alpha=0.3)
|
|
204
|
+
|
|
205
|
+
fig.tight_layout()
|
|
206
|
+
|
|
207
|
+
if show:
|
|
208
|
+
plt.show()
|
|
209
|
+
|
|
210
|
+
return ax
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _render_dose_response_plotly(
|
|
214
|
+
*,
|
|
215
|
+
dose_grid,
|
|
216
|
+
effects,
|
|
217
|
+
ci_lower,
|
|
218
|
+
ci_upper,
|
|
219
|
+
title,
|
|
220
|
+
xlabel,
|
|
221
|
+
ylabel,
|
|
222
|
+
color,
|
|
223
|
+
ci_color,
|
|
224
|
+
show_zero_line,
|
|
225
|
+
show,
|
|
226
|
+
):
|
|
227
|
+
"""Render dose-response curve with plotly."""
|
|
228
|
+
from diff_diff.visualization._common import (
|
|
229
|
+
_color_to_rgba,
|
|
230
|
+
_plotly_default_layout,
|
|
231
|
+
_require_plotly,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
go = _require_plotly()
|
|
235
|
+
|
|
236
|
+
fig = go.Figure()
|
|
237
|
+
|
|
238
|
+
# Zero line
|
|
239
|
+
if show_zero_line:
|
|
240
|
+
fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1, opacity=0.5)
|
|
241
|
+
|
|
242
|
+
# Confidence band
|
|
243
|
+
if ci_lower is not None and ci_upper is not None:
|
|
244
|
+
band_color = ci_color or color
|
|
245
|
+
dose_list = list(dose_grid)
|
|
246
|
+
fig.add_trace(
|
|
247
|
+
go.Scatter(
|
|
248
|
+
x=dose_list + dose_list[::-1],
|
|
249
|
+
y=list(ci_upper) + list(ci_lower)[::-1],
|
|
250
|
+
fill="toself",
|
|
251
|
+
fillcolor=_color_to_rgba(band_color, 0.15),
|
|
252
|
+
line=dict(color="rgba(0,0,0,0)"),
|
|
253
|
+
name="95% CI",
|
|
254
|
+
hoverinfo="skip",
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Effect line
|
|
259
|
+
fig.add_trace(
|
|
260
|
+
go.Scatter(
|
|
261
|
+
x=list(dose_grid),
|
|
262
|
+
y=list(effects),
|
|
263
|
+
mode="lines",
|
|
264
|
+
line=dict(color=color, width=2),
|
|
265
|
+
name="Effect",
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
|
|
270
|
+
|
|
271
|
+
if show:
|
|
272
|
+
fig.show()
|
|
273
|
+
|
|
274
|
+
return fig
|