diff-diff 2.8.3__tar.gz → 2.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {diff_diff-2.8.3 → diff_diff-2.9.0}/PKG-INFO +3 -2
  2. {diff_diff-2.8.3 → diff_diff-2.9.0}/README.md +2 -1
  3. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/__init__.py +10 -1
  4. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/bootstrap_utils.py +84 -14
  5. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/continuous_did_results.py +15 -1
  6. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/efficient_did.py +107 -25
  7. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/efficient_did_bootstrap.py +9 -1
  8. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/efficient_did_covariates.py +102 -17
  9. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/efficient_did_results.py +15 -1
  10. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/imputation.py +247 -72
  11. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/imputation_bootstrap.py +6 -4
  12. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/imputation_results.py +17 -2
  13. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/linalg.py +156 -24
  14. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/prep.py +75 -0
  15. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/prep_dgp.py +16 -1
  16. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/results.py +39 -0
  17. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/stacked_did_results.py +15 -1
  18. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/staggered_results.py +25 -13
  19. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/staggered_triple_diff_results.py +20 -17
  20. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/sun_abraham.py +57 -16
  21. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/survey.py +101 -9
  22. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/trop_results.py +13 -0
  23. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/two_stage.py +69 -15
  24. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/two_stage_results.py +15 -1
  25. diff_diff-2.9.0/diff_diff/wooldridge.py +1090 -0
  26. diff_diff-2.9.0/diff_diff/wooldridge_results.py +341 -0
  27. {diff_diff-2.8.3 → diff_diff-2.9.0}/pyproject.toml +2 -1
  28. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/Cargo.lock +3 -3
  29. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/Cargo.toml +1 -1
  30. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/_backend.py +0 -0
  31. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/bacon.py +0 -0
  32. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/continuous_did.py +0 -0
  33. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/continuous_did_bspline.py +0 -0
  34. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/datasets.py +0 -0
  35. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/diagnostics.py +0 -0
  36. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/efficient_did_weights.py +0 -0
  37. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/estimators.py +0 -0
  38. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/honest_did.py +0 -0
  39. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/power.py +0 -0
  40. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/practitioner.py +0 -0
  41. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/pretrends.py +0 -0
  42. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/stacked_did.py +0 -0
  43. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/staggered.py +0 -0
  44. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/staggered_aggregation.py +0 -0
  45. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/staggered_bootstrap.py +0 -0
  46. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/staggered_triple_diff.py +0 -0
  47. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/synthetic_did.py +0 -0
  48. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/triple_diff.py +0 -0
  49. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/trop.py +0 -0
  50. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/trop_global.py +0 -0
  51. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/trop_local.py +0 -0
  52. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/twfe.py +0 -0
  53. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/two_stage_bootstrap.py +0 -0
  54. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/utils.py +0 -0
  55. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/__init__.py +0 -0
  56. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_common.py +0 -0
  57. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_continuous.py +0 -0
  58. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_diagnostic.py +0 -0
  59. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_event_study.py +0 -0
  60. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_power.py +0 -0
  61. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_staggered.py +0 -0
  62. {diff_diff-2.8.3 → diff_diff-2.9.0}/diff_diff/visualization/_synthetic.py +0 -0
  63. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/build.rs +0 -0
  64. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/src/bootstrap.rs +0 -0
  65. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/src/lib.rs +0 -0
  66. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/src/linalg.rs +0 -0
  67. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/src/trop.rs +0 -0
  68. {diff_diff-2.8.3 → diff_diff-2.9.0}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.8.3
3
+ Version: 2.9.0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -134,7 +134,7 @@ Detailed guide: [`docs/llms-practitioner.txt`](docs/llms-practitioner.txt)
134
134
  - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights
135
135
  - **Panel data support**: Two-way fixed effects estimator for panel designs
136
136
  - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects
137
- - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), and Efficient DiD (Chen, Sant'Anna & Xie 2025) estimators for heterogeneous treatment timing
137
+ - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), Efficient DiD (Chen, Sant'Anna & Xie 2025), and Wooldridge ETWFE (2021/2023) estimators for heterogeneous treatment timing
138
138
  - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling
139
139
  - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness
140
140
  - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025)
@@ -167,6 +167,7 @@ All estimators have short aliases for convenience:
167
167
  | `Stacked` | `StackedDiD` | Stacked DiD |
168
168
  | `Bacon` | `BaconDecomposition` | Goodman-Bacon decomposition |
169
169
  | `EDiD` | `EfficientDiD` | Efficient DiD |
170
+ | `ETWFE` | `WooldridgeDiD` | Wooldridge ETWFE (2021/2023) |
170
171
 
171
172
  `TROP` already uses its short canonical name and needs no alias.
172
173
 
@@ -84,7 +84,7 @@ Detailed guide: [`docs/llms-practitioner.txt`](docs/llms-practitioner.txt)
84
84
  - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights
85
85
  - **Panel data support**: Two-way fixed effects estimator for panel designs
86
86
  - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects
87
- - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), and Efficient DiD (Chen, Sant'Anna & Xie 2025) estimators for heterogeneous treatment timing
87
+ - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), Efficient DiD (Chen, Sant'Anna & Xie 2025), and Wooldridge ETWFE (2021/2023) estimators for heterogeneous treatment timing
88
88
  - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling
89
89
  - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness
90
90
  - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025)
@@ -117,6 +117,7 @@ All estimators have short aliases for convenience:
117
117
  | `Stacked` | `StackedDiD` | Stacked DiD |
118
118
  | `Bacon` | `BaconDecomposition` | Goodman-Bacon decomposition |
119
119
  | `EDiD` | `EfficientDiD` | Efficient DiD |
120
+ | `ETWFE` | `WooldridgeDiD` | Wooldridge ETWFE (2021/2023) |
120
121
 
121
122
  `TROP` already uses its short canonical name and needs no alias.
122
123
 
@@ -94,6 +94,7 @@ from diff_diff.prep import (
94
94
  make_treatment_indicator,
95
95
  rank_control_units,
96
96
  summarize_did_data,
97
+ trim_weights,
97
98
  validate_did_data,
98
99
  wide_to_long,
99
100
  )
@@ -163,6 +164,8 @@ from diff_diff.trop import (
163
164
  TROPResults,
164
165
  trop,
165
166
  )
167
+ from diff_diff.wooldridge import WooldridgeDiD
168
+ from diff_diff.wooldridge_results import WooldridgeDiDResults
166
169
  from diff_diff.utils import (
167
170
  WildBootstrapResults,
168
171
  check_parallel_trends,
@@ -209,8 +212,9 @@ SDDD = StaggeredTripleDifference
209
212
  Stacked = StackedDiD
210
213
  Bacon = BaconDecomposition
211
214
  EDiD = EfficientDiD
215
+ ETWFE = WooldridgeDiD
212
216
 
213
- __version__ = "2.8.3"
217
+ __version__ = "2.9.0"
214
218
  __all__ = [
215
219
  # Estimators
216
220
  "DifferenceInDifferences",
@@ -275,6 +279,10 @@ __all__ = [
275
279
  "EfficientDiDResults",
276
280
  "EDiDBootstrapResults",
277
281
  "EDiD",
282
+ # WooldridgeDiD (ETWFE)
283
+ "WooldridgeDiD",
284
+ "WooldridgeDiDResults",
285
+ "ETWFE",
278
286
  # Visualization
279
287
  "plot_bacon",
280
288
  "plot_event_study",
@@ -307,6 +315,7 @@ __all__ = [
307
315
  "make_post_indicator",
308
316
  "wide_to_long",
309
317
  "balance_panel",
318
+ "trim_weights",
310
319
  "validate_did_data",
311
320
  "summarize_did_data",
312
321
  "generate_did_data",
@@ -433,6 +433,10 @@ def generate_survey_multiplier_weights_batch(
433
433
  is present, weights are scaled by ``sqrt(1 - f_h)`` per stratum so
434
434
  the bootstrap variance matches the TSL variance.
435
435
 
436
+ For ``lonely_psu="adjust"``, singleton PSUs from different strata are
437
+ pooled into a combined pseudo-stratum and weights are generated for
438
+ the pooled group (no FPC scaling on pooled singletons).
439
+
436
440
  Parameters
437
441
  ----------
438
442
  n_bootstrap : int
@@ -454,11 +458,7 @@ def generate_survey_multiplier_weights_batch(
454
458
  psu = resolved_survey.psu
455
459
  strata = resolved_survey.strata
456
460
 
457
- if resolved_survey.lonely_psu == "adjust":
458
- raise NotImplementedError(
459
- "lonely_psu='adjust' is not yet supported for survey-aware bootstrap. "
460
- "Use lonely_psu='remove' or 'certainty', or use analytical inference."
461
- )
461
+ _lonely_psu = resolved_survey.lonely_psu
462
462
 
463
463
  if psu is None:
464
464
  # Each observation is its own PSU
@@ -499,6 +499,7 @@ def generate_survey_multiplier_weights_batch(
499
499
  psu_to_col = {int(p): i for i, p in enumerate(psu_ids)}
500
500
 
501
501
  unique_strata = np.unique(strata)
502
+ _singleton_cols = [] # For lonely_psu="adjust" pooling
502
503
  for h in unique_strata:
503
504
  mask_h = strata == h
504
505
 
@@ -511,8 +512,12 @@ def generate_survey_multiplier_weights_batch(
511
512
  cols = np.array([psu_to_col[int(p)] for p in psus_in_h])
512
513
 
513
514
  if n_h < 2:
514
- # Lonely PSU — zero weight (matches remove/certainty behavior)
515
- weights[:, cols] = 0.0
515
+ if _lonely_psu == "adjust":
516
+ # Collect for pooled pseudo-stratum processing
517
+ _singleton_cols.extend(cols.tolist())
518
+ else:
519
+ # remove / certainty — zero weight
520
+ weights[:, cols] = 0.0
516
521
  continue
517
522
 
518
523
  # Generate weights for this stratum
@@ -536,6 +541,31 @@ def generate_survey_multiplier_weights_batch(
536
541
 
537
542
  weights[:, cols] = stratum_weights
538
543
 
544
+ # Pool singleton PSUs into a pseudo-stratum for "adjust"
545
+ if _singleton_cols:
546
+ n_pooled = len(_singleton_cols)
547
+ if n_pooled >= 2:
548
+ pooled_weights = generate_bootstrap_weights_batch_numpy(
549
+ n_bootstrap, n_pooled, weight_type, rng
550
+ )
551
+ # No FPC scaling for pooled singletons (conservative)
552
+ pooled_cols = np.array(_singleton_cols)
553
+ weights[:, pooled_cols] = pooled_weights
554
+ else:
555
+ # Single singleton — cannot pool, zero weight (library-specific
556
+ # fallback; bootstrap adjust with one singleton = remove).
557
+ import warnings
558
+
559
+ warnings.warn(
560
+ "lonely_psu='adjust' with only 1 singleton stratum in "
561
+ "bootstrap: singleton PSU contributes zero variance "
562
+ "(same as 'remove'). At least 2 singleton strata are "
563
+ "needed for pooled pseudo-stratum bootstrap.",
564
+ UserWarning,
565
+ stacklevel=3,
566
+ )
567
+ weights[:, _singleton_cols[0]] = 0.0
568
+
539
569
  return weights, psu_ids
540
570
 
541
571
 
@@ -553,6 +583,9 @@ def generate_rao_wu_weights(
553
583
  With FPC: ``m_h = max(1, round((1 - f_h) * (n_h - 1)))``
554
584
  (Rao, Wu & Yue 1992, Section 3).
555
585
 
586
+ For ``lonely_psu="adjust"``, singleton PSUs are pooled into a combined
587
+ pseudo-stratum and resampled together (no FPC scaling on pooled group).
588
+
556
589
  Parameters
557
590
  ----------
558
591
  resolved_survey : ResolvedSurveyDesign
@@ -570,11 +603,7 @@ def generate_rao_wu_weights(
570
603
  psu = resolved_survey.psu
571
604
  strata = resolved_survey.strata
572
605
 
573
- if resolved_survey.lonely_psu == "adjust":
574
- raise NotImplementedError(
575
- "lonely_psu='adjust' is not yet supported for survey-aware bootstrap. "
576
- "Use lonely_psu='remove' or 'certainty', or use analytical inference."
577
- )
606
+ _lonely_psu_rw = resolved_survey.lonely_psu
578
607
 
579
608
  rescaled = np.zeros(n_obs, dtype=np.float64)
580
609
 
@@ -589,14 +618,20 @@ def generate_rao_wu_weights(
589
618
  unique_strata = np.unique(strata)
590
619
  strata_masks = [strata == h for h in unique_strata]
591
620
 
621
+ # Collect singleton PSUs for "adjust" pooling
622
+ _singleton_info = [] # list of (mask_h, unique_psu_h) tuples
623
+
592
624
  for mask_h in strata_masks:
593
625
  psu_h = obs_psu[mask_h]
594
626
  unique_psu_h = np.unique(psu_h)
595
627
  n_h = len(unique_psu_h)
596
628
 
597
629
  if n_h < 2:
598
- # Census / lonely PSU — keep original weights (zero variance)
599
- rescaled[mask_h] = base_weights[mask_h]
630
+ if _lonely_psu_rw == "adjust":
631
+ _singleton_info.append((mask_h, unique_psu_h))
632
+ else:
633
+ # remove / certainty — keep original weights (zero variance)
634
+ rescaled[mask_h] = base_weights[mask_h]
600
635
  continue
601
636
 
602
637
  # Compute resample size
@@ -629,6 +664,41 @@ def generate_rao_wu_weights(
629
664
  local_indices = np.array([psu_to_local[int(obs_psu[idx])] for idx in obs_in_h])
630
665
  rescaled[obs_in_h] = base_weights[obs_in_h] * scale_per_psu[local_indices]
631
666
 
667
+ # Pool singleton PSUs into a pseudo-stratum for "adjust"
668
+ if _singleton_info:
669
+ # Combine all singleton PSUs into one group
670
+ pooled_psus = np.concatenate([p for _, p in _singleton_info])
671
+ n_pooled = len(pooled_psus)
672
+
673
+ if n_pooled >= 2:
674
+ m_pooled = n_pooled - 1 # No FPC for pooled singletons
675
+ drawn = rng.choice(n_pooled, size=m_pooled, replace=True)
676
+ counts = np.bincount(drawn, minlength=n_pooled)
677
+ scale_per_psu = (n_pooled / m_pooled) * counts.astype(np.float64)
678
+
679
+ # Build PSU → scale mapping and apply
680
+ psu_scale_map = {int(pooled_psus[i]): scale_per_psu[i] for i in range(n_pooled)}
681
+ for mask_h, _ in _singleton_info:
682
+ obs_in_h = np.where(mask_h)[0]
683
+ for idx in obs_in_h:
684
+ p = int(obs_psu[idx])
685
+ rescaled[idx] = base_weights[idx] * psu_scale_map.get(p, 1.0)
686
+ else:
687
+ # Single singleton — cannot pool, keep base weights (library-specific
688
+ # fallback; bootstrap adjust with one singleton = remove).
689
+ import warnings
690
+
691
+ warnings.warn(
692
+ "lonely_psu='adjust' with only 1 singleton stratum in "
693
+ "bootstrap: singleton PSU contributes zero variance "
694
+ "(same as 'remove'). At least 2 singleton strata are "
695
+ "needed for pooled pseudo-stratum bootstrap.",
696
+ UserWarning,
697
+ stacklevel=2,
698
+ )
699
+ for mask_h, _ in _singleton_info:
700
+ rescaled[mask_h] = base_weights[mask_h]
701
+
632
702
  return rescaled
633
703
 
634
704
 
@@ -154,6 +154,15 @@ class ContinuousDiDResults:
154
154
  f"n_periods={len(self.time_periods)})"
155
155
  )
156
156
 
157
+ @property
158
+ def coef_var(self) -> float:
159
+ """Coefficient of variation: SE / |overall ATT|. NaN when ATT is 0 or SE non-finite."""
160
+ if not (np.isfinite(self.overall_att_se) and self.overall_att_se >= 0):
161
+ return np.nan
162
+ if not np.isfinite(self.overall_att) or self.overall_att == 0:
163
+ return np.nan
164
+ return self.overall_att_se / abs(self.overall_att)
165
+
157
166
  def summary(self, alpha: Optional[float] = None) -> str:
158
167
  """Generate formatted summary."""
159
168
  alpha = alpha or self.alpha
@@ -223,10 +232,15 @@ class ContinuousDiDResults:
223
232
  f"[{self.overall_att_conf_int[0]:.4f}, {self.overall_att_conf_int[1]:.4f}]",
224
233
  f"{conf_level}% CI for ACRT_glob: "
225
234
  f"[{self.overall_acrt_conf_int[0]:.4f}, {self.overall_acrt_conf_int[1]:.4f}]",
226
- "",
227
235
  ]
228
236
  )
229
237
 
238
+ cv = self.coef_var
239
+ if np.isfinite(cv):
240
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
241
+
242
+ lines.append("")
243
+
230
244
  # Dose-response curve summary (first/mid/last points)
231
245
  if len(self.dose_grid) > 0:
232
246
  lines.extend(
@@ -347,8 +347,6 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
347
347
  ValueError
348
348
  Missing columns, unbalanced panel, non-absorbing treatment,
349
349
  or PT-Post without a never-treated group.
350
- NotImplementedError
351
- If ``covariates`` and ``survey_design`` are both set.
352
350
  """
353
351
  self._validate_params()
354
352
 
@@ -381,16 +379,6 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
381
379
 
382
380
  # Bootstrap + survey supported via PSU-level multiplier bootstrap.
383
381
 
384
- # Guard covariates + survey (DR path does not yet thread survey weights)
385
- if covariates is not None and len(covariates) > 0 and resolved_survey is not None:
386
- raise NotImplementedError(
387
- "Survey weights with covariates are not yet supported for "
388
- "EfficientDiD. The doubly robust covariate path does not "
389
- "thread survey weights through nuisance estimation. "
390
- "Use covariates=None with survey_design, or drop survey_design "
391
- "when using covariates."
392
- )
393
-
394
382
  # Normalize empty covariates list to None (use nocov path)
395
383
  if covariates is not None and len(covariates) == 0:
396
384
  covariates = None
@@ -583,6 +571,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
583
571
  # Use the resolved survey's weights (already normalized per weight_type)
584
572
  # subset to unit level via _unit_first_panel_row (aligned to all_units)
585
573
  unit_level_weights = self._unit_resolved_survey.weights
574
+ self._unit_level_weights = unit_level_weights
586
575
 
587
576
  cohort_fractions: Dict[float, float] = {}
588
577
  if unit_level_weights is not None:
@@ -617,6 +606,15 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
617
606
  stacklevel=2,
618
607
  )
619
608
 
609
+ # Guard: never-treated with zero survey weight → no valid comparisons
610
+ # Applies to both covariates (DR nuisance) and nocov (weighted means) paths
611
+ if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
612
+ raise ValueError(
613
+ "Never-treated group has zero survey weight. EfficientDiD "
614
+ "requires a never-treated control group with positive "
615
+ "survey weight for estimation."
616
+ )
617
+
620
618
  # ----- Covariate preparation (if provided) -----
621
619
  covariate_matrix: Optional[np.ndarray] = None
622
620
  m_hat_cache: Dict[Tuple, np.ndarray] = {}
@@ -686,6 +684,15 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
686
684
  else:
687
685
  effective_p1_col = period_1_col
688
686
 
687
+ # Guard: skip cohorts with zero survey weight (all units zero-weighted)
688
+ if cohort_fractions[g] <= 0:
689
+ warnings.warn(
690
+ f"Cohort {g} has zero survey weight; skipping.",
691
+ UserWarning,
692
+ stacklevel=2,
693
+ )
694
+ continue
695
+
689
696
  # Estimate all (g, t) cells including pre-treatment. Under PT-Post,
690
697
  # pre-treatment cells serve as placebo/pre-trend diagnostics, matching
691
698
  # the CallawaySantAnna implementation. Users filter to t >= g for
@@ -707,6 +714,15 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
707
714
  anticipation=self.anticipation,
708
715
  )
709
716
 
717
+ # Filter out comparison pairs with zero survey weight
718
+ if unit_level_weights is not None and pairs:
719
+ pairs = [
720
+ (gp, tpre) for gp, tpre in pairs
721
+ if np.sum(unit_level_weights[
722
+ never_treated_mask if np.isinf(gp) else cohort_masks[gp]
723
+ ]) > 0
724
+ ]
725
+
710
726
  if not pairs:
711
727
  warnings.warn(
712
728
  f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
@@ -742,6 +758,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
742
758
  never_treated_mask,
743
759
  t_col_val,
744
760
  tpre_col_val,
761
+ unit_weights=unit_level_weights,
745
762
  )
746
763
  # m_{g', tpre, 1}(X)
747
764
  key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
@@ -755,6 +772,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
755
772
  gp_mask_for_reg,
756
773
  tpre_col_val,
757
774
  effective_p1_col,
775
+ unit_weights=unit_level_weights,
758
776
  )
759
777
  # r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
760
778
  for comp in {np.inf, gp}:
@@ -770,6 +788,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
770
788
  k_max=self.sieve_k_max,
771
789
  criterion=self.sieve_criterion,
772
790
  ratio_clip=self.ratio_clip,
791
+ unit_weights=unit_level_weights,
773
792
  )
774
793
 
775
794
  # Per-unit DR generated outcomes: shape (n_units, H)
@@ -801,6 +820,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
801
820
  group_mask_s,
802
821
  k_max=self.sieve_k_max,
803
822
  criterion=self.sieve_criterion,
823
+ unit_weights=unit_level_weights,
804
824
  )
805
825
 
806
826
  # Conditional Omega*(X) with per-unit propensities (Eq 3.12)
@@ -817,14 +837,19 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
817
837
  covariate_matrix=covariate_matrix,
818
838
  s_hat_cache=s_hat_cache,
819
839
  bandwidth=self.kernel_bandwidth,
840
+ unit_weights=unit_level_weights,
820
841
  )
821
842
 
822
843
  # Per-unit weights: (n_units, H)
823
844
  per_unit_w = compute_per_unit_weights(omega_cond)
824
845
 
825
- # ATT = mean_i( w(X_i) @ gen_out[i] )
846
+ # ATT = (survey-)weighted mean of per-unit DR scores
826
847
  if per_unit_w.shape[1] > 0:
827
- att_gt = float(np.mean(np.sum(per_unit_w * gen_out, axis=1)))
848
+ per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
849
+ if unit_level_weights is not None:
850
+ att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
851
+ else:
852
+ att_gt = float(np.mean(per_unit_scores))
828
853
  else:
829
854
  att_gt = np.nan
830
855
 
@@ -979,6 +1004,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
979
1004
  cluster_indices=unit_cluster_indices,
980
1005
  n_clusters=n_clusters,
981
1006
  resolved_survey=self._unit_resolved_survey,
1007
+ unit_level_weights=self._unit_level_weights,
982
1008
  )
983
1009
  # Update estimates with bootstrap inference
984
1010
  overall_se = bootstrap_results.overall_att_se
@@ -1140,6 +1166,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1140
1166
  unit_cohorts: np.ndarray,
1141
1167
  cohort_fractions: Dict[float, float],
1142
1168
  n_units: int,
1169
+ unit_weights: Optional[np.ndarray] = None,
1143
1170
  ) -> np.ndarray:
1144
1171
  """Compute weight influence function correction (O(1) scale, matching EIF).
1145
1172
 
@@ -1159,6 +1186,9 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1159
1186
  ``{cohort: n_cohort / n}`` for each cohort.
1160
1187
  n_units : int
1161
1188
  Total number of units.
1189
+ unit_weights : ndarray, shape (n_units,), optional
1190
+ Survey weights at the unit level. When provided, uses the
1191
+ survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
1162
1192
 
1163
1193
  Returns
1164
1194
  -------
@@ -1172,10 +1202,19 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1172
1202
  return np.zeros(n_units)
1173
1203
 
1174
1204
  indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
1175
- indicator_sum = np.sum(indicator - pg_keepers, axis=1)
1205
+
1206
+ if unit_weights is not None:
1207
+ # Survey-weighted WIF (matches staggered_aggregation.py:392-401):
1208
+ # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
1209
+ weighted_indicator = indicator * unit_weights[:, None]
1210
+ indicator_diff = weighted_indicator - pg_keepers
1211
+ indicator_sum = np.sum(indicator_diff, axis=1)
1212
+ else:
1213
+ indicator_diff = indicator - pg_keepers
1214
+ indicator_sum = np.sum(indicator_diff, axis=1)
1176
1215
 
1177
1216
  with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1178
- if1 = (indicator - pg_keepers) / sum_pg
1217
+ if1 = indicator_diff / sum_pg
1179
1218
  if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
1180
1219
  wif_matrix = if1 - if2
1181
1220
  wif_contrib = wif_matrix @ effects
@@ -1232,13 +1271,34 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1232
1271
 
1233
1272
  # WIF correction: accounts for uncertainty in cohort-size weights
1234
1273
  wif = self._compute_wif_contribution(
1235
- keepers, effects, unit_cohorts, cohort_fractions, n_units
1274
+ keepers, effects, unit_cohorts, cohort_fractions, n_units,
1275
+ unit_weights=self._unit_level_weights,
1236
1276
  )
1237
- agg_eif_total = agg_eif + wif # both O(1) scale
1277
+ # Compute SE: survey path uses score-level psi to avoid double-weighting
1278
+ # (compute_survey_vcov applies w_i internally, which would double-weight
1279
+ # the survey-weighted WIF term). Dispatch replicate vs TSL.
1280
+ if self._unit_resolved_survey is not None:
1281
+ uw = self._unit_level_weights
1282
+ total_w = float(np.sum(uw))
1283
+ psi_total = uw * agg_eif / total_w + wif / total_w
1284
+
1285
+ if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1286
+ and self._unit_resolved_survey.uses_replicate_variance):
1287
+ from diff_diff.survey import compute_replicate_if_variance
1288
+
1289
+ variance, _ = compute_replicate_if_variance(
1290
+ psi_total, self._unit_resolved_survey
1291
+ )
1292
+ else:
1293
+ from diff_diff.survey import compute_survey_if_variance
1238
1294
 
1239
- # SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
1240
- # (dispatches to survey TSL or cluster-robust when active)
1241
- se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
1295
+ variance = compute_survey_if_variance(
1296
+ psi_total, self._unit_resolved_survey
1297
+ )
1298
+ se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1299
+ else:
1300
+ agg_eif_total = agg_eif + wif
1301
+ se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
1242
1302
 
1243
1303
  return overall_att, se
1244
1304
 
@@ -1324,15 +1384,37 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1324
1384
  agg_eif += w[k] * eif_by_gt[gt]
1325
1385
 
1326
1386
  # WIF correction for event-study aggregation
1387
+ wif_e = np.zeros(n_units)
1327
1388
  if unit_cohorts is not None:
1328
1389
  es_keepers = [(g, t) for (g, t) in gt_pairs]
1329
1390
  es_effects = effs
1330
- wif = self._compute_wif_contribution(
1331
- es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units
1391
+ wif_e = self._compute_wif_contribution(
1392
+ es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
1393
+ unit_weights=self._unit_level_weights,
1332
1394
  )
1333
- agg_eif = agg_eif + wif
1334
1395
 
1335
- agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1396
+ if self._unit_resolved_survey is not None:
1397
+ uw = self._unit_level_weights
1398
+ total_w = float(np.sum(uw))
1399
+ psi_total = uw * agg_eif / total_w + wif_e / total_w
1400
+
1401
+ if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1402
+ and self._unit_resolved_survey.uses_replicate_variance):
1403
+ from diff_diff.survey import compute_replicate_if_variance
1404
+
1405
+ variance, _ = compute_replicate_if_variance(
1406
+ psi_total, self._unit_resolved_survey
1407
+ )
1408
+ else:
1409
+ from diff_diff.survey import compute_survey_if_variance
1410
+
1411
+ variance = compute_survey_if_variance(
1412
+ psi_total, self._unit_resolved_survey
1413
+ )
1414
+ agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1415
+ else:
1416
+ agg_eif = agg_eif + wif_e
1417
+ agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1336
1418
 
1337
1419
  t_stat, p_val, ci = safe_inference(
1338
1420
  agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
@@ -63,6 +63,7 @@ class EfficientDiDBootstrapMixin:
63
63
  cluster_indices: Optional[np.ndarray] = None,
64
64
  n_clusters: Optional[int] = None,
65
65
  resolved_survey: object = None,
66
+ unit_level_weights: Optional[np.ndarray] = None,
66
67
  ) -> EDiDBootstrapResults:
67
68
  """Run multiplier bootstrap on stored EIF values.
68
69
 
@@ -136,11 +137,18 @@ class EfficientDiDBootstrapMixin:
136
137
  original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])
137
138
 
138
139
  # Perturbed ATTs: (n_bootstrap, n_gt)
140
+ # Under survey design, perturb survey-score object w_i * eif_i / sum(w)
141
+ # to match the analytical variance convention (compute_survey_if_variance).
139
142
  bootstrap_atts = np.zeros((self.n_bootstrap, n_gt))
140
143
  for j, gt in enumerate(gt_pairs):
141
144
  eif_gt = eif_by_gt[gt] # shape (n_units,)
142
145
  with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
143
- perturbation = (all_weights @ eif_gt) / n_units
146
+ if unit_level_weights is not None:
147
+ total_w = float(np.sum(unit_level_weights))
148
+ eif_scaled = unit_level_weights * eif_gt / total_w
149
+ perturbation = all_weights @ eif_scaled
150
+ else:
151
+ perturbation = (all_weights @ eif_gt) / n_units
144
152
  bootstrap_atts[:, j] = original_atts[j] + perturbation
145
153
 
146
154
  # Post-treatment mask — also exclude NaN effects