diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/power.py ADDED
@@ -0,0 +1,2588 @@
1
+ """
2
+ Power analysis tools for difference-in-differences study design.
3
+
4
+ This module provides power calculations and simulation-based power analysis
5
+ for DiD study design, helping practitioners answer questions like:
6
+ - "How many units do I need to detect an effect of size X?"
7
+ - "What is the minimum detectable effect given my sample size?"
8
+ - "What power do I have to detect a given effect?"
9
+
10
+ References
11
+ ----------
12
+ Bloom, H. S. (1995). "Minimum Detectable Effects: A Simple Way to Report the
13
+ Statistical Power of Experimental Designs." Evaluation Review, 19(5), 547-556.
14
+
15
+ Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
16
+ Journal of Development Economics, 144, 102458.
17
+
18
+ Djimeu, E. W., & Houndolo, D.-G. (2016). "Power Calculation for Causal Inference
19
+ in Social Science: Sample Size and Minimum Detectable Effect Determination."
20
+ Journal of Development Effectiveness, 8(4), 508-527.
21
+ """
22
+
23
+ import warnings
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Callable, Dict, List, Optional, Tuple
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+ from scipy import stats
30
+
31
+ # Maximum sample size returned when effect is too small to detect
32
+ # (e.g., zero effect or extremely small relative to noise)
33
+ MAX_SAMPLE_SIZE = 2**31 - 1
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Estimator registry — maps estimator class names to DGP/fit/extract profiles
38
+ # ---------------------------------------------------------------------------
39
+
40
+
41
+ @dataclass
42
+ class _EstimatorProfile:
43
+ """Internal profile describing how to run power simulations for an estimator."""
44
+
45
+ default_dgp: Callable
46
+ dgp_kwargs_builder: Callable
47
+ fit_kwargs_builder: Callable
48
+ result_extractor: Callable
49
+ min_n: int = 20
50
+
51
+
52
+ # -- DGP kwargs adapters -----------------------------------------------------
53
+
54
+
55
+ def _basic_dgp_kwargs(
56
+ n_units: int,
57
+ n_periods: int,
58
+ treatment_effect: float,
59
+ treatment_fraction: float,
60
+ treatment_period: int,
61
+ sigma: float,
62
+ ) -> Dict[str, Any]:
63
+ return dict(
64
+ n_units=n_units,
65
+ n_periods=n_periods,
66
+ treatment_effect=treatment_effect,
67
+ treatment_fraction=treatment_fraction,
68
+ treatment_period=treatment_period,
69
+ noise_sd=sigma,
70
+ )
71
+
72
+
73
+ def _staggered_dgp_kwargs(
74
+ n_units: int,
75
+ n_periods: int,
76
+ treatment_effect: float,
77
+ treatment_fraction: float,
78
+ treatment_period: int,
79
+ sigma: float,
80
+ ) -> Dict[str, Any]:
81
+ return dict(
82
+ n_units=n_units,
83
+ n_periods=n_periods,
84
+ treatment_effect=treatment_effect,
85
+ never_treated_frac=1 - treatment_fraction,
86
+ cohort_periods=[treatment_period],
87
+ dynamic_effects=False,
88
+ noise_sd=sigma,
89
+ )
90
+
91
+
92
+ def _factor_dgp_kwargs(
93
+ n_units: int,
94
+ n_periods: int,
95
+ treatment_effect: float,
96
+ treatment_fraction: float,
97
+ treatment_period: int,
98
+ sigma: float,
99
+ ) -> Dict[str, Any]:
100
+ n_pre = treatment_period
101
+ n_post = n_periods - treatment_period
102
+ return dict(
103
+ n_units=n_units,
104
+ n_pre=n_pre,
105
+ n_post=n_post,
106
+ n_treated=max(1, int(n_units * treatment_fraction)),
107
+ treatment_effect=treatment_effect,
108
+ noise_sd=sigma,
109
+ )
110
+
111
+
112
+ def _ddd_dgp_kwargs(
113
+ n_units: int,
114
+ n_periods: int,
115
+ treatment_effect: float,
116
+ treatment_fraction: float,
117
+ treatment_period: int,
118
+ sigma: float,
119
+ ) -> Dict[str, Any]:
120
+ return dict(
121
+ n_per_cell=max(2, n_units // 8),
122
+ treatment_effect=treatment_effect,
123
+ noise_sd=sigma,
124
+ )
125
+
126
+
127
+ # -- Fit kwargs builders ------------------------------------------------------
128
+
129
+
130
+ def _basic_fit_kwargs(
131
+ data: pd.DataFrame,
132
+ n_units: int,
133
+ n_periods: int,
134
+ treatment_period: int,
135
+ ) -> Dict[str, Any]:
136
+ return dict(outcome="outcome", treatment="treated", time="post")
137
+
138
+
139
+ def _twfe_fit_kwargs(
140
+ data: pd.DataFrame,
141
+ n_units: int,
142
+ n_periods: int,
143
+ treatment_period: int,
144
+ ) -> Dict[str, Any]:
145
+ return dict(outcome="outcome", treatment="treated", time="post", unit="unit")
146
+
147
+
148
+ def _multiperiod_fit_kwargs(
149
+ data: pd.DataFrame,
150
+ n_units: int,
151
+ n_periods: int,
152
+ treatment_period: int,
153
+ ) -> Dict[str, Any]:
154
+ return dict(
155
+ outcome="outcome",
156
+ treatment="treated",
157
+ time="period",
158
+ post_periods=list(range(treatment_period, n_periods)),
159
+ )
160
+
161
+
162
+ def _staggered_fit_kwargs(
163
+ data: pd.DataFrame,
164
+ n_units: int,
165
+ n_periods: int,
166
+ treatment_period: int,
167
+ ) -> Dict[str, Any]:
168
+ return dict(outcome="outcome", unit="unit", time="period", first_treat="first_treat")
169
+
170
+
171
+ def _ddd_fit_kwargs(
172
+ data: pd.DataFrame,
173
+ n_units: int,
174
+ n_periods: int,
175
+ treatment_period: int,
176
+ ) -> Dict[str, Any]:
177
+ return dict(outcome="outcome", group="group", partition="partition", time="time")
178
+
179
+
180
+ def _trop_fit_kwargs(
181
+ data: pd.DataFrame,
182
+ n_units: int,
183
+ n_periods: int,
184
+ treatment_period: int,
185
+ ) -> Dict[str, Any]:
186
+ return dict(outcome="outcome", treatment="treated", unit="unit", time="period")
187
+
188
+
189
+ def _sdid_fit_kwargs(
190
+ data: pd.DataFrame,
191
+ n_units: int,
192
+ n_periods: int,
193
+ treatment_period: int,
194
+ ) -> Dict[str, Any]:
195
+ periods = sorted(data["period"].unique())
196
+ post_periods = [p for p in periods if p >= treatment_period]
197
+ return dict(
198
+ outcome="outcome",
199
+ treatment="treat",
200
+ unit="unit",
201
+ time="period",
202
+ post_periods=post_periods,
203
+ )
204
+
205
+
206
+ # -- Result extractors --------------------------------------------------------
207
+
208
+
209
+ def _extract_simple(result: Any) -> Tuple[float, float, float, Tuple[float, float]]:
210
+ return (result.att, result.se, result.p_value, result.conf_int)
211
+
212
+
213
+ def _extract_multiperiod(
214
+ result: Any,
215
+ ) -> Tuple[float, float, float, Tuple[float, float]]:
216
+ return (result.avg_att, result.avg_se, result.avg_p_value, result.avg_conf_int)
217
+
218
+
219
+ def _extract_staggered(
220
+ result: Any,
221
+ ) -> Tuple[float, float, float, Tuple[float, float]]:
222
+ _nan = float("nan")
223
+ _nan_ci = (_nan, _nan)
224
+
225
+ def _first(r: Any, *attrs: str, default: Any = _nan) -> Any:
226
+ for a in attrs:
227
+ v = getattr(r, a, None)
228
+ if v is not None:
229
+ return v
230
+ return default
231
+
232
+ return (
233
+ result.overall_att,
234
+ _first(result, "overall_se", "overall_att_se"),
235
+ _first(result, "overall_p_value", "overall_att_p_value"),
236
+ _first(result, "overall_conf_int", "overall_att_ci", default=_nan_ci),
237
+ )
238
+
239
+
240
+ # Keys derived from simulate_power() public params — overriding these
241
+ # via data_generator_kwargs would desync the DGP from the result object.
242
+ _PROTECTED_DGP_KEYS = frozenset(
243
+ {
244
+ "treatment_effect", # → true_effect in results / MDE search variable
245
+ "noise_sd", # → sigma param
246
+ "n_units", # → sample-size search variable
247
+ "n_periods", # → n_periods param
248
+ "treatment_fraction", # → treatment_fraction param
249
+ "treatment_period", # → treatment_period param
250
+ "n_pre", # → derived from treatment_period in factor-model DGPs
251
+ "n_post", # → derived from n_periods - treatment_period in factor-model DGPs
252
+ }
253
+ )
254
+
255
+
256
+ # -- Staggered DGP compatibility check ----------------------------------------
257
+
258
+ _STAGGERED_ESTIMATORS = frozenset(
259
+ {
260
+ "CallawaySantAnna",
261
+ "SunAbraham",
262
+ "ImputationDiD",
263
+ "TwoStageDiD",
264
+ "StackedDiD",
265
+ "EfficientDiD",
266
+ }
267
+ )
268
+
269
+
270
+ def _check_staggered_dgp_compat(
271
+ estimator: Any,
272
+ data_generator_kwargs: Optional[Dict[str, Any]],
273
+ ) -> None:
274
+ """Warn if a staggered estimator's settings don't match the default DGP."""
275
+ name = type(estimator).__name__
276
+ if name not in _STAGGERED_ESTIMATORS:
277
+ return
278
+
279
+ dgp_overrides = data_generator_kwargs or {}
280
+ cohort_periods = dgp_overrides.get("cohort_periods")
281
+ has_multi_cohort = cohort_periods is not None and len(set(cohort_periods)) >= 2
282
+ issues: List[str] = []
283
+
284
+ # Check control_group="not_yet_treated" (CS, SA)
285
+ cg = getattr(estimator, "control_group", "never_treated")
286
+ if cg == "not_yet_treated" and not has_multi_cohort:
287
+ issues.append(
288
+ f' - {name} has control_group="not_yet_treated" but the default '
289
+ f"DGP generates a single treatment cohort with never-treated "
290
+ f"controls. Power may not reflect the intended not-yet-treated "
291
+ f"design.\n"
292
+ f" Fix: pass data_generator_kwargs="
293
+ f'{{"cohort_periods": [2, 4], "never_treated_frac": 0.0}} '
294
+ f"(or a custom data_generator)."
295
+ )
296
+
297
+ # Check anticipation > 0 (all staggered)
298
+ antic = getattr(estimator, "anticipation", 0)
299
+ if antic > 0:
300
+ issues.append(
301
+ f" - {name} has anticipation={antic} but the default DGP does "
302
+ f"not model anticipatory effects. The estimator will look for "
303
+ f"treatment effects {antic} period(s) before the DGP generates "
304
+ f"them, biasing power estimates.\n"
305
+ f" Fix: supply a custom data_generator that shifts the "
306
+ f"effect onset."
307
+ )
308
+
309
+ # Check clean_control on StackedDiD
310
+ if name == "StackedDiD":
311
+ cc = getattr(estimator, "clean_control", "not_yet_treated")
312
+ if cc == "strict" and not has_multi_cohort:
313
+ issues.append(
314
+ ' - StackedDiD has clean_control="strict" but the default '
315
+ "single-cohort DGP makes strict controls equivalent to "
316
+ "never-treated controls.\n"
317
+ " Fix: pass data_generator_kwargs="
318
+ '{"cohort_periods": [2, 4]} '
319
+ "to test true strict clean-control behavior."
320
+ )
321
+
322
+ if issues:
323
+ msg = (
324
+ f"Staggered power DGP mismatch for {name}. The default "
325
+ f"single-cohort DGP may not match the estimator "
326
+ f"configuration:\n" + "\n".join(issues)
327
+ )
328
+ warnings.warn(msg, UserWarning, stacklevel=2)
329
+
330
+
331
+ def _ddd_effective_n(
332
+ n_units: int, data_generator_kwargs: Optional[Dict[str, Any]]
333
+ ) -> Optional[int]:
334
+ """Return effective DDD sample size, or None if no rounding occurred."""
335
+ overrides = data_generator_kwargs or {}
336
+ if "n_per_cell" in overrides:
337
+ eff = overrides["n_per_cell"] * 8
338
+ else:
339
+ eff = max(2, n_units // 8) * 8
340
+ return eff if eff != n_units else None
341
+
342
+
343
+ def _check_ddd_dgp_compat(
344
+ n_units: int,
345
+ n_periods: int,
346
+ treatment_fraction: float,
347
+ treatment_period: int,
348
+ data_generator_kwargs: Optional[Dict[str, Any]],
349
+ ) -> None:
350
+ """Warn when simulation inputs don't match DDD's fixed 2×2×2 design."""
351
+ issues: List[str] = []
352
+
353
+ # DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored
354
+ if n_periods != 2:
355
+ issues.append(
356
+ f"n_periods={n_periods} is ignored (DDD uses a fixed " f"2-period design: pre/post)"
357
+ )
358
+ if treatment_period != 1:
359
+ issues.append(
360
+ f"treatment_period={treatment_period} is ignored (DDD "
361
+ f"always treats in the second period)"
362
+ )
363
+
364
+ # DDD's 2×2×2 factorial has inherent 50% treatment fraction
365
+ if treatment_fraction != 0.5:
366
+ issues.append(
367
+ f"treatment_fraction={treatment_fraction} is ignored "
368
+ f"(DDD uses a balanced 2×2×2 factorial where 50% of "
369
+ f"groups are treated)"
370
+ )
371
+
372
+ # n_units rounding: n_per_cell = max(2, n_units // 8)
373
+ eff_n = _ddd_effective_n(n_units, data_generator_kwargs)
374
+ if eff_n is not None:
375
+ eff_n_per_cell = eff_n // 8
376
+ issues.append(
377
+ f"effective sample size is {eff_n} "
378
+ f"(n_per_cell={eff_n_per_cell} × 8 cells), "
379
+ f"not the requested n_units={n_units}"
380
+ )
381
+
382
+ if issues:
383
+ warnings.warn(
384
+ "TripleDifference uses a fixed 2×2×2 factorial DGP "
385
+ "(group × partition × time). "
386
+ + "; ".join(issues)
387
+ + ". Pass a custom data_generator for non-standard DDD designs.",
388
+ UserWarning,
389
+ stacklevel=2,
390
+ )
391
+
392
+
393
+ def _check_sdid_placebo_data(
394
+ data: pd.DataFrame,
395
+ estimator: Any,
396
+ est_kwargs: Dict[str, Any],
397
+ ) -> None:
398
+ """Check SyntheticDiD placebo feasibility on realized data.
399
+
400
+ This catches infeasible designs on the custom-DGP path where the
401
+ pre-generation check (which uses ``n_units * treatment_fraction``)
402
+ cannot run because treatment allocation is determined by the DGP.
403
+ """
404
+ vm = getattr(estimator, "variance_method", "placebo")
405
+ if vm != "placebo":
406
+ return
407
+
408
+ treat_col = est_kwargs.get("treatment", "treat")
409
+ unit_col = est_kwargs.get("unit", "unit")
410
+
411
+ if treat_col not in data.columns or unit_col not in data.columns:
412
+ return # fit will fail with a more specific error
413
+
414
+ unit_treat = data.groupby(unit_col)[treat_col].first()
415
+ n_treated = int(unit_treat.sum())
416
+ n_control = len(unit_treat) - n_treated
417
+
418
+ if n_control <= n_treated:
419
+ raise ValueError(
420
+ f"SyntheticDiD placebo variance requires more control than "
421
+ f"treated units, but the generated data has n_control={n_control}, "
422
+ f"n_treated={n_treated}. Either adjust your data_generator so that "
423
+ f"n_control > n_treated, or use "
424
+ f"SyntheticDiD(variance_method='bootstrap')."
425
+ )
426
+
427
+
428
+ # -- Registry construction (deferred to avoid import-time cost) ---------------
429
+
430
+ _ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None
431
+
432
+
433
+ def _get_registry() -> Dict[str, _EstimatorProfile]:
434
+ """Lazily build and return the estimator registry."""
435
+ global _ESTIMATOR_REGISTRY # noqa: PLW0603
436
+ if _ESTIMATOR_REGISTRY is not None:
437
+ return _ESTIMATOR_REGISTRY
438
+
439
+ from diff_diff.prep import (
440
+ generate_ddd_data,
441
+ generate_did_data,
442
+ generate_factor_data,
443
+ generate_staggered_data,
444
+ )
445
+
446
+ _ESTIMATOR_REGISTRY = {
447
+ # --- Basic DiD group ---
448
+ "DifferenceInDifferences": _EstimatorProfile(
449
+ default_dgp=generate_did_data,
450
+ dgp_kwargs_builder=_basic_dgp_kwargs,
451
+ fit_kwargs_builder=_basic_fit_kwargs,
452
+ result_extractor=_extract_simple,
453
+ min_n=20,
454
+ ),
455
+ "TwoWayFixedEffects": _EstimatorProfile(
456
+ default_dgp=generate_did_data,
457
+ dgp_kwargs_builder=_basic_dgp_kwargs,
458
+ fit_kwargs_builder=_twfe_fit_kwargs,
459
+ result_extractor=_extract_simple,
460
+ min_n=20,
461
+ ),
462
+ "MultiPeriodDiD": _EstimatorProfile(
463
+ default_dgp=generate_did_data,
464
+ dgp_kwargs_builder=_basic_dgp_kwargs,
465
+ fit_kwargs_builder=_multiperiod_fit_kwargs,
466
+ result_extractor=_extract_multiperiod,
467
+ min_n=20,
468
+ ),
469
+ # --- Staggered group ---
470
+ "CallawaySantAnna": _EstimatorProfile(
471
+ default_dgp=generate_staggered_data,
472
+ dgp_kwargs_builder=_staggered_dgp_kwargs,
473
+ fit_kwargs_builder=_staggered_fit_kwargs,
474
+ result_extractor=_extract_staggered,
475
+ min_n=40,
476
+ ),
477
+ "SunAbraham": _EstimatorProfile(
478
+ default_dgp=generate_staggered_data,
479
+ dgp_kwargs_builder=_staggered_dgp_kwargs,
480
+ fit_kwargs_builder=_staggered_fit_kwargs,
481
+ result_extractor=_extract_staggered,
482
+ min_n=40,
483
+ ),
484
+ "ImputationDiD": _EstimatorProfile(
485
+ default_dgp=generate_staggered_data,
486
+ dgp_kwargs_builder=_staggered_dgp_kwargs,
487
+ fit_kwargs_builder=_staggered_fit_kwargs,
488
+ result_extractor=_extract_staggered,
489
+ min_n=40,
490
+ ),
491
+ "TwoStageDiD": _EstimatorProfile(
492
+ default_dgp=generate_staggered_data,
493
+ dgp_kwargs_builder=_staggered_dgp_kwargs,
494
+ fit_kwargs_builder=_staggered_fit_kwargs,
495
+ result_extractor=_extract_staggered,
496
+ min_n=40,
497
+ ),
498
+ "StackedDiD": _EstimatorProfile(
499
+ default_dgp=generate_staggered_data,
500
+ dgp_kwargs_builder=_staggered_dgp_kwargs,
501
+ fit_kwargs_builder=_staggered_fit_kwargs,
502
+ result_extractor=_extract_staggered,
503
+ min_n=40,
504
+ ),
505
+ "EfficientDiD": _EstimatorProfile(
506
+ default_dgp=generate_staggered_data,
507
+ dgp_kwargs_builder=_staggered_dgp_kwargs,
508
+ fit_kwargs_builder=_staggered_fit_kwargs,
509
+ result_extractor=_extract_staggered,
510
+ min_n=40,
511
+ ),
512
+ # --- Factor model group ---
513
+ "TROP": _EstimatorProfile(
514
+ default_dgp=generate_factor_data,
515
+ dgp_kwargs_builder=_factor_dgp_kwargs,
516
+ fit_kwargs_builder=_trop_fit_kwargs,
517
+ result_extractor=_extract_simple,
518
+ min_n=30,
519
+ ),
520
+ "SyntheticDiD": _EstimatorProfile(
521
+ default_dgp=generate_factor_data,
522
+ dgp_kwargs_builder=_factor_dgp_kwargs,
523
+ fit_kwargs_builder=_sdid_fit_kwargs,
524
+ result_extractor=_extract_simple,
525
+ min_n=30,
526
+ ),
527
+ # --- Triple difference ---
528
+ "TripleDifference": _EstimatorProfile(
529
+ default_dgp=generate_ddd_data,
530
+ dgp_kwargs_builder=_ddd_dgp_kwargs,
531
+ fit_kwargs_builder=_ddd_fit_kwargs,
532
+ result_extractor=_extract_simple,
533
+ min_n=64,
534
+ ),
535
+ }
536
+ return _ESTIMATOR_REGISTRY
537
+
538
+
539
+ @dataclass
540
+ class PowerResults:
541
+ """
542
+ Results from analytical power analysis.
543
+
544
+ Attributes
545
+ ----------
546
+ power : float
547
+ Statistical power (probability of rejecting H0 when effect exists).
548
+ mde : float
549
+ Minimum detectable effect size.
550
+ required_n : int
551
+ Required total sample size (treated + control).
552
+ effect_size : float
553
+ Effect size used in calculation.
554
+ alpha : float
555
+ Significance level.
556
+ alternative : str
557
+ Alternative hypothesis ('two-sided', 'greater', 'less').
558
+ n_treated : int
559
+ Number of treated units.
560
+ n_control : int
561
+ Number of control units.
562
+ n_pre : int
563
+ Number of pre-treatment periods.
564
+ n_post : int
565
+ Number of post-treatment periods.
566
+ sigma : float
567
+ Residual standard deviation.
568
+ rho : float
569
+ Intra-cluster correlation (for panel data).
570
+ design : str
571
+ Study design type ('basic_did', 'panel', 'staggered').
572
+ """
573
+
574
+ power: float
575
+ mde: float
576
+ required_n: int
577
+ effect_size: float
578
+ alpha: float
579
+ alternative: str
580
+ n_treated: int
581
+ n_control: int
582
+ n_pre: int
583
+ n_post: int
584
+ sigma: float
585
+ rho: float = 0.0
586
+ design: str = "basic_did"
587
+
588
+ def __repr__(self) -> str:
589
+ """Concise string representation."""
590
+ return (
591
+ f"PowerResults(power={self.power:.3f}, mde={self.mde:.4f}, "
592
+ f"required_n={self.required_n})"
593
+ )
594
+
595
+ def summary(self) -> str:
596
+ """
597
+ Generate a formatted summary of power analysis results.
598
+
599
+ Returns
600
+ -------
601
+ str
602
+ Formatted summary table.
603
+ """
604
+ lines = [
605
+ "=" * 60,
606
+ "Power Analysis for Difference-in-Differences".center(60),
607
+ "=" * 60,
608
+ "",
609
+ f"{'Design:':<30} {self.design}",
610
+ f"{'Significance level (alpha):':<30} {self.alpha:.3f}",
611
+ f"{'Alternative hypothesis:':<30} {self.alternative}",
612
+ "",
613
+ "-" * 60,
614
+ "Sample Size".center(60),
615
+ "-" * 60,
616
+ f"{'Treated units:':<30} {self.n_treated:>10}",
617
+ f"{'Control units:':<30} {self.n_control:>10}",
618
+ f"{'Total units:':<30} {self.n_treated + self.n_control:>10}",
619
+ f"{'Pre-treatment periods:':<30} {self.n_pre:>10}",
620
+ f"{'Post-treatment periods:':<30} {self.n_post:>10}",
621
+ "",
622
+ "-" * 60,
623
+ "Variance Parameters".center(60),
624
+ "-" * 60,
625
+ f"{'Residual SD (sigma):':<30} {self.sigma:>10.4f}",
626
+ f"{'Intra-cluster correlation:':<30} {self.rho:>10.4f}",
627
+ "",
628
+ "-" * 60,
629
+ "Power Analysis Results".center(60),
630
+ "-" * 60,
631
+ f"{'Effect size:':<30} {self.effect_size:>10.4f}",
632
+ f"{'Power:':<30} {self.power:>10.1%}",
633
+ f"{'Minimum detectable effect:':<30} {self.mde:>10.4f}",
634
+ f"{'Required sample size:':<30} {self.required_n:>10}",
635
+ "=" * 60,
636
+ ]
637
+ return "\n".join(lines)
638
+
639
+ def print_summary(self) -> None:
640
+ """Print the summary to stdout."""
641
+ print(self.summary())
642
+
643
+ def to_dict(self) -> Dict[str, Any]:
644
+ """
645
+ Convert results to a dictionary.
646
+
647
+ Returns
648
+ -------
649
+ Dict[str, Any]
650
+ Dictionary containing all power analysis results.
651
+ """
652
+ return {
653
+ "power": self.power,
654
+ "mde": self.mde,
655
+ "required_n": self.required_n,
656
+ "effect_size": self.effect_size,
657
+ "alpha": self.alpha,
658
+ "alternative": self.alternative,
659
+ "n_treated": self.n_treated,
660
+ "n_control": self.n_control,
661
+ "n_pre": self.n_pre,
662
+ "n_post": self.n_post,
663
+ "sigma": self.sigma,
664
+ "rho": self.rho,
665
+ "design": self.design,
666
+ }
667
+
668
+ def to_dataframe(self) -> pd.DataFrame:
669
+ """
670
+ Convert results to a pandas DataFrame.
671
+
672
+ Returns
673
+ -------
674
+ pd.DataFrame
675
+ DataFrame with power analysis results.
676
+ """
677
+ return pd.DataFrame([self.to_dict()])
678
+
679
+
680
+ @dataclass
681
+ class SimulationPowerResults:
682
+ """
683
+ Results from simulation-based power analysis.
684
+
685
+ Attributes
686
+ ----------
687
+ power : float
688
+ Estimated power (proportion of simulations rejecting H0).
689
+ power_se : float
690
+ Standard error of power estimate.
691
+ power_ci : Tuple[float, float]
692
+ Confidence interval for power estimate.
693
+ rejection_rate : float
694
+ Proportion of simulations with p-value < alpha.
695
+ mean_estimate : float
696
+ Mean treatment effect estimate across simulations.
697
+ std_estimate : float
698
+ Standard deviation of estimates across simulations.
699
+ mean_se : float
700
+ Mean standard error across simulations.
701
+ coverage : float
702
+ Proportion of CIs containing true effect.
703
+ n_simulations : int
704
+ Number of simulations performed.
705
+ effect_sizes : List[float]
706
+ Effect sizes tested (if multiple).
707
+ powers : List[float]
708
+ Power at each effect size (if multiple).
709
+ true_effect : float
710
+ True treatment effect used in simulation.
711
+ alpha : float
712
+ Significance level.
713
+ estimator_name : str
714
+ Name of the estimator used.
715
+ effective_n_units : int or None
716
+ Effective sample size when it differs from the requested ``n_units``
717
+ (e.g., due to DDD grid rounding). ``None`` when no rounding occurred.
718
+ """
719
+
720
+ power: float
721
+ power_se: float
722
+ power_ci: Tuple[float, float]
723
+ rejection_rate: float
724
+ mean_estimate: float
725
+ std_estimate: float
726
+ mean_se: float
727
+ coverage: float
728
+ n_simulations: int
729
+ effect_sizes: List[float]
730
+ powers: List[float]
731
+ true_effect: float
732
+ alpha: float
733
+ estimator_name: str
734
+ bias: float = field(init=False)
735
+ rmse: float = field(init=False)
736
+ simulation_results: Optional[List[Dict[str, Any]]] = field(default=None, repr=False)
737
+ effective_n_units: Optional[int] = None
738
+
739
+ def __post_init__(self):
740
+ """Compute derived statistics."""
741
+ self.bias = self.mean_estimate - self.true_effect
742
+ self.rmse = np.sqrt(self.bias**2 + self.std_estimate**2)
743
+
744
+ def __repr__(self) -> str:
745
+ """Concise string representation."""
746
+ return (
747
+ f"SimulationPowerResults(power={self.power:.3f} "
748
+ f"[{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}], "
749
+ f"n_simulations={self.n_simulations})"
750
+ )
751
+
752
+ def summary(self) -> str:
753
+ """
754
+ Generate a formatted summary of simulation power results.
755
+
756
+ Returns
757
+ -------
758
+ str
759
+ Formatted summary table.
760
+ """
761
+ lines = [
762
+ "=" * 65,
763
+ "Simulation-Based Power Analysis Results".center(65),
764
+ "=" * 65,
765
+ "",
766
+ f"{'Estimator:':<35} {self.estimator_name}",
767
+ f"{'Number of simulations:':<35} {self.n_simulations}",
768
+ f"{'True treatment effect:':<35} {self.true_effect:.4f}",
769
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
770
+ "",
771
+ "-" * 65,
772
+ "Power Estimates".center(65),
773
+ "-" * 65,
774
+ f"{'Power (rejection rate):':<35} {self.power:.1%}",
775
+ f"{'Standard error:':<35} {self.power_se:.4f}",
776
+ f"{'95% CI:':<35} [{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}]",
777
+ "",
778
+ "-" * 65,
779
+ "Estimation Performance".center(65),
780
+ "-" * 65,
781
+ f"{'Mean estimate:':<35} {self.mean_estimate:.4f}",
782
+ f"{'Bias:':<35} {self.bias:.4f}",
783
+ f"{'Std. deviation of estimates:':<35} {self.std_estimate:.4f}",
784
+ f"{'RMSE:':<35} {self.rmse:.4f}",
785
+ f"{'Mean standard error:':<35} {self.mean_se:.4f}",
786
+ f"{'Coverage (CI contains true):':<35} {self.coverage:.1%}",
787
+ ]
788
+ if self.effective_n_units is not None:
789
+ lines.append(
790
+ f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
791
+ )
792
+ lines.append("=" * 65)
793
+ return "\n".join(lines)
794
+
795
+ def print_summary(self) -> None:
796
+ """Print the summary to stdout."""
797
+ print(self.summary())
798
+
799
+ def to_dict(self) -> Dict[str, Any]:
800
+ """
801
+ Convert results to a dictionary.
802
+
803
+ Returns
804
+ -------
805
+ Dict[str, Any]
806
+ Dictionary containing simulation power results.
807
+ """
808
+ d: Dict[str, Any] = {
809
+ "power": self.power,
810
+ "power_se": self.power_se,
811
+ "power_ci_lower": self.power_ci[0],
812
+ "power_ci_upper": self.power_ci[1],
813
+ "rejection_rate": self.rejection_rate,
814
+ "mean_estimate": self.mean_estimate,
815
+ "std_estimate": self.std_estimate,
816
+ "bias": self.bias,
817
+ "rmse": self.rmse,
818
+ "mean_se": self.mean_se,
819
+ "coverage": self.coverage,
820
+ "n_simulations": self.n_simulations,
821
+ "true_effect": self.true_effect,
822
+ "alpha": self.alpha,
823
+ "estimator_name": self.estimator_name,
824
+ "effective_n_units": self.effective_n_units,
825
+ }
826
+ return d
827
+
828
+ def to_dataframe(self) -> pd.DataFrame:
829
+ """
830
+ Convert results to a pandas DataFrame.
831
+
832
+ Returns
833
+ -------
834
+ pd.DataFrame
835
+ DataFrame with simulation power results.
836
+ """
837
+ return pd.DataFrame([self.to_dict()])
838
+
839
+ def power_curve_df(self) -> pd.DataFrame:
840
+ """
841
+ Get power curve data as a DataFrame.
842
+
843
+ Returns
844
+ -------
845
+ pd.DataFrame
846
+ DataFrame with effect_size and power columns.
847
+ """
848
+ return pd.DataFrame({"effect_size": self.effect_sizes, "power": self.powers})
849
+
850
+
851
+ class PowerAnalysis:
852
+ """
853
+ Power analysis for difference-in-differences designs.
854
+
855
+ Provides analytical power calculations for basic 2x2 DiD and panel DiD
856
+ designs. For complex designs like staggered adoption, use simulate_power()
857
+ instead.
858
+
859
+ Parameters
860
+ ----------
861
+ alpha : float, default=0.05
862
+ Significance level for hypothesis testing.
863
+ power : float, default=0.80
864
+ Target statistical power.
865
+ alternative : str, default='two-sided'
866
+ Alternative hypothesis: 'two-sided', 'greater', or 'less'.
867
+
868
+ Examples
869
+ --------
870
+ Calculate minimum detectable effect:
871
+
872
+ >>> from diff_diff import PowerAnalysis
873
+ >>> pa = PowerAnalysis(alpha=0.05, power=0.80)
874
+ >>> results = pa.mde(n_treated=50, n_control=50, sigma=1.0)
875
+ >>> print(f"MDE: {results.mde:.3f}")
876
+
877
+ Calculate required sample size:
878
+
879
+ >>> results = pa.sample_size(effect_size=0.5, sigma=1.0)
880
+ >>> print(f"Required N: {results.required_n}")
881
+
882
+ Calculate power for given sample and effect:
883
+
884
+ >>> results = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0)
885
+ >>> print(f"Power: {results.power:.1%}")
886
+
887
+ Notes
888
+ -----
889
+ The power calculations are based on the variance of the DiD estimator:
890
+
891
+ For basic 2x2 DiD:
892
+ Var(ATT) = sigma^2 * (1/n_treated_post + 1/n_treated_pre
893
+ + 1/n_control_post + 1/n_control_pre)
894
+
895
+ For panel DiD with T periods:
896
+ Var(ATT) = sigma^2 * (1/(N_treated * T) + 1/(N_control * T))
897
+ * (1 + (T-1)*rho) / (1 + (T-1)*rho)
898
+
899
+ Where rho is the intra-cluster correlation coefficient.
900
+
901
+ References
902
+ ----------
903
+ Bloom, H. S. (1995). "Minimum Detectable Effects."
904
+ Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
905
+ """
906
+
907
+ def __init__(
908
+ self,
909
+ alpha: float = 0.05,
910
+ power: float = 0.80,
911
+ alternative: str = "two-sided",
912
+ ):
913
+ if not 0 < alpha < 1:
914
+ raise ValueError("alpha must be between 0 and 1")
915
+ if not 0 < power < 1:
916
+ raise ValueError("power must be between 0 and 1")
917
+ if alternative not in ("two-sided", "greater", "less"):
918
+ raise ValueError("alternative must be 'two-sided', 'greater', or 'less'")
919
+
920
+ self.alpha = alpha
921
+ self.target_power = power
922
+ self.alternative = alternative
923
+
924
+ def _get_critical_values(self) -> Tuple[float, float]:
925
+ """Get z critical values for alpha and power."""
926
+ if self.alternative == "two-sided":
927
+ z_alpha = stats.norm.ppf(1 - self.alpha / 2)
928
+ else:
929
+ z_alpha = stats.norm.ppf(1 - self.alpha)
930
+ z_beta = stats.norm.ppf(self.target_power)
931
+ return z_alpha, z_beta
932
+
933
+ def _compute_variance(
934
+ self,
935
+ n_treated: int,
936
+ n_control: int,
937
+ n_pre: int,
938
+ n_post: int,
939
+ sigma: float,
940
+ rho: float = 0.0,
941
+ design: str = "basic_did",
942
+ ) -> float:
943
+ """
944
+ Compute variance of the DiD estimator.
945
+
946
+ Parameters
947
+ ----------
948
+ n_treated : int
949
+ Number of treated units.
950
+ n_control : int
951
+ Number of control units.
952
+ n_pre : int
953
+ Number of pre-treatment periods.
954
+ n_post : int
955
+ Number of post-treatment periods.
956
+ sigma : float
957
+ Residual standard deviation.
958
+ rho : float
959
+ Intra-cluster correlation (for panel data).
960
+ design : str
961
+ Study design type.
962
+
963
+ Returns
964
+ -------
965
+ float
966
+ Variance of the DiD estimator.
967
+ """
968
+ if design == "basic_did":
969
+ # For basic 2x2 DiD, each cell has n_treated/2 or n_control/2 obs
970
+ # assuming balanced design
971
+ n_t_pre = n_treated # treated units in pre-period
972
+ n_t_post = n_treated # treated units in post-period
973
+ n_c_pre = n_control
974
+ n_c_post = n_control
975
+
976
+ variance = sigma**2 * (1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre)
977
+ elif design == "panel":
978
+ # Panel DiD with multiple periods
979
+ # Account for serial correlation via ICC
980
+ T = n_pre + n_post
981
+
982
+ # Design effect for clustering
983
+ design_effect = 1 + (T - 1) * rho
984
+
985
+ # Base variance (as if independent)
986
+ base_var = sigma**2 * (1 / n_treated + 1 / n_control)
987
+
988
+ # Adjust for clustering (Moulton factor)
989
+ variance = base_var * design_effect / T
990
+ else:
991
+ raise ValueError(f"Unknown design: {design}")
992
+
993
+ return variance
994
+
995
+ def power(
996
+ self,
997
+ effect_size: float,
998
+ n_treated: int,
999
+ n_control: int,
1000
+ sigma: float,
1001
+ n_pre: int = 1,
1002
+ n_post: int = 1,
1003
+ rho: float = 0.0,
1004
+ ) -> PowerResults:
1005
+ """
1006
+ Calculate statistical power for given effect size and sample.
1007
+
1008
+ Parameters
1009
+ ----------
1010
+ effect_size : float
1011
+ Expected treatment effect size.
1012
+ n_treated : int
1013
+ Number of treated units.
1014
+ n_control : int
1015
+ Number of control units.
1016
+ sigma : float
1017
+ Residual standard deviation.
1018
+ n_pre : int, default=1
1019
+ Number of pre-treatment periods.
1020
+ n_post : int, default=1
1021
+ Number of post-treatment periods.
1022
+ rho : float, default=0.0
1023
+ Intra-cluster correlation for panel data.
1024
+
1025
+ Returns
1026
+ -------
1027
+ PowerResults
1028
+ Power analysis results.
1029
+
1030
+ Examples
1031
+ --------
1032
+ >>> pa = PowerAnalysis()
1033
+ >>> results = pa.power(effect_size=2.0, n_treated=50, n_control=50, sigma=5.0)
1034
+ >>> print(f"Power: {results.power:.1%}")
1035
+ """
1036
+ T = n_pre + n_post
1037
+ design = "panel" if T > 2 else "basic_did"
1038
+
1039
+ variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design)
1040
+ se = np.sqrt(variance)
1041
+
1042
+ # Calculate power
1043
+ if self.alternative == "two-sided":
1044
+ z_alpha = stats.norm.ppf(1 - self.alpha / 2)
1045
+ # Power = P(reject | effect) = P(|Z| > z_alpha | effect)
1046
+ power_val = (
1047
+ 1
1048
+ - stats.norm.cdf(z_alpha - effect_size / se)
1049
+ + stats.norm.cdf(-z_alpha - effect_size / se)
1050
+ )
1051
+ elif self.alternative == "greater":
1052
+ z_alpha = stats.norm.ppf(1 - self.alpha)
1053
+ power_val = 1 - stats.norm.cdf(z_alpha - effect_size / se)
1054
+ else: # less
1055
+ z_alpha = stats.norm.ppf(1 - self.alpha)
1056
+ power_val = stats.norm.cdf(-z_alpha - effect_size / se)
1057
+
1058
+ # Also compute MDE and required N for reference
1059
+ mde = self._compute_mde_from_se(se)
1060
+ required_n = self._compute_required_n(
1061
+ effect_size, sigma, n_pre, n_post, rho, design, n_treated / (n_treated + n_control)
1062
+ )
1063
+
1064
+ return PowerResults(
1065
+ power=power_val,
1066
+ mde=mde,
1067
+ required_n=required_n,
1068
+ effect_size=effect_size,
1069
+ alpha=self.alpha,
1070
+ alternative=self.alternative,
1071
+ n_treated=n_treated,
1072
+ n_control=n_control,
1073
+ n_pre=n_pre,
1074
+ n_post=n_post,
1075
+ sigma=sigma,
1076
+ rho=rho,
1077
+ design=design,
1078
+ )
1079
+
1080
+ def _compute_mde_from_se(self, se: float) -> float:
1081
+ """Compute MDE given standard error."""
1082
+ z_alpha, z_beta = self._get_critical_values()
1083
+ return (z_alpha + z_beta) * se
1084
+
1085
+ def mde(
1086
+ self,
1087
+ n_treated: int,
1088
+ n_control: int,
1089
+ sigma: float,
1090
+ n_pre: int = 1,
1091
+ n_post: int = 1,
1092
+ rho: float = 0.0,
1093
+ ) -> PowerResults:
1094
+ """
1095
+ Calculate minimum detectable effect given sample size.
1096
+
1097
+ The MDE is the smallest effect size that can be detected with the
1098
+ specified power and significance level.
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ n_treated : int
1103
+ Number of treated units.
1104
+ n_control : int
1105
+ Number of control units.
1106
+ sigma : float
1107
+ Residual standard deviation.
1108
+ n_pre : int, default=1
1109
+ Number of pre-treatment periods.
1110
+ n_post : int, default=1
1111
+ Number of post-treatment periods.
1112
+ rho : float, default=0.0
1113
+ Intra-cluster correlation for panel data.
1114
+
1115
+ Returns
1116
+ -------
1117
+ PowerResults
1118
+ Power analysis results including MDE.
1119
+
1120
+ Examples
1121
+ --------
1122
+ >>> pa = PowerAnalysis(power=0.80)
1123
+ >>> results = pa.mde(n_treated=100, n_control=100, sigma=10.0)
1124
+ >>> print(f"MDE: {results.mde:.2f}")
1125
+ """
1126
+ T = n_pre + n_post
1127
+ design = "panel" if T > 2 else "basic_did"
1128
+
1129
+ variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design)
1130
+ se = np.sqrt(variance)
1131
+
1132
+ mde = self._compute_mde_from_se(se)
1133
+
1134
+ return PowerResults(
1135
+ power=self.target_power,
1136
+ mde=mde,
1137
+ required_n=n_treated + n_control,
1138
+ effect_size=mde,
1139
+ alpha=self.alpha,
1140
+ alternative=self.alternative,
1141
+ n_treated=n_treated,
1142
+ n_control=n_control,
1143
+ n_pre=n_pre,
1144
+ n_post=n_post,
1145
+ sigma=sigma,
1146
+ rho=rho,
1147
+ design=design,
1148
+ )
1149
+
1150
+ def _compute_required_n(
1151
+ self,
1152
+ effect_size: float,
1153
+ sigma: float,
1154
+ n_pre: int,
1155
+ n_post: int,
1156
+ rho: float,
1157
+ design: str,
1158
+ treat_frac: float = 0.5,
1159
+ ) -> int:
1160
+ """Compute required sample size for given effect."""
1161
+ # Handle edge case of zero effect size
1162
+ if effect_size == 0:
1163
+ return MAX_SAMPLE_SIZE # Can't detect zero effect
1164
+
1165
+ z_alpha, z_beta = self._get_critical_values()
1166
+
1167
+ T = n_pre + n_post
1168
+
1169
+ if design == "basic_did":
1170
+ # Var = sigma^2 * (1/n_t + 1/n_t + 1/n_c + 1/n_c) = sigma^2 * (2/n_t + 2/n_c)
1171
+ # For balanced: Var = sigma^2 * 4/n where n = n_t = n_c
1172
+ # SE = sqrt(Var), effect_size = (z_alpha + z_beta) * SE
1173
+ # n = 4 * sigma^2 * (z_alpha + z_beta)^2 / effect_size^2
1174
+
1175
+ # For general allocation with treat_frac:
1176
+ # Var = sigma^2 * 2 * (1/(N*p) + 1/(N*(1-p)))
1177
+ # = 2 * sigma^2 / N * (1/p + 1/(1-p))
1178
+ # = 2 * sigma^2 / N * (1/(p*(1-p)))
1179
+
1180
+ n_total = (
1181
+ 2
1182
+ * sigma**2
1183
+ * (z_alpha + z_beta) ** 2
1184
+ / (effect_size**2 * treat_frac * (1 - treat_frac))
1185
+ )
1186
+ else: # panel
1187
+ design_effect = 1 + (T - 1) * rho
1188
+
1189
+ # Var = sigma^2 * (1/n_t + 1/n_c) * design_effect / T
1190
+ # For balanced: Var = 2 * sigma^2 / N * design_effect / T
1191
+
1192
+ n_total = (
1193
+ 2
1194
+ * sigma**2
1195
+ * (z_alpha + z_beta) ** 2
1196
+ * design_effect
1197
+ / (effect_size**2 * treat_frac * (1 - treat_frac) * T)
1198
+ )
1199
+
1200
+ # Handle infinity case (extremely small effect)
1201
+ if np.isinf(n_total):
1202
+ return MAX_SAMPLE_SIZE
1203
+
1204
+ return max(4, int(np.ceil(n_total))) # At least 4 units
1205
+
1206
+ def sample_size(
1207
+ self,
1208
+ effect_size: float,
1209
+ sigma: float,
1210
+ n_pre: int = 1,
1211
+ n_post: int = 1,
1212
+ rho: float = 0.0,
1213
+ treat_frac: float = 0.5,
1214
+ ) -> PowerResults:
1215
+ """
1216
+ Calculate required sample size to detect given effect.
1217
+
1218
+ Parameters
1219
+ ----------
1220
+ effect_size : float
1221
+ Treatment effect to detect.
1222
+ sigma : float
1223
+ Residual standard deviation.
1224
+ n_pre : int, default=1
1225
+ Number of pre-treatment periods.
1226
+ n_post : int, default=1
1227
+ Number of post-treatment periods.
1228
+ rho : float, default=0.0
1229
+ Intra-cluster correlation for panel data.
1230
+ treat_frac : float, default=0.5
1231
+ Fraction of units assigned to treatment.
1232
+
1233
+ Returns
1234
+ -------
1235
+ PowerResults
1236
+ Power analysis results including required sample size.
1237
+
1238
+ Examples
1239
+ --------
1240
+ >>> pa = PowerAnalysis(power=0.80)
1241
+ >>> results = pa.sample_size(effect_size=5.0, sigma=10.0)
1242
+ >>> print(f"Required N: {results.required_n}")
1243
+ """
1244
+ T = n_pre + n_post
1245
+ design = "panel" if T > 2 else "basic_did"
1246
+
1247
+ n_total = self._compute_required_n(
1248
+ effect_size, sigma, n_pre, n_post, rho, design, treat_frac
1249
+ )
1250
+
1251
+ n_treated = max(2, int(np.ceil(n_total * treat_frac)))
1252
+ n_control = max(2, n_total - n_treated)
1253
+ n_total = n_treated + n_control
1254
+
1255
+ # Compute actual power achieved
1256
+ variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design)
1257
+ se = np.sqrt(variance)
1258
+ mde = self._compute_mde_from_se(se)
1259
+
1260
+ return PowerResults(
1261
+ power=self.target_power,
1262
+ mde=mde,
1263
+ required_n=n_total,
1264
+ effect_size=effect_size,
1265
+ alpha=self.alpha,
1266
+ alternative=self.alternative,
1267
+ n_treated=n_treated,
1268
+ n_control=n_control,
1269
+ n_pre=n_pre,
1270
+ n_post=n_post,
1271
+ sigma=sigma,
1272
+ rho=rho,
1273
+ design=design,
1274
+ )
1275
+
1276
+ def power_curve(
1277
+ self,
1278
+ n_treated: int,
1279
+ n_control: int,
1280
+ sigma: float,
1281
+ effect_sizes: Optional[List[float]] = None,
1282
+ n_pre: int = 1,
1283
+ n_post: int = 1,
1284
+ rho: float = 0.0,
1285
+ ) -> pd.DataFrame:
1286
+ """
1287
+ Compute power for a range of effect sizes.
1288
+
1289
+ Parameters
1290
+ ----------
1291
+ n_treated : int
1292
+ Number of treated units.
1293
+ n_control : int
1294
+ Number of control units.
1295
+ sigma : float
1296
+ Residual standard deviation.
1297
+ effect_sizes : list of float, optional
1298
+ Effect sizes to evaluate. If None, uses a range from 0 to 3*MDE.
1299
+ n_pre : int, default=1
1300
+ Number of pre-treatment periods.
1301
+ n_post : int, default=1
1302
+ Number of post-treatment periods.
1303
+ rho : float, default=0.0
1304
+ Intra-cluster correlation.
1305
+
1306
+ Returns
1307
+ -------
1308
+ pd.DataFrame
1309
+ DataFrame with columns 'effect_size' and 'power'.
1310
+
1311
+ Examples
1312
+ --------
1313
+ >>> pa = PowerAnalysis()
1314
+ >>> curve = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
1315
+ >>> print(curve)
1316
+ """
1317
+ # First get MDE to determine default range
1318
+ mde_result = self.mde(n_treated, n_control, sigma, n_pre, n_post, rho)
1319
+
1320
+ if effect_sizes is None:
1321
+ # Generate range from 0 to 2*MDE
1322
+ effect_sizes = np.linspace(0, 2.5 * mde_result.mde, 50).tolist()
1323
+
1324
+ powers = []
1325
+ for es in effect_sizes:
1326
+ result = self.power(
1327
+ effect_size=es,
1328
+ n_treated=n_treated,
1329
+ n_control=n_control,
1330
+ sigma=sigma,
1331
+ n_pre=n_pre,
1332
+ n_post=n_post,
1333
+ rho=rho,
1334
+ )
1335
+ powers.append(result.power)
1336
+
1337
+ return pd.DataFrame({"effect_size": effect_sizes, "power": powers})
1338
+
1339
+ def sample_size_curve(
1340
+ self,
1341
+ effect_size: float,
1342
+ sigma: float,
1343
+ sample_sizes: Optional[List[int]] = None,
1344
+ n_pre: int = 1,
1345
+ n_post: int = 1,
1346
+ rho: float = 0.0,
1347
+ treat_frac: float = 0.5,
1348
+ ) -> pd.DataFrame:
1349
+ """
1350
+ Compute power for a range of sample sizes.
1351
+
1352
+ Parameters
1353
+ ----------
1354
+ effect_size : float
1355
+ Treatment effect size.
1356
+ sigma : float
1357
+ Residual standard deviation.
1358
+ sample_sizes : list of int, optional
1359
+ Total sample sizes to evaluate. If None, uses sensible range.
1360
+ n_pre : int, default=1
1361
+ Number of pre-treatment periods.
1362
+ n_post : int, default=1
1363
+ Number of post-treatment periods.
1364
+ rho : float, default=0.0
1365
+ Intra-cluster correlation.
1366
+ treat_frac : float, default=0.5
1367
+ Fraction assigned to treatment.
1368
+
1369
+ Returns
1370
+ -------
1371
+ pd.DataFrame
1372
+ DataFrame with columns 'sample_size' and 'power'.
1373
+ """
1374
+ # Get required N to determine default range
1375
+ required = self.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac)
1376
+
1377
+ if sample_sizes is None:
1378
+ min_n = max(10, required.required_n // 4)
1379
+ max_n = required.required_n * 2
1380
+ sample_sizes = list(range(min_n, max_n + 1, max(1, (max_n - min_n) // 50)))
1381
+
1382
+ powers = []
1383
+ for n in sample_sizes:
1384
+ n_treated = max(2, int(n * treat_frac))
1385
+ n_control = max(2, n - n_treated)
1386
+ result = self.power(
1387
+ effect_size=effect_size,
1388
+ n_treated=n_treated,
1389
+ n_control=n_control,
1390
+ sigma=sigma,
1391
+ n_pre=n_pre,
1392
+ n_post=n_post,
1393
+ rho=rho,
1394
+ )
1395
+ powers.append(result.power)
1396
+
1397
+ return pd.DataFrame({"sample_size": sample_sizes, "power": powers})
1398
+
1399
+
1400
+ def simulate_power(
1401
+ estimator: Any,
1402
+ n_units: int = 100,
1403
+ n_periods: int = 4,
1404
+ treatment_effect: float = 5.0,
1405
+ treatment_fraction: float = 0.5,
1406
+ treatment_period: int = 2,
1407
+ sigma: float = 1.0,
1408
+ n_simulations: int = 500,
1409
+ alpha: float = 0.05,
1410
+ effect_sizes: Optional[List[float]] = None,
1411
+ seed: Optional[int] = None,
1412
+ data_generator: Optional[Callable] = None,
1413
+ data_generator_kwargs: Optional[Dict[str, Any]] = None,
1414
+ estimator_kwargs: Optional[Dict[str, Any]] = None,
1415
+ result_extractor: Optional[Callable] = None,
1416
+ progress: bool = True,
1417
+ ) -> SimulationPowerResults:
1418
+ """
1419
+ Estimate power using Monte Carlo simulation.
1420
+
1421
+ This function simulates datasets with known treatment effects and estimates
1422
+ power as the fraction of simulations where the null hypothesis is rejected.
1423
+ Most built-in estimators are supported via an internal registry that selects
1424
+ the appropriate data-generating process and fit signature automatically.
1425
+
1426
+ Parameters
1427
+ ----------
1428
+ estimator : estimator object
1429
+ DiD estimator to use (e.g., DifferenceInDifferences, CallawaySantAnna).
1430
+ n_units : int, default=100
1431
+ Number of units per simulation.
1432
+ n_periods : int, default=4
1433
+ Number of time periods.
1434
+ treatment_effect : float, default=5.0
1435
+ True treatment effect to simulate.
1436
+ treatment_fraction : float, default=0.5
1437
+ Fraction of units that are treated.
1438
+ treatment_period : int, default=2
1439
+ First post-treatment period (0-indexed).
1440
+ sigma : float, default=1.0
1441
+ Residual standard deviation (noise level).
1442
+ n_simulations : int, default=500
1443
+ Number of Monte Carlo simulations.
1444
+ alpha : float, default=0.05
1445
+ Significance level for hypothesis tests.
1446
+ effect_sizes : list of float, optional
1447
+ Multiple effect sizes to evaluate for power curve.
1448
+ If None, uses only treatment_effect.
1449
+ seed : int, optional
1450
+ Random seed for reproducibility.
1451
+ data_generator : callable, optional
1452
+ Custom data generation function. When provided, bypasses the
1453
+ registry DGP and calls this function with the standard kwargs
1454
+ (n_units, n_periods, treatment_effect, etc.).
1455
+ data_generator_kwargs : dict, optional
1456
+ Additional keyword arguments for data generator.
1457
+ estimator_kwargs : dict, optional
1458
+ Additional keyword arguments for estimator.fit().
1459
+ result_extractor : callable, optional
1460
+ Custom function to extract results from the estimator output.
1461
+ Takes the estimator result object and returns a tuple of
1462
+ ``(att, se, p_value, conf_int)``. Useful for unregistered
1463
+ estimators with non-standard result schemas.
1464
+ progress : bool, default=True
1465
+ Whether to print progress updates.
1466
+
1467
+ Returns
1468
+ -------
1469
+ SimulationPowerResults
1470
+ Simulation-based power analysis results.
1471
+
1472
+ Examples
1473
+ --------
1474
+ Basic power simulation:
1475
+
1476
+ >>> from diff_diff import DifferenceInDifferences, simulate_power
1477
+ >>> did = DifferenceInDifferences()
1478
+ >>> results = simulate_power(
1479
+ ... estimator=did,
1480
+ ... n_units=100,
1481
+ ... treatment_effect=5.0,
1482
+ ... sigma=5.0,
1483
+ ... n_simulations=500,
1484
+ ... seed=42
1485
+ ... )
1486
+ >>> print(f"Power: {results.power:.1%}")
1487
+
1488
+ Power curve over multiple effect sizes:
1489
+
1490
+ >>> results = simulate_power(
1491
+ ... estimator=did,
1492
+ ... effect_sizes=[1.0, 2.0, 3.0, 5.0, 7.0],
1493
+ ... n_simulations=200,
1494
+ ... seed=42
1495
+ ... )
1496
+ >>> print(results.power_curve_df())
1497
+
1498
+ With Callaway-Sant'Anna (auto-detected, no custom DGP needed):
1499
+
1500
+ >>> from diff_diff import CallawaySantAnna
1501
+ >>> cs = CallawaySantAnna()
1502
+ >>> results = simulate_power(cs, n_simulations=200, seed=42)
1503
+
1504
+ Notes
1505
+ -----
1506
+ The simulation approach:
1507
+ 1. Generate data with known treatment effect
1508
+ 2. Fit the estimator and record the p-value
1509
+ 3. Repeat n_simulations times
1510
+ 4. Power = fraction of simulations where p-value < alpha
1511
+
1512
+ References
1513
+ ----------
1514
+ Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
1515
+ """
1516
+ rng = np.random.default_rng(seed)
1517
+
1518
+ estimator_name = type(estimator).__name__
1519
+ registry = _get_registry()
1520
+ profile = registry.get(estimator_name)
1521
+
1522
+ # If no profile and no custom data_generator, raise
1523
+ if profile is None and data_generator is None:
1524
+ raise ValueError(
1525
+ f"Estimator '{estimator_name}' not in registry. "
1526
+ f"Provide a custom data_generator and estimator_kwargs "
1527
+ f"(the full dict of keyword arguments for estimator.fit(), "
1528
+ f"e.g. dict(outcome='y', treatment='treat', time='period'))."
1529
+ )
1530
+
1531
+ # When a custom data_generator is provided, bypass registry DGP
1532
+ use_custom_dgp = data_generator is not None
1533
+
1534
+ data_gen_kwargs = data_generator_kwargs or {}
1535
+ est_kwargs = estimator_kwargs or {}
1536
+
1537
+ # SyntheticDiD placebo variance requires n_control > n_treated.
1538
+ # Check after merging data_generator_kwargs so overrides of n_treated
1539
+ # are accounted for.
1540
+ if estimator_name == "SyntheticDiD" and not use_custom_dgp:
1541
+ vm = getattr(estimator, "variance_method", "placebo")
1542
+ effective_n_treated = data_gen_kwargs.get(
1543
+ "n_treated", max(1, int(n_units * treatment_fraction))
1544
+ )
1545
+ n_control = n_units - effective_n_treated
1546
+ if vm == "placebo" and n_control <= effective_n_treated:
1547
+ raise ValueError(
1548
+ f"SyntheticDiD placebo variance requires more control than "
1549
+ f"treated units (got n_control={n_control}, "
1550
+ f"n_treated={effective_n_treated}). Either lower "
1551
+ f"treatment_fraction so that n_control > n_treated, or use "
1552
+ f"SyntheticDiD(variance_method='bootstrap')."
1553
+ )
1554
+
1555
+ # Warn if staggered estimator settings don't match auto DGP
1556
+ if profile is not None and not use_custom_dgp:
1557
+ _check_staggered_dgp_compat(estimator, data_generator_kwargs)
1558
+
1559
+ # Block registry-path collisions on search-critical keys
1560
+ if profile is not None and not use_custom_dgp and data_gen_kwargs:
1561
+ sample_dgp_keys = set(
1562
+ profile.dgp_kwargs_builder(
1563
+ n_units=n_units,
1564
+ n_periods=n_periods,
1565
+ treatment_effect=treatment_effect,
1566
+ treatment_fraction=treatment_fraction,
1567
+ treatment_period=treatment_period,
1568
+ sigma=sigma,
1569
+ ).keys()
1570
+ )
1571
+ collisions = _PROTECTED_DGP_KEYS & set(data_gen_kwargs) & sample_dgp_keys
1572
+ if collisions:
1573
+ raise ValueError(
1574
+ f"data_generator_kwargs contains keys that conflict with "
1575
+ f"registry-managed simulation inputs: {sorted(collisions)}. "
1576
+ f"These are controlled by simulate_power() parameters directly. "
1577
+ f"Use the corresponding function parameters instead, or pass a "
1578
+ f"custom data_generator to override the DGP entirely."
1579
+ )
1580
+
1581
+ # Warn if DDD design inputs are silently ignored
1582
+ if estimator_name == "TripleDifference" and not use_custom_dgp:
1583
+ _check_ddd_dgp_compat(
1584
+ n_units,
1585
+ n_periods,
1586
+ treatment_fraction,
1587
+ treatment_period,
1588
+ data_generator_kwargs,
1589
+ )
1590
+ effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs)
1591
+ else:
1592
+ effective_n_units = None
1593
+
1594
+ # Determine effect sizes to test
1595
+ if effect_sizes is None:
1596
+ effect_sizes = [treatment_effect]
1597
+
1598
+ all_powers = []
1599
+
1600
+ # For the primary effect, collect detailed results
1601
+ if len(effect_sizes) == 1:
1602
+ primary_idx = 0
1603
+ else:
1604
+ primary_idx = -1
1605
+ for i, es in enumerate(effect_sizes):
1606
+ if np.isclose(es, treatment_effect):
1607
+ primary_idx = i
1608
+ break
1609
+ if primary_idx == -1:
1610
+ primary_idx = len(effect_sizes) - 1
1611
+
1612
+ primary_effect = effect_sizes[primary_idx]
1613
+
1614
+ # Initialize so they are always bound
1615
+ primary_estimates: List[float] = []
1616
+ primary_ses: List[float] = []
1617
+ primary_p_values: List[float] = []
1618
+ primary_rejections: List[bool] = []
1619
+ primary_ci_contains: List[bool] = []
1620
+
1621
+ for effect_idx, effect in enumerate(effect_sizes):
1622
+ is_primary = effect_idx == primary_idx
1623
+
1624
+ estimates: List[float] = []
1625
+ ses: List[float] = []
1626
+ p_values: List[float] = []
1627
+ rejections: List[bool] = []
1628
+ ci_contains_true: List[bool] = []
1629
+ n_failures = 0
1630
+
1631
+ for sim in range(n_simulations):
1632
+ if progress and sim % 100 == 0 and sim > 0:
1633
+ pct = (sim + effect_idx * n_simulations) / (len(effect_sizes) * n_simulations)
1634
+ print(f" Simulation progress: {pct:.0%}")
1635
+
1636
+ sim_seed = rng.integers(0, 2**31)
1637
+
1638
+ # --- Generate data ---
1639
+ if use_custom_dgp:
1640
+ assert data_generator is not None
1641
+ data = data_generator(
1642
+ n_units=n_units,
1643
+ n_periods=n_periods,
1644
+ treatment_effect=effect,
1645
+ treatment_fraction=treatment_fraction,
1646
+ treatment_period=treatment_period,
1647
+ noise_sd=sigma,
1648
+ seed=sim_seed,
1649
+ **data_gen_kwargs,
1650
+ )
1651
+ else:
1652
+ assert profile is not None
1653
+ dgp_kwargs = profile.dgp_kwargs_builder(
1654
+ n_units=n_units,
1655
+ n_periods=n_periods,
1656
+ treatment_effect=effect,
1657
+ treatment_fraction=treatment_fraction,
1658
+ treatment_period=treatment_period,
1659
+ sigma=sigma,
1660
+ )
1661
+ dgp_kwargs.update(data_gen_kwargs)
1662
+ dgp_kwargs.pop("seed", None)
1663
+ data = profile.default_dgp(seed=sim_seed, **dgp_kwargs)
1664
+
1665
+ # Check SDID placebo feasibility on realized data (custom DGP path)
1666
+ if effect_idx == 0 and sim == 0 and estimator_name == "SyntheticDiD":
1667
+ _check_sdid_placebo_data(data, estimator, est_kwargs)
1668
+
1669
+ try:
1670
+ # --- Fit estimator ---
1671
+ if profile is not None and not use_custom_dgp:
1672
+ fit_kwargs = profile.fit_kwargs_builder(
1673
+ data, n_units, n_periods, treatment_period
1674
+ )
1675
+ fit_kwargs.update(est_kwargs)
1676
+ else:
1677
+ # Custom DGP fallback: use registry fit kwargs if available,
1678
+ # otherwise use basic DiD signature
1679
+ if profile is not None:
1680
+ fit_kwargs = profile.fit_kwargs_builder(
1681
+ data, n_units, n_periods, treatment_period
1682
+ )
1683
+ fit_kwargs.update(est_kwargs)
1684
+ else:
1685
+ fit_kwargs = dict(est_kwargs)
1686
+
1687
+ result = estimator.fit(data, **fit_kwargs)
1688
+
1689
+ # --- Extract results ---
1690
+ if profile is not None:
1691
+ att, se, p_val, ci = profile.result_extractor(result)
1692
+ elif result_extractor is not None:
1693
+ att, se, p_val, ci = result_extractor(result)
1694
+ else:
1695
+ att = result.att if hasattr(result, "att") else result.avg_att
1696
+ se = result.se if hasattr(result, "se") else result.avg_se
1697
+ p_val = result.p_value if hasattr(result, "p_value") else result.avg_p_value
1698
+ ci = result.conf_int if hasattr(result, "conf_int") else result.avg_conf_int
1699
+
1700
+ # NaN p-value → treat as non-rejection
1701
+ rejected = bool(p_val < alpha) if not np.isnan(p_val) else False
1702
+
1703
+ estimates.append(att)
1704
+ ses.append(se)
1705
+ p_values.append(p_val)
1706
+ rejections.append(rejected)
1707
+ ci_contains_true.append(ci[0] <= effect <= ci[1])
1708
+
1709
+ except Exception as e:
1710
+ n_failures += 1
1711
+ if progress:
1712
+ print(f" Warning: Simulation {sim} failed: {e}")
1713
+ continue
1714
+
1715
+ # Warn if too many simulations failed
1716
+ failure_rate = n_failures / n_simulations
1717
+ if failure_rate > 0.1:
1718
+ warnings.warn(
1719
+ f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) "
1720
+ f"failed for effect_size={effect}. "
1721
+ f"Check estimator and data generator.",
1722
+ UserWarning,
1723
+ )
1724
+
1725
+ if len(estimates) == 0:
1726
+ raise RuntimeError("All simulations failed. Check estimator and data generator.")
1727
+
1728
+ power_val = np.mean(rejections)
1729
+ all_powers.append(power_val)
1730
+
1731
+ if is_primary:
1732
+ primary_estimates = estimates
1733
+ primary_ses = ses
1734
+ primary_p_values = p_values
1735
+ primary_rejections = rejections
1736
+ primary_ci_contains = ci_contains_true
1737
+
1738
+ # Compute confidence interval for power (primary effect)
1739
+ power_val = all_powers[primary_idx]
1740
+ n_valid = len(primary_rejections)
1741
+ power_se = np.sqrt(power_val * (1 - power_val) / n_valid)
1742
+ z = stats.norm.ppf(0.975)
1743
+ power_ci = (
1744
+ max(0.0, power_val - z * power_se),
1745
+ min(1.0, power_val + z * power_se),
1746
+ )
1747
+
1748
+ mean_estimate = np.mean(primary_estimates)
1749
+ std_estimate = np.std(primary_estimates, ddof=1)
1750
+ mean_se = np.mean(primary_ses)
1751
+ coverage = np.mean(primary_ci_contains)
1752
+
1753
+ return SimulationPowerResults(
1754
+ power=power_val,
1755
+ power_se=power_se,
1756
+ power_ci=power_ci,
1757
+ rejection_rate=power_val,
1758
+ mean_estimate=mean_estimate,
1759
+ std_estimate=std_estimate,
1760
+ mean_se=mean_se,
1761
+ coverage=coverage,
1762
+ n_simulations=n_valid,
1763
+ effect_sizes=effect_sizes,
1764
+ powers=all_powers,
1765
+ true_effect=primary_effect,
1766
+ alpha=alpha,
1767
+ estimator_name=estimator_name,
1768
+ simulation_results=[
1769
+ {"estimate": e, "se": s, "p_value": p, "rejected": r}
1770
+ for e, s, p, r in zip(
1771
+ primary_estimates,
1772
+ primary_ses,
1773
+ primary_p_values,
1774
+ primary_rejections,
1775
+ )
1776
+ ],
1777
+ effective_n_units=effective_n_units,
1778
+ )
1779
+
1780
+
1781
+ # ---------------------------------------------------------------------------
1782
+ # Simulation-based MDE and sample-size search
1783
+ # ---------------------------------------------------------------------------
1784
+
1785
+
1786
+ @dataclass
1787
+ class SimulationMDEResults:
1788
+ """
1789
+ Results from simulation-based minimum detectable effect search.
1790
+
1791
+ Attributes
1792
+ ----------
1793
+ mde : float
1794
+ Minimum detectable effect (smallest effect achieving target power).
1795
+ power_at_mde : float
1796
+ Power achieved at the MDE.
1797
+ target_power : float
1798
+ Target power used in the search.
1799
+ alpha : float
1800
+ Significance level.
1801
+ n_units : int
1802
+ Sample size used.
1803
+ n_simulations_per_step : int
1804
+ Number of simulations per bisection step.
1805
+ n_steps : int
1806
+ Number of bisection steps performed.
1807
+ search_path : list of dict
1808
+ Diagnostic trace of ``{effect_size, power}`` at each step.
1809
+ estimator_name : str
1810
+ Name of the estimator used.
1811
+ effective_n_units : int or None
1812
+ Effective sample size when it differs from the requested ``n_units``
1813
+ (e.g., due to DDD grid rounding). ``None`` when no rounding occurred.
1814
+ """
1815
+
1816
+ mde: float
1817
+ power_at_mde: float
1818
+ target_power: float
1819
+ alpha: float
1820
+ n_units: int
1821
+ n_simulations_per_step: int
1822
+ n_steps: int
1823
+ search_path: List[Dict[str, float]]
1824
+ estimator_name: str
1825
+ effective_n_units: Optional[int] = None
1826
+
1827
+ def __repr__(self) -> str:
1828
+ return (
1829
+ f"SimulationMDEResults(mde={self.mde:.4f}, "
1830
+ f"power_at_mde={self.power_at_mde:.3f}, "
1831
+ f"n_steps={self.n_steps})"
1832
+ )
1833
+
1834
+ def summary(self) -> str:
1835
+ """Generate a formatted summary."""
1836
+ lines = [
1837
+ "=" * 65,
1838
+ "Simulation-Based MDE Results".center(65),
1839
+ "=" * 65,
1840
+ "",
1841
+ f"{'Estimator:':<35} {self.estimator_name}",
1842
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
1843
+ f"{'Target power:':<35} {self.target_power:.1%}",
1844
+ f"{'Sample size (n_units):':<35} {self.n_units}",
1845
+ ]
1846
+ if self.effective_n_units is not None:
1847
+ lines.append(
1848
+ f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
1849
+ )
1850
+ lines += [
1851
+ f"{'Simulations per step:':<35} {self.n_simulations_per_step}",
1852
+ "",
1853
+ "-" * 65,
1854
+ "Search Results".center(65),
1855
+ "-" * 65,
1856
+ f"{'Minimum detectable effect:':<35} {self.mde:.4f}",
1857
+ f"{'Power at MDE:':<35} {self.power_at_mde:.1%}",
1858
+ f"{'Bisection steps:':<35} {self.n_steps}",
1859
+ "=" * 65,
1860
+ ]
1861
+ return "\n".join(lines)
1862
+
1863
+ def to_dict(self) -> Dict[str, Any]:
1864
+ """Convert results to a dictionary."""
1865
+ return {
1866
+ "mde": self.mde,
1867
+ "power_at_mde": self.power_at_mde,
1868
+ "target_power": self.target_power,
1869
+ "alpha": self.alpha,
1870
+ "n_units": self.n_units,
1871
+ "effective_n_units": self.effective_n_units,
1872
+ "n_simulations_per_step": self.n_simulations_per_step,
1873
+ "n_steps": self.n_steps,
1874
+ "estimator_name": self.estimator_name,
1875
+ }
1876
+
1877
+ def to_dataframe(self) -> pd.DataFrame:
1878
+ """Convert results to a single-row DataFrame."""
1879
+ return pd.DataFrame([self.to_dict()])
1880
+
1881
+
1882
+ @dataclass
1883
+ class SimulationSampleSizeResults:
1884
+ """
1885
+ Results from simulation-based sample size search.
1886
+
1887
+ Attributes
1888
+ ----------
1889
+ required_n : int
1890
+ Required number of units to achieve target power.
1891
+ power_at_n : float
1892
+ Power achieved at the required N.
1893
+ target_power : float
1894
+ Target power used in the search.
1895
+ alpha : float
1896
+ Significance level.
1897
+ effect_size : float
1898
+ Effect size used in the search.
1899
+ n_simulations_per_step : int
1900
+ Number of simulations per bisection step.
1901
+ n_steps : int
1902
+ Number of bisection steps performed.
1903
+ search_path : list of dict
1904
+ Diagnostic trace of ``{n_units, power}`` at each step.
1905
+ estimator_name : str
1906
+ Name of the estimator used.
1907
+ effective_n_units : int or None
1908
+ Effective sample size when it differs from ``required_n``
1909
+ (e.g., due to DDD grid rounding). ``None`` when no rounding occurred
1910
+ or when the search already snapped to the estimator's grid.
1911
+ """
1912
+
1913
+ required_n: int
1914
+ power_at_n: float
1915
+ target_power: float
1916
+ alpha: float
1917
+ effect_size: float
1918
+ n_simulations_per_step: int
1919
+ n_steps: int
1920
+ search_path: List[Dict[str, float]]
1921
+ estimator_name: str
1922
+ effective_n_units: Optional[int] = None
1923
+
1924
+ def __repr__(self) -> str:
1925
+ return (
1926
+ f"SimulationSampleSizeResults(required_n={self.required_n}, "
1927
+ f"power_at_n={self.power_at_n:.3f}, "
1928
+ f"n_steps={self.n_steps})"
1929
+ )
1930
+
1931
+ def summary(self) -> str:
1932
+ """Generate a formatted summary."""
1933
+ lines = [
1934
+ "=" * 65,
1935
+ "Simulation-Based Sample Size Results".center(65),
1936
+ "=" * 65,
1937
+ "",
1938
+ f"{'Estimator:':<35} {self.estimator_name}",
1939
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
1940
+ f"{'Target power:':<35} {self.target_power:.1%}",
1941
+ f"{'Effect size:':<35} {self.effect_size:.4f}",
1942
+ f"{'Simulations per step:':<35} {self.n_simulations_per_step}",
1943
+ "",
1944
+ "-" * 65,
1945
+ "Search Results".center(65),
1946
+ "-" * 65,
1947
+ f"{'Required sample size:':<35} {self.required_n}",
1948
+ f"{'Power at required N:':<35} {self.power_at_n:.1%}",
1949
+ f"{'Bisection steps:':<35} {self.n_steps}",
1950
+ ]
1951
+ if self.effective_n_units is not None:
1952
+ lines.append(
1953
+ f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
1954
+ )
1955
+ lines.append("=" * 65)
1956
+ return "\n".join(lines)
1957
+
1958
+ def to_dict(self) -> Dict[str, Any]:
1959
+ """Convert results to a dictionary."""
1960
+ return {
1961
+ "required_n": self.required_n,
1962
+ "power_at_n": self.power_at_n,
1963
+ "target_power": self.target_power,
1964
+ "alpha": self.alpha,
1965
+ "effect_size": self.effect_size,
1966
+ "n_simulations_per_step": self.n_simulations_per_step,
1967
+ "n_steps": self.n_steps,
1968
+ "estimator_name": self.estimator_name,
1969
+ "effective_n_units": self.effective_n_units,
1970
+ }
1971
+
1972
+ def to_dataframe(self) -> pd.DataFrame:
1973
+ """Convert results to a single-row DataFrame."""
1974
+ return pd.DataFrame([self.to_dict()])
1975
+
1976
+
1977
+ def simulate_mde(
1978
+ estimator: Any,
1979
+ n_units: int = 100,
1980
+ n_periods: int = 4,
1981
+ treatment_fraction: float = 0.5,
1982
+ treatment_period: int = 2,
1983
+ sigma: float = 1.0,
1984
+ n_simulations: int = 200,
1985
+ power: float = 0.80,
1986
+ alpha: float = 0.05,
1987
+ effect_range: Optional[Tuple[float, float]] = None,
1988
+ tol: float = 0.02,
1989
+ max_steps: int = 15,
1990
+ seed: Optional[int] = None,
1991
+ data_generator: Optional[Callable] = None,
1992
+ data_generator_kwargs: Optional[Dict[str, Any]] = None,
1993
+ estimator_kwargs: Optional[Dict[str, Any]] = None,
1994
+ result_extractor: Optional[Callable] = None,
1995
+ progress: bool = True,
1996
+ ) -> SimulationMDEResults:
1997
+ """
1998
+ Find the minimum detectable effect via simulation-based bisection search.
1999
+
2000
+ Searches over effect sizes to find the smallest effect that achieves the
2001
+ target power, using ``simulate_power()`` at each step.
2002
+
2003
+ Parameters
2004
+ ----------
2005
+ estimator : estimator object
2006
+ DiD estimator to use.
2007
+ n_units : int, default=100
2008
+ Number of units per simulation.
2009
+ n_periods : int, default=4
2010
+ Number of time periods.
2011
+ treatment_fraction : float, default=0.5
2012
+ Fraction of units that are treated.
2013
+ treatment_period : int, default=2
2014
+ First post-treatment period (0-indexed).
2015
+ sigma : float, default=1.0
2016
+ Residual standard deviation.
2017
+ n_simulations : int, default=200
2018
+ Simulations per bisection step.
2019
+ power : float, default=0.80
2020
+ Target power.
2021
+ alpha : float, default=0.05
2022
+ Significance level.
2023
+ effect_range : tuple of (float, float), optional
2024
+ ``(lo, hi)`` bracket for the search. If None, auto-brackets.
2025
+ tol : float, default=0.02
2026
+ Convergence tolerance on power.
2027
+ max_steps : int, default=15
2028
+ Maximum bisection steps.
2029
+ seed : int, optional
2030
+ Random seed for reproducibility.
2031
+ data_generator : callable, optional
2032
+ Custom data generation function.
2033
+ data_generator_kwargs : dict, optional
2034
+ Additional keyword arguments for data generator.
2035
+ estimator_kwargs : dict, optional
2036
+ Additional keyword arguments for estimator.fit().
2037
+ result_extractor : callable, optional
2038
+ Custom function to extract results from the estimator output.
2039
+ Forwarded to ``simulate_power()``.
2040
+ progress : bool, default=True
2041
+ Whether to print progress updates.
2042
+
2043
+ Returns
2044
+ -------
2045
+ SimulationMDEResults
2046
+ Results including the MDE and search diagnostics.
2047
+
2048
+ Examples
2049
+ --------
2050
+ >>> from diff_diff import simulate_mde, DifferenceInDifferences
2051
+ >>> result = simulate_mde(DifferenceInDifferences(), n_simulations=100, seed=42)
2052
+ >>> print(f"MDE: {result.mde:.3f}")
2053
+ """
2054
+ master_rng = np.random.default_rng(seed)
2055
+ estimator_name = type(estimator).__name__
2056
+ search_path: List[Dict[str, float]] = []
2057
+
2058
+ # Compute effective N for DDD (N is fixed throughout MDE search)
2059
+ if estimator_name == "TripleDifference" and data_generator is None:
2060
+ effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs)
2061
+ else:
2062
+ effective_n_units = None
2063
+
2064
+ common_kwargs: Dict[str, Any] = dict(
2065
+ estimator=estimator,
2066
+ n_units=n_units,
2067
+ n_periods=n_periods,
2068
+ treatment_fraction=treatment_fraction,
2069
+ treatment_period=treatment_period,
2070
+ sigma=sigma,
2071
+ n_simulations=n_simulations,
2072
+ alpha=alpha,
2073
+ data_generator=data_generator,
2074
+ data_generator_kwargs=data_generator_kwargs,
2075
+ estimator_kwargs=estimator_kwargs,
2076
+ result_extractor=result_extractor,
2077
+ progress=False,
2078
+ )
2079
+
2080
+ def _power_at(effect: float) -> float:
2081
+ step_seed = int(master_rng.integers(0, 2**31))
2082
+ res = simulate_power(treatment_effect=effect, seed=step_seed, **common_kwargs)
2083
+ pwr = float(res.power)
2084
+ search_path.append({"effect_size": effect, "power": pwr})
2085
+ if progress:
2086
+ print(f" MDE search: effect={effect:.4f}, power={pwr:.3f}")
2087
+ return pwr
2088
+
2089
+ # --- Bracket ---
2090
+ if effect_range is not None:
2091
+ lo, hi = effect_range
2092
+ power_lo = _power_at(lo)
2093
+ power_hi = _power_at(hi)
2094
+ if power_lo >= power:
2095
+ warnings.warn(
2096
+ f"Power at effect={lo} is {power_lo:.2f} >= target {power}. "
2097
+ f"Lower bound already exceeds target power. Returning lo as MDE.",
2098
+ UserWarning,
2099
+ )
2100
+ return SimulationMDEResults(
2101
+ mde=lo,
2102
+ power_at_mde=power_lo,
2103
+ target_power=power,
2104
+ alpha=alpha,
2105
+ n_units=n_units,
2106
+ n_simulations_per_step=n_simulations,
2107
+ n_steps=len(search_path),
2108
+ search_path=search_path,
2109
+ estimator_name=estimator_name,
2110
+ effective_n_units=effective_n_units,
2111
+ )
2112
+ if power_hi < power:
2113
+ warnings.warn(
2114
+ f"Target power {power} not bracketed: power at effect={hi} "
2115
+ f"is {power_hi:.2f}. Upper bound may be too low.",
2116
+ UserWarning,
2117
+ )
2118
+ else:
2119
+ lo = 0.0
2120
+ # Check that power at zero is below target (no inflated Type I error)
2121
+ power_at_zero = _power_at(0.0)
2122
+ if power_at_zero >= power:
2123
+ warnings.warn(
2124
+ f"Power at effect=0 is {power_at_zero:.2f} >= target {power}. "
2125
+ f"This suggests inflated Type I error. Returning MDE=0.",
2126
+ UserWarning,
2127
+ )
2128
+ return SimulationMDEResults(
2129
+ mde=0.0,
2130
+ power_at_mde=power_at_zero,
2131
+ target_power=power,
2132
+ alpha=alpha,
2133
+ n_units=n_units,
2134
+ n_simulations_per_step=n_simulations,
2135
+ n_steps=len(search_path),
2136
+ search_path=search_path,
2137
+ estimator_name=estimator_name,
2138
+ effective_n_units=effective_n_units,
2139
+ )
2140
+
2141
+ hi = sigma
2142
+ for _ in range(10):
2143
+ if _power_at(hi) >= power:
2144
+ break
2145
+ hi *= 2
2146
+ else:
2147
+ warnings.warn(
2148
+ f"Could not bracket MDE (power at effect={hi} still below "
2149
+ f"{power}). Returning best upper bound.",
2150
+ UserWarning,
2151
+ )
2152
+
2153
+ # --- Bisect ---
2154
+ best_effect = hi
2155
+ best_power = search_path[-1]["power"] if search_path else 0.0
2156
+
2157
+ for _ in range(max_steps):
2158
+ mid = (lo + hi) / 2
2159
+ pwr = _power_at(mid)
2160
+
2161
+ if pwr >= power:
2162
+ hi = mid
2163
+ best_effect = mid
2164
+ best_power = pwr
2165
+ else:
2166
+ lo = mid
2167
+
2168
+ # Convergence: effect range is tight or power is close enough
2169
+ if hi - lo < max(tol * hi, 1e-6) or abs(pwr - power) < tol:
2170
+ break
2171
+
2172
+ return SimulationMDEResults(
2173
+ mde=best_effect,
2174
+ power_at_mde=best_power,
2175
+ target_power=power,
2176
+ alpha=alpha,
2177
+ n_units=n_units,
2178
+ n_simulations_per_step=n_simulations,
2179
+ n_steps=len(search_path),
2180
+ search_path=search_path,
2181
+ estimator_name=estimator_name,
2182
+ effective_n_units=effective_n_units,
2183
+ )
2184
+
2185
+
2186
+ def simulate_sample_size(
2187
+ estimator: Any,
2188
+ treatment_effect: float = 5.0,
2189
+ n_periods: int = 4,
2190
+ treatment_fraction: float = 0.5,
2191
+ treatment_period: int = 2,
2192
+ sigma: float = 1.0,
2193
+ n_simulations: int = 200,
2194
+ power: float = 0.80,
2195
+ alpha: float = 0.05,
2196
+ n_range: Optional[Tuple[int, int]] = None,
2197
+ max_steps: int = 15,
2198
+ seed: Optional[int] = None,
2199
+ data_generator: Optional[Callable] = None,
2200
+ data_generator_kwargs: Optional[Dict[str, Any]] = None,
2201
+ estimator_kwargs: Optional[Dict[str, Any]] = None,
2202
+ result_extractor: Optional[Callable] = None,
2203
+ progress: bool = True,
2204
+ ) -> SimulationSampleSizeResults:
2205
+ """
2206
+ Find the required sample size via simulation-based bisection search.
2207
+
2208
+ Searches over ``n_units`` to find the smallest N that achieves the
2209
+ target power, using ``simulate_power()`` at each step.
2210
+
2211
+ Parameters
2212
+ ----------
2213
+ estimator : estimator object
2214
+ DiD estimator to use.
2215
+ treatment_effect : float, default=5.0
2216
+ True treatment effect to simulate.
2217
+ n_periods : int, default=4
2218
+ Number of time periods.
2219
+ treatment_fraction : float, default=0.5
2220
+ Fraction of units that are treated.
2221
+ treatment_period : int, default=2
2222
+ First post-treatment period (0-indexed).
2223
+ sigma : float, default=1.0
2224
+ Residual standard deviation.
2225
+ n_simulations : int, default=200
2226
+ Simulations per bisection step.
2227
+ power : float, default=0.80
2228
+ Target power.
2229
+ alpha : float, default=0.05
2230
+ Significance level.
2231
+ n_range : tuple of (int, int), optional
2232
+ ``(lo, hi)`` bracket for sample size. If None, auto-brackets.
2233
+ max_steps : int, default=15
2234
+ Maximum bisection steps.
2235
+ seed : int, optional
2236
+ Random seed for reproducibility.
2237
+ data_generator : callable, optional
2238
+ Custom data generation function.
2239
+ data_generator_kwargs : dict, optional
2240
+ Additional keyword arguments for data generator.
2241
+ estimator_kwargs : dict, optional
2242
+ Additional keyword arguments for estimator.fit().
2243
+ result_extractor : callable, optional
2244
+ Custom function to extract results from the estimator output.
2245
+ Forwarded to ``simulate_power()``.
2246
+ progress : bool, default=True
2247
+ Whether to print progress updates.
2248
+
2249
+ Returns
2250
+ -------
2251
+ SimulationSampleSizeResults
2252
+ Results including the required N and search diagnostics.
2253
+
2254
+ Examples
2255
+ --------
2256
+ >>> from diff_diff import simulate_sample_size, DifferenceInDifferences
2257
+ >>> result = simulate_sample_size(
2258
+ ... DifferenceInDifferences(), treatment_effect=5.0, n_simulations=100, seed=42
2259
+ ... )
2260
+ >>> print(f"Required N: {result.required_n}")
2261
+ """
2262
+ master_rng = np.random.default_rng(seed)
2263
+ estimator_name = type(estimator).__name__
2264
+ search_path: List[Dict[str, float]] = []
2265
+
2266
+ # Determine min_n from registry
2267
+ registry = _get_registry()
2268
+ profile = registry.get(estimator_name)
2269
+ min_n = profile.min_n if profile is not None else 20
2270
+
2271
+ # DDD grid snapping: bisection candidates must be multiples of 8
2272
+ is_ddd_grid = estimator_name == "TripleDifference" and data_generator is None
2273
+ grid_step = 8 if is_ddd_grid else 1
2274
+ convergence_threshold = grid_step + 1 # 9 for DDD, 2 for others
2275
+
2276
+ if is_ddd_grid and data_generator_kwargs and "n_per_cell" in data_generator_kwargs:
2277
+ raise ValueError(
2278
+ "data_generator_kwargs contains 'n_per_cell', which conflicts with "
2279
+ "the sample-size search in simulate_sample_size(). For "
2280
+ "TripleDifference, n_per_cell is derived from n_units (the search "
2281
+ "variable). Use simulate_power() with a fixed n_per_cell override "
2282
+ "instead, or pass a custom data_generator."
2283
+ )
2284
+
2285
+ def _snap_n(n: int, direction: str = "down", floor: Optional[int] = None) -> int:
2286
+ if grid_step == 1:
2287
+ return n
2288
+ actual_floor = floor if floor is not None else min_n
2289
+ if direction == "up":
2290
+ return max(actual_floor, ((n + grid_step - 1) // grid_step) * grid_step)
2291
+ return max(actual_floor, (n // grid_step) * grid_step)
2292
+
2293
+ common_kwargs: Dict[str, Any] = dict(
2294
+ estimator=estimator,
2295
+ n_periods=n_periods,
2296
+ treatment_effect=treatment_effect,
2297
+ treatment_fraction=treatment_fraction,
2298
+ treatment_period=treatment_period,
2299
+ sigma=sigma,
2300
+ n_simulations=n_simulations,
2301
+ alpha=alpha,
2302
+ data_generator=data_generator,
2303
+ data_generator_kwargs=data_generator_kwargs,
2304
+ estimator_kwargs=estimator_kwargs,
2305
+ result_extractor=result_extractor,
2306
+ progress=False,
2307
+ )
2308
+
2309
+ def _power_at_n(n: int) -> float:
2310
+ step_seed = int(master_rng.integers(0, 2**31))
2311
+ res = simulate_power(n_units=n, seed=step_seed, **common_kwargs)
2312
+ pwr = float(res.power)
2313
+ search_path.append({"n_units": float(n), "power": pwr})
2314
+ if progress:
2315
+ print(f" Sample size search: n={n}, power={pwr:.3f}")
2316
+ return pwr
2317
+
2318
+ # --- Bracket ---
2319
+ abs_min = 16 if is_ddd_grid else 4
2320
+ if n_range is not None:
2321
+ lo, hi = _snap_n(n_range[0], "up", floor=abs_min), _snap_n(
2322
+ n_range[1], "down", floor=abs_min
2323
+ )
2324
+ if lo > hi:
2325
+ lo = hi # collapsed bracket — evaluate single point
2326
+ power_lo = _power_at_n(lo)
2327
+ if power_lo >= power:
2328
+ warnings.warn(
2329
+ f"Power at n={lo} is {power_lo:.2f} >= target {power}. "
2330
+ f"Lower bound already achieves target power. Returning lo.",
2331
+ UserWarning,
2332
+ )
2333
+ return SimulationSampleSizeResults(
2334
+ required_n=lo,
2335
+ power_at_n=power_lo,
2336
+ target_power=power,
2337
+ alpha=alpha,
2338
+ effect_size=treatment_effect,
2339
+ n_simulations_per_step=n_simulations,
2340
+ n_steps=len(search_path),
2341
+ search_path=search_path,
2342
+ estimator_name=estimator_name,
2343
+ )
2344
+ power_hi = _power_at_n(hi)
2345
+ if power_hi < power:
2346
+ warnings.warn(
2347
+ f"Target power {power} not bracketed: power at n={hi} "
2348
+ f"is {power_hi:.2f}. Upper bound may be too low.",
2349
+ UserWarning,
2350
+ )
2351
+ else:
2352
+ lo = min_n
2353
+ power_lo = _power_at_n(lo)
2354
+ if power_lo >= power:
2355
+ # Floor achieves target — search downward for true minimum
2356
+ hi = lo
2357
+ found_lower = False
2358
+ probe = _snap_n(max(abs_min, lo // 2), floor=abs_min)
2359
+ for _ in range(8):
2360
+ if probe >= hi or probe < abs_min:
2361
+ break
2362
+ pwr = _power_at_n(probe)
2363
+ if pwr < power:
2364
+ lo = probe
2365
+ found_lower = True
2366
+ break
2367
+ hi = probe
2368
+ probe = _snap_n(max(abs_min, probe // 2), floor=abs_min)
2369
+ if not found_lower:
2370
+ # Even smallest viable N achieves target — return best found
2371
+ best = min(
2372
+ (s for s in search_path if s["power"] >= power),
2373
+ key=lambda s: s["n_units"],
2374
+ )
2375
+ warnings.warn(
2376
+ f"Power at n={int(best['n_units'])} is "
2377
+ f"{best['power']:.2f} >= target {power}. Could not "
2378
+ f"find a smaller N below target power. Pass "
2379
+ f"n_range=(lo, hi) to refine.",
2380
+ UserWarning,
2381
+ )
2382
+ return SimulationSampleSizeResults(
2383
+ required_n=int(best["n_units"]),
2384
+ power_at_n=best["power"],
2385
+ target_power=power,
2386
+ alpha=alpha,
2387
+ effect_size=treatment_effect,
2388
+ n_simulations_per_step=n_simulations,
2389
+ n_steps=len(search_path),
2390
+ search_path=search_path,
2391
+ estimator_name=estimator_name,
2392
+ )
2393
+ # Fall through to bisection with lo..hi bracket
2394
+ else:
2395
+ hi = max(100, 2 * min_n)
2396
+ for _ in range(10):
2397
+ if _power_at_n(hi) >= power:
2398
+ break
2399
+ hi *= 2
2400
+ else:
2401
+ warnings.warn(
2402
+ f"Could not bracket required N (power at n={hi} still "
2403
+ f"below {power}). Returning best upper bound.",
2404
+ UserWarning,
2405
+ )
2406
+
2407
+ # --- Bisect on integer n_units ---
2408
+ best_n = hi
2409
+ # Look up power at hi (search_path[-1] may not be hi after downward search)
2410
+ best_power = next(
2411
+ (s["power"] for s in reversed(search_path) if int(s["n_units"]) == hi),
2412
+ search_path[-1]["power"] if search_path else 0.0,
2413
+ )
2414
+
2415
+ for _ in range(max_steps):
2416
+ if hi - lo <= convergence_threshold:
2417
+ break
2418
+ mid = _snap_n((lo + hi) // 2, floor=abs_min)
2419
+ if mid <= lo or mid >= hi:
2420
+ break
2421
+ pwr = _power_at_n(mid)
2422
+
2423
+ if pwr >= power:
2424
+ hi = mid
2425
+ best_n = mid
2426
+ best_power = pwr
2427
+ else:
2428
+ lo = mid
2429
+
2430
+ # Final answer is hi (conservative ceiling) — skip if already evaluated
2431
+ if best_n != hi:
2432
+ final_pwr = _power_at_n(hi)
2433
+ if final_pwr >= power:
2434
+ best_n = hi
2435
+ best_power = final_pwr
2436
+
2437
+ return SimulationSampleSizeResults(
2438
+ required_n=best_n,
2439
+ power_at_n=best_power,
2440
+ target_power=power,
2441
+ alpha=alpha,
2442
+ effect_size=treatment_effect,
2443
+ n_simulations_per_step=n_simulations,
2444
+ n_steps=len(search_path),
2445
+ search_path=search_path,
2446
+ estimator_name=estimator_name,
2447
+ )
2448
+
2449
+
2450
+ def compute_mde(
2451
+ n_treated: int,
2452
+ n_control: int,
2453
+ sigma: float,
2454
+ power: float = 0.80,
2455
+ alpha: float = 0.05,
2456
+ n_pre: int = 1,
2457
+ n_post: int = 1,
2458
+ rho: float = 0.0,
2459
+ ) -> float:
2460
+ """
2461
+ Convenience function to compute minimum detectable effect.
2462
+
2463
+ Parameters
2464
+ ----------
2465
+ n_treated : int
2466
+ Number of treated units.
2467
+ n_control : int
2468
+ Number of control units.
2469
+ sigma : float
2470
+ Residual standard deviation.
2471
+ power : float, default=0.80
2472
+ Target statistical power.
2473
+ alpha : float, default=0.05
2474
+ Significance level.
2475
+ n_pre : int, default=1
2476
+ Number of pre-treatment periods.
2477
+ n_post : int, default=1
2478
+ Number of post-treatment periods.
2479
+ rho : float, default=0.0
2480
+ Intra-cluster correlation.
2481
+
2482
+ Returns
2483
+ -------
2484
+ float
2485
+ Minimum detectable effect size.
2486
+
2487
+ Examples
2488
+ --------
2489
+ >>> mde = compute_mde(n_treated=50, n_control=50, sigma=10.0)
2490
+ >>> print(f"MDE: {mde:.2f}")
2491
+ """
2492
+ pa = PowerAnalysis(alpha=alpha, power=power)
2493
+ result = pa.mde(n_treated, n_control, sigma, n_pre, n_post, rho)
2494
+ return result.mde
2495
+
2496
+
2497
+ def compute_power(
2498
+ effect_size: float,
2499
+ n_treated: int,
2500
+ n_control: int,
2501
+ sigma: float,
2502
+ alpha: float = 0.05,
2503
+ n_pre: int = 1,
2504
+ n_post: int = 1,
2505
+ rho: float = 0.0,
2506
+ ) -> float:
2507
+ """
2508
+ Convenience function to compute power for given effect and sample.
2509
+
2510
+ Parameters
2511
+ ----------
2512
+ effect_size : float
2513
+ Expected treatment effect.
2514
+ n_treated : int
2515
+ Number of treated units.
2516
+ n_control : int
2517
+ Number of control units.
2518
+ sigma : float
2519
+ Residual standard deviation.
2520
+ alpha : float, default=0.05
2521
+ Significance level.
2522
+ n_pre : int, default=1
2523
+ Number of pre-treatment periods.
2524
+ n_post : int, default=1
2525
+ Number of post-treatment periods.
2526
+ rho : float, default=0.0
2527
+ Intra-cluster correlation.
2528
+
2529
+ Returns
2530
+ -------
2531
+ float
2532
+ Statistical power.
2533
+
2534
+ Examples
2535
+ --------
2536
+ >>> power = compute_power(effect_size=5.0, n_treated=50, n_control=50, sigma=10.0)
2537
+ >>> print(f"Power: {power:.1%}")
2538
+ """
2539
+ pa = PowerAnalysis(alpha=alpha)
2540
+ result = pa.power(effect_size, n_treated, n_control, sigma, n_pre, n_post, rho)
2541
+ return result.power
2542
+
2543
+
2544
+ def compute_sample_size(
2545
+ effect_size: float,
2546
+ sigma: float,
2547
+ power: float = 0.80,
2548
+ alpha: float = 0.05,
2549
+ n_pre: int = 1,
2550
+ n_post: int = 1,
2551
+ rho: float = 0.0,
2552
+ treat_frac: float = 0.5,
2553
+ ) -> int:
2554
+ """
2555
+ Convenience function to compute required sample size.
2556
+
2557
+ Parameters
2558
+ ----------
2559
+ effect_size : float
2560
+ Treatment effect to detect.
2561
+ sigma : float
2562
+ Residual standard deviation.
2563
+ power : float, default=0.80
2564
+ Target statistical power.
2565
+ alpha : float, default=0.05
2566
+ Significance level.
2567
+ n_pre : int, default=1
2568
+ Number of pre-treatment periods.
2569
+ n_post : int, default=1
2570
+ Number of post-treatment periods.
2571
+ rho : float, default=0.0
2572
+ Intra-cluster correlation.
2573
+ treat_frac : float, default=0.5
2574
+ Fraction assigned to treatment.
2575
+
2576
+ Returns
2577
+ -------
2578
+ int
2579
+ Required total sample size.
2580
+
2581
+ Examples
2582
+ --------
2583
+ >>> n = compute_sample_size(effect_size=5.0, sigma=10.0)
2584
+ >>> print(f"Required N: {n}")
2585
+ """
2586
+ pa = PowerAnalysis(alpha=alpha, power=power)
2587
+ result = pa.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac)
2588
+ return result.required_n