diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/trop.py ADDED
@@ -0,0 +1,2904 @@
1
+ """
2
+ Triply Robust Panel (TROP) estimator.
3
+
4
+ Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
5
+ TROP combines three robustness components:
6
+ 1. Nuclear norm regularized factor model (interactive fixed effects)
7
+ 2. Exponential distance-based unit weights
8
+ 3. Exponential time decay weights
9
+
10
+ The estimator uses leave-one-out cross-validation for tuning parameter
11
+ selection and provides robust treatment effect estimates under factor
12
+ confounding.
13
+
14
+ References
15
+ ----------
16
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
17
+ Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
18
+ """
19
+
20
+ import logging
21
+ import warnings
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ from scipy import stats
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ try:
32
+ from typing import TypedDict
33
+ except ImportError:
34
+ from typing_extensions import TypedDict
35
+
36
+ from diff_diff._backend import (
37
+ HAS_RUST_BACKEND,
38
+ _rust_unit_distance_matrix,
39
+ _rust_loocv_grid_search,
40
+ _rust_bootstrap_trop_variance,
41
+ _rust_loocv_grid_search_joint,
42
+ _rust_bootstrap_trop_variance_joint,
43
+ )
44
+ from diff_diff.results import _get_significance_stars
45
+ from diff_diff.utils import compute_confidence_interval, compute_p_value
46
+
47
+
48
+ # Sentinel value for "disabled" λ_nn in LOOCV parameter search.
49
+ # Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0).
50
+ # For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3:
51
+ # exp(-0 × dist) = 1 for all distances.
52
+ _LAMBDA_INF: float = float('inf')
53
+
54
+
55
+ class _PrecomputedStructures(TypedDict):
56
+ """Type definition for pre-computed structures used across LOOCV iterations.
57
+
58
+ These structures are computed once in `_precompute_structures()` and reused
59
+ to avoid redundant computation during LOOCV and final estimation.
60
+ """
61
+
62
+ unit_dist_matrix: np.ndarray
63
+ """Pairwise unit distance matrix (n_units x n_units)."""
64
+ time_dist_matrix: np.ndarray
65
+ """Time distance matrix where [t, s] = |t - s| (n_periods x n_periods)."""
66
+ control_mask: np.ndarray
67
+ """Boolean mask for control observations (D == 0)."""
68
+ treated_mask: np.ndarray
69
+ """Boolean mask for treated observations (D == 1)."""
70
+ treated_observations: List[Tuple[int, int]]
71
+ """List of (t, i) tuples for treated observations."""
72
+ control_obs: List[Tuple[int, int]]
73
+ """List of (t, i) tuples for valid control observations."""
74
+ control_unit_idx: np.ndarray
75
+ """Array of never-treated unit indices (for backward compatibility)."""
76
+ D: np.ndarray
77
+ """Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
78
+ Y: np.ndarray
79
+ """Outcome matrix (n_periods x n_units)."""
80
+ n_units: int
81
+ """Number of units."""
82
+ n_periods: int
83
+ """Number of time periods."""
84
+
85
+
86
+ @dataclass
87
+ class TROPResults:
88
+ """
89
+ Results from a Triply Robust Panel (TROP) estimation.
90
+
91
+ TROP combines nuclear norm regularized factor estimation with
92
+ exponential distance-based unit weights and time decay weights.
93
+
94
+ Attributes
95
+ ----------
96
+ att : float
97
+ Average Treatment effect on the Treated (ATT).
98
+ se : float
99
+ Standard error of the ATT estimate.
100
+ t_stat : float
101
+ T-statistic for the ATT estimate.
102
+ p_value : float
103
+ P-value for the null hypothesis that ATT = 0.
104
+ conf_int : tuple[float, float]
105
+ Confidence interval for the ATT.
106
+ n_obs : int
107
+ Number of observations used in estimation.
108
+ n_treated : int
109
+ Number of treated units.
110
+ n_control : int
111
+ Number of control units.
112
+ n_treated_obs : int
113
+ Number of treated unit-time observations.
114
+ unit_effects : dict
115
+ Estimated unit fixed effects (alpha_i).
116
+ time_effects : dict
117
+ Estimated time fixed effects (beta_t).
118
+ treatment_effects : dict
119
+ Individual treatment effects for each treated (unit, time) pair.
120
+ lambda_time : float
121
+ Selected time weight decay parameter from grid. 0.0 = uniform time
122
+ weights (disabled) per Eq. 3.
123
+ lambda_unit : float
124
+ Selected unit weight decay parameter from grid. 0.0 = uniform unit
125
+ weights (disabled) per Eq. 3.
126
+ lambda_nn : float
127
+ Selected nuclear norm regularization parameter from grid. inf = factor
128
+ model disabled (L=0); converted to 1e10 internally for computation.
129
+ factor_matrix : np.ndarray
130
+ Estimated low-rank factor matrix L (n_periods x n_units).
131
+ effective_rank : float
132
+ Effective rank of the factor matrix (sum of singular values / max).
133
+ loocv_score : float
134
+ Leave-one-out cross-validation score for selected parameters.
135
+ alpha : float
136
+ Significance level for confidence interval.
137
+ n_pre_periods : int
138
+ Number of pre-treatment periods.
139
+ n_post_periods : int
140
+ Number of post-treatment periods (periods with D=1 observations).
141
+ n_bootstrap : int, optional
142
+ Number of bootstrap replications (if bootstrap variance).
143
+ bootstrap_distribution : np.ndarray, optional
144
+ Bootstrap distribution of estimates.
145
+ """
146
+
147
+ att: float
148
+ se: float
149
+ t_stat: float
150
+ p_value: float
151
+ conf_int: Tuple[float, float]
152
+ n_obs: int
153
+ n_treated: int
154
+ n_control: int
155
+ n_treated_obs: int
156
+ unit_effects: Dict[Any, float]
157
+ time_effects: Dict[Any, float]
158
+ treatment_effects: Dict[Tuple[Any, Any], float]
159
+ lambda_time: float
160
+ lambda_unit: float
161
+ lambda_nn: float
162
+ factor_matrix: np.ndarray
163
+ effective_rank: float
164
+ loocv_score: float
165
+ alpha: float = 0.05
166
+ n_pre_periods: int = 0
167
+ n_post_periods: int = 0
168
+ n_bootstrap: Optional[int] = field(default=None)
169
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
170
+
171
+ def __repr__(self) -> str:
172
+ """Concise string representation."""
173
+ sig = _get_significance_stars(self.p_value)
174
+ return (
175
+ f"TROPResults(ATT={self.att:.4f}{sig}, "
176
+ f"SE={self.se:.4f}, "
177
+ f"eff_rank={self.effective_rank:.1f}, "
178
+ f"p={self.p_value:.4f})"
179
+ )
180
+
181
+ def summary(self, alpha: Optional[float] = None) -> str:
182
+ """
183
+ Generate a formatted summary of the estimation results.
184
+
185
+ Parameters
186
+ ----------
187
+ alpha : float, optional
188
+ Significance level for confidence intervals. Defaults to the
189
+ alpha used during estimation.
190
+
191
+ Returns
192
+ -------
193
+ str
194
+ Formatted summary table.
195
+ """
196
+ alpha = alpha or self.alpha
197
+ conf_level = int((1 - alpha) * 100)
198
+
199
+ lines = [
200
+ "=" * 75,
201
+ "Triply Robust Panel (TROP) Estimation Results".center(75),
202
+ "Athey, Imbens, Qu & Viviano (2025)".center(75),
203
+ "=" * 75,
204
+ "",
205
+ f"{'Observations:':<25} {self.n_obs:>10}",
206
+ f"{'Treated units:':<25} {self.n_treated:>10}",
207
+ f"{'Control units:':<25} {self.n_control:>10}",
208
+ f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
209
+ f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}",
210
+ f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}",
211
+ "",
212
+ "-" * 75,
213
+ "Tuning Parameters (selected via LOOCV)".center(75),
214
+ "-" * 75,
215
+ f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}",
216
+ f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}",
217
+ f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}",
218
+ f"{'Effective rank:':<25} {self.effective_rank:>10.2f}",
219
+ f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}",
220
+ ]
221
+
222
+ # Variance info
223
+ if self.n_bootstrap is not None:
224
+ lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
225
+
226
+ lines.extend([
227
+ "",
228
+ "-" * 75,
229
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
230
+ f"{'t-stat':>10} {'P>|t|':>10} {'':>5}",
231
+ "-" * 75,
232
+ f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} "
233
+ f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
234
+ "-" * 75,
235
+ "",
236
+ f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
237
+ ])
238
+
239
+ # Add significance codes
240
+ lines.extend([
241
+ "",
242
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
243
+ "=" * 75,
244
+ ])
245
+
246
+ return "\n".join(lines)
247
+
248
+ def print_summary(self, alpha: Optional[float] = None) -> None:
249
+ """Print the summary to stdout."""
250
+ print(self.summary(alpha))
251
+
252
+ def to_dict(self) -> Dict[str, Any]:
253
+ """
254
+ Convert results to a dictionary.
255
+
256
+ Returns
257
+ -------
258
+ Dict[str, Any]
259
+ Dictionary containing all estimation results.
260
+ """
261
+ return {
262
+ "att": self.att,
263
+ "se": self.se,
264
+ "t_stat": self.t_stat,
265
+ "p_value": self.p_value,
266
+ "conf_int_lower": self.conf_int[0],
267
+ "conf_int_upper": self.conf_int[1],
268
+ "n_obs": self.n_obs,
269
+ "n_treated": self.n_treated,
270
+ "n_control": self.n_control,
271
+ "n_treated_obs": self.n_treated_obs,
272
+ "n_pre_periods": self.n_pre_periods,
273
+ "n_post_periods": self.n_post_periods,
274
+ "lambda_time": self.lambda_time,
275
+ "lambda_unit": self.lambda_unit,
276
+ "lambda_nn": self.lambda_nn,
277
+ "effective_rank": self.effective_rank,
278
+ "loocv_score": self.loocv_score,
279
+ }
280
+
281
+ def to_dataframe(self) -> pd.DataFrame:
282
+ """
283
+ Convert results to a pandas DataFrame.
284
+
285
+ Returns
286
+ -------
287
+ pd.DataFrame
288
+ DataFrame with estimation results.
289
+ """
290
+ return pd.DataFrame([self.to_dict()])
291
+
292
+ def get_treatment_effects_df(self) -> pd.DataFrame:
293
+ """
294
+ Get individual treatment effects as a DataFrame.
295
+
296
+ Returns
297
+ -------
298
+ pd.DataFrame
299
+ DataFrame with unit, time, and treatment effect columns.
300
+ """
301
+ return pd.DataFrame([
302
+ {"unit": unit, "time": time, "effect": effect}
303
+ for (unit, time), effect in self.treatment_effects.items()
304
+ ])
305
+
306
+ def get_unit_effects_df(self) -> pd.DataFrame:
307
+ """
308
+ Get unit fixed effects as a DataFrame.
309
+
310
+ Returns
311
+ -------
312
+ pd.DataFrame
313
+ DataFrame with unit and effect columns.
314
+ """
315
+ return pd.DataFrame([
316
+ {"unit": unit, "effect": effect}
317
+ for unit, effect in self.unit_effects.items()
318
+ ])
319
+
320
+ def get_time_effects_df(self) -> pd.DataFrame:
321
+ """
322
+ Get time fixed effects as a DataFrame.
323
+
324
+ Returns
325
+ -------
326
+ pd.DataFrame
327
+ DataFrame with time and effect columns.
328
+ """
329
+ return pd.DataFrame([
330
+ {"time": time, "effect": effect}
331
+ for time, effect in self.time_effects.items()
332
+ ])
333
+
334
+ @property
335
+ def is_significant(self) -> bool:
336
+ """Check if the ATT is statistically significant at the alpha level."""
337
+ return bool(self.p_value < self.alpha)
338
+
339
+ @property
340
+ def significance_stars(self) -> str:
341
+ """Return significance stars based on p-value."""
342
+ return _get_significance_stars(self.p_value)
343
+
344
+
345
+ class TROP:
346
+ """
347
+ Triply Robust Panel (TROP) estimator.
348
+
349
+ Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
350
+ TROP combines three robustness components:
351
+
352
+ 1. **Nuclear norm regularized factor model**: Estimates interactive fixed
353
+ effects L_it via matrix completion with nuclear norm penalty ||L||_*
354
+
355
+ 2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
356
+ where d(j,i) is the RMSE of outcome differences between units
357
+
358
+ 3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
359
+ weighting pre-treatment periods by proximity to treatment
360
+
361
+ Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
362
+ cross-validation on control observations.
363
+
364
+ Parameters
365
+ ----------
366
+ method : str, default='twostep'
367
+ Estimation method to use:
368
+
369
+ - 'twostep': Per-observation model fitting following Algorithm 2 of
370
+ Athey et al. (2025). Computes observation-specific weights and fits
371
+ a model for each treated observation, averaging the individual
372
+ treatment effects. More flexible but computationally intensive.
373
+
374
+ - 'joint': Joint weighted least squares optimization. Estimates a
375
+ single scalar treatment effect τ along with fixed effects and
376
+ optional low-rank factor adjustment. Faster but assumes homogeneous
377
+ treatment effects. Uses alternating minimization when nuclear norm
378
+ penalty is finite.
379
+
380
+ lambda_time_grid : list, optional
381
+ Grid of time weight decay parameters. 0.0 = uniform weights (disabled).
382
+ Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
383
+ lambda_unit_grid : list, optional
384
+ Grid of unit weight decay parameters. 0.0 = uniform weights (disabled).
385
+ Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
386
+ lambda_nn_grid : list, optional
387
+ Grid of nuclear norm regularization parameters. inf = factor model
388
+ disabled (L=0). Default: [0, 0.01, 0.1, 1].
389
+ max_iter : int, default=100
390
+ Maximum iterations for nuclear norm optimization.
391
+ tol : float, default=1e-6
392
+ Convergence tolerance for optimization.
393
+ alpha : float, default=0.05
394
+ Significance level for confidence intervals.
395
+ n_bootstrap : int, default=200
396
+ Number of bootstrap replications for variance estimation.
397
+ seed : int, optional
398
+ Random seed for reproducibility.
399
+
400
+ Attributes
401
+ ----------
402
+ results_ : TROPResults
403
+ Estimation results after calling fit().
404
+ is_fitted_ : bool
405
+ Whether the model has been fitted.
406
+
407
+ Examples
408
+ --------
409
+ >>> from diff_diff import TROP
410
+ >>> trop = TROP()
411
+ >>> results = trop.fit(
412
+ ... data,
413
+ ... outcome='outcome',
414
+ ... treatment='treated',
415
+ ... unit='unit',
416
+ ... time='period',
417
+ ... )
418
+ >>> results.print_summary()
419
+
420
+ References
421
+ ----------
422
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
423
+ Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
424
+ """
425
+
426
+ # Class constants
427
+ CONVERGENCE_TOL_SVD: float = 1e-10
428
+ """Tolerance for singular value truncation in soft-thresholding.
429
+
430
+ Singular values below this threshold after soft-thresholding are treated
431
+ as zero to improve numerical stability.
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ method: str = "twostep",
437
+ lambda_time_grid: Optional[List[float]] = None,
438
+ lambda_unit_grid: Optional[List[float]] = None,
439
+ lambda_nn_grid: Optional[List[float]] = None,
440
+ max_iter: int = 100,
441
+ tol: float = 1e-6,
442
+ alpha: float = 0.05,
443
+ n_bootstrap: int = 200,
444
+ seed: Optional[int] = None,
445
+ ):
446
+ # Validate method parameter
447
+ valid_methods = ("twostep", "joint")
448
+ if method not in valid_methods:
449
+ raise ValueError(
450
+ f"method must be one of {valid_methods}, got '{method}'"
451
+ )
452
+ self.method = method
453
+
454
+ # Default grids from paper
455
+ self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
456
+ self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
457
+ self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
458
+
459
+ self.max_iter = max_iter
460
+ self.tol = tol
461
+ self.alpha = alpha
462
+ self.n_bootstrap = n_bootstrap
463
+ self.seed = seed
464
+
465
+ # Validate that time/unit grids do not contain inf.
466
+ # Per Athey et al. (2025) Eq. 3, λ_time=0 and λ_unit=0 give uniform
467
+ # weights (exp(-0 × dist) = 1). Using inf is a misunderstanding of
468
+ # the paper's convention. Only λ_nn=∞ is valid (disables factor model).
469
+ for grid_name, grid_vals in [
470
+ ("lambda_time_grid", self.lambda_time_grid),
471
+ ("lambda_unit_grid", self.lambda_unit_grid),
472
+ ]:
473
+ if any(np.isinf(v) for v in grid_vals):
474
+ raise ValueError(
475
+ f"{grid_name} must not contain inf. Use 0.0 for uniform "
476
+ f"weights (disabled) per Athey et al. (2025) Eq. 3: "
477
+ f"exp(-0 × dist) = 1 for all distances."
478
+ )
479
+
480
+ # Internal state
481
+ self.results_: Optional[TROPResults] = None
482
+ self.is_fitted_: bool = False
483
+ self._optimal_lambda: Optional[Tuple[float, float, float]] = None
484
+
485
+ # Pre-computed structures (set during fit)
486
+ self._precomputed: Optional[_PrecomputedStructures] = None
487
+
488
+ def _precompute_structures(
489
+ self,
490
+ Y: np.ndarray,
491
+ D: np.ndarray,
492
+ control_unit_idx: np.ndarray,
493
+ n_units: int,
494
+ n_periods: int,
495
+ ) -> _PrecomputedStructures:
496
+ """
497
+ Pre-compute data structures that are reused across LOOCV and estimation.
498
+
499
+ This method computes once what would otherwise be computed repeatedly:
500
+ - Pairwise unit distance matrix
501
+ - Time distance vectors
502
+ - Masks and indices
503
+
504
+ Parameters
505
+ ----------
506
+ Y : np.ndarray
507
+ Outcome matrix (n_periods x n_units).
508
+ D : np.ndarray
509
+ Treatment indicator matrix (n_periods x n_units).
510
+ control_unit_idx : np.ndarray
511
+ Indices of control units.
512
+ n_units : int
513
+ Number of units.
514
+ n_periods : int
515
+ Number of periods.
516
+
517
+ Returns
518
+ -------
519
+ _PrecomputedStructures
520
+ Pre-computed structures for efficient reuse.
521
+ """
522
+ # Compute pairwise unit distances (for all observation-specific weights)
523
+ # Following Equation 3 (page 7): RMSE between units over pre-treatment
524
+ if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None:
525
+ # Use Rust backend for parallel distance computation (4-8x speedup)
526
+ unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64))
527
+ else:
528
+ unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)
529
+
530
+ # Pre-compute time distance vectors for each target period
531
+ # Time distance: |t - s| for all s and each target t
532
+ time_dist_matrix = np.abs(
533
+ np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
534
+ ) # (n_periods, n_periods) where [t, s] = |t - s|
535
+
536
+ # Control and treatment masks
537
+ control_mask = D == 0
538
+ treated_mask = D == 1
539
+
540
+ # Identify treated observations
541
+ treated_observations = list(zip(*np.where(treated_mask)))
542
+
543
+ # Control observations for LOOCV
544
+ control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
545
+ if control_mask[t, i] and not np.isnan(Y[t, i])]
546
+
547
+ return {
548
+ "unit_dist_matrix": unit_dist_matrix,
549
+ "time_dist_matrix": time_dist_matrix,
550
+ "control_mask": control_mask,
551
+ "treated_mask": treated_mask,
552
+ "treated_observations": treated_observations,
553
+ "control_obs": control_obs,
554
+ "control_unit_idx": control_unit_idx,
555
+ "D": D,
556
+ "Y": Y,
557
+ "n_units": n_units,
558
+ "n_periods": n_periods,
559
+ }
560
+
561
+ def _compute_all_unit_distances(
562
+ self,
563
+ Y: np.ndarray,
564
+ D: np.ndarray,
565
+ n_units: int,
566
+ n_periods: int,
567
+ ) -> np.ndarray:
568
+ """
569
+ Compute pairwise unit distance matrix using vectorized operations.
570
+
571
+ Following Equation 3 (page 7):
572
+ dist_unit_{-t}(j, i) = sqrt(Σ_u (Y_{iu} - Y_{ju})² / n_valid)
573
+
574
+ For efficiency, we compute a base distance matrix excluding all treated
575
+ observations, which provides a good approximation. The exact per-observation
576
+ distances are refined when needed.
577
+
578
+ Uses vectorized numpy operations with masked arrays for O(n²) complexity
579
+ but with highly optimized inner loops via numpy/BLAS.
580
+
581
+ Parameters
582
+ ----------
583
+ Y : np.ndarray
584
+ Outcome matrix (n_periods x n_units).
585
+ D : np.ndarray
586
+ Treatment indicator matrix (n_periods x n_units).
587
+ n_units : int
588
+ Number of units.
589
+ n_periods : int
590
+ Number of periods.
591
+
592
+ Returns
593
+ -------
594
+ np.ndarray
595
+ Pairwise distance matrix (n_units x n_units).
596
+ """
597
+ # Mask for valid observations: control periods only (D=0), non-NaN
598
+ valid_mask = (D == 0) & ~np.isnan(Y)
599
+
600
+ # Replace invalid values with NaN for masked computation
601
+ Y_masked = np.where(valid_mask, Y, np.nan)
602
+
603
+ # Transpose to (n_units, n_periods) for easier broadcasting
604
+ Y_T = Y_masked.T # (n_units, n_periods)
605
+
606
+ # Compute pairwise squared differences using broadcasting
607
+ # Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
608
+ # Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
609
+ # diff has shape (n_units, n_units, n_periods)
610
+ diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :]
611
+ sq_diff = diff ** 2
612
+
613
+ # Count valid (non-NaN) observations per pair
614
+ # A difference is valid only if both units have valid observations
615
+ valid_diff = ~np.isnan(sq_diff)
616
+ n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units)
617
+
618
+ # Compute sum of squared differences (treating NaN as 0)
619
+ sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units)
620
+
621
+ # Compute RMSE distance: sqrt(sum / n_valid)
622
+ # Avoid division by zero
623
+ with np.errstate(divide='ignore', invalid='ignore'):
624
+ dist_matrix = np.sqrt(sq_diff_sum / n_valid)
625
+
626
+ # Set pairs with no valid observations to inf
627
+ dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf)
628
+
629
+ # Ensure diagonal is 0 (same unit distance)
630
+ np.fill_diagonal(dist_matrix, 0.0)
631
+
632
+ return dist_matrix
633
+
634
+ def _compute_unit_distance_for_obs(
635
+ self,
636
+ Y: np.ndarray,
637
+ D: np.ndarray,
638
+ j: int,
639
+ i: int,
640
+ target_period: int,
641
+ ) -> float:
642
+ """
643
+ Compute observation-specific pairwise distance from unit j to unit i.
644
+
645
+ This is the exact computation from Equation 3, excluding the target period.
646
+ Used when the base distance matrix approximation is insufficient.
647
+
648
+ Parameters
649
+ ----------
650
+ Y : np.ndarray
651
+ Outcome matrix (n_periods x n_units).
652
+ D : np.ndarray
653
+ Treatment indicator matrix.
654
+ j : int
655
+ Control unit index.
656
+ i : int
657
+ Treated unit index.
658
+ target_period : int
659
+ Target period to exclude.
660
+
661
+ Returns
662
+ -------
663
+ float
664
+ Pairwise RMSE distance.
665
+ """
666
+ n_periods = Y.shape[0]
667
+
668
+ # Mask: exclude target period, both units must be untreated, non-NaN
669
+ valid = np.ones(n_periods, dtype=bool)
670
+ valid[target_period] = False
671
+ valid &= (D[:, i] == 0) & (D[:, j] == 0)
672
+ valid &= ~np.isnan(Y[:, i]) & ~np.isnan(Y[:, j])
673
+
674
+ if np.any(valid):
675
+ sq_diffs = (Y[valid, i] - Y[valid, j]) ** 2
676
+ return np.sqrt(np.mean(sq_diffs))
677
+ else:
678
+ return np.inf
679
+
680
+ def _univariate_loocv_search(
681
+ self,
682
+ Y: np.ndarray,
683
+ D: np.ndarray,
684
+ control_mask: np.ndarray,
685
+ control_unit_idx: np.ndarray,
686
+ n_units: int,
687
+ n_periods: int,
688
+ param_name: str,
689
+ grid: List[float],
690
+ fixed_params: Dict[str, float],
691
+ ) -> Tuple[float, float]:
692
+ """
693
+ Search over one parameter with others fixed.
694
+
695
+ Following paper's footnote 2, this performs a univariate grid search
696
+ for one tuning parameter while holding others fixed. The fixed_params
697
+ use 0.0 for disabled time/unit weights and _LAMBDA_INF for disabled
698
+ factor model:
699
+ - lambda_nn = inf: Skip nuclear norm regularization (L=0)
700
+ - lambda_time = 0.0: Uniform time weights (exp(-0×dist)=1)
701
+ - lambda_unit = 0.0: Uniform unit weights (exp(-0×dist)=1)
702
+
703
+ Parameters
704
+ ----------
705
+ Y : np.ndarray
706
+ Outcome matrix (n_periods x n_units).
707
+ D : np.ndarray
708
+ Treatment indicator matrix (n_periods x n_units).
709
+ control_mask : np.ndarray
710
+ Boolean mask for control observations.
711
+ control_unit_idx : np.ndarray
712
+ Indices of control units.
713
+ n_units : int
714
+ Number of units.
715
+ n_periods : int
716
+ Number of periods.
717
+ param_name : str
718
+ Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
719
+ grid : List[float]
720
+ Grid of values to search over.
721
+ fixed_params : Dict[str, float]
722
+ Fixed values for other parameters. May include _LAMBDA_INF for lambda_nn.
723
+
724
+ Returns
725
+ -------
726
+ Tuple[float, float]
727
+ (best_value, best_score) for the searched parameter.
728
+ """
729
+ best_score = np.inf
730
+ best_value = grid[0] if grid else 0.0
731
+
732
+ for value in grid:
733
+ params = {**fixed_params, param_name: value}
734
+
735
+ lambda_time = params.get('lambda_time', 0.0)
736
+ lambda_unit = params.get('lambda_unit', 0.0)
737
+ lambda_nn = params.get('lambda_nn', 0.0)
738
+
739
+ # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
740
+ # λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed)
741
+ if np.isinf(lambda_nn):
742
+ lambda_nn = 1e10
743
+
744
+ try:
745
+ score = self._loocv_score_obs_specific(
746
+ Y, D, control_mask, control_unit_idx,
747
+ lambda_time, lambda_unit, lambda_nn,
748
+ n_units, n_periods
749
+ )
750
+ if score < best_score:
751
+ best_score = score
752
+ best_value = value
753
+ except (np.linalg.LinAlgError, ValueError):
754
+ continue
755
+
756
+ return best_value, best_score
757
+
758
+ def _cycling_parameter_search(
759
+ self,
760
+ Y: np.ndarray,
761
+ D: np.ndarray,
762
+ control_mask: np.ndarray,
763
+ control_unit_idx: np.ndarray,
764
+ n_units: int,
765
+ n_periods: int,
766
+ initial_lambda: Tuple[float, float, float],
767
+ max_cycles: int = 10,
768
+ ) -> Tuple[float, float, float]:
769
+ """
770
+ Cycle through parameters until convergence (coordinate descent).
771
+
772
+ Following paper's footnote 2 (Stage 2), this iteratively optimizes
773
+ each tuning parameter while holding the others fixed, until convergence.
774
+
775
+ Parameters
776
+ ----------
777
+ Y : np.ndarray
778
+ Outcome matrix (n_periods x n_units).
779
+ D : np.ndarray
780
+ Treatment indicator matrix (n_periods x n_units).
781
+ control_mask : np.ndarray
782
+ Boolean mask for control observations.
783
+ control_unit_idx : np.ndarray
784
+ Indices of control units.
785
+ n_units : int
786
+ Number of units.
787
+ n_periods : int
788
+ Number of periods.
789
+ initial_lambda : Tuple[float, float, float]
790
+ Initial values (lambda_time, lambda_unit, lambda_nn).
791
+ max_cycles : int, default=10
792
+ Maximum number of coordinate descent cycles.
793
+
794
+ Returns
795
+ -------
796
+ Tuple[float, float, float]
797
+ Optimized (lambda_time, lambda_unit, lambda_nn).
798
+ """
799
+ lambda_time, lambda_unit, lambda_nn = initial_lambda
800
+ prev_score = np.inf
801
+
802
+ for cycle in range(max_cycles):
803
+ # Optimize λ_unit (fix λ_time, λ_nn)
804
+ lambda_unit, _ = self._univariate_loocv_search(
805
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
806
+ 'lambda_unit', self.lambda_unit_grid,
807
+ {'lambda_time': lambda_time, 'lambda_nn': lambda_nn}
808
+ )
809
+
810
+ # Optimize λ_time (fix λ_unit, λ_nn)
811
+ lambda_time, _ = self._univariate_loocv_search(
812
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
813
+ 'lambda_time', self.lambda_time_grid,
814
+ {'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn}
815
+ )
816
+
817
+ # Optimize λ_nn (fix λ_unit, λ_time)
818
+ lambda_nn, score = self._univariate_loocv_search(
819
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
820
+ 'lambda_nn', self.lambda_nn_grid,
821
+ {'lambda_unit': lambda_unit, 'lambda_time': lambda_time}
822
+ )
823
+
824
+ # Check convergence
825
+ if abs(score - prev_score) < 1e-6:
826
+ logger.debug(
827
+ "Cycling search converged after %d cycles with score %.6f",
828
+ cycle + 1, score
829
+ )
830
+ break
831
+ prev_score = score
832
+
833
+ return lambda_time, lambda_unit, lambda_nn
834
+
835
+ # =========================================================================
836
+ # Joint estimation method
837
+ # =========================================================================
838
+
839
+ def _compute_joint_weights(
840
+ self,
841
+ Y: np.ndarray,
842
+ D: np.ndarray,
843
+ lambda_time: float,
844
+ lambda_unit: float,
845
+ treated_periods: int,
846
+ n_units: int,
847
+ n_periods: int,
848
+ ) -> np.ndarray:
849
+ """
850
+ Compute distance-based weights for joint estimation.
851
+
852
+ Following the reference implementation, weights are computed based on:
853
+ - Time distance: distance to center of treated block
854
+ - Unit distance: RMSE to average treated trajectory over pre-periods
855
+
856
+ Parameters
857
+ ----------
858
+ Y : np.ndarray
859
+ Outcome matrix (n_periods x n_units).
860
+ D : np.ndarray
861
+ Treatment indicator matrix (n_periods x n_units).
862
+ lambda_time : float
863
+ Time weight decay parameter.
864
+ lambda_unit : float
865
+ Unit weight decay parameter.
866
+ treated_periods : int
867
+ Number of post-treatment periods.
868
+ n_units : int
869
+ Number of units.
870
+ n_periods : int
871
+ Number of periods.
872
+
873
+ Returns
874
+ -------
875
+ np.ndarray
876
+ Weight matrix (n_periods x n_units).
877
+ """
878
+ # Identify treated units (ever treated)
879
+ treated_mask = np.any(D == 1, axis=0)
880
+ treated_unit_idx = np.where(treated_mask)[0]
881
+
882
+ if len(treated_unit_idx) == 0:
883
+ raise ValueError("No treated units found")
884
+
885
+ # Time weights: distance to center of treated block
886
+ # Following reference: center = T - treated_periods/2
887
+ center = n_periods - treated_periods / 2.0
888
+ dist_time = np.abs(np.arange(n_periods, dtype=float) - center)
889
+ delta_time = np.exp(-lambda_time * dist_time)
890
+
891
+ # Unit weights: RMSE to average treated trajectory over pre-periods
892
+ # Compute average treated trajectory (use nanmean to handle NaN)
893
+ average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1)
894
+
895
+ # Pre-period mask: 1 in pre, 0 in post
896
+ pre_mask = np.ones(n_periods, dtype=float)
897
+ pre_mask[-treated_periods:] = 0.0
898
+
899
+ # Compute RMS distance for each unit
900
+ # dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre)
901
+ # Use NaN-safe operations: treat NaN differences as 0 (excluded)
902
+ diff = average_treated[:, np.newaxis] - Y
903
+ diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis]
904
+
905
+ # Count valid observations per unit in pre-period
906
+ # Must check diff is finite (both Y and average_treated finite)
907
+ # to match the periods contributing to diff_sq
908
+ valid_count = np.sum(
909
+ np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0
910
+ )
911
+ sum_sq = np.sum(diff_sq, axis=0)
912
+ n_pre = np.sum(pre_mask)
913
+
914
+ if n_pre == 0:
915
+ raise ValueError("No pre-treatment periods")
916
+
917
+ # Track units with no valid pre-period data
918
+ no_valid_pre = valid_count == 0
919
+
920
+ # Use valid count per unit (avoid division by zero for calculation)
921
+ valid_count_safe = np.maximum(valid_count, 1)
922
+ dist_unit = np.sqrt(sum_sq / valid_count_safe)
923
+
924
+ # Units with no valid pre-period data get zero weight
925
+ # (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0)
926
+ delta_unit = np.exp(-lambda_unit * dist_unit)
927
+ delta_unit[no_valid_pre] = 0.0
928
+
929
+ # Outer product: (n_periods x n_units)
930
+ delta = np.outer(delta_time, delta_unit)
931
+
932
+ return delta
933
+
934
+ def _loocv_score_joint(
935
+ self,
936
+ Y: np.ndarray,
937
+ D: np.ndarray,
938
+ control_obs: List[Tuple[int, int]],
939
+ lambda_time: float,
940
+ lambda_unit: float,
941
+ lambda_nn: float,
942
+ treated_periods: int,
943
+ n_units: int,
944
+ n_periods: int,
945
+ ) -> float:
946
+ """
947
+ Compute LOOCV score for joint method with specific parameter combination.
948
+
949
+ Following paper's Equation 5:
950
+ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
951
+
952
+ For joint method, we exclude each control observation, fit the joint model
953
+ on remaining data, and compute the pseudo-treatment effect at the excluded obs.
954
+
955
+ Parameters
956
+ ----------
957
+ Y : np.ndarray
958
+ Outcome matrix (n_periods x n_units).
959
+ D : np.ndarray
960
+ Treatment indicator matrix (n_periods x n_units).
961
+ control_obs : List[Tuple[int, int]]
962
+ List of (t, i) control observations for LOOCV.
963
+ lambda_time : float
964
+ Time weight decay parameter.
965
+ lambda_unit : float
966
+ Unit weight decay parameter.
967
+ lambda_nn : float
968
+ Nuclear norm regularization parameter.
969
+ treated_periods : int
970
+ Number of post-treatment periods.
971
+ n_units : int
972
+ Number of units.
973
+ n_periods : int
974
+ Number of periods.
975
+
976
+ Returns
977
+ -------
978
+ float
979
+ LOOCV score (sum of squared pseudo-treatment effects).
980
+ """
981
+ # Compute global weights (same for all LOOCV iterations)
982
+ delta = self._compute_joint_weights(
983
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
984
+ )
985
+
986
+ tau_sq_sum = 0.0
987
+ n_valid = 0
988
+
989
+ for t_ex, i_ex in control_obs:
990
+ # Create modified delta with excluded observation zeroed out
991
+ delta_ex = delta.copy()
992
+ delta_ex[t_ex, i_ex] = 0.0
993
+
994
+ try:
995
+ # Fit joint model excluding this observation
996
+ if lambda_nn >= 1e10:
997
+ mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta_ex)
998
+ L = np.zeros((n_periods, n_units))
999
+ else:
1000
+ mu, alpha, beta, L, tau = self._solve_joint_with_lowrank(
1001
+ Y, D, delta_ex, lambda_nn, self.max_iter, self.tol
1002
+ )
1003
+
1004
+ # Pseudo treatment effect: τ = Y - μ - α - β - L
1005
+ if np.isfinite(Y[t_ex, i_ex]):
1006
+ tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex]
1007
+ tau_sq_sum += tau_loocv ** 2
1008
+ n_valid += 1
1009
+
1010
+ except (np.linalg.LinAlgError, ValueError):
1011
+ # Any failure means this λ combination is invalid per Equation 5
1012
+ return np.inf
1013
+
1014
+ if n_valid == 0:
1015
+ return np.inf
1016
+
1017
+ return tau_sq_sum
1018
+
1019
+ def _solve_joint_no_lowrank(
1020
+ self,
1021
+ Y: np.ndarray,
1022
+ D: np.ndarray,
1023
+ delta: np.ndarray,
1024
+ ) -> Tuple[float, np.ndarray, np.ndarray, float]:
1025
+ """
1026
+ Solve joint TWFE + treatment via weighted least squares (no low-rank).
1027
+
1028
+ Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})²
1029
+
1030
+ Parameters
1031
+ ----------
1032
+ Y : np.ndarray
1033
+ Outcome matrix (n_periods x n_units).
1034
+ D : np.ndarray
1035
+ Treatment indicator matrix (n_periods x n_units).
1036
+ delta : np.ndarray
1037
+ Weight matrix (n_periods x n_units).
1038
+
1039
+ Returns
1040
+ -------
1041
+ Tuple[float, np.ndarray, np.ndarray, float]
1042
+ (mu, alpha, beta, tau) estimated parameters.
1043
+ """
1044
+ n_periods, n_units = Y.shape
1045
+
1046
+ # Flatten matrices for regression
1047
+ y = Y.flatten() # length n_periods * n_units
1048
+ w = D.flatten()
1049
+ weights = delta.flatten()
1050
+
1051
+ # Handle NaN values: zero weight for NaN outcomes/weights, impute with 0
1052
+ # This ensures NaN observations don't contribute to estimation
1053
+ valid_y = np.isfinite(y)
1054
+ valid_w = np.isfinite(weights)
1055
+ valid_mask = valid_y & valid_w
1056
+ weights = np.where(valid_mask, weights, 0.0)
1057
+ y = np.where(valid_mask, y, 0.0)
1058
+
1059
+ sqrt_weights = np.sqrt(np.maximum(weights, 0))
1060
+
1061
+ # Check for all-zero weights (matches Rust's sum_w < 1e-10 check)
1062
+ sum_w = np.sum(weights)
1063
+ if sum_w < 1e-10:
1064
+ raise ValueError("All weights are zero - cannot estimate")
1065
+
1066
+ # Build design matrix: [intercept, unit_dummies, time_dummies, treatment]
1067
+ # Total columns: 1 + n_units + n_periods + 1
1068
+ # But we need to drop one unit and one time dummy for identification
1069
+ # Drop first unit (unit 0) and first time (time 0)
1070
+ n_obs = n_periods * n_units
1071
+ n_params = 1 + (n_units - 1) + (n_periods - 1) + 1
1072
+
1073
+ X = np.zeros((n_obs, n_params))
1074
+ X[:, 0] = 1.0 # intercept
1075
+
1076
+ # Unit dummies (skip unit 0)
1077
+ for i in range(1, n_units):
1078
+ for t in range(n_periods):
1079
+ X[t * n_units + i, i] = 1.0
1080
+
1081
+ # Time dummies (skip time 0)
1082
+ for t in range(1, n_periods):
1083
+ for i in range(n_units):
1084
+ X[t * n_units + i, (n_units - 1) + t] = 1.0
1085
+
1086
+ # Treatment indicator
1087
+ X[:, -1] = w
1088
+
1089
+ # Apply weights
1090
+ X_weighted = X * sqrt_weights[:, np.newaxis]
1091
+ y_weighted = y * sqrt_weights
1092
+
1093
+ # Solve weighted least squares
1094
+ try:
1095
+ coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None)
1096
+ except np.linalg.LinAlgError:
1097
+ # Fallback: use pseudo-inverse
1098
+ coeffs = np.linalg.pinv(X_weighted) @ y_weighted
1099
+
1100
+ # Extract parameters
1101
+ mu = coeffs[0]
1102
+ alpha = np.zeros(n_units)
1103
+ alpha[1:] = coeffs[1:n_units]
1104
+ beta = np.zeros(n_periods)
1105
+ beta[1:] = coeffs[n_units:(n_units + n_periods - 1)]
1106
+ tau = coeffs[-1]
1107
+
1108
+ return float(mu), alpha, beta, float(tau)
1109
+
1110
+ def _solve_joint_with_lowrank(
1111
+ self,
1112
+ Y: np.ndarray,
1113
+ D: np.ndarray,
1114
+ delta: np.ndarray,
1115
+ lambda_nn: float,
1116
+ max_iter: int = 100,
1117
+ tol: float = 1e-6,
1118
+ ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]:
1119
+ """
1120
+ Solve joint TWFE + treatment + low-rank via alternating minimization.
1121
+
1122
+ Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_*
1123
+
1124
+ Parameters
1125
+ ----------
1126
+ Y : np.ndarray
1127
+ Outcome matrix (n_periods x n_units).
1128
+ D : np.ndarray
1129
+ Treatment indicator matrix (n_periods x n_units).
1130
+ delta : np.ndarray
1131
+ Weight matrix (n_periods x n_units).
1132
+ lambda_nn : float
1133
+ Nuclear norm regularization parameter.
1134
+ max_iter : int, default=100
1135
+ Maximum iterations for alternating minimization.
1136
+ tol : float, default=1e-6
1137
+ Convergence tolerance.
1138
+
1139
+ Returns
1140
+ -------
1141
+ Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]
1142
+ (mu, alpha, beta, L, tau) estimated parameters.
1143
+ """
1144
+ n_periods, n_units = Y.shape
1145
+
1146
+ # Handle NaN values: impute with 0 for computations
1147
+ # The solver will also zero weights for NaN observations
1148
+ Y_safe = np.where(np.isfinite(Y), Y, 0.0)
1149
+
1150
+ # Mask delta to exclude NaN outcomes from estimation
1151
+ # This ensures NaN observations don't contribute to the gradient step
1152
+ nan_mask = ~np.isfinite(Y)
1153
+ delta_masked = delta.copy()
1154
+ delta_masked[nan_mask] = 0.0
1155
+
1156
+ # Initialize L = 0
1157
+ L = np.zeros((n_periods, n_units))
1158
+
1159
+ for iteration in range(max_iter):
1160
+ L_old = L.copy()
1161
+
1162
+ # Step 1: Fix L, solve for (mu, alpha, beta, tau)
1163
+ # Adjusted outcome: Y - L (using NaN-safe Y)
1164
+ # Pass masked delta to exclude NaN observations from WLS
1165
+ Y_adj = Y_safe - L
1166
+ mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta_masked)
1167
+
1168
+ # Step 2: Fix (mu, alpha, beta, tau), update L
1169
+ # Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y)
1170
+ R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D
1171
+
1172
+ # Weighted proximal step for L (soft-threshold SVD)
1173
+ # Normalize weights (using masked delta to exclude NaN observations)
1174
+ delta_max = np.max(delta_masked)
1175
+ if delta_max > 0:
1176
+ delta_norm = delta_masked / delta_max
1177
+ else:
1178
+ delta_norm = delta_masked
1179
+
1180
+ # Weighted average between current L and target R
1181
+ # L_next = L + delta_norm * (R - L), then soft-threshold
1182
+ # NaN observations have delta_norm=0, so they don't influence L update
1183
+ gradient_step = L + delta_norm * (R - L)
1184
+
1185
+ # Soft-threshold singular values
1186
+ # Use eta * lambda_nn for proper proximal step size (matches Rust)
1187
+ eta = 1.0 / delta_max if delta_max > 0 else 1.0
1188
+ L = self._soft_threshold_svd(gradient_step, eta * lambda_nn)
1189
+
1190
+ # Check convergence
1191
+ if np.max(np.abs(L - L_old)) < tol:
1192
+ break
1193
+
1194
+ return mu, alpha, beta, L, tau
1195
+
1196
+ def _fit_joint(
1197
+ self,
1198
+ data: pd.DataFrame,
1199
+ outcome: str,
1200
+ treatment: str,
1201
+ unit: str,
1202
+ time: str,
1203
+ ) -> TROPResults:
1204
+ """
1205
+ Fit TROP using joint weighted least squares method.
1206
+
1207
+ This method estimates a single scalar treatment effect τ along with
1208
+ fixed effects and optional low-rank factor adjustment.
1209
+
1210
+ Parameters
1211
+ ----------
1212
+ data : pd.DataFrame
1213
+ Panel data.
1214
+ outcome : str
1215
+ Outcome variable column name.
1216
+ treatment : str
1217
+ Treatment indicator column name.
1218
+ unit : str
1219
+ Unit identifier column name.
1220
+ time : str
1221
+ Time period column name.
1222
+
1223
+ Returns
1224
+ -------
1225
+ TROPResults
1226
+ Estimation results.
1227
+
1228
+ Notes
1229
+ -----
1230
+ Bootstrap variance estimation assumes simultaneous treatment adoption
1231
+ (fixed `treated_periods` across resamples). The treatment timing is
1232
+ inferred from the data once and held constant for all bootstrap
1233
+ iterations. For staggered adoption designs where treatment timing varies
1234
+ across units, use `method="twostep"` which computes observation-specific
1235
+ weights that naturally handle heterogeneous timing.
1236
+ """
1237
+ # Data setup (same as twostep method)
1238
+ all_units = sorted(data[unit].unique())
1239
+ all_periods = sorted(data[time].unique())
1240
+
1241
+ n_units = len(all_units)
1242
+ n_periods = len(all_periods)
1243
+
1244
+ idx_to_unit = {i: u for i, u in enumerate(all_units)}
1245
+ idx_to_period = {i: p for i, p in enumerate(all_periods)}
1246
+
1247
+ # Create matrices
1248
+ Y = (
1249
+ data.pivot(index=time, columns=unit, values=outcome)
1250
+ .reindex(index=all_periods, columns=all_units)
1251
+ .values
1252
+ )
1253
+
1254
+ D_raw = (
1255
+ data.pivot(index=time, columns=unit, values=treatment)
1256
+ .reindex(index=all_periods, columns=all_units)
1257
+ )
1258
+ missing_mask = pd.isna(D_raw).values
1259
+ D = D_raw.fillna(0).astype(int).values
1260
+
1261
+ # Validate absorbing state
1262
+ violating_units = []
1263
+ for unit_idx in range(n_units):
1264
+ observed_mask = ~missing_mask[:, unit_idx]
1265
+ observed_d = D[observed_mask, unit_idx]
1266
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
1267
+ violating_units.append(all_units[unit_idx])
1268
+
1269
+ if violating_units:
1270
+ raise ValueError(
1271
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
1272
+ f"D[t, unit] must be monotonic non-decreasing."
1273
+ )
1274
+
1275
+ # Identify treated observations
1276
+ treated_mask = D == 1
1277
+ n_treated_obs = np.sum(treated_mask)
1278
+
1279
+ if n_treated_obs == 0:
1280
+ raise ValueError("No treated observations found")
1281
+
1282
+ # Identify treated and control units
1283
+ unit_ever_treated = np.any(D == 1, axis=0)
1284
+ treated_unit_idx = np.where(unit_ever_treated)[0]
1285
+ control_unit_idx = np.where(~unit_ever_treated)[0]
1286
+
1287
+ if len(control_unit_idx) == 0:
1288
+ raise ValueError("No control units found")
1289
+
1290
+ # Determine pre/post periods
1291
+ first_treat_period = None
1292
+ for t in range(n_periods):
1293
+ if np.any(D[t, :] == 1):
1294
+ first_treat_period = t
1295
+ break
1296
+
1297
+ if first_treat_period is None:
1298
+ raise ValueError("Could not infer post-treatment periods from D matrix")
1299
+
1300
+ n_pre_periods = first_treat_period
1301
+ treated_periods = n_periods - first_treat_period
1302
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
1303
+
1304
+ if n_pre_periods < 2:
1305
+ raise ValueError("Need at least 2 pre-treatment periods")
1306
+
1307
+ # Check for staggered adoption (joint method requires simultaneous treatment)
1308
+ # Use only observed periods (skip missing) to avoid false positives on unbalanced panels
1309
+ first_treat_by_unit = []
1310
+ for i in treated_unit_idx:
1311
+ observed_mask = ~missing_mask[:, i]
1312
+ # Get D values for observed periods only
1313
+ observed_d = D[observed_mask, i]
1314
+ observed_periods = np.where(observed_mask)[0]
1315
+ # Find first treatment among observed periods
1316
+ treated_idx = np.where(observed_d == 1)[0]
1317
+ if len(treated_idx) > 0:
1318
+ first_treat_by_unit.append(observed_periods[treated_idx[0]])
1319
+
1320
+ unique_starts = sorted(set(first_treat_by_unit))
1321
+ if len(unique_starts) > 1:
1322
+ raise ValueError(
1323
+ f"method='joint' requires simultaneous treatment adoption, but your data "
1324
+ f"shows staggered adoption (units first treated at periods {unique_starts}). "
1325
+ f"Use method='twostep' which properly handles staggered adoption designs."
1326
+ )
1327
+
1328
+ # LOOCV grid search for tuning parameters
1329
+ # Use Rust backend when available for parallel LOOCV (5-10x speedup)
1330
+ best_lambda = None
1331
+ best_score = np.inf
1332
+ control_mask = D == 0
1333
+
1334
+ if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None:
1335
+ try:
1336
+ # Prepare inputs for Rust function
1337
+ control_mask_u8 = control_mask.astype(np.uint8)
1338
+
1339
+ lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
1340
+ lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
1341
+ lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
1342
+
1343
+ result = _rust_loocv_grid_search_joint(
1344
+ Y, D.astype(np.float64), control_mask_u8,
1345
+ lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
1346
+ self.max_iter, self.tol,
1347
+ )
1348
+ # Unpack result - 7 values including optional first_failed_obs
1349
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
1350
+ # Only accept finite scores - infinite means all fits failed
1351
+ if np.isfinite(best_score):
1352
+ best_lambda = (best_lt, best_lu, best_ln)
1353
+ # Emit warnings consistent with Python implementation
1354
+ if n_valid == 0:
1355
+ obs_info = ""
1356
+ if first_failed_obs is not None:
1357
+ t_idx, i_idx = first_failed_obs
1358
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1359
+ warnings.warn(
1360
+ f"LOOCV: All {n_attempted} fits failed for "
1361
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1362
+ f"Returning infinite score.{obs_info}",
1363
+ UserWarning
1364
+ )
1365
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
1366
+ n_failed = n_attempted - n_valid
1367
+ obs_info = ""
1368
+ if first_failed_obs is not None:
1369
+ t_idx, i_idx = first_failed_obs
1370
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1371
+ warnings.warn(
1372
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
1373
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1374
+ f"This may indicate numerical instability.{obs_info}",
1375
+ UserWarning
1376
+ )
1377
+ except Exception as e:
1378
+ # Fall back to Python implementation on error
1379
+ logger.debug(
1380
+ "Rust LOOCV grid search (joint) failed, falling back to Python: %s", e
1381
+ )
1382
+ best_lambda = None
1383
+ best_score = np.inf
1384
+
1385
+ # Fall back to Python implementation if Rust unavailable or failed
1386
+ if best_lambda is None:
1387
+ # Get control observations for LOOCV
1388
+ control_obs = [
1389
+ (t, i) for t in range(n_periods) for i in range(n_units)
1390
+ if control_mask[t, i] and not np.isnan(Y[t, i])
1391
+ ]
1392
+
1393
+ # Grid search with true LOOCV
1394
+ for lambda_time_val in self.lambda_time_grid:
1395
+ for lambda_unit_val in self.lambda_unit_grid:
1396
+ for lambda_nn_val in self.lambda_nn_grid:
1397
+ # Convert λ_nn=∞ → large finite value (factor model disabled)
1398
+ lt = lambda_time_val
1399
+ lu = lambda_unit_val
1400
+ ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val
1401
+
1402
+ try:
1403
+ score = self._loocv_score_joint(
1404
+ Y, D, control_obs, lt, lu, ln,
1405
+ treated_periods, n_units, n_periods
1406
+ )
1407
+
1408
+ if score < best_score:
1409
+ best_score = score
1410
+ best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val)
1411
+
1412
+ except (np.linalg.LinAlgError, ValueError):
1413
+ continue
1414
+
1415
+ if best_lambda is None:
1416
+ warnings.warn(
1417
+ "All tuning parameter combinations failed. Using defaults.",
1418
+ UserWarning
1419
+ )
1420
+ best_lambda = (1.0, 1.0, 0.1)
1421
+ best_score = np.nan
1422
+
1423
+ # Final estimation with best parameters
1424
+ lambda_time, lambda_unit, lambda_nn = best_lambda
1425
+ original_lambda_nn = lambda_nn
1426
+
1427
+ # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
1428
+ # λ_time and λ_unit use 0.0 for uniform weights directly (no conversion needed)
1429
+ if np.isinf(lambda_nn):
1430
+ lambda_nn = 1e10
1431
+
1432
+ # Compute final weights and fit
1433
+ delta = self._compute_joint_weights(
1434
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1435
+ )
1436
+
1437
+ if lambda_nn >= 1e10:
1438
+ mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta)
1439
+ L = np.zeros((n_periods, n_units))
1440
+ else:
1441
+ mu, alpha, beta, L, tau = self._solve_joint_with_lowrank(
1442
+ Y, D, delta, lambda_nn, self.max_iter, self.tol
1443
+ )
1444
+
1445
+ # ATT is the scalar treatment effect
1446
+ att = tau
1447
+
1448
+ # Compute individual treatment effects for reporting (same τ for all)
1449
+ treatment_effects = {}
1450
+ for t in range(n_periods):
1451
+ for i in range(n_units):
1452
+ if D[t, i] == 1:
1453
+ unit_id = idx_to_unit[i]
1454
+ time_id = idx_to_period[t]
1455
+ treatment_effects[(unit_id, time_id)] = tau
1456
+
1457
+ # Compute effective rank of L
1458
+ _, s, _ = np.linalg.svd(L, full_matrices=False)
1459
+ if s[0] > 0:
1460
+ effective_rank = np.sum(s) / s[0]
1461
+ else:
1462
+ effective_rank = 0.0
1463
+
1464
+ # Bootstrap variance estimation
1465
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
1466
+
1467
+ se, bootstrap_dist = self._bootstrap_variance_joint(
1468
+ data, outcome, treatment, unit, time,
1469
+ effective_lambda, treated_periods
1470
+ )
1471
+
1472
+ # Compute test statistics
1473
+ if se > 0:
1474
+ t_stat = att / se
1475
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
1476
+ conf_int = compute_confidence_interval(att, se, self.alpha)
1477
+ else:
1478
+ t_stat = np.nan
1479
+ p_value = np.nan
1480
+ conf_int = (np.nan, np.nan)
1481
+
1482
+ # Create results dictionaries
1483
+ unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)}
1484
+ time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)}
1485
+
1486
+ self.results_ = TROPResults(
1487
+ att=float(att),
1488
+ se=float(se),
1489
+ t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat,
1490
+ p_value=float(p_value) if np.isfinite(p_value) else p_value,
1491
+ conf_int=conf_int,
1492
+ n_obs=len(data),
1493
+ n_treated=len(treated_unit_idx),
1494
+ n_control=len(control_unit_idx),
1495
+ n_treated_obs=int(n_treated_obs),
1496
+ unit_effects=unit_effects_dict,
1497
+ time_effects=time_effects_dict,
1498
+ treatment_effects=treatment_effects,
1499
+ lambda_time=lambda_time,
1500
+ lambda_unit=lambda_unit,
1501
+ lambda_nn=original_lambda_nn,
1502
+ factor_matrix=L,
1503
+ effective_rank=effective_rank,
1504
+ loocv_score=best_score,
1505
+ alpha=self.alpha,
1506
+ n_pre_periods=n_pre_periods,
1507
+ n_post_periods=n_post_periods,
1508
+ n_bootstrap=self.n_bootstrap,
1509
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
1510
+ )
1511
+
1512
+ self.is_fitted_ = True
1513
+ return self.results_
1514
+
1515
+ def _bootstrap_variance_joint(
1516
+ self,
1517
+ data: pd.DataFrame,
1518
+ outcome: str,
1519
+ treatment: str,
1520
+ unit: str,
1521
+ time: str,
1522
+ optimal_lambda: Tuple[float, float, float],
1523
+ treated_periods: int,
1524
+ ) -> Tuple[float, np.ndarray]:
1525
+ """
1526
+ Compute bootstrap standard error for joint method.
1527
+
1528
+ Uses Rust backend when available for parallel bootstrap (5-15x speedup).
1529
+
1530
+ Parameters
1531
+ ----------
1532
+ data : pd.DataFrame
1533
+ Original data.
1534
+ outcome : str
1535
+ Outcome column name.
1536
+ treatment : str
1537
+ Treatment column name.
1538
+ unit : str
1539
+ Unit column name.
1540
+ time : str
1541
+ Time column name.
1542
+ optimal_lambda : tuple
1543
+ Optimal tuning parameters.
1544
+ treated_periods : int
1545
+ Number of post-treatment periods.
1546
+
1547
+ Returns
1548
+ -------
1549
+ Tuple[float, np.ndarray]
1550
+ (se, bootstrap_estimates).
1551
+ """
1552
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
1553
+
1554
+ # Try Rust backend for parallel bootstrap (5-15x speedup)
1555
+ if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None:
1556
+ try:
1557
+ # Create matrices for Rust function
1558
+ all_units = sorted(data[unit].unique())
1559
+ all_periods = sorted(data[time].unique())
1560
+
1561
+ Y = (
1562
+ data.pivot(index=time, columns=unit, values=outcome)
1563
+ .reindex(index=all_periods, columns=all_units)
1564
+ .values
1565
+ )
1566
+ D = (
1567
+ data.pivot(index=time, columns=unit, values=treatment)
1568
+ .reindex(index=all_periods, columns=all_units)
1569
+ .fillna(0)
1570
+ .astype(np.float64)
1571
+ .values
1572
+ )
1573
+
1574
+ bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint(
1575
+ Y, D,
1576
+ lambda_time, lambda_unit, lambda_nn,
1577
+ self.n_bootstrap, self.max_iter, self.tol,
1578
+ self.seed if self.seed is not None else 0
1579
+ )
1580
+
1581
+ if len(bootstrap_estimates) < 10:
1582
+ warnings.warn(
1583
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
1584
+ UserWarning
1585
+ )
1586
+ if len(bootstrap_estimates) == 0:
1587
+ return 0.0, np.array([])
1588
+
1589
+ return float(se), np.array(bootstrap_estimates)
1590
+
1591
+ except Exception as e:
1592
+ logger.debug(
1593
+ "Rust bootstrap (joint) failed, falling back to Python: %s", e
1594
+ )
1595
+
1596
+ # Python fallback implementation
1597
+ rng = np.random.default_rng(self.seed)
1598
+
1599
+ # Stratified bootstrap sampling
1600
+ unit_ever_treated = data.groupby(unit)[treatment].max()
1601
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist())
1602
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist())
1603
+
1604
+ n_treated_units = len(treated_units)
1605
+ n_control_units = len(control_units)
1606
+
1607
+ bootstrap_estimates_list: List[float] = []
1608
+
1609
+ for _ in range(self.n_bootstrap):
1610
+ # Stratified sampling
1611
+ if n_control_units > 0:
1612
+ sampled_control = rng.choice(
1613
+ control_units, size=n_control_units, replace=True
1614
+ )
1615
+ else:
1616
+ sampled_control = np.array([], dtype=object)
1617
+
1618
+ if n_treated_units > 0:
1619
+ sampled_treated = rng.choice(
1620
+ treated_units, size=n_treated_units, replace=True
1621
+ )
1622
+ else:
1623
+ sampled_treated = np.array([], dtype=object)
1624
+
1625
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
1626
+
1627
+ # Create bootstrap sample
1628
+ boot_data = pd.concat([
1629
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
1630
+ for idx, u in enumerate(sampled_units)
1631
+ ], ignore_index=True)
1632
+
1633
+ try:
1634
+ tau = self._fit_joint_with_fixed_lambda(
1635
+ boot_data, outcome, treatment, unit, time,
1636
+ optimal_lambda, treated_periods
1637
+ )
1638
+ bootstrap_estimates_list.append(tau)
1639
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1640
+ continue
1641
+
1642
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
1643
+
1644
+ if len(bootstrap_estimates) < 10:
1645
+ warnings.warn(
1646
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
1647
+ UserWarning
1648
+ )
1649
+ if len(bootstrap_estimates) == 0:
1650
+ return 0.0, np.array([])
1651
+
1652
+ se = np.std(bootstrap_estimates, ddof=1)
1653
+ return float(se), bootstrap_estimates
1654
+
1655
+ def _fit_joint_with_fixed_lambda(
1656
+ self,
1657
+ data: pd.DataFrame,
1658
+ outcome: str,
1659
+ treatment: str,
1660
+ unit: str,
1661
+ time: str,
1662
+ fixed_lambda: Tuple[float, float, float],
1663
+ treated_periods: int,
1664
+ ) -> float:
1665
+ """
1666
+ Fit joint model with fixed tuning parameters.
1667
+
1668
+ Returns only the treatment effect τ.
1669
+ """
1670
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
1671
+
1672
+ all_units = sorted(data[unit].unique())
1673
+ all_periods = sorted(data[time].unique())
1674
+
1675
+ n_units = len(all_units)
1676
+ n_periods = len(all_periods)
1677
+
1678
+ Y = (
1679
+ data.pivot(index=time, columns=unit, values=outcome)
1680
+ .reindex(index=all_periods, columns=all_units)
1681
+ .values
1682
+ )
1683
+ D = (
1684
+ data.pivot(index=time, columns=unit, values=treatment)
1685
+ .reindex(index=all_periods, columns=all_units)
1686
+ .fillna(0)
1687
+ .astype(int)
1688
+ .values
1689
+ )
1690
+
1691
+ # Compute weights
1692
+ delta = self._compute_joint_weights(
1693
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1694
+ )
1695
+
1696
+ # Fit model
1697
+ if lambda_nn >= 1e10:
1698
+ _, _, _, tau = self._solve_joint_no_lowrank(Y, D, delta)
1699
+ else:
1700
+ _, _, _, _, tau = self._solve_joint_with_lowrank(
1701
+ Y, D, delta, lambda_nn, self.max_iter, self.tol
1702
+ )
1703
+
1704
+ return tau
1705
+
1706
+ def fit(
1707
+ self,
1708
+ data: pd.DataFrame,
1709
+ outcome: str,
1710
+ treatment: str,
1711
+ unit: str,
1712
+ time: str,
1713
+ ) -> TROPResults:
1714
+ """
1715
+ Fit the TROP model.
1716
+
1717
+ Parameters
1718
+ ----------
1719
+ data : pd.DataFrame
1720
+ Panel data with observations for multiple units over multiple
1721
+ time periods.
1722
+ outcome : str
1723
+ Name of the outcome variable column.
1724
+ treatment : str
1725
+ Name of the treatment indicator column (0/1).
1726
+
1727
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a
1728
+ treatment timing indicator. For each unit, D=1 for ALL periods
1729
+ during and after treatment:
1730
+
1731
+ - D[t, i] = 0 for all t < g_i (pre-treatment periods)
1732
+ - D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
1733
+
1734
+ where g_i is the treatment start time for unit i.
1735
+
1736
+ For staggered adoption, different units can have different g_i.
1737
+ The ATT averages over ALL D=1 cells per Equation 1 of the paper.
1738
+ unit : str
1739
+ Name of the unit identifier column.
1740
+ time : str
1741
+ Name of the time period column.
1742
+
1743
+ Returns
1744
+ -------
1745
+ TROPResults
1746
+ Object containing the ATT estimate, standard error,
1747
+ factor estimates, and tuning parameters. The lambda_*
1748
+ attributes show the selected grid values. For λ_time and
1749
+ λ_unit, 0.0 means uniform weights; inf is not accepted.
1750
+ For λ_nn, ∞ is converted to 1e10 (factor model disabled).
1751
+ """
1752
+ # Validate inputs
1753
+ required_cols = [outcome, treatment, unit, time]
1754
+ missing = [c for c in required_cols if c not in data.columns]
1755
+ if missing:
1756
+ raise ValueError(f"Missing columns: {missing}")
1757
+
1758
+ # Dispatch based on estimation method
1759
+ if self.method == "joint":
1760
+ return self._fit_joint(data, outcome, treatment, unit, time)
1761
+
1762
+ # Below is the twostep method (default)
1763
+ # Get unique units and periods
1764
+ all_units = sorted(data[unit].unique())
1765
+ all_periods = sorted(data[time].unique())
1766
+
1767
+ n_units = len(all_units)
1768
+ n_periods = len(all_periods)
1769
+
1770
+ # Create mappings
1771
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1772
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
1773
+ idx_to_unit = {i: u for u, i in unit_to_idx.items()}
1774
+ idx_to_period = {i: p for p, i in period_to_idx.items()}
1775
+
1776
+ # Create outcome matrix Y (n_periods x n_units) and treatment matrix D
1777
+ # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
1778
+ Y = (
1779
+ data.pivot(index=time, columns=unit, values=outcome)
1780
+ .reindex(index=all_periods, columns=all_units)
1781
+ .values
1782
+ )
1783
+
1784
+ # For D matrix, track missing values BEFORE fillna to support unbalanced panels
1785
+ # Issue 3 fix: Missing observations should not trigger spurious violations
1786
+ D_raw = (
1787
+ data.pivot(index=time, columns=unit, values=treatment)
1788
+ .reindex(index=all_periods, columns=all_units)
1789
+ )
1790
+ missing_mask = pd.isna(D_raw).values # True where originally missing
1791
+ D = D_raw.fillna(0).astype(int).values
1792
+
1793
+ # Validate D is monotonic non-decreasing per unit (absorbing state)
1794
+ # D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
1795
+ # Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
1796
+ # This catches 1→0 violations that span missing period gaps
1797
+ # Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
1798
+ # adjacent period transitions don't show it (the gap hides the transition)
1799
+ violating_units = []
1800
+ for unit_idx in range(n_units):
1801
+ # Get observed D values for this unit (where not missing)
1802
+ observed_mask = ~missing_mask[:, unit_idx]
1803
+ observed_d = D[observed_mask, unit_idx]
1804
+
1805
+ # Check if observed sequence is monotonically non-decreasing
1806
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
1807
+ violating_units.append(all_units[unit_idx])
1808
+
1809
+ if violating_units:
1810
+ raise ValueError(
1811
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
1812
+ f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
1813
+ f"If this is event-study style data, convert to absorbing state: "
1814
+ f"D[t, i] = 1 for all t >= first treatment period."
1815
+ )
1816
+
1817
+ # Identify treated observations
1818
+ treated_mask = D == 1
1819
+ n_treated_obs = np.sum(treated_mask)
1820
+
1821
+ if n_treated_obs == 0:
1822
+ raise ValueError("No treated observations found")
1823
+
1824
+ # Identify treated and control units
1825
+ unit_ever_treated = np.any(D == 1, axis=0)
1826
+ treated_unit_idx = np.where(unit_ever_treated)[0]
1827
+ control_unit_idx = np.where(~unit_ever_treated)[0]
1828
+
1829
+ if len(control_unit_idx) == 0:
1830
+ raise ValueError("No control units found")
1831
+
1832
+ # Determine pre/post periods from treatment indicator D
1833
+ # D matrix is the sole input for treatment timing per the paper
1834
+ first_treat_period = None
1835
+ for t in range(n_periods):
1836
+ if np.any(D[t, :] == 1):
1837
+ first_treat_period = t
1838
+ break
1839
+ if first_treat_period is None:
1840
+ raise ValueError("Could not infer post-treatment periods from D matrix")
1841
+
1842
+ n_pre_periods = first_treat_period
1843
+ # Count periods where D=1 is actually observed (matches docstring)
1844
+ # Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
1845
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
1846
+
1847
+ if n_pre_periods < 2:
1848
+ raise ValueError("Need at least 2 pre-treatment periods")
1849
+
1850
+ # Step 1: Grid search with LOOCV for tuning parameters
1851
+ best_lambda = None
1852
+ best_score = np.inf
1853
+
1854
+ # Control observations mask (for LOOCV)
1855
+ control_mask = D == 0
1856
+
1857
+ # Pre-compute structures that are reused across LOOCV iterations
1858
+ self._precomputed = self._precompute_structures(
1859
+ Y, D, control_unit_idx, n_units, n_periods
1860
+ )
1861
+
1862
+ # Use Rust backend for parallel LOOCV grid search (10-50x speedup)
1863
+ if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None:
1864
+ try:
1865
+ # Prepare inputs for Rust function
1866
+ control_mask_u8 = control_mask.astype(np.uint8)
1867
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
1868
+
1869
+ lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
1870
+ lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
1871
+ lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
1872
+
1873
+ result = _rust_loocv_grid_search(
1874
+ Y, D.astype(np.float64), control_mask_u8,
1875
+ time_dist_matrix,
1876
+ lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
1877
+ self.max_iter, self.tol,
1878
+ )
1879
+ # Unpack result - 7 values including optional first_failed_obs
1880
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
1881
+ # Only accept finite scores - infinite means all fits failed
1882
+ if np.isfinite(best_score):
1883
+ best_lambda = (best_lt, best_lu, best_ln)
1884
+ # else: best_lambda stays None, triggering defaults fallback
1885
+ # Emit warnings consistent with Python implementation
1886
+ if n_valid == 0:
1887
+ # Include failed observation coordinates if available (Issue 2 fix)
1888
+ obs_info = ""
1889
+ if first_failed_obs is not None:
1890
+ t_idx, i_idx = first_failed_obs
1891
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1892
+ warnings.warn(
1893
+ f"LOOCV: All {n_attempted} fits failed for "
1894
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1895
+ f"Returning infinite score.{obs_info}",
1896
+ UserWarning
1897
+ )
1898
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
1899
+ n_failed = n_attempted - n_valid
1900
+ # Include failed observation coordinates if available
1901
+ obs_info = ""
1902
+ if first_failed_obs is not None:
1903
+ t_idx, i_idx = first_failed_obs
1904
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1905
+ warnings.warn(
1906
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
1907
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1908
+ f"This may indicate numerical instability.{obs_info}",
1909
+ UserWarning
1910
+ )
1911
+ except Exception as e:
1912
+ # Fall back to Python implementation on error
1913
+ logger.debug(
1914
+ "Rust LOOCV grid search failed, falling back to Python: %s", e
1915
+ )
1916
+ best_lambda = None
1917
+ best_score = np.inf
1918
+
1919
+ # Fall back to Python implementation if Rust unavailable or failed
1920
+ # Uses two-stage approach per paper's footnote 2:
1921
+ # Stage 1: Univariate searches for initial values
1922
+ # Stage 2: Cycling (coordinate descent) until convergence
1923
+ if best_lambda is None:
1924
+ # Stage 1: Univariate searches with extreme fixed values
1925
+ # Following paper's footnote 2 for initial bounds
1926
+
1927
+ # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
1928
+ lambda_time_init, _ = self._univariate_loocv_search(
1929
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1930
+ 'lambda_time', self.lambda_time_grid,
1931
+ {'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF}
1932
+ )
1933
+
1934
+ # λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0
1935
+ lambda_nn_init, _ = self._univariate_loocv_search(
1936
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1937
+ 'lambda_nn', self.lambda_nn_grid,
1938
+ {'lambda_time': 0.0, 'lambda_unit': 0.0}
1939
+ )
1940
+
1941
+ # λ_unit search: fix λ_nn=∞, λ_time=0
1942
+ lambda_unit_init, _ = self._univariate_loocv_search(
1943
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1944
+ 'lambda_unit', self.lambda_unit_grid,
1945
+ {'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0}
1946
+ )
1947
+
1948
+ # Stage 2: Cycling refinement (coordinate descent)
1949
+ lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
1950
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1951
+ (lambda_time_init, lambda_unit_init, lambda_nn_init)
1952
+ )
1953
+
1954
+ # Compute final score for the optimized parameters
1955
+ try:
1956
+ best_score = self._loocv_score_obs_specific(
1957
+ Y, D, control_mask, control_unit_idx,
1958
+ lambda_time, lambda_unit, lambda_nn,
1959
+ n_units, n_periods
1960
+ )
1961
+ # Only accept finite scores - infinite means all fits failed
1962
+ if np.isfinite(best_score):
1963
+ best_lambda = (lambda_time, lambda_unit, lambda_nn)
1964
+ # else: best_lambda stays None, triggering defaults fallback
1965
+ except (np.linalg.LinAlgError, ValueError):
1966
+ # If even the optimized parameters fail, best_lambda stays None
1967
+ pass
1968
+
1969
+ if best_lambda is None:
1970
+ warnings.warn(
1971
+ "All tuning parameter combinations failed. Using defaults.",
1972
+ UserWarning
1973
+ )
1974
+ best_lambda = (1.0, 1.0, 0.1)
1975
+ best_score = np.nan
1976
+
1977
+ self._optimal_lambda = best_lambda
1978
+ lambda_time, lambda_unit, lambda_nn = best_lambda
1979
+
1980
+ # Store original λ_nn for results (only λ_nn needs original→effective conversion).
1981
+ # λ_time and λ_unit use 0.0 for uniform weights directly per Eq. 3.
1982
+ original_lambda_nn = lambda_nn
1983
+
1984
+ # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
1985
+ if np.isinf(lambda_nn):
1986
+ lambda_nn = 1e10
1987
+
1988
+ # effective_lambda with converted λ_nn for ALL downstream computation
1989
+ # (variance estimation uses the same parameters as point estimation)
1990
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
1991
+
1992
+ # Step 2: Final estimation - per-observation model fitting following Algorithm 2
1993
+ # For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
1994
+ treatment_effects = {}
1995
+ tau_values = []
1996
+ alpha_estimates = []
1997
+ beta_estimates = []
1998
+ L_estimates = []
1999
+
2000
+ # Use pre-computed treated observations
2001
+ treated_observations = self._precomputed["treated_observations"]
2002
+
2003
+ for t, i in treated_observations:
2004
+ # Compute observation-specific weights for this (i, t)
2005
+ weight_matrix = self._compute_observation_weights(
2006
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
2007
+ n_units, n_periods
2008
+ )
2009
+
2010
+ # Fit model with these weights
2011
+ alpha_hat, beta_hat, L_hat = self._estimate_model(
2012
+ Y, control_mask, weight_matrix, lambda_nn,
2013
+ n_units, n_periods
2014
+ )
2015
+
2016
+ # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
2017
+ tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
2018
+
2019
+ unit_id = idx_to_unit[i]
2020
+ time_id = idx_to_period[t]
2021
+ treatment_effects[(unit_id, time_id)] = tau_it
2022
+ tau_values.append(tau_it)
2023
+
2024
+ # Store for averaging
2025
+ alpha_estimates.append(alpha_hat)
2026
+ beta_estimates.append(beta_hat)
2027
+ L_estimates.append(L_hat)
2028
+
2029
+ # Average ATT
2030
+ att = np.mean(tau_values)
2031
+
2032
+ # Average parameter estimates for output (representative)
2033
+ alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
2034
+ beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
2035
+ L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
2036
+
2037
+ # Compute effective rank
2038
+ _, s, _ = np.linalg.svd(L_hat, full_matrices=False)
2039
+ if s[0] > 0:
2040
+ effective_rank = np.sum(s) / s[0]
2041
+ else:
2042
+ effective_rank = 0.0
2043
+
2044
+ # Step 4: Variance estimation
2045
+ # Use effective_lambda (converted values) to ensure SE is computed with same
2046
+ # parameters as point estimation. This fixes the variance inconsistency issue.
2047
+ se, bootstrap_dist = self._bootstrap_variance(
2048
+ data, outcome, treatment, unit, time,
2049
+ effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
2050
+ )
2051
+
2052
+ # Compute test statistics
2053
+ if se > 0:
2054
+ t_stat = att / se
2055
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
2056
+ conf_int = compute_confidence_interval(att, se, self.alpha)
2057
+ else:
2058
+ # When SE is undefined/zero, ALL inference fields should be NaN
2059
+ t_stat = np.nan
2060
+ p_value = np.nan
2061
+ conf_int = (np.nan, np.nan)
2062
+
2063
+ # Create results dictionaries
2064
+ unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
2065
+ time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
2066
+
2067
+ # Store results
2068
+ self.results_ = TROPResults(
2069
+ att=att,
2070
+ se=se,
2071
+ t_stat=t_stat,
2072
+ p_value=p_value,
2073
+ conf_int=conf_int,
2074
+ n_obs=len(data),
2075
+ n_treated=len(treated_unit_idx),
2076
+ n_control=len(control_unit_idx),
2077
+ n_treated_obs=n_treated_obs,
2078
+ unit_effects=unit_effects_dict,
2079
+ time_effects=time_effects_dict,
2080
+ treatment_effects=treatment_effects,
2081
+ lambda_time=lambda_time,
2082
+ lambda_unit=lambda_unit,
2083
+ lambda_nn=original_lambda_nn,
2084
+ factor_matrix=L_hat,
2085
+ effective_rank=effective_rank,
2086
+ loocv_score=best_score,
2087
+ alpha=self.alpha,
2088
+ n_pre_periods=n_pre_periods,
2089
+ n_post_periods=n_post_periods,
2090
+ n_bootstrap=self.n_bootstrap,
2091
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
2092
+ )
2093
+
2094
+ self.is_fitted_ = True
2095
+ return self.results_
2096
+
2097
+ def _compute_observation_weights(
2098
+ self,
2099
+ Y: np.ndarray,
2100
+ D: np.ndarray,
2101
+ i: int,
2102
+ t: int,
2103
+ lambda_time: float,
2104
+ lambda_unit: float,
2105
+ control_unit_idx: np.ndarray,
2106
+ n_units: int,
2107
+ n_periods: int,
2108
+ ) -> np.ndarray:
2109
+ """
2110
+ Compute observation-specific weight matrix for treated observation (i, t).
2111
+
2112
+ Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
2113
+ - Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
2114
+ - Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
2115
+
2116
+ IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
2117
+ where (1 - W_js) is non-zero, which includes pre-treatment observations of
2118
+ eventually-treated units since W_js = 0 for those. This method computes
2119
+ weights for ALL units where D[t, j] = 0 at the target period, not just
2120
+ never-treated units.
2121
+
2122
+ Uses pre-computed structures when available for efficiency.
2123
+
2124
+ Parameters
2125
+ ----------
2126
+ Y : np.ndarray
2127
+ Outcome matrix (n_periods x n_units).
2128
+ D : np.ndarray
2129
+ Treatment indicator matrix (n_periods x n_units).
2130
+ i : int
2131
+ Treated unit index.
2132
+ t : int
2133
+ Treatment period index.
2134
+ lambda_time : float
2135
+ Time weight decay parameter.
2136
+ lambda_unit : float
2137
+ Unit weight decay parameter.
2138
+ control_unit_idx : np.ndarray
2139
+ Indices of never-treated units (for backward compatibility, but not
2140
+ used for weight computation - we use D matrix directly).
2141
+ n_units : int
2142
+ Number of units.
2143
+ n_periods : int
2144
+ Number of periods.
2145
+
2146
+ Returns
2147
+ -------
2148
+ np.ndarray
2149
+ Weight matrix (n_periods x n_units) for observation (i, t).
2150
+ """
2151
+ # Use pre-computed structures when available
2152
+ if self._precomputed is not None:
2153
+ # Time weights from pre-computed time distance matrix
2154
+ # time_dist_matrix[t, s] = |t - s|
2155
+ time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
2156
+
2157
+ # Unit weights - computed for ALL units where D[t, j] = 0
2158
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
2159
+ unit_weights = np.zeros(n_units)
2160
+ D_stored = self._precomputed["D"]
2161
+ Y_stored = self._precomputed["Y"]
2162
+
2163
+ # Valid control units at time t: D[t, j] == 0
2164
+ valid_control_at_t = D_stored[t, :] == 0
2165
+
2166
+ if lambda_unit == 0:
2167
+ # Uniform weights when lambda_unit = 0
2168
+ # All units not treated at time t get weight 1
2169
+ unit_weights[valid_control_at_t] = 1.0
2170
+ else:
2171
+ # Use observation-specific distances with target period excluded
2172
+ # (Issue B fix: compute exact per-observation distance)
2173
+ for j in range(n_units):
2174
+ if valid_control_at_t[j] and j != i:
2175
+ # Compute distance excluding target period t
2176
+ dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
2177
+ if np.isinf(dist):
2178
+ unit_weights[j] = 0.0
2179
+ else:
2180
+ unit_weights[j] = np.exp(-lambda_unit * dist)
2181
+
2182
+ # Treated unit i gets weight 1
2183
+ unit_weights[i] = 1.0
2184
+
2185
+ # Weight matrix: outer product (n_periods x n_units)
2186
+ return np.outer(time_weights, unit_weights)
2187
+
2188
+ # Fallback: compute from scratch (used in bootstrap)
2189
+ # Time distance: |t - s| following paper's Equation 3 (page 7)
2190
+ dist_time = np.abs(np.arange(n_periods) - t)
2191
+ time_weights = np.exp(-lambda_time * dist_time)
2192
+
2193
+ # Unit weights - computed for ALL units where D[t, j] = 0
2194
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
2195
+ unit_weights = np.zeros(n_units)
2196
+
2197
+ # Valid control units at time t: D[t, j] == 0
2198
+ valid_control_at_t = D[t, :] == 0
2199
+
2200
+ if lambda_unit == 0:
2201
+ # Uniform weights when lambda_unit = 0
2202
+ unit_weights[valid_control_at_t] = 1.0
2203
+ else:
2204
+ for j in range(n_units):
2205
+ if valid_control_at_t[j] and j != i:
2206
+ # Compute distance excluding target period t (Issue B fix)
2207
+ dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
2208
+ if np.isinf(dist):
2209
+ unit_weights[j] = 0.0
2210
+ else:
2211
+ unit_weights[j] = np.exp(-lambda_unit * dist)
2212
+
2213
+ # Treated unit i gets weight 1 (or could be omitted since we fit on controls)
2214
+ # We include treated unit's own observation for model fitting
2215
+ unit_weights[i] = 1.0
2216
+
2217
+ # Weight matrix: outer product (n_periods x n_units)
2218
+ W = np.outer(time_weights, unit_weights)
2219
+
2220
+ return W
2221
+
2222
+ def _soft_threshold_svd(
2223
+ self,
2224
+ M: np.ndarray,
2225
+ threshold: float,
2226
+ ) -> np.ndarray:
2227
+ """
2228
+ Apply soft-thresholding to singular values (proximal operator for nuclear norm).
2229
+
2230
+ Parameters
2231
+ ----------
2232
+ M : np.ndarray
2233
+ Input matrix.
2234
+ threshold : float
2235
+ Soft-thresholding parameter.
2236
+
2237
+ Returns
2238
+ -------
2239
+ np.ndarray
2240
+ Matrix with soft-thresholded singular values.
2241
+ """
2242
+ if threshold <= 0:
2243
+ return M
2244
+
2245
+ # Handle NaN/Inf values in input
2246
+ if not np.isfinite(M).all():
2247
+ M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
2248
+
2249
+ try:
2250
+ U, s, Vt = np.linalg.svd(M, full_matrices=False)
2251
+ except np.linalg.LinAlgError:
2252
+ # SVD failed, return zero matrix
2253
+ return np.zeros_like(M)
2254
+
2255
+ # Check for numerical issues in SVD output
2256
+ if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
2257
+ # SVD produced non-finite values, return zero matrix
2258
+ return np.zeros_like(M)
2259
+
2260
+ s_thresh = np.maximum(s - threshold, 0)
2261
+
2262
+ # Use truncated reconstruction with only non-zero singular values
2263
+ nonzero_mask = s_thresh > self.CONVERGENCE_TOL_SVD
2264
+ if not np.any(nonzero_mask):
2265
+ return np.zeros_like(M)
2266
+
2267
+ # Truncate to non-zero components for numerical stability
2268
+ U_trunc = U[:, nonzero_mask]
2269
+ s_trunc = s_thresh[nonzero_mask]
2270
+ Vt_trunc = Vt[nonzero_mask, :]
2271
+
2272
+ # Compute result, suppressing expected numerical warnings from
2273
+ # ill-conditioned matrices during alternating minimization
2274
+ with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
2275
+ result = (U_trunc * s_trunc) @ Vt_trunc
2276
+
2277
+ # Replace any NaN/Inf in result with zeros
2278
+ if not np.isfinite(result).all():
2279
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
2280
+
2281
+ return result
2282
+
2283
+ def _weighted_nuclear_norm_solve(
2284
+ self,
2285
+ Y: np.ndarray,
2286
+ W: np.ndarray,
2287
+ L_init: np.ndarray,
2288
+ alpha: np.ndarray,
2289
+ beta: np.ndarray,
2290
+ lambda_nn: float,
2291
+ max_inner_iter: int = 20,
2292
+ ) -> np.ndarray:
2293
+ """
2294
+ Solve weighted nuclear norm problem using iterative weighted soft-impute.
2295
+
2296
+ Issue C fix: Implements the weighted nuclear norm optimization from the
2297
+ paper's Equation 2 (page 7). The full objective is:
2298
+ min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
2299
+
2300
+ This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
2301
+ L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
2302
+
2303
+ where W ⊙ denotes element-wise multiplication with normalized weights.
2304
+
2305
+ IMPORTANT: For observations with W=0 (treated observations), we keep
2306
+ L values from the previous iteration rather than setting L = R, which
2307
+ would absorb the treatment effect.
2308
+
2309
+ Parameters
2310
+ ----------
2311
+ Y : np.ndarray
2312
+ Outcome matrix (n_periods x n_units).
2313
+ W : np.ndarray
2314
+ Weight matrix (n_periods x n_units), non-negative. W=0 indicates
2315
+ observations that should not be used for fitting (treated obs).
2316
+ L_init : np.ndarray
2317
+ Initial estimate of L matrix.
2318
+ alpha : np.ndarray
2319
+ Current unit fixed effects estimate.
2320
+ beta : np.ndarray
2321
+ Current time fixed effects estimate.
2322
+ lambda_nn : float
2323
+ Nuclear norm regularization parameter.
2324
+ max_inner_iter : int, default=20
2325
+ Maximum inner iterations for the proximal algorithm.
2326
+
2327
+ Returns
2328
+ -------
2329
+ np.ndarray
2330
+ Updated L matrix estimate.
2331
+ """
2332
+ # Compute target residual R = Y - α - β
2333
+ R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
2334
+
2335
+ # Handle invalid values
2336
+ R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
2337
+
2338
+ # For observations with W=0 (treated obs), keep L_init instead of R
2339
+ # This prevents L from absorbing the treatment effect
2340
+ valid_obs_mask = W > 0
2341
+ R_masked = np.where(valid_obs_mask, R, L_init)
2342
+
2343
+ if lambda_nn <= 0:
2344
+ # No regularization - just return masked residual
2345
+ # Use soft-thresholding with threshold=0 which returns the input
2346
+ return R_masked
2347
+
2348
+ # Normalize weights so max is 1 (for step size stability)
2349
+ W_max = np.max(W)
2350
+ if W_max > 0:
2351
+ W_norm = W / W_max
2352
+ else:
2353
+ W_norm = W
2354
+
2355
+ # Initialize L
2356
+ L = L_init.copy()
2357
+
2358
+ # Proximal gradient iteration with weighted soft-impute
2359
+ # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
2360
+ # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
2361
+ # where η is the step size (we use η = 1 with normalized weights)
2362
+ for _ in range(max_inner_iter):
2363
+ L_old = L.copy()
2364
+
2365
+ # Gradient step: L_k + W ⊙ (R - L_k)
2366
+ # For W=0 observations, this keeps L_k unchanged
2367
+ gradient_step = L + W_norm * (R_masked - L)
2368
+
2369
+ # Proximal step: soft-threshold singular values
2370
+ L = self._soft_threshold_svd(gradient_step, lambda_nn)
2371
+
2372
+ # Check convergence
2373
+ if np.max(np.abs(L - L_old)) < self.tol:
2374
+ break
2375
+
2376
+ return L
2377
+
2378
+ def _estimate_model(
2379
+ self,
2380
+ Y: np.ndarray,
2381
+ control_mask: np.ndarray,
2382
+ weight_matrix: np.ndarray,
2383
+ lambda_nn: float,
2384
+ n_units: int,
2385
+ n_periods: int,
2386
+ exclude_obs: Optional[Tuple[int, int]] = None,
2387
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
2388
+ """
2389
+ Estimate the model: Y = α + β + L + τD + ε with nuclear norm penalty on L.
2390
+
2391
+ Uses alternating minimization with vectorized operations:
2392
+ 1. Fix L, solve for α, β via weighted means
2393
+ 2. Fix α, β, solve for L via soft-thresholding
2394
+
2395
+ Parameters
2396
+ ----------
2397
+ Y : np.ndarray
2398
+ Outcome matrix (n_periods x n_units).
2399
+ control_mask : np.ndarray
2400
+ Boolean mask for control observations.
2401
+ weight_matrix : np.ndarray
2402
+ Pre-computed global weight matrix (n_periods x n_units).
2403
+ lambda_nn : float
2404
+ Nuclear norm regularization parameter.
2405
+ n_units : int
2406
+ Number of units.
2407
+ n_periods : int
2408
+ Number of periods.
2409
+ exclude_obs : tuple, optional
2410
+ (t, i) observation to exclude (for LOOCV).
2411
+
2412
+ Returns
2413
+ -------
2414
+ tuple
2415
+ (alpha, beta, L) estimated parameters.
2416
+ """
2417
+ W = weight_matrix
2418
+
2419
+ # Mask for estimation (control obs only, excluding LOOCV obs if specified)
2420
+ est_mask = control_mask.copy()
2421
+ if exclude_obs is not None:
2422
+ t_ex, i_ex = exclude_obs
2423
+ est_mask[t_ex, i_ex] = False
2424
+
2425
+ # Handle missing values
2426
+ valid_mask = ~np.isnan(Y) & est_mask
2427
+
2428
+ # Initialize
2429
+ alpha = np.zeros(n_units)
2430
+ beta = np.zeros(n_periods)
2431
+ L = np.zeros((n_periods, n_units))
2432
+
2433
+ # Pre-compute masked weights for vectorized operations
2434
+ # Set weights to 0 where not valid
2435
+ W_masked = W * valid_mask
2436
+
2437
+ # Pre-compute weight sums per unit and per time (for denominator)
2438
+ # shape: (n_units,) and (n_periods,)
2439
+ weight_sum_per_unit = np.sum(W_masked, axis=0) # sum over periods
2440
+ weight_sum_per_time = np.sum(W_masked, axis=1) # sum over units
2441
+
2442
+ # Handle units/periods with zero weight sum
2443
+ unit_has_obs = weight_sum_per_unit > 0
2444
+ time_has_obs = weight_sum_per_time > 0
2445
+
2446
+ # Create safe denominators (avoid division by zero)
2447
+ safe_unit_denom = np.where(unit_has_obs, weight_sum_per_unit, 1.0)
2448
+ safe_time_denom = np.where(time_has_obs, weight_sum_per_time, 1.0)
2449
+
2450
+ # Replace NaN in Y with 0 for computation (mask handles exclusion)
2451
+ Y_safe = np.where(np.isnan(Y), 0.0, Y)
2452
+
2453
+ # Alternating minimization following Algorithm 1 (page 9)
2454
+ # Minimize: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
2455
+ for _ in range(self.max_iter):
2456
+ alpha_old = alpha.copy()
2457
+ beta_old = beta.copy()
2458
+ L_old = L.copy()
2459
+
2460
+ # Step 1: Update α and β (weighted least squares)
2461
+ # Following Equation 2 (page 7), fix L and solve for α, β
2462
+ # R = Y - L (residual without fixed effects)
2463
+ R = Y_safe - L
2464
+
2465
+ # Alpha update (unit fixed effects):
2466
+ # α_i = argmin_α Σ_t W_{ti}(R_{ti} - α - β_t)²
2467
+ # Solution: α_i = Σ_t W_{ti}(R_{ti} - β_t) / Σ_t W_{ti}
2468
+ R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units)
2469
+ weighted_R_minus_beta = W_masked * R_minus_beta
2470
+ alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,)
2471
+ alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0)
2472
+
2473
+ # Beta update (time fixed effects):
2474
+ # β_t = argmin_β Σ_i W_{ti}(R_{ti} - α_i - β)²
2475
+ # Solution: β_t = Σ_i W_{ti}(R_{ti} - α_i) / Σ_i W_{ti}
2476
+ R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units)
2477
+ weighted_R_minus_alpha = W_masked * R_minus_alpha
2478
+ beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
2479
+ beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
2480
+
2481
+ # Step 2: Update L with weighted nuclear norm penalty
2482
+ # Issue C fix: Use weighted soft-impute to properly account for
2483
+ # observation weights in the nuclear norm optimization.
2484
+ # Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
2485
+ L = self._weighted_nuclear_norm_solve(
2486
+ Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
2487
+ )
2488
+
2489
+ # Check convergence
2490
+ alpha_diff = np.max(np.abs(alpha - alpha_old))
2491
+ beta_diff = np.max(np.abs(beta - beta_old))
2492
+ L_diff = np.max(np.abs(L - L_old))
2493
+
2494
+ if max(alpha_diff, beta_diff, L_diff) < self.tol:
2495
+ break
2496
+
2497
+ return alpha, beta, L
2498
+
2499
+ def _loocv_score_obs_specific(
2500
+ self,
2501
+ Y: np.ndarray,
2502
+ D: np.ndarray,
2503
+ control_mask: np.ndarray,
2504
+ control_unit_idx: np.ndarray,
2505
+ lambda_time: float,
2506
+ lambda_unit: float,
2507
+ lambda_nn: float,
2508
+ n_units: int,
2509
+ n_periods: int,
2510
+ ) -> float:
2511
+ """
2512
+ Compute leave-one-out cross-validation score with observation-specific weights.
2513
+
2514
+ Following the paper's Equation 5 (page 8):
2515
+ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
2516
+
2517
+ For each control observation (j, s), treat it as pseudo-treated,
2518
+ compute observation-specific weights, fit model excluding (j, s),
2519
+ and sum squared pseudo-treatment effects.
2520
+
2521
+ Uses pre-computed structures when available for efficiency.
2522
+
2523
+ Parameters
2524
+ ----------
2525
+ Y : np.ndarray
2526
+ Outcome matrix (n_periods x n_units).
2527
+ D : np.ndarray
2528
+ Treatment indicator matrix (n_periods x n_units).
2529
+ control_mask : np.ndarray
2530
+ Boolean mask for control observations.
2531
+ control_unit_idx : np.ndarray
2532
+ Indices of control units.
2533
+ lambda_time : float
2534
+ Time weight decay parameter.
2535
+ lambda_unit : float
2536
+ Unit weight decay parameter.
2537
+ lambda_nn : float
2538
+ Nuclear norm regularization parameter.
2539
+ n_units : int
2540
+ Number of units.
2541
+ n_periods : int
2542
+ Number of periods.
2543
+
2544
+ Returns
2545
+ -------
2546
+ float
2547
+ LOOCV score (lower is better).
2548
+ """
2549
+ # Use pre-computed control observations if available
2550
+ if self._precomputed is not None:
2551
+ control_obs = self._precomputed["control_obs"]
2552
+ else:
2553
+ # Get all control observations
2554
+ control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
2555
+ if control_mask[t, i] and not np.isnan(Y[t, i])]
2556
+
2557
+ # Empty control set check: if no control observations, return infinity
2558
+ # A score of 0.0 would incorrectly "win" over legitimate parameters
2559
+ if len(control_obs) == 0:
2560
+ warnings.warn(
2561
+ f"LOOCV: No valid control observations for "
2562
+ f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
2563
+ "Returning infinite score.",
2564
+ UserWarning
2565
+ )
2566
+ return np.inf
2567
+
2568
+ tau_squared_sum = 0.0
2569
+ n_valid = 0
2570
+
2571
+ for t, i in control_obs:
2572
+ try:
2573
+ # Compute observation-specific weights for pseudo-treated (i, t)
2574
+ # Uses pre-computed distance matrices when available
2575
+ weight_matrix = self._compute_observation_weights(
2576
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
2577
+ n_units, n_periods
2578
+ )
2579
+
2580
+ # Estimate model excluding observation (t, i)
2581
+ alpha, beta, L = self._estimate_model(
2582
+ Y, control_mask, weight_matrix, lambda_nn,
2583
+ n_units, n_periods, exclude_obs=(t, i)
2584
+ )
2585
+
2586
+ # Pseudo treatment effect
2587
+ tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
2588
+ tau_squared_sum += tau_ti ** 2
2589
+ n_valid += 1
2590
+
2591
+ except (np.linalg.LinAlgError, ValueError):
2592
+ # Per Equation 5: Q(λ) must sum over ALL D==0 cells
2593
+ # Any failure means this λ cannot produce valid estimates for all cells
2594
+ warnings.warn(
2595
+ f"LOOCV: Fit failed for observation ({t}, {i}) with "
2596
+ f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
2597
+ "Returning infinite score per Equation 5.",
2598
+ UserWarning
2599
+ )
2600
+ return np.inf
2601
+
2602
+ # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
2603
+ # Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
2604
+ return tau_squared_sum
2605
+
2606
+ def _bootstrap_variance(
2607
+ self,
2608
+ data: pd.DataFrame,
2609
+ outcome: str,
2610
+ treatment: str,
2611
+ unit: str,
2612
+ time: str,
2613
+ optimal_lambda: Tuple[float, float, float],
2614
+ Y: Optional[np.ndarray] = None,
2615
+ D: Optional[np.ndarray] = None,
2616
+ control_unit_idx: Optional[np.ndarray] = None,
2617
+ ) -> Tuple[float, np.ndarray]:
2618
+ """
2619
+ Compute bootstrap standard error using unit-level block bootstrap.
2620
+
2621
+ When the optional Rust backend is available and the matrix parameters
2622
+ (Y, D, control_unit_idx) are provided, uses parallelized Rust
2623
+ implementation for 5-15x speedup. Falls back to Python implementation
2624
+ if Rust is unavailable or if matrix parameters are not provided.
2625
+
2626
+ Parameters
2627
+ ----------
2628
+ data : pd.DataFrame
2629
+ Original data in long format with unit, time, outcome, and treatment.
2630
+ outcome : str
2631
+ Name of the outcome column in data.
2632
+ treatment : str
2633
+ Name of the treatment indicator column in data.
2634
+ unit : str
2635
+ Name of the unit identifier column in data.
2636
+ time : str
2637
+ Name of the time period column in data.
2638
+ optimal_lambda : tuple of float
2639
+ Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
2640
+ from cross-validation. Used for model estimation in each bootstrap.
2641
+ Y : np.ndarray, optional
2642
+ Outcome matrix of shape (n_periods, n_units). Required for Rust
2643
+ backend acceleration. If None, falls back to Python implementation.
2644
+ D : np.ndarray, optional
2645
+ Treatment indicator matrix of shape (n_periods, n_units) where
2646
+ D[t,i]=1 indicates unit i is treated at time t. Required for Rust
2647
+ backend acceleration.
2648
+ control_unit_idx : np.ndarray, optional
2649
+ Array of indices for control units (never-treated). Required for
2650
+ Rust backend acceleration.
2651
+
2652
+ Returns
2653
+ -------
2654
+ se : float
2655
+ Bootstrap standard error of the ATT estimate.
2656
+ bootstrap_estimates : np.ndarray
2657
+ Array of ATT estimates from each bootstrap iteration. Length may
2658
+ be less than n_bootstrap if some iterations failed.
2659
+
2660
+ Notes
2661
+ -----
2662
+ Uses unit-level block bootstrap where entire unit time series are
2663
+ resampled with replacement. This preserves within-unit correlation
2664
+ structure and is appropriate for panel data.
2665
+ """
2666
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
2667
+
2668
+ # Try Rust backend for parallel bootstrap (5-15x speedup)
2669
+ if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
2670
+ and self._precomputed is not None and Y is not None
2671
+ and D is not None):
2672
+ try:
2673
+ control_mask = self._precomputed["control_mask"]
2674
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
2675
+
2676
+ bootstrap_estimates, se = _rust_bootstrap_trop_variance(
2677
+ Y, D.astype(np.float64),
2678
+ control_mask.astype(np.uint8),
2679
+ time_dist_matrix,
2680
+ lambda_time, lambda_unit, lambda_nn,
2681
+ self.n_bootstrap, self.max_iter, self.tol,
2682
+ self.seed if self.seed is not None else 0
2683
+ )
2684
+
2685
+ if len(bootstrap_estimates) >= 10:
2686
+ return float(se), bootstrap_estimates
2687
+ # Fall through to Python if too few bootstrap samples
2688
+ logger.debug(
2689
+ "Rust bootstrap returned only %d samples, falling back to Python",
2690
+ len(bootstrap_estimates)
2691
+ )
2692
+ except Exception as e:
2693
+ logger.debug(
2694
+ "Rust bootstrap variance failed, falling back to Python: %s", e
2695
+ )
2696
+
2697
+ # Python implementation (fallback)
2698
+ rng = np.random.default_rng(self.seed)
2699
+
2700
+ # Issue D fix: Stratified bootstrap sampling
2701
+ # Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
2702
+ # and N_1 treated rows separately to preserve treatment ratio
2703
+ unit_ever_treated = data.groupby(unit)[treatment].max()
2704
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
2705
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
2706
+
2707
+ n_treated_units = len(treated_units)
2708
+ n_control_units = len(control_units)
2709
+
2710
+ bootstrap_estimates_list = []
2711
+
2712
+ for _ in range(self.n_bootstrap):
2713
+ # Stratified sampling: sample control and treated units separately
2714
+ # This preserves the treatment ratio in each bootstrap sample
2715
+ if n_control_units > 0:
2716
+ sampled_control = rng.choice(
2717
+ control_units, size=n_control_units, replace=True
2718
+ )
2719
+ else:
2720
+ sampled_control = np.array([], dtype=control_units.dtype)
2721
+
2722
+ if n_treated_units > 0:
2723
+ sampled_treated = rng.choice(
2724
+ treated_units, size=n_treated_units, replace=True
2725
+ )
2726
+ else:
2727
+ sampled_treated = np.array([], dtype=treated_units.dtype)
2728
+
2729
+ # Combine stratified samples
2730
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
2731
+
2732
+ # Create bootstrap sample with unique unit IDs
2733
+ boot_data = pd.concat([
2734
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
2735
+ for idx, u in enumerate(sampled_units)
2736
+ ], ignore_index=True)
2737
+
2738
+ try:
2739
+ # Fit with fixed lambda (skip LOOCV for speed)
2740
+ att = self._fit_with_fixed_lambda(
2741
+ boot_data, outcome, treatment, unit, time,
2742
+ optimal_lambda
2743
+ )
2744
+ bootstrap_estimates_list.append(att)
2745
+ except (ValueError, np.linalg.LinAlgError, KeyError):
2746
+ continue
2747
+
2748
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
2749
+
2750
+ if len(bootstrap_estimates) < 10:
2751
+ warnings.warn(
2752
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
2753
+ "Standard errors may be unreliable.",
2754
+ UserWarning
2755
+ )
2756
+ if len(bootstrap_estimates) == 0:
2757
+ return 0.0, np.array([])
2758
+
2759
+ se = np.std(bootstrap_estimates, ddof=1)
2760
+ return float(se), bootstrap_estimates
2761
+
2762
+ def _fit_with_fixed_lambda(
2763
+ self,
2764
+ data: pd.DataFrame,
2765
+ outcome: str,
2766
+ treatment: str,
2767
+ unit: str,
2768
+ time: str,
2769
+ fixed_lambda: Tuple[float, float, float],
2770
+ ) -> float:
2771
+ """
2772
+ Fit model with fixed tuning parameters (for bootstrap).
2773
+
2774
+ Uses observation-specific weights following Algorithm 2.
2775
+ Returns only the ATT estimate.
2776
+ """
2777
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
2778
+
2779
+ # Setup matrices
2780
+ all_units = sorted(data[unit].unique())
2781
+ all_periods = sorted(data[time].unique())
2782
+
2783
+ n_units = len(all_units)
2784
+ n_periods = len(all_periods)
2785
+
2786
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
2787
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
2788
+
2789
+ # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
2790
+ Y = (
2791
+ data.pivot(index=time, columns=unit, values=outcome)
2792
+ .reindex(index=all_periods, columns=all_units)
2793
+ .values
2794
+ )
2795
+ D = (
2796
+ data.pivot(index=time, columns=unit, values=treatment)
2797
+ .reindex(index=all_periods, columns=all_units)
2798
+ .fillna(0)
2799
+ .astype(int)
2800
+ .values
2801
+ )
2802
+
2803
+ control_mask = D == 0
2804
+
2805
+ # Get control unit indices
2806
+ unit_ever_treated = np.any(D == 1, axis=0)
2807
+ control_unit_idx = np.where(~unit_ever_treated)[0]
2808
+
2809
+ # Get list of treated observations
2810
+ treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
2811
+ if D[t, i] == 1]
2812
+
2813
+ if not treated_observations:
2814
+ raise ValueError("No treated observations")
2815
+
2816
+ # Compute ATT using observation-specific weights (Algorithm 2)
2817
+ tau_values = []
2818
+ for t, i in treated_observations:
2819
+ # Compute observation-specific weights for this (i, t)
2820
+ weight_matrix = self._compute_observation_weights(
2821
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
2822
+ n_units, n_periods
2823
+ )
2824
+
2825
+ # Fit model with these weights
2826
+ alpha, beta, L = self._estimate_model(
2827
+ Y, control_mask, weight_matrix, lambda_nn,
2828
+ n_units, n_periods
2829
+ )
2830
+
2831
+ # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
2832
+ tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
2833
+ tau_values.append(tau)
2834
+
2835
+ return np.mean(tau_values)
2836
+
2837
+ def get_params(self) -> Dict[str, Any]:
2838
+ """Get estimator parameters."""
2839
+ return {
2840
+ "method": self.method,
2841
+ "lambda_time_grid": self.lambda_time_grid,
2842
+ "lambda_unit_grid": self.lambda_unit_grid,
2843
+ "lambda_nn_grid": self.lambda_nn_grid,
2844
+ "max_iter": self.max_iter,
2845
+ "tol": self.tol,
2846
+ "alpha": self.alpha,
2847
+ "n_bootstrap": self.n_bootstrap,
2848
+ "seed": self.seed,
2849
+ }
2850
+
2851
+ def set_params(self, **params) -> "TROP":
2852
+ """Set estimator parameters."""
2853
+ for key, value in params.items():
2854
+ if hasattr(self, key):
2855
+ setattr(self, key, value)
2856
+ else:
2857
+ raise ValueError(f"Unknown parameter: {key}")
2858
+ return self
2859
+
2860
+
2861
+ def trop(
2862
+ data: pd.DataFrame,
2863
+ outcome: str,
2864
+ treatment: str,
2865
+ unit: str,
2866
+ time: str,
2867
+ **kwargs,
2868
+ ) -> TROPResults:
2869
+ """
2870
+ Convenience function for TROP estimation.
2871
+
2872
+ Parameters
2873
+ ----------
2874
+ data : pd.DataFrame
2875
+ Panel data.
2876
+ outcome : str
2877
+ Outcome variable column name.
2878
+ treatment : str
2879
+ Treatment indicator column name (0/1).
2880
+
2881
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
2882
+ timing indicator. For each unit, D=1 for ALL periods during and after
2883
+ treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
2884
+ the treatment start time for unit i).
2885
+ unit : str
2886
+ Unit identifier column name.
2887
+ time : str
2888
+ Time period column name.
2889
+ **kwargs
2890
+ Additional arguments passed to TROP constructor.
2891
+
2892
+ Returns
2893
+ -------
2894
+ TROPResults
2895
+ Estimation results.
2896
+
2897
+ Examples
2898
+ --------
2899
+ >>> from diff_diff import trop
2900
+ >>> results = trop(data, 'y', 'treated', 'unit', 'time')
2901
+ >>> print(f"ATT: {results.att:.3f}")
2902
+ """
2903
+ estimator = TROP(**kwargs)
2904
+ return estimator.fit(data, outcome, treatment, unit, time)