diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/prep.py
ADDED
|
@@ -0,0 +1,1738 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data preparation utilities for difference-in-differences analysis.
|
|
3
|
+
|
|
4
|
+
This module provides helper functions to prepare data for DiD estimation,
|
|
5
|
+
including creating treatment indicators, reshaping panel data, and
|
|
6
|
+
generating synthetic datasets for testing.
|
|
7
|
+
|
|
8
|
+
Data generation functions (generate_*) are defined in prep_dgp.py and
|
|
9
|
+
re-exported here for backward compatibility.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
# Re-export data generation functions from prep_dgp for backward compatibility
|
|
18
|
+
from diff_diff.prep_dgp import ( # noqa: F401
|
|
19
|
+
generate_continuous_did_data,
|
|
20
|
+
generate_ddd_data,
|
|
21
|
+
generate_did_data,
|
|
22
|
+
generate_event_study_data,
|
|
23
|
+
generate_factor_data,
|
|
24
|
+
generate_panel_data,
|
|
25
|
+
generate_staggered_data,
|
|
26
|
+
generate_staggered_ddd_data,
|
|
27
|
+
generate_survey_did_data,
|
|
28
|
+
)
|
|
29
|
+
from diff_diff.survey import (
|
|
30
|
+
ResolvedSurveyDesign,
|
|
31
|
+
SurveyDesign,
|
|
32
|
+
compute_replicate_if_variance,
|
|
33
|
+
compute_survey_if_variance,
|
|
34
|
+
)
|
|
35
|
+
from diff_diff.utils import compute_synthetic_weights
|
|
36
|
+
|
|
37
|
+
# Constants for rank_control_units
|
|
38
|
+
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
|
|
39
|
+
_OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def make_treatment_indicator(
|
|
43
|
+
data: pd.DataFrame,
|
|
44
|
+
column: str,
|
|
45
|
+
treated_values: Optional[Union[Any, List[Any]]] = None,
|
|
46
|
+
threshold: Optional[float] = None,
|
|
47
|
+
above_threshold: bool = True,
|
|
48
|
+
new_column: str = "treated",
|
|
49
|
+
) -> pd.DataFrame:
|
|
50
|
+
"""
|
|
51
|
+
Create a binary treatment indicator column from various input types.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
data : pd.DataFrame
|
|
56
|
+
Input DataFrame.
|
|
57
|
+
column : str
|
|
58
|
+
Name of the column to use for creating the treatment indicator.
|
|
59
|
+
treated_values : Any or list, optional
|
|
60
|
+
Value(s) that indicate treatment. Units with these values get
|
|
61
|
+
treatment=1, others get treatment=0.
|
|
62
|
+
threshold : float, optional
|
|
63
|
+
Numeric threshold for creating treatment. Used when the treatment
|
|
64
|
+
is based on a continuous variable (e.g., treat firms above median size).
|
|
65
|
+
above_threshold : bool, default=True
|
|
66
|
+
If True, values >= threshold are treated. If False, values <= threshold
|
|
67
|
+
are treated. Only used when threshold is specified.
|
|
68
|
+
new_column : str, default="treated"
|
|
69
|
+
Name of the new treatment indicator column.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
pd.DataFrame
|
|
74
|
+
DataFrame with the new treatment indicator column added.
|
|
75
|
+
|
|
76
|
+
Examples
|
|
77
|
+
--------
|
|
78
|
+
Create treatment from categorical variable:
|
|
79
|
+
|
|
80
|
+
>>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
|
|
81
|
+
>>> df = make_treatment_indicator(df, 'group', treated_values='A')
|
|
82
|
+
>>> df['treated'].tolist()
|
|
83
|
+
[1, 1, 0, 0]
|
|
84
|
+
|
|
85
|
+
Create treatment from numeric threshold:
|
|
86
|
+
|
|
87
|
+
>>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
|
|
88
|
+
>>> df = make_treatment_indicator(df, 'size', threshold=75)
|
|
89
|
+
>>> df['treated'].tolist()
|
|
90
|
+
[0, 0, 1, 1]
|
|
91
|
+
|
|
92
|
+
Treat units below a threshold:
|
|
93
|
+
|
|
94
|
+
>>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
|
|
95
|
+
>>> df['treated'].tolist()
|
|
96
|
+
[1, 1, 0, 0]
|
|
97
|
+
"""
|
|
98
|
+
df = data.copy()
|
|
99
|
+
|
|
100
|
+
if treated_values is not None and threshold is not None:
|
|
101
|
+
raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
|
|
102
|
+
|
|
103
|
+
if treated_values is None and threshold is None:
|
|
104
|
+
raise ValueError("Must specify either 'treated_values' or 'threshold'.")
|
|
105
|
+
|
|
106
|
+
if column not in df.columns:
|
|
107
|
+
raise ValueError(f"Column '{column}' not found in DataFrame.")
|
|
108
|
+
|
|
109
|
+
if treated_values is not None:
|
|
110
|
+
# Convert single value to list
|
|
111
|
+
if not isinstance(treated_values, (list, tuple, set)):
|
|
112
|
+
treated_values = [treated_values]
|
|
113
|
+
df[new_column] = df[column].isin(treated_values).astype(int)
|
|
114
|
+
else:
|
|
115
|
+
# Use threshold
|
|
116
|
+
if above_threshold:
|
|
117
|
+
df[new_column] = (df[column] >= threshold).astype(int)
|
|
118
|
+
else:
|
|
119
|
+
df[new_column] = (df[column] <= threshold).astype(int)
|
|
120
|
+
|
|
121
|
+
return df
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def make_post_indicator(
|
|
125
|
+
data: pd.DataFrame,
|
|
126
|
+
time_column: str,
|
|
127
|
+
post_periods: Optional[Union[Any, List[Any]]] = None,
|
|
128
|
+
treatment_start: Optional[Any] = None,
|
|
129
|
+
new_column: str = "post",
|
|
130
|
+
) -> pd.DataFrame:
|
|
131
|
+
"""
|
|
132
|
+
Create a binary post-treatment indicator column.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
data : pd.DataFrame
|
|
137
|
+
Input DataFrame.
|
|
138
|
+
time_column : str
|
|
139
|
+
Name of the time/period column.
|
|
140
|
+
post_periods : Any or list, optional
|
|
141
|
+
Specific period value(s) that are post-treatment. Periods matching
|
|
142
|
+
these values get post=1, others get post=0.
|
|
143
|
+
treatment_start : Any, optional
|
|
144
|
+
The first post-treatment period. All periods >= this value get post=1.
|
|
145
|
+
Works with numeric periods, strings (sorted alphabetically), or dates.
|
|
146
|
+
new_column : str, default="post"
|
|
147
|
+
Name of the new post indicator column.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
pd.DataFrame
|
|
152
|
+
DataFrame with the new post indicator column added.
|
|
153
|
+
|
|
154
|
+
Examples
|
|
155
|
+
--------
|
|
156
|
+
Using specific post periods:
|
|
157
|
+
|
|
158
|
+
>>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
|
|
159
|
+
>>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
|
|
160
|
+
>>> df['post'].tolist()
|
|
161
|
+
[0, 0, 1, 1]
|
|
162
|
+
|
|
163
|
+
Using treatment start:
|
|
164
|
+
|
|
165
|
+
>>> df = make_post_indicator(df, 'year', treatment_start=2020)
|
|
166
|
+
>>> df['post'].tolist()
|
|
167
|
+
[0, 0, 1, 1]
|
|
168
|
+
|
|
169
|
+
Works with date columns:
|
|
170
|
+
|
|
171
|
+
>>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
|
|
172
|
+
>>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
|
|
173
|
+
"""
|
|
174
|
+
df = data.copy()
|
|
175
|
+
|
|
176
|
+
if post_periods is not None and treatment_start is not None:
|
|
177
|
+
raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
|
|
178
|
+
|
|
179
|
+
if post_periods is None and treatment_start is None:
|
|
180
|
+
raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
|
|
181
|
+
|
|
182
|
+
if time_column not in df.columns:
|
|
183
|
+
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
|
|
184
|
+
|
|
185
|
+
if post_periods is not None:
|
|
186
|
+
# Convert single value to list
|
|
187
|
+
if not isinstance(post_periods, (list, tuple, set)):
|
|
188
|
+
post_periods = [post_periods]
|
|
189
|
+
df[new_column] = df[time_column].isin(post_periods).astype(int)
|
|
190
|
+
else:
|
|
191
|
+
# Use treatment_start - convert to same type as column if needed
|
|
192
|
+
col_dtype = df[time_column].dtype
|
|
193
|
+
if pd.api.types.is_datetime64_any_dtype(col_dtype):
|
|
194
|
+
treatment_start = pd.to_datetime(treatment_start)
|
|
195
|
+
df[new_column] = (df[time_column] >= treatment_start).astype(int)
|
|
196
|
+
|
|
197
|
+
return df
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def wide_to_long(
|
|
201
|
+
data: pd.DataFrame,
|
|
202
|
+
value_columns: List[str],
|
|
203
|
+
id_column: str,
|
|
204
|
+
time_name: str = "period",
|
|
205
|
+
value_name: str = "value",
|
|
206
|
+
time_values: Optional[List[Any]] = None,
|
|
207
|
+
) -> pd.DataFrame:
|
|
208
|
+
"""
|
|
209
|
+
Convert wide-format panel data to long format for DiD analysis.
|
|
210
|
+
|
|
211
|
+
Wide format has one row per unit with multiple columns for each time period.
|
|
212
|
+
Long format has one row per unit-period combination.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
data : pd.DataFrame
|
|
217
|
+
Wide-format DataFrame with one row per unit.
|
|
218
|
+
value_columns : list of str
|
|
219
|
+
Column names containing the outcome values for each period.
|
|
220
|
+
These should be in chronological order.
|
|
221
|
+
id_column : str
|
|
222
|
+
Column name for the unit identifier.
|
|
223
|
+
time_name : str, default="period"
|
|
224
|
+
Name for the new time period column.
|
|
225
|
+
value_name : str, default="value"
|
|
226
|
+
Name for the new value/outcome column.
|
|
227
|
+
time_values : list, optional
|
|
228
|
+
Values to use for time periods. If None, uses 0, 1, 2, ...
|
|
229
|
+
Must have same length as value_columns.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
pd.DataFrame
|
|
234
|
+
Long-format DataFrame with one row per unit-period.
|
|
235
|
+
|
|
236
|
+
Examples
|
|
237
|
+
--------
|
|
238
|
+
>>> wide_df = pd.DataFrame({
|
|
239
|
+
... 'firm_id': [1, 2, 3],
|
|
240
|
+
... 'sales_2019': [100, 150, 200],
|
|
241
|
+
... 'sales_2020': [110, 160, 210],
|
|
242
|
+
... 'sales_2021': [120, 170, 220]
|
|
243
|
+
... })
|
|
244
|
+
>>> long_df = wide_to_long(
|
|
245
|
+
... wide_df,
|
|
246
|
+
... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
|
|
247
|
+
... id_column='firm_id',
|
|
248
|
+
... time_name='year',
|
|
249
|
+
... value_name='sales',
|
|
250
|
+
... time_values=[2019, 2020, 2021]
|
|
251
|
+
... )
|
|
252
|
+
>>> len(long_df)
|
|
253
|
+
9
|
|
254
|
+
>>> long_df.columns.tolist()
|
|
255
|
+
['firm_id', 'year', 'sales']
|
|
256
|
+
"""
|
|
257
|
+
if not value_columns:
|
|
258
|
+
raise ValueError("value_columns cannot be empty.")
|
|
259
|
+
|
|
260
|
+
if id_column not in data.columns:
|
|
261
|
+
raise ValueError(f"Column '{id_column}' not found in DataFrame.")
|
|
262
|
+
|
|
263
|
+
for col in value_columns:
|
|
264
|
+
if col not in data.columns:
|
|
265
|
+
raise ValueError(f"Column '{col}' not found in DataFrame.")
|
|
266
|
+
|
|
267
|
+
if time_values is None:
|
|
268
|
+
time_values = list(range(len(value_columns)))
|
|
269
|
+
|
|
270
|
+
if len(time_values) != len(value_columns):
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"time_values length ({len(time_values)}) must match "
|
|
273
|
+
f"value_columns length ({len(value_columns)})."
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Get other columns to preserve (not id or value columns)
|
|
277
|
+
other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
|
|
278
|
+
|
|
279
|
+
# Use pd.melt for better performance (vectorized)
|
|
280
|
+
long_df = pd.melt(
|
|
281
|
+
data,
|
|
282
|
+
id_vars=[id_column] + other_cols,
|
|
283
|
+
value_vars=value_columns,
|
|
284
|
+
var_name="_temp_var",
|
|
285
|
+
value_name=value_name,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Map column names to time values
|
|
289
|
+
col_to_time = dict(zip(value_columns, time_values))
|
|
290
|
+
long_df[time_name] = long_df["_temp_var"].map(col_to_time)
|
|
291
|
+
long_df = long_df.drop("_temp_var", axis=1)
|
|
292
|
+
|
|
293
|
+
# Reorder columns and sort
|
|
294
|
+
cols = [id_column, time_name, value_name] + other_cols
|
|
295
|
+
return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def balance_panel(
|
|
299
|
+
data: pd.DataFrame,
|
|
300
|
+
unit_column: str,
|
|
301
|
+
time_column: str,
|
|
302
|
+
method: str = "inner",
|
|
303
|
+
fill_value: Optional[float] = None,
|
|
304
|
+
) -> pd.DataFrame:
|
|
305
|
+
"""
|
|
306
|
+
Balance a panel dataset to ensure all units have all time periods.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
data : pd.DataFrame
|
|
311
|
+
Unbalanced panel data.
|
|
312
|
+
unit_column : str
|
|
313
|
+
Column name for unit identifier.
|
|
314
|
+
time_column : str
|
|
315
|
+
Column name for time period.
|
|
316
|
+
method : str, default="inner"
|
|
317
|
+
Balancing method:
|
|
318
|
+
- "inner": Keep only units that appear in all periods (drops units)
|
|
319
|
+
- "outer": Include all unit-period combinations (creates NaN)
|
|
320
|
+
- "fill": Include all combinations and fill missing values
|
|
321
|
+
fill_value : float, optional
|
|
322
|
+
Value to fill missing observations when method="fill".
|
|
323
|
+
If None with method="fill", uses column-specific forward fill.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
pd.DataFrame
|
|
328
|
+
Balanced panel DataFrame.
|
|
329
|
+
|
|
330
|
+
Examples
|
|
331
|
+
--------
|
|
332
|
+
Keep only complete units:
|
|
333
|
+
|
|
334
|
+
>>> df = pd.DataFrame({
|
|
335
|
+
... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
|
|
336
|
+
... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
|
|
337
|
+
... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
|
|
338
|
+
... })
|
|
339
|
+
>>> balanced = balance_panel(df, 'unit', 'period', method='inner')
|
|
340
|
+
>>> balanced['unit'].unique().tolist()
|
|
341
|
+
[1, 3]
|
|
342
|
+
|
|
343
|
+
Include all combinations:
|
|
344
|
+
|
|
345
|
+
>>> balanced = balance_panel(df, 'unit', 'period', method='outer')
|
|
346
|
+
>>> len(balanced)
|
|
347
|
+
9
|
|
348
|
+
"""
|
|
349
|
+
if unit_column not in data.columns:
|
|
350
|
+
raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
|
|
351
|
+
if time_column not in data.columns:
|
|
352
|
+
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
|
|
353
|
+
|
|
354
|
+
if method not in ["inner", "outer", "fill"]:
|
|
355
|
+
raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
|
|
356
|
+
|
|
357
|
+
all_units = data[unit_column].unique()
|
|
358
|
+
all_periods = sorted(data[time_column].unique())
|
|
359
|
+
n_periods = len(all_periods)
|
|
360
|
+
|
|
361
|
+
if method == "inner":
|
|
362
|
+
# Keep only units that have all periods
|
|
363
|
+
unit_counts = data.groupby(unit_column)[time_column].nunique()
|
|
364
|
+
complete_units = unit_counts[unit_counts == n_periods].index
|
|
365
|
+
return data[data[unit_column].isin(complete_units)].copy()
|
|
366
|
+
|
|
367
|
+
elif method in ["outer", "fill"]:
|
|
368
|
+
# Create full grid of unit-period combinations
|
|
369
|
+
full_index = pd.MultiIndex.from_product(
|
|
370
|
+
[all_units, all_periods], names=[unit_column, time_column]
|
|
371
|
+
)
|
|
372
|
+
full_df = pd.DataFrame(index=full_index).reset_index()
|
|
373
|
+
|
|
374
|
+
# Merge with original data
|
|
375
|
+
result = full_df.merge(data, on=[unit_column, time_column], how="left")
|
|
376
|
+
|
|
377
|
+
if method == "fill":
|
|
378
|
+
# Identify columns to fill (exclude unit and time columns)
|
|
379
|
+
cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
|
|
380
|
+
|
|
381
|
+
if fill_value is not None:
|
|
382
|
+
# Fill specified columns with fill_value
|
|
383
|
+
numeric_cols = result.select_dtypes(include=[np.number]).columns
|
|
384
|
+
for col in numeric_cols:
|
|
385
|
+
if col in cols_to_fill:
|
|
386
|
+
result[col] = result[col].fillna(fill_value)
|
|
387
|
+
else:
|
|
388
|
+
# Forward fill within each unit for non-key columns
|
|
389
|
+
result = result.sort_values([unit_column, time_column])
|
|
390
|
+
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
|
|
391
|
+
# Backward fill any remaining NaN at start
|
|
392
|
+
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
|
|
393
|
+
|
|
394
|
+
return result
|
|
395
|
+
|
|
396
|
+
return data
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def validate_did_data(
|
|
400
|
+
data: pd.DataFrame,
|
|
401
|
+
outcome: str,
|
|
402
|
+
treatment: str,
|
|
403
|
+
time: str,
|
|
404
|
+
unit: Optional[str] = None,
|
|
405
|
+
raise_on_error: bool = True,
|
|
406
|
+
) -> Dict[str, Any]:
|
|
407
|
+
"""
|
|
408
|
+
Validate that data is properly formatted for DiD analysis.
|
|
409
|
+
|
|
410
|
+
Checks for common data issues and provides informative error messages.
|
|
411
|
+
|
|
412
|
+
Parameters
|
|
413
|
+
----------
|
|
414
|
+
data : pd.DataFrame
|
|
415
|
+
Data to validate.
|
|
416
|
+
outcome : str
|
|
417
|
+
Name of outcome variable column.
|
|
418
|
+
treatment : str
|
|
419
|
+
Name of treatment indicator column.
|
|
420
|
+
time : str
|
|
421
|
+
Name of time/post indicator column.
|
|
422
|
+
unit : str, optional
|
|
423
|
+
Name of unit identifier column (for panel data validation).
|
|
424
|
+
raise_on_error : bool, default=True
|
|
425
|
+
If True, raises ValueError on validation failures.
|
|
426
|
+
If False, returns validation results without raising.
|
|
427
|
+
|
|
428
|
+
Returns
|
|
429
|
+
-------
|
|
430
|
+
dict
|
|
431
|
+
Validation results with keys:
|
|
432
|
+
- valid: bool indicating if data passed all checks
|
|
433
|
+
- errors: list of error messages
|
|
434
|
+
- warnings: list of warning messages
|
|
435
|
+
- summary: dict with data summary statistics
|
|
436
|
+
|
|
437
|
+
Examples
|
|
438
|
+
--------
|
|
439
|
+
>>> df = pd.DataFrame({
|
|
440
|
+
... 'y': [1, 2, 3, 4],
|
|
441
|
+
... 'treated': [0, 0, 1, 1],
|
|
442
|
+
... 'post': [0, 1, 0, 1]
|
|
443
|
+
... })
|
|
444
|
+
>>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
|
|
445
|
+
>>> result['valid']
|
|
446
|
+
True
|
|
447
|
+
"""
|
|
448
|
+
errors = []
|
|
449
|
+
warnings = []
|
|
450
|
+
|
|
451
|
+
# Check columns exist
|
|
452
|
+
required_cols = [outcome, treatment, time]
|
|
453
|
+
if unit is not None:
|
|
454
|
+
required_cols.append(unit)
|
|
455
|
+
|
|
456
|
+
for col in required_cols:
|
|
457
|
+
if col not in data.columns:
|
|
458
|
+
errors.append(f"Required column '{col}' not found in DataFrame.")
|
|
459
|
+
|
|
460
|
+
if errors:
|
|
461
|
+
if raise_on_error:
|
|
462
|
+
raise ValueError("\n".join(errors))
|
|
463
|
+
return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
|
|
464
|
+
|
|
465
|
+
# Check outcome is numeric
|
|
466
|
+
if not pd.api.types.is_numeric_dtype(data[outcome]):
|
|
467
|
+
errors.append(
|
|
468
|
+
f"Outcome column '{outcome}' must be numeric. " f"Got type: {data[outcome].dtype}"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Check treatment is binary
|
|
472
|
+
treatment_vals = data[treatment].dropna().unique()
|
|
473
|
+
if not set(treatment_vals).issubset({0, 1}):
|
|
474
|
+
errors.append(
|
|
475
|
+
f"Treatment column '{treatment}' must be binary (0 or 1). "
|
|
476
|
+
f"Found values: {sorted(treatment_vals)}"
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# Check time is binary for simple DiD
|
|
480
|
+
time_vals = data[time].dropna().unique()
|
|
481
|
+
if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
|
|
482
|
+
warnings.append(
|
|
483
|
+
f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
|
|
484
|
+
"For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Check for missing values
|
|
488
|
+
for col in required_cols:
|
|
489
|
+
n_missing = data[col].isna().sum()
|
|
490
|
+
if n_missing > 0:
|
|
491
|
+
errors.append(
|
|
492
|
+
f"Column '{col}' has {n_missing} missing values. "
|
|
493
|
+
"Please handle missing data before fitting."
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Calculate summary statistics
|
|
497
|
+
summary = {}
|
|
498
|
+
if not errors:
|
|
499
|
+
summary["n_obs"] = len(data)
|
|
500
|
+
summary["n_treated"] = int((data[treatment] == 1).sum())
|
|
501
|
+
summary["n_control"] = int((data[treatment] == 0).sum())
|
|
502
|
+
summary["n_periods"] = len(time_vals)
|
|
503
|
+
|
|
504
|
+
if unit is not None:
|
|
505
|
+
summary["n_units"] = data[unit].nunique()
|
|
506
|
+
|
|
507
|
+
# Check for sufficient variation
|
|
508
|
+
if summary["n_treated"] == 0:
|
|
509
|
+
errors.append("No treated observations found (treatment column is all 0).")
|
|
510
|
+
if summary["n_control"] == 0:
|
|
511
|
+
errors.append("No control observations found (treatment column is all 1).")
|
|
512
|
+
|
|
513
|
+
# Check for each treatment-time combination
|
|
514
|
+
if len(time_vals) == 2:
|
|
515
|
+
# For 2-period DiD, check all four cells
|
|
516
|
+
for t_val in [0, 1]:
|
|
517
|
+
for p_val in time_vals:
|
|
518
|
+
count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
|
|
519
|
+
if count == 0:
|
|
520
|
+
errors.append(
|
|
521
|
+
f"No observations for treatment={t_val}, time={p_val}. "
|
|
522
|
+
"DiD requires observations in all treatment-time cells."
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
# For multi-period, check that both treatment groups exist in multiple periods
|
|
526
|
+
for t_val in [0, 1]:
|
|
527
|
+
n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
|
|
528
|
+
if n_periods_with_obs < 2:
|
|
529
|
+
group_name = "Treated" if t_val == 1 else "Control"
|
|
530
|
+
errors.append(
|
|
531
|
+
f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
|
|
532
|
+
"DiD requires multiple periods per group."
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Panel-specific validation
|
|
536
|
+
if unit is not None and not errors:
|
|
537
|
+
# Check treatment is constant within units
|
|
538
|
+
unit_treatment_var = data.groupby(unit)[treatment].nunique()
|
|
539
|
+
units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
|
|
540
|
+
if len(units_with_varying_treatment) > 0:
|
|
541
|
+
warnings.append(
|
|
542
|
+
f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
|
|
543
|
+
"For standard DiD, treatment should be constant within units. "
|
|
544
|
+
"This may be intentional for staggered adoption designs."
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
# Check panel balance
|
|
548
|
+
periods_per_unit = data.groupby(unit)[time].nunique()
|
|
549
|
+
if periods_per_unit.min() != periods_per_unit.max():
|
|
550
|
+
warnings.append(
|
|
551
|
+
f"Unbalanced panel detected. Units have between "
|
|
552
|
+
f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
|
|
553
|
+
"Consider using balance_panel() to balance the data."
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
valid = len(errors) == 0
|
|
557
|
+
|
|
558
|
+
if raise_on_error and not valid:
|
|
559
|
+
raise ValueError("Data validation failed:\n" + "\n".join(errors))
|
|
560
|
+
|
|
561
|
+
return {"valid": valid, "errors": errors, "warnings": warnings, "summary": summary}
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def summarize_did_data(
|
|
565
|
+
data: pd.DataFrame, outcome: str, treatment: str, time: str, unit: Optional[str] = None
|
|
566
|
+
) -> pd.DataFrame:
|
|
567
|
+
"""
|
|
568
|
+
Generate summary statistics by treatment group and time period.
|
|
569
|
+
|
|
570
|
+
Parameters
|
|
571
|
+
----------
|
|
572
|
+
data : pd.DataFrame
|
|
573
|
+
Input data.
|
|
574
|
+
outcome : str
|
|
575
|
+
Name of outcome variable column.
|
|
576
|
+
treatment : str
|
|
577
|
+
Name of treatment indicator column.
|
|
578
|
+
time : str
|
|
579
|
+
Name of time/period column.
|
|
580
|
+
unit : str, optional
|
|
581
|
+
Name of unit identifier column.
|
|
582
|
+
|
|
583
|
+
Returns
|
|
584
|
+
-------
|
|
585
|
+
pd.DataFrame
|
|
586
|
+
Summary statistics with columns for each treatment-time combination.
|
|
587
|
+
|
|
588
|
+
Examples
|
|
589
|
+
--------
|
|
590
|
+
>>> df = pd.DataFrame({
|
|
591
|
+
... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
|
|
592
|
+
... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
|
|
593
|
+
... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
|
|
594
|
+
... })
|
|
595
|
+
>>> summary = summarize_did_data(df, 'y', 'treated', 'post')
|
|
596
|
+
>>> print(summary)
|
|
597
|
+
"""
|
|
598
|
+
# Group by treatment and time
|
|
599
|
+
summary = (
|
|
600
|
+
data.groupby([treatment, time])[outcome]
|
|
601
|
+
.agg([("n", "count"), ("mean", "mean"), ("std", "std"), ("min", "min"), ("max", "max")])
|
|
602
|
+
.round(4)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Calculate time values for labeling
|
|
606
|
+
time_vals = sorted(data[time].unique())
|
|
607
|
+
|
|
608
|
+
# Add group labels based on sorted time values (not literal 0/1)
|
|
609
|
+
if len(time_vals) == 2:
|
|
610
|
+
pre_val, post_val = time_vals[0], time_vals[1]
|
|
611
|
+
|
|
612
|
+
def format_label(x: tuple) -> str:
|
|
613
|
+
treatment_label = "Treated" if x[0] == 1 else "Control"
|
|
614
|
+
time_label = "Post" if x[1] == post_val else "Pre"
|
|
615
|
+
return f"{treatment_label} - {time_label}"
|
|
616
|
+
|
|
617
|
+
summary.index = summary.index.map(format_label)
|
|
618
|
+
|
|
619
|
+
# Calculate means for each cell
|
|
620
|
+
treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
|
|
621
|
+
treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
|
|
622
|
+
control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
|
|
623
|
+
control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
|
|
624
|
+
|
|
625
|
+
# Calculate DiD
|
|
626
|
+
treated_diff = treated_post - treated_pre
|
|
627
|
+
control_diff = control_post - control_pre
|
|
628
|
+
did_estimate = treated_diff - control_diff
|
|
629
|
+
|
|
630
|
+
# Add to summary as a new row
|
|
631
|
+
did_row = pd.DataFrame(
|
|
632
|
+
{"n": ["-"], "mean": [did_estimate], "std": ["-"], "min": ["-"], "max": ["-"]},
|
|
633
|
+
index=["DiD Estimate"],
|
|
634
|
+
)
|
|
635
|
+
summary = pd.concat([summary, did_row])
|
|
636
|
+
else:
|
|
637
|
+
summary.index = summary.index.map(
|
|
638
|
+
lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
return summary
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def create_event_time(
|
|
645
|
+
data: pd.DataFrame, time_column: str, treatment_time_column: str, new_column: str = "event_time"
|
|
646
|
+
) -> pd.DataFrame:
|
|
647
|
+
"""
|
|
648
|
+
Create an event-time column relative to treatment timing.
|
|
649
|
+
|
|
650
|
+
Useful for event study designs where treatment occurs at different
|
|
651
|
+
times for different units.
|
|
652
|
+
|
|
653
|
+
Parameters
|
|
654
|
+
----------
|
|
655
|
+
data : pd.DataFrame
|
|
656
|
+
Panel data.
|
|
657
|
+
time_column : str
|
|
658
|
+
Name of the calendar time column.
|
|
659
|
+
treatment_time_column : str
|
|
660
|
+
Name of the column indicating when each unit was treated.
|
|
661
|
+
Units with NaN or infinity are considered never-treated.
|
|
662
|
+
new_column : str, default="event_time"
|
|
663
|
+
Name of the new event-time column.
|
|
664
|
+
|
|
665
|
+
Returns
|
|
666
|
+
-------
|
|
667
|
+
pd.DataFrame
|
|
668
|
+
DataFrame with event-time column added. Values are:
|
|
669
|
+
- Negative for pre-treatment periods
|
|
670
|
+
- 0 for the treatment period
|
|
671
|
+
- Positive for post-treatment periods
|
|
672
|
+
- NaN for never-treated units
|
|
673
|
+
|
|
674
|
+
Examples
|
|
675
|
+
--------
|
|
676
|
+
>>> df = pd.DataFrame({
|
|
677
|
+
... 'unit': [1, 1, 1, 2, 2, 2],
|
|
678
|
+
... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
|
|
679
|
+
... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
|
|
680
|
+
... })
|
|
681
|
+
>>> df = create_event_time(df, 'year', 'treatment_year')
|
|
682
|
+
>>> df['event_time'].tolist()
|
|
683
|
+
[-1, 0, 1, -2, -1, 0]
|
|
684
|
+
"""
|
|
685
|
+
df = data.copy()
|
|
686
|
+
|
|
687
|
+
if time_column not in df.columns:
|
|
688
|
+
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
|
|
689
|
+
if treatment_time_column not in df.columns:
|
|
690
|
+
raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
|
|
691
|
+
|
|
692
|
+
# Calculate event time
|
|
693
|
+
df[new_column] = df[time_column] - df[treatment_time_column]
|
|
694
|
+
|
|
695
|
+
# Handle never-treated (inf or NaN in treatment time)
|
|
696
|
+
col = df[treatment_time_column]
|
|
697
|
+
if pd.api.types.is_numeric_dtype(col):
|
|
698
|
+
never_treated = col.isna() | np.isinf(col)
|
|
699
|
+
else:
|
|
700
|
+
never_treated = col.isna()
|
|
701
|
+
df.loc[never_treated, new_column] = np.nan
|
|
702
|
+
|
|
703
|
+
return df
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def aggregate_to_cohorts(
|
|
707
|
+
data: pd.DataFrame,
|
|
708
|
+
unit_column: str,
|
|
709
|
+
time_column: str,
|
|
710
|
+
treatment_column: str,
|
|
711
|
+
outcome: str,
|
|
712
|
+
covariates: Optional[List[str]] = None,
|
|
713
|
+
) -> pd.DataFrame:
|
|
714
|
+
"""
|
|
715
|
+
Aggregate unit-level data to treatment cohort means.
|
|
716
|
+
|
|
717
|
+
Useful for visualization and cohort-level analysis.
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
data : pd.DataFrame
|
|
722
|
+
Unit-level panel data.
|
|
723
|
+
unit_column : str
|
|
724
|
+
Name of unit identifier column.
|
|
725
|
+
time_column : str
|
|
726
|
+
Name of time period column.
|
|
727
|
+
treatment_column : str
|
|
728
|
+
Name of treatment indicator column.
|
|
729
|
+
outcome : str
|
|
730
|
+
Name of outcome variable column.
|
|
731
|
+
covariates : list of str, optional
|
|
732
|
+
Additional columns to aggregate (will compute means).
|
|
733
|
+
|
|
734
|
+
Returns
|
|
735
|
+
-------
|
|
736
|
+
pd.DataFrame
|
|
737
|
+
Cohort-level data with mean outcomes by treatment status and period.
|
|
738
|
+
|
|
739
|
+
Examples
|
|
740
|
+
--------
|
|
741
|
+
>>> df = pd.DataFrame({
|
|
742
|
+
... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
|
|
743
|
+
... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
|
|
744
|
+
... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
|
|
745
|
+
... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
|
|
746
|
+
... })
|
|
747
|
+
>>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
|
|
748
|
+
>>> len(cohort_df)
|
|
749
|
+
4
|
|
750
|
+
"""
|
|
751
|
+
agg_cols = {outcome: "mean", unit_column: "nunique"}
|
|
752
|
+
|
|
753
|
+
if covariates:
|
|
754
|
+
for cov in covariates:
|
|
755
|
+
agg_cols[cov] = "mean"
|
|
756
|
+
|
|
757
|
+
cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
|
|
758
|
+
|
|
759
|
+
# Rename columns
|
|
760
|
+
cohort_data = cohort_data.rename(columns={unit_column: "n_units", outcome: f"mean_{outcome}"})
|
|
761
|
+
|
|
762
|
+
return cohort_data
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def rank_control_units(
|
|
766
|
+
data: pd.DataFrame,
|
|
767
|
+
unit_column: str,
|
|
768
|
+
time_column: str,
|
|
769
|
+
outcome_column: str,
|
|
770
|
+
treatment_column: Optional[str] = None,
|
|
771
|
+
treated_units: Optional[List[Any]] = None,
|
|
772
|
+
pre_periods: Optional[List[Any]] = None,
|
|
773
|
+
covariates: Optional[List[str]] = None,
|
|
774
|
+
outcome_weight: float = 0.7,
|
|
775
|
+
covariate_weight: float = 0.3,
|
|
776
|
+
exclude_units: Optional[List[Any]] = None,
|
|
777
|
+
require_units: Optional[List[Any]] = None,
|
|
778
|
+
n_top: Optional[int] = None,
|
|
779
|
+
suggest_treatment_candidates: bool = False,
|
|
780
|
+
n_treatment_candidates: int = 5,
|
|
781
|
+
lambda_reg: float = 0.0,
|
|
782
|
+
) -> pd.DataFrame:
|
|
783
|
+
"""
|
|
784
|
+
Rank potential control units by their suitability for DiD analysis.
|
|
785
|
+
|
|
786
|
+
Evaluates control units based on pre-treatment outcome trend similarity
|
|
787
|
+
and optional covariate matching to treated units. Returns a ranked list
|
|
788
|
+
with quality scores.
|
|
789
|
+
|
|
790
|
+
Parameters
|
|
791
|
+
----------
|
|
792
|
+
data : pd.DataFrame
|
|
793
|
+
Panel data in long format.
|
|
794
|
+
unit_column : str
|
|
795
|
+
Column name for unit identifier.
|
|
796
|
+
time_column : str
|
|
797
|
+
Column name for time periods.
|
|
798
|
+
outcome_column : str
|
|
799
|
+
Column name for outcome variable.
|
|
800
|
+
treatment_column : str, optional
|
|
801
|
+
Column with binary treatment indicator (0/1). Used to identify
|
|
802
|
+
treated units from data.
|
|
803
|
+
treated_units : list, optional
|
|
804
|
+
Explicit list of treated unit IDs. Alternative to treatment_column.
|
|
805
|
+
pre_periods : list, optional
|
|
806
|
+
Pre-treatment periods for comparison. If None, uses first half of periods.
|
|
807
|
+
covariates : list of str, optional
|
|
808
|
+
Covariate columns for matching. Similarity is based on pre-treatment means.
|
|
809
|
+
outcome_weight : float, default=0.7
|
|
810
|
+
Weight for pre-treatment outcome trend similarity (0-1).
|
|
811
|
+
covariate_weight : float, default=0.3
|
|
812
|
+
Weight for covariate distance (0-1). Ignored if no covariates.
|
|
813
|
+
exclude_units : list, optional
|
|
814
|
+
Units that cannot be in control group.
|
|
815
|
+
require_units : list, optional
|
|
816
|
+
Units that must be in control group (will always appear in output).
|
|
817
|
+
n_top : int, optional
|
|
818
|
+
Return only top N control units. If None, return all.
|
|
819
|
+
suggest_treatment_candidates : bool, default=False
|
|
820
|
+
If True and no treated units specified, identify potential treatment
|
|
821
|
+
candidates instead of ranking controls.
|
|
822
|
+
n_treatment_candidates : int, default=5
|
|
823
|
+
Number of treatment candidates to suggest.
|
|
824
|
+
lambda_reg : float, default=0.0
|
|
825
|
+
Regularization for synthetic weights. Higher values give more uniform
|
|
826
|
+
weights across controls.
|
|
827
|
+
|
|
828
|
+
Returns
|
|
829
|
+
-------
|
|
830
|
+
pd.DataFrame
|
|
831
|
+
Ranked control units with columns:
|
|
832
|
+
- unit: Unit identifier
|
|
833
|
+
- quality_score: Combined quality score (0-1, higher is better)
|
|
834
|
+
- outcome_trend_score: Pre-treatment outcome trend similarity
|
|
835
|
+
- covariate_score: Covariate match score (NaN if no covariates)
|
|
836
|
+
- synthetic_weight: Weight from synthetic control optimization
|
|
837
|
+
- pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
|
|
838
|
+
- is_required: Whether unit was in require_units
|
|
839
|
+
|
|
840
|
+
If suggest_treatment_candidates=True (and no treated units):
|
|
841
|
+
- unit: Unit identifier
|
|
842
|
+
- treatment_candidate_score: Suitability as treatment unit
|
|
843
|
+
- avg_outcome_level: Pre-treatment outcome mean
|
|
844
|
+
- outcome_trend: Pre-treatment trend slope
|
|
845
|
+
- n_similar_controls: Count of similar potential controls
|
|
846
|
+
|
|
847
|
+
Examples
|
|
848
|
+
--------
|
|
849
|
+
Rank controls against treated units:
|
|
850
|
+
|
|
851
|
+
>>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
|
|
852
|
+
>>> ranking = rank_control_units(
|
|
853
|
+
... data,
|
|
854
|
+
... unit_column='unit',
|
|
855
|
+
... time_column='period',
|
|
856
|
+
... outcome_column='outcome',
|
|
857
|
+
... treatment_column='treated',
|
|
858
|
+
... n_top=10
|
|
859
|
+
... )
|
|
860
|
+
>>> ranking['quality_score'].is_monotonic_decreasing
|
|
861
|
+
True
|
|
862
|
+
|
|
863
|
+
With covariates:
|
|
864
|
+
|
|
865
|
+
>>> data['size'] = np.random.randn(len(data))
|
|
866
|
+
>>> ranking = rank_control_units(
|
|
867
|
+
... data,
|
|
868
|
+
... unit_column='unit',
|
|
869
|
+
... time_column='period',
|
|
870
|
+
... outcome_column='outcome',
|
|
871
|
+
... treatment_column='treated',
|
|
872
|
+
... covariates=['size']
|
|
873
|
+
... )
|
|
874
|
+
|
|
875
|
+
Filter data for SyntheticDiD:
|
|
876
|
+
|
|
877
|
+
>>> top_controls = ranking['unit'].tolist()
|
|
878
|
+
>>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
|
|
879
|
+
"""
|
|
880
|
+
# -------------------------------------------------------------------------
|
|
881
|
+
# Input validation
|
|
882
|
+
# -------------------------------------------------------------------------
|
|
883
|
+
for col in [unit_column, time_column, outcome_column]:
|
|
884
|
+
if col not in data.columns:
|
|
885
|
+
raise ValueError(f"Column '{col}' not found in DataFrame.")
|
|
886
|
+
|
|
887
|
+
if treatment_column is not None and treatment_column not in data.columns:
|
|
888
|
+
raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
|
|
889
|
+
|
|
890
|
+
if covariates:
|
|
891
|
+
for cov in covariates:
|
|
892
|
+
if cov not in data.columns:
|
|
893
|
+
raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
|
|
894
|
+
|
|
895
|
+
if not 0 <= outcome_weight <= 1:
|
|
896
|
+
raise ValueError("outcome_weight must be between 0 and 1")
|
|
897
|
+
if not 0 <= covariate_weight <= 1:
|
|
898
|
+
raise ValueError("covariate_weight must be between 0 and 1")
|
|
899
|
+
|
|
900
|
+
if treated_units is not None and treatment_column is not None:
|
|
901
|
+
raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
|
|
902
|
+
|
|
903
|
+
if require_units and exclude_units:
|
|
904
|
+
invalid_required = [u for u in require_units if u in exclude_units]
|
|
905
|
+
if invalid_required:
|
|
906
|
+
raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
|
|
907
|
+
|
|
908
|
+
# -------------------------------------------------------------------------
|
|
909
|
+
# Determine pre-treatment periods
|
|
910
|
+
# -------------------------------------------------------------------------
|
|
911
|
+
all_periods = sorted(data[time_column].unique())
|
|
912
|
+
if pre_periods is None:
|
|
913
|
+
mid_point = len(all_periods) // 2
|
|
914
|
+
pre_periods = all_periods[:mid_point]
|
|
915
|
+
else:
|
|
916
|
+
pre_periods = list(pre_periods)
|
|
917
|
+
|
|
918
|
+
if len(pre_periods) == 0:
|
|
919
|
+
raise ValueError("No pre-treatment periods specified or inferred.")
|
|
920
|
+
|
|
921
|
+
# -------------------------------------------------------------------------
|
|
922
|
+
# Identify treated and control units
|
|
923
|
+
# -------------------------------------------------------------------------
|
|
924
|
+
all_units = list(data[unit_column].unique())
|
|
925
|
+
|
|
926
|
+
if treated_units is not None:
|
|
927
|
+
treated_set = set(treated_units)
|
|
928
|
+
elif treatment_column is not None:
|
|
929
|
+
unit_treatment = data.groupby(unit_column)[treatment_column].first()
|
|
930
|
+
treated_set = set(unit_treatment[unit_treatment == 1].index)
|
|
931
|
+
elif suggest_treatment_candidates:
|
|
932
|
+
# Treatment candidate discovery mode - no treated units
|
|
933
|
+
treated_set = set()
|
|
934
|
+
else:
|
|
935
|
+
raise ValueError(
|
|
936
|
+
"Must specify treated_units, treatment_column, or set "
|
|
937
|
+
"suggest_treatment_candidates=True"
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# -------------------------------------------------------------------------
|
|
941
|
+
# Treatment candidate discovery mode
|
|
942
|
+
# -------------------------------------------------------------------------
|
|
943
|
+
if suggest_treatment_candidates and len(treated_set) == 0:
|
|
944
|
+
return _suggest_treatment_candidates(
|
|
945
|
+
data, unit_column, time_column, outcome_column, pre_periods, n_treatment_candidates
|
|
946
|
+
)
|
|
947
|
+
|
|
948
|
+
if len(treated_set) == 0:
|
|
949
|
+
raise ValueError("No treated units found.")
|
|
950
|
+
|
|
951
|
+
# Determine control candidates
|
|
952
|
+
control_candidates = [u for u in all_units if u not in treated_set]
|
|
953
|
+
|
|
954
|
+
if exclude_units:
|
|
955
|
+
control_candidates = [u for u in control_candidates if u not in exclude_units]
|
|
956
|
+
|
|
957
|
+
if len(control_candidates) == 0:
|
|
958
|
+
raise ValueError("No control units available after exclusions.")
|
|
959
|
+
|
|
960
|
+
# -------------------------------------------------------------------------
|
|
961
|
+
# Create outcome matrices (pre-treatment)
|
|
962
|
+
# -------------------------------------------------------------------------
|
|
963
|
+
pre_data = data[data[time_column].isin(pre_periods)]
|
|
964
|
+
pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
|
|
965
|
+
|
|
966
|
+
# Filter to pre_periods that exist in data
|
|
967
|
+
valid_pre_periods = [p for p in pre_periods if p in pivot.index]
|
|
968
|
+
if len(valid_pre_periods) == 0:
|
|
969
|
+
raise ValueError("No data found for specified pre-treatment periods.")
|
|
970
|
+
|
|
971
|
+
# Filter control_candidates to those present in pivot (handles unbalanced panels)
|
|
972
|
+
control_candidates = [c for c in control_candidates if c in pivot.columns]
|
|
973
|
+
if len(control_candidates) == 0:
|
|
974
|
+
raise ValueError("No control units found in pre-treatment data.")
|
|
975
|
+
|
|
976
|
+
# Control outcomes: shape (n_pre_periods, n_control_candidates)
|
|
977
|
+
Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
|
|
978
|
+
|
|
979
|
+
# Treated outcomes mean: shape (n_pre_periods,)
|
|
980
|
+
treated_list = [u for u in treated_set if u in pivot.columns]
|
|
981
|
+
if len(treated_list) == 0:
|
|
982
|
+
raise ValueError("Treated units not found in pre-treatment data.")
|
|
983
|
+
Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
|
|
984
|
+
|
|
985
|
+
# -------------------------------------------------------------------------
|
|
986
|
+
# Compute outcome trend scores
|
|
987
|
+
# -------------------------------------------------------------------------
|
|
988
|
+
# Synthetic weights (higher = better match)
|
|
989
|
+
synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg)
|
|
990
|
+
|
|
991
|
+
# RMSE for each control vs treated mean (use nanmean to handle missing data)
|
|
992
|
+
rmse_scores = []
|
|
993
|
+
for j in range(len(control_candidates)):
|
|
994
|
+
y_c = Y_control[:, j]
|
|
995
|
+
rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
|
|
996
|
+
rmse_scores.append(rmse)
|
|
997
|
+
|
|
998
|
+
# Convert RMSE to similarity score (lower RMSE = higher score)
|
|
999
|
+
max_rmse = max(rmse_scores) if rmse_scores else 1.0
|
|
1000
|
+
min_rmse = min(rmse_scores) if rmse_scores else 0.0
|
|
1001
|
+
rmse_range = max_rmse - min_rmse
|
|
1002
|
+
|
|
1003
|
+
if rmse_range < 1e-10:
|
|
1004
|
+
# All controls have identical/similar pre-trends (includes single control case)
|
|
1005
|
+
outcome_trend_scores = [1.0] * len(rmse_scores)
|
|
1006
|
+
else:
|
|
1007
|
+
# Normalize so best control gets 1.0, worst gets 0.0
|
|
1008
|
+
outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
|
|
1009
|
+
|
|
1010
|
+
# -------------------------------------------------------------------------
|
|
1011
|
+
# Compute covariate scores (if covariates provided)
|
|
1012
|
+
# -------------------------------------------------------------------------
|
|
1013
|
+
if covariates and len(covariates) > 0:
|
|
1014
|
+
# Get unit-level covariate values (pre-treatment mean)
|
|
1015
|
+
cov_data = pre_data.groupby(unit_column)[covariates].mean()
|
|
1016
|
+
|
|
1017
|
+
# Treated covariate profile (mean across treated units)
|
|
1018
|
+
treated_cov = cov_data.loc[list(treated_set)].mean()
|
|
1019
|
+
|
|
1020
|
+
# Standardize covariates
|
|
1021
|
+
cov_mean = cov_data.mean()
|
|
1022
|
+
cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
|
|
1023
|
+
cov_standardized = (cov_data - cov_mean) / cov_std
|
|
1024
|
+
treated_cov_std = (treated_cov - cov_mean) / cov_std
|
|
1025
|
+
|
|
1026
|
+
# Euclidean distance in standardized space (vectorized)
|
|
1027
|
+
control_cov_matrix = cov_standardized.loc[control_candidates].values
|
|
1028
|
+
treated_cov_vector = treated_cov_std.values
|
|
1029
|
+
covariate_distances = np.sqrt(
|
|
1030
|
+
np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
# Convert distance to similarity score (min-max normalization)
|
|
1034
|
+
max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
|
|
1035
|
+
min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
|
|
1036
|
+
dist_range = max_dist - min_dist
|
|
1037
|
+
|
|
1038
|
+
if dist_range < 1e-10:
|
|
1039
|
+
# All controls have identical/similar covariate profiles
|
|
1040
|
+
covariate_scores = [1.0] * len(covariate_distances)
|
|
1041
|
+
else:
|
|
1042
|
+
# Normalize so best control (closest) gets 1.0, worst gets 0.0
|
|
1043
|
+
covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
|
|
1044
|
+
else:
|
|
1045
|
+
covariate_scores = [np.nan] * len(control_candidates)
|
|
1046
|
+
|
|
1047
|
+
# -------------------------------------------------------------------------
|
|
1048
|
+
# Compute combined quality score
|
|
1049
|
+
# -------------------------------------------------------------------------
|
|
1050
|
+
# Normalize weights
|
|
1051
|
+
total_weight = outcome_weight + covariate_weight
|
|
1052
|
+
if total_weight > 0:
|
|
1053
|
+
norm_outcome_weight = outcome_weight / total_weight
|
|
1054
|
+
norm_covariate_weight = covariate_weight / total_weight
|
|
1055
|
+
else:
|
|
1056
|
+
norm_outcome_weight = 1.0
|
|
1057
|
+
norm_covariate_weight = 0.0
|
|
1058
|
+
|
|
1059
|
+
quality_scores = []
|
|
1060
|
+
for i in range(len(control_candidates)):
|
|
1061
|
+
outcome_score = outcome_trend_scores[i]
|
|
1062
|
+
cov_score = covariate_scores[i]
|
|
1063
|
+
|
|
1064
|
+
if np.isnan(cov_score):
|
|
1065
|
+
# No covariates - use only outcome score
|
|
1066
|
+
combined = outcome_score
|
|
1067
|
+
else:
|
|
1068
|
+
combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
|
|
1069
|
+
|
|
1070
|
+
quality_scores.append(combined)
|
|
1071
|
+
|
|
1072
|
+
# -------------------------------------------------------------------------
|
|
1073
|
+
# Build result DataFrame
|
|
1074
|
+
# -------------------------------------------------------------------------
|
|
1075
|
+
require_set = set(require_units) if require_units else set()
|
|
1076
|
+
|
|
1077
|
+
result = pd.DataFrame(
|
|
1078
|
+
{
|
|
1079
|
+
"unit": control_candidates,
|
|
1080
|
+
"quality_score": quality_scores,
|
|
1081
|
+
"outcome_trend_score": outcome_trend_scores,
|
|
1082
|
+
"covariate_score": covariate_scores,
|
|
1083
|
+
"synthetic_weight": synthetic_weights,
|
|
1084
|
+
"pre_trend_rmse": rmse_scores,
|
|
1085
|
+
"is_required": [u in require_set for u in control_candidates],
|
|
1086
|
+
}
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
# Sort by quality score (descending)
|
|
1090
|
+
result = result.sort_values("quality_score", ascending=False)
|
|
1091
|
+
|
|
1092
|
+
# Apply n_top limit if specified
|
|
1093
|
+
if n_top is not None and n_top < len(result):
|
|
1094
|
+
# Always include required units
|
|
1095
|
+
required_df = result[result["is_required"]]
|
|
1096
|
+
non_required_df = result[~result["is_required"]]
|
|
1097
|
+
|
|
1098
|
+
# Take top from non-required to fill remaining slots
|
|
1099
|
+
remaining_slots = max(0, n_top - len(required_df))
|
|
1100
|
+
top_non_required = non_required_df.head(remaining_slots)
|
|
1101
|
+
|
|
1102
|
+
result = pd.concat([required_df, top_non_required])
|
|
1103
|
+
result = result.sort_values("quality_score", ascending=False)
|
|
1104
|
+
|
|
1105
|
+
return result.reset_index(drop=True)
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def _suggest_treatment_candidates(
|
|
1109
|
+
data: pd.DataFrame,
|
|
1110
|
+
unit_column: str,
|
|
1111
|
+
time_column: str,
|
|
1112
|
+
outcome_column: str,
|
|
1113
|
+
pre_periods: List[Any],
|
|
1114
|
+
n_candidates: int,
|
|
1115
|
+
) -> pd.DataFrame:
|
|
1116
|
+
"""
|
|
1117
|
+
Identify units that would make good treatment candidates.
|
|
1118
|
+
|
|
1119
|
+
A good treatment candidate:
|
|
1120
|
+
1. Has many similar control units available (for matching)
|
|
1121
|
+
2. Has stable pre-treatment trends (predictable counterfactual)
|
|
1122
|
+
3. Is not an extreme outlier
|
|
1123
|
+
|
|
1124
|
+
Parameters
|
|
1125
|
+
----------
|
|
1126
|
+
data : pd.DataFrame
|
|
1127
|
+
Panel data.
|
|
1128
|
+
unit_column : str
|
|
1129
|
+
Unit identifier column.
|
|
1130
|
+
time_column : str
|
|
1131
|
+
Time period column.
|
|
1132
|
+
outcome_column : str
|
|
1133
|
+
Outcome variable column.
|
|
1134
|
+
pre_periods : list
|
|
1135
|
+
Pre-treatment periods.
|
|
1136
|
+
n_candidates : int
|
|
1137
|
+
Number of candidates to return.
|
|
1138
|
+
|
|
1139
|
+
Returns
|
|
1140
|
+
-------
|
|
1141
|
+
pd.DataFrame
|
|
1142
|
+
Treatment candidates with scores.
|
|
1143
|
+
"""
|
|
1144
|
+
all_units = list(data[unit_column].unique())
|
|
1145
|
+
pre_data = data[data[time_column].isin(pre_periods)]
|
|
1146
|
+
|
|
1147
|
+
candidate_info = []
|
|
1148
|
+
|
|
1149
|
+
for unit in all_units:
|
|
1150
|
+
unit_data = pre_data[pre_data[unit_column] == unit]
|
|
1151
|
+
|
|
1152
|
+
if len(unit_data) == 0:
|
|
1153
|
+
continue
|
|
1154
|
+
|
|
1155
|
+
# Average outcome level
|
|
1156
|
+
avg_outcome = unit_data[outcome_column].mean()
|
|
1157
|
+
|
|
1158
|
+
# Trend (simple linear regression slope)
|
|
1159
|
+
times = unit_data[time_column].values
|
|
1160
|
+
outcomes = unit_data[outcome_column].values
|
|
1161
|
+
if len(times) > 1:
|
|
1162
|
+
times_norm = np.arange(len(times))
|
|
1163
|
+
try:
|
|
1164
|
+
slope = np.polyfit(times_norm, outcomes, 1)[0]
|
|
1165
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1166
|
+
slope = 0.0
|
|
1167
|
+
else:
|
|
1168
|
+
slope = 0.0
|
|
1169
|
+
|
|
1170
|
+
# Count similar potential controls
|
|
1171
|
+
other_units = [u for u in all_units if u != unit]
|
|
1172
|
+
other_means = (
|
|
1173
|
+
pre_data[pre_data[unit_column].isin(other_units)]
|
|
1174
|
+
.groupby(unit_column)[outcome_column]
|
|
1175
|
+
.mean()
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
if len(other_means) > 0:
|
|
1179
|
+
sd = other_means.std()
|
|
1180
|
+
if sd > 0:
|
|
1181
|
+
n_similar = int(
|
|
1182
|
+
np.sum(np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd)
|
|
1183
|
+
)
|
|
1184
|
+
else:
|
|
1185
|
+
n_similar = len(other_means)
|
|
1186
|
+
else:
|
|
1187
|
+
n_similar = 0
|
|
1188
|
+
|
|
1189
|
+
candidate_info.append(
|
|
1190
|
+
{
|
|
1191
|
+
"unit": unit,
|
|
1192
|
+
"avg_outcome_level": avg_outcome,
|
|
1193
|
+
"outcome_trend": slope,
|
|
1194
|
+
"n_similar_controls": n_similar,
|
|
1195
|
+
}
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
if len(candidate_info) == 0:
|
|
1199
|
+
return pd.DataFrame(
|
|
1200
|
+
columns=[
|
|
1201
|
+
"unit",
|
|
1202
|
+
"treatment_candidate_score",
|
|
1203
|
+
"avg_outcome_level",
|
|
1204
|
+
"outcome_trend",
|
|
1205
|
+
"n_similar_controls",
|
|
1206
|
+
]
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
result = pd.DataFrame(candidate_info)
|
|
1210
|
+
|
|
1211
|
+
# Score: prefer units with many similar controls and moderate outcome levels
|
|
1212
|
+
max_similar = result["n_similar_controls"].max()
|
|
1213
|
+
if max_similar > 0:
|
|
1214
|
+
similarity_score = result["n_similar_controls"] / max_similar
|
|
1215
|
+
else:
|
|
1216
|
+
similarity_score = pd.Series([0.0] * len(result))
|
|
1217
|
+
|
|
1218
|
+
# Penalty for outliers in outcome level
|
|
1219
|
+
outcome_mean = result["avg_outcome_level"].mean()
|
|
1220
|
+
outcome_std = result["avg_outcome_level"].std()
|
|
1221
|
+
if outcome_std > 0:
|
|
1222
|
+
outcome_z = np.abs((result["avg_outcome_level"] - outcome_mean) / outcome_std)
|
|
1223
|
+
else:
|
|
1224
|
+
outcome_z = pd.Series([0.0] * len(result))
|
|
1225
|
+
|
|
1226
|
+
result["treatment_candidate_score"] = (
|
|
1227
|
+
similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
|
|
1228
|
+
).clip(0, 1)
|
|
1229
|
+
|
|
1230
|
+
# Return top candidates
|
|
1231
|
+
result = result.nlargest(n_candidates, "treatment_candidate_score")
|
|
1232
|
+
return result.reset_index(drop=True)
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
def trim_weights(
|
|
1236
|
+
data: pd.DataFrame,
|
|
1237
|
+
weight_col: str,
|
|
1238
|
+
upper: Optional[float] = None,
|
|
1239
|
+
quantile: Optional[float] = None,
|
|
1240
|
+
lower: Optional[float] = None,
|
|
1241
|
+
) -> pd.DataFrame:
|
|
1242
|
+
"""Trim (winsorize) survey weights to reduce influence of extreme values.
|
|
1243
|
+
|
|
1244
|
+
Caps weights at specified thresholds. Useful for reducing variance from
|
|
1245
|
+
extreme survey weights before DiD estimation. Federal agencies (e.g., NCHS)
|
|
1246
|
+
recommend reviewing weights with CV > 30%.
|
|
1247
|
+
|
|
1248
|
+
Parameters
|
|
1249
|
+
----------
|
|
1250
|
+
data : pd.DataFrame
|
|
1251
|
+
Input DataFrame.
|
|
1252
|
+
weight_col : str
|
|
1253
|
+
Name of the weight column.
|
|
1254
|
+
upper : float, optional
|
|
1255
|
+
Absolute upper cap. Weights above this value are set to it.
|
|
1256
|
+
Mutually exclusive with ``quantile``.
|
|
1257
|
+
quantile : float, optional
|
|
1258
|
+
Quantile-based upper cap (e.g., 0.99). Weights above the quantile
|
|
1259
|
+
value are capped at it. Mutually exclusive with ``upper``.
|
|
1260
|
+
lower : float, optional
|
|
1261
|
+
Absolute lower floor. Weights below this value are set to it.
|
|
1262
|
+
Can be combined with either ``upper`` or ``quantile``.
|
|
1263
|
+
|
|
1264
|
+
Returns
|
|
1265
|
+
-------
|
|
1266
|
+
pd.DataFrame
|
|
1267
|
+
Copy of data with trimmed weights.
|
|
1268
|
+
|
|
1269
|
+
Raises
|
|
1270
|
+
------
|
|
1271
|
+
ValueError
|
|
1272
|
+
If both ``upper`` and ``quantile`` are provided, or if ``weight_col``
|
|
1273
|
+
is not in the DataFrame.
|
|
1274
|
+
"""
|
|
1275
|
+
if upper is not None and quantile is not None:
|
|
1276
|
+
raise ValueError("Specify either 'upper' or 'quantile', not both.")
|
|
1277
|
+
if weight_col not in data.columns:
|
|
1278
|
+
raise ValueError(f"Column '{weight_col}' not found in DataFrame.")
|
|
1279
|
+
|
|
1280
|
+
result = data.copy()
|
|
1281
|
+
w = result[weight_col].values.copy()
|
|
1282
|
+
|
|
1283
|
+
if quantile is not None:
|
|
1284
|
+
if not (0 < quantile < 1):
|
|
1285
|
+
raise ValueError(f"quantile must be in (0, 1), got {quantile}")
|
|
1286
|
+
upper = float(np.nanquantile(w, quantile))
|
|
1287
|
+
|
|
1288
|
+
# Validate cap values are finite and non-negative
|
|
1289
|
+
if upper is not None:
|
|
1290
|
+
if not np.isfinite(upper) or upper < 0:
|
|
1291
|
+
raise ValueError(f"upper must be finite and >= 0, got {upper}")
|
|
1292
|
+
if lower is not None:
|
|
1293
|
+
if not np.isfinite(lower) or lower < 0:
|
|
1294
|
+
raise ValueError(f"lower must be finite and >= 0, got {lower}")
|
|
1295
|
+
if upper is not None and lower is not None and lower > upper:
|
|
1296
|
+
raise ValueError(
|
|
1297
|
+
f"lower ({lower}) must be <= upper ({upper}). "
|
|
1298
|
+
f"When using quantile, the resolved upper cap may be below lower."
|
|
1299
|
+
)
|
|
1300
|
+
|
|
1301
|
+
if upper is not None:
|
|
1302
|
+
w = np.minimum(w, upper)
|
|
1303
|
+
if lower is not None:
|
|
1304
|
+
w = np.maximum(w, lower)
|
|
1305
|
+
|
|
1306
|
+
result[weight_col] = w
|
|
1307
|
+
return result
|
|
1308
|
+
|
|
1309
|
+
|
|
1310
|
+
# ---------------------------------------------------------------------------
|
|
1311
|
+
# Survey aggregation helpers
|
|
1312
|
+
# ---------------------------------------------------------------------------
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
def _cell_mean_variance(
|
|
1316
|
+
y_full: np.ndarray,
|
|
1317
|
+
full_resolved: ResolvedSurveyDesign,
|
|
1318
|
+
cell_mask: np.ndarray,
|
|
1319
|
+
min_n: int,
|
|
1320
|
+
) -> Tuple[float, float, int, bool]:
|
|
1321
|
+
"""Compute design-based mean and variance of the weighted mean for one cell.
|
|
1322
|
+
|
|
1323
|
+
Uses full-design domain estimation: the influence function is zero-padded
|
|
1324
|
+
outside the cell, preserving the full strata/PSU structure for variance
|
|
1325
|
+
estimation. This is the methodologically correct approach for domain
|
|
1326
|
+
estimation under complex survey designs (Lumley 2004, Section 3.4).
|
|
1327
|
+
|
|
1328
|
+
Parameters
|
|
1329
|
+
----------
|
|
1330
|
+
y_full : np.ndarray
|
|
1331
|
+
Outcome values for the full dataset (may contain NaN).
|
|
1332
|
+
full_resolved : ResolvedSurveyDesign
|
|
1333
|
+
Full-sample resolved survey design.
|
|
1334
|
+
cell_mask : np.ndarray
|
|
1335
|
+
Boolean mask identifying cell members in the full dataset.
|
|
1336
|
+
min_n : int
|
|
1337
|
+
Minimum valid observations for design-based variance. Below this
|
|
1338
|
+
threshold, SRS fallback is used.
|
|
1339
|
+
|
|
1340
|
+
Returns
|
|
1341
|
+
-------
|
|
1342
|
+
mean : float
|
|
1343
|
+
Design-weighted cell mean.
|
|
1344
|
+
variance : float
|
|
1345
|
+
Design-based variance of the cell mean (>= 0). Uses SRS fallback
|
|
1346
|
+
when the design-based estimate is unidentifiable or n_valid < min_n.
|
|
1347
|
+
n_valid : int
|
|
1348
|
+
Number of non-missing observations in the cell.
|
|
1349
|
+
used_srs_fallback : bool
|
|
1350
|
+
True if SRS variance was used instead of design-based.
|
|
1351
|
+
"""
|
|
1352
|
+
y_cell = y_full[cell_mask]
|
|
1353
|
+
w_cell = full_resolved.weights[cell_mask]
|
|
1354
|
+
# Valid = non-missing AND positive weight (zero-weight rows are padding)
|
|
1355
|
+
valid = ~np.isnan(y_cell) & (w_cell > 0)
|
|
1356
|
+
n_valid = int(np.sum(valid))
|
|
1357
|
+
|
|
1358
|
+
if n_valid == 0:
|
|
1359
|
+
return np.nan, np.nan, 0, False
|
|
1360
|
+
|
|
1361
|
+
if n_valid < 2:
|
|
1362
|
+
y_bar = float(y_cell[valid][0])
|
|
1363
|
+
return y_bar, np.nan, 1, False
|
|
1364
|
+
|
|
1365
|
+
# Weighted mean from cell members (NaN-safe)
|
|
1366
|
+
w_valid = w_cell * valid.astype(np.float64)
|
|
1367
|
+
y_clean = np.where(valid, y_cell, 0.0)
|
|
1368
|
+
sum_w = float(np.sum(w_valid))
|
|
1369
|
+
|
|
1370
|
+
if sum_w <= 0:
|
|
1371
|
+
return np.nan, np.nan, n_valid, False
|
|
1372
|
+
|
|
1373
|
+
y_bar = float(np.sum(w_valid * y_clean) / sum_w)
|
|
1374
|
+
|
|
1375
|
+
# SRS fallback if below min_n threshold
|
|
1376
|
+
# Normalize positive weights to mean=1 so fallback is scale-invariant
|
|
1377
|
+
# (replicate designs preserve raw weight scale per survey.py:L189-240)
|
|
1378
|
+
used_srs = False
|
|
1379
|
+
if n_valid < min_n:
|
|
1380
|
+
w_norm = w_valid.copy()
|
|
1381
|
+
w_pos = w_norm[w_norm > 0]
|
|
1382
|
+
if len(w_pos) > 0:
|
|
1383
|
+
w_norm[w_norm > 0] = w_pos / w_pos.mean()
|
|
1384
|
+
sum_wn = float(np.sum(w_norm))
|
|
1385
|
+
resid_sq = w_norm * (y_clean - y_bar) ** 2
|
|
1386
|
+
variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
|
|
1387
|
+
return y_bar, max(variance, 0.0), n_valid, True
|
|
1388
|
+
|
|
1389
|
+
# Full-design domain estimation: construct full-length psi with zeros
|
|
1390
|
+
# outside the cell, preserving full strata/PSU structure for variance
|
|
1391
|
+
n_total = len(y_full)
|
|
1392
|
+
psi = np.zeros(n_total)
|
|
1393
|
+
# Positions in full array where cell member has valid data
|
|
1394
|
+
cell_indices = np.where(cell_mask)[0]
|
|
1395
|
+
valid_positions = cell_indices[valid]
|
|
1396
|
+
psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w
|
|
1397
|
+
|
|
1398
|
+
# Route to TSL or replicate variance using the full design
|
|
1399
|
+
if full_resolved.uses_replicate_variance:
|
|
1400
|
+
variance, _ = compute_replicate_if_variance(psi, full_resolved)
|
|
1401
|
+
else:
|
|
1402
|
+
variance = compute_survey_if_variance(psi, full_resolved)
|
|
1403
|
+
|
|
1404
|
+
# SRS fallback when design-based variance is unidentifiable
|
|
1405
|
+
if np.isnan(variance):
|
|
1406
|
+
w_norm = w_valid.copy()
|
|
1407
|
+
w_pos = w_norm[w_norm > 0]
|
|
1408
|
+
if len(w_pos) > 0:
|
|
1409
|
+
w_norm[w_norm > 0] = w_pos / w_pos.mean()
|
|
1410
|
+
sum_wn = float(np.sum(w_norm))
|
|
1411
|
+
resid_sq = w_norm * (y_clean - y_bar) ** 2
|
|
1412
|
+
variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
|
|
1413
|
+
used_srs = True
|
|
1414
|
+
|
|
1415
|
+
return y_bar, max(float(variance), 0.0), n_valid, used_srs
|
|
1416
|
+
|
|
1417
|
+
|
|
1418
|
+
def aggregate_survey(
|
|
1419
|
+
data: pd.DataFrame,
|
|
1420
|
+
by: Union[str, List[str]],
|
|
1421
|
+
outcomes: Union[str, List[str]],
|
|
1422
|
+
survey_design: SurveyDesign,
|
|
1423
|
+
covariates: Optional[Union[str, List[str]]] = None,
|
|
1424
|
+
min_n: int = 2,
|
|
1425
|
+
lonely_psu: Optional[str] = None,
|
|
1426
|
+
) -> Tuple[pd.DataFrame, SurveyDesign]:
|
|
1427
|
+
"""Aggregate survey microdata to geographic-period cells with design-based precision.
|
|
1428
|
+
|
|
1429
|
+
Computes design-weighted cell means and their Taylor-linearized (or
|
|
1430
|
+
replicate-based) standard errors for each cell defined by the ``by``
|
|
1431
|
+
columns. Returns a panel-ready DataFrame with precision weights and a
|
|
1432
|
+
pre-configured :class:`SurveyDesign` for second-stage DiD estimation.
|
|
1433
|
+
|
|
1434
|
+
Each cell is treated as a subpopulation/domain of the full survey
|
|
1435
|
+
design: influence function values are zero-padded outside the cell,
|
|
1436
|
+
preserving full strata/PSU structure for variance estimation per
|
|
1437
|
+
Lumley (2004) Section 3.4.
|
|
1438
|
+
|
|
1439
|
+
Parameters
|
|
1440
|
+
----------
|
|
1441
|
+
data : pd.DataFrame
|
|
1442
|
+
Individual-level microdata.
|
|
1443
|
+
by : str or list of str
|
|
1444
|
+
Columns defining cells (e.g., ``["state", "year"]``). The first
|
|
1445
|
+
element is used as the clustering variable in the returned
|
|
1446
|
+
SurveyDesign (geographic unit for second-stage inference).
|
|
1447
|
+
outcomes : str or list of str
|
|
1448
|
+
Outcome variable(s) to aggregate with full precision tracking.
|
|
1449
|
+
Each outcome produces ``{name}_mean``, ``{name}_se``,
|
|
1450
|
+
``{name}_n``, and ``{name}_precision`` columns. When multiple
|
|
1451
|
+
outcomes are given, panel filtering (non-estimable cell
|
|
1452
|
+
removal, zero-weight PSU pruning) is based on the **first**
|
|
1453
|
+
outcome only, consistent with the returned SurveyDesign. For
|
|
1454
|
+
independent per-outcome support, call once per outcome.
|
|
1455
|
+
survey_design : SurveyDesign
|
|
1456
|
+
Survey design specification for the microdata.
|
|
1457
|
+
covariates : str or list of str, optional
|
|
1458
|
+
Additional variables to aggregate as design-weighted means only
|
|
1459
|
+
(no SE/precision columns).
|
|
1460
|
+
min_n : int, default 2
|
|
1461
|
+
Minimum respondents per cell. Cells below this threshold use
|
|
1462
|
+
simple random sampling variance as a fallback.
|
|
1463
|
+
lonely_psu : str, optional
|
|
1464
|
+
Override the survey design's ``lonely_psu`` setting for within-cell
|
|
1465
|
+
computation. One of ``"remove"``, ``"certainty"``, ``"adjust"``.
|
|
1466
|
+
|
|
1467
|
+
Returns
|
|
1468
|
+
-------
|
|
1469
|
+
panel_df : pd.DataFrame
|
|
1470
|
+
Aggregated panel with columns: grouping variables,
|
|
1471
|
+
``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``,
|
|
1472
|
+
``{outcome}_precision``, ``{outcome}_weight``,
|
|
1473
|
+
``{covariate}_mean``, ``cell_n``, ``cell_n_eff``,
|
|
1474
|
+
``srs_fallback``. The ``_weight`` column is a fit-ready
|
|
1475
|
+
version of ``_precision`` with NaN/Inf mapped to 0.0.
|
|
1476
|
+
second_stage_design : SurveyDesign
|
|
1477
|
+
Pre-configured for second-stage estimation with
|
|
1478
|
+
``weight_type="aweight"``, precision weights from the first
|
|
1479
|
+
outcome, and geographic clustering via ``psu``.
|
|
1480
|
+
|
|
1481
|
+
Examples
|
|
1482
|
+
--------
|
|
1483
|
+
>>> design = SurveyDesign(weights="finalwt", strata="strat", psu="psu")
|
|
1484
|
+
>>> panel, stage2 = aggregate_survey(
|
|
1485
|
+
... microdata, by=["state", "year"],
|
|
1486
|
+
... outcomes="smoking_rate", survey_design=design,
|
|
1487
|
+
... )
|
|
1488
|
+
>>> # Add treatment/time indicators at the panel level, then fit:
|
|
1489
|
+
>>> # panel["treated"] = ... # e.g., from policy adoption data
|
|
1490
|
+
>>> # panel["post"] = (panel["year"] >= treatment_year).astype(int)
|
|
1491
|
+
>>> # result = DifferenceInDifferences().fit(
|
|
1492
|
+
>>> # panel, outcome="smoking_rate_mean",
|
|
1493
|
+
>>> # treatment="treated", time="post", survey_design=stage2,
|
|
1494
|
+
>>> # )
|
|
1495
|
+
"""
|
|
1496
|
+
import warnings
|
|
1497
|
+
from dataclasses import replace
|
|
1498
|
+
|
|
1499
|
+
# --- Normalize inputs ---
|
|
1500
|
+
by_cols = [by] if isinstance(by, str) else list(by)
|
|
1501
|
+
outcome_cols = [outcomes] if isinstance(outcomes, str) else list(outcomes)
|
|
1502
|
+
cov_cols = (
|
|
1503
|
+
[covariates] if isinstance(covariates, str) else list(covariates) if covariates else []
|
|
1504
|
+
)
|
|
1505
|
+
|
|
1506
|
+
# --- Validate ---
|
|
1507
|
+
if not by_cols:
|
|
1508
|
+
raise ValueError("'by' must specify at least one grouping column")
|
|
1509
|
+
if not outcome_cols:
|
|
1510
|
+
raise ValueError("'outcomes' must specify at least one outcome variable")
|
|
1511
|
+
|
|
1512
|
+
all_cols = by_cols + outcome_cols + cov_cols
|
|
1513
|
+
missing = [c for c in all_cols if c not in data.columns]
|
|
1514
|
+
if missing:
|
|
1515
|
+
raise ValueError(f"Columns not found in DataFrame: {missing}")
|
|
1516
|
+
|
|
1517
|
+
overlap = set(by_cols) & (set(outcome_cols) | set(cov_cols))
|
|
1518
|
+
if overlap:
|
|
1519
|
+
raise ValueError(f"Columns appear in both 'by' and outcomes/covariates: {overlap}")
|
|
1520
|
+
|
|
1521
|
+
if not isinstance(survey_design, SurveyDesign):
|
|
1522
|
+
raise TypeError(
|
|
1523
|
+
f"survey_design must be a SurveyDesign instance, got {type(survey_design).__name__}"
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
if min_n < 1:
|
|
1527
|
+
raise ValueError(f"min_n must be >= 1, got {min_n}")
|
|
1528
|
+
|
|
1529
|
+
if lonely_psu is not None and lonely_psu not in ("remove", "certainty", "adjust"):
|
|
1530
|
+
raise ValueError(
|
|
1531
|
+
f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'"
|
|
1532
|
+
)
|
|
1533
|
+
|
|
1534
|
+
# --- Empty-input guard ---
|
|
1535
|
+
if data.empty:
|
|
1536
|
+
raise ValueError("data must be non-empty")
|
|
1537
|
+
|
|
1538
|
+
# --- Validate grouping columns have no missing values ---
|
|
1539
|
+
by_missing = data[by_cols].isna().any()
|
|
1540
|
+
cols_with_na = list(by_missing[by_missing].index)
|
|
1541
|
+
if cols_with_na:
|
|
1542
|
+
raise ValueError(
|
|
1543
|
+
f"Missing values in grouping column(s): {cols_with_na}. "
|
|
1544
|
+
f"Drop or fill NaN values before calling aggregate_survey()."
|
|
1545
|
+
)
|
|
1546
|
+
|
|
1547
|
+
# --- Resolve design once on full data ---
|
|
1548
|
+
effective_design = (
|
|
1549
|
+
replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design
|
|
1550
|
+
)
|
|
1551
|
+
full_resolved = effective_design.resolve(data)
|
|
1552
|
+
|
|
1553
|
+
# --- Precompute full-length outcome/covariate arrays ---
|
|
1554
|
+
n_total = len(data)
|
|
1555
|
+
all_vars = outcome_cols + cov_cols
|
|
1556
|
+
non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])]
|
|
1557
|
+
if non_numeric:
|
|
1558
|
+
raise ValueError(
|
|
1559
|
+
f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. "
|
|
1560
|
+
f"All outcome and covariate columns must be numeric."
|
|
1561
|
+
)
|
|
1562
|
+
y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars}
|
|
1563
|
+
|
|
1564
|
+
# --- Per-cell computation ---
|
|
1565
|
+
# Use groupby().indices for position-based cell membership (safe with
|
|
1566
|
+
# duplicate DataFrame indices, no column injection into user data)
|
|
1567
|
+
grouped = data.groupby(by_cols, sort=True)
|
|
1568
|
+
cell_indices = grouped.indices # dict of cell_key → positional indices
|
|
1569
|
+
rows: List[Dict[str, Any]] = []
|
|
1570
|
+
srs_cells: List[str] = []
|
|
1571
|
+
zero_var_cells: List[str] = []
|
|
1572
|
+
|
|
1573
|
+
for cell_key, pos_idx in cell_indices.items():
|
|
1574
|
+
# Boolean mask for full-design domain estimation
|
|
1575
|
+
cell_mask = np.zeros(n_total, dtype=bool)
|
|
1576
|
+
cell_mask[pos_idx] = True
|
|
1577
|
+
|
|
1578
|
+
cell_n = int(np.sum(cell_mask))
|
|
1579
|
+
cell_key_str = str(cell_key)
|
|
1580
|
+
|
|
1581
|
+
# Cell-level statistics (Kish ESS is a property of the cell)
|
|
1582
|
+
cell_w = full_resolved.weights[cell_mask]
|
|
1583
|
+
sum_w = float(np.sum(cell_w))
|
|
1584
|
+
sum_w2 = float(np.sum(cell_w**2))
|
|
1585
|
+
cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0
|
|
1586
|
+
|
|
1587
|
+
# Build row dict with grouping columns
|
|
1588
|
+
row: Dict[str, Any] = {}
|
|
1589
|
+
if len(by_cols) == 1:
|
|
1590
|
+
row[by_cols[0]] = cell_key
|
|
1591
|
+
else:
|
|
1592
|
+
for i, col in enumerate(by_cols):
|
|
1593
|
+
row[col] = cell_key[i]
|
|
1594
|
+
|
|
1595
|
+
row["cell_n"] = cell_n
|
|
1596
|
+
row["cell_n_eff"] = cell_n_eff
|
|
1597
|
+
|
|
1598
|
+
cell_srs_fallback = False
|
|
1599
|
+
|
|
1600
|
+
# Outcomes: mean + SE + n + precision (full-design domain estimation)
|
|
1601
|
+
for var in outcome_cols:
|
|
1602
|
+
y_bar, variance, n_valid, used_srs = _cell_mean_variance(
|
|
1603
|
+
y_arrays[var],
|
|
1604
|
+
full_resolved,
|
|
1605
|
+
cell_mask,
|
|
1606
|
+
min_n,
|
|
1607
|
+
)
|
|
1608
|
+
se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan
|
|
1609
|
+
|
|
1610
|
+
if used_srs:
|
|
1611
|
+
cell_srs_fallback = True
|
|
1612
|
+
|
|
1613
|
+
# Zero variance → precision NaN
|
|
1614
|
+
if se == 0.0:
|
|
1615
|
+
precision = np.nan
|
|
1616
|
+
zero_var_cells.append(cell_key_str)
|
|
1617
|
+
elif np.isnan(se):
|
|
1618
|
+
precision = np.nan
|
|
1619
|
+
else:
|
|
1620
|
+
precision = 1.0 / variance
|
|
1621
|
+
|
|
1622
|
+
row[f"{var}_mean"] = y_bar
|
|
1623
|
+
row[f"{var}_se"] = se
|
|
1624
|
+
row[f"{var}_n"] = n_valid
|
|
1625
|
+
row[f"{var}_precision"] = precision
|
|
1626
|
+
|
|
1627
|
+
# Covariates: design-weighted mean only
|
|
1628
|
+
for var in cov_cols:
|
|
1629
|
+
y_cell = y_arrays[var][cell_mask]
|
|
1630
|
+
valid = ~np.isnan(y_cell)
|
|
1631
|
+
w_valid = cell_w * valid.astype(np.float64)
|
|
1632
|
+
sw = float(np.sum(w_valid))
|
|
1633
|
+
if sw > 0:
|
|
1634
|
+
row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw)
|
|
1635
|
+
else:
|
|
1636
|
+
row[f"{var}_mean"] = np.nan
|
|
1637
|
+
|
|
1638
|
+
row["srs_fallback"] = cell_srs_fallback
|
|
1639
|
+
if cell_srs_fallback:
|
|
1640
|
+
srs_cells.append(cell_key_str)
|
|
1641
|
+
|
|
1642
|
+
rows.append(row)
|
|
1643
|
+
|
|
1644
|
+
# --- Warnings ---
|
|
1645
|
+
if srs_cells:
|
|
1646
|
+
warnings.warn(
|
|
1647
|
+
f"Design-based variance not estimable for {len(srs_cells)} cell(s); "
|
|
1648
|
+
f"using SRS fallback: {srs_cells[:5]}"
|
|
1649
|
+
+ (f" ... and {len(srs_cells) - 5} more" if len(srs_cells) > 5 else ""),
|
|
1650
|
+
UserWarning,
|
|
1651
|
+
stacklevel=2,
|
|
1652
|
+
)
|
|
1653
|
+
if zero_var_cells:
|
|
1654
|
+
warnings.warn(
|
|
1655
|
+
f"Zero variance in {len(zero_var_cells)} cell(s) (precision set to NaN): "
|
|
1656
|
+
f"{zero_var_cells[:5]}"
|
|
1657
|
+
+ (f" ... and {len(zero_var_cells) - 5} more" if len(zero_var_cells) > 5 else ""),
|
|
1658
|
+
UserWarning,
|
|
1659
|
+
stacklevel=2,
|
|
1660
|
+
)
|
|
1661
|
+
|
|
1662
|
+
# --- Assemble output ---
|
|
1663
|
+
panel_df = pd.DataFrame(rows)
|
|
1664
|
+
|
|
1665
|
+
# Sort by grouping columns
|
|
1666
|
+
panel_df = panel_df.sort_values(by_cols).reset_index(drop=True)
|
|
1667
|
+
|
|
1668
|
+
# --- Drop non-estimable cells ---
|
|
1669
|
+
# Cells with non-finite mean (n_valid==0 or all-missing) cannot contribute
|
|
1670
|
+
# to second-stage estimation and would cause fit() to reject NaN outcomes.
|
|
1671
|
+
# Dropping them also removes all-zero-weight PSUs from the panel.
|
|
1672
|
+
first_outcome = outcome_cols[0]
|
|
1673
|
+
mean_col = f"{first_outcome}_mean"
|
|
1674
|
+
nonestimable = ~np.isfinite(panel_df[mean_col].values)
|
|
1675
|
+
if np.any(nonestimable):
|
|
1676
|
+
n_dropped = int(np.sum(nonestimable))
|
|
1677
|
+
dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist()
|
|
1678
|
+
# Warn about secondary outcomes losing valid data in dropped cells
|
|
1679
|
+
secondary_loss = []
|
|
1680
|
+
for var in outcome_cols[1:]:
|
|
1681
|
+
valid_secondary = np.isfinite(panel_df.loc[nonestimable, f"{var}_mean"].values)
|
|
1682
|
+
if np.any(valid_secondary):
|
|
1683
|
+
secondary_loss.append(var)
|
|
1684
|
+
msg = (
|
|
1685
|
+
f"Dropped {n_dropped} non-estimable cell(s) (based on first outcome "
|
|
1686
|
+
f"'{first_outcome}'): {dropped_keys[:5]}"
|
|
1687
|
+
+ (f" ... and {n_dropped - 5} more" if n_dropped > 5 else "")
|
|
1688
|
+
)
|
|
1689
|
+
if secondary_loss:
|
|
1690
|
+
msg += (
|
|
1691
|
+
f". Note: {secondary_loss} had valid data in dropped cells. "
|
|
1692
|
+
f"For independent per-outcome support, call once per outcome."
|
|
1693
|
+
)
|
|
1694
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
1695
|
+
panel_df = panel_df[~nonestimable].reset_index(drop=True)
|
|
1696
|
+
|
|
1697
|
+
# --- Construct second-stage SurveyDesign ---
|
|
1698
|
+
# Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream
|
|
1699
|
+
# resolve() doesn't reject missing weights. Diagnostic *_precision is kept.
|
|
1700
|
+
weight_col = f"{first_outcome}_weight"
|
|
1701
|
+
panel_df[weight_col] = np.where(
|
|
1702
|
+
np.isfinite(panel_df[f"{first_outcome}_precision"]),
|
|
1703
|
+
panel_df[f"{first_outcome}_precision"],
|
|
1704
|
+
0.0,
|
|
1705
|
+
)
|
|
1706
|
+
|
|
1707
|
+
# Drop geographic units (PSUs) with zero total weight — they would
|
|
1708
|
+
# inflate survey df and distort second-stage variance estimation.
|
|
1709
|
+
geo_col = by_cols[0]
|
|
1710
|
+
geo_weight = panel_df.groupby(geo_col)[weight_col].sum()
|
|
1711
|
+
zero_geos = geo_weight[geo_weight == 0].index
|
|
1712
|
+
if len(zero_geos) > 0:
|
|
1713
|
+
n_before = len(panel_df)
|
|
1714
|
+
panel_df = panel_df[~panel_df[geo_col].isin(zero_geos)].reset_index(drop=True)
|
|
1715
|
+
n_after = len(panel_df)
|
|
1716
|
+
warnings.warn(
|
|
1717
|
+
f"Dropped {n_before - n_after} cell(s) from {len(zero_geos)} "
|
|
1718
|
+
f"geographic unit(s) with zero total weight: "
|
|
1719
|
+
f"{list(zero_geos[:5])}"
|
|
1720
|
+
+ (f" ... and {len(zero_geos) - 5} more" if len(zero_geos) > 5 else ""),
|
|
1721
|
+
UserWarning,
|
|
1722
|
+
stacklevel=2,
|
|
1723
|
+
)
|
|
1724
|
+
|
|
1725
|
+
# Guard: all cells dropped
|
|
1726
|
+
if panel_df.empty:
|
|
1727
|
+
raise ValueError(
|
|
1728
|
+
"No estimable cells remain after aggregation. "
|
|
1729
|
+
"All cells had missing outcomes or zero effective weight."
|
|
1730
|
+
)
|
|
1731
|
+
|
|
1732
|
+
second_stage_design = SurveyDesign(
|
|
1733
|
+
weights=weight_col,
|
|
1734
|
+
weight_type="aweight",
|
|
1735
|
+
psu=geo_col,
|
|
1736
|
+
)
|
|
1737
|
+
|
|
1738
|
+
return panel_df, second_stage_design
|