diff-diff 2.2.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/pretrends.py ADDED
@@ -0,0 +1,1166 @@
1
+ """
2
+ Pre-trends power analysis for difference-in-differences designs.
3
+
4
+ This module implements the power analysis framework from Roth (2022) for assessing
5
+ the informativeness of pre-trends tests. It answers the question: "If my pre-trends
6
+ test passed, what violations would I have been able to detect?"
7
+
8
+ Key concepts:
9
+ - **Minimum Detectable Violation (MDV)**: The smallest pre-trends violation that
10
+ would be detected with given power (e.g., 80%).
11
+ - **Power of Pre-Trends Test**: Probability of rejecting parallel trends given
12
+ a specific violation pattern.
13
+ - **Relationship to HonestDiD**: If MDV is large relative to your estimated effect,
14
+ a passing pre-trends test provides limited reassurance.
15
+
16
+ References
17
+ ----------
18
+ Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for
19
+ Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
20
+ https://doi.org/10.1257/aeri.20210236
21
+
22
+ See Also
23
+ --------
24
+ https://github.com/jonathandroth/pretrends - R package implementation
25
+ diff_diff.honest_did - Sensitivity analysis for parallel trends violations
26
+ """
27
+
28
+ from dataclasses import dataclass, field
29
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
30
+
31
+ import numpy as np
32
+ import pandas as pd
33
+ from scipy import stats, optimize
34
+
35
+ from diff_diff.results import MultiPeriodDiDResults
36
+
37
+
38
+ # =============================================================================
39
+ # Results Classes
40
+ # =============================================================================
41
+
42
+
43
+ @dataclass
44
+ class PreTrendsPowerResults:
45
+ """
46
+ Results from pre-trends power analysis.
47
+
48
+ Attributes
49
+ ----------
50
+ power : float
51
+ Power to detect the specified violation pattern at given alpha.
52
+ mdv : float
53
+ Minimum detectable violation (smallest M detectable at target power).
54
+ violation_magnitude : float
55
+ The magnitude of violation tested (M parameter).
56
+ violation_type : str
57
+ Type of violation pattern ('linear', 'constant', 'last_period', 'custom').
58
+ alpha : float
59
+ Significance level for the pre-trends test.
60
+ target_power : float
61
+ Target power level used for MDV calculation.
62
+ n_pre_periods : int
63
+ Number of pre-treatment periods in the event study.
64
+ test_statistic : float
65
+ Expected test statistic under the specified violation.
66
+ critical_value : float
67
+ Critical value for the pre-trends test.
68
+ noncentrality : float
69
+ Non-centrality parameter under the alternative hypothesis.
70
+ pre_period_effects : np.ndarray
71
+ Estimated pre-period effects from the event study.
72
+ pre_period_ses : np.ndarray
73
+ Standard errors of pre-period effects.
74
+ vcov : np.ndarray
75
+ Variance-covariance matrix of pre-period effects.
76
+ """
77
+
78
+ power: float
79
+ mdv: float
80
+ violation_magnitude: float
81
+ violation_type: str
82
+ alpha: float
83
+ target_power: float
84
+ n_pre_periods: int
85
+ test_statistic: float
86
+ critical_value: float
87
+ noncentrality: float
88
+ pre_period_effects: np.ndarray = field(repr=False)
89
+ pre_period_ses: np.ndarray = field(repr=False)
90
+ vcov: np.ndarray = field(repr=False)
91
+ original_results: Optional[Any] = field(default=None, repr=False)
92
+
93
+ def __repr__(self) -> str:
94
+ return (
95
+ f"PreTrendsPowerResults(power={self.power:.3f}, "
96
+ f"mdv={self.mdv:.4f}, M={self.violation_magnitude:.4f})"
97
+ )
98
+
99
+ @property
100
+ def is_informative(self) -> bool:
101
+ """
102
+ Check if the pre-trends test is informative.
103
+
104
+ A pre-trends test is considered informative if the MDV is reasonably
105
+ small relative to typical effect sizes. This is a heuristic check;
106
+ see the summary for interpretation guidance.
107
+ """
108
+ # Heuristic: MDV < 2x the max observed pre-period SE
109
+ max_se = np.max(self.pre_period_ses) if len(self.pre_period_ses) > 0 else 1.0
110
+ return bool(self.mdv < 2 * max_se)
111
+
112
+ @property
113
+ def power_adequate(self) -> bool:
114
+ """Check if power meets the target threshold."""
115
+ return bool(self.power >= self.target_power)
116
+
117
+ def summary(self) -> str:
118
+ """
119
+ Generate formatted summary of pre-trends power analysis.
120
+
121
+ Returns
122
+ -------
123
+ str
124
+ Formatted summary.
125
+ """
126
+ lines = [
127
+ "=" * 70,
128
+ "Pre-Trends Power Analysis Results".center(70),
129
+ "(Roth 2022)".center(70),
130
+ "=" * 70,
131
+ "",
132
+ f"{'Number of pre-periods:':<35} {self.n_pre_periods}",
133
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
134
+ f"{'Target power:':<35} {self.target_power:.1%}",
135
+ f"{'Violation type:':<35} {self.violation_type}",
136
+ "",
137
+ "-" * 70,
138
+ "Power Analysis".center(70),
139
+ "-" * 70,
140
+ f"{'Violation magnitude (M):':<35} {self.violation_magnitude:.4f}",
141
+ f"{'Power to detect this violation:':<35} {self.power:.1%}",
142
+ f"{'Minimum detectable violation:':<35} {self.mdv:.4f}",
143
+ "",
144
+ f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}",
145
+ f"{'Critical value:':<35} {self.critical_value:.4f}",
146
+ f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}",
147
+ "",
148
+ "-" * 70,
149
+ "Interpretation".center(70),
150
+ "-" * 70,
151
+ ]
152
+
153
+ if self.power_adequate:
154
+ lines.append(
155
+ f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%})."
156
+ )
157
+ lines.append(
158
+ f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}."
159
+ )
160
+ else:
161
+ lines.append(
162
+ f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%})."
163
+ )
164
+ lines.append(
165
+ f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power."
166
+ )
167
+
168
+ lines.append("")
169
+ lines.append(
170
+ f"Minimum detectable violation (MDV): {self.mdv:.4f}"
171
+ )
172
+ lines.append(
173
+ " → Passing pre-trends test does NOT rule out violations up to this size."
174
+ )
175
+
176
+ lines.extend(["", "=" * 70])
177
+
178
+ return "\n".join(lines)
179
+
180
+ def print_summary(self) -> None:
181
+ """Print summary to stdout."""
182
+ print(self.summary())
183
+
184
+ def to_dict(self) -> Dict[str, Any]:
185
+ """Convert results to dictionary."""
186
+ return {
187
+ "power": self.power,
188
+ "mdv": self.mdv,
189
+ "violation_magnitude": self.violation_magnitude,
190
+ "violation_type": self.violation_type,
191
+ "alpha": self.alpha,
192
+ "target_power": self.target_power,
193
+ "n_pre_periods": self.n_pre_periods,
194
+ "test_statistic": self.test_statistic,
195
+ "critical_value": self.critical_value,
196
+ "noncentrality": self.noncentrality,
197
+ "is_informative": self.is_informative,
198
+ "power_adequate": self.power_adequate,
199
+ }
200
+
201
+ def to_dataframe(self) -> pd.DataFrame:
202
+ """Convert results to DataFrame."""
203
+ return pd.DataFrame([self.to_dict()])
204
+
205
+ def power_at(self, M: float) -> float:
206
+ """
207
+ Compute power to detect a specific violation magnitude.
208
+
209
+ This method allows computing power at different M values without
210
+ re-fitting the model, using the stored variance-covariance matrix.
211
+
212
+ Parameters
213
+ ----------
214
+ M : float
215
+ Violation magnitude to evaluate.
216
+
217
+ Returns
218
+ -------
219
+ float
220
+ Power to detect violation of magnitude M.
221
+ """
222
+ from scipy import stats
223
+
224
+ n_pre = self.n_pre_periods
225
+
226
+ # Reconstruct violation weights based on violation type
227
+ # Must match PreTrendsPower._get_violation_weights() exactly
228
+ if self.violation_type == "linear":
229
+ # Linear trend: weights decrease toward treatment
230
+ # [n-1, n-2, ..., 1, 0] for n pre-periods
231
+ weights = np.arange(-n_pre + 1, 1, dtype=float)
232
+ weights = -weights # Now [n-1, n-2, ..., 1, 0]
233
+ elif self.violation_type == "constant":
234
+ weights = np.ones(n_pre)
235
+ elif self.violation_type == "last_period":
236
+ weights = np.zeros(n_pre)
237
+ weights[-1] = 1.0
238
+ else:
239
+ # For custom, we can't reconstruct - use equal weights as fallback
240
+ weights = np.ones(n_pre)
241
+
242
+ # Normalize weights to unit L2 norm
243
+ norm = np.linalg.norm(weights)
244
+ if norm > 0:
245
+ weights = weights / norm
246
+
247
+ # Compute non-centrality parameter
248
+ try:
249
+ vcov_inv = np.linalg.inv(self.vcov)
250
+ except np.linalg.LinAlgError:
251
+ vcov_inv = np.linalg.pinv(self.vcov)
252
+
253
+ # delta = M * weights
254
+ # nc = delta' * V^{-1} * delta
255
+ noncentrality = M**2 * (weights @ vcov_inv @ weights)
256
+
257
+ # Compute power using non-central chi-squared
258
+ power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality)
259
+
260
+ return float(power)
261
+
262
+
263
+ @dataclass
264
+ class PreTrendsPowerCurve:
265
+ """
266
+ Power curve across violation magnitudes.
267
+
268
+ Attributes
269
+ ----------
270
+ M_values : np.ndarray
271
+ Grid of violation magnitudes tested.
272
+ powers : np.ndarray
273
+ Power at each violation magnitude.
274
+ mdv : float
275
+ Minimum detectable violation.
276
+ alpha : float
277
+ Significance level.
278
+ target_power : float
279
+ Target power level.
280
+ violation_type : str
281
+ Type of violation pattern.
282
+ """
283
+
284
+ M_values: np.ndarray
285
+ powers: np.ndarray
286
+ mdv: float
287
+ alpha: float
288
+ target_power: float
289
+ violation_type: str
290
+
291
+ def __repr__(self) -> str:
292
+ return (
293
+ f"PreTrendsPowerCurve(n_points={len(self.M_values)}, "
294
+ f"mdv={self.mdv:.4f})"
295
+ )
296
+
297
+ def to_dataframe(self) -> pd.DataFrame:
298
+ """Convert to DataFrame with M and power columns."""
299
+ return pd.DataFrame({
300
+ "M": self.M_values,
301
+ "power": self.powers,
302
+ })
303
+
304
+ def plot(self, ax=None, show_mdv: bool = True, show_target: bool = True,
305
+ color: str = "#2563eb", mdv_color: str = "#dc2626",
306
+ target_color: str = "#22c55e", **kwargs):
307
+ """
308
+ Plot the power curve.
309
+
310
+ Parameters
311
+ ----------
312
+ ax : matplotlib.axes.Axes, optional
313
+ Axes to plot on. If None, creates new figure.
314
+ show_mdv : bool, default=True
315
+ Whether to show vertical line at MDV.
316
+ show_target : bool, default=True
317
+ Whether to show horizontal line at target power.
318
+ color : str
319
+ Color for power curve line.
320
+ mdv_color : str
321
+ Color for MDV vertical line.
322
+ target_color : str
323
+ Color for target power horizontal line.
324
+ **kwargs
325
+ Additional arguments passed to plt.plot().
326
+
327
+ Returns
328
+ -------
329
+ ax : matplotlib.axes.Axes
330
+ The axes with the plot.
331
+ """
332
+ try:
333
+ import matplotlib.pyplot as plt
334
+ except ImportError:
335
+ raise ImportError("matplotlib is required for plotting")
336
+
337
+ if ax is None:
338
+ fig, ax = plt.subplots(figsize=(10, 6))
339
+
340
+ # Plot power curve
341
+ ax.plot(self.M_values, self.powers, color=color, linewidth=2,
342
+ label="Power", **kwargs)
343
+
344
+ # Target power line
345
+ if show_target:
346
+ ax.axhline(y=self.target_power, color=target_color, linestyle="--",
347
+ linewidth=1.5, alpha=0.7,
348
+ label=f"Target power ({self.target_power:.0%})")
349
+
350
+ # MDV line
351
+ if show_mdv and self.mdv is not None and np.isfinite(self.mdv):
352
+ ax.axvline(x=self.mdv, color=mdv_color, linestyle=":",
353
+ linewidth=1.5, alpha=0.7,
354
+ label=f"MDV = {self.mdv:.3f}")
355
+
356
+ ax.set_xlabel("Violation Magnitude (M)")
357
+ ax.set_ylabel("Power")
358
+ ax.set_title("Pre-Trends Test Power Curve")
359
+ ax.set_ylim(0, 1.05)
360
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
361
+ ax.legend(loc="lower right")
362
+ ax.grid(True, alpha=0.3)
363
+
364
+ return ax
365
+
366
+
367
+ # =============================================================================
368
+ # Main Class
369
+ # =============================================================================
370
+
371
+
372
+ class PreTrendsPower:
373
+ """
374
+ Pre-trends power analysis (Roth 2022).
375
+
376
+ Computes the power of pre-trends tests to detect violations of parallel
377
+ trends, and the minimum detectable violation (MDV).
378
+
379
+ Parameters
380
+ ----------
381
+ alpha : float, default=0.05
382
+ Significance level for the pre-trends test.
383
+ power : float, default=0.80
384
+ Target power level for MDV calculation.
385
+ violation_type : str, default='linear'
386
+ Type of violation pattern to consider:
387
+ - 'linear': Violations follow a linear trend (most common)
388
+ - 'constant': Same violation in all pre-periods
389
+ - 'last_period': Violation only in the last pre-period
390
+ - 'custom': User-specified violation pattern (via violation_weights)
391
+ violation_weights : array-like, optional
392
+ Custom weights for violation pattern. Length must equal number of
393
+ pre-periods. Only used when violation_type='custom'.
394
+
395
+ Examples
396
+ --------
397
+ Basic usage with MultiPeriodDiD results:
398
+
399
+ >>> from diff_diff import MultiPeriodDiD
400
+ >>> from diff_diff.pretrends import PreTrendsPower
401
+ >>>
402
+ >>> # Fit event study
403
+ >>> mp_did = MultiPeriodDiD()
404
+ >>> results = mp_did.fit(data, outcome='y', treatment='treated',
405
+ ... time='period', post_periods=[4, 5, 6, 7])
406
+ >>>
407
+ >>> # Analyze pre-trends power
408
+ >>> pt = PreTrendsPower(alpha=0.05, power=0.80)
409
+ >>> power_results = pt.fit(results)
410
+ >>> print(power_results.summary())
411
+ >>>
412
+ >>> # Get power curve
413
+ >>> curve = pt.power_curve(results)
414
+ >>> curve.plot()
415
+
416
+ Notes
417
+ -----
418
+ The pre-trends test is typically a joint test that all pre-period
419
+ coefficients are zero. This test has limited power to detect small
420
+ violations, especially when:
421
+
422
+ 1. There are few pre-periods
423
+ 2. Standard errors are large
424
+ 3. The violation pattern is smooth (e.g., linear trend)
425
+
426
+ Passing a pre-trends test does NOT mean parallel trends holds. It means
427
+ violations smaller than the MDV cannot be ruled out. For robust inference,
428
+ combine with HonestDiD sensitivity analysis.
429
+
430
+ References
431
+ ----------
432
+ Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing
433
+ for Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
434
+ """
435
+
436
+ def __init__(
437
+ self,
438
+ alpha: float = 0.05,
439
+ power: float = 0.80,
440
+ violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear",
441
+ violation_weights: Optional[np.ndarray] = None,
442
+ ):
443
+ if not 0 < alpha < 1:
444
+ raise ValueError(f"alpha must be between 0 and 1, got {alpha}")
445
+ if not 0 < power < 1:
446
+ raise ValueError(f"power must be between 0 and 1, got {power}")
447
+ if violation_type not in ["linear", "constant", "last_period", "custom"]:
448
+ raise ValueError(
449
+ f"violation_type must be 'linear', 'constant', 'last_period', or 'custom', "
450
+ f"got '{violation_type}'"
451
+ )
452
+ if violation_type == "custom" and violation_weights is None:
453
+ raise ValueError(
454
+ "violation_weights must be provided when violation_type='custom'"
455
+ )
456
+
457
+ self.alpha = alpha
458
+ self.target_power = power
459
+ self.violation_type = violation_type
460
+ self.violation_weights = (
461
+ np.asarray(violation_weights) if violation_weights is not None else None
462
+ )
463
+
464
+ def get_params(self) -> Dict[str, Any]:
465
+ """Get parameters for this estimator."""
466
+ return {
467
+ "alpha": self.alpha,
468
+ "power": self.target_power,
469
+ "violation_type": self.violation_type,
470
+ "violation_weights": self.violation_weights,
471
+ }
472
+
473
+ def set_params(self, **params) -> "PreTrendsPower":
474
+ """Set parameters for this estimator."""
475
+ for key, value in params.items():
476
+ if key == "power":
477
+ self.target_power = value
478
+ elif hasattr(self, key):
479
+ setattr(self, key, value)
480
+ else:
481
+ raise ValueError(f"Invalid parameter: {key}")
482
+ return self
483
+
484
+ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
485
+ """
486
+ Get violation weights based on violation type.
487
+
488
+ Parameters
489
+ ----------
490
+ n_pre : int
491
+ Number of pre-treatment periods.
492
+
493
+ Returns
494
+ -------
495
+ np.ndarray
496
+ Violation weights, normalized to have L2 norm of 1.
497
+ """
498
+ if self.violation_type == "custom":
499
+ if len(self.violation_weights) != n_pre:
500
+ raise ValueError(
501
+ f"violation_weights has length {len(self.violation_weights)}, "
502
+ f"but there are {n_pre} pre-periods"
503
+ )
504
+ weights = self.violation_weights.copy()
505
+ elif self.violation_type == "linear":
506
+ # Linear trend: weights = [-n+1, -n+2, ..., -1, 0] for periods ending at -1
507
+ # Normalized so that violation at period -1 = 0 and grows linearly backward
508
+ weights = np.arange(-n_pre + 1, 1, dtype=float)
509
+ # Shift so that weights are positive and represent deviation from PT
510
+ weights = -weights # Now [n-1, n-2, ..., 1, 0]
511
+ elif self.violation_type == "constant":
512
+ # Same violation in all periods
513
+ weights = np.ones(n_pre)
514
+ elif self.violation_type == "last_period":
515
+ # Violation only in last pre-period (period -1)
516
+ weights = np.zeros(n_pre)
517
+ weights[-1] = 1.0
518
+ else:
519
+ raise ValueError(f"Unknown violation_type: {self.violation_type}")
520
+
521
+ # Normalize to unit norm (if not all zeros)
522
+ norm = np.linalg.norm(weights)
523
+ if norm > 0:
524
+ weights = weights / norm
525
+
526
+ return weights
527
+
528
+ def _extract_pre_period_params(
529
+ self,
530
+ results: Union[MultiPeriodDiDResults, Any],
531
+ pre_periods: Optional[List[int]] = None,
532
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
533
+ """
534
+ Extract pre-period parameters from results.
535
+
536
+ Parameters
537
+ ----------
538
+ results : MultiPeriodDiDResults or similar
539
+ Results object from event study estimation.
540
+ pre_periods : list of int, optional
541
+ Explicit list of pre-treatment periods. If None, uses results.pre_periods.
542
+
543
+ Returns
544
+ -------
545
+ effects : np.ndarray
546
+ Pre-period effect estimates.
547
+ ses : np.ndarray
548
+ Pre-period standard errors.
549
+ vcov : np.ndarray
550
+ Variance-covariance matrix for pre-period effects.
551
+ n_pre : int
552
+ Number of pre-periods.
553
+ """
554
+ if isinstance(results, MultiPeriodDiDResults):
555
+ # Get pre-period information - use explicit pre_periods if provided
556
+ if pre_periods is not None:
557
+ all_pre_periods = list(pre_periods)
558
+ else:
559
+ all_pre_periods = results.pre_periods
560
+
561
+ if len(all_pre_periods) == 0:
562
+ raise ValueError(
563
+ "No pre-treatment periods found in results. "
564
+ "Pre-trends power analysis requires pre-period coefficients. "
565
+ "If you estimated all periods as post_periods, use the pre_periods "
566
+ "parameter to specify which are actually pre-treatment."
567
+ )
568
+
569
+ # Only include periods with actual estimated coefficients
570
+ # (excludes the reference period which is omitted from estimation)
571
+ if hasattr(results, 'coefficients') and results.coefficients:
572
+ # Find which pre-periods have estimated coefficients
573
+ estimated_pre_periods = [
574
+ p for p in all_pre_periods
575
+ if f"treated:period_{p}" in results.coefficients
576
+ ]
577
+
578
+ if len(estimated_pre_periods) == 0:
579
+ raise ValueError(
580
+ "No estimated pre-period coefficients found. "
581
+ "The pre-trends test requires at least one estimated "
582
+ "pre-period coefficient (excluding the reference period)."
583
+ )
584
+
585
+ n_pre = len(estimated_pre_periods)
586
+
587
+ # Extract effects for estimated periods only
588
+ effects = np.array([
589
+ results.coefficients[f"treated:period_{p}"]
590
+ for p in estimated_pre_periods
591
+ ])
592
+
593
+ # Extract SEs - try period_effects first, fall back to avg_se
594
+ ses = []
595
+ for p in estimated_pre_periods:
596
+ if p in results.period_effects:
597
+ ses.append(results.period_effects[p].se)
598
+ else:
599
+ ses.append(results.avg_se)
600
+ ses = np.array(ses)
601
+
602
+ # Extract vcov for estimated pre-periods
603
+ # Build mapping from period to vcov index
604
+ if results.vcov is not None:
605
+ # Get ordered list of all coefficient keys
606
+ coef_keys = list(results.coefficients.keys())
607
+ pre_indices = [
608
+ coef_keys.index(f"treated:period_{p}")
609
+ for p in estimated_pre_periods
610
+ if f"treated:period_{p}" in coef_keys
611
+ ]
612
+ if len(pre_indices) == n_pre and results.vcov.shape[0] > max(pre_indices):
613
+ vcov = results.vcov[np.ix_(pre_indices, pre_indices)]
614
+ else:
615
+ # Fall back to diagonal
616
+ vcov = np.diag(ses ** 2)
617
+ else:
618
+ vcov = np.diag(ses ** 2)
619
+ else:
620
+ # No coefficients available - try period_effects for pre-periods
621
+ # Exclude reference period (the one with effect=0 and se=0 or missing)
622
+ estimated_pre_periods = [
623
+ p for p in all_pre_periods
624
+ if p in results.period_effects
625
+ and results.period_effects[p].se > 0
626
+ ]
627
+
628
+ if len(estimated_pre_periods) == 0:
629
+ raise ValueError(
630
+ "No estimated pre-period effects found. "
631
+ "The pre-trends test requires at least one estimated "
632
+ "pre-period effect (excluding the reference period)."
633
+ )
634
+
635
+ n_pre = len(estimated_pre_periods)
636
+ effects = np.array([
637
+ results.period_effects[p].effect
638
+ for p in estimated_pre_periods
639
+ ])
640
+ ses = np.array([
641
+ results.period_effects[p].se
642
+ for p in estimated_pre_periods
643
+ ])
644
+ vcov = np.diag(ses ** 2)
645
+
646
+ return effects, ses, vcov, n_pre
647
+
648
+ # Try CallawaySantAnnaResults
649
+ try:
650
+ from diff_diff.staggered import CallawaySantAnnaResults
651
+ if isinstance(results, CallawaySantAnnaResults):
652
+ if results.event_study_effects is None:
653
+ raise ValueError(
654
+ "CallawaySantAnnaResults must have event_study_effects. "
655
+ "Re-run with aggregate='event_study'."
656
+ )
657
+
658
+ # Get pre-period effects (negative relative times)
659
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
660
+ pre_effects = {
661
+ t: data for t, data in results.event_study_effects.items()
662
+ if t < 0
663
+ and data.get('n_groups', 1) > 0
664
+ and np.isfinite(data.get('se', np.nan))
665
+ }
666
+
667
+ if not pre_effects:
668
+ raise ValueError("No pre-treatment periods found in event study.")
669
+
670
+ pre_periods = sorted(pre_effects.keys())
671
+ n_pre = len(pre_periods)
672
+
673
+ effects = np.array([pre_effects[t]['effect'] for t in pre_periods])
674
+ ses = np.array([pre_effects[t]['se'] for t in pre_periods])
675
+ vcov = np.diag(ses ** 2)
676
+
677
+ return effects, ses, vcov, n_pre
678
+ except ImportError:
679
+ pass
680
+
681
+ # Try SunAbrahamResults
682
+ try:
683
+ from diff_diff.sun_abraham import SunAbrahamResults
684
+ if isinstance(results, SunAbrahamResults):
685
+ # Get pre-period effects (negative relative times)
686
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
687
+ pre_effects = {
688
+ t: data for t, data in results.event_study_effects.items()
689
+ if t < 0
690
+ and data.get('n_groups', 1) > 0
691
+ and np.isfinite(data.get('se', np.nan))
692
+ }
693
+
694
+ if not pre_effects:
695
+ raise ValueError("No pre-treatment periods found in event study.")
696
+
697
+ pre_periods = sorted(pre_effects.keys())
698
+ n_pre = len(pre_periods)
699
+
700
+ effects = np.array([pre_effects[t]['effect'] for t in pre_periods])
701
+ ses = np.array([pre_effects[t]['se'] for t in pre_periods])
702
+ vcov = np.diag(ses ** 2)
703
+
704
+ return effects, ses, vcov, n_pre
705
+ except ImportError:
706
+ pass
707
+
708
+ raise TypeError(
709
+ f"Unsupported results type: {type(results)}. "
710
+ "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults."
711
+ )
712
+
713
+ def _compute_power(
714
+ self,
715
+ M: float,
716
+ weights: np.ndarray,
717
+ vcov: np.ndarray,
718
+ ) -> Tuple[float, float, float, float]:
719
+ """
720
+ Compute power to detect violation of magnitude M.
721
+
722
+ The pre-trends test is a Wald test: H0: delta = 0 vs H1: delta != 0
723
+ Under H1 with violation delta = M * weights, the test statistic follows
724
+ a non-central chi-squared distribution.
725
+
726
+ Parameters
727
+ ----------
728
+ M : float
729
+ Violation magnitude.
730
+ weights : np.ndarray
731
+ Normalized violation pattern.
732
+ vcov : np.ndarray
733
+ Variance-covariance matrix.
734
+
735
+ Returns
736
+ -------
737
+ power : float
738
+ Power to detect this violation.
739
+ noncentrality : float
740
+ Non-centrality parameter.
741
+ test_stat : float
742
+ Expected test statistic under H1.
743
+ critical_value : float
744
+ Critical value for the test.
745
+ """
746
+ n_pre = len(weights)
747
+
748
+ # Violation vector: delta = M * weights
749
+ delta = M * weights
750
+
751
+ # Non-centrality parameter for chi-squared test
752
+ # lambda = delta' * V^{-1} * delta
753
+ try:
754
+ vcov_inv = np.linalg.inv(vcov)
755
+ noncentrality = delta @ vcov_inv @ delta
756
+ except np.linalg.LinAlgError:
757
+ # Singular matrix - use pseudo-inverse
758
+ vcov_inv = np.linalg.pinv(vcov)
759
+ noncentrality = delta @ vcov_inv @ delta
760
+
761
+ # Critical value from chi-squared distribution
762
+ critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
763
+
764
+ # Power = P(chi2_nc > critical_value) where chi2_nc is non-central chi2
765
+ if noncentrality > 0:
766
+ power = 1 - stats.ncx2.cdf(critical_value, df=n_pre, nc=noncentrality)
767
+ else:
768
+ power = self.alpha # Size under null
769
+
770
+ # Expected test statistic under H1
771
+ test_stat = n_pre + noncentrality # Mean of non-central chi2
772
+
773
+ return power, noncentrality, test_stat, critical_value
774
+
775
+ def _compute_mdv(
776
+ self,
777
+ weights: np.ndarray,
778
+ vcov: np.ndarray,
779
+ ) -> float:
780
+ """
781
+ Compute minimum detectable violation.
782
+
783
+ Find the smallest M such that power >= target_power.
784
+
785
+ Parameters
786
+ ----------
787
+ weights : np.ndarray
788
+ Normalized violation pattern.
789
+ vcov : np.ndarray
790
+ Variance-covariance matrix.
791
+
792
+ Returns
793
+ -------
794
+ mdv : float
795
+ Minimum detectable violation.
796
+ """
797
+ n_pre = len(weights)
798
+
799
+ # Critical value
800
+ critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
801
+
802
+ # Find non-centrality parameter for target power
803
+ # We need: P(ncx2 > critical_value) = target_power
804
+ # Use inverse: find lambda such that ncx2.cdf(cv, df, lambda) = 1 - target_power
805
+
806
+ def power_minus_target(nc):
807
+ if nc <= 0:
808
+ return self.alpha - self.target_power
809
+ return stats.ncx2.sf(critical_value, df=n_pre, nc=nc) - self.target_power
810
+
811
+ # Binary search for non-centrality parameter
812
+ # Start with bounds
813
+ nc_low, nc_high = 0, 1
814
+
815
+ # Expand upper bound until power exceeds target
816
+ while power_minus_target(nc_high) < 0 and nc_high < 1000:
817
+ nc_high *= 2
818
+
819
+ if nc_high >= 1000:
820
+ # Target power not achievable - return inf
821
+ return np.inf
822
+
823
+ # Binary search
824
+ try:
825
+ result = optimize.brentq(power_minus_target, nc_low, nc_high)
826
+ target_nc = result
827
+ except ValueError:
828
+ # Fallback: use approximate formula
829
+ # For chi2, power ≈ Phi(sqrt(2*nc) - sqrt(2*cv))
830
+ # Solving: sqrt(2*nc) = z_power + sqrt(2*cv)
831
+ z_power = stats.norm.ppf(self.target_power)
832
+ target_nc = 0.5 * (z_power + np.sqrt(2 * critical_value)) ** 2
833
+
834
+ # Convert non-centrality to M
835
+ # nc = delta' * V^{-1} * delta = M^2 * w' * V^{-1} * w
836
+ try:
837
+ vcov_inv = np.linalg.inv(vcov)
838
+ w_Vinv_w = weights @ vcov_inv @ weights
839
+ except np.linalg.LinAlgError:
840
+ vcov_inv = np.linalg.pinv(vcov)
841
+ w_Vinv_w = weights @ vcov_inv @ weights
842
+
843
+ if w_Vinv_w > 0:
844
+ mdv = np.sqrt(target_nc / w_Vinv_w)
845
+ else:
846
+ mdv = np.inf
847
+
848
+ return mdv
849
+
850
+ def fit(
851
+ self,
852
+ results: Union[MultiPeriodDiDResults, Any],
853
+ M: Optional[float] = None,
854
+ pre_periods: Optional[List[int]] = None,
855
+ ) -> PreTrendsPowerResults:
856
+ """
857
+ Compute pre-trends power analysis.
858
+
859
+ Parameters
860
+ ----------
861
+ results : MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults
862
+ Results from an event study estimation.
863
+ M : float, optional
864
+ Specific violation magnitude to evaluate. If None, evaluates at
865
+ a default magnitude based on the data.
866
+ pre_periods : list of int, optional
867
+ Explicit list of pre-treatment periods to use for power analysis.
868
+ If None, attempts to infer from results.pre_periods. Use this when
869
+ you've estimated an event study with all periods in post_periods
870
+ and need to specify which are actually pre-treatment.
871
+
872
+ Returns
873
+ -------
874
+ PreTrendsPowerResults
875
+ Power analysis results including power and MDV.
876
+ """
877
+ # Extract pre-period parameters
878
+ effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
879
+
880
+ # Get violation weights
881
+ weights = self._get_violation_weights(n_pre)
882
+
883
+ # Compute MDV
884
+ mdv = self._compute_mdv(weights, vcov)
885
+
886
+ # Default M: use MDV if not specified
887
+ if M is None:
888
+ M = mdv if np.isfinite(mdv) else np.max(ses)
889
+
890
+ # Compute power at specified M
891
+ power, noncentrality, test_stat, critical_value = self._compute_power(
892
+ M, weights, vcov
893
+ )
894
+
895
+ return PreTrendsPowerResults(
896
+ power=power,
897
+ mdv=mdv,
898
+ violation_magnitude=M,
899
+ violation_type=self.violation_type,
900
+ alpha=self.alpha,
901
+ target_power=self.target_power,
902
+ n_pre_periods=n_pre,
903
+ test_statistic=test_stat,
904
+ critical_value=critical_value,
905
+ noncentrality=noncentrality,
906
+ pre_period_effects=effects,
907
+ pre_period_ses=ses,
908
+ vcov=vcov,
909
+ original_results=results,
910
+ )
911
+
912
+ def power_at(
913
+ self,
914
+ results: Union[MultiPeriodDiDResults, Any],
915
+ M: float,
916
+ pre_periods: Optional[List[int]] = None,
917
+ ) -> float:
918
+ """
919
+ Compute power to detect a specific violation magnitude.
920
+
921
+ Parameters
922
+ ----------
923
+ results : results object
924
+ Event study results.
925
+ M : float
926
+ Violation magnitude.
927
+ pre_periods : list of int, optional
928
+ Explicit list of pre-treatment periods. See fit() for details.
929
+
930
+ Returns
931
+ -------
932
+ float
933
+ Power to detect violation of magnitude M.
934
+ """
935
+ result = self.fit(results, M=M, pre_periods=pre_periods)
936
+ return result.power
937
+
938
+ def power_curve(
939
+ self,
940
+ results: Union[MultiPeriodDiDResults, Any],
941
+ M_grid: Optional[List[float]] = None,
942
+ n_points: int = 50,
943
+ pre_periods: Optional[List[int]] = None,
944
+ ) -> PreTrendsPowerCurve:
945
+ """
946
+ Compute power across a range of violation magnitudes.
947
+
948
+ Parameters
949
+ ----------
950
+ results : results object
951
+ Event study results.
952
+ M_grid : list of float, optional
953
+ Specific violation magnitudes to evaluate. If None, creates
954
+ automatic grid from 0 to 2.5 * MDV.
955
+ n_points : int, default=50
956
+ Number of points in automatic grid.
957
+ pre_periods : list of int, optional
958
+ Explicit list of pre-treatment periods. See fit() for details.
959
+
960
+ Returns
961
+ -------
962
+ PreTrendsPowerCurve
963
+ Power curve data with plot method.
964
+ """
965
+ # Extract parameters
966
+ _, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
967
+ weights = self._get_violation_weights(n_pre)
968
+
969
+ # Compute MDV
970
+ mdv = self._compute_mdv(weights, vcov)
971
+
972
+ # Create M grid if not provided
973
+ if M_grid is None:
974
+ max_M = min(2.5 * mdv if np.isfinite(mdv) else 10 * np.max(ses), 100)
975
+ M_grid = np.linspace(0, max_M, n_points)
976
+ else:
977
+ M_grid = np.asarray(M_grid)
978
+
979
+ # Compute power at each M
980
+ powers = np.array([
981
+ self._compute_power(M, weights, vcov)[0]
982
+ for M in M_grid
983
+ ])
984
+
985
+ return PreTrendsPowerCurve(
986
+ M_values=M_grid,
987
+ powers=powers,
988
+ mdv=mdv,
989
+ alpha=self.alpha,
990
+ target_power=self.target_power,
991
+ violation_type=self.violation_type,
992
+ )
993
+
994
+ def sensitivity_to_honest_did(
995
+ self,
996
+ results: Union[MultiPeriodDiDResults, Any],
997
+ pre_periods: Optional[List[int]] = None,
998
+ ) -> Dict[str, Any]:
999
+ """
1000
+ Compare pre-trends power analysis with HonestDiD sensitivity.
1001
+
1002
+ This method helps interpret how informative a passing pre-trends
1003
+ test is in the context of HonestDiD's relative magnitudes restriction.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ results : results object
1008
+ Event study results.
1009
+ pre_periods : list of int, optional
1010
+ Explicit list of pre-treatment periods. See fit() for details.
1011
+
1012
+ Returns
1013
+ -------
1014
+ dict
1015
+ Dictionary with:
1016
+ - mdv: Minimum detectable violation from pre-trends test
1017
+ - honest_M_at_mdv: Corresponding M value for HonestDiD
1018
+ - interpretation: Text explaining the relationship
1019
+ """
1020
+ pt_results = self.fit(results, pre_periods=pre_periods)
1021
+ mdv = pt_results.mdv
1022
+
1023
+ # The MDV represents the size of violation the test could detect
1024
+ # In HonestDiD's relative magnitudes framework, M=1 means
1025
+ # post-treatment violations can be as large as the max pre-period violation
1026
+ # The MDV gives us a sense of how large that max violation could be
1027
+
1028
+ max_pre_se = np.max(pt_results.pre_period_ses)
1029
+
1030
+ interpretation = []
1031
+ interpretation.append(
1032
+ f"Minimum Detectable Violation (MDV): {mdv:.4f}"
1033
+ )
1034
+ interpretation.append(
1035
+ f"Max pre-period SE: {max_pre_se:.4f}"
1036
+ )
1037
+
1038
+ if np.isfinite(mdv):
1039
+ # Ratio of MDV to max SE - gives sense of how many SEs the MDV is
1040
+ mdv_in_ses = mdv / max_pre_se if max_pre_se > 0 else np.inf
1041
+ interpretation.append(
1042
+ f"MDV / max(SE): {mdv_in_ses:.2f}"
1043
+ )
1044
+
1045
+ if mdv_in_ses < 1:
1046
+ interpretation.append(
1047
+ "→ Pre-trends test is fairly sensitive to violations."
1048
+ )
1049
+ elif mdv_in_ses < 2:
1050
+ interpretation.append(
1051
+ "→ Pre-trends test has moderate sensitivity."
1052
+ )
1053
+ else:
1054
+ interpretation.append(
1055
+ "→ Pre-trends test has low power to detect violations."
1056
+ )
1057
+ interpretation.append(
1058
+ " Consider using HonestDiD with larger M values for robustness."
1059
+ )
1060
+ else:
1061
+ interpretation.append(
1062
+ "→ Pre-trends test cannot achieve target power for any violation size."
1063
+ )
1064
+ interpretation.append(
1065
+ " Use HonestDiD sensitivity analysis for inference."
1066
+ )
1067
+
1068
+ return {
1069
+ "mdv": mdv,
1070
+ "max_pre_se": max_pre_se,
1071
+ "mdv_in_ses": mdv / max_pre_se if max_pre_se > 0 and np.isfinite(mdv) else np.inf,
1072
+ "interpretation": "\n".join(interpretation),
1073
+ }
1074
+
1075
+
1076
+ # =============================================================================
1077
+ # Convenience Functions
1078
+ # =============================================================================
1079
+
1080
+
1081
+ def compute_pretrends_power(
1082
+ results: Union[MultiPeriodDiDResults, Any],
1083
+ M: Optional[float] = None,
1084
+ alpha: float = 0.05,
1085
+ target_power: float = 0.80,
1086
+ violation_type: str = "linear",
1087
+ pre_periods: Optional[List[int]] = None,
1088
+ ) -> PreTrendsPowerResults:
1089
+ """
1090
+ Convenience function for pre-trends power analysis.
1091
+
1092
+ Parameters
1093
+ ----------
1094
+ results : results object
1095
+ Event study results.
1096
+ M : float, optional
1097
+ Violation magnitude to evaluate.
1098
+ alpha : float, default=0.05
1099
+ Significance level.
1100
+ target_power : float, default=0.80
1101
+ Target power for MDV calculation.
1102
+ violation_type : str, default='linear'
1103
+ Type of violation pattern.
1104
+ pre_periods : list of int, optional
1105
+ Explicit list of pre-treatment periods. If None, attempts to infer
1106
+ from results. Use when you've estimated all periods as post_periods.
1107
+
1108
+ Returns
1109
+ -------
1110
+ PreTrendsPowerResults
1111
+ Power analysis results.
1112
+
1113
+ Examples
1114
+ --------
1115
+ >>> from diff_diff import MultiPeriodDiD
1116
+ >>> from diff_diff.pretrends import compute_pretrends_power
1117
+ >>>
1118
+ >>> results = MultiPeriodDiD().fit(data, ...)
1119
+ >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3])
1120
+ >>> print(f"MDV: {power_results.mdv:.3f}")
1121
+ >>> print(f"Power: {power_results.power:.1%}")
1122
+ """
1123
+ pt = PreTrendsPower(
1124
+ alpha=alpha,
1125
+ power=target_power,
1126
+ violation_type=violation_type,
1127
+ )
1128
+ return pt.fit(results, M=M, pre_periods=pre_periods)
1129
+
1130
+
1131
+ def compute_mdv(
1132
+ results: Union[MultiPeriodDiDResults, Any],
1133
+ alpha: float = 0.05,
1134
+ target_power: float = 0.80,
1135
+ violation_type: str = "linear",
1136
+ pre_periods: Optional[List[int]] = None,
1137
+ ) -> float:
1138
+ """
1139
+ Compute minimum detectable violation.
1140
+
1141
+ Parameters
1142
+ ----------
1143
+ results : results object
1144
+ Event study results.
1145
+ alpha : float, default=0.05
1146
+ Significance level.
1147
+ target_power : float, default=0.80
1148
+ Target power for MDV calculation.
1149
+ violation_type : str, default='linear'
1150
+ Type of violation pattern.
1151
+ pre_periods : list of int, optional
1152
+ Explicit list of pre-treatment periods. If None, attempts to infer
1153
+ from results. Use when you've estimated all periods as post_periods.
1154
+
1155
+ Returns
1156
+ -------
1157
+ float
1158
+ Minimum detectable violation.
1159
+ """
1160
+ pt = PreTrendsPower(
1161
+ alpha=alpha,
1162
+ power=target_power,
1163
+ violation_type=violation_type,
1164
+ )
1165
+ result = pt.fit(results, pre_periods=pre_periods)
1166
+ return result.mdv