diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +226 -0
- diff_diff/_backend.py +64 -0
- diff_diff/_rust_backend.cpython-312-darwin.so +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1000 -0
- diff_diff/honest_did.py +1493 -0
- diff_diff/linalg.py +980 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1338 -0
- diff_diff/pretrends.py +1067 -0
- diff_diff/results.py +703 -0
- diff_diff/staggered.py +2297 -0
- diff_diff/sun_abraham.py +1176 -0
- diff_diff/synthetic_did.py +738 -0
- diff_diff/triple_diff.py +1291 -0
- diff_diff/twfe.py +344 -0
- diff_diff/utils.py +1481 -0
- diff_diff/visualization.py +1627 -0
- diff_diff-2.0.4.dist-info/METADATA +2257 -0
- diff_diff-2.0.4.dist-info/RECORD +23 -0
- diff_diff-2.0.4.dist-info/WHEEL +4 -0
diff_diff/prep.py
ADDED
|
@@ -0,0 +1,1338 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data preparation utilities for difference-in-differences analysis.
|
|
3
|
+
|
|
4
|
+
This module provides helper functions to prepare data for DiD estimation,
|
|
5
|
+
including creating treatment indicators, reshaping panel data, and
|
|
6
|
+
generating synthetic datasets for testing.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from diff_diff.utils import compute_synthetic_weights
|
|
15
|
+
|
|
16
|
+
# Constants for rank_control_units
|
|
17
|
+
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
|
|
18
|
+
_OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def make_treatment_indicator(
|
|
22
|
+
data: pd.DataFrame,
|
|
23
|
+
column: str,
|
|
24
|
+
treated_values: Optional[Union[Any, List[Any]]] = None,
|
|
25
|
+
threshold: Optional[float] = None,
|
|
26
|
+
above_threshold: bool = True,
|
|
27
|
+
new_column: str = "treated"
|
|
28
|
+
) -> pd.DataFrame:
|
|
29
|
+
"""
|
|
30
|
+
Create a binary treatment indicator column from various input types.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
data : pd.DataFrame
|
|
35
|
+
Input DataFrame.
|
|
36
|
+
column : str
|
|
37
|
+
Name of the column to use for creating the treatment indicator.
|
|
38
|
+
treated_values : Any or list, optional
|
|
39
|
+
Value(s) that indicate treatment. Units with these values get
|
|
40
|
+
treatment=1, others get treatment=0.
|
|
41
|
+
threshold : float, optional
|
|
42
|
+
Numeric threshold for creating treatment. Used when the treatment
|
|
43
|
+
is based on a continuous variable (e.g., treat firms above median size).
|
|
44
|
+
above_threshold : bool, default=True
|
|
45
|
+
If True, values >= threshold are treated. If False, values <= threshold
|
|
46
|
+
are treated. Only used when threshold is specified.
|
|
47
|
+
new_column : str, default="treated"
|
|
48
|
+
Name of the new treatment indicator column.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
pd.DataFrame
|
|
53
|
+
DataFrame with the new treatment indicator column added.
|
|
54
|
+
|
|
55
|
+
Examples
|
|
56
|
+
--------
|
|
57
|
+
Create treatment from categorical variable:
|
|
58
|
+
|
|
59
|
+
>>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
|
|
60
|
+
>>> df = make_treatment_indicator(df, 'group', treated_values='A')
|
|
61
|
+
>>> df['treated'].tolist()
|
|
62
|
+
[1, 1, 0, 0]
|
|
63
|
+
|
|
64
|
+
Create treatment from numeric threshold:
|
|
65
|
+
|
|
66
|
+
>>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
|
|
67
|
+
>>> df = make_treatment_indicator(df, 'size', threshold=75)
|
|
68
|
+
>>> df['treated'].tolist()
|
|
69
|
+
[0, 0, 1, 1]
|
|
70
|
+
|
|
71
|
+
Treat units below a threshold:
|
|
72
|
+
|
|
73
|
+
>>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
|
|
74
|
+
>>> df['treated'].tolist()
|
|
75
|
+
[1, 1, 0, 0]
|
|
76
|
+
"""
|
|
77
|
+
df = data.copy()
|
|
78
|
+
|
|
79
|
+
if treated_values is not None and threshold is not None:
|
|
80
|
+
raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
|
|
81
|
+
|
|
82
|
+
if treated_values is None and threshold is None:
|
|
83
|
+
raise ValueError("Must specify either 'treated_values' or 'threshold'.")
|
|
84
|
+
|
|
85
|
+
if column not in df.columns:
|
|
86
|
+
raise ValueError(f"Column '{column}' not found in DataFrame.")
|
|
87
|
+
|
|
88
|
+
if treated_values is not None:
|
|
89
|
+
# Convert single value to list
|
|
90
|
+
if not isinstance(treated_values, (list, tuple, set)):
|
|
91
|
+
treated_values = [treated_values]
|
|
92
|
+
df[new_column] = df[column].isin(treated_values).astype(int)
|
|
93
|
+
else:
|
|
94
|
+
# Use threshold
|
|
95
|
+
if above_threshold:
|
|
96
|
+
df[new_column] = (df[column] >= threshold).astype(int)
|
|
97
|
+
else:
|
|
98
|
+
df[new_column] = (df[column] <= threshold).astype(int)
|
|
99
|
+
|
|
100
|
+
return df
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def make_post_indicator(
|
|
104
|
+
data: pd.DataFrame,
|
|
105
|
+
time_column: str,
|
|
106
|
+
post_periods: Optional[Union[Any, List[Any]]] = None,
|
|
107
|
+
treatment_start: Optional[Any] = None,
|
|
108
|
+
new_column: str = "post"
|
|
109
|
+
) -> pd.DataFrame:
|
|
110
|
+
"""
|
|
111
|
+
Create a binary post-treatment indicator column.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
data : pd.DataFrame
|
|
116
|
+
Input DataFrame.
|
|
117
|
+
time_column : str
|
|
118
|
+
Name of the time/period column.
|
|
119
|
+
post_periods : Any or list, optional
|
|
120
|
+
Specific period value(s) that are post-treatment. Periods matching
|
|
121
|
+
these values get post=1, others get post=0.
|
|
122
|
+
treatment_start : Any, optional
|
|
123
|
+
The first post-treatment period. All periods >= this value get post=1.
|
|
124
|
+
Works with numeric periods, strings (sorted alphabetically), or dates.
|
|
125
|
+
new_column : str, default="post"
|
|
126
|
+
Name of the new post indicator column.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
pd.DataFrame
|
|
131
|
+
DataFrame with the new post indicator column added.
|
|
132
|
+
|
|
133
|
+
Examples
|
|
134
|
+
--------
|
|
135
|
+
Using specific post periods:
|
|
136
|
+
|
|
137
|
+
>>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
|
|
138
|
+
>>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
|
|
139
|
+
>>> df['post'].tolist()
|
|
140
|
+
[0, 0, 1, 1]
|
|
141
|
+
|
|
142
|
+
Using treatment start:
|
|
143
|
+
|
|
144
|
+
>>> df = make_post_indicator(df, 'year', treatment_start=2020)
|
|
145
|
+
>>> df['post'].tolist()
|
|
146
|
+
[0, 0, 1, 1]
|
|
147
|
+
|
|
148
|
+
Works with date columns:
|
|
149
|
+
|
|
150
|
+
>>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
|
|
151
|
+
>>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
|
|
152
|
+
"""
|
|
153
|
+
df = data.copy()
|
|
154
|
+
|
|
155
|
+
if post_periods is not None and treatment_start is not None:
|
|
156
|
+
raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
|
|
157
|
+
|
|
158
|
+
if post_periods is None and treatment_start is None:
|
|
159
|
+
raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
|
|
160
|
+
|
|
161
|
+
if time_column not in df.columns:
|
|
162
|
+
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
|
|
163
|
+
|
|
164
|
+
if post_periods is not None:
|
|
165
|
+
# Convert single value to list
|
|
166
|
+
if not isinstance(post_periods, (list, tuple, set)):
|
|
167
|
+
post_periods = [post_periods]
|
|
168
|
+
df[new_column] = df[time_column].isin(post_periods).astype(int)
|
|
169
|
+
else:
|
|
170
|
+
# Use treatment_start - convert to same type as column if needed
|
|
171
|
+
col_dtype = df[time_column].dtype
|
|
172
|
+
if pd.api.types.is_datetime64_any_dtype(col_dtype):
|
|
173
|
+
treatment_start = pd.to_datetime(treatment_start)
|
|
174
|
+
df[new_column] = (df[time_column] >= treatment_start).astype(int)
|
|
175
|
+
|
|
176
|
+
return df
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def wide_to_long(
|
|
180
|
+
data: pd.DataFrame,
|
|
181
|
+
value_columns: List[str],
|
|
182
|
+
id_column: str,
|
|
183
|
+
time_name: str = "period",
|
|
184
|
+
value_name: str = "value",
|
|
185
|
+
time_values: Optional[List[Any]] = None
|
|
186
|
+
) -> pd.DataFrame:
|
|
187
|
+
"""
|
|
188
|
+
Convert wide-format panel data to long format for DiD analysis.
|
|
189
|
+
|
|
190
|
+
Wide format has one row per unit with multiple columns for each time period.
|
|
191
|
+
Long format has one row per unit-period combination.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
data : pd.DataFrame
|
|
196
|
+
Wide-format DataFrame with one row per unit.
|
|
197
|
+
value_columns : list of str
|
|
198
|
+
Column names containing the outcome values for each period.
|
|
199
|
+
These should be in chronological order.
|
|
200
|
+
id_column : str
|
|
201
|
+
Column name for the unit identifier.
|
|
202
|
+
time_name : str, default="period"
|
|
203
|
+
Name for the new time period column.
|
|
204
|
+
value_name : str, default="value"
|
|
205
|
+
Name for the new value/outcome column.
|
|
206
|
+
time_values : list, optional
|
|
207
|
+
Values to use for time periods. If None, uses 0, 1, 2, ...
|
|
208
|
+
Must have same length as value_columns.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
pd.DataFrame
|
|
213
|
+
Long-format DataFrame with one row per unit-period.
|
|
214
|
+
|
|
215
|
+
Examples
|
|
216
|
+
--------
|
|
217
|
+
>>> wide_df = pd.DataFrame({
|
|
218
|
+
... 'firm_id': [1, 2, 3],
|
|
219
|
+
... 'sales_2019': [100, 150, 200],
|
|
220
|
+
... 'sales_2020': [110, 160, 210],
|
|
221
|
+
... 'sales_2021': [120, 170, 220]
|
|
222
|
+
... })
|
|
223
|
+
>>> long_df = wide_to_long(
|
|
224
|
+
... wide_df,
|
|
225
|
+
... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
|
|
226
|
+
... id_column='firm_id',
|
|
227
|
+
... time_name='year',
|
|
228
|
+
... value_name='sales',
|
|
229
|
+
... time_values=[2019, 2020, 2021]
|
|
230
|
+
... )
|
|
231
|
+
>>> len(long_df)
|
|
232
|
+
9
|
|
233
|
+
>>> long_df.columns.tolist()
|
|
234
|
+
['firm_id', 'year', 'sales']
|
|
235
|
+
"""
|
|
236
|
+
if not value_columns:
|
|
237
|
+
raise ValueError("value_columns cannot be empty.")
|
|
238
|
+
|
|
239
|
+
if id_column not in data.columns:
|
|
240
|
+
raise ValueError(f"Column '{id_column}' not found in DataFrame.")
|
|
241
|
+
|
|
242
|
+
for col in value_columns:
|
|
243
|
+
if col not in data.columns:
|
|
244
|
+
raise ValueError(f"Column '{col}' not found in DataFrame.")
|
|
245
|
+
|
|
246
|
+
if time_values is None:
|
|
247
|
+
time_values = list(range(len(value_columns)))
|
|
248
|
+
|
|
249
|
+
if len(time_values) != len(value_columns):
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"time_values length ({len(time_values)}) must match "
|
|
252
|
+
f"value_columns length ({len(value_columns)})."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Get other columns to preserve (not id or value columns)
|
|
256
|
+
other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
|
|
257
|
+
|
|
258
|
+
# Use pd.melt for better performance (vectorized)
|
|
259
|
+
long_df = pd.melt(
|
|
260
|
+
data,
|
|
261
|
+
id_vars=[id_column] + other_cols,
|
|
262
|
+
value_vars=value_columns,
|
|
263
|
+
var_name='_temp_var',
|
|
264
|
+
value_name=value_name
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Map column names to time values
|
|
268
|
+
col_to_time = dict(zip(value_columns, time_values))
|
|
269
|
+
long_df[time_name] = long_df['_temp_var'].map(col_to_time)
|
|
270
|
+
long_df = long_df.drop('_temp_var', axis=1)
|
|
271
|
+
|
|
272
|
+
# Reorder columns and sort
|
|
273
|
+
cols = [id_column, time_name, value_name] + other_cols
|
|
274
|
+
return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def balance_panel(
|
|
278
|
+
data: pd.DataFrame,
|
|
279
|
+
unit_column: str,
|
|
280
|
+
time_column: str,
|
|
281
|
+
method: str = "inner",
|
|
282
|
+
fill_value: Optional[float] = None
|
|
283
|
+
) -> pd.DataFrame:
|
|
284
|
+
"""
|
|
285
|
+
Balance a panel dataset to ensure all units have all time periods.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
data : pd.DataFrame
|
|
290
|
+
Unbalanced panel data.
|
|
291
|
+
unit_column : str
|
|
292
|
+
Column name for unit identifier.
|
|
293
|
+
time_column : str
|
|
294
|
+
Column name for time period.
|
|
295
|
+
method : str, default="inner"
|
|
296
|
+
Balancing method:
|
|
297
|
+
- "inner": Keep only units that appear in all periods (drops units)
|
|
298
|
+
- "outer": Include all unit-period combinations (creates NaN)
|
|
299
|
+
- "fill": Include all combinations and fill missing values
|
|
300
|
+
fill_value : float, optional
|
|
301
|
+
Value to fill missing observations when method="fill".
|
|
302
|
+
If None with method="fill", uses column-specific forward fill.
|
|
303
|
+
|
|
304
|
+
Returns
|
|
305
|
+
-------
|
|
306
|
+
pd.DataFrame
|
|
307
|
+
Balanced panel DataFrame.
|
|
308
|
+
|
|
309
|
+
Examples
|
|
310
|
+
--------
|
|
311
|
+
Keep only complete units:
|
|
312
|
+
|
|
313
|
+
>>> df = pd.DataFrame({
|
|
314
|
+
... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
|
|
315
|
+
... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
|
|
316
|
+
... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
|
|
317
|
+
... })
|
|
318
|
+
>>> balanced = balance_panel(df, 'unit', 'period', method='inner')
|
|
319
|
+
>>> balanced['unit'].unique().tolist()
|
|
320
|
+
[1, 3]
|
|
321
|
+
|
|
322
|
+
Include all combinations:
|
|
323
|
+
|
|
324
|
+
>>> balanced = balance_panel(df, 'unit', 'period', method='outer')
|
|
325
|
+
>>> len(balanced)
|
|
326
|
+
9
|
|
327
|
+
"""
|
|
328
|
+
if unit_column not in data.columns:
|
|
329
|
+
raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
|
|
330
|
+
if time_column not in data.columns:
|
|
331
|
+
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
|
|
332
|
+
|
|
333
|
+
if method not in ["inner", "outer", "fill"]:
|
|
334
|
+
raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
|
|
335
|
+
|
|
336
|
+
all_units = data[unit_column].unique()
|
|
337
|
+
all_periods = sorted(data[time_column].unique())
|
|
338
|
+
n_periods = len(all_periods)
|
|
339
|
+
|
|
340
|
+
if method == "inner":
|
|
341
|
+
# Keep only units that have all periods
|
|
342
|
+
unit_counts = data.groupby(unit_column)[time_column].nunique()
|
|
343
|
+
complete_units = unit_counts[unit_counts == n_periods].index
|
|
344
|
+
return data[data[unit_column].isin(complete_units)].copy()
|
|
345
|
+
|
|
346
|
+
elif method in ["outer", "fill"]:
|
|
347
|
+
# Create full grid of unit-period combinations
|
|
348
|
+
full_index = pd.MultiIndex.from_product(
|
|
349
|
+
[all_units, all_periods],
|
|
350
|
+
names=[unit_column, time_column]
|
|
351
|
+
)
|
|
352
|
+
full_df = pd.DataFrame(index=full_index).reset_index()
|
|
353
|
+
|
|
354
|
+
# Merge with original data
|
|
355
|
+
result = full_df.merge(data, on=[unit_column, time_column], how="left")
|
|
356
|
+
|
|
357
|
+
if method == "fill":
|
|
358
|
+
# Identify columns to fill (exclude unit and time columns)
|
|
359
|
+
cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
|
|
360
|
+
|
|
361
|
+
if fill_value is not None:
|
|
362
|
+
# Fill specified columns with fill_value
|
|
363
|
+
numeric_cols = result.select_dtypes(include=[np.number]).columns
|
|
364
|
+
for col in numeric_cols:
|
|
365
|
+
if col in cols_to_fill:
|
|
366
|
+
result[col] = result[col].fillna(fill_value)
|
|
367
|
+
else:
|
|
368
|
+
# Forward fill within each unit for non-key columns
|
|
369
|
+
result = result.sort_values([unit_column, time_column])
|
|
370
|
+
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
|
|
371
|
+
# Backward fill any remaining NaN at start
|
|
372
|
+
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
|
|
373
|
+
|
|
374
|
+
return result
|
|
375
|
+
|
|
376
|
+
return data
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def validate_did_data(
|
|
380
|
+
data: pd.DataFrame,
|
|
381
|
+
outcome: str,
|
|
382
|
+
treatment: str,
|
|
383
|
+
time: str,
|
|
384
|
+
unit: Optional[str] = None,
|
|
385
|
+
raise_on_error: bool = True
|
|
386
|
+
) -> Dict[str, Any]:
|
|
387
|
+
"""
|
|
388
|
+
Validate that data is properly formatted for DiD analysis.
|
|
389
|
+
|
|
390
|
+
Checks for common data issues and provides informative error messages.
|
|
391
|
+
|
|
392
|
+
Parameters
|
|
393
|
+
----------
|
|
394
|
+
data : pd.DataFrame
|
|
395
|
+
Data to validate.
|
|
396
|
+
outcome : str
|
|
397
|
+
Name of outcome variable column.
|
|
398
|
+
treatment : str
|
|
399
|
+
Name of treatment indicator column.
|
|
400
|
+
time : str
|
|
401
|
+
Name of time/post indicator column.
|
|
402
|
+
unit : str, optional
|
|
403
|
+
Name of unit identifier column (for panel data validation).
|
|
404
|
+
raise_on_error : bool, default=True
|
|
405
|
+
If True, raises ValueError on validation failures.
|
|
406
|
+
If False, returns validation results without raising.
|
|
407
|
+
|
|
408
|
+
Returns
|
|
409
|
+
-------
|
|
410
|
+
dict
|
|
411
|
+
Validation results with keys:
|
|
412
|
+
- valid: bool indicating if data passed all checks
|
|
413
|
+
- errors: list of error messages
|
|
414
|
+
- warnings: list of warning messages
|
|
415
|
+
- summary: dict with data summary statistics
|
|
416
|
+
|
|
417
|
+
Examples
|
|
418
|
+
--------
|
|
419
|
+
>>> df = pd.DataFrame({
|
|
420
|
+
... 'y': [1, 2, 3, 4],
|
|
421
|
+
... 'treated': [0, 0, 1, 1],
|
|
422
|
+
... 'post': [0, 1, 0, 1]
|
|
423
|
+
... })
|
|
424
|
+
>>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
|
|
425
|
+
>>> result['valid']
|
|
426
|
+
True
|
|
427
|
+
"""
|
|
428
|
+
errors = []
|
|
429
|
+
warnings = []
|
|
430
|
+
|
|
431
|
+
# Check columns exist
|
|
432
|
+
required_cols = [outcome, treatment, time]
|
|
433
|
+
if unit is not None:
|
|
434
|
+
required_cols.append(unit)
|
|
435
|
+
|
|
436
|
+
for col in required_cols:
|
|
437
|
+
if col not in data.columns:
|
|
438
|
+
errors.append(f"Required column '{col}' not found in DataFrame.")
|
|
439
|
+
|
|
440
|
+
if errors:
|
|
441
|
+
if raise_on_error:
|
|
442
|
+
raise ValueError("\n".join(errors))
|
|
443
|
+
return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
|
|
444
|
+
|
|
445
|
+
# Check outcome is numeric
|
|
446
|
+
if not pd.api.types.is_numeric_dtype(data[outcome]):
|
|
447
|
+
errors.append(
|
|
448
|
+
f"Outcome column '{outcome}' must be numeric. "
|
|
449
|
+
f"Got type: {data[outcome].dtype}"
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Check treatment is binary
|
|
453
|
+
treatment_vals = data[treatment].dropna().unique()
|
|
454
|
+
if not set(treatment_vals).issubset({0, 1}):
|
|
455
|
+
errors.append(
|
|
456
|
+
f"Treatment column '{treatment}' must be binary (0 or 1). "
|
|
457
|
+
f"Found values: {sorted(treatment_vals)}"
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Check time is binary for simple DiD
|
|
461
|
+
time_vals = data[time].dropna().unique()
|
|
462
|
+
if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
|
|
463
|
+
warnings.append(
|
|
464
|
+
f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
|
|
465
|
+
"For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Check for missing values
|
|
469
|
+
for col in required_cols:
|
|
470
|
+
n_missing = data[col].isna().sum()
|
|
471
|
+
if n_missing > 0:
|
|
472
|
+
errors.append(
|
|
473
|
+
f"Column '{col}' has {n_missing} missing values. "
|
|
474
|
+
"Please handle missing data before fitting."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Calculate summary statistics
|
|
478
|
+
summary = {}
|
|
479
|
+
if not errors:
|
|
480
|
+
summary["n_obs"] = len(data)
|
|
481
|
+
summary["n_treated"] = int((data[treatment] == 1).sum())
|
|
482
|
+
summary["n_control"] = int((data[treatment] == 0).sum())
|
|
483
|
+
summary["n_periods"] = len(time_vals)
|
|
484
|
+
|
|
485
|
+
if unit is not None:
|
|
486
|
+
summary["n_units"] = data[unit].nunique()
|
|
487
|
+
|
|
488
|
+
# Check for sufficient variation
|
|
489
|
+
if summary["n_treated"] == 0:
|
|
490
|
+
errors.append("No treated observations found (treatment column is all 0).")
|
|
491
|
+
if summary["n_control"] == 0:
|
|
492
|
+
errors.append("No control observations found (treatment column is all 1).")
|
|
493
|
+
|
|
494
|
+
# Check for each treatment-time combination
|
|
495
|
+
if len(time_vals) == 2:
|
|
496
|
+
# For 2-period DiD, check all four cells
|
|
497
|
+
for t_val in [0, 1]:
|
|
498
|
+
for p_val in time_vals:
|
|
499
|
+
count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
|
|
500
|
+
if count == 0:
|
|
501
|
+
errors.append(
|
|
502
|
+
f"No observations for treatment={t_val}, time={p_val}. "
|
|
503
|
+
"DiD requires observations in all treatment-time cells."
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
# For multi-period, check that both treatment groups exist in multiple periods
|
|
507
|
+
for t_val in [0, 1]:
|
|
508
|
+
n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
|
|
509
|
+
if n_periods_with_obs < 2:
|
|
510
|
+
group_name = "Treated" if t_val == 1 else "Control"
|
|
511
|
+
errors.append(
|
|
512
|
+
f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
|
|
513
|
+
"DiD requires multiple periods per group."
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Panel-specific validation
|
|
517
|
+
if unit is not None and not errors:
|
|
518
|
+
# Check treatment is constant within units
|
|
519
|
+
unit_treatment_var = data.groupby(unit)[treatment].nunique()
|
|
520
|
+
units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
|
|
521
|
+
if len(units_with_varying_treatment) > 0:
|
|
522
|
+
warnings.append(
|
|
523
|
+
f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
|
|
524
|
+
"For standard DiD, treatment should be constant within units. "
|
|
525
|
+
"This may be intentional for staggered adoption designs."
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Check panel balance
|
|
529
|
+
periods_per_unit = data.groupby(unit)[time].nunique()
|
|
530
|
+
if periods_per_unit.min() != periods_per_unit.max():
|
|
531
|
+
warnings.append(
|
|
532
|
+
f"Unbalanced panel detected. Units have between "
|
|
533
|
+
f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
|
|
534
|
+
"Consider using balance_panel() to balance the data."
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
valid = len(errors) == 0
|
|
538
|
+
|
|
539
|
+
if raise_on_error and not valid:
|
|
540
|
+
raise ValueError("Data validation failed:\n" + "\n".join(errors))
|
|
541
|
+
|
|
542
|
+
return {
|
|
543
|
+
"valid": valid,
|
|
544
|
+
"errors": errors,
|
|
545
|
+
"warnings": warnings,
|
|
546
|
+
"summary": summary
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def summarize_did_data(
|
|
551
|
+
data: pd.DataFrame,
|
|
552
|
+
outcome: str,
|
|
553
|
+
treatment: str,
|
|
554
|
+
time: str,
|
|
555
|
+
unit: Optional[str] = None
|
|
556
|
+
) -> pd.DataFrame:
|
|
557
|
+
"""
|
|
558
|
+
Generate summary statistics by treatment group and time period.
|
|
559
|
+
|
|
560
|
+
Parameters
|
|
561
|
+
----------
|
|
562
|
+
data : pd.DataFrame
|
|
563
|
+
Input data.
|
|
564
|
+
outcome : str
|
|
565
|
+
Name of outcome variable column.
|
|
566
|
+
treatment : str
|
|
567
|
+
Name of treatment indicator column.
|
|
568
|
+
time : str
|
|
569
|
+
Name of time/period column.
|
|
570
|
+
unit : str, optional
|
|
571
|
+
Name of unit identifier column.
|
|
572
|
+
|
|
573
|
+
Returns
|
|
574
|
+
-------
|
|
575
|
+
pd.DataFrame
|
|
576
|
+
Summary statistics with columns for each treatment-time combination.
|
|
577
|
+
|
|
578
|
+
Examples
|
|
579
|
+
--------
|
|
580
|
+
>>> df = pd.DataFrame({
|
|
581
|
+
... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
|
|
582
|
+
... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
|
|
583
|
+
... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
|
|
584
|
+
... })
|
|
585
|
+
>>> summary = summarize_did_data(df, 'y', 'treated', 'post')
|
|
586
|
+
>>> print(summary)
|
|
587
|
+
"""
|
|
588
|
+
# Group by treatment and time
|
|
589
|
+
summary = data.groupby([treatment, time])[outcome].agg([
|
|
590
|
+
("n", "count"),
|
|
591
|
+
("mean", "mean"),
|
|
592
|
+
("std", "std"),
|
|
593
|
+
("min", "min"),
|
|
594
|
+
("max", "max")
|
|
595
|
+
]).round(4)
|
|
596
|
+
|
|
597
|
+
# Calculate time values for labeling
|
|
598
|
+
time_vals = sorted(data[time].unique())
|
|
599
|
+
|
|
600
|
+
# Add group labels based on sorted time values (not literal 0/1)
|
|
601
|
+
if len(time_vals) == 2:
|
|
602
|
+
pre_val, post_val = time_vals[0], time_vals[1]
|
|
603
|
+
|
|
604
|
+
def format_label(x: tuple) -> str:
|
|
605
|
+
treatment_label = 'Treated' if x[0] == 1 else 'Control'
|
|
606
|
+
time_label = 'Post' if x[1] == post_val else 'Pre'
|
|
607
|
+
return f"{treatment_label} - {time_label}"
|
|
608
|
+
|
|
609
|
+
summary.index = summary.index.map(format_label)
|
|
610
|
+
|
|
611
|
+
# Calculate means for each cell
|
|
612
|
+
treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
|
|
613
|
+
treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
|
|
614
|
+
control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
|
|
615
|
+
control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
|
|
616
|
+
|
|
617
|
+
# Calculate DiD
|
|
618
|
+
treated_diff = treated_post - treated_pre
|
|
619
|
+
control_diff = control_post - control_pre
|
|
620
|
+
did_estimate = treated_diff - control_diff
|
|
621
|
+
|
|
622
|
+
# Add to summary as a new row
|
|
623
|
+
did_row = pd.DataFrame(
|
|
624
|
+
{
|
|
625
|
+
"n": ["-"],
|
|
626
|
+
"mean": [did_estimate],
|
|
627
|
+
"std": ["-"],
|
|
628
|
+
"min": ["-"],
|
|
629
|
+
"max": ["-"]
|
|
630
|
+
},
|
|
631
|
+
index=["DiD Estimate"]
|
|
632
|
+
)
|
|
633
|
+
summary = pd.concat([summary, did_row])
|
|
634
|
+
else:
|
|
635
|
+
summary.index = summary.index.map(
|
|
636
|
+
lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
return summary
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def generate_did_data(
|
|
643
|
+
n_units: int = 100,
|
|
644
|
+
n_periods: int = 4,
|
|
645
|
+
treatment_effect: float = 5.0,
|
|
646
|
+
treatment_fraction: float = 0.5,
|
|
647
|
+
treatment_period: int = 2,
|
|
648
|
+
unit_fe_sd: float = 2.0,
|
|
649
|
+
time_trend: float = 0.5,
|
|
650
|
+
noise_sd: float = 1.0,
|
|
651
|
+
seed: Optional[int] = None
|
|
652
|
+
) -> pd.DataFrame:
|
|
653
|
+
"""
|
|
654
|
+
Generate synthetic data for DiD analysis with known treatment effect.
|
|
655
|
+
|
|
656
|
+
Creates a balanced panel dataset with realistic features including
|
|
657
|
+
unit fixed effects, time trends, and a known treatment effect.
|
|
658
|
+
|
|
659
|
+
Parameters
|
|
660
|
+
----------
|
|
661
|
+
n_units : int, default=100
|
|
662
|
+
Number of units in the panel.
|
|
663
|
+
n_periods : int, default=4
|
|
664
|
+
Number of time periods.
|
|
665
|
+
treatment_effect : float, default=5.0
|
|
666
|
+
True average treatment effect on the treated.
|
|
667
|
+
treatment_fraction : float, default=0.5
|
|
668
|
+
Fraction of units that receive treatment.
|
|
669
|
+
treatment_period : int, default=2
|
|
670
|
+
First post-treatment period (0-indexed). Periods >= this are post.
|
|
671
|
+
unit_fe_sd : float, default=2.0
|
|
672
|
+
Standard deviation of unit fixed effects.
|
|
673
|
+
time_trend : float, default=0.5
|
|
674
|
+
Linear time trend coefficient.
|
|
675
|
+
noise_sd : float, default=1.0
|
|
676
|
+
Standard deviation of idiosyncratic noise.
|
|
677
|
+
seed : int, optional
|
|
678
|
+
Random seed for reproducibility.
|
|
679
|
+
|
|
680
|
+
Returns
|
|
681
|
+
-------
|
|
682
|
+
pd.DataFrame
|
|
683
|
+
Synthetic panel data with columns:
|
|
684
|
+
- unit: Unit identifier
|
|
685
|
+
- period: Time period
|
|
686
|
+
- treated: Treatment indicator (0/1)
|
|
687
|
+
- post: Post-treatment indicator (0/1)
|
|
688
|
+
- outcome: Outcome variable
|
|
689
|
+
- true_effect: The true treatment effect (for validation)
|
|
690
|
+
|
|
691
|
+
Examples
|
|
692
|
+
--------
|
|
693
|
+
Generate simple data for testing:
|
|
694
|
+
|
|
695
|
+
>>> data = generate_did_data(n_units=50, n_periods=4, treatment_effect=3.0, seed=42)
|
|
696
|
+
>>> len(data)
|
|
697
|
+
200
|
|
698
|
+
>>> data.columns.tolist()
|
|
699
|
+
['unit', 'period', 'treated', 'post', 'outcome', 'true_effect']
|
|
700
|
+
|
|
701
|
+
Verify treatment effect recovery:
|
|
702
|
+
|
|
703
|
+
>>> from diff_diff import DifferenceInDifferences
|
|
704
|
+
>>> did = DifferenceInDifferences()
|
|
705
|
+
>>> results = did.fit(data, outcome='outcome', treatment='treated', time='post')
|
|
706
|
+
>>> abs(results.att - 3.0) < 1.0 # Close to true effect
|
|
707
|
+
True
|
|
708
|
+
"""
|
|
709
|
+
rng = np.random.default_rng(seed)
|
|
710
|
+
|
|
711
|
+
# Determine treated units
|
|
712
|
+
n_treated = int(n_units * treatment_fraction)
|
|
713
|
+
treated_units = set(range(n_treated))
|
|
714
|
+
|
|
715
|
+
# Generate unit fixed effects
|
|
716
|
+
unit_fe = rng.normal(0, unit_fe_sd, n_units)
|
|
717
|
+
|
|
718
|
+
# Build data
|
|
719
|
+
records = []
|
|
720
|
+
for unit in range(n_units):
|
|
721
|
+
is_treated = unit in treated_units
|
|
722
|
+
|
|
723
|
+
for period in range(n_periods):
|
|
724
|
+
is_post = period >= treatment_period
|
|
725
|
+
|
|
726
|
+
# Base outcome
|
|
727
|
+
y = 10.0 # Baseline
|
|
728
|
+
y += unit_fe[unit] # Unit fixed effect
|
|
729
|
+
y += time_trend * period # Time trend
|
|
730
|
+
|
|
731
|
+
# Treatment effect (only for treated units in post-period)
|
|
732
|
+
effect = 0.0
|
|
733
|
+
if is_treated and is_post:
|
|
734
|
+
effect = treatment_effect
|
|
735
|
+
y += effect
|
|
736
|
+
|
|
737
|
+
# Add noise
|
|
738
|
+
y += rng.normal(0, noise_sd)
|
|
739
|
+
|
|
740
|
+
records.append({
|
|
741
|
+
"unit": unit,
|
|
742
|
+
"period": period,
|
|
743
|
+
"treated": int(is_treated),
|
|
744
|
+
"post": int(is_post),
|
|
745
|
+
"outcome": y,
|
|
746
|
+
"true_effect": effect
|
|
747
|
+
})
|
|
748
|
+
|
|
749
|
+
return pd.DataFrame(records)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def create_event_time(
|
|
753
|
+
data: pd.DataFrame,
|
|
754
|
+
time_column: str,
|
|
755
|
+
treatment_time_column: str,
|
|
756
|
+
new_column: str = "event_time"
|
|
757
|
+
) -> pd.DataFrame:
|
|
758
|
+
"""
|
|
759
|
+
Create an event-time column relative to treatment timing.
|
|
760
|
+
|
|
761
|
+
Useful for event study designs where treatment occurs at different
|
|
762
|
+
times for different units.
|
|
763
|
+
|
|
764
|
+
Parameters
|
|
765
|
+
----------
|
|
766
|
+
data : pd.DataFrame
|
|
767
|
+
Panel data.
|
|
768
|
+
time_column : str
|
|
769
|
+
Name of the calendar time column.
|
|
770
|
+
treatment_time_column : str
|
|
771
|
+
Name of the column indicating when each unit was treated.
|
|
772
|
+
Units with NaN or infinity are considered never-treated.
|
|
773
|
+
new_column : str, default="event_time"
|
|
774
|
+
Name of the new event-time column.
|
|
775
|
+
|
|
776
|
+
Returns
|
|
777
|
+
-------
|
|
778
|
+
pd.DataFrame
|
|
779
|
+
DataFrame with event-time column added. Values are:
|
|
780
|
+
- Negative for pre-treatment periods
|
|
781
|
+
- 0 for the treatment period
|
|
782
|
+
- Positive for post-treatment periods
|
|
783
|
+
- NaN for never-treated units
|
|
784
|
+
|
|
785
|
+
Examples
|
|
786
|
+
--------
|
|
787
|
+
>>> df = pd.DataFrame({
|
|
788
|
+
... 'unit': [1, 1, 1, 2, 2, 2],
|
|
789
|
+
... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
|
|
790
|
+
... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
|
|
791
|
+
... })
|
|
792
|
+
>>> df = create_event_time(df, 'year', 'treatment_year')
|
|
793
|
+
>>> df['event_time'].tolist()
|
|
794
|
+
[-1, 0, 1, -2, -1, 0]
|
|
795
|
+
"""
|
|
796
|
+
df = data.copy()
|
|
797
|
+
|
|
798
|
+
if time_column not in df.columns:
|
|
799
|
+
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
|
|
800
|
+
if treatment_time_column not in df.columns:
|
|
801
|
+
raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
|
|
802
|
+
|
|
803
|
+
# Calculate event time
|
|
804
|
+
df[new_column] = df[time_column] - df[treatment_time_column]
|
|
805
|
+
|
|
806
|
+
# Handle never-treated (inf or NaN in treatment time)
|
|
807
|
+
col = df[treatment_time_column]
|
|
808
|
+
if pd.api.types.is_numeric_dtype(col):
|
|
809
|
+
never_treated = col.isna() | np.isinf(col)
|
|
810
|
+
else:
|
|
811
|
+
never_treated = col.isna()
|
|
812
|
+
df.loc[never_treated, new_column] = np.nan
|
|
813
|
+
|
|
814
|
+
return df
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
def aggregate_to_cohorts(
|
|
818
|
+
data: pd.DataFrame,
|
|
819
|
+
unit_column: str,
|
|
820
|
+
time_column: str,
|
|
821
|
+
treatment_column: str,
|
|
822
|
+
outcome: str,
|
|
823
|
+
covariates: Optional[List[str]] = None
|
|
824
|
+
) -> pd.DataFrame:
|
|
825
|
+
"""
|
|
826
|
+
Aggregate unit-level data to treatment cohort means.
|
|
827
|
+
|
|
828
|
+
Useful for visualization and cohort-level analysis.
|
|
829
|
+
|
|
830
|
+
Parameters
|
|
831
|
+
----------
|
|
832
|
+
data : pd.DataFrame
|
|
833
|
+
Unit-level panel data.
|
|
834
|
+
unit_column : str
|
|
835
|
+
Name of unit identifier column.
|
|
836
|
+
time_column : str
|
|
837
|
+
Name of time period column.
|
|
838
|
+
treatment_column : str
|
|
839
|
+
Name of treatment indicator column.
|
|
840
|
+
outcome : str
|
|
841
|
+
Name of outcome variable column.
|
|
842
|
+
covariates : list of str, optional
|
|
843
|
+
Additional columns to aggregate (will compute means).
|
|
844
|
+
|
|
845
|
+
Returns
|
|
846
|
+
-------
|
|
847
|
+
pd.DataFrame
|
|
848
|
+
Cohort-level data with mean outcomes by treatment status and period.
|
|
849
|
+
|
|
850
|
+
Examples
|
|
851
|
+
--------
|
|
852
|
+
>>> df = pd.DataFrame({
|
|
853
|
+
... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
|
|
854
|
+
... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
|
|
855
|
+
... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
|
|
856
|
+
... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
|
|
857
|
+
... })
|
|
858
|
+
>>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
|
|
859
|
+
>>> len(cohort_df)
|
|
860
|
+
4
|
|
861
|
+
"""
|
|
862
|
+
agg_cols = {outcome: "mean", unit_column: "nunique"}
|
|
863
|
+
|
|
864
|
+
if covariates:
|
|
865
|
+
for cov in covariates:
|
|
866
|
+
agg_cols[cov] = "mean"
|
|
867
|
+
|
|
868
|
+
cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
|
|
869
|
+
|
|
870
|
+
# Rename columns
|
|
871
|
+
cohort_data = cohort_data.rename(columns={
|
|
872
|
+
unit_column: "n_units",
|
|
873
|
+
outcome: f"mean_{outcome}"
|
|
874
|
+
})
|
|
875
|
+
|
|
876
|
+
return cohort_data
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def rank_control_units(
|
|
880
|
+
data: pd.DataFrame,
|
|
881
|
+
unit_column: str,
|
|
882
|
+
time_column: str,
|
|
883
|
+
outcome_column: str,
|
|
884
|
+
treatment_column: Optional[str] = None,
|
|
885
|
+
treated_units: Optional[List[Any]] = None,
|
|
886
|
+
pre_periods: Optional[List[Any]] = None,
|
|
887
|
+
covariates: Optional[List[str]] = None,
|
|
888
|
+
outcome_weight: float = 0.7,
|
|
889
|
+
covariate_weight: float = 0.3,
|
|
890
|
+
exclude_units: Optional[List[Any]] = None,
|
|
891
|
+
require_units: Optional[List[Any]] = None,
|
|
892
|
+
n_top: Optional[int] = None,
|
|
893
|
+
suggest_treatment_candidates: bool = False,
|
|
894
|
+
n_treatment_candidates: int = 5,
|
|
895
|
+
lambda_reg: float = 0.0,
|
|
896
|
+
) -> pd.DataFrame:
|
|
897
|
+
"""
|
|
898
|
+
Rank potential control units by their suitability for DiD analysis.
|
|
899
|
+
|
|
900
|
+
Evaluates control units based on pre-treatment outcome trend similarity
|
|
901
|
+
and optional covariate matching to treated units. Returns a ranked list
|
|
902
|
+
with quality scores.
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
----------
|
|
906
|
+
data : pd.DataFrame
|
|
907
|
+
Panel data in long format.
|
|
908
|
+
unit_column : str
|
|
909
|
+
Column name for unit identifier.
|
|
910
|
+
time_column : str
|
|
911
|
+
Column name for time periods.
|
|
912
|
+
outcome_column : str
|
|
913
|
+
Column name for outcome variable.
|
|
914
|
+
treatment_column : str, optional
|
|
915
|
+
Column with binary treatment indicator (0/1). Used to identify
|
|
916
|
+
treated units from data.
|
|
917
|
+
treated_units : list, optional
|
|
918
|
+
Explicit list of treated unit IDs. Alternative to treatment_column.
|
|
919
|
+
pre_periods : list, optional
|
|
920
|
+
Pre-treatment periods for comparison. If None, uses first half of periods.
|
|
921
|
+
covariates : list of str, optional
|
|
922
|
+
Covariate columns for matching. Similarity is based on pre-treatment means.
|
|
923
|
+
outcome_weight : float, default=0.7
|
|
924
|
+
Weight for pre-treatment outcome trend similarity (0-1).
|
|
925
|
+
covariate_weight : float, default=0.3
|
|
926
|
+
Weight for covariate distance (0-1). Ignored if no covariates.
|
|
927
|
+
exclude_units : list, optional
|
|
928
|
+
Units that cannot be in control group.
|
|
929
|
+
require_units : list, optional
|
|
930
|
+
Units that must be in control group (will always appear in output).
|
|
931
|
+
n_top : int, optional
|
|
932
|
+
Return only top N control units. If None, return all.
|
|
933
|
+
suggest_treatment_candidates : bool, default=False
|
|
934
|
+
If True and no treated units specified, identify potential treatment
|
|
935
|
+
candidates instead of ranking controls.
|
|
936
|
+
n_treatment_candidates : int, default=5
|
|
937
|
+
Number of treatment candidates to suggest.
|
|
938
|
+
lambda_reg : float, default=0.0
|
|
939
|
+
Regularization for synthetic weights. Higher values give more uniform
|
|
940
|
+
weights across controls.
|
|
941
|
+
|
|
942
|
+
Returns
|
|
943
|
+
-------
|
|
944
|
+
pd.DataFrame
|
|
945
|
+
Ranked control units with columns:
|
|
946
|
+
- unit: Unit identifier
|
|
947
|
+
- quality_score: Combined quality score (0-1, higher is better)
|
|
948
|
+
- outcome_trend_score: Pre-treatment outcome trend similarity
|
|
949
|
+
- covariate_score: Covariate match score (NaN if no covariates)
|
|
950
|
+
- synthetic_weight: Weight from synthetic control optimization
|
|
951
|
+
- pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
|
|
952
|
+
- is_required: Whether unit was in require_units
|
|
953
|
+
|
|
954
|
+
If suggest_treatment_candidates=True (and no treated units):
|
|
955
|
+
- unit: Unit identifier
|
|
956
|
+
- treatment_candidate_score: Suitability as treatment unit
|
|
957
|
+
- avg_outcome_level: Pre-treatment outcome mean
|
|
958
|
+
- outcome_trend: Pre-treatment trend slope
|
|
959
|
+
- n_similar_controls: Count of similar potential controls
|
|
960
|
+
|
|
961
|
+
Examples
|
|
962
|
+
--------
|
|
963
|
+
Rank controls against treated units:
|
|
964
|
+
|
|
965
|
+
>>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
|
|
966
|
+
>>> ranking = rank_control_units(
|
|
967
|
+
... data,
|
|
968
|
+
... unit_column='unit',
|
|
969
|
+
... time_column='period',
|
|
970
|
+
... outcome_column='outcome',
|
|
971
|
+
... treatment_column='treated',
|
|
972
|
+
... n_top=10
|
|
973
|
+
... )
|
|
974
|
+
>>> ranking['quality_score'].is_monotonic_decreasing
|
|
975
|
+
True
|
|
976
|
+
|
|
977
|
+
With covariates:
|
|
978
|
+
|
|
979
|
+
>>> data['size'] = np.random.randn(len(data))
|
|
980
|
+
>>> ranking = rank_control_units(
|
|
981
|
+
... data,
|
|
982
|
+
... unit_column='unit',
|
|
983
|
+
... time_column='period',
|
|
984
|
+
... outcome_column='outcome',
|
|
985
|
+
... treatment_column='treated',
|
|
986
|
+
... covariates=['size']
|
|
987
|
+
... )
|
|
988
|
+
|
|
989
|
+
Filter data for SyntheticDiD:
|
|
990
|
+
|
|
991
|
+
>>> top_controls = ranking['unit'].tolist()
|
|
992
|
+
>>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
|
|
993
|
+
"""
|
|
994
|
+
# -------------------------------------------------------------------------
|
|
995
|
+
# Input validation
|
|
996
|
+
# -------------------------------------------------------------------------
|
|
997
|
+
for col in [unit_column, time_column, outcome_column]:
|
|
998
|
+
if col not in data.columns:
|
|
999
|
+
raise ValueError(f"Column '{col}' not found in DataFrame.")
|
|
1000
|
+
|
|
1001
|
+
if treatment_column is not None and treatment_column not in data.columns:
|
|
1002
|
+
raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
|
|
1003
|
+
|
|
1004
|
+
if covariates:
|
|
1005
|
+
for cov in covariates:
|
|
1006
|
+
if cov not in data.columns:
|
|
1007
|
+
raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
|
|
1008
|
+
|
|
1009
|
+
if not 0 <= outcome_weight <= 1:
|
|
1010
|
+
raise ValueError("outcome_weight must be between 0 and 1")
|
|
1011
|
+
if not 0 <= covariate_weight <= 1:
|
|
1012
|
+
raise ValueError("covariate_weight must be between 0 and 1")
|
|
1013
|
+
|
|
1014
|
+
if treated_units is not None and treatment_column is not None:
|
|
1015
|
+
raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
|
|
1016
|
+
|
|
1017
|
+
if require_units and exclude_units:
|
|
1018
|
+
invalid_required = [u for u in require_units if u in exclude_units]
|
|
1019
|
+
if invalid_required:
|
|
1020
|
+
raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
|
|
1021
|
+
|
|
1022
|
+
# -------------------------------------------------------------------------
|
|
1023
|
+
# Determine pre-treatment periods
|
|
1024
|
+
# -------------------------------------------------------------------------
|
|
1025
|
+
all_periods = sorted(data[time_column].unique())
|
|
1026
|
+
if pre_periods is None:
|
|
1027
|
+
mid_point = len(all_periods) // 2
|
|
1028
|
+
pre_periods = all_periods[:mid_point]
|
|
1029
|
+
else:
|
|
1030
|
+
pre_periods = list(pre_periods)
|
|
1031
|
+
|
|
1032
|
+
if len(pre_periods) == 0:
|
|
1033
|
+
raise ValueError("No pre-treatment periods specified or inferred.")
|
|
1034
|
+
|
|
1035
|
+
# -------------------------------------------------------------------------
|
|
1036
|
+
# Identify treated and control units
|
|
1037
|
+
# -------------------------------------------------------------------------
|
|
1038
|
+
all_units = list(data[unit_column].unique())
|
|
1039
|
+
|
|
1040
|
+
if treated_units is not None:
|
|
1041
|
+
treated_set = set(treated_units)
|
|
1042
|
+
elif treatment_column is not None:
|
|
1043
|
+
unit_treatment = data.groupby(unit_column)[treatment_column].first()
|
|
1044
|
+
treated_set = set(unit_treatment[unit_treatment == 1].index)
|
|
1045
|
+
elif suggest_treatment_candidates:
|
|
1046
|
+
# Treatment candidate discovery mode - no treated units
|
|
1047
|
+
treated_set = set()
|
|
1048
|
+
else:
|
|
1049
|
+
raise ValueError(
|
|
1050
|
+
"Must specify treated_units, treatment_column, or set "
|
|
1051
|
+
"suggest_treatment_candidates=True"
|
|
1052
|
+
)
|
|
1053
|
+
|
|
1054
|
+
# -------------------------------------------------------------------------
|
|
1055
|
+
# Treatment candidate discovery mode
|
|
1056
|
+
# -------------------------------------------------------------------------
|
|
1057
|
+
if suggest_treatment_candidates and len(treated_set) == 0:
|
|
1058
|
+
return _suggest_treatment_candidates(
|
|
1059
|
+
data, unit_column, time_column, outcome_column,
|
|
1060
|
+
pre_periods, n_treatment_candidates
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
if len(treated_set) == 0:
|
|
1064
|
+
raise ValueError("No treated units found.")
|
|
1065
|
+
|
|
1066
|
+
# Determine control candidates
|
|
1067
|
+
control_candidates = [u for u in all_units if u not in treated_set]
|
|
1068
|
+
|
|
1069
|
+
if exclude_units:
|
|
1070
|
+
control_candidates = [u for u in control_candidates if u not in exclude_units]
|
|
1071
|
+
|
|
1072
|
+
if len(control_candidates) == 0:
|
|
1073
|
+
raise ValueError("No control units available after exclusions.")
|
|
1074
|
+
|
|
1075
|
+
# -------------------------------------------------------------------------
|
|
1076
|
+
# Create outcome matrices (pre-treatment)
|
|
1077
|
+
# -------------------------------------------------------------------------
|
|
1078
|
+
pre_data = data[data[time_column].isin(pre_periods)]
|
|
1079
|
+
pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
|
|
1080
|
+
|
|
1081
|
+
# Filter to pre_periods that exist in data
|
|
1082
|
+
valid_pre_periods = [p for p in pre_periods if p in pivot.index]
|
|
1083
|
+
if len(valid_pre_periods) == 0:
|
|
1084
|
+
raise ValueError("No data found for specified pre-treatment periods.")
|
|
1085
|
+
|
|
1086
|
+
# Filter control_candidates to those present in pivot (handles unbalanced panels)
|
|
1087
|
+
control_candidates = [c for c in control_candidates if c in pivot.columns]
|
|
1088
|
+
if len(control_candidates) == 0:
|
|
1089
|
+
raise ValueError("No control units found in pre-treatment data.")
|
|
1090
|
+
|
|
1091
|
+
# Control outcomes: shape (n_pre_periods, n_control_candidates)
|
|
1092
|
+
Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
|
|
1093
|
+
|
|
1094
|
+
# Treated outcomes mean: shape (n_pre_periods,)
|
|
1095
|
+
treated_list = [u for u in treated_set if u in pivot.columns]
|
|
1096
|
+
if len(treated_list) == 0:
|
|
1097
|
+
raise ValueError("Treated units not found in pre-treatment data.")
|
|
1098
|
+
Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
|
|
1099
|
+
|
|
1100
|
+
# -------------------------------------------------------------------------
|
|
1101
|
+
# Compute outcome trend scores
|
|
1102
|
+
# -------------------------------------------------------------------------
|
|
1103
|
+
# Synthetic weights (higher = better match)
|
|
1104
|
+
synthetic_weights = compute_synthetic_weights(
|
|
1105
|
+
Y_control, Y_treated_mean, lambda_reg=lambda_reg
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
# RMSE for each control vs treated mean (use nanmean to handle missing data)
|
|
1109
|
+
rmse_scores = []
|
|
1110
|
+
for j in range(len(control_candidates)):
|
|
1111
|
+
y_c = Y_control[:, j]
|
|
1112
|
+
rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
|
|
1113
|
+
rmse_scores.append(rmse)
|
|
1114
|
+
|
|
1115
|
+
# Convert RMSE to similarity score (lower RMSE = higher score)
|
|
1116
|
+
max_rmse = max(rmse_scores) if rmse_scores else 1.0
|
|
1117
|
+
min_rmse = min(rmse_scores) if rmse_scores else 0.0
|
|
1118
|
+
rmse_range = max_rmse - min_rmse
|
|
1119
|
+
|
|
1120
|
+
if rmse_range < 1e-10:
|
|
1121
|
+
# All controls have identical/similar pre-trends (includes single control case)
|
|
1122
|
+
outcome_trend_scores = [1.0] * len(rmse_scores)
|
|
1123
|
+
else:
|
|
1124
|
+
# Normalize so best control gets 1.0, worst gets 0.0
|
|
1125
|
+
outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
|
|
1126
|
+
|
|
1127
|
+
# -------------------------------------------------------------------------
|
|
1128
|
+
# Compute covariate scores (if covariates provided)
|
|
1129
|
+
# -------------------------------------------------------------------------
|
|
1130
|
+
if covariates and len(covariates) > 0:
|
|
1131
|
+
# Get unit-level covariate values (pre-treatment mean)
|
|
1132
|
+
cov_data = pre_data.groupby(unit_column)[covariates].mean()
|
|
1133
|
+
|
|
1134
|
+
# Treated covariate profile (mean across treated units)
|
|
1135
|
+
treated_cov = cov_data.loc[list(treated_set)].mean()
|
|
1136
|
+
|
|
1137
|
+
# Standardize covariates
|
|
1138
|
+
cov_mean = cov_data.mean()
|
|
1139
|
+
cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
|
|
1140
|
+
cov_standardized = (cov_data - cov_mean) / cov_std
|
|
1141
|
+
treated_cov_std = (treated_cov - cov_mean) / cov_std
|
|
1142
|
+
|
|
1143
|
+
# Euclidean distance in standardized space (vectorized)
|
|
1144
|
+
control_cov_matrix = cov_standardized.loc[control_candidates].values
|
|
1145
|
+
treated_cov_vector = treated_cov_std.values
|
|
1146
|
+
covariate_distances = np.sqrt(
|
|
1147
|
+
np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
# Convert distance to similarity score (min-max normalization)
|
|
1151
|
+
max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
|
|
1152
|
+
min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
|
|
1153
|
+
dist_range = max_dist - min_dist
|
|
1154
|
+
|
|
1155
|
+
if dist_range < 1e-10:
|
|
1156
|
+
# All controls have identical/similar covariate profiles
|
|
1157
|
+
covariate_scores = [1.0] * len(covariate_distances)
|
|
1158
|
+
else:
|
|
1159
|
+
# Normalize so best control (closest) gets 1.0, worst gets 0.0
|
|
1160
|
+
covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
|
|
1161
|
+
else:
|
|
1162
|
+
covariate_scores = [np.nan] * len(control_candidates)
|
|
1163
|
+
|
|
1164
|
+
# -------------------------------------------------------------------------
|
|
1165
|
+
# Compute combined quality score
|
|
1166
|
+
# -------------------------------------------------------------------------
|
|
1167
|
+
# Normalize weights
|
|
1168
|
+
total_weight = outcome_weight + covariate_weight
|
|
1169
|
+
if total_weight > 0:
|
|
1170
|
+
norm_outcome_weight = outcome_weight / total_weight
|
|
1171
|
+
norm_covariate_weight = covariate_weight / total_weight
|
|
1172
|
+
else:
|
|
1173
|
+
norm_outcome_weight = 1.0
|
|
1174
|
+
norm_covariate_weight = 0.0
|
|
1175
|
+
|
|
1176
|
+
quality_scores = []
|
|
1177
|
+
for i in range(len(control_candidates)):
|
|
1178
|
+
outcome_score = outcome_trend_scores[i]
|
|
1179
|
+
cov_score = covariate_scores[i]
|
|
1180
|
+
|
|
1181
|
+
if np.isnan(cov_score):
|
|
1182
|
+
# No covariates - use only outcome score
|
|
1183
|
+
combined = outcome_score
|
|
1184
|
+
else:
|
|
1185
|
+
combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
|
|
1186
|
+
|
|
1187
|
+
quality_scores.append(combined)
|
|
1188
|
+
|
|
1189
|
+
# -------------------------------------------------------------------------
|
|
1190
|
+
# Build result DataFrame
|
|
1191
|
+
# -------------------------------------------------------------------------
|
|
1192
|
+
require_set = set(require_units) if require_units else set()
|
|
1193
|
+
|
|
1194
|
+
result = pd.DataFrame({
|
|
1195
|
+
'unit': control_candidates,
|
|
1196
|
+
'quality_score': quality_scores,
|
|
1197
|
+
'outcome_trend_score': outcome_trend_scores,
|
|
1198
|
+
'covariate_score': covariate_scores,
|
|
1199
|
+
'synthetic_weight': synthetic_weights,
|
|
1200
|
+
'pre_trend_rmse': rmse_scores,
|
|
1201
|
+
'is_required': [u in require_set for u in control_candidates]
|
|
1202
|
+
})
|
|
1203
|
+
|
|
1204
|
+
# Sort by quality score (descending)
|
|
1205
|
+
result = result.sort_values('quality_score', ascending=False)
|
|
1206
|
+
|
|
1207
|
+
# Apply n_top limit if specified
|
|
1208
|
+
if n_top is not None and n_top < len(result):
|
|
1209
|
+
# Always include required units
|
|
1210
|
+
required_df = result[result['is_required']]
|
|
1211
|
+
non_required_df = result[~result['is_required']]
|
|
1212
|
+
|
|
1213
|
+
# Take top from non-required to fill remaining slots
|
|
1214
|
+
remaining_slots = max(0, n_top - len(required_df))
|
|
1215
|
+
top_non_required = non_required_df.head(remaining_slots)
|
|
1216
|
+
|
|
1217
|
+
result = pd.concat([required_df, top_non_required])
|
|
1218
|
+
result = result.sort_values('quality_score', ascending=False)
|
|
1219
|
+
|
|
1220
|
+
return result.reset_index(drop=True)
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
def _suggest_treatment_candidates(
|
|
1224
|
+
data: pd.DataFrame,
|
|
1225
|
+
unit_column: str,
|
|
1226
|
+
time_column: str,
|
|
1227
|
+
outcome_column: str,
|
|
1228
|
+
pre_periods: List[Any],
|
|
1229
|
+
n_candidates: int
|
|
1230
|
+
) -> pd.DataFrame:
|
|
1231
|
+
"""
|
|
1232
|
+
Identify units that would make good treatment candidates.
|
|
1233
|
+
|
|
1234
|
+
A good treatment candidate:
|
|
1235
|
+
1. Has many similar control units available (for matching)
|
|
1236
|
+
2. Has stable pre-treatment trends (predictable counterfactual)
|
|
1237
|
+
3. Is not an extreme outlier
|
|
1238
|
+
|
|
1239
|
+
Parameters
|
|
1240
|
+
----------
|
|
1241
|
+
data : pd.DataFrame
|
|
1242
|
+
Panel data.
|
|
1243
|
+
unit_column : str
|
|
1244
|
+
Unit identifier column.
|
|
1245
|
+
time_column : str
|
|
1246
|
+
Time period column.
|
|
1247
|
+
outcome_column : str
|
|
1248
|
+
Outcome variable column.
|
|
1249
|
+
pre_periods : list
|
|
1250
|
+
Pre-treatment periods.
|
|
1251
|
+
n_candidates : int
|
|
1252
|
+
Number of candidates to return.
|
|
1253
|
+
|
|
1254
|
+
Returns
|
|
1255
|
+
-------
|
|
1256
|
+
pd.DataFrame
|
|
1257
|
+
Treatment candidates with scores.
|
|
1258
|
+
"""
|
|
1259
|
+
all_units = list(data[unit_column].unique())
|
|
1260
|
+
pre_data = data[data[time_column].isin(pre_periods)]
|
|
1261
|
+
|
|
1262
|
+
candidate_info = []
|
|
1263
|
+
|
|
1264
|
+
for unit in all_units:
|
|
1265
|
+
unit_data = pre_data[pre_data[unit_column] == unit]
|
|
1266
|
+
|
|
1267
|
+
if len(unit_data) == 0:
|
|
1268
|
+
continue
|
|
1269
|
+
|
|
1270
|
+
# Average outcome level
|
|
1271
|
+
avg_outcome = unit_data[outcome_column].mean()
|
|
1272
|
+
|
|
1273
|
+
# Trend (simple linear regression slope)
|
|
1274
|
+
times = unit_data[time_column].values
|
|
1275
|
+
outcomes = unit_data[outcome_column].values
|
|
1276
|
+
if len(times) > 1:
|
|
1277
|
+
times_norm = np.arange(len(times))
|
|
1278
|
+
try:
|
|
1279
|
+
slope = np.polyfit(times_norm, outcomes, 1)[0]
|
|
1280
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1281
|
+
slope = 0.0
|
|
1282
|
+
else:
|
|
1283
|
+
slope = 0.0
|
|
1284
|
+
|
|
1285
|
+
# Count similar potential controls
|
|
1286
|
+
other_units = [u for u in all_units if u != unit]
|
|
1287
|
+
other_means = pre_data[
|
|
1288
|
+
pre_data[unit_column].isin(other_units)
|
|
1289
|
+
].groupby(unit_column)[outcome_column].mean()
|
|
1290
|
+
|
|
1291
|
+
if len(other_means) > 0:
|
|
1292
|
+
sd = other_means.std()
|
|
1293
|
+
if sd > 0:
|
|
1294
|
+
n_similar = int(np.sum(
|
|
1295
|
+
np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd
|
|
1296
|
+
))
|
|
1297
|
+
else:
|
|
1298
|
+
n_similar = len(other_means)
|
|
1299
|
+
else:
|
|
1300
|
+
n_similar = 0
|
|
1301
|
+
|
|
1302
|
+
candidate_info.append({
|
|
1303
|
+
'unit': unit,
|
|
1304
|
+
'avg_outcome_level': avg_outcome,
|
|
1305
|
+
'outcome_trend': slope,
|
|
1306
|
+
'n_similar_controls': n_similar
|
|
1307
|
+
})
|
|
1308
|
+
|
|
1309
|
+
if len(candidate_info) == 0:
|
|
1310
|
+
return pd.DataFrame(columns=[
|
|
1311
|
+
'unit', 'treatment_candidate_score', 'avg_outcome_level',
|
|
1312
|
+
'outcome_trend', 'n_similar_controls'
|
|
1313
|
+
])
|
|
1314
|
+
|
|
1315
|
+
result = pd.DataFrame(candidate_info)
|
|
1316
|
+
|
|
1317
|
+
# Score: prefer units with many similar controls and moderate outcome levels
|
|
1318
|
+
max_similar = result['n_similar_controls'].max()
|
|
1319
|
+
if max_similar > 0:
|
|
1320
|
+
similarity_score = result['n_similar_controls'] / max_similar
|
|
1321
|
+
else:
|
|
1322
|
+
similarity_score = pd.Series([0.0] * len(result))
|
|
1323
|
+
|
|
1324
|
+
# Penalty for outliers in outcome level
|
|
1325
|
+
outcome_mean = result['avg_outcome_level'].mean()
|
|
1326
|
+
outcome_std = result['avg_outcome_level'].std()
|
|
1327
|
+
if outcome_std > 0:
|
|
1328
|
+
outcome_z = np.abs((result['avg_outcome_level'] - outcome_mean) / outcome_std)
|
|
1329
|
+
else:
|
|
1330
|
+
outcome_z = pd.Series([0.0] * len(result))
|
|
1331
|
+
|
|
1332
|
+
result['treatment_candidate_score'] = (
|
|
1333
|
+
similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
|
|
1334
|
+
).clip(0, 1)
|
|
1335
|
+
|
|
1336
|
+
# Return top candidates
|
|
1337
|
+
result = result.nlargest(n_candidates, 'treatment_candidate_score')
|
|
1338
|
+
return result.reset_index(drop=True)
|