diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/power.py
ADDED
|
@@ -0,0 +1,2588 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Power analysis tools for difference-in-differences study design.
|
|
3
|
+
|
|
4
|
+
This module provides power calculations and simulation-based power analysis
|
|
5
|
+
for DiD study design, helping practitioners answer questions like:
|
|
6
|
+
- "How many units do I need to detect an effect of size X?"
|
|
7
|
+
- "What is the minimum detectable effect given my sample size?"
|
|
8
|
+
- "What power do I have to detect a given effect?"
|
|
9
|
+
|
|
10
|
+
References
|
|
11
|
+
----------
|
|
12
|
+
Bloom, H. S. (1995). "Minimum Detectable Effects: A Simple Way to Report the
|
|
13
|
+
Statistical Power of Experimental Designs." Evaluation Review, 19(5), 547-556.
|
|
14
|
+
|
|
15
|
+
Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
|
|
16
|
+
Journal of Development Economics, 144, 102458.
|
|
17
|
+
|
|
18
|
+
Djimeu, E. W., & Houndolo, D.-G. (2016). "Power Calculation for Causal Inference
|
|
19
|
+
in Social Science: Sample Size and Minimum Detectable Effect Determination."
|
|
20
|
+
Journal of Development Effectiveness, 8(4), 508-527.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import warnings
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
import pandas as pd
|
|
29
|
+
from scipy import stats
|
|
30
|
+
|
|
31
|
+
# Maximum sample size returned when effect is too small to detect
|
|
32
|
+
# (e.g., zero effect or extremely small relative to noise)
|
|
33
|
+
MAX_SAMPLE_SIZE = 2**31 - 1
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ---------------------------------------------------------------------------
|
|
37
|
+
# Estimator registry — maps estimator class names to DGP/fit/extract profiles
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class _EstimatorProfile:
|
|
43
|
+
"""Internal profile describing how to run power simulations for an estimator."""
|
|
44
|
+
|
|
45
|
+
default_dgp: Callable
|
|
46
|
+
dgp_kwargs_builder: Callable
|
|
47
|
+
fit_kwargs_builder: Callable
|
|
48
|
+
result_extractor: Callable
|
|
49
|
+
min_n: int = 20
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# -- DGP kwargs adapters -----------------------------------------------------
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _basic_dgp_kwargs(
|
|
56
|
+
n_units: int,
|
|
57
|
+
n_periods: int,
|
|
58
|
+
treatment_effect: float,
|
|
59
|
+
treatment_fraction: float,
|
|
60
|
+
treatment_period: int,
|
|
61
|
+
sigma: float,
|
|
62
|
+
) -> Dict[str, Any]:
|
|
63
|
+
return dict(
|
|
64
|
+
n_units=n_units,
|
|
65
|
+
n_periods=n_periods,
|
|
66
|
+
treatment_effect=treatment_effect,
|
|
67
|
+
treatment_fraction=treatment_fraction,
|
|
68
|
+
treatment_period=treatment_period,
|
|
69
|
+
noise_sd=sigma,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _staggered_dgp_kwargs(
|
|
74
|
+
n_units: int,
|
|
75
|
+
n_periods: int,
|
|
76
|
+
treatment_effect: float,
|
|
77
|
+
treatment_fraction: float,
|
|
78
|
+
treatment_period: int,
|
|
79
|
+
sigma: float,
|
|
80
|
+
) -> Dict[str, Any]:
|
|
81
|
+
return dict(
|
|
82
|
+
n_units=n_units,
|
|
83
|
+
n_periods=n_periods,
|
|
84
|
+
treatment_effect=treatment_effect,
|
|
85
|
+
never_treated_frac=1 - treatment_fraction,
|
|
86
|
+
cohort_periods=[treatment_period],
|
|
87
|
+
dynamic_effects=False,
|
|
88
|
+
noise_sd=sigma,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _factor_dgp_kwargs(
|
|
93
|
+
n_units: int,
|
|
94
|
+
n_periods: int,
|
|
95
|
+
treatment_effect: float,
|
|
96
|
+
treatment_fraction: float,
|
|
97
|
+
treatment_period: int,
|
|
98
|
+
sigma: float,
|
|
99
|
+
) -> Dict[str, Any]:
|
|
100
|
+
n_pre = treatment_period
|
|
101
|
+
n_post = n_periods - treatment_period
|
|
102
|
+
return dict(
|
|
103
|
+
n_units=n_units,
|
|
104
|
+
n_pre=n_pre,
|
|
105
|
+
n_post=n_post,
|
|
106
|
+
n_treated=max(1, int(n_units * treatment_fraction)),
|
|
107
|
+
treatment_effect=treatment_effect,
|
|
108
|
+
noise_sd=sigma,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _ddd_dgp_kwargs(
|
|
113
|
+
n_units: int,
|
|
114
|
+
n_periods: int,
|
|
115
|
+
treatment_effect: float,
|
|
116
|
+
treatment_fraction: float,
|
|
117
|
+
treatment_period: int,
|
|
118
|
+
sigma: float,
|
|
119
|
+
) -> Dict[str, Any]:
|
|
120
|
+
return dict(
|
|
121
|
+
n_per_cell=max(2, n_units // 8),
|
|
122
|
+
treatment_effect=treatment_effect,
|
|
123
|
+
noise_sd=sigma,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# -- Fit kwargs builders ------------------------------------------------------
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _basic_fit_kwargs(
|
|
131
|
+
data: pd.DataFrame,
|
|
132
|
+
n_units: int,
|
|
133
|
+
n_periods: int,
|
|
134
|
+
treatment_period: int,
|
|
135
|
+
) -> Dict[str, Any]:
|
|
136
|
+
return dict(outcome="outcome", treatment="treated", time="post")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _twfe_fit_kwargs(
|
|
140
|
+
data: pd.DataFrame,
|
|
141
|
+
n_units: int,
|
|
142
|
+
n_periods: int,
|
|
143
|
+
treatment_period: int,
|
|
144
|
+
) -> Dict[str, Any]:
|
|
145
|
+
return dict(outcome="outcome", treatment="treated", time="post", unit="unit")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _multiperiod_fit_kwargs(
|
|
149
|
+
data: pd.DataFrame,
|
|
150
|
+
n_units: int,
|
|
151
|
+
n_periods: int,
|
|
152
|
+
treatment_period: int,
|
|
153
|
+
) -> Dict[str, Any]:
|
|
154
|
+
return dict(
|
|
155
|
+
outcome="outcome",
|
|
156
|
+
treatment="treated",
|
|
157
|
+
time="period",
|
|
158
|
+
post_periods=list(range(treatment_period, n_periods)),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _staggered_fit_kwargs(
|
|
163
|
+
data: pd.DataFrame,
|
|
164
|
+
n_units: int,
|
|
165
|
+
n_periods: int,
|
|
166
|
+
treatment_period: int,
|
|
167
|
+
) -> Dict[str, Any]:
|
|
168
|
+
return dict(outcome="outcome", unit="unit", time="period", first_treat="first_treat")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _ddd_fit_kwargs(
|
|
172
|
+
data: pd.DataFrame,
|
|
173
|
+
n_units: int,
|
|
174
|
+
n_periods: int,
|
|
175
|
+
treatment_period: int,
|
|
176
|
+
) -> Dict[str, Any]:
|
|
177
|
+
return dict(outcome="outcome", group="group", partition="partition", time="time")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _trop_fit_kwargs(
|
|
181
|
+
data: pd.DataFrame,
|
|
182
|
+
n_units: int,
|
|
183
|
+
n_periods: int,
|
|
184
|
+
treatment_period: int,
|
|
185
|
+
) -> Dict[str, Any]:
|
|
186
|
+
return dict(outcome="outcome", treatment="treated", unit="unit", time="period")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _sdid_fit_kwargs(
|
|
190
|
+
data: pd.DataFrame,
|
|
191
|
+
n_units: int,
|
|
192
|
+
n_periods: int,
|
|
193
|
+
treatment_period: int,
|
|
194
|
+
) -> Dict[str, Any]:
|
|
195
|
+
periods = sorted(data["period"].unique())
|
|
196
|
+
post_periods = [p for p in periods if p >= treatment_period]
|
|
197
|
+
return dict(
|
|
198
|
+
outcome="outcome",
|
|
199
|
+
treatment="treat",
|
|
200
|
+
unit="unit",
|
|
201
|
+
time="period",
|
|
202
|
+
post_periods=post_periods,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# -- Result extractors --------------------------------------------------------
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _extract_simple(result: Any) -> Tuple[float, float, float, Tuple[float, float]]:
|
|
210
|
+
return (result.att, result.se, result.p_value, result.conf_int)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _extract_multiperiod(
|
|
214
|
+
result: Any,
|
|
215
|
+
) -> Tuple[float, float, float, Tuple[float, float]]:
|
|
216
|
+
return (result.avg_att, result.avg_se, result.avg_p_value, result.avg_conf_int)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _extract_staggered(
|
|
220
|
+
result: Any,
|
|
221
|
+
) -> Tuple[float, float, float, Tuple[float, float]]:
|
|
222
|
+
_nan = float("nan")
|
|
223
|
+
_nan_ci = (_nan, _nan)
|
|
224
|
+
|
|
225
|
+
def _first(r: Any, *attrs: str, default: Any = _nan) -> Any:
|
|
226
|
+
for a in attrs:
|
|
227
|
+
v = getattr(r, a, None)
|
|
228
|
+
if v is not None:
|
|
229
|
+
return v
|
|
230
|
+
return default
|
|
231
|
+
|
|
232
|
+
return (
|
|
233
|
+
result.overall_att,
|
|
234
|
+
_first(result, "overall_se", "overall_att_se"),
|
|
235
|
+
_first(result, "overall_p_value", "overall_att_p_value"),
|
|
236
|
+
_first(result, "overall_conf_int", "overall_att_ci", default=_nan_ci),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# Keys derived from simulate_power() public params — overriding these
|
|
241
|
+
# via data_generator_kwargs would desync the DGP from the result object.
|
|
242
|
+
_PROTECTED_DGP_KEYS = frozenset(
|
|
243
|
+
{
|
|
244
|
+
"treatment_effect", # → true_effect in results / MDE search variable
|
|
245
|
+
"noise_sd", # → sigma param
|
|
246
|
+
"n_units", # → sample-size search variable
|
|
247
|
+
"n_periods", # → n_periods param
|
|
248
|
+
"treatment_fraction", # → treatment_fraction param
|
|
249
|
+
"treatment_period", # → treatment_period param
|
|
250
|
+
"n_pre", # → derived from treatment_period in factor-model DGPs
|
|
251
|
+
"n_post", # → derived from n_periods - treatment_period in factor-model DGPs
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# -- Staggered DGP compatibility check ----------------------------------------
|
|
257
|
+
|
|
258
|
+
_STAGGERED_ESTIMATORS = frozenset(
|
|
259
|
+
{
|
|
260
|
+
"CallawaySantAnna",
|
|
261
|
+
"SunAbraham",
|
|
262
|
+
"ImputationDiD",
|
|
263
|
+
"TwoStageDiD",
|
|
264
|
+
"StackedDiD",
|
|
265
|
+
"EfficientDiD",
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _check_staggered_dgp_compat(
|
|
271
|
+
estimator: Any,
|
|
272
|
+
data_generator_kwargs: Optional[Dict[str, Any]],
|
|
273
|
+
) -> None:
|
|
274
|
+
"""Warn if a staggered estimator's settings don't match the default DGP."""
|
|
275
|
+
name = type(estimator).__name__
|
|
276
|
+
if name not in _STAGGERED_ESTIMATORS:
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
dgp_overrides = data_generator_kwargs or {}
|
|
280
|
+
cohort_periods = dgp_overrides.get("cohort_periods")
|
|
281
|
+
has_multi_cohort = cohort_periods is not None and len(set(cohort_periods)) >= 2
|
|
282
|
+
issues: List[str] = []
|
|
283
|
+
|
|
284
|
+
# Check control_group="not_yet_treated" (CS, SA)
|
|
285
|
+
cg = getattr(estimator, "control_group", "never_treated")
|
|
286
|
+
if cg == "not_yet_treated" and not has_multi_cohort:
|
|
287
|
+
issues.append(
|
|
288
|
+
f' - {name} has control_group="not_yet_treated" but the default '
|
|
289
|
+
f"DGP generates a single treatment cohort with never-treated "
|
|
290
|
+
f"controls. Power may not reflect the intended not-yet-treated "
|
|
291
|
+
f"design.\n"
|
|
292
|
+
f" Fix: pass data_generator_kwargs="
|
|
293
|
+
f'{{"cohort_periods": [2, 4], "never_treated_frac": 0.0}} '
|
|
294
|
+
f"(or a custom data_generator)."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Check anticipation > 0 (all staggered)
|
|
298
|
+
antic = getattr(estimator, "anticipation", 0)
|
|
299
|
+
if antic > 0:
|
|
300
|
+
issues.append(
|
|
301
|
+
f" - {name} has anticipation={antic} but the default DGP does "
|
|
302
|
+
f"not model anticipatory effects. The estimator will look for "
|
|
303
|
+
f"treatment effects {antic} period(s) before the DGP generates "
|
|
304
|
+
f"them, biasing power estimates.\n"
|
|
305
|
+
f" Fix: supply a custom data_generator that shifts the "
|
|
306
|
+
f"effect onset."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Check clean_control on StackedDiD
|
|
310
|
+
if name == "StackedDiD":
|
|
311
|
+
cc = getattr(estimator, "clean_control", "not_yet_treated")
|
|
312
|
+
if cc == "strict" and not has_multi_cohort:
|
|
313
|
+
issues.append(
|
|
314
|
+
' - StackedDiD has clean_control="strict" but the default '
|
|
315
|
+
"single-cohort DGP makes strict controls equivalent to "
|
|
316
|
+
"never-treated controls.\n"
|
|
317
|
+
" Fix: pass data_generator_kwargs="
|
|
318
|
+
'{"cohort_periods": [2, 4]} '
|
|
319
|
+
"to test true strict clean-control behavior."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if issues:
|
|
323
|
+
msg = (
|
|
324
|
+
f"Staggered power DGP mismatch for {name}. The default "
|
|
325
|
+
f"single-cohort DGP may not match the estimator "
|
|
326
|
+
f"configuration:\n" + "\n".join(issues)
|
|
327
|
+
)
|
|
328
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _ddd_effective_n(
|
|
332
|
+
n_units: int, data_generator_kwargs: Optional[Dict[str, Any]]
|
|
333
|
+
) -> Optional[int]:
|
|
334
|
+
"""Return effective DDD sample size, or None if no rounding occurred."""
|
|
335
|
+
overrides = data_generator_kwargs or {}
|
|
336
|
+
if "n_per_cell" in overrides:
|
|
337
|
+
eff = overrides["n_per_cell"] * 8
|
|
338
|
+
else:
|
|
339
|
+
eff = max(2, n_units // 8) * 8
|
|
340
|
+
return eff if eff != n_units else None
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _check_ddd_dgp_compat(
|
|
344
|
+
n_units: int,
|
|
345
|
+
n_periods: int,
|
|
346
|
+
treatment_fraction: float,
|
|
347
|
+
treatment_period: int,
|
|
348
|
+
data_generator_kwargs: Optional[Dict[str, Any]],
|
|
349
|
+
) -> None:
|
|
350
|
+
"""Warn when simulation inputs don't match DDD's fixed 2×2×2 design."""
|
|
351
|
+
issues: List[str] = []
|
|
352
|
+
|
|
353
|
+
# DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored
|
|
354
|
+
if n_periods != 2:
|
|
355
|
+
issues.append(
|
|
356
|
+
f"n_periods={n_periods} is ignored (DDD uses a fixed " f"2-period design: pre/post)"
|
|
357
|
+
)
|
|
358
|
+
if treatment_period != 1:
|
|
359
|
+
issues.append(
|
|
360
|
+
f"treatment_period={treatment_period} is ignored (DDD "
|
|
361
|
+
f"always treats in the second period)"
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# DDD's 2×2×2 factorial has inherent 50% treatment fraction
|
|
365
|
+
if treatment_fraction != 0.5:
|
|
366
|
+
issues.append(
|
|
367
|
+
f"treatment_fraction={treatment_fraction} is ignored "
|
|
368
|
+
f"(DDD uses a balanced 2×2×2 factorial where 50% of "
|
|
369
|
+
f"groups are treated)"
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# n_units rounding: n_per_cell = max(2, n_units // 8)
|
|
373
|
+
eff_n = _ddd_effective_n(n_units, data_generator_kwargs)
|
|
374
|
+
if eff_n is not None:
|
|
375
|
+
eff_n_per_cell = eff_n // 8
|
|
376
|
+
issues.append(
|
|
377
|
+
f"effective sample size is {eff_n} "
|
|
378
|
+
f"(n_per_cell={eff_n_per_cell} × 8 cells), "
|
|
379
|
+
f"not the requested n_units={n_units}"
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if issues:
|
|
383
|
+
warnings.warn(
|
|
384
|
+
"TripleDifference uses a fixed 2×2×2 factorial DGP "
|
|
385
|
+
"(group × partition × time). "
|
|
386
|
+
+ "; ".join(issues)
|
|
387
|
+
+ ". Pass a custom data_generator for non-standard DDD designs.",
|
|
388
|
+
UserWarning,
|
|
389
|
+
stacklevel=2,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _check_sdid_placebo_data(
|
|
394
|
+
data: pd.DataFrame,
|
|
395
|
+
estimator: Any,
|
|
396
|
+
est_kwargs: Dict[str, Any],
|
|
397
|
+
) -> None:
|
|
398
|
+
"""Check SyntheticDiD placebo feasibility on realized data.
|
|
399
|
+
|
|
400
|
+
This catches infeasible designs on the custom-DGP path where the
|
|
401
|
+
pre-generation check (which uses ``n_units * treatment_fraction``)
|
|
402
|
+
cannot run because treatment allocation is determined by the DGP.
|
|
403
|
+
"""
|
|
404
|
+
vm = getattr(estimator, "variance_method", "placebo")
|
|
405
|
+
if vm != "placebo":
|
|
406
|
+
return
|
|
407
|
+
|
|
408
|
+
treat_col = est_kwargs.get("treatment", "treat")
|
|
409
|
+
unit_col = est_kwargs.get("unit", "unit")
|
|
410
|
+
|
|
411
|
+
if treat_col not in data.columns or unit_col not in data.columns:
|
|
412
|
+
return # fit will fail with a more specific error
|
|
413
|
+
|
|
414
|
+
unit_treat = data.groupby(unit_col)[treat_col].first()
|
|
415
|
+
n_treated = int(unit_treat.sum())
|
|
416
|
+
n_control = len(unit_treat) - n_treated
|
|
417
|
+
|
|
418
|
+
if n_control <= n_treated:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"SyntheticDiD placebo variance requires more control than "
|
|
421
|
+
f"treated units, but the generated data has n_control={n_control}, "
|
|
422
|
+
f"n_treated={n_treated}. Either adjust your data_generator so that "
|
|
423
|
+
f"n_control > n_treated, or use "
|
|
424
|
+
f"SyntheticDiD(variance_method='bootstrap')."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
# -- Registry construction (deferred to avoid import-time cost) ---------------
|
|
429
|
+
|
|
430
|
+
_ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _get_registry() -> Dict[str, _EstimatorProfile]:
|
|
434
|
+
"""Lazily build and return the estimator registry."""
|
|
435
|
+
global _ESTIMATOR_REGISTRY # noqa: PLW0603
|
|
436
|
+
if _ESTIMATOR_REGISTRY is not None:
|
|
437
|
+
return _ESTIMATOR_REGISTRY
|
|
438
|
+
|
|
439
|
+
from diff_diff.prep import (
|
|
440
|
+
generate_ddd_data,
|
|
441
|
+
generate_did_data,
|
|
442
|
+
generate_factor_data,
|
|
443
|
+
generate_staggered_data,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
_ESTIMATOR_REGISTRY = {
|
|
447
|
+
# --- Basic DiD group ---
|
|
448
|
+
"DifferenceInDifferences": _EstimatorProfile(
|
|
449
|
+
default_dgp=generate_did_data,
|
|
450
|
+
dgp_kwargs_builder=_basic_dgp_kwargs,
|
|
451
|
+
fit_kwargs_builder=_basic_fit_kwargs,
|
|
452
|
+
result_extractor=_extract_simple,
|
|
453
|
+
min_n=20,
|
|
454
|
+
),
|
|
455
|
+
"TwoWayFixedEffects": _EstimatorProfile(
|
|
456
|
+
default_dgp=generate_did_data,
|
|
457
|
+
dgp_kwargs_builder=_basic_dgp_kwargs,
|
|
458
|
+
fit_kwargs_builder=_twfe_fit_kwargs,
|
|
459
|
+
result_extractor=_extract_simple,
|
|
460
|
+
min_n=20,
|
|
461
|
+
),
|
|
462
|
+
"MultiPeriodDiD": _EstimatorProfile(
|
|
463
|
+
default_dgp=generate_did_data,
|
|
464
|
+
dgp_kwargs_builder=_basic_dgp_kwargs,
|
|
465
|
+
fit_kwargs_builder=_multiperiod_fit_kwargs,
|
|
466
|
+
result_extractor=_extract_multiperiod,
|
|
467
|
+
min_n=20,
|
|
468
|
+
),
|
|
469
|
+
# --- Staggered group ---
|
|
470
|
+
"CallawaySantAnna": _EstimatorProfile(
|
|
471
|
+
default_dgp=generate_staggered_data,
|
|
472
|
+
dgp_kwargs_builder=_staggered_dgp_kwargs,
|
|
473
|
+
fit_kwargs_builder=_staggered_fit_kwargs,
|
|
474
|
+
result_extractor=_extract_staggered,
|
|
475
|
+
min_n=40,
|
|
476
|
+
),
|
|
477
|
+
"SunAbraham": _EstimatorProfile(
|
|
478
|
+
default_dgp=generate_staggered_data,
|
|
479
|
+
dgp_kwargs_builder=_staggered_dgp_kwargs,
|
|
480
|
+
fit_kwargs_builder=_staggered_fit_kwargs,
|
|
481
|
+
result_extractor=_extract_staggered,
|
|
482
|
+
min_n=40,
|
|
483
|
+
),
|
|
484
|
+
"ImputationDiD": _EstimatorProfile(
|
|
485
|
+
default_dgp=generate_staggered_data,
|
|
486
|
+
dgp_kwargs_builder=_staggered_dgp_kwargs,
|
|
487
|
+
fit_kwargs_builder=_staggered_fit_kwargs,
|
|
488
|
+
result_extractor=_extract_staggered,
|
|
489
|
+
min_n=40,
|
|
490
|
+
),
|
|
491
|
+
"TwoStageDiD": _EstimatorProfile(
|
|
492
|
+
default_dgp=generate_staggered_data,
|
|
493
|
+
dgp_kwargs_builder=_staggered_dgp_kwargs,
|
|
494
|
+
fit_kwargs_builder=_staggered_fit_kwargs,
|
|
495
|
+
result_extractor=_extract_staggered,
|
|
496
|
+
min_n=40,
|
|
497
|
+
),
|
|
498
|
+
"StackedDiD": _EstimatorProfile(
|
|
499
|
+
default_dgp=generate_staggered_data,
|
|
500
|
+
dgp_kwargs_builder=_staggered_dgp_kwargs,
|
|
501
|
+
fit_kwargs_builder=_staggered_fit_kwargs,
|
|
502
|
+
result_extractor=_extract_staggered,
|
|
503
|
+
min_n=40,
|
|
504
|
+
),
|
|
505
|
+
"EfficientDiD": _EstimatorProfile(
|
|
506
|
+
default_dgp=generate_staggered_data,
|
|
507
|
+
dgp_kwargs_builder=_staggered_dgp_kwargs,
|
|
508
|
+
fit_kwargs_builder=_staggered_fit_kwargs,
|
|
509
|
+
result_extractor=_extract_staggered,
|
|
510
|
+
min_n=40,
|
|
511
|
+
),
|
|
512
|
+
# --- Factor model group ---
|
|
513
|
+
"TROP": _EstimatorProfile(
|
|
514
|
+
default_dgp=generate_factor_data,
|
|
515
|
+
dgp_kwargs_builder=_factor_dgp_kwargs,
|
|
516
|
+
fit_kwargs_builder=_trop_fit_kwargs,
|
|
517
|
+
result_extractor=_extract_simple,
|
|
518
|
+
min_n=30,
|
|
519
|
+
),
|
|
520
|
+
"SyntheticDiD": _EstimatorProfile(
|
|
521
|
+
default_dgp=generate_factor_data,
|
|
522
|
+
dgp_kwargs_builder=_factor_dgp_kwargs,
|
|
523
|
+
fit_kwargs_builder=_sdid_fit_kwargs,
|
|
524
|
+
result_extractor=_extract_simple,
|
|
525
|
+
min_n=30,
|
|
526
|
+
),
|
|
527
|
+
# --- Triple difference ---
|
|
528
|
+
"TripleDifference": _EstimatorProfile(
|
|
529
|
+
default_dgp=generate_ddd_data,
|
|
530
|
+
dgp_kwargs_builder=_ddd_dgp_kwargs,
|
|
531
|
+
fit_kwargs_builder=_ddd_fit_kwargs,
|
|
532
|
+
result_extractor=_extract_simple,
|
|
533
|
+
min_n=64,
|
|
534
|
+
),
|
|
535
|
+
}
|
|
536
|
+
return _ESTIMATOR_REGISTRY
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
@dataclass
|
|
540
|
+
class PowerResults:
|
|
541
|
+
"""
|
|
542
|
+
Results from analytical power analysis.
|
|
543
|
+
|
|
544
|
+
Attributes
|
|
545
|
+
----------
|
|
546
|
+
power : float
|
|
547
|
+
Statistical power (probability of rejecting H0 when effect exists).
|
|
548
|
+
mde : float
|
|
549
|
+
Minimum detectable effect size.
|
|
550
|
+
required_n : int
|
|
551
|
+
Required total sample size (treated + control).
|
|
552
|
+
effect_size : float
|
|
553
|
+
Effect size used in calculation.
|
|
554
|
+
alpha : float
|
|
555
|
+
Significance level.
|
|
556
|
+
alternative : str
|
|
557
|
+
Alternative hypothesis ('two-sided', 'greater', 'less').
|
|
558
|
+
n_treated : int
|
|
559
|
+
Number of treated units.
|
|
560
|
+
n_control : int
|
|
561
|
+
Number of control units.
|
|
562
|
+
n_pre : int
|
|
563
|
+
Number of pre-treatment periods.
|
|
564
|
+
n_post : int
|
|
565
|
+
Number of post-treatment periods.
|
|
566
|
+
sigma : float
|
|
567
|
+
Residual standard deviation.
|
|
568
|
+
rho : float
|
|
569
|
+
Intra-cluster correlation (for panel data).
|
|
570
|
+
design : str
|
|
571
|
+
Study design type ('basic_did', 'panel', 'staggered').
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
power: float
|
|
575
|
+
mde: float
|
|
576
|
+
required_n: int
|
|
577
|
+
effect_size: float
|
|
578
|
+
alpha: float
|
|
579
|
+
alternative: str
|
|
580
|
+
n_treated: int
|
|
581
|
+
n_control: int
|
|
582
|
+
n_pre: int
|
|
583
|
+
n_post: int
|
|
584
|
+
sigma: float
|
|
585
|
+
rho: float = 0.0
|
|
586
|
+
design: str = "basic_did"
|
|
587
|
+
|
|
588
|
+
def __repr__(self) -> str:
|
|
589
|
+
"""Concise string representation."""
|
|
590
|
+
return (
|
|
591
|
+
f"PowerResults(power={self.power:.3f}, mde={self.mde:.4f}, "
|
|
592
|
+
f"required_n={self.required_n})"
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def summary(self) -> str:
|
|
596
|
+
"""
|
|
597
|
+
Generate a formatted summary of power analysis results.
|
|
598
|
+
|
|
599
|
+
Returns
|
|
600
|
+
-------
|
|
601
|
+
str
|
|
602
|
+
Formatted summary table.
|
|
603
|
+
"""
|
|
604
|
+
lines = [
|
|
605
|
+
"=" * 60,
|
|
606
|
+
"Power Analysis for Difference-in-Differences".center(60),
|
|
607
|
+
"=" * 60,
|
|
608
|
+
"",
|
|
609
|
+
f"{'Design:':<30} {self.design}",
|
|
610
|
+
f"{'Significance level (alpha):':<30} {self.alpha:.3f}",
|
|
611
|
+
f"{'Alternative hypothesis:':<30} {self.alternative}",
|
|
612
|
+
"",
|
|
613
|
+
"-" * 60,
|
|
614
|
+
"Sample Size".center(60),
|
|
615
|
+
"-" * 60,
|
|
616
|
+
f"{'Treated units:':<30} {self.n_treated:>10}",
|
|
617
|
+
f"{'Control units:':<30} {self.n_control:>10}",
|
|
618
|
+
f"{'Total units:':<30} {self.n_treated + self.n_control:>10}",
|
|
619
|
+
f"{'Pre-treatment periods:':<30} {self.n_pre:>10}",
|
|
620
|
+
f"{'Post-treatment periods:':<30} {self.n_post:>10}",
|
|
621
|
+
"",
|
|
622
|
+
"-" * 60,
|
|
623
|
+
"Variance Parameters".center(60),
|
|
624
|
+
"-" * 60,
|
|
625
|
+
f"{'Residual SD (sigma):':<30} {self.sigma:>10.4f}",
|
|
626
|
+
f"{'Intra-cluster correlation:':<30} {self.rho:>10.4f}",
|
|
627
|
+
"",
|
|
628
|
+
"-" * 60,
|
|
629
|
+
"Power Analysis Results".center(60),
|
|
630
|
+
"-" * 60,
|
|
631
|
+
f"{'Effect size:':<30} {self.effect_size:>10.4f}",
|
|
632
|
+
f"{'Power:':<30} {self.power:>10.1%}",
|
|
633
|
+
f"{'Minimum detectable effect:':<30} {self.mde:>10.4f}",
|
|
634
|
+
f"{'Required sample size:':<30} {self.required_n:>10}",
|
|
635
|
+
"=" * 60,
|
|
636
|
+
]
|
|
637
|
+
return "\n".join(lines)
|
|
638
|
+
|
|
639
|
+
def print_summary(self) -> None:
|
|
640
|
+
"""Print the summary to stdout."""
|
|
641
|
+
print(self.summary())
|
|
642
|
+
|
|
643
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
644
|
+
"""
|
|
645
|
+
Convert results to a dictionary.
|
|
646
|
+
|
|
647
|
+
Returns
|
|
648
|
+
-------
|
|
649
|
+
Dict[str, Any]
|
|
650
|
+
Dictionary containing all power analysis results.
|
|
651
|
+
"""
|
|
652
|
+
return {
|
|
653
|
+
"power": self.power,
|
|
654
|
+
"mde": self.mde,
|
|
655
|
+
"required_n": self.required_n,
|
|
656
|
+
"effect_size": self.effect_size,
|
|
657
|
+
"alpha": self.alpha,
|
|
658
|
+
"alternative": self.alternative,
|
|
659
|
+
"n_treated": self.n_treated,
|
|
660
|
+
"n_control": self.n_control,
|
|
661
|
+
"n_pre": self.n_pre,
|
|
662
|
+
"n_post": self.n_post,
|
|
663
|
+
"sigma": self.sigma,
|
|
664
|
+
"rho": self.rho,
|
|
665
|
+
"design": self.design,
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
669
|
+
"""
|
|
670
|
+
Convert results to a pandas DataFrame.
|
|
671
|
+
|
|
672
|
+
Returns
|
|
673
|
+
-------
|
|
674
|
+
pd.DataFrame
|
|
675
|
+
DataFrame with power analysis results.
|
|
676
|
+
"""
|
|
677
|
+
return pd.DataFrame([self.to_dict()])
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
@dataclass
|
|
681
|
+
class SimulationPowerResults:
|
|
682
|
+
"""
|
|
683
|
+
Results from simulation-based power analysis.
|
|
684
|
+
|
|
685
|
+
Attributes
|
|
686
|
+
----------
|
|
687
|
+
power : float
|
|
688
|
+
Estimated power (proportion of simulations rejecting H0).
|
|
689
|
+
power_se : float
|
|
690
|
+
Standard error of power estimate.
|
|
691
|
+
power_ci : Tuple[float, float]
|
|
692
|
+
Confidence interval for power estimate.
|
|
693
|
+
rejection_rate : float
|
|
694
|
+
Proportion of simulations with p-value < alpha.
|
|
695
|
+
mean_estimate : float
|
|
696
|
+
Mean treatment effect estimate across simulations.
|
|
697
|
+
std_estimate : float
|
|
698
|
+
Standard deviation of estimates across simulations.
|
|
699
|
+
mean_se : float
|
|
700
|
+
Mean standard error across simulations.
|
|
701
|
+
coverage : float
|
|
702
|
+
Proportion of CIs containing true effect.
|
|
703
|
+
n_simulations : int
|
|
704
|
+
Number of simulations performed.
|
|
705
|
+
effect_sizes : List[float]
|
|
706
|
+
Effect sizes tested (if multiple).
|
|
707
|
+
powers : List[float]
|
|
708
|
+
Power at each effect size (if multiple).
|
|
709
|
+
true_effect : float
|
|
710
|
+
True treatment effect used in simulation.
|
|
711
|
+
alpha : float
|
|
712
|
+
Significance level.
|
|
713
|
+
estimator_name : str
|
|
714
|
+
Name of the estimator used.
|
|
715
|
+
effective_n_units : int or None
|
|
716
|
+
Effective sample size when it differs from the requested ``n_units``
|
|
717
|
+
(e.g., due to DDD grid rounding). ``None`` when no rounding occurred.
|
|
718
|
+
"""
|
|
719
|
+
|
|
720
|
+
power: float
|
|
721
|
+
power_se: float
|
|
722
|
+
power_ci: Tuple[float, float]
|
|
723
|
+
rejection_rate: float
|
|
724
|
+
mean_estimate: float
|
|
725
|
+
std_estimate: float
|
|
726
|
+
mean_se: float
|
|
727
|
+
coverage: float
|
|
728
|
+
n_simulations: int
|
|
729
|
+
effect_sizes: List[float]
|
|
730
|
+
powers: List[float]
|
|
731
|
+
true_effect: float
|
|
732
|
+
alpha: float
|
|
733
|
+
estimator_name: str
|
|
734
|
+
bias: float = field(init=False)
|
|
735
|
+
rmse: float = field(init=False)
|
|
736
|
+
simulation_results: Optional[List[Dict[str, Any]]] = field(default=None, repr=False)
|
|
737
|
+
effective_n_units: Optional[int] = None
|
|
738
|
+
|
|
739
|
+
def __post_init__(self):
|
|
740
|
+
"""Compute derived statistics."""
|
|
741
|
+
self.bias = self.mean_estimate - self.true_effect
|
|
742
|
+
self.rmse = np.sqrt(self.bias**2 + self.std_estimate**2)
|
|
743
|
+
|
|
744
|
+
def __repr__(self) -> str:
|
|
745
|
+
"""Concise string representation."""
|
|
746
|
+
return (
|
|
747
|
+
f"SimulationPowerResults(power={self.power:.3f} "
|
|
748
|
+
f"[{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}], "
|
|
749
|
+
f"n_simulations={self.n_simulations})"
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
def summary(self) -> str:
|
|
753
|
+
"""
|
|
754
|
+
Generate a formatted summary of simulation power results.
|
|
755
|
+
|
|
756
|
+
Returns
|
|
757
|
+
-------
|
|
758
|
+
str
|
|
759
|
+
Formatted summary table.
|
|
760
|
+
"""
|
|
761
|
+
lines = [
|
|
762
|
+
"=" * 65,
|
|
763
|
+
"Simulation-Based Power Analysis Results".center(65),
|
|
764
|
+
"=" * 65,
|
|
765
|
+
"",
|
|
766
|
+
f"{'Estimator:':<35} {self.estimator_name}",
|
|
767
|
+
f"{'Number of simulations:':<35} {self.n_simulations}",
|
|
768
|
+
f"{'True treatment effect:':<35} {self.true_effect:.4f}",
|
|
769
|
+
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
|
|
770
|
+
"",
|
|
771
|
+
"-" * 65,
|
|
772
|
+
"Power Estimates".center(65),
|
|
773
|
+
"-" * 65,
|
|
774
|
+
f"{'Power (rejection rate):':<35} {self.power:.1%}",
|
|
775
|
+
f"{'Standard error:':<35} {self.power_se:.4f}",
|
|
776
|
+
f"{'95% CI:':<35} [{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}]",
|
|
777
|
+
"",
|
|
778
|
+
"-" * 65,
|
|
779
|
+
"Estimation Performance".center(65),
|
|
780
|
+
"-" * 65,
|
|
781
|
+
f"{'Mean estimate:':<35} {self.mean_estimate:.4f}",
|
|
782
|
+
f"{'Bias:':<35} {self.bias:.4f}",
|
|
783
|
+
f"{'Std. deviation of estimates:':<35} {self.std_estimate:.4f}",
|
|
784
|
+
f"{'RMSE:':<35} {self.rmse:.4f}",
|
|
785
|
+
f"{'Mean standard error:':<35} {self.mean_se:.4f}",
|
|
786
|
+
f"{'Coverage (CI contains true):':<35} {self.coverage:.1%}",
|
|
787
|
+
]
|
|
788
|
+
if self.effective_n_units is not None:
|
|
789
|
+
lines.append(
|
|
790
|
+
f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
|
|
791
|
+
)
|
|
792
|
+
lines.append("=" * 65)
|
|
793
|
+
return "\n".join(lines)
|
|
794
|
+
|
|
795
|
+
def print_summary(self) -> None:
|
|
796
|
+
"""Print the summary to stdout."""
|
|
797
|
+
print(self.summary())
|
|
798
|
+
|
|
799
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
800
|
+
"""
|
|
801
|
+
Convert results to a dictionary.
|
|
802
|
+
|
|
803
|
+
Returns
|
|
804
|
+
-------
|
|
805
|
+
Dict[str, Any]
|
|
806
|
+
Dictionary containing simulation power results.
|
|
807
|
+
"""
|
|
808
|
+
d: Dict[str, Any] = {
|
|
809
|
+
"power": self.power,
|
|
810
|
+
"power_se": self.power_se,
|
|
811
|
+
"power_ci_lower": self.power_ci[0],
|
|
812
|
+
"power_ci_upper": self.power_ci[1],
|
|
813
|
+
"rejection_rate": self.rejection_rate,
|
|
814
|
+
"mean_estimate": self.mean_estimate,
|
|
815
|
+
"std_estimate": self.std_estimate,
|
|
816
|
+
"bias": self.bias,
|
|
817
|
+
"rmse": self.rmse,
|
|
818
|
+
"mean_se": self.mean_se,
|
|
819
|
+
"coverage": self.coverage,
|
|
820
|
+
"n_simulations": self.n_simulations,
|
|
821
|
+
"true_effect": self.true_effect,
|
|
822
|
+
"alpha": self.alpha,
|
|
823
|
+
"estimator_name": self.estimator_name,
|
|
824
|
+
"effective_n_units": self.effective_n_units,
|
|
825
|
+
}
|
|
826
|
+
return d
|
|
827
|
+
|
|
828
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
829
|
+
"""
|
|
830
|
+
Convert results to a pandas DataFrame.
|
|
831
|
+
|
|
832
|
+
Returns
|
|
833
|
+
-------
|
|
834
|
+
pd.DataFrame
|
|
835
|
+
DataFrame with simulation power results.
|
|
836
|
+
"""
|
|
837
|
+
return pd.DataFrame([self.to_dict()])
|
|
838
|
+
|
|
839
|
+
def power_curve_df(self) -> pd.DataFrame:
|
|
840
|
+
"""
|
|
841
|
+
Get power curve data as a DataFrame.
|
|
842
|
+
|
|
843
|
+
Returns
|
|
844
|
+
-------
|
|
845
|
+
pd.DataFrame
|
|
846
|
+
DataFrame with effect_size and power columns.
|
|
847
|
+
"""
|
|
848
|
+
return pd.DataFrame({"effect_size": self.effect_sizes, "power": self.powers})
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
class PowerAnalysis:
|
|
852
|
+
"""
|
|
853
|
+
Power analysis for difference-in-differences designs.
|
|
854
|
+
|
|
855
|
+
Provides analytical power calculations for basic 2x2 DiD and panel DiD
|
|
856
|
+
designs. For complex designs like staggered adoption, use simulate_power()
|
|
857
|
+
instead.
|
|
858
|
+
|
|
859
|
+
Parameters
|
|
860
|
+
----------
|
|
861
|
+
alpha : float, default=0.05
|
|
862
|
+
Significance level for hypothesis testing.
|
|
863
|
+
power : float, default=0.80
|
|
864
|
+
Target statistical power.
|
|
865
|
+
alternative : str, default='two-sided'
|
|
866
|
+
Alternative hypothesis: 'two-sided', 'greater', or 'less'.
|
|
867
|
+
|
|
868
|
+
Examples
|
|
869
|
+
--------
|
|
870
|
+
Calculate minimum detectable effect:
|
|
871
|
+
|
|
872
|
+
>>> from diff_diff import PowerAnalysis
|
|
873
|
+
>>> pa = PowerAnalysis(alpha=0.05, power=0.80)
|
|
874
|
+
>>> results = pa.mde(n_treated=50, n_control=50, sigma=1.0)
|
|
875
|
+
>>> print(f"MDE: {results.mde:.3f}")
|
|
876
|
+
|
|
877
|
+
Calculate required sample size:
|
|
878
|
+
|
|
879
|
+
>>> results = pa.sample_size(effect_size=0.5, sigma=1.0)
|
|
880
|
+
>>> print(f"Required N: {results.required_n}")
|
|
881
|
+
|
|
882
|
+
Calculate power for given sample and effect:
|
|
883
|
+
|
|
884
|
+
>>> results = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0)
|
|
885
|
+
>>> print(f"Power: {results.power:.1%}")
|
|
886
|
+
|
|
887
|
+
Notes
|
|
888
|
+
-----
|
|
889
|
+
The power calculations are based on the variance of the DiD estimator:
|
|
890
|
+
|
|
891
|
+
For basic 2x2 DiD:
|
|
892
|
+
Var(ATT) = sigma^2 * (1/n_treated_post + 1/n_treated_pre
|
|
893
|
+
+ 1/n_control_post + 1/n_control_pre)
|
|
894
|
+
|
|
895
|
+
For panel DiD with T periods:
|
|
896
|
+
Var(ATT) = sigma^2 * (1/(N_treated * T) + 1/(N_control * T))
|
|
897
|
+
* (1 + (T-1)*rho) / (1 + (T-1)*rho)
|
|
898
|
+
|
|
899
|
+
Where rho is the intra-cluster correlation coefficient.
|
|
900
|
+
|
|
901
|
+
References
|
|
902
|
+
----------
|
|
903
|
+
Bloom, H. S. (1995). "Minimum Detectable Effects."
|
|
904
|
+
Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
|
|
905
|
+
"""
|
|
906
|
+
|
|
907
|
+
def __init__(
|
|
908
|
+
self,
|
|
909
|
+
alpha: float = 0.05,
|
|
910
|
+
power: float = 0.80,
|
|
911
|
+
alternative: str = "two-sided",
|
|
912
|
+
):
|
|
913
|
+
if not 0 < alpha < 1:
|
|
914
|
+
raise ValueError("alpha must be between 0 and 1")
|
|
915
|
+
if not 0 < power < 1:
|
|
916
|
+
raise ValueError("power must be between 0 and 1")
|
|
917
|
+
if alternative not in ("two-sided", "greater", "less"):
|
|
918
|
+
raise ValueError("alternative must be 'two-sided', 'greater', or 'less'")
|
|
919
|
+
|
|
920
|
+
self.alpha = alpha
|
|
921
|
+
self.target_power = power
|
|
922
|
+
self.alternative = alternative
|
|
923
|
+
|
|
924
|
+
def _get_critical_values(self) -> Tuple[float, float]:
|
|
925
|
+
"""Get z critical values for alpha and power."""
|
|
926
|
+
if self.alternative == "two-sided":
|
|
927
|
+
z_alpha = stats.norm.ppf(1 - self.alpha / 2)
|
|
928
|
+
else:
|
|
929
|
+
z_alpha = stats.norm.ppf(1 - self.alpha)
|
|
930
|
+
z_beta = stats.norm.ppf(self.target_power)
|
|
931
|
+
return z_alpha, z_beta
|
|
932
|
+
|
|
933
|
+
def _compute_variance(
|
|
934
|
+
self,
|
|
935
|
+
n_treated: int,
|
|
936
|
+
n_control: int,
|
|
937
|
+
n_pre: int,
|
|
938
|
+
n_post: int,
|
|
939
|
+
sigma: float,
|
|
940
|
+
rho: float = 0.0,
|
|
941
|
+
design: str = "basic_did",
|
|
942
|
+
) -> float:
|
|
943
|
+
"""
|
|
944
|
+
Compute variance of the DiD estimator.
|
|
945
|
+
|
|
946
|
+
Parameters
|
|
947
|
+
----------
|
|
948
|
+
n_treated : int
|
|
949
|
+
Number of treated units.
|
|
950
|
+
n_control : int
|
|
951
|
+
Number of control units.
|
|
952
|
+
n_pre : int
|
|
953
|
+
Number of pre-treatment periods.
|
|
954
|
+
n_post : int
|
|
955
|
+
Number of post-treatment periods.
|
|
956
|
+
sigma : float
|
|
957
|
+
Residual standard deviation.
|
|
958
|
+
rho : float
|
|
959
|
+
Intra-cluster correlation (for panel data).
|
|
960
|
+
design : str
|
|
961
|
+
Study design type.
|
|
962
|
+
|
|
963
|
+
Returns
|
|
964
|
+
-------
|
|
965
|
+
float
|
|
966
|
+
Variance of the DiD estimator.
|
|
967
|
+
"""
|
|
968
|
+
if design == "basic_did":
|
|
969
|
+
# For basic 2x2 DiD, each cell has n_treated/2 or n_control/2 obs
|
|
970
|
+
# assuming balanced design
|
|
971
|
+
n_t_pre = n_treated # treated units in pre-period
|
|
972
|
+
n_t_post = n_treated # treated units in post-period
|
|
973
|
+
n_c_pre = n_control
|
|
974
|
+
n_c_post = n_control
|
|
975
|
+
|
|
976
|
+
variance = sigma**2 * (1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre)
|
|
977
|
+
elif design == "panel":
|
|
978
|
+
# Panel DiD with multiple periods
|
|
979
|
+
# Account for serial correlation via ICC
|
|
980
|
+
T = n_pre + n_post
|
|
981
|
+
|
|
982
|
+
# Design effect for clustering
|
|
983
|
+
design_effect = 1 + (T - 1) * rho
|
|
984
|
+
|
|
985
|
+
# Base variance (as if independent)
|
|
986
|
+
base_var = sigma**2 * (1 / n_treated + 1 / n_control)
|
|
987
|
+
|
|
988
|
+
# Adjust for clustering (Moulton factor)
|
|
989
|
+
variance = base_var * design_effect / T
|
|
990
|
+
else:
|
|
991
|
+
raise ValueError(f"Unknown design: {design}")
|
|
992
|
+
|
|
993
|
+
return variance
|
|
994
|
+
|
|
995
|
+
def power(
|
|
996
|
+
self,
|
|
997
|
+
effect_size: float,
|
|
998
|
+
n_treated: int,
|
|
999
|
+
n_control: int,
|
|
1000
|
+
sigma: float,
|
|
1001
|
+
n_pre: int = 1,
|
|
1002
|
+
n_post: int = 1,
|
|
1003
|
+
rho: float = 0.0,
|
|
1004
|
+
) -> PowerResults:
|
|
1005
|
+
"""
|
|
1006
|
+
Calculate statistical power for given effect size and sample.
|
|
1007
|
+
|
|
1008
|
+
Parameters
|
|
1009
|
+
----------
|
|
1010
|
+
effect_size : float
|
|
1011
|
+
Expected treatment effect size.
|
|
1012
|
+
n_treated : int
|
|
1013
|
+
Number of treated units.
|
|
1014
|
+
n_control : int
|
|
1015
|
+
Number of control units.
|
|
1016
|
+
sigma : float
|
|
1017
|
+
Residual standard deviation.
|
|
1018
|
+
n_pre : int, default=1
|
|
1019
|
+
Number of pre-treatment periods.
|
|
1020
|
+
n_post : int, default=1
|
|
1021
|
+
Number of post-treatment periods.
|
|
1022
|
+
rho : float, default=0.0
|
|
1023
|
+
Intra-cluster correlation for panel data.
|
|
1024
|
+
|
|
1025
|
+
Returns
|
|
1026
|
+
-------
|
|
1027
|
+
PowerResults
|
|
1028
|
+
Power analysis results.
|
|
1029
|
+
|
|
1030
|
+
Examples
|
|
1031
|
+
--------
|
|
1032
|
+
>>> pa = PowerAnalysis()
|
|
1033
|
+
>>> results = pa.power(effect_size=2.0, n_treated=50, n_control=50, sigma=5.0)
|
|
1034
|
+
>>> print(f"Power: {results.power:.1%}")
|
|
1035
|
+
"""
|
|
1036
|
+
T = n_pre + n_post
|
|
1037
|
+
design = "panel" if T > 2 else "basic_did"
|
|
1038
|
+
|
|
1039
|
+
variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design)
|
|
1040
|
+
se = np.sqrt(variance)
|
|
1041
|
+
|
|
1042
|
+
# Calculate power
|
|
1043
|
+
if self.alternative == "two-sided":
|
|
1044
|
+
z_alpha = stats.norm.ppf(1 - self.alpha / 2)
|
|
1045
|
+
# Power = P(reject | effect) = P(|Z| > z_alpha | effect)
|
|
1046
|
+
power_val = (
|
|
1047
|
+
1
|
|
1048
|
+
- stats.norm.cdf(z_alpha - effect_size / se)
|
|
1049
|
+
+ stats.norm.cdf(-z_alpha - effect_size / se)
|
|
1050
|
+
)
|
|
1051
|
+
elif self.alternative == "greater":
|
|
1052
|
+
z_alpha = stats.norm.ppf(1 - self.alpha)
|
|
1053
|
+
power_val = 1 - stats.norm.cdf(z_alpha - effect_size / se)
|
|
1054
|
+
else: # less
|
|
1055
|
+
z_alpha = stats.norm.ppf(1 - self.alpha)
|
|
1056
|
+
power_val = stats.norm.cdf(-z_alpha - effect_size / se)
|
|
1057
|
+
|
|
1058
|
+
# Also compute MDE and required N for reference
|
|
1059
|
+
mde = self._compute_mde_from_se(se)
|
|
1060
|
+
required_n = self._compute_required_n(
|
|
1061
|
+
effect_size, sigma, n_pre, n_post, rho, design, n_treated / (n_treated + n_control)
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
return PowerResults(
|
|
1065
|
+
power=power_val,
|
|
1066
|
+
mde=mde,
|
|
1067
|
+
required_n=required_n,
|
|
1068
|
+
effect_size=effect_size,
|
|
1069
|
+
alpha=self.alpha,
|
|
1070
|
+
alternative=self.alternative,
|
|
1071
|
+
n_treated=n_treated,
|
|
1072
|
+
n_control=n_control,
|
|
1073
|
+
n_pre=n_pre,
|
|
1074
|
+
n_post=n_post,
|
|
1075
|
+
sigma=sigma,
|
|
1076
|
+
rho=rho,
|
|
1077
|
+
design=design,
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
def _compute_mde_from_se(self, se: float) -> float:
|
|
1081
|
+
"""Compute MDE given standard error."""
|
|
1082
|
+
z_alpha, z_beta = self._get_critical_values()
|
|
1083
|
+
return (z_alpha + z_beta) * se
|
|
1084
|
+
|
|
1085
|
+
def mde(
|
|
1086
|
+
self,
|
|
1087
|
+
n_treated: int,
|
|
1088
|
+
n_control: int,
|
|
1089
|
+
sigma: float,
|
|
1090
|
+
n_pre: int = 1,
|
|
1091
|
+
n_post: int = 1,
|
|
1092
|
+
rho: float = 0.0,
|
|
1093
|
+
) -> PowerResults:
|
|
1094
|
+
"""
|
|
1095
|
+
Calculate minimum detectable effect given sample size.
|
|
1096
|
+
|
|
1097
|
+
The MDE is the smallest effect size that can be detected with the
|
|
1098
|
+
specified power and significance level.
|
|
1099
|
+
|
|
1100
|
+
Parameters
|
|
1101
|
+
----------
|
|
1102
|
+
n_treated : int
|
|
1103
|
+
Number of treated units.
|
|
1104
|
+
n_control : int
|
|
1105
|
+
Number of control units.
|
|
1106
|
+
sigma : float
|
|
1107
|
+
Residual standard deviation.
|
|
1108
|
+
n_pre : int, default=1
|
|
1109
|
+
Number of pre-treatment periods.
|
|
1110
|
+
n_post : int, default=1
|
|
1111
|
+
Number of post-treatment periods.
|
|
1112
|
+
rho : float, default=0.0
|
|
1113
|
+
Intra-cluster correlation for panel data.
|
|
1114
|
+
|
|
1115
|
+
Returns
|
|
1116
|
+
-------
|
|
1117
|
+
PowerResults
|
|
1118
|
+
Power analysis results including MDE.
|
|
1119
|
+
|
|
1120
|
+
Examples
|
|
1121
|
+
--------
|
|
1122
|
+
>>> pa = PowerAnalysis(power=0.80)
|
|
1123
|
+
>>> results = pa.mde(n_treated=100, n_control=100, sigma=10.0)
|
|
1124
|
+
>>> print(f"MDE: {results.mde:.2f}")
|
|
1125
|
+
"""
|
|
1126
|
+
T = n_pre + n_post
|
|
1127
|
+
design = "panel" if T > 2 else "basic_did"
|
|
1128
|
+
|
|
1129
|
+
variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design)
|
|
1130
|
+
se = np.sqrt(variance)
|
|
1131
|
+
|
|
1132
|
+
mde = self._compute_mde_from_se(se)
|
|
1133
|
+
|
|
1134
|
+
return PowerResults(
|
|
1135
|
+
power=self.target_power,
|
|
1136
|
+
mde=mde,
|
|
1137
|
+
required_n=n_treated + n_control,
|
|
1138
|
+
effect_size=mde,
|
|
1139
|
+
alpha=self.alpha,
|
|
1140
|
+
alternative=self.alternative,
|
|
1141
|
+
n_treated=n_treated,
|
|
1142
|
+
n_control=n_control,
|
|
1143
|
+
n_pre=n_pre,
|
|
1144
|
+
n_post=n_post,
|
|
1145
|
+
sigma=sigma,
|
|
1146
|
+
rho=rho,
|
|
1147
|
+
design=design,
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
def _compute_required_n(
|
|
1151
|
+
self,
|
|
1152
|
+
effect_size: float,
|
|
1153
|
+
sigma: float,
|
|
1154
|
+
n_pre: int,
|
|
1155
|
+
n_post: int,
|
|
1156
|
+
rho: float,
|
|
1157
|
+
design: str,
|
|
1158
|
+
treat_frac: float = 0.5,
|
|
1159
|
+
) -> int:
|
|
1160
|
+
"""Compute required sample size for given effect."""
|
|
1161
|
+
# Handle edge case of zero effect size
|
|
1162
|
+
if effect_size == 0:
|
|
1163
|
+
return MAX_SAMPLE_SIZE # Can't detect zero effect
|
|
1164
|
+
|
|
1165
|
+
z_alpha, z_beta = self._get_critical_values()
|
|
1166
|
+
|
|
1167
|
+
T = n_pre + n_post
|
|
1168
|
+
|
|
1169
|
+
if design == "basic_did":
|
|
1170
|
+
# Var = sigma^2 * (1/n_t + 1/n_t + 1/n_c + 1/n_c) = sigma^2 * (2/n_t + 2/n_c)
|
|
1171
|
+
# For balanced: Var = sigma^2 * 4/n where n = n_t = n_c
|
|
1172
|
+
# SE = sqrt(Var), effect_size = (z_alpha + z_beta) * SE
|
|
1173
|
+
# n = 4 * sigma^2 * (z_alpha + z_beta)^2 / effect_size^2
|
|
1174
|
+
|
|
1175
|
+
# For general allocation with treat_frac:
|
|
1176
|
+
# Var = sigma^2 * 2 * (1/(N*p) + 1/(N*(1-p)))
|
|
1177
|
+
# = 2 * sigma^2 / N * (1/p + 1/(1-p))
|
|
1178
|
+
# = 2 * sigma^2 / N * (1/(p*(1-p)))
|
|
1179
|
+
|
|
1180
|
+
n_total = (
|
|
1181
|
+
2
|
|
1182
|
+
* sigma**2
|
|
1183
|
+
* (z_alpha + z_beta) ** 2
|
|
1184
|
+
/ (effect_size**2 * treat_frac * (1 - treat_frac))
|
|
1185
|
+
)
|
|
1186
|
+
else: # panel
|
|
1187
|
+
design_effect = 1 + (T - 1) * rho
|
|
1188
|
+
|
|
1189
|
+
# Var = sigma^2 * (1/n_t + 1/n_c) * design_effect / T
|
|
1190
|
+
# For balanced: Var = 2 * sigma^2 / N * design_effect / T
|
|
1191
|
+
|
|
1192
|
+
n_total = (
|
|
1193
|
+
2
|
|
1194
|
+
* sigma**2
|
|
1195
|
+
* (z_alpha + z_beta) ** 2
|
|
1196
|
+
* design_effect
|
|
1197
|
+
/ (effect_size**2 * treat_frac * (1 - treat_frac) * T)
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
# Handle infinity case (extremely small effect)
|
|
1201
|
+
if np.isinf(n_total):
|
|
1202
|
+
return MAX_SAMPLE_SIZE
|
|
1203
|
+
|
|
1204
|
+
return max(4, int(np.ceil(n_total))) # At least 4 units
|
|
1205
|
+
|
|
1206
|
+
def sample_size(
|
|
1207
|
+
self,
|
|
1208
|
+
effect_size: float,
|
|
1209
|
+
sigma: float,
|
|
1210
|
+
n_pre: int = 1,
|
|
1211
|
+
n_post: int = 1,
|
|
1212
|
+
rho: float = 0.0,
|
|
1213
|
+
treat_frac: float = 0.5,
|
|
1214
|
+
) -> PowerResults:
|
|
1215
|
+
"""
|
|
1216
|
+
Calculate required sample size to detect given effect.
|
|
1217
|
+
|
|
1218
|
+
Parameters
|
|
1219
|
+
----------
|
|
1220
|
+
effect_size : float
|
|
1221
|
+
Treatment effect to detect.
|
|
1222
|
+
sigma : float
|
|
1223
|
+
Residual standard deviation.
|
|
1224
|
+
n_pre : int, default=1
|
|
1225
|
+
Number of pre-treatment periods.
|
|
1226
|
+
n_post : int, default=1
|
|
1227
|
+
Number of post-treatment periods.
|
|
1228
|
+
rho : float, default=0.0
|
|
1229
|
+
Intra-cluster correlation for panel data.
|
|
1230
|
+
treat_frac : float, default=0.5
|
|
1231
|
+
Fraction of units assigned to treatment.
|
|
1232
|
+
|
|
1233
|
+
Returns
|
|
1234
|
+
-------
|
|
1235
|
+
PowerResults
|
|
1236
|
+
Power analysis results including required sample size.
|
|
1237
|
+
|
|
1238
|
+
Examples
|
|
1239
|
+
--------
|
|
1240
|
+
>>> pa = PowerAnalysis(power=0.80)
|
|
1241
|
+
>>> results = pa.sample_size(effect_size=5.0, sigma=10.0)
|
|
1242
|
+
>>> print(f"Required N: {results.required_n}")
|
|
1243
|
+
"""
|
|
1244
|
+
T = n_pre + n_post
|
|
1245
|
+
design = "panel" if T > 2 else "basic_did"
|
|
1246
|
+
|
|
1247
|
+
n_total = self._compute_required_n(
|
|
1248
|
+
effect_size, sigma, n_pre, n_post, rho, design, treat_frac
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
n_treated = max(2, int(np.ceil(n_total * treat_frac)))
|
|
1252
|
+
n_control = max(2, n_total - n_treated)
|
|
1253
|
+
n_total = n_treated + n_control
|
|
1254
|
+
|
|
1255
|
+
# Compute actual power achieved
|
|
1256
|
+
variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design)
|
|
1257
|
+
se = np.sqrt(variance)
|
|
1258
|
+
mde = self._compute_mde_from_se(se)
|
|
1259
|
+
|
|
1260
|
+
return PowerResults(
|
|
1261
|
+
power=self.target_power,
|
|
1262
|
+
mde=mde,
|
|
1263
|
+
required_n=n_total,
|
|
1264
|
+
effect_size=effect_size,
|
|
1265
|
+
alpha=self.alpha,
|
|
1266
|
+
alternative=self.alternative,
|
|
1267
|
+
n_treated=n_treated,
|
|
1268
|
+
n_control=n_control,
|
|
1269
|
+
n_pre=n_pre,
|
|
1270
|
+
n_post=n_post,
|
|
1271
|
+
sigma=sigma,
|
|
1272
|
+
rho=rho,
|
|
1273
|
+
design=design,
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
def power_curve(
|
|
1277
|
+
self,
|
|
1278
|
+
n_treated: int,
|
|
1279
|
+
n_control: int,
|
|
1280
|
+
sigma: float,
|
|
1281
|
+
effect_sizes: Optional[List[float]] = None,
|
|
1282
|
+
n_pre: int = 1,
|
|
1283
|
+
n_post: int = 1,
|
|
1284
|
+
rho: float = 0.0,
|
|
1285
|
+
) -> pd.DataFrame:
|
|
1286
|
+
"""
|
|
1287
|
+
Compute power for a range of effect sizes.
|
|
1288
|
+
|
|
1289
|
+
Parameters
|
|
1290
|
+
----------
|
|
1291
|
+
n_treated : int
|
|
1292
|
+
Number of treated units.
|
|
1293
|
+
n_control : int
|
|
1294
|
+
Number of control units.
|
|
1295
|
+
sigma : float
|
|
1296
|
+
Residual standard deviation.
|
|
1297
|
+
effect_sizes : list of float, optional
|
|
1298
|
+
Effect sizes to evaluate. If None, uses a range from 0 to 3*MDE.
|
|
1299
|
+
n_pre : int, default=1
|
|
1300
|
+
Number of pre-treatment periods.
|
|
1301
|
+
n_post : int, default=1
|
|
1302
|
+
Number of post-treatment periods.
|
|
1303
|
+
rho : float, default=0.0
|
|
1304
|
+
Intra-cluster correlation.
|
|
1305
|
+
|
|
1306
|
+
Returns
|
|
1307
|
+
-------
|
|
1308
|
+
pd.DataFrame
|
|
1309
|
+
DataFrame with columns 'effect_size' and 'power'.
|
|
1310
|
+
|
|
1311
|
+
Examples
|
|
1312
|
+
--------
|
|
1313
|
+
>>> pa = PowerAnalysis()
|
|
1314
|
+
>>> curve = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
|
|
1315
|
+
>>> print(curve)
|
|
1316
|
+
"""
|
|
1317
|
+
# First get MDE to determine default range
|
|
1318
|
+
mde_result = self.mde(n_treated, n_control, sigma, n_pre, n_post, rho)
|
|
1319
|
+
|
|
1320
|
+
if effect_sizes is None:
|
|
1321
|
+
# Generate range from 0 to 2*MDE
|
|
1322
|
+
effect_sizes = np.linspace(0, 2.5 * mde_result.mde, 50).tolist()
|
|
1323
|
+
|
|
1324
|
+
powers = []
|
|
1325
|
+
for es in effect_sizes:
|
|
1326
|
+
result = self.power(
|
|
1327
|
+
effect_size=es,
|
|
1328
|
+
n_treated=n_treated,
|
|
1329
|
+
n_control=n_control,
|
|
1330
|
+
sigma=sigma,
|
|
1331
|
+
n_pre=n_pre,
|
|
1332
|
+
n_post=n_post,
|
|
1333
|
+
rho=rho,
|
|
1334
|
+
)
|
|
1335
|
+
powers.append(result.power)
|
|
1336
|
+
|
|
1337
|
+
return pd.DataFrame({"effect_size": effect_sizes, "power": powers})
|
|
1338
|
+
|
|
1339
|
+
def sample_size_curve(
|
|
1340
|
+
self,
|
|
1341
|
+
effect_size: float,
|
|
1342
|
+
sigma: float,
|
|
1343
|
+
sample_sizes: Optional[List[int]] = None,
|
|
1344
|
+
n_pre: int = 1,
|
|
1345
|
+
n_post: int = 1,
|
|
1346
|
+
rho: float = 0.0,
|
|
1347
|
+
treat_frac: float = 0.5,
|
|
1348
|
+
) -> pd.DataFrame:
|
|
1349
|
+
"""
|
|
1350
|
+
Compute power for a range of sample sizes.
|
|
1351
|
+
|
|
1352
|
+
Parameters
|
|
1353
|
+
----------
|
|
1354
|
+
effect_size : float
|
|
1355
|
+
Treatment effect size.
|
|
1356
|
+
sigma : float
|
|
1357
|
+
Residual standard deviation.
|
|
1358
|
+
sample_sizes : list of int, optional
|
|
1359
|
+
Total sample sizes to evaluate. If None, uses sensible range.
|
|
1360
|
+
n_pre : int, default=1
|
|
1361
|
+
Number of pre-treatment periods.
|
|
1362
|
+
n_post : int, default=1
|
|
1363
|
+
Number of post-treatment periods.
|
|
1364
|
+
rho : float, default=0.0
|
|
1365
|
+
Intra-cluster correlation.
|
|
1366
|
+
treat_frac : float, default=0.5
|
|
1367
|
+
Fraction assigned to treatment.
|
|
1368
|
+
|
|
1369
|
+
Returns
|
|
1370
|
+
-------
|
|
1371
|
+
pd.DataFrame
|
|
1372
|
+
DataFrame with columns 'sample_size' and 'power'.
|
|
1373
|
+
"""
|
|
1374
|
+
# Get required N to determine default range
|
|
1375
|
+
required = self.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac)
|
|
1376
|
+
|
|
1377
|
+
if sample_sizes is None:
|
|
1378
|
+
min_n = max(10, required.required_n // 4)
|
|
1379
|
+
max_n = required.required_n * 2
|
|
1380
|
+
sample_sizes = list(range(min_n, max_n + 1, max(1, (max_n - min_n) // 50)))
|
|
1381
|
+
|
|
1382
|
+
powers = []
|
|
1383
|
+
for n in sample_sizes:
|
|
1384
|
+
n_treated = max(2, int(n * treat_frac))
|
|
1385
|
+
n_control = max(2, n - n_treated)
|
|
1386
|
+
result = self.power(
|
|
1387
|
+
effect_size=effect_size,
|
|
1388
|
+
n_treated=n_treated,
|
|
1389
|
+
n_control=n_control,
|
|
1390
|
+
sigma=sigma,
|
|
1391
|
+
n_pre=n_pre,
|
|
1392
|
+
n_post=n_post,
|
|
1393
|
+
rho=rho,
|
|
1394
|
+
)
|
|
1395
|
+
powers.append(result.power)
|
|
1396
|
+
|
|
1397
|
+
return pd.DataFrame({"sample_size": sample_sizes, "power": powers})
|
|
1398
|
+
|
|
1399
|
+
|
|
1400
|
+
def simulate_power(
|
|
1401
|
+
estimator: Any,
|
|
1402
|
+
n_units: int = 100,
|
|
1403
|
+
n_periods: int = 4,
|
|
1404
|
+
treatment_effect: float = 5.0,
|
|
1405
|
+
treatment_fraction: float = 0.5,
|
|
1406
|
+
treatment_period: int = 2,
|
|
1407
|
+
sigma: float = 1.0,
|
|
1408
|
+
n_simulations: int = 500,
|
|
1409
|
+
alpha: float = 0.05,
|
|
1410
|
+
effect_sizes: Optional[List[float]] = None,
|
|
1411
|
+
seed: Optional[int] = None,
|
|
1412
|
+
data_generator: Optional[Callable] = None,
|
|
1413
|
+
data_generator_kwargs: Optional[Dict[str, Any]] = None,
|
|
1414
|
+
estimator_kwargs: Optional[Dict[str, Any]] = None,
|
|
1415
|
+
result_extractor: Optional[Callable] = None,
|
|
1416
|
+
progress: bool = True,
|
|
1417
|
+
) -> SimulationPowerResults:
|
|
1418
|
+
"""
|
|
1419
|
+
Estimate power using Monte Carlo simulation.
|
|
1420
|
+
|
|
1421
|
+
This function simulates datasets with known treatment effects and estimates
|
|
1422
|
+
power as the fraction of simulations where the null hypothesis is rejected.
|
|
1423
|
+
Most built-in estimators are supported via an internal registry that selects
|
|
1424
|
+
the appropriate data-generating process and fit signature automatically.
|
|
1425
|
+
|
|
1426
|
+
Parameters
|
|
1427
|
+
----------
|
|
1428
|
+
estimator : estimator object
|
|
1429
|
+
DiD estimator to use (e.g., DifferenceInDifferences, CallawaySantAnna).
|
|
1430
|
+
n_units : int, default=100
|
|
1431
|
+
Number of units per simulation.
|
|
1432
|
+
n_periods : int, default=4
|
|
1433
|
+
Number of time periods.
|
|
1434
|
+
treatment_effect : float, default=5.0
|
|
1435
|
+
True treatment effect to simulate.
|
|
1436
|
+
treatment_fraction : float, default=0.5
|
|
1437
|
+
Fraction of units that are treated.
|
|
1438
|
+
treatment_period : int, default=2
|
|
1439
|
+
First post-treatment period (0-indexed).
|
|
1440
|
+
sigma : float, default=1.0
|
|
1441
|
+
Residual standard deviation (noise level).
|
|
1442
|
+
n_simulations : int, default=500
|
|
1443
|
+
Number of Monte Carlo simulations.
|
|
1444
|
+
alpha : float, default=0.05
|
|
1445
|
+
Significance level for hypothesis tests.
|
|
1446
|
+
effect_sizes : list of float, optional
|
|
1447
|
+
Multiple effect sizes to evaluate for power curve.
|
|
1448
|
+
If None, uses only treatment_effect.
|
|
1449
|
+
seed : int, optional
|
|
1450
|
+
Random seed for reproducibility.
|
|
1451
|
+
data_generator : callable, optional
|
|
1452
|
+
Custom data generation function. When provided, bypasses the
|
|
1453
|
+
registry DGP and calls this function with the standard kwargs
|
|
1454
|
+
(n_units, n_periods, treatment_effect, etc.).
|
|
1455
|
+
data_generator_kwargs : dict, optional
|
|
1456
|
+
Additional keyword arguments for data generator.
|
|
1457
|
+
estimator_kwargs : dict, optional
|
|
1458
|
+
Additional keyword arguments for estimator.fit().
|
|
1459
|
+
result_extractor : callable, optional
|
|
1460
|
+
Custom function to extract results from the estimator output.
|
|
1461
|
+
Takes the estimator result object and returns a tuple of
|
|
1462
|
+
``(att, se, p_value, conf_int)``. Useful for unregistered
|
|
1463
|
+
estimators with non-standard result schemas.
|
|
1464
|
+
progress : bool, default=True
|
|
1465
|
+
Whether to print progress updates.
|
|
1466
|
+
|
|
1467
|
+
Returns
|
|
1468
|
+
-------
|
|
1469
|
+
SimulationPowerResults
|
|
1470
|
+
Simulation-based power analysis results.
|
|
1471
|
+
|
|
1472
|
+
Examples
|
|
1473
|
+
--------
|
|
1474
|
+
Basic power simulation:
|
|
1475
|
+
|
|
1476
|
+
>>> from diff_diff import DifferenceInDifferences, simulate_power
|
|
1477
|
+
>>> did = DifferenceInDifferences()
|
|
1478
|
+
>>> results = simulate_power(
|
|
1479
|
+
... estimator=did,
|
|
1480
|
+
... n_units=100,
|
|
1481
|
+
... treatment_effect=5.0,
|
|
1482
|
+
... sigma=5.0,
|
|
1483
|
+
... n_simulations=500,
|
|
1484
|
+
... seed=42
|
|
1485
|
+
... )
|
|
1486
|
+
>>> print(f"Power: {results.power:.1%}")
|
|
1487
|
+
|
|
1488
|
+
Power curve over multiple effect sizes:
|
|
1489
|
+
|
|
1490
|
+
>>> results = simulate_power(
|
|
1491
|
+
... estimator=did,
|
|
1492
|
+
... effect_sizes=[1.0, 2.0, 3.0, 5.0, 7.0],
|
|
1493
|
+
... n_simulations=200,
|
|
1494
|
+
... seed=42
|
|
1495
|
+
... )
|
|
1496
|
+
>>> print(results.power_curve_df())
|
|
1497
|
+
|
|
1498
|
+
With Callaway-Sant'Anna (auto-detected, no custom DGP needed):
|
|
1499
|
+
|
|
1500
|
+
>>> from diff_diff import CallawaySantAnna
|
|
1501
|
+
>>> cs = CallawaySantAnna()
|
|
1502
|
+
>>> results = simulate_power(cs, n_simulations=200, seed=42)
|
|
1503
|
+
|
|
1504
|
+
Notes
|
|
1505
|
+
-----
|
|
1506
|
+
The simulation approach:
|
|
1507
|
+
1. Generate data with known treatment effect
|
|
1508
|
+
2. Fit the estimator and record the p-value
|
|
1509
|
+
3. Repeat n_simulations times
|
|
1510
|
+
4. Power = fraction of simulations where p-value < alpha
|
|
1511
|
+
|
|
1512
|
+
References
|
|
1513
|
+
----------
|
|
1514
|
+
Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
|
|
1515
|
+
"""
|
|
1516
|
+
rng = np.random.default_rng(seed)
|
|
1517
|
+
|
|
1518
|
+
estimator_name = type(estimator).__name__
|
|
1519
|
+
registry = _get_registry()
|
|
1520
|
+
profile = registry.get(estimator_name)
|
|
1521
|
+
|
|
1522
|
+
# If no profile and no custom data_generator, raise
|
|
1523
|
+
if profile is None and data_generator is None:
|
|
1524
|
+
raise ValueError(
|
|
1525
|
+
f"Estimator '{estimator_name}' not in registry. "
|
|
1526
|
+
f"Provide a custom data_generator and estimator_kwargs "
|
|
1527
|
+
f"(the full dict of keyword arguments for estimator.fit(), "
|
|
1528
|
+
f"e.g. dict(outcome='y', treatment='treat', time='period'))."
|
|
1529
|
+
)
|
|
1530
|
+
|
|
1531
|
+
# When a custom data_generator is provided, bypass registry DGP
|
|
1532
|
+
use_custom_dgp = data_generator is not None
|
|
1533
|
+
|
|
1534
|
+
data_gen_kwargs = data_generator_kwargs or {}
|
|
1535
|
+
est_kwargs = estimator_kwargs or {}
|
|
1536
|
+
|
|
1537
|
+
# SyntheticDiD placebo variance requires n_control > n_treated.
|
|
1538
|
+
# Check after merging data_generator_kwargs so overrides of n_treated
|
|
1539
|
+
# are accounted for.
|
|
1540
|
+
if estimator_name == "SyntheticDiD" and not use_custom_dgp:
|
|
1541
|
+
vm = getattr(estimator, "variance_method", "placebo")
|
|
1542
|
+
effective_n_treated = data_gen_kwargs.get(
|
|
1543
|
+
"n_treated", max(1, int(n_units * treatment_fraction))
|
|
1544
|
+
)
|
|
1545
|
+
n_control = n_units - effective_n_treated
|
|
1546
|
+
if vm == "placebo" and n_control <= effective_n_treated:
|
|
1547
|
+
raise ValueError(
|
|
1548
|
+
f"SyntheticDiD placebo variance requires more control than "
|
|
1549
|
+
f"treated units (got n_control={n_control}, "
|
|
1550
|
+
f"n_treated={effective_n_treated}). Either lower "
|
|
1551
|
+
f"treatment_fraction so that n_control > n_treated, or use "
|
|
1552
|
+
f"SyntheticDiD(variance_method='bootstrap')."
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
# Warn if staggered estimator settings don't match auto DGP
|
|
1556
|
+
if profile is not None and not use_custom_dgp:
|
|
1557
|
+
_check_staggered_dgp_compat(estimator, data_generator_kwargs)
|
|
1558
|
+
|
|
1559
|
+
# Block registry-path collisions on search-critical keys
|
|
1560
|
+
if profile is not None and not use_custom_dgp and data_gen_kwargs:
|
|
1561
|
+
sample_dgp_keys = set(
|
|
1562
|
+
profile.dgp_kwargs_builder(
|
|
1563
|
+
n_units=n_units,
|
|
1564
|
+
n_periods=n_periods,
|
|
1565
|
+
treatment_effect=treatment_effect,
|
|
1566
|
+
treatment_fraction=treatment_fraction,
|
|
1567
|
+
treatment_period=treatment_period,
|
|
1568
|
+
sigma=sigma,
|
|
1569
|
+
).keys()
|
|
1570
|
+
)
|
|
1571
|
+
collisions = _PROTECTED_DGP_KEYS & set(data_gen_kwargs) & sample_dgp_keys
|
|
1572
|
+
if collisions:
|
|
1573
|
+
raise ValueError(
|
|
1574
|
+
f"data_generator_kwargs contains keys that conflict with "
|
|
1575
|
+
f"registry-managed simulation inputs: {sorted(collisions)}. "
|
|
1576
|
+
f"These are controlled by simulate_power() parameters directly. "
|
|
1577
|
+
f"Use the corresponding function parameters instead, or pass a "
|
|
1578
|
+
f"custom data_generator to override the DGP entirely."
|
|
1579
|
+
)
|
|
1580
|
+
|
|
1581
|
+
# Warn if DDD design inputs are silently ignored
|
|
1582
|
+
if estimator_name == "TripleDifference" and not use_custom_dgp:
|
|
1583
|
+
_check_ddd_dgp_compat(
|
|
1584
|
+
n_units,
|
|
1585
|
+
n_periods,
|
|
1586
|
+
treatment_fraction,
|
|
1587
|
+
treatment_period,
|
|
1588
|
+
data_generator_kwargs,
|
|
1589
|
+
)
|
|
1590
|
+
effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs)
|
|
1591
|
+
else:
|
|
1592
|
+
effective_n_units = None
|
|
1593
|
+
|
|
1594
|
+
# Determine effect sizes to test
|
|
1595
|
+
if effect_sizes is None:
|
|
1596
|
+
effect_sizes = [treatment_effect]
|
|
1597
|
+
|
|
1598
|
+
all_powers = []
|
|
1599
|
+
|
|
1600
|
+
# For the primary effect, collect detailed results
|
|
1601
|
+
if len(effect_sizes) == 1:
|
|
1602
|
+
primary_idx = 0
|
|
1603
|
+
else:
|
|
1604
|
+
primary_idx = -1
|
|
1605
|
+
for i, es in enumerate(effect_sizes):
|
|
1606
|
+
if np.isclose(es, treatment_effect):
|
|
1607
|
+
primary_idx = i
|
|
1608
|
+
break
|
|
1609
|
+
if primary_idx == -1:
|
|
1610
|
+
primary_idx = len(effect_sizes) - 1
|
|
1611
|
+
|
|
1612
|
+
primary_effect = effect_sizes[primary_idx]
|
|
1613
|
+
|
|
1614
|
+
# Initialize so they are always bound
|
|
1615
|
+
primary_estimates: List[float] = []
|
|
1616
|
+
primary_ses: List[float] = []
|
|
1617
|
+
primary_p_values: List[float] = []
|
|
1618
|
+
primary_rejections: List[bool] = []
|
|
1619
|
+
primary_ci_contains: List[bool] = []
|
|
1620
|
+
|
|
1621
|
+
for effect_idx, effect in enumerate(effect_sizes):
|
|
1622
|
+
is_primary = effect_idx == primary_idx
|
|
1623
|
+
|
|
1624
|
+
estimates: List[float] = []
|
|
1625
|
+
ses: List[float] = []
|
|
1626
|
+
p_values: List[float] = []
|
|
1627
|
+
rejections: List[bool] = []
|
|
1628
|
+
ci_contains_true: List[bool] = []
|
|
1629
|
+
n_failures = 0
|
|
1630
|
+
|
|
1631
|
+
for sim in range(n_simulations):
|
|
1632
|
+
if progress and sim % 100 == 0 and sim > 0:
|
|
1633
|
+
pct = (sim + effect_idx * n_simulations) / (len(effect_sizes) * n_simulations)
|
|
1634
|
+
print(f" Simulation progress: {pct:.0%}")
|
|
1635
|
+
|
|
1636
|
+
sim_seed = rng.integers(0, 2**31)
|
|
1637
|
+
|
|
1638
|
+
# --- Generate data ---
|
|
1639
|
+
if use_custom_dgp:
|
|
1640
|
+
assert data_generator is not None
|
|
1641
|
+
data = data_generator(
|
|
1642
|
+
n_units=n_units,
|
|
1643
|
+
n_periods=n_periods,
|
|
1644
|
+
treatment_effect=effect,
|
|
1645
|
+
treatment_fraction=treatment_fraction,
|
|
1646
|
+
treatment_period=treatment_period,
|
|
1647
|
+
noise_sd=sigma,
|
|
1648
|
+
seed=sim_seed,
|
|
1649
|
+
**data_gen_kwargs,
|
|
1650
|
+
)
|
|
1651
|
+
else:
|
|
1652
|
+
assert profile is not None
|
|
1653
|
+
dgp_kwargs = profile.dgp_kwargs_builder(
|
|
1654
|
+
n_units=n_units,
|
|
1655
|
+
n_periods=n_periods,
|
|
1656
|
+
treatment_effect=effect,
|
|
1657
|
+
treatment_fraction=treatment_fraction,
|
|
1658
|
+
treatment_period=treatment_period,
|
|
1659
|
+
sigma=sigma,
|
|
1660
|
+
)
|
|
1661
|
+
dgp_kwargs.update(data_gen_kwargs)
|
|
1662
|
+
dgp_kwargs.pop("seed", None)
|
|
1663
|
+
data = profile.default_dgp(seed=sim_seed, **dgp_kwargs)
|
|
1664
|
+
|
|
1665
|
+
# Check SDID placebo feasibility on realized data (custom DGP path)
|
|
1666
|
+
if effect_idx == 0 and sim == 0 and estimator_name == "SyntheticDiD":
|
|
1667
|
+
_check_sdid_placebo_data(data, estimator, est_kwargs)
|
|
1668
|
+
|
|
1669
|
+
try:
|
|
1670
|
+
# --- Fit estimator ---
|
|
1671
|
+
if profile is not None and not use_custom_dgp:
|
|
1672
|
+
fit_kwargs = profile.fit_kwargs_builder(
|
|
1673
|
+
data, n_units, n_periods, treatment_period
|
|
1674
|
+
)
|
|
1675
|
+
fit_kwargs.update(est_kwargs)
|
|
1676
|
+
else:
|
|
1677
|
+
# Custom DGP fallback: use registry fit kwargs if available,
|
|
1678
|
+
# otherwise use basic DiD signature
|
|
1679
|
+
if profile is not None:
|
|
1680
|
+
fit_kwargs = profile.fit_kwargs_builder(
|
|
1681
|
+
data, n_units, n_periods, treatment_period
|
|
1682
|
+
)
|
|
1683
|
+
fit_kwargs.update(est_kwargs)
|
|
1684
|
+
else:
|
|
1685
|
+
fit_kwargs = dict(est_kwargs)
|
|
1686
|
+
|
|
1687
|
+
result = estimator.fit(data, **fit_kwargs)
|
|
1688
|
+
|
|
1689
|
+
# --- Extract results ---
|
|
1690
|
+
if profile is not None:
|
|
1691
|
+
att, se, p_val, ci = profile.result_extractor(result)
|
|
1692
|
+
elif result_extractor is not None:
|
|
1693
|
+
att, se, p_val, ci = result_extractor(result)
|
|
1694
|
+
else:
|
|
1695
|
+
att = result.att if hasattr(result, "att") else result.avg_att
|
|
1696
|
+
se = result.se if hasattr(result, "se") else result.avg_se
|
|
1697
|
+
p_val = result.p_value if hasattr(result, "p_value") else result.avg_p_value
|
|
1698
|
+
ci = result.conf_int if hasattr(result, "conf_int") else result.avg_conf_int
|
|
1699
|
+
|
|
1700
|
+
# NaN p-value → treat as non-rejection
|
|
1701
|
+
rejected = bool(p_val < alpha) if not np.isnan(p_val) else False
|
|
1702
|
+
|
|
1703
|
+
estimates.append(att)
|
|
1704
|
+
ses.append(se)
|
|
1705
|
+
p_values.append(p_val)
|
|
1706
|
+
rejections.append(rejected)
|
|
1707
|
+
ci_contains_true.append(ci[0] <= effect <= ci[1])
|
|
1708
|
+
|
|
1709
|
+
except Exception as e:
|
|
1710
|
+
n_failures += 1
|
|
1711
|
+
if progress:
|
|
1712
|
+
print(f" Warning: Simulation {sim} failed: {e}")
|
|
1713
|
+
continue
|
|
1714
|
+
|
|
1715
|
+
# Warn if too many simulations failed
|
|
1716
|
+
failure_rate = n_failures / n_simulations
|
|
1717
|
+
if failure_rate > 0.1:
|
|
1718
|
+
warnings.warn(
|
|
1719
|
+
f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) "
|
|
1720
|
+
f"failed for effect_size={effect}. "
|
|
1721
|
+
f"Check estimator and data generator.",
|
|
1722
|
+
UserWarning,
|
|
1723
|
+
)
|
|
1724
|
+
|
|
1725
|
+
if len(estimates) == 0:
|
|
1726
|
+
raise RuntimeError("All simulations failed. Check estimator and data generator.")
|
|
1727
|
+
|
|
1728
|
+
power_val = np.mean(rejections)
|
|
1729
|
+
all_powers.append(power_val)
|
|
1730
|
+
|
|
1731
|
+
if is_primary:
|
|
1732
|
+
primary_estimates = estimates
|
|
1733
|
+
primary_ses = ses
|
|
1734
|
+
primary_p_values = p_values
|
|
1735
|
+
primary_rejections = rejections
|
|
1736
|
+
primary_ci_contains = ci_contains_true
|
|
1737
|
+
|
|
1738
|
+
# Compute confidence interval for power (primary effect)
|
|
1739
|
+
power_val = all_powers[primary_idx]
|
|
1740
|
+
n_valid = len(primary_rejections)
|
|
1741
|
+
power_se = np.sqrt(power_val * (1 - power_val) / n_valid)
|
|
1742
|
+
z = stats.norm.ppf(0.975)
|
|
1743
|
+
power_ci = (
|
|
1744
|
+
max(0.0, power_val - z * power_se),
|
|
1745
|
+
min(1.0, power_val + z * power_se),
|
|
1746
|
+
)
|
|
1747
|
+
|
|
1748
|
+
mean_estimate = np.mean(primary_estimates)
|
|
1749
|
+
std_estimate = np.std(primary_estimates, ddof=1)
|
|
1750
|
+
mean_se = np.mean(primary_ses)
|
|
1751
|
+
coverage = np.mean(primary_ci_contains)
|
|
1752
|
+
|
|
1753
|
+
return SimulationPowerResults(
|
|
1754
|
+
power=power_val,
|
|
1755
|
+
power_se=power_se,
|
|
1756
|
+
power_ci=power_ci,
|
|
1757
|
+
rejection_rate=power_val,
|
|
1758
|
+
mean_estimate=mean_estimate,
|
|
1759
|
+
std_estimate=std_estimate,
|
|
1760
|
+
mean_se=mean_se,
|
|
1761
|
+
coverage=coverage,
|
|
1762
|
+
n_simulations=n_valid,
|
|
1763
|
+
effect_sizes=effect_sizes,
|
|
1764
|
+
powers=all_powers,
|
|
1765
|
+
true_effect=primary_effect,
|
|
1766
|
+
alpha=alpha,
|
|
1767
|
+
estimator_name=estimator_name,
|
|
1768
|
+
simulation_results=[
|
|
1769
|
+
{"estimate": e, "se": s, "p_value": p, "rejected": r}
|
|
1770
|
+
for e, s, p, r in zip(
|
|
1771
|
+
primary_estimates,
|
|
1772
|
+
primary_ses,
|
|
1773
|
+
primary_p_values,
|
|
1774
|
+
primary_rejections,
|
|
1775
|
+
)
|
|
1776
|
+
],
|
|
1777
|
+
effective_n_units=effective_n_units,
|
|
1778
|
+
)
|
|
1779
|
+
|
|
1780
|
+
|
|
1781
|
+
# ---------------------------------------------------------------------------
|
|
1782
|
+
# Simulation-based MDE and sample-size search
|
|
1783
|
+
# ---------------------------------------------------------------------------
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
@dataclass
|
|
1787
|
+
class SimulationMDEResults:
|
|
1788
|
+
"""
|
|
1789
|
+
Results from simulation-based minimum detectable effect search.
|
|
1790
|
+
|
|
1791
|
+
Attributes
|
|
1792
|
+
----------
|
|
1793
|
+
mde : float
|
|
1794
|
+
Minimum detectable effect (smallest effect achieving target power).
|
|
1795
|
+
power_at_mde : float
|
|
1796
|
+
Power achieved at the MDE.
|
|
1797
|
+
target_power : float
|
|
1798
|
+
Target power used in the search.
|
|
1799
|
+
alpha : float
|
|
1800
|
+
Significance level.
|
|
1801
|
+
n_units : int
|
|
1802
|
+
Sample size used.
|
|
1803
|
+
n_simulations_per_step : int
|
|
1804
|
+
Number of simulations per bisection step.
|
|
1805
|
+
n_steps : int
|
|
1806
|
+
Number of bisection steps performed.
|
|
1807
|
+
search_path : list of dict
|
|
1808
|
+
Diagnostic trace of ``{effect_size, power}`` at each step.
|
|
1809
|
+
estimator_name : str
|
|
1810
|
+
Name of the estimator used.
|
|
1811
|
+
effective_n_units : int or None
|
|
1812
|
+
Effective sample size when it differs from the requested ``n_units``
|
|
1813
|
+
(e.g., due to DDD grid rounding). ``None`` when no rounding occurred.
|
|
1814
|
+
"""
|
|
1815
|
+
|
|
1816
|
+
mde: float
|
|
1817
|
+
power_at_mde: float
|
|
1818
|
+
target_power: float
|
|
1819
|
+
alpha: float
|
|
1820
|
+
n_units: int
|
|
1821
|
+
n_simulations_per_step: int
|
|
1822
|
+
n_steps: int
|
|
1823
|
+
search_path: List[Dict[str, float]]
|
|
1824
|
+
estimator_name: str
|
|
1825
|
+
effective_n_units: Optional[int] = None
|
|
1826
|
+
|
|
1827
|
+
def __repr__(self) -> str:
|
|
1828
|
+
return (
|
|
1829
|
+
f"SimulationMDEResults(mde={self.mde:.4f}, "
|
|
1830
|
+
f"power_at_mde={self.power_at_mde:.3f}, "
|
|
1831
|
+
f"n_steps={self.n_steps})"
|
|
1832
|
+
)
|
|
1833
|
+
|
|
1834
|
+
def summary(self) -> str:
|
|
1835
|
+
"""Generate a formatted summary."""
|
|
1836
|
+
lines = [
|
|
1837
|
+
"=" * 65,
|
|
1838
|
+
"Simulation-Based MDE Results".center(65),
|
|
1839
|
+
"=" * 65,
|
|
1840
|
+
"",
|
|
1841
|
+
f"{'Estimator:':<35} {self.estimator_name}",
|
|
1842
|
+
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
|
|
1843
|
+
f"{'Target power:':<35} {self.target_power:.1%}",
|
|
1844
|
+
f"{'Sample size (n_units):':<35} {self.n_units}",
|
|
1845
|
+
]
|
|
1846
|
+
if self.effective_n_units is not None:
|
|
1847
|
+
lines.append(
|
|
1848
|
+
f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
|
|
1849
|
+
)
|
|
1850
|
+
lines += [
|
|
1851
|
+
f"{'Simulations per step:':<35} {self.n_simulations_per_step}",
|
|
1852
|
+
"",
|
|
1853
|
+
"-" * 65,
|
|
1854
|
+
"Search Results".center(65),
|
|
1855
|
+
"-" * 65,
|
|
1856
|
+
f"{'Minimum detectable effect:':<35} {self.mde:.4f}",
|
|
1857
|
+
f"{'Power at MDE:':<35} {self.power_at_mde:.1%}",
|
|
1858
|
+
f"{'Bisection steps:':<35} {self.n_steps}",
|
|
1859
|
+
"=" * 65,
|
|
1860
|
+
]
|
|
1861
|
+
return "\n".join(lines)
|
|
1862
|
+
|
|
1863
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1864
|
+
"""Convert results to a dictionary."""
|
|
1865
|
+
return {
|
|
1866
|
+
"mde": self.mde,
|
|
1867
|
+
"power_at_mde": self.power_at_mde,
|
|
1868
|
+
"target_power": self.target_power,
|
|
1869
|
+
"alpha": self.alpha,
|
|
1870
|
+
"n_units": self.n_units,
|
|
1871
|
+
"effective_n_units": self.effective_n_units,
|
|
1872
|
+
"n_simulations_per_step": self.n_simulations_per_step,
|
|
1873
|
+
"n_steps": self.n_steps,
|
|
1874
|
+
"estimator_name": self.estimator_name,
|
|
1875
|
+
}
|
|
1876
|
+
|
|
1877
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
1878
|
+
"""Convert results to a single-row DataFrame."""
|
|
1879
|
+
return pd.DataFrame([self.to_dict()])
|
|
1880
|
+
|
|
1881
|
+
|
|
1882
|
+
@dataclass
|
|
1883
|
+
class SimulationSampleSizeResults:
|
|
1884
|
+
"""
|
|
1885
|
+
Results from simulation-based sample size search.
|
|
1886
|
+
|
|
1887
|
+
Attributes
|
|
1888
|
+
----------
|
|
1889
|
+
required_n : int
|
|
1890
|
+
Required number of units to achieve target power.
|
|
1891
|
+
power_at_n : float
|
|
1892
|
+
Power achieved at the required N.
|
|
1893
|
+
target_power : float
|
|
1894
|
+
Target power used in the search.
|
|
1895
|
+
alpha : float
|
|
1896
|
+
Significance level.
|
|
1897
|
+
effect_size : float
|
|
1898
|
+
Effect size used in the search.
|
|
1899
|
+
n_simulations_per_step : int
|
|
1900
|
+
Number of simulations per bisection step.
|
|
1901
|
+
n_steps : int
|
|
1902
|
+
Number of bisection steps performed.
|
|
1903
|
+
search_path : list of dict
|
|
1904
|
+
Diagnostic trace of ``{n_units, power}`` at each step.
|
|
1905
|
+
estimator_name : str
|
|
1906
|
+
Name of the estimator used.
|
|
1907
|
+
effective_n_units : int or None
|
|
1908
|
+
Effective sample size when it differs from ``required_n``
|
|
1909
|
+
(e.g., due to DDD grid rounding). ``None`` when no rounding occurred
|
|
1910
|
+
or when the search already snapped to the estimator's grid.
|
|
1911
|
+
"""
|
|
1912
|
+
|
|
1913
|
+
required_n: int
|
|
1914
|
+
power_at_n: float
|
|
1915
|
+
target_power: float
|
|
1916
|
+
alpha: float
|
|
1917
|
+
effect_size: float
|
|
1918
|
+
n_simulations_per_step: int
|
|
1919
|
+
n_steps: int
|
|
1920
|
+
search_path: List[Dict[str, float]]
|
|
1921
|
+
estimator_name: str
|
|
1922
|
+
effective_n_units: Optional[int] = None
|
|
1923
|
+
|
|
1924
|
+
def __repr__(self) -> str:
|
|
1925
|
+
return (
|
|
1926
|
+
f"SimulationSampleSizeResults(required_n={self.required_n}, "
|
|
1927
|
+
f"power_at_n={self.power_at_n:.3f}, "
|
|
1928
|
+
f"n_steps={self.n_steps})"
|
|
1929
|
+
)
|
|
1930
|
+
|
|
1931
|
+
def summary(self) -> str:
|
|
1932
|
+
"""Generate a formatted summary."""
|
|
1933
|
+
lines = [
|
|
1934
|
+
"=" * 65,
|
|
1935
|
+
"Simulation-Based Sample Size Results".center(65),
|
|
1936
|
+
"=" * 65,
|
|
1937
|
+
"",
|
|
1938
|
+
f"{'Estimator:':<35} {self.estimator_name}",
|
|
1939
|
+
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
|
|
1940
|
+
f"{'Target power:':<35} {self.target_power:.1%}",
|
|
1941
|
+
f"{'Effect size:':<35} {self.effect_size:.4f}",
|
|
1942
|
+
f"{'Simulations per step:':<35} {self.n_simulations_per_step}",
|
|
1943
|
+
"",
|
|
1944
|
+
"-" * 65,
|
|
1945
|
+
"Search Results".center(65),
|
|
1946
|
+
"-" * 65,
|
|
1947
|
+
f"{'Required sample size:':<35} {self.required_n}",
|
|
1948
|
+
f"{'Power at required N:':<35} {self.power_at_n:.1%}",
|
|
1949
|
+
f"{'Bisection steps:':<35} {self.n_steps}",
|
|
1950
|
+
]
|
|
1951
|
+
if self.effective_n_units is not None:
|
|
1952
|
+
lines.append(
|
|
1953
|
+
f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
|
|
1954
|
+
)
|
|
1955
|
+
lines.append("=" * 65)
|
|
1956
|
+
return "\n".join(lines)
|
|
1957
|
+
|
|
1958
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1959
|
+
"""Convert results to a dictionary."""
|
|
1960
|
+
return {
|
|
1961
|
+
"required_n": self.required_n,
|
|
1962
|
+
"power_at_n": self.power_at_n,
|
|
1963
|
+
"target_power": self.target_power,
|
|
1964
|
+
"alpha": self.alpha,
|
|
1965
|
+
"effect_size": self.effect_size,
|
|
1966
|
+
"n_simulations_per_step": self.n_simulations_per_step,
|
|
1967
|
+
"n_steps": self.n_steps,
|
|
1968
|
+
"estimator_name": self.estimator_name,
|
|
1969
|
+
"effective_n_units": self.effective_n_units,
|
|
1970
|
+
}
|
|
1971
|
+
|
|
1972
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
1973
|
+
"""Convert results to a single-row DataFrame."""
|
|
1974
|
+
return pd.DataFrame([self.to_dict()])
|
|
1975
|
+
|
|
1976
|
+
|
|
1977
|
+
def simulate_mde(
|
|
1978
|
+
estimator: Any,
|
|
1979
|
+
n_units: int = 100,
|
|
1980
|
+
n_periods: int = 4,
|
|
1981
|
+
treatment_fraction: float = 0.5,
|
|
1982
|
+
treatment_period: int = 2,
|
|
1983
|
+
sigma: float = 1.0,
|
|
1984
|
+
n_simulations: int = 200,
|
|
1985
|
+
power: float = 0.80,
|
|
1986
|
+
alpha: float = 0.05,
|
|
1987
|
+
effect_range: Optional[Tuple[float, float]] = None,
|
|
1988
|
+
tol: float = 0.02,
|
|
1989
|
+
max_steps: int = 15,
|
|
1990
|
+
seed: Optional[int] = None,
|
|
1991
|
+
data_generator: Optional[Callable] = None,
|
|
1992
|
+
data_generator_kwargs: Optional[Dict[str, Any]] = None,
|
|
1993
|
+
estimator_kwargs: Optional[Dict[str, Any]] = None,
|
|
1994
|
+
result_extractor: Optional[Callable] = None,
|
|
1995
|
+
progress: bool = True,
|
|
1996
|
+
) -> SimulationMDEResults:
|
|
1997
|
+
"""
|
|
1998
|
+
Find the minimum detectable effect via simulation-based bisection search.
|
|
1999
|
+
|
|
2000
|
+
Searches over effect sizes to find the smallest effect that achieves the
|
|
2001
|
+
target power, using ``simulate_power()`` at each step.
|
|
2002
|
+
|
|
2003
|
+
Parameters
|
|
2004
|
+
----------
|
|
2005
|
+
estimator : estimator object
|
|
2006
|
+
DiD estimator to use.
|
|
2007
|
+
n_units : int, default=100
|
|
2008
|
+
Number of units per simulation.
|
|
2009
|
+
n_periods : int, default=4
|
|
2010
|
+
Number of time periods.
|
|
2011
|
+
treatment_fraction : float, default=0.5
|
|
2012
|
+
Fraction of units that are treated.
|
|
2013
|
+
treatment_period : int, default=2
|
|
2014
|
+
First post-treatment period (0-indexed).
|
|
2015
|
+
sigma : float, default=1.0
|
|
2016
|
+
Residual standard deviation.
|
|
2017
|
+
n_simulations : int, default=200
|
|
2018
|
+
Simulations per bisection step.
|
|
2019
|
+
power : float, default=0.80
|
|
2020
|
+
Target power.
|
|
2021
|
+
alpha : float, default=0.05
|
|
2022
|
+
Significance level.
|
|
2023
|
+
effect_range : tuple of (float, float), optional
|
|
2024
|
+
``(lo, hi)`` bracket for the search. If None, auto-brackets.
|
|
2025
|
+
tol : float, default=0.02
|
|
2026
|
+
Convergence tolerance on power.
|
|
2027
|
+
max_steps : int, default=15
|
|
2028
|
+
Maximum bisection steps.
|
|
2029
|
+
seed : int, optional
|
|
2030
|
+
Random seed for reproducibility.
|
|
2031
|
+
data_generator : callable, optional
|
|
2032
|
+
Custom data generation function.
|
|
2033
|
+
data_generator_kwargs : dict, optional
|
|
2034
|
+
Additional keyword arguments for data generator.
|
|
2035
|
+
estimator_kwargs : dict, optional
|
|
2036
|
+
Additional keyword arguments for estimator.fit().
|
|
2037
|
+
result_extractor : callable, optional
|
|
2038
|
+
Custom function to extract results from the estimator output.
|
|
2039
|
+
Forwarded to ``simulate_power()``.
|
|
2040
|
+
progress : bool, default=True
|
|
2041
|
+
Whether to print progress updates.
|
|
2042
|
+
|
|
2043
|
+
Returns
|
|
2044
|
+
-------
|
|
2045
|
+
SimulationMDEResults
|
|
2046
|
+
Results including the MDE and search diagnostics.
|
|
2047
|
+
|
|
2048
|
+
Examples
|
|
2049
|
+
--------
|
|
2050
|
+
>>> from diff_diff import simulate_mde, DifferenceInDifferences
|
|
2051
|
+
>>> result = simulate_mde(DifferenceInDifferences(), n_simulations=100, seed=42)
|
|
2052
|
+
>>> print(f"MDE: {result.mde:.3f}")
|
|
2053
|
+
"""
|
|
2054
|
+
master_rng = np.random.default_rng(seed)
|
|
2055
|
+
estimator_name = type(estimator).__name__
|
|
2056
|
+
search_path: List[Dict[str, float]] = []
|
|
2057
|
+
|
|
2058
|
+
# Compute effective N for DDD (N is fixed throughout MDE search)
|
|
2059
|
+
if estimator_name == "TripleDifference" and data_generator is None:
|
|
2060
|
+
effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs)
|
|
2061
|
+
else:
|
|
2062
|
+
effective_n_units = None
|
|
2063
|
+
|
|
2064
|
+
common_kwargs: Dict[str, Any] = dict(
|
|
2065
|
+
estimator=estimator,
|
|
2066
|
+
n_units=n_units,
|
|
2067
|
+
n_periods=n_periods,
|
|
2068
|
+
treatment_fraction=treatment_fraction,
|
|
2069
|
+
treatment_period=treatment_period,
|
|
2070
|
+
sigma=sigma,
|
|
2071
|
+
n_simulations=n_simulations,
|
|
2072
|
+
alpha=alpha,
|
|
2073
|
+
data_generator=data_generator,
|
|
2074
|
+
data_generator_kwargs=data_generator_kwargs,
|
|
2075
|
+
estimator_kwargs=estimator_kwargs,
|
|
2076
|
+
result_extractor=result_extractor,
|
|
2077
|
+
progress=False,
|
|
2078
|
+
)
|
|
2079
|
+
|
|
2080
|
+
def _power_at(effect: float) -> float:
|
|
2081
|
+
step_seed = int(master_rng.integers(0, 2**31))
|
|
2082
|
+
res = simulate_power(treatment_effect=effect, seed=step_seed, **common_kwargs)
|
|
2083
|
+
pwr = float(res.power)
|
|
2084
|
+
search_path.append({"effect_size": effect, "power": pwr})
|
|
2085
|
+
if progress:
|
|
2086
|
+
print(f" MDE search: effect={effect:.4f}, power={pwr:.3f}")
|
|
2087
|
+
return pwr
|
|
2088
|
+
|
|
2089
|
+
# --- Bracket ---
|
|
2090
|
+
if effect_range is not None:
|
|
2091
|
+
lo, hi = effect_range
|
|
2092
|
+
power_lo = _power_at(lo)
|
|
2093
|
+
power_hi = _power_at(hi)
|
|
2094
|
+
if power_lo >= power:
|
|
2095
|
+
warnings.warn(
|
|
2096
|
+
f"Power at effect={lo} is {power_lo:.2f} >= target {power}. "
|
|
2097
|
+
f"Lower bound already exceeds target power. Returning lo as MDE.",
|
|
2098
|
+
UserWarning,
|
|
2099
|
+
)
|
|
2100
|
+
return SimulationMDEResults(
|
|
2101
|
+
mde=lo,
|
|
2102
|
+
power_at_mde=power_lo,
|
|
2103
|
+
target_power=power,
|
|
2104
|
+
alpha=alpha,
|
|
2105
|
+
n_units=n_units,
|
|
2106
|
+
n_simulations_per_step=n_simulations,
|
|
2107
|
+
n_steps=len(search_path),
|
|
2108
|
+
search_path=search_path,
|
|
2109
|
+
estimator_name=estimator_name,
|
|
2110
|
+
effective_n_units=effective_n_units,
|
|
2111
|
+
)
|
|
2112
|
+
if power_hi < power:
|
|
2113
|
+
warnings.warn(
|
|
2114
|
+
f"Target power {power} not bracketed: power at effect={hi} "
|
|
2115
|
+
f"is {power_hi:.2f}. Upper bound may be too low.",
|
|
2116
|
+
UserWarning,
|
|
2117
|
+
)
|
|
2118
|
+
else:
|
|
2119
|
+
lo = 0.0
|
|
2120
|
+
# Check that power at zero is below target (no inflated Type I error)
|
|
2121
|
+
power_at_zero = _power_at(0.0)
|
|
2122
|
+
if power_at_zero >= power:
|
|
2123
|
+
warnings.warn(
|
|
2124
|
+
f"Power at effect=0 is {power_at_zero:.2f} >= target {power}. "
|
|
2125
|
+
f"This suggests inflated Type I error. Returning MDE=0.",
|
|
2126
|
+
UserWarning,
|
|
2127
|
+
)
|
|
2128
|
+
return SimulationMDEResults(
|
|
2129
|
+
mde=0.0,
|
|
2130
|
+
power_at_mde=power_at_zero,
|
|
2131
|
+
target_power=power,
|
|
2132
|
+
alpha=alpha,
|
|
2133
|
+
n_units=n_units,
|
|
2134
|
+
n_simulations_per_step=n_simulations,
|
|
2135
|
+
n_steps=len(search_path),
|
|
2136
|
+
search_path=search_path,
|
|
2137
|
+
estimator_name=estimator_name,
|
|
2138
|
+
effective_n_units=effective_n_units,
|
|
2139
|
+
)
|
|
2140
|
+
|
|
2141
|
+
hi = sigma
|
|
2142
|
+
for _ in range(10):
|
|
2143
|
+
if _power_at(hi) >= power:
|
|
2144
|
+
break
|
|
2145
|
+
hi *= 2
|
|
2146
|
+
else:
|
|
2147
|
+
warnings.warn(
|
|
2148
|
+
f"Could not bracket MDE (power at effect={hi} still below "
|
|
2149
|
+
f"{power}). Returning best upper bound.",
|
|
2150
|
+
UserWarning,
|
|
2151
|
+
)
|
|
2152
|
+
|
|
2153
|
+
# --- Bisect ---
|
|
2154
|
+
best_effect = hi
|
|
2155
|
+
best_power = search_path[-1]["power"] if search_path else 0.0
|
|
2156
|
+
|
|
2157
|
+
for _ in range(max_steps):
|
|
2158
|
+
mid = (lo + hi) / 2
|
|
2159
|
+
pwr = _power_at(mid)
|
|
2160
|
+
|
|
2161
|
+
if pwr >= power:
|
|
2162
|
+
hi = mid
|
|
2163
|
+
best_effect = mid
|
|
2164
|
+
best_power = pwr
|
|
2165
|
+
else:
|
|
2166
|
+
lo = mid
|
|
2167
|
+
|
|
2168
|
+
# Convergence: effect range is tight or power is close enough
|
|
2169
|
+
if hi - lo < max(tol * hi, 1e-6) or abs(pwr - power) < tol:
|
|
2170
|
+
break
|
|
2171
|
+
|
|
2172
|
+
return SimulationMDEResults(
|
|
2173
|
+
mde=best_effect,
|
|
2174
|
+
power_at_mde=best_power,
|
|
2175
|
+
target_power=power,
|
|
2176
|
+
alpha=alpha,
|
|
2177
|
+
n_units=n_units,
|
|
2178
|
+
n_simulations_per_step=n_simulations,
|
|
2179
|
+
n_steps=len(search_path),
|
|
2180
|
+
search_path=search_path,
|
|
2181
|
+
estimator_name=estimator_name,
|
|
2182
|
+
effective_n_units=effective_n_units,
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
|
|
2186
|
+
def simulate_sample_size(
|
|
2187
|
+
estimator: Any,
|
|
2188
|
+
treatment_effect: float = 5.0,
|
|
2189
|
+
n_periods: int = 4,
|
|
2190
|
+
treatment_fraction: float = 0.5,
|
|
2191
|
+
treatment_period: int = 2,
|
|
2192
|
+
sigma: float = 1.0,
|
|
2193
|
+
n_simulations: int = 200,
|
|
2194
|
+
power: float = 0.80,
|
|
2195
|
+
alpha: float = 0.05,
|
|
2196
|
+
n_range: Optional[Tuple[int, int]] = None,
|
|
2197
|
+
max_steps: int = 15,
|
|
2198
|
+
seed: Optional[int] = None,
|
|
2199
|
+
data_generator: Optional[Callable] = None,
|
|
2200
|
+
data_generator_kwargs: Optional[Dict[str, Any]] = None,
|
|
2201
|
+
estimator_kwargs: Optional[Dict[str, Any]] = None,
|
|
2202
|
+
result_extractor: Optional[Callable] = None,
|
|
2203
|
+
progress: bool = True,
|
|
2204
|
+
) -> SimulationSampleSizeResults:
|
|
2205
|
+
"""
|
|
2206
|
+
Find the required sample size via simulation-based bisection search.
|
|
2207
|
+
|
|
2208
|
+
Searches over ``n_units`` to find the smallest N that achieves the
|
|
2209
|
+
target power, using ``simulate_power()`` at each step.
|
|
2210
|
+
|
|
2211
|
+
Parameters
|
|
2212
|
+
----------
|
|
2213
|
+
estimator : estimator object
|
|
2214
|
+
DiD estimator to use.
|
|
2215
|
+
treatment_effect : float, default=5.0
|
|
2216
|
+
True treatment effect to simulate.
|
|
2217
|
+
n_periods : int, default=4
|
|
2218
|
+
Number of time periods.
|
|
2219
|
+
treatment_fraction : float, default=0.5
|
|
2220
|
+
Fraction of units that are treated.
|
|
2221
|
+
treatment_period : int, default=2
|
|
2222
|
+
First post-treatment period (0-indexed).
|
|
2223
|
+
sigma : float, default=1.0
|
|
2224
|
+
Residual standard deviation.
|
|
2225
|
+
n_simulations : int, default=200
|
|
2226
|
+
Simulations per bisection step.
|
|
2227
|
+
power : float, default=0.80
|
|
2228
|
+
Target power.
|
|
2229
|
+
alpha : float, default=0.05
|
|
2230
|
+
Significance level.
|
|
2231
|
+
n_range : tuple of (int, int), optional
|
|
2232
|
+
``(lo, hi)`` bracket for sample size. If None, auto-brackets.
|
|
2233
|
+
max_steps : int, default=15
|
|
2234
|
+
Maximum bisection steps.
|
|
2235
|
+
seed : int, optional
|
|
2236
|
+
Random seed for reproducibility.
|
|
2237
|
+
data_generator : callable, optional
|
|
2238
|
+
Custom data generation function.
|
|
2239
|
+
data_generator_kwargs : dict, optional
|
|
2240
|
+
Additional keyword arguments for data generator.
|
|
2241
|
+
estimator_kwargs : dict, optional
|
|
2242
|
+
Additional keyword arguments for estimator.fit().
|
|
2243
|
+
result_extractor : callable, optional
|
|
2244
|
+
Custom function to extract results from the estimator output.
|
|
2245
|
+
Forwarded to ``simulate_power()``.
|
|
2246
|
+
progress : bool, default=True
|
|
2247
|
+
Whether to print progress updates.
|
|
2248
|
+
|
|
2249
|
+
Returns
|
|
2250
|
+
-------
|
|
2251
|
+
SimulationSampleSizeResults
|
|
2252
|
+
Results including the required N and search diagnostics.
|
|
2253
|
+
|
|
2254
|
+
Examples
|
|
2255
|
+
--------
|
|
2256
|
+
>>> from diff_diff import simulate_sample_size, DifferenceInDifferences
|
|
2257
|
+
>>> result = simulate_sample_size(
|
|
2258
|
+
... DifferenceInDifferences(), treatment_effect=5.0, n_simulations=100, seed=42
|
|
2259
|
+
... )
|
|
2260
|
+
>>> print(f"Required N: {result.required_n}")
|
|
2261
|
+
"""
|
|
2262
|
+
master_rng = np.random.default_rng(seed)
|
|
2263
|
+
estimator_name = type(estimator).__name__
|
|
2264
|
+
search_path: List[Dict[str, float]] = []
|
|
2265
|
+
|
|
2266
|
+
# Determine min_n from registry
|
|
2267
|
+
registry = _get_registry()
|
|
2268
|
+
profile = registry.get(estimator_name)
|
|
2269
|
+
min_n = profile.min_n if profile is not None else 20
|
|
2270
|
+
|
|
2271
|
+
# DDD grid snapping: bisection candidates must be multiples of 8
|
|
2272
|
+
is_ddd_grid = estimator_name == "TripleDifference" and data_generator is None
|
|
2273
|
+
grid_step = 8 if is_ddd_grid else 1
|
|
2274
|
+
convergence_threshold = grid_step + 1 # 9 for DDD, 2 for others
|
|
2275
|
+
|
|
2276
|
+
if is_ddd_grid and data_generator_kwargs and "n_per_cell" in data_generator_kwargs:
|
|
2277
|
+
raise ValueError(
|
|
2278
|
+
"data_generator_kwargs contains 'n_per_cell', which conflicts with "
|
|
2279
|
+
"the sample-size search in simulate_sample_size(). For "
|
|
2280
|
+
"TripleDifference, n_per_cell is derived from n_units (the search "
|
|
2281
|
+
"variable). Use simulate_power() with a fixed n_per_cell override "
|
|
2282
|
+
"instead, or pass a custom data_generator."
|
|
2283
|
+
)
|
|
2284
|
+
|
|
2285
|
+
def _snap_n(n: int, direction: str = "down", floor: Optional[int] = None) -> int:
|
|
2286
|
+
if grid_step == 1:
|
|
2287
|
+
return n
|
|
2288
|
+
actual_floor = floor if floor is not None else min_n
|
|
2289
|
+
if direction == "up":
|
|
2290
|
+
return max(actual_floor, ((n + grid_step - 1) // grid_step) * grid_step)
|
|
2291
|
+
return max(actual_floor, (n // grid_step) * grid_step)
|
|
2292
|
+
|
|
2293
|
+
common_kwargs: Dict[str, Any] = dict(
|
|
2294
|
+
estimator=estimator,
|
|
2295
|
+
n_periods=n_periods,
|
|
2296
|
+
treatment_effect=treatment_effect,
|
|
2297
|
+
treatment_fraction=treatment_fraction,
|
|
2298
|
+
treatment_period=treatment_period,
|
|
2299
|
+
sigma=sigma,
|
|
2300
|
+
n_simulations=n_simulations,
|
|
2301
|
+
alpha=alpha,
|
|
2302
|
+
data_generator=data_generator,
|
|
2303
|
+
data_generator_kwargs=data_generator_kwargs,
|
|
2304
|
+
estimator_kwargs=estimator_kwargs,
|
|
2305
|
+
result_extractor=result_extractor,
|
|
2306
|
+
progress=False,
|
|
2307
|
+
)
|
|
2308
|
+
|
|
2309
|
+
def _power_at_n(n: int) -> float:
|
|
2310
|
+
step_seed = int(master_rng.integers(0, 2**31))
|
|
2311
|
+
res = simulate_power(n_units=n, seed=step_seed, **common_kwargs)
|
|
2312
|
+
pwr = float(res.power)
|
|
2313
|
+
search_path.append({"n_units": float(n), "power": pwr})
|
|
2314
|
+
if progress:
|
|
2315
|
+
print(f" Sample size search: n={n}, power={pwr:.3f}")
|
|
2316
|
+
return pwr
|
|
2317
|
+
|
|
2318
|
+
# --- Bracket ---
|
|
2319
|
+
abs_min = 16 if is_ddd_grid else 4
|
|
2320
|
+
if n_range is not None:
|
|
2321
|
+
lo, hi = _snap_n(n_range[0], "up", floor=abs_min), _snap_n(
|
|
2322
|
+
n_range[1], "down", floor=abs_min
|
|
2323
|
+
)
|
|
2324
|
+
if lo > hi:
|
|
2325
|
+
lo = hi # collapsed bracket — evaluate single point
|
|
2326
|
+
power_lo = _power_at_n(lo)
|
|
2327
|
+
if power_lo >= power:
|
|
2328
|
+
warnings.warn(
|
|
2329
|
+
f"Power at n={lo} is {power_lo:.2f} >= target {power}. "
|
|
2330
|
+
f"Lower bound already achieves target power. Returning lo.",
|
|
2331
|
+
UserWarning,
|
|
2332
|
+
)
|
|
2333
|
+
return SimulationSampleSizeResults(
|
|
2334
|
+
required_n=lo,
|
|
2335
|
+
power_at_n=power_lo,
|
|
2336
|
+
target_power=power,
|
|
2337
|
+
alpha=alpha,
|
|
2338
|
+
effect_size=treatment_effect,
|
|
2339
|
+
n_simulations_per_step=n_simulations,
|
|
2340
|
+
n_steps=len(search_path),
|
|
2341
|
+
search_path=search_path,
|
|
2342
|
+
estimator_name=estimator_name,
|
|
2343
|
+
)
|
|
2344
|
+
power_hi = _power_at_n(hi)
|
|
2345
|
+
if power_hi < power:
|
|
2346
|
+
warnings.warn(
|
|
2347
|
+
f"Target power {power} not bracketed: power at n={hi} "
|
|
2348
|
+
f"is {power_hi:.2f}. Upper bound may be too low.",
|
|
2349
|
+
UserWarning,
|
|
2350
|
+
)
|
|
2351
|
+
else:
|
|
2352
|
+
lo = min_n
|
|
2353
|
+
power_lo = _power_at_n(lo)
|
|
2354
|
+
if power_lo >= power:
|
|
2355
|
+
# Floor achieves target — search downward for true minimum
|
|
2356
|
+
hi = lo
|
|
2357
|
+
found_lower = False
|
|
2358
|
+
probe = _snap_n(max(abs_min, lo // 2), floor=abs_min)
|
|
2359
|
+
for _ in range(8):
|
|
2360
|
+
if probe >= hi or probe < abs_min:
|
|
2361
|
+
break
|
|
2362
|
+
pwr = _power_at_n(probe)
|
|
2363
|
+
if pwr < power:
|
|
2364
|
+
lo = probe
|
|
2365
|
+
found_lower = True
|
|
2366
|
+
break
|
|
2367
|
+
hi = probe
|
|
2368
|
+
probe = _snap_n(max(abs_min, probe // 2), floor=abs_min)
|
|
2369
|
+
if not found_lower:
|
|
2370
|
+
# Even smallest viable N achieves target — return best found
|
|
2371
|
+
best = min(
|
|
2372
|
+
(s for s in search_path if s["power"] >= power),
|
|
2373
|
+
key=lambda s: s["n_units"],
|
|
2374
|
+
)
|
|
2375
|
+
warnings.warn(
|
|
2376
|
+
f"Power at n={int(best['n_units'])} is "
|
|
2377
|
+
f"{best['power']:.2f} >= target {power}. Could not "
|
|
2378
|
+
f"find a smaller N below target power. Pass "
|
|
2379
|
+
f"n_range=(lo, hi) to refine.",
|
|
2380
|
+
UserWarning,
|
|
2381
|
+
)
|
|
2382
|
+
return SimulationSampleSizeResults(
|
|
2383
|
+
required_n=int(best["n_units"]),
|
|
2384
|
+
power_at_n=best["power"],
|
|
2385
|
+
target_power=power,
|
|
2386
|
+
alpha=alpha,
|
|
2387
|
+
effect_size=treatment_effect,
|
|
2388
|
+
n_simulations_per_step=n_simulations,
|
|
2389
|
+
n_steps=len(search_path),
|
|
2390
|
+
search_path=search_path,
|
|
2391
|
+
estimator_name=estimator_name,
|
|
2392
|
+
)
|
|
2393
|
+
# Fall through to bisection with lo..hi bracket
|
|
2394
|
+
else:
|
|
2395
|
+
hi = max(100, 2 * min_n)
|
|
2396
|
+
for _ in range(10):
|
|
2397
|
+
if _power_at_n(hi) >= power:
|
|
2398
|
+
break
|
|
2399
|
+
hi *= 2
|
|
2400
|
+
else:
|
|
2401
|
+
warnings.warn(
|
|
2402
|
+
f"Could not bracket required N (power at n={hi} still "
|
|
2403
|
+
f"below {power}). Returning best upper bound.",
|
|
2404
|
+
UserWarning,
|
|
2405
|
+
)
|
|
2406
|
+
|
|
2407
|
+
# --- Bisect on integer n_units ---
|
|
2408
|
+
best_n = hi
|
|
2409
|
+
# Look up power at hi (search_path[-1] may not be hi after downward search)
|
|
2410
|
+
best_power = next(
|
|
2411
|
+
(s["power"] for s in reversed(search_path) if int(s["n_units"]) == hi),
|
|
2412
|
+
search_path[-1]["power"] if search_path else 0.0,
|
|
2413
|
+
)
|
|
2414
|
+
|
|
2415
|
+
for _ in range(max_steps):
|
|
2416
|
+
if hi - lo <= convergence_threshold:
|
|
2417
|
+
break
|
|
2418
|
+
mid = _snap_n((lo + hi) // 2, floor=abs_min)
|
|
2419
|
+
if mid <= lo or mid >= hi:
|
|
2420
|
+
break
|
|
2421
|
+
pwr = _power_at_n(mid)
|
|
2422
|
+
|
|
2423
|
+
if pwr >= power:
|
|
2424
|
+
hi = mid
|
|
2425
|
+
best_n = mid
|
|
2426
|
+
best_power = pwr
|
|
2427
|
+
else:
|
|
2428
|
+
lo = mid
|
|
2429
|
+
|
|
2430
|
+
# Final answer is hi (conservative ceiling) — skip if already evaluated
|
|
2431
|
+
if best_n != hi:
|
|
2432
|
+
final_pwr = _power_at_n(hi)
|
|
2433
|
+
if final_pwr >= power:
|
|
2434
|
+
best_n = hi
|
|
2435
|
+
best_power = final_pwr
|
|
2436
|
+
|
|
2437
|
+
return SimulationSampleSizeResults(
|
|
2438
|
+
required_n=best_n,
|
|
2439
|
+
power_at_n=best_power,
|
|
2440
|
+
target_power=power,
|
|
2441
|
+
alpha=alpha,
|
|
2442
|
+
effect_size=treatment_effect,
|
|
2443
|
+
n_simulations_per_step=n_simulations,
|
|
2444
|
+
n_steps=len(search_path),
|
|
2445
|
+
search_path=search_path,
|
|
2446
|
+
estimator_name=estimator_name,
|
|
2447
|
+
)
|
|
2448
|
+
|
|
2449
|
+
|
|
2450
|
+
def compute_mde(
|
|
2451
|
+
n_treated: int,
|
|
2452
|
+
n_control: int,
|
|
2453
|
+
sigma: float,
|
|
2454
|
+
power: float = 0.80,
|
|
2455
|
+
alpha: float = 0.05,
|
|
2456
|
+
n_pre: int = 1,
|
|
2457
|
+
n_post: int = 1,
|
|
2458
|
+
rho: float = 0.0,
|
|
2459
|
+
) -> float:
|
|
2460
|
+
"""
|
|
2461
|
+
Convenience function to compute minimum detectable effect.
|
|
2462
|
+
|
|
2463
|
+
Parameters
|
|
2464
|
+
----------
|
|
2465
|
+
n_treated : int
|
|
2466
|
+
Number of treated units.
|
|
2467
|
+
n_control : int
|
|
2468
|
+
Number of control units.
|
|
2469
|
+
sigma : float
|
|
2470
|
+
Residual standard deviation.
|
|
2471
|
+
power : float, default=0.80
|
|
2472
|
+
Target statistical power.
|
|
2473
|
+
alpha : float, default=0.05
|
|
2474
|
+
Significance level.
|
|
2475
|
+
n_pre : int, default=1
|
|
2476
|
+
Number of pre-treatment periods.
|
|
2477
|
+
n_post : int, default=1
|
|
2478
|
+
Number of post-treatment periods.
|
|
2479
|
+
rho : float, default=0.0
|
|
2480
|
+
Intra-cluster correlation.
|
|
2481
|
+
|
|
2482
|
+
Returns
|
|
2483
|
+
-------
|
|
2484
|
+
float
|
|
2485
|
+
Minimum detectable effect size.
|
|
2486
|
+
|
|
2487
|
+
Examples
|
|
2488
|
+
--------
|
|
2489
|
+
>>> mde = compute_mde(n_treated=50, n_control=50, sigma=10.0)
|
|
2490
|
+
>>> print(f"MDE: {mde:.2f}")
|
|
2491
|
+
"""
|
|
2492
|
+
pa = PowerAnalysis(alpha=alpha, power=power)
|
|
2493
|
+
result = pa.mde(n_treated, n_control, sigma, n_pre, n_post, rho)
|
|
2494
|
+
return result.mde
|
|
2495
|
+
|
|
2496
|
+
|
|
2497
|
+
def compute_power(
|
|
2498
|
+
effect_size: float,
|
|
2499
|
+
n_treated: int,
|
|
2500
|
+
n_control: int,
|
|
2501
|
+
sigma: float,
|
|
2502
|
+
alpha: float = 0.05,
|
|
2503
|
+
n_pre: int = 1,
|
|
2504
|
+
n_post: int = 1,
|
|
2505
|
+
rho: float = 0.0,
|
|
2506
|
+
) -> float:
|
|
2507
|
+
"""
|
|
2508
|
+
Convenience function to compute power for given effect and sample.
|
|
2509
|
+
|
|
2510
|
+
Parameters
|
|
2511
|
+
----------
|
|
2512
|
+
effect_size : float
|
|
2513
|
+
Expected treatment effect.
|
|
2514
|
+
n_treated : int
|
|
2515
|
+
Number of treated units.
|
|
2516
|
+
n_control : int
|
|
2517
|
+
Number of control units.
|
|
2518
|
+
sigma : float
|
|
2519
|
+
Residual standard deviation.
|
|
2520
|
+
alpha : float, default=0.05
|
|
2521
|
+
Significance level.
|
|
2522
|
+
n_pre : int, default=1
|
|
2523
|
+
Number of pre-treatment periods.
|
|
2524
|
+
n_post : int, default=1
|
|
2525
|
+
Number of post-treatment periods.
|
|
2526
|
+
rho : float, default=0.0
|
|
2527
|
+
Intra-cluster correlation.
|
|
2528
|
+
|
|
2529
|
+
Returns
|
|
2530
|
+
-------
|
|
2531
|
+
float
|
|
2532
|
+
Statistical power.
|
|
2533
|
+
|
|
2534
|
+
Examples
|
|
2535
|
+
--------
|
|
2536
|
+
>>> power = compute_power(effect_size=5.0, n_treated=50, n_control=50, sigma=10.0)
|
|
2537
|
+
>>> print(f"Power: {power:.1%}")
|
|
2538
|
+
"""
|
|
2539
|
+
pa = PowerAnalysis(alpha=alpha)
|
|
2540
|
+
result = pa.power(effect_size, n_treated, n_control, sigma, n_pre, n_post, rho)
|
|
2541
|
+
return result.power
|
|
2542
|
+
|
|
2543
|
+
|
|
2544
|
+
def compute_sample_size(
|
|
2545
|
+
effect_size: float,
|
|
2546
|
+
sigma: float,
|
|
2547
|
+
power: float = 0.80,
|
|
2548
|
+
alpha: float = 0.05,
|
|
2549
|
+
n_pre: int = 1,
|
|
2550
|
+
n_post: int = 1,
|
|
2551
|
+
rho: float = 0.0,
|
|
2552
|
+
treat_frac: float = 0.5,
|
|
2553
|
+
) -> int:
|
|
2554
|
+
"""
|
|
2555
|
+
Convenience function to compute required sample size.
|
|
2556
|
+
|
|
2557
|
+
Parameters
|
|
2558
|
+
----------
|
|
2559
|
+
effect_size : float
|
|
2560
|
+
Treatment effect to detect.
|
|
2561
|
+
sigma : float
|
|
2562
|
+
Residual standard deviation.
|
|
2563
|
+
power : float, default=0.80
|
|
2564
|
+
Target statistical power.
|
|
2565
|
+
alpha : float, default=0.05
|
|
2566
|
+
Significance level.
|
|
2567
|
+
n_pre : int, default=1
|
|
2568
|
+
Number of pre-treatment periods.
|
|
2569
|
+
n_post : int, default=1
|
|
2570
|
+
Number of post-treatment periods.
|
|
2571
|
+
rho : float, default=0.0
|
|
2572
|
+
Intra-cluster correlation.
|
|
2573
|
+
treat_frac : float, default=0.5
|
|
2574
|
+
Fraction assigned to treatment.
|
|
2575
|
+
|
|
2576
|
+
Returns
|
|
2577
|
+
-------
|
|
2578
|
+
int
|
|
2579
|
+
Required total sample size.
|
|
2580
|
+
|
|
2581
|
+
Examples
|
|
2582
|
+
--------
|
|
2583
|
+
>>> n = compute_sample_size(effect_size=5.0, sigma=10.0)
|
|
2584
|
+
>>> print(f"Required N: {n}")
|
|
2585
|
+
"""
|
|
2586
|
+
pa = PowerAnalysis(alpha=alpha, power=power)
|
|
2587
|
+
result = pa.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac)
|
|
2588
|
+
return result.required_n
|