diff-diff 2.9.1__tar.gz → 3.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {diff_diff-2.9.1 → diff_diff-3.0.0}/PKG-INFO +1 -1
  2. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/__init__.py +1 -1
  3. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/linalg.py +64 -2
  4. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/staggered.py +1 -20
  5. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/staggered_bootstrap.py +4 -4
  6. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/staggered_triple_diff.py +1 -2
  7. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/synthetic_did.py +4 -3
  8. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/trop.py +4 -37
  9. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/wooldridge.py +267 -72
  10. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/wooldridge_results.py +9 -1
  11. {diff_diff-2.9.1 → diff_diff-3.0.0}/pyproject.toml +1 -1
  12. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/Cargo.lock +1 -1
  13. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/Cargo.toml +1 -1
  14. {diff_diff-2.9.1 → diff_diff-3.0.0}/README.md +0 -0
  15. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/_backend.py +0 -0
  16. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/bacon.py +0 -0
  17. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/bootstrap_utils.py +0 -0
  18. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/continuous_did.py +0 -0
  19. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/continuous_did_bspline.py +0 -0
  20. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/continuous_did_results.py +0 -0
  21. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/datasets.py +0 -0
  22. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/diagnostics.py +0 -0
  23. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/efficient_did.py +0 -0
  24. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/efficient_did_bootstrap.py +0 -0
  25. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/efficient_did_covariates.py +0 -0
  26. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/efficient_did_results.py +0 -0
  27. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/efficient_did_weights.py +0 -0
  28. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/estimators.py +0 -0
  29. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/honest_did.py +0 -0
  30. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/imputation.py +0 -0
  31. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/imputation_bootstrap.py +0 -0
  32. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/imputation_results.py +0 -0
  33. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/power.py +0 -0
  34. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/practitioner.py +0 -0
  35. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/prep.py +0 -0
  36. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/prep_dgp.py +0 -0
  37. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/pretrends.py +0 -0
  38. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/results.py +0 -0
  39. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/stacked_did.py +0 -0
  40. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/stacked_did_results.py +0 -0
  41. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/staggered_aggregation.py +0 -0
  42. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/staggered_results.py +0 -0
  43. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/staggered_triple_diff_results.py +0 -0
  44. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/sun_abraham.py +0 -0
  45. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/survey.py +0 -0
  46. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/triple_diff.py +0 -0
  47. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/trop_global.py +0 -0
  48. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/trop_local.py +0 -0
  49. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/trop_results.py +0 -0
  50. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/twfe.py +0 -0
  51. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/two_stage.py +0 -0
  52. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/two_stage_bootstrap.py +0 -0
  53. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/two_stage_results.py +0 -0
  54. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/utils.py +0 -0
  55. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/__init__.py +0 -0
  56. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_common.py +0 -0
  57. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_continuous.py +0 -0
  58. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_diagnostic.py +0 -0
  59. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_event_study.py +0 -0
  60. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_power.py +0 -0
  61. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_staggered.py +0 -0
  62. {diff_diff-2.9.1 → diff_diff-3.0.0}/diff_diff/visualization/_synthetic.py +0 -0
  63. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/build.rs +0 -0
  64. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/src/bootstrap.rs +0 -0
  65. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/src/lib.rs +0 -0
  66. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/src/linalg.rs +0 -0
  67. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/src/trop.rs +0 -0
  68. {diff_diff-2.9.1 → diff_diff-3.0.0}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.9.1
3
+ Version: 3.0.0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -214,7 +214,7 @@ Bacon = BaconDecomposition
214
214
  EDiD = EfficientDiD
215
215
  ETWFE = WooldridgeDiD
216
216
 
217
- __version__ = "2.9.1"
217
+ __version__ = "3.0.0"
218
218
  __all__ = [
219
219
  # Estimators
220
220
  "DifferenceInDifferences",
@@ -2372,6 +2372,7 @@ def solve_poisson(
2372
2372
  tol: float = 1e-8,
2373
2373
  init_beta: Optional[np.ndarray] = None,
2374
2374
  rank_deficient_action: str = "warn",
2375
+ weights: Optional[np.ndarray] = None,
2375
2376
  ) -> Tuple[np.ndarray, np.ndarray]:
2376
2377
  """Poisson IRLS (Newton-Raphson with log link).
2377
2378
 
@@ -2389,6 +2390,9 @@ def solve_poisson(
2389
2390
  log(mean(y)) to improve convergence for large-scale outcomes.
2390
2391
  rank_deficient_action : {"warn", "error", "silent"}
2391
2392
  How to handle rank-deficient design matrices. Mirrors solve_ols/solve_logit.
2393
+ weights : (n,) optional observation weights (e.g. survey sampling weights).
2394
+ When provided, the weighted pseudo-log-likelihood is maximised:
2395
+ score = X'(w*(y - mu)), Hessian = X'diag(w*mu)X.
2392
2396
 
2393
2397
  Returns
2394
2398
  -------
@@ -2397,6 +2401,20 @@ def solve_poisson(
2397
2401
  """
2398
2402
  n, k_orig = X.shape
2399
2403
 
2404
+ # Validate weights (mirrors solve_logit validation)
2405
+ if weights is not None:
2406
+ weights = np.asarray(weights, dtype=np.float64)
2407
+ if weights.shape != (n,):
2408
+ raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
2409
+ if np.any(np.isnan(weights)):
2410
+ raise ValueError("weights contain NaN values")
2411
+ if np.any(~np.isfinite(weights)):
2412
+ raise ValueError("weights contain Inf values")
2413
+ if np.any(weights < 0):
2414
+ raise ValueError("weights must be non-negative")
2415
+ if np.sum(weights) <= 0:
2416
+ raise ValueError("weights sum to zero — no observations have positive weight")
2417
+
2400
2418
  # Validate rank_deficient_action (same as solve_logit/solve_ols)
2401
2419
  valid_actions = ("warn", "error", "silent")
2402
2420
  if rank_deficient_action not in valid_actions:
@@ -2425,6 +2443,46 @@ def solve_poisson(
2425
2443
  X = X[:, kept_cols]
2426
2444
 
2427
2445
  n, k = X.shape
2446
+
2447
+ # Validate effective weighted sample when weights have zeros
2448
+ # (mirrors solve_logit's positive-weight safeguards)
2449
+ if weights is not None and np.any(weights == 0):
2450
+ pos_mask = weights > 0
2451
+ n_pos = int(np.sum(pos_mask))
2452
+ X_eff = X[pos_mask]
2453
+ eff_rank_info = _detect_rank_deficiency(X_eff)
2454
+ if len(eff_rank_info[1]) > 0:
2455
+ n_dropped_eff = len(eff_rank_info[1])
2456
+ if rank_deficient_action == "error":
2457
+ raise ValueError(
2458
+ f"Effective (positive-weight) sample is rank-deficient: "
2459
+ f"{n_dropped_eff} linearly dependent column(s). "
2460
+ f"Cannot identify Poisson model on this subpopulation."
2461
+ )
2462
+ elif rank_deficient_action == "warn":
2463
+ warnings.warn(
2464
+ f"Effective (positive-weight) sample is rank-deficient: "
2465
+ f"dropping {n_dropped_eff} column(s). Poisson estimates "
2466
+ f"may be unreliable on this subpopulation.",
2467
+ UserWarning,
2468
+ stacklevel=2,
2469
+ )
2470
+ eff_dropped = set(int(d) for d in eff_rank_info[1])
2471
+ eff_kept = np.array([i for i in range(k) if i not in eff_dropped])
2472
+ X = X[:, eff_kept]
2473
+ if len(dropped_cols) > 0:
2474
+ kept_cols = kept_cols[eff_kept]
2475
+ else:
2476
+ kept_cols = eff_kept
2477
+ dropped_cols = list(eff_dropped)
2478
+ n, k = X.shape
2479
+ if n_pos <= k:
2480
+ raise ValueError(
2481
+ f"Only {n_pos} positive-weight observation(s) for "
2482
+ f"{k} parameters (after rank reduction). "
2483
+ f"Cannot identify Poisson model."
2484
+ )
2485
+
2428
2486
  if init_beta is not None:
2429
2487
  beta = init_beta[kept_cols].copy() if len(dropped_cols) > 0 else init_beta.copy()
2430
2488
  else:
@@ -2438,8 +2496,12 @@ def solve_poisson(
2438
2496
  for _ in range(max_iter):
2439
2497
  eta = np.clip(X @ beta, -500, 500)
2440
2498
  mu = np.exp(eta)
2441
- score = X.T @ (y - mu) # gradient of log-likelihood
2442
- hess = X.T @ (mu[:, None] * X) # -Hessian = X'WX, W=diag(mu)
2499
+ if weights is not None:
2500
+ score = X.T @ (weights * (y - mu))
2501
+ hess = X.T @ ((weights * mu)[:, None] * X)
2502
+ else:
2503
+ score = X.T @ (y - mu)
2504
+ hess = X.T @ (mu[:, None] * X)
2443
2505
  try:
2444
2506
  delta = np.linalg.solve(hess + 1e-12 * np.eye(k), score)
2445
2507
  except np.linalg.LinAlgError:
@@ -153,9 +153,6 @@ class CallawaySantAnna(
153
153
  - "rademacher": +1/-1 with equal probability (standard choice)
154
154
  - "mammen": Two-point distribution (asymptotically valid, matches skewness)
155
155
  - "webb": Six-point distribution (recommended when n_clusters < 20)
156
- bootstrap_weight_type : str, optional
157
- .. deprecated:: 1.0.1
158
- Use ``bootstrap_weights`` instead. Will be removed in v3.0.
159
156
  seed : int, optional
160
157
  Random seed for reproducibility.
161
158
  rank_deficient_action : str, default="warn"
@@ -293,7 +290,6 @@ class CallawaySantAnna(
293
290
  cluster: Optional[str] = None,
294
291
  n_bootstrap: int = 0,
295
292
  bootstrap_weights: Optional[str] = None,
296
- bootstrap_weight_type: Optional[str] = None,
297
293
  seed: Optional[int] = None,
298
294
  rank_deficient_action: str = "warn",
299
295
  base_period: str = "varying",
@@ -323,18 +319,7 @@ class CallawaySantAnna(
323
319
  f"pscore_fallback must be 'error' or 'unconditional', " f"got '{pscore_fallback}'"
324
320
  )
325
321
 
326
- # Handle bootstrap_weight_type deprecation
327
- if bootstrap_weight_type is not None:
328
- warnings.warn(
329
- "bootstrap_weight_type is deprecated and will be removed in v3.0. "
330
- "Use bootstrap_weights instead.",
331
- DeprecationWarning,
332
- stacklevel=2,
333
- )
334
- if bootstrap_weights is None:
335
- bootstrap_weights = bootstrap_weight_type
336
-
337
- # Default to rademacher if neither specified
322
+ # Default to rademacher if not specified
338
323
  if bootstrap_weights is None:
339
324
  bootstrap_weights = "rademacher"
340
325
 
@@ -362,8 +347,6 @@ class CallawaySantAnna(
362
347
  self.cluster = cluster
363
348
  self.n_bootstrap = n_bootstrap
364
349
  self.bootstrap_weights = bootstrap_weights
365
- # Keep bootstrap_weight_type for backward compatibility
366
- self.bootstrap_weight_type = bootstrap_weights
367
350
  self.seed = seed
368
351
  self.rank_deficient_action = rank_deficient_action
369
352
  self.base_period = base_period
@@ -3881,8 +3864,6 @@ class CallawaySantAnna(
3881
3864
  "cluster": self.cluster,
3882
3865
  "n_bootstrap": self.n_bootstrap,
3883
3866
  "bootstrap_weights": self.bootstrap_weights,
3884
- # Deprecated but kept for backward compatibility
3885
- "bootstrap_weight_type": self.bootstrap_weight_type,
3886
3867
  "seed": self.seed,
3887
3868
  "rank_deficient_action": self.rank_deficient_action,
3888
3869
  "base_period": self.base_period,
@@ -118,7 +118,7 @@ class CallawaySantAnnaBootstrapMixin:
118
118
 
119
119
  # Type hints for attributes accessed from the main class
120
120
  n_bootstrap: int
121
- bootstrap_weight_type: str
121
+ bootstrap_weights: str
122
122
  alpha: float
123
123
  seed: Optional[int]
124
124
  anticipation: int
@@ -329,7 +329,7 @@ class CallawaySantAnnaBootstrapMixin:
329
329
  if _use_survey_bootstrap:
330
330
  # PSU-level multiplier weights
331
331
  psu_weights, psu_ids = _generate_survey_multiplier_weights_batch(
332
- self.n_bootstrap, resolved_survey_unit, self.bootstrap_weight_type, rng
332
+ self.n_bootstrap, resolved_survey_unit, self.bootstrap_weights, rng
333
333
  )
334
334
  # Build unit → PSU column map
335
335
  if resolved_survey_unit.psu is not None:
@@ -348,7 +348,7 @@ class CallawaySantAnnaBootstrapMixin:
348
348
  else:
349
349
  # Standard unit-level weights (no survey or weights-only)
350
350
  all_bootstrap_weights = _generate_bootstrap_weights_batch(
351
- self.n_bootstrap, n_units, self.bootstrap_weight_type, rng
351
+ self.n_bootstrap, n_units, self.bootstrap_weights, rng
352
352
  )
353
353
 
354
354
  # Vectorized bootstrap ATT(g,t) computation
@@ -534,7 +534,7 @@ class CallawaySantAnnaBootstrapMixin:
534
534
 
535
535
  return CSBootstrapResults(
536
536
  n_bootstrap=self.n_bootstrap,
537
- weight_type=self.bootstrap_weight_type,
537
+ weight_type=self.bootstrap_weights,
538
538
  alpha=self.alpha,
539
539
  overall_att_se=overall_se,
540
540
  overall_att_ci=overall_ci,
@@ -147,7 +147,6 @@ class StaggeredTripleDifference(
147
147
  self.base_period = base_period
148
148
  self.n_bootstrap = n_bootstrap
149
149
  self.bootstrap_weights = bootstrap_weights
150
- self.bootstrap_weight_type = bootstrap_weights
151
150
  self.seed = seed
152
151
  self.cband = cband
153
152
  self.pscore_trim = pscore_trim
@@ -186,7 +185,7 @@ class StaggeredTripleDifference(
186
185
  raise ValueError(f"Unknown parameter: {key}")
187
186
  setattr(self, key, value)
188
187
  if "bootstrap_weights" in params:
189
- self.bootstrap_weight_type = params["bootstrap_weights"]
188
+ self.bootstrap_weights = params["bootstrap_weights"]
190
189
  return self
191
190
 
192
191
  # ------------------------------------------------------------------
@@ -144,14 +144,14 @@ class SyntheticDiD(DifferenceInDifferences):
144
144
  warnings.warn(
145
145
  "lambda_reg is deprecated and ignored. Regularization is now "
146
146
  "auto-computed from data. Use zeta_omega to override unit weight "
147
- "regularization.",
147
+ "regularization. Will be removed in v3.1.",
148
148
  DeprecationWarning,
149
149
  stacklevel=2,
150
150
  )
151
151
  if zeta is not None:
152
152
  warnings.warn(
153
153
  "zeta is deprecated and ignored. Use zeta_lambda to override "
154
- "time weight regularization.",
154
+ "time weight regularization. Will be removed in v3.1.",
155
155
  DeprecationWarning,
156
156
  stacklevel=2,
157
157
  )
@@ -1124,7 +1124,8 @@ class SyntheticDiD(DifferenceInDifferences):
1124
1124
  for key, value in params.items():
1125
1125
  if key in _deprecated:
1126
1126
  warnings.warn(
1127
- f"{key} is deprecated and ignored. Use zeta_omega/zeta_lambda " f"instead.",
1127
+ f"{key} is deprecated and ignored. Use zeta_omega/zeta_lambda "
1128
+ f"instead. Will be removed in v3.1.",
1128
1129
  DeprecationWarning,
1129
1130
  stacklevel=2,
1130
1131
  )
@@ -77,10 +77,6 @@ class TROP(TROPLocalMixin, TROPGlobalMixin):
77
77
  ATT is the mean of these effects. For the paper's full
78
78
  per-treated-cell estimator, use ``method='local'``.
79
79
 
80
- - 'twostep': Deprecated alias for 'local'. Will be removed in v3.0.
81
-
82
- - 'joint': Deprecated alias for 'global'. Will be removed in v3.0.
83
-
84
80
  lambda_time_grid : list, optional
85
81
  Grid of time weight decay parameters. 0.0 = uniform weights (disabled).
86
82
  Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
@@ -140,26 +136,9 @@ class TROP(TROPLocalMixin, TROPGlobalMixin):
140
136
  seed: Optional[int] = None,
141
137
  ):
142
138
  # Validate method parameter
143
- # 'local'/'global' are preferred; 'twostep'/'joint' are deprecated aliases
144
- valid_methods = ("local", "twostep", "joint", "global")
139
+ valid_methods = ("local", "global")
145
140
  if method not in valid_methods:
146
141
  raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
147
- if method == "twostep":
148
- warnings.warn(
149
- "method='twostep' is deprecated and will be removed in v3.0. "
150
- "Use method='local' instead.",
151
- FutureWarning,
152
- stacklevel=2,
153
- )
154
- method = "local"
155
- if method == "joint":
156
- warnings.warn(
157
- "method='joint' is deprecated and will be removed in v3.0. "
158
- "Use method='global' instead.",
159
- FutureWarning,
160
- stacklevel=2,
161
- )
162
- method = "global"
163
142
  self.method = method
164
143
 
165
144
  # Default grids from paper
@@ -913,22 +892,10 @@ class TROP(TROPLocalMixin, TROPGlobalMixin):
913
892
  def set_params(self, **params) -> "TROP":
914
893
  """Set estimator parameters."""
915
894
  for key, value in params.items():
916
- if key == "method" and value == "twostep":
917
- warnings.warn(
918
- "method='twostep' is deprecated and will be removed in "
919
- "v3.0. Use method='local' instead.",
920
- FutureWarning,
921
- stacklevel=2,
922
- )
923
- value = "local"
924
- if key == "method" and value == "joint":
925
- warnings.warn(
926
- "method='joint' is deprecated and will be removed in "
927
- "v3.0. Use method='global' instead.",
928
- FutureWarning,
929
- stacklevel=2,
895
+ if key == "method" and value not in ("local", "global"):
896
+ raise ValueError(
897
+ f"method must be one of ('local', 'global'), got '{value}'"
930
898
  )
931
- value = "global"
932
899
  if hasattr(self, key):
933
900
  setattr(self, key, value)
934
901
  else:
@@ -42,6 +42,7 @@ def _compute_weighted_agg(
42
42
  gt_keys: List,
43
43
  gt_vcov: Optional[np.ndarray],
44
44
  alpha: float,
45
+ df: Optional[int] = None,
45
46
  ) -> Dict:
46
47
  """Compute simple (overall) weighted average ATT and SE via delta method."""
47
48
  post_keys = [(g, t) for (g, t) in gt_keys if t >= g]
@@ -63,10 +64,54 @@ def _compute_weighted_agg(
63
64
  else:
64
65
  se = float("nan")
65
66
 
66
- t_stat, p_value, conf_int = safe_inference(att, se, alpha=alpha)
67
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=alpha, df=df)
67
68
  return {"att": att, "se": se, "t_stat": t_stat, "p_value": p_value, "conf_int": conf_int}
68
69
 
69
70
 
71
+ def _resolve_survey_for_wooldridge(survey_design, sample, cluster_ids, cluster_name):
72
+ """Resolve survey design, inject cluster as PSU, recompute metadata.
73
+
74
+ Shared helper for all three WooldridgeDiD sub-fitters. Matches the
75
+ resolution chain in DifferenceInDifferences.fit() (estimators.py:344-359).
76
+ """
77
+ from diff_diff.survey import (
78
+ _resolve_survey_for_fit,
79
+ _resolve_effective_cluster,
80
+ _inject_cluster_as_psu,
81
+ compute_survey_metadata,
82
+ )
83
+
84
+ resolved, survey_weights, survey_weight_type, survey_metadata = (
85
+ _resolve_survey_for_fit(survey_design, sample)
86
+ )
87
+ if resolved is not None and resolved.uses_replicate_variance:
88
+ raise NotImplementedError(
89
+ "WooldridgeDiD does not yet support replicate-weight variance. "
90
+ "Use TSL (strata/PSU/FPC) instead."
91
+ )
92
+ if resolved is not None and resolved.weight_type != "pweight":
93
+ raise ValueError(
94
+ f"WooldridgeDiD survey support requires weight_type='pweight', "
95
+ f"got '{resolved.weight_type}'. The survey variance math "
96
+ f"assumes probability weights (pweight)."
97
+ )
98
+ if resolved is not None:
99
+ effective_cluster = _resolve_effective_cluster(
100
+ resolved, cluster_ids, cluster_name
101
+ )
102
+ if effective_cluster is not None:
103
+ resolved = _inject_cluster_as_psu(resolved, effective_cluster)
104
+ if resolved.psu is not None and survey_metadata is not None:
105
+ raw_w = (
106
+ sample[survey_design.weights].values.astype(np.float64)
107
+ if survey_design.weights
108
+ else np.ones(len(sample), dtype=np.float64)
109
+ )
110
+ survey_metadata = compute_survey_metadata(resolved, raw_w)
111
+ df_inf = resolved.df_survey if resolved is not None else None
112
+ return resolved, survey_weights, survey_weight_type, survey_metadata, df_inf
113
+
114
+
70
115
  def _filter_sample(
71
116
  data: pd.DataFrame,
72
117
  unit: str,
@@ -329,6 +374,7 @@ class WooldridgeDiD:
329
374
  exovar: Optional[List[str]] = None,
330
375
  xtvar: Optional[List[str]] = None,
331
376
  xgvar: Optional[List[str]] = None,
377
+ survey_design=None,
332
378
  ) -> WooldridgeDiDResults:
333
379
  """Fit the ETWFE model. See class docstring for parameter details.
334
380
 
@@ -343,6 +389,11 @@ class WooldridgeDiD:
343
389
  xtvar : time-varying covariates (demeaned within cohort×period cells
344
390
  when ``demean_covariates=True``)
345
391
  xgvar : covariates interacted with each cohort indicator
392
+ survey_design : SurveyDesign, optional
393
+ Survey design specification for complex survey data. Supports
394
+ stratified, clustered, and weighted designs via Taylor Series
395
+ Linearization (TSL). Replicate-weight designs raise
396
+ ``NotImplementedError``.
346
397
  """
347
398
  df = data.copy()
348
399
  df[cohort] = df[cohort].fillna(0)
@@ -366,6 +417,13 @@ class WooldridgeDiD:
366
417
  f"Set n_bootstrap=0 for analytic SEs."
367
418
  )
368
419
 
420
+ # 0c. Reject bootstrap + survey (no survey-aware bootstrap variant)
421
+ if self.n_bootstrap > 0 and survey_design is not None:
422
+ raise ValueError(
423
+ "Bootstrap inference is not supported with survey_design. "
424
+ "Set n_bootstrap=0 for analytic survey SEs."
425
+ )
426
+
369
427
  # 1. Filter to analysis sample
370
428
  sample = _filter_sample(df, unit, time, cohort, self.control_group, self.anticipation)
371
429
 
@@ -502,6 +560,7 @@ class WooldridgeDiD:
502
560
  gt_keys,
503
561
  int_col_names,
504
562
  groups,
563
+ survey_design=survey_design,
505
564
  )
506
565
  elif self.method == "logit":
507
566
  n_cov_interact = X_cov.shape[1] if X_cov is not None else 0
@@ -517,6 +576,7 @@ class WooldridgeDiD:
517
576
  int_col_names,
518
577
  groups,
519
578
  n_cov_interact=n_cov_interact,
579
+ survey_design=survey_design,
520
580
  )
521
581
  else: # poisson
522
582
  n_cov_interact = X_cov.shape[1] if X_cov is not None else 0
@@ -532,6 +592,7 @@ class WooldridgeDiD:
532
592
  int_col_names,
533
593
  groups,
534
594
  n_cov_interact=n_cov_interact,
595
+ survey_design=survey_design,
535
596
  )
536
597
 
537
598
  self._results = results
@@ -561,8 +622,21 @@ class WooldridgeDiD:
561
622
  gt_keys: List[Tuple],
562
623
  int_col_names: List[str],
563
624
  groups: List[Any],
625
+ survey_design=None,
564
626
  ) -> WooldridgeDiDResults:
565
627
  """OLS path: within-transform FE, solve_ols, cluster SE."""
628
+ # Reset index so numpy positional indexing matches pandas groupby
629
+ sample = sample.reset_index(drop=True)
630
+ # Cluster IDs (default: unit level) — needed before survey resolution
631
+ cluster_col = self.cluster if self.cluster else unit
632
+ cluster_ids = sample[cluster_col].values
633
+
634
+ # Resolve survey design, inject cluster as PSU only when user explicitly set cluster=
635
+ survey_cluster_ids = cluster_ids if self.cluster else None
636
+ resolved, survey_weights, survey_weight_type, survey_metadata, df_inf = (
637
+ _resolve_survey_for_wooldridge(survey_design, sample, survey_cluster_ids, self.cluster)
638
+ )
639
+
566
640
  # 4. Within-transform: absorb unit + time FE
567
641
  all_vars = [outcome] + [f"_x{i}" for i in range(X_design.shape[1])]
568
642
  tmp = sample[[unit, time]].copy()
@@ -570,32 +644,60 @@ class WooldridgeDiD:
570
644
  for i in range(X_design.shape[1]):
571
645
  tmp[f"_x{i}"] = X_design[:, i]
572
646
 
573
- # Use uniform weights to trigger iterative alternating projections,
574
- # which is exact for both balanced and unbalanced panels.
575
- # The one-pass formula (y - ȳ_i - ȳ_t + ȳ) is only exact for balanced panels.
647
+ # Use iterative alternating projections for demeaning (exact for
648
+ # both balanced and unbalanced panels). Survey weights change the
649
+ # weighted FWL projection all columns (treatment interactions +
650
+ # covariates) are demeaned together.
651
+ wt_weights = survey_weights if survey_weights is not None else np.ones(len(tmp))
652
+
653
+ # Guard: zero-weight unit/time groups cause 0/0 in within_transform
654
+ if survey_weights is not None and np.any(survey_weights == 0):
655
+ sw_series = pd.Series(survey_weights, index=sample.index)
656
+ for grp_col, grp_label in [(unit, "unit"), (time, "time period")]:
657
+ grp_sums = sw_series.groupby(sample[grp_col]).sum()
658
+ zero_grps = grp_sums[grp_sums == 0].index.tolist()
659
+ if zero_grps:
660
+ raise ValueError(
661
+ f"Survey weights sum to zero for {grp_label}(s) "
662
+ f"{zero_grps[:3]}. Cannot compute weighted "
663
+ f"within-transformation. Remove zero-weight "
664
+ f"{grp_label}s or use non-zero weights."
665
+ )
666
+
576
667
  transformed = within_transform(
577
668
  tmp, all_vars, unit=unit, time=time, suffix="_demeaned",
578
- weights=np.ones(len(tmp)),
669
+ weights=wt_weights,
579
670
  )
580
671
 
581
672
  y = transformed[f"{outcome}_demeaned"].values
582
673
  X_cols = [f"_x{i}_demeaned" for i in range(X_design.shape[1])]
583
674
  X = transformed[X_cols].values
584
675
 
585
- # 5. Cluster IDs (default: unit level)
586
- cluster_col = self.cluster if self.cluster else unit
587
- cluster_ids = sample[cluster_col].values
588
-
589
- # 6. Solve OLS
676
+ # 6. Solve OLS (skip cluster-robust vcov when survey will provide TSL vcov)
590
677
  coefs, resids, vcov = solve_ols(
591
678
  X,
592
679
  y,
593
680
  cluster_ids=cluster_ids,
594
- return_vcov=True,
681
+ return_vcov=(resolved is None),
595
682
  rank_deficient_action=self.rank_deficient_action,
596
683
  column_names=col_names,
684
+ weights=survey_weights,
685
+ weight_type=survey_weight_type,
597
686
  )
598
687
 
688
+ # Survey TSL vcov replaces cluster-robust vcov
689
+ if resolved is not None:
690
+ from diff_diff.survey import compute_survey_vcov
691
+ nan_mask_ols = np.isnan(coefs)
692
+ if np.any(nan_mask_ols):
693
+ kept = ~nan_mask_ols
694
+ vcov_kept = compute_survey_vcov(X[:, kept], resids, resolved)
695
+ vcov = np.full((len(coefs), len(coefs)), np.nan)
696
+ kept_idx = np.where(kept)[0]
697
+ vcov[np.ix_(kept_idx, kept_idx)] = vcov_kept
698
+ else:
699
+ vcov = compute_survey_vcov(X, resids, resolved)
700
+
599
701
  # 7. Extract β_{g,t} and build gt_effects dict
600
702
  gt_effects: Dict[Tuple, Dict] = {}
601
703
  gt_weights: Dict[Tuple, int] = {}
@@ -607,7 +709,7 @@ class WooldridgeDiD:
607
709
  continue
608
710
  att = float(coefs[idx])
609
711
  se = float(np.sqrt(max(vcov[idx, idx], 0.0))) if vcov is not None else float("nan")
610
- t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha)
712
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_inf)
611
713
  gt_effects[(g, t)] = {
612
714
  "att": att,
613
715
  "se": se,
@@ -628,7 +730,7 @@ class WooldridgeDiD:
628
730
 
629
731
  # 8. Simple aggregation (always computed)
630
732
  overall = _compute_weighted_agg(
631
- gt_effects, gt_weights, gt_keys_ordered, gt_vcov, self.alpha
733
+ gt_effects, gt_weights, gt_keys_ordered, gt_vcov, self.alpha, df=df_inf
632
734
  )
633
735
 
634
736
  # Metadata
@@ -652,9 +754,11 @@ class WooldridgeDiD:
652
754
  n_control_units=n_control,
653
755
  alpha=self.alpha,
654
756
  anticipation=self.anticipation,
757
+ survey_metadata=survey_metadata,
655
758
  _gt_weights=gt_weights,
656
759
  _gt_vcov=gt_vcov,
657
760
  _gt_keys=gt_keys_ordered,
761
+ _df_survey=df_inf,
658
762
  )
659
763
 
660
764
  # 9. Optional multiplier bootstrap (overrides analytic SE for overall ATT)
@@ -723,6 +827,7 @@ class WooldridgeDiD:
723
827
  int_col_names: List[str],
724
828
  groups: List[Any],
725
829
  n_cov_interact: int = 0,
830
+ survey_design=None,
726
831
  ) -> WooldridgeDiDResults:
727
832
  """Logit path: cohort + time additive FEs + solve_logit + ASF ATT.
728
833
 
@@ -749,10 +854,18 @@ class WooldridgeDiD:
749
854
  cluster_col = self.cluster if self.cluster else unit
750
855
  cluster_ids = sample[cluster_col].values
751
856
 
857
+ # Resolve survey design, inject cluster as PSU only when user explicitly set cluster=
858
+ survey_cluster_ids = cluster_ids if self.cluster else None
859
+ resolved, survey_weights, survey_weight_type, survey_metadata, df_inf = (
860
+ _resolve_survey_for_wooldridge(survey_design, sample, survey_cluster_ids, self.cluster)
861
+ )
862
+ _has_survey = resolved is not None
863
+
752
864
  beta, probs = solve_logit(
753
865
  X_full,
754
866
  y,
755
867
  rank_deficient_action=self.rank_deficient_action,
868
+ weights=survey_weights,
756
869
  )
757
870
  # solve_logit prepends intercept — beta[0] is intercept, beta[1:] are X_full cols
758
871
  beta_int_cols = beta[1 : n_int + 1] # treatment interaction coefficients
@@ -763,34 +876,65 @@ class WooldridgeDiD:
763
876
  beta_clean = np.where(nan_mask, 0.0, beta)
764
877
  kept_beta = ~nan_mask
765
878
 
766
- # QMLE sandwich vcov via shared linalg backend
879
+ # QMLE sandwich vcov
767
880
  resids = y - probs
768
881
  X_with_intercept = np.column_stack([np.ones(len(y)), X_full])
769
- if np.any(nan_mask):
770
- # Compute vcov on reduced design (only identified columns)
771
- X_reduced = X_with_intercept[:, kept_beta]
772
- vcov_reduced = compute_robust_vcov(
773
- X_reduced,
774
- resids,
775
- cluster_ids=cluster_ids,
776
- weights=probs * (1 - probs),
777
- weight_type="aweight",
778
- )
779
- # Expand back to full size with NaN for dropped columns
780
- k_full = len(beta)
781
- vcov_full = np.full((k_full, k_full), np.nan)
782
- kept_idx = np.where(kept_beta)[0]
783
- vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
882
+
883
+ if _has_survey:
884
+ # X_tilde trick: transform design matrix so compute_survey_vcov
885
+ # produces the correct QMLE sandwich for nonlinear models.
886
+ # Bread: (X_tilde'WX_tilde)^{-1} = (X'diag(w*V)X)^{-1}
887
+ # Scores: w*X_tilde*r_tilde = w*X*(y-mu)
888
+ from diff_diff.survey import compute_survey_vcov
889
+ V = probs * (1 - probs)
890
+ sqrt_V = np.sqrt(np.clip(V, 1e-20, None))
891
+ X_tilde = X_with_intercept * sqrt_V[:, None]
892
+ r_tilde = resids / sqrt_V
893
+ if np.any(nan_mask):
894
+ X_tilde_r = X_tilde[:, kept_beta]
895
+ vcov_reduced = compute_survey_vcov(X_tilde_r, r_tilde, resolved)
896
+ k_full = len(beta)
897
+ vcov_full = np.full((k_full, k_full), np.nan)
898
+ kept_idx = np.where(kept_beta)[0]
899
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
900
+ else:
901
+ vcov_full = compute_survey_vcov(X_tilde, r_tilde, resolved)
784
902
  else:
785
- vcov_full = compute_robust_vcov(
786
- X_with_intercept,
787
- resids,
788
- cluster_ids=cluster_ids,
789
- weights=probs * (1 - probs),
790
- weight_type="aweight",
791
- )
903
+ # Cluster-robust QMLE sandwich (non-survey path)
904
+ if np.any(nan_mask):
905
+ X_reduced = X_with_intercept[:, kept_beta]
906
+ vcov_reduced = compute_robust_vcov(
907
+ X_reduced,
908
+ resids,
909
+ cluster_ids=cluster_ids,
910
+ weights=probs * (1 - probs),
911
+ weight_type="aweight",
912
+ )
913
+ k_full = len(beta)
914
+ vcov_full = np.full((k_full, k_full), np.nan)
915
+ kept_idx = np.where(kept_beta)[0]
916
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
917
+ else:
918
+ vcov_full = compute_robust_vcov(
919
+ X_with_intercept,
920
+ resids,
921
+ cluster_ids=cluster_ids,
922
+ weights=probs * (1 - probs),
923
+ weight_type="aweight",
924
+ )
792
925
  beta = beta_clean
793
926
 
927
+ # Survey-weighted averaging helpers for ASF computation
928
+ def _avg(a, cell_mask):
929
+ if survey_weights is not None:
930
+ return float(np.average(a, weights=survey_weights[cell_mask]))
931
+ return float(np.mean(a))
932
+
933
+ def _avg_ax0(a, cell_mask):
934
+ if survey_weights is not None:
935
+ return np.average(a, weights=survey_weights[cell_mask], axis=0)
936
+ return np.mean(a, axis=0)
937
+
794
938
  # ASF ATT(g,t) for treated units in each cell
795
939
  gt_effects: Dict[Tuple, Dict] = {}
796
940
  gt_weights: Dict[Tuple, int] = {}
@@ -802,6 +946,9 @@ class WooldridgeDiD:
802
946
  if cell_mask.sum() == 0:
803
947
  continue
804
948
  # Skip cells whose interaction coefficient was dropped (rank deficiency)
949
+ # Skip cells where all survey weights are zero (non-estimable)
950
+ if survey_weights is not None and np.sum(survey_weights[cell_mask]) == 0:
951
+ continue
805
952
  delta = beta_int_cols[idx]
806
953
  if np.isnan(delta):
807
954
  continue
@@ -816,26 +963,26 @@ class WooldridgeDiD:
816
963
  x_hat_j = X_with_intercept[cell_mask, coef_pos]
817
964
  delta_total = delta_total + beta[coef_pos] * x_hat_j
818
965
  eta_0 = eta_base - delta_total
819
- att = float(np.mean(_logistic(eta_base) - _logistic(eta_0)))
966
+ att = _avg(_logistic(eta_base) - _logistic(eta_0), cell_mask)
820
967
  # Delta method gradient: d(ATT)/d(β)
821
968
  # for nuisance p: mean_i[(Λ'(η_1) - Λ'(η_0)) * X_p]
822
969
  # for cell intercept: mean_i[Λ'(η_1)]
823
970
  # for cell × cov j: mean_i[Λ'(η_1) * x_hat_j]
824
971
  d_diff = _logistic_deriv(eta_base) - _logistic_deriv(eta_0)
825
- grad = np.mean(X_with_intercept[cell_mask] * d_diff[:, None], axis=0)
826
- grad[1 + idx] = float(np.mean(_logistic_deriv(eta_base)))
972
+ grad = _avg_ax0(X_with_intercept[cell_mask] * d_diff[:, None], cell_mask)
973
+ grad[1 + idx] = _avg(_logistic_deriv(eta_base), cell_mask)
827
974
  for j in range(n_cov_interact):
828
975
  coef_pos = 1 + n_int + idx * n_cov_interact + j
829
976
  if coef_pos < len(beta):
830
977
  x_hat_j = X_with_intercept[cell_mask, coef_pos]
831
- grad[coef_pos] = float(np.mean(_logistic_deriv(eta_base) * x_hat_j))
978
+ grad[coef_pos] = _avg(_logistic_deriv(eta_base) * x_hat_j, cell_mask)
832
979
  # Compute SE in reduced parameter space if rank-deficient
833
980
  if np.any(nan_mask):
834
981
  grad_r = grad[kept_beta]
835
982
  se = float(np.sqrt(max(grad_r @ vcov_reduced @ grad_r, 0.0)))
836
983
  else:
837
984
  se = float(np.sqrt(max(grad @ vcov_full @ grad, 0.0)))
838
- t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha)
985
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_inf)
839
986
  gt_effects[(g, t)] = {
840
987
  "att": att,
841
988
  "se": se,
@@ -864,7 +1011,7 @@ class WooldridgeDiD:
864
1011
  overall_att = sum(gt_weights[k] * gt_effects[k]["att"] for k in post_keys) / w_total
865
1012
  agg_grad = sum((gt_weights[k] / w_total) * gt_grads[k] for k in post_keys)
866
1013
  overall_se = float(np.sqrt(max(agg_grad @ _vcov_se @ agg_grad, 0.0)))
867
- t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha)
1014
+ t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha, df=df_inf)
868
1015
  overall = {
869
1016
  "att": overall_att,
870
1017
  "se": overall_se,
@@ -874,7 +1021,7 @@ class WooldridgeDiD:
874
1021
  }
875
1022
  else:
876
1023
  overall = _compute_weighted_agg(
877
- gt_effects, gt_weights, gt_keys_ordered, None, self.alpha
1024
+ gt_effects, gt_weights, gt_keys_ordered, None, self.alpha, df=df_inf
878
1025
  )
879
1026
 
880
1027
  return WooldridgeDiDResults(
@@ -893,9 +1040,11 @@ class WooldridgeDiD:
893
1040
  n_control_units=self._count_control_units(sample, unit, cohort, time),
894
1041
  alpha=self.alpha,
895
1042
  anticipation=self.anticipation,
1043
+ survey_metadata=survey_metadata,
896
1044
  _gt_weights=gt_weights,
897
1045
  _gt_vcov=gt_vcov,
898
1046
  _gt_keys=gt_keys_ordered,
1047
+ _df_survey=df_inf,
899
1048
  )
900
1049
 
901
1050
  def _fit_poisson(
@@ -911,6 +1060,7 @@ class WooldridgeDiD:
911
1060
  int_col_names: List[str],
912
1061
  groups: List[Any],
913
1062
  n_cov_interact: int = 0,
1063
+ survey_design=None,
914
1064
  ) -> WooldridgeDiDResults:
915
1065
  """Poisson path: cohort + time additive FEs + solve_poisson + ASF ATT.
916
1066
 
@@ -940,7 +1090,18 @@ class WooldridgeDiD:
940
1090
  cluster_col = self.cluster if self.cluster else unit
941
1091
  cluster_ids = sample[cluster_col].values
942
1092
 
943
- beta, mu_hat = solve_poisson(X_full, y, rank_deficient_action=self.rank_deficient_action)
1093
+ # Resolve survey design, inject cluster as PSU only when user explicitly set cluster=
1094
+ survey_cluster_ids = cluster_ids if self.cluster else None
1095
+ resolved, survey_weights, survey_weight_type, survey_metadata, df_inf = (
1096
+ _resolve_survey_for_wooldridge(survey_design, sample, survey_cluster_ids, self.cluster)
1097
+ )
1098
+ _has_survey = resolved is not None
1099
+
1100
+ beta, mu_hat = solve_poisson(
1101
+ X_full, y,
1102
+ rank_deficient_action=self.rank_deficient_action,
1103
+ weights=survey_weights,
1104
+ )
944
1105
 
945
1106
  # Handle rank-deficient designs: compute vcov on reduced design.
946
1107
  # Preserve raw interaction coefficients BEFORE zeroing NaN so the
@@ -950,34 +1111,63 @@ class WooldridgeDiD:
950
1111
  beta_clean = np.where(nan_mask, 0.0, beta)
951
1112
  kept_beta = ~nan_mask
952
1113
 
953
- # QMLE sandwich vcov via shared linalg backend
1114
+ # QMLE sandwich vcov
954
1115
  resids = y - mu_hat
955
- if np.any(nan_mask):
956
- X_reduced = X_full[:, kept_beta]
957
- vcov_reduced = compute_robust_vcov(
958
- X_reduced,
959
- resids,
960
- cluster_ids=cluster_ids,
961
- weights=mu_hat,
962
- weight_type="aweight",
963
- )
964
- k_full = len(beta)
965
- vcov_full = np.full((k_full, k_full), np.nan)
966
- kept_idx = np.where(kept_beta)[0]
967
- vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
1116
+
1117
+ if _has_survey:
1118
+ # X_tilde trick for nonlinear survey vcov (V = mu for Poisson)
1119
+ from diff_diff.survey import compute_survey_vcov
1120
+ sqrt_V = np.sqrt(np.clip(mu_hat, 1e-20, None))
1121
+ X_tilde = X_full * sqrt_V[:, None]
1122
+ r_tilde = resids / sqrt_V
1123
+ if np.any(nan_mask):
1124
+ X_tilde_r = X_tilde[:, kept_beta]
1125
+ vcov_reduced = compute_survey_vcov(X_tilde_r, r_tilde, resolved)
1126
+ k_full = len(beta)
1127
+ vcov_full = np.full((k_full, k_full), np.nan)
1128
+ kept_idx = np.where(kept_beta)[0]
1129
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
1130
+ else:
1131
+ vcov_full = compute_survey_vcov(X_tilde, r_tilde, resolved)
968
1132
  else:
969
- vcov_full = compute_robust_vcov(
970
- X_full,
971
- resids,
972
- cluster_ids=cluster_ids,
973
- weights=mu_hat,
974
- weight_type="aweight",
975
- )
1133
+ # Cluster-robust QMLE sandwich (non-survey path)
1134
+ if np.any(nan_mask):
1135
+ X_reduced = X_full[:, kept_beta]
1136
+ vcov_reduced = compute_robust_vcov(
1137
+ X_reduced,
1138
+ resids,
1139
+ cluster_ids=cluster_ids,
1140
+ weights=mu_hat,
1141
+ weight_type="aweight",
1142
+ )
1143
+ k_full = len(beta)
1144
+ vcov_full = np.full((k_full, k_full), np.nan)
1145
+ kept_idx = np.where(kept_beta)[0]
1146
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
1147
+ else:
1148
+ vcov_full = compute_robust_vcov(
1149
+ X_full,
1150
+ resids,
1151
+ cluster_ids=cluster_ids,
1152
+ weights=mu_hat,
1153
+ weight_type="aweight",
1154
+ )
976
1155
  beta = beta_clean
977
1156
 
978
1157
  # Treatment interaction coefficients (from cleaned beta for computation)
979
1158
  beta_int = beta[1 : 1 + n_int]
980
1159
 
1160
+ # Survey-weighted averaging helpers for ASF computation
1161
+ def _avg(a, cell_mask):
1162
+ if survey_weights is not None:
1163
+ return float(np.average(a, weights=survey_weights[cell_mask]))
1164
+ return float(np.mean(a))
1165
+
1166
+ def _avg_ax0(a, cell_mask):
1167
+ if survey_weights is not None:
1168
+ return np.average(a, weights=survey_weights[cell_mask], axis=0)
1169
+ return np.mean(a, axis=0)
1170
+
981
1171
  # ASF ATT(g,t) for treated units in each cell.
982
1172
  # eta_base = X_full @ beta already includes the treatment effect (D_{g,t}=1).
983
1173
  # Counterfactual: eta_0 = eta_base - delta (treatment switched off).
@@ -995,6 +1185,9 @@ class WooldridgeDiD:
995
1185
  # Use raw coefficients (before NaN->0 zeroing) to detect dropped cells.
996
1186
  if np.isnan(beta_int_raw[idx]):
997
1187
  continue
1188
+ # Skip cells where all survey weights are zero (non-estimable)
1189
+ if survey_weights is not None and np.sum(survey_weights[cell_mask]) == 0:
1190
+ continue
998
1191
  delta = beta_int[idx]
999
1192
  if np.isnan(delta):
1000
1193
  continue
@@ -1009,26 +1202,26 @@ class WooldridgeDiD:
1009
1202
  eta_0 = eta_base - delta_total
1010
1203
  mu_1 = np.exp(eta_base)
1011
1204
  mu_0 = np.exp(eta_0)
1012
- att = float(np.mean(mu_1 - mu_0))
1205
+ att = _avg(mu_1 - mu_0, cell_mask)
1013
1206
  # Delta method gradient:
1014
1207
  # for nuisance p: mean_i[(μ_1 - μ_0) * X_p]
1015
1208
  # for cell intercept: mean_i[μ_1]
1016
1209
  # for cell × cov j: mean_i[μ_1 * x_hat_j]
1017
1210
  diff_mu = mu_1 - mu_0
1018
- grad = np.mean(X_full[cell_mask] * diff_mu[:, None], axis=0)
1019
- grad[1 + idx] = float(np.mean(mu_1))
1211
+ grad = _avg_ax0(X_full[cell_mask] * diff_mu[:, None], cell_mask)
1212
+ grad[1 + idx] = _avg(mu_1, cell_mask)
1020
1213
  for j in range(n_cov_interact):
1021
1214
  coef_pos = 1 + n_int + idx * n_cov_interact + j
1022
1215
  if coef_pos < len(beta):
1023
1216
  x_hat_j = X_full[cell_mask, coef_pos]
1024
- grad[coef_pos] = float(np.mean(mu_1 * x_hat_j))
1217
+ grad[coef_pos] = _avg(mu_1 * x_hat_j, cell_mask)
1025
1218
  # Compute SE in reduced parameter space if rank-deficient
1026
1219
  if np.any(nan_mask):
1027
1220
  grad_r = grad[kept_beta]
1028
1221
  se = float(np.sqrt(max(grad_r @ vcov_reduced @ grad_r, 0.0)))
1029
1222
  else:
1030
1223
  se = float(np.sqrt(max(grad @ vcov_full @ grad, 0.0)))
1031
- t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha)
1224
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_inf)
1032
1225
  gt_effects[(g, t)] = {
1033
1226
  "att": att,
1034
1227
  "se": se,
@@ -1055,7 +1248,7 @@ class WooldridgeDiD:
1055
1248
  overall_att = sum(gt_weights[k] * gt_effects[k]["att"] for k in post_keys) / w_total
1056
1249
  agg_grad = sum((gt_weights[k] / w_total) * gt_grads[k] for k in post_keys)
1057
1250
  overall_se = float(np.sqrt(max(agg_grad @ _vcov_se @ agg_grad, 0.0)))
1058
- t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha)
1251
+ t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha, df=df_inf)
1059
1252
  overall = {
1060
1253
  "att": overall_att,
1061
1254
  "se": overall_se,
@@ -1065,7 +1258,7 @@ class WooldridgeDiD:
1065
1258
  }
1066
1259
  else:
1067
1260
  overall = _compute_weighted_agg(
1068
- gt_effects, gt_weights, gt_keys_ordered, None, self.alpha
1261
+ gt_effects, gt_weights, gt_keys_ordered, None, self.alpha, df=df_inf
1069
1262
  )
1070
1263
 
1071
1264
  return WooldridgeDiDResults(
@@ -1084,7 +1277,9 @@ class WooldridgeDiD:
1084
1277
  n_control_units=self._count_control_units(sample, unit, cohort, time),
1085
1278
  alpha=self.alpha,
1086
1279
  anticipation=self.anticipation,
1280
+ survey_metadata=survey_metadata,
1087
1281
  _gt_weights=gt_weights,
1088
1282
  _gt_vcov=gt_vcov,
1089
1283
  _gt_keys=gt_keys_ordered,
1284
+ _df_survey=df_inf,
1090
1285
  )
@@ -54,6 +54,7 @@ class WooldridgeDiDResults:
54
54
  n_control_units: int = 0
55
55
  alpha: float = 0.05
56
56
  anticipation: int = 0
57
+ survey_metadata: Optional[Any] = field(default=None, repr=False)
57
58
 
58
59
  # ------------------------------------------------------------------ #
59
60
  # Internal — used by aggregate() for delta-method SEs #
@@ -63,6 +64,8 @@ class WooldridgeDiDResults:
63
64
  """Full vcov of all β_{g,t} coefficients (ordered same as sorted group_time_effects keys)."""
64
65
  _gt_keys: List[Tuple[Any, Any]] = field(default_factory=list, repr=False)
65
66
  """Ordered list of (g,t) keys corresponding to _gt_vcov columns."""
67
+ _df_survey: Optional[int] = field(default=None, repr=False)
68
+ """Survey degrees of freedom for t-distribution inference."""
66
69
 
67
70
  # ------------------------------------------------------------------ #
68
71
  # Public methods #
@@ -93,7 +96,7 @@ class WooldridgeDiDResults:
93
96
  return float(np.sqrt(max(w_vec @ vcov @ w_vec, 0.0)))
94
97
 
95
98
  def _build_effect(att: float, se: float) -> Dict[str, Any]:
96
- t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha)
99
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=self._df_survey)
97
100
  return {
98
101
  "att": att,
99
102
  "se": se,
@@ -181,6 +184,11 @@ class WooldridgeDiDResults:
181
184
  "-" * 70,
182
185
  ]
183
186
 
187
+ if self.survey_metadata is not None:
188
+ from diff_diff.results import _format_survey_block
189
+ lines.extend(_format_survey_block(self.survey_metadata, 70))
190
+ lines.append("-" * 70)
191
+
184
192
  def _fmt_row(label: str, att: float, se: float, t: float, p: float, ci: Tuple) -> str:
185
193
  from diff_diff.results import _get_significance_stars # type: ignore
186
194
 
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "diff-diff"
7
- version = "2.9.1"
7
+ version = "3.0.0"
8
8
  description = "Difference-in-Differences causal inference with sklearn-like API. Callaway-Sant'Anna, Synthetic DiD, Honest DiD, event studies, parallel trends."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -197,7 +197,7 @@ checksum = "930c7171c8df9fb1782bdf9b918ed9ed2d33d1d22300abb754f9085bc48bf8e8"
197
197
 
198
198
  [[package]]
199
199
  name = "diff_diff_rust"
200
- version = "2.9.1"
200
+ version = "3.0.0"
201
201
  dependencies = [
202
202
  "blas-src",
203
203
  "faer",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "diff_diff_rust"
3
- version = "2.9.1"
3
+ version = "3.0.0"
4
4
  edition = "2021"
5
5
  description = "Rust backend for diff-diff DiD library"
6
6
  license = "MIT"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes