diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/utils.py ADDED
@@ -0,0 +1,1481 @@
1
+ """
2
+ Utility functions for difference-in-differences estimation.
3
+ """
4
+
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy import stats
12
+
13
+ from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
14
+ from diff_diff.linalg import solve_ols as _solve_ols_linalg
15
+
16
+ # Import Rust backend if available (from _backend to avoid circular imports)
17
+ from diff_diff._backend import (
18
+ HAS_RUST_BACKEND,
19
+ _rust_project_simplex,
20
+ _rust_synthetic_weights,
21
+ )
22
+
23
+ # Numerical constants for optimization algorithms
24
+ _OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization
25
+ _OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization
26
+ _NUMERICAL_EPS = 1e-10 # Small constant to prevent division by zero
27
+
28
+
29
+ def validate_binary(arr: np.ndarray, name: str) -> None:
30
+ """
31
+ Validate that an array contains only binary values (0 or 1).
32
+
33
+ Parameters
34
+ ----------
35
+ arr : np.ndarray
36
+ Array to validate.
37
+ name : str
38
+ Name of the variable (for error messages).
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If array contains non-binary values.
44
+ """
45
+ unique_values = np.unique(arr[~np.isnan(arr)])
46
+ if not np.all(np.isin(unique_values, [0, 1])):
47
+ raise ValueError(
48
+ f"{name} must be binary (0 or 1). "
49
+ f"Found values: {unique_values}"
50
+ )
51
+
52
+
53
+ def compute_robust_se(
54
+ X: np.ndarray,
55
+ residuals: np.ndarray,
56
+ cluster_ids: Optional[np.ndarray] = None
57
+ ) -> np.ndarray:
58
+ """
59
+ Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.
60
+
61
+ This function is a thin wrapper around the optimized implementation in
62
+ diff_diff.linalg for backwards compatibility.
63
+
64
+ Parameters
65
+ ----------
66
+ X : np.ndarray
67
+ Design matrix of shape (n, k).
68
+ residuals : np.ndarray
69
+ Residuals from regression of shape (n,).
70
+ cluster_ids : np.ndarray, optional
71
+ Cluster identifiers for cluster-robust SEs.
72
+
73
+ Returns
74
+ -------
75
+ np.ndarray
76
+ Variance-covariance matrix of shape (k, k).
77
+ """
78
+ return _compute_robust_vcov_linalg(X, residuals, cluster_ids)
79
+
80
+
81
+ def compute_confidence_interval(
82
+ estimate: float,
83
+ se: float,
84
+ alpha: float = 0.05,
85
+ df: Optional[int] = None
86
+ ) -> Tuple[float, float]:
87
+ """
88
+ Compute confidence interval for an estimate.
89
+
90
+ Parameters
91
+ ----------
92
+ estimate : float
93
+ Point estimate.
94
+ se : float
95
+ Standard error.
96
+ alpha : float
97
+ Significance level (default 0.05 for 95% CI).
98
+ df : int, optional
99
+ Degrees of freedom. If None, uses normal distribution.
100
+
101
+ Returns
102
+ -------
103
+ tuple
104
+ (lower_bound, upper_bound) of confidence interval.
105
+ """
106
+ if df is not None:
107
+ critical_value = stats.t.ppf(1 - alpha / 2, df)
108
+ else:
109
+ critical_value = stats.norm.ppf(1 - alpha / 2)
110
+
111
+ lower = estimate - critical_value * se
112
+ upper = estimate + critical_value * se
113
+
114
+ return (lower, upper)
115
+
116
+
117
+ def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
118
+ """
119
+ Compute p-value for a t-statistic.
120
+
121
+ Parameters
122
+ ----------
123
+ t_stat : float
124
+ T-statistic.
125
+ df : int, optional
126
+ Degrees of freedom. If None, uses normal distribution.
127
+ two_sided : bool
128
+ Whether to compute two-sided p-value (default True).
129
+
130
+ Returns
131
+ -------
132
+ float
133
+ P-value.
134
+ """
135
+ if df is not None:
136
+ p_value = stats.t.sf(np.abs(t_stat), df)
137
+ else:
138
+ p_value = stats.norm.sf(np.abs(t_stat))
139
+
140
+ if two_sided:
141
+ p_value *= 2
142
+
143
+ return float(p_value)
144
+
145
+
146
+ # =============================================================================
147
+ # Wild Cluster Bootstrap
148
+ # =============================================================================
149
+
150
+
151
+ @dataclass
152
+ class WildBootstrapResults:
153
+ """
154
+ Results from wild cluster bootstrap inference.
155
+
156
+ Attributes
157
+ ----------
158
+ se : float
159
+ Bootstrap standard error of the coefficient.
160
+ p_value : float
161
+ Bootstrap p-value (two-sided).
162
+ t_stat_original : float
163
+ Original t-statistic from the data.
164
+ ci_lower : float
165
+ Lower bound of the confidence interval.
166
+ ci_upper : float
167
+ Upper bound of the confidence interval.
168
+ n_clusters : int
169
+ Number of clusters in the data.
170
+ n_bootstrap : int
171
+ Number of bootstrap replications.
172
+ weight_type : str
173
+ Type of bootstrap weights used ("rademacher", "webb", or "mammen").
174
+ alpha : float
175
+ Significance level used for confidence interval.
176
+ bootstrap_distribution : np.ndarray, optional
177
+ Full bootstrap distribution of coefficients (if requested).
178
+
179
+ References
180
+ ----------
181
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
182
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
183
+ The Review of Economics and Statistics, 90(3), 414-427.
184
+ """
185
+
186
+ se: float
187
+ p_value: float
188
+ t_stat_original: float
189
+ ci_lower: float
190
+ ci_upper: float
191
+ n_clusters: int
192
+ n_bootstrap: int
193
+ weight_type: str
194
+ alpha: float = 0.05
195
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
196
+
197
+ def summary(self) -> str:
198
+ """Generate formatted summary of bootstrap results."""
199
+ lines = [
200
+ "Wild Cluster Bootstrap Results",
201
+ "=" * 40,
202
+ f"Bootstrap SE: {self.se:.6f}",
203
+ f"Bootstrap p-value: {self.p_value:.4f}",
204
+ f"Original t-stat: {self.t_stat_original:.4f}",
205
+ f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]",
206
+ f"Number of clusters: {self.n_clusters}",
207
+ f"Bootstrap reps: {self.n_bootstrap}",
208
+ f"Weight type: {self.weight_type}",
209
+ ]
210
+ return "\n".join(lines)
211
+
212
+ def print_summary(self) -> None:
213
+ """Print formatted summary to stdout."""
214
+ print(self.summary())
215
+
216
+
217
+ def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
218
+ """
219
+ Generate Rademacher weights: +1 or -1 with probability 0.5.
220
+
221
+ Parameters
222
+ ----------
223
+ n_clusters : int
224
+ Number of clusters.
225
+ rng : np.random.Generator
226
+ Random number generator.
227
+
228
+ Returns
229
+ -------
230
+ np.ndarray
231
+ Array of Rademacher weights.
232
+ """
233
+ return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters))
234
+
235
+
236
+ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
237
+ """
238
+ Generate Webb's 6-point distribution weights.
239
+
240
+ Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
241
+ with probabilities proportional to {1, 2, 3, 3, 2, 1}.
242
+
243
+ This distribution is recommended for very few clusters (G < 10) as it
244
+ provides better finite-sample properties than Rademacher weights.
245
+
246
+ Parameters
247
+ ----------
248
+ n_clusters : int
249
+ Number of clusters.
250
+ rng : np.random.Generator
251
+ Random number generator.
252
+
253
+ Returns
254
+ -------
255
+ np.ndarray
256
+ Array of Webb weights.
257
+
258
+ References
259
+ ----------
260
+ Webb, M. D. (2014). Reworking wild bootstrap based inference for
261
+ clustered errors. Queen's Economics Department Working Paper No. 1315.
262
+ """
263
+ values = np.array([
264
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
265
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
266
+ ])
267
+ probs = np.array([1, 2, 3, 3, 2, 1]) / 12
268
+ return np.asarray(rng.choice(values, size=n_clusters, p=probs))
269
+
270
+
271
+ def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
272
+ """
273
+ Generate Mammen's two-point distribution weights.
274
+
275
+ Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2}
276
+ with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}.
277
+
278
+ This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides
279
+ asymptotic refinement for skewed error distributions.
280
+
281
+ Parameters
282
+ ----------
283
+ n_clusters : int
284
+ Number of clusters.
285
+ rng : np.random.Generator
286
+ Random number generator.
287
+
288
+ Returns
289
+ -------
290
+ np.ndarray
291
+ Array of Mammen weights.
292
+
293
+ References
294
+ ----------
295
+ Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional
296
+ Linear Models. The Annals of Statistics, 21(1), 255-285.
297
+ """
298
+ sqrt5 = np.sqrt(5)
299
+ # Values from Mammen (1993)
300
+ val1 = -(sqrt5 - 1) / 2 # approximately -0.618
301
+ val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio)
302
+
303
+ # Probability of val1
304
+ p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724
305
+
306
+ return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
307
+
308
+
309
+ def wild_bootstrap_se(
310
+ X: np.ndarray,
311
+ y: np.ndarray,
312
+ residuals: np.ndarray,
313
+ cluster_ids: np.ndarray,
314
+ coefficient_index: int,
315
+ n_bootstrap: int = 999,
316
+ weight_type: str = "rademacher",
317
+ null_hypothesis: float = 0.0,
318
+ alpha: float = 0.05,
319
+ seed: Optional[int] = None,
320
+ return_distribution: bool = False
321
+ ) -> WildBootstrapResults:
322
+ """
323
+ Compute wild cluster bootstrap standard errors and p-values.
324
+
325
+ Implements the Wild Cluster Residual (WCR) bootstrap procedure from
326
+ Cameron, Gelbach, and Miller (2008). Uses the restricted residuals
327
+ approach (imposing H0: coefficient = null_hypothesis) for more accurate
328
+ p-value computation.
329
+
330
+ Parameters
331
+ ----------
332
+ X : np.ndarray
333
+ Design matrix of shape (n, k).
334
+ y : np.ndarray
335
+ Outcome vector of shape (n,).
336
+ residuals : np.ndarray
337
+ OLS residuals from unrestricted regression, shape (n,).
338
+ cluster_ids : np.ndarray
339
+ Cluster identifiers of shape (n,).
340
+ coefficient_index : int
341
+ Index of the coefficient for which to compute bootstrap inference.
342
+ For DiD, this is typically 3 (the treatment*post interaction term).
343
+ n_bootstrap : int, default=999
344
+ Number of bootstrap replications. Odd numbers are recommended for
345
+ exact p-value computation.
346
+ weight_type : str, default="rademacher"
347
+ Type of bootstrap weights:
348
+ - "rademacher": +1 or -1 with equal probability (standard choice)
349
+ - "webb": 6-point distribution (recommended for <10 clusters)
350
+ - "mammen": Two-point distribution with skewness correction
351
+ null_hypothesis : float, default=0.0
352
+ Value of the null hypothesis for p-value computation.
353
+ alpha : float, default=0.05
354
+ Significance level for confidence interval.
355
+ seed : int, optional
356
+ Random seed for reproducibility. If None (default), results
357
+ will vary between runs.
358
+ return_distribution : bool, default=False
359
+ If True, include full bootstrap distribution in results.
360
+
361
+ Returns
362
+ -------
363
+ WildBootstrapResults
364
+ Dataclass containing bootstrap SE, p-value, confidence interval,
365
+ and other inference results.
366
+
367
+ Raises
368
+ ------
369
+ ValueError
370
+ If weight_type is not recognized or if there are fewer than 2 clusters.
371
+
372
+ Warns
373
+ -----
374
+ UserWarning
375
+ If the number of clusters is less than 5, as bootstrap inference
376
+ may be unreliable.
377
+
378
+ Examples
379
+ --------
380
+ >>> from diff_diff.utils import wild_bootstrap_se
381
+ >>> results = wild_bootstrap_se(
382
+ ... X, y, residuals, cluster_ids,
383
+ ... coefficient_index=3, # ATT coefficient
384
+ ... n_bootstrap=999,
385
+ ... weight_type="rademacher",
386
+ ... seed=42
387
+ ... )
388
+ >>> print(f"Bootstrap SE: {results.se:.4f}")
389
+ >>> print(f"Bootstrap p-value: {results.p_value:.4f}")
390
+
391
+ References
392
+ ----------
393
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
394
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
395
+ The Review of Economics and Statistics, 90(3), 414-427.
396
+
397
+ MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for
398
+ few (treated) clusters. The Econometrics Journal, 21(2), 114-135.
399
+ """
400
+ # Validate inputs
401
+ valid_weight_types = ["rademacher", "webb", "mammen"]
402
+ if weight_type not in valid_weight_types:
403
+ raise ValueError(
404
+ f"weight_type must be one of {valid_weight_types}, got '{weight_type}'"
405
+ )
406
+
407
+ unique_clusters = np.unique(cluster_ids)
408
+ n_clusters = len(unique_clusters)
409
+
410
+ if n_clusters < 2:
411
+ raise ValueError(
412
+ f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}"
413
+ )
414
+
415
+ if n_clusters < 5:
416
+ warnings.warn(
417
+ f"Only {n_clusters} clusters detected. Wild bootstrap inference may be "
418
+ "unreliable with fewer than 5 clusters. Consider using Webb weights "
419
+ "(weight_type='webb') for improved finite-sample properties.",
420
+ UserWarning
421
+ )
422
+
423
+ # Initialize RNG
424
+ rng = np.random.default_rng(seed)
425
+
426
+ # Select weight generator
427
+ weight_generators = {
428
+ "rademacher": _generate_rademacher_weights,
429
+ "webb": _generate_webb_weights,
430
+ "mammen": _generate_mammen_weights,
431
+ }
432
+ generate_weights = weight_generators[weight_type]
433
+
434
+ n = X.shape[0]
435
+
436
+ # Step 1: Compute original coefficient and cluster-robust SE
437
+ beta_hat, _, vcov_original = _solve_ols_linalg(
438
+ X, y, cluster_ids=cluster_ids, return_vcov=True
439
+ )
440
+ original_coef = beta_hat[coefficient_index]
441
+ se_original = np.sqrt(vcov_original[coefficient_index, coefficient_index])
442
+ t_stat_original = (original_coef - null_hypothesis) / se_original
443
+
444
+ # Step 2: Impose null hypothesis (restricted estimation)
445
+ # Create restricted y: y_restricted = y - X[:, coef_index] * null_hypothesis
446
+ # This imposes the null that the coefficient equals null_hypothesis
447
+ y_restricted = y - X[:, coefficient_index] * null_hypothesis
448
+
449
+ # Fit restricted model (but we need to drop the column for the restricted coef)
450
+ # Actually, for WCR bootstrap we keep all columns but impose the null via residuals
451
+ # Re-estimate with the restricted dependent variable
452
+ beta_restricted, residuals_restricted, _ = _solve_ols_linalg(
453
+ X, y_restricted, return_vcov=False
454
+ )
455
+
456
+ # Create cluster-to-observation mapping for efficiency
457
+ cluster_map = {c: np.where(cluster_ids == c)[0] for c in unique_clusters}
458
+ cluster_indices = [cluster_map[c] for c in unique_clusters]
459
+
460
+ # Step 3: Bootstrap loop
461
+ bootstrap_t_stats = np.zeros(n_bootstrap)
462
+ bootstrap_coefs = np.zeros(n_bootstrap)
463
+
464
+ for b in range(n_bootstrap):
465
+ # Generate cluster-level weights
466
+ cluster_weights = generate_weights(n_clusters, rng)
467
+
468
+ # Map cluster weights to observations
469
+ obs_weights = np.zeros(n)
470
+ for g, indices in enumerate(cluster_indices):
471
+ obs_weights[indices] = cluster_weights[g]
472
+
473
+ # Construct bootstrap sample: y* = X @ beta_restricted + e_restricted * weights
474
+ y_star = X @ beta_restricted + residuals_restricted * obs_weights
475
+
476
+ # Estimate bootstrap coefficients with cluster-robust SE
477
+ beta_star, residuals_star, vcov_star = _solve_ols_linalg(
478
+ X, y_star, cluster_ids=cluster_ids, return_vcov=True
479
+ )
480
+ bootstrap_coefs[b] = beta_star[coefficient_index]
481
+ se_star = np.sqrt(vcov_star[coefficient_index, coefficient_index])
482
+
483
+ # Compute bootstrap t-statistic (under null hypothesis)
484
+ if se_star > 0:
485
+ bootstrap_t_stats[b] = (beta_star[coefficient_index] - null_hypothesis) / se_star
486
+ else:
487
+ bootstrap_t_stats[b] = 0.0
488
+
489
+ # Step 4: Compute bootstrap p-value
490
+ # P-value is proportion of |t*| >= |t_original|
491
+ p_value = np.mean(np.abs(bootstrap_t_stats) >= np.abs(t_stat_original))
492
+
493
+ # Ensure p-value is at least 1/(n_bootstrap+1) to avoid exact zero
494
+ p_value = float(max(float(p_value), 1 / (n_bootstrap + 1)))
495
+
496
+ # Step 5: Compute bootstrap SE and confidence interval
497
+ # SE from standard deviation of bootstrap coefficient distribution
498
+ se_bootstrap = float(np.std(bootstrap_coefs, ddof=1))
499
+
500
+ # Percentile confidence interval from bootstrap distribution
501
+ lower_percentile = alpha / 2 * 100
502
+ upper_percentile = (1 - alpha / 2) * 100
503
+ ci_lower = float(np.percentile(bootstrap_coefs, lower_percentile))
504
+ ci_upper = float(np.percentile(bootstrap_coefs, upper_percentile))
505
+
506
+ return WildBootstrapResults(
507
+ se=se_bootstrap,
508
+ p_value=p_value,
509
+ t_stat_original=t_stat_original,
510
+ ci_lower=ci_lower,
511
+ ci_upper=ci_upper,
512
+ n_clusters=n_clusters,
513
+ n_bootstrap=n_bootstrap,
514
+ weight_type=weight_type,
515
+ alpha=alpha,
516
+ bootstrap_distribution=bootstrap_coefs if return_distribution else None
517
+ )
518
+
519
+
520
+ def check_parallel_trends(
521
+ data: pd.DataFrame,
522
+ outcome: str,
523
+ time: str,
524
+ treatment_group: str,
525
+ pre_periods: Optional[List[Any]] = None
526
+ ) -> Dict[str, Any]:
527
+ """
528
+ Perform a simple check for parallel trends assumption.
529
+
530
+ This computes the trend (slope) in the outcome variable for both
531
+ treatment and control groups during pre-treatment periods.
532
+
533
+ Parameters
534
+ ----------
535
+ data : pd.DataFrame
536
+ Panel data.
537
+ outcome : str
538
+ Name of outcome variable column.
539
+ time : str
540
+ Name of time period column.
541
+ treatment_group : str
542
+ Name of treatment group indicator column.
543
+ pre_periods : list, optional
544
+ List of pre-treatment time periods. If None, infers from data.
545
+
546
+ Returns
547
+ -------
548
+ dict
549
+ Dictionary with trend statistics and test results.
550
+ """
551
+ if pre_periods is None:
552
+ # Assume treatment happens at median time period
553
+ all_periods = sorted(data[time].unique())
554
+ mid_point = len(all_periods) // 2
555
+ pre_periods = all_periods[:mid_point]
556
+
557
+ pre_data = data[data[time].isin(pre_periods)]
558
+
559
+ # Compute trends for each group
560
+ treated_data = pre_data[pre_data[treatment_group] == 1]
561
+ control_data = pre_data[pre_data[treatment_group] == 0]
562
+
563
+ # Simple linear regression for trends
564
+ def compute_trend(group_data: pd.DataFrame) -> Tuple[float, float]:
565
+ time_values = group_data[time].values
566
+ outcome_values = group_data[outcome].values
567
+
568
+ # Normalize time to start at 0
569
+ time_norm = time_values - time_values.min()
570
+
571
+ # Compute slope using least squares
572
+ n = len(time_norm)
573
+ if n < 2:
574
+ return np.nan, np.nan
575
+
576
+ mean_t = np.mean(time_norm)
577
+ mean_y = np.mean(outcome_values)
578
+
579
+ # Check for zero variance in time (all same time period)
580
+ time_var = np.sum((time_norm - mean_t) ** 2)
581
+ if time_var == 0:
582
+ return np.nan, np.nan
583
+
584
+ slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / time_var
585
+
586
+ # Compute standard error of slope
587
+ y_hat = mean_y + slope * (time_norm - mean_t)
588
+ residuals = outcome_values - y_hat
589
+ mse = np.sum(residuals ** 2) / (n - 2)
590
+ se_slope = np.sqrt(mse / time_var)
591
+
592
+ return slope, se_slope
593
+
594
+ treated_slope, treated_se = compute_trend(treated_data)
595
+ control_slope, control_se = compute_trend(control_data)
596
+
597
+ # Test for difference in trends
598
+ slope_diff = treated_slope - control_slope
599
+ se_diff = np.sqrt(treated_se ** 2 + control_se ** 2)
600
+ t_stat = slope_diff / se_diff if se_diff > 0 else np.nan
601
+ p_value = compute_p_value(t_stat) if not np.isnan(t_stat) else np.nan
602
+
603
+ return {
604
+ "treated_trend": treated_slope,
605
+ "treated_trend_se": treated_se,
606
+ "control_trend": control_slope,
607
+ "control_trend_se": control_se,
608
+ "trend_difference": slope_diff,
609
+ "trend_difference_se": se_diff,
610
+ "t_statistic": t_stat,
611
+ "p_value": p_value,
612
+ "parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
613
+ }
614
+
615
+
616
+ def check_parallel_trends_robust(
617
+ data: pd.DataFrame,
618
+ outcome: str,
619
+ time: str,
620
+ treatment_group: str,
621
+ unit: Optional[str] = None,
622
+ pre_periods: Optional[List[Any]] = None,
623
+ n_permutations: int = 1000,
624
+ seed: Optional[int] = None,
625
+ wasserstein_threshold: float = 0.2
626
+ ) -> Dict[str, Any]:
627
+ """
628
+ Perform robust parallel trends testing using distributional comparisons.
629
+
630
+ Uses the Wasserstein (Earth Mover's) distance to compare the full
631
+ distribution of outcome changes between treated and control groups,
632
+ with permutation-based inference.
633
+
634
+ Parameters
635
+ ----------
636
+ data : pd.DataFrame
637
+ Panel data with repeated observations over time.
638
+ outcome : str
639
+ Name of outcome variable column.
640
+ time : str
641
+ Name of time period column.
642
+ treatment_group : str
643
+ Name of treatment group indicator column (0/1).
644
+ unit : str, optional
645
+ Name of unit identifier column. If provided, computes unit-level
646
+ changes. Otherwise uses observation-level data.
647
+ pre_periods : list, optional
648
+ List of pre-treatment time periods. If None, uses first half of periods.
649
+ n_permutations : int, default=1000
650
+ Number of permutations for computing p-value.
651
+ seed : int, optional
652
+ Random seed for reproducibility.
653
+ wasserstein_threshold : float, default=0.2
654
+ Threshold for normalized Wasserstein distance. Values below this
655
+ threshold (combined with p > 0.05) suggest parallel trends are plausible.
656
+
657
+ Returns
658
+ -------
659
+ dict
660
+ Dictionary containing:
661
+ - wasserstein_distance: Wasserstein distance between group distributions
662
+ - wasserstein_p_value: Permutation-based p-value
663
+ - ks_statistic: Kolmogorov-Smirnov test statistic
664
+ - ks_p_value: KS test p-value
665
+ - mean_difference: Difference in mean changes
666
+ - variance_ratio: Ratio of variances in changes
667
+ - treated_changes: Array of outcome changes for treated
668
+ - control_changes: Array of outcome changes for control
669
+ - parallel_trends_plausible: Boolean assessment
670
+
671
+ Examples
672
+ --------
673
+ >>> results = check_parallel_trends_robust(
674
+ ... data, outcome='sales', time='year',
675
+ ... treatment_group='treated', unit='firm_id'
676
+ ... )
677
+ >>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
678
+ >>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
679
+
680
+ Notes
681
+ -----
682
+ The Wasserstein distance (Earth Mover's Distance) measures the minimum
683
+ "cost" of transforming one distribution into another. Unlike simple
684
+ mean comparisons, it captures differences in the entire distribution
685
+ shape, making it more robust to non-normal data and heterogeneous effects.
686
+
687
+ A small Wasserstein distance and high p-value suggest the distributions
688
+ of pre-treatment changes are similar, supporting the parallel trends
689
+ assumption.
690
+ """
691
+ # Use local RNG to avoid affecting global random state
692
+ rng = np.random.default_rng(seed)
693
+
694
+ # Identify pre-treatment periods
695
+ if pre_periods is None:
696
+ all_periods = sorted(data[time].unique())
697
+ mid_point = len(all_periods) // 2
698
+ pre_periods = all_periods[:mid_point]
699
+
700
+ pre_data = data[data[time].isin(pre_periods)].copy()
701
+
702
+ # Compute outcome changes
703
+ treated_changes, control_changes = _compute_outcome_changes(
704
+ pre_data, outcome, time, treatment_group, unit
705
+ )
706
+
707
+ if len(treated_changes) < 2 or len(control_changes) < 2:
708
+ return {
709
+ "wasserstein_distance": np.nan,
710
+ "wasserstein_p_value": np.nan,
711
+ "ks_statistic": np.nan,
712
+ "ks_p_value": np.nan,
713
+ "mean_difference": np.nan,
714
+ "variance_ratio": np.nan,
715
+ "treated_changes": treated_changes,
716
+ "control_changes": control_changes,
717
+ "parallel_trends_plausible": None,
718
+ "error": "Insufficient data for comparison",
719
+ }
720
+
721
+ # Compute Wasserstein distance
722
+ wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
723
+
724
+ # Permutation test for Wasserstein distance
725
+ all_changes = np.concatenate([treated_changes, control_changes])
726
+ n_treated = len(treated_changes)
727
+ n_total = len(all_changes)
728
+
729
+ permuted_distances = np.zeros(n_permutations)
730
+ for i in range(n_permutations):
731
+ perm_idx = rng.permutation(n_total)
732
+ perm_treated = all_changes[perm_idx[:n_treated]]
733
+ perm_control = all_changes[perm_idx[n_treated:]]
734
+ permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
735
+
736
+ # P-value: proportion of permuted distances >= observed
737
+ wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
738
+
739
+ # Kolmogorov-Smirnov test
740
+ ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
741
+
742
+ # Additional summary statistics
743
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
744
+ var_treated = np.var(treated_changes, ddof=1)
745
+ var_control = np.var(control_changes, ddof=1)
746
+ var_ratio = var_treated / var_control if var_control > 0 else np.nan
747
+
748
+ # Normalized Wasserstein (relative to pooled std)
749
+ pooled_std = np.std(all_changes, ddof=1)
750
+ wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
751
+
752
+ # Assessment: parallel trends plausible if p-value > 0.05
753
+ # and normalized Wasserstein is small (below threshold)
754
+ plausible = bool(
755
+ wasserstein_p > 0.05 and
756
+ (wasserstein_normalized < wasserstein_threshold if not np.isnan(wasserstein_normalized) else True)
757
+ )
758
+
759
+ return {
760
+ "wasserstein_distance": wasserstein_dist,
761
+ "wasserstein_normalized": wasserstein_normalized,
762
+ "wasserstein_p_value": wasserstein_p,
763
+ "ks_statistic": ks_stat,
764
+ "ks_p_value": ks_p,
765
+ "mean_difference": mean_diff,
766
+ "variance_ratio": var_ratio,
767
+ "n_treated": len(treated_changes),
768
+ "n_control": len(control_changes),
769
+ "treated_changes": treated_changes,
770
+ "control_changes": control_changes,
771
+ "parallel_trends_plausible": plausible,
772
+ }
773
+
774
+
775
+ def _compute_outcome_changes(
776
+ data: pd.DataFrame,
777
+ outcome: str,
778
+ time: str,
779
+ treatment_group: str,
780
+ unit: Optional[str] = None
781
+ ) -> Tuple[np.ndarray, np.ndarray]:
782
+ """
783
+ Compute period-to-period outcome changes for treated and control groups.
784
+
785
+ Parameters
786
+ ----------
787
+ data : pd.DataFrame
788
+ Panel data.
789
+ outcome : str
790
+ Outcome variable column.
791
+ time : str
792
+ Time period column.
793
+ treatment_group : str
794
+ Treatment group indicator column.
795
+ unit : str, optional
796
+ Unit identifier column.
797
+
798
+ Returns
799
+ -------
800
+ tuple
801
+ (treated_changes, control_changes) as numpy arrays.
802
+ """
803
+ if unit is not None:
804
+ # Unit-level changes: compute change for each unit across periods
805
+ data_sorted = data.sort_values([unit, time])
806
+ data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
807
+
808
+ # Remove NaN from first period of each unit
809
+ changes_data = data_sorted.dropna(subset=["_outcome_change"])
810
+
811
+ treated_changes = changes_data[
812
+ changes_data[treatment_group] == 1
813
+ ]["_outcome_change"].values
814
+
815
+ control_changes = changes_data[
816
+ changes_data[treatment_group] == 0
817
+ ]["_outcome_change"].values
818
+ else:
819
+ # Aggregate changes: compute mean change per period per group
820
+ treated_data = data[data[treatment_group] == 1]
821
+ control_data = data[data[treatment_group] == 0]
822
+
823
+ # Compute period means
824
+ treated_means = treated_data.groupby(time)[outcome].mean()
825
+ control_means = control_data.groupby(time)[outcome].mean()
826
+
827
+ # Compute changes between consecutive periods
828
+ treated_changes = np.diff(treated_means.values)
829
+ control_changes = np.diff(control_means.values)
830
+
831
+ return treated_changes.astype(float), control_changes.astype(float)
832
+
833
+
834
+ def equivalence_test_trends(
835
+ data: pd.DataFrame,
836
+ outcome: str,
837
+ time: str,
838
+ treatment_group: str,
839
+ unit: Optional[str] = None,
840
+ pre_periods: Optional[List[Any]] = None,
841
+ equivalence_margin: Optional[float] = None
842
+ ) -> Dict[str, Any]:
843
+ """
844
+ Perform equivalence testing (TOST) for parallel trends.
845
+
846
+ Tests whether the difference in trends is practically equivalent to zero
847
+ using Two One-Sided Tests (TOST) procedure.
848
+
849
+ Parameters
850
+ ----------
851
+ data : pd.DataFrame
852
+ Panel data.
853
+ outcome : str
854
+ Name of outcome variable column.
855
+ time : str
856
+ Name of time period column.
857
+ treatment_group : str
858
+ Name of treatment group indicator column.
859
+ unit : str, optional
860
+ Name of unit identifier column.
861
+ pre_periods : list, optional
862
+ List of pre-treatment time periods.
863
+ equivalence_margin : float, optional
864
+ The margin for equivalence (delta). If None, uses 0.5 * pooled SD
865
+ of outcome changes as a default.
866
+
867
+ Returns
868
+ -------
869
+ dict
870
+ Dictionary containing:
871
+ - mean_difference: Difference in mean changes
872
+ - equivalence_margin: The margin used
873
+ - lower_p_value: P-value for lower bound test
874
+ - upper_p_value: P-value for upper bound test
875
+ - tost_p_value: Maximum of the two p-values
876
+ - equivalent: Boolean indicating equivalence at alpha=0.05
877
+ """
878
+ # Get pre-treatment periods
879
+ if pre_periods is None:
880
+ all_periods = sorted(data[time].unique())
881
+ mid_point = len(all_periods) // 2
882
+ pre_periods = all_periods[:mid_point]
883
+
884
+ pre_data = data[data[time].isin(pre_periods)].copy()
885
+
886
+ # Compute outcome changes
887
+ treated_changes, control_changes = _compute_outcome_changes(
888
+ pre_data, outcome, time, treatment_group, unit
889
+ )
890
+
891
+ # Need at least 2 observations per group to compute variance
892
+ # and at least 3 total for meaningful df calculation
893
+ if len(treated_changes) < 2 or len(control_changes) < 2:
894
+ return {
895
+ "mean_difference": np.nan,
896
+ "se_difference": np.nan,
897
+ "equivalence_margin": np.nan,
898
+ "lower_t_stat": np.nan,
899
+ "upper_t_stat": np.nan,
900
+ "lower_p_value": np.nan,
901
+ "upper_p_value": np.nan,
902
+ "tost_p_value": np.nan,
903
+ "degrees_of_freedom": np.nan,
904
+ "equivalent": None,
905
+ "error": "Insufficient data (need at least 2 observations per group)",
906
+ }
907
+
908
+ # Compute statistics
909
+ var_t = np.var(treated_changes, ddof=1)
910
+ var_c = np.var(control_changes, ddof=1)
911
+ n_t = len(treated_changes)
912
+ n_c = len(control_changes)
913
+
914
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
915
+
916
+ # Handle zero variance case
917
+ if var_t == 0 and var_c == 0:
918
+ return {
919
+ "mean_difference": mean_diff,
920
+ "se_difference": 0.0,
921
+ "equivalence_margin": np.nan,
922
+ "lower_t_stat": np.nan,
923
+ "upper_t_stat": np.nan,
924
+ "lower_p_value": np.nan,
925
+ "upper_p_value": np.nan,
926
+ "tost_p_value": np.nan,
927
+ "degrees_of_freedom": np.nan,
928
+ "equivalent": None,
929
+ "error": "Zero variance in both groups - cannot perform t-test",
930
+ }
931
+
932
+ se_diff = np.sqrt(var_t / n_t + var_c / n_c)
933
+
934
+ # Handle zero SE case (cannot divide by zero in t-stat calculation)
935
+ if se_diff == 0:
936
+ return {
937
+ "mean_difference": mean_diff,
938
+ "se_difference": 0.0,
939
+ "equivalence_margin": np.nan,
940
+ "lower_t_stat": np.nan,
941
+ "upper_t_stat": np.nan,
942
+ "lower_p_value": np.nan,
943
+ "upper_p_value": np.nan,
944
+ "tost_p_value": np.nan,
945
+ "degrees_of_freedom": np.nan,
946
+ "equivalent": None,
947
+ "error": "Zero standard error - cannot perform t-test",
948
+ }
949
+
950
+ # Set equivalence margin if not provided
951
+ if equivalence_margin is None:
952
+ pooled_changes = np.concatenate([treated_changes, control_changes])
953
+ equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)
954
+
955
+ # Degrees of freedom (Welch-Satterthwaite approximation)
956
+ # Guard against division by zero when one group has zero variance
957
+ numerator = (var_t/n_t + var_c/n_c)**2
958
+ denom_t = (var_t/n_t)**2/(n_t-1) if var_t > 0 else 0
959
+ denom_c = (var_c/n_c)**2/(n_c-1) if var_c > 0 else 0
960
+ denominator = denom_t + denom_c
961
+
962
+ if denominator == 0:
963
+ # Fall back to minimum of n_t-1 and n_c-1 when one variance is zero
964
+ df = min(n_t - 1, n_c - 1)
965
+ else:
966
+ df = numerator / denominator
967
+
968
+ # TOST: Two one-sided tests
969
+ # Test 1: H0: diff <= -margin vs H1: diff > -margin
970
+ t_lower = (mean_diff - (-equivalence_margin)) / se_diff
971
+ p_lower = stats.t.sf(t_lower, df)
972
+
973
+ # Test 2: H0: diff >= margin vs H1: diff < margin
974
+ t_upper = (mean_diff - equivalence_margin) / se_diff
975
+ p_upper = stats.t.cdf(t_upper, df)
976
+
977
+ # TOST p-value is the maximum of the two
978
+ tost_p = max(p_lower, p_upper)
979
+
980
+ return {
981
+ "mean_difference": mean_diff,
982
+ "se_difference": se_diff,
983
+ "equivalence_margin": equivalence_margin,
984
+ "lower_t_stat": t_lower,
985
+ "upper_t_stat": t_upper,
986
+ "lower_p_value": p_lower,
987
+ "upper_p_value": p_upper,
988
+ "tost_p_value": tost_p,
989
+ "degrees_of_freedom": df,
990
+ "equivalent": bool(tost_p < 0.05),
991
+ }
992
+
993
+
994
+ def compute_synthetic_weights(
995
+ Y_control: np.ndarray,
996
+ Y_treated: np.ndarray,
997
+ lambda_reg: float = 0.0,
998
+ min_weight: float = 1e-6
999
+ ) -> np.ndarray:
1000
+ """
1001
+ Compute synthetic control unit weights using constrained optimization.
1002
+
1003
+ Finds weights ω that minimize the squared difference between the
1004
+ weighted average of control unit outcomes and the treated unit outcomes
1005
+ during pre-treatment periods.
1006
+
1007
+ Parameters
1008
+ ----------
1009
+ Y_control : np.ndarray
1010
+ Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
1011
+ Each column is a control unit, each row is a pre-treatment period.
1012
+ Y_treated : np.ndarray
1013
+ Treated unit mean outcomes of shape (n_pre_periods,).
1014
+ Average across treated units for each pre-treatment period.
1015
+ lambda_reg : float, default=0.0
1016
+ L2 regularization parameter. Larger values shrink weights toward
1017
+ uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
1018
+ min_weight : float, default=1e-6
1019
+ Minimum weight threshold. Weights below this are set to zero.
1020
+
1021
+ Returns
1022
+ -------
1023
+ np.ndarray
1024
+ Unit weights of shape (n_control_units,) that sum to 1.
1025
+
1026
+ Notes
1027
+ -----
1028
+ Solves the quadratic program:
1029
+
1030
+ min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
1031
+ s.t. ω >= 0, sum(ω) = 1
1032
+
1033
+ Uses a simplified coordinate descent approach with projection onto simplex.
1034
+ """
1035
+ n_pre, n_control = Y_control.shape
1036
+
1037
+ if n_control == 0:
1038
+ return np.asarray([])
1039
+
1040
+ if n_control == 1:
1041
+ return np.asarray([1.0])
1042
+
1043
+ # Use Rust backend if available
1044
+ if HAS_RUST_BACKEND:
1045
+ Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1046
+ Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
1047
+ weights = _rust_synthetic_weights(
1048
+ Y_control, Y_treated, lambda_reg,
1049
+ _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
1050
+ )
1051
+ else:
1052
+ # Fallback to NumPy implementation
1053
+ weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)
1054
+
1055
+ # Set small weights to zero for interpretability
1056
+ weights[weights < min_weight] = 0
1057
+ if np.sum(weights) > 0:
1058
+ weights = weights / np.sum(weights)
1059
+ else:
1060
+ # Fallback to uniform if all weights are zeroed
1061
+ weights = np.ones(n_control) / n_control
1062
+
1063
+ return np.asarray(weights)
1064
+
1065
+
1066
+ def _compute_synthetic_weights_numpy(
1067
+ Y_control: np.ndarray,
1068
+ Y_treated: np.ndarray,
1069
+ lambda_reg: float = 0.0,
1070
+ ) -> np.ndarray:
1071
+ """NumPy fallback implementation of compute_synthetic_weights."""
1072
+ n_pre, n_control = Y_control.shape
1073
+
1074
+ # Initialize with uniform weights
1075
+ weights = np.ones(n_control) / n_control
1076
+
1077
+ # Precompute matrices for optimization
1078
+ # Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
1079
+ # = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
1080
+ YtY = Y_control.T @ Y_control
1081
+ YtT = Y_control.T @ Y_treated
1082
+ w_uniform = np.ones(n_control) / n_control
1083
+
1084
+ # Add regularization
1085
+ H = YtY + lambda_reg * np.eye(n_control)
1086
+ f = YtT + lambda_reg * w_uniform
1087
+
1088
+ # Solve with projected gradient descent
1089
+ # Project onto probability simplex
1090
+ step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)
1091
+
1092
+ for _ in range(_OPTIMIZATION_MAX_ITER):
1093
+ weights_old = weights.copy()
1094
+
1095
+ # Gradient step: minimize ||Y - Y_control @ w||^2
1096
+ grad = H @ weights - f
1097
+ weights = weights - step_size * grad
1098
+
1099
+ # Project onto simplex (sum to 1, non-negative)
1100
+ weights = _project_simplex(weights)
1101
+
1102
+ # Check convergence
1103
+ if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
1104
+ break
1105
+
1106
+ return weights
1107
+
1108
+
1109
+ def _project_simplex(v: np.ndarray) -> np.ndarray:
1110
+ """
1111
+ Project vector onto probability simplex (sum to 1, non-negative).
1112
+
1113
+ Uses the algorithm from Duchi et al. (2008).
1114
+
1115
+ Parameters
1116
+ ----------
1117
+ v : np.ndarray
1118
+ Vector to project.
1119
+
1120
+ Returns
1121
+ -------
1122
+ np.ndarray
1123
+ Projected vector on the simplex.
1124
+ """
1125
+ n = len(v)
1126
+ if n == 0:
1127
+ return v
1128
+
1129
+ # Sort in descending order
1130
+ u = np.sort(v)[::-1]
1131
+
1132
+ # Find the threshold
1133
+ cssv = np.cumsum(u)
1134
+ rho = np.where(u > (cssv - 1) / np.arange(1, n + 1))[0]
1135
+
1136
+ if len(rho) == 0:
1137
+ # All elements are negative or zero
1138
+ rho_val = 0
1139
+ else:
1140
+ rho_val = rho[-1]
1141
+
1142
+ theta = (cssv[rho_val] - 1) / (rho_val + 1)
1143
+
1144
+ return np.asarray(np.maximum(v - theta, 0))
1145
+
1146
+
1147
+ def compute_time_weights(
1148
+ Y_control: np.ndarray,
1149
+ Y_treated: np.ndarray,
1150
+ zeta: float = 1.0
1151
+ ) -> np.ndarray:
1152
+ """
1153
+ Compute time weights for synthetic DiD.
1154
+
1155
+ Time weights emphasize pre-treatment periods where the outcome
1156
+ is more informative for constructing the synthetic control.
1157
+ Based on the SDID approach from Arkhangelsky et al. (2021).
1158
+
1159
+ Parameters
1160
+ ----------
1161
+ Y_control : np.ndarray
1162
+ Control unit outcomes of shape (n_pre_periods, n_control_units).
1163
+ Y_treated : np.ndarray
1164
+ Treated unit mean outcomes of shape (n_pre_periods,).
1165
+ zeta : float, default=1.0
1166
+ Regularization parameter for time weights. Higher values
1167
+ give more uniform weights.
1168
+
1169
+ Returns
1170
+ -------
1171
+ np.ndarray
1172
+ Time weights of shape (n_pre_periods,) that sum to 1.
1173
+
1174
+ Notes
1175
+ -----
1176
+ The time weights help interpolate between DiD (uniform weights)
1177
+ and synthetic control (weights concentrated on similar periods).
1178
+ """
1179
+ n_pre = len(Y_treated)
1180
+
1181
+ if n_pre <= 1:
1182
+ return np.asarray(np.ones(n_pre))
1183
+
1184
+ # Compute mean control outcomes per period
1185
+ control_means = np.mean(Y_control, axis=1)
1186
+
1187
+ # Compute differences from treated
1188
+ diffs = np.abs(Y_treated - control_means)
1189
+
1190
+ # Inverse weighting: periods with smaller differences get higher weight
1191
+ # Add regularization to prevent extreme weights
1192
+ inv_diffs = 1.0 / (diffs + zeta * np.std(diffs) + _NUMERICAL_EPS)
1193
+
1194
+ # Normalize to sum to 1
1195
+ weights = inv_diffs / np.sum(inv_diffs)
1196
+
1197
+ return np.asarray(weights)
1198
+
1199
+
1200
+ def compute_sdid_estimator(
1201
+ Y_pre_control: np.ndarray,
1202
+ Y_post_control: np.ndarray,
1203
+ Y_pre_treated: np.ndarray,
1204
+ Y_post_treated: np.ndarray,
1205
+ unit_weights: np.ndarray,
1206
+ time_weights: np.ndarray
1207
+ ) -> float:
1208
+ """
1209
+ Compute the Synthetic DiD estimator.
1210
+
1211
+ Parameters
1212
+ ----------
1213
+ Y_pre_control : np.ndarray
1214
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1215
+ Y_post_control : np.ndarray
1216
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1217
+ Y_pre_treated : np.ndarray
1218
+ Treated unit outcomes in pre-treatment periods, shape (n_pre,).
1219
+ Y_post_treated : np.ndarray
1220
+ Treated unit outcomes in post-treatment periods, shape (n_post,).
1221
+ unit_weights : np.ndarray
1222
+ Weights for control units, shape (n_control,).
1223
+ time_weights : np.ndarray
1224
+ Weights for pre-treatment periods, shape (n_pre,).
1225
+
1226
+ Returns
1227
+ -------
1228
+ float
1229
+ The synthetic DiD treatment effect estimate.
1230
+
1231
+ Notes
1232
+ -----
1233
+ The SDID estimator is:
1234
+
1235
+ τ̂ = (Ȳ_treated,post - Σ_t λ_t * Y_treated,t)
1236
+ - Σ_j ω_j * (Ȳ_j,post - Σ_t λ_t * Y_j,t)
1237
+
1238
+ Where:
1239
+ - ω_j are unit weights
1240
+ - λ_t are time weights
1241
+ - Ȳ denotes average over post periods
1242
+ """
1243
+ # Weighted pre-treatment averages
1244
+ weighted_pre_control = time_weights @ Y_pre_control # shape: (n_control,)
1245
+ weighted_pre_treated = time_weights @ Y_pre_treated # scalar
1246
+
1247
+ # Post-treatment averages
1248
+ mean_post_control = np.mean(Y_post_control, axis=0) # shape: (n_control,)
1249
+ mean_post_treated = np.mean(Y_post_treated) # scalar
1250
+
1251
+ # DiD for treated: post - weighted pre
1252
+ did_treated = mean_post_treated - weighted_pre_treated
1253
+
1254
+ # Weighted DiD for controls: sum over j of omega_j * (post_j - weighted_pre_j)
1255
+ did_control = unit_weights @ (mean_post_control - weighted_pre_control)
1256
+
1257
+ # SDID estimator
1258
+ tau = did_treated - did_control
1259
+
1260
+ return float(tau)
1261
+
1262
+
1263
+ def compute_placebo_effects(
1264
+ Y_pre_control: np.ndarray,
1265
+ Y_post_control: np.ndarray,
1266
+ Y_pre_treated: np.ndarray,
1267
+ unit_weights: np.ndarray,
1268
+ time_weights: np.ndarray,
1269
+ control_unit_ids: List[Any],
1270
+ n_placebo: Optional[int] = None
1271
+ ) -> np.ndarray:
1272
+ """
1273
+ Compute placebo treatment effects by treating each control as treated.
1274
+
1275
+ Used for inference in synthetic DiD when bootstrap is not appropriate.
1276
+
1277
+ Parameters
1278
+ ----------
1279
+ Y_pre_control : np.ndarray
1280
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1281
+ Y_post_control : np.ndarray
1282
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1283
+ Y_pre_treated : np.ndarray
1284
+ Treated outcomes in pre-treatment periods, shape (n_pre,).
1285
+ unit_weights : np.ndarray
1286
+ Unit weights, shape (n_control,).
1287
+ time_weights : np.ndarray
1288
+ Time weights, shape (n_pre,).
1289
+ control_unit_ids : list
1290
+ List of control unit identifiers.
1291
+ n_placebo : int, optional
1292
+ Number of placebo tests. If None, uses all control units.
1293
+
1294
+ Returns
1295
+ -------
1296
+ np.ndarray
1297
+ Array of placebo treatment effects.
1298
+
1299
+ Notes
1300
+ -----
1301
+ For each control unit j, we pretend it was treated and compute
1302
+ the SDID estimate using the remaining controls. The distribution
1303
+ of these placebo effects provides a reference for inference.
1304
+ """
1305
+ n_pre, n_control = Y_pre_control.shape
1306
+
1307
+ if n_placebo is None:
1308
+ n_placebo = n_control
1309
+
1310
+ placebo_effects = []
1311
+
1312
+ for j in range(min(n_placebo, n_control)):
1313
+ # Treat unit j as the "treated" unit
1314
+ Y_pre_placebo_treated = Y_pre_control[:, j]
1315
+ Y_post_placebo_treated = Y_post_control[:, j]
1316
+
1317
+ # Use remaining units as controls
1318
+ remaining_idx = [i for i in range(n_control) if i != j]
1319
+
1320
+ if len(remaining_idx) == 0:
1321
+ continue
1322
+
1323
+ Y_pre_remaining = Y_pre_control[:, remaining_idx]
1324
+ Y_post_remaining = Y_post_control[:, remaining_idx]
1325
+
1326
+ # Recompute weights for remaining controls
1327
+ remaining_weights = compute_synthetic_weights(
1328
+ Y_pre_remaining,
1329
+ Y_pre_placebo_treated
1330
+ )
1331
+
1332
+ # Compute placebo effect
1333
+ placebo_tau = compute_sdid_estimator(
1334
+ Y_pre_remaining,
1335
+ Y_post_remaining,
1336
+ Y_pre_placebo_treated,
1337
+ Y_post_placebo_treated,
1338
+ remaining_weights,
1339
+ time_weights
1340
+ )
1341
+
1342
+ placebo_effects.append(placebo_tau)
1343
+
1344
+ return np.asarray(placebo_effects)
1345
+
1346
+
1347
+ def demean_by_group(
1348
+ data: pd.DataFrame,
1349
+ variables: List[str],
1350
+ group_var: str,
1351
+ inplace: bool = False,
1352
+ suffix: str = "",
1353
+ ) -> Tuple[pd.DataFrame, int]:
1354
+ """
1355
+ Demean variables by a grouping variable (one-way within transformation).
1356
+
1357
+ For each variable, computes: x_ig - mean(x_g) where g is the group.
1358
+
1359
+ Parameters
1360
+ ----------
1361
+ data : pd.DataFrame
1362
+ DataFrame containing the variables to demean.
1363
+ variables : list of str
1364
+ Column names to demean.
1365
+ group_var : str
1366
+ Column name for the grouping variable.
1367
+ inplace : bool, default False
1368
+ If True, modifies the original columns. If False, leaves original
1369
+ columns unchanged (demeaning is still applied to return value).
1370
+ suffix : str, default ""
1371
+ Suffix to add to demeaned column names (only used when inplace=False
1372
+ and you want to keep both original and demeaned columns).
1373
+
1374
+ Returns
1375
+ -------
1376
+ data : pd.DataFrame
1377
+ DataFrame with demeaned variables.
1378
+ n_effects : int
1379
+ Number of absorbed fixed effects (nunique - 1).
1380
+
1381
+ Examples
1382
+ --------
1383
+ >>> df, n_fe = demean_by_group(df, ['y', 'x1', 'x2'], 'unit')
1384
+ >>> # df['y'], df['x1'], df['x2'] are now demeaned by unit
1385
+ """
1386
+ if not inplace:
1387
+ data = data.copy()
1388
+
1389
+ # Count fixed effects (categories - 1 for identification)
1390
+ n_effects = data[group_var].nunique() - 1
1391
+
1392
+ # Cache the groupby object for efficiency
1393
+ grouper = data.groupby(group_var, sort=False)
1394
+
1395
+ for var in variables:
1396
+ col_name = var if not suffix else f"{var}{suffix}"
1397
+ group_means = grouper[var].transform("mean")
1398
+ data[col_name] = data[var] - group_means
1399
+
1400
+ return data, n_effects
1401
+
1402
+
1403
+ def within_transform(
1404
+ data: pd.DataFrame,
1405
+ variables: List[str],
1406
+ unit: str,
1407
+ time: str,
1408
+ inplace: bool = False,
1409
+ suffix: str = "_demeaned",
1410
+ ) -> pd.DataFrame:
1411
+ """
1412
+ Apply two-way within transformation to remove unit and time fixed effects.
1413
+
1414
+ Computes: y_it - y_i. - y_.t + y_.. for each variable.
1415
+
1416
+ This is the standard fixed effects transformation for panel data that
1417
+ removes both unit-specific and time-specific effects.
1418
+
1419
+ Parameters
1420
+ ----------
1421
+ data : pd.DataFrame
1422
+ Panel data containing the variables to transform.
1423
+ variables : list of str
1424
+ Column names to transform.
1425
+ unit : str
1426
+ Column name for unit identifier.
1427
+ time : str
1428
+ Column name for time period identifier.
1429
+ inplace : bool, default False
1430
+ If True, modifies the original columns. If False, creates new columns
1431
+ with the specified suffix.
1432
+ suffix : str, default "_demeaned"
1433
+ Suffix for new column names when inplace=False.
1434
+
1435
+ Returns
1436
+ -------
1437
+ pd.DataFrame
1438
+ DataFrame with within-transformed variables.
1439
+
1440
+ Notes
1441
+ -----
1442
+ The within transformation removes variation that is constant within units
1443
+ (unit fixed effects) and constant within time periods (time fixed effects).
1444
+ The resulting estimates are equivalent to including unit and time dummies
1445
+ but is computationally more efficient for large panels.
1446
+
1447
+ Examples
1448
+ --------
1449
+ >>> df = within_transform(df, ['y', 'x'], 'unit_id', 'year')
1450
+ >>> # df now has 'y_demeaned' and 'x_demeaned' columns
1451
+ """
1452
+ if not inplace:
1453
+ data = data.copy()
1454
+
1455
+ # Cache groupby objects for efficiency
1456
+ unit_grouper = data.groupby(unit, sort=False)
1457
+ time_grouper = data.groupby(time, sort=False)
1458
+
1459
+ if inplace:
1460
+ # Modify columns in place
1461
+ for var in variables:
1462
+ unit_means = unit_grouper[var].transform("mean")
1463
+ time_means = time_grouper[var].transform("mean")
1464
+ grand_mean = data[var].mean()
1465
+ data[var] = data[var] - unit_means - time_means + grand_mean
1466
+ else:
1467
+ # Build all demeaned columns at once to avoid DataFrame fragmentation
1468
+ demeaned_data = {}
1469
+ for var in variables:
1470
+ unit_means = unit_grouper[var].transform("mean")
1471
+ time_means = time_grouper[var].transform("mean")
1472
+ grand_mean = data[var].mean()
1473
+ demeaned_data[f"{var}{suffix}"] = (
1474
+ data[var] - unit_means - time_means + grand_mean
1475
+ ).values
1476
+
1477
+ # Add all columns at once
1478
+ demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
1479
+ data = pd.concat([data, demeaned_df], axis=1)
1480
+
1481
+ return data