diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +226 -0
- diff_diff/_backend.py +64 -0
- diff_diff/_rust_backend.cpython-312-darwin.so +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1000 -0
- diff_diff/honest_did.py +1493 -0
- diff_diff/linalg.py +980 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1338 -0
- diff_diff/pretrends.py +1067 -0
- diff_diff/results.py +703 -0
- diff_diff/staggered.py +2297 -0
- diff_diff/sun_abraham.py +1176 -0
- diff_diff/synthetic_did.py +738 -0
- diff_diff/triple_diff.py +1291 -0
- diff_diff/twfe.py +344 -0
- diff_diff/utils.py +1481 -0
- diff_diff/visualization.py +1627 -0
- diff_diff-2.0.4.dist-info/METADATA +2257 -0
- diff_diff-2.0.4.dist-info/RECORD +23 -0
- diff_diff-2.0.4.dist-info/WHEEL +4 -0
diff_diff/pretrends.py
ADDED
|
@@ -0,0 +1,1067 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pre-trends power analysis for difference-in-differences designs.
|
|
3
|
+
|
|
4
|
+
This module implements the power analysis framework from Roth (2022) for assessing
|
|
5
|
+
the informativeness of pre-trends tests. It answers the question: "If my pre-trends
|
|
6
|
+
test passed, what violations would I have been able to detect?"
|
|
7
|
+
|
|
8
|
+
Key concepts:
|
|
9
|
+
- **Minimum Detectable Violation (MDV)**: The smallest pre-trends violation that
|
|
10
|
+
would be detected with given power (e.g., 80%).
|
|
11
|
+
- **Power of Pre-Trends Test**: Probability of rejecting parallel trends given
|
|
12
|
+
a specific violation pattern.
|
|
13
|
+
- **Relationship to HonestDiD**: If MDV is large relative to your estimated effect,
|
|
14
|
+
a passing pre-trends test provides limited reassurance.
|
|
15
|
+
|
|
16
|
+
References
|
|
17
|
+
----------
|
|
18
|
+
Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for
|
|
19
|
+
Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
|
|
20
|
+
https://doi.org/10.1257/aeri.20210236
|
|
21
|
+
|
|
22
|
+
See Also
|
|
23
|
+
--------
|
|
24
|
+
https://github.com/jonathandroth/pretrends - R package implementation
|
|
25
|
+
diff_diff.honest_did - Sensitivity analysis for parallel trends violations
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from dataclasses import dataclass, field
|
|
29
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
import pandas as pd
|
|
33
|
+
from scipy import stats, optimize
|
|
34
|
+
|
|
35
|
+
from diff_diff.results import MultiPeriodDiDResults
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Results Classes
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class PreTrendsPowerResults:
|
|
45
|
+
"""
|
|
46
|
+
Results from pre-trends power analysis.
|
|
47
|
+
|
|
48
|
+
Attributes
|
|
49
|
+
----------
|
|
50
|
+
power : float
|
|
51
|
+
Power to detect the specified violation pattern at given alpha.
|
|
52
|
+
mdv : float
|
|
53
|
+
Minimum detectable violation (smallest M detectable at target power).
|
|
54
|
+
violation_magnitude : float
|
|
55
|
+
The magnitude of violation tested (M parameter).
|
|
56
|
+
violation_type : str
|
|
57
|
+
Type of violation pattern ('linear', 'constant', 'last_period', 'custom').
|
|
58
|
+
alpha : float
|
|
59
|
+
Significance level for the pre-trends test.
|
|
60
|
+
target_power : float
|
|
61
|
+
Target power level used for MDV calculation.
|
|
62
|
+
n_pre_periods : int
|
|
63
|
+
Number of pre-treatment periods in the event study.
|
|
64
|
+
test_statistic : float
|
|
65
|
+
Expected test statistic under the specified violation.
|
|
66
|
+
critical_value : float
|
|
67
|
+
Critical value for the pre-trends test.
|
|
68
|
+
noncentrality : float
|
|
69
|
+
Non-centrality parameter under the alternative hypothesis.
|
|
70
|
+
pre_period_effects : np.ndarray
|
|
71
|
+
Estimated pre-period effects from the event study.
|
|
72
|
+
pre_period_ses : np.ndarray
|
|
73
|
+
Standard errors of pre-period effects.
|
|
74
|
+
vcov : np.ndarray
|
|
75
|
+
Variance-covariance matrix of pre-period effects.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
power: float
|
|
79
|
+
mdv: float
|
|
80
|
+
violation_magnitude: float
|
|
81
|
+
violation_type: str
|
|
82
|
+
alpha: float
|
|
83
|
+
target_power: float
|
|
84
|
+
n_pre_periods: int
|
|
85
|
+
test_statistic: float
|
|
86
|
+
critical_value: float
|
|
87
|
+
noncentrality: float
|
|
88
|
+
pre_period_effects: np.ndarray = field(repr=False)
|
|
89
|
+
pre_period_ses: np.ndarray = field(repr=False)
|
|
90
|
+
vcov: np.ndarray = field(repr=False)
|
|
91
|
+
original_results: Optional[Any] = field(default=None, repr=False)
|
|
92
|
+
|
|
93
|
+
def __repr__(self) -> str:
|
|
94
|
+
return (
|
|
95
|
+
f"PreTrendsPowerResults(power={self.power:.3f}, "
|
|
96
|
+
f"mdv={self.mdv:.4f}, M={self.violation_magnitude:.4f})"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def is_informative(self) -> bool:
|
|
101
|
+
"""
|
|
102
|
+
Check if the pre-trends test is informative.
|
|
103
|
+
|
|
104
|
+
A pre-trends test is considered informative if the MDV is reasonably
|
|
105
|
+
small relative to typical effect sizes. This is a heuristic check;
|
|
106
|
+
see the summary for interpretation guidance.
|
|
107
|
+
"""
|
|
108
|
+
# Heuristic: MDV < 2x the max observed pre-period SE
|
|
109
|
+
max_se = np.max(self.pre_period_ses) if len(self.pre_period_ses) > 0 else 1.0
|
|
110
|
+
return bool(self.mdv < 2 * max_se)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def power_adequate(self) -> bool:
|
|
114
|
+
"""Check if power meets the target threshold."""
|
|
115
|
+
return bool(self.power >= self.target_power)
|
|
116
|
+
|
|
117
|
+
def summary(self) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Generate formatted summary of pre-trends power analysis.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
str
|
|
124
|
+
Formatted summary.
|
|
125
|
+
"""
|
|
126
|
+
lines = [
|
|
127
|
+
"=" * 70,
|
|
128
|
+
"Pre-Trends Power Analysis Results".center(70),
|
|
129
|
+
"(Roth 2022)".center(70),
|
|
130
|
+
"=" * 70,
|
|
131
|
+
"",
|
|
132
|
+
f"{'Number of pre-periods:':<35} {self.n_pre_periods}",
|
|
133
|
+
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
|
|
134
|
+
f"{'Target power:':<35} {self.target_power:.1%}",
|
|
135
|
+
f"{'Violation type:':<35} {self.violation_type}",
|
|
136
|
+
"",
|
|
137
|
+
"-" * 70,
|
|
138
|
+
"Power Analysis".center(70),
|
|
139
|
+
"-" * 70,
|
|
140
|
+
f"{'Violation magnitude (M):':<35} {self.violation_magnitude:.4f}",
|
|
141
|
+
f"{'Power to detect this violation:':<35} {self.power:.1%}",
|
|
142
|
+
f"{'Minimum detectable violation:':<35} {self.mdv:.4f}",
|
|
143
|
+
"",
|
|
144
|
+
f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}",
|
|
145
|
+
f"{'Critical value:':<35} {self.critical_value:.4f}",
|
|
146
|
+
f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}",
|
|
147
|
+
"",
|
|
148
|
+
"-" * 70,
|
|
149
|
+
"Interpretation".center(70),
|
|
150
|
+
"-" * 70,
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
if self.power_adequate:
|
|
154
|
+
lines.append(
|
|
155
|
+
f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%})."
|
|
156
|
+
)
|
|
157
|
+
lines.append(
|
|
158
|
+
f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}."
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
lines.append(
|
|
162
|
+
f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%})."
|
|
163
|
+
)
|
|
164
|
+
lines.append(
|
|
165
|
+
f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
lines.append("")
|
|
169
|
+
lines.append(
|
|
170
|
+
f"Minimum detectable violation (MDV): {self.mdv:.4f}"
|
|
171
|
+
)
|
|
172
|
+
lines.append(
|
|
173
|
+
" → Passing pre-trends test does NOT rule out violations up to this size."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
lines.extend(["", "=" * 70])
|
|
177
|
+
|
|
178
|
+
return "\n".join(lines)
|
|
179
|
+
|
|
180
|
+
def print_summary(self) -> None:
|
|
181
|
+
"""Print summary to stdout."""
|
|
182
|
+
print(self.summary())
|
|
183
|
+
|
|
184
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
185
|
+
"""Convert results to dictionary."""
|
|
186
|
+
return {
|
|
187
|
+
"power": self.power,
|
|
188
|
+
"mdv": self.mdv,
|
|
189
|
+
"violation_magnitude": self.violation_magnitude,
|
|
190
|
+
"violation_type": self.violation_type,
|
|
191
|
+
"alpha": self.alpha,
|
|
192
|
+
"target_power": self.target_power,
|
|
193
|
+
"n_pre_periods": self.n_pre_periods,
|
|
194
|
+
"test_statistic": self.test_statistic,
|
|
195
|
+
"critical_value": self.critical_value,
|
|
196
|
+
"noncentrality": self.noncentrality,
|
|
197
|
+
"is_informative": self.is_informative,
|
|
198
|
+
"power_adequate": self.power_adequate,
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
202
|
+
"""Convert results to DataFrame."""
|
|
203
|
+
return pd.DataFrame([self.to_dict()])
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@dataclass
|
|
207
|
+
class PreTrendsPowerCurve:
|
|
208
|
+
"""
|
|
209
|
+
Power curve across violation magnitudes.
|
|
210
|
+
|
|
211
|
+
Attributes
|
|
212
|
+
----------
|
|
213
|
+
M_values : np.ndarray
|
|
214
|
+
Grid of violation magnitudes tested.
|
|
215
|
+
powers : np.ndarray
|
|
216
|
+
Power at each violation magnitude.
|
|
217
|
+
mdv : float
|
|
218
|
+
Minimum detectable violation.
|
|
219
|
+
alpha : float
|
|
220
|
+
Significance level.
|
|
221
|
+
target_power : float
|
|
222
|
+
Target power level.
|
|
223
|
+
violation_type : str
|
|
224
|
+
Type of violation pattern.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
M_values: np.ndarray
|
|
228
|
+
powers: np.ndarray
|
|
229
|
+
mdv: float
|
|
230
|
+
alpha: float
|
|
231
|
+
target_power: float
|
|
232
|
+
violation_type: str
|
|
233
|
+
|
|
234
|
+
def __repr__(self) -> str:
|
|
235
|
+
return (
|
|
236
|
+
f"PreTrendsPowerCurve(n_points={len(self.M_values)}, "
|
|
237
|
+
f"mdv={self.mdv:.4f})"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
241
|
+
"""Convert to DataFrame with M and power columns."""
|
|
242
|
+
return pd.DataFrame({
|
|
243
|
+
"M": self.M_values,
|
|
244
|
+
"power": self.powers,
|
|
245
|
+
})
|
|
246
|
+
|
|
247
|
+
def plot(self, ax=None, show_mdv: bool = True, show_target: bool = True,
|
|
248
|
+
color: str = "#2563eb", mdv_color: str = "#dc2626",
|
|
249
|
+
target_color: str = "#22c55e", **kwargs):
|
|
250
|
+
"""
|
|
251
|
+
Plot the power curve.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
ax : matplotlib.axes.Axes, optional
|
|
256
|
+
Axes to plot on. If None, creates new figure.
|
|
257
|
+
show_mdv : bool, default=True
|
|
258
|
+
Whether to show vertical line at MDV.
|
|
259
|
+
show_target : bool, default=True
|
|
260
|
+
Whether to show horizontal line at target power.
|
|
261
|
+
color : str
|
|
262
|
+
Color for power curve line.
|
|
263
|
+
mdv_color : str
|
|
264
|
+
Color for MDV vertical line.
|
|
265
|
+
target_color : str
|
|
266
|
+
Color for target power horizontal line.
|
|
267
|
+
**kwargs
|
|
268
|
+
Additional arguments passed to plt.plot().
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
ax : matplotlib.axes.Axes
|
|
273
|
+
The axes with the plot.
|
|
274
|
+
"""
|
|
275
|
+
try:
|
|
276
|
+
import matplotlib.pyplot as plt
|
|
277
|
+
except ImportError:
|
|
278
|
+
raise ImportError("matplotlib is required for plotting")
|
|
279
|
+
|
|
280
|
+
if ax is None:
|
|
281
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
282
|
+
|
|
283
|
+
# Plot power curve
|
|
284
|
+
ax.plot(self.M_values, self.powers, color=color, linewidth=2,
|
|
285
|
+
label="Power", **kwargs)
|
|
286
|
+
|
|
287
|
+
# Target power line
|
|
288
|
+
if show_target:
|
|
289
|
+
ax.axhline(y=self.target_power, color=target_color, linestyle="--",
|
|
290
|
+
linewidth=1.5, alpha=0.7,
|
|
291
|
+
label=f"Target power ({self.target_power:.0%})")
|
|
292
|
+
|
|
293
|
+
# MDV line
|
|
294
|
+
if show_mdv and self.mdv is not None and np.isfinite(self.mdv):
|
|
295
|
+
ax.axvline(x=self.mdv, color=mdv_color, linestyle=":",
|
|
296
|
+
linewidth=1.5, alpha=0.7,
|
|
297
|
+
label=f"MDV = {self.mdv:.3f}")
|
|
298
|
+
|
|
299
|
+
ax.set_xlabel("Violation Magnitude (M)")
|
|
300
|
+
ax.set_ylabel("Power")
|
|
301
|
+
ax.set_title("Pre-Trends Test Power Curve")
|
|
302
|
+
ax.set_ylim(0, 1.05)
|
|
303
|
+
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
|
|
304
|
+
ax.legend(loc="lower right")
|
|
305
|
+
ax.grid(True, alpha=0.3)
|
|
306
|
+
|
|
307
|
+
return ax
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# =============================================================================
|
|
311
|
+
# Main Class
|
|
312
|
+
# =============================================================================
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PreTrendsPower:
|
|
316
|
+
"""
|
|
317
|
+
Pre-trends power analysis (Roth 2022).
|
|
318
|
+
|
|
319
|
+
Computes the power of pre-trends tests to detect violations of parallel
|
|
320
|
+
trends, and the minimum detectable violation (MDV).
|
|
321
|
+
|
|
322
|
+
Parameters
|
|
323
|
+
----------
|
|
324
|
+
alpha : float, default=0.05
|
|
325
|
+
Significance level for the pre-trends test.
|
|
326
|
+
power : float, default=0.80
|
|
327
|
+
Target power level for MDV calculation.
|
|
328
|
+
violation_type : str, default='linear'
|
|
329
|
+
Type of violation pattern to consider:
|
|
330
|
+
- 'linear': Violations follow a linear trend (most common)
|
|
331
|
+
- 'constant': Same violation in all pre-periods
|
|
332
|
+
- 'last_period': Violation only in the last pre-period
|
|
333
|
+
- 'custom': User-specified violation pattern (via violation_weights)
|
|
334
|
+
violation_weights : array-like, optional
|
|
335
|
+
Custom weights for violation pattern. Length must equal number of
|
|
336
|
+
pre-periods. Only used when violation_type='custom'.
|
|
337
|
+
|
|
338
|
+
Examples
|
|
339
|
+
--------
|
|
340
|
+
Basic usage with MultiPeriodDiD results:
|
|
341
|
+
|
|
342
|
+
>>> from diff_diff import MultiPeriodDiD
|
|
343
|
+
>>> from diff_diff.pretrends import PreTrendsPower
|
|
344
|
+
>>>
|
|
345
|
+
>>> # Fit event study
|
|
346
|
+
>>> mp_did = MultiPeriodDiD()
|
|
347
|
+
>>> results = mp_did.fit(data, outcome='y', treatment='treated',
|
|
348
|
+
... time='period', post_periods=[4, 5, 6, 7])
|
|
349
|
+
>>>
|
|
350
|
+
>>> # Analyze pre-trends power
|
|
351
|
+
>>> pt = PreTrendsPower(alpha=0.05, power=0.80)
|
|
352
|
+
>>> power_results = pt.fit(results)
|
|
353
|
+
>>> print(power_results.summary())
|
|
354
|
+
>>>
|
|
355
|
+
>>> # Get power curve
|
|
356
|
+
>>> curve = pt.power_curve(results)
|
|
357
|
+
>>> curve.plot()
|
|
358
|
+
|
|
359
|
+
Notes
|
|
360
|
+
-----
|
|
361
|
+
The pre-trends test is typically a joint test that all pre-period
|
|
362
|
+
coefficients are zero. This test has limited power to detect small
|
|
363
|
+
violations, especially when:
|
|
364
|
+
|
|
365
|
+
1. There are few pre-periods
|
|
366
|
+
2. Standard errors are large
|
|
367
|
+
3. The violation pattern is smooth (e.g., linear trend)
|
|
368
|
+
|
|
369
|
+
Passing a pre-trends test does NOT mean parallel trends holds. It means
|
|
370
|
+
violations smaller than the MDV cannot be ruled out. For robust inference,
|
|
371
|
+
combine with HonestDiD sensitivity analysis.
|
|
372
|
+
|
|
373
|
+
References
|
|
374
|
+
----------
|
|
375
|
+
Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing
|
|
376
|
+
for Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(
|
|
380
|
+
self,
|
|
381
|
+
alpha: float = 0.05,
|
|
382
|
+
power: float = 0.80,
|
|
383
|
+
violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear",
|
|
384
|
+
violation_weights: Optional[np.ndarray] = None,
|
|
385
|
+
):
|
|
386
|
+
if not 0 < alpha < 1:
|
|
387
|
+
raise ValueError(f"alpha must be between 0 and 1, got {alpha}")
|
|
388
|
+
if not 0 < power < 1:
|
|
389
|
+
raise ValueError(f"power must be between 0 and 1, got {power}")
|
|
390
|
+
if violation_type not in ["linear", "constant", "last_period", "custom"]:
|
|
391
|
+
raise ValueError(
|
|
392
|
+
f"violation_type must be 'linear', 'constant', 'last_period', or 'custom', "
|
|
393
|
+
f"got '{violation_type}'"
|
|
394
|
+
)
|
|
395
|
+
if violation_type == "custom" and violation_weights is None:
|
|
396
|
+
raise ValueError(
|
|
397
|
+
"violation_weights must be provided when violation_type='custom'"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
self.alpha = alpha
|
|
401
|
+
self.target_power = power
|
|
402
|
+
self.violation_type = violation_type
|
|
403
|
+
self.violation_weights = (
|
|
404
|
+
np.asarray(violation_weights) if violation_weights is not None else None
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def get_params(self) -> Dict[str, Any]:
|
|
408
|
+
"""Get parameters for this estimator."""
|
|
409
|
+
return {
|
|
410
|
+
"alpha": self.alpha,
|
|
411
|
+
"power": self.target_power,
|
|
412
|
+
"violation_type": self.violation_type,
|
|
413
|
+
"violation_weights": self.violation_weights,
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
def set_params(self, **params) -> "PreTrendsPower":
|
|
417
|
+
"""Set parameters for this estimator."""
|
|
418
|
+
for key, value in params.items():
|
|
419
|
+
if key == "power":
|
|
420
|
+
self.target_power = value
|
|
421
|
+
elif hasattr(self, key):
|
|
422
|
+
setattr(self, key, value)
|
|
423
|
+
else:
|
|
424
|
+
raise ValueError(f"Invalid parameter: {key}")
|
|
425
|
+
return self
|
|
426
|
+
|
|
427
|
+
def _get_violation_weights(self, n_pre: int) -> np.ndarray:
|
|
428
|
+
"""
|
|
429
|
+
Get violation weights based on violation type.
|
|
430
|
+
|
|
431
|
+
Parameters
|
|
432
|
+
----------
|
|
433
|
+
n_pre : int
|
|
434
|
+
Number of pre-treatment periods.
|
|
435
|
+
|
|
436
|
+
Returns
|
|
437
|
+
-------
|
|
438
|
+
np.ndarray
|
|
439
|
+
Violation weights, normalized to have L2 norm of 1.
|
|
440
|
+
"""
|
|
441
|
+
if self.violation_type == "custom":
|
|
442
|
+
if len(self.violation_weights) != n_pre:
|
|
443
|
+
raise ValueError(
|
|
444
|
+
f"violation_weights has length {len(self.violation_weights)}, "
|
|
445
|
+
f"but there are {n_pre} pre-periods"
|
|
446
|
+
)
|
|
447
|
+
weights = self.violation_weights.copy()
|
|
448
|
+
elif self.violation_type == "linear":
|
|
449
|
+
# Linear trend: weights = [-n+1, -n+2, ..., -1, 0] for periods ending at -1
|
|
450
|
+
# Normalized so that violation at period -1 = 0 and grows linearly backward
|
|
451
|
+
weights = np.arange(-n_pre + 1, 1, dtype=float)
|
|
452
|
+
# Shift so that weights are positive and represent deviation from PT
|
|
453
|
+
weights = -weights # Now [n-1, n-2, ..., 1, 0]
|
|
454
|
+
elif self.violation_type == "constant":
|
|
455
|
+
# Same violation in all periods
|
|
456
|
+
weights = np.ones(n_pre)
|
|
457
|
+
elif self.violation_type == "last_period":
|
|
458
|
+
# Violation only in last pre-period (period -1)
|
|
459
|
+
weights = np.zeros(n_pre)
|
|
460
|
+
weights[-1] = 1.0
|
|
461
|
+
else:
|
|
462
|
+
raise ValueError(f"Unknown violation_type: {self.violation_type}")
|
|
463
|
+
|
|
464
|
+
# Normalize to unit norm (if not all zeros)
|
|
465
|
+
norm = np.linalg.norm(weights)
|
|
466
|
+
if norm > 0:
|
|
467
|
+
weights = weights / norm
|
|
468
|
+
|
|
469
|
+
return weights
|
|
470
|
+
|
|
471
|
+
def _extract_pre_period_params(
|
|
472
|
+
self,
|
|
473
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
474
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
|
|
475
|
+
"""
|
|
476
|
+
Extract pre-period parameters from results.
|
|
477
|
+
|
|
478
|
+
Returns
|
|
479
|
+
-------
|
|
480
|
+
effects : np.ndarray
|
|
481
|
+
Pre-period effect estimates.
|
|
482
|
+
ses : np.ndarray
|
|
483
|
+
Pre-period standard errors.
|
|
484
|
+
vcov : np.ndarray
|
|
485
|
+
Variance-covariance matrix for pre-period effects.
|
|
486
|
+
n_pre : int
|
|
487
|
+
Number of pre-periods.
|
|
488
|
+
"""
|
|
489
|
+
if isinstance(results, MultiPeriodDiDResults):
|
|
490
|
+
# Get pre-period information
|
|
491
|
+
all_pre_periods = results.pre_periods
|
|
492
|
+
|
|
493
|
+
if len(all_pre_periods) == 0:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"No pre-treatment periods found in results. "
|
|
496
|
+
"Pre-trends power analysis requires pre-period coefficients."
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Only include periods with actual estimated coefficients
|
|
500
|
+
# (excludes the reference period which is omitted from estimation)
|
|
501
|
+
if hasattr(results, 'coefficients') and results.coefficients:
|
|
502
|
+
# Find which pre-periods have estimated coefficients
|
|
503
|
+
estimated_pre_periods = [
|
|
504
|
+
p for p in all_pre_periods
|
|
505
|
+
if f"treated:period_{p}" in results.coefficients
|
|
506
|
+
]
|
|
507
|
+
|
|
508
|
+
if len(estimated_pre_periods) == 0:
|
|
509
|
+
raise ValueError(
|
|
510
|
+
"No estimated pre-period coefficients found. "
|
|
511
|
+
"The pre-trends test requires at least one estimated "
|
|
512
|
+
"pre-period coefficient (excluding the reference period)."
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
n_pre = len(estimated_pre_periods)
|
|
516
|
+
|
|
517
|
+
# Extract effects for estimated periods only
|
|
518
|
+
effects = np.array([
|
|
519
|
+
results.coefficients[f"treated:period_{p}"]
|
|
520
|
+
for p in estimated_pre_periods
|
|
521
|
+
])
|
|
522
|
+
|
|
523
|
+
# Extract SEs - try period_effects first, fall back to avg_se
|
|
524
|
+
ses = []
|
|
525
|
+
for p in estimated_pre_periods:
|
|
526
|
+
if p in results.period_effects:
|
|
527
|
+
ses.append(results.period_effects[p].se)
|
|
528
|
+
else:
|
|
529
|
+
ses.append(results.avg_se)
|
|
530
|
+
ses = np.array(ses)
|
|
531
|
+
|
|
532
|
+
# Extract vcov for estimated pre-periods
|
|
533
|
+
# Build mapping from period to vcov index
|
|
534
|
+
if results.vcov is not None:
|
|
535
|
+
# Get ordered list of all coefficient keys
|
|
536
|
+
coef_keys = list(results.coefficients.keys())
|
|
537
|
+
pre_indices = [
|
|
538
|
+
coef_keys.index(f"treated:period_{p}")
|
|
539
|
+
for p in estimated_pre_periods
|
|
540
|
+
if f"treated:period_{p}" in coef_keys
|
|
541
|
+
]
|
|
542
|
+
if len(pre_indices) == n_pre and results.vcov.shape[0] > max(pre_indices):
|
|
543
|
+
vcov = results.vcov[np.ix_(pre_indices, pre_indices)]
|
|
544
|
+
else:
|
|
545
|
+
# Fall back to diagonal
|
|
546
|
+
vcov = np.diag(ses ** 2)
|
|
547
|
+
else:
|
|
548
|
+
vcov = np.diag(ses ** 2)
|
|
549
|
+
else:
|
|
550
|
+
# No coefficients available - try period_effects for pre-periods
|
|
551
|
+
# Exclude reference period (the one with effect=0 and se=0 or missing)
|
|
552
|
+
estimated_pre_periods = [
|
|
553
|
+
p for p in all_pre_periods
|
|
554
|
+
if p in results.period_effects
|
|
555
|
+
and results.period_effects[p].se > 0
|
|
556
|
+
]
|
|
557
|
+
|
|
558
|
+
if len(estimated_pre_periods) == 0:
|
|
559
|
+
raise ValueError(
|
|
560
|
+
"No estimated pre-period effects found. "
|
|
561
|
+
"The pre-trends test requires at least one estimated "
|
|
562
|
+
"pre-period effect (excluding the reference period)."
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
n_pre = len(estimated_pre_periods)
|
|
566
|
+
effects = np.array([
|
|
567
|
+
results.period_effects[p].effect
|
|
568
|
+
for p in estimated_pre_periods
|
|
569
|
+
])
|
|
570
|
+
ses = np.array([
|
|
571
|
+
results.period_effects[p].se
|
|
572
|
+
for p in estimated_pre_periods
|
|
573
|
+
])
|
|
574
|
+
vcov = np.diag(ses ** 2)
|
|
575
|
+
|
|
576
|
+
return effects, ses, vcov, n_pre
|
|
577
|
+
|
|
578
|
+
# Try CallawaySantAnnaResults
|
|
579
|
+
try:
|
|
580
|
+
from diff_diff.staggered import CallawaySantAnnaResults
|
|
581
|
+
if isinstance(results, CallawaySantAnnaResults):
|
|
582
|
+
if results.event_study_effects is None:
|
|
583
|
+
raise ValueError(
|
|
584
|
+
"CallawaySantAnnaResults must have event_study_effects. "
|
|
585
|
+
"Re-run with aggregate='event_study'."
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# Get pre-period effects (negative relative times)
|
|
589
|
+
pre_effects = {
|
|
590
|
+
t: data for t, data in results.event_study_effects.items()
|
|
591
|
+
if t < 0
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
if not pre_effects:
|
|
595
|
+
raise ValueError("No pre-treatment periods found in event study.")
|
|
596
|
+
|
|
597
|
+
pre_periods = sorted(pre_effects.keys())
|
|
598
|
+
n_pre = len(pre_periods)
|
|
599
|
+
|
|
600
|
+
effects = np.array([pre_effects[t]['effect'] for t in pre_periods])
|
|
601
|
+
ses = np.array([pre_effects[t]['se'] for t in pre_periods])
|
|
602
|
+
vcov = np.diag(ses ** 2)
|
|
603
|
+
|
|
604
|
+
return effects, ses, vcov, n_pre
|
|
605
|
+
except ImportError:
|
|
606
|
+
pass
|
|
607
|
+
|
|
608
|
+
# Try SunAbrahamResults
|
|
609
|
+
try:
|
|
610
|
+
from diff_diff.sun_abraham import SunAbrahamResults
|
|
611
|
+
if isinstance(results, SunAbrahamResults):
|
|
612
|
+
# Get pre-period effects (negative relative times)
|
|
613
|
+
pre_effects = {
|
|
614
|
+
t: data for t, data in results.event_study_effects.items()
|
|
615
|
+
if t < 0
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
if not pre_effects:
|
|
619
|
+
raise ValueError("No pre-treatment periods found in event study.")
|
|
620
|
+
|
|
621
|
+
pre_periods = sorted(pre_effects.keys())
|
|
622
|
+
n_pre = len(pre_periods)
|
|
623
|
+
|
|
624
|
+
effects = np.array([pre_effects[t]['effect'] for t in pre_periods])
|
|
625
|
+
ses = np.array([pre_effects[t]['se'] for t in pre_periods])
|
|
626
|
+
vcov = np.diag(ses ** 2)
|
|
627
|
+
|
|
628
|
+
return effects, ses, vcov, n_pre
|
|
629
|
+
except ImportError:
|
|
630
|
+
pass
|
|
631
|
+
|
|
632
|
+
raise TypeError(
|
|
633
|
+
f"Unsupported results type: {type(results)}. "
|
|
634
|
+
"Expected MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults."
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
def _compute_power(
|
|
638
|
+
self,
|
|
639
|
+
M: float,
|
|
640
|
+
weights: np.ndarray,
|
|
641
|
+
vcov: np.ndarray,
|
|
642
|
+
) -> Tuple[float, float, float, float]:
|
|
643
|
+
"""
|
|
644
|
+
Compute power to detect violation of magnitude M.
|
|
645
|
+
|
|
646
|
+
The pre-trends test is a Wald test: H0: delta = 0 vs H1: delta != 0
|
|
647
|
+
Under H1 with violation delta = M * weights, the test statistic follows
|
|
648
|
+
a non-central chi-squared distribution.
|
|
649
|
+
|
|
650
|
+
Parameters
|
|
651
|
+
----------
|
|
652
|
+
M : float
|
|
653
|
+
Violation magnitude.
|
|
654
|
+
weights : np.ndarray
|
|
655
|
+
Normalized violation pattern.
|
|
656
|
+
vcov : np.ndarray
|
|
657
|
+
Variance-covariance matrix.
|
|
658
|
+
|
|
659
|
+
Returns
|
|
660
|
+
-------
|
|
661
|
+
power : float
|
|
662
|
+
Power to detect this violation.
|
|
663
|
+
noncentrality : float
|
|
664
|
+
Non-centrality parameter.
|
|
665
|
+
test_stat : float
|
|
666
|
+
Expected test statistic under H1.
|
|
667
|
+
critical_value : float
|
|
668
|
+
Critical value for the test.
|
|
669
|
+
"""
|
|
670
|
+
n_pre = len(weights)
|
|
671
|
+
|
|
672
|
+
# Violation vector: delta = M * weights
|
|
673
|
+
delta = M * weights
|
|
674
|
+
|
|
675
|
+
# Non-centrality parameter for chi-squared test
|
|
676
|
+
# lambda = delta' * V^{-1} * delta
|
|
677
|
+
try:
|
|
678
|
+
vcov_inv = np.linalg.inv(vcov)
|
|
679
|
+
noncentrality = delta @ vcov_inv @ delta
|
|
680
|
+
except np.linalg.LinAlgError:
|
|
681
|
+
# Singular matrix - use pseudo-inverse
|
|
682
|
+
vcov_inv = np.linalg.pinv(vcov)
|
|
683
|
+
noncentrality = delta @ vcov_inv @ delta
|
|
684
|
+
|
|
685
|
+
# Critical value from chi-squared distribution
|
|
686
|
+
critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
|
|
687
|
+
|
|
688
|
+
# Power = P(chi2_nc > critical_value) where chi2_nc is non-central chi2
|
|
689
|
+
if noncentrality > 0:
|
|
690
|
+
power = 1 - stats.ncx2.cdf(critical_value, df=n_pre, nc=noncentrality)
|
|
691
|
+
else:
|
|
692
|
+
power = self.alpha # Size under null
|
|
693
|
+
|
|
694
|
+
# Expected test statistic under H1
|
|
695
|
+
test_stat = n_pre + noncentrality # Mean of non-central chi2
|
|
696
|
+
|
|
697
|
+
return power, noncentrality, test_stat, critical_value
|
|
698
|
+
|
|
699
|
+
def _compute_mdv(
|
|
700
|
+
self,
|
|
701
|
+
weights: np.ndarray,
|
|
702
|
+
vcov: np.ndarray,
|
|
703
|
+
) -> float:
|
|
704
|
+
"""
|
|
705
|
+
Compute minimum detectable violation.
|
|
706
|
+
|
|
707
|
+
Find the smallest M such that power >= target_power.
|
|
708
|
+
|
|
709
|
+
Parameters
|
|
710
|
+
----------
|
|
711
|
+
weights : np.ndarray
|
|
712
|
+
Normalized violation pattern.
|
|
713
|
+
vcov : np.ndarray
|
|
714
|
+
Variance-covariance matrix.
|
|
715
|
+
|
|
716
|
+
Returns
|
|
717
|
+
-------
|
|
718
|
+
mdv : float
|
|
719
|
+
Minimum detectable violation.
|
|
720
|
+
"""
|
|
721
|
+
n_pre = len(weights)
|
|
722
|
+
|
|
723
|
+
# Critical value
|
|
724
|
+
critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre)
|
|
725
|
+
|
|
726
|
+
# Find non-centrality parameter for target power
|
|
727
|
+
# We need: P(ncx2 > critical_value) = target_power
|
|
728
|
+
# Use inverse: find lambda such that ncx2.cdf(cv, df, lambda) = 1 - target_power
|
|
729
|
+
|
|
730
|
+
def power_minus_target(nc):
|
|
731
|
+
if nc <= 0:
|
|
732
|
+
return self.alpha - self.target_power
|
|
733
|
+
return stats.ncx2.sf(critical_value, df=n_pre, nc=nc) - self.target_power
|
|
734
|
+
|
|
735
|
+
# Binary search for non-centrality parameter
|
|
736
|
+
# Start with bounds
|
|
737
|
+
nc_low, nc_high = 0, 1
|
|
738
|
+
|
|
739
|
+
# Expand upper bound until power exceeds target
|
|
740
|
+
while power_minus_target(nc_high) < 0 and nc_high < 1000:
|
|
741
|
+
nc_high *= 2
|
|
742
|
+
|
|
743
|
+
if nc_high >= 1000:
|
|
744
|
+
# Target power not achievable - return inf
|
|
745
|
+
return np.inf
|
|
746
|
+
|
|
747
|
+
# Binary search
|
|
748
|
+
try:
|
|
749
|
+
result = optimize.brentq(power_minus_target, nc_low, nc_high)
|
|
750
|
+
target_nc = result
|
|
751
|
+
except ValueError:
|
|
752
|
+
# Fallback: use approximate formula
|
|
753
|
+
# For chi2, power ≈ Phi(sqrt(2*nc) - sqrt(2*cv))
|
|
754
|
+
# Solving: sqrt(2*nc) = z_power + sqrt(2*cv)
|
|
755
|
+
z_power = stats.norm.ppf(self.target_power)
|
|
756
|
+
target_nc = 0.5 * (z_power + np.sqrt(2 * critical_value)) ** 2
|
|
757
|
+
|
|
758
|
+
# Convert non-centrality to M
|
|
759
|
+
# nc = delta' * V^{-1} * delta = M^2 * w' * V^{-1} * w
|
|
760
|
+
try:
|
|
761
|
+
vcov_inv = np.linalg.inv(vcov)
|
|
762
|
+
w_Vinv_w = weights @ vcov_inv @ weights
|
|
763
|
+
except np.linalg.LinAlgError:
|
|
764
|
+
vcov_inv = np.linalg.pinv(vcov)
|
|
765
|
+
w_Vinv_w = weights @ vcov_inv @ weights
|
|
766
|
+
|
|
767
|
+
if w_Vinv_w > 0:
|
|
768
|
+
mdv = np.sqrt(target_nc / w_Vinv_w)
|
|
769
|
+
else:
|
|
770
|
+
mdv = np.inf
|
|
771
|
+
|
|
772
|
+
return mdv
|
|
773
|
+
|
|
774
|
+
def fit(
|
|
775
|
+
self,
|
|
776
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
777
|
+
M: Optional[float] = None,
|
|
778
|
+
) -> PreTrendsPowerResults:
|
|
779
|
+
"""
|
|
780
|
+
Compute pre-trends power analysis.
|
|
781
|
+
|
|
782
|
+
Parameters
|
|
783
|
+
----------
|
|
784
|
+
results : MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults
|
|
785
|
+
Results from an event study estimation.
|
|
786
|
+
M : float, optional
|
|
787
|
+
Specific violation magnitude to evaluate. If None, evaluates at
|
|
788
|
+
a default magnitude based on the data.
|
|
789
|
+
|
|
790
|
+
Returns
|
|
791
|
+
-------
|
|
792
|
+
PreTrendsPowerResults
|
|
793
|
+
Power analysis results including power and MDV.
|
|
794
|
+
"""
|
|
795
|
+
# Extract pre-period parameters
|
|
796
|
+
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
|
|
797
|
+
|
|
798
|
+
# Get violation weights
|
|
799
|
+
weights = self._get_violation_weights(n_pre)
|
|
800
|
+
|
|
801
|
+
# Compute MDV
|
|
802
|
+
mdv = self._compute_mdv(weights, vcov)
|
|
803
|
+
|
|
804
|
+
# Default M: use MDV if not specified
|
|
805
|
+
if M is None:
|
|
806
|
+
M = mdv if np.isfinite(mdv) else np.max(ses)
|
|
807
|
+
|
|
808
|
+
# Compute power at specified M
|
|
809
|
+
power, noncentrality, test_stat, critical_value = self._compute_power(
|
|
810
|
+
M, weights, vcov
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
return PreTrendsPowerResults(
|
|
814
|
+
power=power,
|
|
815
|
+
mdv=mdv,
|
|
816
|
+
violation_magnitude=M,
|
|
817
|
+
violation_type=self.violation_type,
|
|
818
|
+
alpha=self.alpha,
|
|
819
|
+
target_power=self.target_power,
|
|
820
|
+
n_pre_periods=n_pre,
|
|
821
|
+
test_statistic=test_stat,
|
|
822
|
+
critical_value=critical_value,
|
|
823
|
+
noncentrality=noncentrality,
|
|
824
|
+
pre_period_effects=effects,
|
|
825
|
+
pre_period_ses=ses,
|
|
826
|
+
vcov=vcov,
|
|
827
|
+
original_results=results,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
def power_at(
|
|
831
|
+
self,
|
|
832
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
833
|
+
M: float,
|
|
834
|
+
) -> float:
|
|
835
|
+
"""
|
|
836
|
+
Compute power to detect a specific violation magnitude.
|
|
837
|
+
|
|
838
|
+
Parameters
|
|
839
|
+
----------
|
|
840
|
+
results : results object
|
|
841
|
+
Event study results.
|
|
842
|
+
M : float
|
|
843
|
+
Violation magnitude.
|
|
844
|
+
|
|
845
|
+
Returns
|
|
846
|
+
-------
|
|
847
|
+
float
|
|
848
|
+
Power to detect violation of magnitude M.
|
|
849
|
+
"""
|
|
850
|
+
result = self.fit(results, M=M)
|
|
851
|
+
return result.power
|
|
852
|
+
|
|
853
|
+
def power_curve(
|
|
854
|
+
self,
|
|
855
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
856
|
+
M_grid: Optional[List[float]] = None,
|
|
857
|
+
n_points: int = 50,
|
|
858
|
+
) -> PreTrendsPowerCurve:
|
|
859
|
+
"""
|
|
860
|
+
Compute power across a range of violation magnitudes.
|
|
861
|
+
|
|
862
|
+
Parameters
|
|
863
|
+
----------
|
|
864
|
+
results : results object
|
|
865
|
+
Event study results.
|
|
866
|
+
M_grid : list of float, optional
|
|
867
|
+
Specific violation magnitudes to evaluate. If None, creates
|
|
868
|
+
automatic grid from 0 to 2.5 * MDV.
|
|
869
|
+
n_points : int, default=50
|
|
870
|
+
Number of points in automatic grid.
|
|
871
|
+
|
|
872
|
+
Returns
|
|
873
|
+
-------
|
|
874
|
+
PreTrendsPowerCurve
|
|
875
|
+
Power curve data with plot method.
|
|
876
|
+
"""
|
|
877
|
+
# Extract parameters
|
|
878
|
+
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
|
|
879
|
+
weights = self._get_violation_weights(n_pre)
|
|
880
|
+
|
|
881
|
+
# Compute MDV
|
|
882
|
+
mdv = self._compute_mdv(weights, vcov)
|
|
883
|
+
|
|
884
|
+
# Create M grid if not provided
|
|
885
|
+
if M_grid is None:
|
|
886
|
+
max_M = min(2.5 * mdv if np.isfinite(mdv) else 10 * np.max(ses), 100)
|
|
887
|
+
M_grid = np.linspace(0, max_M, n_points)
|
|
888
|
+
else:
|
|
889
|
+
M_grid = np.asarray(M_grid)
|
|
890
|
+
|
|
891
|
+
# Compute power at each M
|
|
892
|
+
powers = np.array([
|
|
893
|
+
self._compute_power(M, weights, vcov)[0]
|
|
894
|
+
for M in M_grid
|
|
895
|
+
])
|
|
896
|
+
|
|
897
|
+
return PreTrendsPowerCurve(
|
|
898
|
+
M_values=M_grid,
|
|
899
|
+
powers=powers,
|
|
900
|
+
mdv=mdv,
|
|
901
|
+
alpha=self.alpha,
|
|
902
|
+
target_power=self.target_power,
|
|
903
|
+
violation_type=self.violation_type,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
def sensitivity_to_honest_did(
|
|
907
|
+
self,
|
|
908
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
909
|
+
) -> Dict[str, Any]:
|
|
910
|
+
"""
|
|
911
|
+
Compare pre-trends power analysis with HonestDiD sensitivity.
|
|
912
|
+
|
|
913
|
+
This method helps interpret how informative a passing pre-trends
|
|
914
|
+
test is in the context of HonestDiD's relative magnitudes restriction.
|
|
915
|
+
|
|
916
|
+
Parameters
|
|
917
|
+
----------
|
|
918
|
+
results : results object
|
|
919
|
+
Event study results.
|
|
920
|
+
|
|
921
|
+
Returns
|
|
922
|
+
-------
|
|
923
|
+
dict
|
|
924
|
+
Dictionary with:
|
|
925
|
+
- mdv: Minimum detectable violation from pre-trends test
|
|
926
|
+
- honest_M_at_mdv: Corresponding M value for HonestDiD
|
|
927
|
+
- interpretation: Text explaining the relationship
|
|
928
|
+
"""
|
|
929
|
+
pt_results = self.fit(results)
|
|
930
|
+
mdv = pt_results.mdv
|
|
931
|
+
|
|
932
|
+
# The MDV represents the size of violation the test could detect
|
|
933
|
+
# In HonestDiD's relative magnitudes framework, M=1 means
|
|
934
|
+
# post-treatment violations can be as large as the max pre-period violation
|
|
935
|
+
# The MDV gives us a sense of how large that max violation could be
|
|
936
|
+
|
|
937
|
+
max_pre_se = np.max(pt_results.pre_period_ses)
|
|
938
|
+
|
|
939
|
+
interpretation = []
|
|
940
|
+
interpretation.append(
|
|
941
|
+
f"Minimum Detectable Violation (MDV): {mdv:.4f}"
|
|
942
|
+
)
|
|
943
|
+
interpretation.append(
|
|
944
|
+
f"Max pre-period SE: {max_pre_se:.4f}"
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
if np.isfinite(mdv):
|
|
948
|
+
# Ratio of MDV to max SE - gives sense of how many SEs the MDV is
|
|
949
|
+
mdv_in_ses = mdv / max_pre_se if max_pre_se > 0 else np.inf
|
|
950
|
+
interpretation.append(
|
|
951
|
+
f"MDV / max(SE): {mdv_in_ses:.2f}"
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
if mdv_in_ses < 1:
|
|
955
|
+
interpretation.append(
|
|
956
|
+
"→ Pre-trends test is fairly sensitive to violations."
|
|
957
|
+
)
|
|
958
|
+
elif mdv_in_ses < 2:
|
|
959
|
+
interpretation.append(
|
|
960
|
+
"→ Pre-trends test has moderate sensitivity."
|
|
961
|
+
)
|
|
962
|
+
else:
|
|
963
|
+
interpretation.append(
|
|
964
|
+
"→ Pre-trends test has low power to detect violations."
|
|
965
|
+
)
|
|
966
|
+
interpretation.append(
|
|
967
|
+
" Consider using HonestDiD with larger M values for robustness."
|
|
968
|
+
)
|
|
969
|
+
else:
|
|
970
|
+
interpretation.append(
|
|
971
|
+
"→ Pre-trends test cannot achieve target power for any violation size."
|
|
972
|
+
)
|
|
973
|
+
interpretation.append(
|
|
974
|
+
" Use HonestDiD sensitivity analysis for inference."
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
return {
|
|
978
|
+
"mdv": mdv,
|
|
979
|
+
"max_pre_se": max_pre_se,
|
|
980
|
+
"mdv_in_ses": mdv / max_pre_se if max_pre_se > 0 and np.isfinite(mdv) else np.inf,
|
|
981
|
+
"interpretation": "\n".join(interpretation),
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
# =============================================================================
|
|
986
|
+
# Convenience Functions
|
|
987
|
+
# =============================================================================
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def compute_pretrends_power(
|
|
991
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
992
|
+
M: Optional[float] = None,
|
|
993
|
+
alpha: float = 0.05,
|
|
994
|
+
target_power: float = 0.80,
|
|
995
|
+
violation_type: str = "linear",
|
|
996
|
+
) -> PreTrendsPowerResults:
|
|
997
|
+
"""
|
|
998
|
+
Convenience function for pre-trends power analysis.
|
|
999
|
+
|
|
1000
|
+
Parameters
|
|
1001
|
+
----------
|
|
1002
|
+
results : results object
|
|
1003
|
+
Event study results.
|
|
1004
|
+
M : float, optional
|
|
1005
|
+
Violation magnitude to evaluate.
|
|
1006
|
+
alpha : float, default=0.05
|
|
1007
|
+
Significance level.
|
|
1008
|
+
target_power : float, default=0.80
|
|
1009
|
+
Target power for MDV calculation.
|
|
1010
|
+
violation_type : str, default='linear'
|
|
1011
|
+
Type of violation pattern.
|
|
1012
|
+
|
|
1013
|
+
Returns
|
|
1014
|
+
-------
|
|
1015
|
+
PreTrendsPowerResults
|
|
1016
|
+
Power analysis results.
|
|
1017
|
+
|
|
1018
|
+
Examples
|
|
1019
|
+
--------
|
|
1020
|
+
>>> from diff_diff import MultiPeriodDiD
|
|
1021
|
+
>>> from diff_diff.pretrends import compute_pretrends_power
|
|
1022
|
+
>>>
|
|
1023
|
+
>>> results = MultiPeriodDiD().fit(data, ...)
|
|
1024
|
+
>>> power_results = compute_pretrends_power(results)
|
|
1025
|
+
>>> print(f"MDV: {power_results.mdv:.3f}")
|
|
1026
|
+
>>> print(f"Power: {power_results.power:.1%}")
|
|
1027
|
+
"""
|
|
1028
|
+
pt = PreTrendsPower(
|
|
1029
|
+
alpha=alpha,
|
|
1030
|
+
power=target_power,
|
|
1031
|
+
violation_type=violation_type,
|
|
1032
|
+
)
|
|
1033
|
+
return pt.fit(results, M=M)
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def compute_mdv(
|
|
1037
|
+
results: Union[MultiPeriodDiDResults, Any],
|
|
1038
|
+
alpha: float = 0.05,
|
|
1039
|
+
target_power: float = 0.80,
|
|
1040
|
+
violation_type: str = "linear",
|
|
1041
|
+
) -> float:
|
|
1042
|
+
"""
|
|
1043
|
+
Compute minimum detectable violation.
|
|
1044
|
+
|
|
1045
|
+
Parameters
|
|
1046
|
+
----------
|
|
1047
|
+
results : results object
|
|
1048
|
+
Event study results.
|
|
1049
|
+
alpha : float, default=0.05
|
|
1050
|
+
Significance level.
|
|
1051
|
+
target_power : float, default=0.80
|
|
1052
|
+
Target power.
|
|
1053
|
+
violation_type : str, default='linear'
|
|
1054
|
+
Type of violation pattern.
|
|
1055
|
+
|
|
1056
|
+
Returns
|
|
1057
|
+
-------
|
|
1058
|
+
float
|
|
1059
|
+
Minimum detectable violation.
|
|
1060
|
+
"""
|
|
1061
|
+
pt = PreTrendsPower(
|
|
1062
|
+
alpha=alpha,
|
|
1063
|
+
power=target_power,
|
|
1064
|
+
violation_type=violation_type,
|
|
1065
|
+
)
|
|
1066
|
+
result = pt.fit(results)
|
|
1067
|
+
return result.mdv
|