diff-diff 2.8.4__tar.gz → 2.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {diff_diff-2.8.4 → diff_diff-2.9.0}/PKG-INFO +3 -2
  2. {diff_diff-2.8.4 → diff_diff-2.9.0}/README.md +2 -1
  3. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/__init__.py +8 -1
  4. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/efficient_did.py +107 -25
  5. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/efficient_did_bootstrap.py +9 -1
  6. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/efficient_did_covariates.py +102 -17
  7. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/linalg.py +156 -24
  8. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/survey.py +1 -2
  9. diff_diff-2.9.0/diff_diff/wooldridge.py +1090 -0
  10. diff_diff-2.9.0/diff_diff/wooldridge_results.py +341 -0
  11. {diff_diff-2.8.4 → diff_diff-2.9.0}/pyproject.toml +2 -1
  12. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/Cargo.lock +1 -1
  13. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/Cargo.toml +1 -1
  14. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/_backend.py +0 -0
  15. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/bacon.py +0 -0
  16. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/bootstrap_utils.py +0 -0
  17. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/continuous_did.py +0 -0
  18. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/continuous_did_bspline.py +0 -0
  19. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/continuous_did_results.py +0 -0
  20. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/datasets.py +0 -0
  21. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/diagnostics.py +0 -0
  22. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/efficient_did_results.py +0 -0
  23. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/efficient_did_weights.py +0 -0
  24. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/estimators.py +0 -0
  25. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/honest_did.py +0 -0
  26. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/imputation.py +0 -0
  27. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/imputation_bootstrap.py +0 -0
  28. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/imputation_results.py +0 -0
  29. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/power.py +0 -0
  30. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/practitioner.py +0 -0
  31. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/prep.py +0 -0
  32. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/prep_dgp.py +0 -0
  33. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/pretrends.py +0 -0
  34. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/results.py +0 -0
  35. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/stacked_did.py +0 -0
  36. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/stacked_did_results.py +0 -0
  37. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/staggered.py +0 -0
  38. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/staggered_aggregation.py +0 -0
  39. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/staggered_bootstrap.py +0 -0
  40. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/staggered_results.py +0 -0
  41. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/staggered_triple_diff.py +0 -0
  42. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/staggered_triple_diff_results.py +0 -0
  43. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/sun_abraham.py +0 -0
  44. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/synthetic_did.py +0 -0
  45. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/triple_diff.py +0 -0
  46. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/trop.py +0 -0
  47. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/trop_global.py +0 -0
  48. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/trop_local.py +0 -0
  49. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/trop_results.py +0 -0
  50. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/twfe.py +0 -0
  51. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/two_stage.py +0 -0
  52. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/two_stage_bootstrap.py +0 -0
  53. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/two_stage_results.py +0 -0
  54. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/utils.py +0 -0
  55. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/__init__.py +0 -0
  56. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_common.py +0 -0
  57. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_continuous.py +0 -0
  58. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_diagnostic.py +0 -0
  59. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_event_study.py +0 -0
  60. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_power.py +0 -0
  61. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_staggered.py +0 -0
  62. {diff_diff-2.8.4 → diff_diff-2.9.0}/diff_diff/visualization/_synthetic.py +0 -0
  63. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/build.rs +0 -0
  64. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/src/bootstrap.rs +0 -0
  65. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/src/lib.rs +0 -0
  66. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/src/linalg.rs +0 -0
  67. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/src/trop.rs +0 -0
  68. {diff_diff-2.8.4 → diff_diff-2.9.0}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.8.4
3
+ Version: 2.9.0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -134,7 +134,7 @@ Detailed guide: [`docs/llms-practitioner.txt`](docs/llms-practitioner.txt)
134
134
  - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights
135
135
  - **Panel data support**: Two-way fixed effects estimator for panel designs
136
136
  - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects
137
- - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), and Efficient DiD (Chen, Sant'Anna & Xie 2025) estimators for heterogeneous treatment timing
137
+ - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), Efficient DiD (Chen, Sant'Anna & Xie 2025), and Wooldridge ETWFE (2021/2023) estimators for heterogeneous treatment timing
138
138
  - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling
139
139
  - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness
140
140
  - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025)
@@ -167,6 +167,7 @@ All estimators have short aliases for convenience:
167
167
  | `Stacked` | `StackedDiD` | Stacked DiD |
168
168
  | `Bacon` | `BaconDecomposition` | Goodman-Bacon decomposition |
169
169
  | `EDiD` | `EfficientDiD` | Efficient DiD |
170
+ | `ETWFE` | `WooldridgeDiD` | Wooldridge ETWFE (2021/2023) |
170
171
 
171
172
  `TROP` already uses its short canonical name and needs no alias.
172
173
 
@@ -84,7 +84,7 @@ Detailed guide: [`docs/llms-practitioner.txt`](docs/llms-practitioner.txt)
84
84
  - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights
85
85
  - **Panel data support**: Two-way fixed effects estimator for panel designs
86
86
  - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects
87
- - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), and Efficient DiD (Chen, Sant'Anna & Xie 2025) estimators for heterogeneous treatment timing
87
+ - **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), Efficient DiD (Chen, Sant'Anna & Xie 2025), and Wooldridge ETWFE (2021/2023) estimators for heterogeneous treatment timing
88
88
  - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling
89
89
  - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness
90
90
  - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025)
@@ -117,6 +117,7 @@ All estimators have short aliases for convenience:
117
117
  | `Stacked` | `StackedDiD` | Stacked DiD |
118
118
  | `Bacon` | `BaconDecomposition` | Goodman-Bacon decomposition |
119
119
  | `EDiD` | `EfficientDiD` | Efficient DiD |
120
+ | `ETWFE` | `WooldridgeDiD` | Wooldridge ETWFE (2021/2023) |
120
121
 
121
122
  `TROP` already uses its short canonical name and needs no alias.
122
123
 
@@ -164,6 +164,8 @@ from diff_diff.trop import (
164
164
  TROPResults,
165
165
  trop,
166
166
  )
167
+ from diff_diff.wooldridge import WooldridgeDiD
168
+ from diff_diff.wooldridge_results import WooldridgeDiDResults
167
169
  from diff_diff.utils import (
168
170
  WildBootstrapResults,
169
171
  check_parallel_trends,
@@ -210,8 +212,9 @@ SDDD = StaggeredTripleDifference
210
212
  Stacked = StackedDiD
211
213
  Bacon = BaconDecomposition
212
214
  EDiD = EfficientDiD
215
+ ETWFE = WooldridgeDiD
213
216
 
214
- __version__ = "2.8.4"
217
+ __version__ = "2.9.0"
215
218
  __all__ = [
216
219
  # Estimators
217
220
  "DifferenceInDifferences",
@@ -276,6 +279,10 @@ __all__ = [
276
279
  "EfficientDiDResults",
277
280
  "EDiDBootstrapResults",
278
281
  "EDiD",
282
+ # WooldridgeDiD (ETWFE)
283
+ "WooldridgeDiD",
284
+ "WooldridgeDiDResults",
285
+ "ETWFE",
279
286
  # Visualization
280
287
  "plot_bacon",
281
288
  "plot_event_study",
@@ -347,8 +347,6 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
347
347
  ValueError
348
348
  Missing columns, unbalanced panel, non-absorbing treatment,
349
349
  or PT-Post without a never-treated group.
350
- NotImplementedError
351
- If ``covariates`` and ``survey_design`` are both set.
352
350
  """
353
351
  self._validate_params()
354
352
 
@@ -381,16 +379,6 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
381
379
 
382
380
  # Bootstrap + survey supported via PSU-level multiplier bootstrap.
383
381
 
384
- # Guard covariates + survey (DR path does not yet thread survey weights)
385
- if covariates is not None and len(covariates) > 0 and resolved_survey is not None:
386
- raise NotImplementedError(
387
- "Survey weights with covariates are not yet supported for "
388
- "EfficientDiD. The doubly robust covariate path does not "
389
- "thread survey weights through nuisance estimation. "
390
- "Use covariates=None with survey_design, or drop survey_design "
391
- "when using covariates."
392
- )
393
-
394
382
  # Normalize empty covariates list to None (use nocov path)
395
383
  if covariates is not None and len(covariates) == 0:
396
384
  covariates = None
@@ -583,6 +571,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
583
571
  # Use the resolved survey's weights (already normalized per weight_type)
584
572
  # subset to unit level via _unit_first_panel_row (aligned to all_units)
585
573
  unit_level_weights = self._unit_resolved_survey.weights
574
+ self._unit_level_weights = unit_level_weights
586
575
 
587
576
  cohort_fractions: Dict[float, float] = {}
588
577
  if unit_level_weights is not None:
@@ -617,6 +606,15 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
617
606
  stacklevel=2,
618
607
  )
619
608
 
609
+ # Guard: never-treated with zero survey weight → no valid comparisons
610
+ # Applies to both covariates (DR nuisance) and nocov (weighted means) paths
611
+ if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
612
+ raise ValueError(
613
+ "Never-treated group has zero survey weight. EfficientDiD "
614
+ "requires a never-treated control group with positive "
615
+ "survey weight for estimation."
616
+ )
617
+
620
618
  # ----- Covariate preparation (if provided) -----
621
619
  covariate_matrix: Optional[np.ndarray] = None
622
620
  m_hat_cache: Dict[Tuple, np.ndarray] = {}
@@ -686,6 +684,15 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
686
684
  else:
687
685
  effective_p1_col = period_1_col
688
686
 
687
+ # Guard: skip cohorts with zero survey weight (all units zero-weighted)
688
+ if cohort_fractions[g] <= 0:
689
+ warnings.warn(
690
+ f"Cohort {g} has zero survey weight; skipping.",
691
+ UserWarning,
692
+ stacklevel=2,
693
+ )
694
+ continue
695
+
689
696
  # Estimate all (g, t) cells including pre-treatment. Under PT-Post,
690
697
  # pre-treatment cells serve as placebo/pre-trend diagnostics, matching
691
698
  # the CallawaySantAnna implementation. Users filter to t >= g for
@@ -707,6 +714,15 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
707
714
  anticipation=self.anticipation,
708
715
  )
709
716
 
717
+ # Filter out comparison pairs with zero survey weight
718
+ if unit_level_weights is not None and pairs:
719
+ pairs = [
720
+ (gp, tpre) for gp, tpre in pairs
721
+ if np.sum(unit_level_weights[
722
+ never_treated_mask if np.isinf(gp) else cohort_masks[gp]
723
+ ]) > 0
724
+ ]
725
+
710
726
  if not pairs:
711
727
  warnings.warn(
712
728
  f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
@@ -742,6 +758,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
742
758
  never_treated_mask,
743
759
  t_col_val,
744
760
  tpre_col_val,
761
+ unit_weights=unit_level_weights,
745
762
  )
746
763
  # m_{g', tpre, 1}(X)
747
764
  key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
@@ -755,6 +772,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
755
772
  gp_mask_for_reg,
756
773
  tpre_col_val,
757
774
  effective_p1_col,
775
+ unit_weights=unit_level_weights,
758
776
  )
759
777
  # r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
760
778
  for comp in {np.inf, gp}:
@@ -770,6 +788,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
770
788
  k_max=self.sieve_k_max,
771
789
  criterion=self.sieve_criterion,
772
790
  ratio_clip=self.ratio_clip,
791
+ unit_weights=unit_level_weights,
773
792
  )
774
793
 
775
794
  # Per-unit DR generated outcomes: shape (n_units, H)
@@ -801,6 +820,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
801
820
  group_mask_s,
802
821
  k_max=self.sieve_k_max,
803
822
  criterion=self.sieve_criterion,
823
+ unit_weights=unit_level_weights,
804
824
  )
805
825
 
806
826
  # Conditional Omega*(X) with per-unit propensities (Eq 3.12)
@@ -817,14 +837,19 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
817
837
  covariate_matrix=covariate_matrix,
818
838
  s_hat_cache=s_hat_cache,
819
839
  bandwidth=self.kernel_bandwidth,
840
+ unit_weights=unit_level_weights,
820
841
  )
821
842
 
822
843
  # Per-unit weights: (n_units, H)
823
844
  per_unit_w = compute_per_unit_weights(omega_cond)
824
845
 
825
- # ATT = mean_i( w(X_i) @ gen_out[i] )
846
+ # ATT = (survey-)weighted mean of per-unit DR scores
826
847
  if per_unit_w.shape[1] > 0:
827
- att_gt = float(np.mean(np.sum(per_unit_w * gen_out, axis=1)))
848
+ per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
849
+ if unit_level_weights is not None:
850
+ att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
851
+ else:
852
+ att_gt = float(np.mean(per_unit_scores))
828
853
  else:
829
854
  att_gt = np.nan
830
855
 
@@ -979,6 +1004,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
979
1004
  cluster_indices=unit_cluster_indices,
980
1005
  n_clusters=n_clusters,
981
1006
  resolved_survey=self._unit_resolved_survey,
1007
+ unit_level_weights=self._unit_level_weights,
982
1008
  )
983
1009
  # Update estimates with bootstrap inference
984
1010
  overall_se = bootstrap_results.overall_att_se
@@ -1140,6 +1166,7 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1140
1166
  unit_cohorts: np.ndarray,
1141
1167
  cohort_fractions: Dict[float, float],
1142
1168
  n_units: int,
1169
+ unit_weights: Optional[np.ndarray] = None,
1143
1170
  ) -> np.ndarray:
1144
1171
  """Compute weight influence function correction (O(1) scale, matching EIF).
1145
1172
 
@@ -1159,6 +1186,9 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1159
1186
  ``{cohort: n_cohort / n}`` for each cohort.
1160
1187
  n_units : int
1161
1188
  Total number of units.
1189
+ unit_weights : ndarray, shape (n_units,), optional
1190
+ Survey weights at the unit level. When provided, uses the
1191
+ survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
1162
1192
 
1163
1193
  Returns
1164
1194
  -------
@@ -1172,10 +1202,19 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1172
1202
  return np.zeros(n_units)
1173
1203
 
1174
1204
  indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
1175
- indicator_sum = np.sum(indicator - pg_keepers, axis=1)
1205
+
1206
+ if unit_weights is not None:
1207
+ # Survey-weighted WIF (matches staggered_aggregation.py:392-401):
1208
+ # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
1209
+ weighted_indicator = indicator * unit_weights[:, None]
1210
+ indicator_diff = weighted_indicator - pg_keepers
1211
+ indicator_sum = np.sum(indicator_diff, axis=1)
1212
+ else:
1213
+ indicator_diff = indicator - pg_keepers
1214
+ indicator_sum = np.sum(indicator_diff, axis=1)
1176
1215
 
1177
1216
  with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1178
- if1 = (indicator - pg_keepers) / sum_pg
1217
+ if1 = indicator_diff / sum_pg
1179
1218
  if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
1180
1219
  wif_matrix = if1 - if2
1181
1220
  wif_contrib = wif_matrix @ effects
@@ -1232,13 +1271,34 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1232
1271
 
1233
1272
  # WIF correction: accounts for uncertainty in cohort-size weights
1234
1273
  wif = self._compute_wif_contribution(
1235
- keepers, effects, unit_cohorts, cohort_fractions, n_units
1274
+ keepers, effects, unit_cohorts, cohort_fractions, n_units,
1275
+ unit_weights=self._unit_level_weights,
1236
1276
  )
1237
- agg_eif_total = agg_eif + wif # both O(1) scale
1277
+ # Compute SE: survey path uses score-level psi to avoid double-weighting
1278
+ # (compute_survey_vcov applies w_i internally, which would double-weight
1279
+ # the survey-weighted WIF term). Dispatch replicate vs TSL.
1280
+ if self._unit_resolved_survey is not None:
1281
+ uw = self._unit_level_weights
1282
+ total_w = float(np.sum(uw))
1283
+ psi_total = uw * agg_eif / total_w + wif / total_w
1284
+
1285
+ if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1286
+ and self._unit_resolved_survey.uses_replicate_variance):
1287
+ from diff_diff.survey import compute_replicate_if_variance
1288
+
1289
+ variance, _ = compute_replicate_if_variance(
1290
+ psi_total, self._unit_resolved_survey
1291
+ )
1292
+ else:
1293
+ from diff_diff.survey import compute_survey_if_variance
1238
1294
 
1239
- # SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
1240
- # (dispatches to survey TSL or cluster-robust when active)
1241
- se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
1295
+ variance = compute_survey_if_variance(
1296
+ psi_total, self._unit_resolved_survey
1297
+ )
1298
+ se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1299
+ else:
1300
+ agg_eif_total = agg_eif + wif
1301
+ se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
1242
1302
 
1243
1303
  return overall_att, se
1244
1304
 
@@ -1324,15 +1384,37 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
1324
1384
  agg_eif += w[k] * eif_by_gt[gt]
1325
1385
 
1326
1386
  # WIF correction for event-study aggregation
1387
+ wif_e = np.zeros(n_units)
1327
1388
  if unit_cohorts is not None:
1328
1389
  es_keepers = [(g, t) for (g, t) in gt_pairs]
1329
1390
  es_effects = effs
1330
- wif = self._compute_wif_contribution(
1331
- es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units
1391
+ wif_e = self._compute_wif_contribution(
1392
+ es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
1393
+ unit_weights=self._unit_level_weights,
1332
1394
  )
1333
- agg_eif = agg_eif + wif
1334
1395
 
1335
- agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1396
+ if self._unit_resolved_survey is not None:
1397
+ uw = self._unit_level_weights
1398
+ total_w = float(np.sum(uw))
1399
+ psi_total = uw * agg_eif / total_w + wif_e / total_w
1400
+
1401
+ if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1402
+ and self._unit_resolved_survey.uses_replicate_variance):
1403
+ from diff_diff.survey import compute_replicate_if_variance
1404
+
1405
+ variance, _ = compute_replicate_if_variance(
1406
+ psi_total, self._unit_resolved_survey
1407
+ )
1408
+ else:
1409
+ from diff_diff.survey import compute_survey_if_variance
1410
+
1411
+ variance = compute_survey_if_variance(
1412
+ psi_total, self._unit_resolved_survey
1413
+ )
1414
+ agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1415
+ else:
1416
+ agg_eif = agg_eif + wif_e
1417
+ agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1336
1418
 
1337
1419
  t_stat, p_val, ci = safe_inference(
1338
1420
  agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
@@ -63,6 +63,7 @@ class EfficientDiDBootstrapMixin:
63
63
  cluster_indices: Optional[np.ndarray] = None,
64
64
  n_clusters: Optional[int] = None,
65
65
  resolved_survey: object = None,
66
+ unit_level_weights: Optional[np.ndarray] = None,
66
67
  ) -> EDiDBootstrapResults:
67
68
  """Run multiplier bootstrap on stored EIF values.
68
69
 
@@ -136,11 +137,18 @@ class EfficientDiDBootstrapMixin:
136
137
  original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])
137
138
 
138
139
  # Perturbed ATTs: (n_bootstrap, n_gt)
140
+ # Under survey design, perturb survey-score object w_i * eif_i / sum(w)
141
+ # to match the analytical variance convention (compute_survey_if_variance).
139
142
  bootstrap_atts = np.zeros((self.n_bootstrap, n_gt))
140
143
  for j, gt in enumerate(gt_pairs):
141
144
  eif_gt = eif_by_gt[gt] # shape (n_units,)
142
145
  with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
143
- perturbation = (all_weights @ eif_gt) / n_units
146
+ if unit_level_weights is not None:
147
+ total_w = float(np.sum(unit_level_weights))
148
+ eif_scaled = unit_level_weights * eif_gt / total_w
149
+ perturbation = all_weights @ eif_scaled
150
+ else:
151
+ perturbation = (all_weights @ eif_gt) / n_units
144
152
  bootstrap_atts[:, j] = original_atts[j] + perturbation
145
153
 
146
154
  # Post-treatment mask — also exclude NaN effects
@@ -37,6 +37,7 @@ def estimate_outcome_regression(
37
37
  group_mask: np.ndarray,
38
38
  t_col: int,
39
39
  tpre_col: int,
40
+ unit_weights: Optional[np.ndarray] = None,
40
41
  ) -> np.ndarray:
41
42
  """Estimate conditional mean outcome change m_hat(X) for a comparison group.
42
43
 
@@ -56,6 +57,9 @@ def estimate_outcome_regression(
56
57
  Mask selecting units in the comparison group.
57
58
  t_col, tpre_col : int
58
59
  Column indices in ``outcome_wide`` for the two time periods.
60
+ unit_weights : ndarray, shape (n_units,), optional
61
+ Survey weights at the unit level. When provided, uses WLS
62
+ instead of OLS for the within-group regression.
59
63
 
60
64
  Returns
61
65
  -------
@@ -68,9 +72,13 @@ def estimate_outcome_regression(
68
72
  X_group = covariate_matrix[group_mask]
69
73
  X_design = np.column_stack([np.ones(len(X_group)), X_group])
70
74
 
75
+ w_group = unit_weights[group_mask] if unit_weights is not None else None
76
+
71
77
  coef, _, _ = solve_ols(
72
78
  X_design,
73
79
  delta_y,
80
+ weights=w_group,
81
+ weight_type="pweight" if w_group is not None else None,
74
82
  return_vcov=False,
75
83
  rank_deficient_action="warn",
76
84
  )
@@ -121,7 +129,9 @@ def _polynomial_sieve_basis(X: np.ndarray, degree: int) -> np.ndarray:
121
129
  """
122
130
  n, d = X.shape
123
131
 
124
- # Standardize for numerical stability
132
+ # Standardize for numerical stability (unweighted mean/std intentional —
133
+ # this is only for conditioning, not for the statistical estimand; with
134
+ # survey weights the sieve basis is the same, only the objective changes)
125
135
  X_mean = X.mean(axis=0)
126
136
  X_std = X.std(axis=0)
127
137
  X_std[X_std < 1e-10] = 1.0 # avoid division by zero for constant columns
@@ -146,6 +156,7 @@ def estimate_propensity_ratio_sieve(
146
156
  k_max: Optional[int] = None,
147
157
  criterion: str = "bic",
148
158
  ratio_clip: float = 20.0,
159
+ unit_weights: Optional[np.ndarray] = None,
149
160
  ) -> np.ndarray:
150
161
  r"""Estimate propensity ratio via sieve convex minimization (Eq 4.1-4.2).
151
162
 
@@ -176,6 +187,9 @@ def estimate_propensity_ratio_sieve(
176
187
  ``"aic"`` or ``"bic"``.
177
188
  ratio_clip : float
178
189
  Clip ratios to ``[1/ratio_clip, ratio_clip]``.
190
+ unit_weights : ndarray, shape (n_units,), optional
191
+ Survey weights at the unit level. When provided, uses weighted
192
+ normal equations for the sieve estimation.
179
193
 
180
194
  Returns
181
195
  -------
@@ -197,9 +211,20 @@ def estimate_propensity_ratio_sieve(
197
211
  k_max = max(k_max, 1)
198
212
 
199
213
  # Penalty multiplier for IC
214
+ # BIC penalty uses observation count (not weighted) — complexity vs distinct obs
200
215
  n_total = int(np.sum(mask_g)) + n_gp
201
216
  c_n = 2.0 if criterion == "aic" else np.log(max(n_total, 2))
202
217
 
218
+ # Weighted totals for loss normalization (raw probability weights)
219
+ if unit_weights is not None:
220
+ w_g = unit_weights[mask_g]
221
+ w_gp = unit_weights[mask_gp]
222
+ n_total_w = float(np.sum(w_g)) + float(np.sum(w_gp))
223
+ else:
224
+ w_g = None
225
+ w_gp = None
226
+ n_total_w = float(n_total)
227
+
203
228
  best_ic = np.inf
204
229
  best_ratio = np.ones(n_units) # fallback: constant ratio 1
205
230
 
@@ -214,9 +239,15 @@ def estimate_propensity_ratio_sieve(
214
239
  Psi_gp = basis_all[mask_gp] # (n_gp, n_basis)
215
240
  Psi_g = basis_all[mask_g] # (n_g, n_basis)
216
241
 
217
- # Normal equations: (Psi_gp' Psi_gp) beta = Psi_g.sum(axis=0)
218
- A = Psi_gp.T @ Psi_gp
219
- b = Psi_g.sum(axis=0)
242
+ # Normal equations (weighted when survey weights present):
243
+ # Unweighted: (Psi_gp' Psi_gp) beta = Psi_g.sum(axis=0)
244
+ # Weighted: (Psi_gp' W_gp Psi_gp) beta = (w_g * Psi_g).sum(axis=0)
245
+ if w_gp is not None:
246
+ A = Psi_gp.T @ (w_gp[:, None] * Psi_gp)
247
+ b = (w_g[:, None] * Psi_g).sum(axis=0)
248
+ else:
249
+ A = Psi_gp.T @ Psi_gp
250
+ b = Psi_g.sum(axis=0)
220
251
 
221
252
  try:
222
253
  beta = np.linalg.solve(A, b)
@@ -230,11 +261,12 @@ def estimate_propensity_ratio_sieve(
230
261
  # Predicted ratio for all units
231
262
  r_hat = basis_all @ beta
232
263
 
233
- # IC selection: loss at optimum = -(1/n) * b'beta
234
- # Derivation: L(beta) = (1/n)(beta'A*beta - 2*b'beta).
264
+ # IC selection: loss at optimum = -(1/n_w) * b'beta
265
+ # Derivation: L(beta) = (1/n_w)(beta'A*beta - 2*b'beta).
235
266
  # At optimum A*beta = b, so beta'A*beta = b'beta.
236
- # Therefore L = (1/n)(b'beta - 2*b'beta) = -(1/n)*b'beta.
237
- loss = -float(b @ beta) / n_total
267
+ # Therefore L = (1/n_w)(b'beta - 2*b'beta) = -(1/n_w)*b'beta.
268
+ # Loss uses weighted totals; BIC penalty uses observation count.
269
+ loss = -float(b @ beta) / n_total_w
238
270
  ic_val = 2.0 * loss + c_n * n_basis / n_total
239
271
 
240
272
  if ic_val < best_ic:
@@ -280,6 +312,7 @@ def estimate_inverse_propensity_sieve(
280
312
  group_mask: np.ndarray,
281
313
  k_max: Optional[int] = None,
282
314
  criterion: str = "bic",
315
+ unit_weights: Optional[np.ndarray] = None,
283
316
  ) -> np.ndarray:
284
317
  r"""Estimate s_{g'}(X) = 1/p_{g'}(X) via sieve convex minimization.
285
318
 
@@ -305,6 +338,9 @@ def estimate_inverse_propensity_sieve(
305
338
  Maximum polynomial degree. None = auto.
306
339
  criterion : str
307
340
  ``"aic"`` or ``"bic"``.
341
+ unit_weights : ndarray, shape (n_units,), optional
342
+ Survey weights at the unit level. When provided, uses weighted
343
+ normal equations for the sieve estimation.
308
344
 
309
345
  Returns
310
346
  -------
@@ -322,10 +358,25 @@ def estimate_inverse_propensity_sieve(
322
358
  k_max = min(int(n_group**0.2), 5)
323
359
  k_max = max(k_max, 1)
324
360
 
361
+ # BIC penalty uses observation count (not weighted)
325
362
  c_n = 2.0 if criterion == "aic" else np.log(max(n_units, 2))
326
363
 
364
+ # Weighted loss normalization and fallback
365
+ if unit_weights is not None:
366
+ w_group = unit_weights[group_mask]
367
+ sum_w_group = float(np.sum(w_group))
368
+ if sum_w_group <= 0:
369
+ # Zero survey weight for this group — return unconditional fallback
370
+ return np.ones(n_units)
371
+ n_units_w = float(np.sum(unit_weights))
372
+ fallback_ratio = n_units_w / sum_w_group
373
+ else:
374
+ w_group = None
375
+ n_units_w = float(n_units)
376
+ fallback_ratio = n_units / n_group
377
+
327
378
  best_ic = np.inf
328
- best_s = np.full(n_units, n_units / n_group) # fallback: unconditional
379
+ best_s = np.full(n_units, fallback_ratio) # fallback: unconditional
329
380
 
330
381
  for K in range(1, k_max + 1):
331
382
  n_basis = comb(K + d, d)
@@ -335,9 +386,16 @@ def estimate_inverse_propensity_sieve(
335
386
  basis_all = _polynomial_sieve_basis(covariate_matrix, K)
336
387
  Psi_gp = basis_all[group_mask]
337
388
 
338
- A = Psi_gp.T @ Psi_gp
339
- # RHS: sum of basis over ALL units (not just one group)
340
- b = basis_all.sum(axis=0)
389
+ # Normal equations (weighted when survey weights present):
390
+ # Unweighted: (Psi_gp' Psi_gp) beta = Psi_all.sum(axis=0)
391
+ # Weighted: (Psi_gp' W_group Psi_gp) beta = (w_all * Psi_all).sum(axis=0)
392
+ if w_group is not None:
393
+ A = Psi_gp.T @ (w_group[:, None] * Psi_gp)
394
+ b = (unit_weights[:, None] * basis_all).sum(axis=0)
395
+ else:
396
+ A = Psi_gp.T @ Psi_gp
397
+ # RHS: sum of basis over ALL units (not just one group)
398
+ b = basis_all.sum(axis=0)
341
399
 
342
400
  try:
343
401
  beta = np.linalg.solve(A, b)
@@ -348,8 +406,9 @@ def estimate_inverse_propensity_sieve(
348
406
 
349
407
  s_hat = basis_all @ beta
350
408
 
351
- # IC: loss = -(1/n) * b'beta (same derivation as ratio estimator)
352
- loss = -float(b @ beta) / n_units
409
+ # IC: loss = -(1/n_w) * b'beta (same derivation as ratio estimator)
410
+ # Loss uses weighted totals; BIC penalty uses observation count.
411
+ loss = -float(b @ beta) / n_units_w
353
412
  ic_val = 2.0 * loss + c_n * n_basis / n_units
354
413
 
355
414
  if ic_val < best_ic:
@@ -433,6 +492,10 @@ def compute_generated_outcomes_cov(
433
492
  g_mask = cohort_masks[target_g]
434
493
  pi_g = cohort_fractions[target_g]
435
494
 
495
+ # Guard: zero survey weight for the target cohort → no DR estimation possible
496
+ if pi_g <= 0:
497
+ return np.zeros((n_units, H))
498
+
436
499
  gen_out = np.zeros((n_units, H))
437
500
 
438
501
  for j, (gp, tpre) in enumerate(valid_pairs):
@@ -496,6 +559,7 @@ def _kernel_weights_matrix(
496
559
  X_all: np.ndarray,
497
560
  X_group: np.ndarray,
498
561
  bandwidth: float,
562
+ group_weights: Optional[np.ndarray] = None,
499
563
  ) -> np.ndarray:
500
564
  """Gaussian kernel weight matrix.
501
565
 
@@ -503,11 +567,21 @@ def _kernel_weights_matrix(
503
567
  normalized kernel weight ``K_h(X_group[j], X_all[i])``.
504
568
 
505
569
  Each row sums to 1 (Nadaraya-Watson normalization).
570
+
571
+ Parameters
572
+ ----------
573
+ group_weights : ndarray, shape (n_group,), optional
574
+ Survey weights for the group units. When provided, kernel
575
+ weights are multiplied by survey weights before row-normalization,
576
+ making the Nadaraya-Watson estimator survey-weighted.
506
577
  """
507
578
  # Squared distances: (n_all, n_group)
508
579
  dist_sq = cdist(X_all, X_group, metric="sqeuclidean")
509
580
  # Gaussian kernel
510
581
  raw = np.exp(-dist_sq / (2.0 * bandwidth**2))
582
+ # Survey-weight: each group unit j contributes ∝ w_j * K_h(X_i, X_j)
583
+ if group_weights is not None:
584
+ raw = raw * group_weights[np.newaxis, :]
511
585
  # Normalize each row
512
586
  row_sums = raw.sum(axis=1, keepdims=True)
513
587
  row_sums[row_sums < 1e-15] = 1.0 # avoid division by zero
@@ -559,6 +633,7 @@ def compute_omega_star_conditional(
559
633
  covariate_matrix: np.ndarray,
560
634
  s_hat_cache: Dict[float, np.ndarray],
561
635
  bandwidth: Optional[float] = None,
636
+ unit_weights: Optional[np.ndarray] = None,
562
637
  never_treated_val: float = np.inf,
563
638
  ) -> np.ndarray:
564
639
  r"""Kernel-smoothed conditional Omega\*(X_i) for each unit (Eq 3.12).
@@ -583,6 +658,9 @@ def compute_omega_star_conditional(
583
658
  value is shape ``(n_units,)``. Keyed by group identifier.
584
659
  bandwidth : float or None
585
660
  Kernel bandwidth. None = Silverman's rule.
661
+ unit_weights : ndarray, shape (n_units,), optional
662
+ Survey weights at the unit level. When provided, kernel-smoothed
663
+ covariances use survey-weighted Nadaraya-Watson regression.
586
664
  never_treated_val : float
587
665
 
588
666
  Returns
@@ -622,13 +700,17 @@ def compute_omega_star_conditional(
622
700
  stacklevel=2,
623
701
  )
624
702
 
703
+ # Per-group survey weights for kernel smoothing
704
+ w_g = unit_weights[g_mask] if unit_weights is not None else None
705
+ w_inf = unit_weights[never_treated_mask] if unit_weights is not None else None
706
+
625
707
  # Pre-compute kernel weight matrices per group
626
708
  Y_g = outcome_wide[g_mask]
627
709
  X_g = covariate_matrix[g_mask]
628
710
  Yg_t_minus_1 = Y_g[:, t_col] - Y_g[:, y1_col]
629
711
 
630
- W_g = _kernel_weights_matrix(covariate_matrix, X_g, bandwidth)
631
- W_inf = _kernel_weights_matrix(covariate_matrix, X_inf, bandwidth)
712
+ W_g = _kernel_weights_matrix(covariate_matrix, X_g, bandwidth, group_weights=w_g)
713
+ W_inf = _kernel_weights_matrix(covariate_matrix, X_inf, bandwidth, group_weights=w_inf)
632
714
 
633
715
  inf_t_minus_tpre = {}
634
716
  for _, tpre in valid_pairs:
@@ -683,7 +765,10 @@ def compute_omega_star_conditional(
683
765
  )
684
766
  if gp_j not in W_gp_cache:
685
767
  X_gp = covariate_matrix[cohort_masks[gp_j]]
686
- W_gp_cache[gp_j] = _kernel_weights_matrix(covariate_matrix, X_gp, bandwidth)
768
+ w_gp_j = unit_weights[cohort_masks[gp_j]] if unit_weights is not None else None
769
+ W_gp_cache[gp_j] = _kernel_weights_matrix(
770
+ covariate_matrix, X_gp, bandwidth, group_weights=w_gp_j
771
+ )
687
772
  gp_outcomes_cache[gp_j] = outcome_wide[cohort_masks[gp_j]]
688
773
  W_gp = W_gp_cache[gp_j]
689
774
  Y_gp = gp_outcomes_cache[gp_j]