diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,833 @@
1
+ """Staggered DiD visualization functions (group effects, staircase, heatmap)."""
2
+
3
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ if TYPE_CHECKING:
9
+ from diff_diff.continuous_did_results import ContinuousDiDResults
10
+ from diff_diff.efficient_did_results import EfficientDiDResults
11
+ from diff_diff.staggered import CallawaySantAnnaResults
12
+
13
+
14
+ def plot_group_effects(
15
+ results: "CallawaySantAnnaResults",
16
+ *,
17
+ groups: Optional[List[Any]] = None,
18
+ figsize: Tuple[float, float] = (10, 6),
19
+ title: str = "Treatment Effects by Cohort",
20
+ xlabel: str = "Time Period",
21
+ ylabel: str = "Treatment Effect",
22
+ alpha: float = 0.05,
23
+ show: bool = True,
24
+ ax: Optional[Any] = None,
25
+ backend: str = "matplotlib",
26
+ ) -> Any:
27
+ """
28
+ Plot treatment effects by treatment cohort (group).
29
+
30
+ Parameters
31
+ ----------
32
+ results : CallawaySantAnnaResults
33
+ Results from CallawaySantAnna estimator.
34
+ groups : list, optional
35
+ List of groups (cohorts) to plot. If None, plots all groups.
36
+ figsize : tuple, default=(10, 6)
37
+ Figure size.
38
+ title : str
39
+ Plot title.
40
+ xlabel : str
41
+ X-axis label.
42
+ ylabel : str
43
+ Y-axis label.
44
+ alpha : float, default=0.05
45
+ Significance level for confidence intervals.
46
+ show : bool, default=True
47
+ Whether to call plt.show().
48
+ ax : matplotlib.axes.Axes, optional
49
+ Axes to plot on.
50
+ backend : str, default="matplotlib"
51
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
52
+
53
+ Returns
54
+ -------
55
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
56
+ The axes object (matplotlib) or figure (plotly).
57
+ """
58
+ from scipy import stats as scipy_stats
59
+
60
+ if not hasattr(results, "group_time_effects"):
61
+ raise TypeError("results must be a CallawaySantAnnaResults object")
62
+
63
+ # Get groups to plot
64
+ if groups is None:
65
+ groups = sorted(set(g for g, t in results.group_time_effects.keys()))
66
+
67
+ critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
68
+
69
+ # Build data per group
70
+ group_data = {}
71
+ for group in groups:
72
+ group_effects = [
73
+ (t, data) for (g, t), data in results.group_time_effects.items() if g == group
74
+ ]
75
+ group_effects.sort(key=lambda x: x[0])
76
+ if not group_effects:
77
+ continue
78
+ group_data[group] = group_effects
79
+
80
+ if backend == "plotly":
81
+ return _render_group_effects_plotly(
82
+ group_data=group_data,
83
+ groups=groups,
84
+ critical_value=critical_value,
85
+ title=title,
86
+ xlabel=xlabel,
87
+ ylabel=ylabel,
88
+ show=show,
89
+ )
90
+
91
+ return _render_group_effects_mpl(
92
+ group_data=group_data,
93
+ groups=groups,
94
+ critical_value=critical_value,
95
+ figsize=figsize,
96
+ title=title,
97
+ xlabel=xlabel,
98
+ ylabel=ylabel,
99
+ ax=ax,
100
+ show=show,
101
+ )
102
+
103
+
104
+ def _render_group_effects_mpl(
105
+ *, group_data, groups, critical_value, figsize, title, xlabel, ylabel, ax, show
106
+ ):
107
+ """Render group effects plot with matplotlib."""
108
+ from diff_diff.visualization._common import _require_matplotlib
109
+
110
+ plt = _require_matplotlib()
111
+
112
+ if ax is None:
113
+ fig, ax = plt.subplots(figsize=figsize)
114
+ else:
115
+ fig = ax.get_figure()
116
+
117
+ cmap = getattr(plt.cm, "tab10", None) or plt.colormaps["tab10"]
118
+ colors = cmap(np.linspace(0, 1, len(groups)))
119
+
120
+ for i, group in enumerate(groups):
121
+ if group not in group_data:
122
+ continue
123
+ group_effects = group_data[group]
124
+ times = [t for t, _ in group_effects]
125
+ effects = [data["effect"] for _, data in group_effects]
126
+ ses = [data["se"] for _, data in group_effects]
127
+
128
+ yerr = [
129
+ [e - (e - critical_value * s) for e, s in zip(effects, ses)],
130
+ [(e + critical_value * s) - e for e, s in zip(effects, ses)],
131
+ ]
132
+
133
+ ax.errorbar(
134
+ times,
135
+ effects,
136
+ yerr=yerr,
137
+ label=f"Cohort {group}",
138
+ color=colors[i],
139
+ marker="o",
140
+ capsize=3,
141
+ linewidth=1.5,
142
+ )
143
+
144
+ ax.axhline(y=0, color="gray", linestyle="--", linewidth=1)
145
+ ax.set_xlabel(xlabel)
146
+ ax.set_ylabel(ylabel)
147
+ ax.set_title(title)
148
+ ax.legend(loc="best")
149
+ ax.grid(True, alpha=0.3, axis="y")
150
+
151
+ fig.tight_layout()
152
+
153
+ if show:
154
+ plt.show()
155
+
156
+ return ax
157
+
158
+
159
+ def _render_group_effects_plotly(
160
+ *, group_data, groups, critical_value, title, xlabel, ylabel, show
161
+ ):
162
+ """Render group effects plot with plotly."""
163
+ from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
164
+
165
+ go = _require_plotly()
166
+
167
+ fig = go.Figure()
168
+
169
+ # Zero line
170
+ fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1)
171
+
172
+ for group in groups:
173
+ if group not in group_data:
174
+ continue
175
+ group_effects = group_data[group]
176
+ times = [t for t, _ in group_effects]
177
+ effects = [data["effect"] for _, data in group_effects]
178
+ ses = [data["se"] for _, data in group_effects]
179
+
180
+ ci_lo = [e - critical_value * s for e, s in zip(effects, ses)]
181
+ ci_hi = [e + critical_value * s for e, s in zip(effects, ses)]
182
+
183
+ fig.add_trace(
184
+ go.Scatter(
185
+ x=times,
186
+ y=effects,
187
+ mode="lines+markers",
188
+ name=f"Cohort {group}",
189
+ error_y=dict(
190
+ type="data",
191
+ symmetric=False,
192
+ array=[h - e for e, h in zip(effects, ci_hi)],
193
+ arrayminus=[e - lo for e, lo in zip(effects, ci_lo)],
194
+ ),
195
+ )
196
+ )
197
+
198
+ _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
199
+
200
+ if show:
201
+ fig.show()
202
+
203
+ return fig
204
+
205
+
206
+ def plot_staircase(
207
+ results: Optional["CallawaySantAnnaResults"] = None,
208
+ *,
209
+ data: Optional[pd.DataFrame] = None,
210
+ unit: Optional[str] = None,
211
+ time: Optional[str] = None,
212
+ first_treat: Optional[str] = None,
213
+ figsize: Tuple[float, float] = (10, 6),
214
+ title: str = "Treatment Adoption Over Time",
215
+ color: str = "#2563eb",
216
+ show_counts: bool = True,
217
+ ax: Optional[Any] = None,
218
+ show: bool = True,
219
+ backend: str = "matplotlib",
220
+ ) -> Any:
221
+ """
222
+ Plot treatment adoption "staircase" for staggered designs.
223
+
224
+ Shows how many units enter treatment over time, creating a step-function
225
+ pattern that illustrates the staggered adoption of treatment.
226
+
227
+ Parameters
228
+ ----------
229
+ results : CallawaySantAnnaResults, optional
230
+ Results from CallawaySantAnna estimator. Extracts groups and cohort
231
+ sizes from ``group_time_effects``.
232
+ data : pd.DataFrame, optional
233
+ Raw panel data. Must provide ``unit``, ``time``, and ``first_treat``
234
+ column names.
235
+ unit : str, optional
236
+ Column name for unit identifier (required with ``data``).
237
+ time : str, optional
238
+ Column name for time period (required with ``data``).
239
+ first_treat : str, optional
240
+ Column name for first treatment period (required with ``data``).
241
+ figsize : tuple, default=(10, 6)
242
+ Figure size (width, height) in inches.
243
+ title : str, default="Treatment Adoption Over Time"
244
+ Plot title.
245
+ color : str, default="#2563eb"
246
+ Base color for the staircase.
247
+ show_counts : bool, default=True
248
+ Whether to annotate each step with the cohort size.
249
+ ax : matplotlib.axes.Axes, optional
250
+ Axes to plot on. If None, creates new figure.
251
+ show : bool, default=True
252
+ Whether to call plt.show() at the end.
253
+ backend : str, default="matplotlib"
254
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
255
+
256
+ Returns
257
+ -------
258
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
259
+ The axes object (matplotlib) or figure (plotly).
260
+ """
261
+ # Extract cohort data
262
+ cohort_counts = _extract_staircase_data(results, data, unit, time, first_treat)
263
+
264
+ if backend == "plotly":
265
+ return _render_staircase_plotly(
266
+ cohort_counts=cohort_counts,
267
+ title=title,
268
+ color=color,
269
+ show_counts=show_counts,
270
+ show=show,
271
+ )
272
+
273
+ return _render_staircase_mpl(
274
+ cohort_counts=cohort_counts,
275
+ figsize=figsize,
276
+ title=title,
277
+ color=color,
278
+ show_counts=show_counts,
279
+ ax=ax,
280
+ show=show,
281
+ )
282
+
283
+
284
+ def _extract_staircase_data(results, data, unit, time, first_treat):
285
+ """Extract cohort periods and counts for the staircase plot.
286
+
287
+ Returns
288
+ -------
289
+ list of (period, count) tuples, sorted by period.
290
+ """
291
+ if results is not None and data is not None:
292
+ raise ValueError("Provide either 'results' or 'data', not both.")
293
+
294
+ if results is not None:
295
+ if not hasattr(results, "group_time_effects") or not hasattr(results, "groups"):
296
+ raise TypeError("results must be a CallawaySantAnnaResults object")
297
+
298
+ groups = sorted(results.groups)
299
+ cohort_counts = []
300
+ for g in groups:
301
+ # Collect n_treated across all (g, t) cells for this cohort.
302
+ # n_treated is a per-cell observation count that can vary with
303
+ # missingness, so we use the max as the best cohort size estimate.
304
+ cell_counts = []
305
+ for (gg, _tt), eff in results.group_time_effects.items():
306
+ if gg == g:
307
+ n = eff.get("n_treated", eff.get("n_obs", None))
308
+ if n is not None:
309
+ cell_counts.append(int(n))
310
+ if not cell_counts:
311
+ cohort_counts.append((g, 0))
312
+ continue
313
+ max_count = max(cell_counts)
314
+ if min(cell_counts) != max_count:
315
+ import warnings
316
+
317
+ warnings.warn(
318
+ f"Cohort {g}: n_treated varies across cells "
319
+ f"({min(cell_counts)}-{max_count}). "
320
+ f"Using max as cohort size; pass data= for exact counts.",
321
+ stacklevel=3,
322
+ )
323
+ cohort_counts.append((g, max_count))
324
+
325
+ return cohort_counts
326
+
327
+ if data is not None:
328
+ if unit is None or time is None or first_treat is None:
329
+ raise ValueError(
330
+ "When using 'data', must provide 'unit', 'time', and 'first_treat' column names."
331
+ )
332
+ # Count unique units per first_treat cohort
333
+ cohort_df = data.groupby(first_treat)[unit].nunique().reset_index()
334
+ cohort_df.columns = ["period", "count"]
335
+ cohort_df = cohort_df.sort_values("period")
336
+ # Exclude never-treated (inf, NaN, or 0 conventions)
337
+ cohort_df = cohort_df[
338
+ cohort_df["period"].notna()
339
+ & np.isfinite(cohort_df["period"])
340
+ & (cohort_df["period"] > 0)
341
+ ]
342
+ return list(zip(cohort_df["period"], cohort_df["count"]))
343
+
344
+ raise ValueError("Must provide either 'results' or 'data'.")
345
+
346
+
347
+ def _render_staircase_mpl(*, cohort_counts, figsize, title, color, show_counts, ax, show):
348
+ """Render staircase plot with matplotlib."""
349
+ from diff_diff.visualization._common import _require_matplotlib
350
+
351
+ plt = _require_matplotlib()
352
+
353
+ if ax is None:
354
+ fig, ax = plt.subplots(figsize=figsize)
355
+ else:
356
+ fig = ax.get_figure()
357
+
358
+ if not cohort_counts:
359
+ ax.set_title(title)
360
+ ax.text(0.5, 0.5, "No treatment cohorts", ha="center", va="center", transform=ax.transAxes)
361
+ if show:
362
+ plt.show()
363
+ return ax
364
+
365
+ periods = [p for p, _ in cohort_counts]
366
+ counts = [c for _, c in cohort_counts]
367
+ cumulative = np.cumsum(counts)
368
+
369
+ # Create step plot
370
+ ax.step(periods, cumulative, where="post", color=color, linewidth=2, label="Cumulative treated")
371
+ ax.fill_between(periods, cumulative, step="post", alpha=0.15, color=color)
372
+
373
+ # Annotate cohort sizes
374
+ if show_counts:
375
+ for i, (period, count) in enumerate(cohort_counts):
376
+ cum = cumulative[i]
377
+ ax.annotate(
378
+ f"+{count}",
379
+ xy=(period, cum),
380
+ xytext=(0, 8),
381
+ textcoords="offset points",
382
+ ha="center",
383
+ fontsize=9,
384
+ color=color,
385
+ fontweight="bold",
386
+ )
387
+
388
+ ax.set_xlabel("Time Period")
389
+ ax.set_ylabel("Cumulative Treated Units")
390
+ ax.set_title(title)
391
+ ax.grid(True, alpha=0.3, axis="y")
392
+
393
+ # Set y to start at 0
394
+ ax.set_ylim(bottom=0)
395
+
396
+ fig.tight_layout()
397
+
398
+ if show:
399
+ plt.show()
400
+
401
+ return ax
402
+
403
+
404
+ def _render_staircase_plotly(*, cohort_counts, title, color, show_counts, show):
405
+ """Render staircase plot with plotly."""
406
+ from diff_diff.visualization._common import (
407
+ _color_to_rgba,
408
+ _plotly_default_layout,
409
+ _require_plotly,
410
+ )
411
+
412
+ go = _require_plotly()
413
+
414
+ fig = go.Figure()
415
+
416
+ if not cohort_counts:
417
+ fig.add_annotation(text="No treatment cohorts", x=0.5, y=0.5, showarrow=False)
418
+ _plotly_default_layout(fig, title=title)
419
+ if show:
420
+ fig.show()
421
+ return fig
422
+
423
+ periods = [p for p, _ in cohort_counts]
424
+ counts = [c for _, c in cohort_counts]
425
+ cumulative = list(np.cumsum(counts))
426
+
427
+ # Step line
428
+ fig.add_trace(
429
+ go.Scatter(
430
+ x=periods,
431
+ y=cumulative,
432
+ mode="lines",
433
+ line=dict(color=color, width=2, shape="hv"),
434
+ fill="tozeroy",
435
+ fillcolor=_color_to_rgba(color, 0.15),
436
+ name="Cumulative treated",
437
+ )
438
+ )
439
+
440
+ # Annotations for cohort sizes
441
+ if show_counts:
442
+ for period, count, cum in zip(periods, counts, cumulative):
443
+ fig.add_annotation(
444
+ x=period,
445
+ y=cum,
446
+ text=f"+{count}",
447
+ showarrow=False,
448
+ yshift=15,
449
+ font=dict(color=color, size=11),
450
+ )
451
+
452
+ _plotly_default_layout(
453
+ fig,
454
+ title=title,
455
+ xlabel="Time Period",
456
+ ylabel="Cumulative Treated Units",
457
+ )
458
+ fig.update_yaxes(rangemode="tozero")
459
+
460
+ if show:
461
+ fig.show()
462
+
463
+ return fig
464
+
465
+
466
+ def plot_group_time_heatmap(
467
+ results: Optional[
468
+ Union["CallawaySantAnnaResults", "EfficientDiDResults", "ContinuousDiDResults"]
469
+ ] = None,
470
+ *,
471
+ data: Optional[pd.DataFrame] = None,
472
+ figsize: Tuple[float, float] = (10, 8),
473
+ title: str = "Group-Time Treatment Effects",
474
+ cmap: str = "RdBu_r",
475
+ center: float = 0.0,
476
+ annotate: bool = True,
477
+ fmt: str = ".3f",
478
+ mask_insignificant: bool = False,
479
+ alpha: float = 0.05,
480
+ ax: Optional[Any] = None,
481
+ show: bool = True,
482
+ backend: str = "matplotlib",
483
+ ) -> Any:
484
+ """
485
+ Plot heatmap of group-time treatment effects ATT(g,t).
486
+
487
+ Displays treatment effects as a colored matrix with treatment cohorts
488
+ (groups) on the y-axis and calendar time periods on the x-axis.
489
+
490
+ Parameters
491
+ ----------
492
+ results : CallawaySantAnnaResults, EfficientDiDResults, or ContinuousDiDResults, optional
493
+ Results object with ``group_time_effects`` dict.
494
+ data : pd.DataFrame, optional
495
+ DataFrame with columns ``group``, ``time``, ``effect``
496
+ (and optionally ``p_value``).
497
+ figsize : tuple, default=(10, 8)
498
+ Figure size (width, height) in inches.
499
+ title : str, default="Group-Time Treatment Effects"
500
+ Plot title.
501
+ cmap : str, default="RdBu_r"
502
+ Colormap name. Diverging colormaps centered at zero work best.
503
+ center : float, default=0.0
504
+ Value to center the colormap at.
505
+ annotate : bool, default=True
506
+ Whether to show effect values in each cell.
507
+ fmt : str, default=".3f"
508
+ Format string for cell annotations.
509
+ mask_insignificant : bool, default=False
510
+ Whether to grey out cells with non-significant effects.
511
+ alpha : float, default=0.05
512
+ Significance level for masking (when ``mask_insignificant=True``).
513
+ ax : matplotlib.axes.Axes, optional
514
+ Axes to plot on. If None, creates new figure.
515
+ show : bool, default=True
516
+ Whether to call plt.show() at the end.
517
+ backend : str, default="matplotlib"
518
+ Plotting backend: ``"matplotlib"`` or ``"plotly"``.
519
+
520
+ Returns
521
+ -------
522
+ matplotlib.axes.Axes or plotly.graph_objects.Figure
523
+ The axes object (matplotlib) or figure (plotly).
524
+ """
525
+ # Extract data into matrix form
526
+ effect_matrix, p_matrix, group_labels, time_labels = _extract_heatmap_data(results, data)
527
+
528
+ if backend == "plotly":
529
+ return _render_group_time_heatmap_plotly(
530
+ effect_matrix=effect_matrix,
531
+ p_matrix=p_matrix,
532
+ group_labels=group_labels,
533
+ time_labels=time_labels,
534
+ title=title,
535
+ cmap=cmap,
536
+ center=center,
537
+ annotate=annotate,
538
+ fmt=fmt,
539
+ mask_insignificant=mask_insignificant,
540
+ alpha=alpha,
541
+ show=show,
542
+ )
543
+
544
+ return _render_group_time_heatmap_mpl(
545
+ effect_matrix=effect_matrix,
546
+ p_matrix=p_matrix,
547
+ group_labels=group_labels,
548
+ time_labels=time_labels,
549
+ figsize=figsize,
550
+ title=title,
551
+ cmap=cmap,
552
+ center=center,
553
+ annotate=annotate,
554
+ fmt=fmt,
555
+ mask_insignificant=mask_insignificant,
556
+ alpha=alpha,
557
+ ax=ax,
558
+ show=show,
559
+ )
560
+
561
+
562
+ def _extract_heatmap_data(results, data):
563
+ """Extract group-time effects into a 2D matrix.
564
+
565
+ Returns
566
+ -------
567
+ effect_matrix : np.ndarray
568
+ 2D array of effects (groups x time).
569
+ p_matrix : np.ndarray or None
570
+ 2D array of p-values, or None if unavailable.
571
+ group_labels : list
572
+ Sorted group labels.
573
+ time_labels : list
574
+ Sorted time labels.
575
+ """
576
+ if results is not None and data is not None:
577
+ raise ValueError("Provide either 'results' or 'data', not both.")
578
+
579
+ if results is not None:
580
+ if not hasattr(results, "group_time_effects"):
581
+ raise TypeError(f"{type(results).__name__} does not have group_time_effects attribute")
582
+ gte = results.group_time_effects
583
+ if not gte:
584
+ raise ValueError("group_time_effects is empty — nothing to plot.")
585
+
586
+ groups = sorted(set(g for g, t in gte.keys()))
587
+ times = sorted(set(t for g, t in gte.keys()))
588
+
589
+ effect_matrix = np.full((len(groups), len(times)), np.nan)
590
+ p_matrix = np.full((len(groups), len(times)), np.nan)
591
+
592
+ group_idx = {g: i for i, g in enumerate(groups)}
593
+ time_idx = {t: j for j, t in enumerate(times)}
594
+
595
+ for (g, t), eff_data in gte.items():
596
+ i, j = group_idx[g], time_idx[t]
597
+ # Handle different result type structures
598
+ if "effect" in eff_data:
599
+ effect_matrix[i, j] = eff_data["effect"]
600
+ elif "att_glob" in eff_data:
601
+ effect_matrix[i, j] = eff_data["att_glob"]
602
+ if "p_value" in eff_data:
603
+ p_matrix[i, j] = eff_data["p_value"]
604
+
605
+ has_p = np.any(np.isfinite(p_matrix))
606
+ return effect_matrix, p_matrix if has_p else None, groups, times
607
+
608
+ if data is not None:
609
+ required = {"group", "time", "effect"}
610
+ missing = required - set(data.columns)
611
+ if missing:
612
+ raise ValueError(f"DataFrame missing required columns: {missing}")
613
+
614
+ pivot = data.pivot(index="group", columns="time", values="effect")
615
+ pivot = pivot.sort_index(axis=0).sort_index(axis=1)
616
+
617
+ p_matrix = None
618
+ if "p_value" in data.columns:
619
+ p_pivot = data.pivot(index="group", columns="time", values="p_value")
620
+ p_pivot = p_pivot.sort_index(axis=0).sort_index(axis=1)
621
+ p_matrix = p_pivot.values
622
+
623
+ return pivot.values, p_matrix, list(pivot.index), list(pivot.columns)
624
+
625
+ raise ValueError("Must provide either 'results' or 'data'.")
626
+
627
+
628
+ def _render_group_time_heatmap_mpl(
629
+ *,
630
+ effect_matrix,
631
+ p_matrix,
632
+ group_labels,
633
+ time_labels,
634
+ figsize,
635
+ title,
636
+ cmap,
637
+ center,
638
+ annotate,
639
+ fmt,
640
+ mask_insignificant,
641
+ alpha,
642
+ ax,
643
+ show,
644
+ ):
645
+ """Render group-time heatmap with matplotlib."""
646
+ from diff_diff.visualization._common import _require_matplotlib
647
+
648
+ plt = _require_matplotlib()
649
+ from matplotlib.colors import TwoSlopeNorm
650
+
651
+ if ax is None:
652
+ fig, ax = plt.subplots(figsize=figsize)
653
+ else:
654
+ fig = ax.get_figure()
655
+
656
+ display_matrix = effect_matrix.copy()
657
+
658
+ # Build significance mask
659
+ sig_mask = None
660
+ if mask_insignificant and p_matrix is not None:
661
+ sig_mask = p_matrix > alpha
662
+
663
+ # Compute color normalization centered at `center`
664
+ finite_vals = effect_matrix[np.isfinite(effect_matrix)]
665
+ vmin = center - 0.01
666
+ vmax = center + 0.01
667
+ if len(finite_vals) > 0:
668
+ vmin = np.nanmin(finite_vals)
669
+ vmax = np.nanmax(finite_vals)
670
+ # Ensure center is between vmin and vmax
671
+ if vmin >= center:
672
+ vmin = center - 0.01
673
+ if vmax <= center:
674
+ vmax = center + 0.01
675
+ norm = TwoSlopeNorm(vmin=vmin, vcenter=center, vmax=vmax)
676
+
677
+ im = ax.imshow(display_matrix, cmap=cmap, norm=norm, aspect="auto")
678
+
679
+ # Add colorbar
680
+ fig.colorbar(im, ax=ax, label="Treatment Effect")
681
+
682
+ # Set ticks
683
+ ax.set_xticks(range(len(time_labels)))
684
+ ax.set_xticklabels([str(t) for t in time_labels], rotation=45, ha="right")
685
+ ax.set_yticks(range(len(group_labels)))
686
+ ax.set_yticklabels([str(g) for g in group_labels])
687
+
688
+ ax.set_xlabel("Time Period")
689
+ ax.set_ylabel("Treatment Cohort")
690
+ ax.set_title(title)
691
+
692
+ # Annotate cells
693
+ if annotate:
694
+ for i in range(len(group_labels)):
695
+ for j in range(len(time_labels)):
696
+ val = effect_matrix[i, j]
697
+ if np.isnan(val):
698
+ continue
699
+ is_masked = sig_mask is not None and sig_mask[i, j]
700
+ text_color = (
701
+ "gray"
702
+ if is_masked
703
+ else ("white" if abs(val - center) > (vmax - vmin) * 0.3 else "black")
704
+ )
705
+ ax.text(
706
+ j,
707
+ i,
708
+ f"{val:{fmt}}",
709
+ ha="center",
710
+ va="center",
711
+ fontsize=7,
712
+ color=text_color,
713
+ )
714
+
715
+ # Grey out insignificant cells
716
+ if sig_mask is not None:
717
+ for i in range(sig_mask.shape[0]):
718
+ for j in range(sig_mask.shape[1]):
719
+ if sig_mask[i, j]:
720
+ ax.add_patch(
721
+ plt.Rectangle(
722
+ (j - 0.5, i - 0.5),
723
+ 1,
724
+ 1,
725
+ fill=True,
726
+ facecolor="white",
727
+ alpha=0.6,
728
+ edgecolor="none",
729
+ )
730
+ )
731
+
732
+ fig.tight_layout()
733
+
734
+ if show:
735
+ plt.show()
736
+
737
+ return ax
738
+
739
+
740
+ def _render_group_time_heatmap_plotly(
741
+ *,
742
+ effect_matrix,
743
+ p_matrix,
744
+ group_labels,
745
+ time_labels,
746
+ title,
747
+ cmap,
748
+ center,
749
+ annotate,
750
+ fmt,
751
+ mask_insignificant,
752
+ alpha,
753
+ show,
754
+ ):
755
+ """Render group-time heatmap with plotly."""
756
+ from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
757
+
758
+ go = _require_plotly()
759
+
760
+ # Pass cmap name through to plotly unchanged — plotly supports the same
761
+ # diverging colorscale names as matplotlib (RdBu, RdBu_r, etc.)
762
+ plotly_cmap = cmap
763
+
764
+ # Build text annotations
765
+ text = None
766
+ if annotate:
767
+ text = []
768
+ for i in range(effect_matrix.shape[0]):
769
+ row = []
770
+ for j in range(effect_matrix.shape[1]):
771
+ val = effect_matrix[i, j]
772
+ if np.isnan(val):
773
+ row.append("")
774
+ else:
775
+ row.append(f"{val:{fmt}}")
776
+ text.append(row)
777
+
778
+ # Build significance mask for overlay (do NOT replace with NaN — that
779
+ # conflates "insignificant" with "missing cell")
780
+ sig_mask = None
781
+ if mask_insignificant and p_matrix is not None:
782
+ sig_mask = p_matrix > alpha
783
+
784
+ # Center the colorscale
785
+ finite_vals = effect_matrix[np.isfinite(effect_matrix)]
786
+ if len(finite_vals) > 0:
787
+ abs_max = max(abs(np.nanmin(finite_vals) - center), abs(np.nanmax(finite_vals) - center))
788
+ zmin = center - abs_max
789
+ zmax = center + abs_max
790
+ else:
791
+ zmin, zmax = -1, 1
792
+
793
+ # Main heatmap — always shows all values (insignificant cells greyed via opacity)
794
+ fig = go.Figure(
795
+ data=go.Heatmap(
796
+ z=effect_matrix,
797
+ x=[str(t) for t in time_labels],
798
+ y=[str(g) for g in group_labels],
799
+ colorscale=plotly_cmap,
800
+ zmin=zmin,
801
+ zmax=zmax,
802
+ text=text,
803
+ texttemplate="%{text}" if annotate else None,
804
+ colorbar=dict(title="Effect"),
805
+ )
806
+ )
807
+
808
+ # Grey overlay for insignificant cells (preserves underlying value)
809
+ if sig_mask is not None and np.any(sig_mask):
810
+ grey_z = np.where(sig_mask, 1.0, np.nan)
811
+ fig.add_trace(
812
+ go.Heatmap(
813
+ z=grey_z,
814
+ x=[str(t) for t in time_labels],
815
+ y=[str(g) for g in group_labels],
816
+ colorscale=[[0, "rgba(255,255,255,0.6)"], [1, "rgba(255,255,255,0.6)"]],
817
+ showscale=False,
818
+ hoverinfo="skip",
819
+ )
820
+ )
821
+
822
+ _plotly_default_layout(
823
+ fig,
824
+ title=title,
825
+ xlabel="Time Period",
826
+ ylabel="Treatment Cohort",
827
+ show_legend=False,
828
+ )
829
+
830
+ if show:
831
+ fig.show()
832
+
833
+ return fig