diff-diff 2.9.1__tar.gz → 3.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {diff_diff-2.9.1 → diff_diff-3.0.1}/PKG-INFO +4 -3
  2. {diff_diff-2.9.1 → diff_diff-3.0.1}/README.md +1 -1
  3. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/__init__.py +3 -1
  4. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/linalg.py +64 -2
  5. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/prep.py +443 -7
  6. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/staggered.py +1 -20
  7. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/staggered_bootstrap.py +4 -4
  8. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/staggered_triple_diff.py +1 -2
  9. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/synthetic_did.py +4 -3
  10. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/trop.py +4 -37
  11. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/wooldridge.py +267 -72
  12. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/wooldridge_results.py +9 -1
  13. {diff_diff-2.9.1 → diff_diff-3.0.1}/pyproject.toml +3 -2
  14. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/Cargo.lock +20 -48
  15. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/Cargo.toml +6 -5
  16. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/src/bootstrap.rs +1 -1
  17. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/src/linalg.rs +5 -5
  18. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/src/trop.rs +3 -3
  19. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/src/weights.rs +5 -5
  20. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/_backend.py +0 -0
  21. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/bacon.py +0 -0
  22. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/bootstrap_utils.py +0 -0
  23. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/continuous_did.py +0 -0
  24. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/continuous_did_bspline.py +0 -0
  25. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/continuous_did_results.py +0 -0
  26. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/datasets.py +0 -0
  27. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/diagnostics.py +0 -0
  28. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/efficient_did.py +0 -0
  29. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/efficient_did_bootstrap.py +0 -0
  30. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/efficient_did_covariates.py +0 -0
  31. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/efficient_did_results.py +0 -0
  32. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/efficient_did_weights.py +0 -0
  33. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/estimators.py +0 -0
  34. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/honest_did.py +0 -0
  35. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/imputation.py +0 -0
  36. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/imputation_bootstrap.py +0 -0
  37. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/imputation_results.py +0 -0
  38. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/power.py +0 -0
  39. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/practitioner.py +0 -0
  40. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/prep_dgp.py +0 -0
  41. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/pretrends.py +0 -0
  42. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/results.py +0 -0
  43. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/stacked_did.py +0 -0
  44. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/stacked_did_results.py +0 -0
  45. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/staggered_aggregation.py +0 -0
  46. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/staggered_results.py +0 -0
  47. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/staggered_triple_diff_results.py +0 -0
  48. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/sun_abraham.py +0 -0
  49. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/survey.py +0 -0
  50. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/triple_diff.py +0 -0
  51. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/trop_global.py +0 -0
  52. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/trop_local.py +0 -0
  53. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/trop_results.py +0 -0
  54. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/twfe.py +0 -0
  55. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/two_stage.py +0 -0
  56. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/two_stage_bootstrap.py +0 -0
  57. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/two_stage_results.py +0 -0
  58. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/utils.py +0 -0
  59. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/__init__.py +0 -0
  60. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_common.py +0 -0
  61. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_continuous.py +0 -0
  62. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_diagnostic.py +0 -0
  63. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_event_study.py +0 -0
  64. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_power.py +0 -0
  65. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_staggered.py +0 -0
  66. {diff_diff-2.9.1 → diff_diff-3.0.1}/diff_diff/visualization/_synthetic.py +0 -0
  67. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/build.rs +0 -0
  68. {diff_diff-2.9.1 → diff_diff-3.0.1}/rust/src/lib.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.9.1
3
+ Version: 3.0.1
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Classifier: Programming Language :: Python :: 3.13
13
+ Classifier: Programming Language :: Python :: 3.14
13
14
  Classifier: Topic :: Scientific/Engineering :: Mathematics
14
15
  Classifier: Topic :: Scientific/Engineering :: Information Analysis
15
16
  Classifier: Topic :: Scientific/Engineering
@@ -40,7 +41,7 @@ Summary: Difference-in-Differences causal inference with sklearn-like API. Calla
40
41
  Keywords: causal-inference,difference-in-differences,econometrics,statistics,treatment-effects,event-study,staggered-adoption,parallel-trends,synthetic-control,panel-data,did,twfe,callaway-santanna,honest-did,sensitivity-analysis
41
42
  Author: diff-diff contributors
42
43
  License-Expression: MIT
43
- Requires-Python: >=3.9, <3.14
44
+ Requires-Python: >=3.9, <3.15
44
45
  Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
45
46
  Project-URL: Documentation, https://diff-diff.readthedocs.io
46
47
  Project-URL: Homepage, https://github.com/igerber/diff-diff
@@ -2819,7 +2820,7 @@ Returns DataFrame with columns: `unit`, `quality_score`, `outcome_trend_score`,
2819
2820
 
2820
2821
  ## Requirements
2821
2822
 
2822
- - Python 3.9 - 3.13
2823
+ - Python 3.9 - 3.14
2823
2824
  - numpy >= 1.20
2824
2825
  - pandas >= 1.3
2825
2826
  - scipy >= 1.7
@@ -2769,7 +2769,7 @@ Returns DataFrame with columns: `unit`, `quality_score`, `outcome_trend_score`,
2769
2769
 
2770
2770
  ## Requirements
2771
2771
 
2772
- - Python 3.9 - 3.13
2772
+ - Python 3.9 - 3.14
2773
2773
  - numpy >= 1.20
2774
2774
  - pandas >= 1.3
2775
2775
  - scipy >= 1.7
@@ -78,6 +78,7 @@ from diff_diff.pretrends import (
78
78
  compute_pretrends_power,
79
79
  )
80
80
  from diff_diff.prep import (
81
+ aggregate_survey,
81
82
  aggregate_to_cohorts,
82
83
  balance_panel,
83
84
  create_event_time,
@@ -214,7 +215,7 @@ Bacon = BaconDecomposition
214
215
  EDiD = EfficientDiD
215
216
  ETWFE = WooldridgeDiD
216
217
 
217
- __version__ = "2.9.1"
218
+ __version__ = "3.0.1"
218
219
  __all__ = [
219
220
  # Estimators
220
221
  "DifferenceInDifferences",
@@ -328,6 +329,7 @@ __all__ = [
328
329
  "generate_survey_did_data",
329
330
  "generate_continuous_did_data",
330
331
  "create_event_time",
332
+ "aggregate_survey",
331
333
  "aggregate_to_cohorts",
332
334
  "rank_control_units",
333
335
  # Honest DiD sensitivity analysis
@@ -2372,6 +2372,7 @@ def solve_poisson(
2372
2372
  tol: float = 1e-8,
2373
2373
  init_beta: Optional[np.ndarray] = None,
2374
2374
  rank_deficient_action: str = "warn",
2375
+ weights: Optional[np.ndarray] = None,
2375
2376
  ) -> Tuple[np.ndarray, np.ndarray]:
2376
2377
  """Poisson IRLS (Newton-Raphson with log link).
2377
2378
 
@@ -2389,6 +2390,9 @@ def solve_poisson(
2389
2390
  log(mean(y)) to improve convergence for large-scale outcomes.
2390
2391
  rank_deficient_action : {"warn", "error", "silent"}
2391
2392
  How to handle rank-deficient design matrices. Mirrors solve_ols/solve_logit.
2393
+ weights : (n,) optional observation weights (e.g. survey sampling weights).
2394
+ When provided, the weighted pseudo-log-likelihood is maximised:
2395
+ score = X'(w*(y - mu)), Hessian = X'diag(w*mu)X.
2392
2396
 
2393
2397
  Returns
2394
2398
  -------
@@ -2397,6 +2401,20 @@ def solve_poisson(
2397
2401
  """
2398
2402
  n, k_orig = X.shape
2399
2403
 
2404
+ # Validate weights (mirrors solve_logit validation)
2405
+ if weights is not None:
2406
+ weights = np.asarray(weights, dtype=np.float64)
2407
+ if weights.shape != (n,):
2408
+ raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
2409
+ if np.any(np.isnan(weights)):
2410
+ raise ValueError("weights contain NaN values")
2411
+ if np.any(~np.isfinite(weights)):
2412
+ raise ValueError("weights contain Inf values")
2413
+ if np.any(weights < 0):
2414
+ raise ValueError("weights must be non-negative")
2415
+ if np.sum(weights) <= 0:
2416
+ raise ValueError("weights sum to zero — no observations have positive weight")
2417
+
2400
2418
  # Validate rank_deficient_action (same as solve_logit/solve_ols)
2401
2419
  valid_actions = ("warn", "error", "silent")
2402
2420
  if rank_deficient_action not in valid_actions:
@@ -2425,6 +2443,46 @@ def solve_poisson(
2425
2443
  X = X[:, kept_cols]
2426
2444
 
2427
2445
  n, k = X.shape
2446
+
2447
+ # Validate effective weighted sample when weights have zeros
2448
+ # (mirrors solve_logit's positive-weight safeguards)
2449
+ if weights is not None and np.any(weights == 0):
2450
+ pos_mask = weights > 0
2451
+ n_pos = int(np.sum(pos_mask))
2452
+ X_eff = X[pos_mask]
2453
+ eff_rank_info = _detect_rank_deficiency(X_eff)
2454
+ if len(eff_rank_info[1]) > 0:
2455
+ n_dropped_eff = len(eff_rank_info[1])
2456
+ if rank_deficient_action == "error":
2457
+ raise ValueError(
2458
+ f"Effective (positive-weight) sample is rank-deficient: "
2459
+ f"{n_dropped_eff} linearly dependent column(s). "
2460
+ f"Cannot identify Poisson model on this subpopulation."
2461
+ )
2462
+ elif rank_deficient_action == "warn":
2463
+ warnings.warn(
2464
+ f"Effective (positive-weight) sample is rank-deficient: "
2465
+ f"dropping {n_dropped_eff} column(s). Poisson estimates "
2466
+ f"may be unreliable on this subpopulation.",
2467
+ UserWarning,
2468
+ stacklevel=2,
2469
+ )
2470
+ eff_dropped = set(int(d) for d in eff_rank_info[1])
2471
+ eff_kept = np.array([i for i in range(k) if i not in eff_dropped])
2472
+ X = X[:, eff_kept]
2473
+ if len(dropped_cols) > 0:
2474
+ kept_cols = kept_cols[eff_kept]
2475
+ else:
2476
+ kept_cols = eff_kept
2477
+ dropped_cols = list(eff_dropped)
2478
+ n, k = X.shape
2479
+ if n_pos <= k:
2480
+ raise ValueError(
2481
+ f"Only {n_pos} positive-weight observation(s) for "
2482
+ f"{k} parameters (after rank reduction). "
2483
+ f"Cannot identify Poisson model."
2484
+ )
2485
+
2428
2486
  if init_beta is not None:
2429
2487
  beta = init_beta[kept_cols].copy() if len(dropped_cols) > 0 else init_beta.copy()
2430
2488
  else:
@@ -2438,8 +2496,12 @@ def solve_poisson(
2438
2496
  for _ in range(max_iter):
2439
2497
  eta = np.clip(X @ beta, -500, 500)
2440
2498
  mu = np.exp(eta)
2441
- score = X.T @ (y - mu) # gradient of log-likelihood
2442
- hess = X.T @ (mu[:, None] * X) # -Hessian = X'WX, W=diag(mu)
2499
+ if weights is not None:
2500
+ score = X.T @ (weights * (y - mu))
2501
+ hess = X.T @ ((weights * mu)[:, None] * X)
2502
+ else:
2503
+ score = X.T @ (y - mu)
2504
+ hess = X.T @ (mu[:, None] * X)
2443
2505
  try:
2444
2506
  delta = np.linalg.solve(hess + 1e-12 * np.eye(k), score)
2445
2507
  except np.linalg.LinAlgError:
@@ -9,25 +9,30 @@ Data generation functions (generate_*) are defined in prep_dgp.py and
9
9
  re-exported here for backward compatibility.
10
10
  """
11
11
 
12
- from typing import Any, Dict, List, Optional, Union
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
13
 
14
14
  import numpy as np
15
15
  import pandas as pd
16
16
 
17
- from diff_diff.utils import compute_synthetic_weights
18
-
19
17
  # Re-export data generation functions from prep_dgp for backward compatibility
20
- from diff_diff.prep_dgp import (
18
+ from diff_diff.prep_dgp import ( # noqa: F401
21
19
  generate_continuous_did_data,
20
+ generate_ddd_data,
22
21
  generate_did_data,
23
- generate_staggered_data,
22
+ generate_event_study_data,
24
23
  generate_factor_data,
25
- generate_ddd_data,
26
24
  generate_panel_data,
27
- generate_event_study_data,
25
+ generate_staggered_data,
28
26
  generate_staggered_ddd_data,
29
27
  generate_survey_did_data,
30
28
  )
29
+ from diff_diff.survey import (
30
+ ResolvedSurveyDesign,
31
+ SurveyDesign,
32
+ compute_replicate_if_variance,
33
+ compute_survey_if_variance,
34
+ )
35
+ from diff_diff.utils import compute_synthetic_weights
31
36
 
32
37
  # Constants for rank_control_units
33
38
  _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
@@ -1300,3 +1305,434 @@ def trim_weights(
1300
1305
 
1301
1306
  result[weight_col] = w
1302
1307
  return result
1308
+
1309
+
1310
+ # ---------------------------------------------------------------------------
1311
+ # Survey aggregation helpers
1312
+ # ---------------------------------------------------------------------------
1313
+
1314
+
1315
+ def _cell_mean_variance(
1316
+ y_full: np.ndarray,
1317
+ full_resolved: ResolvedSurveyDesign,
1318
+ cell_mask: np.ndarray,
1319
+ min_n: int,
1320
+ ) -> Tuple[float, float, int, bool]:
1321
+ """Compute design-based mean and variance of the weighted mean for one cell.
1322
+
1323
+ Uses full-design domain estimation: the influence function is zero-padded
1324
+ outside the cell, preserving the full strata/PSU structure for variance
1325
+ estimation. This is the methodologically correct approach for domain
1326
+ estimation under complex survey designs (Lumley 2004, Section 3.4).
1327
+
1328
+ Parameters
1329
+ ----------
1330
+ y_full : np.ndarray
1331
+ Outcome values for the full dataset (may contain NaN).
1332
+ full_resolved : ResolvedSurveyDesign
1333
+ Full-sample resolved survey design.
1334
+ cell_mask : np.ndarray
1335
+ Boolean mask identifying cell members in the full dataset.
1336
+ min_n : int
1337
+ Minimum valid observations for design-based variance. Below this
1338
+ threshold, SRS fallback is used.
1339
+
1340
+ Returns
1341
+ -------
1342
+ mean : float
1343
+ Design-weighted cell mean.
1344
+ variance : float
1345
+ Design-based variance of the cell mean (>= 0). Uses SRS fallback
1346
+ when the design-based estimate is unidentifiable or n_valid < min_n.
1347
+ n_valid : int
1348
+ Number of non-missing observations in the cell.
1349
+ used_srs_fallback : bool
1350
+ True if SRS variance was used instead of design-based.
1351
+ """
1352
+ y_cell = y_full[cell_mask]
1353
+ w_cell = full_resolved.weights[cell_mask]
1354
+ # Valid = non-missing AND positive weight (zero-weight rows are padding)
1355
+ valid = ~np.isnan(y_cell) & (w_cell > 0)
1356
+ n_valid = int(np.sum(valid))
1357
+
1358
+ if n_valid == 0:
1359
+ return np.nan, np.nan, 0, False
1360
+
1361
+ if n_valid < 2:
1362
+ y_bar = float(y_cell[valid][0])
1363
+ return y_bar, np.nan, 1, False
1364
+
1365
+ # Weighted mean from cell members (NaN-safe)
1366
+ w_valid = w_cell * valid.astype(np.float64)
1367
+ y_clean = np.where(valid, y_cell, 0.0)
1368
+ sum_w = float(np.sum(w_valid))
1369
+
1370
+ if sum_w <= 0:
1371
+ return np.nan, np.nan, n_valid, False
1372
+
1373
+ y_bar = float(np.sum(w_valid * y_clean) / sum_w)
1374
+
1375
+ # SRS fallback if below min_n threshold
1376
+ # Normalize positive weights to mean=1 so fallback is scale-invariant
1377
+ # (replicate designs preserve raw weight scale per survey.py:L189-240)
1378
+ used_srs = False
1379
+ if n_valid < min_n:
1380
+ w_norm = w_valid.copy()
1381
+ w_pos = w_norm[w_norm > 0]
1382
+ if len(w_pos) > 0:
1383
+ w_norm[w_norm > 0] = w_pos / w_pos.mean()
1384
+ sum_wn = float(np.sum(w_norm))
1385
+ resid_sq = w_norm * (y_clean - y_bar) ** 2
1386
+ variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
1387
+ return y_bar, max(variance, 0.0), n_valid, True
1388
+
1389
+ # Full-design domain estimation: construct full-length psi with zeros
1390
+ # outside the cell, preserving full strata/PSU structure for variance
1391
+ n_total = len(y_full)
1392
+ psi = np.zeros(n_total)
1393
+ # Positions in full array where cell member has valid data
1394
+ cell_indices = np.where(cell_mask)[0]
1395
+ valid_positions = cell_indices[valid]
1396
+ psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w
1397
+
1398
+ # Route to TSL or replicate variance using the full design
1399
+ if full_resolved.uses_replicate_variance:
1400
+ variance, _ = compute_replicate_if_variance(psi, full_resolved)
1401
+ else:
1402
+ variance = compute_survey_if_variance(psi, full_resolved)
1403
+
1404
+ # SRS fallback when design-based variance is unidentifiable
1405
+ if np.isnan(variance):
1406
+ w_norm = w_valid.copy()
1407
+ w_pos = w_norm[w_norm > 0]
1408
+ if len(w_pos) > 0:
1409
+ w_norm[w_norm > 0] = w_pos / w_pos.mean()
1410
+ sum_wn = float(np.sum(w_norm))
1411
+ resid_sq = w_norm * (y_clean - y_bar) ** 2
1412
+ variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
1413
+ used_srs = True
1414
+
1415
+ return y_bar, max(float(variance), 0.0), n_valid, used_srs
1416
+
1417
+
1418
+ def aggregate_survey(
1419
+ data: pd.DataFrame,
1420
+ by: Union[str, List[str]],
1421
+ outcomes: Union[str, List[str]],
1422
+ survey_design: SurveyDesign,
1423
+ covariates: Optional[Union[str, List[str]]] = None,
1424
+ min_n: int = 2,
1425
+ lonely_psu: Optional[str] = None,
1426
+ ) -> Tuple[pd.DataFrame, SurveyDesign]:
1427
+ """Aggregate survey microdata to geographic-period cells with design-based precision.
1428
+
1429
+ Computes design-weighted cell means and their Taylor-linearized (or
1430
+ replicate-based) standard errors for each cell defined by the ``by``
1431
+ columns. Returns a panel-ready DataFrame with precision weights and a
1432
+ pre-configured :class:`SurveyDesign` for second-stage DiD estimation.
1433
+
1434
+ Each cell is treated as a subpopulation/domain of the full survey
1435
+ design: influence function values are zero-padded outside the cell,
1436
+ preserving full strata/PSU structure for variance estimation per
1437
+ Lumley (2004) Section 3.4.
1438
+
1439
+ Parameters
1440
+ ----------
1441
+ data : pd.DataFrame
1442
+ Individual-level microdata.
1443
+ by : str or list of str
1444
+ Columns defining cells (e.g., ``["state", "year"]``). The first
1445
+ element is used as the clustering variable in the returned
1446
+ SurveyDesign (geographic unit for second-stage inference).
1447
+ outcomes : str or list of str
1448
+ Outcome variable(s) to aggregate with full precision tracking.
1449
+ Each outcome produces ``{name}_mean``, ``{name}_se``,
1450
+ ``{name}_n``, and ``{name}_precision`` columns. When multiple
1451
+ outcomes are given, panel filtering (non-estimable cell
1452
+ removal, zero-weight PSU pruning) is based on the **first**
1453
+ outcome only, consistent with the returned SurveyDesign. For
1454
+ independent per-outcome support, call once per outcome.
1455
+ survey_design : SurveyDesign
1456
+ Survey design specification for the microdata.
1457
+ covariates : str or list of str, optional
1458
+ Additional variables to aggregate as design-weighted means only
1459
+ (no SE/precision columns).
1460
+ min_n : int, default 2
1461
+ Minimum respondents per cell. Cells below this threshold use
1462
+ simple random sampling variance as a fallback.
1463
+ lonely_psu : str, optional
1464
+ Override the survey design's ``lonely_psu`` setting for within-cell
1465
+ computation. One of ``"remove"``, ``"certainty"``, ``"adjust"``.
1466
+
1467
+ Returns
1468
+ -------
1469
+ panel_df : pd.DataFrame
1470
+ Aggregated panel with columns: grouping variables,
1471
+ ``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``,
1472
+ ``{outcome}_precision``, ``{outcome}_weight``,
1473
+ ``{covariate}_mean``, ``cell_n``, ``cell_n_eff``,
1474
+ ``srs_fallback``. The ``_weight`` column is a fit-ready
1475
+ version of ``_precision`` with NaN/Inf mapped to 0.0.
1476
+ second_stage_design : SurveyDesign
1477
+ Pre-configured for second-stage estimation with
1478
+ ``weight_type="aweight"``, precision weights from the first
1479
+ outcome, and geographic clustering via ``psu``.
1480
+
1481
+ Examples
1482
+ --------
1483
+ >>> design = SurveyDesign(weights="finalwt", strata="strat", psu="psu")
1484
+ >>> panel, stage2 = aggregate_survey(
1485
+ ... microdata, by=["state", "year"],
1486
+ ... outcomes="smoking_rate", survey_design=design,
1487
+ ... )
1488
+ >>> # Add treatment/time indicators at the panel level, then fit:
1489
+ >>> # panel["treated"] = ... # e.g., from policy adoption data
1490
+ >>> # panel["post"] = (panel["year"] >= treatment_year).astype(int)
1491
+ >>> # result = DifferenceInDifferences().fit(
1492
+ >>> # panel, outcome="smoking_rate_mean",
1493
+ >>> # treatment="treated", time="post", survey_design=stage2,
1494
+ >>> # )
1495
+ """
1496
+ import warnings
1497
+ from dataclasses import replace
1498
+
1499
+ # --- Normalize inputs ---
1500
+ by_cols = [by] if isinstance(by, str) else list(by)
1501
+ outcome_cols = [outcomes] if isinstance(outcomes, str) else list(outcomes)
1502
+ cov_cols = (
1503
+ [covariates] if isinstance(covariates, str) else list(covariates) if covariates else []
1504
+ )
1505
+
1506
+ # --- Validate ---
1507
+ if not by_cols:
1508
+ raise ValueError("'by' must specify at least one grouping column")
1509
+ if not outcome_cols:
1510
+ raise ValueError("'outcomes' must specify at least one outcome variable")
1511
+
1512
+ all_cols = by_cols + outcome_cols + cov_cols
1513
+ missing = [c for c in all_cols if c not in data.columns]
1514
+ if missing:
1515
+ raise ValueError(f"Columns not found in DataFrame: {missing}")
1516
+
1517
+ overlap = set(by_cols) & (set(outcome_cols) | set(cov_cols))
1518
+ if overlap:
1519
+ raise ValueError(f"Columns appear in both 'by' and outcomes/covariates: {overlap}")
1520
+
1521
+ if not isinstance(survey_design, SurveyDesign):
1522
+ raise TypeError(
1523
+ f"survey_design must be a SurveyDesign instance, got {type(survey_design).__name__}"
1524
+ )
1525
+
1526
+ if min_n < 1:
1527
+ raise ValueError(f"min_n must be >= 1, got {min_n}")
1528
+
1529
+ if lonely_psu is not None and lonely_psu not in ("remove", "certainty", "adjust"):
1530
+ raise ValueError(
1531
+ f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'"
1532
+ )
1533
+
1534
+ # --- Empty-input guard ---
1535
+ if data.empty:
1536
+ raise ValueError("data must be non-empty")
1537
+
1538
+ # --- Validate grouping columns have no missing values ---
1539
+ by_missing = data[by_cols].isna().any()
1540
+ cols_with_na = list(by_missing[by_missing].index)
1541
+ if cols_with_na:
1542
+ raise ValueError(
1543
+ f"Missing values in grouping column(s): {cols_with_na}. "
1544
+ f"Drop or fill NaN values before calling aggregate_survey()."
1545
+ )
1546
+
1547
+ # --- Resolve design once on full data ---
1548
+ effective_design = (
1549
+ replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design
1550
+ )
1551
+ full_resolved = effective_design.resolve(data)
1552
+
1553
+ # --- Precompute full-length outcome/covariate arrays ---
1554
+ n_total = len(data)
1555
+ all_vars = outcome_cols + cov_cols
1556
+ non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])]
1557
+ if non_numeric:
1558
+ raise ValueError(
1559
+ f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. "
1560
+ f"All outcome and covariate columns must be numeric."
1561
+ )
1562
+ y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars}
1563
+
1564
+ # --- Per-cell computation ---
1565
+ # Use groupby().indices for position-based cell membership (safe with
1566
+ # duplicate DataFrame indices, no column injection into user data)
1567
+ grouped = data.groupby(by_cols, sort=True)
1568
+ cell_indices = grouped.indices # dict of cell_key → positional indices
1569
+ rows: List[Dict[str, Any]] = []
1570
+ srs_cells: List[str] = []
1571
+ zero_var_cells: List[str] = []
1572
+
1573
+ for cell_key, pos_idx in cell_indices.items():
1574
+ # Boolean mask for full-design domain estimation
1575
+ cell_mask = np.zeros(n_total, dtype=bool)
1576
+ cell_mask[pos_idx] = True
1577
+
1578
+ cell_n = int(np.sum(cell_mask))
1579
+ cell_key_str = str(cell_key)
1580
+
1581
+ # Cell-level statistics (Kish ESS is a property of the cell)
1582
+ cell_w = full_resolved.weights[cell_mask]
1583
+ sum_w = float(np.sum(cell_w))
1584
+ sum_w2 = float(np.sum(cell_w**2))
1585
+ cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0
1586
+
1587
+ # Build row dict with grouping columns
1588
+ row: Dict[str, Any] = {}
1589
+ if len(by_cols) == 1:
1590
+ row[by_cols[0]] = cell_key
1591
+ else:
1592
+ for i, col in enumerate(by_cols):
1593
+ row[col] = cell_key[i]
1594
+
1595
+ row["cell_n"] = cell_n
1596
+ row["cell_n_eff"] = cell_n_eff
1597
+
1598
+ cell_srs_fallback = False
1599
+
1600
+ # Outcomes: mean + SE + n + precision (full-design domain estimation)
1601
+ for var in outcome_cols:
1602
+ y_bar, variance, n_valid, used_srs = _cell_mean_variance(
1603
+ y_arrays[var],
1604
+ full_resolved,
1605
+ cell_mask,
1606
+ min_n,
1607
+ )
1608
+ se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan
1609
+
1610
+ if used_srs:
1611
+ cell_srs_fallback = True
1612
+
1613
+ # Zero variance → precision NaN
1614
+ if se == 0.0:
1615
+ precision = np.nan
1616
+ zero_var_cells.append(cell_key_str)
1617
+ elif np.isnan(se):
1618
+ precision = np.nan
1619
+ else:
1620
+ precision = 1.0 / variance
1621
+
1622
+ row[f"{var}_mean"] = y_bar
1623
+ row[f"{var}_se"] = se
1624
+ row[f"{var}_n"] = n_valid
1625
+ row[f"{var}_precision"] = precision
1626
+
1627
+ # Covariates: design-weighted mean only
1628
+ for var in cov_cols:
1629
+ y_cell = y_arrays[var][cell_mask]
1630
+ valid = ~np.isnan(y_cell)
1631
+ w_valid = cell_w * valid.astype(np.float64)
1632
+ sw = float(np.sum(w_valid))
1633
+ if sw > 0:
1634
+ row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw)
1635
+ else:
1636
+ row[f"{var}_mean"] = np.nan
1637
+
1638
+ row["srs_fallback"] = cell_srs_fallback
1639
+ if cell_srs_fallback:
1640
+ srs_cells.append(cell_key_str)
1641
+
1642
+ rows.append(row)
1643
+
1644
+ # --- Warnings ---
1645
+ if srs_cells:
1646
+ warnings.warn(
1647
+ f"Design-based variance not estimable for {len(srs_cells)} cell(s); "
1648
+ f"using SRS fallback: {srs_cells[:5]}"
1649
+ + (f" ... and {len(srs_cells) - 5} more" if len(srs_cells) > 5 else ""),
1650
+ UserWarning,
1651
+ stacklevel=2,
1652
+ )
1653
+ if zero_var_cells:
1654
+ warnings.warn(
1655
+ f"Zero variance in {len(zero_var_cells)} cell(s) (precision set to NaN): "
1656
+ f"{zero_var_cells[:5]}"
1657
+ + (f" ... and {len(zero_var_cells) - 5} more" if len(zero_var_cells) > 5 else ""),
1658
+ UserWarning,
1659
+ stacklevel=2,
1660
+ )
1661
+
1662
+ # --- Assemble output ---
1663
+ panel_df = pd.DataFrame(rows)
1664
+
1665
+ # Sort by grouping columns
1666
+ panel_df = panel_df.sort_values(by_cols).reset_index(drop=True)
1667
+
1668
+ # --- Drop non-estimable cells ---
1669
+ # Cells with non-finite mean (n_valid==0 or all-missing) cannot contribute
1670
+ # to second-stage estimation and would cause fit() to reject NaN outcomes.
1671
+ # Dropping them also removes all-zero-weight PSUs from the panel.
1672
+ first_outcome = outcome_cols[0]
1673
+ mean_col = f"{first_outcome}_mean"
1674
+ nonestimable = ~np.isfinite(panel_df[mean_col].values)
1675
+ if np.any(nonestimable):
1676
+ n_dropped = int(np.sum(nonestimable))
1677
+ dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist()
1678
+ # Warn about secondary outcomes losing valid data in dropped cells
1679
+ secondary_loss = []
1680
+ for var in outcome_cols[1:]:
1681
+ valid_secondary = np.isfinite(panel_df.loc[nonestimable, f"{var}_mean"].values)
1682
+ if np.any(valid_secondary):
1683
+ secondary_loss.append(var)
1684
+ msg = (
1685
+ f"Dropped {n_dropped} non-estimable cell(s) (based on first outcome "
1686
+ f"'{first_outcome}'): {dropped_keys[:5]}"
1687
+ + (f" ... and {n_dropped - 5} more" if n_dropped > 5 else "")
1688
+ )
1689
+ if secondary_loss:
1690
+ msg += (
1691
+ f". Note: {secondary_loss} had valid data in dropped cells. "
1692
+ f"For independent per-outcome support, call once per outcome."
1693
+ )
1694
+ warnings.warn(msg, UserWarning, stacklevel=2)
1695
+ panel_df = panel_df[~nonestimable].reset_index(drop=True)
1696
+
1697
+ # --- Construct second-stage SurveyDesign ---
1698
+ # Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream
1699
+ # resolve() doesn't reject missing weights. Diagnostic *_precision is kept.
1700
+ weight_col = f"{first_outcome}_weight"
1701
+ panel_df[weight_col] = np.where(
1702
+ np.isfinite(panel_df[f"{first_outcome}_precision"]),
1703
+ panel_df[f"{first_outcome}_precision"],
1704
+ 0.0,
1705
+ )
1706
+
1707
+ # Drop geographic units (PSUs) with zero total weight — they would
1708
+ # inflate survey df and distort second-stage variance estimation.
1709
+ geo_col = by_cols[0]
1710
+ geo_weight = panel_df.groupby(geo_col)[weight_col].sum()
1711
+ zero_geos = geo_weight[geo_weight == 0].index
1712
+ if len(zero_geos) > 0:
1713
+ n_before = len(panel_df)
1714
+ panel_df = panel_df[~panel_df[geo_col].isin(zero_geos)].reset_index(drop=True)
1715
+ n_after = len(panel_df)
1716
+ warnings.warn(
1717
+ f"Dropped {n_before - n_after} cell(s) from {len(zero_geos)} "
1718
+ f"geographic unit(s) with zero total weight: "
1719
+ f"{list(zero_geos[:5])}"
1720
+ + (f" ... and {len(zero_geos) - 5} more" if len(zero_geos) > 5 else ""),
1721
+ UserWarning,
1722
+ stacklevel=2,
1723
+ )
1724
+
1725
+ # Guard: all cells dropped
1726
+ if panel_df.empty:
1727
+ raise ValueError(
1728
+ "No estimable cells remain after aggregation. "
1729
+ "All cells had missing outcomes or zero effective weight."
1730
+ )
1731
+
1732
+ second_stage_design = SurveyDesign(
1733
+ weights=weight_col,
1734
+ weight_type="aweight",
1735
+ psu=geo_col,
1736
+ )
1737
+
1738
+ return panel_df, second_stage_design