diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/power.py ADDED
@@ -0,0 +1,1350 @@
1
+ """
2
+ Power analysis tools for difference-in-differences study design.
3
+
4
+ This module provides power calculations and simulation-based power analysis
5
+ for DiD study design, helping practitioners answer questions like:
6
+ - "How many units do I need to detect an effect of size X?"
7
+ - "What is the minimum detectable effect given my sample size?"
8
+ - "What power do I have to detect a given effect?"
9
+
10
+ References
11
+ ----------
12
+ Bloom, H. S. (1995). "Minimum Detectable Effects: A Simple Way to Report the
13
+ Statistical Power of Experimental Designs." Evaluation Review, 19(5), 547-556.
14
+
15
+ Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
16
+ Journal of Development Economics, 144, 102458.
17
+
18
+ Djimeu, E. W., & Houndolo, D.-G. (2016). "Power Calculation for Causal Inference
19
+ in Social Science: Sample Size and Minimum Detectable Effect Determination."
20
+ Journal of Development Effectiveness, 8(4), 508-527.
21
+ """
22
+
23
+ import warnings
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Callable, Dict, List, Optional, Tuple
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+ from scipy import stats
30
+
31
+ # Maximum sample size returned when effect is too small to detect
32
+ # (e.g., zero effect or extremely small relative to noise)
33
+ MAX_SAMPLE_SIZE = 2**31 - 1
34
+
35
+
36
+ @dataclass
37
+ class PowerResults:
38
+ """
39
+ Results from analytical power analysis.
40
+
41
+ Attributes
42
+ ----------
43
+ power : float
44
+ Statistical power (probability of rejecting H0 when effect exists).
45
+ mde : float
46
+ Minimum detectable effect size.
47
+ required_n : int
48
+ Required total sample size (treated + control).
49
+ effect_size : float
50
+ Effect size used in calculation.
51
+ alpha : float
52
+ Significance level.
53
+ alternative : str
54
+ Alternative hypothesis ('two-sided', 'greater', 'less').
55
+ n_treated : int
56
+ Number of treated units.
57
+ n_control : int
58
+ Number of control units.
59
+ n_pre : int
60
+ Number of pre-treatment periods.
61
+ n_post : int
62
+ Number of post-treatment periods.
63
+ sigma : float
64
+ Residual standard deviation.
65
+ rho : float
66
+ Intra-cluster correlation (for panel data).
67
+ design : str
68
+ Study design type ('basic_did', 'panel', 'staggered').
69
+ """
70
+
71
+ power: float
72
+ mde: float
73
+ required_n: int
74
+ effect_size: float
75
+ alpha: float
76
+ alternative: str
77
+ n_treated: int
78
+ n_control: int
79
+ n_pre: int
80
+ n_post: int
81
+ sigma: float
82
+ rho: float = 0.0
83
+ design: str = "basic_did"
84
+
85
+ def __repr__(self) -> str:
86
+ """Concise string representation."""
87
+ return (
88
+ f"PowerResults(power={self.power:.3f}, mde={self.mde:.4f}, "
89
+ f"required_n={self.required_n})"
90
+ )
91
+
92
+ def summary(self) -> str:
93
+ """
94
+ Generate a formatted summary of power analysis results.
95
+
96
+ Returns
97
+ -------
98
+ str
99
+ Formatted summary table.
100
+ """
101
+ lines = [
102
+ "=" * 60,
103
+ "Power Analysis for Difference-in-Differences".center(60),
104
+ "=" * 60,
105
+ "",
106
+ f"{'Design:':<30} {self.design}",
107
+ f"{'Significance level (alpha):':<30} {self.alpha:.3f}",
108
+ f"{'Alternative hypothesis:':<30} {self.alternative}",
109
+ "",
110
+ "-" * 60,
111
+ "Sample Size".center(60),
112
+ "-" * 60,
113
+ f"{'Treated units:':<30} {self.n_treated:>10}",
114
+ f"{'Control units:':<30} {self.n_control:>10}",
115
+ f"{'Total units:':<30} {self.n_treated + self.n_control:>10}",
116
+ f"{'Pre-treatment periods:':<30} {self.n_pre:>10}",
117
+ f"{'Post-treatment periods:':<30} {self.n_post:>10}",
118
+ "",
119
+ "-" * 60,
120
+ "Variance Parameters".center(60),
121
+ "-" * 60,
122
+ f"{'Residual SD (sigma):':<30} {self.sigma:>10.4f}",
123
+ f"{'Intra-cluster correlation:':<30} {self.rho:>10.4f}",
124
+ "",
125
+ "-" * 60,
126
+ "Power Analysis Results".center(60),
127
+ "-" * 60,
128
+ f"{'Effect size:':<30} {self.effect_size:>10.4f}",
129
+ f"{'Power:':<30} {self.power:>10.1%}",
130
+ f"{'Minimum detectable effect:':<30} {self.mde:>10.4f}",
131
+ f"{'Required sample size:':<30} {self.required_n:>10}",
132
+ "=" * 60,
133
+ ]
134
+ return "\n".join(lines)
135
+
136
+ def print_summary(self) -> None:
137
+ """Print the summary to stdout."""
138
+ print(self.summary())
139
+
140
+ def to_dict(self) -> Dict[str, Any]:
141
+ """
142
+ Convert results to a dictionary.
143
+
144
+ Returns
145
+ -------
146
+ Dict[str, Any]
147
+ Dictionary containing all power analysis results.
148
+ """
149
+ return {
150
+ "power": self.power,
151
+ "mde": self.mde,
152
+ "required_n": self.required_n,
153
+ "effect_size": self.effect_size,
154
+ "alpha": self.alpha,
155
+ "alternative": self.alternative,
156
+ "n_treated": self.n_treated,
157
+ "n_control": self.n_control,
158
+ "n_pre": self.n_pre,
159
+ "n_post": self.n_post,
160
+ "sigma": self.sigma,
161
+ "rho": self.rho,
162
+ "design": self.design,
163
+ }
164
+
165
+ def to_dataframe(self) -> pd.DataFrame:
166
+ """
167
+ Convert results to a pandas DataFrame.
168
+
169
+ Returns
170
+ -------
171
+ pd.DataFrame
172
+ DataFrame with power analysis results.
173
+ """
174
+ return pd.DataFrame([self.to_dict()])
175
+
176
+
177
+ @dataclass
178
+ class SimulationPowerResults:
179
+ """
180
+ Results from simulation-based power analysis.
181
+
182
+ Attributes
183
+ ----------
184
+ power : float
185
+ Estimated power (proportion of simulations rejecting H0).
186
+ power_se : float
187
+ Standard error of power estimate.
188
+ power_ci : Tuple[float, float]
189
+ Confidence interval for power estimate.
190
+ rejection_rate : float
191
+ Proportion of simulations with p-value < alpha.
192
+ mean_estimate : float
193
+ Mean treatment effect estimate across simulations.
194
+ std_estimate : float
195
+ Standard deviation of estimates across simulations.
196
+ mean_se : float
197
+ Mean standard error across simulations.
198
+ coverage : float
199
+ Proportion of CIs containing true effect.
200
+ n_simulations : int
201
+ Number of simulations performed.
202
+ effect_sizes : List[float]
203
+ Effect sizes tested (if multiple).
204
+ powers : List[float]
205
+ Power at each effect size (if multiple).
206
+ true_effect : float
207
+ True treatment effect used in simulation.
208
+ alpha : float
209
+ Significance level.
210
+ estimator_name : str
211
+ Name of the estimator used.
212
+ """
213
+
214
+ power: float
215
+ power_se: float
216
+ power_ci: Tuple[float, float]
217
+ rejection_rate: float
218
+ mean_estimate: float
219
+ std_estimate: float
220
+ mean_se: float
221
+ coverage: float
222
+ n_simulations: int
223
+ effect_sizes: List[float]
224
+ powers: List[float]
225
+ true_effect: float
226
+ alpha: float
227
+ estimator_name: str
228
+ bias: float = field(init=False)
229
+ rmse: float = field(init=False)
230
+ simulation_results: Optional[List[Dict[str, Any]]] = field(default=None, repr=False)
231
+
232
+ def __post_init__(self):
233
+ """Compute derived statistics."""
234
+ self.bias = self.mean_estimate - self.true_effect
235
+ self.rmse = np.sqrt(self.bias**2 + self.std_estimate**2)
236
+
237
+ def __repr__(self) -> str:
238
+ """Concise string representation."""
239
+ return (
240
+ f"SimulationPowerResults(power={self.power:.3f} "
241
+ f"[{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}], "
242
+ f"n_simulations={self.n_simulations})"
243
+ )
244
+
245
+ def summary(self) -> str:
246
+ """
247
+ Generate a formatted summary of simulation power results.
248
+
249
+ Returns
250
+ -------
251
+ str
252
+ Formatted summary table.
253
+ """
254
+ lines = [
255
+ "=" * 65,
256
+ "Simulation-Based Power Analysis Results".center(65),
257
+ "=" * 65,
258
+ "",
259
+ f"{'Estimator:':<35} {self.estimator_name}",
260
+ f"{'Number of simulations:':<35} {self.n_simulations}",
261
+ f"{'True treatment effect:':<35} {self.true_effect:.4f}",
262
+ f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
263
+ "",
264
+ "-" * 65,
265
+ "Power Estimates".center(65),
266
+ "-" * 65,
267
+ f"{'Power (rejection rate):':<35} {self.power:.1%}",
268
+ f"{'Standard error:':<35} {self.power_se:.4f}",
269
+ f"{'95% CI:':<35} [{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}]",
270
+ "",
271
+ "-" * 65,
272
+ "Estimation Performance".center(65),
273
+ "-" * 65,
274
+ f"{'Mean estimate:':<35} {self.mean_estimate:.4f}",
275
+ f"{'Bias:':<35} {self.bias:.4f}",
276
+ f"{'Std. deviation of estimates:':<35} {self.std_estimate:.4f}",
277
+ f"{'RMSE:':<35} {self.rmse:.4f}",
278
+ f"{'Mean standard error:':<35} {self.mean_se:.4f}",
279
+ f"{'Coverage (CI contains true):':<35} {self.coverage:.1%}",
280
+ "=" * 65,
281
+ ]
282
+ return "\n".join(lines)
283
+
284
+ def print_summary(self) -> None:
285
+ """Print the summary to stdout."""
286
+ print(self.summary())
287
+
288
+ def to_dict(self) -> Dict[str, Any]:
289
+ """
290
+ Convert results to a dictionary.
291
+
292
+ Returns
293
+ -------
294
+ Dict[str, Any]
295
+ Dictionary containing simulation power results.
296
+ """
297
+ return {
298
+ "power": self.power,
299
+ "power_se": self.power_se,
300
+ "power_ci_lower": self.power_ci[0],
301
+ "power_ci_upper": self.power_ci[1],
302
+ "rejection_rate": self.rejection_rate,
303
+ "mean_estimate": self.mean_estimate,
304
+ "std_estimate": self.std_estimate,
305
+ "bias": self.bias,
306
+ "rmse": self.rmse,
307
+ "mean_se": self.mean_se,
308
+ "coverage": self.coverage,
309
+ "n_simulations": self.n_simulations,
310
+ "true_effect": self.true_effect,
311
+ "alpha": self.alpha,
312
+ "estimator_name": self.estimator_name,
313
+ }
314
+
315
+ def to_dataframe(self) -> pd.DataFrame:
316
+ """
317
+ Convert results to a pandas DataFrame.
318
+
319
+ Returns
320
+ -------
321
+ pd.DataFrame
322
+ DataFrame with simulation power results.
323
+ """
324
+ return pd.DataFrame([self.to_dict()])
325
+
326
+ def power_curve_df(self) -> pd.DataFrame:
327
+ """
328
+ Get power curve data as a DataFrame.
329
+
330
+ Returns
331
+ -------
332
+ pd.DataFrame
333
+ DataFrame with effect_size and power columns.
334
+ """
335
+ return pd.DataFrame({
336
+ "effect_size": self.effect_sizes,
337
+ "power": self.powers
338
+ })
339
+
340
+
341
+ class PowerAnalysis:
342
+ """
343
+ Power analysis for difference-in-differences designs.
344
+
345
+ Provides analytical power calculations for basic 2x2 DiD and panel DiD
346
+ designs. For complex designs like staggered adoption, use simulate_power()
347
+ instead.
348
+
349
+ Parameters
350
+ ----------
351
+ alpha : float, default=0.05
352
+ Significance level for hypothesis testing.
353
+ power : float, default=0.80
354
+ Target statistical power.
355
+ alternative : str, default='two-sided'
356
+ Alternative hypothesis: 'two-sided', 'greater', or 'less'.
357
+
358
+ Examples
359
+ --------
360
+ Calculate minimum detectable effect:
361
+
362
+ >>> from diff_diff import PowerAnalysis
363
+ >>> pa = PowerAnalysis(alpha=0.05, power=0.80)
364
+ >>> results = pa.mde(n_treated=50, n_control=50, sigma=1.0)
365
+ >>> print(f"MDE: {results.mde:.3f}")
366
+
367
+ Calculate required sample size:
368
+
369
+ >>> results = pa.sample_size(effect_size=0.5, sigma=1.0)
370
+ >>> print(f"Required N: {results.required_n}")
371
+
372
+ Calculate power for given sample and effect:
373
+
374
+ >>> results = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0)
375
+ >>> print(f"Power: {results.power:.1%}")
376
+
377
+ Notes
378
+ -----
379
+ The power calculations are based on the variance of the DiD estimator:
380
+
381
+ For basic 2x2 DiD:
382
+ Var(ATT) = sigma^2 * (1/n_treated_post + 1/n_treated_pre
383
+ + 1/n_control_post + 1/n_control_pre)
384
+
385
+ For panel DiD with T periods:
386
+ Var(ATT) = sigma^2 * (1/(N_treated * T) + 1/(N_control * T))
387
+ * (1 + (T-1)*rho) / (1 + (T-1)*rho)
388
+
389
+ Where rho is the intra-cluster correlation coefficient.
390
+
391
+ References
392
+ ----------
393
+ Bloom, H. S. (1995). "Minimum Detectable Effects."
394
+ Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ alpha: float = 0.05,
400
+ power: float = 0.80,
401
+ alternative: str = "two-sided",
402
+ ):
403
+ if not 0 < alpha < 1:
404
+ raise ValueError("alpha must be between 0 and 1")
405
+ if not 0 < power < 1:
406
+ raise ValueError("power must be between 0 and 1")
407
+ if alternative not in ("two-sided", "greater", "less"):
408
+ raise ValueError("alternative must be 'two-sided', 'greater', or 'less'")
409
+
410
+ self.alpha = alpha
411
+ self.target_power = power
412
+ self.alternative = alternative
413
+
414
+ def _get_critical_values(self) -> Tuple[float, float]:
415
+ """Get z critical values for alpha and power."""
416
+ if self.alternative == "two-sided":
417
+ z_alpha = stats.norm.ppf(1 - self.alpha / 2)
418
+ else:
419
+ z_alpha = stats.norm.ppf(1 - self.alpha)
420
+ z_beta = stats.norm.ppf(self.target_power)
421
+ return z_alpha, z_beta
422
+
423
+ def _compute_variance(
424
+ self,
425
+ n_treated: int,
426
+ n_control: int,
427
+ n_pre: int,
428
+ n_post: int,
429
+ sigma: float,
430
+ rho: float = 0.0,
431
+ design: str = "basic_did",
432
+ ) -> float:
433
+ """
434
+ Compute variance of the DiD estimator.
435
+
436
+ Parameters
437
+ ----------
438
+ n_treated : int
439
+ Number of treated units.
440
+ n_control : int
441
+ Number of control units.
442
+ n_pre : int
443
+ Number of pre-treatment periods.
444
+ n_post : int
445
+ Number of post-treatment periods.
446
+ sigma : float
447
+ Residual standard deviation.
448
+ rho : float
449
+ Intra-cluster correlation (for panel data).
450
+ design : str
451
+ Study design type.
452
+
453
+ Returns
454
+ -------
455
+ float
456
+ Variance of the DiD estimator.
457
+ """
458
+ if design == "basic_did":
459
+ # For basic 2x2 DiD, each cell has n_treated/2 or n_control/2 obs
460
+ # assuming balanced design
461
+ n_t_pre = n_treated # treated units in pre-period
462
+ n_t_post = n_treated # treated units in post-period
463
+ n_c_pre = n_control
464
+ n_c_post = n_control
465
+
466
+ variance = sigma**2 * (
467
+ 1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre
468
+ )
469
+ elif design == "panel":
470
+ # Panel DiD with multiple periods
471
+ # Account for serial correlation via ICC
472
+ T = n_pre + n_post
473
+
474
+ # Design effect for clustering
475
+ design_effect = 1 + (T - 1) * rho
476
+
477
+ # Base variance (as if independent)
478
+ base_var = sigma**2 * (1 / n_treated + 1 / n_control)
479
+
480
+ # Adjust for clustering (Moulton factor)
481
+ variance = base_var * design_effect / T
482
+ else:
483
+ raise ValueError(f"Unknown design: {design}")
484
+
485
+ return variance
486
+
487
+ def power(
488
+ self,
489
+ effect_size: float,
490
+ n_treated: int,
491
+ n_control: int,
492
+ sigma: float,
493
+ n_pre: int = 1,
494
+ n_post: int = 1,
495
+ rho: float = 0.0,
496
+ ) -> PowerResults:
497
+ """
498
+ Calculate statistical power for given effect size and sample.
499
+
500
+ Parameters
501
+ ----------
502
+ effect_size : float
503
+ Expected treatment effect size.
504
+ n_treated : int
505
+ Number of treated units.
506
+ n_control : int
507
+ Number of control units.
508
+ sigma : float
509
+ Residual standard deviation.
510
+ n_pre : int, default=1
511
+ Number of pre-treatment periods.
512
+ n_post : int, default=1
513
+ Number of post-treatment periods.
514
+ rho : float, default=0.0
515
+ Intra-cluster correlation for panel data.
516
+
517
+ Returns
518
+ -------
519
+ PowerResults
520
+ Power analysis results.
521
+
522
+ Examples
523
+ --------
524
+ >>> pa = PowerAnalysis()
525
+ >>> results = pa.power(effect_size=2.0, n_treated=50, n_control=50, sigma=5.0)
526
+ >>> print(f"Power: {results.power:.1%}")
527
+ """
528
+ T = n_pre + n_post
529
+ design = "panel" if T > 2 else "basic_did"
530
+
531
+ variance = self._compute_variance(
532
+ n_treated, n_control, n_pre, n_post, sigma, rho, design
533
+ )
534
+ se = np.sqrt(variance)
535
+
536
+ # Calculate power
537
+ if self.alternative == "two-sided":
538
+ z_alpha = stats.norm.ppf(1 - self.alpha / 2)
539
+ # Power = P(reject | effect) = P(|Z| > z_alpha | effect)
540
+ power_val = (
541
+ 1 - stats.norm.cdf(z_alpha - effect_size / se)
542
+ + stats.norm.cdf(-z_alpha - effect_size / se)
543
+ )
544
+ elif self.alternative == "greater":
545
+ z_alpha = stats.norm.ppf(1 - self.alpha)
546
+ power_val = 1 - stats.norm.cdf(z_alpha - effect_size / se)
547
+ else: # less
548
+ z_alpha = stats.norm.ppf(1 - self.alpha)
549
+ power_val = stats.norm.cdf(-z_alpha - effect_size / se)
550
+
551
+ # Also compute MDE and required N for reference
552
+ mde = self._compute_mde_from_se(se)
553
+ required_n = self._compute_required_n(
554
+ effect_size, sigma, n_pre, n_post, rho, design,
555
+ n_treated / (n_treated + n_control)
556
+ )
557
+
558
+ return PowerResults(
559
+ power=power_val,
560
+ mde=mde,
561
+ required_n=required_n,
562
+ effect_size=effect_size,
563
+ alpha=self.alpha,
564
+ alternative=self.alternative,
565
+ n_treated=n_treated,
566
+ n_control=n_control,
567
+ n_pre=n_pre,
568
+ n_post=n_post,
569
+ sigma=sigma,
570
+ rho=rho,
571
+ design=design,
572
+ )
573
+
574
+ def _compute_mde_from_se(self, se: float) -> float:
575
+ """Compute MDE given standard error."""
576
+ z_alpha, z_beta = self._get_critical_values()
577
+ return (z_alpha + z_beta) * se
578
+
579
+ def mde(
580
+ self,
581
+ n_treated: int,
582
+ n_control: int,
583
+ sigma: float,
584
+ n_pre: int = 1,
585
+ n_post: int = 1,
586
+ rho: float = 0.0,
587
+ ) -> PowerResults:
588
+ """
589
+ Calculate minimum detectable effect given sample size.
590
+
591
+ The MDE is the smallest effect size that can be detected with the
592
+ specified power and significance level.
593
+
594
+ Parameters
595
+ ----------
596
+ n_treated : int
597
+ Number of treated units.
598
+ n_control : int
599
+ Number of control units.
600
+ sigma : float
601
+ Residual standard deviation.
602
+ n_pre : int, default=1
603
+ Number of pre-treatment periods.
604
+ n_post : int, default=1
605
+ Number of post-treatment periods.
606
+ rho : float, default=0.0
607
+ Intra-cluster correlation for panel data.
608
+
609
+ Returns
610
+ -------
611
+ PowerResults
612
+ Power analysis results including MDE.
613
+
614
+ Examples
615
+ --------
616
+ >>> pa = PowerAnalysis(power=0.80)
617
+ >>> results = pa.mde(n_treated=100, n_control=100, sigma=10.0)
618
+ >>> print(f"MDE: {results.mde:.2f}")
619
+ """
620
+ T = n_pre + n_post
621
+ design = "panel" if T > 2 else "basic_did"
622
+
623
+ variance = self._compute_variance(
624
+ n_treated, n_control, n_pre, n_post, sigma, rho, design
625
+ )
626
+ se = np.sqrt(variance)
627
+
628
+ mde = self._compute_mde_from_se(se)
629
+
630
+ return PowerResults(
631
+ power=self.target_power,
632
+ mde=mde,
633
+ required_n=n_treated + n_control,
634
+ effect_size=mde,
635
+ alpha=self.alpha,
636
+ alternative=self.alternative,
637
+ n_treated=n_treated,
638
+ n_control=n_control,
639
+ n_pre=n_pre,
640
+ n_post=n_post,
641
+ sigma=sigma,
642
+ rho=rho,
643
+ design=design,
644
+ )
645
+
646
+ def _compute_required_n(
647
+ self,
648
+ effect_size: float,
649
+ sigma: float,
650
+ n_pre: int,
651
+ n_post: int,
652
+ rho: float,
653
+ design: str,
654
+ treat_frac: float = 0.5,
655
+ ) -> int:
656
+ """Compute required sample size for given effect."""
657
+ # Handle edge case of zero effect size
658
+ if effect_size == 0:
659
+ return MAX_SAMPLE_SIZE # Can't detect zero effect
660
+
661
+ z_alpha, z_beta = self._get_critical_values()
662
+
663
+ T = n_pre + n_post
664
+
665
+ if design == "basic_did":
666
+ # Var = sigma^2 * (1/n_t + 1/n_t + 1/n_c + 1/n_c) = sigma^2 * (2/n_t + 2/n_c)
667
+ # For balanced: Var = sigma^2 * 4/n where n = n_t = n_c
668
+ # SE = sqrt(Var), effect_size = (z_alpha + z_beta) * SE
669
+ # n = 4 * sigma^2 * (z_alpha + z_beta)^2 / effect_size^2
670
+
671
+ # For general allocation with treat_frac:
672
+ # Var = sigma^2 * 2 * (1/(N*p) + 1/(N*(1-p)))
673
+ # = 2 * sigma^2 / N * (1/p + 1/(1-p))
674
+ # = 2 * sigma^2 / N * (1/(p*(1-p)))
675
+
676
+ n_total = (
677
+ 2 * sigma**2 * (z_alpha + z_beta)**2
678
+ / (effect_size**2 * treat_frac * (1 - treat_frac))
679
+ )
680
+ else: # panel
681
+ design_effect = 1 + (T - 1) * rho
682
+
683
+ # Var = sigma^2 * (1/n_t + 1/n_c) * design_effect / T
684
+ # For balanced: Var = 2 * sigma^2 / N * design_effect / T
685
+
686
+ n_total = (
687
+ 2 * sigma**2 * (z_alpha + z_beta)**2 * design_effect
688
+ / (effect_size**2 * treat_frac * (1 - treat_frac) * T)
689
+ )
690
+
691
+ # Handle infinity case (extremely small effect)
692
+ if np.isinf(n_total):
693
+ return MAX_SAMPLE_SIZE
694
+
695
+ return max(4, int(np.ceil(n_total))) # At least 4 units
696
+
697
+ def sample_size(
698
+ self,
699
+ effect_size: float,
700
+ sigma: float,
701
+ n_pre: int = 1,
702
+ n_post: int = 1,
703
+ rho: float = 0.0,
704
+ treat_frac: float = 0.5,
705
+ ) -> PowerResults:
706
+ """
707
+ Calculate required sample size to detect given effect.
708
+
709
+ Parameters
710
+ ----------
711
+ effect_size : float
712
+ Treatment effect to detect.
713
+ sigma : float
714
+ Residual standard deviation.
715
+ n_pre : int, default=1
716
+ Number of pre-treatment periods.
717
+ n_post : int, default=1
718
+ Number of post-treatment periods.
719
+ rho : float, default=0.0
720
+ Intra-cluster correlation for panel data.
721
+ treat_frac : float, default=0.5
722
+ Fraction of units assigned to treatment.
723
+
724
+ Returns
725
+ -------
726
+ PowerResults
727
+ Power analysis results including required sample size.
728
+
729
+ Examples
730
+ --------
731
+ >>> pa = PowerAnalysis(power=0.80)
732
+ >>> results = pa.sample_size(effect_size=5.0, sigma=10.0)
733
+ >>> print(f"Required N: {results.required_n}")
734
+ """
735
+ T = n_pre + n_post
736
+ design = "panel" if T > 2 else "basic_did"
737
+
738
+ n_total = self._compute_required_n(
739
+ effect_size, sigma, n_pre, n_post, rho, design, treat_frac
740
+ )
741
+
742
+ n_treated = max(2, int(np.ceil(n_total * treat_frac)))
743
+ n_control = max(2, n_total - n_treated)
744
+ n_total = n_treated + n_control
745
+
746
+ # Compute actual power achieved
747
+ variance = self._compute_variance(
748
+ n_treated, n_control, n_pre, n_post, sigma, rho, design
749
+ )
750
+ se = np.sqrt(variance)
751
+ mde = self._compute_mde_from_se(se)
752
+
753
+ return PowerResults(
754
+ power=self.target_power,
755
+ mde=mde,
756
+ required_n=n_total,
757
+ effect_size=effect_size,
758
+ alpha=self.alpha,
759
+ alternative=self.alternative,
760
+ n_treated=n_treated,
761
+ n_control=n_control,
762
+ n_pre=n_pre,
763
+ n_post=n_post,
764
+ sigma=sigma,
765
+ rho=rho,
766
+ design=design,
767
+ )
768
+
769
+ def power_curve(
770
+ self,
771
+ n_treated: int,
772
+ n_control: int,
773
+ sigma: float,
774
+ effect_sizes: Optional[List[float]] = None,
775
+ n_pre: int = 1,
776
+ n_post: int = 1,
777
+ rho: float = 0.0,
778
+ ) -> pd.DataFrame:
779
+ """
780
+ Compute power for a range of effect sizes.
781
+
782
+ Parameters
783
+ ----------
784
+ n_treated : int
785
+ Number of treated units.
786
+ n_control : int
787
+ Number of control units.
788
+ sigma : float
789
+ Residual standard deviation.
790
+ effect_sizes : list of float, optional
791
+ Effect sizes to evaluate. If None, uses a range from 0 to 3*MDE.
792
+ n_pre : int, default=1
793
+ Number of pre-treatment periods.
794
+ n_post : int, default=1
795
+ Number of post-treatment periods.
796
+ rho : float, default=0.0
797
+ Intra-cluster correlation.
798
+
799
+ Returns
800
+ -------
801
+ pd.DataFrame
802
+ DataFrame with columns 'effect_size' and 'power'.
803
+
804
+ Examples
805
+ --------
806
+ >>> pa = PowerAnalysis()
807
+ >>> curve = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
808
+ >>> print(curve)
809
+ """
810
+ # First get MDE to determine default range
811
+ mde_result = self.mde(n_treated, n_control, sigma, n_pre, n_post, rho)
812
+
813
+ if effect_sizes is None:
814
+ # Generate range from 0 to 2*MDE
815
+ effect_sizes = np.linspace(0, 2.5 * mde_result.mde, 50).tolist()
816
+
817
+ powers = []
818
+ for es in effect_sizes:
819
+ result = self.power(
820
+ effect_size=es,
821
+ n_treated=n_treated,
822
+ n_control=n_control,
823
+ sigma=sigma,
824
+ n_pre=n_pre,
825
+ n_post=n_post,
826
+ rho=rho,
827
+ )
828
+ powers.append(result.power)
829
+
830
+ return pd.DataFrame({"effect_size": effect_sizes, "power": powers})
831
+
832
+ def sample_size_curve(
833
+ self,
834
+ effect_size: float,
835
+ sigma: float,
836
+ sample_sizes: Optional[List[int]] = None,
837
+ n_pre: int = 1,
838
+ n_post: int = 1,
839
+ rho: float = 0.0,
840
+ treat_frac: float = 0.5,
841
+ ) -> pd.DataFrame:
842
+ """
843
+ Compute power for a range of sample sizes.
844
+
845
+ Parameters
846
+ ----------
847
+ effect_size : float
848
+ Treatment effect size.
849
+ sigma : float
850
+ Residual standard deviation.
851
+ sample_sizes : list of int, optional
852
+ Total sample sizes to evaluate. If None, uses sensible range.
853
+ n_pre : int, default=1
854
+ Number of pre-treatment periods.
855
+ n_post : int, default=1
856
+ Number of post-treatment periods.
857
+ rho : float, default=0.0
858
+ Intra-cluster correlation.
859
+ treat_frac : float, default=0.5
860
+ Fraction assigned to treatment.
861
+
862
+ Returns
863
+ -------
864
+ pd.DataFrame
865
+ DataFrame with columns 'sample_size' and 'power'.
866
+ """
867
+ # Get required N to determine default range
868
+ required = self.sample_size(
869
+ effect_size, sigma, n_pre, n_post, rho, treat_frac
870
+ )
871
+
872
+ if sample_sizes is None:
873
+ min_n = max(10, required.required_n // 4)
874
+ max_n = required.required_n * 2
875
+ sample_sizes = list(range(min_n, max_n + 1, max(1, (max_n - min_n) // 50)))
876
+
877
+ powers = []
878
+ for n in sample_sizes:
879
+ n_treated = max(2, int(n * treat_frac))
880
+ n_control = max(2, n - n_treated)
881
+ result = self.power(
882
+ effect_size=effect_size,
883
+ n_treated=n_treated,
884
+ n_control=n_control,
885
+ sigma=sigma,
886
+ n_pre=n_pre,
887
+ n_post=n_post,
888
+ rho=rho,
889
+ )
890
+ powers.append(result.power)
891
+
892
+ return pd.DataFrame({"sample_size": sample_sizes, "power": powers})
893
+
894
+
895
+ def simulate_power(
896
+ estimator: Any,
897
+ n_units: int = 100,
898
+ n_periods: int = 4,
899
+ treatment_effect: float = 5.0,
900
+ treatment_fraction: float = 0.5,
901
+ treatment_period: int = 2,
902
+ sigma: float = 1.0,
903
+ n_simulations: int = 500,
904
+ alpha: float = 0.05,
905
+ effect_sizes: Optional[List[float]] = None,
906
+ seed: Optional[int] = None,
907
+ data_generator: Optional[Callable] = None,
908
+ data_generator_kwargs: Optional[Dict[str, Any]] = None,
909
+ estimator_kwargs: Optional[Dict[str, Any]] = None,
910
+ progress: bool = True,
911
+ ) -> SimulationPowerResults:
912
+ """
913
+ Estimate power using Monte Carlo simulation.
914
+
915
+ This function simulates datasets with known treatment effects and estimates
916
+ power as the fraction of simulations where the null hypothesis is rejected.
917
+ This is the recommended approach for complex designs like staggered adoption.
918
+
919
+ Parameters
920
+ ----------
921
+ estimator : estimator object
922
+ DiD estimator to use (e.g., DifferenceInDifferences, CallawaySantAnna).
923
+ n_units : int, default=100
924
+ Number of units per simulation.
925
+ n_periods : int, default=4
926
+ Number of time periods.
927
+ treatment_effect : float, default=5.0
928
+ True treatment effect to simulate.
929
+ treatment_fraction : float, default=0.5
930
+ Fraction of units that are treated.
931
+ treatment_period : int, default=2
932
+ First post-treatment period (0-indexed).
933
+ sigma : float, default=1.0
934
+ Residual standard deviation (noise level).
935
+ n_simulations : int, default=500
936
+ Number of Monte Carlo simulations.
937
+ alpha : float, default=0.05
938
+ Significance level for hypothesis tests.
939
+ effect_sizes : list of float, optional
940
+ Multiple effect sizes to evaluate for power curve.
941
+ If None, uses only treatment_effect.
942
+ seed : int, optional
943
+ Random seed for reproducibility.
944
+ data_generator : callable, optional
945
+ Custom data generation function. Should accept same signature as
946
+ generate_did_data(). If None, uses generate_did_data().
947
+ data_generator_kwargs : dict, optional
948
+ Additional keyword arguments for data generator.
949
+ estimator_kwargs : dict, optional
950
+ Additional keyword arguments for estimator.fit().
951
+ progress : bool, default=True
952
+ Whether to print progress updates.
953
+
954
+ Returns
955
+ -------
956
+ SimulationPowerResults
957
+ Simulation-based power analysis results.
958
+
959
+ Examples
960
+ --------
961
+ Basic power simulation:
962
+
963
+ >>> from diff_diff import DifferenceInDifferences, simulate_power
964
+ >>> did = DifferenceInDifferences()
965
+ >>> results = simulate_power(
966
+ ... estimator=did,
967
+ ... n_units=100,
968
+ ... treatment_effect=5.0,
969
+ ... sigma=5.0,
970
+ ... n_simulations=500,
971
+ ... seed=42
972
+ ... )
973
+ >>> print(f"Power: {results.power:.1%}")
974
+
975
+ Power curve over multiple effect sizes:
976
+
977
+ >>> results = simulate_power(
978
+ ... estimator=did,
979
+ ... effect_sizes=[1.0, 2.0, 3.0, 5.0, 7.0],
980
+ ... n_simulations=200,
981
+ ... seed=42
982
+ ... )
983
+ >>> print(results.power_curve_df())
984
+
985
+ With Callaway-Sant'Anna for staggered designs:
986
+
987
+ >>> from diff_diff import CallawaySantAnna
988
+ >>> cs = CallawaySantAnna()
989
+ >>> # Custom data generator for staggered adoption
990
+ >>> def staggered_data(n_units, n_periods, treatment_effect, **kwargs):
991
+ ... # Your staggered data generation logic
992
+ ... ...
993
+ >>> results = simulate_power(cs, data_generator=staggered_data, ...)
994
+
995
+ Notes
996
+ -----
997
+ The simulation approach:
998
+ 1. Generate data with known treatment effect
999
+ 2. Fit the estimator and record the p-value
1000
+ 3. Repeat n_simulations times
1001
+ 4. Power = fraction of simulations where p-value < alpha
1002
+
1003
+ For staggered designs, you'll need to provide a custom data_generator
1004
+ that creates appropriate staggered treatment timing.
1005
+
1006
+ References
1007
+ ----------
1008
+ Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
1009
+ """
1010
+ from diff_diff.prep import generate_did_data
1011
+
1012
+ rng = np.random.default_rng(seed)
1013
+
1014
+ # Use default data generator if none provided
1015
+ if data_generator is None:
1016
+ data_generator = generate_did_data
1017
+
1018
+ data_gen_kwargs = data_generator_kwargs or {}
1019
+ est_kwargs = estimator_kwargs or {}
1020
+
1021
+ # Determine effect sizes to test
1022
+ if effect_sizes is None:
1023
+ effect_sizes = [treatment_effect]
1024
+
1025
+ all_powers = []
1026
+
1027
+ # For the primary effect (last in list), collect detailed results
1028
+ # Use index-based comparison to avoid float precision issues
1029
+ if len(effect_sizes) == 1:
1030
+ primary_idx = 0
1031
+ else:
1032
+ # Find index of treatment_effect in effect_sizes
1033
+ primary_idx = -1
1034
+ for i, es in enumerate(effect_sizes):
1035
+ if np.isclose(es, treatment_effect):
1036
+ primary_idx = i
1037
+ break
1038
+ if primary_idx == -1:
1039
+ primary_idx = len(effect_sizes) - 1 # Default to last
1040
+
1041
+ primary_effect = effect_sizes[primary_idx]
1042
+
1043
+ for effect_idx, effect in enumerate(effect_sizes):
1044
+ is_primary = (effect_idx == primary_idx)
1045
+
1046
+ estimates = []
1047
+ ses = []
1048
+ p_values = []
1049
+ rejections = []
1050
+ ci_contains_true = []
1051
+ n_failures = 0
1052
+
1053
+ for sim in range(n_simulations):
1054
+ if progress and sim % 100 == 0 and sim > 0:
1055
+ pct = (sim + effect_idx * n_simulations) / (len(effect_sizes) * n_simulations)
1056
+ print(f" Simulation progress: {pct:.0%}")
1057
+
1058
+ # Generate data
1059
+ sim_seed = rng.integers(0, 2**31)
1060
+ data = data_generator(
1061
+ n_units=n_units,
1062
+ n_periods=n_periods,
1063
+ treatment_effect=effect,
1064
+ treatment_fraction=treatment_fraction,
1065
+ treatment_period=treatment_period,
1066
+ noise_sd=sigma,
1067
+ seed=sim_seed,
1068
+ **data_gen_kwargs
1069
+ )
1070
+
1071
+ try:
1072
+ # Fit estimator
1073
+ # Try to determine the right arguments based on estimator type
1074
+ estimator_name = type(estimator).__name__
1075
+
1076
+ if estimator_name == "DifferenceInDifferences":
1077
+ result = estimator.fit(
1078
+ data,
1079
+ outcome="outcome",
1080
+ treatment="treated",
1081
+ time="post",
1082
+ **est_kwargs
1083
+ )
1084
+ elif estimator_name == "TwoWayFixedEffects":
1085
+ result = estimator.fit(
1086
+ data,
1087
+ outcome="outcome",
1088
+ treatment="treated",
1089
+ time="period",
1090
+ unit="unit",
1091
+ **est_kwargs
1092
+ )
1093
+ elif estimator_name == "MultiPeriodDiD":
1094
+ post_periods = list(range(treatment_period, n_periods))
1095
+ result = estimator.fit(
1096
+ data,
1097
+ outcome="outcome",
1098
+ treatment="treated",
1099
+ time="period",
1100
+ post_periods=post_periods,
1101
+ **est_kwargs
1102
+ )
1103
+ elif estimator_name == "CallawaySantAnna":
1104
+ # Need to create first_treat column for staggered
1105
+ # For standard generate_did_data, convert to first_treat format
1106
+ data = data.copy()
1107
+ data["first_treat"] = np.where(
1108
+ data["treated"] == 1, treatment_period, 0
1109
+ )
1110
+ result = estimator.fit(
1111
+ data,
1112
+ outcome="outcome",
1113
+ unit="unit",
1114
+ time="period",
1115
+ first_treat="first_treat",
1116
+ **est_kwargs
1117
+ )
1118
+ else:
1119
+ # Generic fallback - try common signature
1120
+ result = estimator.fit(
1121
+ data,
1122
+ outcome="outcome",
1123
+ treatment="treated",
1124
+ time="post",
1125
+ **est_kwargs
1126
+ )
1127
+
1128
+ # Extract results
1129
+ att = result.att if hasattr(result, 'att') else result.avg_att
1130
+ se = result.se if hasattr(result, 'se') else result.avg_se
1131
+ p_val = result.p_value if hasattr(result, 'p_value') else result.avg_p_value
1132
+ ci = result.conf_int if hasattr(result, 'conf_int') else result.avg_conf_int
1133
+
1134
+ estimates.append(att)
1135
+ ses.append(se)
1136
+ p_values.append(p_val)
1137
+ rejections.append(p_val < alpha)
1138
+ ci_contains_true.append(ci[0] <= effect <= ci[1])
1139
+
1140
+ except Exception as e:
1141
+ # Track failed simulations
1142
+ n_failures += 1
1143
+ if progress:
1144
+ print(f" Warning: Simulation {sim} failed: {e}")
1145
+ continue
1146
+
1147
+ # Warn if too many simulations failed
1148
+ failure_rate = n_failures / n_simulations
1149
+ if failure_rate > 0.1:
1150
+ warnings.warn(
1151
+ f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) failed "
1152
+ f"for effect_size={effect}. Check estimator and data generator.",
1153
+ UserWarning
1154
+ )
1155
+
1156
+ if len(estimates) == 0:
1157
+ raise RuntimeError("All simulations failed. Check estimator and data generator.")
1158
+
1159
+ # Compute power and SE
1160
+ power_val = np.mean(rejections)
1161
+ power_se = np.sqrt(power_val * (1 - power_val) / len(rejections))
1162
+
1163
+ all_powers.append(power_val)
1164
+
1165
+ # Store detailed results for primary effect
1166
+ if is_primary:
1167
+ primary_estimates = estimates
1168
+ primary_ses = ses
1169
+ primary_p_values = p_values
1170
+ primary_rejections = rejections
1171
+ primary_ci_contains = ci_contains_true
1172
+
1173
+ # Compute confidence interval for power (primary effect)
1174
+ power_val = all_powers[primary_idx]
1175
+ n_valid = len(primary_rejections)
1176
+ power_se = np.sqrt(power_val * (1 - power_val) / n_valid)
1177
+ z = stats.norm.ppf(0.975)
1178
+ power_ci = (
1179
+ max(0.0, power_val - z * power_se),
1180
+ min(1.0, power_val + z * power_se)
1181
+ )
1182
+
1183
+ # Compute summary statistics
1184
+ mean_estimate = np.mean(primary_estimates)
1185
+ std_estimate = np.std(primary_estimates, ddof=1)
1186
+ mean_se = np.mean(primary_ses)
1187
+ coverage = np.mean(primary_ci_contains)
1188
+
1189
+ return SimulationPowerResults(
1190
+ power=power_val,
1191
+ power_se=power_se,
1192
+ power_ci=power_ci,
1193
+ rejection_rate=power_val,
1194
+ mean_estimate=mean_estimate,
1195
+ std_estimate=std_estimate,
1196
+ mean_se=mean_se,
1197
+ coverage=coverage,
1198
+ n_simulations=n_valid,
1199
+ effect_sizes=effect_sizes,
1200
+ powers=all_powers,
1201
+ true_effect=primary_effect,
1202
+ alpha=alpha,
1203
+ estimator_name=type(estimator).__name__,
1204
+ simulation_results=[
1205
+ {"estimate": e, "se": s, "p_value": p, "rejected": r}
1206
+ for e, s, p, r in zip(primary_estimates, primary_ses,
1207
+ primary_p_values, primary_rejections)
1208
+ ],
1209
+ )
1210
+
1211
+
1212
+ def compute_mde(
1213
+ n_treated: int,
1214
+ n_control: int,
1215
+ sigma: float,
1216
+ power: float = 0.80,
1217
+ alpha: float = 0.05,
1218
+ n_pre: int = 1,
1219
+ n_post: int = 1,
1220
+ rho: float = 0.0,
1221
+ ) -> float:
1222
+ """
1223
+ Convenience function to compute minimum detectable effect.
1224
+
1225
+ Parameters
1226
+ ----------
1227
+ n_treated : int
1228
+ Number of treated units.
1229
+ n_control : int
1230
+ Number of control units.
1231
+ sigma : float
1232
+ Residual standard deviation.
1233
+ power : float, default=0.80
1234
+ Target statistical power.
1235
+ alpha : float, default=0.05
1236
+ Significance level.
1237
+ n_pre : int, default=1
1238
+ Number of pre-treatment periods.
1239
+ n_post : int, default=1
1240
+ Number of post-treatment periods.
1241
+ rho : float, default=0.0
1242
+ Intra-cluster correlation.
1243
+
1244
+ Returns
1245
+ -------
1246
+ float
1247
+ Minimum detectable effect size.
1248
+
1249
+ Examples
1250
+ --------
1251
+ >>> mde = compute_mde(n_treated=50, n_control=50, sigma=10.0)
1252
+ >>> print(f"MDE: {mde:.2f}")
1253
+ """
1254
+ pa = PowerAnalysis(alpha=alpha, power=power)
1255
+ result = pa.mde(n_treated, n_control, sigma, n_pre, n_post, rho)
1256
+ return result.mde
1257
+
1258
+
1259
+ def compute_power(
1260
+ effect_size: float,
1261
+ n_treated: int,
1262
+ n_control: int,
1263
+ sigma: float,
1264
+ alpha: float = 0.05,
1265
+ n_pre: int = 1,
1266
+ n_post: int = 1,
1267
+ rho: float = 0.0,
1268
+ ) -> float:
1269
+ """
1270
+ Convenience function to compute power for given effect and sample.
1271
+
1272
+ Parameters
1273
+ ----------
1274
+ effect_size : float
1275
+ Expected treatment effect.
1276
+ n_treated : int
1277
+ Number of treated units.
1278
+ n_control : int
1279
+ Number of control units.
1280
+ sigma : float
1281
+ Residual standard deviation.
1282
+ alpha : float, default=0.05
1283
+ Significance level.
1284
+ n_pre : int, default=1
1285
+ Number of pre-treatment periods.
1286
+ n_post : int, default=1
1287
+ Number of post-treatment periods.
1288
+ rho : float, default=0.0
1289
+ Intra-cluster correlation.
1290
+
1291
+ Returns
1292
+ -------
1293
+ float
1294
+ Statistical power.
1295
+
1296
+ Examples
1297
+ --------
1298
+ >>> power = compute_power(effect_size=5.0, n_treated=50, n_control=50, sigma=10.0)
1299
+ >>> print(f"Power: {power:.1%}")
1300
+ """
1301
+ pa = PowerAnalysis(alpha=alpha)
1302
+ result = pa.power(effect_size, n_treated, n_control, sigma, n_pre, n_post, rho)
1303
+ return result.power
1304
+
1305
+
1306
+ def compute_sample_size(
1307
+ effect_size: float,
1308
+ sigma: float,
1309
+ power: float = 0.80,
1310
+ alpha: float = 0.05,
1311
+ n_pre: int = 1,
1312
+ n_post: int = 1,
1313
+ rho: float = 0.0,
1314
+ treat_frac: float = 0.5,
1315
+ ) -> int:
1316
+ """
1317
+ Convenience function to compute required sample size.
1318
+
1319
+ Parameters
1320
+ ----------
1321
+ effect_size : float
1322
+ Treatment effect to detect.
1323
+ sigma : float
1324
+ Residual standard deviation.
1325
+ power : float, default=0.80
1326
+ Target statistical power.
1327
+ alpha : float, default=0.05
1328
+ Significance level.
1329
+ n_pre : int, default=1
1330
+ Number of pre-treatment periods.
1331
+ n_post : int, default=1
1332
+ Number of post-treatment periods.
1333
+ rho : float, default=0.0
1334
+ Intra-cluster correlation.
1335
+ treat_frac : float, default=0.5
1336
+ Fraction assigned to treatment.
1337
+
1338
+ Returns
1339
+ -------
1340
+ int
1341
+ Required total sample size.
1342
+
1343
+ Examples
1344
+ --------
1345
+ >>> n = compute_sample_size(effect_size=5.0, sigma=10.0)
1346
+ >>> print(f"Required N: {n}")
1347
+ """
1348
+ pa = PowerAnalysis(alpha=alpha, power=power)
1349
+ result = pa.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac)
1350
+ return result.required_n