diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
|
@@ -0,0 +1,833 @@
|
|
|
1
|
+
"""Staggered DiD visualization functions (group effects, staircase, heatmap)."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from diff_diff.continuous_did_results import ContinuousDiDResults
|
|
10
|
+
from diff_diff.efficient_did_results import EfficientDiDResults
|
|
11
|
+
from diff_diff.staggered import CallawaySantAnnaResults
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def plot_group_effects(
|
|
15
|
+
results: "CallawaySantAnnaResults",
|
|
16
|
+
*,
|
|
17
|
+
groups: Optional[List[Any]] = None,
|
|
18
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
19
|
+
title: str = "Treatment Effects by Cohort",
|
|
20
|
+
xlabel: str = "Time Period",
|
|
21
|
+
ylabel: str = "Treatment Effect",
|
|
22
|
+
alpha: float = 0.05,
|
|
23
|
+
show: bool = True,
|
|
24
|
+
ax: Optional[Any] = None,
|
|
25
|
+
backend: str = "matplotlib",
|
|
26
|
+
) -> Any:
|
|
27
|
+
"""
|
|
28
|
+
Plot treatment effects by treatment cohort (group).
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
results : CallawaySantAnnaResults
|
|
33
|
+
Results from CallawaySantAnna estimator.
|
|
34
|
+
groups : list, optional
|
|
35
|
+
List of groups (cohorts) to plot. If None, plots all groups.
|
|
36
|
+
figsize : tuple, default=(10, 6)
|
|
37
|
+
Figure size.
|
|
38
|
+
title : str
|
|
39
|
+
Plot title.
|
|
40
|
+
xlabel : str
|
|
41
|
+
X-axis label.
|
|
42
|
+
ylabel : str
|
|
43
|
+
Y-axis label.
|
|
44
|
+
alpha : float, default=0.05
|
|
45
|
+
Significance level for confidence intervals.
|
|
46
|
+
show : bool, default=True
|
|
47
|
+
Whether to call plt.show().
|
|
48
|
+
ax : matplotlib.axes.Axes, optional
|
|
49
|
+
Axes to plot on.
|
|
50
|
+
backend : str, default="matplotlib"
|
|
51
|
+
Plotting backend: ``"matplotlib"`` or ``"plotly"``.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
matplotlib.axes.Axes or plotly.graph_objects.Figure
|
|
56
|
+
The axes object (matplotlib) or figure (plotly).
|
|
57
|
+
"""
|
|
58
|
+
from scipy import stats as scipy_stats
|
|
59
|
+
|
|
60
|
+
if not hasattr(results, "group_time_effects"):
|
|
61
|
+
raise TypeError("results must be a CallawaySantAnnaResults object")
|
|
62
|
+
|
|
63
|
+
# Get groups to plot
|
|
64
|
+
if groups is None:
|
|
65
|
+
groups = sorted(set(g for g, t in results.group_time_effects.keys()))
|
|
66
|
+
|
|
67
|
+
critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
|
|
68
|
+
|
|
69
|
+
# Build data per group
|
|
70
|
+
group_data = {}
|
|
71
|
+
for group in groups:
|
|
72
|
+
group_effects = [
|
|
73
|
+
(t, data) for (g, t), data in results.group_time_effects.items() if g == group
|
|
74
|
+
]
|
|
75
|
+
group_effects.sort(key=lambda x: x[0])
|
|
76
|
+
if not group_effects:
|
|
77
|
+
continue
|
|
78
|
+
group_data[group] = group_effects
|
|
79
|
+
|
|
80
|
+
if backend == "plotly":
|
|
81
|
+
return _render_group_effects_plotly(
|
|
82
|
+
group_data=group_data,
|
|
83
|
+
groups=groups,
|
|
84
|
+
critical_value=critical_value,
|
|
85
|
+
title=title,
|
|
86
|
+
xlabel=xlabel,
|
|
87
|
+
ylabel=ylabel,
|
|
88
|
+
show=show,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return _render_group_effects_mpl(
|
|
92
|
+
group_data=group_data,
|
|
93
|
+
groups=groups,
|
|
94
|
+
critical_value=critical_value,
|
|
95
|
+
figsize=figsize,
|
|
96
|
+
title=title,
|
|
97
|
+
xlabel=xlabel,
|
|
98
|
+
ylabel=ylabel,
|
|
99
|
+
ax=ax,
|
|
100
|
+
show=show,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _render_group_effects_mpl(
|
|
105
|
+
*, group_data, groups, critical_value, figsize, title, xlabel, ylabel, ax, show
|
|
106
|
+
):
|
|
107
|
+
"""Render group effects plot with matplotlib."""
|
|
108
|
+
from diff_diff.visualization._common import _require_matplotlib
|
|
109
|
+
|
|
110
|
+
plt = _require_matplotlib()
|
|
111
|
+
|
|
112
|
+
if ax is None:
|
|
113
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
114
|
+
else:
|
|
115
|
+
fig = ax.get_figure()
|
|
116
|
+
|
|
117
|
+
cmap = getattr(plt.cm, "tab10", None) or plt.colormaps["tab10"]
|
|
118
|
+
colors = cmap(np.linspace(0, 1, len(groups)))
|
|
119
|
+
|
|
120
|
+
for i, group in enumerate(groups):
|
|
121
|
+
if group not in group_data:
|
|
122
|
+
continue
|
|
123
|
+
group_effects = group_data[group]
|
|
124
|
+
times = [t for t, _ in group_effects]
|
|
125
|
+
effects = [data["effect"] for _, data in group_effects]
|
|
126
|
+
ses = [data["se"] for _, data in group_effects]
|
|
127
|
+
|
|
128
|
+
yerr = [
|
|
129
|
+
[e - (e - critical_value * s) for e, s in zip(effects, ses)],
|
|
130
|
+
[(e + critical_value * s) - e for e, s in zip(effects, ses)],
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
ax.errorbar(
|
|
134
|
+
times,
|
|
135
|
+
effects,
|
|
136
|
+
yerr=yerr,
|
|
137
|
+
label=f"Cohort {group}",
|
|
138
|
+
color=colors[i],
|
|
139
|
+
marker="o",
|
|
140
|
+
capsize=3,
|
|
141
|
+
linewidth=1.5,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
ax.axhline(y=0, color="gray", linestyle="--", linewidth=1)
|
|
145
|
+
ax.set_xlabel(xlabel)
|
|
146
|
+
ax.set_ylabel(ylabel)
|
|
147
|
+
ax.set_title(title)
|
|
148
|
+
ax.legend(loc="best")
|
|
149
|
+
ax.grid(True, alpha=0.3, axis="y")
|
|
150
|
+
|
|
151
|
+
fig.tight_layout()
|
|
152
|
+
|
|
153
|
+
if show:
|
|
154
|
+
plt.show()
|
|
155
|
+
|
|
156
|
+
return ax
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _render_group_effects_plotly(
|
|
160
|
+
*, group_data, groups, critical_value, title, xlabel, ylabel, show
|
|
161
|
+
):
|
|
162
|
+
"""Render group effects plot with plotly."""
|
|
163
|
+
from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
|
|
164
|
+
|
|
165
|
+
go = _require_plotly()
|
|
166
|
+
|
|
167
|
+
fig = go.Figure()
|
|
168
|
+
|
|
169
|
+
# Zero line
|
|
170
|
+
fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1)
|
|
171
|
+
|
|
172
|
+
for group in groups:
|
|
173
|
+
if group not in group_data:
|
|
174
|
+
continue
|
|
175
|
+
group_effects = group_data[group]
|
|
176
|
+
times = [t for t, _ in group_effects]
|
|
177
|
+
effects = [data["effect"] for _, data in group_effects]
|
|
178
|
+
ses = [data["se"] for _, data in group_effects]
|
|
179
|
+
|
|
180
|
+
ci_lo = [e - critical_value * s for e, s in zip(effects, ses)]
|
|
181
|
+
ci_hi = [e + critical_value * s for e, s in zip(effects, ses)]
|
|
182
|
+
|
|
183
|
+
fig.add_trace(
|
|
184
|
+
go.Scatter(
|
|
185
|
+
x=times,
|
|
186
|
+
y=effects,
|
|
187
|
+
mode="lines+markers",
|
|
188
|
+
name=f"Cohort {group}",
|
|
189
|
+
error_y=dict(
|
|
190
|
+
type="data",
|
|
191
|
+
symmetric=False,
|
|
192
|
+
array=[h - e for e, h in zip(effects, ci_hi)],
|
|
193
|
+
arrayminus=[e - lo for e, lo in zip(effects, ci_lo)],
|
|
194
|
+
),
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
|
|
199
|
+
|
|
200
|
+
if show:
|
|
201
|
+
fig.show()
|
|
202
|
+
|
|
203
|
+
return fig
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def plot_staircase(
|
|
207
|
+
results: Optional["CallawaySantAnnaResults"] = None,
|
|
208
|
+
*,
|
|
209
|
+
data: Optional[pd.DataFrame] = None,
|
|
210
|
+
unit: Optional[str] = None,
|
|
211
|
+
time: Optional[str] = None,
|
|
212
|
+
first_treat: Optional[str] = None,
|
|
213
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
214
|
+
title: str = "Treatment Adoption Over Time",
|
|
215
|
+
color: str = "#2563eb",
|
|
216
|
+
show_counts: bool = True,
|
|
217
|
+
ax: Optional[Any] = None,
|
|
218
|
+
show: bool = True,
|
|
219
|
+
backend: str = "matplotlib",
|
|
220
|
+
) -> Any:
|
|
221
|
+
"""
|
|
222
|
+
Plot treatment adoption "staircase" for staggered designs.
|
|
223
|
+
|
|
224
|
+
Shows how many units enter treatment over time, creating a step-function
|
|
225
|
+
pattern that illustrates the staggered adoption of treatment.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
----------
|
|
229
|
+
results : CallawaySantAnnaResults, optional
|
|
230
|
+
Results from CallawaySantAnna estimator. Extracts groups and cohort
|
|
231
|
+
sizes from ``group_time_effects``.
|
|
232
|
+
data : pd.DataFrame, optional
|
|
233
|
+
Raw panel data. Must provide ``unit``, ``time``, and ``first_treat``
|
|
234
|
+
column names.
|
|
235
|
+
unit : str, optional
|
|
236
|
+
Column name for unit identifier (required with ``data``).
|
|
237
|
+
time : str, optional
|
|
238
|
+
Column name for time period (required with ``data``).
|
|
239
|
+
first_treat : str, optional
|
|
240
|
+
Column name for first treatment period (required with ``data``).
|
|
241
|
+
figsize : tuple, default=(10, 6)
|
|
242
|
+
Figure size (width, height) in inches.
|
|
243
|
+
title : str, default="Treatment Adoption Over Time"
|
|
244
|
+
Plot title.
|
|
245
|
+
color : str, default="#2563eb"
|
|
246
|
+
Base color for the staircase.
|
|
247
|
+
show_counts : bool, default=True
|
|
248
|
+
Whether to annotate each step with the cohort size.
|
|
249
|
+
ax : matplotlib.axes.Axes, optional
|
|
250
|
+
Axes to plot on. If None, creates new figure.
|
|
251
|
+
show : bool, default=True
|
|
252
|
+
Whether to call plt.show() at the end.
|
|
253
|
+
backend : str, default="matplotlib"
|
|
254
|
+
Plotting backend: ``"matplotlib"`` or ``"plotly"``.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
matplotlib.axes.Axes or plotly.graph_objects.Figure
|
|
259
|
+
The axes object (matplotlib) or figure (plotly).
|
|
260
|
+
"""
|
|
261
|
+
# Extract cohort data
|
|
262
|
+
cohort_counts = _extract_staircase_data(results, data, unit, time, first_treat)
|
|
263
|
+
|
|
264
|
+
if backend == "plotly":
|
|
265
|
+
return _render_staircase_plotly(
|
|
266
|
+
cohort_counts=cohort_counts,
|
|
267
|
+
title=title,
|
|
268
|
+
color=color,
|
|
269
|
+
show_counts=show_counts,
|
|
270
|
+
show=show,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return _render_staircase_mpl(
|
|
274
|
+
cohort_counts=cohort_counts,
|
|
275
|
+
figsize=figsize,
|
|
276
|
+
title=title,
|
|
277
|
+
color=color,
|
|
278
|
+
show_counts=show_counts,
|
|
279
|
+
ax=ax,
|
|
280
|
+
show=show,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _extract_staircase_data(results, data, unit, time, first_treat):
|
|
285
|
+
"""Extract cohort periods and counts for the staircase plot.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
list of (period, count) tuples, sorted by period.
|
|
290
|
+
"""
|
|
291
|
+
if results is not None and data is not None:
|
|
292
|
+
raise ValueError("Provide either 'results' or 'data', not both.")
|
|
293
|
+
|
|
294
|
+
if results is not None:
|
|
295
|
+
if not hasattr(results, "group_time_effects") or not hasattr(results, "groups"):
|
|
296
|
+
raise TypeError("results must be a CallawaySantAnnaResults object")
|
|
297
|
+
|
|
298
|
+
groups = sorted(results.groups)
|
|
299
|
+
cohort_counts = []
|
|
300
|
+
for g in groups:
|
|
301
|
+
# Collect n_treated across all (g, t) cells for this cohort.
|
|
302
|
+
# n_treated is a per-cell observation count that can vary with
|
|
303
|
+
# missingness, so we use the max as the best cohort size estimate.
|
|
304
|
+
cell_counts = []
|
|
305
|
+
for (gg, _tt), eff in results.group_time_effects.items():
|
|
306
|
+
if gg == g:
|
|
307
|
+
n = eff.get("n_treated", eff.get("n_obs", None))
|
|
308
|
+
if n is not None:
|
|
309
|
+
cell_counts.append(int(n))
|
|
310
|
+
if not cell_counts:
|
|
311
|
+
cohort_counts.append((g, 0))
|
|
312
|
+
continue
|
|
313
|
+
max_count = max(cell_counts)
|
|
314
|
+
if min(cell_counts) != max_count:
|
|
315
|
+
import warnings
|
|
316
|
+
|
|
317
|
+
warnings.warn(
|
|
318
|
+
f"Cohort {g}: n_treated varies across cells "
|
|
319
|
+
f"({min(cell_counts)}-{max_count}). "
|
|
320
|
+
f"Using max as cohort size; pass data= for exact counts.",
|
|
321
|
+
stacklevel=3,
|
|
322
|
+
)
|
|
323
|
+
cohort_counts.append((g, max_count))
|
|
324
|
+
|
|
325
|
+
return cohort_counts
|
|
326
|
+
|
|
327
|
+
if data is not None:
|
|
328
|
+
if unit is None or time is None or first_treat is None:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
"When using 'data', must provide 'unit', 'time', and 'first_treat' column names."
|
|
331
|
+
)
|
|
332
|
+
# Count unique units per first_treat cohort
|
|
333
|
+
cohort_df = data.groupby(first_treat)[unit].nunique().reset_index()
|
|
334
|
+
cohort_df.columns = ["period", "count"]
|
|
335
|
+
cohort_df = cohort_df.sort_values("period")
|
|
336
|
+
# Exclude never-treated (inf, NaN, or 0 conventions)
|
|
337
|
+
cohort_df = cohort_df[
|
|
338
|
+
cohort_df["period"].notna()
|
|
339
|
+
& np.isfinite(cohort_df["period"])
|
|
340
|
+
& (cohort_df["period"] > 0)
|
|
341
|
+
]
|
|
342
|
+
return list(zip(cohort_df["period"], cohort_df["count"]))
|
|
343
|
+
|
|
344
|
+
raise ValueError("Must provide either 'results' or 'data'.")
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _render_staircase_mpl(*, cohort_counts, figsize, title, color, show_counts, ax, show):
|
|
348
|
+
"""Render staircase plot with matplotlib."""
|
|
349
|
+
from diff_diff.visualization._common import _require_matplotlib
|
|
350
|
+
|
|
351
|
+
plt = _require_matplotlib()
|
|
352
|
+
|
|
353
|
+
if ax is None:
|
|
354
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
355
|
+
else:
|
|
356
|
+
fig = ax.get_figure()
|
|
357
|
+
|
|
358
|
+
if not cohort_counts:
|
|
359
|
+
ax.set_title(title)
|
|
360
|
+
ax.text(0.5, 0.5, "No treatment cohorts", ha="center", va="center", transform=ax.transAxes)
|
|
361
|
+
if show:
|
|
362
|
+
plt.show()
|
|
363
|
+
return ax
|
|
364
|
+
|
|
365
|
+
periods = [p for p, _ in cohort_counts]
|
|
366
|
+
counts = [c for _, c in cohort_counts]
|
|
367
|
+
cumulative = np.cumsum(counts)
|
|
368
|
+
|
|
369
|
+
# Create step plot
|
|
370
|
+
ax.step(periods, cumulative, where="post", color=color, linewidth=2, label="Cumulative treated")
|
|
371
|
+
ax.fill_between(periods, cumulative, step="post", alpha=0.15, color=color)
|
|
372
|
+
|
|
373
|
+
# Annotate cohort sizes
|
|
374
|
+
if show_counts:
|
|
375
|
+
for i, (period, count) in enumerate(cohort_counts):
|
|
376
|
+
cum = cumulative[i]
|
|
377
|
+
ax.annotate(
|
|
378
|
+
f"+{count}",
|
|
379
|
+
xy=(period, cum),
|
|
380
|
+
xytext=(0, 8),
|
|
381
|
+
textcoords="offset points",
|
|
382
|
+
ha="center",
|
|
383
|
+
fontsize=9,
|
|
384
|
+
color=color,
|
|
385
|
+
fontweight="bold",
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
ax.set_xlabel("Time Period")
|
|
389
|
+
ax.set_ylabel("Cumulative Treated Units")
|
|
390
|
+
ax.set_title(title)
|
|
391
|
+
ax.grid(True, alpha=0.3, axis="y")
|
|
392
|
+
|
|
393
|
+
# Set y to start at 0
|
|
394
|
+
ax.set_ylim(bottom=0)
|
|
395
|
+
|
|
396
|
+
fig.tight_layout()
|
|
397
|
+
|
|
398
|
+
if show:
|
|
399
|
+
plt.show()
|
|
400
|
+
|
|
401
|
+
return ax
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _render_staircase_plotly(*, cohort_counts, title, color, show_counts, show):
|
|
405
|
+
"""Render staircase plot with plotly."""
|
|
406
|
+
from diff_diff.visualization._common import (
|
|
407
|
+
_color_to_rgba,
|
|
408
|
+
_plotly_default_layout,
|
|
409
|
+
_require_plotly,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
go = _require_plotly()
|
|
413
|
+
|
|
414
|
+
fig = go.Figure()
|
|
415
|
+
|
|
416
|
+
if not cohort_counts:
|
|
417
|
+
fig.add_annotation(text="No treatment cohorts", x=0.5, y=0.5, showarrow=False)
|
|
418
|
+
_plotly_default_layout(fig, title=title)
|
|
419
|
+
if show:
|
|
420
|
+
fig.show()
|
|
421
|
+
return fig
|
|
422
|
+
|
|
423
|
+
periods = [p for p, _ in cohort_counts]
|
|
424
|
+
counts = [c for _, c in cohort_counts]
|
|
425
|
+
cumulative = list(np.cumsum(counts))
|
|
426
|
+
|
|
427
|
+
# Step line
|
|
428
|
+
fig.add_trace(
|
|
429
|
+
go.Scatter(
|
|
430
|
+
x=periods,
|
|
431
|
+
y=cumulative,
|
|
432
|
+
mode="lines",
|
|
433
|
+
line=dict(color=color, width=2, shape="hv"),
|
|
434
|
+
fill="tozeroy",
|
|
435
|
+
fillcolor=_color_to_rgba(color, 0.15),
|
|
436
|
+
name="Cumulative treated",
|
|
437
|
+
)
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Annotations for cohort sizes
|
|
441
|
+
if show_counts:
|
|
442
|
+
for period, count, cum in zip(periods, counts, cumulative):
|
|
443
|
+
fig.add_annotation(
|
|
444
|
+
x=period,
|
|
445
|
+
y=cum,
|
|
446
|
+
text=f"+{count}",
|
|
447
|
+
showarrow=False,
|
|
448
|
+
yshift=15,
|
|
449
|
+
font=dict(color=color, size=11),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
_plotly_default_layout(
|
|
453
|
+
fig,
|
|
454
|
+
title=title,
|
|
455
|
+
xlabel="Time Period",
|
|
456
|
+
ylabel="Cumulative Treated Units",
|
|
457
|
+
)
|
|
458
|
+
fig.update_yaxes(rangemode="tozero")
|
|
459
|
+
|
|
460
|
+
if show:
|
|
461
|
+
fig.show()
|
|
462
|
+
|
|
463
|
+
return fig
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def plot_group_time_heatmap(
|
|
467
|
+
results: Optional[
|
|
468
|
+
Union["CallawaySantAnnaResults", "EfficientDiDResults", "ContinuousDiDResults"]
|
|
469
|
+
] = None,
|
|
470
|
+
*,
|
|
471
|
+
data: Optional[pd.DataFrame] = None,
|
|
472
|
+
figsize: Tuple[float, float] = (10, 8),
|
|
473
|
+
title: str = "Group-Time Treatment Effects",
|
|
474
|
+
cmap: str = "RdBu_r",
|
|
475
|
+
center: float = 0.0,
|
|
476
|
+
annotate: bool = True,
|
|
477
|
+
fmt: str = ".3f",
|
|
478
|
+
mask_insignificant: bool = False,
|
|
479
|
+
alpha: float = 0.05,
|
|
480
|
+
ax: Optional[Any] = None,
|
|
481
|
+
show: bool = True,
|
|
482
|
+
backend: str = "matplotlib",
|
|
483
|
+
) -> Any:
|
|
484
|
+
"""
|
|
485
|
+
Plot heatmap of group-time treatment effects ATT(g,t).
|
|
486
|
+
|
|
487
|
+
Displays treatment effects as a colored matrix with treatment cohorts
|
|
488
|
+
(groups) on the y-axis and calendar time periods on the x-axis.
|
|
489
|
+
|
|
490
|
+
Parameters
|
|
491
|
+
----------
|
|
492
|
+
results : CallawaySantAnnaResults, EfficientDiDResults, or ContinuousDiDResults, optional
|
|
493
|
+
Results object with ``group_time_effects`` dict.
|
|
494
|
+
data : pd.DataFrame, optional
|
|
495
|
+
DataFrame with columns ``group``, ``time``, ``effect``
|
|
496
|
+
(and optionally ``p_value``).
|
|
497
|
+
figsize : tuple, default=(10, 8)
|
|
498
|
+
Figure size (width, height) in inches.
|
|
499
|
+
title : str, default="Group-Time Treatment Effects"
|
|
500
|
+
Plot title.
|
|
501
|
+
cmap : str, default="RdBu_r"
|
|
502
|
+
Colormap name. Diverging colormaps centered at zero work best.
|
|
503
|
+
center : float, default=0.0
|
|
504
|
+
Value to center the colormap at.
|
|
505
|
+
annotate : bool, default=True
|
|
506
|
+
Whether to show effect values in each cell.
|
|
507
|
+
fmt : str, default=".3f"
|
|
508
|
+
Format string for cell annotations.
|
|
509
|
+
mask_insignificant : bool, default=False
|
|
510
|
+
Whether to grey out cells with non-significant effects.
|
|
511
|
+
alpha : float, default=0.05
|
|
512
|
+
Significance level for masking (when ``mask_insignificant=True``).
|
|
513
|
+
ax : matplotlib.axes.Axes, optional
|
|
514
|
+
Axes to plot on. If None, creates new figure.
|
|
515
|
+
show : bool, default=True
|
|
516
|
+
Whether to call plt.show() at the end.
|
|
517
|
+
backend : str, default="matplotlib"
|
|
518
|
+
Plotting backend: ``"matplotlib"`` or ``"plotly"``.
|
|
519
|
+
|
|
520
|
+
Returns
|
|
521
|
+
-------
|
|
522
|
+
matplotlib.axes.Axes or plotly.graph_objects.Figure
|
|
523
|
+
The axes object (matplotlib) or figure (plotly).
|
|
524
|
+
"""
|
|
525
|
+
# Extract data into matrix form
|
|
526
|
+
effect_matrix, p_matrix, group_labels, time_labels = _extract_heatmap_data(results, data)
|
|
527
|
+
|
|
528
|
+
if backend == "plotly":
|
|
529
|
+
return _render_group_time_heatmap_plotly(
|
|
530
|
+
effect_matrix=effect_matrix,
|
|
531
|
+
p_matrix=p_matrix,
|
|
532
|
+
group_labels=group_labels,
|
|
533
|
+
time_labels=time_labels,
|
|
534
|
+
title=title,
|
|
535
|
+
cmap=cmap,
|
|
536
|
+
center=center,
|
|
537
|
+
annotate=annotate,
|
|
538
|
+
fmt=fmt,
|
|
539
|
+
mask_insignificant=mask_insignificant,
|
|
540
|
+
alpha=alpha,
|
|
541
|
+
show=show,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
return _render_group_time_heatmap_mpl(
|
|
545
|
+
effect_matrix=effect_matrix,
|
|
546
|
+
p_matrix=p_matrix,
|
|
547
|
+
group_labels=group_labels,
|
|
548
|
+
time_labels=time_labels,
|
|
549
|
+
figsize=figsize,
|
|
550
|
+
title=title,
|
|
551
|
+
cmap=cmap,
|
|
552
|
+
center=center,
|
|
553
|
+
annotate=annotate,
|
|
554
|
+
fmt=fmt,
|
|
555
|
+
mask_insignificant=mask_insignificant,
|
|
556
|
+
alpha=alpha,
|
|
557
|
+
ax=ax,
|
|
558
|
+
show=show,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _extract_heatmap_data(results, data):
|
|
563
|
+
"""Extract group-time effects into a 2D matrix.
|
|
564
|
+
|
|
565
|
+
Returns
|
|
566
|
+
-------
|
|
567
|
+
effect_matrix : np.ndarray
|
|
568
|
+
2D array of effects (groups x time).
|
|
569
|
+
p_matrix : np.ndarray or None
|
|
570
|
+
2D array of p-values, or None if unavailable.
|
|
571
|
+
group_labels : list
|
|
572
|
+
Sorted group labels.
|
|
573
|
+
time_labels : list
|
|
574
|
+
Sorted time labels.
|
|
575
|
+
"""
|
|
576
|
+
if results is not None and data is not None:
|
|
577
|
+
raise ValueError("Provide either 'results' or 'data', not both.")
|
|
578
|
+
|
|
579
|
+
if results is not None:
|
|
580
|
+
if not hasattr(results, "group_time_effects"):
|
|
581
|
+
raise TypeError(f"{type(results).__name__} does not have group_time_effects attribute")
|
|
582
|
+
gte = results.group_time_effects
|
|
583
|
+
if not gte:
|
|
584
|
+
raise ValueError("group_time_effects is empty — nothing to plot.")
|
|
585
|
+
|
|
586
|
+
groups = sorted(set(g for g, t in gte.keys()))
|
|
587
|
+
times = sorted(set(t for g, t in gte.keys()))
|
|
588
|
+
|
|
589
|
+
effect_matrix = np.full((len(groups), len(times)), np.nan)
|
|
590
|
+
p_matrix = np.full((len(groups), len(times)), np.nan)
|
|
591
|
+
|
|
592
|
+
group_idx = {g: i for i, g in enumerate(groups)}
|
|
593
|
+
time_idx = {t: j for j, t in enumerate(times)}
|
|
594
|
+
|
|
595
|
+
for (g, t), eff_data in gte.items():
|
|
596
|
+
i, j = group_idx[g], time_idx[t]
|
|
597
|
+
# Handle different result type structures
|
|
598
|
+
if "effect" in eff_data:
|
|
599
|
+
effect_matrix[i, j] = eff_data["effect"]
|
|
600
|
+
elif "att_glob" in eff_data:
|
|
601
|
+
effect_matrix[i, j] = eff_data["att_glob"]
|
|
602
|
+
if "p_value" in eff_data:
|
|
603
|
+
p_matrix[i, j] = eff_data["p_value"]
|
|
604
|
+
|
|
605
|
+
has_p = np.any(np.isfinite(p_matrix))
|
|
606
|
+
return effect_matrix, p_matrix if has_p else None, groups, times
|
|
607
|
+
|
|
608
|
+
if data is not None:
|
|
609
|
+
required = {"group", "time", "effect"}
|
|
610
|
+
missing = required - set(data.columns)
|
|
611
|
+
if missing:
|
|
612
|
+
raise ValueError(f"DataFrame missing required columns: {missing}")
|
|
613
|
+
|
|
614
|
+
pivot = data.pivot(index="group", columns="time", values="effect")
|
|
615
|
+
pivot = pivot.sort_index(axis=0).sort_index(axis=1)
|
|
616
|
+
|
|
617
|
+
p_matrix = None
|
|
618
|
+
if "p_value" in data.columns:
|
|
619
|
+
p_pivot = data.pivot(index="group", columns="time", values="p_value")
|
|
620
|
+
p_pivot = p_pivot.sort_index(axis=0).sort_index(axis=1)
|
|
621
|
+
p_matrix = p_pivot.values
|
|
622
|
+
|
|
623
|
+
return pivot.values, p_matrix, list(pivot.index), list(pivot.columns)
|
|
624
|
+
|
|
625
|
+
raise ValueError("Must provide either 'results' or 'data'.")
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def _render_group_time_heatmap_mpl(
|
|
629
|
+
*,
|
|
630
|
+
effect_matrix,
|
|
631
|
+
p_matrix,
|
|
632
|
+
group_labels,
|
|
633
|
+
time_labels,
|
|
634
|
+
figsize,
|
|
635
|
+
title,
|
|
636
|
+
cmap,
|
|
637
|
+
center,
|
|
638
|
+
annotate,
|
|
639
|
+
fmt,
|
|
640
|
+
mask_insignificant,
|
|
641
|
+
alpha,
|
|
642
|
+
ax,
|
|
643
|
+
show,
|
|
644
|
+
):
|
|
645
|
+
"""Render group-time heatmap with matplotlib."""
|
|
646
|
+
from diff_diff.visualization._common import _require_matplotlib
|
|
647
|
+
|
|
648
|
+
plt = _require_matplotlib()
|
|
649
|
+
from matplotlib.colors import TwoSlopeNorm
|
|
650
|
+
|
|
651
|
+
if ax is None:
|
|
652
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
653
|
+
else:
|
|
654
|
+
fig = ax.get_figure()
|
|
655
|
+
|
|
656
|
+
display_matrix = effect_matrix.copy()
|
|
657
|
+
|
|
658
|
+
# Build significance mask
|
|
659
|
+
sig_mask = None
|
|
660
|
+
if mask_insignificant and p_matrix is not None:
|
|
661
|
+
sig_mask = p_matrix > alpha
|
|
662
|
+
|
|
663
|
+
# Compute color normalization centered at `center`
|
|
664
|
+
finite_vals = effect_matrix[np.isfinite(effect_matrix)]
|
|
665
|
+
vmin = center - 0.01
|
|
666
|
+
vmax = center + 0.01
|
|
667
|
+
if len(finite_vals) > 0:
|
|
668
|
+
vmin = np.nanmin(finite_vals)
|
|
669
|
+
vmax = np.nanmax(finite_vals)
|
|
670
|
+
# Ensure center is between vmin and vmax
|
|
671
|
+
if vmin >= center:
|
|
672
|
+
vmin = center - 0.01
|
|
673
|
+
if vmax <= center:
|
|
674
|
+
vmax = center + 0.01
|
|
675
|
+
norm = TwoSlopeNorm(vmin=vmin, vcenter=center, vmax=vmax)
|
|
676
|
+
|
|
677
|
+
im = ax.imshow(display_matrix, cmap=cmap, norm=norm, aspect="auto")
|
|
678
|
+
|
|
679
|
+
# Add colorbar
|
|
680
|
+
fig.colorbar(im, ax=ax, label="Treatment Effect")
|
|
681
|
+
|
|
682
|
+
# Set ticks
|
|
683
|
+
ax.set_xticks(range(len(time_labels)))
|
|
684
|
+
ax.set_xticklabels([str(t) for t in time_labels], rotation=45, ha="right")
|
|
685
|
+
ax.set_yticks(range(len(group_labels)))
|
|
686
|
+
ax.set_yticklabels([str(g) for g in group_labels])
|
|
687
|
+
|
|
688
|
+
ax.set_xlabel("Time Period")
|
|
689
|
+
ax.set_ylabel("Treatment Cohort")
|
|
690
|
+
ax.set_title(title)
|
|
691
|
+
|
|
692
|
+
# Annotate cells
|
|
693
|
+
if annotate:
|
|
694
|
+
for i in range(len(group_labels)):
|
|
695
|
+
for j in range(len(time_labels)):
|
|
696
|
+
val = effect_matrix[i, j]
|
|
697
|
+
if np.isnan(val):
|
|
698
|
+
continue
|
|
699
|
+
is_masked = sig_mask is not None and sig_mask[i, j]
|
|
700
|
+
text_color = (
|
|
701
|
+
"gray"
|
|
702
|
+
if is_masked
|
|
703
|
+
else ("white" if abs(val - center) > (vmax - vmin) * 0.3 else "black")
|
|
704
|
+
)
|
|
705
|
+
ax.text(
|
|
706
|
+
j,
|
|
707
|
+
i,
|
|
708
|
+
f"{val:{fmt}}",
|
|
709
|
+
ha="center",
|
|
710
|
+
va="center",
|
|
711
|
+
fontsize=7,
|
|
712
|
+
color=text_color,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
# Grey out insignificant cells
|
|
716
|
+
if sig_mask is not None:
|
|
717
|
+
for i in range(sig_mask.shape[0]):
|
|
718
|
+
for j in range(sig_mask.shape[1]):
|
|
719
|
+
if sig_mask[i, j]:
|
|
720
|
+
ax.add_patch(
|
|
721
|
+
plt.Rectangle(
|
|
722
|
+
(j - 0.5, i - 0.5),
|
|
723
|
+
1,
|
|
724
|
+
1,
|
|
725
|
+
fill=True,
|
|
726
|
+
facecolor="white",
|
|
727
|
+
alpha=0.6,
|
|
728
|
+
edgecolor="none",
|
|
729
|
+
)
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
fig.tight_layout()
|
|
733
|
+
|
|
734
|
+
if show:
|
|
735
|
+
plt.show()
|
|
736
|
+
|
|
737
|
+
return ax
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def _render_group_time_heatmap_plotly(
|
|
741
|
+
*,
|
|
742
|
+
effect_matrix,
|
|
743
|
+
p_matrix,
|
|
744
|
+
group_labels,
|
|
745
|
+
time_labels,
|
|
746
|
+
title,
|
|
747
|
+
cmap,
|
|
748
|
+
center,
|
|
749
|
+
annotate,
|
|
750
|
+
fmt,
|
|
751
|
+
mask_insignificant,
|
|
752
|
+
alpha,
|
|
753
|
+
show,
|
|
754
|
+
):
|
|
755
|
+
"""Render group-time heatmap with plotly."""
|
|
756
|
+
from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
|
|
757
|
+
|
|
758
|
+
go = _require_plotly()
|
|
759
|
+
|
|
760
|
+
# Pass cmap name through to plotly unchanged — plotly supports the same
|
|
761
|
+
# diverging colorscale names as matplotlib (RdBu, RdBu_r, etc.)
|
|
762
|
+
plotly_cmap = cmap
|
|
763
|
+
|
|
764
|
+
# Build text annotations
|
|
765
|
+
text = None
|
|
766
|
+
if annotate:
|
|
767
|
+
text = []
|
|
768
|
+
for i in range(effect_matrix.shape[0]):
|
|
769
|
+
row = []
|
|
770
|
+
for j in range(effect_matrix.shape[1]):
|
|
771
|
+
val = effect_matrix[i, j]
|
|
772
|
+
if np.isnan(val):
|
|
773
|
+
row.append("")
|
|
774
|
+
else:
|
|
775
|
+
row.append(f"{val:{fmt}}")
|
|
776
|
+
text.append(row)
|
|
777
|
+
|
|
778
|
+
# Build significance mask for overlay (do NOT replace with NaN — that
|
|
779
|
+
# conflates "insignificant" with "missing cell")
|
|
780
|
+
sig_mask = None
|
|
781
|
+
if mask_insignificant and p_matrix is not None:
|
|
782
|
+
sig_mask = p_matrix > alpha
|
|
783
|
+
|
|
784
|
+
# Center the colorscale
|
|
785
|
+
finite_vals = effect_matrix[np.isfinite(effect_matrix)]
|
|
786
|
+
if len(finite_vals) > 0:
|
|
787
|
+
abs_max = max(abs(np.nanmin(finite_vals) - center), abs(np.nanmax(finite_vals) - center))
|
|
788
|
+
zmin = center - abs_max
|
|
789
|
+
zmax = center + abs_max
|
|
790
|
+
else:
|
|
791
|
+
zmin, zmax = -1, 1
|
|
792
|
+
|
|
793
|
+
# Main heatmap — always shows all values (insignificant cells greyed via opacity)
|
|
794
|
+
fig = go.Figure(
|
|
795
|
+
data=go.Heatmap(
|
|
796
|
+
z=effect_matrix,
|
|
797
|
+
x=[str(t) for t in time_labels],
|
|
798
|
+
y=[str(g) for g in group_labels],
|
|
799
|
+
colorscale=plotly_cmap,
|
|
800
|
+
zmin=zmin,
|
|
801
|
+
zmax=zmax,
|
|
802
|
+
text=text,
|
|
803
|
+
texttemplate="%{text}" if annotate else None,
|
|
804
|
+
colorbar=dict(title="Effect"),
|
|
805
|
+
)
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Grey overlay for insignificant cells (preserves underlying value)
|
|
809
|
+
if sig_mask is not None and np.any(sig_mask):
|
|
810
|
+
grey_z = np.where(sig_mask, 1.0, np.nan)
|
|
811
|
+
fig.add_trace(
|
|
812
|
+
go.Heatmap(
|
|
813
|
+
z=grey_z,
|
|
814
|
+
x=[str(t) for t in time_labels],
|
|
815
|
+
y=[str(g) for g in group_labels],
|
|
816
|
+
colorscale=[[0, "rgba(255,255,255,0.6)"], [1, "rgba(255,255,255,0.6)"]],
|
|
817
|
+
showscale=False,
|
|
818
|
+
hoverinfo="skip",
|
|
819
|
+
)
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
_plotly_default_layout(
|
|
823
|
+
fig,
|
|
824
|
+
title=title,
|
|
825
|
+
xlabel="Time Period",
|
|
826
|
+
ylabel="Treatment Cohort",
|
|
827
|
+
show_legend=False,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
if show:
|
|
831
|
+
fig.show()
|
|
832
|
+
|
|
833
|
+
return fig
|