diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/survey.py ADDED
@@ -0,0 +1,1981 @@
1
+ """
2
+ Survey data support for diff-diff.
3
+
4
+ Provides SurveyDesign for specifying complex survey structures (stratification,
5
+ clustering, weights, FPC) and Taylor Series Linearization (TSL) variance
6
+ estimation for design-based inference.
7
+
8
+ References
9
+ ----------
10
+ - Lumley (2004) "Analysis of Complex Survey Samples", JSS 9(8).
11
+ - Binder (1983) "On the Variances of Asymptotically Normal Estimators
12
+ from Complex Surveys", International Statistical Review 51(3).
13
+ - Solon, Haider, & Wooldridge (2015) "What Are We Weighting For?",
14
+ Journal of Human Resources 50(2).
15
+ """
16
+
17
+ import warnings
18
+ from dataclasses import dataclass, field, replace
19
+ from typing import Callable, List, Optional, Tuple
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+ from diff_diff.linalg import _factorize_cluster_ids
25
+
26
+
27
+ @dataclass
28
+ class SurveyDesign:
29
+ """
30
+ User-facing class specifying complex survey design structure.
31
+
32
+ Column names are resolved against the DataFrame at fit-time via
33
+ :meth:`resolve`.
34
+
35
+ Parameters
36
+ ----------
37
+ weights : str, optional
38
+ Column name for observation weights (sampling weights).
39
+ strata : str, optional
40
+ Column name for stratification variable.
41
+ psu : str, optional
42
+ Column name for primary sampling unit (cluster).
43
+ fpc : str, optional
44
+ Column name for finite population size (N_h per stratum).
45
+ weight_type : str, default "pweight"
46
+ Weight type: "pweight" (inverse selection probability),
47
+ "fweight" (frequency/expansion), or "aweight" (inverse variance).
48
+ nest : bool, default False
49
+ Whether PSU IDs are nested within strata (i.e., PSU IDs may repeat
50
+ across strata). If True, PSU IDs are made unique by combining with
51
+ strata.
52
+ lonely_psu : str, default "remove"
53
+ How to handle singleton strata (strata with only one PSU):
54
+ "remove" (skip, emit warning), "certainty" (set f_h=1, zero
55
+ variance contribution), or "adjust" (center around grand mean).
56
+ """
57
+
58
+ weights: Optional[str] = None
59
+ strata: Optional[str] = None
60
+ psu: Optional[str] = None
61
+ fpc: Optional[str] = None
62
+ weight_type: str = "pweight"
63
+ nest: bool = False
64
+ lonely_psu: str = "remove"
65
+ replicate_weights: Optional[List[str]] = None
66
+ replicate_method: Optional[str] = None
67
+ fay_rho: float = 0.0
68
+ replicate_strata: Optional[List[int]] = None
69
+ combined_weights: bool = True
70
+ replicate_scale: Optional[float] = None
71
+ replicate_rscales: Optional[List[float]] = None
72
+ mse: bool = False
73
+
74
+ def __post_init__(self):
75
+ valid_weight_types = {"pweight", "fweight", "aweight"}
76
+ if self.weight_type not in valid_weight_types:
77
+ raise ValueError(
78
+ f"weight_type must be one of {valid_weight_types}, " f"got '{self.weight_type}'"
79
+ )
80
+ valid_lonely = {"remove", "certainty", "adjust"}
81
+ if self.lonely_psu not in valid_lonely:
82
+ raise ValueError(
83
+ f"lonely_psu must be one of {valid_lonely}, " f"got '{self.lonely_psu}'"
84
+ )
85
+ # Replicate weight validation
86
+ valid_rep_methods = {"BRR", "Fay", "JK1", "JKn", "SDR"}
87
+ if self.replicate_method is not None:
88
+ if self.replicate_method not in valid_rep_methods:
89
+ raise ValueError(
90
+ f"replicate_method must be one of {valid_rep_methods}, "
91
+ f"got '{self.replicate_method}'"
92
+ )
93
+ if self.replicate_weights is None:
94
+ raise ValueError("replicate_weights must be provided when replicate_method is set")
95
+ if self.replicate_weights is not None and self.replicate_method is None:
96
+ raise ValueError("replicate_method must be provided when replicate_weights is set")
97
+ if self.replicate_method == "Fay":
98
+ if not (0 < self.fay_rho < 1):
99
+ raise ValueError(f"fay_rho must be in (0, 1) for Fay's method, got {self.fay_rho}")
100
+ elif self.replicate_method is not None and self.fay_rho != 0.0:
101
+ raise ValueError(
102
+ f"fay_rho must be 0 for method '{self.replicate_method}', " f"got {self.fay_rho}"
103
+ )
104
+ # Replicate weights are mutually exclusive with strata/psu/fpc
105
+ if self.replicate_weights is not None:
106
+ if self.strata is not None or self.psu is not None or self.fpc is not None:
107
+ raise ValueError(
108
+ "replicate_weights cannot be combined with strata/psu/fpc. "
109
+ "Replicate weights encode the design structure implicitly."
110
+ )
111
+ if self.weights is None:
112
+ raise ValueError("Full-sample weights must be provided alongside replicate_weights")
113
+ # JKn requires replicate_strata
114
+ if self.replicate_method == "JKn":
115
+ if self.replicate_strata is None:
116
+ raise ValueError(
117
+ "replicate_strata is required for JKn method. Provide a list "
118
+ "of stratum assignments (one per replicate weight column)."
119
+ )
120
+ if len(self.replicate_strata) != len(self.replicate_weights):
121
+ raise ValueError(
122
+ f"replicate_strata length ({len(self.replicate_strata)}) must "
123
+ f"match replicate_weights length ({len(self.replicate_weights)})"
124
+ )
125
+ # Validate scale/rscales values and length
126
+ if self.replicate_scale is not None:
127
+ if not (np.isfinite(self.replicate_scale) and self.replicate_scale > 0):
128
+ raise ValueError(
129
+ f"replicate_scale must be a positive finite number, "
130
+ f"got {self.replicate_scale}"
131
+ )
132
+ if self.replicate_rscales is not None and self.replicate_weights is not None:
133
+ if len(self.replicate_rscales) != len(self.replicate_weights):
134
+ raise ValueError(
135
+ f"replicate_rscales length ({len(self.replicate_rscales)}) must "
136
+ f"match replicate_weights length ({len(self.replicate_weights)})"
137
+ )
138
+ rscales_arr = np.asarray(self.replicate_rscales, dtype=float)
139
+ if not np.all(np.isfinite(rscales_arr)):
140
+ raise ValueError("replicate_rscales must be finite")
141
+ if np.any(rscales_arr < 0):
142
+ raise ValueError("replicate_rscales must be non-negative")
143
+
144
+ def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
145
+ """
146
+ Validate column names and extract numpy arrays from DataFrame.
147
+
148
+ Parameters
149
+ ----------
150
+ data : pd.DataFrame
151
+ DataFrame containing survey design columns.
152
+
153
+ Returns
154
+ -------
155
+ ResolvedSurveyDesign
156
+ Internal representation with extracted numpy arrays.
157
+ """
158
+ n = len(data)
159
+
160
+ # --- Weights ---
161
+ if self.weights is not None:
162
+ if self.weights not in data.columns:
163
+ raise ValueError(f"Weight column '{self.weights}' not found in data")
164
+ raw_weights = data[self.weights].values.astype(np.float64)
165
+
166
+ # Validate weights
167
+ if np.any(np.isnan(raw_weights)):
168
+ raise ValueError("Weights contain NaN values")
169
+ if np.any(~np.isfinite(raw_weights)):
170
+ raise ValueError("Weights contain Inf values")
171
+ if np.any(raw_weights < 0):
172
+ raise ValueError("Weights must be non-negative")
173
+ if np.any(raw_weights == 0) and np.all(raw_weights == 0):
174
+ raise ValueError(
175
+ "All weights are zero. At least one observation must " "have a positive weight."
176
+ )
177
+
178
+ # fweight validation: must be non-negative integers
179
+ if self.weight_type == "fweight":
180
+ pos_mask = raw_weights > 0
181
+ if np.any(pos_mask):
182
+ fractional = raw_weights[pos_mask] - np.round(raw_weights[pos_mask])
183
+ if np.any(np.abs(fractional) > 1e-10):
184
+ raise ValueError(
185
+ "Frequency weights (fweight) must be non-negative integers. "
186
+ "Fractional values detected. Use pweight for non-integer weights."
187
+ )
188
+
189
+ # Normalize: pweights/aweights to sum=n (mean=1); fweights unchanged
190
+ # Skip normalization for replicate designs — the IF path uses
191
+ # w_r / w_full ratios that must be on the same raw scale
192
+ if self.replicate_weights is not None:
193
+ weights = raw_weights.copy()
194
+ elif self.weight_type in ("pweight", "aweight"):
195
+ raw_sum = float(np.sum(raw_weights))
196
+ weights = raw_weights * (n / raw_sum)
197
+ if not np.isclose(raw_sum, n):
198
+ warnings.warn(
199
+ f"{self.weight_type} weights normalized to mean=1 "
200
+ f"(sum={n}). Original sum was {raw_sum:.4g}.",
201
+ UserWarning,
202
+ stacklevel=2,
203
+ )
204
+ else:
205
+ weights = raw_weights.copy()
206
+ else:
207
+ weights = np.ones(n, dtype=np.float64)
208
+
209
+ # --- Replicate weights (short-circuit strata/psu/fpc) ---
210
+ if self.replicate_weights is not None:
211
+ rep_cols = self.replicate_weights
212
+ for col in rep_cols:
213
+ if col not in data.columns:
214
+ raise ValueError(f"Replicate weight column '{col}' not found in data")
215
+ rep_arr = np.column_stack([data[col].values.astype(np.float64) for col in rep_cols])
216
+ if np.any(np.isnan(rep_arr)):
217
+ raise ValueError("Replicate weights contain NaN values")
218
+ if np.any(~np.isfinite(rep_arr)):
219
+ raise ValueError("Replicate weights contain Inf values")
220
+ if np.any(rep_arr < 0):
221
+ raise ValueError("Replicate weights must be non-negative")
222
+ # Validate combined_weights contract: when True, replicate columns
223
+ # include the full-sample weight, so w_r > 0 with w_full == 0 is
224
+ # malformed (observation excluded from full sample but included in
225
+ # a replicate).
226
+ combined = self.combined_weights if self.combined_weights is not None else True
227
+ if combined:
228
+ zero_full = weights == 0
229
+ if np.any(zero_full):
230
+ rep_positive_on_zero = np.any(rep_arr[zero_full] > 0, axis=1)
231
+ if np.any(rep_positive_on_zero):
232
+ raise ValueError(
233
+ "Malformed combined_weights=True design: some "
234
+ "replicate columns have positive weight where "
235
+ "full-sample weight is zero. Either fix the "
236
+ "replicate columns or use combined_weights=False."
237
+ )
238
+ # Do NOT normalize replicate columns — the IF path uses w_r/w_full
239
+ # ratios that must reflect the true replicate design, not rescaled sums
240
+ n_rep = rep_arr.shape[1]
241
+ if n_rep < 2:
242
+ raise ValueError("At least 2 replicate weight columns are required")
243
+ return ResolvedSurveyDesign(
244
+ weights=weights,
245
+ weight_type=self.weight_type,
246
+ strata=None,
247
+ psu=None,
248
+ fpc=None,
249
+ n_strata=0,
250
+ n_psu=0,
251
+ lonely_psu=self.lonely_psu,
252
+ replicate_weights=rep_arr,
253
+ replicate_method=self.replicate_method,
254
+ fay_rho=self.fay_rho,
255
+ n_replicates=n_rep,
256
+ replicate_strata=(
257
+ np.asarray(self.replicate_strata, dtype=int)
258
+ if self.replicate_strata is not None
259
+ else None
260
+ ),
261
+ combined_weights=self.combined_weights,
262
+ replicate_scale=self.replicate_scale,
263
+ replicate_rscales=(
264
+ np.asarray(self.replicate_rscales, dtype=float)
265
+ if self.replicate_rscales is not None
266
+ else None
267
+ ),
268
+ mse=self.mse,
269
+ )
270
+
271
+ # --- Strata ---
272
+ strata_arr = None
273
+ n_strata = 0
274
+ if self.strata is not None:
275
+ if self.strata not in data.columns:
276
+ raise ValueError(f"Strata column '{self.strata}' not found in data")
277
+ strata_vals = data[self.strata].values
278
+ if pd.isna(strata_vals).any():
279
+ raise ValueError(
280
+ f"Strata column '{self.strata}' contains missing values. "
281
+ "All observations must have valid strata identifiers."
282
+ )
283
+ strata_arr = _factorize_cluster_ids(strata_vals)
284
+ n_strata = len(np.unique(strata_arr))
285
+
286
+ # --- PSU ---
287
+ psu_arr = None
288
+ n_psu = 0
289
+ if self.psu is not None:
290
+ if self.psu not in data.columns:
291
+ raise ValueError(f"PSU column '{self.psu}' not found in data")
292
+ psu_raw = data[self.psu].values
293
+ if pd.isna(psu_raw).any():
294
+ raise ValueError(
295
+ f"PSU column '{self.psu}' contains missing values. "
296
+ "All observations must have valid PSU identifiers."
297
+ )
298
+
299
+ if self.nest and strata_arr is not None:
300
+ # Make PSU IDs unique within strata by combining
301
+ combined = np.array([f"{s}_{p}" for s, p in zip(strata_arr, psu_raw)])
302
+ psu_arr = _factorize_cluster_ids(combined)
303
+ else:
304
+ psu_arr = _factorize_cluster_ids(psu_raw)
305
+ # Validate PSU labels are globally unique when nest=False
306
+ # and strata are present. Repeated labels cause wrong n_psu,
307
+ # df_survey, and lonely_psu="adjust" global mean.
308
+ if strata_arr is not None:
309
+ seen_psus: set = set()
310
+ for h in np.unique(strata_arr):
311
+ psu_in_h = set(psu_raw[strata_arr == h])
312
+ overlap = seen_psus & psu_in_h
313
+ if overlap:
314
+ raise ValueError(
315
+ f"PSU labels {overlap} appear in multiple strata. "
316
+ "Set nest=True in SurveyDesign to make PSU IDs "
317
+ "unique within strata, or use globally unique "
318
+ "PSU labels."
319
+ )
320
+ seen_psus |= psu_in_h
321
+
322
+ n_psu = len(np.unique(psu_arr))
323
+
324
+ # --- FPC ---
325
+ fpc_arr = None
326
+ if self.fpc is not None:
327
+ if self.fpc not in data.columns:
328
+ raise ValueError(f"FPC column '{self.fpc}' not found in data")
329
+ fpc_arr = data[self.fpc].values.astype(np.float64)
330
+
331
+ if np.any(np.isnan(fpc_arr)) or np.any(~np.isfinite(fpc_arr)):
332
+ raise ValueError("FPC values must be finite and non-NaN")
333
+
334
+ # Validate FPC structure (constant within strata, positive).
335
+ # FPC >= n_PSU validation is deferred to compute_survey_vcov()
336
+ # where the final effective PSU structure is known (after
337
+ # cluster-as-PSU injection and implicit per-obs PSU fallback).
338
+ if strata_arr is not None:
339
+ for h in np.unique(strata_arr):
340
+ mask_h = strata_arr == h
341
+ fpc_vals = fpc_arr[mask_h]
342
+ # Enforce FPC is constant within stratum
343
+ if len(np.unique(fpc_vals)) > 1:
344
+ raise ValueError(
345
+ f"FPC values must be constant within each stratum. "
346
+ f"Stratum {h} has values: {np.unique(fpc_vals)}"
347
+ )
348
+ fpc_h = fpc_vals[0]
349
+ # Validate FPC >= n_PSU when explicit PSU is declared
350
+ if psu_arr is not None:
351
+ n_psu_h = len(np.unique(psu_arr[mask_h]))
352
+ if fpc_h < n_psu_h:
353
+ raise ValueError(
354
+ f"FPC ({fpc_h}) is less than the number of PSUs "
355
+ f"({n_psu_h}) in stratum {h}. FPC must be >= n_PSU."
356
+ )
357
+ else:
358
+ # No strata: require FPC is a single constant value
359
+ if len(np.unique(fpc_arr)) > 1:
360
+ raise ValueError(
361
+ "FPC values must be constant when no strata are specified. "
362
+ f"Found {len(np.unique(fpc_arr))} distinct values."
363
+ )
364
+ # Validate FPC >= n_PSU when explicit PSU is declared
365
+ if psu_arr is not None and fpc_arr[0] < n_psu:
366
+ raise ValueError(
367
+ f"FPC ({fpc_arr[0]}) is less than the number of PSUs "
368
+ f"({n_psu}). FPC must be >= number of PSUs."
369
+ )
370
+
371
+ # --- Validate PSU counts per stratum ---
372
+ if psu_arr is not None and strata_arr is not None:
373
+ for h in np.unique(strata_arr):
374
+ mask_h = strata_arr == h
375
+ n_psu_h = len(np.unique(psu_arr[mask_h]))
376
+ if n_psu_h < 2:
377
+ if self.lonely_psu == "remove":
378
+ warnings.warn(
379
+ f"Stratum {h} has only {n_psu_h} PSU(s). "
380
+ "It will be excluded from variance estimation "
381
+ "(lonely_psu='remove').",
382
+ UserWarning,
383
+ stacklevel=3,
384
+ )
385
+ elif self.lonely_psu == "certainty":
386
+ pass # Handled in compute_survey_vcov
387
+ elif self.lonely_psu == "adjust":
388
+ pass # Handled in compute_survey_vcov
389
+
390
+ # Validate PSU count for unstratified designs
391
+ if psu_arr is not None and strata_arr is None:
392
+ if n_psu < 2:
393
+ if self.lonely_psu == "remove":
394
+ msg = (
395
+ f"Only {n_psu} PSU(s) found (unstratified design). "
396
+ "Variance cannot be estimated (lonely_psu='remove')."
397
+ )
398
+ elif self.lonely_psu == "certainty":
399
+ msg = (
400
+ f"Only {n_psu} PSU(s) found (unstratified design). "
401
+ "Treated as certainty PSU; zero variance contribution."
402
+ )
403
+ else:
404
+ msg = (
405
+ f"Only {n_psu} PSU(s) found (unstratified design). "
406
+ "Cannot adjust with a single cluster and no strata; "
407
+ "variance will be NaN."
408
+ )
409
+ warnings.warn(msg, UserWarning, stacklevel=3)
410
+
411
+ return ResolvedSurveyDesign(
412
+ weights=weights,
413
+ weight_type=self.weight_type,
414
+ strata=strata_arr,
415
+ psu=psu_arr,
416
+ fpc=fpc_arr,
417
+ n_strata=n_strata,
418
+ n_psu=n_psu,
419
+ lonely_psu=self.lonely_psu,
420
+ )
421
+
422
+ def subpopulation(
423
+ self,
424
+ data: pd.DataFrame,
425
+ mask,
426
+ ) -> Tuple["SurveyDesign", pd.DataFrame]:
427
+ """Create a subpopulation design by zeroing out excluded observations.
428
+
429
+ Preserves the full survey design structure (strata, PSU) while setting
430
+ weights to zero for observations outside the subpopulation. This is
431
+ the correct approach for subpopulation analysis — unlike naive
432
+ subsetting, it retains design information for variance estimation.
433
+
434
+ Parameters
435
+ ----------
436
+ mask : array-like of bool, str, or callable
437
+ Defines the subpopulation:
438
+ - bool array/Series of length ``len(data)`` — True = included
439
+ - str — column name in ``data`` containing boolean values
440
+ - callable — applied to ``data``, must return bool array
441
+
442
+ Returns
443
+ -------
444
+ (SurveyDesign, pd.DataFrame)
445
+ A new SurveyDesign pointing to a ``_subpop_weight`` column in the
446
+ returned DataFrame copy. The pair should be used together: pass
447
+ the returned DataFrame to ``fit()`` with the returned SurveyDesign.
448
+ """
449
+ # Resolve mask to boolean array
450
+ if callable(mask):
451
+ raw_mask = np.asarray(mask(data))
452
+ elif isinstance(mask, str):
453
+ if mask not in data.columns:
454
+ raise ValueError(f"Mask column '{mask}' not found in data")
455
+ raw_mask = np.asarray(data[mask].values)
456
+ else:
457
+ raw_mask = np.asarray(mask)
458
+
459
+ # Validate: reject pd.NA/pd.NaT/None before bool coercion
460
+ try:
461
+ if pd.isna(raw_mask).any():
462
+ raise ValueError(
463
+ "Subpopulation mask contains NA/missing values. "
464
+ "Provide a boolean mask with no missing values."
465
+ )
466
+ except (TypeError, ValueError) as e:
467
+ if "NA/missing" in str(e):
468
+ raise
469
+ # pd.isna can't handle some dtypes — fall through to specific checks
470
+ if raw_mask.dtype.kind == "f" and np.any(np.isnan(raw_mask)):
471
+ raise ValueError(
472
+ "Subpopulation mask contains NaN values. "
473
+ "Provide a boolean mask with no missing values."
474
+ )
475
+ if hasattr(raw_mask, "dtype") and raw_mask.dtype == object:
476
+ # Check for None values (pd.NA, None, etc.)
477
+ if any(v is None for v in raw_mask):
478
+ raise ValueError(
479
+ "Subpopulation mask contains None/NA values. "
480
+ "Provide a boolean mask with no missing values."
481
+ )
482
+ # Reject string/object masks — non-empty strings coerce to True
483
+ # which silently defines the wrong domain
484
+ if any(isinstance(v, str) for v in raw_mask):
485
+ raise ValueError(
486
+ "Subpopulation mask has object dtype with string values. "
487
+ "Provide a boolean or numeric (0/1) mask, not strings."
488
+ )
489
+ if hasattr(raw_mask, "dtype") and raw_mask.dtype.kind in ("U", "S"):
490
+ raise ValueError(
491
+ "Subpopulation mask contains string values. "
492
+ "Provide a boolean or numeric (0/1) mask."
493
+ )
494
+ # Validate numeric masks: only {0, 1} allowed (not {1, 2}, etc.)
495
+ if hasattr(raw_mask, "dtype") and raw_mask.dtype.kind in ("i", "u", "f"):
496
+ unique_vals = set(np.unique(raw_mask[np.isfinite(raw_mask)]).tolist())
497
+ if not unique_vals.issubset({0, 1, 0.0, 1.0, True, False}):
498
+ raise ValueError(
499
+ f"Subpopulation mask contains non-binary numeric values "
500
+ f"{unique_vals - {0, 1, 0.0, 1.0}}. "
501
+ f"Provide a boolean or numeric (0/1) mask."
502
+ )
503
+ mask_arr = raw_mask.astype(bool)
504
+
505
+ if len(mask_arr) != len(data):
506
+ raise ValueError(
507
+ f"Mask length ({len(mask_arr)}) does not match data " f"length ({len(data)})"
508
+ )
509
+
510
+ if not np.any(mask_arr):
511
+ raise ValueError(
512
+ "Subpopulation mask excludes all observations. "
513
+ "At least one observation must be included."
514
+ )
515
+
516
+ # Build subpopulation weights
517
+ if self.weights is not None:
518
+ if self.weights not in data.columns:
519
+ raise ValueError(f"Weight column '{self.weights}' not found in data")
520
+ base_weights = data[self.weights].values.astype(np.float64)
521
+ else:
522
+ base_weights = np.ones(len(data), dtype=np.float64)
523
+
524
+ subpop_weights = np.where(mask_arr, base_weights, 0.0)
525
+
526
+ # Create data copy with synthetic weight column
527
+ data_out = data.copy()
528
+ data_out["_subpop_weight"] = subpop_weights
529
+
530
+ # Zero out replicate weight columns for excluded observations
531
+ if self.replicate_weights is not None:
532
+ for col in self.replicate_weights:
533
+ if col in data.columns:
534
+ data_out[col] = np.where(mask_arr, data[col].values, 0.0)
535
+
536
+ # Return new SurveyDesign using the synthetic column
537
+ new_design = SurveyDesign(
538
+ weights="_subpop_weight",
539
+ strata=self.strata,
540
+ psu=self.psu,
541
+ fpc=self.fpc,
542
+ weight_type=self.weight_type,
543
+ nest=self.nest,
544
+ lonely_psu=self.lonely_psu,
545
+ replicate_weights=self.replicate_weights,
546
+ replicate_method=self.replicate_method,
547
+ fay_rho=self.fay_rho,
548
+ replicate_strata=self.replicate_strata,
549
+ combined_weights=self.combined_weights,
550
+ replicate_scale=self.replicate_scale,
551
+ replicate_rscales=self.replicate_rscales,
552
+ mse=self.mse,
553
+ )
554
+
555
+ return new_design, data_out
556
+
557
+
558
+ @dataclass
559
+ class ResolvedSurveyDesign:
560
+ """
561
+ Internal class with extracted numpy arrays from SurveyDesign.resolve().
562
+
563
+ Not intended for direct construction by users.
564
+ """
565
+
566
+ weights: np.ndarray
567
+ weight_type: str
568
+ strata: Optional[np.ndarray]
569
+ psu: Optional[np.ndarray]
570
+ fpc: Optional[np.ndarray]
571
+ n_strata: int
572
+ n_psu: int
573
+ lonely_psu: str
574
+ replicate_weights: Optional[np.ndarray] = None # (n, R) array
575
+ replicate_method: Optional[str] = None
576
+ fay_rho: float = 0.0
577
+ n_replicates: int = 0
578
+ replicate_strata: Optional[np.ndarray] = None # (R,) for JKn
579
+ combined_weights: bool = True
580
+ replicate_scale: Optional[float] = None
581
+ replicate_rscales: Optional[np.ndarray] = None # (R,) per-replicate scales
582
+ mse: bool = False
583
+
584
+ @property
585
+ def uses_replicate_variance(self) -> bool:
586
+ """Whether replicate-based variance should be used instead of TSL."""
587
+ return self.replicate_method is not None
588
+
589
+ @property
590
+ def df_survey(self) -> Optional[int]:
591
+ """Survey degrees of freedom.
592
+
593
+ For replicate designs: QR-rank of the analysis-weight matrix minus 1,
594
+ matching R's ``survey::degf()`` which uses ``qr(..., tol=1e-5)$rank``.
595
+ Returns ``None`` when rank <= 1 (insufficient for t-based inference).
596
+ For TSL: n_PSU - n_strata.
597
+ """
598
+ if self.uses_replicate_variance:
599
+ if self.replicate_weights is None or self.n_replicates < 2:
600
+ return None
601
+ # QR-rank of analysis-weight matrix, matching R's survey::degf()
602
+ # which uses qr(weights(design, "analysis"), tol=1e-5)$rank.
603
+ # For combined_weights=True, replicate cols ARE analysis weights.
604
+ # For combined_weights=False, analysis weights = rep * full-sample.
605
+ if self.combined_weights:
606
+ analysis_weights = self.replicate_weights
607
+ else:
608
+ analysis_weights = self.replicate_weights * self.weights[:, np.newaxis]
609
+ # Pivoted QR with R-compatible tolerance, matching R's
610
+ # qr(..., tol=1e-5) which uses column pivoting (LAPACK dgeqp3)
611
+ from scipy.linalg import qr as scipy_qr
612
+
613
+ _, R_mat, _ = scipy_qr(analysis_weights, pivoting=True, mode="economic")
614
+ diag_abs = np.abs(np.diag(R_mat))
615
+ tol = 1e-5
616
+ rank = int(np.sum(diag_abs > tol * diag_abs.max())) if diag_abs.max() > 0 else 0
617
+ df = rank - 1
618
+ return df if df > 0 else None
619
+ if self.psu is not None and self.n_psu > 0:
620
+ if self.strata is not None and self.n_strata > 0:
621
+ return self.n_psu - self.n_strata
622
+ return self.n_psu - 1
623
+ # Implicit PSU: each observation is its own PSU
624
+ n_obs = len(self.weights)
625
+ if self.strata is not None and self.n_strata > 0:
626
+ return n_obs - self.n_strata
627
+ return n_obs - 1
628
+
629
+ def subset_to_units(
630
+ self,
631
+ row_idx: np.ndarray,
632
+ weights: np.ndarray,
633
+ strata: Optional[np.ndarray],
634
+ psu: Optional[np.ndarray],
635
+ fpc: Optional[np.ndarray],
636
+ n_strata: int,
637
+ n_psu: int,
638
+ ) -> "ResolvedSurveyDesign":
639
+ """Create a unit-level copy preserving replicate metadata.
640
+
641
+ Used by panel estimators (ContinuousDiD, EfficientDiD) that collapse
642
+ panel-level survey info to one row per unit.
643
+
644
+ Parameters
645
+ ----------
646
+ row_idx : np.ndarray
647
+ Indices into the panel-level arrays to select one row per unit.
648
+ weights, strata, psu, fpc, n_strata, n_psu
649
+ Already-subsetted TSL fields (computed by the caller).
650
+ """
651
+ rep_weights_sub = None
652
+ if self.replicate_weights is not None:
653
+ rep_weights_sub = self.replicate_weights[row_idx, :]
654
+
655
+ return ResolvedSurveyDesign(
656
+ weights=weights,
657
+ weight_type=self.weight_type,
658
+ strata=strata,
659
+ psu=psu,
660
+ fpc=fpc,
661
+ n_strata=n_strata,
662
+ n_psu=n_psu,
663
+ lonely_psu=self.lonely_psu,
664
+ replicate_weights=rep_weights_sub,
665
+ replicate_method=self.replicate_method,
666
+ fay_rho=self.fay_rho,
667
+ n_replicates=self.n_replicates,
668
+ replicate_strata=self.replicate_strata,
669
+ combined_weights=self.combined_weights,
670
+ replicate_scale=self.replicate_scale,
671
+ replicate_rscales=self.replicate_rscales,
672
+ mse=self.mse,
673
+ )
674
+
675
+ @property
676
+ def needs_survey_vcov(self) -> bool:
677
+ """Whether survey vcov (not generic sandwich) should be used."""
678
+ return True # Any resolved survey design uses the survey vcov path
679
+
680
+
681
+ @dataclass
682
+ class SurveyMetadata:
683
+ """
684
+ Survey design metadata stored in results objects.
685
+
686
+ Attributes
687
+ ----------
688
+ weight_type : str
689
+ Type of weights used.
690
+ effective_n : float
691
+ Kish effective sample size: (sum(w))^2 / sum(w^2).
692
+ design_effect : float
693
+ DEFF: n * sum(w^2) / (sum(w))^2.
694
+ sum_weights : float
695
+ Sum of original (pre-normalization) weights.
696
+ n_strata : int or None
697
+ Number of strata (None if unstratified).
698
+ n_psu : int or None
699
+ Number of PSUs (None if no PSU specified).
700
+ weight_range : tuple of (float, float)
701
+ (min, max) of original weights.
702
+ df_survey : int or None
703
+ Survey degrees of freedom (n_psu - n_strata).
704
+ """
705
+
706
+ weight_type: str
707
+ effective_n: float
708
+ design_effect: float
709
+ sum_weights: float
710
+ n_strata: Optional[int] = None
711
+ n_psu: Optional[int] = None
712
+ weight_range: Tuple[float, float] = field(default=(0.0, 0.0))
713
+ df_survey: Optional[int] = None
714
+ replicate_method: Optional[str] = None
715
+ n_replicates: Optional[int] = None
716
+ deff_diagnostics: Optional["DEFFDiagnostics"] = None
717
+
718
+
719
+ def compute_survey_metadata(
720
+ resolved: "ResolvedSurveyDesign",
721
+ raw_weights: np.ndarray,
722
+ ) -> SurveyMetadata:
723
+ """
724
+ Compute survey metadata from resolved design.
725
+
726
+ Parameters
727
+ ----------
728
+ resolved : ResolvedSurveyDesign
729
+ Resolved survey design.
730
+ raw_weights : np.ndarray
731
+ Original (pre-normalization) weights.
732
+
733
+ Returns
734
+ -------
735
+ SurveyMetadata
736
+ """
737
+ sum_w = float(np.sum(raw_weights))
738
+ sum_w2 = float(np.sum(raw_weights**2))
739
+ n = len(raw_weights)
740
+
741
+ effective_n = sum_w**2 / sum_w2 if sum_w2 > 0 else float(n)
742
+ design_effect = n * sum_w2 / (sum_w**2) if sum_w > 0 else 1.0
743
+
744
+ n_strata = resolved.n_strata if resolved.strata is not None else None
745
+ if resolved.uses_replicate_variance:
746
+ # Replicate designs don't have meaningful PSU/strata counts
747
+ n_psu = None
748
+ n_strata = None
749
+ elif resolved.psu is not None:
750
+ n_psu = resolved.n_psu
751
+ else:
752
+ # Implicit PSU: each observation is its own PSU
753
+ n_psu = len(resolved.weights)
754
+ df_survey = resolved.df_survey
755
+
756
+ # Replicate info
757
+ rep_method = resolved.replicate_method if resolved.uses_replicate_variance else None
758
+ n_rep = resolved.n_replicates if resolved.uses_replicate_variance else None
759
+
760
+ return SurveyMetadata(
761
+ weight_type=resolved.weight_type,
762
+ effective_n=effective_n,
763
+ design_effect=design_effect,
764
+ sum_weights=sum_w,
765
+ n_strata=n_strata,
766
+ n_psu=n_psu,
767
+ weight_range=(float(np.min(raw_weights)), float(np.max(raw_weights))),
768
+ df_survey=df_survey,
769
+ replicate_method=rep_method,
770
+ n_replicates=n_rep,
771
+ )
772
+
773
+
774
+ @dataclass
775
+ class DEFFDiagnostics:
776
+ """Per-coefficient design effect diagnostics.
777
+
778
+ Compares survey-design variance to simple random sampling (SRS)
779
+ variance for each coefficient, giving the variance inflation factor
780
+ due to the survey design (clustering, stratification, weighting).
781
+
782
+ Attributes
783
+ ----------
784
+ deff : np.ndarray
785
+ Per-coefficient DEFF: survey_var / srs_var. Shape (k,).
786
+ effective_n : np.ndarray
787
+ Effective sample size per coefficient: n / DEFF. Shape (k,).
788
+ srs_se : np.ndarray
789
+ SRS (HC1) standard errors. Shape (k,).
790
+ survey_se : np.ndarray
791
+ Survey standard errors. Shape (k,).
792
+ coefficient_names : list of str or None
793
+ Names for display.
794
+ """
795
+
796
+ deff: np.ndarray
797
+ effective_n: np.ndarray
798
+ srs_se: np.ndarray
799
+ survey_se: np.ndarray
800
+ coefficient_names: Optional[List[str]] = None
801
+
802
+
803
+ def compute_deff_diagnostics(
804
+ X: np.ndarray,
805
+ residuals: np.ndarray,
806
+ survey_vcov: np.ndarray,
807
+ weights: np.ndarray,
808
+ weight_type: str = "pweight",
809
+ coefficient_names: Optional[List[str]] = None,
810
+ ) -> DEFFDiagnostics:
811
+ """Compute per-coefficient design effects.
812
+
813
+ Compares the survey variance-covariance matrix to a simple random
814
+ sampling (SRS) baseline (HC1 sandwich, ignoring strata/PSU/FPC).
815
+
816
+ Parameters
817
+ ----------
818
+ X : np.ndarray
819
+ Design matrix of shape (n, k).
820
+ residuals : np.ndarray
821
+ Residuals from the WLS fit, shape (n,).
822
+ survey_vcov : np.ndarray
823
+ Survey variance-covariance matrix, shape (k, k).
824
+ weights : np.ndarray
825
+ Observation weights (normalized), shape (n,).
826
+ weight_type : str, default "pweight"
827
+ Weight type for SRS computation.
828
+ coefficient_names : list of str, optional
829
+ Names for display.
830
+
831
+ Returns
832
+ -------
833
+ DEFFDiagnostics
834
+ """
835
+ from diff_diff.linalg import compute_robust_vcov
836
+
837
+ n = X.shape[0]
838
+ # Use positive-weight count for effective n (zero-weight rows from
839
+ # subpopulation don't contribute to the effective sample)
840
+ n_eff = int(np.count_nonzero(weights > 0)) if np.any(weights == 0) else n
841
+
842
+ # SRS baseline: HC1 weighted sandwich ignoring design structure
843
+ srs_vcov = compute_robust_vcov(
844
+ X,
845
+ residuals,
846
+ cluster_ids=None,
847
+ weights=weights,
848
+ weight_type=weight_type,
849
+ )
850
+
851
+ survey_var = np.diag(survey_vcov)
852
+ srs_var = np.diag(srs_vcov)
853
+
854
+ # DEFF = survey_var / srs_var
855
+ with np.errstate(divide="ignore", invalid="ignore"):
856
+ deff = np.where(srs_var > 0, survey_var / srs_var, np.nan)
857
+ eff_n = np.where(deff > 0, n_eff / deff, np.nan)
858
+
859
+ survey_se = np.sqrt(np.maximum(survey_var, 0.0))
860
+ srs_se = np.sqrt(np.maximum(srs_var, 0.0))
861
+
862
+ return DEFFDiagnostics(
863
+ deff=deff,
864
+ effective_n=eff_n,
865
+ srs_se=srs_se,
866
+ survey_se=survey_se,
867
+ coefficient_names=coefficient_names,
868
+ )
869
+
870
+
871
+ def _validate_unit_constant_survey(data, unit_col, survey_design):
872
+ """Validate that survey design columns are constant within units.
873
+
874
+ Panel estimators (ContinuousDiD, EfficientDiD) collapse panel-level
875
+ survey info to one row per unit. This requires that survey columns
876
+ do not vary across time periods within a unit.
877
+
878
+ Parameters
879
+ ----------
880
+ data : pd.DataFrame
881
+ Panel data.
882
+ unit_col : str
883
+ Unit identifier column name.
884
+ survey_design : SurveyDesign
885
+ Survey design specification (uses attribute names, not resolved arrays).
886
+
887
+ Raises
888
+ ------
889
+ ValueError
890
+ If any survey column varies within units.
891
+ """
892
+ cols_to_check = [
893
+ survey_design.weights,
894
+ survey_design.strata,
895
+ survey_design.psu,
896
+ survey_design.fpc,
897
+ ]
898
+ # Also validate replicate weight columns for within-unit constancy
899
+ if survey_design.replicate_weights is not None:
900
+ cols_to_check.extend(survey_design.replicate_weights)
901
+ for col in cols_to_check:
902
+ if col is not None and col in data.columns:
903
+ n_unique = data.groupby(unit_col)[col].nunique()
904
+ varying_units = n_unique[n_unique > 1]
905
+ if len(varying_units) > 0:
906
+ raise ValueError(
907
+ f"Survey column '{col}' varies within units "
908
+ f"(found {len(varying_units)} units with multiple values). "
909
+ f"Panel estimators require survey design columns to be "
910
+ f"constant within units."
911
+ )
912
+
913
+
914
+ def _resolve_pweight_only(resolved_survey, estimator_name):
915
+ """Guard: reject non-pweight and strata/PSU/FPC for pweight-only estimators.
916
+
917
+ Parameters
918
+ ----------
919
+ resolved_survey : ResolvedSurveyDesign or None
920
+ Resolved survey design. If None, returns immediately.
921
+ estimator_name : str
922
+ Estimator name for error messages.
923
+
924
+ Raises
925
+ ------
926
+ ValueError
927
+ If weight_type is not 'pweight'.
928
+ NotImplementedError
929
+ If strata, PSU, or FPC are present.
930
+ """
931
+ if resolved_survey is None:
932
+ return
933
+ if resolved_survey.weight_type != "pweight":
934
+ raise ValueError(
935
+ f"{estimator_name} survey support requires weight_type='pweight'. "
936
+ f"Got '{resolved_survey.weight_type}'."
937
+ )
938
+ if (
939
+ resolved_survey.strata is not None
940
+ or resolved_survey.psu is not None
941
+ or resolved_survey.fpc is not None
942
+ ):
943
+ raise NotImplementedError(
944
+ f"{estimator_name} does not yet support strata/PSU/FPC in "
945
+ "SurveyDesign. Use SurveyDesign(weights=...) only. Full "
946
+ "design-based bootstrap is planned for the Bootstrap + "
947
+ "Survey Interaction phase."
948
+ )
949
+
950
+
951
+ def collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
952
+ """Collapse observation-level ResolvedSurveyDesign to unit level.
953
+
954
+ Panel estimators with influence-function-based variance need a
955
+ unit-level survey design (one row per unit) rather than the
956
+ observation-level design (one row per unit-time). Survey design
957
+ columns must be constant within units (validated upstream via
958
+ ``_validate_unit_constant_survey``).
959
+
960
+ Parameters
961
+ ----------
962
+ resolved_survey : ResolvedSurveyDesign
963
+ Observation-level resolved survey design.
964
+ df : pd.DataFrame
965
+ Panel data used for groupby operations.
966
+ unit_col : str
967
+ Unit identifier column name.
968
+ all_units : array-like
969
+ Ordered sequence of unique unit identifiers.
970
+
971
+ Returns
972
+ -------
973
+ ResolvedSurveyDesign
974
+ Unit-level design with arrays indexed by ``all_units``.
975
+ """
976
+ n_units = len(all_units)
977
+
978
+ weights_unit = (
979
+ pd.Series(resolved_survey.weights, index=df.index)
980
+ .groupby(df[unit_col])
981
+ .first()
982
+ .reindex(all_units)
983
+ .values
984
+ )
985
+
986
+ strata_unit = None
987
+ if resolved_survey.strata is not None:
988
+ strata_unit = (
989
+ pd.Series(resolved_survey.strata, index=df.index)
990
+ .groupby(df[unit_col])
991
+ .first()
992
+ .reindex(all_units)
993
+ .values
994
+ )
995
+
996
+ psu_unit = None
997
+ if resolved_survey.psu is not None:
998
+ psu_unit = (
999
+ pd.Series(resolved_survey.psu, index=df.index)
1000
+ .groupby(df[unit_col])
1001
+ .first()
1002
+ .reindex(all_units)
1003
+ .values
1004
+ )
1005
+
1006
+ fpc_unit = None
1007
+ if resolved_survey.fpc is not None:
1008
+ fpc_unit = (
1009
+ pd.Series(resolved_survey.fpc, index=df.index)
1010
+ .groupby(df[unit_col])
1011
+ .first()
1012
+ .reindex(all_units)
1013
+ .values
1014
+ )
1015
+
1016
+ # Collapse replicate weights to unit level (same groupby pattern)
1017
+ rep_weights_unit = None
1018
+ if resolved_survey.replicate_weights is not None:
1019
+ R = resolved_survey.replicate_weights.shape[1]
1020
+ rep_weights_unit = np.zeros((n_units, R))
1021
+ for r in range(R):
1022
+ rep_weights_unit[:, r] = (
1023
+ pd.Series(resolved_survey.replicate_weights[:, r], index=df.index)
1024
+ .groupby(df[unit_col])
1025
+ .first()
1026
+ .reindex(all_units)
1027
+ .values
1028
+ )
1029
+
1030
+ return ResolvedSurveyDesign(
1031
+ weights=weights_unit.astype(np.float64),
1032
+ weight_type=resolved_survey.weight_type,
1033
+ strata=strata_unit,
1034
+ psu=psu_unit,
1035
+ fpc=fpc_unit,
1036
+ n_strata=resolved_survey.n_strata,
1037
+ n_psu=resolved_survey.n_psu,
1038
+ lonely_psu=resolved_survey.lonely_psu,
1039
+ replicate_weights=rep_weights_unit,
1040
+ replicate_method=resolved_survey.replicate_method,
1041
+ fay_rho=resolved_survey.fay_rho,
1042
+ n_replicates=resolved_survey.n_replicates,
1043
+ replicate_strata=resolved_survey.replicate_strata,
1044
+ combined_weights=resolved_survey.combined_weights,
1045
+ replicate_scale=resolved_survey.replicate_scale,
1046
+ replicate_rscales=resolved_survey.replicate_rscales,
1047
+ mse=resolved_survey.mse,
1048
+ )
1049
+
1050
+
1051
+ def _extract_unit_survey_weights(data, unit_col, survey_design, unit_order):
1052
+ """Extract unit-level survey weights aligned to a given unit ordering.
1053
+
1054
+ Parameters
1055
+ ----------
1056
+ data : pd.DataFrame
1057
+ Panel data with survey weight column.
1058
+ unit_col : str
1059
+ Unit identifier column name.
1060
+ survey_design : SurveyDesign
1061
+ Survey design (uses ``weights`` column name).
1062
+ unit_order : array-like
1063
+ Ordered sequence of unit identifiers to align weights to.
1064
+
1065
+ Returns
1066
+ -------
1067
+ np.ndarray
1068
+ Float64 array of unit-level weights, one per unit in ``unit_order``.
1069
+ """
1070
+ unit_w = data.groupby(unit_col)[survey_design.weights].first()
1071
+ return np.array([unit_w[u] for u in unit_order], dtype=np.float64)
1072
+
1073
+
1074
+ def _resolve_survey_for_fit(survey_design, data, inference_mode="analytical"):
1075
+ """
1076
+ Shared helper: validate and resolve a SurveyDesign for an estimator fit() call.
1077
+
1078
+ Returns (resolved, weights, weight_type, metadata) or all-None tuple if
1079
+ survey_design is None.
1080
+ """
1081
+ if survey_design is None:
1082
+ return None, None, "pweight", None
1083
+
1084
+ if not isinstance(survey_design, SurveyDesign):
1085
+ raise TypeError("survey_design must be a SurveyDesign instance")
1086
+
1087
+ if inference_mode == "wild_bootstrap":
1088
+ raise NotImplementedError(
1089
+ "Wild bootstrap with survey weights is not yet supported. "
1090
+ "Use analytical survey inference (the default) instead."
1091
+ )
1092
+
1093
+ resolved = survey_design.resolve(data)
1094
+ raw_w = (
1095
+ data[survey_design.weights].values.astype(np.float64)
1096
+ if survey_design.weights
1097
+ else np.ones(len(data), dtype=np.float64)
1098
+ )
1099
+ metadata = compute_survey_metadata(resolved, raw_w)
1100
+ return resolved, resolved.weights, resolved.weight_type, metadata
1101
+
1102
+
1103
+ def _resolve_effective_cluster(resolved_survey, cluster_ids, cluster_name=None):
1104
+ """
1105
+ Shared helper: determine effective cluster IDs for variance estimation.
1106
+
1107
+ When survey PSU is present, it overrides the user-specified cluster.
1108
+ Warns if both are specified with different groupings.
1109
+ """
1110
+ if resolved_survey is None or resolved_survey.psu is None:
1111
+ return cluster_ids
1112
+
1113
+ if cluster_ids is not None and cluster_name is not None:
1114
+ # Compare partition equivalence (not label equality)
1115
+ psu_codes, _ = pd.factorize(resolved_survey.psu)
1116
+ cluster_codes, _ = pd.factorize(cluster_ids)
1117
+ if not np.array_equal(psu_codes, cluster_codes):
1118
+ warnings.warn(
1119
+ f"Both survey_design.psu and cluster='{cluster_name}' specified "
1120
+ "with different groupings. PSU will be used for variance "
1121
+ "estimation (survey design-based inference).",
1122
+ UserWarning,
1123
+ stacklevel=3,
1124
+ )
1125
+ return resolved_survey.psu
1126
+
1127
+
1128
+ def _inject_cluster_as_psu(resolved, cluster_ids):
1129
+ """
1130
+ When survey design has no PSU but cluster_ids are provided,
1131
+ inject cluster_ids as the effective PSU for TSL variance estimation.
1132
+
1133
+ Returns a new ResolvedSurveyDesign (no mutation) or the original unchanged.
1134
+ """
1135
+ if resolved is None or cluster_ids is None:
1136
+ return resolved
1137
+ if resolved.psu is not None:
1138
+ return resolved # PSU already present; _resolve_effective_cluster handles this
1139
+
1140
+ # Validate no missing cluster IDs before factorization
1141
+ if pd.isna(cluster_ids).any():
1142
+ raise ValueError(
1143
+ "Cluster IDs contain missing values. "
1144
+ "All observations must have valid cluster identifiers "
1145
+ "when used as effective PSUs for survey variance estimation."
1146
+ )
1147
+
1148
+ # When strata are present, make cluster IDs unique within strata
1149
+ # (same nesting logic as SurveyDesign.resolve() with nest=True)
1150
+ if resolved.strata is not None:
1151
+ combined = np.array([f"{s}_{c}" for s, c in zip(resolved.strata, cluster_ids)])
1152
+ codes, uniques = pd.factorize(combined)
1153
+ else:
1154
+ codes, uniques = pd.factorize(cluster_ids)
1155
+ n_clusters = len(uniques)
1156
+
1157
+ return replace(resolved, psu=codes, n_psu=n_clusters)
1158
+
1159
+
1160
+ def _compute_stratified_psu_meat(
1161
+ scores: np.ndarray,
1162
+ resolved: "ResolvedSurveyDesign",
1163
+ ) -> tuple:
1164
+ """Compute the stratified PSU-level meat matrix for TSL variance.
1165
+
1166
+ This is the core computation shared by :func:`compute_survey_vcov`
1167
+ (which wraps it in a sandwich with the bread matrix) and
1168
+ :func:`compute_survey_if_variance` (which uses it directly for
1169
+ influence-function-based estimators).
1170
+
1171
+ Parameters
1172
+ ----------
1173
+ scores : np.ndarray
1174
+ Score matrix of shape (n, k). For OLS-based estimators these are
1175
+ the weighted score contributions X_i * w_i * u_i. For IF-based
1176
+ estimators these are the per-unit influence function values
1177
+ (reshaped to (n, 1) for scalar estimators).
1178
+ resolved : ResolvedSurveyDesign
1179
+ Resolved survey design with weights, strata, PSU arrays.
1180
+
1181
+ Returns
1182
+ -------
1183
+ meat : np.ndarray
1184
+ Meat matrix of shape (k, k).
1185
+ variance_computed : bool
1186
+ Whether any actual variance computation happened.
1187
+ legitimate_zero_count : int
1188
+ Number of strata/sources that legitimately contribute zero variance.
1189
+ """
1190
+ n = scores.shape[0]
1191
+ k = scores.shape[1] if scores.ndim > 1 else 1
1192
+ if scores.ndim == 1:
1193
+ scores = scores[:, np.newaxis]
1194
+
1195
+ strata = resolved.strata
1196
+ psu = resolved.psu
1197
+
1198
+ legitimate_zero_count = 0
1199
+ _variance_computed = False
1200
+
1201
+ if strata is None and psu is None:
1202
+ # No survey structure beyond weights — implicit per-observation PSUs
1203
+ psu_mean = scores.mean(axis=0, keepdims=True)
1204
+ centered = scores - psu_mean
1205
+ f_h = 0.0
1206
+ if resolved.fpc is not None:
1207
+ N_h = resolved.fpc[0]
1208
+ if N_h < n:
1209
+ raise ValueError(
1210
+ f"FPC ({N_h}) is less than the number of observations "
1211
+ f"({n}). FPC must be >= n_obs for implicit per-observation PSUs."
1212
+ )
1213
+ f_h = n / N_h
1214
+ if f_h >= 1.0:
1215
+ legitimate_zero_count += 1
1216
+ adjustment = (1.0 - f_h) * (n / (n - 1))
1217
+ meat = adjustment * (centered.T @ centered)
1218
+ _variance_computed = True
1219
+ elif strata is None and psu is not None:
1220
+ # No strata, but PSU present — single-stratum cluster-robust
1221
+ psu_scores = pd.DataFrame(scores).groupby(psu).sum().values
1222
+ n_psu = psu_scores.shape[0]
1223
+
1224
+ if n_psu < 2:
1225
+ meat = np.zeros((k, k))
1226
+ else:
1227
+ psu_mean = psu_scores.mean(axis=0, keepdims=True)
1228
+ centered = psu_scores - psu_mean
1229
+ f_h = 0.0
1230
+ if resolved.fpc is not None:
1231
+ N_h = resolved.fpc[0]
1232
+ if N_h < n_psu:
1233
+ raise ValueError(
1234
+ f"FPC ({N_h}) is less than the number of effective PSUs "
1235
+ f"({n_psu}). FPC must be >= n_PSU."
1236
+ )
1237
+ f_h = n_psu / N_h
1238
+ if f_h >= 1.0:
1239
+ legitimate_zero_count += 1
1240
+ adjustment = (1.0 - f_h) * (n_psu / (n_psu - 1))
1241
+ meat = adjustment * (centered.T @ centered)
1242
+ _variance_computed = True
1243
+ else:
1244
+ # Stratified with or without PSU
1245
+ unique_strata = np.unique(strata)
1246
+ meat = np.zeros((k, k))
1247
+
1248
+ _global_psu_mean = None
1249
+ if resolved.lonely_psu == "adjust":
1250
+ if psu is not None:
1251
+ _global_psu_mean = (
1252
+ pd.DataFrame(scores).groupby(psu).sum().values.mean(axis=0, keepdims=True)
1253
+ )
1254
+ else:
1255
+ _global_psu_mean = scores.mean(axis=0, keepdims=True)
1256
+
1257
+ for h in unique_strata:
1258
+ mask_h = strata == h
1259
+
1260
+ if psu is not None:
1261
+ psu_h = psu[mask_h]
1262
+ scores_h = scores[mask_h]
1263
+ psu_scores_h = pd.DataFrame(scores_h).groupby(psu_h).sum().values
1264
+ n_psu_h = psu_scores_h.shape[0]
1265
+ else:
1266
+ psu_scores_h = scores[mask_h]
1267
+ n_psu_h = psu_scores_h.shape[0]
1268
+
1269
+ # Handle singleton strata
1270
+ if n_psu_h < 2:
1271
+ if resolved.lonely_psu == "remove":
1272
+ continue
1273
+ elif resolved.lonely_psu == "certainty":
1274
+ legitimate_zero_count += 1
1275
+ continue
1276
+ elif resolved.lonely_psu == "adjust":
1277
+ centered = psu_scores_h - _global_psu_mean
1278
+ V_h = centered.T @ centered
1279
+ meat += V_h
1280
+ _variance_computed = True
1281
+ continue
1282
+
1283
+ # FPC
1284
+ f_h = 0.0
1285
+ if resolved.fpc is not None:
1286
+ N_h = resolved.fpc[mask_h][0]
1287
+ if N_h < n_psu_h:
1288
+ raise ValueError(
1289
+ f"FPC ({N_h}) is less than the number of effective PSUs "
1290
+ f"({n_psu_h}) in stratum. FPC must be >= n_PSU."
1291
+ )
1292
+ f_h = n_psu_h / N_h
1293
+ if f_h >= 1.0:
1294
+ legitimate_zero_count += 1
1295
+
1296
+ psu_mean_h = psu_scores_h.mean(axis=0, keepdims=True)
1297
+ centered = psu_scores_h - psu_mean_h
1298
+
1299
+ adjustment = (1.0 - f_h) * (n_psu_h / (n_psu_h - 1))
1300
+ V_h = adjustment * (centered.T @ centered)
1301
+ meat += V_h
1302
+ _variance_computed = True
1303
+
1304
+ return meat, _variance_computed, legitimate_zero_count
1305
+
1306
+
1307
+ def _compute_stratified_meat_from_psu_scores(
1308
+ psu_scores: np.ndarray,
1309
+ psu_strata: np.ndarray,
1310
+ fpc_per_psu: "Optional[np.ndarray]" = None,
1311
+ lonely_psu: str = "remove",
1312
+ ) -> np.ndarray:
1313
+ """Compute stratified meat matrix from pre-aggregated PSU-level scores.
1314
+
1315
+ Like :func:`_compute_stratified_psu_meat`, but accepts scores that are
1316
+ already aggregated to the PSU level (one row per PSU). Used by
1317
+ TwoStageDiD's GMM sandwich where the score matrix ``S`` is built at
1318
+ the cluster/PSU level.
1319
+
1320
+ Parameters
1321
+ ----------
1322
+ psu_scores : np.ndarray
1323
+ Score matrix of shape (G, k) — one row per PSU.
1324
+ psu_strata : np.ndarray
1325
+ Stratum assignment per PSU, shape (G,).
1326
+ fpc_per_psu : np.ndarray, optional
1327
+ FPC population size per PSU, shape (G,). All PSUs in the same
1328
+ stratum should share the same FPC value (first occurrence used).
1329
+ lonely_psu : str
1330
+ How to handle singleton strata: "remove", "certainty", or "adjust".
1331
+
1332
+ Returns
1333
+ -------
1334
+ meat : np.ndarray
1335
+ Meat matrix of shape (k, k).
1336
+ variance_computed : bool
1337
+ Whether any actual variance computation happened.
1338
+ legitimate_zero_count : int
1339
+ Number of strata that legitimately contribute zero variance.
1340
+ """
1341
+ if psu_scores.ndim == 1:
1342
+ psu_scores = psu_scores[:, np.newaxis]
1343
+ k = psu_scores.shape[1]
1344
+ meat = np.zeros((k, k))
1345
+
1346
+ unique_strata = np.unique(psu_strata)
1347
+ _variance_computed = False
1348
+ legitimate_zero_count = 0
1349
+
1350
+ # Pre-compute global mean for lonely_psu="adjust"
1351
+ _global_psu_mean = None
1352
+ if lonely_psu == "adjust":
1353
+ _global_psu_mean = psu_scores.mean(axis=0, keepdims=True)
1354
+
1355
+ for h in unique_strata:
1356
+ mask_h = psu_strata == h
1357
+ scores_h = psu_scores[mask_h]
1358
+ n_psu_h = scores_h.shape[0]
1359
+
1360
+ # Handle singleton strata
1361
+ if n_psu_h < 2:
1362
+ if lonely_psu == "remove":
1363
+ continue
1364
+ elif lonely_psu == "certainty":
1365
+ legitimate_zero_count += 1
1366
+ continue
1367
+ elif lonely_psu == "adjust":
1368
+ centered = scores_h - _global_psu_mean
1369
+ with np.errstate(invalid="ignore", over="ignore"):
1370
+ meat += centered.T @ centered
1371
+ _variance_computed = True
1372
+ continue
1373
+
1374
+ # FPC
1375
+ f_h = 0.0
1376
+ if fpc_per_psu is not None:
1377
+ N_h = fpc_per_psu[mask_h][0]
1378
+ if N_h < n_psu_h:
1379
+ raise ValueError(
1380
+ f"FPC ({N_h}) is less than the number of PSUs "
1381
+ f"({n_psu_h}) in stratum. FPC must be >= n_PSU."
1382
+ )
1383
+ f_h = n_psu_h / N_h
1384
+ if f_h >= 1.0:
1385
+ legitimate_zero_count += 1
1386
+
1387
+ psu_mean_h = scores_h.mean(axis=0, keepdims=True)
1388
+ centered = scores_h - psu_mean_h
1389
+
1390
+ adjustment = (1.0 - f_h) * (n_psu_h / (n_psu_h - 1))
1391
+ with np.errstate(invalid="ignore", over="ignore"):
1392
+ meat += adjustment * (centered.T @ centered)
1393
+ _variance_computed = True
1394
+
1395
+ return meat, _variance_computed, legitimate_zero_count
1396
+
1397
+
1398
+ def compute_survey_vcov(
1399
+ X: np.ndarray,
1400
+ residuals: np.ndarray,
1401
+ resolved: "ResolvedSurveyDesign",
1402
+ ) -> np.ndarray:
1403
+ """
1404
+ Compute Taylor Series Linearization (TSL) variance-covariance matrix.
1405
+
1406
+ Implements the stratified cluster sandwich estimator with optional
1407
+ finite population correction (FPC).
1408
+
1409
+ V_TSL = (X'WX)^{-1} [sum_h V_h] (X'WX)^{-1}
1410
+
1411
+ Parameters
1412
+ ----------
1413
+ X : np.ndarray
1414
+ Design matrix of shape (n, k).
1415
+ residuals : np.ndarray
1416
+ Residuals from WLS fit (y - X @ beta, on ORIGINAL scale).
1417
+ resolved : ResolvedSurveyDesign
1418
+ Resolved survey design with weights, strata, PSU arrays.
1419
+
1420
+ Returns
1421
+ -------
1422
+ vcov : np.ndarray
1423
+ Variance-covariance matrix of shape (k, k).
1424
+ """
1425
+ n, k = X.shape
1426
+ weights = resolved.weights
1427
+
1428
+ # Bread: (X'WX)^{-1}
1429
+ XtWX = X.T @ (X * weights[:, np.newaxis])
1430
+
1431
+ # Compute weighted scores per observation: w_i * X_i * u_i
1432
+ if resolved.weight_type == "aweight":
1433
+ scores = X * residuals[:, np.newaxis]
1434
+ # Zero-weight observations should not contribute to aweight meat
1435
+ if np.any(weights == 0):
1436
+ scores[weights == 0] = 0.0
1437
+ else:
1438
+ scores = X * (weights * residuals)[:, np.newaxis]
1439
+
1440
+ meat, _variance_computed, legitimate_zero_count = _compute_stratified_psu_meat(scores, resolved)
1441
+
1442
+ # Guard: if meat is zero, distinguish legitimate zero from unidentified variance
1443
+ if not np.any(meat != 0):
1444
+ if _variance_computed or legitimate_zero_count > 0:
1445
+ return np.zeros((k, k))
1446
+ return np.full((k, k), np.nan)
1447
+
1448
+ # Sandwich: (X'WX)^{-1} meat (X'WX)^{-1}
1449
+ try:
1450
+ temp = np.linalg.solve(XtWX, meat)
1451
+ vcov = np.linalg.solve(XtWX, temp.T).T
1452
+ except np.linalg.LinAlgError as e:
1453
+ if "Singular" in str(e):
1454
+ raise ValueError(
1455
+ "Design matrix is rank-deficient (singular X'WX matrix). "
1456
+ "This indicates perfect multicollinearity."
1457
+ ) from e
1458
+ raise
1459
+
1460
+ return vcov
1461
+
1462
+
1463
+ def compute_survey_if_variance(
1464
+ psi: np.ndarray,
1465
+ resolved: "ResolvedSurveyDesign",
1466
+ ) -> float:
1467
+ """Compute design-based variance of a scalar estimator from IF values.
1468
+
1469
+ For influence-function-based estimators (e.g., CallawaySantAnna),
1470
+ the per-unit influence function values ``psi_i`` capture each unit's
1471
+ contribution to the estimating equation. Under simple random sampling
1472
+ the variance is ``sum(psi_i^2)``. This function computes the
1473
+ design-based analogue accounting for PSU clustering, stratification,
1474
+ and finite population correction.
1475
+
1476
+ V_design = sum_h (1-f_h) * (n_h/(n_h-1)) * sum_j (psi_hj - psi_h_bar)^2
1477
+
1478
+ where psi_hj = sum_{i in PSU j, stratum h} psi_i.
1479
+
1480
+ Parameters
1481
+ ----------
1482
+ psi : np.ndarray
1483
+ Per-unit influence function values, shape (n,).
1484
+ resolved : ResolvedSurveyDesign
1485
+ Resolved survey design.
1486
+
1487
+ Returns
1488
+ -------
1489
+ float
1490
+ Design-based variance. Returns ``np.nan`` when variance is
1491
+ unidentified (e.g., all strata removed by lonely_psu='remove').
1492
+ """
1493
+ psi = np.asarray(psi, dtype=np.float64).ravel()
1494
+
1495
+ meat, _variance_computed, legitimate_zero_count = _compute_stratified_psu_meat(
1496
+ psi[:, np.newaxis], resolved
1497
+ )
1498
+
1499
+ # meat is (1, 1) — extract scalar
1500
+ meat_scalar = float(meat[0, 0])
1501
+
1502
+ if meat_scalar == 0.0:
1503
+ if _variance_computed or legitimate_zero_count > 0:
1504
+ return 0.0
1505
+ return np.nan
1506
+
1507
+ return meat_scalar
1508
+
1509
+
1510
+ def _replicate_variance_factor(
1511
+ method: str,
1512
+ n_replicates: int,
1513
+ fay_rho: float,
1514
+ ) -> float:
1515
+ """Compute the scalar variance factor for replicate methods."""
1516
+ if method == "BRR":
1517
+ return 1.0 / n_replicates
1518
+ elif method == "Fay":
1519
+ return 1.0 / (n_replicates * (1.0 - fay_rho) ** 2)
1520
+ elif method == "SDR":
1521
+ return 4.0 / n_replicates
1522
+ elif method == "JK1":
1523
+ return (n_replicates - 1.0) / n_replicates
1524
+ # JKn handled separately (per-stratum factors)
1525
+ raise ValueError(f"Unknown replicate method: {method}")
1526
+
1527
+
1528
+ def compute_replicate_vcov(
1529
+ X: np.ndarray,
1530
+ y: np.ndarray,
1531
+ full_sample_coef: np.ndarray,
1532
+ resolved: "ResolvedSurveyDesign",
1533
+ weight_type: str = "pweight",
1534
+ ) -> np.ndarray:
1535
+ """Compute replicate-weight variance-covariance matrix.
1536
+
1537
+ Re-runs WLS for each replicate weight column and computes variance
1538
+ from the distribution of replicate coefficient vectors.
1539
+
1540
+ Parameters
1541
+ ----------
1542
+ X : np.ndarray
1543
+ Design matrix of shape (n, k).
1544
+ y : np.ndarray
1545
+ Response vector of shape (n,).
1546
+ full_sample_coef : np.ndarray
1547
+ Coefficients from the full-sample fit, shape (k,).
1548
+ resolved : ResolvedSurveyDesign
1549
+ Must have ``uses_replicate_variance == True``.
1550
+ weight_type : str, default "pweight"
1551
+ Weight type for per-replicate WLS.
1552
+
1553
+ Returns
1554
+ -------
1555
+ np.ndarray
1556
+ Variance-covariance matrix of shape (k, k).
1557
+ """
1558
+ from diff_diff.linalg import solve_ols
1559
+
1560
+ rep_weights = resolved.replicate_weights
1561
+ method = resolved.replicate_method
1562
+ R = resolved.n_replicates
1563
+ k = X.shape[1]
1564
+
1565
+ # Collect replicate coefficient vectors
1566
+ coef_reps = np.full((R, k), np.nan)
1567
+ for r in range(R):
1568
+ w_r = rep_weights[:, r]
1569
+ # For non-combined weights, multiply by full-sample weights
1570
+ if not resolved.combined_weights:
1571
+ w_r = w_r * resolved.weights
1572
+ # Skip replicates where all weights are zero
1573
+ if np.sum(w_r) == 0:
1574
+ continue
1575
+ try:
1576
+ coef_r, _, _ = solve_ols(
1577
+ X,
1578
+ y,
1579
+ weights=w_r,
1580
+ weight_type=weight_type,
1581
+ rank_deficient_action="silent",
1582
+ return_vcov=False,
1583
+ check_finite=False,
1584
+ )
1585
+ coef_reps[r] = coef_r
1586
+ except (np.linalg.LinAlgError, ValueError):
1587
+ pass # NaN row for singular/degenerate replicate solve
1588
+
1589
+ # Remove replicates with NaN coefficients
1590
+ valid = np.all(np.isfinite(coef_reps), axis=1)
1591
+ n_invalid = int(R - np.sum(valid))
1592
+ if n_invalid > 0:
1593
+ warnings.warn(
1594
+ f"{n_invalid} of {R} replicate solves failed (singular or degenerate). "
1595
+ f"Variance computed from {int(np.sum(valid))} valid replicates.",
1596
+ UserWarning,
1597
+ stacklevel=2,
1598
+ )
1599
+ n_valid = int(np.sum(valid))
1600
+ if n_valid < 2:
1601
+ if n_valid == 0:
1602
+ warnings.warn(
1603
+ "All replicate solves failed. Returning NaN variance.",
1604
+ UserWarning,
1605
+ stacklevel=2,
1606
+ )
1607
+ else:
1608
+ warnings.warn(
1609
+ f"Only {n_valid} valid replicate(s) — variance is not estimable "
1610
+ f"with fewer than 2. Returning NaN.",
1611
+ UserWarning,
1612
+ stacklevel=2,
1613
+ )
1614
+ return np.full((k, k), np.nan), n_valid
1615
+ coef_valid = coef_reps[valid]
1616
+ c = full_sample_coef
1617
+
1618
+ # Compute variance by method
1619
+ # Support mse=False: center on replicate mean instead of full-sample estimate
1620
+ # When rscales present and mse=False, center only over rscales > 0
1621
+ # (R's svrVar convention — zero-scaled replicates should not shift center)
1622
+ if resolved.mse:
1623
+ center = c
1624
+ else:
1625
+ if resolved.replicate_rscales is not None:
1626
+ pos_scale = resolved.replicate_rscales[valid] > 0
1627
+ if np.any(pos_scale):
1628
+ center = np.mean(coef_valid[pos_scale], axis=0)
1629
+ else:
1630
+ center = np.mean(coef_valid, axis=0)
1631
+ else:
1632
+ center = np.mean(coef_valid, axis=0)
1633
+ diffs = coef_valid - center[np.newaxis, :]
1634
+
1635
+ outer_sum = diffs.T @ diffs # (k, k)
1636
+
1637
+ # BRR/Fay: use fixed scaling, ignore user-supplied scale/rscales (R convention)
1638
+ if method in ("BRR", "Fay", "SDR"):
1639
+ if resolved.replicate_scale is not None or resolved.replicate_rscales is not None:
1640
+ warnings.warn(
1641
+ f"Custom replicate_scale/replicate_rscales ignored for {method} "
1642
+ f"(BRR/Fay/SDR use fixed scaling).",
1643
+ UserWarning,
1644
+ stacklevel=2,
1645
+ )
1646
+ factor = _replicate_variance_factor(method, R, resolved.fay_rho)
1647
+ return factor * outer_sum, n_valid
1648
+
1649
+ # JK1/JKn: apply scale * rscales multiplicatively (R's svrVar contract)
1650
+ scale = resolved.replicate_scale if resolved.replicate_scale is not None else 1.0
1651
+
1652
+ if resolved.replicate_rscales is not None:
1653
+ valid_rscales = resolved.replicate_rscales[valid]
1654
+ V = np.zeros((k, k))
1655
+ for i in range(len(diffs)):
1656
+ V += valid_rscales[i] * np.outer(diffs[i], diffs[i])
1657
+ return scale * V, n_valid
1658
+
1659
+ if method == "JK1":
1660
+ factor = _replicate_variance_factor(method, R, resolved.fay_rho)
1661
+ return scale * factor * outer_sum, n_valid
1662
+ elif method == "JKn":
1663
+ # JKn: V = sum_h ((n_h-1)/n_h) * sum_{r in h} (c_r - c)(c_r - c)^T
1664
+ rep_strata = resolved.replicate_strata
1665
+ if rep_strata is None:
1666
+ raise ValueError("JKn requires replicate_strata")
1667
+ valid_strata = rep_strata[valid]
1668
+ V = np.zeros((k, k))
1669
+ for h in np.unique(rep_strata):
1670
+ n_h_original = int(np.sum(rep_strata == h))
1671
+ mask_h = valid_strata == h
1672
+ if not np.any(mask_h):
1673
+ continue
1674
+ diffs_h = diffs[mask_h]
1675
+ V += ((n_h_original - 1.0) / n_h_original) * (diffs_h.T @ diffs_h)
1676
+ return scale * V, n_valid
1677
+ else:
1678
+ raise ValueError(f"Unknown replicate method: {method}")
1679
+
1680
+
1681
+ def compute_replicate_if_variance(
1682
+ psi: np.ndarray,
1683
+ resolved: "ResolvedSurveyDesign",
1684
+ ) -> Tuple[float, int]:
1685
+ """Compute replicate-based variance for influence-function estimators.
1686
+
1687
+ Instead of re-running the full estimator, reweights the influence
1688
+ function under each replicate weight set.
1689
+
1690
+ Parameters
1691
+ ----------
1692
+ psi : np.ndarray
1693
+ Per-unit influence function values, shape (n,).
1694
+ resolved : ResolvedSurveyDesign
1695
+ Must have ``uses_replicate_variance == True``.
1696
+
1697
+ Returns
1698
+ -------
1699
+ float
1700
+ Replicate-based variance estimate.
1701
+ """
1702
+ psi = np.asarray(psi, dtype=np.float64).ravel()
1703
+ rep_weights = resolved.replicate_weights
1704
+ method = resolved.replicate_method
1705
+ R = resolved.n_replicates
1706
+
1707
+ # Match the contract of compute_survey_if_variance(): psi is accepted
1708
+ # as-is (the combined IF/WIF object), with NO extra weight multiplication.
1709
+ # Replicate contrasts are formed by rescaling each unit's contribution
1710
+ # by the ratio w_r/w_full (Rao-Wu reweighting).
1711
+ full_weights = resolved.weights
1712
+ theta_full = float(np.sum(psi))
1713
+
1714
+ # Validate: combined_weights=True requires w_full > 0 wherever w_r > 0
1715
+ if resolved.combined_weights:
1716
+ for r in range(R):
1717
+ bad = (rep_weights[:, r] > 0) & (full_weights <= 0)
1718
+ if np.any(bad):
1719
+ raise ValueError(
1720
+ f"Replicate column {r} has positive weight where full-sample "
1721
+ f"weight is zero. With combined_weights=True, every "
1722
+ f"replicate-positive observation must have a positive "
1723
+ f"full-sample weight."
1724
+ )
1725
+
1726
+ # Compute replicate estimates via weight-ratio rescaling
1727
+ theta_reps = np.full(R, np.nan)
1728
+ for r in range(R):
1729
+ w_r = rep_weights[:, r]
1730
+ if np.any(w_r > 0):
1731
+ if resolved.combined_weights:
1732
+ # Combined: w_r already includes full-sample weight
1733
+ ratio = np.divide(
1734
+ w_r,
1735
+ full_weights,
1736
+ out=np.zeros_like(w_r, dtype=np.float64),
1737
+ where=full_weights > 0,
1738
+ )
1739
+ else:
1740
+ # Non-combined: w_r is perturbation factor directly
1741
+ ratio = w_r
1742
+ theta_reps[r] = np.sum(ratio * psi)
1743
+
1744
+ valid = np.isfinite(theta_reps)
1745
+ n_valid = int(np.sum(valid))
1746
+ if n_valid < 2:
1747
+ return np.nan, n_valid
1748
+
1749
+ # Support mse=False: center on replicate mean
1750
+ # When rscales present and mse=False, center only over rscales > 0
1751
+ # (R's svrVar convention — zero-scaled replicates should not shift center)
1752
+ if resolved.mse:
1753
+ center = theta_full
1754
+ else:
1755
+ if resolved.replicate_rscales is not None:
1756
+ pos_scale = resolved.replicate_rscales[valid] > 0
1757
+ if np.any(pos_scale):
1758
+ center = float(np.mean(theta_reps[valid][pos_scale]))
1759
+ else:
1760
+ center = float(np.mean(theta_reps[valid]))
1761
+ else:
1762
+ center = float(np.mean(theta_reps[valid]))
1763
+ diffs = theta_reps[valid] - center
1764
+
1765
+ ss = float(np.sum(diffs**2))
1766
+
1767
+ # BRR/Fay: use fixed scaling, ignore user-supplied scale/rscales (R convention)
1768
+ if method in ("BRR", "Fay", "SDR"):
1769
+ if resolved.replicate_scale is not None or resolved.replicate_rscales is not None:
1770
+ warnings.warn(
1771
+ f"Custom replicate_scale/replicate_rscales ignored for {method} "
1772
+ f"(BRR/Fay/SDR use fixed scaling).",
1773
+ UserWarning,
1774
+ stacklevel=2,
1775
+ )
1776
+ factor = _replicate_variance_factor(method, R, resolved.fay_rho)
1777
+ return factor * ss, n_valid
1778
+
1779
+ # JK1/JKn: apply scale * rscales multiplicatively (R's svrVar contract)
1780
+ scale = resolved.replicate_scale if resolved.replicate_scale is not None else 1.0
1781
+
1782
+ if resolved.replicate_rscales is not None:
1783
+ valid_rscales = resolved.replicate_rscales[valid]
1784
+ return scale * float(np.sum(valid_rscales * diffs**2)), n_valid
1785
+
1786
+ if method == "JK1":
1787
+ factor = _replicate_variance_factor(method, R, resolved.fay_rho)
1788
+ return scale * factor * ss, n_valid
1789
+ elif method == "JKn":
1790
+ rep_strata = resolved.replicate_strata
1791
+ if rep_strata is None:
1792
+ raise ValueError("JKn requires replicate_strata")
1793
+ valid_strata = rep_strata[valid]
1794
+ result = 0.0
1795
+ for h in np.unique(rep_strata):
1796
+ n_h_original = int(np.sum(rep_strata == h))
1797
+ mask_h = valid_strata == h
1798
+ if not np.any(mask_h):
1799
+ continue
1800
+ result += ((n_h_original - 1.0) / n_h_original) * float(np.sum(diffs[mask_h] ** 2))
1801
+ return scale * result, n_valid
1802
+ else:
1803
+ raise ValueError(f"Unknown replicate method: {method}")
1804
+
1805
+
1806
+ def compute_replicate_refit_variance(
1807
+ refit_fn: Callable[[np.ndarray], np.ndarray],
1808
+ full_sample_estimate: np.ndarray,
1809
+ resolved: "ResolvedSurveyDesign",
1810
+ ) -> Tuple[np.ndarray, int]:
1811
+ """Compute replicate variance by re-running an arbitrary estimation function.
1812
+
1813
+ For each replicate weight column, calls ``refit_fn(w_r)`` and collects
1814
+ the resulting estimate vector. Variance is computed from the distribution
1815
+ of replicate estimates using method-specific scaling.
1816
+
1817
+ This generalises :func:`compute_replicate_vcov` (which hard-codes
1818
+ ``solve_ols`` as the refit) for estimators whose estimation procedure
1819
+ is more complex than a single OLS call (e.g. within-transformation,
1820
+ two-stage imputation, stacked regression).
1821
+
1822
+ Parameters
1823
+ ----------
1824
+ refit_fn : callable
1825
+ ``(n,) weight array -> (k,) estimate array``. Must return the same
1826
+ length *k* on every call. Should return all-NaN when the estimation
1827
+ fails for that replicate.
1828
+ full_sample_estimate : np.ndarray
1829
+ Estimate vector from the full-sample weights, shape ``(k,)``.
1830
+ resolved : ResolvedSurveyDesign
1831
+ Must have ``uses_replicate_variance == True``.
1832
+
1833
+ Returns
1834
+ -------
1835
+ tuple of (np.ndarray, int)
1836
+ ``(vcov, n_valid)`` where *vcov* has shape ``(k, k)`` and *n_valid*
1837
+ is the number of replicates that produced finite estimates.
1838
+ """
1839
+ full_sample_estimate = np.asarray(full_sample_estimate, dtype=np.float64).ravel()
1840
+ k = len(full_sample_estimate)
1841
+ rep_weights = resolved.replicate_weights
1842
+ method = resolved.replicate_method
1843
+ R = resolved.n_replicates
1844
+
1845
+ # Collect replicate estimate vectors
1846
+ est_reps = np.full((R, k), np.nan)
1847
+ for r in range(R):
1848
+ w_r = rep_weights[:, r].copy()
1849
+ if not resolved.combined_weights:
1850
+ w_r = w_r * resolved.weights
1851
+ if np.sum(w_r) == 0:
1852
+ continue
1853
+ try:
1854
+ est_r = refit_fn(w_r)
1855
+ est_r = np.asarray(est_r, dtype=np.float64).ravel()
1856
+ if len(est_r) == k:
1857
+ est_reps[r] = est_r
1858
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
1859
+ pass # NaN row for failed replicate
1860
+
1861
+ # Remove replicates with NaN estimates
1862
+ valid = np.all(np.isfinite(est_reps), axis=1)
1863
+ n_invalid = int(R - np.sum(valid))
1864
+ if n_invalid > 0:
1865
+ warnings.warn(
1866
+ f"{n_invalid} of {R} replicate refits failed. "
1867
+ f"Variance computed from {int(np.sum(valid))} valid replicates.",
1868
+ UserWarning,
1869
+ stacklevel=2,
1870
+ )
1871
+ n_valid = int(np.sum(valid))
1872
+ if n_valid < 2:
1873
+ if n_valid == 0:
1874
+ warnings.warn(
1875
+ "All replicate refits failed. Returning NaN variance.",
1876
+ UserWarning,
1877
+ stacklevel=2,
1878
+ )
1879
+ else:
1880
+ warnings.warn(
1881
+ f"Only {n_valid} valid replicate(s) — variance is not estimable "
1882
+ f"with fewer than 2. Returning NaN.",
1883
+ UserWarning,
1884
+ stacklevel=2,
1885
+ )
1886
+ return np.full((k, k), np.nan), n_valid
1887
+
1888
+ est_valid = est_reps[valid]
1889
+ c = full_sample_estimate
1890
+
1891
+ # --- Centering (mse flag) ---
1892
+ if resolved.mse:
1893
+ center = c
1894
+ else:
1895
+ if resolved.replicate_rscales is not None:
1896
+ pos_scale = resolved.replicate_rscales[valid] > 0
1897
+ if np.any(pos_scale):
1898
+ center = np.mean(est_valid[pos_scale], axis=0)
1899
+ else:
1900
+ center = np.mean(est_valid, axis=0)
1901
+ else:
1902
+ center = np.mean(est_valid, axis=0)
1903
+ diffs = est_valid - center[np.newaxis, :]
1904
+
1905
+ outer_sum = diffs.T @ diffs # (k, k)
1906
+
1907
+ # --- Method-specific scaling ---
1908
+ # BRR/Fay: fixed scaling, ignore user-supplied scale/rscales
1909
+ if method in ("BRR", "Fay", "SDR"):
1910
+ if resolved.replicate_scale is not None or resolved.replicate_rscales is not None:
1911
+ warnings.warn(
1912
+ f"Custom replicate_scale/replicate_rscales ignored for {method} "
1913
+ f"(BRR/Fay/SDR use fixed scaling).",
1914
+ UserWarning,
1915
+ stacklevel=2,
1916
+ )
1917
+ factor = _replicate_variance_factor(method, R, resolved.fay_rho)
1918
+ return factor * outer_sum, n_valid
1919
+
1920
+ # JK1/JKn: apply scale * rscales multiplicatively
1921
+ scale = resolved.replicate_scale if resolved.replicate_scale is not None else 1.0
1922
+
1923
+ if resolved.replicate_rscales is not None:
1924
+ valid_rscales = resolved.replicate_rscales[valid]
1925
+ V = np.zeros((k, k))
1926
+ for i in range(len(diffs)):
1927
+ V += valid_rscales[i] * np.outer(diffs[i], diffs[i])
1928
+ return scale * V, n_valid
1929
+
1930
+ if method == "JK1":
1931
+ factor = _replicate_variance_factor(method, R, resolved.fay_rho)
1932
+ return scale * factor * outer_sum, n_valid
1933
+ elif method == "JKn":
1934
+ rep_strata = resolved.replicate_strata
1935
+ if rep_strata is None:
1936
+ raise ValueError("JKn requires replicate_strata")
1937
+ valid_strata = rep_strata[valid]
1938
+ V = np.zeros((k, k))
1939
+ for h in np.unique(rep_strata):
1940
+ n_h_original = int(np.sum(rep_strata == h))
1941
+ mask_h = valid_strata == h
1942
+ if not np.any(mask_h):
1943
+ continue
1944
+ diffs_h = diffs[mask_h]
1945
+ V += ((n_h_original - 1.0) / n_h_original) * (diffs_h.T @ diffs_h)
1946
+ return scale * V, n_valid
1947
+ else:
1948
+ raise ValueError(f"Unknown replicate method: {method}")
1949
+
1950
+
1951
+ def aggregate_to_psu(
1952
+ values: np.ndarray,
1953
+ resolved: "ResolvedSurveyDesign",
1954
+ ) -> tuple:
1955
+ """Sum values within PSUs for PSU-level bootstrap perturbation.
1956
+
1957
+ Parameters
1958
+ ----------
1959
+ values : np.ndarray
1960
+ Per-observation values, shape (n,) or (n, k).
1961
+ resolved : ResolvedSurveyDesign
1962
+ Resolved survey design.
1963
+
1964
+ Returns
1965
+ -------
1966
+ psu_sums : np.ndarray
1967
+ Aggregated values, shape (n_psu,) or (n_psu, k).
1968
+ psu_ids : np.ndarray
1969
+ Unique PSU identifiers in the same order as ``psu_sums``.
1970
+ """
1971
+ if resolved.psu is None:
1972
+ # Each observation is its own PSU — return as-is
1973
+ return values.copy(), np.arange(len(values))
1974
+
1975
+ psu = resolved.psu
1976
+ unique_psu = np.unique(psu)
1977
+ if values.ndim == 1:
1978
+ psu_sums = np.array([values[psu == p].sum() for p in unique_psu])
1979
+ else:
1980
+ psu_sums = np.array([values[psu == p].sum(axis=0) for p in unique_psu])
1981
+ return psu_sums, unique_psu