diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/utils.py ADDED
@@ -0,0 +1,1902 @@
1
+ """
2
+ Utility functions for difference-in-differences estimation.
3
+ """
4
+
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy import stats
12
+
13
+ from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
14
+ from diff_diff.linalg import solve_ols as _solve_ols_linalg
15
+
16
+ # Import Rust backend if available (from _backend to avoid circular imports)
17
+ from diff_diff._backend import (
18
+ HAS_RUST_BACKEND,
19
+ _rust_project_simplex,
20
+ _rust_synthetic_weights,
21
+ _rust_sdid_unit_weights,
22
+ _rust_compute_time_weights,
23
+ _rust_compute_noise_level,
24
+ _rust_sc_weight_fw,
25
+ )
26
+
27
+ # Numerical constants for optimization algorithms
28
+ _OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization
29
+ _OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization
30
+ _NUMERICAL_EPS = 1e-10 # Small constant to prevent division by zero
31
+
32
+ # Cache for critical values to avoid repeated scipy calls
33
+ _critical_value_cache: Dict[Tuple[float, Optional[int]], float] = {}
34
+
35
+
36
+ def _get_critical_value(alpha: float, df: Optional[int] = None) -> float:
37
+ """Return cached critical value for (alpha, df) pair."""
38
+ key = (alpha, df)
39
+ if key not in _critical_value_cache:
40
+ if df is not None:
41
+ _critical_value_cache[key] = float(stats.t.ppf(1 - alpha / 2, df))
42
+ else:
43
+ _critical_value_cache[key] = float(stats.norm.ppf(1 - alpha / 2))
44
+ return _critical_value_cache[key]
45
+
46
+
47
+ def validate_binary(arr: np.ndarray, name: str) -> None:
48
+ """
49
+ Validate that an array contains only binary values (0 or 1).
50
+
51
+ Parameters
52
+ ----------
53
+ arr : np.ndarray
54
+ Array to validate.
55
+ name : str
56
+ Name of the variable (for error messages).
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If array contains non-binary values.
62
+ """
63
+ unique_values = np.unique(arr[~np.isnan(arr)])
64
+ if not np.all(np.isin(unique_values, [0, 1])):
65
+ raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}")
66
+
67
+
68
+ def compute_robust_se(
69
+ X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None
70
+ ) -> np.ndarray:
71
+ """
72
+ Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.
73
+
74
+ This function is a thin wrapper around the optimized implementation in
75
+ diff_diff.linalg for backwards compatibility.
76
+
77
+ Parameters
78
+ ----------
79
+ X : np.ndarray
80
+ Design matrix of shape (n, k).
81
+ residuals : np.ndarray
82
+ Residuals from regression of shape (n,).
83
+ cluster_ids : np.ndarray, optional
84
+ Cluster identifiers for cluster-robust SEs.
85
+
86
+ Returns
87
+ -------
88
+ np.ndarray
89
+ Variance-covariance matrix of shape (k, k).
90
+ """
91
+ return _compute_robust_vcov_linalg(X, residuals, cluster_ids)
92
+
93
+
94
+ def compute_confidence_interval(
95
+ estimate: float, se: float, alpha: float = 0.05, df: Optional[int] = None
96
+ ) -> Tuple[float, float]:
97
+ """
98
+ Compute confidence interval for an estimate.
99
+
100
+ Parameters
101
+ ----------
102
+ estimate : float
103
+ Point estimate.
104
+ se : float
105
+ Standard error.
106
+ alpha : float
107
+ Significance level (default 0.05 for 95% CI).
108
+ df : int, optional
109
+ Degrees of freedom. If None, uses normal distribution.
110
+
111
+ Returns
112
+ -------
113
+ tuple
114
+ (lower_bound, upper_bound) of confidence interval.
115
+ """
116
+ critical_value = _get_critical_value(alpha, df)
117
+ lower = estimate - critical_value * se
118
+ upper = estimate + critical_value * se
119
+
120
+ return (lower, upper)
121
+
122
+
123
+ def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
124
+ """
125
+ Compute p-value for a t-statistic.
126
+
127
+ Parameters
128
+ ----------
129
+ t_stat : float
130
+ T-statistic.
131
+ df : int, optional
132
+ Degrees of freedom. If None, uses normal distribution.
133
+ two_sided : bool
134
+ Whether to compute two-sided p-value (default True).
135
+
136
+ Returns
137
+ -------
138
+ float
139
+ P-value.
140
+ """
141
+ if df is not None:
142
+ p_value = stats.t.sf(np.abs(t_stat), df)
143
+ else:
144
+ p_value = stats.norm.sf(np.abs(t_stat))
145
+
146
+ if two_sided:
147
+ p_value *= 2
148
+
149
+ return float(p_value)
150
+
151
+
152
+ def safe_inference(effect, se, alpha=0.05, df=None):
153
+ """Compute t_stat, p_value, conf_int with NaN-safe gating.
154
+
155
+ When SE is non-finite, zero, or negative, ALL inference fields
156
+ are set to NaN to prevent misleading statistical output.
157
+
158
+ Accepts scalar inputs only (not numpy arrays). All existing inference
159
+ call sites operate on scalars within loops.
160
+
161
+ Parameters
162
+ ----------
163
+ effect : float
164
+ Point estimate (treatment effect or coefficient).
165
+ se : float
166
+ Standard error of the estimate.
167
+ alpha : float, optional
168
+ Significance level for confidence interval (default 0.05).
169
+ df : int, optional
170
+ Degrees of freedom. If None, uses normal distribution.
171
+
172
+ Returns
173
+ -------
174
+ tuple
175
+ (t_stat, p_value, (ci_lower, ci_upper)). All NaN when SE is
176
+ non-finite, zero, or negative.
177
+ """
178
+ if not (np.isfinite(se) and se > 0):
179
+ return np.nan, np.nan, (np.nan, np.nan)
180
+ if df is not None and df <= 0:
181
+ # Undefined degrees of freedom (e.g., rank-deficient replicate design)
182
+ return np.nan, np.nan, (np.nan, np.nan)
183
+ t_stat = effect / se
184
+ p_value = compute_p_value(t_stat, df=df)
185
+ conf_int = compute_confidence_interval(effect, se, alpha, df=df)
186
+ return t_stat, p_value, conf_int
187
+
188
+
189
+ def safe_inference_batch(effects, ses, alpha=0.05, df=None):
190
+ """Vectorized batch inference for arrays of effects and SEs.
191
+
192
+ Parameters
193
+ ----------
194
+ effects : np.ndarray
195
+ Array of point estimates.
196
+ ses : np.ndarray
197
+ Array of standard errors.
198
+ alpha : float, optional
199
+ Significance level (default 0.05).
200
+ df : int, optional
201
+ Degrees of freedom. If None, uses normal distribution.
202
+
203
+ Returns
204
+ -------
205
+ t_stats : np.ndarray
206
+ p_values : np.ndarray
207
+ ci_lowers : np.ndarray
208
+ ci_uppers : np.ndarray
209
+ """
210
+ effects = np.asarray(effects, dtype=float)
211
+ ses = np.asarray(ses, dtype=float)
212
+ n = len(effects)
213
+
214
+ t_stats = np.full(n, np.nan)
215
+ p_values = np.full(n, np.nan)
216
+ ci_lowers = np.full(n, np.nan)
217
+ ci_uppers = np.full(n, np.nan)
218
+
219
+ # Undefined df (e.g., rank-deficient replicate design) → all NaN
220
+ if df is not None and df <= 0:
221
+ return t_stats, p_values, ci_lowers, ci_uppers
222
+
223
+ valid = np.isfinite(ses) & (ses > 0)
224
+ if not np.any(valid):
225
+ return t_stats, p_values, ci_lowers, ci_uppers
226
+
227
+ t_stats[valid] = effects[valid] / ses[valid]
228
+
229
+ if df is not None:
230
+ p_values[valid] = 2.0 * stats.t.sf(np.abs(t_stats[valid]), df)
231
+ else:
232
+ p_values[valid] = 2.0 * stats.norm.sf(np.abs(t_stats[valid]))
233
+
234
+ crit = _get_critical_value(alpha, df)
235
+ ci_lowers[valid] = effects[valid] - crit * ses[valid]
236
+ ci_uppers[valid] = effects[valid] + crit * ses[valid]
237
+
238
+ return t_stats, p_values, ci_lowers, ci_uppers
239
+
240
+
241
+ # =============================================================================
242
+ # Wild Cluster Bootstrap
243
+ # =============================================================================
244
+
245
+
246
+ @dataclass
247
+ class WildBootstrapResults:
248
+ """
249
+ Results from wild cluster bootstrap inference.
250
+
251
+ Attributes
252
+ ----------
253
+ se : float
254
+ Bootstrap standard error of the coefficient.
255
+ p_value : float
256
+ Bootstrap p-value (two-sided).
257
+ t_stat_original : float
258
+ Original t-statistic from the data.
259
+ ci_lower : float
260
+ Lower bound of the confidence interval.
261
+ ci_upper : float
262
+ Upper bound of the confidence interval.
263
+ n_clusters : int
264
+ Number of clusters in the data.
265
+ n_bootstrap : int
266
+ Number of bootstrap replications.
267
+ weight_type : str
268
+ Type of bootstrap weights used ("rademacher", "webb", or "mammen").
269
+ alpha : float
270
+ Significance level used for confidence interval.
271
+ bootstrap_distribution : np.ndarray, optional
272
+ Full bootstrap distribution of coefficients (if requested).
273
+
274
+ References
275
+ ----------
276
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
277
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
278
+ The Review of Economics and Statistics, 90(3), 414-427.
279
+ """
280
+
281
+ se: float
282
+ p_value: float
283
+ t_stat_original: float
284
+ ci_lower: float
285
+ ci_upper: float
286
+ n_clusters: int
287
+ n_bootstrap: int
288
+ weight_type: str
289
+ alpha: float = 0.05
290
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
291
+
292
+ def summary(self) -> str:
293
+ """Generate formatted summary of bootstrap results."""
294
+ lines = [
295
+ "Wild Cluster Bootstrap Results",
296
+ "=" * 40,
297
+ f"Bootstrap SE: {self.se:.6f}",
298
+ f"Bootstrap p-value: {self.p_value:.4f}",
299
+ f"Original t-stat: {self.t_stat_original:.4f}",
300
+ f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]",
301
+ f"Number of clusters: {self.n_clusters}",
302
+ f"Bootstrap reps: {self.n_bootstrap}",
303
+ f"Weight type: {self.weight_type}",
304
+ ]
305
+ return "\n".join(lines)
306
+
307
+ def print_summary(self) -> None:
308
+ """Print formatted summary to stdout."""
309
+ print(self.summary())
310
+
311
+
312
+ def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
313
+ """
314
+ Generate Rademacher weights: +1 or -1 with probability 0.5.
315
+
316
+ Parameters
317
+ ----------
318
+ n_clusters : int
319
+ Number of clusters.
320
+ rng : np.random.Generator
321
+ Random number generator.
322
+
323
+ Returns
324
+ -------
325
+ np.ndarray
326
+ Array of Rademacher weights.
327
+ """
328
+ return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters))
329
+
330
+
331
+ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
332
+ """
333
+ Generate Webb's 6-point distribution weights.
334
+
335
+ Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
336
+ with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
337
+
338
+ This distribution is recommended for very few clusters (G < 10) as it
339
+ provides better finite-sample properties than Rademacher weights.
340
+
341
+ Parameters
342
+ ----------
343
+ n_clusters : int
344
+ Number of clusters.
345
+ rng : np.random.Generator
346
+ Random number generator.
347
+
348
+ Returns
349
+ -------
350
+ np.ndarray
351
+ Array of Webb weights.
352
+
353
+ References
354
+ ----------
355
+ Webb, M. D. (2014). Reworking wild bootstrap based inference for
356
+ clustered errors. Queen's Economics Department Working Paper No. 1315.
357
+
358
+ Note: Uses equal probabilities (1/6 each) matching R's `did` package,
359
+ which gives unit variance for consistency with other weight distributions.
360
+ """
361
+ values = np.array(
362
+ [
363
+ -np.sqrt(3 / 2),
364
+ -np.sqrt(2 / 2),
365
+ -np.sqrt(1 / 2),
366
+ np.sqrt(1 / 2),
367
+ np.sqrt(2 / 2),
368
+ np.sqrt(3 / 2),
369
+ ]
370
+ )
371
+ # Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
372
+ return np.asarray(rng.choice(values, size=n_clusters))
373
+
374
+
375
+ def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
376
+ """
377
+ Generate Mammen's two-point distribution weights.
378
+
379
+ Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2}
380
+ with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}.
381
+
382
+ This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides
383
+ asymptotic refinement for skewed error distributions.
384
+
385
+ Parameters
386
+ ----------
387
+ n_clusters : int
388
+ Number of clusters.
389
+ rng : np.random.Generator
390
+ Random number generator.
391
+
392
+ Returns
393
+ -------
394
+ np.ndarray
395
+ Array of Mammen weights.
396
+
397
+ References
398
+ ----------
399
+ Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional
400
+ Linear Models. The Annals of Statistics, 21(1), 255-285.
401
+ """
402
+ sqrt5 = np.sqrt(5)
403
+ # Values from Mammen (1993)
404
+ val1 = -(sqrt5 - 1) / 2 # approximately -0.618
405
+ val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio)
406
+
407
+ # Probability of val1
408
+ p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724
409
+
410
+ return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
411
+
412
+
413
+ def wild_bootstrap_se(
414
+ X: np.ndarray,
415
+ y: np.ndarray,
416
+ residuals: np.ndarray,
417
+ cluster_ids: np.ndarray,
418
+ coefficient_index: int,
419
+ n_bootstrap: int = 999,
420
+ weight_type: str = "rademacher",
421
+ null_hypothesis: float = 0.0,
422
+ alpha: float = 0.05,
423
+ seed: Optional[int] = None,
424
+ return_distribution: bool = False,
425
+ ) -> WildBootstrapResults:
426
+ """
427
+ Compute wild cluster bootstrap standard errors and p-values.
428
+
429
+ Implements the Wild Cluster Residual (WCR) bootstrap procedure from
430
+ Cameron, Gelbach, and Miller (2008). Uses the restricted residuals
431
+ approach (imposing H0: coefficient = null_hypothesis) for more accurate
432
+ p-value computation.
433
+
434
+ Parameters
435
+ ----------
436
+ X : np.ndarray
437
+ Design matrix of shape (n, k).
438
+ y : np.ndarray
439
+ Outcome vector of shape (n,).
440
+ residuals : np.ndarray
441
+ OLS residuals from unrestricted regression, shape (n,).
442
+ cluster_ids : np.ndarray
443
+ Cluster identifiers of shape (n,).
444
+ coefficient_index : int
445
+ Index of the coefficient for which to compute bootstrap inference.
446
+ For DiD, this is typically 3 (the treatment*post interaction term).
447
+ n_bootstrap : int, default=999
448
+ Number of bootstrap replications. Odd numbers are recommended for
449
+ exact p-value computation.
450
+ weight_type : str, default="rademacher"
451
+ Type of bootstrap weights:
452
+ - "rademacher": +1 or -1 with equal probability (standard choice)
453
+ - "webb": 6-point distribution (recommended for <10 clusters)
454
+ - "mammen": Two-point distribution with skewness correction
455
+ null_hypothesis : float, default=0.0
456
+ Value of the null hypothesis for p-value computation.
457
+ alpha : float, default=0.05
458
+ Significance level for confidence interval.
459
+ seed : int, optional
460
+ Random seed for reproducibility. If None (default), results
461
+ will vary between runs.
462
+ return_distribution : bool, default=False
463
+ If True, include full bootstrap distribution in results.
464
+
465
+ Returns
466
+ -------
467
+ WildBootstrapResults
468
+ Dataclass containing bootstrap SE, p-value, confidence interval,
469
+ and other inference results.
470
+
471
+ Raises
472
+ ------
473
+ ValueError
474
+ If weight_type is not recognized or if there are fewer than 2 clusters.
475
+
476
+ Warns
477
+ -----
478
+ UserWarning
479
+ If the number of clusters is less than 5, as bootstrap inference
480
+ may be unreliable.
481
+
482
+ Examples
483
+ --------
484
+ >>> from diff_diff.utils import wild_bootstrap_se
485
+ >>> results = wild_bootstrap_se(
486
+ ... X, y, residuals, cluster_ids,
487
+ ... coefficient_index=3, # ATT coefficient
488
+ ... n_bootstrap=999,
489
+ ... weight_type="rademacher",
490
+ ... seed=42
491
+ ... )
492
+ >>> print(f"Bootstrap SE: {results.se:.4f}")
493
+ >>> print(f"Bootstrap p-value: {results.p_value:.4f}")
494
+
495
+ References
496
+ ----------
497
+ Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
498
+ Bootstrap-Based Improvements for Inference with Clustered Errors.
499
+ The Review of Economics and Statistics, 90(3), 414-427.
500
+
501
+ MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for
502
+ few (treated) clusters. The Econometrics Journal, 21(2), 114-135.
503
+ """
504
+ # Validate inputs
505
+ valid_weight_types = ["rademacher", "webb", "mammen"]
506
+ if weight_type not in valid_weight_types:
507
+ raise ValueError(f"weight_type must be one of {valid_weight_types}, got '{weight_type}'")
508
+
509
+ unique_clusters = np.unique(cluster_ids)
510
+ n_clusters = len(unique_clusters)
511
+
512
+ if n_clusters < 2:
513
+ raise ValueError(f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}")
514
+
515
+ if n_clusters < 5:
516
+ warnings.warn(
517
+ f"Only {n_clusters} clusters detected. Wild bootstrap inference may be "
518
+ "unreliable with fewer than 5 clusters. Consider using Webb weights "
519
+ "(weight_type='webb') for improved finite-sample properties.",
520
+ UserWarning,
521
+ )
522
+
523
+ # Initialize RNG
524
+ rng = np.random.default_rng(seed)
525
+
526
+ # Select weight generator
527
+ weight_generators = {
528
+ "rademacher": _generate_rademacher_weights,
529
+ "webb": _generate_webb_weights,
530
+ "mammen": _generate_mammen_weights,
531
+ }
532
+ generate_weights = weight_generators[weight_type]
533
+
534
+ n = X.shape[0]
535
+
536
+ # Step 1: Compute original coefficient and cluster-robust SE
537
+ beta_hat, _, vcov_original = _solve_ols_linalg(X, y, cluster_ids=cluster_ids, return_vcov=True)
538
+ original_coef = beta_hat[coefficient_index]
539
+ assert vcov_original is not None
540
+ se_original = np.sqrt(vcov_original[coefficient_index, coefficient_index])
541
+ t_stat_original = (original_coef - null_hypothesis) / se_original
542
+
543
+ # Step 2: Impose null hypothesis (restricted estimation)
544
+ # Create restricted y: y_restricted = y - X[:, coef_index] * null_hypothesis
545
+ # This imposes the null that the coefficient equals null_hypothesis
546
+ y_restricted = y - X[:, coefficient_index] * null_hypothesis
547
+
548
+ # Fit restricted model (but we need to drop the column for the restricted coef)
549
+ # Actually, for WCR bootstrap we keep all columns but impose the null via residuals
550
+ # Re-estimate with the restricted dependent variable
551
+ beta_restricted, residuals_restricted, _ = _solve_ols_linalg(X, y_restricted, return_vcov=False)
552
+
553
+ # Create cluster-to-observation mapping for efficiency
554
+ cluster_map = {c: np.where(cluster_ids == c)[0] for c in unique_clusters}
555
+ cluster_indices = [cluster_map[c] for c in unique_clusters]
556
+
557
+ # Step 3: Bootstrap loop
558
+ bootstrap_t_stats = np.zeros(n_bootstrap)
559
+ bootstrap_coefs = np.zeros(n_bootstrap)
560
+
561
+ for b in range(n_bootstrap):
562
+ # Generate cluster-level weights
563
+ cluster_weights = generate_weights(n_clusters, rng)
564
+
565
+ # Map cluster weights to observations
566
+ obs_weights = np.zeros(n)
567
+ for g, indices in enumerate(cluster_indices):
568
+ obs_weights[indices] = cluster_weights[g]
569
+
570
+ # Construct bootstrap sample: y* = X @ beta_restricted + e_restricted * weights
571
+ y_star = np.dot(X, beta_restricted) + residuals_restricted * obs_weights
572
+
573
+ # Estimate bootstrap coefficients with cluster-robust SE
574
+ beta_star, residuals_star, vcov_star = _solve_ols_linalg(
575
+ X, y_star, cluster_ids=cluster_ids, return_vcov=True
576
+ )
577
+ bootstrap_coefs[b] = beta_star[coefficient_index]
578
+ assert vcov_star is not None
579
+ se_star = np.sqrt(vcov_star[coefficient_index, coefficient_index])
580
+
581
+ # Compute bootstrap t-statistic (under null hypothesis)
582
+ if se_star > 0:
583
+ bootstrap_t_stats[b] = (beta_star[coefficient_index] - null_hypothesis) / se_star
584
+ else:
585
+ bootstrap_t_stats[b] = 0.0
586
+
587
+ # Step 4: Compute bootstrap p-value
588
+ # P-value is proportion of |t*| >= |t_original|
589
+ p_value = np.mean(np.abs(bootstrap_t_stats) >= np.abs(t_stat_original))
590
+
591
+ # Ensure p-value is at least 1/(n_bootstrap+1) to avoid exact zero
592
+ p_value = float(max(float(p_value), 1 / (n_bootstrap + 1)))
593
+
594
+ # Step 5: Compute bootstrap SE and confidence interval
595
+ # SE from standard deviation of bootstrap coefficient distribution
596
+ se_bootstrap = float(np.std(bootstrap_coefs, ddof=1))
597
+
598
+ # Percentile confidence interval from bootstrap distribution
599
+ lower_percentile = alpha / 2 * 100
600
+ upper_percentile = (1 - alpha / 2) * 100
601
+ ci_lower = float(np.percentile(bootstrap_coefs, lower_percentile))
602
+ ci_upper = float(np.percentile(bootstrap_coefs, upper_percentile))
603
+
604
+ return WildBootstrapResults(
605
+ se=se_bootstrap,
606
+ p_value=p_value,
607
+ t_stat_original=t_stat_original,
608
+ ci_lower=ci_lower,
609
+ ci_upper=ci_upper,
610
+ n_clusters=n_clusters,
611
+ n_bootstrap=n_bootstrap,
612
+ weight_type=weight_type,
613
+ alpha=alpha,
614
+ bootstrap_distribution=bootstrap_coefs if return_distribution else None,
615
+ )
616
+
617
+
618
+ def check_parallel_trends(
619
+ data: pd.DataFrame,
620
+ outcome: str,
621
+ time: str,
622
+ treatment_group: str,
623
+ pre_periods: Optional[List[Any]] = None,
624
+ ) -> Dict[str, Any]:
625
+ """
626
+ Perform a simple check for parallel trends assumption.
627
+
628
+ This computes the trend (slope) in the outcome variable for both
629
+ treatment and control groups during pre-treatment periods.
630
+
631
+ Parameters
632
+ ----------
633
+ data : pd.DataFrame
634
+ Panel data.
635
+ outcome : str
636
+ Name of outcome variable column.
637
+ time : str
638
+ Name of time period column.
639
+ treatment_group : str
640
+ Name of treatment group indicator column.
641
+ pre_periods : list, optional
642
+ List of pre-treatment time periods. If None, infers from data.
643
+
644
+ Returns
645
+ -------
646
+ dict
647
+ Dictionary with trend statistics and test results.
648
+ """
649
+ if pre_periods is None:
650
+ # Assume treatment happens at median time period
651
+ all_periods = sorted(data[time].unique())
652
+ mid_point = len(all_periods) // 2
653
+ pre_periods = all_periods[:mid_point]
654
+
655
+ pre_data = data[data[time].isin(pre_periods)]
656
+
657
+ # Compute trends for each group
658
+ treated_data = pre_data[pre_data[treatment_group] == 1]
659
+ control_data = pre_data[pre_data[treatment_group] == 0]
660
+
661
+ # Simple linear regression for trends
662
+ def compute_trend(group_data: pd.DataFrame) -> Tuple[float, float]:
663
+ time_values = group_data[time].values
664
+ outcome_values = group_data[outcome].values
665
+
666
+ # Normalize time to start at 0
667
+ time_norm = time_values - time_values.min()
668
+
669
+ # Compute slope using least squares
670
+ n = len(time_norm)
671
+ if n < 2:
672
+ return np.nan, np.nan
673
+
674
+ mean_t = np.mean(time_norm)
675
+ mean_y = np.mean(outcome_values)
676
+
677
+ # Check for zero variance in time (all same time period)
678
+ time_var = np.sum((time_norm - mean_t) ** 2)
679
+ if time_var == 0:
680
+ return np.nan, np.nan
681
+
682
+ slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / time_var
683
+
684
+ # Compute standard error of slope
685
+ y_hat = mean_y + slope * (time_norm - mean_t)
686
+ residuals = outcome_values - y_hat
687
+ mse = np.sum(residuals**2) / (n - 2)
688
+ se_slope = np.sqrt(mse / time_var)
689
+
690
+ return slope, se_slope
691
+
692
+ treated_slope, treated_se = compute_trend(treated_data)
693
+ control_slope, control_se = compute_trend(control_data)
694
+
695
+ # Test for difference in trends
696
+ slope_diff = treated_slope - control_slope
697
+ se_diff = np.sqrt(treated_se**2 + control_se**2)
698
+ t_stat, p_value, _ = safe_inference(slope_diff, se_diff)
699
+
700
+ return {
701
+ "treated_trend": treated_slope,
702
+ "treated_trend_se": treated_se,
703
+ "control_trend": control_slope,
704
+ "control_trend_se": control_se,
705
+ "trend_difference": slope_diff,
706
+ "trend_difference_se": se_diff,
707
+ "t_statistic": t_stat,
708
+ "p_value": p_value,
709
+ "parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
710
+ }
711
+
712
+
713
+ def check_parallel_trends_robust(
714
+ data: pd.DataFrame,
715
+ outcome: str,
716
+ time: str,
717
+ treatment_group: str,
718
+ unit: Optional[str] = None,
719
+ pre_periods: Optional[List[Any]] = None,
720
+ n_permutations: int = 1000,
721
+ seed: Optional[int] = None,
722
+ wasserstein_threshold: float = 0.2,
723
+ ) -> Dict[str, Any]:
724
+ """
725
+ Perform robust parallel trends testing using distributional comparisons.
726
+
727
+ Uses the Wasserstein (Earth Mover's) distance to compare the full
728
+ distribution of outcome changes between treated and control groups,
729
+ with permutation-based inference.
730
+
731
+ Parameters
732
+ ----------
733
+ data : pd.DataFrame
734
+ Panel data with repeated observations over time.
735
+ outcome : str
736
+ Name of outcome variable column.
737
+ time : str
738
+ Name of time period column.
739
+ treatment_group : str
740
+ Name of treatment group indicator column (0/1).
741
+ unit : str, optional
742
+ Name of unit identifier column. If provided, computes unit-level
743
+ changes. Otherwise uses observation-level data.
744
+ pre_periods : list, optional
745
+ List of pre-treatment time periods. If None, uses first half of periods.
746
+ n_permutations : int, default=1000
747
+ Number of permutations for computing p-value.
748
+ seed : int, optional
749
+ Random seed for reproducibility.
750
+ wasserstein_threshold : float, default=0.2
751
+ Threshold for normalized Wasserstein distance. Values below this
752
+ threshold (combined with p > 0.05) suggest parallel trends are plausible.
753
+
754
+ Returns
755
+ -------
756
+ dict
757
+ Dictionary containing:
758
+ - wasserstein_distance: Wasserstein distance between group distributions
759
+ - wasserstein_p_value: Permutation-based p-value
760
+ - ks_statistic: Kolmogorov-Smirnov test statistic
761
+ - ks_p_value: KS test p-value
762
+ - mean_difference: Difference in mean changes
763
+ - variance_ratio: Ratio of variances in changes
764
+ - treated_changes: Array of outcome changes for treated
765
+ - control_changes: Array of outcome changes for control
766
+ - parallel_trends_plausible: Boolean assessment
767
+
768
+ Examples
769
+ --------
770
+ >>> results = check_parallel_trends_robust(
771
+ ... data, outcome='sales', time='year',
772
+ ... treatment_group='treated', unit='firm_id'
773
+ ... )
774
+ >>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
775
+ >>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
776
+
777
+ Notes
778
+ -----
779
+ The Wasserstein distance (Earth Mover's Distance) measures the minimum
780
+ "cost" of transforming one distribution into another. Unlike simple
781
+ mean comparisons, it captures differences in the entire distribution
782
+ shape, making it more robust to non-normal data and heterogeneous effects.
783
+
784
+ A small Wasserstein distance and high p-value suggest the distributions
785
+ of pre-treatment changes are similar, supporting the parallel trends
786
+ assumption.
787
+ """
788
+ # Use local RNG to avoid affecting global random state
789
+ rng = np.random.default_rng(seed)
790
+
791
+ # Identify pre-treatment periods
792
+ if pre_periods is None:
793
+ all_periods = sorted(data[time].unique())
794
+ mid_point = len(all_periods) // 2
795
+ pre_periods = all_periods[:mid_point]
796
+
797
+ pre_data = data[data[time].isin(pre_periods)].copy()
798
+
799
+ # Compute outcome changes
800
+ treated_changes, control_changes = _compute_outcome_changes(
801
+ pre_data, outcome, time, treatment_group, unit
802
+ )
803
+
804
+ if len(treated_changes) < 2 or len(control_changes) < 2:
805
+ return {
806
+ "wasserstein_distance": np.nan,
807
+ "wasserstein_p_value": np.nan,
808
+ "ks_statistic": np.nan,
809
+ "ks_p_value": np.nan,
810
+ "mean_difference": np.nan,
811
+ "variance_ratio": np.nan,
812
+ "treated_changes": treated_changes,
813
+ "control_changes": control_changes,
814
+ "parallel_trends_plausible": None,
815
+ "error": "Insufficient data for comparison",
816
+ }
817
+
818
+ # Compute Wasserstein distance
819
+ wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
820
+
821
+ # Permutation test for Wasserstein distance
822
+ all_changes = np.concatenate([treated_changes, control_changes])
823
+ n_treated = len(treated_changes)
824
+ n_total = len(all_changes)
825
+
826
+ permuted_distances = np.zeros(n_permutations)
827
+ for i in range(n_permutations):
828
+ perm_idx = rng.permutation(n_total)
829
+ perm_treated = all_changes[perm_idx[:n_treated]]
830
+ perm_control = all_changes[perm_idx[n_treated:]]
831
+ permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
832
+
833
+ # P-value: proportion of permuted distances >= observed
834
+ wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
835
+
836
+ # Kolmogorov-Smirnov test
837
+ ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
838
+
839
+ # Additional summary statistics
840
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
841
+ var_treated = np.var(treated_changes, ddof=1)
842
+ var_control = np.var(control_changes, ddof=1)
843
+ var_ratio = var_treated / var_control if var_control > 0 else np.nan
844
+
845
+ # Normalized Wasserstein (relative to pooled std)
846
+ pooled_std = np.std(all_changes, ddof=1)
847
+ wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
848
+
849
+ # Assessment: parallel trends plausible if p-value > 0.05
850
+ # and normalized Wasserstein is small (below threshold)
851
+ plausible = bool(
852
+ wasserstein_p > 0.05
853
+ and (
854
+ wasserstein_normalized < wasserstein_threshold
855
+ if not np.isnan(wasserstein_normalized)
856
+ else True
857
+ )
858
+ )
859
+
860
+ return {
861
+ "wasserstein_distance": wasserstein_dist,
862
+ "wasserstein_normalized": wasserstein_normalized,
863
+ "wasserstein_p_value": wasserstein_p,
864
+ "ks_statistic": ks_stat,
865
+ "ks_p_value": ks_p,
866
+ "mean_difference": mean_diff,
867
+ "variance_ratio": var_ratio,
868
+ "n_treated": len(treated_changes),
869
+ "n_control": len(control_changes),
870
+ "treated_changes": treated_changes,
871
+ "control_changes": control_changes,
872
+ "parallel_trends_plausible": plausible,
873
+ }
874
+
875
+
876
+ def _compute_outcome_changes(
877
+ data: pd.DataFrame, outcome: str, time: str, treatment_group: str, unit: Optional[str] = None
878
+ ) -> Tuple[np.ndarray, np.ndarray]:
879
+ """
880
+ Compute period-to-period outcome changes for treated and control groups.
881
+
882
+ Parameters
883
+ ----------
884
+ data : pd.DataFrame
885
+ Panel data.
886
+ outcome : str
887
+ Outcome variable column.
888
+ time : str
889
+ Time period column.
890
+ treatment_group : str
891
+ Treatment group indicator column.
892
+ unit : str, optional
893
+ Unit identifier column.
894
+
895
+ Returns
896
+ -------
897
+ tuple
898
+ (treated_changes, control_changes) as numpy arrays.
899
+ """
900
+ if unit is not None:
901
+ # Unit-level changes: compute change for each unit across periods
902
+ data_sorted = data.sort_values([unit, time])
903
+ data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
904
+
905
+ # Remove NaN from first period of each unit
906
+ changes_data = data_sorted.dropna(subset=["_outcome_change"])
907
+
908
+ treated_changes = changes_data[changes_data[treatment_group] == 1]["_outcome_change"].values
909
+
910
+ control_changes = changes_data[changes_data[treatment_group] == 0]["_outcome_change"].values
911
+ else:
912
+ # Aggregate changes: compute mean change per period per group
913
+ treated_data = data[data[treatment_group] == 1]
914
+ control_data = data[data[treatment_group] == 0]
915
+
916
+ # Compute period means
917
+ treated_means = treated_data.groupby(time)[outcome].mean()
918
+ control_means = control_data.groupby(time)[outcome].mean()
919
+
920
+ # Compute changes between consecutive periods
921
+ treated_changes = np.diff(treated_means.values)
922
+ control_changes = np.diff(control_means.values)
923
+
924
+ return treated_changes.astype(float), control_changes.astype(float)
925
+
926
+
927
+ def equivalence_test_trends(
928
+ data: pd.DataFrame,
929
+ outcome: str,
930
+ time: str,
931
+ treatment_group: str,
932
+ unit: Optional[str] = None,
933
+ pre_periods: Optional[List[Any]] = None,
934
+ equivalence_margin: Optional[float] = None,
935
+ ) -> Dict[str, Any]:
936
+ """
937
+ Perform equivalence testing (TOST) for parallel trends.
938
+
939
+ Tests whether the difference in trends is practically equivalent to zero
940
+ using Two One-Sided Tests (TOST) procedure.
941
+
942
+ Parameters
943
+ ----------
944
+ data : pd.DataFrame
945
+ Panel data.
946
+ outcome : str
947
+ Name of outcome variable column.
948
+ time : str
949
+ Name of time period column.
950
+ treatment_group : str
951
+ Name of treatment group indicator column.
952
+ unit : str, optional
953
+ Name of unit identifier column.
954
+ pre_periods : list, optional
955
+ List of pre-treatment time periods.
956
+ equivalence_margin : float, optional
957
+ The margin for equivalence (delta). If None, uses 0.5 * pooled SD
958
+ of outcome changes as a default.
959
+
960
+ Returns
961
+ -------
962
+ dict
963
+ Dictionary containing:
964
+ - mean_difference: Difference in mean changes
965
+ - equivalence_margin: The margin used
966
+ - lower_p_value: P-value for lower bound test
967
+ - upper_p_value: P-value for upper bound test
968
+ - tost_p_value: Maximum of the two p-values
969
+ - equivalent: Boolean indicating equivalence at alpha=0.05
970
+ """
971
+ # Get pre-treatment periods
972
+ if pre_periods is None:
973
+ all_periods = sorted(data[time].unique())
974
+ mid_point = len(all_periods) // 2
975
+ pre_periods = all_periods[:mid_point]
976
+
977
+ pre_data = data[data[time].isin(pre_periods)].copy()
978
+
979
+ # Compute outcome changes
980
+ treated_changes, control_changes = _compute_outcome_changes(
981
+ pre_data, outcome, time, treatment_group, unit
982
+ )
983
+
984
+ # Need at least 2 observations per group to compute variance
985
+ # and at least 3 total for meaningful df calculation
986
+ if len(treated_changes) < 2 or len(control_changes) < 2:
987
+ return {
988
+ "mean_difference": np.nan,
989
+ "se_difference": np.nan,
990
+ "equivalence_margin": np.nan,
991
+ "lower_t_stat": np.nan,
992
+ "upper_t_stat": np.nan,
993
+ "lower_p_value": np.nan,
994
+ "upper_p_value": np.nan,
995
+ "tost_p_value": np.nan,
996
+ "degrees_of_freedom": np.nan,
997
+ "equivalent": None,
998
+ "error": "Insufficient data (need at least 2 observations per group)",
999
+ }
1000
+
1001
+ # Compute statistics
1002
+ var_t = np.var(treated_changes, ddof=1)
1003
+ var_c = np.var(control_changes, ddof=1)
1004
+ n_t = len(treated_changes)
1005
+ n_c = len(control_changes)
1006
+
1007
+ mean_diff = np.mean(treated_changes) - np.mean(control_changes)
1008
+
1009
+ # Handle zero variance case
1010
+ if var_t == 0 and var_c == 0:
1011
+ return {
1012
+ "mean_difference": mean_diff,
1013
+ "se_difference": 0.0,
1014
+ "equivalence_margin": np.nan,
1015
+ "lower_t_stat": np.nan,
1016
+ "upper_t_stat": np.nan,
1017
+ "lower_p_value": np.nan,
1018
+ "upper_p_value": np.nan,
1019
+ "tost_p_value": np.nan,
1020
+ "degrees_of_freedom": np.nan,
1021
+ "equivalent": None,
1022
+ "error": "Zero variance in both groups - cannot perform t-test",
1023
+ }
1024
+
1025
+ se_diff = np.sqrt(var_t / n_t + var_c / n_c)
1026
+
1027
+ # Handle zero SE case (cannot divide by zero in t-stat calculation)
1028
+ if se_diff == 0:
1029
+ return {
1030
+ "mean_difference": mean_diff,
1031
+ "se_difference": 0.0,
1032
+ "equivalence_margin": np.nan,
1033
+ "lower_t_stat": np.nan,
1034
+ "upper_t_stat": np.nan,
1035
+ "lower_p_value": np.nan,
1036
+ "upper_p_value": np.nan,
1037
+ "tost_p_value": np.nan,
1038
+ "degrees_of_freedom": np.nan,
1039
+ "equivalent": None,
1040
+ "error": "Zero standard error - cannot perform t-test",
1041
+ }
1042
+
1043
+ # Set equivalence margin if not provided
1044
+ if equivalence_margin is None:
1045
+ pooled_changes = np.concatenate([treated_changes, control_changes])
1046
+ equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)
1047
+
1048
+ # Degrees of freedom (Welch-Satterthwaite approximation)
1049
+ # Guard against division by zero when one group has zero variance
1050
+ numerator = (var_t / n_t + var_c / n_c) ** 2
1051
+ denom_t = (var_t / n_t) ** 2 / (n_t - 1) if var_t > 0 else 0
1052
+ denom_c = (var_c / n_c) ** 2 / (n_c - 1) if var_c > 0 else 0
1053
+ denominator = denom_t + denom_c
1054
+
1055
+ if denominator == 0:
1056
+ # Fall back to minimum of n_t-1 and n_c-1 when one variance is zero
1057
+ df = min(n_t - 1, n_c - 1)
1058
+ else:
1059
+ df = numerator / denominator
1060
+
1061
+ # TOST: Two one-sided tests
1062
+ # Test 1: H0: diff <= -margin vs H1: diff > -margin
1063
+ t_lower = (mean_diff - (-equivalence_margin)) / se_diff
1064
+ p_lower = stats.t.sf(t_lower, df)
1065
+
1066
+ # Test 2: H0: diff >= margin vs H1: diff < margin
1067
+ t_upper = (mean_diff - equivalence_margin) / se_diff
1068
+ p_upper = stats.t.cdf(t_upper, df)
1069
+
1070
+ # TOST p-value is the maximum of the two
1071
+ tost_p = max(p_lower, p_upper)
1072
+
1073
+ return {
1074
+ "mean_difference": mean_diff,
1075
+ "se_difference": se_diff,
1076
+ "equivalence_margin": equivalence_margin,
1077
+ "lower_t_stat": t_lower,
1078
+ "upper_t_stat": t_upper,
1079
+ "lower_p_value": p_lower,
1080
+ "upper_p_value": p_upper,
1081
+ "tost_p_value": tost_p,
1082
+ "degrees_of_freedom": df,
1083
+ "equivalent": bool(tost_p < 0.05),
1084
+ }
1085
+
1086
+
1087
+ def compute_synthetic_weights(
1088
+ Y_control: np.ndarray, Y_treated: np.ndarray, lambda_reg: float = 0.0, min_weight: float = 1e-6
1089
+ ) -> np.ndarray:
1090
+ """
1091
+ Compute synthetic control unit weights using constrained optimization.
1092
+
1093
+ Finds weights ω that minimize the squared difference between the
1094
+ weighted average of control unit outcomes and the treated unit outcomes
1095
+ during pre-treatment periods.
1096
+
1097
+ Parameters
1098
+ ----------
1099
+ Y_control : np.ndarray
1100
+ Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
1101
+ Each column is a control unit, each row is a pre-treatment period.
1102
+ Y_treated : np.ndarray
1103
+ Treated unit mean outcomes of shape (n_pre_periods,).
1104
+ Average across treated units for each pre-treatment period.
1105
+ lambda_reg : float, default=0.0
1106
+ L2 regularization parameter. Larger values shrink weights toward
1107
+ uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
1108
+ min_weight : float, default=1e-6
1109
+ Minimum weight threshold. Weights below this are set to zero.
1110
+
1111
+ Returns
1112
+ -------
1113
+ np.ndarray
1114
+ Unit weights of shape (n_control_units,) that sum to 1.
1115
+
1116
+ Notes
1117
+ -----
1118
+ Solves the quadratic program:
1119
+
1120
+ min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
1121
+ s.t. ω >= 0, sum(ω) = 1
1122
+
1123
+ Uses a simplified coordinate descent approach with projection onto simplex.
1124
+ """
1125
+ n_pre, n_control = Y_control.shape
1126
+
1127
+ if n_control == 0:
1128
+ return np.asarray([])
1129
+
1130
+ if n_control == 1:
1131
+ return np.asarray([1.0])
1132
+
1133
+ # Use Rust backend if available
1134
+ if HAS_RUST_BACKEND:
1135
+ Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1136
+ Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
1137
+ weights = _rust_synthetic_weights(
1138
+ Y_control, Y_treated, lambda_reg, _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
1139
+ )
1140
+ else:
1141
+ # Fallback to NumPy implementation
1142
+ weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)
1143
+
1144
+ # Set small weights to zero for interpretability
1145
+ weights[weights < min_weight] = 0
1146
+ if np.sum(weights) > 0:
1147
+ weights = weights / np.sum(weights)
1148
+ else:
1149
+ # Fallback to uniform if all weights are zeroed
1150
+ weights = np.ones(n_control) / n_control
1151
+
1152
+ return np.asarray(weights)
1153
+
1154
+
1155
+ def _compute_synthetic_weights_numpy(
1156
+ Y_control: np.ndarray,
1157
+ Y_treated: np.ndarray,
1158
+ lambda_reg: float = 0.0,
1159
+ ) -> np.ndarray:
1160
+ """NumPy fallback implementation of compute_synthetic_weights."""
1161
+ n_pre, n_control = Y_control.shape
1162
+
1163
+ # Initialize with uniform weights
1164
+ weights = np.ones(n_control) / n_control
1165
+
1166
+ # Precompute matrices for optimization
1167
+ # Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
1168
+ # = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
1169
+ YtY = Y_control.T @ Y_control
1170
+ YtT = Y_control.T @ Y_treated
1171
+ w_uniform = np.ones(n_control) / n_control
1172
+
1173
+ # Add regularization
1174
+ H = YtY + lambda_reg * np.eye(n_control)
1175
+ f = YtT + lambda_reg * w_uniform
1176
+
1177
+ # Solve with projected gradient descent
1178
+ # Project onto probability simplex
1179
+ step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)
1180
+
1181
+ for _ in range(_OPTIMIZATION_MAX_ITER):
1182
+ weights_old = weights.copy()
1183
+
1184
+ # Gradient step: minimize ||Y - Y_control @ w||^2
1185
+ grad = H @ weights - f
1186
+ weights = weights - step_size * grad
1187
+
1188
+ # Project onto simplex (sum to 1, non-negative)
1189
+ weights = _project_simplex(weights)
1190
+
1191
+ # Check convergence
1192
+ if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
1193
+ break
1194
+
1195
+ return weights
1196
+
1197
+
1198
+ def _project_simplex(v: np.ndarray) -> np.ndarray:
1199
+ """
1200
+ Project vector onto probability simplex (sum to 1, non-negative).
1201
+
1202
+ Uses the algorithm from Duchi et al. (2008).
1203
+
1204
+ Parameters
1205
+ ----------
1206
+ v : np.ndarray
1207
+ Vector to project.
1208
+
1209
+ Returns
1210
+ -------
1211
+ np.ndarray
1212
+ Projected vector on the simplex.
1213
+ """
1214
+ n = len(v)
1215
+ if n == 0:
1216
+ return v
1217
+
1218
+ # Sort in descending order
1219
+ u = np.sort(v)[::-1]
1220
+
1221
+ # Find the threshold
1222
+ cssv = np.cumsum(u)
1223
+ rho = np.where(u > (cssv - 1) / np.arange(1, n + 1))[0]
1224
+
1225
+ if len(rho) == 0:
1226
+ # All elements are negative or zero
1227
+ rho_val = 0
1228
+ else:
1229
+ rho_val = rho[-1]
1230
+
1231
+ theta = (cssv[rho_val] - 1) / (rho_val + 1)
1232
+
1233
+ return np.asarray(np.maximum(v - theta, 0))
1234
+
1235
+
1236
+ # =============================================================================
1237
+ # SDID Weight Optimization (Frank-Wolfe, matching R's synthdid)
1238
+ # =============================================================================
1239
+
1240
+
1241
+ def _sum_normalize(v: np.ndarray) -> np.ndarray:
1242
+ """Normalize vector to sum to 1. Fallback to uniform if sum is zero.
1243
+
1244
+ Matches R's synthdid ``sum_normalize()`` helper.
1245
+ """
1246
+ s = np.sum(v)
1247
+ if s > 0:
1248
+ return v / s
1249
+ return np.ones(len(v)) / len(v)
1250
+
1251
+
1252
+ def _compute_noise_level(Y_pre_control: np.ndarray) -> float:
1253
+ """Compute noise level from first-differences of control outcomes.
1254
+
1255
+ Matches R's ``sd(apply(Y[1:N0, 1:T0], 1, diff))`` which computes
1256
+ first-differences across time for each control unit, then takes the
1257
+ pooled standard deviation.
1258
+
1259
+ Parameters
1260
+ ----------
1261
+ Y_pre_control : np.ndarray
1262
+ Control unit pre-treatment outcomes, shape (n_pre, n_control).
1263
+
1264
+ Returns
1265
+ -------
1266
+ float
1267
+ Noise level (standard deviation of first-differences).
1268
+ """
1269
+ if HAS_RUST_BACKEND:
1270
+ return float(_rust_compute_noise_level(np.ascontiguousarray(Y_pre_control)))
1271
+ return _compute_noise_level_numpy(Y_pre_control)
1272
+
1273
+
1274
+ def _compute_noise_level_numpy(Y_pre_control: np.ndarray) -> float:
1275
+ """Pure NumPy implementation of noise level computation."""
1276
+ if Y_pre_control.shape[0] < 2:
1277
+ return 0.0
1278
+ # R: apply(Y[1:N0, 1:T0], 1, diff) computes diff per row (unit).
1279
+ # Our matrix is (T, N) so diff along axis=0 gives (T-1, N).
1280
+ first_diffs = np.diff(Y_pre_control, axis=0) # (T_pre-1, N_co)
1281
+ if first_diffs.size <= 1:
1282
+ return 0.0
1283
+ return float(np.std(first_diffs, ddof=1))
1284
+
1285
+
1286
+ def _compute_regularization(
1287
+ Y_pre_control: np.ndarray,
1288
+ n_treated: int,
1289
+ n_post: int,
1290
+ ) -> tuple:
1291
+ """Compute auto-regularization parameters matching R's synthdid.
1292
+
1293
+ Parameters
1294
+ ----------
1295
+ Y_pre_control : np.ndarray
1296
+ Control unit pre-treatment outcomes, shape (n_pre, n_control).
1297
+ n_treated : int
1298
+ Number of treated units.
1299
+ n_post : int
1300
+ Number of post-treatment periods.
1301
+
1302
+ Returns
1303
+ -------
1304
+ tuple
1305
+ (zeta_omega, zeta_lambda) regularization parameters.
1306
+ """
1307
+ sigma = _compute_noise_level(Y_pre_control)
1308
+ eta_omega = (n_treated * n_post) ** 0.25
1309
+ eta_lambda = 1e-6
1310
+ return eta_omega * sigma, eta_lambda * sigma
1311
+
1312
+
1313
+ def _fw_step(
1314
+ A: np.ndarray,
1315
+ x: np.ndarray,
1316
+ b: np.ndarray,
1317
+ eta: float,
1318
+ ) -> np.ndarray:
1319
+ """Single Frank-Wolfe step on the simplex.
1320
+
1321
+ Matches R's ``fw.step()`` in synthdid's ``sc.weight.fw()``.
1322
+
1323
+ Parameters
1324
+ ----------
1325
+ A : np.ndarray
1326
+ Matrix of shape (N, T0).
1327
+ x : np.ndarray
1328
+ Current weight vector of shape (T0,).
1329
+ b : np.ndarray
1330
+ Target vector of shape (N,).
1331
+ eta : float
1332
+ Regularization strength (N * zeta^2).
1333
+
1334
+ Returns
1335
+ -------
1336
+ np.ndarray
1337
+ Updated weight vector on the simplex.
1338
+ """
1339
+ Ax = A @ x
1340
+ half_grad = A.T @ (Ax - b) + eta * x
1341
+ i = int(np.argmin(half_grad))
1342
+ d_x = -x.copy()
1343
+ d_x[i] += 1.0
1344
+ if np.allclose(d_x, 0.0):
1345
+ return x.copy()
1346
+ d_err = A[:, i] - Ax
1347
+ denom = d_err @ d_err + eta * (d_x @ d_x)
1348
+ if denom <= 0:
1349
+ return x.copy()
1350
+ step = -(half_grad @ d_x) / denom
1351
+ step = float(np.clip(step, 0.0, 1.0))
1352
+ return x + step * d_x
1353
+
1354
+
1355
+ def _sc_weight_fw(
1356
+ Y: np.ndarray,
1357
+ zeta: float,
1358
+ intercept: bool = True,
1359
+ init_weights: Optional[np.ndarray] = None,
1360
+ min_decrease: float = 1e-5,
1361
+ max_iter: int = 10000,
1362
+ ) -> np.ndarray:
1363
+ """Compute synthetic control weights via Frank-Wolfe optimization.
1364
+
1365
+ Matches R's ``sc.weight.fw()`` from the synthdid package. Solves::
1366
+
1367
+ min_{lambda on simplex} zeta^2 * ||lambda||^2
1368
+ + (1/N) * ||A_centered @ lambda - b_centered||^2
1369
+
1370
+ Parameters
1371
+ ----------
1372
+ Y : np.ndarray
1373
+ Matrix of shape (N, T0+1). Last column is the target (post-period
1374
+ mean or treated pre-period mean depending on context).
1375
+ zeta : float
1376
+ Regularization strength.
1377
+ intercept : bool, default True
1378
+ If True, column-center Y before optimization.
1379
+ init_weights : np.ndarray, optional
1380
+ Initial weights. If None, starts with uniform weights.
1381
+ min_decrease : float, default 1e-5
1382
+ Convergence criterion: stop when objective decreases by less than
1383
+ ``min_decrease**2``. R uses ``1e-5 * noise_level``; the caller
1384
+ should pass the data-dependent value for best results.
1385
+ max_iter : int, default 10000
1386
+ Maximum number of iterations. Matches R's default.
1387
+
1388
+ Returns
1389
+ -------
1390
+ np.ndarray
1391
+ Weights of shape (T0,) on the simplex.
1392
+ """
1393
+ if HAS_RUST_BACKEND:
1394
+ return np.asarray(
1395
+ _rust_sc_weight_fw(
1396
+ np.ascontiguousarray(Y, dtype=np.float64),
1397
+ zeta,
1398
+ intercept,
1399
+ (
1400
+ np.ascontiguousarray(init_weights, dtype=np.float64)
1401
+ if init_weights is not None
1402
+ else None
1403
+ ),
1404
+ min_decrease,
1405
+ max_iter,
1406
+ )
1407
+ )
1408
+ return _sc_weight_fw_numpy(Y, zeta, intercept, init_weights, min_decrease, max_iter)
1409
+
1410
+
1411
+ def _sc_weight_fw_numpy(
1412
+ Y: np.ndarray,
1413
+ zeta: float,
1414
+ intercept: bool = True,
1415
+ init_weights: Optional[np.ndarray] = None,
1416
+ min_decrease: float = 1e-5,
1417
+ max_iter: int = 10000,
1418
+ ) -> np.ndarray:
1419
+ """Pure NumPy implementation of Frank-Wolfe SC weight solver."""
1420
+ T0 = Y.shape[1] - 1
1421
+ N = Y.shape[0]
1422
+
1423
+ if T0 <= 0:
1424
+ return np.ones(max(T0, 1))
1425
+
1426
+ # Column-center if using intercept (matches R's intercept=TRUE default)
1427
+ if intercept:
1428
+ Y = Y - Y.mean(axis=0)
1429
+
1430
+ A = Y[:, :T0]
1431
+ b = Y[:, T0]
1432
+ eta = N * zeta**2
1433
+
1434
+ if init_weights is not None:
1435
+ lam = init_weights.copy()
1436
+ else:
1437
+ lam = np.ones(T0) / T0
1438
+
1439
+ vals = np.full(max_iter, np.nan)
1440
+ for t in range(max_iter):
1441
+ lam = _fw_step(A, lam, b, eta)
1442
+ err = Y @ np.append(lam, -1.0)
1443
+ vals[t] = zeta**2 * np.sum(lam**2) + np.sum(err**2) / N
1444
+ if t >= 1 and vals[t - 1] - vals[t] < min_decrease**2:
1445
+ break
1446
+
1447
+ return lam
1448
+
1449
+
1450
+ def _sparsify(v: np.ndarray) -> np.ndarray:
1451
+ """Sparsify weight vector by zeroing out small entries.
1452
+
1453
+ Matches R's synthdid ``sparsify_function``:
1454
+ ``v[v <= max(v)/4] = 0; v = v / sum(v)``
1455
+
1456
+ Parameters
1457
+ ----------
1458
+ v : np.ndarray
1459
+ Weight vector.
1460
+
1461
+ Returns
1462
+ -------
1463
+ np.ndarray
1464
+ Sparsified weight vector summing to 1.
1465
+ """
1466
+ v = v.copy()
1467
+ max_v = np.max(v)
1468
+ if max_v <= 0:
1469
+ return np.ones(len(v)) / len(v)
1470
+ v[v <= max_v / 4] = 0.0
1471
+ return _sum_normalize(v)
1472
+
1473
+
1474
+ def compute_time_weights(
1475
+ Y_pre_control: np.ndarray,
1476
+ Y_post_control: np.ndarray,
1477
+ zeta_lambda: float,
1478
+ intercept: bool = True,
1479
+ min_decrease: float = 1e-5,
1480
+ max_iter_pre_sparsify: int = 100,
1481
+ max_iter: int = 10000,
1482
+ ) -> np.ndarray:
1483
+ """Compute SDID time weights via Frank-Wolfe optimization.
1484
+
1485
+ Matches R's ``synthdid::sc.weight.fw(Yc[1:N0, ], zeta=zeta.lambda,
1486
+ intercept=TRUE)`` where ``Yc`` is the collapsed-form matrix. Uses
1487
+ two-pass optimization with sparsification (same as unit weights),
1488
+ matching R's default ``sparsify=sparsify_function``.
1489
+
1490
+ Parameters
1491
+ ----------
1492
+ Y_pre_control : np.ndarray
1493
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1494
+ Y_post_control : np.ndarray
1495
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1496
+ zeta_lambda : float
1497
+ Regularization parameter for time weights.
1498
+ intercept : bool, default True
1499
+ If True, column-center the optimization matrix.
1500
+ min_decrease : float, default 1e-5
1501
+ Convergence criterion for Frank-Wolfe. R uses ``1e-5 * noise_level``.
1502
+ max_iter_pre_sparsify : int, default 100
1503
+ Iterations for first pass (before sparsification).
1504
+ max_iter : int, default 10000
1505
+ Maximum iterations for second pass (after sparsification).
1506
+ Matches R's default.
1507
+
1508
+ Returns
1509
+ -------
1510
+ np.ndarray
1511
+ Time weights of shape (n_pre,) on the simplex.
1512
+ """
1513
+ if Y_post_control.shape[0] == 0:
1514
+ raise ValueError(
1515
+ "Y_post_control has no rows. At least one post-treatment period "
1516
+ "is required for time weight computation."
1517
+ )
1518
+
1519
+ if HAS_RUST_BACKEND:
1520
+ return np.asarray(
1521
+ _rust_compute_time_weights(
1522
+ np.ascontiguousarray(Y_pre_control, dtype=np.float64),
1523
+ np.ascontiguousarray(Y_post_control, dtype=np.float64),
1524
+ zeta_lambda,
1525
+ intercept,
1526
+ min_decrease,
1527
+ max_iter_pre_sparsify,
1528
+ max_iter,
1529
+ )
1530
+ )
1531
+
1532
+ n_pre = Y_pre_control.shape[0]
1533
+
1534
+ if n_pre <= 1:
1535
+ return np.ones(n_pre)
1536
+
1537
+ # Build collapsed form: (N_co, T_pre + 1), last col = per-control post mean
1538
+ post_means = np.mean(Y_post_control, axis=0) # (N_co,)
1539
+ Y_time = np.column_stack([Y_pre_control.T, post_means]) # (N_co, T_pre+1)
1540
+
1541
+ # First pass: limited iterations (matching R's max.iter.pre.sparsify)
1542
+ lam = _sc_weight_fw(
1543
+ Y_time,
1544
+ zeta=zeta_lambda,
1545
+ intercept=intercept,
1546
+ min_decrease=min_decrease,
1547
+ max_iter=max_iter_pre_sparsify,
1548
+ )
1549
+
1550
+ # Sparsify: zero out small weights, renormalize (R's sparsify_function)
1551
+ lam = _sparsify(lam)
1552
+
1553
+ # Second pass: from sparsified initialization (matching R's max.iter)
1554
+ lam = _sc_weight_fw(
1555
+ Y_time,
1556
+ zeta=zeta_lambda,
1557
+ intercept=intercept,
1558
+ init_weights=lam,
1559
+ min_decrease=min_decrease,
1560
+ max_iter=max_iter,
1561
+ )
1562
+
1563
+ return lam
1564
+
1565
+
1566
+ def compute_sdid_unit_weights(
1567
+ Y_pre_control: np.ndarray,
1568
+ Y_pre_treated_mean: np.ndarray,
1569
+ zeta_omega: float,
1570
+ intercept: bool = True,
1571
+ min_decrease: float = 1e-5,
1572
+ max_iter_pre_sparsify: int = 100,
1573
+ max_iter: int = 10000,
1574
+ ) -> np.ndarray:
1575
+ """Compute SDID unit weights via Frank-Wolfe with two-pass sparsification.
1576
+
1577
+ Matches R's ``synthdid::sc.weight.fw(t(Yc[, 1:T0]), zeta=zeta.omega,
1578
+ intercept=TRUE)`` followed by the sparsify/re-optimize pass.
1579
+
1580
+ Parameters
1581
+ ----------
1582
+ Y_pre_control : np.ndarray
1583
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1584
+ Y_pre_treated_mean : np.ndarray
1585
+ Mean treated outcomes in pre-treatment periods, shape (n_pre,).
1586
+ zeta_omega : float
1587
+ Regularization parameter for unit weights.
1588
+ intercept : bool, default True
1589
+ If True, column-center the optimization matrix.
1590
+ min_decrease : float, default 1e-5
1591
+ Convergence criterion for Frank-Wolfe. R uses ``1e-5 * noise_level``.
1592
+ max_iter_pre_sparsify : int, default 100
1593
+ Iterations for first pass (before sparsification).
1594
+ max_iter : int, default 10000
1595
+ Iterations for second pass (after sparsification). Matches R's default.
1596
+
1597
+ Returns
1598
+ -------
1599
+ np.ndarray
1600
+ Unit weights of shape (n_control,) on the simplex.
1601
+ """
1602
+ n_control = Y_pre_control.shape[1]
1603
+
1604
+ if n_control == 0:
1605
+ return np.asarray([])
1606
+ if n_control == 1:
1607
+ return np.asarray([1.0])
1608
+
1609
+ if HAS_RUST_BACKEND:
1610
+ return np.asarray(
1611
+ _rust_sdid_unit_weights(
1612
+ np.ascontiguousarray(Y_pre_control, dtype=np.float64),
1613
+ np.ascontiguousarray(Y_pre_treated_mean, dtype=np.float64),
1614
+ zeta_omega,
1615
+ intercept,
1616
+ min_decrease,
1617
+ max_iter_pre_sparsify,
1618
+ max_iter,
1619
+ )
1620
+ )
1621
+
1622
+ # Build collapsed form: (T_pre, N_co + 1), last col = treated pre means
1623
+ Y_unit = np.column_stack([Y_pre_control, Y_pre_treated_mean.reshape(-1, 1)])
1624
+
1625
+ # First pass: limited iterations
1626
+ omega = _sc_weight_fw(
1627
+ Y_unit,
1628
+ zeta=zeta_omega,
1629
+ intercept=intercept,
1630
+ max_iter=max_iter_pre_sparsify,
1631
+ min_decrease=min_decrease,
1632
+ )
1633
+
1634
+ # Sparsify: zero out weights <= max/4, renormalize
1635
+ omega = _sparsify(omega)
1636
+
1637
+ # Second pass: from sparsified initialization
1638
+ omega = _sc_weight_fw(
1639
+ Y_unit,
1640
+ zeta=zeta_omega,
1641
+ intercept=intercept,
1642
+ init_weights=omega,
1643
+ max_iter=max_iter,
1644
+ min_decrease=min_decrease,
1645
+ )
1646
+
1647
+ return omega
1648
+
1649
+
1650
+ def compute_sdid_estimator(
1651
+ Y_pre_control: np.ndarray,
1652
+ Y_post_control: np.ndarray,
1653
+ Y_pre_treated: np.ndarray,
1654
+ Y_post_treated: np.ndarray,
1655
+ unit_weights: np.ndarray,
1656
+ time_weights: np.ndarray,
1657
+ ) -> float:
1658
+ """
1659
+ Compute the Synthetic DiD estimator.
1660
+
1661
+ Parameters
1662
+ ----------
1663
+ Y_pre_control : np.ndarray
1664
+ Control outcomes in pre-treatment periods, shape (n_pre, n_control).
1665
+ Y_post_control : np.ndarray
1666
+ Control outcomes in post-treatment periods, shape (n_post, n_control).
1667
+ Y_pre_treated : np.ndarray
1668
+ Treated unit outcomes in pre-treatment periods, shape (n_pre,).
1669
+ Y_post_treated : np.ndarray
1670
+ Treated unit outcomes in post-treatment periods, shape (n_post,).
1671
+ unit_weights : np.ndarray
1672
+ Weights for control units, shape (n_control,).
1673
+ time_weights : np.ndarray
1674
+ Weights for pre-treatment periods, shape (n_pre,).
1675
+
1676
+ Returns
1677
+ -------
1678
+ float
1679
+ The synthetic DiD treatment effect estimate.
1680
+
1681
+ Notes
1682
+ -----
1683
+ The SDID estimator is:
1684
+
1685
+ τ̂ = (Ȳ_treated,post - Σ_t λ_t * Y_treated,t)
1686
+ - Σ_j ω_j * (Ȳ_j,post - Σ_t λ_t * Y_j,t)
1687
+
1688
+ Where:
1689
+ - ω_j are unit weights
1690
+ - λ_t are time weights
1691
+ - Ȳ denotes average over post periods
1692
+ """
1693
+ # Weighted pre-treatment averages
1694
+ weighted_pre_control = time_weights @ Y_pre_control # shape: (n_control,)
1695
+ weighted_pre_treated = time_weights @ Y_pre_treated # scalar
1696
+
1697
+ # Post-treatment averages
1698
+ mean_post_control = np.mean(Y_post_control, axis=0) # shape: (n_control,)
1699
+ mean_post_treated = np.mean(Y_post_treated) # scalar
1700
+
1701
+ # DiD for treated: post - weighted pre
1702
+ did_treated = mean_post_treated - weighted_pre_treated
1703
+
1704
+ # Weighted DiD for controls: sum over j of omega_j * (post_j - weighted_pre_j)
1705
+ did_control = unit_weights @ (mean_post_control - weighted_pre_control)
1706
+
1707
+ # SDID estimator
1708
+ tau = did_treated - did_control
1709
+
1710
+ return float(tau)
1711
+
1712
+
1713
+ def demean_by_group(
1714
+ data: pd.DataFrame,
1715
+ variables: List[str],
1716
+ group_var: str,
1717
+ inplace: bool = False,
1718
+ suffix: str = "",
1719
+ weights: Optional[np.ndarray] = None,
1720
+ ) -> Tuple[pd.DataFrame, int]:
1721
+ """
1722
+ Demean variables by a grouping variable (one-way within transformation).
1723
+
1724
+ For each variable, computes: x_ig - mean(x_g) where g is the group.
1725
+ When weights are provided, uses weighted group means:
1726
+ mean_g = sum(w_i * x_i) / sum(w_i) for i in group g.
1727
+
1728
+ Parameters
1729
+ ----------
1730
+ data : pd.DataFrame
1731
+ DataFrame containing the variables to demean.
1732
+ variables : list of str
1733
+ Column names to demean.
1734
+ group_var : str
1735
+ Column name for the grouping variable.
1736
+ inplace : bool, default False
1737
+ If True, modifies the original columns. If False, leaves original
1738
+ columns unchanged (demeaning is still applied to return value).
1739
+ suffix : str, default ""
1740
+ Suffix to add to demeaned column names (only used when inplace=False
1741
+ and you want to keep both original and demeaned columns).
1742
+ weights : np.ndarray, optional
1743
+ Observation weights for weighted group means.
1744
+
1745
+ Returns
1746
+ -------
1747
+ data : pd.DataFrame
1748
+ DataFrame with demeaned variables.
1749
+ n_effects : int
1750
+ Number of absorbed fixed effects (nunique - 1).
1751
+
1752
+ Examples
1753
+ --------
1754
+ >>> df, n_fe = demean_by_group(df, ['y', 'x1', 'x2'], 'unit')
1755
+ >>> # df['y'], df['x1'], df['x2'] are now demeaned by unit
1756
+ """
1757
+ if not inplace:
1758
+ data = data.copy()
1759
+
1760
+ # Count fixed effects (categories - 1 for identification)
1761
+ n_effects = data[group_var].nunique() - 1
1762
+
1763
+ if weights is not None:
1764
+ # Weighted demeaning: weighted_mean_g = sum(w*x) / sum(w) per group
1765
+ groups = data[group_var].values
1766
+ w = np.asarray(weights, dtype=np.float64)
1767
+ # Cache weight sums per group (invariant across variables)
1768
+ w_sum = pd.Series(w).groupby(groups).transform("sum")
1769
+ for var in variables:
1770
+ col_name = var if not suffix else f"{var}{suffix}"
1771
+ x = data[var].values.astype(np.float64)
1772
+ wx = pd.Series(w * x).groupby(groups).transform("sum")
1773
+ weighted_means = wx / w_sum
1774
+ data[col_name] = x - weighted_means.values
1775
+ else:
1776
+ # Cache the groupby object for efficiency
1777
+ grouper = data.groupby(group_var, sort=False)
1778
+ for var in variables:
1779
+ col_name = var if not suffix else f"{var}{suffix}"
1780
+ group_means = grouper[var].transform("mean")
1781
+ data[col_name] = data[var] - group_means
1782
+
1783
+ return data, n_effects
1784
+
1785
+
1786
+ def within_transform(
1787
+ data: pd.DataFrame,
1788
+ variables: List[str],
1789
+ unit: str,
1790
+ time: str,
1791
+ inplace: bool = False,
1792
+ suffix: str = "_demeaned",
1793
+ weights: Optional[np.ndarray] = None,
1794
+ ) -> pd.DataFrame:
1795
+ """
1796
+ Apply two-way within transformation to remove unit and time fixed effects.
1797
+
1798
+ Computes: y_it - y_i. - y_.t + y_.. for each variable.
1799
+ When weights are provided, uses weighted group means at each step.
1800
+
1801
+ This is the standard fixed effects transformation for panel data that
1802
+ removes both unit-specific and time-specific effects.
1803
+
1804
+ Parameters
1805
+ ----------
1806
+ data : pd.DataFrame
1807
+ Panel data containing the variables to transform.
1808
+ variables : list of str
1809
+ Column names to transform.
1810
+ unit : str
1811
+ Column name for unit identifier.
1812
+ time : str
1813
+ Column name for time period identifier.
1814
+ inplace : bool, default False
1815
+ If True, modifies the original columns. If False, creates new columns
1816
+ with the specified suffix.
1817
+ suffix : str, default "_demeaned"
1818
+ Suffix for new column names when inplace=False.
1819
+ weights : np.ndarray, optional
1820
+ Observation weights for weighted group means.
1821
+
1822
+ Returns
1823
+ -------
1824
+ pd.DataFrame
1825
+ DataFrame with within-transformed variables.
1826
+
1827
+ Notes
1828
+ -----
1829
+ The within transformation removes variation that is constant within units
1830
+ (unit fixed effects) and constant within time periods (time fixed effects).
1831
+ The resulting estimates are equivalent to including unit and time dummies
1832
+ but is computationally more efficient for large panels.
1833
+
1834
+ Examples
1835
+ --------
1836
+ >>> df = within_transform(df, ['y', 'x'], 'unit_id', 'year')
1837
+ >>> # df now has 'y_demeaned' and 'x_demeaned' columns
1838
+ """
1839
+ if not inplace:
1840
+ data = data.copy()
1841
+
1842
+ if weights is not None:
1843
+ # Weighted within-transformation via iterative alternating projections
1844
+ w = np.asarray(weights, dtype=np.float64)
1845
+ unit_groups = data[unit].values
1846
+ time_groups = data[time].values
1847
+
1848
+ # Cache weight sums per group (invariant across variables)
1849
+ unit_w_sum = pd.Series(w).groupby(unit_groups).transform("sum").values
1850
+ time_w_sum = pd.Series(w).groupby(time_groups).transform("sum").values
1851
+
1852
+ def _weighted_group_demean(x, groups, w, w_sum):
1853
+ wx_sum = pd.Series(w * x).groupby(groups).transform("sum").values
1854
+ return x - wx_sum / w_sum
1855
+
1856
+ if inplace:
1857
+ for var in variables:
1858
+ x = data[var].values.astype(np.float64)
1859
+ for _iter in range(100): # max iterations
1860
+ x_old = x.copy()
1861
+ x = _weighted_group_demean(x, unit_groups, w, unit_w_sum)
1862
+ x = _weighted_group_demean(x, time_groups, w, time_w_sum)
1863
+ if np.max(np.abs(x - x_old)) < 1e-8:
1864
+ break
1865
+ data[var] = x
1866
+ else:
1867
+ demeaned_data = {}
1868
+ for var in variables:
1869
+ x = data[var].values.astype(np.float64)
1870
+ for _iter in range(100):
1871
+ x_old = x.copy()
1872
+ x = _weighted_group_demean(x, unit_groups, w, unit_w_sum)
1873
+ x = _weighted_group_demean(x, time_groups, w, time_w_sum)
1874
+ if np.max(np.abs(x - x_old)) < 1e-8:
1875
+ break
1876
+ demeaned_data[f"{var}{suffix}"] = x
1877
+ demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
1878
+ data = pd.concat([data, demeaned_df], axis=1)
1879
+ else:
1880
+ # Cache groupby objects for efficiency
1881
+ unit_grouper = data.groupby(unit, sort=False)
1882
+ time_grouper = data.groupby(time, sort=False)
1883
+
1884
+ if inplace:
1885
+ for var in variables:
1886
+ unit_means = unit_grouper[var].transform("mean")
1887
+ time_means = time_grouper[var].transform("mean")
1888
+ grand_mean = data[var].mean()
1889
+ data[var] = data[var] - unit_means - time_means + grand_mean
1890
+ else:
1891
+ demeaned_data = {}
1892
+ for var in variables:
1893
+ unit_means = unit_grouper[var].transform("mean")
1894
+ time_means = time_grouper[var].transform("mean")
1895
+ grand_mean = data[var].mean()
1896
+ demeaned_data[f"{var}{suffix}"] = (
1897
+ data[var] - unit_means - time_means + grand_mean
1898
+ ).values
1899
+ demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
1900
+ data = pd.concat([data, demeaned_df], axis=1)
1901
+
1902
+ return data