diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/pretrends.py ADDED
@@ -0,0 +1,1067 @@
1
+ """
2
+ Pre-trends power analysis for difference-in-differences designs.
3
+
4
+ This module implements the power analysis framework from Roth (2022) for assessing
5
+ the informativeness of pre-trends tests. It answers the question: "If my pre-trends
6
+ test passed, what violations would I have been able to detect?"
7
+
8
+ Key concepts:
9
+ - **Minimum Detectable Violation (MDV)**: The smallest pre-trends violation that
10
+ would be detected with given power (e.g., 80%).
11
+ - **Power of Pre-Trends Test**: Probability of rejecting parallel trends given
12
+ a specific violation pattern.
13
+ - **Relationship to HonestDiD**: If MDV is large relative to your estimated effect,
14
+ a passing pre-trends test provides limited reassurance.
15
+
16
+ References
17
+ ----------
18
+ Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for
19
+ Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
20
+ https://doi.org/10.1257/aeri.20210236
21
+
22
+ See Also
23
+ --------
24
+ https://github.com/jonathandroth/pretrends - R package implementation
25
+ diff_diff.honest_did - Sensitivity analysis for parallel trends violations
26
+ """
27
+
28
+ from dataclasses import dataclass, field
29
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
30
+
31
+ import numpy as np
32
+ import pandas as pd
33
+ from scipy import stats, optimize
34
+
35
+ from diff_diff.results import MultiPeriodDiDResults
36
+
37
+
38
+ # =============================================================================
39
+ # Results Classes
40
+ # =============================================================================
41
+
42
+
43
+ @dataclass
44
+ class PreTrendsPowerResults:
45
+ """
46
+ Results from pre-trends power analysis.
47
+
48
+ Attributes
49
+ ----------
50
+ power : float
51
+ Power to detect the specified violation pattern at given alpha.
52
+ mdv : float
53
+ Minimum detectable violation (smallest M detectable at target power).
54
+ violation_magnitude : float
55
+ The magnitude of violation tested (M parameter).
56
+ violation_type : str
57
+ Type of violation pattern ('linear', 'constant', 'last_period', 'custom').
58
+ alpha : float
59
+ Significance level for the pre-trends test.
60
+ target_power : float
61
+ Target power level used for MDV calculation.
62
+ n_pre_periods : int
63
+ Number of pre-treatment periods in the event study.
64
+ test_statistic : float
65
+ Expected test statistic under the specified violation.
66
+ critical_value : float
67
+ Critical value for the pre-trends test.
68
+ noncentrality : float
69
+ Non-centrality parameter under the alternative hypothesis.
70
+ pre_period_effects : np.ndarray
71
+ Estimated pre-period effects from the event study.
72
+ pre_period_ses : np.ndarray
73
+ Standard errors of pre-period effects.
74
+ vcov : np.ndarray
75
+ Variance-covariance matrix of pre-period effects.
76
+ """
77
+
78
+ power: float
79
+ mdv: float
80
+ violation_magnitude: float
81
+ violation_type: str
82
+ alpha: float
83
+ target_power: float
84
+ n_pre_periods: int
85
+ test_statistic: float
86
+ critical_value: float
87
+ noncentrality: float
88
+ pre_period_effects: np.ndarray = field(repr=False)
89
+ pre_period_ses: np.ndarray = field(repr=False)
90
+ vcov: np.ndarray = field(repr=False)
91
+ original_results: Optional[Any] = field(default=None, repr=False)
92
+
93
+ def __repr__(self) -> str:
94
+ return (
95
+ f"PreTrendsPowerResults(power={self.power:.3f}, "
96
+ f"mdv={self.mdv:.4f}, M={self.violation_magnitude:.4f})"
97
+ )
98
+
99
+ @property
100
+ def is_informative(self) -> bool:
101
+ """
102
+ Check if the pre-trends test is informative.
103
+
104
+ A pre-trends test is considered informative if the MDV is reasonably
105
+ small relative to typical effect sizes. This is a heuristic check;
106
+ see the summary for interpretation guidance.
107
+ """
108
+ # Heuristic: MDV < 2x the max observed pre-period SE
109
+ max_se = np.max(self.pre_period_ses) if len(self.pre_period_ses) > 0 else 1.0
110
+ return bool(self.mdv < 2 * max_se)
111
+
112
+ @property
113
+ def power_adequate(self) -> bool:
114
+ """Check if power meets the target threshold."""
115
+ return bool(self.power >= self.target_power)
116
+
117
+ def summary(self) -> str:
118
+ """
119
+ Generate formatted summary of pre-trends power analysis.
120
+
121
+ Returns
122
+ -------
123
+ str
124
+ Formatted summary.
125
+ """
126
+ lines = [
127
+ "=" * 70,
128
+ "Pre-Trends Power Analysis Results".center(70),
129
+ "(Roth 2022)".center(70),
130
+ "=" * 70,
131
+ "",
132
+ f"{'Number of pre-periods:':<35} {self.n_pre_periods}",
133
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
134
+ f"{'Target power:':<35} {self.target_power:.1%}",
135
+ f"{'Violation type:':<35} {self.violation_type}",
136
+ "",
137
+ "-" * 70,
138
+ "Power Analysis".center(70),
139
+ "-" * 70,
140
+ f"{'Violation magnitude (M):':<35} {self.violation_magnitude:.4f}",
141
+ f"{'Power to detect this violation:':<35} {self.power:.1%}",
142
+ f"{'Minimum detectable violation:':<35} {self.mdv:.4f}",
143
+ "",
144
+ f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}",
145
+ f"{'Critical value:':<35} {self.critical_value:.4f}",
146
+ f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}",
147
+ "",
148
+ "-" * 70,
149
+ "Interpretation".center(70),
150
+ "-" * 70,
151
+ ]
152
+
153
+ if self.power_adequate:
154
+ lines.append(
155
+ f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%})."
156
+ )
157
+ lines.append(
158
+ f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}."
159
+ )
160
+ else:
161
+ lines.append(
162
+ f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%})."
163
+ )
164
+ lines.append(
165
+ f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power."
166
+ )
167
+
168
+ lines.append("")
169
+ lines.append(
170
+ f"Minimum detectable violation (MDV): {self.mdv:.4f}"
171
+ )
172
+ lines.append(
173
+ " → Passing pre-trends test does NOT rule out violations up to this size."
174
+ )
175
+
176
+ lines.extend(["", "=" * 70])
177
+
178
+ return "\n".join(lines)
179
+
180
+ def print_summary(self) -> None:
181
+ """Print summary to stdout."""
182
+ print(self.summary())
183
+
184
+ def to_dict(self) -> Dict[str, Any]:
185
+ """Convert results to dictionary."""
186
+ return {
187
+ "power": self.power,
188
+ "mdv": self.mdv,
189
+ "violation_magnitude": self.violation_magnitude,
190
+ "violation_type": self.violation_type,
191
+ "alpha": self.alpha,
192
+ "target_power": self.target_power,
193
+ "n_pre_periods": self.n_pre_periods,
194
+ "test_statistic": self.test_statistic,
195
+ "critical_value": self.critical_value,
196
+ "noncentrality": self.noncentrality,
197
+ "is_informative": self.is_informative,
198
+ "power_adequate": self.power_adequate,
199
+ }
200
+
201
+ def to_dataframe(self) -> pd.DataFrame:
202
+ """Convert results to DataFrame."""
203
+ return pd.DataFrame([self.to_dict()])
204
+
205
+
206
+ @dataclass
207
+ class PreTrendsPowerCurve:
208
+ """
209
+ Power curve across violation magnitudes.
210
+
211
+ Attributes
212
+ ----------
213
+ M_values : np.ndarray
214
+ Grid of violation magnitudes tested.
215
+ powers : np.ndarray
216
+ Power at each violation magnitude.
217
+ mdv : float
218
+ Minimum detectable violation.
219
+ alpha : float
220
+ Significance level.
221
+ target_power : float
222
+ Target power level.
223
+ violation_type : str
224
+ Type of violation pattern.
225
+ """
226
+
227
+ M_values: np.ndarray
228
+ powers: np.ndarray
229
+ mdv: float
230
+ alpha: float
231
+ target_power: float
232
+ violation_type: str
233
+
234
+ def __repr__(self) -> str:
235
+ return (
236
+ f"PreTrendsPowerCurve(n_points={len(self.M_values)}, "
237
+ f"mdv={self.mdv:.4f})"
238
+ )
239
+
240
+ def to_dataframe(self) -> pd.DataFrame:
241
+ """Convert to DataFrame with M and power columns."""
242
+ return pd.DataFrame({
243
+ "M": self.M_values,
244
+ "power": self.powers,
245
+ })
246
+
247
+ def plot(self, ax=None, show_mdv: bool = True, show_target: bool = True,
248
+ color: str = "#2563eb", mdv_color: str = "#dc2626",
249
+ target_color: str = "#22c55e", **kwargs):
250
+ """
251
+ Plot the power curve.
252
+
253
+ Parameters
254
+ ----------
255
+ ax : matplotlib.axes.Axes, optional
256
+ Axes to plot on. If None, creates new figure.
257
+ show_mdv : bool, default=True
258
+ Whether to show vertical line at MDV.
259
+ show_target : bool, default=True
260
+ Whether to show horizontal line at target power.
261
+ color : str
262
+ Color for power curve line.
263
+ mdv_color : str
264
+ Color for MDV vertical line.
265
+ target_color : str
266
+ Color for target power horizontal line.
267
+ **kwargs
268
+ Additional arguments passed to plt.plot().
269
+
270
+ Returns
271
+ -------
272
+ ax : matplotlib.axes.Axes
273
+ The axes with the plot.
274
+ """
275
+ try:
276
+ import matplotlib.pyplot as plt
277
+ except ImportError:
278
+ raise ImportError("matplotlib is required for plotting")
279
+
280
+ if ax is None:
281
+ fig, ax = plt.subplots(figsize=(10, 6))
282
+
283
+ # Plot power curve
284
+ ax.plot(self.M_values, self.powers, color=color, linewidth=2,
285
+ label="Power", **kwargs)
286
+
287
+ # Target power line
288
+ if show_target:
289
+ ax.axhline(y=self.target_power, color=target_color, linestyle="--",
290
+ linewidth=1.5, alpha=0.7,
291
+ label=f"Target power ({self.target_power:.0%})")
292
+
293
+ # MDV line
294
+ if show_mdv and self.mdv is not None and np.isfinite(self.mdv):
295
+ ax.axvline(x=self.mdv, color=mdv_color, linestyle=":",
296
+ linewidth=1.5, alpha=0.7,
297
+ label=f"MDV = {self.mdv:.3f}")
298
+
299
+ ax.set_xlabel("Violation Magnitude (M)")
300
+ ax.set_ylabel("Power")
301
+ ax.set_title("Pre-Trends Test Power Curve")
302
+ ax.set_ylim(0, 1.05)
303
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
304
+ ax.legend(loc="lower right")
305
+ ax.grid(True, alpha=0.3)
306
+
307
+ return ax
308
+
309
+
310
+ # =============================================================================
311
+ # Main Class
312
+ # =============================================================================
313
+
314
+
315
+ class PreTrendsPower:
316
+ """
317
+ Pre-trends power analysis (Roth 2022).
318
+
319
+ Computes the power of pre-trends tests to detect violations of parallel
320
+ trends, and the minimum detectable violation (MDV).
321
+
322
+ Parameters
323
+ ----------
324
+ alpha : float, default=0.05
325
+ Significance level for the pre-trends test.
326
+ power : float, default=0.80
327
+ Target power level for MDV calculation.
328
+ violation_type : str, default='linear'
329
+ Type of violation pattern to consider:
330
+ - 'linear': Violations follow a linear trend (most common)
331
+ - 'constant': Same violation in all pre-periods
332
+ - 'last_period': Violation only in the last pre-period
333
+ - 'custom': User-specified violation pattern (via violation_weights)
334
+ violation_weights : array-like, optional
335
+ Custom weights for violation pattern. Length must equal number of
336
+ pre-periods. Only used when violation_type='custom'.
337
+
338
+ Examples
339
+ --------
340
+ Basic usage with MultiPeriodDiD results:
341
+
342
+ >>> from diff_diff import MultiPeriodDiD
343
+ >>> from diff_diff.pretrends import PreTrendsPower
344
+ >>>
345
+ >>> # Fit event study
346
+ >>> mp_did = MultiPeriodDiD()
347
+ >>> results = mp_did.fit(data, outcome='y', treatment='treated',
348
+ ... time='period', post_periods=[4, 5, 6, 7])
349
+ >>>
350
+ >>> # Analyze pre-trends power
351
+ >>> pt = PreTrendsPower(alpha=0.05, power=0.80)
352
+ >>> power_results = pt.fit(results)
353
+ >>> print(power_results.summary())
354
+ >>>
355
+ >>> # Get power curve
356
+ >>> curve = pt.power_curve(results)
357
+ >>> curve.plot()
358
+
359
+ Notes
360
+ -----
361
+ The pre-trends test is typically a joint test that all pre-period
362
+ coefficients are zero. This test has limited power to detect small
363
+ violations, especially when:
364
+
365
+ 1. There are few pre-periods
366
+ 2. Standard errors are large
367
+ 3. The violation pattern is smooth (e.g., linear trend)
368
+
369
+ Passing a pre-trends test does NOT mean parallel trends holds. It means
370
+ violations smaller than the MDV cannot be ruled out. For robust inference,
371
+ combine with HonestDiD sensitivity analysis.
372
+
373
+ References
374
+ ----------
375
+ Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing
376
+ for Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ alpha: float = 0.05,
382
+ power: float = 0.80,
383
+ violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear",
384
+ violation_weights: Optional[np.ndarray] = None,
385
+ ):
386
+ if not 0 < alpha < 1:
387
+ raise ValueError(f"alpha must be between 0 and 1, got {alpha}")
388
+ if not 0 < power < 1:
389
+ raise ValueError(f"power must be between 0 and 1, got {power}")
390
+ if violation_type not in ["linear", "constant", "last_period", "custom"]:
391
+ raise ValueError(
392
+ f"violation_type must be 'linear', 'constant', 'last_period', or 'custom', "
393
+ f"got '{violation_type}'"
394
+ )
395
+ if violation_type == "custom" and violation_weights is None:
396
+ raise ValueError(
397
+ "violation_weights must be provided when violation_type='custom'"
398
+ )
399
+
400
+ self.alpha = alpha
401
+ self.target_power = power
402
+ self.violation_type = violation_type
403
+ self.violation_weights = (
404
+ np.asarray(violation_weights) if violation_weights is not None else None
405
+ )
406
+
407
+ def get_params(self) -> Dict[str, Any]:
408
+ """Get parameters for this estimator."""
409
+ return {
410
+ "alpha": self.alpha,
411
+ "power": self.target_power,
412
+ "violation_type": self.violation_type,
413
+ "violation_weights": self.violation_weights,
414
+ }
415
+
416
+ def set_params(self, **params) -> "PreTrendsPower":
417
+ """Set parameters for this estimator."""
418
+ for key, value in params.items():
419
+ if key == "power":
420
+ self.target_power = value
421
+ elif hasattr(self, key):
422
+ setattr(self, key, value)
423
+ else:
424
+ raise ValueError(f"Invalid parameter: {key}")
425
+ return self
426
+
427
+ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
428
+ """
429
+ Get violation weights based on violation type.
430
+
431
+ Parameters
432
+ ----------
433
+ n_pre : int
434
+ Number of pre-treatment periods.
435
+
436
+ Returns
437
+ -------
438
+ np.ndarray
439
+ Violation weights, normalized to have L2 norm of 1.
440
+ """
441
+ if self.violation_type == "custom":
442
+ if len(self.violation_weights) != n_pre:
443
+ raise ValueError(
444
+ f"violation_weights has length {len(self.violation_weights)}, "
445
+ f"but there are {n_pre} pre-periods"
446
+ )
447
+ weights = self.violation_weights.copy()
448
+ elif self.violation_type == "linear":
449
+ # Linear trend: weights = [-n+1, -n+2, ..., -1, 0] for periods ending at -1
450
+ # Normalized so that violation at period -1 = 0 and grows linearly backward
451
+ weights = np.arange(-n_pre + 1, 1, dtype=float)
452
+ # Shift so that weights are positive and represent deviation from PT
453
+ weights = -weights # Now [n-1, n-2, ..., 1, 0]
454
+ elif self.violation_type == "constant":
455
+ # Same violation in all periods
456
+ weights = np.ones(n_pre)
457
+ elif self.violation_type == "last_period":
458
+ # Violation only in last pre-period (period -1)
459
+ weights = np.zeros(n_pre)
460
+ weights[-1] = 1.0
461
+ else:
462
+ raise ValueError(f"Unknown violation_type: {self.violation_type}")
463
+
464
+ # Normalize to unit norm (if not all zeros)
465
+ norm = np.linalg.norm(weights)
466
+ if norm > 0:
467
+ weights = weights / norm
468
+
469
+ return weights
470
+
471
+ def _extract_pre_period_params(
472
+ self,
473
+ results: Union[MultiPeriodDiDResults, Any],
474
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
475
+ """
476
+ Extract pre-period parameters from results.
477
+
478
+ Returns
479
+ -------
480
+ effects : np.ndarray
481
+ Pre-period effect estimates.
482
+ ses : np.ndarray
483
+ Pre-period standard errors.
484
+ vcov : np.ndarray
485
+ Variance-covariance matrix for pre-period effects.
486
+ n_pre : int
487
+ Number of pre-periods.
488
+ """
489
+ if isinstance(results, MultiPeriodDiDResults):
490
+ # Get pre-period information
491
+ all_pre_periods = results.pre_periods
492
+
493
+ if len(all_pre_periods) == 0:
494
+ raise ValueError(
495
+ "No pre-treatment periods found in results. "
496
+ "Pre-trends power analysis requires pre-period coefficients."
497
+ )
498
+
499
+ # Only include periods with actual estimated coefficients
500
+ # (excludes the reference period which is omitted from estimation)
501
+ if hasattr(results, 'coefficients') and results.coefficients:
502
+ # Find which pre-periods have estimated coefficients
503
+ estimated_pre_periods = [
504
+ p for p in all_pre_periods
505
+ if f"treated:period_{p}" in results.coefficients
506
+ ]
507
+
508
+ if len(estimated_pre_periods) == 0:
509
+ raise ValueError(
510
+ "No estimated pre-period coefficients found. "
511
+ "The pre-trends test requires at least one estimated "
512
+ "pre-period coefficient (excluding the reference period)."
513
+ )
514
+
515
+ n_pre = len(estimated_pre_periods)
516
+
517
+ # Extract effects for estimated periods only
518
+ effects = np.array([
519
+ results.coefficients[f"treated:period_{p}"]
520
+ for p in estimated_pre_periods
521
+ ])
522
+
523
+ # Extract SEs - try period_effects first, fall back to avg_se
524
+ ses = []
525
+ for p in estimated_pre_periods:
526
+ if p in results.period_effects:
527
+ ses.append(results.period_effects[p].se)
528
+ else:
529
+ ses.append(results.avg_se)
530
+ ses = np.array(ses)
531
+
532
+ # Extract vcov for estimated pre-periods
533
+ # Build mapping from period to vcov index
534
+ if results.vcov is not None:
535
+ # Get ordered list of all coefficient keys
536
+ coef_keys = list(results.coefficients.keys())
537
+ pre_indices = [
538
+ coef_keys.index(f"treated:period_{p}")
539
+ for p in estimated_pre_periods
540
+ if f"treated:period_{p}" in coef_keys
541
+ ]
542
+ if len(pre_indices) == n_pre and results.vcov.shape[0] > max(pre_indices):
543
+ vcov = results.vcov[np.ix_(pre_indices, pre_indices)]
544
+ else:
545
+ # Fall back to diagonal
546
+ vcov = np.diag(ses ** 2)
547
+ else:
548
+ vcov = np.diag(ses ** 2)
549
+ else:
550
+ # No coefficients available - try period_effects for pre-periods
551
+ # Exclude reference period (the one with effect=0 and se=0 or missing)
552
+ estimated_pre_periods = [
553
+ p for p in all_pre_periods
554
+ if p in results.period_effects
555
+ and results.period_effects[p].se > 0
556
+ ]
557
+
558
+ if len(estimated_pre_periods) == 0:
559
+ raise ValueError(
560
+ "No estimated pre-period effects found. "
561
+ "The pre-trends test requires at least one estimated "
562
+ "pre-period effect (excluding the reference period)."
563
+ )
564
+
565
+ n_pre = len(estimated_pre_periods)
566
+ effects = np.array([
567
+ results.period_effects[p].effect
568
+ for p in estimated_pre_periods
569
+ ])
570
+ ses = np.array([
571
+ results.period_effects[p].se
572
+ for p in estimated_pre_periods
573
+ ])
574
+ vcov = np.diag(ses ** 2)
575
+
576
+ return effects, ses, vcov, n_pre
577
+
578
+ # Try CallawaySantAnnaResults
579
+ try:
580
+ from diff_diff.staggered import CallawaySantAnnaResults
581
+ if isinstance(results, CallawaySantAnnaResults):
582
+ if results.event_study_effects is None:
583
+ raise ValueError(
584
+ "CallawaySantAnnaResults must have event_study_effects. "
585
+ "Re-run with aggregate='event_study'."
586
+ )
587
+
588
+ # Get pre-period effects (negative relative times)
589
+ pre_effects = {
590
+ t: data for t, data in results.event_study_effects.items()
591
+ if t < 0
592
+ }
593
+
594
+ if not pre_effects:
595
+ raise ValueError("No pre-treatment periods found in event study.")
596
+
597
+ pre_periods = sorted(pre_effects.keys())
598
+ n_pre = len(pre_periods)
599
+
600
+ effects = np.array([pre_effects[t]['effect'] for t in pre_periods])
601
+ ses = np.array([pre_effects[t]['se'] for t in pre_periods])
602
+ vcov = np.diag(ses ** 2)
603
+
604
+ return effects, ses, vcov, n_pre
605
+ except ImportError:
606
+ pass
607
+
608
+ # Try SunAbrahamResults
609
+ try:
610
+ from diff_diff.sun_abraham import SunAbrahamResults
611
+ if isinstance(results, SunAbrahamResults):
612
+ # Get pre-period effects (negative relative times)
613
+ pre_effects = {
614
+ t: data for t, data in results.event_study_effects.items()
615
+ if t < 0
616
+ }
617
+
618
+ if not pre_effects:
619
+ raise ValueError("No pre-treatment periods found in event study.")
620
+
621
+ pre_periods = sorted(pre_effects.keys())
622
+ n_pre = len(pre_periods)
623
+
624
+ effects = np.array([pre_effects[t]['effect'] for t in pre_periods])
625
+ ses = np.array([pre_effects[t]['se'] for t in pre_periods])
626
+ vcov = np.diag(ses ** 2)
627
+
628
+ return effects, ses, vcov, n_pre
629
+ except ImportError:
630
+ pass
631
+
632
+ raise TypeError(
633
+ f"Unsupported results type: {type(results)}. "
634
+ "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults."
635
+ )
636
+
637
+ def _compute_power(
638
+ self,
639
+ M: float,
640
+ weights: np.ndarray,
641
+ vcov: np.ndarray,
642
+ ) -> Tuple[float, float, float, float]:
643
+ """
644
+ Compute power to detect violation of magnitude M.
645
+
646
+ The pre-trends test is a Wald test: H0: delta = 0 vs H1: delta != 0
647
+ Under H1 with violation delta = M * weights, the test statistic follows
648
+ a non-central chi-squared distribution.
649
+
650
+ Parameters
651
+ ----------
652
+ M : float
653
+ Violation magnitude.
654
+ weights : np.ndarray
655
+ Normalized violation pattern.
656
+ vcov : np.ndarray
657
+ Variance-covariance matrix.
658
+
659
+ Returns
660
+ -------
661
+ power : float
662
+ Power to detect this violation.
663
+ noncentrality : float
664
+ Non-centrality parameter.
665
+ test_stat : float
666
+ Expected test statistic under H1.
667
+ critical_value : float
668
+ Critical value for the test.
669
+ """
670
+ n_pre = len(weights)
671
+
672
+ # Violation vector: delta = M * weights
673
+ delta = M * weights
674
+
675
+ # Non-centrality parameter for chi-squared test
676
+ # lambda = delta' * V^{-1} * delta
677
+ try:
678
+ vcov_inv = np.linalg.inv(vcov)
679
+ noncentrality = delta @ vcov_inv @ delta
680
+ except np.linalg.LinAlgError:
681
+ # Singular matrix - use pseudo-inverse
682
+ vcov_inv = np.linalg.pinv(vcov)
683
+ noncentrality = delta @ vcov_inv @ delta
684
+
685
+ # Critical value from chi-squared distribution
686
+ critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
687
+
688
+ # Power = P(chi2_nc > critical_value) where chi2_nc is non-central chi2
689
+ if noncentrality > 0:
690
+ power = 1 - stats.ncx2.cdf(critical_value, df=n_pre, nc=noncentrality)
691
+ else:
692
+ power = self.alpha # Size under null
693
+
694
+ # Expected test statistic under H1
695
+ test_stat = n_pre + noncentrality # Mean of non-central chi2
696
+
697
+ return power, noncentrality, test_stat, critical_value
698
+
699
+ def _compute_mdv(
700
+ self,
701
+ weights: np.ndarray,
702
+ vcov: np.ndarray,
703
+ ) -> float:
704
+ """
705
+ Compute minimum detectable violation.
706
+
707
+ Find the smallest M such that power >= target_power.
708
+
709
+ Parameters
710
+ ----------
711
+ weights : np.ndarray
712
+ Normalized violation pattern.
713
+ vcov : np.ndarray
714
+ Variance-covariance matrix.
715
+
716
+ Returns
717
+ -------
718
+ mdv : float
719
+ Minimum detectable violation.
720
+ """
721
+ n_pre = len(weights)
722
+
723
+ # Critical value
724
+ critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
725
+
726
+ # Find non-centrality parameter for target power
727
+ # We need: P(ncx2 > critical_value) = target_power
728
+ # Use inverse: find lambda such that ncx2.cdf(cv, df, lambda) = 1 - target_power
729
+
730
+ def power_minus_target(nc):
731
+ if nc <= 0:
732
+ return self.alpha - self.target_power
733
+ return stats.ncx2.sf(critical_value, df=n_pre, nc=nc) - self.target_power
734
+
735
+ # Binary search for non-centrality parameter
736
+ # Start with bounds
737
+ nc_low, nc_high = 0, 1
738
+
739
+ # Expand upper bound until power exceeds target
740
+ while power_minus_target(nc_high) < 0 and nc_high < 1000:
741
+ nc_high *= 2
742
+
743
+ if nc_high >= 1000:
744
+ # Target power not achievable - return inf
745
+ return np.inf
746
+
747
+ # Binary search
748
+ try:
749
+ result = optimize.brentq(power_minus_target, nc_low, nc_high)
750
+ target_nc = result
751
+ except ValueError:
752
+ # Fallback: use approximate formula
753
+ # For chi2, power ≈ Phi(sqrt(2*nc) - sqrt(2*cv))
754
+ # Solving: sqrt(2*nc) = z_power + sqrt(2*cv)
755
+ z_power = stats.norm.ppf(self.target_power)
756
+ target_nc = 0.5 * (z_power + np.sqrt(2 * critical_value)) ** 2
757
+
758
+ # Convert non-centrality to M
759
+ # nc = delta' * V^{-1} * delta = M^2 * w' * V^{-1} * w
760
+ try:
761
+ vcov_inv = np.linalg.inv(vcov)
762
+ w_Vinv_w = weights @ vcov_inv @ weights
763
+ except np.linalg.LinAlgError:
764
+ vcov_inv = np.linalg.pinv(vcov)
765
+ w_Vinv_w = weights @ vcov_inv @ weights
766
+
767
+ if w_Vinv_w > 0:
768
+ mdv = np.sqrt(target_nc / w_Vinv_w)
769
+ else:
770
+ mdv = np.inf
771
+
772
+ return mdv
773
+
774
+ def fit(
775
+ self,
776
+ results: Union[MultiPeriodDiDResults, Any],
777
+ M: Optional[float] = None,
778
+ ) -> PreTrendsPowerResults:
779
+ """
780
+ Compute pre-trends power analysis.
781
+
782
+ Parameters
783
+ ----------
784
+ results : MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults
785
+ Results from an event study estimation.
786
+ M : float, optional
787
+ Specific violation magnitude to evaluate. If None, evaluates at
788
+ a default magnitude based on the data.
789
+
790
+ Returns
791
+ -------
792
+ PreTrendsPowerResults
793
+ Power analysis results including power and MDV.
794
+ """
795
+ # Extract pre-period parameters
796
+ effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
797
+
798
+ # Get violation weights
799
+ weights = self._get_violation_weights(n_pre)
800
+
801
+ # Compute MDV
802
+ mdv = self._compute_mdv(weights, vcov)
803
+
804
+ # Default M: use MDV if not specified
805
+ if M is None:
806
+ M = mdv if np.isfinite(mdv) else np.max(ses)
807
+
808
+ # Compute power at specified M
809
+ power, noncentrality, test_stat, critical_value = self._compute_power(
810
+ M, weights, vcov
811
+ )
812
+
813
+ return PreTrendsPowerResults(
814
+ power=power,
815
+ mdv=mdv,
816
+ violation_magnitude=M,
817
+ violation_type=self.violation_type,
818
+ alpha=self.alpha,
819
+ target_power=self.target_power,
820
+ n_pre_periods=n_pre,
821
+ test_statistic=test_stat,
822
+ critical_value=critical_value,
823
+ noncentrality=noncentrality,
824
+ pre_period_effects=effects,
825
+ pre_period_ses=ses,
826
+ vcov=vcov,
827
+ original_results=results,
828
+ )
829
+
830
+ def power_at(
831
+ self,
832
+ results: Union[MultiPeriodDiDResults, Any],
833
+ M: float,
834
+ ) -> float:
835
+ """
836
+ Compute power to detect a specific violation magnitude.
837
+
838
+ Parameters
839
+ ----------
840
+ results : results object
841
+ Event study results.
842
+ M : float
843
+ Violation magnitude.
844
+
845
+ Returns
846
+ -------
847
+ float
848
+ Power to detect violation of magnitude M.
849
+ """
850
+ result = self.fit(results, M=M)
851
+ return result.power
852
+
853
+ def power_curve(
854
+ self,
855
+ results: Union[MultiPeriodDiDResults, Any],
856
+ M_grid: Optional[List[float]] = None,
857
+ n_points: int = 50,
858
+ ) -> PreTrendsPowerCurve:
859
+ """
860
+ Compute power across a range of violation magnitudes.
861
+
862
+ Parameters
863
+ ----------
864
+ results : results object
865
+ Event study results.
866
+ M_grid : list of float, optional
867
+ Specific violation magnitudes to evaluate. If None, creates
868
+ automatic grid from 0 to 2.5 * MDV.
869
+ n_points : int, default=50
870
+ Number of points in automatic grid.
871
+
872
+ Returns
873
+ -------
874
+ PreTrendsPowerCurve
875
+ Power curve data with plot method.
876
+ """
877
+ # Extract parameters
878
+ effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
879
+ weights = self._get_violation_weights(n_pre)
880
+
881
+ # Compute MDV
882
+ mdv = self._compute_mdv(weights, vcov)
883
+
884
+ # Create M grid if not provided
885
+ if M_grid is None:
886
+ max_M = min(2.5 * mdv if np.isfinite(mdv) else 10 * np.max(ses), 100)
887
+ M_grid = np.linspace(0, max_M, n_points)
888
+ else:
889
+ M_grid = np.asarray(M_grid)
890
+
891
+ # Compute power at each M
892
+ powers = np.array([
893
+ self._compute_power(M, weights, vcov)[0]
894
+ for M in M_grid
895
+ ])
896
+
897
+ return PreTrendsPowerCurve(
898
+ M_values=M_grid,
899
+ powers=powers,
900
+ mdv=mdv,
901
+ alpha=self.alpha,
902
+ target_power=self.target_power,
903
+ violation_type=self.violation_type,
904
+ )
905
+
906
+ def sensitivity_to_honest_did(
907
+ self,
908
+ results: Union[MultiPeriodDiDResults, Any],
909
+ ) -> Dict[str, Any]:
910
+ """
911
+ Compare pre-trends power analysis with HonestDiD sensitivity.
912
+
913
+ This method helps interpret how informative a passing pre-trends
914
+ test is in the context of HonestDiD's relative magnitudes restriction.
915
+
916
+ Parameters
917
+ ----------
918
+ results : results object
919
+ Event study results.
920
+
921
+ Returns
922
+ -------
923
+ dict
924
+ Dictionary with:
925
+ - mdv: Minimum detectable violation from pre-trends test
926
+ - honest_M_at_mdv: Corresponding M value for HonestDiD
927
+ - interpretation: Text explaining the relationship
928
+ """
929
+ pt_results = self.fit(results)
930
+ mdv = pt_results.mdv
931
+
932
+ # The MDV represents the size of violation the test could detect
933
+ # In HonestDiD's relative magnitudes framework, M=1 means
934
+ # post-treatment violations can be as large as the max pre-period violation
935
+ # The MDV gives us a sense of how large that max violation could be
936
+
937
+ max_pre_se = np.max(pt_results.pre_period_ses)
938
+
939
+ interpretation = []
940
+ interpretation.append(
941
+ f"Minimum Detectable Violation (MDV): {mdv:.4f}"
942
+ )
943
+ interpretation.append(
944
+ f"Max pre-period SE: {max_pre_se:.4f}"
945
+ )
946
+
947
+ if np.isfinite(mdv):
948
+ # Ratio of MDV to max SE - gives sense of how many SEs the MDV is
949
+ mdv_in_ses = mdv / max_pre_se if max_pre_se > 0 else np.inf
950
+ interpretation.append(
951
+ f"MDV / max(SE): {mdv_in_ses:.2f}"
952
+ )
953
+
954
+ if mdv_in_ses < 1:
955
+ interpretation.append(
956
+ "→ Pre-trends test is fairly sensitive to violations."
957
+ )
958
+ elif mdv_in_ses < 2:
959
+ interpretation.append(
960
+ "→ Pre-trends test has moderate sensitivity."
961
+ )
962
+ else:
963
+ interpretation.append(
964
+ "→ Pre-trends test has low power to detect violations."
965
+ )
966
+ interpretation.append(
967
+ " Consider using HonestDiD with larger M values for robustness."
968
+ )
969
+ else:
970
+ interpretation.append(
971
+ "→ Pre-trends test cannot achieve target power for any violation size."
972
+ )
973
+ interpretation.append(
974
+ " Use HonestDiD sensitivity analysis for inference."
975
+ )
976
+
977
+ return {
978
+ "mdv": mdv,
979
+ "max_pre_se": max_pre_se,
980
+ "mdv_in_ses": mdv / max_pre_se if max_pre_se > 0 and np.isfinite(mdv) else np.inf,
981
+ "interpretation": "\n".join(interpretation),
982
+ }
983
+
984
+
985
+ # =============================================================================
986
+ # Convenience Functions
987
+ # =============================================================================
988
+
989
+
990
+ def compute_pretrends_power(
991
+ results: Union[MultiPeriodDiDResults, Any],
992
+ M: Optional[float] = None,
993
+ alpha: float = 0.05,
994
+ target_power: float = 0.80,
995
+ violation_type: str = "linear",
996
+ ) -> PreTrendsPowerResults:
997
+ """
998
+ Convenience function for pre-trends power analysis.
999
+
1000
+ Parameters
1001
+ ----------
1002
+ results : results object
1003
+ Event study results.
1004
+ M : float, optional
1005
+ Violation magnitude to evaluate.
1006
+ alpha : float, default=0.05
1007
+ Significance level.
1008
+ target_power : float, default=0.80
1009
+ Target power for MDV calculation.
1010
+ violation_type : str, default='linear'
1011
+ Type of violation pattern.
1012
+
1013
+ Returns
1014
+ -------
1015
+ PreTrendsPowerResults
1016
+ Power analysis results.
1017
+
1018
+ Examples
1019
+ --------
1020
+ >>> from diff_diff import MultiPeriodDiD
1021
+ >>> from diff_diff.pretrends import compute_pretrends_power
1022
+ >>>
1023
+ >>> results = MultiPeriodDiD().fit(data, ...)
1024
+ >>> power_results = compute_pretrends_power(results)
1025
+ >>> print(f"MDV: {power_results.mdv:.3f}")
1026
+ >>> print(f"Power: {power_results.power:.1%}")
1027
+ """
1028
+ pt = PreTrendsPower(
1029
+ alpha=alpha,
1030
+ power=target_power,
1031
+ violation_type=violation_type,
1032
+ )
1033
+ return pt.fit(results, M=M)
1034
+
1035
+
1036
+ def compute_mdv(
1037
+ results: Union[MultiPeriodDiDResults, Any],
1038
+ alpha: float = 0.05,
1039
+ target_power: float = 0.80,
1040
+ violation_type: str = "linear",
1041
+ ) -> float:
1042
+ """
1043
+ Compute minimum detectable violation.
1044
+
1045
+ Parameters
1046
+ ----------
1047
+ results : results object
1048
+ Event study results.
1049
+ alpha : float, default=0.05
1050
+ Significance level.
1051
+ target_power : float, default=0.80
1052
+ Target power.
1053
+ violation_type : str, default='linear'
1054
+ Type of violation pattern.
1055
+
1056
+ Returns
1057
+ -------
1058
+ float
1059
+ Minimum detectable violation.
1060
+ """
1061
+ pt = PreTrendsPower(
1062
+ alpha=alpha,
1063
+ power=target_power,
1064
+ violation_type=violation_type,
1065
+ )
1066
+ result = pt.fit(results)
1067
+ return result.mdv