diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/pretrends.py ADDED
@@ -0,0 +1,1104 @@
1
+ """
2
+ Pre-trends power analysis for difference-in-differences designs.
3
+
4
+ This module implements the power analysis framework from Roth (2022) for assessing
5
+ the informativeness of pre-trends tests. It answers the question: "If my pre-trends
6
+ test passed, what violations would I have been able to detect?"
7
+
8
+ Key concepts:
9
+ - **Minimum Detectable Violation (MDV)**: The smallest pre-trends violation that
10
+ would be detected with given power (e.g., 80%).
11
+ - **Power of Pre-Trends Test**: Probability of rejecting parallel trends given
12
+ a specific violation pattern.
13
+ - **Relationship to HonestDiD**: If MDV is large relative to your estimated effect,
14
+ a passing pre-trends test provides limited reassurance.
15
+
16
+ References
17
+ ----------
18
+ Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for
19
+ Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
20
+ https://doi.org/10.1257/aeri.20210236
21
+
22
+ See Also
23
+ --------
24
+ https://github.com/jonathandroth/pretrends - R package implementation
25
+ diff_diff.honest_did - Sensitivity analysis for parallel trends violations
26
+ """
27
+
28
+ from dataclasses import dataclass, field
29
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
30
+
31
+ import numpy as np
32
+ import pandas as pd
33
+ from scipy import stats, optimize
34
+
35
+ from diff_diff.results import MultiPeriodDiDResults
36
+
37
+
38
+ # =============================================================================
39
+ # Results Classes
40
+ # =============================================================================
41
+
42
+
43
+ @dataclass
44
+ class PreTrendsPowerResults:
45
+ """
46
+ Results from pre-trends power analysis.
47
+
48
+ Attributes
49
+ ----------
50
+ power : float
51
+ Power to detect the specified violation pattern at given alpha.
52
+ mdv : float
53
+ Minimum detectable violation (smallest M detectable at target power).
54
+ violation_magnitude : float
55
+ The magnitude of violation tested (M parameter).
56
+ violation_type : str
57
+ Type of violation pattern ('linear', 'constant', 'last_period', 'custom').
58
+ alpha : float
59
+ Significance level for the pre-trends test.
60
+ target_power : float
61
+ Target power level used for MDV calculation.
62
+ n_pre_periods : int
63
+ Number of pre-treatment periods in the event study.
64
+ test_statistic : float
65
+ Expected test statistic under the specified violation.
66
+ critical_value : float
67
+ Critical value for the pre-trends test.
68
+ noncentrality : float
69
+ Non-centrality parameter under the alternative hypothesis.
70
+ pre_period_effects : np.ndarray
71
+ Estimated pre-period effects from the event study.
72
+ pre_period_ses : np.ndarray
73
+ Standard errors of pre-period effects.
74
+ vcov : np.ndarray
75
+ Variance-covariance matrix of pre-period effects.
76
+ """
77
+
78
+ power: float
79
+ mdv: float
80
+ violation_magnitude: float
81
+ violation_type: str
82
+ alpha: float
83
+ target_power: float
84
+ n_pre_periods: int
85
+ test_statistic: float
86
+ critical_value: float
87
+ noncentrality: float
88
+ pre_period_effects: np.ndarray = field(repr=False)
89
+ pre_period_ses: np.ndarray = field(repr=False)
90
+ vcov: np.ndarray = field(repr=False)
91
+ original_results: Optional[Any] = field(default=None, repr=False)
92
+
93
+ def __repr__(self) -> str:
94
+ return (
95
+ f"PreTrendsPowerResults(power={self.power:.3f}, "
96
+ f"mdv={self.mdv:.4f}, M={self.violation_magnitude:.4f})"
97
+ )
98
+
99
+ @property
100
+ def is_informative(self) -> bool:
101
+ """
102
+ Check if the pre-trends test is informative.
103
+
104
+ A pre-trends test is considered informative if the MDV is reasonably
105
+ small relative to typical effect sizes. This is a heuristic check;
106
+ see the summary for interpretation guidance.
107
+ """
108
+ # Heuristic: MDV < 2x the max observed pre-period SE
109
+ max_se = np.max(self.pre_period_ses) if len(self.pre_period_ses) > 0 else 1.0
110
+ return bool(self.mdv < 2 * max_se)
111
+
112
+ @property
113
+ def power_adequate(self) -> bool:
114
+ """Check if power meets the target threshold."""
115
+ return bool(self.power >= self.target_power)
116
+
117
+ def summary(self) -> str:
118
+ """
119
+ Generate formatted summary of pre-trends power analysis.
120
+
121
+ Returns
122
+ -------
123
+ str
124
+ Formatted summary.
125
+ """
126
+ lines = [
127
+ "=" * 70,
128
+ "Pre-Trends Power Analysis Results".center(70),
129
+ "(Roth 2022)".center(70),
130
+ "=" * 70,
131
+ "",
132
+ f"{'Number of pre-periods:':<35} {self.n_pre_periods}",
133
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
134
+ f"{'Target power:':<35} {self.target_power:.1%}",
135
+ f"{'Violation type:':<35} {self.violation_type}",
136
+ "",
137
+ "-" * 70,
138
+ "Power Analysis".center(70),
139
+ "-" * 70,
140
+ f"{'Violation magnitude (M):':<35} {self.violation_magnitude:.4f}",
141
+ f"{'Power to detect this violation:':<35} {self.power:.1%}",
142
+ f"{'Minimum detectable violation:':<35} {self.mdv:.4f}",
143
+ "",
144
+ f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}",
145
+ f"{'Critical value:':<35} {self.critical_value:.4f}",
146
+ f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}",
147
+ "",
148
+ "-" * 70,
149
+ "Interpretation".center(70),
150
+ "-" * 70,
151
+ ]
152
+
153
+ if self.power_adequate:
154
+ lines.append(f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%}).")
155
+ lines.append(
156
+ f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}."
157
+ )
158
+ else:
159
+ lines.append(f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%}).")
160
+ lines.append(
161
+ f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power."
162
+ )
163
+
164
+ lines.append("")
165
+ lines.append(f"Minimum detectable violation (MDV): {self.mdv:.4f}")
166
+ lines.append(" → Passing pre-trends test does NOT rule out violations up to this size.")
167
+
168
+ lines.extend(["", "=" * 70])
169
+
170
+ return "\n".join(lines)
171
+
172
+ def print_summary(self) -> None:
173
+ """Print summary to stdout."""
174
+ print(self.summary())
175
+
176
+ def to_dict(self) -> Dict[str, Any]:
177
+ """Convert results to dictionary."""
178
+ return {
179
+ "power": self.power,
180
+ "mdv": self.mdv,
181
+ "violation_magnitude": self.violation_magnitude,
182
+ "violation_type": self.violation_type,
183
+ "alpha": self.alpha,
184
+ "target_power": self.target_power,
185
+ "n_pre_periods": self.n_pre_periods,
186
+ "test_statistic": self.test_statistic,
187
+ "critical_value": self.critical_value,
188
+ "noncentrality": self.noncentrality,
189
+ "is_informative": self.is_informative,
190
+ "power_adequate": self.power_adequate,
191
+ }
192
+
193
+ def to_dataframe(self) -> pd.DataFrame:
194
+ """Convert results to DataFrame."""
195
+ return pd.DataFrame([self.to_dict()])
196
+
197
+ def power_at(self, M: float) -> float:
198
+ """
199
+ Compute power to detect a specific violation magnitude.
200
+
201
+ This method allows computing power at different M values without
202
+ re-fitting the model, using the stored variance-covariance matrix.
203
+
204
+ Parameters
205
+ ----------
206
+ M : float
207
+ Violation magnitude to evaluate.
208
+
209
+ Returns
210
+ -------
211
+ float
212
+ Power to detect violation of magnitude M.
213
+ """
214
+ from scipy import stats
215
+
216
+ n_pre = self.n_pre_periods
217
+
218
+ # Reconstruct violation weights based on violation type
219
+ # Must match PreTrendsPower._get_violation_weights() exactly
220
+ if self.violation_type == "linear":
221
+ # Linear trend: weights decrease toward treatment
222
+ # [n-1, n-2, ..., 1, 0] for n pre-periods
223
+ weights = np.arange(-n_pre + 1, 1, dtype=float)
224
+ weights = -weights # Now [n-1, n-2, ..., 1, 0]
225
+ elif self.violation_type == "constant":
226
+ weights = np.ones(n_pre)
227
+ elif self.violation_type == "last_period":
228
+ weights = np.zeros(n_pre)
229
+ weights[-1] = 1.0
230
+ else:
231
+ # For custom, we can't reconstruct - use equal weights as fallback
232
+ weights = np.ones(n_pre)
233
+
234
+ # Normalize weights to unit L2 norm
235
+ norm = np.linalg.norm(weights)
236
+ if norm > 0:
237
+ weights = weights / norm
238
+
239
+ # Compute non-centrality parameter
240
+ try:
241
+ vcov_inv = np.linalg.inv(self.vcov)
242
+ except np.linalg.LinAlgError:
243
+ vcov_inv = np.linalg.pinv(self.vcov)
244
+
245
+ # delta = M * weights
246
+ # nc = delta' * V^{-1} * delta
247
+ noncentrality = M**2 * (weights @ vcov_inv @ weights)
248
+
249
+ # Compute power using non-central chi-squared
250
+ power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality)
251
+
252
+ return float(power)
253
+
254
+
255
+ @dataclass
256
+ class PreTrendsPowerCurve:
257
+ """
258
+ Power curve across violation magnitudes.
259
+
260
+ Attributes
261
+ ----------
262
+ M_values : np.ndarray
263
+ Grid of violation magnitudes tested.
264
+ powers : np.ndarray
265
+ Power at each violation magnitude.
266
+ mdv : float
267
+ Minimum detectable violation.
268
+ alpha : float
269
+ Significance level.
270
+ target_power : float
271
+ Target power level.
272
+ violation_type : str
273
+ Type of violation pattern.
274
+ """
275
+
276
+ M_values: np.ndarray
277
+ powers: np.ndarray
278
+ mdv: float
279
+ alpha: float
280
+ target_power: float
281
+ violation_type: str
282
+
283
+ def __repr__(self) -> str:
284
+ return f"PreTrendsPowerCurve(n_points={len(self.M_values)}, " f"mdv={self.mdv:.4f})"
285
+
286
+ def to_dataframe(self) -> pd.DataFrame:
287
+ """Convert to DataFrame with M and power columns."""
288
+ return pd.DataFrame(
289
+ {
290
+ "M": self.M_values,
291
+ "power": self.powers,
292
+ }
293
+ )
294
+
295
+ def plot(
296
+ self,
297
+ ax=None,
298
+ show_mdv: bool = True,
299
+ show_target: bool = True,
300
+ color: str = "#2563eb",
301
+ mdv_color: str = "#dc2626",
302
+ target_color: str = "#22c55e",
303
+ **kwargs,
304
+ ):
305
+ """
306
+ Plot the power curve.
307
+
308
+ Parameters
309
+ ----------
310
+ ax : matplotlib.axes.Axes, optional
311
+ Axes to plot on. If None, creates new figure.
312
+ show_mdv : bool, default=True
313
+ Whether to show vertical line at MDV.
314
+ show_target : bool, default=True
315
+ Whether to show horizontal line at target power.
316
+ color : str
317
+ Color for power curve line.
318
+ mdv_color : str
319
+ Color for MDV vertical line.
320
+ target_color : str
321
+ Color for target power horizontal line.
322
+ **kwargs
323
+ Additional arguments passed to plt.plot().
324
+
325
+ Returns
326
+ -------
327
+ ax : matplotlib.axes.Axes
328
+ The axes with the plot.
329
+ """
330
+ try:
331
+ import matplotlib.pyplot as plt
332
+ except ImportError:
333
+ raise ImportError("matplotlib is required for plotting")
334
+
335
+ if ax is None:
336
+ fig, ax = plt.subplots(figsize=(10, 6))
337
+
338
+ # Plot power curve
339
+ ax.plot(self.M_values, self.powers, color=color, linewidth=2, label="Power", **kwargs)
340
+
341
+ # Target power line
342
+ if show_target:
343
+ ax.axhline(
344
+ y=self.target_power,
345
+ color=target_color,
346
+ linestyle="--",
347
+ linewidth=1.5,
348
+ alpha=0.7,
349
+ label=f"Target power ({self.target_power:.0%})",
350
+ )
351
+
352
+ # MDV line
353
+ if show_mdv and self.mdv is not None and np.isfinite(self.mdv):
354
+ ax.axvline(
355
+ x=self.mdv,
356
+ color=mdv_color,
357
+ linestyle=":",
358
+ linewidth=1.5,
359
+ alpha=0.7,
360
+ label=f"MDV = {self.mdv:.3f}",
361
+ )
362
+
363
+ ax.set_xlabel("Violation Magnitude (M)")
364
+ ax.set_ylabel("Power")
365
+ ax.set_title("Pre-Trends Test Power Curve")
366
+ ax.set_ylim(0, 1.05)
367
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
368
+ ax.legend(loc="lower right")
369
+ ax.grid(True, alpha=0.3)
370
+
371
+ return ax
372
+
373
+
374
+ # =============================================================================
375
+ # Main Class
376
+ # =============================================================================
377
+
378
+
379
+ class PreTrendsPower:
380
+ """
381
+ Pre-trends power analysis (Roth 2022).
382
+
383
+ Computes the power of pre-trends tests to detect violations of parallel
384
+ trends, and the minimum detectable violation (MDV).
385
+
386
+ Parameters
387
+ ----------
388
+ alpha : float, default=0.05
389
+ Significance level for the pre-trends test.
390
+ power : float, default=0.80
391
+ Target power level for MDV calculation.
392
+ violation_type : str, default='linear'
393
+ Type of violation pattern to consider:
394
+ - 'linear': Violations follow a linear trend (most common)
395
+ - 'constant': Same violation in all pre-periods
396
+ - 'last_period': Violation only in the last pre-period
397
+ - 'custom': User-specified violation pattern (via violation_weights)
398
+ violation_weights : array-like, optional
399
+ Custom weights for violation pattern. Length must equal number of
400
+ pre-periods. Only used when violation_type='custom'.
401
+
402
+ Examples
403
+ --------
404
+ Basic usage with MultiPeriodDiD results:
405
+
406
+ >>> from diff_diff import MultiPeriodDiD
407
+ >>> from diff_diff.pretrends import PreTrendsPower
408
+ >>>
409
+ >>> # Fit event study
410
+ >>> mp_did = MultiPeriodDiD()
411
+ >>> results = mp_did.fit(data, outcome='y', treatment='treated',
412
+ ... time='period', post_periods=[4, 5, 6, 7])
413
+ >>>
414
+ >>> # Analyze pre-trends power
415
+ >>> pt = PreTrendsPower(alpha=0.05, power=0.80)
416
+ >>> power_results = pt.fit(results)
417
+ >>> print(power_results.summary())
418
+ >>>
419
+ >>> # Get power curve
420
+ >>> curve = pt.power_curve(results)
421
+ >>> curve.plot()
422
+
423
+ Notes
424
+ -----
425
+ The pre-trends test is typically a joint test that all pre-period
426
+ coefficients are zero. This test has limited power to detect small
427
+ violations, especially when:
428
+
429
+ 1. There are few pre-periods
430
+ 2. Standard errors are large
431
+ 3. The violation pattern is smooth (e.g., linear trend)
432
+
433
+ Passing a pre-trends test does NOT mean parallel trends holds. It means
434
+ violations smaller than the MDV cannot be ruled out. For robust inference,
435
+ combine with HonestDiD sensitivity analysis.
436
+
437
+ References
438
+ ----------
439
+ Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing
440
+ for Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ alpha: float = 0.05,
446
+ power: float = 0.80,
447
+ violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear",
448
+ violation_weights: Optional[np.ndarray] = None,
449
+ ):
450
+ if not 0 < alpha < 1:
451
+ raise ValueError(f"alpha must be between 0 and 1, got {alpha}")
452
+ if not 0 < power < 1:
453
+ raise ValueError(f"power must be between 0 and 1, got {power}")
454
+ if violation_type not in ["linear", "constant", "last_period", "custom"]:
455
+ raise ValueError(
456
+ f"violation_type must be 'linear', 'constant', 'last_period', or 'custom', "
457
+ f"got '{violation_type}'"
458
+ )
459
+ if violation_type == "custom" and violation_weights is None:
460
+ raise ValueError("violation_weights must be provided when violation_type='custom'")
461
+
462
+ self.alpha = alpha
463
+ self.target_power = power
464
+ self.violation_type = violation_type
465
+ self.violation_weights = (
466
+ np.asarray(violation_weights) if violation_weights is not None else None
467
+ )
468
+
469
+ def get_params(self) -> Dict[str, Any]:
470
+ """Get parameters for this estimator."""
471
+ return {
472
+ "alpha": self.alpha,
473
+ "power": self.target_power,
474
+ "violation_type": self.violation_type,
475
+ "violation_weights": self.violation_weights,
476
+ }
477
+
478
+ def set_params(self, **params) -> "PreTrendsPower":
479
+ """Set parameters for this estimator."""
480
+ for key, value in params.items():
481
+ if key == "power":
482
+ self.target_power = value
483
+ elif hasattr(self, key):
484
+ setattr(self, key, value)
485
+ else:
486
+ raise ValueError(f"Invalid parameter: {key}")
487
+ return self
488
+
489
+ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
490
+ """
491
+ Get violation weights based on violation type.
492
+
493
+ Parameters
494
+ ----------
495
+ n_pre : int
496
+ Number of pre-treatment periods.
497
+
498
+ Returns
499
+ -------
500
+ np.ndarray
501
+ Violation weights, normalized to have L2 norm of 1.
502
+ """
503
+ if self.violation_type == "custom":
504
+ if len(self.violation_weights) != n_pre:
505
+ raise ValueError(
506
+ f"violation_weights has length {len(self.violation_weights)}, "
507
+ f"but there are {n_pre} pre-periods"
508
+ )
509
+ weights = self.violation_weights.copy()
510
+ elif self.violation_type == "linear":
511
+ # Linear trend: weights = [-n+1, -n+2, ..., -1, 0] for periods ending at -1
512
+ # Normalized so that violation at period -1 = 0 and grows linearly backward
513
+ weights = np.arange(-n_pre + 1, 1, dtype=float)
514
+ # Shift so that weights are positive and represent deviation from PT
515
+ weights = -weights # Now [n-1, n-2, ..., 1, 0]
516
+ elif self.violation_type == "constant":
517
+ # Same violation in all periods
518
+ weights = np.ones(n_pre)
519
+ elif self.violation_type == "last_period":
520
+ # Violation only in last pre-period (period -1)
521
+ weights = np.zeros(n_pre)
522
+ weights[-1] = 1.0
523
+ else:
524
+ raise ValueError(f"Unknown violation_type: {self.violation_type}")
525
+
526
+ # Normalize to unit norm (if not all zeros)
527
+ norm = np.linalg.norm(weights)
528
+ if norm > 0:
529
+ weights = weights / norm
530
+
531
+ return weights
532
+
533
+ def _extract_pre_period_params(
534
+ self,
535
+ results: Union[MultiPeriodDiDResults, Any],
536
+ pre_periods: Optional[List[int]] = None,
537
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
538
+ """
539
+ Extract pre-period parameters from results.
540
+
541
+ Parameters
542
+ ----------
543
+ results : MultiPeriodDiDResults or similar
544
+ Results object from event study estimation.
545
+ pre_periods : list of int, optional
546
+ Explicit list of pre-treatment periods. If None, uses results.pre_periods.
547
+
548
+ Returns
549
+ -------
550
+ effects : np.ndarray
551
+ Pre-period effect estimates.
552
+ ses : np.ndarray
553
+ Pre-period standard errors.
554
+ vcov : np.ndarray
555
+ Variance-covariance matrix for pre-period effects.
556
+ n_pre : int
557
+ Number of pre-periods.
558
+ """
559
+ if isinstance(results, MultiPeriodDiDResults):
560
+ # Get pre-period information - use explicit pre_periods if provided
561
+ if pre_periods is not None:
562
+ all_pre_periods = list(pre_periods)
563
+ else:
564
+ all_pre_periods = results.pre_periods
565
+
566
+ if len(all_pre_periods) == 0:
567
+ raise ValueError(
568
+ "No pre-treatment periods found in results. "
569
+ "Pre-trends power analysis requires pre-period coefficients. "
570
+ "If you estimated all periods as post_periods, use the pre_periods "
571
+ "parameter to specify which are actually pre-treatment."
572
+ )
573
+
574
+ # Pre-period effects are in period_effects (excluding reference period)
575
+ estimated_pre_periods = [
576
+ p
577
+ for p in all_pre_periods
578
+ if p in results.period_effects and results.period_effects[p].se > 0
579
+ ]
580
+
581
+ if len(estimated_pre_periods) == 0:
582
+ raise ValueError(
583
+ "No estimated pre-period coefficients found. "
584
+ "The pre-trends test requires at least one estimated "
585
+ "pre-period coefficient (excluding the reference period)."
586
+ )
587
+
588
+ n_pre = len(estimated_pre_periods)
589
+ effects = np.array([results.period_effects[p].effect for p in estimated_pre_periods])
590
+ ses = np.array([results.period_effects[p].se for p in estimated_pre_periods])
591
+
592
+ # Extract vcov using stored interaction indices for robust extraction
593
+ if (
594
+ results.vcov is not None
595
+ and hasattr(results, "interaction_indices")
596
+ and results.interaction_indices is not None
597
+ ):
598
+ indices = [results.interaction_indices[p] for p in estimated_pre_periods]
599
+ vcov = results.vcov[np.ix_(indices, indices)]
600
+ else:
601
+ vcov = np.diag(ses**2)
602
+
603
+ return effects, ses, vcov, n_pre
604
+
605
+ # Try CallawaySantAnnaResults
606
+ try:
607
+ from diff_diff.staggered import CallawaySantAnnaResults
608
+
609
+ if isinstance(results, CallawaySantAnnaResults):
610
+ if results.event_study_effects is None:
611
+ raise ValueError(
612
+ "CallawaySantAnnaResults must have event_study_effects. "
613
+ "Re-run with aggregate='event_study'."
614
+ )
615
+
616
+ # Get pre-period effects (negative relative times)
617
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
618
+ pre_effects = {
619
+ t: data
620
+ for t, data in results.event_study_effects.items()
621
+ if t < 0 and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan))
622
+ }
623
+
624
+ if not pre_effects:
625
+ raise ValueError("No pre-treatment periods found in event study.")
626
+
627
+ pre_periods = sorted(pre_effects.keys())
628
+ n_pre = len(pre_periods)
629
+
630
+ effects = np.array([pre_effects[t]["effect"] for t in pre_periods])
631
+ ses = np.array([pre_effects[t]["se"] for t in pre_periods])
632
+ vcov = np.diag(ses**2)
633
+
634
+ return effects, ses, vcov, n_pre
635
+ except ImportError:
636
+ pass
637
+
638
+ # Try SunAbrahamResults
639
+ try:
640
+ from diff_diff.sun_abraham import SunAbrahamResults
641
+
642
+ if isinstance(results, SunAbrahamResults):
643
+ # Get pre-period effects (negative relative times)
644
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
645
+ pre_effects = {
646
+ t: data
647
+ for t, data in results.event_study_effects.items()
648
+ if t < 0 and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan))
649
+ }
650
+
651
+ if not pre_effects:
652
+ raise ValueError("No pre-treatment periods found in event study.")
653
+
654
+ pre_periods = sorted(pre_effects.keys())
655
+ n_pre = len(pre_periods)
656
+
657
+ effects = np.array([pre_effects[t]["effect"] for t in pre_periods])
658
+ ses = np.array([pre_effects[t]["se"] for t in pre_periods])
659
+ vcov = np.diag(ses**2)
660
+
661
+ return effects, ses, vcov, n_pre
662
+ except ImportError:
663
+ pass
664
+
665
+ raise TypeError(
666
+ f"Unsupported results type: {type(results)}. "
667
+ "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults."
668
+ )
669
+
670
+ def _compute_power(
671
+ self,
672
+ M: float,
673
+ weights: np.ndarray,
674
+ vcov: np.ndarray,
675
+ ) -> Tuple[float, float, float, float]:
676
+ """
677
+ Compute power to detect violation of magnitude M.
678
+
679
+ The pre-trends test is a Wald test: H0: delta = 0 vs H1: delta != 0
680
+ Under H1 with violation delta = M * weights, the test statistic follows
681
+ a non-central chi-squared distribution.
682
+
683
+ Parameters
684
+ ----------
685
+ M : float
686
+ Violation magnitude.
687
+ weights : np.ndarray
688
+ Normalized violation pattern.
689
+ vcov : np.ndarray
690
+ Variance-covariance matrix.
691
+
692
+ Returns
693
+ -------
694
+ power : float
695
+ Power to detect this violation.
696
+ noncentrality : float
697
+ Non-centrality parameter.
698
+ test_stat : float
699
+ Expected test statistic under H1.
700
+ critical_value : float
701
+ Critical value for the test.
702
+ """
703
+ n_pre = len(weights)
704
+
705
+ # Violation vector: delta = M * weights
706
+ delta = M * weights
707
+
708
+ # Non-centrality parameter for chi-squared test
709
+ # lambda = delta' * V^{-1} * delta
710
+ try:
711
+ vcov_inv = np.linalg.inv(vcov)
712
+ noncentrality = delta @ vcov_inv @ delta
713
+ except np.linalg.LinAlgError:
714
+ # Singular matrix - use pseudo-inverse
715
+ vcov_inv = np.linalg.pinv(vcov)
716
+ noncentrality = delta @ vcov_inv @ delta
717
+
718
+ # Critical value from chi-squared distribution
719
+ critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
720
+
721
+ # Power = P(chi2_nc > critical_value) where chi2_nc is non-central chi2
722
+ if noncentrality > 0:
723
+ power = 1 - stats.ncx2.cdf(critical_value, df=n_pre, nc=noncentrality)
724
+ else:
725
+ power = self.alpha # Size under null
726
+
727
+ # Expected test statistic under H1
728
+ test_stat = n_pre + noncentrality # Mean of non-central chi2
729
+
730
+ return power, noncentrality, test_stat, critical_value
731
+
732
+ def _compute_mdv(
733
+ self,
734
+ weights: np.ndarray,
735
+ vcov: np.ndarray,
736
+ ) -> float:
737
+ """
738
+ Compute minimum detectable violation.
739
+
740
+ Find the smallest M such that power >= target_power.
741
+
742
+ Parameters
743
+ ----------
744
+ weights : np.ndarray
745
+ Normalized violation pattern.
746
+ vcov : np.ndarray
747
+ Variance-covariance matrix.
748
+
749
+ Returns
750
+ -------
751
+ mdv : float
752
+ Minimum detectable violation.
753
+ """
754
+ n_pre = len(weights)
755
+
756
+ # Critical value
757
+ critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
758
+
759
+ # Find non-centrality parameter for target power
760
+ # We need: P(ncx2 > critical_value) = target_power
761
+ # Use inverse: find lambda such that ncx2.cdf(cv, df, lambda) = 1 - target_power
762
+
763
+ def power_minus_target(nc):
764
+ if nc <= 0:
765
+ return self.alpha - self.target_power
766
+ return stats.ncx2.sf(critical_value, df=n_pre, nc=nc) - self.target_power
767
+
768
+ # Binary search for non-centrality parameter
769
+ # Start with bounds
770
+ nc_low, nc_high = 0, 1
771
+
772
+ # Expand upper bound until power exceeds target
773
+ while power_minus_target(nc_high) < 0 and nc_high < 1000:
774
+ nc_high *= 2
775
+
776
+ if nc_high >= 1000:
777
+ # Target power not achievable - return inf
778
+ return np.inf
779
+
780
+ # Binary search
781
+ try:
782
+ result = optimize.brentq(power_minus_target, nc_low, nc_high)
783
+ target_nc = result
784
+ except ValueError:
785
+ # Fallback: use approximate formula
786
+ # For chi2, power ≈ Phi(sqrt(2*nc) - sqrt(2*cv))
787
+ # Solving: sqrt(2*nc) = z_power + sqrt(2*cv)
788
+ z_power = stats.norm.ppf(self.target_power)
789
+ target_nc = 0.5 * (z_power + np.sqrt(2 * critical_value)) ** 2
790
+
791
+ # Convert non-centrality to M
792
+ # nc = delta' * V^{-1} * delta = M^2 * w' * V^{-1} * w
793
+ try:
794
+ vcov_inv = np.linalg.inv(vcov)
795
+ w_Vinv_w = weights @ vcov_inv @ weights
796
+ except np.linalg.LinAlgError:
797
+ vcov_inv = np.linalg.pinv(vcov)
798
+ w_Vinv_w = weights @ vcov_inv @ weights
799
+
800
+ if w_Vinv_w > 0:
801
+ mdv = np.sqrt(target_nc / w_Vinv_w)
802
+ else:
803
+ mdv = np.inf
804
+
805
+ return mdv
806
+
807
+ def fit(
808
+ self,
809
+ results: Union[MultiPeriodDiDResults, Any],
810
+ M: Optional[float] = None,
811
+ pre_periods: Optional[List[int]] = None,
812
+ ) -> PreTrendsPowerResults:
813
+ """
814
+ Compute pre-trends power analysis.
815
+
816
+ Parameters
817
+ ----------
818
+ results : MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults
819
+ Results from an event study estimation.
820
+ M : float, optional
821
+ Specific violation magnitude to evaluate. If None, evaluates at
822
+ a default magnitude based on the data.
823
+ pre_periods : list of int, optional
824
+ Explicit list of pre-treatment periods to use for power analysis.
825
+ If None, attempts to infer from results.pre_periods. Use this when
826
+ you've estimated an event study with all periods in post_periods
827
+ and need to specify which are actually pre-treatment.
828
+
829
+ Returns
830
+ -------
831
+ PreTrendsPowerResults
832
+ Power analysis results including power and MDV.
833
+ """
834
+ # Extract pre-period parameters
835
+ effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
836
+
837
+ # Get violation weights
838
+ weights = self._get_violation_weights(n_pre)
839
+
840
+ # Compute MDV
841
+ mdv = self._compute_mdv(weights, vcov)
842
+
843
+ # Default M: use MDV if not specified
844
+ if M is None:
845
+ M = mdv if np.isfinite(mdv) else np.max(ses)
846
+
847
+ # Compute power at specified M
848
+ power, noncentrality, test_stat, critical_value = self._compute_power(M, weights, vcov)
849
+
850
+ return PreTrendsPowerResults(
851
+ power=power,
852
+ mdv=mdv,
853
+ violation_magnitude=M,
854
+ violation_type=self.violation_type,
855
+ alpha=self.alpha,
856
+ target_power=self.target_power,
857
+ n_pre_periods=n_pre,
858
+ test_statistic=test_stat,
859
+ critical_value=critical_value,
860
+ noncentrality=noncentrality,
861
+ pre_period_effects=effects,
862
+ pre_period_ses=ses,
863
+ vcov=vcov,
864
+ original_results=results,
865
+ )
866
+
867
+ def power_at(
868
+ self,
869
+ results: Union[MultiPeriodDiDResults, Any],
870
+ M: float,
871
+ pre_periods: Optional[List[int]] = None,
872
+ ) -> float:
873
+ """
874
+ Compute power to detect a specific violation magnitude.
875
+
876
+ Parameters
877
+ ----------
878
+ results : results object
879
+ Event study results.
880
+ M : float
881
+ Violation magnitude.
882
+ pre_periods : list of int, optional
883
+ Explicit list of pre-treatment periods. See fit() for details.
884
+
885
+ Returns
886
+ -------
887
+ float
888
+ Power to detect violation of magnitude M.
889
+ """
890
+ result = self.fit(results, M=M, pre_periods=pre_periods)
891
+ return result.power
892
+
893
+ def power_curve(
894
+ self,
895
+ results: Union[MultiPeriodDiDResults, Any],
896
+ M_grid: Optional[List[float]] = None,
897
+ n_points: int = 50,
898
+ pre_periods: Optional[List[int]] = None,
899
+ ) -> PreTrendsPowerCurve:
900
+ """
901
+ Compute power across a range of violation magnitudes.
902
+
903
+ Parameters
904
+ ----------
905
+ results : results object
906
+ Event study results.
907
+ M_grid : list of float, optional
908
+ Specific violation magnitudes to evaluate. If None, creates
909
+ automatic grid from 0 to 2.5 * MDV.
910
+ n_points : int, default=50
911
+ Number of points in automatic grid.
912
+ pre_periods : list of int, optional
913
+ Explicit list of pre-treatment periods. See fit() for details.
914
+
915
+ Returns
916
+ -------
917
+ PreTrendsPowerCurve
918
+ Power curve data with plot method.
919
+ """
920
+ # Extract parameters
921
+ _, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
922
+ weights = self._get_violation_weights(n_pre)
923
+
924
+ # Compute MDV
925
+ mdv = self._compute_mdv(weights, vcov)
926
+
927
+ # Create M grid if not provided
928
+ if M_grid is None:
929
+ max_M = min(2.5 * mdv if np.isfinite(mdv) else 10 * np.max(ses), 100)
930
+ M_grid = np.linspace(0, max_M, n_points)
931
+ else:
932
+ M_grid = np.asarray(M_grid)
933
+
934
+ # Compute power at each M
935
+ powers = np.array([self._compute_power(M, weights, vcov)[0] for M in M_grid])
936
+
937
+ return PreTrendsPowerCurve(
938
+ M_values=M_grid,
939
+ powers=powers,
940
+ mdv=mdv,
941
+ alpha=self.alpha,
942
+ target_power=self.target_power,
943
+ violation_type=self.violation_type,
944
+ )
945
+
946
+ def sensitivity_to_honest_did(
947
+ self,
948
+ results: Union[MultiPeriodDiDResults, Any],
949
+ pre_periods: Optional[List[int]] = None,
950
+ ) -> Dict[str, Any]:
951
+ """
952
+ Compare pre-trends power analysis with HonestDiD sensitivity.
953
+
954
+ This method helps interpret how informative a passing pre-trends
955
+ test is in the context of HonestDiD's relative magnitudes restriction.
956
+
957
+ Parameters
958
+ ----------
959
+ results : results object
960
+ Event study results.
961
+ pre_periods : list of int, optional
962
+ Explicit list of pre-treatment periods. See fit() for details.
963
+
964
+ Returns
965
+ -------
966
+ dict
967
+ Dictionary with:
968
+ - mdv: Minimum detectable violation from pre-trends test
969
+ - honest_M_at_mdv: Corresponding M value for HonestDiD
970
+ - interpretation: Text explaining the relationship
971
+ """
972
+ pt_results = self.fit(results, pre_periods=pre_periods)
973
+ mdv = pt_results.mdv
974
+
975
+ # The MDV represents the size of violation the test could detect
976
+ # In HonestDiD's relative magnitudes framework, M=1 means
977
+ # post-treatment violations can be as large as the max pre-period violation
978
+ # The MDV gives us a sense of how large that max violation could be
979
+
980
+ max_pre_se = np.max(pt_results.pre_period_ses)
981
+
982
+ interpretation = []
983
+ interpretation.append(f"Minimum Detectable Violation (MDV): {mdv:.4f}")
984
+ interpretation.append(f"Max pre-period SE: {max_pre_se:.4f}")
985
+
986
+ if np.isfinite(mdv):
987
+ # Ratio of MDV to max SE - gives sense of how many SEs the MDV is
988
+ mdv_in_ses = mdv / max_pre_se if max_pre_se > 0 else np.inf
989
+ interpretation.append(f"MDV / max(SE): {mdv_in_ses:.2f}")
990
+
991
+ if mdv_in_ses < 1:
992
+ interpretation.append("→ Pre-trends test is fairly sensitive to violations.")
993
+ elif mdv_in_ses < 2:
994
+ interpretation.append("→ Pre-trends test has moderate sensitivity.")
995
+ else:
996
+ interpretation.append("→ Pre-trends test has low power to detect violations.")
997
+ interpretation.append(
998
+ " Consider using HonestDiD with larger M values for robustness."
999
+ )
1000
+ else:
1001
+ interpretation.append(
1002
+ "→ Pre-trends test cannot achieve target power for any violation size."
1003
+ )
1004
+ interpretation.append(" Use HonestDiD sensitivity analysis for inference.")
1005
+
1006
+ return {
1007
+ "mdv": mdv,
1008
+ "max_pre_se": max_pre_se,
1009
+ "mdv_in_ses": mdv / max_pre_se if max_pre_se > 0 and np.isfinite(mdv) else np.inf,
1010
+ "interpretation": "\n".join(interpretation),
1011
+ }
1012
+
1013
+
1014
+ # =============================================================================
1015
+ # Convenience Functions
1016
+ # =============================================================================
1017
+
1018
+
1019
+ def compute_pretrends_power(
1020
+ results: Union[MultiPeriodDiDResults, Any],
1021
+ M: Optional[float] = None,
1022
+ alpha: float = 0.05,
1023
+ target_power: float = 0.80,
1024
+ violation_type: str = "linear",
1025
+ pre_periods: Optional[List[int]] = None,
1026
+ ) -> PreTrendsPowerResults:
1027
+ """
1028
+ Convenience function for pre-trends power analysis.
1029
+
1030
+ Parameters
1031
+ ----------
1032
+ results : results object
1033
+ Event study results.
1034
+ M : float, optional
1035
+ Violation magnitude to evaluate.
1036
+ alpha : float, default=0.05
1037
+ Significance level.
1038
+ target_power : float, default=0.80
1039
+ Target power for MDV calculation.
1040
+ violation_type : str, default='linear'
1041
+ Type of violation pattern.
1042
+ pre_periods : list of int, optional
1043
+ Explicit list of pre-treatment periods. If None, attempts to infer
1044
+ from results. Use when you've estimated all periods as post_periods.
1045
+
1046
+ Returns
1047
+ -------
1048
+ PreTrendsPowerResults
1049
+ Power analysis results.
1050
+
1051
+ Examples
1052
+ --------
1053
+ >>> from diff_diff import MultiPeriodDiD
1054
+ >>> from diff_diff.pretrends import compute_pretrends_power
1055
+ >>>
1056
+ >>> results = MultiPeriodDiD().fit(data, ...)
1057
+ >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3])
1058
+ >>> print(f"MDV: {power_results.mdv:.3f}")
1059
+ >>> print(f"Power: {power_results.power:.1%}")
1060
+ """
1061
+ pt = PreTrendsPower(
1062
+ alpha=alpha,
1063
+ power=target_power,
1064
+ violation_type=violation_type,
1065
+ )
1066
+ return pt.fit(results, M=M, pre_periods=pre_periods)
1067
+
1068
+
1069
+ def compute_mdv(
1070
+ results: Union[MultiPeriodDiDResults, Any],
1071
+ alpha: float = 0.05,
1072
+ target_power: float = 0.80,
1073
+ violation_type: str = "linear",
1074
+ pre_periods: Optional[List[int]] = None,
1075
+ ) -> float:
1076
+ """
1077
+ Compute minimum detectable violation.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ results : results object
1082
+ Event study results.
1083
+ alpha : float, default=0.05
1084
+ Significance level.
1085
+ target_power : float, default=0.80
1086
+ Target power for MDV calculation.
1087
+ violation_type : str, default='linear'
1088
+ Type of violation pattern.
1089
+ pre_periods : list of int, optional
1090
+ Explicit list of pre-treatment periods. If None, attempts to infer
1091
+ from results. Use when you've estimated all periods as post_periods.
1092
+
1093
+ Returns
1094
+ -------
1095
+ float
1096
+ Minimum detectable violation.
1097
+ """
1098
+ pt = PreTrendsPower(
1099
+ alpha=alpha,
1100
+ power=target_power,
1101
+ violation_type=violation_type,
1102
+ )
1103
+ result = pt.fit(results, pre_periods=pre_periods)
1104
+ return result.mdv