diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,61 @@
1
+ """
2
+ Visualization functions for difference-in-differences analysis.
3
+
4
+ Provides event study plots, diagnostic visualizations, and other plotting
5
+ utilities with support for matplotlib (default) and plotly backends.
6
+ """
7
+
8
+ # Event study plots
9
+ # Continuous DiD plots
10
+ from diff_diff.visualization._continuous import (
11
+ plot_dose_response,
12
+ )
13
+
14
+ # Diagnostic plots
15
+ from diff_diff.visualization._diagnostic import (
16
+ plot_bacon,
17
+ plot_sensitivity,
18
+ )
19
+ from diff_diff.visualization._event_study import (
20
+ PlottableResults,
21
+ _extract_plot_data,
22
+ plot_event_study,
23
+ plot_honest_event_study,
24
+ )
25
+
26
+ # Power analysis plots
27
+ from diff_diff.visualization._power import (
28
+ plot_power_curve,
29
+ plot_pretrends_power,
30
+ )
31
+
32
+ # Staggered DiD plots
33
+ from diff_diff.visualization._staggered import (
34
+ plot_group_effects,
35
+ plot_group_time_heatmap,
36
+ plot_staircase,
37
+ )
38
+
39
+ # Synthetic control plots
40
+ from diff_diff.visualization._synthetic import (
41
+ plot_synth_weights,
42
+ )
43
+
44
+ __all__ = [
45
+ # Existing public functions
46
+ "plot_event_study",
47
+ "plot_honest_event_study",
48
+ "plot_group_effects",
49
+ "plot_sensitivity",
50
+ "plot_bacon",
51
+ "plot_power_curve",
52
+ "plot_pretrends_power",
53
+ # New public functions
54
+ "plot_synth_weights",
55
+ "plot_staircase",
56
+ "plot_dose_response",
57
+ "plot_group_time_heatmap",
58
+ # Re-exported for backward compatibility (used in tests)
59
+ "_extract_plot_data",
60
+ "PlottableResults",
61
+ ]
@@ -0,0 +1,328 @@
1
+ """Shared utilities for the visualization subpackage."""
2
+
3
+ import re
4
+
5
+
6
+ def _require_matplotlib():
7
+ """Lazy import matplotlib with clear error message.
8
+
9
+ Returns
10
+ -------
11
+ module
12
+ The ``matplotlib.pyplot`` module.
13
+ """
14
+ try:
15
+ import matplotlib.pyplot as plt
16
+
17
+ return plt
18
+ except ImportError:
19
+ raise ImportError(
20
+ "matplotlib is required for plotting. " "Install it with: pip install matplotlib"
21
+ )
22
+
23
+
24
+ def _require_plotly():
25
+ """Lazy import plotly with clear error message.
26
+
27
+ Returns
28
+ -------
29
+ module
30
+ The ``plotly.graph_objects`` module.
31
+ """
32
+ try:
33
+ import plotly.graph_objects as go
34
+
35
+ return go
36
+ except ImportError:
37
+ raise ImportError(
38
+ "plotly is required for interactive plots. "
39
+ "Install with: pip install diff-diff[plotly]"
40
+ )
41
+
42
+
43
+ def _plotly_default_layout(fig, *, title=None, xlabel=None, ylabel=None, show_legend=True):
44
+ """Apply standard plotly layout settings.
45
+
46
+ Parameters
47
+ ----------
48
+ fig : plotly.graph_objects.Figure
49
+ The figure to configure.
50
+ title : str, optional
51
+ Plot title.
52
+ xlabel : str, optional
53
+ X-axis label.
54
+ ylabel : str, optional
55
+ Y-axis label.
56
+ show_legend : bool, default=True
57
+ Whether to show the legend.
58
+ """
59
+ fig.update_layout(
60
+ title=title,
61
+ xaxis_title=xlabel,
62
+ yaxis_title=ylabel,
63
+ showlegend=show_legend,
64
+ template="plotly_white",
65
+ font=dict(size=12),
66
+ margin=dict(l=60, r=30, t=50, b=50),
67
+ )
68
+
69
+
70
+ # Complete CSS named color table (all 148 standard CSS colors).
71
+ # No matplotlib dependency required for any color used by plotly/CSS.
72
+ _CSS_COLORS = {
73
+ "aliceblue": (240, 248, 255),
74
+ "antiquewhite": (250, 235, 215),
75
+ "aqua": (0, 255, 255),
76
+ "aquamarine": (127, 255, 212),
77
+ "azure": (240, 255, 255),
78
+ "beige": (245, 245, 220),
79
+ "bisque": (255, 228, 196),
80
+ "black": (0, 0, 0),
81
+ "blanchedalmond": (255, 235, 205),
82
+ "blue": (0, 0, 255),
83
+ "blueviolet": (138, 43, 226),
84
+ "brown": (165, 42, 42),
85
+ "burlywood": (222, 184, 135),
86
+ "cadetblue": (95, 158, 160),
87
+ "chartreuse": (127, 255, 0),
88
+ "chocolate": (210, 105, 30),
89
+ "coral": (255, 127, 80),
90
+ "cornflowerblue": (100, 149, 237),
91
+ "cornsilk": (255, 248, 220),
92
+ "crimson": (220, 20, 60),
93
+ "cyan": (0, 255, 255),
94
+ "darkblue": (0, 0, 139),
95
+ "darkcyan": (0, 139, 139),
96
+ "darkgoldenrod": (184, 134, 11),
97
+ "darkgray": (169, 169, 169),
98
+ "darkgreen": (0, 100, 0),
99
+ "darkgrey": (169, 169, 169),
100
+ "darkkhaki": (189, 183, 107),
101
+ "darkmagenta": (139, 0, 139),
102
+ "darkolivegreen": (85, 107, 47),
103
+ "darkorange": (255, 140, 0),
104
+ "darkorchid": (153, 50, 204),
105
+ "darkred": (139, 0, 0),
106
+ "darksalmon": (233, 150, 122),
107
+ "darkseagreen": (143, 188, 143),
108
+ "darkslateblue": (72, 61, 139),
109
+ "darkslategray": (47, 79, 79),
110
+ "darkslategrey": (47, 79, 79),
111
+ "darkturquoise": (0, 206, 209),
112
+ "darkviolet": (148, 0, 211),
113
+ "deeppink": (255, 20, 147),
114
+ "deepskyblue": (0, 191, 255),
115
+ "dimgray": (105, 105, 105),
116
+ "dimgrey": (105, 105, 105),
117
+ "dodgerblue": (30, 144, 255),
118
+ "firebrick": (178, 34, 34),
119
+ "floralwhite": (255, 250, 240),
120
+ "forestgreen": (34, 139, 34),
121
+ "fuchsia": (255, 0, 255),
122
+ "gainsboro": (220, 220, 220),
123
+ "ghostwhite": (248, 248, 255),
124
+ "gold": (255, 215, 0),
125
+ "goldenrod": (218, 165, 32),
126
+ "gray": (128, 128, 128),
127
+ "green": (0, 128, 0),
128
+ "greenyellow": (173, 255, 47),
129
+ "grey": (128, 128, 128),
130
+ "honeydew": (240, 255, 240),
131
+ "hotpink": (255, 105, 180),
132
+ "indianred": (205, 92, 92),
133
+ "indigo": (75, 0, 130),
134
+ "ivory": (255, 255, 240),
135
+ "khaki": (240, 230, 140),
136
+ "lavender": (230, 230, 250),
137
+ "lavenderblush": (255, 240, 245),
138
+ "lawngreen": (124, 252, 0),
139
+ "lemonchiffon": (255, 250, 205),
140
+ "lightblue": (173, 216, 230),
141
+ "lightcoral": (240, 128, 128),
142
+ "lightcyan": (224, 255, 255),
143
+ "lightgoldenrodyellow": (250, 250, 210),
144
+ "lightgray": (211, 211, 211),
145
+ "lightgreen": (144, 238, 144),
146
+ "lightgrey": (211, 211, 211),
147
+ "lightpink": (255, 182, 193),
148
+ "lightsalmon": (255, 160, 122),
149
+ "lightseagreen": (32, 178, 170),
150
+ "lightskyblue": (135, 206, 250),
151
+ "lightslategray": (119, 136, 153),
152
+ "lightslategrey": (119, 136, 153),
153
+ "lightsteelblue": (176, 196, 222),
154
+ "lightyellow": (255, 255, 224),
155
+ "lime": (0, 255, 0),
156
+ "limegreen": (50, 205, 50),
157
+ "linen": (250, 240, 230),
158
+ "magenta": (255, 0, 255),
159
+ "maroon": (128, 0, 0),
160
+ "mediumaquamarine": (102, 205, 170),
161
+ "mediumblue": (0, 0, 205),
162
+ "mediumorchid": (186, 85, 211),
163
+ "mediumpurple": (147, 112, 219),
164
+ "mediumseagreen": (60, 179, 113),
165
+ "mediumslateblue": (123, 104, 238),
166
+ "mediumspringgreen": (0, 250, 154),
167
+ "mediumturquoise": (72, 209, 204),
168
+ "mediumvioletred": (199, 21, 133),
169
+ "midnightblue": (25, 25, 112),
170
+ "mintcream": (245, 255, 250),
171
+ "mistyrose": (255, 228, 225),
172
+ "moccasin": (255, 228, 181),
173
+ "navajowhite": (255, 222, 173),
174
+ "navy": (0, 0, 128),
175
+ "oldlace": (253, 245, 230),
176
+ "olive": (128, 128, 0),
177
+ "olivedrab": (107, 142, 35),
178
+ "orange": (255, 165, 0),
179
+ "orangered": (255, 69, 0),
180
+ "orchid": (218, 112, 214),
181
+ "palegoldenrod": (238, 232, 170),
182
+ "palegreen": (152, 251, 152),
183
+ "paleturquoise": (175, 238, 238),
184
+ "palevioletred": (219, 112, 147),
185
+ "papayawhip": (255, 239, 213),
186
+ "peachpuff": (255, 218, 185),
187
+ "peru": (205, 133, 63),
188
+ "pink": (255, 192, 203),
189
+ "plum": (221, 160, 221),
190
+ "powderblue": (176, 224, 230),
191
+ "purple": (128, 0, 128),
192
+ "rebeccapurple": (102, 51, 153),
193
+ "red": (255, 0, 0),
194
+ "rosybrown": (188, 143, 143),
195
+ "royalblue": (65, 105, 225),
196
+ "saddlebrown": (139, 69, 19),
197
+ "salmon": (250, 128, 114),
198
+ "sandybrown": (244, 164, 96),
199
+ "seagreen": (46, 139, 87),
200
+ "seashell": (255, 245, 238),
201
+ "sienna": (160, 82, 45),
202
+ "silver": (192, 192, 192),
203
+ "skyblue": (135, 206, 235),
204
+ "slateblue": (106, 90, 205),
205
+ "slategray": (112, 128, 144),
206
+ "slategrey": (112, 128, 144),
207
+ "snow": (255, 250, 250),
208
+ "springgreen": (0, 255, 127),
209
+ "steelblue": (70, 130, 180),
210
+ "tan": (210, 180, 140),
211
+ "teal": (0, 128, 128),
212
+ "thistle": (216, 191, 216),
213
+ "tomato": (255, 99, 71),
214
+ "turquoise": (64, 224, 208),
215
+ "violet": (238, 130, 238),
216
+ "wheat": (245, 222, 179),
217
+ "white": (255, 255, 255),
218
+ "whitesmoke": (245, 245, 245),
219
+ "yellow": (255, 255, 0),
220
+ "yellowgreen": (154, 205, 50),
221
+ }
222
+
223
+ _RGB_RE = re.compile(r"^rgb\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$")
224
+ _RGBA_RE = re.compile(r"^rgba\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*([0-9.]+)\s*\)$")
225
+
226
+
227
+ def _color_to_rgba(color, alpha=1.0):
228
+ """Convert any color to an ``rgba(r, g, b, a)`` string for plotly.
229
+
230
+ Accepts hex colors (``#rrggbb``, ``#rgb``), all 148 CSS named colors,
231
+ and ``rgb(r,g,b)`` / ``rgba(r,g,b,a)`` strings. Does **not** require
232
+ matplotlib.
233
+
234
+ Parameters
235
+ ----------
236
+ color : str
237
+ Color specification.
238
+ alpha : float, default=1.0
239
+ Opacity value between 0 and 1.
240
+
241
+ Returns
242
+ -------
243
+ str
244
+ An ``rgba(r, g, b, a)`` string.
245
+ """
246
+ if not isinstance(color, str):
247
+ raise ValueError(f"Expected a color string, got {type(color).__name__}")
248
+
249
+ # 1. Hex colors: #rrggbb or #rgb
250
+ stripped = color.lstrip("#")
251
+ if color.startswith("#") and all(c in "0123456789abcdefABCDEF" for c in stripped):
252
+ if len(stripped) == 6:
253
+ r = int(stripped[0:2], 16)
254
+ g = int(stripped[2:4], 16)
255
+ b = int(stripped[4:6], 16)
256
+ return f"rgba({r}, {g}, {b}, {alpha})"
257
+ if len(stripped) == 3:
258
+ r = int(stripped[0] * 2, 16)
259
+ g = int(stripped[1] * 2, 16)
260
+ b = int(stripped[2] * 2, 16)
261
+ return f"rgba({r}, {g}, {b}, {alpha})"
262
+
263
+ # 2. Named CSS colors (complete table — no matplotlib needed)
264
+ if color.lower() in _CSS_COLORS:
265
+ r, g, b = _CSS_COLORS[color.lower()]
266
+ return f"rgba({r}, {g}, {b}, {alpha})"
267
+
268
+ # 3. rgb(r, g, b) — parse and apply alpha
269
+ m = _RGB_RE.match(color.strip())
270
+ if m:
271
+ r, g, b = int(m.group(1)), int(m.group(2)), int(m.group(3))
272
+ return f"rgba({r}, {g}, {b}, {alpha})"
273
+
274
+ # 4. rgba(r, g, b, a) — parse and override alpha
275
+ m = _RGBA_RE.match(color.strip())
276
+ if m:
277
+ r, g, b = int(m.group(1)), int(m.group(2)), int(m.group(3))
278
+ return f"rgba({r}, {g}, {b}, {alpha})"
279
+
280
+ raise ValueError(
281
+ f"Cannot parse color '{color}'. Use hex (#rrggbb), a CSS color name, "
282
+ "or rgb(r,g,b) / rgba(r,g,b,a) format."
283
+ )
284
+
285
+
286
+ # Matplotlib marker code -> plotly symbol name mapping
287
+ _MPL_TO_PLOTLY_SYMBOL = {
288
+ "o": "circle",
289
+ "s": "square",
290
+ "D": "diamond",
291
+ "d": "diamond",
292
+ "^": "triangle-up",
293
+ "v": "triangle-down",
294
+ "<": "triangle-left",
295
+ ">": "triangle-right",
296
+ "p": "pentagon",
297
+ "h": "hexagon",
298
+ "+": "cross",
299
+ "x": "x",
300
+ "*": "star",
301
+ ".": "circle",
302
+ }
303
+
304
+
305
+ def _mpl_marker_to_plotly_symbol(marker):
306
+ """Convert a matplotlib marker code to a plotly symbol name.
307
+
308
+ Parameters
309
+ ----------
310
+ marker : str
311
+ Matplotlib marker shorthand (e.g., ``"o"``, ``"s"``, ``"D"``).
312
+
313
+ Returns
314
+ -------
315
+ str
316
+ Plotly symbol name (e.g., ``"circle"``, ``"square"``, ``"diamond"``).
317
+ Returns ``"circle"`` for unrecognized markers.
318
+ """
319
+ return _MPL_TO_PLOTLY_SYMBOL.get(marker, "circle")
320
+
321
+
322
+ # Default color constants
323
+ DEFAULT_BLUE = "#2563eb"
324
+ DEFAULT_RED = "#dc2626"
325
+ DEFAULT_GREEN = "#22c55e"
326
+ DEFAULT_GRAY = "#6b7280"
327
+ DEFAULT_DARK = "#1f2937"
328
+ DEFAULT_SHADE = "#f0f0f0"
@@ -0,0 +1,274 @@
1
+ """Continuous DiD visualization functions (dose-response curves)."""
2
+
3
+ from typing import TYPE_CHECKING, Any, Optional, Tuple
4
+
5
+ import pandas as pd
6
+
7
+ if TYPE_CHECKING:
8
+ from diff_diff.continuous_did_results import ContinuousDiDResults, DoseResponseCurve
9
+
10
+
11
+ def plot_dose_response(
12
+ results: Optional["ContinuousDiDResults"] = None,
13
+ *,
14
+ curve: Optional["DoseResponseCurve"] = None,
15
+ data: Optional[pd.DataFrame] = None,
16
+ target: str = "att",
17
+ alpha: float = 0.05,
18
+ figsize: Tuple[float, float] = (10, 6),
19
+ title: Optional[str] = None,
20
+ xlabel: str = "Dose",
21
+ ylabel: str = "Treatment Effect",
22
+ color: str = "#2563eb",
23
+ ci_color: Optional[str] = None,
24
+ show_zero_line: bool = True,
25
+ ax: Optional[Any] = None,
26
+ show: bool = True,
27
+ backend: str = "matplotlib",
28
+ ) -> Any:
29
+ """
30
+ Plot dose-response curve from Continuous DiD estimation.
31
+
32
+ Visualizes how the treatment effect varies with the treatment dose
33
+ (intensity), with confidence bands.
34
+
35
+ Parameters
36
+ ----------
37
+ results : ContinuousDiDResults, optional
38
+ Results from ContinuousDiD estimator. Extracts the dose-response
39
+ curve based on ``target``.
40
+ curve : DoseResponseCurve, optional
41
+ A DoseResponseCurve object directly.
42
+ data : pd.DataFrame, optional
43
+ DataFrame with columns ``dose``, ``effect``, ``se`` (and optionally
44
+ ``conf_int_lower``, ``conf_int_upper``).
45
+ target : str, default="att"
46
+ Which dose-response curve: ``"att"`` or ``"acrt"``.
47
+ alpha : float, default=0.05
48
+ Significance level for confidence intervals (used with DataFrame input).
49
+ figsize : tuple, default=(10, 6)
50
+ Figure size (width, height) in inches.
51
+ title : str, optional
52
+ Plot title. Auto-generated if None.
53
+ xlabel : str, default="Dose"
54
+ X-axis label.
55
+ ylabel : str, default="Treatment Effect"
56
+ Y-axis label.
57
+ color : str, default="#2563eb"
58
+ Color for the line.
59
+ ci_color : str, optional
60
+ Color for confidence band. Defaults to ``color`` with transparency.
61
+ show_zero_line : bool, default=True
62
+ Whether to show a horizontal line at y=0.
63
+ ax : matplotlib.axes.Axes, optional
64
+ Axes to plot on. If None, creates new figure.
65
+ show : bool, default=True
66
+ Whether to call plt.show() at the end.
67
+ backend : str, default="matplotlib"
68
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
69
+
70
+ Returns
71
+ -------
72
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
73
+ The axes object (matplotlib) or figure (plotly).
74
+ """
75
+ from scipy import stats as scipy_stats
76
+
77
+ # Extract dose-response data
78
+ if sum(x is not None for x in (results, curve, data)) != 1:
79
+ raise ValueError("Provide exactly one of 'results', 'curve', or 'data'.")
80
+
81
+ if results is not None:
82
+ if target == "att":
83
+ curve = results.dose_response_att
84
+ elif target == "acrt":
85
+ curve = results.dose_response_acrt
86
+ else:
87
+ raise ValueError(f"target must be 'att' or 'acrt', got '{target}'")
88
+
89
+ if curve is not None:
90
+ # Infer target from curve when passed directly (not via results)
91
+ if results is None and hasattr(curve, "target") and curve.target:
92
+ target = curve.target
93
+ dose_grid = curve.dose_grid
94
+ effects = curve.effects
95
+ ci_lower = curve.conf_int_lower
96
+ ci_upper = curve.conf_int_upper
97
+ elif data is not None:
98
+ if "dose" not in data.columns or "effect" not in data.columns:
99
+ raise ValueError("DataFrame must have 'dose' and 'effect' columns")
100
+ dose_grid = data["dose"].values
101
+ effects = data["effect"].values
102
+ if "conf_int_lower" in data.columns and "conf_int_upper" in data.columns:
103
+ ci_lower = data["conf_int_lower"].values
104
+ ci_upper = data["conf_int_upper"].values
105
+ elif "se" in data.columns:
106
+ z = scipy_stats.norm.ppf(1 - alpha / 2)
107
+ ci_lower = effects - z * data["se"].values
108
+ ci_upper = effects + z * data["se"].values
109
+ else:
110
+ ci_lower = None
111
+ ci_upper = None
112
+ else:
113
+ raise ValueError("Must provide 'results', 'curve', or 'data'.")
114
+
115
+ # Auto-generate title
116
+ if title is None:
117
+ if target == "att":
118
+ title = "ATT Dose-Response Curve"
119
+ else:
120
+ title = "ACRT Dose-Response Curve"
121
+
122
+ if backend == "plotly":
123
+ return _render_dose_response_plotly(
124
+ dose_grid=dose_grid,
125
+ effects=effects,
126
+ ci_lower=ci_lower,
127
+ ci_upper=ci_upper,
128
+ title=title,
129
+ xlabel=xlabel,
130
+ ylabel=ylabel,
131
+ color=color,
132
+ ci_color=ci_color,
133
+ show_zero_line=show_zero_line,
134
+ show=show,
135
+ )
136
+
137
+ return _render_dose_response_mpl(
138
+ dose_grid=dose_grid,
139
+ effects=effects,
140
+ ci_lower=ci_lower,
141
+ ci_upper=ci_upper,
142
+ figsize=figsize,
143
+ title=title,
144
+ xlabel=xlabel,
145
+ ylabel=ylabel,
146
+ color=color,
147
+ ci_color=ci_color,
148
+ show_zero_line=show_zero_line,
149
+ ax=ax,
150
+ show=show,
151
+ )
152
+
153
+
154
+ def _render_dose_response_mpl(
155
+ *,
156
+ dose_grid,
157
+ effects,
158
+ ci_lower,
159
+ ci_upper,
160
+ figsize,
161
+ title,
162
+ xlabel,
163
+ ylabel,
164
+ color,
165
+ ci_color,
166
+ show_zero_line,
167
+ ax,
168
+ show,
169
+ ):
170
+ """Render dose-response curve with matplotlib."""
171
+ from diff_diff.visualization._common import _require_matplotlib
172
+
173
+ plt = _require_matplotlib()
174
+
175
+ if ax is None:
176
+ fig, ax = plt.subplots(figsize=figsize)
177
+ else:
178
+ fig = ax.get_figure()
179
+
180
+ # Zero line
181
+ if show_zero_line:
182
+ ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5)
183
+
184
+ # Confidence band
185
+ if ci_lower is not None and ci_upper is not None:
186
+ band_color = ci_color or color
187
+ ax.fill_between(
188
+ dose_grid,
189
+ ci_lower,
190
+ ci_upper,
191
+ alpha=0.15,
192
+ color=band_color,
193
+ label="95% CI",
194
+ )
195
+
196
+ # Effect line
197
+ ax.plot(dose_grid, effects, color=color, linewidth=2, label="Effect")
198
+
199
+ ax.set_xlabel(xlabel)
200
+ ax.set_ylabel(ylabel)
201
+ ax.set_title(title)
202
+ ax.legend(loc="best")
203
+ ax.grid(True, alpha=0.3)
204
+
205
+ fig.tight_layout()
206
+
207
+ if show:
208
+ plt.show()
209
+
210
+ return ax
211
+
212
+
213
+ def _render_dose_response_plotly(
214
+ *,
215
+ dose_grid,
216
+ effects,
217
+ ci_lower,
218
+ ci_upper,
219
+ title,
220
+ xlabel,
221
+ ylabel,
222
+ color,
223
+ ci_color,
224
+ show_zero_line,
225
+ show,
226
+ ):
227
+ """Render dose-response curve with plotly."""
228
+ from diff_diff.visualization._common import (
229
+ _color_to_rgba,
230
+ _plotly_default_layout,
231
+ _require_plotly,
232
+ )
233
+
234
+ go = _require_plotly()
235
+
236
+ fig = go.Figure()
237
+
238
+ # Zero line
239
+ if show_zero_line:
240
+ fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1, opacity=0.5)
241
+
242
+ # Confidence band
243
+ if ci_lower is not None and ci_upper is not None:
244
+ band_color = ci_color or color
245
+ dose_list = list(dose_grid)
246
+ fig.add_trace(
247
+ go.Scatter(
248
+ x=dose_list + dose_list[::-1],
249
+ y=list(ci_upper) + list(ci_lower)[::-1],
250
+ fill="toself",
251
+ fillcolor=_color_to_rgba(band_color, 0.15),
252
+ line=dict(color="rgba(0,0,0,0)"),
253
+ name="95% CI",
254
+ hoverinfo="skip",
255
+ )
256
+ )
257
+
258
+ # Effect line
259
+ fig.add_trace(
260
+ go.Scatter(
261
+ x=list(dose_grid),
262
+ y=list(effects),
263
+ mode="lines",
264
+ line=dict(color=color, width=2),
265
+ name="Effect",
266
+ )
267
+ )
268
+
269
+ _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
270
+
271
+ if show:
272
+ fig.show()
273
+
274
+ return fig