diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/trop.py ADDED
@@ -0,0 +1,1348 @@
1
+ """
2
+ Triply Robust Panel (TROP) estimator.
3
+
4
+ Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
5
+ TROP combines three robustness components:
6
+ 1. Nuclear norm regularized factor model (interactive fixed effects)
7
+ 2. Exponential distance-based unit weights
8
+ 3. Exponential time decay weights
9
+
10
+ The estimator uses leave-one-out cross-validation for tuning parameter
11
+ selection and provides robust treatment effect estimates under factor
12
+ confounding.
13
+
14
+ References
15
+ ----------
16
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
17
+ Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
18
+ """
19
+
20
+ import warnings
21
+ from dataclasses import dataclass, field
22
+ from typing import Any, Dict, List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from scipy import stats
27
+
28
+ from diff_diff.results import _get_significance_stars
29
+ from diff_diff.utils import compute_confidence_interval, compute_p_value
30
+
31
+
32
+ @dataclass
33
+ class TROPResults:
34
+ """
35
+ Results from a Triply Robust Panel (TROP) estimation.
36
+
37
+ TROP combines nuclear norm regularized factor estimation with
38
+ exponential distance-based unit weights and time decay weights.
39
+
40
+ Attributes
41
+ ----------
42
+ att : float
43
+ Average Treatment effect on the Treated (ATT).
44
+ se : float
45
+ Standard error of the ATT estimate.
46
+ t_stat : float
47
+ T-statistic for the ATT estimate.
48
+ p_value : float
49
+ P-value for the null hypothesis that ATT = 0.
50
+ conf_int : tuple[float, float]
51
+ Confidence interval for the ATT.
52
+ n_obs : int
53
+ Number of observations used in estimation.
54
+ n_treated : int
55
+ Number of treated units.
56
+ n_control : int
57
+ Number of control units.
58
+ n_treated_obs : int
59
+ Number of treated unit-time observations.
60
+ unit_effects : dict
61
+ Estimated unit fixed effects (alpha_i).
62
+ time_effects : dict
63
+ Estimated time fixed effects (beta_t).
64
+ treatment_effects : dict
65
+ Individual treatment effects for each treated (unit, time) pair.
66
+ lambda_time : float
67
+ Selected time weight decay parameter.
68
+ lambda_unit : float
69
+ Selected unit weight decay parameter.
70
+ lambda_nn : float
71
+ Selected nuclear norm regularization parameter.
72
+ factor_matrix : np.ndarray
73
+ Estimated low-rank factor matrix L (n_periods x n_units).
74
+ effective_rank : float
75
+ Effective rank of the factor matrix (sum of singular values / max).
76
+ loocv_score : float
77
+ Leave-one-out cross-validation score for selected parameters.
78
+ variance_method : str
79
+ Method used for variance estimation.
80
+ alpha : float
81
+ Significance level for confidence interval.
82
+ pre_periods : list
83
+ List of pre-treatment period identifiers.
84
+ post_periods : list
85
+ List of post-treatment period identifiers.
86
+ n_bootstrap : int, optional
87
+ Number of bootstrap replications (if bootstrap variance).
88
+ bootstrap_distribution : np.ndarray, optional
89
+ Bootstrap distribution of estimates.
90
+ """
91
+
92
+ att: float
93
+ se: float
94
+ t_stat: float
95
+ p_value: float
96
+ conf_int: Tuple[float, float]
97
+ n_obs: int
98
+ n_treated: int
99
+ n_control: int
100
+ n_treated_obs: int
101
+ unit_effects: Dict[Any, float]
102
+ time_effects: Dict[Any, float]
103
+ treatment_effects: Dict[Tuple[Any, Any], float]
104
+ lambda_time: float
105
+ lambda_unit: float
106
+ lambda_nn: float
107
+ factor_matrix: np.ndarray
108
+ effective_rank: float
109
+ loocv_score: float
110
+ variance_method: str
111
+ alpha: float = 0.05
112
+ pre_periods: List[Any] = field(default_factory=list)
113
+ post_periods: List[Any] = field(default_factory=list)
114
+ n_bootstrap: Optional[int] = field(default=None)
115
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
116
+
117
+ def __repr__(self) -> str:
118
+ """Concise string representation."""
119
+ sig = _get_significance_stars(self.p_value)
120
+ return (
121
+ f"TROPResults(ATT={self.att:.4f}{sig}, "
122
+ f"SE={self.se:.4f}, "
123
+ f"eff_rank={self.effective_rank:.1f}, "
124
+ f"p={self.p_value:.4f})"
125
+ )
126
+
127
+ def summary(self, alpha: Optional[float] = None) -> str:
128
+ """
129
+ Generate a formatted summary of the estimation results.
130
+
131
+ Parameters
132
+ ----------
133
+ alpha : float, optional
134
+ Significance level for confidence intervals. Defaults to the
135
+ alpha used during estimation.
136
+
137
+ Returns
138
+ -------
139
+ str
140
+ Formatted summary table.
141
+ """
142
+ alpha = alpha or self.alpha
143
+ conf_level = int((1 - alpha) * 100)
144
+
145
+ lines = [
146
+ "=" * 75,
147
+ "Triply Robust Panel (TROP) Estimation Results".center(75),
148
+ "Athey, Imbens, Qu & Viviano (2025)".center(75),
149
+ "=" * 75,
150
+ "",
151
+ f"{'Observations:':<25} {self.n_obs:>10}",
152
+ f"{'Treated units:':<25} {self.n_treated:>10}",
153
+ f"{'Control units:':<25} {self.n_control:>10}",
154
+ f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
155
+ f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
156
+ f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
157
+ "",
158
+ "-" * 75,
159
+ "Tuning Parameters (selected via LOOCV)".center(75),
160
+ "-" * 75,
161
+ f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}",
162
+ f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}",
163
+ f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}",
164
+ f"{'Effective rank:':<25} {self.effective_rank:>10.2f}",
165
+ f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}",
166
+ ]
167
+
168
+ # Variance method info
169
+ lines.append(f"{'Variance method:':<25} {self.variance_method:>10}")
170
+ if self.variance_method == "bootstrap" and self.n_bootstrap is not None:
171
+ lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
172
+
173
+ lines.extend([
174
+ "",
175
+ "-" * 75,
176
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
177
+ f"{'t-stat':>10} {'P>|t|':>10} {'':>5}",
178
+ "-" * 75,
179
+ f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} "
180
+ f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
181
+ "-" * 75,
182
+ "",
183
+ f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
184
+ ])
185
+
186
+ # Add significance codes
187
+ lines.extend([
188
+ "",
189
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
190
+ "=" * 75,
191
+ ])
192
+
193
+ return "\n".join(lines)
194
+
195
+ def print_summary(self, alpha: Optional[float] = None) -> None:
196
+ """Print the summary to stdout."""
197
+ print(self.summary(alpha))
198
+
199
+ def to_dict(self) -> Dict[str, Any]:
200
+ """
201
+ Convert results to a dictionary.
202
+
203
+ Returns
204
+ -------
205
+ Dict[str, Any]
206
+ Dictionary containing all estimation results.
207
+ """
208
+ return {
209
+ "att": self.att,
210
+ "se": self.se,
211
+ "t_stat": self.t_stat,
212
+ "p_value": self.p_value,
213
+ "conf_int_lower": self.conf_int[0],
214
+ "conf_int_upper": self.conf_int[1],
215
+ "n_obs": self.n_obs,
216
+ "n_treated": self.n_treated,
217
+ "n_control": self.n_control,
218
+ "n_treated_obs": self.n_treated_obs,
219
+ "n_pre_periods": len(self.pre_periods),
220
+ "n_post_periods": len(self.post_periods),
221
+ "lambda_time": self.lambda_time,
222
+ "lambda_unit": self.lambda_unit,
223
+ "lambda_nn": self.lambda_nn,
224
+ "effective_rank": self.effective_rank,
225
+ "loocv_score": self.loocv_score,
226
+ "variance_method": self.variance_method,
227
+ }
228
+
229
+ def to_dataframe(self) -> pd.DataFrame:
230
+ """
231
+ Convert results to a pandas DataFrame.
232
+
233
+ Returns
234
+ -------
235
+ pd.DataFrame
236
+ DataFrame with estimation results.
237
+ """
238
+ return pd.DataFrame([self.to_dict()])
239
+
240
+ def get_treatment_effects_df(self) -> pd.DataFrame:
241
+ """
242
+ Get individual treatment effects as a DataFrame.
243
+
244
+ Returns
245
+ -------
246
+ pd.DataFrame
247
+ DataFrame with unit, time, and treatment effect columns.
248
+ """
249
+ return pd.DataFrame([
250
+ {"unit": unit, "time": time, "effect": effect}
251
+ for (unit, time), effect in self.treatment_effects.items()
252
+ ])
253
+
254
+ def get_unit_effects_df(self) -> pd.DataFrame:
255
+ """
256
+ Get unit fixed effects as a DataFrame.
257
+
258
+ Returns
259
+ -------
260
+ pd.DataFrame
261
+ DataFrame with unit and effect columns.
262
+ """
263
+ return pd.DataFrame([
264
+ {"unit": unit, "effect": effect}
265
+ for unit, effect in self.unit_effects.items()
266
+ ])
267
+
268
+ def get_time_effects_df(self) -> pd.DataFrame:
269
+ """
270
+ Get time fixed effects as a DataFrame.
271
+
272
+ Returns
273
+ -------
274
+ pd.DataFrame
275
+ DataFrame with time and effect columns.
276
+ """
277
+ return pd.DataFrame([
278
+ {"time": time, "effect": effect}
279
+ for time, effect in self.time_effects.items()
280
+ ])
281
+
282
+ @property
283
+ def is_significant(self) -> bool:
284
+ """Check if the ATT is statistically significant at the alpha level."""
285
+ return bool(self.p_value < self.alpha)
286
+
287
+ @property
288
+ def significance_stars(self) -> str:
289
+ """Return significance stars based on p-value."""
290
+ return _get_significance_stars(self.p_value)
291
+
292
+
293
+ class TROP:
294
+ """
295
+ Triply Robust Panel (TROP) estimator.
296
+
297
+ Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
298
+ TROP combines three robustness components:
299
+
300
+ 1. **Nuclear norm regularized factor model**: Estimates interactive fixed
301
+ effects L_it via matrix completion with nuclear norm penalty ||L||_*
302
+
303
+ 2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
304
+ where d(j,i) is the RMSE of outcome differences between units
305
+
306
+ 3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
307
+ weighting pre-treatment periods by proximity to treatment
308
+
309
+ Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
310
+ cross-validation on control observations.
311
+
312
+ Parameters
313
+ ----------
314
+ lambda_time_grid : list, optional
315
+ Grid of time weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
316
+ lambda_unit_grid : list, optional
317
+ Grid of unit weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
318
+ lambda_nn_grid : list, optional
319
+ Grid of nuclear norm regularization parameters. Default: [0, 0.01, 0.1, 1].
320
+ max_iter : int, default=100
321
+ Maximum iterations for nuclear norm optimization.
322
+ tol : float, default=1e-6
323
+ Convergence tolerance for optimization.
324
+ alpha : float, default=0.05
325
+ Significance level for confidence intervals.
326
+ variance_method : str, default='bootstrap'
327
+ Method for variance estimation: 'bootstrap' or 'jackknife'.
328
+ n_bootstrap : int, default=200
329
+ Number of replications for variance estimation.
330
+ seed : int, optional
331
+ Random seed for reproducibility.
332
+
333
+ Attributes
334
+ ----------
335
+ results_ : TROPResults
336
+ Estimation results after calling fit().
337
+ is_fitted_ : bool
338
+ Whether the model has been fitted.
339
+
340
+ Examples
341
+ --------
342
+ >>> from diff_diff import TROP
343
+ >>> trop = TROP()
344
+ >>> results = trop.fit(
345
+ ... data,
346
+ ... outcome='outcome',
347
+ ... treatment='treated',
348
+ ... unit='unit',
349
+ ... time='period',
350
+ ... post_periods=[5, 6, 7, 8]
351
+ ... )
352
+ >>> results.print_summary()
353
+
354
+ References
355
+ ----------
356
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
357
+ Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ lambda_time_grid: Optional[List[float]] = None,
363
+ lambda_unit_grid: Optional[List[float]] = None,
364
+ lambda_nn_grid: Optional[List[float]] = None,
365
+ max_iter: int = 100,
366
+ tol: float = 1e-6,
367
+ alpha: float = 0.05,
368
+ variance_method: str = 'bootstrap',
369
+ n_bootstrap: int = 200,
370
+ seed: Optional[int] = None,
371
+ ):
372
+ # Default grids from paper
373
+ self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
374
+ self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
375
+ self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
376
+
377
+ self.max_iter = max_iter
378
+ self.tol = tol
379
+ self.alpha = alpha
380
+ self.variance_method = variance_method
381
+ self.n_bootstrap = n_bootstrap
382
+ self.seed = seed
383
+
384
+ # Validate parameters
385
+ valid_variance_methods = ("bootstrap", "jackknife")
386
+ if variance_method not in valid_variance_methods:
387
+ raise ValueError(
388
+ f"variance_method must be one of {valid_variance_methods}, "
389
+ f"got '{variance_method}'"
390
+ )
391
+
392
+ # Internal state
393
+ self.results_: Optional[TROPResults] = None
394
+ self.is_fitted_: bool = False
395
+ self._optimal_lambda: Optional[Tuple[float, float, float]] = None
396
+
397
+ def fit(
398
+ self,
399
+ data: pd.DataFrame,
400
+ outcome: str,
401
+ treatment: str,
402
+ unit: str,
403
+ time: str,
404
+ post_periods: Optional[List[Any]] = None,
405
+ ) -> TROPResults:
406
+ """
407
+ Fit the TROP model.
408
+
409
+ Parameters
410
+ ----------
411
+ data : pd.DataFrame
412
+ Panel data with observations for multiple units over multiple
413
+ time periods.
414
+ outcome : str
415
+ Name of the outcome variable column.
416
+ treatment : str
417
+ Name of the treatment indicator column (0/1).
418
+ Should be 1 for treated unit-time observations.
419
+ unit : str
420
+ Name of the unit identifier column.
421
+ time : str
422
+ Name of the time period column.
423
+ post_periods : list, optional
424
+ List of time period values that are post-treatment.
425
+ If None, infers from treatment indicator.
426
+
427
+ Returns
428
+ -------
429
+ TROPResults
430
+ Object containing the ATT estimate, standard error,
431
+ factor estimates, and tuning parameters.
432
+ """
433
+ # Validate inputs
434
+ required_cols = [outcome, treatment, unit, time]
435
+ missing = [c for c in required_cols if c not in data.columns]
436
+ if missing:
437
+ raise ValueError(f"Missing columns: {missing}")
438
+
439
+ # Get unique units and periods
440
+ all_units = sorted(data[unit].unique())
441
+ all_periods = sorted(data[time].unique())
442
+
443
+ n_units = len(all_units)
444
+ n_periods = len(all_periods)
445
+
446
+ # Create mappings
447
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
448
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
449
+ idx_to_unit = {i: u for u, i in unit_to_idx.items()}
450
+ idx_to_period = {i: p for p, i in period_to_idx.items()}
451
+
452
+ # Create outcome matrix Y (n_periods x n_units) and treatment matrix D
453
+ Y = np.full((n_periods, n_units), np.nan)
454
+ D = np.zeros((n_periods, n_units), dtype=int)
455
+
456
+ for _, row in data.iterrows():
457
+ i = unit_to_idx[row[unit]]
458
+ t = period_to_idx[row[time]]
459
+ Y[t, i] = row[outcome]
460
+ D[t, i] = int(row[treatment])
461
+
462
+ # Identify treated observations
463
+ treated_mask = D == 1
464
+ n_treated_obs = np.sum(treated_mask)
465
+
466
+ if n_treated_obs == 0:
467
+ raise ValueError("No treated observations found")
468
+
469
+ # Identify treated and control units
470
+ unit_ever_treated = np.any(D == 1, axis=0)
471
+ treated_unit_idx = np.where(unit_ever_treated)[0]
472
+ control_unit_idx = np.where(~unit_ever_treated)[0]
473
+
474
+ if len(control_unit_idx) == 0:
475
+ raise ValueError("No control units found")
476
+
477
+ # Determine pre/post periods
478
+ if post_periods is None:
479
+ # Infer from first treatment time
480
+ first_treat_period = None
481
+ for t in range(n_periods):
482
+ if np.any(D[t, :] == 1):
483
+ first_treat_period = t
484
+ break
485
+ if first_treat_period is None:
486
+ raise ValueError("Could not infer post-treatment periods")
487
+ pre_period_idx = list(range(first_treat_period))
488
+ post_period_idx = list(range(first_treat_period, n_periods))
489
+ else:
490
+ post_period_idx = [period_to_idx[p] for p in post_periods if p in period_to_idx]
491
+ pre_period_idx = [i for i in range(n_periods) if i not in post_period_idx]
492
+
493
+ if len(pre_period_idx) < 2:
494
+ raise ValueError("Need at least 2 pre-treatment periods")
495
+
496
+ pre_periods_list = [idx_to_period[i] for i in pre_period_idx]
497
+ post_periods_list = [idx_to_period[i] for i in post_period_idx]
498
+ n_treated_periods = len(post_period_idx)
499
+
500
+ # Step 1: Grid search with LOOCV for tuning parameters
501
+ best_lambda = None
502
+ best_score = np.inf
503
+
504
+ # Control observations mask (for LOOCV)
505
+ control_mask = D == 0
506
+
507
+ for lambda_time in self.lambda_time_grid:
508
+ for lambda_unit in self.lambda_unit_grid:
509
+ for lambda_nn in self.lambda_nn_grid:
510
+ try:
511
+ score = self._loocv_score_obs_specific(
512
+ Y, D, control_mask, control_unit_idx,
513
+ lambda_time, lambda_unit, lambda_nn,
514
+ n_units, n_periods
515
+ )
516
+ if score < best_score:
517
+ best_score = score
518
+ best_lambda = (lambda_time, lambda_unit, lambda_nn)
519
+ except (np.linalg.LinAlgError, ValueError):
520
+ continue
521
+
522
+ if best_lambda is None:
523
+ warnings.warn(
524
+ "All tuning parameter combinations failed. Using defaults.",
525
+ UserWarning
526
+ )
527
+ best_lambda = (1.0, 1.0, 0.1)
528
+ best_score = np.nan
529
+
530
+ self._optimal_lambda = best_lambda
531
+ lambda_time, lambda_unit, lambda_nn = best_lambda
532
+
533
+ # Step 2: Final estimation - per-observation model fitting following Algorithm 2
534
+ # For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
535
+ treatment_effects = {}
536
+ tau_values = []
537
+ alpha_estimates = []
538
+ beta_estimates = []
539
+ L_estimates = []
540
+
541
+ # Get list of treated observations
542
+ treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
543
+ if D[t, i] == 1]
544
+
545
+ for t, i in treated_observations:
546
+ # Compute observation-specific weights for this (i, t)
547
+ weight_matrix = self._compute_observation_weights(
548
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
549
+ n_units, n_periods
550
+ )
551
+
552
+ # Fit model with these weights
553
+ alpha_hat, beta_hat, L_hat = self._estimate_model(
554
+ Y, control_mask, weight_matrix, lambda_nn,
555
+ n_units, n_periods
556
+ )
557
+
558
+ # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
559
+ tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
560
+
561
+ unit_id = idx_to_unit[i]
562
+ time_id = idx_to_period[t]
563
+ treatment_effects[(unit_id, time_id)] = tau_it
564
+ tau_values.append(tau_it)
565
+
566
+ # Store for averaging
567
+ alpha_estimates.append(alpha_hat)
568
+ beta_estimates.append(beta_hat)
569
+ L_estimates.append(L_hat)
570
+
571
+ # Average ATT
572
+ att = np.mean(tau_values)
573
+
574
+ # Average parameter estimates for output (representative)
575
+ alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
576
+ beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
577
+ L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
578
+
579
+ # Compute effective rank
580
+ _, s, _ = np.linalg.svd(L_hat, full_matrices=False)
581
+ if s[0] > 0:
582
+ effective_rank = np.sum(s) / s[0]
583
+ else:
584
+ effective_rank = 0.0
585
+
586
+ # Step 4: Variance estimation
587
+ if self.variance_method == "bootstrap":
588
+ se, bootstrap_dist = self._bootstrap_variance(
589
+ data, outcome, treatment, unit, time, post_periods_list,
590
+ best_lambda
591
+ )
592
+ else:
593
+ se, bootstrap_dist = self._jackknife_variance(
594
+ Y, D, control_mask, control_unit_idx, best_lambda,
595
+ n_units, n_periods
596
+ )
597
+
598
+ # Compute test statistics
599
+ if se > 0:
600
+ t_stat = att / se
601
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
602
+ else:
603
+ t_stat = 0.0
604
+ p_value = 1.0
605
+
606
+ conf_int = compute_confidence_interval(att, se, self.alpha)
607
+
608
+ # Create results dictionaries
609
+ unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
610
+ time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
611
+
612
+ # Store results
613
+ self.results_ = TROPResults(
614
+ att=att,
615
+ se=se,
616
+ t_stat=t_stat,
617
+ p_value=p_value,
618
+ conf_int=conf_int,
619
+ n_obs=len(data),
620
+ n_treated=len(treated_unit_idx),
621
+ n_control=len(control_unit_idx),
622
+ n_treated_obs=n_treated_obs,
623
+ unit_effects=unit_effects_dict,
624
+ time_effects=time_effects_dict,
625
+ treatment_effects=treatment_effects,
626
+ lambda_time=lambda_time,
627
+ lambda_unit=lambda_unit,
628
+ lambda_nn=lambda_nn,
629
+ factor_matrix=L_hat,
630
+ effective_rank=effective_rank,
631
+ loocv_score=best_score,
632
+ variance_method=self.variance_method,
633
+ alpha=self.alpha,
634
+ pre_periods=pre_periods_list,
635
+ post_periods=post_periods_list,
636
+ n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None,
637
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
638
+ )
639
+
640
+ self.is_fitted_ = True
641
+ return self.results_
642
+
643
+ def _compute_unit_distance_pairwise(
644
+ self,
645
+ Y: np.ndarray,
646
+ D: np.ndarray,
647
+ j: int,
648
+ i: int,
649
+ target_period: int,
650
+ ) -> float:
651
+ """
652
+ Compute pairwise distance from control unit j to treated unit i.
653
+
654
+ Following the paper's Equation 3 (page 7):
655
+ dist_unit_{-t}(j, i) = sqrt(
656
+ Σ_u 1{u≠t}(1-W_{iu})(1-W_{ju})(Y_{iu} - Y_{ju})²
657
+ / Σ_u 1{u≠t}(1-W_{iu})(1-W_{ju})
658
+ )
659
+
660
+ This computes the RMSE between units j and i over periods where
661
+ both are untreated, excluding the target period t.
662
+
663
+ Parameters
664
+ ----------
665
+ Y : np.ndarray
666
+ Outcome matrix (n_periods x n_units).
667
+ D : np.ndarray
668
+ Treatment indicator matrix (n_periods x n_units).
669
+ j : int
670
+ Index of control unit.
671
+ i : int
672
+ Index of treated unit.
673
+ target_period : int
674
+ Target treatment period t (excluded from distance computation).
675
+
676
+ Returns
677
+ -------
678
+ float
679
+ Pairwise RMSE distance between units j and i.
680
+ """
681
+ n_periods = Y.shape[0]
682
+
683
+ sq_diffs = []
684
+ for u in range(n_periods):
685
+ # Exclude target period and periods where either unit is treated
686
+ if u == target_period:
687
+ continue
688
+ # (1 - W_{iu})(1 - W_{ju}) means both must be untreated
689
+ if D[u, i] == 1 or D[u, j] == 1:
690
+ continue
691
+ if np.isnan(Y[u, i]) or np.isnan(Y[u, j]):
692
+ continue
693
+
694
+ sq_diffs.append((Y[u, i] - Y[u, j]) ** 2)
695
+
696
+ if len(sq_diffs) > 0:
697
+ return np.sqrt(np.mean(sq_diffs))
698
+ else:
699
+ return np.inf
700
+
701
+ def _compute_observation_weights(
702
+ self,
703
+ Y: np.ndarray,
704
+ D: np.ndarray,
705
+ i: int,
706
+ t: int,
707
+ lambda_time: float,
708
+ lambda_unit: float,
709
+ control_unit_idx: np.ndarray,
710
+ n_units: int,
711
+ n_periods: int,
712
+ ) -> np.ndarray:
713
+ """
714
+ Compute observation-specific weight matrix for treated observation (i, t).
715
+
716
+ Following the paper's Algorithm 2 (page 27):
717
+ - Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
718
+ - Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
719
+
720
+ Parameters
721
+ ----------
722
+ Y : np.ndarray
723
+ Outcome matrix (n_periods x n_units).
724
+ D : np.ndarray
725
+ Treatment indicator matrix (n_periods x n_units).
726
+ i : int
727
+ Treated unit index.
728
+ t : int
729
+ Treatment period index.
730
+ lambda_time : float
731
+ Time weight decay parameter.
732
+ lambda_unit : float
733
+ Unit weight decay parameter.
734
+ control_unit_idx : np.ndarray
735
+ Indices of control units.
736
+ n_units : int
737
+ Number of units.
738
+ n_periods : int
739
+ Number of periods.
740
+
741
+ Returns
742
+ -------
743
+ np.ndarray
744
+ Weight matrix (n_periods x n_units) for observation (i, t).
745
+ """
746
+ # Time distance: |t - s| following paper's Equation 3 (page 7)
747
+ dist_time = np.array([abs(t - s) for s in range(n_periods)])
748
+ time_weights = np.exp(-lambda_time * dist_time)
749
+
750
+ # Unit distance: pairwise RMSE from each control j to treated i
751
+ unit_weights = np.zeros(n_units)
752
+
753
+ if lambda_unit == 0:
754
+ # Uniform weights when lambda_unit = 0
755
+ unit_weights[:] = 1.0
756
+ else:
757
+ for j in control_unit_idx:
758
+ dist = self._compute_unit_distance_pairwise(Y, D, j, i, t)
759
+ if np.isinf(dist):
760
+ unit_weights[j] = 0.0
761
+ else:
762
+ unit_weights[j] = np.exp(-lambda_unit * dist)
763
+
764
+ # Treated unit i gets weight 1 (or could be omitted since we fit on controls)
765
+ # We include treated unit's own observation for model fitting
766
+ unit_weights[i] = 1.0
767
+
768
+ # Weight matrix: outer product (n_periods x n_units)
769
+ W = np.outer(time_weights, unit_weights)
770
+
771
+ return W
772
+
773
+ def _soft_threshold_svd(
774
+ self,
775
+ M: np.ndarray,
776
+ threshold: float,
777
+ ) -> np.ndarray:
778
+ """
779
+ Apply soft-thresholding to singular values (proximal operator for nuclear norm).
780
+
781
+ Parameters
782
+ ----------
783
+ M : np.ndarray
784
+ Input matrix.
785
+ threshold : float
786
+ Soft-thresholding parameter.
787
+
788
+ Returns
789
+ -------
790
+ np.ndarray
791
+ Matrix with soft-thresholded singular values.
792
+ """
793
+ if threshold <= 0:
794
+ return M
795
+
796
+ # Handle NaN/Inf values in input
797
+ if not np.isfinite(M).all():
798
+ M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
799
+
800
+ try:
801
+ U, s, Vt = np.linalg.svd(M, full_matrices=False)
802
+ except np.linalg.LinAlgError:
803
+ # SVD failed, return zero matrix
804
+ return np.zeros_like(M)
805
+
806
+ # Check for numerical issues in SVD output
807
+ if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
808
+ # SVD produced non-finite values, return zero matrix
809
+ return np.zeros_like(M)
810
+
811
+ s_thresh = np.maximum(s - threshold, 0)
812
+
813
+ # Use truncated reconstruction with only non-zero singular values
814
+ nonzero_mask = s_thresh > 1e-10
815
+ if not np.any(nonzero_mask):
816
+ return np.zeros_like(M)
817
+
818
+ # Truncate to non-zero components for numerical stability
819
+ U_trunc = U[:, nonzero_mask]
820
+ s_trunc = s_thresh[nonzero_mask]
821
+ Vt_trunc = Vt[nonzero_mask, :]
822
+
823
+ # Compute result, suppressing expected numerical warnings from
824
+ # ill-conditioned matrices during alternating minimization
825
+ with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
826
+ result = (U_trunc * s_trunc) @ Vt_trunc
827
+
828
+ # Replace any NaN/Inf in result with zeros
829
+ if not np.isfinite(result).all():
830
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
831
+
832
+ return result
833
+
834
+ def _estimate_model(
835
+ self,
836
+ Y: np.ndarray,
837
+ control_mask: np.ndarray,
838
+ weight_matrix: np.ndarray,
839
+ lambda_nn: float,
840
+ n_units: int,
841
+ n_periods: int,
842
+ exclude_obs: Optional[Tuple[int, int]] = None,
843
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
844
+ """
845
+ Estimate the model: Y = α + β + L + τD + ε with nuclear norm penalty on L.
846
+
847
+ Uses alternating minimization:
848
+ 1. Fix L, solve for α, β
849
+ 2. Fix α, β, solve for L via soft-thresholding
850
+
851
+ Parameters
852
+ ----------
853
+ Y : np.ndarray
854
+ Outcome matrix (n_periods x n_units).
855
+ control_mask : np.ndarray
856
+ Boolean mask for control observations.
857
+ weight_matrix : np.ndarray
858
+ Pre-computed global weight matrix (n_periods x n_units).
859
+ lambda_nn : float
860
+ Nuclear norm regularization parameter.
861
+ n_units : int
862
+ Number of units.
863
+ n_periods : int
864
+ Number of periods.
865
+ exclude_obs : tuple, optional
866
+ (t, i) observation to exclude (for LOOCV).
867
+
868
+ Returns
869
+ -------
870
+ tuple
871
+ (alpha, beta, L) estimated parameters.
872
+ """
873
+ W = weight_matrix
874
+
875
+ # Mask for estimation (control obs only, excluding LOOCV obs if specified)
876
+ est_mask = control_mask.copy()
877
+ if exclude_obs is not None:
878
+ t_ex, i_ex = exclude_obs
879
+ est_mask[t_ex, i_ex] = False
880
+
881
+ # Handle missing values
882
+ valid_mask = ~np.isnan(Y) & est_mask
883
+
884
+ # Initialize
885
+ alpha = np.zeros(n_units)
886
+ beta = np.zeros(n_periods)
887
+ L = np.zeros((n_periods, n_units))
888
+
889
+ # Alternating minimization
890
+ for iteration in range(self.max_iter):
891
+ alpha_old = alpha.copy()
892
+ beta_old = beta.copy()
893
+ L_old = L.copy()
894
+
895
+ # Step 1: Update α and β (weighted means)
896
+ R = Y - L # Residual without fixed effects
897
+
898
+ # Weighted mean for alpha (unit effects)
899
+ for i in range(n_units):
900
+ mask_i = valid_mask[:, i]
901
+ if np.any(mask_i):
902
+ weights_i = W[mask_i, i]
903
+ # Handle case where weights sum to zero (unit not in weight computation)
904
+ weight_sum = np.sum(weights_i)
905
+ if weight_sum > 0:
906
+ alpha[i] = np.average(R[mask_i, i] - beta[mask_i], weights=weights_i)
907
+ else:
908
+ # Use unweighted mean for units with zero total weight
909
+ alpha[i] = np.mean(R[mask_i, i] - beta[mask_i])
910
+ else:
911
+ alpha[i] = 0.0
912
+
913
+ # Weighted mean for beta (time effects)
914
+ for t in range(n_periods):
915
+ mask_t = valid_mask[t, :]
916
+ if np.any(mask_t):
917
+ weights_t = W[t, mask_t]
918
+ # Handle case where weights sum to zero
919
+ weight_sum = np.sum(weights_t)
920
+ if weight_sum > 0:
921
+ beta[t] = np.average(R[t, mask_t] - alpha[mask_t], weights=weights_t)
922
+ else:
923
+ # Use unweighted mean for periods with zero total weight
924
+ beta[t] = np.mean(R[t, mask_t] - alpha[mask_t])
925
+ else:
926
+ beta[t] = 0.0
927
+
928
+ # Step 2: Update L with nuclear norm penalty
929
+ # L = soft_threshold(Y - α - β, λ_nn)
930
+ R_for_L = np.zeros((n_periods, n_units))
931
+ for t in range(n_periods):
932
+ for i in range(n_units):
933
+ if valid_mask[t, i]:
934
+ R_for_L[t, i] = Y[t, i] - alpha[i] - beta[t]
935
+ else:
936
+ # Impute with current L
937
+ R_for_L[t, i] = L[t, i]
938
+
939
+ L = self._soft_threshold_svd(R_for_L, lambda_nn)
940
+
941
+ # Check convergence
942
+ alpha_diff = np.max(np.abs(alpha - alpha_old))
943
+ beta_diff = np.max(np.abs(beta - beta_old))
944
+ L_diff = np.max(np.abs(L - L_old))
945
+
946
+ if max(alpha_diff, beta_diff, L_diff) < self.tol:
947
+ break
948
+
949
+ return alpha, beta, L
950
+
951
+ def _loocv_score_obs_specific(
952
+ self,
953
+ Y: np.ndarray,
954
+ D: np.ndarray,
955
+ control_mask: np.ndarray,
956
+ control_unit_idx: np.ndarray,
957
+ lambda_time: float,
958
+ lambda_unit: float,
959
+ lambda_nn: float,
960
+ n_units: int,
961
+ n_periods: int,
962
+ ) -> float:
963
+ """
964
+ Compute leave-one-out cross-validation score with observation-specific weights.
965
+
966
+ Following the paper's Equation 5 (page 8):
967
+ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
968
+
969
+ For each control observation (j, s), treat it as pseudo-treated,
970
+ compute observation-specific weights, fit model excluding (j, s),
971
+ and sum squared pseudo-treatment effects.
972
+
973
+ Parameters
974
+ ----------
975
+ Y : np.ndarray
976
+ Outcome matrix (n_periods x n_units).
977
+ D : np.ndarray
978
+ Treatment indicator matrix (n_periods x n_units).
979
+ control_mask : np.ndarray
980
+ Boolean mask for control observations.
981
+ control_unit_idx : np.ndarray
982
+ Indices of control units.
983
+ lambda_time : float
984
+ Time weight decay parameter.
985
+ lambda_unit : float
986
+ Unit weight decay parameter.
987
+ lambda_nn : float
988
+ Nuclear norm regularization parameter.
989
+ n_units : int
990
+ Number of units.
991
+ n_periods : int
992
+ Number of periods.
993
+
994
+ Returns
995
+ -------
996
+ float
997
+ LOOCV score (lower is better).
998
+ """
999
+ # Get all control observations
1000
+ control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
1001
+ if control_mask[t, i] and not np.isnan(Y[t, i])]
1002
+
1003
+ # Subsample for computational tractability (as noted in paper's footnote)
1004
+ rng = np.random.default_rng(self.seed)
1005
+ max_loocv = min(100, len(control_obs))
1006
+ if len(control_obs) > max_loocv:
1007
+ indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
1008
+ control_obs = [control_obs[idx] for idx in indices]
1009
+
1010
+ tau_squared_sum = 0.0
1011
+ n_valid = 0
1012
+
1013
+ for t, i in control_obs:
1014
+ try:
1015
+ # Compute observation-specific weights for pseudo-treated (i, t)
1016
+ weight_matrix = self._compute_observation_weights(
1017
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
1018
+ n_units, n_periods
1019
+ )
1020
+
1021
+ # Estimate model excluding observation (t, i)
1022
+ alpha, beta, L = self._estimate_model(
1023
+ Y, control_mask, weight_matrix, lambda_nn,
1024
+ n_units, n_periods, exclude_obs=(t, i)
1025
+ )
1026
+
1027
+ # Pseudo treatment effect
1028
+ tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
1029
+ tau_squared_sum += tau_ti ** 2
1030
+ n_valid += 1
1031
+
1032
+ except (np.linalg.LinAlgError, ValueError):
1033
+ continue
1034
+
1035
+ if n_valid == 0:
1036
+ return np.inf
1037
+
1038
+ return tau_squared_sum / n_valid
1039
+
1040
+ def _bootstrap_variance(
1041
+ self,
1042
+ data: pd.DataFrame,
1043
+ outcome: str,
1044
+ treatment: str,
1045
+ unit: str,
1046
+ time: str,
1047
+ post_periods: List[Any],
1048
+ optimal_lambda: Tuple[float, float, float],
1049
+ ) -> Tuple[float, np.ndarray]:
1050
+ """
1051
+ Compute bootstrap standard error using unit-level block bootstrap.
1052
+
1053
+ Parameters
1054
+ ----------
1055
+ data : pd.DataFrame
1056
+ Original data.
1057
+ outcome : str
1058
+ Outcome column name.
1059
+ treatment : str
1060
+ Treatment column name.
1061
+ unit : str
1062
+ Unit column name.
1063
+ time : str
1064
+ Time column name.
1065
+ post_periods : list
1066
+ Post-treatment periods.
1067
+ optimal_lambda : tuple
1068
+ Optimal (lambda_time, lambda_unit, lambda_nn).
1069
+
1070
+ Returns
1071
+ -------
1072
+ tuple
1073
+ (se, bootstrap_estimates).
1074
+ """
1075
+ rng = np.random.default_rng(self.seed)
1076
+ all_units = data[unit].unique()
1077
+ n_units = len(all_units)
1078
+
1079
+ bootstrap_estimates = []
1080
+
1081
+ for b in range(self.n_bootstrap):
1082
+ # Sample units with replacement
1083
+ sampled_units = rng.choice(all_units, size=n_units, replace=True)
1084
+
1085
+ # Create bootstrap sample with unique unit IDs
1086
+ boot_data = pd.concat([
1087
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
1088
+ for idx, u in enumerate(sampled_units)
1089
+ ], ignore_index=True)
1090
+
1091
+ try:
1092
+ # Fit with fixed lambda (skip LOOCV for speed)
1093
+ att = self._fit_with_fixed_lambda(
1094
+ boot_data, outcome, treatment, unit, time,
1095
+ post_periods, optimal_lambda
1096
+ )
1097
+ bootstrap_estimates.append(att)
1098
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1099
+ continue
1100
+
1101
+ bootstrap_estimates = np.array(bootstrap_estimates)
1102
+
1103
+ if len(bootstrap_estimates) < 10:
1104
+ warnings.warn(
1105
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
1106
+ "Standard errors may be unreliable.",
1107
+ UserWarning
1108
+ )
1109
+ if len(bootstrap_estimates) == 0:
1110
+ return 0.0, np.array([])
1111
+
1112
+ se = np.std(bootstrap_estimates, ddof=1)
1113
+ return se, bootstrap_estimates
1114
+
1115
+ def _jackknife_variance(
1116
+ self,
1117
+ Y: np.ndarray,
1118
+ D: np.ndarray,
1119
+ control_mask: np.ndarray,
1120
+ control_unit_idx: np.ndarray,
1121
+ optimal_lambda: Tuple[float, float, float],
1122
+ n_units: int,
1123
+ n_periods: int,
1124
+ ) -> Tuple[float, np.ndarray]:
1125
+ """
1126
+ Compute jackknife standard error (leave-one-unit-out).
1127
+
1128
+ Uses observation-specific weights following Algorithm 2.
1129
+
1130
+ Parameters
1131
+ ----------
1132
+ Y : np.ndarray
1133
+ Outcome matrix.
1134
+ D : np.ndarray
1135
+ Treatment matrix.
1136
+ control_mask : np.ndarray
1137
+ Control observation mask.
1138
+ control_unit_idx : np.ndarray
1139
+ Indices of control units.
1140
+ optimal_lambda : tuple
1141
+ Optimal tuning parameters.
1142
+ n_units : int
1143
+ Number of units.
1144
+ n_periods : int
1145
+ Number of periods.
1146
+
1147
+ Returns
1148
+ -------
1149
+ tuple
1150
+ (se, jackknife_estimates).
1151
+ """
1152
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
1153
+ jackknife_estimates = []
1154
+
1155
+ # Get treated unit indices
1156
+ treated_unit_idx = np.where(np.any(D == 1, axis=0))[0]
1157
+
1158
+ for leave_out in treated_unit_idx:
1159
+ # Create mask excluding this unit
1160
+ Y_jack = Y.copy()
1161
+ D_jack = D.copy()
1162
+ Y_jack[:, leave_out] = np.nan
1163
+ D_jack[:, leave_out] = 0
1164
+
1165
+ control_mask_jack = D_jack == 0
1166
+
1167
+ # Get remaining treated observations
1168
+ treated_obs_jack = [(t, i) for t in range(n_periods) for i in range(n_units)
1169
+ if D_jack[t, i] == 1]
1170
+
1171
+ if not treated_obs_jack:
1172
+ continue
1173
+
1174
+ try:
1175
+ # Compute ATT using observation-specific weights (Algorithm 2)
1176
+ tau_values = []
1177
+ for t, i in treated_obs_jack:
1178
+ # Compute observation-specific weights for this (i, t)
1179
+ weight_matrix = self._compute_observation_weights(
1180
+ Y_jack, D_jack, i, t, lambda_time, lambda_unit,
1181
+ control_unit_idx, n_units, n_periods
1182
+ )
1183
+
1184
+ # Fit model with these weights
1185
+ alpha, beta, L = self._estimate_model(
1186
+ Y_jack, control_mask_jack, weight_matrix, lambda_nn,
1187
+ n_units, n_periods
1188
+ )
1189
+
1190
+ # Compute treatment effect
1191
+ tau = Y_jack[t, i] - alpha[i] - beta[t] - L[t, i]
1192
+ tau_values.append(tau)
1193
+
1194
+ if tau_values:
1195
+ jackknife_estimates.append(np.mean(tau_values))
1196
+
1197
+ except (np.linalg.LinAlgError, ValueError):
1198
+ continue
1199
+
1200
+ jackknife_estimates = np.array(jackknife_estimates)
1201
+
1202
+ if len(jackknife_estimates) < 2:
1203
+ return 0.0, jackknife_estimates
1204
+
1205
+ # Jackknife SE formula
1206
+ n = len(jackknife_estimates)
1207
+ mean_est = np.mean(jackknife_estimates)
1208
+ se = np.sqrt((n - 1) / n * np.sum((jackknife_estimates - mean_est) ** 2))
1209
+
1210
+ return se, jackknife_estimates
1211
+
1212
+ def _fit_with_fixed_lambda(
1213
+ self,
1214
+ data: pd.DataFrame,
1215
+ outcome: str,
1216
+ treatment: str,
1217
+ unit: str,
1218
+ time: str,
1219
+ post_periods: List[Any],
1220
+ fixed_lambda: Tuple[float, float, float],
1221
+ ) -> float:
1222
+ """
1223
+ Fit model with fixed tuning parameters (for bootstrap).
1224
+
1225
+ Uses observation-specific weights following Algorithm 2.
1226
+ Returns only the ATT estimate.
1227
+ """
1228
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
1229
+
1230
+ # Setup matrices
1231
+ all_units = sorted(data[unit].unique())
1232
+ all_periods = sorted(data[time].unique())
1233
+
1234
+ n_units = len(all_units)
1235
+ n_periods = len(all_periods)
1236
+
1237
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1238
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
1239
+
1240
+ Y = np.full((n_periods, n_units), np.nan)
1241
+ D = np.zeros((n_periods, n_units), dtype=int)
1242
+
1243
+ for _, row in data.iterrows():
1244
+ i = unit_to_idx[row[unit]]
1245
+ t = period_to_idx[row[time]]
1246
+ Y[t, i] = row[outcome]
1247
+ D[t, i] = int(row[treatment])
1248
+
1249
+ control_mask = D == 0
1250
+
1251
+ # Get control unit indices
1252
+ unit_ever_treated = np.any(D == 1, axis=0)
1253
+ control_unit_idx = np.where(~unit_ever_treated)[0]
1254
+
1255
+ # Get list of treated observations
1256
+ treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
1257
+ if D[t, i] == 1]
1258
+
1259
+ if not treated_observations:
1260
+ raise ValueError("No treated observations")
1261
+
1262
+ # Compute ATT using observation-specific weights (Algorithm 2)
1263
+ tau_values = []
1264
+ for t, i in treated_observations:
1265
+ # Compute observation-specific weights for this (i, t)
1266
+ weight_matrix = self._compute_observation_weights(
1267
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
1268
+ n_units, n_periods
1269
+ )
1270
+
1271
+ # Fit model with these weights
1272
+ alpha, beta, L = self._estimate_model(
1273
+ Y, control_mask, weight_matrix, lambda_nn,
1274
+ n_units, n_periods
1275
+ )
1276
+
1277
+ # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
1278
+ tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
1279
+ tau_values.append(tau)
1280
+
1281
+ return np.mean(tau_values)
1282
+
1283
+ def get_params(self) -> Dict[str, Any]:
1284
+ """Get estimator parameters."""
1285
+ return {
1286
+ "lambda_time_grid": self.lambda_time_grid,
1287
+ "lambda_unit_grid": self.lambda_unit_grid,
1288
+ "lambda_nn_grid": self.lambda_nn_grid,
1289
+ "max_iter": self.max_iter,
1290
+ "tol": self.tol,
1291
+ "alpha": self.alpha,
1292
+ "variance_method": self.variance_method,
1293
+ "n_bootstrap": self.n_bootstrap,
1294
+ "seed": self.seed,
1295
+ }
1296
+
1297
+ def set_params(self, **params) -> "TROP":
1298
+ """Set estimator parameters."""
1299
+ for key, value in params.items():
1300
+ if hasattr(self, key):
1301
+ setattr(self, key, value)
1302
+ else:
1303
+ raise ValueError(f"Unknown parameter: {key}")
1304
+ return self
1305
+
1306
+
1307
+ def trop(
1308
+ data: pd.DataFrame,
1309
+ outcome: str,
1310
+ treatment: str,
1311
+ unit: str,
1312
+ time: str,
1313
+ post_periods: Optional[List[Any]] = None,
1314
+ **kwargs,
1315
+ ) -> TROPResults:
1316
+ """
1317
+ Convenience function for TROP estimation.
1318
+
1319
+ Parameters
1320
+ ----------
1321
+ data : pd.DataFrame
1322
+ Panel data.
1323
+ outcome : str
1324
+ Outcome variable column name.
1325
+ treatment : str
1326
+ Treatment indicator column name.
1327
+ unit : str
1328
+ Unit identifier column name.
1329
+ time : str
1330
+ Time period column name.
1331
+ post_periods : list, optional
1332
+ Post-treatment periods.
1333
+ **kwargs
1334
+ Additional arguments passed to TROP constructor.
1335
+
1336
+ Returns
1337
+ -------
1338
+ TROPResults
1339
+ Estimation results.
1340
+
1341
+ Examples
1342
+ --------
1343
+ >>> from diff_diff import trop
1344
+ >>> results = trop(data, 'y', 'treated', 'unit', 'time', post_periods=[5,6,7])
1345
+ >>> print(f"ATT: {results.att:.3f}")
1346
+ """
1347
+ estimator = TROP(**kwargs)
1348
+ return estimator.fit(data, outcome, treatment, unit, time, post_periods)