diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/prep.py ADDED
@@ -0,0 +1,1738 @@
1
+ """
2
+ Data preparation utilities for difference-in-differences analysis.
3
+
4
+ This module provides helper functions to prepare data for DiD estimation,
5
+ including creating treatment indicators, reshaping panel data, and
6
+ generating synthetic datasets for testing.
7
+
8
+ Data generation functions (generate_*) are defined in prep_dgp.py and
9
+ re-exported here for backward compatibility.
10
+ """
11
+
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ # Re-export data generation functions from prep_dgp for backward compatibility
18
+ from diff_diff.prep_dgp import ( # noqa: F401
19
+ generate_continuous_did_data,
20
+ generate_ddd_data,
21
+ generate_did_data,
22
+ generate_event_study_data,
23
+ generate_factor_data,
24
+ generate_panel_data,
25
+ generate_staggered_data,
26
+ generate_staggered_ddd_data,
27
+ generate_survey_did_data,
28
+ )
29
+ from diff_diff.survey import (
30
+ ResolvedSurveyDesign,
31
+ SurveyDesign,
32
+ compute_replicate_if_variance,
33
+ compute_survey_if_variance,
34
+ )
35
+ from diff_diff.utils import compute_synthetic_weights
36
+
37
+ # Constants for rank_control_units
38
+ _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
39
+ _OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
40
+
41
+
42
+ def make_treatment_indicator(
43
+ data: pd.DataFrame,
44
+ column: str,
45
+ treated_values: Optional[Union[Any, List[Any]]] = None,
46
+ threshold: Optional[float] = None,
47
+ above_threshold: bool = True,
48
+ new_column: str = "treated",
49
+ ) -> pd.DataFrame:
50
+ """
51
+ Create a binary treatment indicator column from various input types.
52
+
53
+ Parameters
54
+ ----------
55
+ data : pd.DataFrame
56
+ Input DataFrame.
57
+ column : str
58
+ Name of the column to use for creating the treatment indicator.
59
+ treated_values : Any or list, optional
60
+ Value(s) that indicate treatment. Units with these values get
61
+ treatment=1, others get treatment=0.
62
+ threshold : float, optional
63
+ Numeric threshold for creating treatment. Used when the treatment
64
+ is based on a continuous variable (e.g., treat firms above median size).
65
+ above_threshold : bool, default=True
66
+ If True, values >= threshold are treated. If False, values <= threshold
67
+ are treated. Only used when threshold is specified.
68
+ new_column : str, default="treated"
69
+ Name of the new treatment indicator column.
70
+
71
+ Returns
72
+ -------
73
+ pd.DataFrame
74
+ DataFrame with the new treatment indicator column added.
75
+
76
+ Examples
77
+ --------
78
+ Create treatment from categorical variable:
79
+
80
+ >>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
81
+ >>> df = make_treatment_indicator(df, 'group', treated_values='A')
82
+ >>> df['treated'].tolist()
83
+ [1, 1, 0, 0]
84
+
85
+ Create treatment from numeric threshold:
86
+
87
+ >>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
88
+ >>> df = make_treatment_indicator(df, 'size', threshold=75)
89
+ >>> df['treated'].tolist()
90
+ [0, 0, 1, 1]
91
+
92
+ Treat units below a threshold:
93
+
94
+ >>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
95
+ >>> df['treated'].tolist()
96
+ [1, 1, 0, 0]
97
+ """
98
+ df = data.copy()
99
+
100
+ if treated_values is not None and threshold is not None:
101
+ raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
102
+
103
+ if treated_values is None and threshold is None:
104
+ raise ValueError("Must specify either 'treated_values' or 'threshold'.")
105
+
106
+ if column not in df.columns:
107
+ raise ValueError(f"Column '{column}' not found in DataFrame.")
108
+
109
+ if treated_values is not None:
110
+ # Convert single value to list
111
+ if not isinstance(treated_values, (list, tuple, set)):
112
+ treated_values = [treated_values]
113
+ df[new_column] = df[column].isin(treated_values).astype(int)
114
+ else:
115
+ # Use threshold
116
+ if above_threshold:
117
+ df[new_column] = (df[column] >= threshold).astype(int)
118
+ else:
119
+ df[new_column] = (df[column] <= threshold).astype(int)
120
+
121
+ return df
122
+
123
+
124
+ def make_post_indicator(
125
+ data: pd.DataFrame,
126
+ time_column: str,
127
+ post_periods: Optional[Union[Any, List[Any]]] = None,
128
+ treatment_start: Optional[Any] = None,
129
+ new_column: str = "post",
130
+ ) -> pd.DataFrame:
131
+ """
132
+ Create a binary post-treatment indicator column.
133
+
134
+ Parameters
135
+ ----------
136
+ data : pd.DataFrame
137
+ Input DataFrame.
138
+ time_column : str
139
+ Name of the time/period column.
140
+ post_periods : Any or list, optional
141
+ Specific period value(s) that are post-treatment. Periods matching
142
+ these values get post=1, others get post=0.
143
+ treatment_start : Any, optional
144
+ The first post-treatment period. All periods >= this value get post=1.
145
+ Works with numeric periods, strings (sorted alphabetically), or dates.
146
+ new_column : str, default="post"
147
+ Name of the new post indicator column.
148
+
149
+ Returns
150
+ -------
151
+ pd.DataFrame
152
+ DataFrame with the new post indicator column added.
153
+
154
+ Examples
155
+ --------
156
+ Using specific post periods:
157
+
158
+ >>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
159
+ >>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
160
+ >>> df['post'].tolist()
161
+ [0, 0, 1, 1]
162
+
163
+ Using treatment start:
164
+
165
+ >>> df = make_post_indicator(df, 'year', treatment_start=2020)
166
+ >>> df['post'].tolist()
167
+ [0, 0, 1, 1]
168
+
169
+ Works with date columns:
170
+
171
+ >>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
172
+ >>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
173
+ """
174
+ df = data.copy()
175
+
176
+ if post_periods is not None and treatment_start is not None:
177
+ raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
178
+
179
+ if post_periods is None and treatment_start is None:
180
+ raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
181
+
182
+ if time_column not in df.columns:
183
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
184
+
185
+ if post_periods is not None:
186
+ # Convert single value to list
187
+ if not isinstance(post_periods, (list, tuple, set)):
188
+ post_periods = [post_periods]
189
+ df[new_column] = df[time_column].isin(post_periods).astype(int)
190
+ else:
191
+ # Use treatment_start - convert to same type as column if needed
192
+ col_dtype = df[time_column].dtype
193
+ if pd.api.types.is_datetime64_any_dtype(col_dtype):
194
+ treatment_start = pd.to_datetime(treatment_start)
195
+ df[new_column] = (df[time_column] >= treatment_start).astype(int)
196
+
197
+ return df
198
+
199
+
200
+ def wide_to_long(
201
+ data: pd.DataFrame,
202
+ value_columns: List[str],
203
+ id_column: str,
204
+ time_name: str = "period",
205
+ value_name: str = "value",
206
+ time_values: Optional[List[Any]] = None,
207
+ ) -> pd.DataFrame:
208
+ """
209
+ Convert wide-format panel data to long format for DiD analysis.
210
+
211
+ Wide format has one row per unit with multiple columns for each time period.
212
+ Long format has one row per unit-period combination.
213
+
214
+ Parameters
215
+ ----------
216
+ data : pd.DataFrame
217
+ Wide-format DataFrame with one row per unit.
218
+ value_columns : list of str
219
+ Column names containing the outcome values for each period.
220
+ These should be in chronological order.
221
+ id_column : str
222
+ Column name for the unit identifier.
223
+ time_name : str, default="period"
224
+ Name for the new time period column.
225
+ value_name : str, default="value"
226
+ Name for the new value/outcome column.
227
+ time_values : list, optional
228
+ Values to use for time periods. If None, uses 0, 1, 2, ...
229
+ Must have same length as value_columns.
230
+
231
+ Returns
232
+ -------
233
+ pd.DataFrame
234
+ Long-format DataFrame with one row per unit-period.
235
+
236
+ Examples
237
+ --------
238
+ >>> wide_df = pd.DataFrame({
239
+ ... 'firm_id': [1, 2, 3],
240
+ ... 'sales_2019': [100, 150, 200],
241
+ ... 'sales_2020': [110, 160, 210],
242
+ ... 'sales_2021': [120, 170, 220]
243
+ ... })
244
+ >>> long_df = wide_to_long(
245
+ ... wide_df,
246
+ ... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
247
+ ... id_column='firm_id',
248
+ ... time_name='year',
249
+ ... value_name='sales',
250
+ ... time_values=[2019, 2020, 2021]
251
+ ... )
252
+ >>> len(long_df)
253
+ 9
254
+ >>> long_df.columns.tolist()
255
+ ['firm_id', 'year', 'sales']
256
+ """
257
+ if not value_columns:
258
+ raise ValueError("value_columns cannot be empty.")
259
+
260
+ if id_column not in data.columns:
261
+ raise ValueError(f"Column '{id_column}' not found in DataFrame.")
262
+
263
+ for col in value_columns:
264
+ if col not in data.columns:
265
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
266
+
267
+ if time_values is None:
268
+ time_values = list(range(len(value_columns)))
269
+
270
+ if len(time_values) != len(value_columns):
271
+ raise ValueError(
272
+ f"time_values length ({len(time_values)}) must match "
273
+ f"value_columns length ({len(value_columns)})."
274
+ )
275
+
276
+ # Get other columns to preserve (not id or value columns)
277
+ other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
278
+
279
+ # Use pd.melt for better performance (vectorized)
280
+ long_df = pd.melt(
281
+ data,
282
+ id_vars=[id_column] + other_cols,
283
+ value_vars=value_columns,
284
+ var_name="_temp_var",
285
+ value_name=value_name,
286
+ )
287
+
288
+ # Map column names to time values
289
+ col_to_time = dict(zip(value_columns, time_values))
290
+ long_df[time_name] = long_df["_temp_var"].map(col_to_time)
291
+ long_df = long_df.drop("_temp_var", axis=1)
292
+
293
+ # Reorder columns and sort
294
+ cols = [id_column, time_name, value_name] + other_cols
295
+ return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
296
+
297
+
298
+ def balance_panel(
299
+ data: pd.DataFrame,
300
+ unit_column: str,
301
+ time_column: str,
302
+ method: str = "inner",
303
+ fill_value: Optional[float] = None,
304
+ ) -> pd.DataFrame:
305
+ """
306
+ Balance a panel dataset to ensure all units have all time periods.
307
+
308
+ Parameters
309
+ ----------
310
+ data : pd.DataFrame
311
+ Unbalanced panel data.
312
+ unit_column : str
313
+ Column name for unit identifier.
314
+ time_column : str
315
+ Column name for time period.
316
+ method : str, default="inner"
317
+ Balancing method:
318
+ - "inner": Keep only units that appear in all periods (drops units)
319
+ - "outer": Include all unit-period combinations (creates NaN)
320
+ - "fill": Include all combinations and fill missing values
321
+ fill_value : float, optional
322
+ Value to fill missing observations when method="fill".
323
+ If None with method="fill", uses column-specific forward fill.
324
+
325
+ Returns
326
+ -------
327
+ pd.DataFrame
328
+ Balanced panel DataFrame.
329
+
330
+ Examples
331
+ --------
332
+ Keep only complete units:
333
+
334
+ >>> df = pd.DataFrame({
335
+ ... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
336
+ ... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
337
+ ... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
338
+ ... })
339
+ >>> balanced = balance_panel(df, 'unit', 'period', method='inner')
340
+ >>> balanced['unit'].unique().tolist()
341
+ [1, 3]
342
+
343
+ Include all combinations:
344
+
345
+ >>> balanced = balance_panel(df, 'unit', 'period', method='outer')
346
+ >>> len(balanced)
347
+ 9
348
+ """
349
+ if unit_column not in data.columns:
350
+ raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
351
+ if time_column not in data.columns:
352
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
353
+
354
+ if method not in ["inner", "outer", "fill"]:
355
+ raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
356
+
357
+ all_units = data[unit_column].unique()
358
+ all_periods = sorted(data[time_column].unique())
359
+ n_periods = len(all_periods)
360
+
361
+ if method == "inner":
362
+ # Keep only units that have all periods
363
+ unit_counts = data.groupby(unit_column)[time_column].nunique()
364
+ complete_units = unit_counts[unit_counts == n_periods].index
365
+ return data[data[unit_column].isin(complete_units)].copy()
366
+
367
+ elif method in ["outer", "fill"]:
368
+ # Create full grid of unit-period combinations
369
+ full_index = pd.MultiIndex.from_product(
370
+ [all_units, all_periods], names=[unit_column, time_column]
371
+ )
372
+ full_df = pd.DataFrame(index=full_index).reset_index()
373
+
374
+ # Merge with original data
375
+ result = full_df.merge(data, on=[unit_column, time_column], how="left")
376
+
377
+ if method == "fill":
378
+ # Identify columns to fill (exclude unit and time columns)
379
+ cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
380
+
381
+ if fill_value is not None:
382
+ # Fill specified columns with fill_value
383
+ numeric_cols = result.select_dtypes(include=[np.number]).columns
384
+ for col in numeric_cols:
385
+ if col in cols_to_fill:
386
+ result[col] = result[col].fillna(fill_value)
387
+ else:
388
+ # Forward fill within each unit for non-key columns
389
+ result = result.sort_values([unit_column, time_column])
390
+ result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
391
+ # Backward fill any remaining NaN at start
392
+ result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
393
+
394
+ return result
395
+
396
+ return data
397
+
398
+
399
+ def validate_did_data(
400
+ data: pd.DataFrame,
401
+ outcome: str,
402
+ treatment: str,
403
+ time: str,
404
+ unit: Optional[str] = None,
405
+ raise_on_error: bool = True,
406
+ ) -> Dict[str, Any]:
407
+ """
408
+ Validate that data is properly formatted for DiD analysis.
409
+
410
+ Checks for common data issues and provides informative error messages.
411
+
412
+ Parameters
413
+ ----------
414
+ data : pd.DataFrame
415
+ Data to validate.
416
+ outcome : str
417
+ Name of outcome variable column.
418
+ treatment : str
419
+ Name of treatment indicator column.
420
+ time : str
421
+ Name of time/post indicator column.
422
+ unit : str, optional
423
+ Name of unit identifier column (for panel data validation).
424
+ raise_on_error : bool, default=True
425
+ If True, raises ValueError on validation failures.
426
+ If False, returns validation results without raising.
427
+
428
+ Returns
429
+ -------
430
+ dict
431
+ Validation results with keys:
432
+ - valid: bool indicating if data passed all checks
433
+ - errors: list of error messages
434
+ - warnings: list of warning messages
435
+ - summary: dict with data summary statistics
436
+
437
+ Examples
438
+ --------
439
+ >>> df = pd.DataFrame({
440
+ ... 'y': [1, 2, 3, 4],
441
+ ... 'treated': [0, 0, 1, 1],
442
+ ... 'post': [0, 1, 0, 1]
443
+ ... })
444
+ >>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
445
+ >>> result['valid']
446
+ True
447
+ """
448
+ errors = []
449
+ warnings = []
450
+
451
+ # Check columns exist
452
+ required_cols = [outcome, treatment, time]
453
+ if unit is not None:
454
+ required_cols.append(unit)
455
+
456
+ for col in required_cols:
457
+ if col not in data.columns:
458
+ errors.append(f"Required column '{col}' not found in DataFrame.")
459
+
460
+ if errors:
461
+ if raise_on_error:
462
+ raise ValueError("\n".join(errors))
463
+ return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
464
+
465
+ # Check outcome is numeric
466
+ if not pd.api.types.is_numeric_dtype(data[outcome]):
467
+ errors.append(
468
+ f"Outcome column '{outcome}' must be numeric. " f"Got type: {data[outcome].dtype}"
469
+ )
470
+
471
+ # Check treatment is binary
472
+ treatment_vals = data[treatment].dropna().unique()
473
+ if not set(treatment_vals).issubset({0, 1}):
474
+ errors.append(
475
+ f"Treatment column '{treatment}' must be binary (0 or 1). "
476
+ f"Found values: {sorted(treatment_vals)}"
477
+ )
478
+
479
+ # Check time is binary for simple DiD
480
+ time_vals = data[time].dropna().unique()
481
+ if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
482
+ warnings.append(
483
+ f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
484
+ "For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
485
+ )
486
+
487
+ # Check for missing values
488
+ for col in required_cols:
489
+ n_missing = data[col].isna().sum()
490
+ if n_missing > 0:
491
+ errors.append(
492
+ f"Column '{col}' has {n_missing} missing values. "
493
+ "Please handle missing data before fitting."
494
+ )
495
+
496
+ # Calculate summary statistics
497
+ summary = {}
498
+ if not errors:
499
+ summary["n_obs"] = len(data)
500
+ summary["n_treated"] = int((data[treatment] == 1).sum())
501
+ summary["n_control"] = int((data[treatment] == 0).sum())
502
+ summary["n_periods"] = len(time_vals)
503
+
504
+ if unit is not None:
505
+ summary["n_units"] = data[unit].nunique()
506
+
507
+ # Check for sufficient variation
508
+ if summary["n_treated"] == 0:
509
+ errors.append("No treated observations found (treatment column is all 0).")
510
+ if summary["n_control"] == 0:
511
+ errors.append("No control observations found (treatment column is all 1).")
512
+
513
+ # Check for each treatment-time combination
514
+ if len(time_vals) == 2:
515
+ # For 2-period DiD, check all four cells
516
+ for t_val in [0, 1]:
517
+ for p_val in time_vals:
518
+ count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
519
+ if count == 0:
520
+ errors.append(
521
+ f"No observations for treatment={t_val}, time={p_val}. "
522
+ "DiD requires observations in all treatment-time cells."
523
+ )
524
+ else:
525
+ # For multi-period, check that both treatment groups exist in multiple periods
526
+ for t_val in [0, 1]:
527
+ n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
528
+ if n_periods_with_obs < 2:
529
+ group_name = "Treated" if t_val == 1 else "Control"
530
+ errors.append(
531
+ f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
532
+ "DiD requires multiple periods per group."
533
+ )
534
+
535
+ # Panel-specific validation
536
+ if unit is not None and not errors:
537
+ # Check treatment is constant within units
538
+ unit_treatment_var = data.groupby(unit)[treatment].nunique()
539
+ units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
540
+ if len(units_with_varying_treatment) > 0:
541
+ warnings.append(
542
+ f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
543
+ "For standard DiD, treatment should be constant within units. "
544
+ "This may be intentional for staggered adoption designs."
545
+ )
546
+
547
+ # Check panel balance
548
+ periods_per_unit = data.groupby(unit)[time].nunique()
549
+ if periods_per_unit.min() != periods_per_unit.max():
550
+ warnings.append(
551
+ f"Unbalanced panel detected. Units have between "
552
+ f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
553
+ "Consider using balance_panel() to balance the data."
554
+ )
555
+
556
+ valid = len(errors) == 0
557
+
558
+ if raise_on_error and not valid:
559
+ raise ValueError("Data validation failed:\n" + "\n".join(errors))
560
+
561
+ return {"valid": valid, "errors": errors, "warnings": warnings, "summary": summary}
562
+
563
+
564
+ def summarize_did_data(
565
+ data: pd.DataFrame, outcome: str, treatment: str, time: str, unit: Optional[str] = None
566
+ ) -> pd.DataFrame:
567
+ """
568
+ Generate summary statistics by treatment group and time period.
569
+
570
+ Parameters
571
+ ----------
572
+ data : pd.DataFrame
573
+ Input data.
574
+ outcome : str
575
+ Name of outcome variable column.
576
+ treatment : str
577
+ Name of treatment indicator column.
578
+ time : str
579
+ Name of time/period column.
580
+ unit : str, optional
581
+ Name of unit identifier column.
582
+
583
+ Returns
584
+ -------
585
+ pd.DataFrame
586
+ Summary statistics with columns for each treatment-time combination.
587
+
588
+ Examples
589
+ --------
590
+ >>> df = pd.DataFrame({
591
+ ... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
592
+ ... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
593
+ ... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
594
+ ... })
595
+ >>> summary = summarize_did_data(df, 'y', 'treated', 'post')
596
+ >>> print(summary)
597
+ """
598
+ # Group by treatment and time
599
+ summary = (
600
+ data.groupby([treatment, time])[outcome]
601
+ .agg([("n", "count"), ("mean", "mean"), ("std", "std"), ("min", "min"), ("max", "max")])
602
+ .round(4)
603
+ )
604
+
605
+ # Calculate time values for labeling
606
+ time_vals = sorted(data[time].unique())
607
+
608
+ # Add group labels based on sorted time values (not literal 0/1)
609
+ if len(time_vals) == 2:
610
+ pre_val, post_val = time_vals[0], time_vals[1]
611
+
612
+ def format_label(x: tuple) -> str:
613
+ treatment_label = "Treated" if x[0] == 1 else "Control"
614
+ time_label = "Post" if x[1] == post_val else "Pre"
615
+ return f"{treatment_label} - {time_label}"
616
+
617
+ summary.index = summary.index.map(format_label)
618
+
619
+ # Calculate means for each cell
620
+ treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
621
+ treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
622
+ control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
623
+ control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
624
+
625
+ # Calculate DiD
626
+ treated_diff = treated_post - treated_pre
627
+ control_diff = control_post - control_pre
628
+ did_estimate = treated_diff - control_diff
629
+
630
+ # Add to summary as a new row
631
+ did_row = pd.DataFrame(
632
+ {"n": ["-"], "mean": [did_estimate], "std": ["-"], "min": ["-"], "max": ["-"]},
633
+ index=["DiD Estimate"],
634
+ )
635
+ summary = pd.concat([summary, did_row])
636
+ else:
637
+ summary.index = summary.index.map(
638
+ lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
639
+ )
640
+
641
+ return summary
642
+
643
+
644
+ def create_event_time(
645
+ data: pd.DataFrame, time_column: str, treatment_time_column: str, new_column: str = "event_time"
646
+ ) -> pd.DataFrame:
647
+ """
648
+ Create an event-time column relative to treatment timing.
649
+
650
+ Useful for event study designs where treatment occurs at different
651
+ times for different units.
652
+
653
+ Parameters
654
+ ----------
655
+ data : pd.DataFrame
656
+ Panel data.
657
+ time_column : str
658
+ Name of the calendar time column.
659
+ treatment_time_column : str
660
+ Name of the column indicating when each unit was treated.
661
+ Units with NaN or infinity are considered never-treated.
662
+ new_column : str, default="event_time"
663
+ Name of the new event-time column.
664
+
665
+ Returns
666
+ -------
667
+ pd.DataFrame
668
+ DataFrame with event-time column added. Values are:
669
+ - Negative for pre-treatment periods
670
+ - 0 for the treatment period
671
+ - Positive for post-treatment periods
672
+ - NaN for never-treated units
673
+
674
+ Examples
675
+ --------
676
+ >>> df = pd.DataFrame({
677
+ ... 'unit': [1, 1, 1, 2, 2, 2],
678
+ ... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
679
+ ... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
680
+ ... })
681
+ >>> df = create_event_time(df, 'year', 'treatment_year')
682
+ >>> df['event_time'].tolist()
683
+ [-1, 0, 1, -2, -1, 0]
684
+ """
685
+ df = data.copy()
686
+
687
+ if time_column not in df.columns:
688
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
689
+ if treatment_time_column not in df.columns:
690
+ raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
691
+
692
+ # Calculate event time
693
+ df[new_column] = df[time_column] - df[treatment_time_column]
694
+
695
+ # Handle never-treated (inf or NaN in treatment time)
696
+ col = df[treatment_time_column]
697
+ if pd.api.types.is_numeric_dtype(col):
698
+ never_treated = col.isna() | np.isinf(col)
699
+ else:
700
+ never_treated = col.isna()
701
+ df.loc[never_treated, new_column] = np.nan
702
+
703
+ return df
704
+
705
+
706
+ def aggregate_to_cohorts(
707
+ data: pd.DataFrame,
708
+ unit_column: str,
709
+ time_column: str,
710
+ treatment_column: str,
711
+ outcome: str,
712
+ covariates: Optional[List[str]] = None,
713
+ ) -> pd.DataFrame:
714
+ """
715
+ Aggregate unit-level data to treatment cohort means.
716
+
717
+ Useful for visualization and cohort-level analysis.
718
+
719
+ Parameters
720
+ ----------
721
+ data : pd.DataFrame
722
+ Unit-level panel data.
723
+ unit_column : str
724
+ Name of unit identifier column.
725
+ time_column : str
726
+ Name of time period column.
727
+ treatment_column : str
728
+ Name of treatment indicator column.
729
+ outcome : str
730
+ Name of outcome variable column.
731
+ covariates : list of str, optional
732
+ Additional columns to aggregate (will compute means).
733
+
734
+ Returns
735
+ -------
736
+ pd.DataFrame
737
+ Cohort-level data with mean outcomes by treatment status and period.
738
+
739
+ Examples
740
+ --------
741
+ >>> df = pd.DataFrame({
742
+ ... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
743
+ ... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
744
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
745
+ ... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
746
+ ... })
747
+ >>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
748
+ >>> len(cohort_df)
749
+ 4
750
+ """
751
+ agg_cols = {outcome: "mean", unit_column: "nunique"}
752
+
753
+ if covariates:
754
+ for cov in covariates:
755
+ agg_cols[cov] = "mean"
756
+
757
+ cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
758
+
759
+ # Rename columns
760
+ cohort_data = cohort_data.rename(columns={unit_column: "n_units", outcome: f"mean_{outcome}"})
761
+
762
+ return cohort_data
763
+
764
+
765
+ def rank_control_units(
766
+ data: pd.DataFrame,
767
+ unit_column: str,
768
+ time_column: str,
769
+ outcome_column: str,
770
+ treatment_column: Optional[str] = None,
771
+ treated_units: Optional[List[Any]] = None,
772
+ pre_periods: Optional[List[Any]] = None,
773
+ covariates: Optional[List[str]] = None,
774
+ outcome_weight: float = 0.7,
775
+ covariate_weight: float = 0.3,
776
+ exclude_units: Optional[List[Any]] = None,
777
+ require_units: Optional[List[Any]] = None,
778
+ n_top: Optional[int] = None,
779
+ suggest_treatment_candidates: bool = False,
780
+ n_treatment_candidates: int = 5,
781
+ lambda_reg: float = 0.0,
782
+ ) -> pd.DataFrame:
783
+ """
784
+ Rank potential control units by their suitability for DiD analysis.
785
+
786
+ Evaluates control units based on pre-treatment outcome trend similarity
787
+ and optional covariate matching to treated units. Returns a ranked list
788
+ with quality scores.
789
+
790
+ Parameters
791
+ ----------
792
+ data : pd.DataFrame
793
+ Panel data in long format.
794
+ unit_column : str
795
+ Column name for unit identifier.
796
+ time_column : str
797
+ Column name for time periods.
798
+ outcome_column : str
799
+ Column name for outcome variable.
800
+ treatment_column : str, optional
801
+ Column with binary treatment indicator (0/1). Used to identify
802
+ treated units from data.
803
+ treated_units : list, optional
804
+ Explicit list of treated unit IDs. Alternative to treatment_column.
805
+ pre_periods : list, optional
806
+ Pre-treatment periods for comparison. If None, uses first half of periods.
807
+ covariates : list of str, optional
808
+ Covariate columns for matching. Similarity is based on pre-treatment means.
809
+ outcome_weight : float, default=0.7
810
+ Weight for pre-treatment outcome trend similarity (0-1).
811
+ covariate_weight : float, default=0.3
812
+ Weight for covariate distance (0-1). Ignored if no covariates.
813
+ exclude_units : list, optional
814
+ Units that cannot be in control group.
815
+ require_units : list, optional
816
+ Units that must be in control group (will always appear in output).
817
+ n_top : int, optional
818
+ Return only top N control units. If None, return all.
819
+ suggest_treatment_candidates : bool, default=False
820
+ If True and no treated units specified, identify potential treatment
821
+ candidates instead of ranking controls.
822
+ n_treatment_candidates : int, default=5
823
+ Number of treatment candidates to suggest.
824
+ lambda_reg : float, default=0.0
825
+ Regularization for synthetic weights. Higher values give more uniform
826
+ weights across controls.
827
+
828
+ Returns
829
+ -------
830
+ pd.DataFrame
831
+ Ranked control units with columns:
832
+ - unit: Unit identifier
833
+ - quality_score: Combined quality score (0-1, higher is better)
834
+ - outcome_trend_score: Pre-treatment outcome trend similarity
835
+ - covariate_score: Covariate match score (NaN if no covariates)
836
+ - synthetic_weight: Weight from synthetic control optimization
837
+ - pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
838
+ - is_required: Whether unit was in require_units
839
+
840
+ If suggest_treatment_candidates=True (and no treated units):
841
+ - unit: Unit identifier
842
+ - treatment_candidate_score: Suitability as treatment unit
843
+ - avg_outcome_level: Pre-treatment outcome mean
844
+ - outcome_trend: Pre-treatment trend slope
845
+ - n_similar_controls: Count of similar potential controls
846
+
847
+ Examples
848
+ --------
849
+ Rank controls against treated units:
850
+
851
+ >>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
852
+ >>> ranking = rank_control_units(
853
+ ... data,
854
+ ... unit_column='unit',
855
+ ... time_column='period',
856
+ ... outcome_column='outcome',
857
+ ... treatment_column='treated',
858
+ ... n_top=10
859
+ ... )
860
+ >>> ranking['quality_score'].is_monotonic_decreasing
861
+ True
862
+
863
+ With covariates:
864
+
865
+ >>> data['size'] = np.random.randn(len(data))
866
+ >>> ranking = rank_control_units(
867
+ ... data,
868
+ ... unit_column='unit',
869
+ ... time_column='period',
870
+ ... outcome_column='outcome',
871
+ ... treatment_column='treated',
872
+ ... covariates=['size']
873
+ ... )
874
+
875
+ Filter data for SyntheticDiD:
876
+
877
+ >>> top_controls = ranking['unit'].tolist()
878
+ >>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
879
+ """
880
+ # -------------------------------------------------------------------------
881
+ # Input validation
882
+ # -------------------------------------------------------------------------
883
+ for col in [unit_column, time_column, outcome_column]:
884
+ if col not in data.columns:
885
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
886
+
887
+ if treatment_column is not None and treatment_column not in data.columns:
888
+ raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
889
+
890
+ if covariates:
891
+ for cov in covariates:
892
+ if cov not in data.columns:
893
+ raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
894
+
895
+ if not 0 <= outcome_weight <= 1:
896
+ raise ValueError("outcome_weight must be between 0 and 1")
897
+ if not 0 <= covariate_weight <= 1:
898
+ raise ValueError("covariate_weight must be between 0 and 1")
899
+
900
+ if treated_units is not None and treatment_column is not None:
901
+ raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
902
+
903
+ if require_units and exclude_units:
904
+ invalid_required = [u for u in require_units if u in exclude_units]
905
+ if invalid_required:
906
+ raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
907
+
908
+ # -------------------------------------------------------------------------
909
+ # Determine pre-treatment periods
910
+ # -------------------------------------------------------------------------
911
+ all_periods = sorted(data[time_column].unique())
912
+ if pre_periods is None:
913
+ mid_point = len(all_periods) // 2
914
+ pre_periods = all_periods[:mid_point]
915
+ else:
916
+ pre_periods = list(pre_periods)
917
+
918
+ if len(pre_periods) == 0:
919
+ raise ValueError("No pre-treatment periods specified or inferred.")
920
+
921
+ # -------------------------------------------------------------------------
922
+ # Identify treated and control units
923
+ # -------------------------------------------------------------------------
924
+ all_units = list(data[unit_column].unique())
925
+
926
+ if treated_units is not None:
927
+ treated_set = set(treated_units)
928
+ elif treatment_column is not None:
929
+ unit_treatment = data.groupby(unit_column)[treatment_column].first()
930
+ treated_set = set(unit_treatment[unit_treatment == 1].index)
931
+ elif suggest_treatment_candidates:
932
+ # Treatment candidate discovery mode - no treated units
933
+ treated_set = set()
934
+ else:
935
+ raise ValueError(
936
+ "Must specify treated_units, treatment_column, or set "
937
+ "suggest_treatment_candidates=True"
938
+ )
939
+
940
+ # -------------------------------------------------------------------------
941
+ # Treatment candidate discovery mode
942
+ # -------------------------------------------------------------------------
943
+ if suggest_treatment_candidates and len(treated_set) == 0:
944
+ return _suggest_treatment_candidates(
945
+ data, unit_column, time_column, outcome_column, pre_periods, n_treatment_candidates
946
+ )
947
+
948
+ if len(treated_set) == 0:
949
+ raise ValueError("No treated units found.")
950
+
951
+ # Determine control candidates
952
+ control_candidates = [u for u in all_units if u not in treated_set]
953
+
954
+ if exclude_units:
955
+ control_candidates = [u for u in control_candidates if u not in exclude_units]
956
+
957
+ if len(control_candidates) == 0:
958
+ raise ValueError("No control units available after exclusions.")
959
+
960
+ # -------------------------------------------------------------------------
961
+ # Create outcome matrices (pre-treatment)
962
+ # -------------------------------------------------------------------------
963
+ pre_data = data[data[time_column].isin(pre_periods)]
964
+ pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
965
+
966
+ # Filter to pre_periods that exist in data
967
+ valid_pre_periods = [p for p in pre_periods if p in pivot.index]
968
+ if len(valid_pre_periods) == 0:
969
+ raise ValueError("No data found for specified pre-treatment periods.")
970
+
971
+ # Filter control_candidates to those present in pivot (handles unbalanced panels)
972
+ control_candidates = [c for c in control_candidates if c in pivot.columns]
973
+ if len(control_candidates) == 0:
974
+ raise ValueError("No control units found in pre-treatment data.")
975
+
976
+ # Control outcomes: shape (n_pre_periods, n_control_candidates)
977
+ Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
978
+
979
+ # Treated outcomes mean: shape (n_pre_periods,)
980
+ treated_list = [u for u in treated_set if u in pivot.columns]
981
+ if len(treated_list) == 0:
982
+ raise ValueError("Treated units not found in pre-treatment data.")
983
+ Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
984
+
985
+ # -------------------------------------------------------------------------
986
+ # Compute outcome trend scores
987
+ # -------------------------------------------------------------------------
988
+ # Synthetic weights (higher = better match)
989
+ synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg)
990
+
991
+ # RMSE for each control vs treated mean (use nanmean to handle missing data)
992
+ rmse_scores = []
993
+ for j in range(len(control_candidates)):
994
+ y_c = Y_control[:, j]
995
+ rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
996
+ rmse_scores.append(rmse)
997
+
998
+ # Convert RMSE to similarity score (lower RMSE = higher score)
999
+ max_rmse = max(rmse_scores) if rmse_scores else 1.0
1000
+ min_rmse = min(rmse_scores) if rmse_scores else 0.0
1001
+ rmse_range = max_rmse - min_rmse
1002
+
1003
+ if rmse_range < 1e-10:
1004
+ # All controls have identical/similar pre-trends (includes single control case)
1005
+ outcome_trend_scores = [1.0] * len(rmse_scores)
1006
+ else:
1007
+ # Normalize so best control gets 1.0, worst gets 0.0
1008
+ outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
1009
+
1010
+ # -------------------------------------------------------------------------
1011
+ # Compute covariate scores (if covariates provided)
1012
+ # -------------------------------------------------------------------------
1013
+ if covariates and len(covariates) > 0:
1014
+ # Get unit-level covariate values (pre-treatment mean)
1015
+ cov_data = pre_data.groupby(unit_column)[covariates].mean()
1016
+
1017
+ # Treated covariate profile (mean across treated units)
1018
+ treated_cov = cov_data.loc[list(treated_set)].mean()
1019
+
1020
+ # Standardize covariates
1021
+ cov_mean = cov_data.mean()
1022
+ cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
1023
+ cov_standardized = (cov_data - cov_mean) / cov_std
1024
+ treated_cov_std = (treated_cov - cov_mean) / cov_std
1025
+
1026
+ # Euclidean distance in standardized space (vectorized)
1027
+ control_cov_matrix = cov_standardized.loc[control_candidates].values
1028
+ treated_cov_vector = treated_cov_std.values
1029
+ covariate_distances = np.sqrt(
1030
+ np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
1031
+ )
1032
+
1033
+ # Convert distance to similarity score (min-max normalization)
1034
+ max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
1035
+ min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
1036
+ dist_range = max_dist - min_dist
1037
+
1038
+ if dist_range < 1e-10:
1039
+ # All controls have identical/similar covariate profiles
1040
+ covariate_scores = [1.0] * len(covariate_distances)
1041
+ else:
1042
+ # Normalize so best control (closest) gets 1.0, worst gets 0.0
1043
+ covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
1044
+ else:
1045
+ covariate_scores = [np.nan] * len(control_candidates)
1046
+
1047
+ # -------------------------------------------------------------------------
1048
+ # Compute combined quality score
1049
+ # -------------------------------------------------------------------------
1050
+ # Normalize weights
1051
+ total_weight = outcome_weight + covariate_weight
1052
+ if total_weight > 0:
1053
+ norm_outcome_weight = outcome_weight / total_weight
1054
+ norm_covariate_weight = covariate_weight / total_weight
1055
+ else:
1056
+ norm_outcome_weight = 1.0
1057
+ norm_covariate_weight = 0.0
1058
+
1059
+ quality_scores = []
1060
+ for i in range(len(control_candidates)):
1061
+ outcome_score = outcome_trend_scores[i]
1062
+ cov_score = covariate_scores[i]
1063
+
1064
+ if np.isnan(cov_score):
1065
+ # No covariates - use only outcome score
1066
+ combined = outcome_score
1067
+ else:
1068
+ combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
1069
+
1070
+ quality_scores.append(combined)
1071
+
1072
+ # -------------------------------------------------------------------------
1073
+ # Build result DataFrame
1074
+ # -------------------------------------------------------------------------
1075
+ require_set = set(require_units) if require_units else set()
1076
+
1077
+ result = pd.DataFrame(
1078
+ {
1079
+ "unit": control_candidates,
1080
+ "quality_score": quality_scores,
1081
+ "outcome_trend_score": outcome_trend_scores,
1082
+ "covariate_score": covariate_scores,
1083
+ "synthetic_weight": synthetic_weights,
1084
+ "pre_trend_rmse": rmse_scores,
1085
+ "is_required": [u in require_set for u in control_candidates],
1086
+ }
1087
+ )
1088
+
1089
+ # Sort by quality score (descending)
1090
+ result = result.sort_values("quality_score", ascending=False)
1091
+
1092
+ # Apply n_top limit if specified
1093
+ if n_top is not None and n_top < len(result):
1094
+ # Always include required units
1095
+ required_df = result[result["is_required"]]
1096
+ non_required_df = result[~result["is_required"]]
1097
+
1098
+ # Take top from non-required to fill remaining slots
1099
+ remaining_slots = max(0, n_top - len(required_df))
1100
+ top_non_required = non_required_df.head(remaining_slots)
1101
+
1102
+ result = pd.concat([required_df, top_non_required])
1103
+ result = result.sort_values("quality_score", ascending=False)
1104
+
1105
+ return result.reset_index(drop=True)
1106
+
1107
+
1108
+ def _suggest_treatment_candidates(
1109
+ data: pd.DataFrame,
1110
+ unit_column: str,
1111
+ time_column: str,
1112
+ outcome_column: str,
1113
+ pre_periods: List[Any],
1114
+ n_candidates: int,
1115
+ ) -> pd.DataFrame:
1116
+ """
1117
+ Identify units that would make good treatment candidates.
1118
+
1119
+ A good treatment candidate:
1120
+ 1. Has many similar control units available (for matching)
1121
+ 2. Has stable pre-treatment trends (predictable counterfactual)
1122
+ 3. Is not an extreme outlier
1123
+
1124
+ Parameters
1125
+ ----------
1126
+ data : pd.DataFrame
1127
+ Panel data.
1128
+ unit_column : str
1129
+ Unit identifier column.
1130
+ time_column : str
1131
+ Time period column.
1132
+ outcome_column : str
1133
+ Outcome variable column.
1134
+ pre_periods : list
1135
+ Pre-treatment periods.
1136
+ n_candidates : int
1137
+ Number of candidates to return.
1138
+
1139
+ Returns
1140
+ -------
1141
+ pd.DataFrame
1142
+ Treatment candidates with scores.
1143
+ """
1144
+ all_units = list(data[unit_column].unique())
1145
+ pre_data = data[data[time_column].isin(pre_periods)]
1146
+
1147
+ candidate_info = []
1148
+
1149
+ for unit in all_units:
1150
+ unit_data = pre_data[pre_data[unit_column] == unit]
1151
+
1152
+ if len(unit_data) == 0:
1153
+ continue
1154
+
1155
+ # Average outcome level
1156
+ avg_outcome = unit_data[outcome_column].mean()
1157
+
1158
+ # Trend (simple linear regression slope)
1159
+ times = unit_data[time_column].values
1160
+ outcomes = unit_data[outcome_column].values
1161
+ if len(times) > 1:
1162
+ times_norm = np.arange(len(times))
1163
+ try:
1164
+ slope = np.polyfit(times_norm, outcomes, 1)[0]
1165
+ except (np.linalg.LinAlgError, ValueError):
1166
+ slope = 0.0
1167
+ else:
1168
+ slope = 0.0
1169
+
1170
+ # Count similar potential controls
1171
+ other_units = [u for u in all_units if u != unit]
1172
+ other_means = (
1173
+ pre_data[pre_data[unit_column].isin(other_units)]
1174
+ .groupby(unit_column)[outcome_column]
1175
+ .mean()
1176
+ )
1177
+
1178
+ if len(other_means) > 0:
1179
+ sd = other_means.std()
1180
+ if sd > 0:
1181
+ n_similar = int(
1182
+ np.sum(np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd)
1183
+ )
1184
+ else:
1185
+ n_similar = len(other_means)
1186
+ else:
1187
+ n_similar = 0
1188
+
1189
+ candidate_info.append(
1190
+ {
1191
+ "unit": unit,
1192
+ "avg_outcome_level": avg_outcome,
1193
+ "outcome_trend": slope,
1194
+ "n_similar_controls": n_similar,
1195
+ }
1196
+ )
1197
+
1198
+ if len(candidate_info) == 0:
1199
+ return pd.DataFrame(
1200
+ columns=[
1201
+ "unit",
1202
+ "treatment_candidate_score",
1203
+ "avg_outcome_level",
1204
+ "outcome_trend",
1205
+ "n_similar_controls",
1206
+ ]
1207
+ )
1208
+
1209
+ result = pd.DataFrame(candidate_info)
1210
+
1211
+ # Score: prefer units with many similar controls and moderate outcome levels
1212
+ max_similar = result["n_similar_controls"].max()
1213
+ if max_similar > 0:
1214
+ similarity_score = result["n_similar_controls"] / max_similar
1215
+ else:
1216
+ similarity_score = pd.Series([0.0] * len(result))
1217
+
1218
+ # Penalty for outliers in outcome level
1219
+ outcome_mean = result["avg_outcome_level"].mean()
1220
+ outcome_std = result["avg_outcome_level"].std()
1221
+ if outcome_std > 0:
1222
+ outcome_z = np.abs((result["avg_outcome_level"] - outcome_mean) / outcome_std)
1223
+ else:
1224
+ outcome_z = pd.Series([0.0] * len(result))
1225
+
1226
+ result["treatment_candidate_score"] = (
1227
+ similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
1228
+ ).clip(0, 1)
1229
+
1230
+ # Return top candidates
1231
+ result = result.nlargest(n_candidates, "treatment_candidate_score")
1232
+ return result.reset_index(drop=True)
1233
+
1234
+
1235
+ def trim_weights(
1236
+ data: pd.DataFrame,
1237
+ weight_col: str,
1238
+ upper: Optional[float] = None,
1239
+ quantile: Optional[float] = None,
1240
+ lower: Optional[float] = None,
1241
+ ) -> pd.DataFrame:
1242
+ """Trim (winsorize) survey weights to reduce influence of extreme values.
1243
+
1244
+ Caps weights at specified thresholds. Useful for reducing variance from
1245
+ extreme survey weights before DiD estimation. Federal agencies (e.g., NCHS)
1246
+ recommend reviewing weights with CV > 30%.
1247
+
1248
+ Parameters
1249
+ ----------
1250
+ data : pd.DataFrame
1251
+ Input DataFrame.
1252
+ weight_col : str
1253
+ Name of the weight column.
1254
+ upper : float, optional
1255
+ Absolute upper cap. Weights above this value are set to it.
1256
+ Mutually exclusive with ``quantile``.
1257
+ quantile : float, optional
1258
+ Quantile-based upper cap (e.g., 0.99). Weights above the quantile
1259
+ value are capped at it. Mutually exclusive with ``upper``.
1260
+ lower : float, optional
1261
+ Absolute lower floor. Weights below this value are set to it.
1262
+ Can be combined with either ``upper`` or ``quantile``.
1263
+
1264
+ Returns
1265
+ -------
1266
+ pd.DataFrame
1267
+ Copy of data with trimmed weights.
1268
+
1269
+ Raises
1270
+ ------
1271
+ ValueError
1272
+ If both ``upper`` and ``quantile`` are provided, or if ``weight_col``
1273
+ is not in the DataFrame.
1274
+ """
1275
+ if upper is not None and quantile is not None:
1276
+ raise ValueError("Specify either 'upper' or 'quantile', not both.")
1277
+ if weight_col not in data.columns:
1278
+ raise ValueError(f"Column '{weight_col}' not found in DataFrame.")
1279
+
1280
+ result = data.copy()
1281
+ w = result[weight_col].values.copy()
1282
+
1283
+ if quantile is not None:
1284
+ if not (0 < quantile < 1):
1285
+ raise ValueError(f"quantile must be in (0, 1), got {quantile}")
1286
+ upper = float(np.nanquantile(w, quantile))
1287
+
1288
+ # Validate cap values are finite and non-negative
1289
+ if upper is not None:
1290
+ if not np.isfinite(upper) or upper < 0:
1291
+ raise ValueError(f"upper must be finite and >= 0, got {upper}")
1292
+ if lower is not None:
1293
+ if not np.isfinite(lower) or lower < 0:
1294
+ raise ValueError(f"lower must be finite and >= 0, got {lower}")
1295
+ if upper is not None and lower is not None and lower > upper:
1296
+ raise ValueError(
1297
+ f"lower ({lower}) must be <= upper ({upper}). "
1298
+ f"When using quantile, the resolved upper cap may be below lower."
1299
+ )
1300
+
1301
+ if upper is not None:
1302
+ w = np.minimum(w, upper)
1303
+ if lower is not None:
1304
+ w = np.maximum(w, lower)
1305
+
1306
+ result[weight_col] = w
1307
+ return result
1308
+
1309
+
1310
+ # ---------------------------------------------------------------------------
1311
+ # Survey aggregation helpers
1312
+ # ---------------------------------------------------------------------------
1313
+
1314
+
1315
+ def _cell_mean_variance(
1316
+ y_full: np.ndarray,
1317
+ full_resolved: ResolvedSurveyDesign,
1318
+ cell_mask: np.ndarray,
1319
+ min_n: int,
1320
+ ) -> Tuple[float, float, int, bool]:
1321
+ """Compute design-based mean and variance of the weighted mean for one cell.
1322
+
1323
+ Uses full-design domain estimation: the influence function is zero-padded
1324
+ outside the cell, preserving the full strata/PSU structure for variance
1325
+ estimation. This is the methodologically correct approach for domain
1326
+ estimation under complex survey designs (Lumley 2004, Section 3.4).
1327
+
1328
+ Parameters
1329
+ ----------
1330
+ y_full : np.ndarray
1331
+ Outcome values for the full dataset (may contain NaN).
1332
+ full_resolved : ResolvedSurveyDesign
1333
+ Full-sample resolved survey design.
1334
+ cell_mask : np.ndarray
1335
+ Boolean mask identifying cell members in the full dataset.
1336
+ min_n : int
1337
+ Minimum valid observations for design-based variance. Below this
1338
+ threshold, SRS fallback is used.
1339
+
1340
+ Returns
1341
+ -------
1342
+ mean : float
1343
+ Design-weighted cell mean.
1344
+ variance : float
1345
+ Design-based variance of the cell mean (>= 0). Uses SRS fallback
1346
+ when the design-based estimate is unidentifiable or n_valid < min_n.
1347
+ n_valid : int
1348
+ Number of non-missing observations in the cell.
1349
+ used_srs_fallback : bool
1350
+ True if SRS variance was used instead of design-based.
1351
+ """
1352
+ y_cell = y_full[cell_mask]
1353
+ w_cell = full_resolved.weights[cell_mask]
1354
+ # Valid = non-missing AND positive weight (zero-weight rows are padding)
1355
+ valid = ~np.isnan(y_cell) & (w_cell > 0)
1356
+ n_valid = int(np.sum(valid))
1357
+
1358
+ if n_valid == 0:
1359
+ return np.nan, np.nan, 0, False
1360
+
1361
+ if n_valid < 2:
1362
+ y_bar = float(y_cell[valid][0])
1363
+ return y_bar, np.nan, 1, False
1364
+
1365
+ # Weighted mean from cell members (NaN-safe)
1366
+ w_valid = w_cell * valid.astype(np.float64)
1367
+ y_clean = np.where(valid, y_cell, 0.0)
1368
+ sum_w = float(np.sum(w_valid))
1369
+
1370
+ if sum_w <= 0:
1371
+ return np.nan, np.nan, n_valid, False
1372
+
1373
+ y_bar = float(np.sum(w_valid * y_clean) / sum_w)
1374
+
1375
+ # SRS fallback if below min_n threshold
1376
+ # Normalize positive weights to mean=1 so fallback is scale-invariant
1377
+ # (replicate designs preserve raw weight scale per survey.py:L189-240)
1378
+ used_srs = False
1379
+ if n_valid < min_n:
1380
+ w_norm = w_valid.copy()
1381
+ w_pos = w_norm[w_norm > 0]
1382
+ if len(w_pos) > 0:
1383
+ w_norm[w_norm > 0] = w_pos / w_pos.mean()
1384
+ sum_wn = float(np.sum(w_norm))
1385
+ resid_sq = w_norm * (y_clean - y_bar) ** 2
1386
+ variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
1387
+ return y_bar, max(variance, 0.0), n_valid, True
1388
+
1389
+ # Full-design domain estimation: construct full-length psi with zeros
1390
+ # outside the cell, preserving full strata/PSU structure for variance
1391
+ n_total = len(y_full)
1392
+ psi = np.zeros(n_total)
1393
+ # Positions in full array where cell member has valid data
1394
+ cell_indices = np.where(cell_mask)[0]
1395
+ valid_positions = cell_indices[valid]
1396
+ psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w
1397
+
1398
+ # Route to TSL or replicate variance using the full design
1399
+ if full_resolved.uses_replicate_variance:
1400
+ variance, _ = compute_replicate_if_variance(psi, full_resolved)
1401
+ else:
1402
+ variance = compute_survey_if_variance(psi, full_resolved)
1403
+
1404
+ # SRS fallback when design-based variance is unidentifiable
1405
+ if np.isnan(variance):
1406
+ w_norm = w_valid.copy()
1407
+ w_pos = w_norm[w_norm > 0]
1408
+ if len(w_pos) > 0:
1409
+ w_norm[w_norm > 0] = w_pos / w_pos.mean()
1410
+ sum_wn = float(np.sum(w_norm))
1411
+ resid_sq = w_norm * (y_clean - y_bar) ** 2
1412
+ variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
1413
+ used_srs = True
1414
+
1415
+ return y_bar, max(float(variance), 0.0), n_valid, used_srs
1416
+
1417
+
1418
+ def aggregate_survey(
1419
+ data: pd.DataFrame,
1420
+ by: Union[str, List[str]],
1421
+ outcomes: Union[str, List[str]],
1422
+ survey_design: SurveyDesign,
1423
+ covariates: Optional[Union[str, List[str]]] = None,
1424
+ min_n: int = 2,
1425
+ lonely_psu: Optional[str] = None,
1426
+ ) -> Tuple[pd.DataFrame, SurveyDesign]:
1427
+ """Aggregate survey microdata to geographic-period cells with design-based precision.
1428
+
1429
+ Computes design-weighted cell means and their Taylor-linearized (or
1430
+ replicate-based) standard errors for each cell defined by the ``by``
1431
+ columns. Returns a panel-ready DataFrame with precision weights and a
1432
+ pre-configured :class:`SurveyDesign` for second-stage DiD estimation.
1433
+
1434
+ Each cell is treated as a subpopulation/domain of the full survey
1435
+ design: influence function values are zero-padded outside the cell,
1436
+ preserving full strata/PSU structure for variance estimation per
1437
+ Lumley (2004) Section 3.4.
1438
+
1439
+ Parameters
1440
+ ----------
1441
+ data : pd.DataFrame
1442
+ Individual-level microdata.
1443
+ by : str or list of str
1444
+ Columns defining cells (e.g., ``["state", "year"]``). The first
1445
+ element is used as the clustering variable in the returned
1446
+ SurveyDesign (geographic unit for second-stage inference).
1447
+ outcomes : str or list of str
1448
+ Outcome variable(s) to aggregate with full precision tracking.
1449
+ Each outcome produces ``{name}_mean``, ``{name}_se``,
1450
+ ``{name}_n``, and ``{name}_precision`` columns. When multiple
1451
+ outcomes are given, panel filtering (non-estimable cell
1452
+ removal, zero-weight PSU pruning) is based on the **first**
1453
+ outcome only, consistent with the returned SurveyDesign. For
1454
+ independent per-outcome support, call once per outcome.
1455
+ survey_design : SurveyDesign
1456
+ Survey design specification for the microdata.
1457
+ covariates : str or list of str, optional
1458
+ Additional variables to aggregate as design-weighted means only
1459
+ (no SE/precision columns).
1460
+ min_n : int, default 2
1461
+ Minimum respondents per cell. Cells below this threshold use
1462
+ simple random sampling variance as a fallback.
1463
+ lonely_psu : str, optional
1464
+ Override the survey design's ``lonely_psu`` setting for within-cell
1465
+ computation. One of ``"remove"``, ``"certainty"``, ``"adjust"``.
1466
+
1467
+ Returns
1468
+ -------
1469
+ panel_df : pd.DataFrame
1470
+ Aggregated panel with columns: grouping variables,
1471
+ ``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``,
1472
+ ``{outcome}_precision``, ``{outcome}_weight``,
1473
+ ``{covariate}_mean``, ``cell_n``, ``cell_n_eff``,
1474
+ ``srs_fallback``. The ``_weight`` column is a fit-ready
1475
+ version of ``_precision`` with NaN/Inf mapped to 0.0.
1476
+ second_stage_design : SurveyDesign
1477
+ Pre-configured for second-stage estimation with
1478
+ ``weight_type="aweight"``, precision weights from the first
1479
+ outcome, and geographic clustering via ``psu``.
1480
+
1481
+ Examples
1482
+ --------
1483
+ >>> design = SurveyDesign(weights="finalwt", strata="strat", psu="psu")
1484
+ >>> panel, stage2 = aggregate_survey(
1485
+ ... microdata, by=["state", "year"],
1486
+ ... outcomes="smoking_rate", survey_design=design,
1487
+ ... )
1488
+ >>> # Add treatment/time indicators at the panel level, then fit:
1489
+ >>> # panel["treated"] = ... # e.g., from policy adoption data
1490
+ >>> # panel["post"] = (panel["year"] >= treatment_year).astype(int)
1491
+ >>> # result = DifferenceInDifferences().fit(
1492
+ >>> # panel, outcome="smoking_rate_mean",
1493
+ >>> # treatment="treated", time="post", survey_design=stage2,
1494
+ >>> # )
1495
+ """
1496
+ import warnings
1497
+ from dataclasses import replace
1498
+
1499
+ # --- Normalize inputs ---
1500
+ by_cols = [by] if isinstance(by, str) else list(by)
1501
+ outcome_cols = [outcomes] if isinstance(outcomes, str) else list(outcomes)
1502
+ cov_cols = (
1503
+ [covariates] if isinstance(covariates, str) else list(covariates) if covariates else []
1504
+ )
1505
+
1506
+ # --- Validate ---
1507
+ if not by_cols:
1508
+ raise ValueError("'by' must specify at least one grouping column")
1509
+ if not outcome_cols:
1510
+ raise ValueError("'outcomes' must specify at least one outcome variable")
1511
+
1512
+ all_cols = by_cols + outcome_cols + cov_cols
1513
+ missing = [c for c in all_cols if c not in data.columns]
1514
+ if missing:
1515
+ raise ValueError(f"Columns not found in DataFrame: {missing}")
1516
+
1517
+ overlap = set(by_cols) & (set(outcome_cols) | set(cov_cols))
1518
+ if overlap:
1519
+ raise ValueError(f"Columns appear in both 'by' and outcomes/covariates: {overlap}")
1520
+
1521
+ if not isinstance(survey_design, SurveyDesign):
1522
+ raise TypeError(
1523
+ f"survey_design must be a SurveyDesign instance, got {type(survey_design).__name__}"
1524
+ )
1525
+
1526
+ if min_n < 1:
1527
+ raise ValueError(f"min_n must be >= 1, got {min_n}")
1528
+
1529
+ if lonely_psu is not None and lonely_psu not in ("remove", "certainty", "adjust"):
1530
+ raise ValueError(
1531
+ f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'"
1532
+ )
1533
+
1534
+ # --- Empty-input guard ---
1535
+ if data.empty:
1536
+ raise ValueError("data must be non-empty")
1537
+
1538
+ # --- Validate grouping columns have no missing values ---
1539
+ by_missing = data[by_cols].isna().any()
1540
+ cols_with_na = list(by_missing[by_missing].index)
1541
+ if cols_with_na:
1542
+ raise ValueError(
1543
+ f"Missing values in grouping column(s): {cols_with_na}. "
1544
+ f"Drop or fill NaN values before calling aggregate_survey()."
1545
+ )
1546
+
1547
+ # --- Resolve design once on full data ---
1548
+ effective_design = (
1549
+ replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design
1550
+ )
1551
+ full_resolved = effective_design.resolve(data)
1552
+
1553
+ # --- Precompute full-length outcome/covariate arrays ---
1554
+ n_total = len(data)
1555
+ all_vars = outcome_cols + cov_cols
1556
+ non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])]
1557
+ if non_numeric:
1558
+ raise ValueError(
1559
+ f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. "
1560
+ f"All outcome and covariate columns must be numeric."
1561
+ )
1562
+ y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars}
1563
+
1564
+ # --- Per-cell computation ---
1565
+ # Use groupby().indices for position-based cell membership (safe with
1566
+ # duplicate DataFrame indices, no column injection into user data)
1567
+ grouped = data.groupby(by_cols, sort=True)
1568
+ cell_indices = grouped.indices # dict of cell_key → positional indices
1569
+ rows: List[Dict[str, Any]] = []
1570
+ srs_cells: List[str] = []
1571
+ zero_var_cells: List[str] = []
1572
+
1573
+ for cell_key, pos_idx in cell_indices.items():
1574
+ # Boolean mask for full-design domain estimation
1575
+ cell_mask = np.zeros(n_total, dtype=bool)
1576
+ cell_mask[pos_idx] = True
1577
+
1578
+ cell_n = int(np.sum(cell_mask))
1579
+ cell_key_str = str(cell_key)
1580
+
1581
+ # Cell-level statistics (Kish ESS is a property of the cell)
1582
+ cell_w = full_resolved.weights[cell_mask]
1583
+ sum_w = float(np.sum(cell_w))
1584
+ sum_w2 = float(np.sum(cell_w**2))
1585
+ cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0
1586
+
1587
+ # Build row dict with grouping columns
1588
+ row: Dict[str, Any] = {}
1589
+ if len(by_cols) == 1:
1590
+ row[by_cols[0]] = cell_key
1591
+ else:
1592
+ for i, col in enumerate(by_cols):
1593
+ row[col] = cell_key[i]
1594
+
1595
+ row["cell_n"] = cell_n
1596
+ row["cell_n_eff"] = cell_n_eff
1597
+
1598
+ cell_srs_fallback = False
1599
+
1600
+ # Outcomes: mean + SE + n + precision (full-design domain estimation)
1601
+ for var in outcome_cols:
1602
+ y_bar, variance, n_valid, used_srs = _cell_mean_variance(
1603
+ y_arrays[var],
1604
+ full_resolved,
1605
+ cell_mask,
1606
+ min_n,
1607
+ )
1608
+ se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan
1609
+
1610
+ if used_srs:
1611
+ cell_srs_fallback = True
1612
+
1613
+ # Zero variance → precision NaN
1614
+ if se == 0.0:
1615
+ precision = np.nan
1616
+ zero_var_cells.append(cell_key_str)
1617
+ elif np.isnan(se):
1618
+ precision = np.nan
1619
+ else:
1620
+ precision = 1.0 / variance
1621
+
1622
+ row[f"{var}_mean"] = y_bar
1623
+ row[f"{var}_se"] = se
1624
+ row[f"{var}_n"] = n_valid
1625
+ row[f"{var}_precision"] = precision
1626
+
1627
+ # Covariates: design-weighted mean only
1628
+ for var in cov_cols:
1629
+ y_cell = y_arrays[var][cell_mask]
1630
+ valid = ~np.isnan(y_cell)
1631
+ w_valid = cell_w * valid.astype(np.float64)
1632
+ sw = float(np.sum(w_valid))
1633
+ if sw > 0:
1634
+ row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw)
1635
+ else:
1636
+ row[f"{var}_mean"] = np.nan
1637
+
1638
+ row["srs_fallback"] = cell_srs_fallback
1639
+ if cell_srs_fallback:
1640
+ srs_cells.append(cell_key_str)
1641
+
1642
+ rows.append(row)
1643
+
1644
+ # --- Warnings ---
1645
+ if srs_cells:
1646
+ warnings.warn(
1647
+ f"Design-based variance not estimable for {len(srs_cells)} cell(s); "
1648
+ f"using SRS fallback: {srs_cells[:5]}"
1649
+ + (f" ... and {len(srs_cells) - 5} more" if len(srs_cells) > 5 else ""),
1650
+ UserWarning,
1651
+ stacklevel=2,
1652
+ )
1653
+ if zero_var_cells:
1654
+ warnings.warn(
1655
+ f"Zero variance in {len(zero_var_cells)} cell(s) (precision set to NaN): "
1656
+ f"{zero_var_cells[:5]}"
1657
+ + (f" ... and {len(zero_var_cells) - 5} more" if len(zero_var_cells) > 5 else ""),
1658
+ UserWarning,
1659
+ stacklevel=2,
1660
+ )
1661
+
1662
+ # --- Assemble output ---
1663
+ panel_df = pd.DataFrame(rows)
1664
+
1665
+ # Sort by grouping columns
1666
+ panel_df = panel_df.sort_values(by_cols).reset_index(drop=True)
1667
+
1668
+ # --- Drop non-estimable cells ---
1669
+ # Cells with non-finite mean (n_valid==0 or all-missing) cannot contribute
1670
+ # to second-stage estimation and would cause fit() to reject NaN outcomes.
1671
+ # Dropping them also removes all-zero-weight PSUs from the panel.
1672
+ first_outcome = outcome_cols[0]
1673
+ mean_col = f"{first_outcome}_mean"
1674
+ nonestimable = ~np.isfinite(panel_df[mean_col].values)
1675
+ if np.any(nonestimable):
1676
+ n_dropped = int(np.sum(nonestimable))
1677
+ dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist()
1678
+ # Warn about secondary outcomes losing valid data in dropped cells
1679
+ secondary_loss = []
1680
+ for var in outcome_cols[1:]:
1681
+ valid_secondary = np.isfinite(panel_df.loc[nonestimable, f"{var}_mean"].values)
1682
+ if np.any(valid_secondary):
1683
+ secondary_loss.append(var)
1684
+ msg = (
1685
+ f"Dropped {n_dropped} non-estimable cell(s) (based on first outcome "
1686
+ f"'{first_outcome}'): {dropped_keys[:5]}"
1687
+ + (f" ... and {n_dropped - 5} more" if n_dropped > 5 else "")
1688
+ )
1689
+ if secondary_loss:
1690
+ msg += (
1691
+ f". Note: {secondary_loss} had valid data in dropped cells. "
1692
+ f"For independent per-outcome support, call once per outcome."
1693
+ )
1694
+ warnings.warn(msg, UserWarning, stacklevel=2)
1695
+ panel_df = panel_df[~nonestimable].reset_index(drop=True)
1696
+
1697
+ # --- Construct second-stage SurveyDesign ---
1698
+ # Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream
1699
+ # resolve() doesn't reject missing weights. Diagnostic *_precision is kept.
1700
+ weight_col = f"{first_outcome}_weight"
1701
+ panel_df[weight_col] = np.where(
1702
+ np.isfinite(panel_df[f"{first_outcome}_precision"]),
1703
+ panel_df[f"{first_outcome}_precision"],
1704
+ 0.0,
1705
+ )
1706
+
1707
+ # Drop geographic units (PSUs) with zero total weight — they would
1708
+ # inflate survey df and distort second-stage variance estimation.
1709
+ geo_col = by_cols[0]
1710
+ geo_weight = panel_df.groupby(geo_col)[weight_col].sum()
1711
+ zero_geos = geo_weight[geo_weight == 0].index
1712
+ if len(zero_geos) > 0:
1713
+ n_before = len(panel_df)
1714
+ panel_df = panel_df[~panel_df[geo_col].isin(zero_geos)].reset_index(drop=True)
1715
+ n_after = len(panel_df)
1716
+ warnings.warn(
1717
+ f"Dropped {n_before - n_after} cell(s) from {len(zero_geos)} "
1718
+ f"geographic unit(s) with zero total weight: "
1719
+ f"{list(zero_geos[:5])}"
1720
+ + (f" ... and {len(zero_geos) - 5} more" if len(zero_geos) > 5 else ""),
1721
+ UserWarning,
1722
+ stacklevel=2,
1723
+ )
1724
+
1725
+ # Guard: all cells dropped
1726
+ if panel_df.empty:
1727
+ raise ValueError(
1728
+ "No estimable cells remain after aggregation. "
1729
+ "All cells had missing outcomes or zero effective weight."
1730
+ )
1731
+
1732
+ second_stage_design = SurveyDesign(
1733
+ weights=weight_col,
1734
+ weight_type="aweight",
1735
+ psu=geo_col,
1736
+ )
1737
+
1738
+ return panel_df, second_stage_design