diff-diff 2.2.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/utils.py ADDED
@@ -0,0 +1,1484 @@
1
+ """
2
+ Utility functions for difference-in-differences estimation.
3
+ """
4
+
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy import stats
12
+
13
+ from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
14
+ from diff_diff.linalg import solve_ols as _solve_ols_linalg
15
+
16
+ # Import Rust backend if available (from _backend to avoid circular imports)
17
+ from diff_diff._backend import (
18
+ HAS_RUST_BACKEND,
19
+ _rust_project_simplex,
20
+ _rust_synthetic_weights,
21
+ )
22
+
23
+ # Numerical constants for optimization algorithms
24
+ _OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization
25
+ _OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization
26
+ _NUMERICAL_EPS = 1e-10 # Small constant to prevent division by zero
27
+
28
+
29
+ def validate_binary(arr: np.ndarray, name: str) -> None:
30
+ """
31
+ Validate that an array contains only binary values (0 or 1).
32
+
33
+ Parameters
34
+ ----------
35
+ arr : np.ndarray
36
+ Array to validate.
37
+ name : str
38
+ Name of the variable (for error messages).
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If array contains non-binary values.
44
+ """
45
+ unique_values = np.unique(arr[~np.isnan(arr)])
46
+ if not np.all(np.isin(unique_values, [0, 1])):
47
+ raise ValueError(
48
+ f"{name} must be binary (0 or 1). "
49
+ f"Found values: {unique_values}"
50
+ )
51
+
52
+
53
+ def compute_robust_se(
54
+ X: np.ndarray,
55
+ residuals: np.ndarray,
56
+ cluster_ids: Optional[np.ndarray] = None
57
+ ) -> np.ndarray:
58
+ """
59
+ Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.
60
+
61
+ This function is a thin wrapper around the optimized implementation in
62
+ diff_diff.linalg for backwards compatibility.
63
+
64
+ Parameters
65
+ ----------
66
+ X : np.ndarray
67
+ Design matrix of shape (n, k).
68
+ residuals : np.ndarray
69
+ Residuals from regression of shape (n,).
70
+ cluster_ids : np.ndarray, optional
71
+ Cluster identifiers for cluster-robust SEs.
72
+
73
+ Returns
74
+ -------
75
+ np.ndarray
76
+ Variance-covariance matrix of shape (k, k).
77
+ """
78
+ return _compute_robust_vcov_linalg(X, residuals, cluster_ids)
79
+
80
+
81
+ def compute_confidence_interval(
82
+ estimate: float,
83
+ se: float,
84
+ alpha: float = 0.05,
85
+ df: Optional[int] = None
86
+ ) -> Tuple[float, float]:
87
+ """
88
+ Compute confidence interval for an estimate.
89
+
90
+ Parameters
91
+ ----------
92
+ estimate : float
93
+ Point estimate.
94
+ se : float
95
+ Standard error.
96
+ alpha : float
97
+ Significance level (default 0.05 for 95% CI).
98
+ df : int, optional
99
+ Degrees of freedom. If None, uses normal distribution.
100
+
101
+ Returns
102
+ -------
103
+ tuple
104
+ (lower_bound, upper_bound) of confidence interval.
105
+ """
106
+ if df is not None:
107
+ critical_value = stats.t.ppf(1 - alpha / 2, df)
108
+ else:
109
+ critical_value = stats.norm.ppf(1 - alpha / 2)
110
+
111
+ lower = estimate - critical_value * se
112
+ upper = estimate + critical_value * se
113
+
114
+ return (lower, upper)
115
+
116
+
117
+ def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
118
+ """
119
+ Compute p-value for a t-statistic.
120
+
121
+ Parameters
122
+ ----------
123
+ t_stat : float
124
+ T-statistic.
125
+ df : int, optional
126
+ Degrees of freedom. If None, uses normal distribution.
127
+ two_sided : bool
128
+ Whether to compute two-sided p-value (default True).
129
+
130
+ Returns
131
+ -------
132
+ float
133
+ P-value.
134
+ """
135
+ if df is not None:
136
+ p_value = stats.t.sf(np.abs(t_stat), df)
137
+ else:
138
+ p_value = stats.norm.sf(np.abs(t_stat))
139
+
140
+ if two_sided:
141
+ p_value *= 2
142
+
143
+ return float(p_value)
144
+
145
+
146
+ # =============================================================================
147
+ # Wild Cluster Bootstrap
148
+ # =============================================================================
149
+
150
+
151
+ @dataclass
152
+ class WildBootstrapResults:
153
+ """
154
+ Results from wild cluster bootstrap inference.
155
+
156
+ Attributes
157
+ ----------
158
+ se : float
159
+ Bootstrap standard error of the coefficient.
160
+ p_value : float
161
+ Bootstrap p-value (two-sided).
162
+ t_stat_original : float
163
+ Original t-statistic from the data.
164
+ ci_lower : float
165
+ Lower bound of the confidence interval.
166
+ ci_upper : float
167
+ Upper bound of the confidence interval.
168
+ n_clusters : int
169
+ Number of clusters in the data.
170
+ n_bootstrap : int
171
+ Number of bootstrap replications.
172
+ weight_type : str
173
+ Type of bootstrap weights used ("rademacher", "webb", or "mammen").
174
+ alpha : float
175
+ Significance level used for confidence interval.
176
+ bootstrap_distribution : np.ndarray, optional
177
+ Full bootstrap distribution of coefficients (if requested).
178
+
179
+ References
180
+ ----------
181
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
182
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
183
+ The Review of Economics and Statistics, 90(3), 414-427.
184
+ """
185
+
186
+ se: float
187
+ p_value: float
188
+ t_stat_original: float
189
+ ci_lower: float
190
+ ci_upper: float
191
+ n_clusters: int
192
+ n_bootstrap: int
193
+ weight_type: str
194
+ alpha: float = 0.05
195
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
196
+
197
+ def summary(self) -> str:
198
+ """Generate formatted summary of bootstrap results."""
199
+ lines = [
200
+ "Wild Cluster Bootstrap Results",
201
+ "=" * 40,
202
+ f"Bootstrap SE: {self.se:.6f}",
203
+ f"Bootstrap p-value: {self.p_value:.4f}",
204
+ f"Original t-stat: {self.t_stat_original:.4f}",
205
+ f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]",
206
+ f"Number of clusters: {self.n_clusters}",
207
+ f"Bootstrap reps: {self.n_bootstrap}",
208
+ f"Weight type: {self.weight_type}",
209
+ ]
210
+ return "\n".join(lines)
211
+
212
+ def print_summary(self) -> None:
213
+ """Print formatted summary to stdout."""
214
+ print(self.summary())
215
+
216
+
217
+ def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
218
+ """
219
+ Generate Rademacher weights: +1 or -1 with probability 0.5.
220
+
221
+ Parameters
222
+ ----------
223
+ n_clusters : int
224
+ Number of clusters.
225
+ rng : np.random.Generator
226
+ Random number generator.
227
+
228
+ Returns
229
+ -------
230
+ np.ndarray
231
+ Array of Rademacher weights.
232
+ """
233
+ return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters))
234
+
235
+
236
+ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
237
+ """
238
+ Generate Webb's 6-point distribution weights.
239
+
240
+ Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
241
+ with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
242
+
243
+ This distribution is recommended for very few clusters (G < 10) as it
244
+ provides better finite-sample properties than Rademacher weights.
245
+
246
+ Parameters
247
+ ----------
248
+ n_clusters : int
249
+ Number of clusters.
250
+ rng : np.random.Generator
251
+ Random number generator.
252
+
253
+ Returns
254
+ -------
255
+ np.ndarray
256
+ Array of Webb weights.
257
+
258
+ References
259
+ ----------
260
+ Webb, M. D. (2014). Reworking wild bootstrap based inference for
261
+ clustered errors. Queen's Economics Department Working Paper No. 1315.
262
+
263
+ Note: Uses equal probabilities (1/6 each) matching R's `did` package,
264
+ which gives unit variance for consistency with other weight distributions.
265
+ """
266
+ values = np.array([
267
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
268
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
269
+ ])
270
+ # Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
271
+ return np.asarray(rng.choice(values, size=n_clusters))
272
+
273
+
274
+ def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
275
+ """
276
+ Generate Mammen's two-point distribution weights.
277
+
278
+ Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2}
279
+ with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}.
280
+
281
+ This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides
282
+ asymptotic refinement for skewed error distributions.
283
+
284
+ Parameters
285
+ ----------
286
+ n_clusters : int
287
+ Number of clusters.
288
+ rng : np.random.Generator
289
+ Random number generator.
290
+
291
+ Returns
292
+ -------
293
+ np.ndarray
294
+ Array of Mammen weights.
295
+
296
+ References
297
+ ----------
298
+ Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional
299
+ Linear Models. The Annals of Statistics, 21(1), 255-285.
300
+ """
301
+ sqrt5 = np.sqrt(5)
302
+ # Values from Mammen (1993)
303
+ val1 = -(sqrt5 - 1) / 2 # approximately -0.618
304
+ val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio)
305
+
306
+ # Probability of val1
307
+ p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724
308
+
309
+ return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
310
+
311
+
312
+ def wild_bootstrap_se(
313
+ X: np.ndarray,
314
+ y: np.ndarray,
315
+ residuals: np.ndarray,
316
+ cluster_ids: np.ndarray,
317
+ coefficient_index: int,
318
+ n_bootstrap: int = 999,
319
+ weight_type: str = "rademacher",
320
+ null_hypothesis: float = 0.0,
321
+ alpha: float = 0.05,
322
+ seed: Optional[int] = None,
323
+ return_distribution: bool = False
324
+ ) -> WildBootstrapResults:
325
+ """
326
+ Compute wild cluster bootstrap standard errors and p-values.
327
+
328
+ Implements the Wild Cluster Residual (WCR) bootstrap procedure from
329
+ Cameron, Gelbach, and Miller (2008). Uses the restricted residuals
330
+ approach (imposing H0: coefficient = null_hypothesis) for more accurate
331
+ p-value computation.
332
+
333
+ Parameters
334
+ ----------
335
+ X : np.ndarray
336
+ Design matrix of shape (n, k).
337
+ y : np.ndarray
338
+ Outcome vector of shape (n,).
339
+ residuals : np.ndarray
340
+ OLS residuals from unrestricted regression, shape (n,).
341
+ cluster_ids : np.ndarray
342
+ Cluster identifiers of shape (n,).
343
+ coefficient_index : int
344
+ Index of the coefficient for which to compute bootstrap inference.
345
+ For DiD, this is typically 3 (the treatment*post interaction term).
346
+ n_bootstrap : int, default=999
347
+ Number of bootstrap replications. Odd numbers are recommended for
348
+ exact p-value computation.
349
+ weight_type : str, default="rademacher"
350
+ Type of bootstrap weights:
351
+ - "rademacher": +1 or -1 with equal probability (standard choice)
352
+ - "webb": 6-point distribution (recommended for <10 clusters)
353
+ - "mammen": Two-point distribution with skewness correction
354
+ null_hypothesis : float, default=0.0
355
+ Value of the null hypothesis for p-value computation.
356
+ alpha : float, default=0.05
357
+ Significance level for confidence interval.
358
+ seed : int, optional
359
+ Random seed for reproducibility. If None (default), results
360
+ will vary between runs.
361
+ return_distribution : bool, default=False
362
+ If True, include full bootstrap distribution in results.
363
+
364
+ Returns
365
+ -------
366
+ WildBootstrapResults
367
+ Dataclass containing bootstrap SE, p-value, confidence interval,
368
+ and other inference results.
369
+
370
+ Raises
371
+ ------
372
+ ValueError
373
+ If weight_type is not recognized or if there are fewer than 2 clusters.
374
+
375
+ Warns
376
+ -----
377
+ UserWarning
378
+ If the number of clusters is less than 5, as bootstrap inference
379
+ may be unreliable.
380
+
381
+ Examples
382
+ --------
383
+ >>> from diff_diff.utils import wild_bootstrap_se
384
+ >>> results = wild_bootstrap_se(
385
+ ... X, y, residuals, cluster_ids,
386
+ ... coefficient_index=3, # ATT coefficient
387
+ ... n_bootstrap=999,
388
+ ... weight_type="rademacher",
389
+ ... seed=42
390
+ ... )
391
+ >>> print(f"Bootstrap SE: {results.se:.4f}")
392
+ >>> print(f"Bootstrap p-value: {results.p_value:.4f}")
393
+
394
+ References
395
+ ----------
396
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
397
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
398
+ The Review of Economics and Statistics, 90(3), 414-427.
399
+
400
+ MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for
401
+ few (treated) clusters. The Econometrics Journal, 21(2), 114-135.
402
+ """
403
+ # Validate inputs
404
+ valid_weight_types = ["rademacher", "webb", "mammen"]
405
+ if weight_type not in valid_weight_types:
406
+ raise ValueError(
407
+ f"weight_type must be one of {valid_weight_types}, got '{weight_type}'"
408
+ )
409
+
410
+ unique_clusters = np.unique(cluster_ids)
411
+ n_clusters = len(unique_clusters)
412
+
413
+ if n_clusters < 2:
414
+ raise ValueError(
415
+ f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}"
416
+ )
417
+
418
+ if n_clusters < 5:
419
+ warnings.warn(
420
+ f"Only {n_clusters} clusters detected. Wild bootstrap inference may be "
421
+ "unreliable with fewer than 5 clusters. Consider using Webb weights "
422
+ "(weight_type='webb') for improved finite-sample properties.",
423
+ UserWarning
424
+ )
425
+
426
+ # Initialize RNG
427
+ rng = np.random.default_rng(seed)
428
+
429
+ # Select weight generator
430
+ weight_generators = {
431
+ "rademacher": _generate_rademacher_weights,
432
+ "webb": _generate_webb_weights,
433
+ "mammen": _generate_mammen_weights,
434
+ }
435
+ generate_weights = weight_generators[weight_type]
436
+
437
+ n = X.shape[0]
438
+
439
+ # Step 1: Compute original coefficient and cluster-robust SE
440
+ beta_hat, _, vcov_original = _solve_ols_linalg(
441
+ X, y, cluster_ids=cluster_ids, return_vcov=True
442
+ )
443
+ original_coef = beta_hat[coefficient_index]
444
+ se_original = np.sqrt(vcov_original[coefficient_index, coefficient_index])
445
+ t_stat_original = (original_coef - null_hypothesis) / se_original
446
+
447
+ # Step 2: Impose null hypothesis (restricted estimation)
448
+ # Create restricted y: y_restricted = y - X[:, coef_index] * null_hypothesis
449
+ # This imposes the null that the coefficient equals null_hypothesis
450
+ y_restricted = y - X[:, coefficient_index] * null_hypothesis
451
+
452
+ # Fit restricted model (but we need to drop the column for the restricted coef)
453
+ # Actually, for WCR bootstrap we keep all columns but impose the null via residuals
454
+ # Re-estimate with the restricted dependent variable
455
+ beta_restricted, residuals_restricted, _ = _solve_ols_linalg(
456
+ X, y_restricted, return_vcov=False
457
+ )
458
+
459
+ # Create cluster-to-observation mapping for efficiency
460
+ cluster_map = {c: np.where(cluster_ids == c)[0] for c in unique_clusters}
461
+ cluster_indices = [cluster_map[c] for c in unique_clusters]
462
+
463
+ # Step 3: Bootstrap loop
464
+ bootstrap_t_stats = np.zeros(n_bootstrap)
465
+ bootstrap_coefs = np.zeros(n_bootstrap)
466
+
467
+ for b in range(n_bootstrap):
468
+ # Generate cluster-level weights
469
+ cluster_weights = generate_weights(n_clusters, rng)
470
+
471
+ # Map cluster weights to observations
472
+ obs_weights = np.zeros(n)
473
+ for g, indices in enumerate(cluster_indices):
474
+ obs_weights[indices] = cluster_weights[g]
475
+
476
+ # Construct bootstrap sample: y* = X @ beta_restricted + e_restricted * weights
477
+ y_star = X @ beta_restricted + residuals_restricted * obs_weights
478
+
479
+ # Estimate bootstrap coefficients with cluster-robust SE
480
+ beta_star, residuals_star, vcov_star = _solve_ols_linalg(
481
+ X, y_star, cluster_ids=cluster_ids, return_vcov=True
482
+ )
483
+ bootstrap_coefs[b] = beta_star[coefficient_index]
484
+ se_star = np.sqrt(vcov_star[coefficient_index, coefficient_index])
485
+
486
+ # Compute bootstrap t-statistic (under null hypothesis)
487
+ if se_star > 0:
488
+ bootstrap_t_stats[b] = (beta_star[coefficient_index] - null_hypothesis) / se_star
489
+ else:
490
+ bootstrap_t_stats[b] = 0.0
491
+
492
+ # Step 4: Compute bootstrap p-value
493
+ # P-value is proportion of |t*| >= |t_original|
494
+ p_value = np.mean(np.abs(bootstrap_t_stats) >= np.abs(t_stat_original))
495
+
496
+ # Ensure p-value is at least 1/(n_bootstrap+1) to avoid exact zero
497
+ p_value = float(max(float(p_value), 1 / (n_bootstrap + 1)))
498
+
499
+ # Step 5: Compute bootstrap SE and confidence interval
500
+ # SE from standard deviation of bootstrap coefficient distribution
501
+ se_bootstrap = float(np.std(bootstrap_coefs, ddof=1))
502
+
503
+ # Percentile confidence interval from bootstrap distribution
504
+ lower_percentile = alpha / 2 * 100
505
+ upper_percentile = (1 - alpha / 2) * 100
506
+ ci_lower = float(np.percentile(bootstrap_coefs, lower_percentile))
507
+ ci_upper = float(np.percentile(bootstrap_coefs, upper_percentile))
508
+
509
+ return WildBootstrapResults(
510
+ se=se_bootstrap,
511
+ p_value=p_value,
512
+ t_stat_original=t_stat_original,
513
+ ci_lower=ci_lower,
514
+ ci_upper=ci_upper,
515
+ n_clusters=n_clusters,
516
+ n_bootstrap=n_bootstrap,
517
+ weight_type=weight_type,
518
+ alpha=alpha,
519
+ bootstrap_distribution=bootstrap_coefs if return_distribution else None
520
+ )
521
+
522
+
523
+ def check_parallel_trends(
524
+ data: pd.DataFrame,
525
+ outcome: str,
526
+ time: str,
527
+ treatment_group: str,
528
+ pre_periods: Optional[List[Any]] = None
529
+ ) -> Dict[str, Any]:
530
+ """
531
+ Perform a simple check for parallel trends assumption.
532
+
533
+ This computes the trend (slope) in the outcome variable for both
534
+ treatment and control groups during pre-treatment periods.
535
+
536
+ Parameters
537
+ ----------
538
+ data : pd.DataFrame
539
+ Panel data.
540
+ outcome : str
541
+ Name of outcome variable column.
542
+ time : str
543
+ Name of time period column.
544
+ treatment_group : str
545
+ Name of treatment group indicator column.
546
+ pre_periods : list, optional
547
+ List of pre-treatment time periods. If None, infers from data.
548
+
549
+ Returns
550
+ -------
551
+ dict
552
+ Dictionary with trend statistics and test results.
553
+ """
554
+ if pre_periods is None:
555
+ # Assume treatment happens at median time period
556
+ all_periods = sorted(data[time].unique())
557
+ mid_point = len(all_periods) // 2
558
+ pre_periods = all_periods[:mid_point]
559
+
560
+ pre_data = data[data[time].isin(pre_periods)]
561
+
562
+ # Compute trends for each group
563
+ treated_data = pre_data[pre_data[treatment_group] == 1]
564
+ control_data = pre_data[pre_data[treatment_group] == 0]
565
+
566
+ # Simple linear regression for trends
567
+ def compute_trend(group_data: pd.DataFrame) -> Tuple[float, float]:
568
+ time_values = group_data[time].values
569
+ outcome_values = group_data[outcome].values
570
+
571
+ # Normalize time to start at 0
572
+ time_norm = time_values - time_values.min()
573
+
574
+ # Compute slope using least squares
575
+ n = len(time_norm)
576
+ if n < 2:
577
+ return np.nan, np.nan
578
+
579
+ mean_t = np.mean(time_norm)
580
+ mean_y = np.mean(outcome_values)
581
+
582
+ # Check for zero variance in time (all same time period)
583
+ time_var = np.sum((time_norm - mean_t) ** 2)
584
+ if time_var == 0:
585
+ return np.nan, np.nan
586
+
587
+ slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / time_var
588
+
589
+ # Compute standard error of slope
590
+ y_hat = mean_y + slope * (time_norm - mean_t)
591
+ residuals = outcome_values - y_hat
592
+ mse = np.sum(residuals ** 2) / (n - 2)
593
+ se_slope = np.sqrt(mse / time_var)
594
+
595
+ return slope, se_slope
596
+
597
+ treated_slope, treated_se = compute_trend(treated_data)
598
+ control_slope, control_se = compute_trend(control_data)
599
+
600
+ # Test for difference in trends
601
+ slope_diff = treated_slope - control_slope
602
+ se_diff = np.sqrt(treated_se ** 2 + control_se ** 2)
603
+ t_stat = slope_diff / se_diff if se_diff > 0 else np.nan
604
+ p_value = compute_p_value(t_stat) if not np.isnan(t_stat) else np.nan
605
+
606
+ return {
607
+ "treated_trend": treated_slope,
608
+ "treated_trend_se": treated_se,
609
+ "control_trend": control_slope,
610
+ "control_trend_se": control_se,
611
+ "trend_difference": slope_diff,
612
+ "trend_difference_se": se_diff,
613
+ "t_statistic": t_stat,
614
+ "p_value": p_value,
615
+ "parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
616
+ }
617
+
618
+
619
+ def check_parallel_trends_robust(
620
+ data: pd.DataFrame,
621
+ outcome: str,
622
+ time: str,
623
+ treatment_group: str,
624
+ unit: Optional[str] = None,
625
+ pre_periods: Optional[List[Any]] = None,
626
+ n_permutations: int = 1000,
627
+ seed: Optional[int] = None,
628
+ wasserstein_threshold: float = 0.2
629
+ ) -> Dict[str, Any]:
630
+ """
631
+ Perform robust parallel trends testing using distributional comparisons.
632
+
633
+ Uses the Wasserstein (Earth Mover's) distance to compare the full
634
+ distribution of outcome changes between treated and control groups,
635
+ with permutation-based inference.
636
+
637
+ Parameters
638
+ ----------
639
+ data : pd.DataFrame
640
+ Panel data with repeated observations over time.
641
+ outcome : str
642
+ Name of outcome variable column.
643
+ time : str
644
+ Name of time period column.
645
+ treatment_group : str
646
+ Name of treatment group indicator column (0/1).
647
+ unit : str, optional
648
+ Name of unit identifier column. If provided, computes unit-level
649
+ changes. Otherwise uses observation-level data.
650
+ pre_periods : list, optional
651
+ List of pre-treatment time periods. If None, uses first half of periods.
652
+ n_permutations : int, default=1000
653
+ Number of permutations for computing p-value.
654
+ seed : int, optional
655
+ Random seed for reproducibility.
656
+ wasserstein_threshold : float, default=0.2
657
+ Threshold for normalized Wasserstein distance. Values below this
658
+ threshold (combined with p > 0.05) suggest parallel trends are plausible.
659
+
660
+ Returns
661
+ -------
662
+ dict
663
+ Dictionary containing:
664
+ - wasserstein_distance: Wasserstein distance between group distributions
665
+ - wasserstein_p_value: Permutation-based p-value
666
+ - ks_statistic: Kolmogorov-Smirnov test statistic
667
+ - ks_p_value: KS test p-value
668
+ - mean_difference: Difference in mean changes
669
+ - variance_ratio: Ratio of variances in changes
670
+ - treated_changes: Array of outcome changes for treated
671
+ - control_changes: Array of outcome changes for control
672
+ - parallel_trends_plausible: Boolean assessment
673
+
674
+ Examples
675
+ --------
676
+ >>> results = check_parallel_trends_robust(
677
+ ... data, outcome='sales', time='year',
678
+ ... treatment_group='treated', unit='firm_id'
679
+ ... )
680
+ >>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
681
+ >>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
682
+
683
+ Notes
684
+ -----
685
+ The Wasserstein distance (Earth Mover's Distance) measures the minimum
686
+ "cost" of transforming one distribution into another. Unlike simple
687
+ mean comparisons, it captures differences in the entire distribution
688
+ shape, making it more robust to non-normal data and heterogeneous effects.
689
+
690
+ A small Wasserstein distance and high p-value suggest the distributions
691
+ of pre-treatment changes are similar, supporting the parallel trends
692
+ assumption.
693
+ """
694
+ # Use local RNG to avoid affecting global random state
695
+ rng = np.random.default_rng(seed)
696
+
697
+ # Identify pre-treatment periods
698
+ if pre_periods is None:
699
+ all_periods = sorted(data[time].unique())
700
+ mid_point = len(all_periods) // 2
701
+ pre_periods = all_periods[:mid_point]
702
+
703
+ pre_data = data[data[time].isin(pre_periods)].copy()
704
+
705
+ # Compute outcome changes
706
+ treated_changes, control_changes = _compute_outcome_changes(
707
+ pre_data, outcome, time, treatment_group, unit
708
+ )
709
+
710
+ if len(treated_changes) < 2 or len(control_changes) < 2:
711
+ return {
712
+ "wasserstein_distance": np.nan,
713
+ "wasserstein_p_value": np.nan,
714
+ "ks_statistic": np.nan,
715
+ "ks_p_value": np.nan,
716
+ "mean_difference": np.nan,
717
+ "variance_ratio": np.nan,
718
+ "treated_changes": treated_changes,
719
+ "control_changes": control_changes,
720
+ "parallel_trends_plausible": None,
721
+ "error": "Insufficient data for comparison",
722
+ }
723
+
724
+ # Compute Wasserstein distance
725
+ wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
726
+
727
+ # Permutation test for Wasserstein distance
728
+ all_changes = np.concatenate([treated_changes, control_changes])
729
+ n_treated = len(treated_changes)
730
+ n_total = len(all_changes)
731
+
732
+ permuted_distances = np.zeros(n_permutations)
733
+ for i in range(n_permutations):
734
+ perm_idx = rng.permutation(n_total)
735
+ perm_treated = all_changes[perm_idx[:n_treated]]
736
+ perm_control = all_changes[perm_idx[n_treated:]]
737
+ permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
738
+
739
+ # P-value: proportion of permuted distances >= observed
740
+ wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
741
+
742
+ # Kolmogorov-Smirnov test
743
+ ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
744
+
745
+ # Additional summary statistics
746
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
747
+ var_treated = np.var(treated_changes, ddof=1)
748
+ var_control = np.var(control_changes, ddof=1)
749
+ var_ratio = var_treated / var_control if var_control > 0 else np.nan
750
+
751
+ # Normalized Wasserstein (relative to pooled std)
752
+ pooled_std = np.std(all_changes, ddof=1)
753
+ wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
754
+
755
+ # Assessment: parallel trends plausible if p-value > 0.05
756
+ # and normalized Wasserstein is small (below threshold)
757
+ plausible = bool(
758
+ wasserstein_p > 0.05 and
759
+ (wasserstein_normalized < wasserstein_threshold if not np.isnan(wasserstein_normalized) else True)
760
+ )
761
+
762
+ return {
763
+ "wasserstein_distance": wasserstein_dist,
764
+ "wasserstein_normalized": wasserstein_normalized,
765
+ "wasserstein_p_value": wasserstein_p,
766
+ "ks_statistic": ks_stat,
767
+ "ks_p_value": ks_p,
768
+ "mean_difference": mean_diff,
769
+ "variance_ratio": var_ratio,
770
+ "n_treated": len(treated_changes),
771
+ "n_control": len(control_changes),
772
+ "treated_changes": treated_changes,
773
+ "control_changes": control_changes,
774
+ "parallel_trends_plausible": plausible,
775
+ }
776
+
777
+
778
+ def _compute_outcome_changes(
779
+ data: pd.DataFrame,
780
+ outcome: str,
781
+ time: str,
782
+ treatment_group: str,
783
+ unit: Optional[str] = None
784
+ ) -> Tuple[np.ndarray, np.ndarray]:
785
+ """
786
+ Compute period-to-period outcome changes for treated and control groups.
787
+
788
+ Parameters
789
+ ----------
790
+ data : pd.DataFrame
791
+ Panel data.
792
+ outcome : str
793
+ Outcome variable column.
794
+ time : str
795
+ Time period column.
796
+ treatment_group : str
797
+ Treatment group indicator column.
798
+ unit : str, optional
799
+ Unit identifier column.
800
+
801
+ Returns
802
+ -------
803
+ tuple
804
+ (treated_changes, control_changes) as numpy arrays.
805
+ """
806
+ if unit is not None:
807
+ # Unit-level changes: compute change for each unit across periods
808
+ data_sorted = data.sort_values([unit, time])
809
+ data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
810
+
811
+ # Remove NaN from first period of each unit
812
+ changes_data = data_sorted.dropna(subset=["_outcome_change"])
813
+
814
+ treated_changes = changes_data[
815
+ changes_data[treatment_group] == 1
816
+ ]["_outcome_change"].values
817
+
818
+ control_changes = changes_data[
819
+ changes_data[treatment_group] == 0
820
+ ]["_outcome_change"].values
821
+ else:
822
+ # Aggregate changes: compute mean change per period per group
823
+ treated_data = data[data[treatment_group] == 1]
824
+ control_data = data[data[treatment_group] == 0]
825
+
826
+ # Compute period means
827
+ treated_means = treated_data.groupby(time)[outcome].mean()
828
+ control_means = control_data.groupby(time)[outcome].mean()
829
+
830
+ # Compute changes between consecutive periods
831
+ treated_changes = np.diff(treated_means.values)
832
+ control_changes = np.diff(control_means.values)
833
+
834
+ return treated_changes.astype(float), control_changes.astype(float)
835
+
836
+
837
+ def equivalence_test_trends(
838
+ data: pd.DataFrame,
839
+ outcome: str,
840
+ time: str,
841
+ treatment_group: str,
842
+ unit: Optional[str] = None,
843
+ pre_periods: Optional[List[Any]] = None,
844
+ equivalence_margin: Optional[float] = None
845
+ ) -> Dict[str, Any]:
846
+ """
847
+ Perform equivalence testing (TOST) for parallel trends.
848
+
849
+ Tests whether the difference in trends is practically equivalent to zero
850
+ using Two One-Sided Tests (TOST) procedure.
851
+
852
+ Parameters
853
+ ----------
854
+ data : pd.DataFrame
855
+ Panel data.
856
+ outcome : str
857
+ Name of outcome variable column.
858
+ time : str
859
+ Name of time period column.
860
+ treatment_group : str
861
+ Name of treatment group indicator column.
862
+ unit : str, optional
863
+ Name of unit identifier column.
864
+ pre_periods : list, optional
865
+ List of pre-treatment time periods.
866
+ equivalence_margin : float, optional
867
+ The margin for equivalence (delta). If None, uses 0.5 * pooled SD
868
+ of outcome changes as a default.
869
+
870
+ Returns
871
+ -------
872
+ dict
873
+ Dictionary containing:
874
+ - mean_difference: Difference in mean changes
875
+ - equivalence_margin: The margin used
876
+ - lower_p_value: P-value for lower bound test
877
+ - upper_p_value: P-value for upper bound test
878
+ - tost_p_value: Maximum of the two p-values
879
+ - equivalent: Boolean indicating equivalence at alpha=0.05
880
+ """
881
+ # Get pre-treatment periods
882
+ if pre_periods is None:
883
+ all_periods = sorted(data[time].unique())
884
+ mid_point = len(all_periods) // 2
885
+ pre_periods = all_periods[:mid_point]
886
+
887
+ pre_data = data[data[time].isin(pre_periods)].copy()
888
+
889
+ # Compute outcome changes
890
+ treated_changes, control_changes = _compute_outcome_changes(
891
+ pre_data, outcome, time, treatment_group, unit
892
+ )
893
+
894
+ # Need at least 2 observations per group to compute variance
895
+ # and at least 3 total for meaningful df calculation
896
+ if len(treated_changes) < 2 or len(control_changes) < 2:
897
+ return {
898
+ "mean_difference": np.nan,
899
+ "se_difference": np.nan,
900
+ "equivalence_margin": np.nan,
901
+ "lower_t_stat": np.nan,
902
+ "upper_t_stat": np.nan,
903
+ "lower_p_value": np.nan,
904
+ "upper_p_value": np.nan,
905
+ "tost_p_value": np.nan,
906
+ "degrees_of_freedom": np.nan,
907
+ "equivalent": None,
908
+ "error": "Insufficient data (need at least 2 observations per group)",
909
+ }
910
+
911
+ # Compute statistics
912
+ var_t = np.var(treated_changes, ddof=1)
913
+ var_c = np.var(control_changes, ddof=1)
914
+ n_t = len(treated_changes)
915
+ n_c = len(control_changes)
916
+
917
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
918
+
919
+ # Handle zero variance case
920
+ if var_t == 0 and var_c == 0:
921
+ return {
922
+ "mean_difference": mean_diff,
923
+ "se_difference": 0.0,
924
+ "equivalence_margin": np.nan,
925
+ "lower_t_stat": np.nan,
926
+ "upper_t_stat": np.nan,
927
+ "lower_p_value": np.nan,
928
+ "upper_p_value": np.nan,
929
+ "tost_p_value": np.nan,
930
+ "degrees_of_freedom": np.nan,
931
+ "equivalent": None,
932
+ "error": "Zero variance in both groups - cannot perform t-test",
933
+ }
934
+
935
+ se_diff = np.sqrt(var_t / n_t + var_c / n_c)
936
+
937
+ # Handle zero SE case (cannot divide by zero in t-stat calculation)
938
+ if se_diff == 0:
939
+ return {
940
+ "mean_difference": mean_diff,
941
+ "se_difference": 0.0,
942
+ "equivalence_margin": np.nan,
943
+ "lower_t_stat": np.nan,
944
+ "upper_t_stat": np.nan,
945
+ "lower_p_value": np.nan,
946
+ "upper_p_value": np.nan,
947
+ "tost_p_value": np.nan,
948
+ "degrees_of_freedom": np.nan,
949
+ "equivalent": None,
950
+ "error": "Zero standard error - cannot perform t-test",
951
+ }
952
+
953
+ # Set equivalence margin if not provided
954
+ if equivalence_margin is None:
955
+ pooled_changes = np.concatenate([treated_changes, control_changes])
956
+ equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)
957
+
958
+ # Degrees of freedom (Welch-Satterthwaite approximation)
959
+ # Guard against division by zero when one group has zero variance
960
+ numerator = (var_t/n_t + var_c/n_c)**2
961
+ denom_t = (var_t/n_t)**2/(n_t-1) if var_t > 0 else 0
962
+ denom_c = (var_c/n_c)**2/(n_c-1) if var_c > 0 else 0
963
+ denominator = denom_t + denom_c
964
+
965
+ if denominator == 0:
966
+ # Fall back to minimum of n_t-1 and n_c-1 when one variance is zero
967
+ df = min(n_t - 1, n_c - 1)
968
+ else:
969
+ df = numerator / denominator
970
+
971
+ # TOST: Two one-sided tests
972
+ # Test 1: H0: diff <= -margin vs H1: diff > -margin
973
+ t_lower = (mean_diff - (-equivalence_margin)) / se_diff
974
+ p_lower = stats.t.sf(t_lower, df)
975
+
976
+ # Test 2: H0: diff >= margin vs H1: diff < margin
977
+ t_upper = (mean_diff - equivalence_margin) / se_diff
978
+ p_upper = stats.t.cdf(t_upper, df)
979
+
980
+ # TOST p-value is the maximum of the two
981
+ tost_p = max(p_lower, p_upper)
982
+
983
+ return {
984
+ "mean_difference": mean_diff,
985
+ "se_difference": se_diff,
986
+ "equivalence_margin": equivalence_margin,
987
+ "lower_t_stat": t_lower,
988
+ "upper_t_stat": t_upper,
989
+ "lower_p_value": p_lower,
990
+ "upper_p_value": p_upper,
991
+ "tost_p_value": tost_p,
992
+ "degrees_of_freedom": df,
993
+ "equivalent": bool(tost_p < 0.05),
994
+ }
995
+
996
+
997
+ def compute_synthetic_weights(
998
+ Y_control: np.ndarray,
999
+ Y_treated: np.ndarray,
1000
+ lambda_reg: float = 0.0,
1001
+ min_weight: float = 1e-6
1002
+ ) -> np.ndarray:
1003
+ """
1004
+ Compute synthetic control unit weights using constrained optimization.
1005
+
1006
+ Finds weights ω that minimize the squared difference between the
1007
+ weighted average of control unit outcomes and the treated unit outcomes
1008
+ during pre-treatment periods.
1009
+
1010
+ Parameters
1011
+ ----------
1012
+ Y_control : np.ndarray
1013
+ Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
1014
+ Each column is a control unit, each row is a pre-treatment period.
1015
+ Y_treated : np.ndarray
1016
+ Treated unit mean outcomes of shape (n_pre_periods,).
1017
+ Average across treated units for each pre-treatment period.
1018
+ lambda_reg : float, default=0.0
1019
+ L2 regularization parameter. Larger values shrink weights toward
1020
+ uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
1021
+ min_weight : float, default=1e-6
1022
+ Minimum weight threshold. Weights below this are set to zero.
1023
+
1024
+ Returns
1025
+ -------
1026
+ np.ndarray
1027
+ Unit weights of shape (n_control_units,) that sum to 1.
1028
+
1029
+ Notes
1030
+ -----
1031
+ Solves the quadratic program:
1032
+
1033
+ min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
1034
+ s.t. ω >= 0, sum(ω) = 1
1035
+
1036
+ Uses a simplified coordinate descent approach with projection onto simplex.
1037
+ """
1038
+ n_pre, n_control = Y_control.shape
1039
+
1040
+ if n_control == 0:
1041
+ return np.asarray([])
1042
+
1043
+ if n_control == 1:
1044
+ return np.asarray([1.0])
1045
+
1046
+ # Use Rust backend if available
1047
+ if HAS_RUST_BACKEND:
1048
+ Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1049
+ Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
1050
+ weights = _rust_synthetic_weights(
1051
+ Y_control, Y_treated, lambda_reg,
1052
+ _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
1053
+ )
1054
+ else:
1055
+ # Fallback to NumPy implementation
1056
+ weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)
1057
+
1058
+ # Set small weights to zero for interpretability
1059
+ weights[weights < min_weight] = 0
1060
+ if np.sum(weights) > 0:
1061
+ weights = weights / np.sum(weights)
1062
+ else:
1063
+ # Fallback to uniform if all weights are zeroed
1064
+ weights = np.ones(n_control) / n_control
1065
+
1066
+ return np.asarray(weights)
1067
+
1068
+
1069
+ def _compute_synthetic_weights_numpy(
1070
+ Y_control: np.ndarray,
1071
+ Y_treated: np.ndarray,
1072
+ lambda_reg: float = 0.0,
1073
+ ) -> np.ndarray:
1074
+ """NumPy fallback implementation of compute_synthetic_weights."""
1075
+ n_pre, n_control = Y_control.shape
1076
+
1077
+ # Initialize with uniform weights
1078
+ weights = np.ones(n_control) / n_control
1079
+
1080
+ # Precompute matrices for optimization
1081
+ # Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
1082
+ # = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
1083
+ YtY = Y_control.T @ Y_control
1084
+ YtT = Y_control.T @ Y_treated
1085
+ w_uniform = np.ones(n_control) / n_control
1086
+
1087
+ # Add regularization
1088
+ H = YtY + lambda_reg * np.eye(n_control)
1089
+ f = YtT + lambda_reg * w_uniform
1090
+
1091
+ # Solve with projected gradient descent
1092
+ # Project onto probability simplex
1093
+ step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)
1094
+
1095
+ for _ in range(_OPTIMIZATION_MAX_ITER):
1096
+ weights_old = weights.copy()
1097
+
1098
+ # Gradient step: minimize ||Y - Y_control @ w||^2
1099
+ grad = H @ weights - f
1100
+ weights = weights - step_size * grad
1101
+
1102
+ # Project onto simplex (sum to 1, non-negative)
1103
+ weights = _project_simplex(weights)
1104
+
1105
+ # Check convergence
1106
+ if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
1107
+ break
1108
+
1109
+ return weights
1110
+
1111
+
1112
+ def _project_simplex(v: np.ndarray) -> np.ndarray:
1113
+ """
1114
+ Project vector onto probability simplex (sum to 1, non-negative).
1115
+
1116
+ Uses the algorithm from Duchi et al. (2008).
1117
+
1118
+ Parameters
1119
+ ----------
1120
+ v : np.ndarray
1121
+ Vector to project.
1122
+
1123
+ Returns
1124
+ -------
1125
+ np.ndarray
1126
+ Projected vector on the simplex.
1127
+ """
1128
+ n = len(v)
1129
+ if n == 0:
1130
+ return v
1131
+
1132
+ # Sort in descending order
1133
+ u = np.sort(v)[::-1]
1134
+
1135
+ # Find the threshold
1136
+ cssv = np.cumsum(u)
1137
+ rho = np.where(u > (cssv - 1) / np.arange(1, n + 1))[0]
1138
+
1139
+ if len(rho) == 0:
1140
+ # All elements are negative or zero
1141
+ rho_val = 0
1142
+ else:
1143
+ rho_val = rho[-1]
1144
+
1145
+ theta = (cssv[rho_val] - 1) / (rho_val + 1)
1146
+
1147
+ return np.asarray(np.maximum(v - theta, 0))
1148
+
1149
+
1150
+ def compute_time_weights(
1151
+ Y_control: np.ndarray,
1152
+ Y_treated: np.ndarray,
1153
+ zeta: float = 1.0
1154
+ ) -> np.ndarray:
1155
+ """
1156
+ Compute time weights for synthetic DiD.
1157
+
1158
+ Time weights emphasize pre-treatment periods where the outcome
1159
+ is more informative for constructing the synthetic control.
1160
+ Based on the SDID approach from Arkhangelsky et al. (2021).
1161
+
1162
+ Parameters
1163
+ ----------
1164
+ Y_control : np.ndarray
1165
+ Control unit outcomes of shape (n_pre_periods, n_control_units).
1166
+ Y_treated : np.ndarray
1167
+ Treated unit mean outcomes of shape (n_pre_periods,).
1168
+ zeta : float, default=1.0
1169
+ Regularization parameter for time weights. Higher values
1170
+ give more uniform weights.
1171
+
1172
+ Returns
1173
+ -------
1174
+ np.ndarray
1175
+ Time weights of shape (n_pre_periods,) that sum to 1.
1176
+
1177
+ Notes
1178
+ -----
1179
+ The time weights help interpolate between DiD (uniform weights)
1180
+ and synthetic control (weights concentrated on similar periods).
1181
+ """
1182
+ n_pre = len(Y_treated)
1183
+
1184
+ if n_pre <= 1:
1185
+ return np.asarray(np.ones(n_pre))
1186
+
1187
+ # Compute mean control outcomes per period
1188
+ control_means = np.mean(Y_control, axis=1)
1189
+
1190
+ # Compute differences from treated
1191
+ diffs = np.abs(Y_treated - control_means)
1192
+
1193
+ # Inverse weighting: periods with smaller differences get higher weight
1194
+ # Add regularization to prevent extreme weights
1195
+ inv_diffs = 1.0 / (diffs + zeta * np.std(diffs) + _NUMERICAL_EPS)
1196
+
1197
+ # Normalize to sum to 1
1198
+ weights = inv_diffs / np.sum(inv_diffs)
1199
+
1200
+ return np.asarray(weights)
1201
+
1202
+
1203
+ def compute_sdid_estimator(
1204
+ Y_pre_control: np.ndarray,
1205
+ Y_post_control: np.ndarray,
1206
+ Y_pre_treated: np.ndarray,
1207
+ Y_post_treated: np.ndarray,
1208
+ unit_weights: np.ndarray,
1209
+ time_weights: np.ndarray
1210
+ ) -> float:
1211
+ """
1212
+ Compute the Synthetic DiD estimator.
1213
+
1214
+ Parameters
1215
+ ----------
1216
+ Y_pre_control : np.ndarray
1217
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1218
+ Y_post_control : np.ndarray
1219
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1220
+ Y_pre_treated : np.ndarray
1221
+ Treated unit outcomes in pre-treatment periods, shape (n_pre,).
1222
+ Y_post_treated : np.ndarray
1223
+ Treated unit outcomes in post-treatment periods, shape (n_post,).
1224
+ unit_weights : np.ndarray
1225
+ Weights for control units, shape (n_control,).
1226
+ time_weights : np.ndarray
1227
+ Weights for pre-treatment periods, shape (n_pre,).
1228
+
1229
+ Returns
1230
+ -------
1231
+ float
1232
+ The synthetic DiD treatment effect estimate.
1233
+
1234
+ Notes
1235
+ -----
1236
+ The SDID estimator is:
1237
+
1238
+ τ̂ = (Ȳ_treated,post - Σ_t λ_t * Y_treated,t)
1239
+ - Σ_j ω_j * (Ȳ_j,post - Σ_t λ_t * Y_j,t)
1240
+
1241
+ Where:
1242
+ - ω_j are unit weights
1243
+ - λ_t are time weights
1244
+ - Ȳ denotes average over post periods
1245
+ """
1246
+ # Weighted pre-treatment averages
1247
+ weighted_pre_control = time_weights @ Y_pre_control # shape: (n_control,)
1248
+ weighted_pre_treated = time_weights @ Y_pre_treated # scalar
1249
+
1250
+ # Post-treatment averages
1251
+ mean_post_control = np.mean(Y_post_control, axis=0) # shape: (n_control,)
1252
+ mean_post_treated = np.mean(Y_post_treated) # scalar
1253
+
1254
+ # DiD for treated: post - weighted pre
1255
+ did_treated = mean_post_treated - weighted_pre_treated
1256
+
1257
+ # Weighted DiD for controls: sum over j of omega_j * (post_j - weighted_pre_j)
1258
+ did_control = unit_weights @ (mean_post_control - weighted_pre_control)
1259
+
1260
+ # SDID estimator
1261
+ tau = did_treated - did_control
1262
+
1263
+ return float(tau)
1264
+
1265
+
1266
+ def compute_placebo_effects(
1267
+ Y_pre_control: np.ndarray,
1268
+ Y_post_control: np.ndarray,
1269
+ Y_pre_treated: np.ndarray,
1270
+ unit_weights: np.ndarray,
1271
+ time_weights: np.ndarray,
1272
+ control_unit_ids: List[Any],
1273
+ n_placebo: Optional[int] = None
1274
+ ) -> np.ndarray:
1275
+ """
1276
+ Compute placebo treatment effects by treating each control as treated.
1277
+
1278
+ Used for inference in synthetic DiD when bootstrap is not appropriate.
1279
+
1280
+ Parameters
1281
+ ----------
1282
+ Y_pre_control : np.ndarray
1283
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1284
+ Y_post_control : np.ndarray
1285
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1286
+ Y_pre_treated : np.ndarray
1287
+ Treated outcomes in pre-treatment periods, shape (n_pre,).
1288
+ unit_weights : np.ndarray
1289
+ Unit weights, shape (n_control,).
1290
+ time_weights : np.ndarray
1291
+ Time weights, shape (n_pre,).
1292
+ control_unit_ids : list
1293
+ List of control unit identifiers.
1294
+ n_placebo : int, optional
1295
+ Number of placebo tests. If None, uses all control units.
1296
+
1297
+ Returns
1298
+ -------
1299
+ np.ndarray
1300
+ Array of placebo treatment effects.
1301
+
1302
+ Notes
1303
+ -----
1304
+ For each control unit j, we pretend it was treated and compute
1305
+ the SDID estimate using the remaining controls. The distribution
1306
+ of these placebo effects provides a reference for inference.
1307
+ """
1308
+ n_pre, n_control = Y_pre_control.shape
1309
+
1310
+ if n_placebo is None:
1311
+ n_placebo = n_control
1312
+
1313
+ placebo_effects = []
1314
+
1315
+ for j in range(min(n_placebo, n_control)):
1316
+ # Treat unit j as the "treated" unit
1317
+ Y_pre_placebo_treated = Y_pre_control[:, j]
1318
+ Y_post_placebo_treated = Y_post_control[:, j]
1319
+
1320
+ # Use remaining units as controls
1321
+ remaining_idx = [i for i in range(n_control) if i != j]
1322
+
1323
+ if len(remaining_idx) == 0:
1324
+ continue
1325
+
1326
+ Y_pre_remaining = Y_pre_control[:, remaining_idx]
1327
+ Y_post_remaining = Y_post_control[:, remaining_idx]
1328
+
1329
+ # Recompute weights for remaining controls
1330
+ remaining_weights = compute_synthetic_weights(
1331
+ Y_pre_remaining,
1332
+ Y_pre_placebo_treated
1333
+ )
1334
+
1335
+ # Compute placebo effect
1336
+ placebo_tau = compute_sdid_estimator(
1337
+ Y_pre_remaining,
1338
+ Y_post_remaining,
1339
+ Y_pre_placebo_treated,
1340
+ Y_post_placebo_treated,
1341
+ remaining_weights,
1342
+ time_weights
1343
+ )
1344
+
1345
+ placebo_effects.append(placebo_tau)
1346
+
1347
+ return np.asarray(placebo_effects)
1348
+
1349
+
1350
+ def demean_by_group(
1351
+ data: pd.DataFrame,
1352
+ variables: List[str],
1353
+ group_var: str,
1354
+ inplace: bool = False,
1355
+ suffix: str = "",
1356
+ ) -> Tuple[pd.DataFrame, int]:
1357
+ """
1358
+ Demean variables by a grouping variable (one-way within transformation).
1359
+
1360
+ For each variable, computes: x_ig - mean(x_g) where g is the group.
1361
+
1362
+ Parameters
1363
+ ----------
1364
+ data : pd.DataFrame
1365
+ DataFrame containing the variables to demean.
1366
+ variables : list of str
1367
+ Column names to demean.
1368
+ group_var : str
1369
+ Column name for the grouping variable.
1370
+ inplace : bool, default False
1371
+ If True, modifies the original columns. If False, leaves original
1372
+ columns unchanged (demeaning is still applied to return value).
1373
+ suffix : str, default ""
1374
+ Suffix to add to demeaned column names (only used when inplace=False
1375
+ and you want to keep both original and demeaned columns).
1376
+
1377
+ Returns
1378
+ -------
1379
+ data : pd.DataFrame
1380
+ DataFrame with demeaned variables.
1381
+ n_effects : int
1382
+ Number of absorbed fixed effects (nunique - 1).
1383
+
1384
+ Examples
1385
+ --------
1386
+ >>> df, n_fe = demean_by_group(df, ['y', 'x1', 'x2'], 'unit')
1387
+ >>> # df['y'], df['x1'], df['x2'] are now demeaned by unit
1388
+ """
1389
+ if not inplace:
1390
+ data = data.copy()
1391
+
1392
+ # Count fixed effects (categories - 1 for identification)
1393
+ n_effects = data[group_var].nunique() - 1
1394
+
1395
+ # Cache the groupby object for efficiency
1396
+ grouper = data.groupby(group_var, sort=False)
1397
+
1398
+ for var in variables:
1399
+ col_name = var if not suffix else f"{var}{suffix}"
1400
+ group_means = grouper[var].transform("mean")
1401
+ data[col_name] = data[var] - group_means
1402
+
1403
+ return data, n_effects
1404
+
1405
+
1406
+ def within_transform(
1407
+ data: pd.DataFrame,
1408
+ variables: List[str],
1409
+ unit: str,
1410
+ time: str,
1411
+ inplace: bool = False,
1412
+ suffix: str = "_demeaned",
1413
+ ) -> pd.DataFrame:
1414
+ """
1415
+ Apply two-way within transformation to remove unit and time fixed effects.
1416
+
1417
+ Computes: y_it - y_i. - y_.t + y_.. for each variable.
1418
+
1419
+ This is the standard fixed effects transformation for panel data that
1420
+ removes both unit-specific and time-specific effects.
1421
+
1422
+ Parameters
1423
+ ----------
1424
+ data : pd.DataFrame
1425
+ Panel data containing the variables to transform.
1426
+ variables : list of str
1427
+ Column names to transform.
1428
+ unit : str
1429
+ Column name for unit identifier.
1430
+ time : str
1431
+ Column name for time period identifier.
1432
+ inplace : bool, default False
1433
+ If True, modifies the original columns. If False, creates new columns
1434
+ with the specified suffix.
1435
+ suffix : str, default "_demeaned"
1436
+ Suffix for new column names when inplace=False.
1437
+
1438
+ Returns
1439
+ -------
1440
+ pd.DataFrame
1441
+ DataFrame with within-transformed variables.
1442
+
1443
+ Notes
1444
+ -----
1445
+ The within transformation removes variation that is constant within units
1446
+ (unit fixed effects) and constant within time periods (time fixed effects).
1447
+ The resulting estimates are equivalent to including unit and time dummies
1448
+ but is computationally more efficient for large panels.
1449
+
1450
+ Examples
1451
+ --------
1452
+ >>> df = within_transform(df, ['y', 'x'], 'unit_id', 'year')
1453
+ >>> # df now has 'y_demeaned' and 'x_demeaned' columns
1454
+ """
1455
+ if not inplace:
1456
+ data = data.copy()
1457
+
1458
+ # Cache groupby objects for efficiency
1459
+ unit_grouper = data.groupby(unit, sort=False)
1460
+ time_grouper = data.groupby(time, sort=False)
1461
+
1462
+ if inplace:
1463
+ # Modify columns in place
1464
+ for var in variables:
1465
+ unit_means = unit_grouper[var].transform("mean")
1466
+ time_means = time_grouper[var].transform("mean")
1467
+ grand_mean = data[var].mean()
1468
+ data[var] = data[var] - unit_means - time_means + grand_mean
1469
+ else:
1470
+ # Build all demeaned columns at once to avoid DataFrame fragmentation
1471
+ demeaned_data = {}
1472
+ for var in variables:
1473
+ unit_means = unit_grouper[var].transform("mean")
1474
+ time_means = time_grouper[var].transform("mean")
1475
+ grand_mean = data[var].mean()
1476
+ demeaned_data[f"{var}{suffix}"] = (
1477
+ data[var] - unit_means - time_means + grand_mean
1478
+ ).values
1479
+
1480
+ # Add all columns at once
1481
+ demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
1482
+ data = pd.concat([data, demeaned_df], axis=1)
1483
+
1484
+ return data