diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/prep.py ADDED
@@ -0,0 +1,1338 @@
1
+ """
2
+ Data preparation utilities for difference-in-differences analysis.
3
+
4
+ This module provides helper functions to prepare data for DiD estimation,
5
+ including creating treatment indicators, reshaping panel data, and
6
+ generating synthetic datasets for testing.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ from diff_diff.utils import compute_synthetic_weights
15
+
16
+ # Constants for rank_control_units
17
+ _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
18
+ _OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
19
+
20
+
21
+ def make_treatment_indicator(
22
+ data: pd.DataFrame,
23
+ column: str,
24
+ treated_values: Optional[Union[Any, List[Any]]] = None,
25
+ threshold: Optional[float] = None,
26
+ above_threshold: bool = True,
27
+ new_column: str = "treated"
28
+ ) -> pd.DataFrame:
29
+ """
30
+ Create a binary treatment indicator column from various input types.
31
+
32
+ Parameters
33
+ ----------
34
+ data : pd.DataFrame
35
+ Input DataFrame.
36
+ column : str
37
+ Name of the column to use for creating the treatment indicator.
38
+ treated_values : Any or list, optional
39
+ Value(s) that indicate treatment. Units with these values get
40
+ treatment=1, others get treatment=0.
41
+ threshold : float, optional
42
+ Numeric threshold for creating treatment. Used when the treatment
43
+ is based on a continuous variable (e.g., treat firms above median size).
44
+ above_threshold : bool, default=True
45
+ If True, values >= threshold are treated. If False, values <= threshold
46
+ are treated. Only used when threshold is specified.
47
+ new_column : str, default="treated"
48
+ Name of the new treatment indicator column.
49
+
50
+ Returns
51
+ -------
52
+ pd.DataFrame
53
+ DataFrame with the new treatment indicator column added.
54
+
55
+ Examples
56
+ --------
57
+ Create treatment from categorical variable:
58
+
59
+ >>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
60
+ >>> df = make_treatment_indicator(df, 'group', treated_values='A')
61
+ >>> df['treated'].tolist()
62
+ [1, 1, 0, 0]
63
+
64
+ Create treatment from numeric threshold:
65
+
66
+ >>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
67
+ >>> df = make_treatment_indicator(df, 'size', threshold=75)
68
+ >>> df['treated'].tolist()
69
+ [0, 0, 1, 1]
70
+
71
+ Treat units below a threshold:
72
+
73
+ >>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
74
+ >>> df['treated'].tolist()
75
+ [1, 1, 0, 0]
76
+ """
77
+ df = data.copy()
78
+
79
+ if treated_values is not None and threshold is not None:
80
+ raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
81
+
82
+ if treated_values is None and threshold is None:
83
+ raise ValueError("Must specify either 'treated_values' or 'threshold'.")
84
+
85
+ if column not in df.columns:
86
+ raise ValueError(f"Column '{column}' not found in DataFrame.")
87
+
88
+ if treated_values is not None:
89
+ # Convert single value to list
90
+ if not isinstance(treated_values, (list, tuple, set)):
91
+ treated_values = [treated_values]
92
+ df[new_column] = df[column].isin(treated_values).astype(int)
93
+ else:
94
+ # Use threshold
95
+ if above_threshold:
96
+ df[new_column] = (df[column] >= threshold).astype(int)
97
+ else:
98
+ df[new_column] = (df[column] <= threshold).astype(int)
99
+
100
+ return df
101
+
102
+
103
+ def make_post_indicator(
104
+ data: pd.DataFrame,
105
+ time_column: str,
106
+ post_periods: Optional[Union[Any, List[Any]]] = None,
107
+ treatment_start: Optional[Any] = None,
108
+ new_column: str = "post"
109
+ ) -> pd.DataFrame:
110
+ """
111
+ Create a binary post-treatment indicator column.
112
+
113
+ Parameters
114
+ ----------
115
+ data : pd.DataFrame
116
+ Input DataFrame.
117
+ time_column : str
118
+ Name of the time/period column.
119
+ post_periods : Any or list, optional
120
+ Specific period value(s) that are post-treatment. Periods matching
121
+ these values get post=1, others get post=0.
122
+ treatment_start : Any, optional
123
+ The first post-treatment period. All periods >= this value get post=1.
124
+ Works with numeric periods, strings (sorted alphabetically), or dates.
125
+ new_column : str, default="post"
126
+ Name of the new post indicator column.
127
+
128
+ Returns
129
+ -------
130
+ pd.DataFrame
131
+ DataFrame with the new post indicator column added.
132
+
133
+ Examples
134
+ --------
135
+ Using specific post periods:
136
+
137
+ >>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
138
+ >>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
139
+ >>> df['post'].tolist()
140
+ [0, 0, 1, 1]
141
+
142
+ Using treatment start:
143
+
144
+ >>> df = make_post_indicator(df, 'year', treatment_start=2020)
145
+ >>> df['post'].tolist()
146
+ [0, 0, 1, 1]
147
+
148
+ Works with date columns:
149
+
150
+ >>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
151
+ >>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
152
+ """
153
+ df = data.copy()
154
+
155
+ if post_periods is not None and treatment_start is not None:
156
+ raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
157
+
158
+ if post_periods is None and treatment_start is None:
159
+ raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
160
+
161
+ if time_column not in df.columns:
162
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
163
+
164
+ if post_periods is not None:
165
+ # Convert single value to list
166
+ if not isinstance(post_periods, (list, tuple, set)):
167
+ post_periods = [post_periods]
168
+ df[new_column] = df[time_column].isin(post_periods).astype(int)
169
+ else:
170
+ # Use treatment_start - convert to same type as column if needed
171
+ col_dtype = df[time_column].dtype
172
+ if pd.api.types.is_datetime64_any_dtype(col_dtype):
173
+ treatment_start = pd.to_datetime(treatment_start)
174
+ df[new_column] = (df[time_column] >= treatment_start).astype(int)
175
+
176
+ return df
177
+
178
+
179
+ def wide_to_long(
180
+ data: pd.DataFrame,
181
+ value_columns: List[str],
182
+ id_column: str,
183
+ time_name: str = "period",
184
+ value_name: str = "value",
185
+ time_values: Optional[List[Any]] = None
186
+ ) -> pd.DataFrame:
187
+ """
188
+ Convert wide-format panel data to long format for DiD analysis.
189
+
190
+ Wide format has one row per unit with multiple columns for each time period.
191
+ Long format has one row per unit-period combination.
192
+
193
+ Parameters
194
+ ----------
195
+ data : pd.DataFrame
196
+ Wide-format DataFrame with one row per unit.
197
+ value_columns : list of str
198
+ Column names containing the outcome values for each period.
199
+ These should be in chronological order.
200
+ id_column : str
201
+ Column name for the unit identifier.
202
+ time_name : str, default="period"
203
+ Name for the new time period column.
204
+ value_name : str, default="value"
205
+ Name for the new value/outcome column.
206
+ time_values : list, optional
207
+ Values to use for time periods. If None, uses 0, 1, 2, ...
208
+ Must have same length as value_columns.
209
+
210
+ Returns
211
+ -------
212
+ pd.DataFrame
213
+ Long-format DataFrame with one row per unit-period.
214
+
215
+ Examples
216
+ --------
217
+ >>> wide_df = pd.DataFrame({
218
+ ... 'firm_id': [1, 2, 3],
219
+ ... 'sales_2019': [100, 150, 200],
220
+ ... 'sales_2020': [110, 160, 210],
221
+ ... 'sales_2021': [120, 170, 220]
222
+ ... })
223
+ >>> long_df = wide_to_long(
224
+ ... wide_df,
225
+ ... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
226
+ ... id_column='firm_id',
227
+ ... time_name='year',
228
+ ... value_name='sales',
229
+ ... time_values=[2019, 2020, 2021]
230
+ ... )
231
+ >>> len(long_df)
232
+ 9
233
+ >>> long_df.columns.tolist()
234
+ ['firm_id', 'year', 'sales']
235
+ """
236
+ if not value_columns:
237
+ raise ValueError("value_columns cannot be empty.")
238
+
239
+ if id_column not in data.columns:
240
+ raise ValueError(f"Column '{id_column}' not found in DataFrame.")
241
+
242
+ for col in value_columns:
243
+ if col not in data.columns:
244
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
245
+
246
+ if time_values is None:
247
+ time_values = list(range(len(value_columns)))
248
+
249
+ if len(time_values) != len(value_columns):
250
+ raise ValueError(
251
+ f"time_values length ({len(time_values)}) must match "
252
+ f"value_columns length ({len(value_columns)})."
253
+ )
254
+
255
+ # Get other columns to preserve (not id or value columns)
256
+ other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
257
+
258
+ # Use pd.melt for better performance (vectorized)
259
+ long_df = pd.melt(
260
+ data,
261
+ id_vars=[id_column] + other_cols,
262
+ value_vars=value_columns,
263
+ var_name='_temp_var',
264
+ value_name=value_name
265
+ )
266
+
267
+ # Map column names to time values
268
+ col_to_time = dict(zip(value_columns, time_values))
269
+ long_df[time_name] = long_df['_temp_var'].map(col_to_time)
270
+ long_df = long_df.drop('_temp_var', axis=1)
271
+
272
+ # Reorder columns and sort
273
+ cols = [id_column, time_name, value_name] + other_cols
274
+ return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
275
+
276
+
277
+ def balance_panel(
278
+ data: pd.DataFrame,
279
+ unit_column: str,
280
+ time_column: str,
281
+ method: str = "inner",
282
+ fill_value: Optional[float] = None
283
+ ) -> pd.DataFrame:
284
+ """
285
+ Balance a panel dataset to ensure all units have all time periods.
286
+
287
+ Parameters
288
+ ----------
289
+ data : pd.DataFrame
290
+ Unbalanced panel data.
291
+ unit_column : str
292
+ Column name for unit identifier.
293
+ time_column : str
294
+ Column name for time period.
295
+ method : str, default="inner"
296
+ Balancing method:
297
+ - "inner": Keep only units that appear in all periods (drops units)
298
+ - "outer": Include all unit-period combinations (creates NaN)
299
+ - "fill": Include all combinations and fill missing values
300
+ fill_value : float, optional
301
+ Value to fill missing observations when method="fill".
302
+ If None with method="fill", uses column-specific forward fill.
303
+
304
+ Returns
305
+ -------
306
+ pd.DataFrame
307
+ Balanced panel DataFrame.
308
+
309
+ Examples
310
+ --------
311
+ Keep only complete units:
312
+
313
+ >>> df = pd.DataFrame({
314
+ ... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
315
+ ... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
316
+ ... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
317
+ ... })
318
+ >>> balanced = balance_panel(df, 'unit', 'period', method='inner')
319
+ >>> balanced['unit'].unique().tolist()
320
+ [1, 3]
321
+
322
+ Include all combinations:
323
+
324
+ >>> balanced = balance_panel(df, 'unit', 'period', method='outer')
325
+ >>> len(balanced)
326
+ 9
327
+ """
328
+ if unit_column not in data.columns:
329
+ raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
330
+ if time_column not in data.columns:
331
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
332
+
333
+ if method not in ["inner", "outer", "fill"]:
334
+ raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
335
+
336
+ all_units = data[unit_column].unique()
337
+ all_periods = sorted(data[time_column].unique())
338
+ n_periods = len(all_periods)
339
+
340
+ if method == "inner":
341
+ # Keep only units that have all periods
342
+ unit_counts = data.groupby(unit_column)[time_column].nunique()
343
+ complete_units = unit_counts[unit_counts == n_periods].index
344
+ return data[data[unit_column].isin(complete_units)].copy()
345
+
346
+ elif method in ["outer", "fill"]:
347
+ # Create full grid of unit-period combinations
348
+ full_index = pd.MultiIndex.from_product(
349
+ [all_units, all_periods],
350
+ names=[unit_column, time_column]
351
+ )
352
+ full_df = pd.DataFrame(index=full_index).reset_index()
353
+
354
+ # Merge with original data
355
+ result = full_df.merge(data, on=[unit_column, time_column], how="left")
356
+
357
+ if method == "fill":
358
+ # Identify columns to fill (exclude unit and time columns)
359
+ cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
360
+
361
+ if fill_value is not None:
362
+ # Fill specified columns with fill_value
363
+ numeric_cols = result.select_dtypes(include=[np.number]).columns
364
+ for col in numeric_cols:
365
+ if col in cols_to_fill:
366
+ result[col] = result[col].fillna(fill_value)
367
+ else:
368
+ # Forward fill within each unit for non-key columns
369
+ result = result.sort_values([unit_column, time_column])
370
+ result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
371
+ # Backward fill any remaining NaN at start
372
+ result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
373
+
374
+ return result
375
+
376
+ return data
377
+
378
+
379
+ def validate_did_data(
380
+ data: pd.DataFrame,
381
+ outcome: str,
382
+ treatment: str,
383
+ time: str,
384
+ unit: Optional[str] = None,
385
+ raise_on_error: bool = True
386
+ ) -> Dict[str, Any]:
387
+ """
388
+ Validate that data is properly formatted for DiD analysis.
389
+
390
+ Checks for common data issues and provides informative error messages.
391
+
392
+ Parameters
393
+ ----------
394
+ data : pd.DataFrame
395
+ Data to validate.
396
+ outcome : str
397
+ Name of outcome variable column.
398
+ treatment : str
399
+ Name of treatment indicator column.
400
+ time : str
401
+ Name of time/post indicator column.
402
+ unit : str, optional
403
+ Name of unit identifier column (for panel data validation).
404
+ raise_on_error : bool, default=True
405
+ If True, raises ValueError on validation failures.
406
+ If False, returns validation results without raising.
407
+
408
+ Returns
409
+ -------
410
+ dict
411
+ Validation results with keys:
412
+ - valid: bool indicating if data passed all checks
413
+ - errors: list of error messages
414
+ - warnings: list of warning messages
415
+ - summary: dict with data summary statistics
416
+
417
+ Examples
418
+ --------
419
+ >>> df = pd.DataFrame({
420
+ ... 'y': [1, 2, 3, 4],
421
+ ... 'treated': [0, 0, 1, 1],
422
+ ... 'post': [0, 1, 0, 1]
423
+ ... })
424
+ >>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
425
+ >>> result['valid']
426
+ True
427
+ """
428
+ errors = []
429
+ warnings = []
430
+
431
+ # Check columns exist
432
+ required_cols = [outcome, treatment, time]
433
+ if unit is not None:
434
+ required_cols.append(unit)
435
+
436
+ for col in required_cols:
437
+ if col not in data.columns:
438
+ errors.append(f"Required column '{col}' not found in DataFrame.")
439
+
440
+ if errors:
441
+ if raise_on_error:
442
+ raise ValueError("\n".join(errors))
443
+ return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
444
+
445
+ # Check outcome is numeric
446
+ if not pd.api.types.is_numeric_dtype(data[outcome]):
447
+ errors.append(
448
+ f"Outcome column '{outcome}' must be numeric. "
449
+ f"Got type: {data[outcome].dtype}"
450
+ )
451
+
452
+ # Check treatment is binary
453
+ treatment_vals = data[treatment].dropna().unique()
454
+ if not set(treatment_vals).issubset({0, 1}):
455
+ errors.append(
456
+ f"Treatment column '{treatment}' must be binary (0 or 1). "
457
+ f"Found values: {sorted(treatment_vals)}"
458
+ )
459
+
460
+ # Check time is binary for simple DiD
461
+ time_vals = data[time].dropna().unique()
462
+ if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
463
+ warnings.append(
464
+ f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
465
+ "For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
466
+ )
467
+
468
+ # Check for missing values
469
+ for col in required_cols:
470
+ n_missing = data[col].isna().sum()
471
+ if n_missing > 0:
472
+ errors.append(
473
+ f"Column '{col}' has {n_missing} missing values. "
474
+ "Please handle missing data before fitting."
475
+ )
476
+
477
+ # Calculate summary statistics
478
+ summary = {}
479
+ if not errors:
480
+ summary["n_obs"] = len(data)
481
+ summary["n_treated"] = int((data[treatment] == 1).sum())
482
+ summary["n_control"] = int((data[treatment] == 0).sum())
483
+ summary["n_periods"] = len(time_vals)
484
+
485
+ if unit is not None:
486
+ summary["n_units"] = data[unit].nunique()
487
+
488
+ # Check for sufficient variation
489
+ if summary["n_treated"] == 0:
490
+ errors.append("No treated observations found (treatment column is all 0).")
491
+ if summary["n_control"] == 0:
492
+ errors.append("No control observations found (treatment column is all 1).")
493
+
494
+ # Check for each treatment-time combination
495
+ if len(time_vals) == 2:
496
+ # For 2-period DiD, check all four cells
497
+ for t_val in [0, 1]:
498
+ for p_val in time_vals:
499
+ count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
500
+ if count == 0:
501
+ errors.append(
502
+ f"No observations for treatment={t_val}, time={p_val}. "
503
+ "DiD requires observations in all treatment-time cells."
504
+ )
505
+ else:
506
+ # For multi-period, check that both treatment groups exist in multiple periods
507
+ for t_val in [0, 1]:
508
+ n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
509
+ if n_periods_with_obs < 2:
510
+ group_name = "Treated" if t_val == 1 else "Control"
511
+ errors.append(
512
+ f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
513
+ "DiD requires multiple periods per group."
514
+ )
515
+
516
+ # Panel-specific validation
517
+ if unit is not None and not errors:
518
+ # Check treatment is constant within units
519
+ unit_treatment_var = data.groupby(unit)[treatment].nunique()
520
+ units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
521
+ if len(units_with_varying_treatment) > 0:
522
+ warnings.append(
523
+ f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
524
+ "For standard DiD, treatment should be constant within units. "
525
+ "This may be intentional for staggered adoption designs."
526
+ )
527
+
528
+ # Check panel balance
529
+ periods_per_unit = data.groupby(unit)[time].nunique()
530
+ if periods_per_unit.min() != periods_per_unit.max():
531
+ warnings.append(
532
+ f"Unbalanced panel detected. Units have between "
533
+ f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
534
+ "Consider using balance_panel() to balance the data."
535
+ )
536
+
537
+ valid = len(errors) == 0
538
+
539
+ if raise_on_error and not valid:
540
+ raise ValueError("Data validation failed:\n" + "\n".join(errors))
541
+
542
+ return {
543
+ "valid": valid,
544
+ "errors": errors,
545
+ "warnings": warnings,
546
+ "summary": summary
547
+ }
548
+
549
+
550
+ def summarize_did_data(
551
+ data: pd.DataFrame,
552
+ outcome: str,
553
+ treatment: str,
554
+ time: str,
555
+ unit: Optional[str] = None
556
+ ) -> pd.DataFrame:
557
+ """
558
+ Generate summary statistics by treatment group and time period.
559
+
560
+ Parameters
561
+ ----------
562
+ data : pd.DataFrame
563
+ Input data.
564
+ outcome : str
565
+ Name of outcome variable column.
566
+ treatment : str
567
+ Name of treatment indicator column.
568
+ time : str
569
+ Name of time/period column.
570
+ unit : str, optional
571
+ Name of unit identifier column.
572
+
573
+ Returns
574
+ -------
575
+ pd.DataFrame
576
+ Summary statistics with columns for each treatment-time combination.
577
+
578
+ Examples
579
+ --------
580
+ >>> df = pd.DataFrame({
581
+ ... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
582
+ ... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
583
+ ... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
584
+ ... })
585
+ >>> summary = summarize_did_data(df, 'y', 'treated', 'post')
586
+ >>> print(summary)
587
+ """
588
+ # Group by treatment and time
589
+ summary = data.groupby([treatment, time])[outcome].agg([
590
+ ("n", "count"),
591
+ ("mean", "mean"),
592
+ ("std", "std"),
593
+ ("min", "min"),
594
+ ("max", "max")
595
+ ]).round(4)
596
+
597
+ # Calculate time values for labeling
598
+ time_vals = sorted(data[time].unique())
599
+
600
+ # Add group labels based on sorted time values (not literal 0/1)
601
+ if len(time_vals) == 2:
602
+ pre_val, post_val = time_vals[0], time_vals[1]
603
+
604
+ def format_label(x: tuple) -> str:
605
+ treatment_label = 'Treated' if x[0] == 1 else 'Control'
606
+ time_label = 'Post' if x[1] == post_val else 'Pre'
607
+ return f"{treatment_label} - {time_label}"
608
+
609
+ summary.index = summary.index.map(format_label)
610
+
611
+ # Calculate means for each cell
612
+ treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
613
+ treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
614
+ control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
615
+ control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
616
+
617
+ # Calculate DiD
618
+ treated_diff = treated_post - treated_pre
619
+ control_diff = control_post - control_pre
620
+ did_estimate = treated_diff - control_diff
621
+
622
+ # Add to summary as a new row
623
+ did_row = pd.DataFrame(
624
+ {
625
+ "n": ["-"],
626
+ "mean": [did_estimate],
627
+ "std": ["-"],
628
+ "min": ["-"],
629
+ "max": ["-"]
630
+ },
631
+ index=["DiD Estimate"]
632
+ )
633
+ summary = pd.concat([summary, did_row])
634
+ else:
635
+ summary.index = summary.index.map(
636
+ lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
637
+ )
638
+
639
+ return summary
640
+
641
+
642
+ def generate_did_data(
643
+ n_units: int = 100,
644
+ n_periods: int = 4,
645
+ treatment_effect: float = 5.0,
646
+ treatment_fraction: float = 0.5,
647
+ treatment_period: int = 2,
648
+ unit_fe_sd: float = 2.0,
649
+ time_trend: float = 0.5,
650
+ noise_sd: float = 1.0,
651
+ seed: Optional[int] = None
652
+ ) -> pd.DataFrame:
653
+ """
654
+ Generate synthetic data for DiD analysis with known treatment effect.
655
+
656
+ Creates a balanced panel dataset with realistic features including
657
+ unit fixed effects, time trends, and a known treatment effect.
658
+
659
+ Parameters
660
+ ----------
661
+ n_units : int, default=100
662
+ Number of units in the panel.
663
+ n_periods : int, default=4
664
+ Number of time periods.
665
+ treatment_effect : float, default=5.0
666
+ True average treatment effect on the treated.
667
+ treatment_fraction : float, default=0.5
668
+ Fraction of units that receive treatment.
669
+ treatment_period : int, default=2
670
+ First post-treatment period (0-indexed). Periods >= this are post.
671
+ unit_fe_sd : float, default=2.0
672
+ Standard deviation of unit fixed effects.
673
+ time_trend : float, default=0.5
674
+ Linear time trend coefficient.
675
+ noise_sd : float, default=1.0
676
+ Standard deviation of idiosyncratic noise.
677
+ seed : int, optional
678
+ Random seed for reproducibility.
679
+
680
+ Returns
681
+ -------
682
+ pd.DataFrame
683
+ Synthetic panel data with columns:
684
+ - unit: Unit identifier
685
+ - period: Time period
686
+ - treated: Treatment indicator (0/1)
687
+ - post: Post-treatment indicator (0/1)
688
+ - outcome: Outcome variable
689
+ - true_effect: The true treatment effect (for validation)
690
+
691
+ Examples
692
+ --------
693
+ Generate simple data for testing:
694
+
695
+ >>> data = generate_did_data(n_units=50, n_periods=4, treatment_effect=3.0, seed=42)
696
+ >>> len(data)
697
+ 200
698
+ >>> data.columns.tolist()
699
+ ['unit', 'period', 'treated', 'post', 'outcome', 'true_effect']
700
+
701
+ Verify treatment effect recovery:
702
+
703
+ >>> from diff_diff import DifferenceInDifferences
704
+ >>> did = DifferenceInDifferences()
705
+ >>> results = did.fit(data, outcome='outcome', treatment='treated', time='post')
706
+ >>> abs(results.att - 3.0) < 1.0 # Close to true effect
707
+ True
708
+ """
709
+ rng = np.random.default_rng(seed)
710
+
711
+ # Determine treated units
712
+ n_treated = int(n_units * treatment_fraction)
713
+ treated_units = set(range(n_treated))
714
+
715
+ # Generate unit fixed effects
716
+ unit_fe = rng.normal(0, unit_fe_sd, n_units)
717
+
718
+ # Build data
719
+ records = []
720
+ for unit in range(n_units):
721
+ is_treated = unit in treated_units
722
+
723
+ for period in range(n_periods):
724
+ is_post = period >= treatment_period
725
+
726
+ # Base outcome
727
+ y = 10.0 # Baseline
728
+ y += unit_fe[unit] # Unit fixed effect
729
+ y += time_trend * period # Time trend
730
+
731
+ # Treatment effect (only for treated units in post-period)
732
+ effect = 0.0
733
+ if is_treated and is_post:
734
+ effect = treatment_effect
735
+ y += effect
736
+
737
+ # Add noise
738
+ y += rng.normal(0, noise_sd)
739
+
740
+ records.append({
741
+ "unit": unit,
742
+ "period": period,
743
+ "treated": int(is_treated),
744
+ "post": int(is_post),
745
+ "outcome": y,
746
+ "true_effect": effect
747
+ })
748
+
749
+ return pd.DataFrame(records)
750
+
751
+
752
+ def create_event_time(
753
+ data: pd.DataFrame,
754
+ time_column: str,
755
+ treatment_time_column: str,
756
+ new_column: str = "event_time"
757
+ ) -> pd.DataFrame:
758
+ """
759
+ Create an event-time column relative to treatment timing.
760
+
761
+ Useful for event study designs where treatment occurs at different
762
+ times for different units.
763
+
764
+ Parameters
765
+ ----------
766
+ data : pd.DataFrame
767
+ Panel data.
768
+ time_column : str
769
+ Name of the calendar time column.
770
+ treatment_time_column : str
771
+ Name of the column indicating when each unit was treated.
772
+ Units with NaN or infinity are considered never-treated.
773
+ new_column : str, default="event_time"
774
+ Name of the new event-time column.
775
+
776
+ Returns
777
+ -------
778
+ pd.DataFrame
779
+ DataFrame with event-time column added. Values are:
780
+ - Negative for pre-treatment periods
781
+ - 0 for the treatment period
782
+ - Positive for post-treatment periods
783
+ - NaN for never-treated units
784
+
785
+ Examples
786
+ --------
787
+ >>> df = pd.DataFrame({
788
+ ... 'unit': [1, 1, 1, 2, 2, 2],
789
+ ... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
790
+ ... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
791
+ ... })
792
+ >>> df = create_event_time(df, 'year', 'treatment_year')
793
+ >>> df['event_time'].tolist()
794
+ [-1, 0, 1, -2, -1, 0]
795
+ """
796
+ df = data.copy()
797
+
798
+ if time_column not in df.columns:
799
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
800
+ if treatment_time_column not in df.columns:
801
+ raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
802
+
803
+ # Calculate event time
804
+ df[new_column] = df[time_column] - df[treatment_time_column]
805
+
806
+ # Handle never-treated (inf or NaN in treatment time)
807
+ col = df[treatment_time_column]
808
+ if pd.api.types.is_numeric_dtype(col):
809
+ never_treated = col.isna() | np.isinf(col)
810
+ else:
811
+ never_treated = col.isna()
812
+ df.loc[never_treated, new_column] = np.nan
813
+
814
+ return df
815
+
816
+
817
+ def aggregate_to_cohorts(
818
+ data: pd.DataFrame,
819
+ unit_column: str,
820
+ time_column: str,
821
+ treatment_column: str,
822
+ outcome: str,
823
+ covariates: Optional[List[str]] = None
824
+ ) -> pd.DataFrame:
825
+ """
826
+ Aggregate unit-level data to treatment cohort means.
827
+
828
+ Useful for visualization and cohort-level analysis.
829
+
830
+ Parameters
831
+ ----------
832
+ data : pd.DataFrame
833
+ Unit-level panel data.
834
+ unit_column : str
835
+ Name of unit identifier column.
836
+ time_column : str
837
+ Name of time period column.
838
+ treatment_column : str
839
+ Name of treatment indicator column.
840
+ outcome : str
841
+ Name of outcome variable column.
842
+ covariates : list of str, optional
843
+ Additional columns to aggregate (will compute means).
844
+
845
+ Returns
846
+ -------
847
+ pd.DataFrame
848
+ Cohort-level data with mean outcomes by treatment status and period.
849
+
850
+ Examples
851
+ --------
852
+ >>> df = pd.DataFrame({
853
+ ... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
854
+ ... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
855
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
856
+ ... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
857
+ ... })
858
+ >>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
859
+ >>> len(cohort_df)
860
+ 4
861
+ """
862
+ agg_cols = {outcome: "mean", unit_column: "nunique"}
863
+
864
+ if covariates:
865
+ for cov in covariates:
866
+ agg_cols[cov] = "mean"
867
+
868
+ cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
869
+
870
+ # Rename columns
871
+ cohort_data = cohort_data.rename(columns={
872
+ unit_column: "n_units",
873
+ outcome: f"mean_{outcome}"
874
+ })
875
+
876
+ return cohort_data
877
+
878
+
879
+ def rank_control_units(
880
+ data: pd.DataFrame,
881
+ unit_column: str,
882
+ time_column: str,
883
+ outcome_column: str,
884
+ treatment_column: Optional[str] = None,
885
+ treated_units: Optional[List[Any]] = None,
886
+ pre_periods: Optional[List[Any]] = None,
887
+ covariates: Optional[List[str]] = None,
888
+ outcome_weight: float = 0.7,
889
+ covariate_weight: float = 0.3,
890
+ exclude_units: Optional[List[Any]] = None,
891
+ require_units: Optional[List[Any]] = None,
892
+ n_top: Optional[int] = None,
893
+ suggest_treatment_candidates: bool = False,
894
+ n_treatment_candidates: int = 5,
895
+ lambda_reg: float = 0.0,
896
+ ) -> pd.DataFrame:
897
+ """
898
+ Rank potential control units by their suitability for DiD analysis.
899
+
900
+ Evaluates control units based on pre-treatment outcome trend similarity
901
+ and optional covariate matching to treated units. Returns a ranked list
902
+ with quality scores.
903
+
904
+ Parameters
905
+ ----------
906
+ data : pd.DataFrame
907
+ Panel data in long format.
908
+ unit_column : str
909
+ Column name for unit identifier.
910
+ time_column : str
911
+ Column name for time periods.
912
+ outcome_column : str
913
+ Column name for outcome variable.
914
+ treatment_column : str, optional
915
+ Column with binary treatment indicator (0/1). Used to identify
916
+ treated units from data.
917
+ treated_units : list, optional
918
+ Explicit list of treated unit IDs. Alternative to treatment_column.
919
+ pre_periods : list, optional
920
+ Pre-treatment periods for comparison. If None, uses first half of periods.
921
+ covariates : list of str, optional
922
+ Covariate columns for matching. Similarity is based on pre-treatment means.
923
+ outcome_weight : float, default=0.7
924
+ Weight for pre-treatment outcome trend similarity (0-1).
925
+ covariate_weight : float, default=0.3
926
+ Weight for covariate distance (0-1). Ignored if no covariates.
927
+ exclude_units : list, optional
928
+ Units that cannot be in control group.
929
+ require_units : list, optional
930
+ Units that must be in control group (will always appear in output).
931
+ n_top : int, optional
932
+ Return only top N control units. If None, return all.
933
+ suggest_treatment_candidates : bool, default=False
934
+ If True and no treated units specified, identify potential treatment
935
+ candidates instead of ranking controls.
936
+ n_treatment_candidates : int, default=5
937
+ Number of treatment candidates to suggest.
938
+ lambda_reg : float, default=0.0
939
+ Regularization for synthetic weights. Higher values give more uniform
940
+ weights across controls.
941
+
942
+ Returns
943
+ -------
944
+ pd.DataFrame
945
+ Ranked control units with columns:
946
+ - unit: Unit identifier
947
+ - quality_score: Combined quality score (0-1, higher is better)
948
+ - outcome_trend_score: Pre-treatment outcome trend similarity
949
+ - covariate_score: Covariate match score (NaN if no covariates)
950
+ - synthetic_weight: Weight from synthetic control optimization
951
+ - pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
952
+ - is_required: Whether unit was in require_units
953
+
954
+ If suggest_treatment_candidates=True (and no treated units):
955
+ - unit: Unit identifier
956
+ - treatment_candidate_score: Suitability as treatment unit
957
+ - avg_outcome_level: Pre-treatment outcome mean
958
+ - outcome_trend: Pre-treatment trend slope
959
+ - n_similar_controls: Count of similar potential controls
960
+
961
+ Examples
962
+ --------
963
+ Rank controls against treated units:
964
+
965
+ >>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
966
+ >>> ranking = rank_control_units(
967
+ ... data,
968
+ ... unit_column='unit',
969
+ ... time_column='period',
970
+ ... outcome_column='outcome',
971
+ ... treatment_column='treated',
972
+ ... n_top=10
973
+ ... )
974
+ >>> ranking['quality_score'].is_monotonic_decreasing
975
+ True
976
+
977
+ With covariates:
978
+
979
+ >>> data['size'] = np.random.randn(len(data))
980
+ >>> ranking = rank_control_units(
981
+ ... data,
982
+ ... unit_column='unit',
983
+ ... time_column='period',
984
+ ... outcome_column='outcome',
985
+ ... treatment_column='treated',
986
+ ... covariates=['size']
987
+ ... )
988
+
989
+ Filter data for SyntheticDiD:
990
+
991
+ >>> top_controls = ranking['unit'].tolist()
992
+ >>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
993
+ """
994
+ # -------------------------------------------------------------------------
995
+ # Input validation
996
+ # -------------------------------------------------------------------------
997
+ for col in [unit_column, time_column, outcome_column]:
998
+ if col not in data.columns:
999
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
1000
+
1001
+ if treatment_column is not None and treatment_column not in data.columns:
1002
+ raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
1003
+
1004
+ if covariates:
1005
+ for cov in covariates:
1006
+ if cov not in data.columns:
1007
+ raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
1008
+
1009
+ if not 0 <= outcome_weight <= 1:
1010
+ raise ValueError("outcome_weight must be between 0 and 1")
1011
+ if not 0 <= covariate_weight <= 1:
1012
+ raise ValueError("covariate_weight must be between 0 and 1")
1013
+
1014
+ if treated_units is not None and treatment_column is not None:
1015
+ raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
1016
+
1017
+ if require_units and exclude_units:
1018
+ invalid_required = [u for u in require_units if u in exclude_units]
1019
+ if invalid_required:
1020
+ raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
1021
+
1022
+ # -------------------------------------------------------------------------
1023
+ # Determine pre-treatment periods
1024
+ # -------------------------------------------------------------------------
1025
+ all_periods = sorted(data[time_column].unique())
1026
+ if pre_periods is None:
1027
+ mid_point = len(all_periods) // 2
1028
+ pre_periods = all_periods[:mid_point]
1029
+ else:
1030
+ pre_periods = list(pre_periods)
1031
+
1032
+ if len(pre_periods) == 0:
1033
+ raise ValueError("No pre-treatment periods specified or inferred.")
1034
+
1035
+ # -------------------------------------------------------------------------
1036
+ # Identify treated and control units
1037
+ # -------------------------------------------------------------------------
1038
+ all_units = list(data[unit_column].unique())
1039
+
1040
+ if treated_units is not None:
1041
+ treated_set = set(treated_units)
1042
+ elif treatment_column is not None:
1043
+ unit_treatment = data.groupby(unit_column)[treatment_column].first()
1044
+ treated_set = set(unit_treatment[unit_treatment == 1].index)
1045
+ elif suggest_treatment_candidates:
1046
+ # Treatment candidate discovery mode - no treated units
1047
+ treated_set = set()
1048
+ else:
1049
+ raise ValueError(
1050
+ "Must specify treated_units, treatment_column, or set "
1051
+ "suggest_treatment_candidates=True"
1052
+ )
1053
+
1054
+ # -------------------------------------------------------------------------
1055
+ # Treatment candidate discovery mode
1056
+ # -------------------------------------------------------------------------
1057
+ if suggest_treatment_candidates and len(treated_set) == 0:
1058
+ return _suggest_treatment_candidates(
1059
+ data, unit_column, time_column, outcome_column,
1060
+ pre_periods, n_treatment_candidates
1061
+ )
1062
+
1063
+ if len(treated_set) == 0:
1064
+ raise ValueError("No treated units found.")
1065
+
1066
+ # Determine control candidates
1067
+ control_candidates = [u for u in all_units if u not in treated_set]
1068
+
1069
+ if exclude_units:
1070
+ control_candidates = [u for u in control_candidates if u not in exclude_units]
1071
+
1072
+ if len(control_candidates) == 0:
1073
+ raise ValueError("No control units available after exclusions.")
1074
+
1075
+ # -------------------------------------------------------------------------
1076
+ # Create outcome matrices (pre-treatment)
1077
+ # -------------------------------------------------------------------------
1078
+ pre_data = data[data[time_column].isin(pre_periods)]
1079
+ pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
1080
+
1081
+ # Filter to pre_periods that exist in data
1082
+ valid_pre_periods = [p for p in pre_periods if p in pivot.index]
1083
+ if len(valid_pre_periods) == 0:
1084
+ raise ValueError("No data found for specified pre-treatment periods.")
1085
+
1086
+ # Filter control_candidates to those present in pivot (handles unbalanced panels)
1087
+ control_candidates = [c for c in control_candidates if c in pivot.columns]
1088
+ if len(control_candidates) == 0:
1089
+ raise ValueError("No control units found in pre-treatment data.")
1090
+
1091
+ # Control outcomes: shape (n_pre_periods, n_control_candidates)
1092
+ Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
1093
+
1094
+ # Treated outcomes mean: shape (n_pre_periods,)
1095
+ treated_list = [u for u in treated_set if u in pivot.columns]
1096
+ if len(treated_list) == 0:
1097
+ raise ValueError("Treated units not found in pre-treatment data.")
1098
+ Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
1099
+
1100
+ # -------------------------------------------------------------------------
1101
+ # Compute outcome trend scores
1102
+ # -------------------------------------------------------------------------
1103
+ # Synthetic weights (higher = better match)
1104
+ synthetic_weights = compute_synthetic_weights(
1105
+ Y_control, Y_treated_mean, lambda_reg=lambda_reg
1106
+ )
1107
+
1108
+ # RMSE for each control vs treated mean (use nanmean to handle missing data)
1109
+ rmse_scores = []
1110
+ for j in range(len(control_candidates)):
1111
+ y_c = Y_control[:, j]
1112
+ rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
1113
+ rmse_scores.append(rmse)
1114
+
1115
+ # Convert RMSE to similarity score (lower RMSE = higher score)
1116
+ max_rmse = max(rmse_scores) if rmse_scores else 1.0
1117
+ min_rmse = min(rmse_scores) if rmse_scores else 0.0
1118
+ rmse_range = max_rmse - min_rmse
1119
+
1120
+ if rmse_range < 1e-10:
1121
+ # All controls have identical/similar pre-trends (includes single control case)
1122
+ outcome_trend_scores = [1.0] * len(rmse_scores)
1123
+ else:
1124
+ # Normalize so best control gets 1.0, worst gets 0.0
1125
+ outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
1126
+
1127
+ # -------------------------------------------------------------------------
1128
+ # Compute covariate scores (if covariates provided)
1129
+ # -------------------------------------------------------------------------
1130
+ if covariates and len(covariates) > 0:
1131
+ # Get unit-level covariate values (pre-treatment mean)
1132
+ cov_data = pre_data.groupby(unit_column)[covariates].mean()
1133
+
1134
+ # Treated covariate profile (mean across treated units)
1135
+ treated_cov = cov_data.loc[list(treated_set)].mean()
1136
+
1137
+ # Standardize covariates
1138
+ cov_mean = cov_data.mean()
1139
+ cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
1140
+ cov_standardized = (cov_data - cov_mean) / cov_std
1141
+ treated_cov_std = (treated_cov - cov_mean) / cov_std
1142
+
1143
+ # Euclidean distance in standardized space (vectorized)
1144
+ control_cov_matrix = cov_standardized.loc[control_candidates].values
1145
+ treated_cov_vector = treated_cov_std.values
1146
+ covariate_distances = np.sqrt(
1147
+ np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
1148
+ )
1149
+
1150
+ # Convert distance to similarity score (min-max normalization)
1151
+ max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
1152
+ min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
1153
+ dist_range = max_dist - min_dist
1154
+
1155
+ if dist_range < 1e-10:
1156
+ # All controls have identical/similar covariate profiles
1157
+ covariate_scores = [1.0] * len(covariate_distances)
1158
+ else:
1159
+ # Normalize so best control (closest) gets 1.0, worst gets 0.0
1160
+ covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
1161
+ else:
1162
+ covariate_scores = [np.nan] * len(control_candidates)
1163
+
1164
+ # -------------------------------------------------------------------------
1165
+ # Compute combined quality score
1166
+ # -------------------------------------------------------------------------
1167
+ # Normalize weights
1168
+ total_weight = outcome_weight + covariate_weight
1169
+ if total_weight > 0:
1170
+ norm_outcome_weight = outcome_weight / total_weight
1171
+ norm_covariate_weight = covariate_weight / total_weight
1172
+ else:
1173
+ norm_outcome_weight = 1.0
1174
+ norm_covariate_weight = 0.0
1175
+
1176
+ quality_scores = []
1177
+ for i in range(len(control_candidates)):
1178
+ outcome_score = outcome_trend_scores[i]
1179
+ cov_score = covariate_scores[i]
1180
+
1181
+ if np.isnan(cov_score):
1182
+ # No covariates - use only outcome score
1183
+ combined = outcome_score
1184
+ else:
1185
+ combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
1186
+
1187
+ quality_scores.append(combined)
1188
+
1189
+ # -------------------------------------------------------------------------
1190
+ # Build result DataFrame
1191
+ # -------------------------------------------------------------------------
1192
+ require_set = set(require_units) if require_units else set()
1193
+
1194
+ result = pd.DataFrame({
1195
+ 'unit': control_candidates,
1196
+ 'quality_score': quality_scores,
1197
+ 'outcome_trend_score': outcome_trend_scores,
1198
+ 'covariate_score': covariate_scores,
1199
+ 'synthetic_weight': synthetic_weights,
1200
+ 'pre_trend_rmse': rmse_scores,
1201
+ 'is_required': [u in require_set for u in control_candidates]
1202
+ })
1203
+
1204
+ # Sort by quality score (descending)
1205
+ result = result.sort_values('quality_score', ascending=False)
1206
+
1207
+ # Apply n_top limit if specified
1208
+ if n_top is not None and n_top < len(result):
1209
+ # Always include required units
1210
+ required_df = result[result['is_required']]
1211
+ non_required_df = result[~result['is_required']]
1212
+
1213
+ # Take top from non-required to fill remaining slots
1214
+ remaining_slots = max(0, n_top - len(required_df))
1215
+ top_non_required = non_required_df.head(remaining_slots)
1216
+
1217
+ result = pd.concat([required_df, top_non_required])
1218
+ result = result.sort_values('quality_score', ascending=False)
1219
+
1220
+ return result.reset_index(drop=True)
1221
+
1222
+
1223
+ def _suggest_treatment_candidates(
1224
+ data: pd.DataFrame,
1225
+ unit_column: str,
1226
+ time_column: str,
1227
+ outcome_column: str,
1228
+ pre_periods: List[Any],
1229
+ n_candidates: int
1230
+ ) -> pd.DataFrame:
1231
+ """
1232
+ Identify units that would make good treatment candidates.
1233
+
1234
+ A good treatment candidate:
1235
+ 1. Has many similar control units available (for matching)
1236
+ 2. Has stable pre-treatment trends (predictable counterfactual)
1237
+ 3. Is not an extreme outlier
1238
+
1239
+ Parameters
1240
+ ----------
1241
+ data : pd.DataFrame
1242
+ Panel data.
1243
+ unit_column : str
1244
+ Unit identifier column.
1245
+ time_column : str
1246
+ Time period column.
1247
+ outcome_column : str
1248
+ Outcome variable column.
1249
+ pre_periods : list
1250
+ Pre-treatment periods.
1251
+ n_candidates : int
1252
+ Number of candidates to return.
1253
+
1254
+ Returns
1255
+ -------
1256
+ pd.DataFrame
1257
+ Treatment candidates with scores.
1258
+ """
1259
+ all_units = list(data[unit_column].unique())
1260
+ pre_data = data[data[time_column].isin(pre_periods)]
1261
+
1262
+ candidate_info = []
1263
+
1264
+ for unit in all_units:
1265
+ unit_data = pre_data[pre_data[unit_column] == unit]
1266
+
1267
+ if len(unit_data) == 0:
1268
+ continue
1269
+
1270
+ # Average outcome level
1271
+ avg_outcome = unit_data[outcome_column].mean()
1272
+
1273
+ # Trend (simple linear regression slope)
1274
+ times = unit_data[time_column].values
1275
+ outcomes = unit_data[outcome_column].values
1276
+ if len(times) > 1:
1277
+ times_norm = np.arange(len(times))
1278
+ try:
1279
+ slope = np.polyfit(times_norm, outcomes, 1)[0]
1280
+ except (np.linalg.LinAlgError, ValueError):
1281
+ slope = 0.0
1282
+ else:
1283
+ slope = 0.0
1284
+
1285
+ # Count similar potential controls
1286
+ other_units = [u for u in all_units if u != unit]
1287
+ other_means = pre_data[
1288
+ pre_data[unit_column].isin(other_units)
1289
+ ].groupby(unit_column)[outcome_column].mean()
1290
+
1291
+ if len(other_means) > 0:
1292
+ sd = other_means.std()
1293
+ if sd > 0:
1294
+ n_similar = int(np.sum(
1295
+ np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd
1296
+ ))
1297
+ else:
1298
+ n_similar = len(other_means)
1299
+ else:
1300
+ n_similar = 0
1301
+
1302
+ candidate_info.append({
1303
+ 'unit': unit,
1304
+ 'avg_outcome_level': avg_outcome,
1305
+ 'outcome_trend': slope,
1306
+ 'n_similar_controls': n_similar
1307
+ })
1308
+
1309
+ if len(candidate_info) == 0:
1310
+ return pd.DataFrame(columns=[
1311
+ 'unit', 'treatment_candidate_score', 'avg_outcome_level',
1312
+ 'outcome_trend', 'n_similar_controls'
1313
+ ])
1314
+
1315
+ result = pd.DataFrame(candidate_info)
1316
+
1317
+ # Score: prefer units with many similar controls and moderate outcome levels
1318
+ max_similar = result['n_similar_controls'].max()
1319
+ if max_similar > 0:
1320
+ similarity_score = result['n_similar_controls'] / max_similar
1321
+ else:
1322
+ similarity_score = pd.Series([0.0] * len(result))
1323
+
1324
+ # Penalty for outliers in outcome level
1325
+ outcome_mean = result['avg_outcome_level'].mean()
1326
+ outcome_std = result['avg_outcome_level'].std()
1327
+ if outcome_std > 0:
1328
+ outcome_z = np.abs((result['avg_outcome_level'] - outcome_mean) / outcome_std)
1329
+ else:
1330
+ outcome_z = pd.Series([0.0] * len(result))
1331
+
1332
+ result['treatment_candidate_score'] = (
1333
+ similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
1334
+ ).clip(0, 1)
1335
+
1336
+ # Return top candidates
1337
+ result = result.nlargest(n_candidates, 'treatment_candidate_score')
1338
+ return result.reset_index(drop=True)