diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/utils.py ADDED
@@ -0,0 +1,1845 @@
1
+ """
2
+ Utility functions for difference-in-differences estimation.
3
+ """
4
+
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy import stats
12
+
13
+ from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
14
+ from diff_diff.linalg import solve_ols as _solve_ols_linalg
15
+
16
+ # Import Rust backend if available (from _backend to avoid circular imports)
17
+ from diff_diff._backend import (
18
+ HAS_RUST_BACKEND,
19
+ _rust_project_simplex,
20
+ _rust_synthetic_weights,
21
+ _rust_sdid_unit_weights,
22
+ _rust_compute_time_weights,
23
+ _rust_compute_noise_level,
24
+ _rust_sc_weight_fw,
25
+ )
26
+
27
+ # Numerical constants for optimization algorithms
28
+ _OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization
29
+ _OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization
30
+ _NUMERICAL_EPS = 1e-10 # Small constant to prevent division by zero
31
+
32
+
33
+ def validate_binary(arr: np.ndarray, name: str) -> None:
34
+ """
35
+ Validate that an array contains only binary values (0 or 1).
36
+
37
+ Parameters
38
+ ----------
39
+ arr : np.ndarray
40
+ Array to validate.
41
+ name : str
42
+ Name of the variable (for error messages).
43
+
44
+ Raises
45
+ ------
46
+ ValueError
47
+ If array contains non-binary values.
48
+ """
49
+ unique_values = np.unique(arr[~np.isnan(arr)])
50
+ if not np.all(np.isin(unique_values, [0, 1])):
51
+ raise ValueError(
52
+ f"{name} must be binary (0 or 1). "
53
+ f"Found values: {unique_values}"
54
+ )
55
+
56
+
57
+ def compute_robust_se(
58
+ X: np.ndarray,
59
+ residuals: np.ndarray,
60
+ cluster_ids: Optional[np.ndarray] = None
61
+ ) -> np.ndarray:
62
+ """
63
+ Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.
64
+
65
+ This function is a thin wrapper around the optimized implementation in
66
+ diff_diff.linalg for backwards compatibility.
67
+
68
+ Parameters
69
+ ----------
70
+ X : np.ndarray
71
+ Design matrix of shape (n, k).
72
+ residuals : np.ndarray
73
+ Residuals from regression of shape (n,).
74
+ cluster_ids : np.ndarray, optional
75
+ Cluster identifiers for cluster-robust SEs.
76
+
77
+ Returns
78
+ -------
79
+ np.ndarray
80
+ Variance-covariance matrix of shape (k, k).
81
+ """
82
+ return _compute_robust_vcov_linalg(X, residuals, cluster_ids)
83
+
84
+
85
+ def compute_confidence_interval(
86
+ estimate: float,
87
+ se: float,
88
+ alpha: float = 0.05,
89
+ df: Optional[int] = None
90
+ ) -> Tuple[float, float]:
91
+ """
92
+ Compute confidence interval for an estimate.
93
+
94
+ Parameters
95
+ ----------
96
+ estimate : float
97
+ Point estimate.
98
+ se : float
99
+ Standard error.
100
+ alpha : float
101
+ Significance level (default 0.05 for 95% CI).
102
+ df : int, optional
103
+ Degrees of freedom. If None, uses normal distribution.
104
+
105
+ Returns
106
+ -------
107
+ tuple
108
+ (lower_bound, upper_bound) of confidence interval.
109
+ """
110
+ if df is not None:
111
+ critical_value = stats.t.ppf(1 - alpha / 2, df)
112
+ else:
113
+ critical_value = stats.norm.ppf(1 - alpha / 2)
114
+
115
+ lower = estimate - critical_value * se
116
+ upper = estimate + critical_value * se
117
+
118
+ return (lower, upper)
119
+
120
+
121
+ def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
122
+ """
123
+ Compute p-value for a t-statistic.
124
+
125
+ Parameters
126
+ ----------
127
+ t_stat : float
128
+ T-statistic.
129
+ df : int, optional
130
+ Degrees of freedom. If None, uses normal distribution.
131
+ two_sided : bool
132
+ Whether to compute two-sided p-value (default True).
133
+
134
+ Returns
135
+ -------
136
+ float
137
+ P-value.
138
+ """
139
+ if df is not None:
140
+ p_value = stats.t.sf(np.abs(t_stat), df)
141
+ else:
142
+ p_value = stats.norm.sf(np.abs(t_stat))
143
+
144
+ if two_sided:
145
+ p_value *= 2
146
+
147
+ return float(p_value)
148
+
149
+
150
+ # =============================================================================
151
+ # Wild Cluster Bootstrap
152
+ # =============================================================================
153
+
154
+
155
+ @dataclass
156
+ class WildBootstrapResults:
157
+ """
158
+ Results from wild cluster bootstrap inference.
159
+
160
+ Attributes
161
+ ----------
162
+ se : float
163
+ Bootstrap standard error of the coefficient.
164
+ p_value : float
165
+ Bootstrap p-value (two-sided).
166
+ t_stat_original : float
167
+ Original t-statistic from the data.
168
+ ci_lower : float
169
+ Lower bound of the confidence interval.
170
+ ci_upper : float
171
+ Upper bound of the confidence interval.
172
+ n_clusters : int
173
+ Number of clusters in the data.
174
+ n_bootstrap : int
175
+ Number of bootstrap replications.
176
+ weight_type : str
177
+ Type of bootstrap weights used ("rademacher", "webb", or "mammen").
178
+ alpha : float
179
+ Significance level used for confidence interval.
180
+ bootstrap_distribution : np.ndarray, optional
181
+ Full bootstrap distribution of coefficients (if requested).
182
+
183
+ References
184
+ ----------
185
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
186
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
187
+ The Review of Economics and Statistics, 90(3), 414-427.
188
+ """
189
+
190
+ se: float
191
+ p_value: float
192
+ t_stat_original: float
193
+ ci_lower: float
194
+ ci_upper: float
195
+ n_clusters: int
196
+ n_bootstrap: int
197
+ weight_type: str
198
+ alpha: float = 0.05
199
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
200
+
201
+ def summary(self) -> str:
202
+ """Generate formatted summary of bootstrap results."""
203
+ lines = [
204
+ "Wild Cluster Bootstrap Results",
205
+ "=" * 40,
206
+ f"Bootstrap SE: {self.se:.6f}",
207
+ f"Bootstrap p-value: {self.p_value:.4f}",
208
+ f"Original t-stat: {self.t_stat_original:.4f}",
209
+ f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]",
210
+ f"Number of clusters: {self.n_clusters}",
211
+ f"Bootstrap reps: {self.n_bootstrap}",
212
+ f"Weight type: {self.weight_type}",
213
+ ]
214
+ return "\n".join(lines)
215
+
216
+ def print_summary(self) -> None:
217
+ """Print formatted summary to stdout."""
218
+ print(self.summary())
219
+
220
+
221
+ def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
222
+ """
223
+ Generate Rademacher weights: +1 or -1 with probability 0.5.
224
+
225
+ Parameters
226
+ ----------
227
+ n_clusters : int
228
+ Number of clusters.
229
+ rng : np.random.Generator
230
+ Random number generator.
231
+
232
+ Returns
233
+ -------
234
+ np.ndarray
235
+ Array of Rademacher weights.
236
+ """
237
+ return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters))
238
+
239
+
240
+ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
241
+ """
242
+ Generate Webb's 6-point distribution weights.
243
+
244
+ Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
245
+ with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
246
+
247
+ This distribution is recommended for very few clusters (G < 10) as it
248
+ provides better finite-sample properties than Rademacher weights.
249
+
250
+ Parameters
251
+ ----------
252
+ n_clusters : int
253
+ Number of clusters.
254
+ rng : np.random.Generator
255
+ Random number generator.
256
+
257
+ Returns
258
+ -------
259
+ np.ndarray
260
+ Array of Webb weights.
261
+
262
+ References
263
+ ----------
264
+ Webb, M. D. (2014). Reworking wild bootstrap based inference for
265
+ clustered errors. Queen's Economics Department Working Paper No. 1315.
266
+
267
+ Note: Uses equal probabilities (1/6 each) matching R's `did` package,
268
+ which gives unit variance for consistency with other weight distributions.
269
+ """
270
+ values = np.array([
271
+ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
272
+ np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
273
+ ])
274
+ # Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
275
+ return np.asarray(rng.choice(values, size=n_clusters))
276
+
277
+
278
+ def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
279
+ """
280
+ Generate Mammen's two-point distribution weights.
281
+
282
+ Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2}
283
+ with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}.
284
+
285
+ This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides
286
+ asymptotic refinement for skewed error distributions.
287
+
288
+ Parameters
289
+ ----------
290
+ n_clusters : int
291
+ Number of clusters.
292
+ rng : np.random.Generator
293
+ Random number generator.
294
+
295
+ Returns
296
+ -------
297
+ np.ndarray
298
+ Array of Mammen weights.
299
+
300
+ References
301
+ ----------
302
+ Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional
303
+ Linear Models. The Annals of Statistics, 21(1), 255-285.
304
+ """
305
+ sqrt5 = np.sqrt(5)
306
+ # Values from Mammen (1993)
307
+ val1 = -(sqrt5 - 1) / 2 # approximately -0.618
308
+ val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio)
309
+
310
+ # Probability of val1
311
+ p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724
312
+
313
+ return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
314
+
315
+
316
+ def wild_bootstrap_se(
317
+ X: np.ndarray,
318
+ y: np.ndarray,
319
+ residuals: np.ndarray,
320
+ cluster_ids: np.ndarray,
321
+ coefficient_index: int,
322
+ n_bootstrap: int = 999,
323
+ weight_type: str = "rademacher",
324
+ null_hypothesis: float = 0.0,
325
+ alpha: float = 0.05,
326
+ seed: Optional[int] = None,
327
+ return_distribution: bool = False
328
+ ) -> WildBootstrapResults:
329
+ """
330
+ Compute wild cluster bootstrap standard errors and p-values.
331
+
332
+ Implements the Wild Cluster Residual (WCR) bootstrap procedure from
333
+ Cameron, Gelbach, and Miller (2008). Uses the restricted residuals
334
+ approach (imposing H0: coefficient = null_hypothesis) for more accurate
335
+ p-value computation.
336
+
337
+ Parameters
338
+ ----------
339
+ X : np.ndarray
340
+ Design matrix of shape (n, k).
341
+ y : np.ndarray
342
+ Outcome vector of shape (n,).
343
+ residuals : np.ndarray
344
+ OLS residuals from unrestricted regression, shape (n,).
345
+ cluster_ids : np.ndarray
346
+ Cluster identifiers of shape (n,).
347
+ coefficient_index : int
348
+ Index of the coefficient for which to compute bootstrap inference.
349
+ For DiD, this is typically 3 (the treatment*post interaction term).
350
+ n_bootstrap : int, default=999
351
+ Number of bootstrap replications. Odd numbers are recommended for
352
+ exact p-value computation.
353
+ weight_type : str, default="rademacher"
354
+ Type of bootstrap weights:
355
+ - "rademacher": +1 or -1 with equal probability (standard choice)
356
+ - "webb": 6-point distribution (recommended for <10 clusters)
357
+ - "mammen": Two-point distribution with skewness correction
358
+ null_hypothesis : float, default=0.0
359
+ Value of the null hypothesis for p-value computation.
360
+ alpha : float, default=0.05
361
+ Significance level for confidence interval.
362
+ seed : int, optional
363
+ Random seed for reproducibility. If None (default), results
364
+ will vary between runs.
365
+ return_distribution : bool, default=False
366
+ If True, include full bootstrap distribution in results.
367
+
368
+ Returns
369
+ -------
370
+ WildBootstrapResults
371
+ Dataclass containing bootstrap SE, p-value, confidence interval,
372
+ and other inference results.
373
+
374
+ Raises
375
+ ------
376
+ ValueError
377
+ If weight_type is not recognized or if there are fewer than 2 clusters.
378
+
379
+ Warns
380
+ -----
381
+ UserWarning
382
+ If the number of clusters is less than 5, as bootstrap inference
383
+ may be unreliable.
384
+
385
+ Examples
386
+ --------
387
+ >>> from diff_diff.utils import wild_bootstrap_se
388
+ >>> results = wild_bootstrap_se(
389
+ ... X, y, residuals, cluster_ids,
390
+ ... coefficient_index=3, # ATT coefficient
391
+ ... n_bootstrap=999,
392
+ ... weight_type="rademacher",
393
+ ... seed=42
394
+ ... )
395
+ >>> print(f"Bootstrap SE: {results.se:.4f}")
396
+ >>> print(f"Bootstrap p-value: {results.p_value:.4f}")
397
+
398
+ References
399
+ ----------
400
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
401
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
402
+ The Review of Economics and Statistics, 90(3), 414-427.
403
+
404
+ MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for
405
+ few (treated) clusters. The Econometrics Journal, 21(2), 114-135.
406
+ """
407
+ # Validate inputs
408
+ valid_weight_types = ["rademacher", "webb", "mammen"]
409
+ if weight_type not in valid_weight_types:
410
+ raise ValueError(
411
+ f"weight_type must be one of {valid_weight_types}, got '{weight_type}'"
412
+ )
413
+
414
+ unique_clusters = np.unique(cluster_ids)
415
+ n_clusters = len(unique_clusters)
416
+
417
+ if n_clusters < 2:
418
+ raise ValueError(
419
+ f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}"
420
+ )
421
+
422
+ if n_clusters < 5:
423
+ warnings.warn(
424
+ f"Only {n_clusters} clusters detected. Wild bootstrap inference may be "
425
+ "unreliable with fewer than 5 clusters. Consider using Webb weights "
426
+ "(weight_type='webb') for improved finite-sample properties.",
427
+ UserWarning
428
+ )
429
+
430
+ # Initialize RNG
431
+ rng = np.random.default_rng(seed)
432
+
433
+ # Select weight generator
434
+ weight_generators = {
435
+ "rademacher": _generate_rademacher_weights,
436
+ "webb": _generate_webb_weights,
437
+ "mammen": _generate_mammen_weights,
438
+ }
439
+ generate_weights = weight_generators[weight_type]
440
+
441
+ n = X.shape[0]
442
+
443
+ # Step 1: Compute original coefficient and cluster-robust SE
444
+ beta_hat, _, vcov_original = _solve_ols_linalg(
445
+ X, y, cluster_ids=cluster_ids, return_vcov=True
446
+ )
447
+ original_coef = beta_hat[coefficient_index]
448
+ se_original = np.sqrt(vcov_original[coefficient_index, coefficient_index])
449
+ t_stat_original = (original_coef - null_hypothesis) / se_original
450
+
451
+ # Step 2: Impose null hypothesis (restricted estimation)
452
+ # Create restricted y: y_restricted = y - X[:, coef_index] * null_hypothesis
453
+ # This imposes the null that the coefficient equals null_hypothesis
454
+ y_restricted = y - X[:, coefficient_index] * null_hypothesis
455
+
456
+ # Fit restricted model (but we need to drop the column for the restricted coef)
457
+ # Actually, for WCR bootstrap we keep all columns but impose the null via residuals
458
+ # Re-estimate with the restricted dependent variable
459
+ beta_restricted, residuals_restricted, _ = _solve_ols_linalg(
460
+ X, y_restricted, return_vcov=False
461
+ )
462
+
463
+ # Create cluster-to-observation mapping for efficiency
464
+ cluster_map = {c: np.where(cluster_ids == c)[0] for c in unique_clusters}
465
+ cluster_indices = [cluster_map[c] for c in unique_clusters]
466
+
467
+ # Step 3: Bootstrap loop
468
+ bootstrap_t_stats = np.zeros(n_bootstrap)
469
+ bootstrap_coefs = np.zeros(n_bootstrap)
470
+
471
+ for b in range(n_bootstrap):
472
+ # Generate cluster-level weights
473
+ cluster_weights = generate_weights(n_clusters, rng)
474
+
475
+ # Map cluster weights to observations
476
+ obs_weights = np.zeros(n)
477
+ for g, indices in enumerate(cluster_indices):
478
+ obs_weights[indices] = cluster_weights[g]
479
+
480
+ # Construct bootstrap sample: y* = X @ beta_restricted + e_restricted * weights
481
+ y_star = X @ beta_restricted + residuals_restricted * obs_weights
482
+
483
+ # Estimate bootstrap coefficients with cluster-robust SE
484
+ beta_star, residuals_star, vcov_star = _solve_ols_linalg(
485
+ X, y_star, cluster_ids=cluster_ids, return_vcov=True
486
+ )
487
+ bootstrap_coefs[b] = beta_star[coefficient_index]
488
+ se_star = np.sqrt(vcov_star[coefficient_index, coefficient_index])
489
+
490
+ # Compute bootstrap t-statistic (under null hypothesis)
491
+ if se_star > 0:
492
+ bootstrap_t_stats[b] = (beta_star[coefficient_index] - null_hypothesis) / se_star
493
+ else:
494
+ bootstrap_t_stats[b] = 0.0
495
+
496
+ # Step 4: Compute bootstrap p-value
497
+ # P-value is proportion of |t*| >= |t_original|
498
+ p_value = np.mean(np.abs(bootstrap_t_stats) >= np.abs(t_stat_original))
499
+
500
+ # Ensure p-value is at least 1/(n_bootstrap+1) to avoid exact zero
501
+ p_value = float(max(float(p_value), 1 / (n_bootstrap + 1)))
502
+
503
+ # Step 5: Compute bootstrap SE and confidence interval
504
+ # SE from standard deviation of bootstrap coefficient distribution
505
+ se_bootstrap = float(np.std(bootstrap_coefs, ddof=1))
506
+
507
+ # Percentile confidence interval from bootstrap distribution
508
+ lower_percentile = alpha / 2 * 100
509
+ upper_percentile = (1 - alpha / 2) * 100
510
+ ci_lower = float(np.percentile(bootstrap_coefs, lower_percentile))
511
+ ci_upper = float(np.percentile(bootstrap_coefs, upper_percentile))
512
+
513
+ return WildBootstrapResults(
514
+ se=se_bootstrap,
515
+ p_value=p_value,
516
+ t_stat_original=t_stat_original,
517
+ ci_lower=ci_lower,
518
+ ci_upper=ci_upper,
519
+ n_clusters=n_clusters,
520
+ n_bootstrap=n_bootstrap,
521
+ weight_type=weight_type,
522
+ alpha=alpha,
523
+ bootstrap_distribution=bootstrap_coefs if return_distribution else None
524
+ )
525
+
526
+
527
+ def check_parallel_trends(
528
+ data: pd.DataFrame,
529
+ outcome: str,
530
+ time: str,
531
+ treatment_group: str,
532
+ pre_periods: Optional[List[Any]] = None
533
+ ) -> Dict[str, Any]:
534
+ """
535
+ Perform a simple check for parallel trends assumption.
536
+
537
+ This computes the trend (slope) in the outcome variable for both
538
+ treatment and control groups during pre-treatment periods.
539
+
540
+ Parameters
541
+ ----------
542
+ data : pd.DataFrame
543
+ Panel data.
544
+ outcome : str
545
+ Name of outcome variable column.
546
+ time : str
547
+ Name of time period column.
548
+ treatment_group : str
549
+ Name of treatment group indicator column.
550
+ pre_periods : list, optional
551
+ List of pre-treatment time periods. If None, infers from data.
552
+
553
+ Returns
554
+ -------
555
+ dict
556
+ Dictionary with trend statistics and test results.
557
+ """
558
+ if pre_periods is None:
559
+ # Assume treatment happens at median time period
560
+ all_periods = sorted(data[time].unique())
561
+ mid_point = len(all_periods) // 2
562
+ pre_periods = all_periods[:mid_point]
563
+
564
+ pre_data = data[data[time].isin(pre_periods)]
565
+
566
+ # Compute trends for each group
567
+ treated_data = pre_data[pre_data[treatment_group] == 1]
568
+ control_data = pre_data[pre_data[treatment_group] == 0]
569
+
570
+ # Simple linear regression for trends
571
+ def compute_trend(group_data: pd.DataFrame) -> Tuple[float, float]:
572
+ time_values = group_data[time].values
573
+ outcome_values = group_data[outcome].values
574
+
575
+ # Normalize time to start at 0
576
+ time_norm = time_values - time_values.min()
577
+
578
+ # Compute slope using least squares
579
+ n = len(time_norm)
580
+ if n < 2:
581
+ return np.nan, np.nan
582
+
583
+ mean_t = np.mean(time_norm)
584
+ mean_y = np.mean(outcome_values)
585
+
586
+ # Check for zero variance in time (all same time period)
587
+ time_var = np.sum((time_norm - mean_t) ** 2)
588
+ if time_var == 0:
589
+ return np.nan, np.nan
590
+
591
+ slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / time_var
592
+
593
+ # Compute standard error of slope
594
+ y_hat = mean_y + slope * (time_norm - mean_t)
595
+ residuals = outcome_values - y_hat
596
+ mse = np.sum(residuals ** 2) / (n - 2)
597
+ se_slope = np.sqrt(mse / time_var)
598
+
599
+ return slope, se_slope
600
+
601
+ treated_slope, treated_se = compute_trend(treated_data)
602
+ control_slope, control_se = compute_trend(control_data)
603
+
604
+ # Test for difference in trends
605
+ slope_diff = treated_slope - control_slope
606
+ se_diff = np.sqrt(treated_se ** 2 + control_se ** 2)
607
+ t_stat = slope_diff / se_diff if se_diff > 0 else np.nan
608
+ p_value = compute_p_value(t_stat) if not np.isnan(t_stat) else np.nan
609
+
610
+ return {
611
+ "treated_trend": treated_slope,
612
+ "treated_trend_se": treated_se,
613
+ "control_trend": control_slope,
614
+ "control_trend_se": control_se,
615
+ "trend_difference": slope_diff,
616
+ "trend_difference_se": se_diff,
617
+ "t_statistic": t_stat,
618
+ "p_value": p_value,
619
+ "parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
620
+ }
621
+
622
+
623
+ def check_parallel_trends_robust(
624
+ data: pd.DataFrame,
625
+ outcome: str,
626
+ time: str,
627
+ treatment_group: str,
628
+ unit: Optional[str] = None,
629
+ pre_periods: Optional[List[Any]] = None,
630
+ n_permutations: int = 1000,
631
+ seed: Optional[int] = None,
632
+ wasserstein_threshold: float = 0.2
633
+ ) -> Dict[str, Any]:
634
+ """
635
+ Perform robust parallel trends testing using distributional comparisons.
636
+
637
+ Uses the Wasserstein (Earth Mover's) distance to compare the full
638
+ distribution of outcome changes between treated and control groups,
639
+ with permutation-based inference.
640
+
641
+ Parameters
642
+ ----------
643
+ data : pd.DataFrame
644
+ Panel data with repeated observations over time.
645
+ outcome : str
646
+ Name of outcome variable column.
647
+ time : str
648
+ Name of time period column.
649
+ treatment_group : str
650
+ Name of treatment group indicator column (0/1).
651
+ unit : str, optional
652
+ Name of unit identifier column. If provided, computes unit-level
653
+ changes. Otherwise uses observation-level data.
654
+ pre_periods : list, optional
655
+ List of pre-treatment time periods. If None, uses first half of periods.
656
+ n_permutations : int, default=1000
657
+ Number of permutations for computing p-value.
658
+ seed : int, optional
659
+ Random seed for reproducibility.
660
+ wasserstein_threshold : float, default=0.2
661
+ Threshold for normalized Wasserstein distance. Values below this
662
+ threshold (combined with p > 0.05) suggest parallel trends are plausible.
663
+
664
+ Returns
665
+ -------
666
+ dict
667
+ Dictionary containing:
668
+ - wasserstein_distance: Wasserstein distance between group distributions
669
+ - wasserstein_p_value: Permutation-based p-value
670
+ - ks_statistic: Kolmogorov-Smirnov test statistic
671
+ - ks_p_value: KS test p-value
672
+ - mean_difference: Difference in mean changes
673
+ - variance_ratio: Ratio of variances in changes
674
+ - treated_changes: Array of outcome changes for treated
675
+ - control_changes: Array of outcome changes for control
676
+ - parallel_trends_plausible: Boolean assessment
677
+
678
+ Examples
679
+ --------
680
+ >>> results = check_parallel_trends_robust(
681
+ ... data, outcome='sales', time='year',
682
+ ... treatment_group='treated', unit='firm_id'
683
+ ... )
684
+ >>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
685
+ >>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
686
+
687
+ Notes
688
+ -----
689
+ The Wasserstein distance (Earth Mover's Distance) measures the minimum
690
+ "cost" of transforming one distribution into another. Unlike simple
691
+ mean comparisons, it captures differences in the entire distribution
692
+ shape, making it more robust to non-normal data and heterogeneous effects.
693
+
694
+ A small Wasserstein distance and high p-value suggest the distributions
695
+ of pre-treatment changes are similar, supporting the parallel trends
696
+ assumption.
697
+ """
698
+ # Use local RNG to avoid affecting global random state
699
+ rng = np.random.default_rng(seed)
700
+
701
+ # Identify pre-treatment periods
702
+ if pre_periods is None:
703
+ all_periods = sorted(data[time].unique())
704
+ mid_point = len(all_periods) // 2
705
+ pre_periods = all_periods[:mid_point]
706
+
707
+ pre_data = data[data[time].isin(pre_periods)].copy()
708
+
709
+ # Compute outcome changes
710
+ treated_changes, control_changes = _compute_outcome_changes(
711
+ pre_data, outcome, time, treatment_group, unit
712
+ )
713
+
714
+ if len(treated_changes) < 2 or len(control_changes) < 2:
715
+ return {
716
+ "wasserstein_distance": np.nan,
717
+ "wasserstein_p_value": np.nan,
718
+ "ks_statistic": np.nan,
719
+ "ks_p_value": np.nan,
720
+ "mean_difference": np.nan,
721
+ "variance_ratio": np.nan,
722
+ "treated_changes": treated_changes,
723
+ "control_changes": control_changes,
724
+ "parallel_trends_plausible": None,
725
+ "error": "Insufficient data for comparison",
726
+ }
727
+
728
+ # Compute Wasserstein distance
729
+ wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
730
+
731
+ # Permutation test for Wasserstein distance
732
+ all_changes = np.concatenate([treated_changes, control_changes])
733
+ n_treated = len(treated_changes)
734
+ n_total = len(all_changes)
735
+
736
+ permuted_distances = np.zeros(n_permutations)
737
+ for i in range(n_permutations):
738
+ perm_idx = rng.permutation(n_total)
739
+ perm_treated = all_changes[perm_idx[:n_treated]]
740
+ perm_control = all_changes[perm_idx[n_treated:]]
741
+ permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
742
+
743
+ # P-value: proportion of permuted distances >= observed
744
+ wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
745
+
746
+ # Kolmogorov-Smirnov test
747
+ ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
748
+
749
+ # Additional summary statistics
750
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
751
+ var_treated = np.var(treated_changes, ddof=1)
752
+ var_control = np.var(control_changes, ddof=1)
753
+ var_ratio = var_treated / var_control if var_control > 0 else np.nan
754
+
755
+ # Normalized Wasserstein (relative to pooled std)
756
+ pooled_std = np.std(all_changes, ddof=1)
757
+ wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
758
+
759
+ # Assessment: parallel trends plausible if p-value > 0.05
760
+ # and normalized Wasserstein is small (below threshold)
761
+ plausible = bool(
762
+ wasserstein_p > 0.05 and
763
+ (wasserstein_normalized < wasserstein_threshold if not np.isnan(wasserstein_normalized) else True)
764
+ )
765
+
766
+ return {
767
+ "wasserstein_distance": wasserstein_dist,
768
+ "wasserstein_normalized": wasserstein_normalized,
769
+ "wasserstein_p_value": wasserstein_p,
770
+ "ks_statistic": ks_stat,
771
+ "ks_p_value": ks_p,
772
+ "mean_difference": mean_diff,
773
+ "variance_ratio": var_ratio,
774
+ "n_treated": len(treated_changes),
775
+ "n_control": len(control_changes),
776
+ "treated_changes": treated_changes,
777
+ "control_changes": control_changes,
778
+ "parallel_trends_plausible": plausible,
779
+ }
780
+
781
+
782
+ def _compute_outcome_changes(
783
+ data: pd.DataFrame,
784
+ outcome: str,
785
+ time: str,
786
+ treatment_group: str,
787
+ unit: Optional[str] = None
788
+ ) -> Tuple[np.ndarray, np.ndarray]:
789
+ """
790
+ Compute period-to-period outcome changes for treated and control groups.
791
+
792
+ Parameters
793
+ ----------
794
+ data : pd.DataFrame
795
+ Panel data.
796
+ outcome : str
797
+ Outcome variable column.
798
+ time : str
799
+ Time period column.
800
+ treatment_group : str
801
+ Treatment group indicator column.
802
+ unit : str, optional
803
+ Unit identifier column.
804
+
805
+ Returns
806
+ -------
807
+ tuple
808
+ (treated_changes, control_changes) as numpy arrays.
809
+ """
810
+ if unit is not None:
811
+ # Unit-level changes: compute change for each unit across periods
812
+ data_sorted = data.sort_values([unit, time])
813
+ data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
814
+
815
+ # Remove NaN from first period of each unit
816
+ changes_data = data_sorted.dropna(subset=["_outcome_change"])
817
+
818
+ treated_changes = changes_data[
819
+ changes_data[treatment_group] == 1
820
+ ]["_outcome_change"].values
821
+
822
+ control_changes = changes_data[
823
+ changes_data[treatment_group] == 0
824
+ ]["_outcome_change"].values
825
+ else:
826
+ # Aggregate changes: compute mean change per period per group
827
+ treated_data = data[data[treatment_group] == 1]
828
+ control_data = data[data[treatment_group] == 0]
829
+
830
+ # Compute period means
831
+ treated_means = treated_data.groupby(time)[outcome].mean()
832
+ control_means = control_data.groupby(time)[outcome].mean()
833
+
834
+ # Compute changes between consecutive periods
835
+ treated_changes = np.diff(treated_means.values)
836
+ control_changes = np.diff(control_means.values)
837
+
838
+ return treated_changes.astype(float), control_changes.astype(float)
839
+
840
+
841
+ def equivalence_test_trends(
842
+ data: pd.DataFrame,
843
+ outcome: str,
844
+ time: str,
845
+ treatment_group: str,
846
+ unit: Optional[str] = None,
847
+ pre_periods: Optional[List[Any]] = None,
848
+ equivalence_margin: Optional[float] = None
849
+ ) -> Dict[str, Any]:
850
+ """
851
+ Perform equivalence testing (TOST) for parallel trends.
852
+
853
+ Tests whether the difference in trends is practically equivalent to zero
854
+ using Two One-Sided Tests (TOST) procedure.
855
+
856
+ Parameters
857
+ ----------
858
+ data : pd.DataFrame
859
+ Panel data.
860
+ outcome : str
861
+ Name of outcome variable column.
862
+ time : str
863
+ Name of time period column.
864
+ treatment_group : str
865
+ Name of treatment group indicator column.
866
+ unit : str, optional
867
+ Name of unit identifier column.
868
+ pre_periods : list, optional
869
+ List of pre-treatment time periods.
870
+ equivalence_margin : float, optional
871
+ The margin for equivalence (delta). If None, uses 0.5 * pooled SD
872
+ of outcome changes as a default.
873
+
874
+ Returns
875
+ -------
876
+ dict
877
+ Dictionary containing:
878
+ - mean_difference: Difference in mean changes
879
+ - equivalence_margin: The margin used
880
+ - lower_p_value: P-value for lower bound test
881
+ - upper_p_value: P-value for upper bound test
882
+ - tost_p_value: Maximum of the two p-values
883
+ - equivalent: Boolean indicating equivalence at alpha=0.05
884
+ """
885
+ # Get pre-treatment periods
886
+ if pre_periods is None:
887
+ all_periods = sorted(data[time].unique())
888
+ mid_point = len(all_periods) // 2
889
+ pre_periods = all_periods[:mid_point]
890
+
891
+ pre_data = data[data[time].isin(pre_periods)].copy()
892
+
893
+ # Compute outcome changes
894
+ treated_changes, control_changes = _compute_outcome_changes(
895
+ pre_data, outcome, time, treatment_group, unit
896
+ )
897
+
898
+ # Need at least 2 observations per group to compute variance
899
+ # and at least 3 total for meaningful df calculation
900
+ if len(treated_changes) < 2 or len(control_changes) < 2:
901
+ return {
902
+ "mean_difference": np.nan,
903
+ "se_difference": np.nan,
904
+ "equivalence_margin": np.nan,
905
+ "lower_t_stat": np.nan,
906
+ "upper_t_stat": np.nan,
907
+ "lower_p_value": np.nan,
908
+ "upper_p_value": np.nan,
909
+ "tost_p_value": np.nan,
910
+ "degrees_of_freedom": np.nan,
911
+ "equivalent": None,
912
+ "error": "Insufficient data (need at least 2 observations per group)",
913
+ }
914
+
915
+ # Compute statistics
916
+ var_t = np.var(treated_changes, ddof=1)
917
+ var_c = np.var(control_changes, ddof=1)
918
+ n_t = len(treated_changes)
919
+ n_c = len(control_changes)
920
+
921
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
922
+
923
+ # Handle zero variance case
924
+ if var_t == 0 and var_c == 0:
925
+ return {
926
+ "mean_difference": mean_diff,
927
+ "se_difference": 0.0,
928
+ "equivalence_margin": np.nan,
929
+ "lower_t_stat": np.nan,
930
+ "upper_t_stat": np.nan,
931
+ "lower_p_value": np.nan,
932
+ "upper_p_value": np.nan,
933
+ "tost_p_value": np.nan,
934
+ "degrees_of_freedom": np.nan,
935
+ "equivalent": None,
936
+ "error": "Zero variance in both groups - cannot perform t-test",
937
+ }
938
+
939
+ se_diff = np.sqrt(var_t / n_t + var_c / n_c)
940
+
941
+ # Handle zero SE case (cannot divide by zero in t-stat calculation)
942
+ if se_diff == 0:
943
+ return {
944
+ "mean_difference": mean_diff,
945
+ "se_difference": 0.0,
946
+ "equivalence_margin": np.nan,
947
+ "lower_t_stat": np.nan,
948
+ "upper_t_stat": np.nan,
949
+ "lower_p_value": np.nan,
950
+ "upper_p_value": np.nan,
951
+ "tost_p_value": np.nan,
952
+ "degrees_of_freedom": np.nan,
953
+ "equivalent": None,
954
+ "error": "Zero standard error - cannot perform t-test",
955
+ }
956
+
957
+ # Set equivalence margin if not provided
958
+ if equivalence_margin is None:
959
+ pooled_changes = np.concatenate([treated_changes, control_changes])
960
+ equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)
961
+
962
+ # Degrees of freedom (Welch-Satterthwaite approximation)
963
+ # Guard against division by zero when one group has zero variance
964
+ numerator = (var_t/n_t + var_c/n_c)**2
965
+ denom_t = (var_t/n_t)**2/(n_t-1) if var_t > 0 else 0
966
+ denom_c = (var_c/n_c)**2/(n_c-1) if var_c > 0 else 0
967
+ denominator = denom_t + denom_c
968
+
969
+ if denominator == 0:
970
+ # Fall back to minimum of n_t-1 and n_c-1 when one variance is zero
971
+ df = min(n_t - 1, n_c - 1)
972
+ else:
973
+ df = numerator / denominator
974
+
975
+ # TOST: Two one-sided tests
976
+ # Test 1: H0: diff <= -margin vs H1: diff > -margin
977
+ t_lower = (mean_diff - (-equivalence_margin)) / se_diff
978
+ p_lower = stats.t.sf(t_lower, df)
979
+
980
+ # Test 2: H0: diff >= margin vs H1: diff < margin
981
+ t_upper = (mean_diff - equivalence_margin) / se_diff
982
+ p_upper = stats.t.cdf(t_upper, df)
983
+
984
+ # TOST p-value is the maximum of the two
985
+ tost_p = max(p_lower, p_upper)
986
+
987
+ return {
988
+ "mean_difference": mean_diff,
989
+ "se_difference": se_diff,
990
+ "equivalence_margin": equivalence_margin,
991
+ "lower_t_stat": t_lower,
992
+ "upper_t_stat": t_upper,
993
+ "lower_p_value": p_lower,
994
+ "upper_p_value": p_upper,
995
+ "tost_p_value": tost_p,
996
+ "degrees_of_freedom": df,
997
+ "equivalent": bool(tost_p < 0.05),
998
+ }
999
+
1000
+
1001
+ def compute_synthetic_weights(
1002
+ Y_control: np.ndarray,
1003
+ Y_treated: np.ndarray,
1004
+ lambda_reg: float = 0.0,
1005
+ min_weight: float = 1e-6
1006
+ ) -> np.ndarray:
1007
+ """
1008
+ Compute synthetic control unit weights using constrained optimization.
1009
+
1010
+ Finds weights ω that minimize the squared difference between the
1011
+ weighted average of control unit outcomes and the treated unit outcomes
1012
+ during pre-treatment periods.
1013
+
1014
+ Parameters
1015
+ ----------
1016
+ Y_control : np.ndarray
1017
+ Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
1018
+ Each column is a control unit, each row is a pre-treatment period.
1019
+ Y_treated : np.ndarray
1020
+ Treated unit mean outcomes of shape (n_pre_periods,).
1021
+ Average across treated units for each pre-treatment period.
1022
+ lambda_reg : float, default=0.0
1023
+ L2 regularization parameter. Larger values shrink weights toward
1024
+ uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
1025
+ min_weight : float, default=1e-6
1026
+ Minimum weight threshold. Weights below this are set to zero.
1027
+
1028
+ Returns
1029
+ -------
1030
+ np.ndarray
1031
+ Unit weights of shape (n_control_units,) that sum to 1.
1032
+
1033
+ Notes
1034
+ -----
1035
+ Solves the quadratic program:
1036
+
1037
+ min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
1038
+ s.t. ω >= 0, sum(ω) = 1
1039
+
1040
+ Uses a simplified coordinate descent approach with projection onto simplex.
1041
+ """
1042
+ n_pre, n_control = Y_control.shape
1043
+
1044
+ if n_control == 0:
1045
+ return np.asarray([])
1046
+
1047
+ if n_control == 1:
1048
+ return np.asarray([1.0])
1049
+
1050
+ # Use Rust backend if available
1051
+ if HAS_RUST_BACKEND:
1052
+ Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1053
+ Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
1054
+ weights = _rust_synthetic_weights(
1055
+ Y_control, Y_treated, lambda_reg,
1056
+ _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
1057
+ )
1058
+ else:
1059
+ # Fallback to NumPy implementation
1060
+ weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)
1061
+
1062
+ # Set small weights to zero for interpretability
1063
+ weights[weights < min_weight] = 0
1064
+ if np.sum(weights) > 0:
1065
+ weights = weights / np.sum(weights)
1066
+ else:
1067
+ # Fallback to uniform if all weights are zeroed
1068
+ weights = np.ones(n_control) / n_control
1069
+
1070
+ return np.asarray(weights)
1071
+
1072
+
1073
+ def _compute_synthetic_weights_numpy(
1074
+ Y_control: np.ndarray,
1075
+ Y_treated: np.ndarray,
1076
+ lambda_reg: float = 0.0,
1077
+ ) -> np.ndarray:
1078
+ """NumPy fallback implementation of compute_synthetic_weights."""
1079
+ n_pre, n_control = Y_control.shape
1080
+
1081
+ # Initialize with uniform weights
1082
+ weights = np.ones(n_control) / n_control
1083
+
1084
+ # Precompute matrices for optimization
1085
+ # Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
1086
+ # = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
1087
+ YtY = Y_control.T @ Y_control
1088
+ YtT = Y_control.T @ Y_treated
1089
+ w_uniform = np.ones(n_control) / n_control
1090
+
1091
+ # Add regularization
1092
+ H = YtY + lambda_reg * np.eye(n_control)
1093
+ f = YtT + lambda_reg * w_uniform
1094
+
1095
+ # Solve with projected gradient descent
1096
+ # Project onto probability simplex
1097
+ step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)
1098
+
1099
+ for _ in range(_OPTIMIZATION_MAX_ITER):
1100
+ weights_old = weights.copy()
1101
+
1102
+ # Gradient step: minimize ||Y - Y_control @ w||^2
1103
+ grad = H @ weights - f
1104
+ weights = weights - step_size * grad
1105
+
1106
+ # Project onto simplex (sum to 1, non-negative)
1107
+ weights = _project_simplex(weights)
1108
+
1109
+ # Check convergence
1110
+ if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
1111
+ break
1112
+
1113
+ return weights
1114
+
1115
+
1116
+ def _project_simplex(v: np.ndarray) -> np.ndarray:
1117
+ """
1118
+ Project vector onto probability simplex (sum to 1, non-negative).
1119
+
1120
+ Uses the algorithm from Duchi et al. (2008).
1121
+
1122
+ Parameters
1123
+ ----------
1124
+ v : np.ndarray
1125
+ Vector to project.
1126
+
1127
+ Returns
1128
+ -------
1129
+ np.ndarray
1130
+ Projected vector on the simplex.
1131
+ """
1132
+ n = len(v)
1133
+ if n == 0:
1134
+ return v
1135
+
1136
+ # Sort in descending order
1137
+ u = np.sort(v)[::-1]
1138
+
1139
+ # Find the threshold
1140
+ cssv = np.cumsum(u)
1141
+ rho = np.where(u > (cssv - 1) / np.arange(1, n + 1))[0]
1142
+
1143
+ if len(rho) == 0:
1144
+ # All elements are negative or zero
1145
+ rho_val = 0
1146
+ else:
1147
+ rho_val = rho[-1]
1148
+
1149
+ theta = (cssv[rho_val] - 1) / (rho_val + 1)
1150
+
1151
+ return np.asarray(np.maximum(v - theta, 0))
1152
+
1153
+
1154
+ # =============================================================================
1155
+ # SDID Weight Optimization (Frank-Wolfe, matching R's synthdid)
1156
+ # =============================================================================
1157
+
1158
+
1159
+ def _sum_normalize(v: np.ndarray) -> np.ndarray:
1160
+ """Normalize vector to sum to 1. Fallback to uniform if sum is zero.
1161
+
1162
+ Matches R's synthdid ``sum_normalize()`` helper.
1163
+ """
1164
+ s = np.sum(v)
1165
+ if s > 0:
1166
+ return v / s
1167
+ return np.ones(len(v)) / len(v)
1168
+
1169
+
1170
+ def _compute_noise_level(Y_pre_control: np.ndarray) -> float:
1171
+ """Compute noise level from first-differences of control outcomes.
1172
+
1173
+ Matches R's ``sd(apply(Y[1:N0, 1:T0], 1, diff))`` which computes
1174
+ first-differences across time for each control unit, then takes the
1175
+ pooled standard deviation.
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ Y_pre_control : np.ndarray
1180
+ Control unit pre-treatment outcomes, shape (n_pre, n_control).
1181
+
1182
+ Returns
1183
+ -------
1184
+ float
1185
+ Noise level (standard deviation of first-differences).
1186
+ """
1187
+ if HAS_RUST_BACKEND:
1188
+ return float(_rust_compute_noise_level(np.ascontiguousarray(Y_pre_control)))
1189
+ return _compute_noise_level_numpy(Y_pre_control)
1190
+
1191
+
1192
+ def _compute_noise_level_numpy(Y_pre_control: np.ndarray) -> float:
1193
+ """Pure NumPy implementation of noise level computation."""
1194
+ if Y_pre_control.shape[0] < 2:
1195
+ return 0.0
1196
+ # R: apply(Y[1:N0, 1:T0], 1, diff) computes diff per row (unit).
1197
+ # Our matrix is (T, N) so diff along axis=0 gives (T-1, N).
1198
+ first_diffs = np.diff(Y_pre_control, axis=0) # (T_pre-1, N_co)
1199
+ if first_diffs.size <= 1:
1200
+ return 0.0
1201
+ return float(np.std(first_diffs, ddof=1))
1202
+
1203
+
1204
+ def _compute_regularization(
1205
+ Y_pre_control: np.ndarray,
1206
+ n_treated: int,
1207
+ n_post: int,
1208
+ ) -> tuple:
1209
+ """Compute auto-regularization parameters matching R's synthdid.
1210
+
1211
+ Parameters
1212
+ ----------
1213
+ Y_pre_control : np.ndarray
1214
+ Control unit pre-treatment outcomes, shape (n_pre, n_control).
1215
+ n_treated : int
1216
+ Number of treated units.
1217
+ n_post : int
1218
+ Number of post-treatment periods.
1219
+
1220
+ Returns
1221
+ -------
1222
+ tuple
1223
+ (zeta_omega, zeta_lambda) regularization parameters.
1224
+ """
1225
+ sigma = _compute_noise_level(Y_pre_control)
1226
+ eta_omega = (n_treated * n_post) ** 0.25
1227
+ eta_lambda = 1e-6
1228
+ return eta_omega * sigma, eta_lambda * sigma
1229
+
1230
+
1231
+ def _fw_step(
1232
+ A: np.ndarray,
1233
+ x: np.ndarray,
1234
+ b: np.ndarray,
1235
+ eta: float,
1236
+ ) -> np.ndarray:
1237
+ """Single Frank-Wolfe step on the simplex.
1238
+
1239
+ Matches R's ``fw.step()`` in synthdid's ``sc.weight.fw()``.
1240
+
1241
+ Parameters
1242
+ ----------
1243
+ A : np.ndarray
1244
+ Matrix of shape (N, T0).
1245
+ x : np.ndarray
1246
+ Current weight vector of shape (T0,).
1247
+ b : np.ndarray
1248
+ Target vector of shape (N,).
1249
+ eta : float
1250
+ Regularization strength (N * zeta^2).
1251
+
1252
+ Returns
1253
+ -------
1254
+ np.ndarray
1255
+ Updated weight vector on the simplex.
1256
+ """
1257
+ Ax = A @ x
1258
+ half_grad = A.T @ (Ax - b) + eta * x
1259
+ i = int(np.argmin(half_grad))
1260
+ d_x = -x.copy()
1261
+ d_x[i] += 1.0
1262
+ if np.allclose(d_x, 0.0):
1263
+ return x.copy()
1264
+ d_err = A[:, i] - Ax
1265
+ denom = d_err @ d_err + eta * (d_x @ d_x)
1266
+ if denom <= 0:
1267
+ return x.copy()
1268
+ step = -(half_grad @ d_x) / denom
1269
+ step = float(np.clip(step, 0.0, 1.0))
1270
+ return x + step * d_x
1271
+
1272
+
1273
+ def _sc_weight_fw(
1274
+ Y: np.ndarray,
1275
+ zeta: float,
1276
+ intercept: bool = True,
1277
+ init_weights: Optional[np.ndarray] = None,
1278
+ min_decrease: float = 1e-5,
1279
+ max_iter: int = 10000,
1280
+ ) -> np.ndarray:
1281
+ """Compute synthetic control weights via Frank-Wolfe optimization.
1282
+
1283
+ Matches R's ``sc.weight.fw()`` from the synthdid package. Solves::
1284
+
1285
+ min_{lambda on simplex} zeta^2 * ||lambda||^2
1286
+ + (1/N) * ||A_centered @ lambda - b_centered||^2
1287
+
1288
+ Parameters
1289
+ ----------
1290
+ Y : np.ndarray
1291
+ Matrix of shape (N, T0+1). Last column is the target (post-period
1292
+ mean or treated pre-period mean depending on context).
1293
+ zeta : float
1294
+ Regularization strength.
1295
+ intercept : bool, default True
1296
+ If True, column-center Y before optimization.
1297
+ init_weights : np.ndarray, optional
1298
+ Initial weights. If None, starts with uniform weights.
1299
+ min_decrease : float, default 1e-5
1300
+ Convergence criterion: stop when objective decreases by less than
1301
+ ``min_decrease**2``. R uses ``1e-5 * noise_level``; the caller
1302
+ should pass the data-dependent value for best results.
1303
+ max_iter : int, default 10000
1304
+ Maximum number of iterations. Matches R's default.
1305
+
1306
+ Returns
1307
+ -------
1308
+ np.ndarray
1309
+ Weights of shape (T0,) on the simplex.
1310
+ """
1311
+ if HAS_RUST_BACKEND:
1312
+ return np.asarray(_rust_sc_weight_fw(
1313
+ np.ascontiguousarray(Y, dtype=np.float64),
1314
+ zeta, intercept,
1315
+ np.ascontiguousarray(init_weights, dtype=np.float64) if init_weights is not None else None,
1316
+ min_decrease, max_iter,
1317
+ ))
1318
+ return _sc_weight_fw_numpy(Y, zeta, intercept, init_weights, min_decrease, max_iter)
1319
+
1320
+
1321
+ def _sc_weight_fw_numpy(
1322
+ Y: np.ndarray,
1323
+ zeta: float,
1324
+ intercept: bool = True,
1325
+ init_weights: Optional[np.ndarray] = None,
1326
+ min_decrease: float = 1e-5,
1327
+ max_iter: int = 10000,
1328
+ ) -> np.ndarray:
1329
+ """Pure NumPy implementation of Frank-Wolfe SC weight solver."""
1330
+ T0 = Y.shape[1] - 1
1331
+ N = Y.shape[0]
1332
+
1333
+ if T0 <= 0:
1334
+ return np.ones(max(T0, 1))
1335
+
1336
+ # Column-center if using intercept (matches R's intercept=TRUE default)
1337
+ if intercept:
1338
+ Y = Y - Y.mean(axis=0)
1339
+
1340
+ A = Y[:, :T0]
1341
+ b = Y[:, T0]
1342
+ eta = N * zeta ** 2
1343
+
1344
+ if init_weights is not None:
1345
+ lam = init_weights.copy()
1346
+ else:
1347
+ lam = np.ones(T0) / T0
1348
+
1349
+ vals = np.full(max_iter, np.nan)
1350
+ for t in range(max_iter):
1351
+ lam = _fw_step(A, lam, b, eta)
1352
+ err = Y @ np.append(lam, -1.0)
1353
+ vals[t] = zeta ** 2 * np.sum(lam ** 2) + np.sum(err ** 2) / N
1354
+ if t >= 1 and vals[t - 1] - vals[t] < min_decrease ** 2:
1355
+ break
1356
+
1357
+ return lam
1358
+
1359
+
1360
+ def _sparsify(v: np.ndarray) -> np.ndarray:
1361
+ """Sparsify weight vector by zeroing out small entries.
1362
+
1363
+ Matches R's synthdid ``sparsify_function``:
1364
+ ``v[v <= max(v)/4] = 0; v = v / sum(v)``
1365
+
1366
+ Parameters
1367
+ ----------
1368
+ v : np.ndarray
1369
+ Weight vector.
1370
+
1371
+ Returns
1372
+ -------
1373
+ np.ndarray
1374
+ Sparsified weight vector summing to 1.
1375
+ """
1376
+ v = v.copy()
1377
+ max_v = np.max(v)
1378
+ if max_v <= 0:
1379
+ return np.ones(len(v)) / len(v)
1380
+ v[v <= max_v / 4] = 0.0
1381
+ return _sum_normalize(v)
1382
+
1383
+
1384
+ def compute_time_weights(
1385
+ Y_pre_control: np.ndarray,
1386
+ Y_post_control: np.ndarray,
1387
+ zeta_lambda: float,
1388
+ intercept: bool = True,
1389
+ min_decrease: float = 1e-5,
1390
+ max_iter_pre_sparsify: int = 100,
1391
+ max_iter: int = 10000,
1392
+ ) -> np.ndarray:
1393
+ """Compute SDID time weights via Frank-Wolfe optimization.
1394
+
1395
+ Matches R's ``synthdid::sc.weight.fw(Yc[1:N0, ], zeta=zeta.lambda,
1396
+ intercept=TRUE)`` where ``Yc`` is the collapsed-form matrix. Uses
1397
+ two-pass optimization with sparsification (same as unit weights),
1398
+ matching R's default ``sparsify=sparsify_function``.
1399
+
1400
+ Parameters
1401
+ ----------
1402
+ Y_pre_control : np.ndarray
1403
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1404
+ Y_post_control : np.ndarray
1405
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1406
+ zeta_lambda : float
1407
+ Regularization parameter for time weights.
1408
+ intercept : bool, default True
1409
+ If True, column-center the optimization matrix.
1410
+ min_decrease : float, default 1e-5
1411
+ Convergence criterion for Frank-Wolfe. R uses ``1e-5 * noise_level``.
1412
+ max_iter_pre_sparsify : int, default 100
1413
+ Iterations for first pass (before sparsification).
1414
+ max_iter : int, default 10000
1415
+ Maximum iterations for second pass (after sparsification).
1416
+ Matches R's default.
1417
+
1418
+ Returns
1419
+ -------
1420
+ np.ndarray
1421
+ Time weights of shape (n_pre,) on the simplex.
1422
+ """
1423
+ if Y_post_control.shape[0] == 0:
1424
+ raise ValueError(
1425
+ "Y_post_control has no rows. At least one post-treatment period "
1426
+ "is required for time weight computation."
1427
+ )
1428
+
1429
+ if HAS_RUST_BACKEND:
1430
+ return np.asarray(_rust_compute_time_weights(
1431
+ np.ascontiguousarray(Y_pre_control, dtype=np.float64),
1432
+ np.ascontiguousarray(Y_post_control, dtype=np.float64),
1433
+ zeta_lambda, intercept, min_decrease,
1434
+ max_iter_pre_sparsify, max_iter,
1435
+ ))
1436
+
1437
+ n_pre = Y_pre_control.shape[0]
1438
+
1439
+ if n_pre <= 1:
1440
+ return np.ones(n_pre)
1441
+
1442
+ # Build collapsed form: (N_co, T_pre + 1), last col = per-control post mean
1443
+ post_means = np.mean(Y_post_control, axis=0) # (N_co,)
1444
+ Y_time = np.column_stack([Y_pre_control.T, post_means]) # (N_co, T_pre+1)
1445
+
1446
+ # First pass: limited iterations (matching R's max.iter.pre.sparsify)
1447
+ lam = _sc_weight_fw(
1448
+ Y_time,
1449
+ zeta=zeta_lambda,
1450
+ intercept=intercept,
1451
+ min_decrease=min_decrease,
1452
+ max_iter=max_iter_pre_sparsify,
1453
+ )
1454
+
1455
+ # Sparsify: zero out small weights, renormalize (R's sparsify_function)
1456
+ lam = _sparsify(lam)
1457
+
1458
+ # Second pass: from sparsified initialization (matching R's max.iter)
1459
+ lam = _sc_weight_fw(
1460
+ Y_time,
1461
+ zeta=zeta_lambda,
1462
+ intercept=intercept,
1463
+ init_weights=lam,
1464
+ min_decrease=min_decrease,
1465
+ max_iter=max_iter,
1466
+ )
1467
+
1468
+ return lam
1469
+
1470
+
1471
+ def compute_sdid_unit_weights(
1472
+ Y_pre_control: np.ndarray,
1473
+ Y_pre_treated_mean: np.ndarray,
1474
+ zeta_omega: float,
1475
+ intercept: bool = True,
1476
+ min_decrease: float = 1e-5,
1477
+ max_iter_pre_sparsify: int = 100,
1478
+ max_iter: int = 10000,
1479
+ ) -> np.ndarray:
1480
+ """Compute SDID unit weights via Frank-Wolfe with two-pass sparsification.
1481
+
1482
+ Matches R's ``synthdid::sc.weight.fw(t(Yc[, 1:T0]), zeta=zeta.omega,
1483
+ intercept=TRUE)`` followed by the sparsify/re-optimize pass.
1484
+
1485
+ Parameters
1486
+ ----------
1487
+ Y_pre_control : np.ndarray
1488
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1489
+ Y_pre_treated_mean : np.ndarray
1490
+ Mean treated outcomes in pre-treatment periods, shape (n_pre,).
1491
+ zeta_omega : float
1492
+ Regularization parameter for unit weights.
1493
+ intercept : bool, default True
1494
+ If True, column-center the optimization matrix.
1495
+ min_decrease : float, default 1e-5
1496
+ Convergence criterion for Frank-Wolfe. R uses ``1e-5 * noise_level``.
1497
+ max_iter_pre_sparsify : int, default 100
1498
+ Iterations for first pass (before sparsification).
1499
+ max_iter : int, default 10000
1500
+ Iterations for second pass (after sparsification). Matches R's default.
1501
+
1502
+ Returns
1503
+ -------
1504
+ np.ndarray
1505
+ Unit weights of shape (n_control,) on the simplex.
1506
+ """
1507
+ n_control = Y_pre_control.shape[1]
1508
+
1509
+ if n_control == 0:
1510
+ return np.asarray([])
1511
+ if n_control == 1:
1512
+ return np.asarray([1.0])
1513
+
1514
+ if HAS_RUST_BACKEND:
1515
+ return np.asarray(_rust_sdid_unit_weights(
1516
+ np.ascontiguousarray(Y_pre_control, dtype=np.float64),
1517
+ np.ascontiguousarray(Y_pre_treated_mean, dtype=np.float64),
1518
+ zeta_omega, intercept, min_decrease,
1519
+ max_iter_pre_sparsify, max_iter,
1520
+ ))
1521
+
1522
+ # Build collapsed form: (T_pre, N_co + 1), last col = treated pre means
1523
+ Y_unit = np.column_stack([Y_pre_control, Y_pre_treated_mean.reshape(-1, 1)])
1524
+
1525
+ # First pass: limited iterations
1526
+ omega = _sc_weight_fw(
1527
+ Y_unit,
1528
+ zeta=zeta_omega,
1529
+ intercept=intercept,
1530
+ max_iter=max_iter_pre_sparsify,
1531
+ min_decrease=min_decrease,
1532
+ )
1533
+
1534
+ # Sparsify: zero out weights <= max/4, renormalize
1535
+ omega = _sparsify(omega)
1536
+
1537
+ # Second pass: from sparsified initialization
1538
+ omega = _sc_weight_fw(
1539
+ Y_unit,
1540
+ zeta=zeta_omega,
1541
+ intercept=intercept,
1542
+ init_weights=omega,
1543
+ max_iter=max_iter,
1544
+ min_decrease=min_decrease,
1545
+ )
1546
+
1547
+ return omega
1548
+
1549
+
1550
+ def compute_sdid_estimator(
1551
+ Y_pre_control: np.ndarray,
1552
+ Y_post_control: np.ndarray,
1553
+ Y_pre_treated: np.ndarray,
1554
+ Y_post_treated: np.ndarray,
1555
+ unit_weights: np.ndarray,
1556
+ time_weights: np.ndarray
1557
+ ) -> float:
1558
+ """
1559
+ Compute the Synthetic DiD estimator.
1560
+
1561
+ Parameters
1562
+ ----------
1563
+ Y_pre_control : np.ndarray
1564
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1565
+ Y_post_control : np.ndarray
1566
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1567
+ Y_pre_treated : np.ndarray
1568
+ Treated unit outcomes in pre-treatment periods, shape (n_pre,).
1569
+ Y_post_treated : np.ndarray
1570
+ Treated unit outcomes in post-treatment periods, shape (n_post,).
1571
+ unit_weights : np.ndarray
1572
+ Weights for control units, shape (n_control,).
1573
+ time_weights : np.ndarray
1574
+ Weights for pre-treatment periods, shape (n_pre,).
1575
+
1576
+ Returns
1577
+ -------
1578
+ float
1579
+ The synthetic DiD treatment effect estimate.
1580
+
1581
+ Notes
1582
+ -----
1583
+ The SDID estimator is:
1584
+
1585
+ τ̂ = (Ȳ_treated,post - Σ_t λ_t * Y_treated,t)
1586
+ - Σ_j ω_j * (Ȳ_j,post - Σ_t λ_t * Y_j,t)
1587
+
1588
+ Where:
1589
+ - ω_j are unit weights
1590
+ - λ_t are time weights
1591
+ - Ȳ denotes average over post periods
1592
+ """
1593
+ # Weighted pre-treatment averages
1594
+ weighted_pre_control = time_weights @ Y_pre_control # shape: (n_control,)
1595
+ weighted_pre_treated = time_weights @ Y_pre_treated # scalar
1596
+
1597
+ # Post-treatment averages
1598
+ mean_post_control = np.mean(Y_post_control, axis=0) # shape: (n_control,)
1599
+ mean_post_treated = np.mean(Y_post_treated) # scalar
1600
+
1601
+ # DiD for treated: post - weighted pre
1602
+ did_treated = mean_post_treated - weighted_pre_treated
1603
+
1604
+ # Weighted DiD for controls: sum over j of omega_j * (post_j - weighted_pre_j)
1605
+ did_control = unit_weights @ (mean_post_control - weighted_pre_control)
1606
+
1607
+ # SDID estimator
1608
+ tau = did_treated - did_control
1609
+
1610
+ return float(tau)
1611
+
1612
+
1613
+ def compute_placebo_effects(
1614
+ Y_pre_control: np.ndarray,
1615
+ Y_post_control: np.ndarray,
1616
+ Y_pre_treated: np.ndarray,
1617
+ unit_weights: np.ndarray,
1618
+ time_weights: np.ndarray,
1619
+ control_unit_ids: List[Any],
1620
+ n_placebo: Optional[int] = None
1621
+ ) -> np.ndarray:
1622
+ """
1623
+ Compute placebo treatment effects by treating each control as treated.
1624
+
1625
+ Used for inference in synthetic DiD when bootstrap is not appropriate.
1626
+
1627
+ Parameters
1628
+ ----------
1629
+ Y_pre_control : np.ndarray
1630
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1631
+ Y_post_control : np.ndarray
1632
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1633
+ Y_pre_treated : np.ndarray
1634
+ Treated outcomes in pre-treatment periods, shape (n_pre,).
1635
+ unit_weights : np.ndarray
1636
+ Unit weights, shape (n_control,).
1637
+ time_weights : np.ndarray
1638
+ Time weights, shape (n_pre,).
1639
+ control_unit_ids : list
1640
+ List of control unit identifiers.
1641
+ n_placebo : int, optional
1642
+ Number of placebo tests. If None, uses all control units.
1643
+
1644
+ Returns
1645
+ -------
1646
+ np.ndarray
1647
+ Array of placebo treatment effects.
1648
+
1649
+ Notes
1650
+ -----
1651
+ For each control unit j, we pretend it was treated and compute
1652
+ the SDID estimate using the remaining controls. The distribution
1653
+ of these placebo effects provides a reference for inference.
1654
+
1655
+ .. deprecated::
1656
+ This function uses legacy projected-gradient weights and fixed time
1657
+ weights. Use ``SyntheticDiD(variance_method='placebo')`` instead.
1658
+ """
1659
+ warnings.warn(
1660
+ "compute_placebo_effects uses legacy projected-gradient weights and "
1661
+ "fixed time weights, which does not match the SDID methodology "
1662
+ "(Algorithm 4 requires re-estimating both omega and lambda per "
1663
+ "placebo iteration). Use SyntheticDiD(variance_method='placebo') "
1664
+ "for correct placebo inference. This function will be removed in a "
1665
+ "future version.",
1666
+ DeprecationWarning,
1667
+ stacklevel=2,
1668
+ )
1669
+ n_pre, n_control = Y_pre_control.shape
1670
+
1671
+ if n_placebo is None:
1672
+ n_placebo = n_control
1673
+
1674
+ placebo_effects = []
1675
+
1676
+ for j in range(min(n_placebo, n_control)):
1677
+ # Treat unit j as the "treated" unit
1678
+ Y_pre_placebo_treated = Y_pre_control[:, j]
1679
+ Y_post_placebo_treated = Y_post_control[:, j]
1680
+
1681
+ # Use remaining units as controls
1682
+ remaining_idx = [i for i in range(n_control) if i != j]
1683
+
1684
+ if len(remaining_idx) == 0:
1685
+ continue
1686
+
1687
+ Y_pre_remaining = Y_pre_control[:, remaining_idx]
1688
+ Y_post_remaining = Y_post_control[:, remaining_idx]
1689
+
1690
+ # Recompute weights for remaining controls
1691
+ remaining_weights = compute_synthetic_weights(
1692
+ Y_pre_remaining,
1693
+ Y_pre_placebo_treated
1694
+ )
1695
+
1696
+ # Compute placebo effect
1697
+ placebo_tau = compute_sdid_estimator(
1698
+ Y_pre_remaining,
1699
+ Y_post_remaining,
1700
+ Y_pre_placebo_treated,
1701
+ Y_post_placebo_treated,
1702
+ remaining_weights,
1703
+ time_weights
1704
+ )
1705
+
1706
+ placebo_effects.append(placebo_tau)
1707
+
1708
+ return np.asarray(placebo_effects)
1709
+
1710
+
1711
+ def demean_by_group(
1712
+ data: pd.DataFrame,
1713
+ variables: List[str],
1714
+ group_var: str,
1715
+ inplace: bool = False,
1716
+ suffix: str = "",
1717
+ ) -> Tuple[pd.DataFrame, int]:
1718
+ """
1719
+ Demean variables by a grouping variable (one-way within transformation).
1720
+
1721
+ For each variable, computes: x_ig - mean(x_g) where g is the group.
1722
+
1723
+ Parameters
1724
+ ----------
1725
+ data : pd.DataFrame
1726
+ DataFrame containing the variables to demean.
1727
+ variables : list of str
1728
+ Column names to demean.
1729
+ group_var : str
1730
+ Column name for the grouping variable.
1731
+ inplace : bool, default False
1732
+ If True, modifies the original columns. If False, leaves original
1733
+ columns unchanged (demeaning is still applied to return value).
1734
+ suffix : str, default ""
1735
+ Suffix to add to demeaned column names (only used when inplace=False
1736
+ and you want to keep both original and demeaned columns).
1737
+
1738
+ Returns
1739
+ -------
1740
+ data : pd.DataFrame
1741
+ DataFrame with demeaned variables.
1742
+ n_effects : int
1743
+ Number of absorbed fixed effects (nunique - 1).
1744
+
1745
+ Examples
1746
+ --------
1747
+ >>> df, n_fe = demean_by_group(df, ['y', 'x1', 'x2'], 'unit')
1748
+ >>> # df['y'], df['x1'], df['x2'] are now demeaned by unit
1749
+ """
1750
+ if not inplace:
1751
+ data = data.copy()
1752
+
1753
+ # Count fixed effects (categories - 1 for identification)
1754
+ n_effects = data[group_var].nunique() - 1
1755
+
1756
+ # Cache the groupby object for efficiency
1757
+ grouper = data.groupby(group_var, sort=False)
1758
+
1759
+ for var in variables:
1760
+ col_name = var if not suffix else f"{var}{suffix}"
1761
+ group_means = grouper[var].transform("mean")
1762
+ data[col_name] = data[var] - group_means
1763
+
1764
+ return data, n_effects
1765
+
1766
+
1767
+ def within_transform(
1768
+ data: pd.DataFrame,
1769
+ variables: List[str],
1770
+ unit: str,
1771
+ time: str,
1772
+ inplace: bool = False,
1773
+ suffix: str = "_demeaned",
1774
+ ) -> pd.DataFrame:
1775
+ """
1776
+ Apply two-way within transformation to remove unit and time fixed effects.
1777
+
1778
+ Computes: y_it - y_i. - y_.t + y_.. for each variable.
1779
+
1780
+ This is the standard fixed effects transformation for panel data that
1781
+ removes both unit-specific and time-specific effects.
1782
+
1783
+ Parameters
1784
+ ----------
1785
+ data : pd.DataFrame
1786
+ Panel data containing the variables to transform.
1787
+ variables : list of str
1788
+ Column names to transform.
1789
+ unit : str
1790
+ Column name for unit identifier.
1791
+ time : str
1792
+ Column name for time period identifier.
1793
+ inplace : bool, default False
1794
+ If True, modifies the original columns. If False, creates new columns
1795
+ with the specified suffix.
1796
+ suffix : str, default "_demeaned"
1797
+ Suffix for new column names when inplace=False.
1798
+
1799
+ Returns
1800
+ -------
1801
+ pd.DataFrame
1802
+ DataFrame with within-transformed variables.
1803
+
1804
+ Notes
1805
+ -----
1806
+ The within transformation removes variation that is constant within units
1807
+ (unit fixed effects) and constant within time periods (time fixed effects).
1808
+ The resulting estimates are equivalent to including unit and time dummies
1809
+ but is computationally more efficient for large panels.
1810
+
1811
+ Examples
1812
+ --------
1813
+ >>> df = within_transform(df, ['y', 'x'], 'unit_id', 'year')
1814
+ >>> # df now has 'y_demeaned' and 'x_demeaned' columns
1815
+ """
1816
+ if not inplace:
1817
+ data = data.copy()
1818
+
1819
+ # Cache groupby objects for efficiency
1820
+ unit_grouper = data.groupby(unit, sort=False)
1821
+ time_grouper = data.groupby(time, sort=False)
1822
+
1823
+ if inplace:
1824
+ # Modify columns in place
1825
+ for var in variables:
1826
+ unit_means = unit_grouper[var].transform("mean")
1827
+ time_means = time_grouper[var].transform("mean")
1828
+ grand_mean = data[var].mean()
1829
+ data[var] = data[var] - unit_means - time_means + grand_mean
1830
+ else:
1831
+ # Build all demeaned columns at once to avoid DataFrame fragmentation
1832
+ demeaned_data = {}
1833
+ for var in variables:
1834
+ unit_means = unit_grouper[var].transform("mean")
1835
+ time_means = time_grouper[var].transform("mean")
1836
+ grand_mean = data[var].mean()
1837
+ demeaned_data[f"{var}{suffix}"] = (
1838
+ data[var] - unit_means - time_means + grand_mean
1839
+ ).values
1840
+
1841
+ # Add all columns at once
1842
+ demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
1843
+ data = pd.concat([data, demeaned_df], axis=1)
1844
+
1845
+ return data