diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/prep.py ADDED
@@ -0,0 +1,1241 @@
1
+ """
2
+ Data preparation utilities for difference-in-differences analysis.
3
+
4
+ This module provides helper functions to prepare data for DiD estimation,
5
+ including creating treatment indicators, reshaping panel data, and
6
+ generating synthetic datasets for testing.
7
+
8
+ Data generation functions (generate_*) are defined in prep_dgp.py and
9
+ re-exported here for backward compatibility.
10
+ """
11
+
12
+ from typing import Any, Dict, List, Optional, Union
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from diff_diff.utils import compute_synthetic_weights
18
+
19
+ # Re-export data generation functions from prep_dgp for backward compatibility
20
+ from diff_diff.prep_dgp import (
21
+ generate_did_data,
22
+ generate_staggered_data,
23
+ generate_factor_data,
24
+ generate_ddd_data,
25
+ generate_panel_data,
26
+ generate_event_study_data,
27
+ )
28
+
29
+ # Constants for rank_control_units
30
+ _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
31
+ _OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
32
+
33
+
34
+ def make_treatment_indicator(
35
+ data: pd.DataFrame,
36
+ column: str,
37
+ treated_values: Optional[Union[Any, List[Any]]] = None,
38
+ threshold: Optional[float] = None,
39
+ above_threshold: bool = True,
40
+ new_column: str = "treated"
41
+ ) -> pd.DataFrame:
42
+ """
43
+ Create a binary treatment indicator column from various input types.
44
+
45
+ Parameters
46
+ ----------
47
+ data : pd.DataFrame
48
+ Input DataFrame.
49
+ column : str
50
+ Name of the column to use for creating the treatment indicator.
51
+ treated_values : Any or list, optional
52
+ Value(s) that indicate treatment. Units with these values get
53
+ treatment=1, others get treatment=0.
54
+ threshold : float, optional
55
+ Numeric threshold for creating treatment. Used when the treatment
56
+ is based on a continuous variable (e.g., treat firms above median size).
57
+ above_threshold : bool, default=True
58
+ If True, values >= threshold are treated. If False, values <= threshold
59
+ are treated. Only used when threshold is specified.
60
+ new_column : str, default="treated"
61
+ Name of the new treatment indicator column.
62
+
63
+ Returns
64
+ -------
65
+ pd.DataFrame
66
+ DataFrame with the new treatment indicator column added.
67
+
68
+ Examples
69
+ --------
70
+ Create treatment from categorical variable:
71
+
72
+ >>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
73
+ >>> df = make_treatment_indicator(df, 'group', treated_values='A')
74
+ >>> df['treated'].tolist()
75
+ [1, 1, 0, 0]
76
+
77
+ Create treatment from numeric threshold:
78
+
79
+ >>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
80
+ >>> df = make_treatment_indicator(df, 'size', threshold=75)
81
+ >>> df['treated'].tolist()
82
+ [0, 0, 1, 1]
83
+
84
+ Treat units below a threshold:
85
+
86
+ >>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
87
+ >>> df['treated'].tolist()
88
+ [1, 1, 0, 0]
89
+ """
90
+ df = data.copy()
91
+
92
+ if treated_values is not None and threshold is not None:
93
+ raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
94
+
95
+ if treated_values is None and threshold is None:
96
+ raise ValueError("Must specify either 'treated_values' or 'threshold'.")
97
+
98
+ if column not in df.columns:
99
+ raise ValueError(f"Column '{column}' not found in DataFrame.")
100
+
101
+ if treated_values is not None:
102
+ # Convert single value to list
103
+ if not isinstance(treated_values, (list, tuple, set)):
104
+ treated_values = [treated_values]
105
+ df[new_column] = df[column].isin(treated_values).astype(int)
106
+ else:
107
+ # Use threshold
108
+ if above_threshold:
109
+ df[new_column] = (df[column] >= threshold).astype(int)
110
+ else:
111
+ df[new_column] = (df[column] <= threshold).astype(int)
112
+
113
+ return df
114
+
115
+
116
+ def make_post_indicator(
117
+ data: pd.DataFrame,
118
+ time_column: str,
119
+ post_periods: Optional[Union[Any, List[Any]]] = None,
120
+ treatment_start: Optional[Any] = None,
121
+ new_column: str = "post"
122
+ ) -> pd.DataFrame:
123
+ """
124
+ Create a binary post-treatment indicator column.
125
+
126
+ Parameters
127
+ ----------
128
+ data : pd.DataFrame
129
+ Input DataFrame.
130
+ time_column : str
131
+ Name of the time/period column.
132
+ post_periods : Any or list, optional
133
+ Specific period value(s) that are post-treatment. Periods matching
134
+ these values get post=1, others get post=0.
135
+ treatment_start : Any, optional
136
+ The first post-treatment period. All periods >= this value get post=1.
137
+ Works with numeric periods, strings (sorted alphabetically), or dates.
138
+ new_column : str, default="post"
139
+ Name of the new post indicator column.
140
+
141
+ Returns
142
+ -------
143
+ pd.DataFrame
144
+ DataFrame with the new post indicator column added.
145
+
146
+ Examples
147
+ --------
148
+ Using specific post periods:
149
+
150
+ >>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
151
+ >>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
152
+ >>> df['post'].tolist()
153
+ [0, 0, 1, 1]
154
+
155
+ Using treatment start:
156
+
157
+ >>> df = make_post_indicator(df, 'year', treatment_start=2020)
158
+ >>> df['post'].tolist()
159
+ [0, 0, 1, 1]
160
+
161
+ Works with date columns:
162
+
163
+ >>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
164
+ >>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
165
+ """
166
+ df = data.copy()
167
+
168
+ if post_periods is not None and treatment_start is not None:
169
+ raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
170
+
171
+ if post_periods is None and treatment_start is None:
172
+ raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
173
+
174
+ if time_column not in df.columns:
175
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
176
+
177
+ if post_periods is not None:
178
+ # Convert single value to list
179
+ if not isinstance(post_periods, (list, tuple, set)):
180
+ post_periods = [post_periods]
181
+ df[new_column] = df[time_column].isin(post_periods).astype(int)
182
+ else:
183
+ # Use treatment_start - convert to same type as column if needed
184
+ col_dtype = df[time_column].dtype
185
+ if pd.api.types.is_datetime64_any_dtype(col_dtype):
186
+ treatment_start = pd.to_datetime(treatment_start)
187
+ df[new_column] = (df[time_column] >= treatment_start).astype(int)
188
+
189
+ return df
190
+
191
+
192
+ def wide_to_long(
193
+ data: pd.DataFrame,
194
+ value_columns: List[str],
195
+ id_column: str,
196
+ time_name: str = "period",
197
+ value_name: str = "value",
198
+ time_values: Optional[List[Any]] = None
199
+ ) -> pd.DataFrame:
200
+ """
201
+ Convert wide-format panel data to long format for DiD analysis.
202
+
203
+ Wide format has one row per unit with multiple columns for each time period.
204
+ Long format has one row per unit-period combination.
205
+
206
+ Parameters
207
+ ----------
208
+ data : pd.DataFrame
209
+ Wide-format DataFrame with one row per unit.
210
+ value_columns : list of str
211
+ Column names containing the outcome values for each period.
212
+ These should be in chronological order.
213
+ id_column : str
214
+ Column name for the unit identifier.
215
+ time_name : str, default="period"
216
+ Name for the new time period column.
217
+ value_name : str, default="value"
218
+ Name for the new value/outcome column.
219
+ time_values : list, optional
220
+ Values to use for time periods. If None, uses 0, 1, 2, ...
221
+ Must have same length as value_columns.
222
+
223
+ Returns
224
+ -------
225
+ pd.DataFrame
226
+ Long-format DataFrame with one row per unit-period.
227
+
228
+ Examples
229
+ --------
230
+ >>> wide_df = pd.DataFrame({
231
+ ... 'firm_id': [1, 2, 3],
232
+ ... 'sales_2019': [100, 150, 200],
233
+ ... 'sales_2020': [110, 160, 210],
234
+ ... 'sales_2021': [120, 170, 220]
235
+ ... })
236
+ >>> long_df = wide_to_long(
237
+ ... wide_df,
238
+ ... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
239
+ ... id_column='firm_id',
240
+ ... time_name='year',
241
+ ... value_name='sales',
242
+ ... time_values=[2019, 2020, 2021]
243
+ ... )
244
+ >>> len(long_df)
245
+ 9
246
+ >>> long_df.columns.tolist()
247
+ ['firm_id', 'year', 'sales']
248
+ """
249
+ if not value_columns:
250
+ raise ValueError("value_columns cannot be empty.")
251
+
252
+ if id_column not in data.columns:
253
+ raise ValueError(f"Column '{id_column}' not found in DataFrame.")
254
+
255
+ for col in value_columns:
256
+ if col not in data.columns:
257
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
258
+
259
+ if time_values is None:
260
+ time_values = list(range(len(value_columns)))
261
+
262
+ if len(time_values) != len(value_columns):
263
+ raise ValueError(
264
+ f"time_values length ({len(time_values)}) must match "
265
+ f"value_columns length ({len(value_columns)})."
266
+ )
267
+
268
+ # Get other columns to preserve (not id or value columns)
269
+ other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
270
+
271
+ # Use pd.melt for better performance (vectorized)
272
+ long_df = pd.melt(
273
+ data,
274
+ id_vars=[id_column] + other_cols,
275
+ value_vars=value_columns,
276
+ var_name='_temp_var',
277
+ value_name=value_name
278
+ )
279
+
280
+ # Map column names to time values
281
+ col_to_time = dict(zip(value_columns, time_values))
282
+ long_df[time_name] = long_df['_temp_var'].map(col_to_time)
283
+ long_df = long_df.drop('_temp_var', axis=1)
284
+
285
+ # Reorder columns and sort
286
+ cols = [id_column, time_name, value_name] + other_cols
287
+ return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
288
+
289
+
290
+ def balance_panel(
291
+ data: pd.DataFrame,
292
+ unit_column: str,
293
+ time_column: str,
294
+ method: str = "inner",
295
+ fill_value: Optional[float] = None
296
+ ) -> pd.DataFrame:
297
+ """
298
+ Balance a panel dataset to ensure all units have all time periods.
299
+
300
+ Parameters
301
+ ----------
302
+ data : pd.DataFrame
303
+ Unbalanced panel data.
304
+ unit_column : str
305
+ Column name for unit identifier.
306
+ time_column : str
307
+ Column name for time period.
308
+ method : str, default="inner"
309
+ Balancing method:
310
+ - "inner": Keep only units that appear in all periods (drops units)
311
+ - "outer": Include all unit-period combinations (creates NaN)
312
+ - "fill": Include all combinations and fill missing values
313
+ fill_value : float, optional
314
+ Value to fill missing observations when method="fill".
315
+ If None with method="fill", uses column-specific forward fill.
316
+
317
+ Returns
318
+ -------
319
+ pd.DataFrame
320
+ Balanced panel DataFrame.
321
+
322
+ Examples
323
+ --------
324
+ Keep only complete units:
325
+
326
+ >>> df = pd.DataFrame({
327
+ ... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
328
+ ... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
329
+ ... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
330
+ ... })
331
+ >>> balanced = balance_panel(df, 'unit', 'period', method='inner')
332
+ >>> balanced['unit'].unique().tolist()
333
+ [1, 3]
334
+
335
+ Include all combinations:
336
+
337
+ >>> balanced = balance_panel(df, 'unit', 'period', method='outer')
338
+ >>> len(balanced)
339
+ 9
340
+ """
341
+ if unit_column not in data.columns:
342
+ raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
343
+ if time_column not in data.columns:
344
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
345
+
346
+ if method not in ["inner", "outer", "fill"]:
347
+ raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
348
+
349
+ all_units = data[unit_column].unique()
350
+ all_periods = sorted(data[time_column].unique())
351
+ n_periods = len(all_periods)
352
+
353
+ if method == "inner":
354
+ # Keep only units that have all periods
355
+ unit_counts = data.groupby(unit_column)[time_column].nunique()
356
+ complete_units = unit_counts[unit_counts == n_periods].index
357
+ return data[data[unit_column].isin(complete_units)].copy()
358
+
359
+ elif method in ["outer", "fill"]:
360
+ # Create full grid of unit-period combinations
361
+ full_index = pd.MultiIndex.from_product(
362
+ [all_units, all_periods],
363
+ names=[unit_column, time_column]
364
+ )
365
+ full_df = pd.DataFrame(index=full_index).reset_index()
366
+
367
+ # Merge with original data
368
+ result = full_df.merge(data, on=[unit_column, time_column], how="left")
369
+
370
+ if method == "fill":
371
+ # Identify columns to fill (exclude unit and time columns)
372
+ cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
373
+
374
+ if fill_value is not None:
375
+ # Fill specified columns with fill_value
376
+ numeric_cols = result.select_dtypes(include=[np.number]).columns
377
+ for col in numeric_cols:
378
+ if col in cols_to_fill:
379
+ result[col] = result[col].fillna(fill_value)
380
+ else:
381
+ # Forward fill within each unit for non-key columns
382
+ result = result.sort_values([unit_column, time_column])
383
+ result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
384
+ # Backward fill any remaining NaN at start
385
+ result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
386
+
387
+ return result
388
+
389
+ return data
390
+
391
+
392
+ def validate_did_data(
393
+ data: pd.DataFrame,
394
+ outcome: str,
395
+ treatment: str,
396
+ time: str,
397
+ unit: Optional[str] = None,
398
+ raise_on_error: bool = True
399
+ ) -> Dict[str, Any]:
400
+ """
401
+ Validate that data is properly formatted for DiD analysis.
402
+
403
+ Checks for common data issues and provides informative error messages.
404
+
405
+ Parameters
406
+ ----------
407
+ data : pd.DataFrame
408
+ Data to validate.
409
+ outcome : str
410
+ Name of outcome variable column.
411
+ treatment : str
412
+ Name of treatment indicator column.
413
+ time : str
414
+ Name of time/post indicator column.
415
+ unit : str, optional
416
+ Name of unit identifier column (for panel data validation).
417
+ raise_on_error : bool, default=True
418
+ If True, raises ValueError on validation failures.
419
+ If False, returns validation results without raising.
420
+
421
+ Returns
422
+ -------
423
+ dict
424
+ Validation results with keys:
425
+ - valid: bool indicating if data passed all checks
426
+ - errors: list of error messages
427
+ - warnings: list of warning messages
428
+ - summary: dict with data summary statistics
429
+
430
+ Examples
431
+ --------
432
+ >>> df = pd.DataFrame({
433
+ ... 'y': [1, 2, 3, 4],
434
+ ... 'treated': [0, 0, 1, 1],
435
+ ... 'post': [0, 1, 0, 1]
436
+ ... })
437
+ >>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
438
+ >>> result['valid']
439
+ True
440
+ """
441
+ errors = []
442
+ warnings = []
443
+
444
+ # Check columns exist
445
+ required_cols = [outcome, treatment, time]
446
+ if unit is not None:
447
+ required_cols.append(unit)
448
+
449
+ for col in required_cols:
450
+ if col not in data.columns:
451
+ errors.append(f"Required column '{col}' not found in DataFrame.")
452
+
453
+ if errors:
454
+ if raise_on_error:
455
+ raise ValueError("\n".join(errors))
456
+ return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
457
+
458
+ # Check outcome is numeric
459
+ if not pd.api.types.is_numeric_dtype(data[outcome]):
460
+ errors.append(
461
+ f"Outcome column '{outcome}' must be numeric. "
462
+ f"Got type: {data[outcome].dtype}"
463
+ )
464
+
465
+ # Check treatment is binary
466
+ treatment_vals = data[treatment].dropna().unique()
467
+ if not set(treatment_vals).issubset({0, 1}):
468
+ errors.append(
469
+ f"Treatment column '{treatment}' must be binary (0 or 1). "
470
+ f"Found values: {sorted(treatment_vals)}"
471
+ )
472
+
473
+ # Check time is binary for simple DiD
474
+ time_vals = data[time].dropna().unique()
475
+ if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
476
+ warnings.append(
477
+ f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
478
+ "For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
479
+ )
480
+
481
+ # Check for missing values
482
+ for col in required_cols:
483
+ n_missing = data[col].isna().sum()
484
+ if n_missing > 0:
485
+ errors.append(
486
+ f"Column '{col}' has {n_missing} missing values. "
487
+ "Please handle missing data before fitting."
488
+ )
489
+
490
+ # Calculate summary statistics
491
+ summary = {}
492
+ if not errors:
493
+ summary["n_obs"] = len(data)
494
+ summary["n_treated"] = int((data[treatment] == 1).sum())
495
+ summary["n_control"] = int((data[treatment] == 0).sum())
496
+ summary["n_periods"] = len(time_vals)
497
+
498
+ if unit is not None:
499
+ summary["n_units"] = data[unit].nunique()
500
+
501
+ # Check for sufficient variation
502
+ if summary["n_treated"] == 0:
503
+ errors.append("No treated observations found (treatment column is all 0).")
504
+ if summary["n_control"] == 0:
505
+ errors.append("No control observations found (treatment column is all 1).")
506
+
507
+ # Check for each treatment-time combination
508
+ if len(time_vals) == 2:
509
+ # For 2-period DiD, check all four cells
510
+ for t_val in [0, 1]:
511
+ for p_val in time_vals:
512
+ count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
513
+ if count == 0:
514
+ errors.append(
515
+ f"No observations for treatment={t_val}, time={p_val}. "
516
+ "DiD requires observations in all treatment-time cells."
517
+ )
518
+ else:
519
+ # For multi-period, check that both treatment groups exist in multiple periods
520
+ for t_val in [0, 1]:
521
+ n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
522
+ if n_periods_with_obs < 2:
523
+ group_name = "Treated" if t_val == 1 else "Control"
524
+ errors.append(
525
+ f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
526
+ "DiD requires multiple periods per group."
527
+ )
528
+
529
+ # Panel-specific validation
530
+ if unit is not None and not errors:
531
+ # Check treatment is constant within units
532
+ unit_treatment_var = data.groupby(unit)[treatment].nunique()
533
+ units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
534
+ if len(units_with_varying_treatment) > 0:
535
+ warnings.append(
536
+ f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
537
+ "For standard DiD, treatment should be constant within units. "
538
+ "This may be intentional for staggered adoption designs."
539
+ )
540
+
541
+ # Check panel balance
542
+ periods_per_unit = data.groupby(unit)[time].nunique()
543
+ if periods_per_unit.min() != periods_per_unit.max():
544
+ warnings.append(
545
+ f"Unbalanced panel detected. Units have between "
546
+ f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
547
+ "Consider using balance_panel() to balance the data."
548
+ )
549
+
550
+ valid = len(errors) == 0
551
+
552
+ if raise_on_error and not valid:
553
+ raise ValueError("Data validation failed:\n" + "\n".join(errors))
554
+
555
+ return {
556
+ "valid": valid,
557
+ "errors": errors,
558
+ "warnings": warnings,
559
+ "summary": summary
560
+ }
561
+
562
+
563
+ def summarize_did_data(
564
+ data: pd.DataFrame,
565
+ outcome: str,
566
+ treatment: str,
567
+ time: str,
568
+ unit: Optional[str] = None
569
+ ) -> pd.DataFrame:
570
+ """
571
+ Generate summary statistics by treatment group and time period.
572
+
573
+ Parameters
574
+ ----------
575
+ data : pd.DataFrame
576
+ Input data.
577
+ outcome : str
578
+ Name of outcome variable column.
579
+ treatment : str
580
+ Name of treatment indicator column.
581
+ time : str
582
+ Name of time/period column.
583
+ unit : str, optional
584
+ Name of unit identifier column.
585
+
586
+ Returns
587
+ -------
588
+ pd.DataFrame
589
+ Summary statistics with columns for each treatment-time combination.
590
+
591
+ Examples
592
+ --------
593
+ >>> df = pd.DataFrame({
594
+ ... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
595
+ ... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
596
+ ... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
597
+ ... })
598
+ >>> summary = summarize_did_data(df, 'y', 'treated', 'post')
599
+ >>> print(summary)
600
+ """
601
+ # Group by treatment and time
602
+ summary = data.groupby([treatment, time])[outcome].agg([
603
+ ("n", "count"),
604
+ ("mean", "mean"),
605
+ ("std", "std"),
606
+ ("min", "min"),
607
+ ("max", "max")
608
+ ]).round(4)
609
+
610
+ # Calculate time values for labeling
611
+ time_vals = sorted(data[time].unique())
612
+
613
+ # Add group labels based on sorted time values (not literal 0/1)
614
+ if len(time_vals) == 2:
615
+ pre_val, post_val = time_vals[0], time_vals[1]
616
+
617
+ def format_label(x: tuple) -> str:
618
+ treatment_label = 'Treated' if x[0] == 1 else 'Control'
619
+ time_label = 'Post' if x[1] == post_val else 'Pre'
620
+ return f"{treatment_label} - {time_label}"
621
+
622
+ summary.index = summary.index.map(format_label)
623
+
624
+ # Calculate means for each cell
625
+ treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
626
+ treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
627
+ control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
628
+ control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
629
+
630
+ # Calculate DiD
631
+ treated_diff = treated_post - treated_pre
632
+ control_diff = control_post - control_pre
633
+ did_estimate = treated_diff - control_diff
634
+
635
+ # Add to summary as a new row
636
+ did_row = pd.DataFrame(
637
+ {
638
+ "n": ["-"],
639
+ "mean": [did_estimate],
640
+ "std": ["-"],
641
+ "min": ["-"],
642
+ "max": ["-"]
643
+ },
644
+ index=["DiD Estimate"]
645
+ )
646
+ summary = pd.concat([summary, did_row])
647
+ else:
648
+ summary.index = summary.index.map(
649
+ lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
650
+ )
651
+
652
+ return summary
653
+
654
+
655
+ def create_event_time(
656
+ data: pd.DataFrame,
657
+ time_column: str,
658
+ treatment_time_column: str,
659
+ new_column: str = "event_time"
660
+ ) -> pd.DataFrame:
661
+ """
662
+ Create an event-time column relative to treatment timing.
663
+
664
+ Useful for event study designs where treatment occurs at different
665
+ times for different units.
666
+
667
+ Parameters
668
+ ----------
669
+ data : pd.DataFrame
670
+ Panel data.
671
+ time_column : str
672
+ Name of the calendar time column.
673
+ treatment_time_column : str
674
+ Name of the column indicating when each unit was treated.
675
+ Units with NaN or infinity are considered never-treated.
676
+ new_column : str, default="event_time"
677
+ Name of the new event-time column.
678
+
679
+ Returns
680
+ -------
681
+ pd.DataFrame
682
+ DataFrame with event-time column added. Values are:
683
+ - Negative for pre-treatment periods
684
+ - 0 for the treatment period
685
+ - Positive for post-treatment periods
686
+ - NaN for never-treated units
687
+
688
+ Examples
689
+ --------
690
+ >>> df = pd.DataFrame({
691
+ ... 'unit': [1, 1, 1, 2, 2, 2],
692
+ ... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
693
+ ... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
694
+ ... })
695
+ >>> df = create_event_time(df, 'year', 'treatment_year')
696
+ >>> df['event_time'].tolist()
697
+ [-1, 0, 1, -2, -1, 0]
698
+ """
699
+ df = data.copy()
700
+
701
+ if time_column not in df.columns:
702
+ raise ValueError(f"Column '{time_column}' not found in DataFrame.")
703
+ if treatment_time_column not in df.columns:
704
+ raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
705
+
706
+ # Calculate event time
707
+ df[new_column] = df[time_column] - df[treatment_time_column]
708
+
709
+ # Handle never-treated (inf or NaN in treatment time)
710
+ col = df[treatment_time_column]
711
+ if pd.api.types.is_numeric_dtype(col):
712
+ never_treated = col.isna() | np.isinf(col)
713
+ else:
714
+ never_treated = col.isna()
715
+ df.loc[never_treated, new_column] = np.nan
716
+
717
+ return df
718
+
719
+
720
+ def aggregate_to_cohorts(
721
+ data: pd.DataFrame,
722
+ unit_column: str,
723
+ time_column: str,
724
+ treatment_column: str,
725
+ outcome: str,
726
+ covariates: Optional[List[str]] = None
727
+ ) -> pd.DataFrame:
728
+ """
729
+ Aggregate unit-level data to treatment cohort means.
730
+
731
+ Useful for visualization and cohort-level analysis.
732
+
733
+ Parameters
734
+ ----------
735
+ data : pd.DataFrame
736
+ Unit-level panel data.
737
+ unit_column : str
738
+ Name of unit identifier column.
739
+ time_column : str
740
+ Name of time period column.
741
+ treatment_column : str
742
+ Name of treatment indicator column.
743
+ outcome : str
744
+ Name of outcome variable column.
745
+ covariates : list of str, optional
746
+ Additional columns to aggregate (will compute means).
747
+
748
+ Returns
749
+ -------
750
+ pd.DataFrame
751
+ Cohort-level data with mean outcomes by treatment status and period.
752
+
753
+ Examples
754
+ --------
755
+ >>> df = pd.DataFrame({
756
+ ... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
757
+ ... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
758
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
759
+ ... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
760
+ ... })
761
+ >>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
762
+ >>> len(cohort_df)
763
+ 4
764
+ """
765
+ agg_cols = {outcome: "mean", unit_column: "nunique"}
766
+
767
+ if covariates:
768
+ for cov in covariates:
769
+ agg_cols[cov] = "mean"
770
+
771
+ cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
772
+
773
+ # Rename columns
774
+ cohort_data = cohort_data.rename(columns={
775
+ unit_column: "n_units",
776
+ outcome: f"mean_{outcome}"
777
+ })
778
+
779
+ return cohort_data
780
+
781
+
782
+ def rank_control_units(
783
+ data: pd.DataFrame,
784
+ unit_column: str,
785
+ time_column: str,
786
+ outcome_column: str,
787
+ treatment_column: Optional[str] = None,
788
+ treated_units: Optional[List[Any]] = None,
789
+ pre_periods: Optional[List[Any]] = None,
790
+ covariates: Optional[List[str]] = None,
791
+ outcome_weight: float = 0.7,
792
+ covariate_weight: float = 0.3,
793
+ exclude_units: Optional[List[Any]] = None,
794
+ require_units: Optional[List[Any]] = None,
795
+ n_top: Optional[int] = None,
796
+ suggest_treatment_candidates: bool = False,
797
+ n_treatment_candidates: int = 5,
798
+ lambda_reg: float = 0.0,
799
+ ) -> pd.DataFrame:
800
+ """
801
+ Rank potential control units by their suitability for DiD analysis.
802
+
803
+ Evaluates control units based on pre-treatment outcome trend similarity
804
+ and optional covariate matching to treated units. Returns a ranked list
805
+ with quality scores.
806
+
807
+ Parameters
808
+ ----------
809
+ data : pd.DataFrame
810
+ Panel data in long format.
811
+ unit_column : str
812
+ Column name for unit identifier.
813
+ time_column : str
814
+ Column name for time periods.
815
+ outcome_column : str
816
+ Column name for outcome variable.
817
+ treatment_column : str, optional
818
+ Column with binary treatment indicator (0/1). Used to identify
819
+ treated units from data.
820
+ treated_units : list, optional
821
+ Explicit list of treated unit IDs. Alternative to treatment_column.
822
+ pre_periods : list, optional
823
+ Pre-treatment periods for comparison. If None, uses first half of periods.
824
+ covariates : list of str, optional
825
+ Covariate columns for matching. Similarity is based on pre-treatment means.
826
+ outcome_weight : float, default=0.7
827
+ Weight for pre-treatment outcome trend similarity (0-1).
828
+ covariate_weight : float, default=0.3
829
+ Weight for covariate distance (0-1). Ignored if no covariates.
830
+ exclude_units : list, optional
831
+ Units that cannot be in control group.
832
+ require_units : list, optional
833
+ Units that must be in control group (will always appear in output).
834
+ n_top : int, optional
835
+ Return only top N control units. If None, return all.
836
+ suggest_treatment_candidates : bool, default=False
837
+ If True and no treated units specified, identify potential treatment
838
+ candidates instead of ranking controls.
839
+ n_treatment_candidates : int, default=5
840
+ Number of treatment candidates to suggest.
841
+ lambda_reg : float, default=0.0
842
+ Regularization for synthetic weights. Higher values give more uniform
843
+ weights across controls.
844
+
845
+ Returns
846
+ -------
847
+ pd.DataFrame
848
+ Ranked control units with columns:
849
+ - unit: Unit identifier
850
+ - quality_score: Combined quality score (0-1, higher is better)
851
+ - outcome_trend_score: Pre-treatment outcome trend similarity
852
+ - covariate_score: Covariate match score (NaN if no covariates)
853
+ - synthetic_weight: Weight from synthetic control optimization
854
+ - pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
855
+ - is_required: Whether unit was in require_units
856
+
857
+ If suggest_treatment_candidates=True (and no treated units):
858
+ - unit: Unit identifier
859
+ - treatment_candidate_score: Suitability as treatment unit
860
+ - avg_outcome_level: Pre-treatment outcome mean
861
+ - outcome_trend: Pre-treatment trend slope
862
+ - n_similar_controls: Count of similar potential controls
863
+
864
+ Examples
865
+ --------
866
+ Rank controls against treated units:
867
+
868
+ >>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
869
+ >>> ranking = rank_control_units(
870
+ ... data,
871
+ ... unit_column='unit',
872
+ ... time_column='period',
873
+ ... outcome_column='outcome',
874
+ ... treatment_column='treated',
875
+ ... n_top=10
876
+ ... )
877
+ >>> ranking['quality_score'].is_monotonic_decreasing
878
+ True
879
+
880
+ With covariates:
881
+
882
+ >>> data['size'] = np.random.randn(len(data))
883
+ >>> ranking = rank_control_units(
884
+ ... data,
885
+ ... unit_column='unit',
886
+ ... time_column='period',
887
+ ... outcome_column='outcome',
888
+ ... treatment_column='treated',
889
+ ... covariates=['size']
890
+ ... )
891
+
892
+ Filter data for SyntheticDiD:
893
+
894
+ >>> top_controls = ranking['unit'].tolist()
895
+ >>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
896
+ """
897
+ # -------------------------------------------------------------------------
898
+ # Input validation
899
+ # -------------------------------------------------------------------------
900
+ for col in [unit_column, time_column, outcome_column]:
901
+ if col not in data.columns:
902
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
903
+
904
+ if treatment_column is not None and treatment_column not in data.columns:
905
+ raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
906
+
907
+ if covariates:
908
+ for cov in covariates:
909
+ if cov not in data.columns:
910
+ raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
911
+
912
+ if not 0 <= outcome_weight <= 1:
913
+ raise ValueError("outcome_weight must be between 0 and 1")
914
+ if not 0 <= covariate_weight <= 1:
915
+ raise ValueError("covariate_weight must be between 0 and 1")
916
+
917
+ if treated_units is not None and treatment_column is not None:
918
+ raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
919
+
920
+ if require_units and exclude_units:
921
+ invalid_required = [u for u in require_units if u in exclude_units]
922
+ if invalid_required:
923
+ raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
924
+
925
+ # -------------------------------------------------------------------------
926
+ # Determine pre-treatment periods
927
+ # -------------------------------------------------------------------------
928
+ all_periods = sorted(data[time_column].unique())
929
+ if pre_periods is None:
930
+ mid_point = len(all_periods) // 2
931
+ pre_periods = all_periods[:mid_point]
932
+ else:
933
+ pre_periods = list(pre_periods)
934
+
935
+ if len(pre_periods) == 0:
936
+ raise ValueError("No pre-treatment periods specified or inferred.")
937
+
938
+ # -------------------------------------------------------------------------
939
+ # Identify treated and control units
940
+ # -------------------------------------------------------------------------
941
+ all_units = list(data[unit_column].unique())
942
+
943
+ if treated_units is not None:
944
+ treated_set = set(treated_units)
945
+ elif treatment_column is not None:
946
+ unit_treatment = data.groupby(unit_column)[treatment_column].first()
947
+ treated_set = set(unit_treatment[unit_treatment == 1].index)
948
+ elif suggest_treatment_candidates:
949
+ # Treatment candidate discovery mode - no treated units
950
+ treated_set = set()
951
+ else:
952
+ raise ValueError(
953
+ "Must specify treated_units, treatment_column, or set "
954
+ "suggest_treatment_candidates=True"
955
+ )
956
+
957
+ # -------------------------------------------------------------------------
958
+ # Treatment candidate discovery mode
959
+ # -------------------------------------------------------------------------
960
+ if suggest_treatment_candidates and len(treated_set) == 0:
961
+ return _suggest_treatment_candidates(
962
+ data, unit_column, time_column, outcome_column,
963
+ pre_periods, n_treatment_candidates
964
+ )
965
+
966
+ if len(treated_set) == 0:
967
+ raise ValueError("No treated units found.")
968
+
969
+ # Determine control candidates
970
+ control_candidates = [u for u in all_units if u not in treated_set]
971
+
972
+ if exclude_units:
973
+ control_candidates = [u for u in control_candidates if u not in exclude_units]
974
+
975
+ if len(control_candidates) == 0:
976
+ raise ValueError("No control units available after exclusions.")
977
+
978
+ # -------------------------------------------------------------------------
979
+ # Create outcome matrices (pre-treatment)
980
+ # -------------------------------------------------------------------------
981
+ pre_data = data[data[time_column].isin(pre_periods)]
982
+ pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
983
+
984
+ # Filter to pre_periods that exist in data
985
+ valid_pre_periods = [p for p in pre_periods if p in pivot.index]
986
+ if len(valid_pre_periods) == 0:
987
+ raise ValueError("No data found for specified pre-treatment periods.")
988
+
989
+ # Filter control_candidates to those present in pivot (handles unbalanced panels)
990
+ control_candidates = [c for c in control_candidates if c in pivot.columns]
991
+ if len(control_candidates) == 0:
992
+ raise ValueError("No control units found in pre-treatment data.")
993
+
994
+ # Control outcomes: shape (n_pre_periods, n_control_candidates)
995
+ Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
996
+
997
+ # Treated outcomes mean: shape (n_pre_periods,)
998
+ treated_list = [u for u in treated_set if u in pivot.columns]
999
+ if len(treated_list) == 0:
1000
+ raise ValueError("Treated units not found in pre-treatment data.")
1001
+ Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
1002
+
1003
+ # -------------------------------------------------------------------------
1004
+ # Compute outcome trend scores
1005
+ # -------------------------------------------------------------------------
1006
+ # Synthetic weights (higher = better match)
1007
+ synthetic_weights = compute_synthetic_weights(
1008
+ Y_control, Y_treated_mean, lambda_reg=lambda_reg
1009
+ )
1010
+
1011
+ # RMSE for each control vs treated mean (use nanmean to handle missing data)
1012
+ rmse_scores = []
1013
+ for j in range(len(control_candidates)):
1014
+ y_c = Y_control[:, j]
1015
+ rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
1016
+ rmse_scores.append(rmse)
1017
+
1018
+ # Convert RMSE to similarity score (lower RMSE = higher score)
1019
+ max_rmse = max(rmse_scores) if rmse_scores else 1.0
1020
+ min_rmse = min(rmse_scores) if rmse_scores else 0.0
1021
+ rmse_range = max_rmse - min_rmse
1022
+
1023
+ if rmse_range < 1e-10:
1024
+ # All controls have identical/similar pre-trends (includes single control case)
1025
+ outcome_trend_scores = [1.0] * len(rmse_scores)
1026
+ else:
1027
+ # Normalize so best control gets 1.0, worst gets 0.0
1028
+ outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
1029
+
1030
+ # -------------------------------------------------------------------------
1031
+ # Compute covariate scores (if covariates provided)
1032
+ # -------------------------------------------------------------------------
1033
+ if covariates and len(covariates) > 0:
1034
+ # Get unit-level covariate values (pre-treatment mean)
1035
+ cov_data = pre_data.groupby(unit_column)[covariates].mean()
1036
+
1037
+ # Treated covariate profile (mean across treated units)
1038
+ treated_cov = cov_data.loc[list(treated_set)].mean()
1039
+
1040
+ # Standardize covariates
1041
+ cov_mean = cov_data.mean()
1042
+ cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
1043
+ cov_standardized = (cov_data - cov_mean) / cov_std
1044
+ treated_cov_std = (treated_cov - cov_mean) / cov_std
1045
+
1046
+ # Euclidean distance in standardized space (vectorized)
1047
+ control_cov_matrix = cov_standardized.loc[control_candidates].values
1048
+ treated_cov_vector = treated_cov_std.values
1049
+ covariate_distances = np.sqrt(
1050
+ np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
1051
+ )
1052
+
1053
+ # Convert distance to similarity score (min-max normalization)
1054
+ max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
1055
+ min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
1056
+ dist_range = max_dist - min_dist
1057
+
1058
+ if dist_range < 1e-10:
1059
+ # All controls have identical/similar covariate profiles
1060
+ covariate_scores = [1.0] * len(covariate_distances)
1061
+ else:
1062
+ # Normalize so best control (closest) gets 1.0, worst gets 0.0
1063
+ covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
1064
+ else:
1065
+ covariate_scores = [np.nan] * len(control_candidates)
1066
+
1067
+ # -------------------------------------------------------------------------
1068
+ # Compute combined quality score
1069
+ # -------------------------------------------------------------------------
1070
+ # Normalize weights
1071
+ total_weight = outcome_weight + covariate_weight
1072
+ if total_weight > 0:
1073
+ norm_outcome_weight = outcome_weight / total_weight
1074
+ norm_covariate_weight = covariate_weight / total_weight
1075
+ else:
1076
+ norm_outcome_weight = 1.0
1077
+ norm_covariate_weight = 0.0
1078
+
1079
+ quality_scores = []
1080
+ for i in range(len(control_candidates)):
1081
+ outcome_score = outcome_trend_scores[i]
1082
+ cov_score = covariate_scores[i]
1083
+
1084
+ if np.isnan(cov_score):
1085
+ # No covariates - use only outcome score
1086
+ combined = outcome_score
1087
+ else:
1088
+ combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
1089
+
1090
+ quality_scores.append(combined)
1091
+
1092
+ # -------------------------------------------------------------------------
1093
+ # Build result DataFrame
1094
+ # -------------------------------------------------------------------------
1095
+ require_set = set(require_units) if require_units else set()
1096
+
1097
+ result = pd.DataFrame({
1098
+ 'unit': control_candidates,
1099
+ 'quality_score': quality_scores,
1100
+ 'outcome_trend_score': outcome_trend_scores,
1101
+ 'covariate_score': covariate_scores,
1102
+ 'synthetic_weight': synthetic_weights,
1103
+ 'pre_trend_rmse': rmse_scores,
1104
+ 'is_required': [u in require_set for u in control_candidates]
1105
+ })
1106
+
1107
+ # Sort by quality score (descending)
1108
+ result = result.sort_values('quality_score', ascending=False)
1109
+
1110
+ # Apply n_top limit if specified
1111
+ if n_top is not None and n_top < len(result):
1112
+ # Always include required units
1113
+ required_df = result[result['is_required']]
1114
+ non_required_df = result[~result['is_required']]
1115
+
1116
+ # Take top from non-required to fill remaining slots
1117
+ remaining_slots = max(0, n_top - len(required_df))
1118
+ top_non_required = non_required_df.head(remaining_slots)
1119
+
1120
+ result = pd.concat([required_df, top_non_required])
1121
+ result = result.sort_values('quality_score', ascending=False)
1122
+
1123
+ return result.reset_index(drop=True)
1124
+
1125
+
1126
+ def _suggest_treatment_candidates(
1127
+ data: pd.DataFrame,
1128
+ unit_column: str,
1129
+ time_column: str,
1130
+ outcome_column: str,
1131
+ pre_periods: List[Any],
1132
+ n_candidates: int
1133
+ ) -> pd.DataFrame:
1134
+ """
1135
+ Identify units that would make good treatment candidates.
1136
+
1137
+ A good treatment candidate:
1138
+ 1. Has many similar control units available (for matching)
1139
+ 2. Has stable pre-treatment trends (predictable counterfactual)
1140
+ 3. Is not an extreme outlier
1141
+
1142
+ Parameters
1143
+ ----------
1144
+ data : pd.DataFrame
1145
+ Panel data.
1146
+ unit_column : str
1147
+ Unit identifier column.
1148
+ time_column : str
1149
+ Time period column.
1150
+ outcome_column : str
1151
+ Outcome variable column.
1152
+ pre_periods : list
1153
+ Pre-treatment periods.
1154
+ n_candidates : int
1155
+ Number of candidates to return.
1156
+
1157
+ Returns
1158
+ -------
1159
+ pd.DataFrame
1160
+ Treatment candidates with scores.
1161
+ """
1162
+ all_units = list(data[unit_column].unique())
1163
+ pre_data = data[data[time_column].isin(pre_periods)]
1164
+
1165
+ candidate_info = []
1166
+
1167
+ for unit in all_units:
1168
+ unit_data = pre_data[pre_data[unit_column] == unit]
1169
+
1170
+ if len(unit_data) == 0:
1171
+ continue
1172
+
1173
+ # Average outcome level
1174
+ avg_outcome = unit_data[outcome_column].mean()
1175
+
1176
+ # Trend (simple linear regression slope)
1177
+ times = unit_data[time_column].values
1178
+ outcomes = unit_data[outcome_column].values
1179
+ if len(times) > 1:
1180
+ times_norm = np.arange(len(times))
1181
+ try:
1182
+ slope = np.polyfit(times_norm, outcomes, 1)[0]
1183
+ except (np.linalg.LinAlgError, ValueError):
1184
+ slope = 0.0
1185
+ else:
1186
+ slope = 0.0
1187
+
1188
+ # Count similar potential controls
1189
+ other_units = [u for u in all_units if u != unit]
1190
+ other_means = pre_data[
1191
+ pre_data[unit_column].isin(other_units)
1192
+ ].groupby(unit_column)[outcome_column].mean()
1193
+
1194
+ if len(other_means) > 0:
1195
+ sd = other_means.std()
1196
+ if sd > 0:
1197
+ n_similar = int(np.sum(
1198
+ np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd
1199
+ ))
1200
+ else:
1201
+ n_similar = len(other_means)
1202
+ else:
1203
+ n_similar = 0
1204
+
1205
+ candidate_info.append({
1206
+ 'unit': unit,
1207
+ 'avg_outcome_level': avg_outcome,
1208
+ 'outcome_trend': slope,
1209
+ 'n_similar_controls': n_similar
1210
+ })
1211
+
1212
+ if len(candidate_info) == 0:
1213
+ return pd.DataFrame(columns=[
1214
+ 'unit', 'treatment_candidate_score', 'avg_outcome_level',
1215
+ 'outcome_trend', 'n_similar_controls'
1216
+ ])
1217
+
1218
+ result = pd.DataFrame(candidate_info)
1219
+
1220
+ # Score: prefer units with many similar controls and moderate outcome levels
1221
+ max_similar = result['n_similar_controls'].max()
1222
+ if max_similar > 0:
1223
+ similarity_score = result['n_similar_controls'] / max_similar
1224
+ else:
1225
+ similarity_score = pd.Series([0.0] * len(result))
1226
+
1227
+ # Penalty for outliers in outcome level
1228
+ outcome_mean = result['avg_outcome_level'].mean()
1229
+ outcome_std = result['avg_outcome_level'].std()
1230
+ if outcome_std > 0:
1231
+ outcome_z = np.abs((result['avg_outcome_level'] - outcome_mean) / outcome_std)
1232
+ else:
1233
+ outcome_z = pd.Series([0.0] * len(result))
1234
+
1235
+ result['treatment_candidate_score'] = (
1236
+ similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
1237
+ ).clip(0, 1)
1238
+
1239
+ # Return top candidates
1240
+ result = result.nlargest(n_candidates, 'treatment_candidate_score')
1241
+ return result.reset_index(drop=True)