diff-diff 2.2.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/trop.py ADDED
@@ -0,0 +1,3149 @@
1
+ """
2
+ Triply Robust Panel (TROP) estimator.
3
+
4
+ Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
5
+ TROP combines three robustness components:
6
+ 1. Nuclear norm regularized factor model (interactive fixed effects)
7
+ 2. Exponential distance-based unit weights
8
+ 3. Exponential time decay weights
9
+
10
+ The estimator uses leave-one-out cross-validation for tuning parameter
11
+ selection and provides robust treatment effect estimates under factor
12
+ confounding.
13
+
14
+ References
15
+ ----------
16
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
17
+ Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
18
+ """
19
+
20
+ import logging
21
+ import warnings
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ from scipy import stats
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ try:
32
+ from typing import TypedDict
33
+ except ImportError:
34
+ from typing_extensions import TypedDict
35
+
36
+ from diff_diff._backend import (
37
+ HAS_RUST_BACKEND,
38
+ _rust_unit_distance_matrix,
39
+ _rust_loocv_grid_search,
40
+ _rust_bootstrap_trop_variance,
41
+ _rust_loocv_grid_search_joint,
42
+ _rust_bootstrap_trop_variance_joint,
43
+ )
44
+ from diff_diff.results import _get_significance_stars
45
+ from diff_diff.utils import compute_confidence_interval, compute_p_value
46
+
47
+
48
+ # Sentinel value for "disabled" mode in LOOCV parameter search
49
+ # Following paper's footnote 2: λ=∞ disables the corresponding component
50
+ _LAMBDA_INF: float = float('inf')
51
+
52
+
53
+ class _PrecomputedStructures(TypedDict):
54
+ """Type definition for pre-computed structures used across LOOCV iterations.
55
+
56
+ These structures are computed once in `_precompute_structures()` and reused
57
+ to avoid redundant computation during LOOCV and final estimation.
58
+ """
59
+
60
+ unit_dist_matrix: np.ndarray
61
+ """Pairwise unit distance matrix (n_units x n_units)."""
62
+ time_dist_matrix: np.ndarray
63
+ """Time distance matrix where [t, s] = |t - s| (n_periods x n_periods)."""
64
+ control_mask: np.ndarray
65
+ """Boolean mask for control observations (D == 0)."""
66
+ treated_mask: np.ndarray
67
+ """Boolean mask for treated observations (D == 1)."""
68
+ treated_observations: List[Tuple[int, int]]
69
+ """List of (t, i) tuples for treated observations."""
70
+ control_obs: List[Tuple[int, int]]
71
+ """List of (t, i) tuples for valid control observations."""
72
+ control_unit_idx: np.ndarray
73
+ """Array of never-treated unit indices (for backward compatibility)."""
74
+ D: np.ndarray
75
+ """Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
76
+ Y: np.ndarray
77
+ """Outcome matrix (n_periods x n_units)."""
78
+ n_units: int
79
+ """Number of units."""
80
+ n_periods: int
81
+ """Number of time periods."""
82
+
83
+
84
+ @dataclass
85
+ class TROPResults:
86
+ """
87
+ Results from a Triply Robust Panel (TROP) estimation.
88
+
89
+ TROP combines nuclear norm regularized factor estimation with
90
+ exponential distance-based unit weights and time decay weights.
91
+
92
+ Attributes
93
+ ----------
94
+ att : float
95
+ Average Treatment effect on the Treated (ATT).
96
+ se : float
97
+ Standard error of the ATT estimate.
98
+ t_stat : float
99
+ T-statistic for the ATT estimate.
100
+ p_value : float
101
+ P-value for the null hypothesis that ATT = 0.
102
+ conf_int : tuple[float, float]
103
+ Confidence interval for the ATT.
104
+ n_obs : int
105
+ Number of observations used in estimation.
106
+ n_treated : int
107
+ Number of treated units.
108
+ n_control : int
109
+ Number of control units.
110
+ n_treated_obs : int
111
+ Number of treated unit-time observations.
112
+ unit_effects : dict
113
+ Estimated unit fixed effects (alpha_i).
114
+ time_effects : dict
115
+ Estimated time fixed effects (beta_t).
116
+ treatment_effects : dict
117
+ Individual treatment effects for each treated (unit, time) pair.
118
+ lambda_time : float
119
+ Selected time weight decay parameter from grid. Note: infinity values
120
+ are converted internally (∞ → 0.0 for uniform weights) for computation.
121
+ lambda_unit : float
122
+ Selected unit weight decay parameter from grid. Note: infinity values
123
+ are converted internally (∞ → 0.0 for uniform weights) for computation.
124
+ lambda_nn : float
125
+ Selected nuclear norm regularization parameter from grid. Note: infinity
126
+ values are converted internally (∞ → 1e10, factor model disabled) for
127
+ computation.
128
+ factor_matrix : np.ndarray
129
+ Estimated low-rank factor matrix L (n_periods x n_units).
130
+ effective_rank : float
131
+ Effective rank of the factor matrix (sum of singular values / max).
132
+ loocv_score : float
133
+ Leave-one-out cross-validation score for selected parameters.
134
+ variance_method : str
135
+ Method used for variance estimation.
136
+ alpha : float
137
+ Significance level for confidence interval.
138
+ n_pre_periods : int
139
+ Number of pre-treatment periods.
140
+ n_post_periods : int
141
+ Number of post-treatment periods (periods with D=1 observations).
142
+ n_bootstrap : int, optional
143
+ Number of bootstrap replications (if bootstrap variance).
144
+ bootstrap_distribution : np.ndarray, optional
145
+ Bootstrap distribution of estimates.
146
+ """
147
+
148
+ att: float
149
+ se: float
150
+ t_stat: float
151
+ p_value: float
152
+ conf_int: Tuple[float, float]
153
+ n_obs: int
154
+ n_treated: int
155
+ n_control: int
156
+ n_treated_obs: int
157
+ unit_effects: Dict[Any, float]
158
+ time_effects: Dict[Any, float]
159
+ treatment_effects: Dict[Tuple[Any, Any], float]
160
+ lambda_time: float
161
+ lambda_unit: float
162
+ lambda_nn: float
163
+ factor_matrix: np.ndarray
164
+ effective_rank: float
165
+ loocv_score: float
166
+ variance_method: str
167
+ alpha: float = 0.05
168
+ n_pre_periods: int = 0
169
+ n_post_periods: int = 0
170
+ n_bootstrap: Optional[int] = field(default=None)
171
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
172
+
173
+ def __repr__(self) -> str:
174
+ """Concise string representation."""
175
+ sig = _get_significance_stars(self.p_value)
176
+ return (
177
+ f"TROPResults(ATT={self.att:.4f}{sig}, "
178
+ f"SE={self.se:.4f}, "
179
+ f"eff_rank={self.effective_rank:.1f}, "
180
+ f"p={self.p_value:.4f})"
181
+ )
182
+
183
+ def summary(self, alpha: Optional[float] = None) -> str:
184
+ """
185
+ Generate a formatted summary of the estimation results.
186
+
187
+ Parameters
188
+ ----------
189
+ alpha : float, optional
190
+ Significance level for confidence intervals. Defaults to the
191
+ alpha used during estimation.
192
+
193
+ Returns
194
+ -------
195
+ str
196
+ Formatted summary table.
197
+ """
198
+ alpha = alpha or self.alpha
199
+ conf_level = int((1 - alpha) * 100)
200
+
201
+ lines = [
202
+ "=" * 75,
203
+ "Triply Robust Panel (TROP) Estimation Results".center(75),
204
+ "Athey, Imbens, Qu & Viviano (2025)".center(75),
205
+ "=" * 75,
206
+ "",
207
+ f"{'Observations:':<25} {self.n_obs:>10}",
208
+ f"{'Treated units:':<25} {self.n_treated:>10}",
209
+ f"{'Control units:':<25} {self.n_control:>10}",
210
+ f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
211
+ f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}",
212
+ f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}",
213
+ "",
214
+ "-" * 75,
215
+ "Tuning Parameters (selected via LOOCV)".center(75),
216
+ "-" * 75,
217
+ f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}",
218
+ f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}",
219
+ f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}",
220
+ f"{'Effective rank:':<25} {self.effective_rank:>10.2f}",
221
+ f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}",
222
+ ]
223
+
224
+ # Variance method info
225
+ lines.append(f"{'Variance method:':<25} {self.variance_method:>10}")
226
+ if self.variance_method == "bootstrap" and self.n_bootstrap is not None:
227
+ lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
228
+
229
+ lines.extend([
230
+ "",
231
+ "-" * 75,
232
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
233
+ f"{'t-stat':>10} {'P>|t|':>10} {'':>5}",
234
+ "-" * 75,
235
+ f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} "
236
+ f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
237
+ "-" * 75,
238
+ "",
239
+ f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
240
+ ])
241
+
242
+ # Add significance codes
243
+ lines.extend([
244
+ "",
245
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
246
+ "=" * 75,
247
+ ])
248
+
249
+ return "\n".join(lines)
250
+
251
+ def print_summary(self, alpha: Optional[float] = None) -> None:
252
+ """Print the summary to stdout."""
253
+ print(self.summary(alpha))
254
+
255
+ def to_dict(self) -> Dict[str, Any]:
256
+ """
257
+ Convert results to a dictionary.
258
+
259
+ Returns
260
+ -------
261
+ Dict[str, Any]
262
+ Dictionary containing all estimation results.
263
+ """
264
+ return {
265
+ "att": self.att,
266
+ "se": self.se,
267
+ "t_stat": self.t_stat,
268
+ "p_value": self.p_value,
269
+ "conf_int_lower": self.conf_int[0],
270
+ "conf_int_upper": self.conf_int[1],
271
+ "n_obs": self.n_obs,
272
+ "n_treated": self.n_treated,
273
+ "n_control": self.n_control,
274
+ "n_treated_obs": self.n_treated_obs,
275
+ "n_pre_periods": self.n_pre_periods,
276
+ "n_post_periods": self.n_post_periods,
277
+ "lambda_time": self.lambda_time,
278
+ "lambda_unit": self.lambda_unit,
279
+ "lambda_nn": self.lambda_nn,
280
+ "effective_rank": self.effective_rank,
281
+ "loocv_score": self.loocv_score,
282
+ "variance_method": self.variance_method,
283
+ }
284
+
285
+ def to_dataframe(self) -> pd.DataFrame:
286
+ """
287
+ Convert results to a pandas DataFrame.
288
+
289
+ Returns
290
+ -------
291
+ pd.DataFrame
292
+ DataFrame with estimation results.
293
+ """
294
+ return pd.DataFrame([self.to_dict()])
295
+
296
+ def get_treatment_effects_df(self) -> pd.DataFrame:
297
+ """
298
+ Get individual treatment effects as a DataFrame.
299
+
300
+ Returns
301
+ -------
302
+ pd.DataFrame
303
+ DataFrame with unit, time, and treatment effect columns.
304
+ """
305
+ return pd.DataFrame([
306
+ {"unit": unit, "time": time, "effect": effect}
307
+ for (unit, time), effect in self.treatment_effects.items()
308
+ ])
309
+
310
+ def get_unit_effects_df(self) -> pd.DataFrame:
311
+ """
312
+ Get unit fixed effects as a DataFrame.
313
+
314
+ Returns
315
+ -------
316
+ pd.DataFrame
317
+ DataFrame with unit and effect columns.
318
+ """
319
+ return pd.DataFrame([
320
+ {"unit": unit, "effect": effect}
321
+ for unit, effect in self.unit_effects.items()
322
+ ])
323
+
324
+ def get_time_effects_df(self) -> pd.DataFrame:
325
+ """
326
+ Get time fixed effects as a DataFrame.
327
+
328
+ Returns
329
+ -------
330
+ pd.DataFrame
331
+ DataFrame with time and effect columns.
332
+ """
333
+ return pd.DataFrame([
334
+ {"time": time, "effect": effect}
335
+ for time, effect in self.time_effects.items()
336
+ ])
337
+
338
+ @property
339
+ def is_significant(self) -> bool:
340
+ """Check if the ATT is statistically significant at the alpha level."""
341
+ return bool(self.p_value < self.alpha)
342
+
343
+ @property
344
+ def significance_stars(self) -> str:
345
+ """Return significance stars based on p-value."""
346
+ return _get_significance_stars(self.p_value)
347
+
348
+
349
+ class TROP:
350
+ """
351
+ Triply Robust Panel (TROP) estimator.
352
+
353
+ Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
354
+ TROP combines three robustness components:
355
+
356
+ 1. **Nuclear norm regularized factor model**: Estimates interactive fixed
357
+ effects L_it via matrix completion with nuclear norm penalty ||L||_*
358
+
359
+ 2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
360
+ where d(j,i) is the RMSE of outcome differences between units
361
+
362
+ 3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
363
+ weighting pre-treatment periods by proximity to treatment
364
+
365
+ Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
366
+ cross-validation on control observations.
367
+
368
+ Parameters
369
+ ----------
370
+ method : str, default='twostep'
371
+ Estimation method to use:
372
+
373
+ - 'twostep': Per-observation model fitting following Algorithm 2 of
374
+ Athey et al. (2025). Computes observation-specific weights and fits
375
+ a model for each treated observation, averaging the individual
376
+ treatment effects. More flexible but computationally intensive.
377
+
378
+ - 'joint': Joint weighted least squares optimization. Estimates a
379
+ single scalar treatment effect τ along with fixed effects and
380
+ optional low-rank factor adjustment. Faster but assumes homogeneous
381
+ treatment effects. Uses alternating minimization when nuclear norm
382
+ penalty is finite.
383
+
384
+ lambda_time_grid : list, optional
385
+ Grid of time weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
386
+ lambda_unit_grid : list, optional
387
+ Grid of unit weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
388
+ lambda_nn_grid : list, optional
389
+ Grid of nuclear norm regularization parameters. Default: [0, 0.01, 0.1, 1].
390
+ max_iter : int, default=100
391
+ Maximum iterations for nuclear norm optimization.
392
+ tol : float, default=1e-6
393
+ Convergence tolerance for optimization.
394
+ alpha : float, default=0.05
395
+ Significance level for confidence intervals.
396
+ variance_method : str, default='bootstrap'
397
+ Method for variance estimation: 'bootstrap' or 'jackknife'.
398
+ n_bootstrap : int, default=200
399
+ Number of replications for variance estimation.
400
+ max_loocv_samples : int, default=100
401
+ Maximum control observations to use in LOOCV for tuning parameter
402
+ selection. Subsampling is used for computational tractability as
403
+ noted in the paper. Increase for more precise tuning at the cost
404
+ of computational time.
405
+ seed : int, optional
406
+ Random seed for reproducibility.
407
+
408
+ Attributes
409
+ ----------
410
+ results_ : TROPResults
411
+ Estimation results after calling fit().
412
+ is_fitted_ : bool
413
+ Whether the model has been fitted.
414
+
415
+ Examples
416
+ --------
417
+ >>> from diff_diff import TROP
418
+ >>> trop = TROP()
419
+ >>> results = trop.fit(
420
+ ... data,
421
+ ... outcome='outcome',
422
+ ... treatment='treated',
423
+ ... unit='unit',
424
+ ... time='period',
425
+ ... )
426
+ >>> results.print_summary()
427
+
428
+ References
429
+ ----------
430
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
431
+ Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
432
+ """
433
+
434
+ # Class constants
435
+ DEFAULT_LOOCV_MAX_SAMPLES: int = 100
436
+ """Maximum control observations to use in LOOCV (for computational tractability).
437
+
438
+ As noted in the paper's footnote, LOOCV is subsampled for computational
439
+ tractability. This constant controls the maximum number of control observations
440
+ used in each LOOCV evaluation. Increase for more precise tuning at the cost
441
+ of computational time.
442
+ """
443
+
444
+ CONVERGENCE_TOL_SVD: float = 1e-10
445
+ """Tolerance for singular value truncation in soft-thresholding.
446
+
447
+ Singular values below this threshold after soft-thresholding are treated
448
+ as zero to improve numerical stability.
449
+ """
450
+
451
+ def __init__(
452
+ self,
453
+ method: str = "twostep",
454
+ lambda_time_grid: Optional[List[float]] = None,
455
+ lambda_unit_grid: Optional[List[float]] = None,
456
+ lambda_nn_grid: Optional[List[float]] = None,
457
+ max_iter: int = 100,
458
+ tol: float = 1e-6,
459
+ alpha: float = 0.05,
460
+ variance_method: str = 'bootstrap',
461
+ n_bootstrap: int = 200,
462
+ max_loocv_samples: int = 100,
463
+ seed: Optional[int] = None,
464
+ ):
465
+ # Validate method parameter
466
+ valid_methods = ("twostep", "joint")
467
+ if method not in valid_methods:
468
+ raise ValueError(
469
+ f"method must be one of {valid_methods}, got '{method}'"
470
+ )
471
+ self.method = method
472
+
473
+ # Default grids from paper
474
+ self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
475
+ self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
476
+ self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
477
+
478
+ self.max_iter = max_iter
479
+ self.tol = tol
480
+ self.alpha = alpha
481
+ self.variance_method = variance_method
482
+ self.n_bootstrap = n_bootstrap
483
+ self.max_loocv_samples = max_loocv_samples
484
+ self.seed = seed
485
+
486
+ # Validate parameters
487
+ valid_variance_methods = ("bootstrap", "jackknife")
488
+ if variance_method not in valid_variance_methods:
489
+ raise ValueError(
490
+ f"variance_method must be one of {valid_variance_methods}, "
491
+ f"got '{variance_method}'"
492
+ )
493
+
494
+ # Internal state
495
+ self.results_: Optional[TROPResults] = None
496
+ self.is_fitted_: bool = False
497
+ self._optimal_lambda: Optional[Tuple[float, float, float]] = None
498
+
499
+ # Pre-computed structures (set during fit)
500
+ self._precomputed: Optional[_PrecomputedStructures] = None
501
+
502
+ def _precompute_structures(
503
+ self,
504
+ Y: np.ndarray,
505
+ D: np.ndarray,
506
+ control_unit_idx: np.ndarray,
507
+ n_units: int,
508
+ n_periods: int,
509
+ ) -> _PrecomputedStructures:
510
+ """
511
+ Pre-compute data structures that are reused across LOOCV and estimation.
512
+
513
+ This method computes once what would otherwise be computed repeatedly:
514
+ - Pairwise unit distance matrix
515
+ - Time distance vectors
516
+ - Masks and indices
517
+
518
+ Parameters
519
+ ----------
520
+ Y : np.ndarray
521
+ Outcome matrix (n_periods x n_units).
522
+ D : np.ndarray
523
+ Treatment indicator matrix (n_periods x n_units).
524
+ control_unit_idx : np.ndarray
525
+ Indices of control units.
526
+ n_units : int
527
+ Number of units.
528
+ n_periods : int
529
+ Number of periods.
530
+
531
+ Returns
532
+ -------
533
+ _PrecomputedStructures
534
+ Pre-computed structures for efficient reuse.
535
+ """
536
+ # Compute pairwise unit distances (for all observation-specific weights)
537
+ # Following Equation 3 (page 7): RMSE between units over pre-treatment
538
+ if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None:
539
+ # Use Rust backend for parallel distance computation (4-8x speedup)
540
+ unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64))
541
+ else:
542
+ unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)
543
+
544
+ # Pre-compute time distance vectors for each target period
545
+ # Time distance: |t - s| for all s and each target t
546
+ time_dist_matrix = np.abs(
547
+ np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
548
+ ) # (n_periods, n_periods) where [t, s] = |t - s|
549
+
550
+ # Control and treatment masks
551
+ control_mask = D == 0
552
+ treated_mask = D == 1
553
+
554
+ # Identify treated observations
555
+ treated_observations = list(zip(*np.where(treated_mask)))
556
+
557
+ # Control observations for LOOCV
558
+ control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
559
+ if control_mask[t, i] and not np.isnan(Y[t, i])]
560
+
561
+ return {
562
+ "unit_dist_matrix": unit_dist_matrix,
563
+ "time_dist_matrix": time_dist_matrix,
564
+ "control_mask": control_mask,
565
+ "treated_mask": treated_mask,
566
+ "treated_observations": treated_observations,
567
+ "control_obs": control_obs,
568
+ "control_unit_idx": control_unit_idx,
569
+ "D": D,
570
+ "Y": Y,
571
+ "n_units": n_units,
572
+ "n_periods": n_periods,
573
+ }
574
+
575
+ def _compute_all_unit_distances(
576
+ self,
577
+ Y: np.ndarray,
578
+ D: np.ndarray,
579
+ n_units: int,
580
+ n_periods: int,
581
+ ) -> np.ndarray:
582
+ """
583
+ Compute pairwise unit distance matrix using vectorized operations.
584
+
585
+ Following Equation 3 (page 7):
586
+ dist_unit_{-t}(j, i) = sqrt(Σ_u (Y_{iu} - Y_{ju})² / n_valid)
587
+
588
+ For efficiency, we compute a base distance matrix excluding all treated
589
+ observations, which provides a good approximation. The exact per-observation
590
+ distances are refined when needed.
591
+
592
+ Uses vectorized numpy operations with masked arrays for O(n²) complexity
593
+ but with highly optimized inner loops via numpy/BLAS.
594
+
595
+ Parameters
596
+ ----------
597
+ Y : np.ndarray
598
+ Outcome matrix (n_periods x n_units).
599
+ D : np.ndarray
600
+ Treatment indicator matrix (n_periods x n_units).
601
+ n_units : int
602
+ Number of units.
603
+ n_periods : int
604
+ Number of periods.
605
+
606
+ Returns
607
+ -------
608
+ np.ndarray
609
+ Pairwise distance matrix (n_units x n_units).
610
+ """
611
+ # Mask for valid observations: control periods only (D=0), non-NaN
612
+ valid_mask = (D == 0) & ~np.isnan(Y)
613
+
614
+ # Replace invalid values with NaN for masked computation
615
+ Y_masked = np.where(valid_mask, Y, np.nan)
616
+
617
+ # Transpose to (n_units, n_periods) for easier broadcasting
618
+ Y_T = Y_masked.T # (n_units, n_periods)
619
+
620
+ # Compute pairwise squared differences using broadcasting
621
+ # Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
622
+ # Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
623
+ # diff has shape (n_units, n_units, n_periods)
624
+ diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :]
625
+ sq_diff = diff ** 2
626
+
627
+ # Count valid (non-NaN) observations per pair
628
+ # A difference is valid only if both units have valid observations
629
+ valid_diff = ~np.isnan(sq_diff)
630
+ n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units)
631
+
632
+ # Compute sum of squared differences (treating NaN as 0)
633
+ sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units)
634
+
635
+ # Compute RMSE distance: sqrt(sum / n_valid)
636
+ # Avoid division by zero
637
+ with np.errstate(divide='ignore', invalid='ignore'):
638
+ dist_matrix = np.sqrt(sq_diff_sum / n_valid)
639
+
640
+ # Set pairs with no valid observations to inf
641
+ dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf)
642
+
643
+ # Ensure diagonal is 0 (same unit distance)
644
+ np.fill_diagonal(dist_matrix, 0.0)
645
+
646
+ return dist_matrix
647
+
648
+ def _compute_unit_distance_for_obs(
649
+ self,
650
+ Y: np.ndarray,
651
+ D: np.ndarray,
652
+ j: int,
653
+ i: int,
654
+ target_period: int,
655
+ ) -> float:
656
+ """
657
+ Compute observation-specific pairwise distance from unit j to unit i.
658
+
659
+ This is the exact computation from Equation 3, excluding the target period.
660
+ Used when the base distance matrix approximation is insufficient.
661
+
662
+ Parameters
663
+ ----------
664
+ Y : np.ndarray
665
+ Outcome matrix (n_periods x n_units).
666
+ D : np.ndarray
667
+ Treatment indicator matrix.
668
+ j : int
669
+ Control unit index.
670
+ i : int
671
+ Treated unit index.
672
+ target_period : int
673
+ Target period to exclude.
674
+
675
+ Returns
676
+ -------
677
+ float
678
+ Pairwise RMSE distance.
679
+ """
680
+ n_periods = Y.shape[0]
681
+
682
+ # Mask: exclude target period, both units must be untreated, non-NaN
683
+ valid = np.ones(n_periods, dtype=bool)
684
+ valid[target_period] = False
685
+ valid &= (D[:, i] == 0) & (D[:, j] == 0)
686
+ valid &= ~np.isnan(Y[:, i]) & ~np.isnan(Y[:, j])
687
+
688
+ if np.any(valid):
689
+ sq_diffs = (Y[valid, i] - Y[valid, j]) ** 2
690
+ return np.sqrt(np.mean(sq_diffs))
691
+ else:
692
+ return np.inf
693
+
694
+ def _univariate_loocv_search(
695
+ self,
696
+ Y: np.ndarray,
697
+ D: np.ndarray,
698
+ control_mask: np.ndarray,
699
+ control_unit_idx: np.ndarray,
700
+ n_units: int,
701
+ n_periods: int,
702
+ param_name: str,
703
+ grid: List[float],
704
+ fixed_params: Dict[str, float],
705
+ ) -> Tuple[float, float]:
706
+ """
707
+ Search over one parameter with others fixed.
708
+
709
+ Following paper's footnote 2, this performs a univariate grid search
710
+ for one tuning parameter while holding others fixed. The fixed_params
711
+ can include _LAMBDA_INF values to disable specific components:
712
+ - lambda_nn = inf: Skip nuclear norm regularization (L=0)
713
+ - lambda_time = inf: Uniform time weights (treated as 0)
714
+ - lambda_unit = inf: Uniform unit weights (treated as 0)
715
+
716
+ Parameters
717
+ ----------
718
+ Y : np.ndarray
719
+ Outcome matrix (n_periods x n_units).
720
+ D : np.ndarray
721
+ Treatment indicator matrix (n_periods x n_units).
722
+ control_mask : np.ndarray
723
+ Boolean mask for control observations.
724
+ control_unit_idx : np.ndarray
725
+ Indices of control units.
726
+ n_units : int
727
+ Number of units.
728
+ n_periods : int
729
+ Number of periods.
730
+ param_name : str
731
+ Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
732
+ grid : List[float]
733
+ Grid of values to search over.
734
+ fixed_params : Dict[str, float]
735
+ Fixed values for other parameters. May include _LAMBDA_INF.
736
+
737
+ Returns
738
+ -------
739
+ Tuple[float, float]
740
+ (best_value, best_score) for the searched parameter.
741
+ """
742
+ best_score = np.inf
743
+ best_value = grid[0] if grid else 0.0
744
+
745
+ for value in grid:
746
+ params = {**fixed_params, param_name: value}
747
+
748
+ # Convert inf values to 0 for computation (inf means "disabled" = uniform weights)
749
+ lambda_time = params.get('lambda_time', 0.0)
750
+ lambda_unit = params.get('lambda_unit', 0.0)
751
+ lambda_nn = params.get('lambda_nn', 0.0)
752
+
753
+ # Handle infinity as "disabled" mode
754
+ # Per paper Equations 2-3:
755
+ # - λ_time/λ_unit=∞ → exp(-∞×dist)→0 for dist>0, uniform weights → use 0.0
756
+ # - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
757
+ # Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled"
758
+ if np.isinf(lambda_time):
759
+ lambda_time = 0.0 # Uniform time weights
760
+ if np.isinf(lambda_unit):
761
+ lambda_unit = 0.0 # Uniform unit weights
762
+ if np.isinf(lambda_nn):
763
+ lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
764
+
765
+ try:
766
+ score = self._loocv_score_obs_specific(
767
+ Y, D, control_mask, control_unit_idx,
768
+ lambda_time, lambda_unit, lambda_nn,
769
+ n_units, n_periods
770
+ )
771
+ if score < best_score:
772
+ best_score = score
773
+ best_value = value
774
+ except (np.linalg.LinAlgError, ValueError):
775
+ continue
776
+
777
+ return best_value, best_score
778
+
779
+ def _cycling_parameter_search(
780
+ self,
781
+ Y: np.ndarray,
782
+ D: np.ndarray,
783
+ control_mask: np.ndarray,
784
+ control_unit_idx: np.ndarray,
785
+ n_units: int,
786
+ n_periods: int,
787
+ initial_lambda: Tuple[float, float, float],
788
+ max_cycles: int = 10,
789
+ ) -> Tuple[float, float, float]:
790
+ """
791
+ Cycle through parameters until convergence (coordinate descent).
792
+
793
+ Following paper's footnote 2 (Stage 2), this iteratively optimizes
794
+ each tuning parameter while holding the others fixed, until convergence.
795
+
796
+ Parameters
797
+ ----------
798
+ Y : np.ndarray
799
+ Outcome matrix (n_periods x n_units).
800
+ D : np.ndarray
801
+ Treatment indicator matrix (n_periods x n_units).
802
+ control_mask : np.ndarray
803
+ Boolean mask for control observations.
804
+ control_unit_idx : np.ndarray
805
+ Indices of control units.
806
+ n_units : int
807
+ Number of units.
808
+ n_periods : int
809
+ Number of periods.
810
+ initial_lambda : Tuple[float, float, float]
811
+ Initial values (lambda_time, lambda_unit, lambda_nn).
812
+ max_cycles : int, default=10
813
+ Maximum number of coordinate descent cycles.
814
+
815
+ Returns
816
+ -------
817
+ Tuple[float, float, float]
818
+ Optimized (lambda_time, lambda_unit, lambda_nn).
819
+ """
820
+ lambda_time, lambda_unit, lambda_nn = initial_lambda
821
+ prev_score = np.inf
822
+
823
+ for cycle in range(max_cycles):
824
+ # Optimize λ_unit (fix λ_time, λ_nn)
825
+ lambda_unit, _ = self._univariate_loocv_search(
826
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
827
+ 'lambda_unit', self.lambda_unit_grid,
828
+ {'lambda_time': lambda_time, 'lambda_nn': lambda_nn}
829
+ )
830
+
831
+ # Optimize λ_time (fix λ_unit, λ_nn)
832
+ lambda_time, _ = self._univariate_loocv_search(
833
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
834
+ 'lambda_time', self.lambda_time_grid,
835
+ {'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn}
836
+ )
837
+
838
+ # Optimize λ_nn (fix λ_unit, λ_time)
839
+ lambda_nn, score = self._univariate_loocv_search(
840
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
841
+ 'lambda_nn', self.lambda_nn_grid,
842
+ {'lambda_unit': lambda_unit, 'lambda_time': lambda_time}
843
+ )
844
+
845
+ # Check convergence
846
+ if abs(score - prev_score) < 1e-6:
847
+ logger.debug(
848
+ "Cycling search converged after %d cycles with score %.6f",
849
+ cycle + 1, score
850
+ )
851
+ break
852
+ prev_score = score
853
+
854
+ return lambda_time, lambda_unit, lambda_nn
855
+
856
+ # =========================================================================
857
+ # Joint estimation method
858
+ # =========================================================================
859
+
860
+ def _compute_joint_weights(
861
+ self,
862
+ Y: np.ndarray,
863
+ D: np.ndarray,
864
+ lambda_time: float,
865
+ lambda_unit: float,
866
+ treated_periods: int,
867
+ n_units: int,
868
+ n_periods: int,
869
+ ) -> np.ndarray:
870
+ """
871
+ Compute distance-based weights for joint estimation.
872
+
873
+ Following the reference implementation, weights are computed based on:
874
+ - Time distance: distance to center of treated block
875
+ - Unit distance: RMSE to average treated trajectory over pre-periods
876
+
877
+ Parameters
878
+ ----------
879
+ Y : np.ndarray
880
+ Outcome matrix (n_periods x n_units).
881
+ D : np.ndarray
882
+ Treatment indicator matrix (n_periods x n_units).
883
+ lambda_time : float
884
+ Time weight decay parameter.
885
+ lambda_unit : float
886
+ Unit weight decay parameter.
887
+ treated_periods : int
888
+ Number of post-treatment periods.
889
+ n_units : int
890
+ Number of units.
891
+ n_periods : int
892
+ Number of periods.
893
+
894
+ Returns
895
+ -------
896
+ np.ndarray
897
+ Weight matrix (n_periods x n_units).
898
+ """
899
+ # Identify treated units (ever treated)
900
+ treated_mask = np.any(D == 1, axis=0)
901
+ treated_unit_idx = np.where(treated_mask)[0]
902
+
903
+ if len(treated_unit_idx) == 0:
904
+ raise ValueError("No treated units found")
905
+
906
+ # Time weights: distance to center of treated block
907
+ # Following reference: center = T - treated_periods/2
908
+ center = n_periods - treated_periods / 2.0
909
+ dist_time = np.abs(np.arange(n_periods, dtype=float) - center)
910
+ delta_time = np.exp(-lambda_time * dist_time)
911
+
912
+ # Unit weights: RMSE to average treated trajectory over pre-periods
913
+ # Compute average treated trajectory (use nanmean to handle NaN)
914
+ average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1)
915
+
916
+ # Pre-period mask: 1 in pre, 0 in post
917
+ pre_mask = np.ones(n_periods, dtype=float)
918
+ pre_mask[-treated_periods:] = 0.0
919
+
920
+ # Compute RMS distance for each unit
921
+ # dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre)
922
+ # Use NaN-safe operations: treat NaN differences as 0 (excluded)
923
+ diff = average_treated[:, np.newaxis] - Y
924
+ diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis]
925
+
926
+ # Count valid observations per unit in pre-period
927
+ # Must check diff is finite (both Y and average_treated finite)
928
+ # to match the periods contributing to diff_sq
929
+ valid_count = np.sum(
930
+ np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0
931
+ )
932
+ sum_sq = np.sum(diff_sq, axis=0)
933
+ n_pre = np.sum(pre_mask)
934
+
935
+ if n_pre == 0:
936
+ raise ValueError("No pre-treatment periods")
937
+
938
+ # Track units with no valid pre-period data
939
+ no_valid_pre = valid_count == 0
940
+
941
+ # Use valid count per unit (avoid division by zero for calculation)
942
+ valid_count_safe = np.maximum(valid_count, 1)
943
+ dist_unit = np.sqrt(sum_sq / valid_count_safe)
944
+
945
+ # Units with no valid pre-period data get zero weight
946
+ # (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0)
947
+ delta_unit = np.exp(-lambda_unit * dist_unit)
948
+ delta_unit[no_valid_pre] = 0.0
949
+
950
+ # Outer product: (n_periods x n_units)
951
+ delta = np.outer(delta_time, delta_unit)
952
+
953
+ return delta
954
+
955
+ def _loocv_score_joint(
956
+ self,
957
+ Y: np.ndarray,
958
+ D: np.ndarray,
959
+ control_obs: List[Tuple[int, int]],
960
+ lambda_time: float,
961
+ lambda_unit: float,
962
+ lambda_nn: float,
963
+ treated_periods: int,
964
+ n_units: int,
965
+ n_periods: int,
966
+ ) -> float:
967
+ """
968
+ Compute LOOCV score for joint method with specific parameter combination.
969
+
970
+ Following paper's Equation 5:
971
+ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
972
+
973
+ For joint method, we exclude each control observation, fit the joint model
974
+ on remaining data, and compute the pseudo-treatment effect at the excluded obs.
975
+
976
+ Parameters
977
+ ----------
978
+ Y : np.ndarray
979
+ Outcome matrix (n_periods x n_units).
980
+ D : np.ndarray
981
+ Treatment indicator matrix (n_periods x n_units).
982
+ control_obs : List[Tuple[int, int]]
983
+ List of (t, i) control observations for LOOCV.
984
+ lambda_time : float
985
+ Time weight decay parameter.
986
+ lambda_unit : float
987
+ Unit weight decay parameter.
988
+ lambda_nn : float
989
+ Nuclear norm regularization parameter.
990
+ treated_periods : int
991
+ Number of post-treatment periods.
992
+ n_units : int
993
+ Number of units.
994
+ n_periods : int
995
+ Number of periods.
996
+
997
+ Returns
998
+ -------
999
+ float
1000
+ LOOCV score (sum of squared pseudo-treatment effects).
1001
+ """
1002
+ # Compute global weights (same for all LOOCV iterations)
1003
+ delta = self._compute_joint_weights(
1004
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1005
+ )
1006
+
1007
+ tau_sq_sum = 0.0
1008
+ n_valid = 0
1009
+
1010
+ for t_ex, i_ex in control_obs:
1011
+ # Create modified delta with excluded observation zeroed out
1012
+ delta_ex = delta.copy()
1013
+ delta_ex[t_ex, i_ex] = 0.0
1014
+
1015
+ try:
1016
+ # Fit joint model excluding this observation
1017
+ if lambda_nn >= 1e10:
1018
+ mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta_ex)
1019
+ L = np.zeros((n_periods, n_units))
1020
+ else:
1021
+ mu, alpha, beta, L, tau = self._solve_joint_with_lowrank(
1022
+ Y, D, delta_ex, lambda_nn, self.max_iter, self.tol
1023
+ )
1024
+
1025
+ # Pseudo treatment effect: τ = Y - μ - α - β - L
1026
+ if np.isfinite(Y[t_ex, i_ex]):
1027
+ tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex]
1028
+ tau_sq_sum += tau_loocv ** 2
1029
+ n_valid += 1
1030
+
1031
+ except (np.linalg.LinAlgError, ValueError):
1032
+ # Any failure means this λ combination is invalid per Equation 5
1033
+ return np.inf
1034
+
1035
+ if n_valid == 0:
1036
+ return np.inf
1037
+
1038
+ return tau_sq_sum
1039
+
1040
+ def _solve_joint_no_lowrank(
1041
+ self,
1042
+ Y: np.ndarray,
1043
+ D: np.ndarray,
1044
+ delta: np.ndarray,
1045
+ ) -> Tuple[float, np.ndarray, np.ndarray, float]:
1046
+ """
1047
+ Solve joint TWFE + treatment via weighted least squares (no low-rank).
1048
+
1049
+ Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})²
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ Y : np.ndarray
1054
+ Outcome matrix (n_periods x n_units).
1055
+ D : np.ndarray
1056
+ Treatment indicator matrix (n_periods x n_units).
1057
+ delta : np.ndarray
1058
+ Weight matrix (n_periods x n_units).
1059
+
1060
+ Returns
1061
+ -------
1062
+ Tuple[float, np.ndarray, np.ndarray, float]
1063
+ (mu, alpha, beta, tau) estimated parameters.
1064
+ """
1065
+ n_periods, n_units = Y.shape
1066
+
1067
+ # Flatten matrices for regression
1068
+ y = Y.flatten() # length n_periods * n_units
1069
+ w = D.flatten()
1070
+ weights = delta.flatten()
1071
+
1072
+ # Handle NaN values: zero weight for NaN outcomes/weights, impute with 0
1073
+ # This ensures NaN observations don't contribute to estimation
1074
+ valid_y = np.isfinite(y)
1075
+ valid_w = np.isfinite(weights)
1076
+ valid_mask = valid_y & valid_w
1077
+ weights = np.where(valid_mask, weights, 0.0)
1078
+ y = np.where(valid_mask, y, 0.0)
1079
+
1080
+ sqrt_weights = np.sqrt(np.maximum(weights, 0))
1081
+
1082
+ # Check for all-zero weights (matches Rust's sum_w < 1e-10 check)
1083
+ sum_w = np.sum(weights)
1084
+ if sum_w < 1e-10:
1085
+ raise ValueError("All weights are zero - cannot estimate")
1086
+
1087
+ # Build design matrix: [intercept, unit_dummies, time_dummies, treatment]
1088
+ # Total columns: 1 + n_units + n_periods + 1
1089
+ # But we need to drop one unit and one time dummy for identification
1090
+ # Drop first unit (unit 0) and first time (time 0)
1091
+ n_obs = n_periods * n_units
1092
+ n_params = 1 + (n_units - 1) + (n_periods - 1) + 1
1093
+
1094
+ X = np.zeros((n_obs, n_params))
1095
+ X[:, 0] = 1.0 # intercept
1096
+
1097
+ # Unit dummies (skip unit 0)
1098
+ for i in range(1, n_units):
1099
+ for t in range(n_periods):
1100
+ X[t * n_units + i, i] = 1.0
1101
+
1102
+ # Time dummies (skip time 0)
1103
+ for t in range(1, n_periods):
1104
+ for i in range(n_units):
1105
+ X[t * n_units + i, (n_units - 1) + t] = 1.0
1106
+
1107
+ # Treatment indicator
1108
+ X[:, -1] = w
1109
+
1110
+ # Apply weights
1111
+ X_weighted = X * sqrt_weights[:, np.newaxis]
1112
+ y_weighted = y * sqrt_weights
1113
+
1114
+ # Solve weighted least squares
1115
+ try:
1116
+ coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None)
1117
+ except np.linalg.LinAlgError:
1118
+ # Fallback: use pseudo-inverse
1119
+ coeffs = np.linalg.pinv(X_weighted) @ y_weighted
1120
+
1121
+ # Extract parameters
1122
+ mu = coeffs[0]
1123
+ alpha = np.zeros(n_units)
1124
+ alpha[1:] = coeffs[1:n_units]
1125
+ beta = np.zeros(n_periods)
1126
+ beta[1:] = coeffs[n_units:(n_units + n_periods - 1)]
1127
+ tau = coeffs[-1]
1128
+
1129
+ return float(mu), alpha, beta, float(tau)
1130
+
1131
+ def _solve_joint_with_lowrank(
1132
+ self,
1133
+ Y: np.ndarray,
1134
+ D: np.ndarray,
1135
+ delta: np.ndarray,
1136
+ lambda_nn: float,
1137
+ max_iter: int = 100,
1138
+ tol: float = 1e-6,
1139
+ ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]:
1140
+ """
1141
+ Solve joint TWFE + treatment + low-rank via alternating minimization.
1142
+
1143
+ Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_*
1144
+
1145
+ Parameters
1146
+ ----------
1147
+ Y : np.ndarray
1148
+ Outcome matrix (n_periods x n_units).
1149
+ D : np.ndarray
1150
+ Treatment indicator matrix (n_periods x n_units).
1151
+ delta : np.ndarray
1152
+ Weight matrix (n_periods x n_units).
1153
+ lambda_nn : float
1154
+ Nuclear norm regularization parameter.
1155
+ max_iter : int, default=100
1156
+ Maximum iterations for alternating minimization.
1157
+ tol : float, default=1e-6
1158
+ Convergence tolerance.
1159
+
1160
+ Returns
1161
+ -------
1162
+ Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]
1163
+ (mu, alpha, beta, L, tau) estimated parameters.
1164
+ """
1165
+ n_periods, n_units = Y.shape
1166
+
1167
+ # Handle NaN values: impute with 0 for computations
1168
+ # The solver will also zero weights for NaN observations
1169
+ Y_safe = np.where(np.isfinite(Y), Y, 0.0)
1170
+
1171
+ # Mask delta to exclude NaN outcomes from estimation
1172
+ # This ensures NaN observations don't contribute to the gradient step
1173
+ nan_mask = ~np.isfinite(Y)
1174
+ delta_masked = delta.copy()
1175
+ delta_masked[nan_mask] = 0.0
1176
+
1177
+ # Initialize L = 0
1178
+ L = np.zeros((n_periods, n_units))
1179
+
1180
+ for iteration in range(max_iter):
1181
+ L_old = L.copy()
1182
+
1183
+ # Step 1: Fix L, solve for (mu, alpha, beta, tau)
1184
+ # Adjusted outcome: Y - L (using NaN-safe Y)
1185
+ # Pass masked delta to exclude NaN observations from WLS
1186
+ Y_adj = Y_safe - L
1187
+ mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta_masked)
1188
+
1189
+ # Step 2: Fix (mu, alpha, beta, tau), update L
1190
+ # Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y)
1191
+ R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D
1192
+
1193
+ # Weighted proximal step for L (soft-threshold SVD)
1194
+ # Normalize weights (using masked delta to exclude NaN observations)
1195
+ delta_max = np.max(delta_masked)
1196
+ if delta_max > 0:
1197
+ delta_norm = delta_masked / delta_max
1198
+ else:
1199
+ delta_norm = delta_masked
1200
+
1201
+ # Weighted average between current L and target R
1202
+ # L_next = L + delta_norm * (R - L), then soft-threshold
1203
+ # NaN observations have delta_norm=0, so they don't influence L update
1204
+ gradient_step = L + delta_norm * (R - L)
1205
+
1206
+ # Soft-threshold singular values
1207
+ # Use eta * lambda_nn for proper proximal step size (matches Rust)
1208
+ eta = 1.0 / delta_max if delta_max > 0 else 1.0
1209
+ L = self._soft_threshold_svd(gradient_step, eta * lambda_nn)
1210
+
1211
+ # Check convergence
1212
+ if np.max(np.abs(L - L_old)) < tol:
1213
+ break
1214
+
1215
+ return mu, alpha, beta, L, tau
1216
+
1217
+ def _fit_joint(
1218
+ self,
1219
+ data: pd.DataFrame,
1220
+ outcome: str,
1221
+ treatment: str,
1222
+ unit: str,
1223
+ time: str,
1224
+ ) -> TROPResults:
1225
+ """
1226
+ Fit TROP using joint weighted least squares method.
1227
+
1228
+ This method estimates a single scalar treatment effect τ along with
1229
+ fixed effects and optional low-rank factor adjustment.
1230
+
1231
+ Parameters
1232
+ ----------
1233
+ data : pd.DataFrame
1234
+ Panel data.
1235
+ outcome : str
1236
+ Outcome variable column name.
1237
+ treatment : str
1238
+ Treatment indicator column name.
1239
+ unit : str
1240
+ Unit identifier column name.
1241
+ time : str
1242
+ Time period column name.
1243
+
1244
+ Returns
1245
+ -------
1246
+ TROPResults
1247
+ Estimation results.
1248
+
1249
+ Notes
1250
+ -----
1251
+ Bootstrap and jackknife variance estimation assume simultaneous treatment
1252
+ adoption (fixed `treated_periods` across resamples). The treatment timing
1253
+ is inferred from the data once and held constant for all bootstrap/jackknife
1254
+ iterations. For staggered adoption designs where treatment timing varies
1255
+ across units, use `method="twostep"` which computes observation-specific
1256
+ weights that naturally handle heterogeneous timing.
1257
+ """
1258
+ # Data setup (same as twostep method)
1259
+ all_units = sorted(data[unit].unique())
1260
+ all_periods = sorted(data[time].unique())
1261
+
1262
+ n_units = len(all_units)
1263
+ n_periods = len(all_periods)
1264
+
1265
+ idx_to_unit = {i: u for i, u in enumerate(all_units)}
1266
+ idx_to_period = {i: p for i, p in enumerate(all_periods)}
1267
+
1268
+ # Create matrices
1269
+ Y = (
1270
+ data.pivot(index=time, columns=unit, values=outcome)
1271
+ .reindex(index=all_periods, columns=all_units)
1272
+ .values
1273
+ )
1274
+
1275
+ D_raw = (
1276
+ data.pivot(index=time, columns=unit, values=treatment)
1277
+ .reindex(index=all_periods, columns=all_units)
1278
+ )
1279
+ missing_mask = pd.isna(D_raw).values
1280
+ D = D_raw.fillna(0).astype(int).values
1281
+
1282
+ # Validate absorbing state
1283
+ violating_units = []
1284
+ for unit_idx in range(n_units):
1285
+ observed_mask = ~missing_mask[:, unit_idx]
1286
+ observed_d = D[observed_mask, unit_idx]
1287
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
1288
+ violating_units.append(all_units[unit_idx])
1289
+
1290
+ if violating_units:
1291
+ raise ValueError(
1292
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
1293
+ f"D[t, unit] must be monotonic non-decreasing."
1294
+ )
1295
+
1296
+ # Identify treated observations
1297
+ treated_mask = D == 1
1298
+ n_treated_obs = np.sum(treated_mask)
1299
+
1300
+ if n_treated_obs == 0:
1301
+ raise ValueError("No treated observations found")
1302
+
1303
+ # Identify treated and control units
1304
+ unit_ever_treated = np.any(D == 1, axis=0)
1305
+ treated_unit_idx = np.where(unit_ever_treated)[0]
1306
+ control_unit_idx = np.where(~unit_ever_treated)[0]
1307
+
1308
+ if len(control_unit_idx) == 0:
1309
+ raise ValueError("No control units found")
1310
+
1311
+ # Determine pre/post periods
1312
+ first_treat_period = None
1313
+ for t in range(n_periods):
1314
+ if np.any(D[t, :] == 1):
1315
+ first_treat_period = t
1316
+ break
1317
+
1318
+ if first_treat_period is None:
1319
+ raise ValueError("Could not infer post-treatment periods from D matrix")
1320
+
1321
+ n_pre_periods = first_treat_period
1322
+ treated_periods = n_periods - first_treat_period
1323
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
1324
+
1325
+ if n_pre_periods < 2:
1326
+ raise ValueError("Need at least 2 pre-treatment periods")
1327
+
1328
+ # Check for staggered adoption (joint method requires simultaneous treatment)
1329
+ # Use only observed periods (skip missing) to avoid false positives on unbalanced panels
1330
+ first_treat_by_unit = []
1331
+ for i in treated_unit_idx:
1332
+ observed_mask = ~missing_mask[:, i]
1333
+ # Get D values for observed periods only
1334
+ observed_d = D[observed_mask, i]
1335
+ observed_periods = np.where(observed_mask)[0]
1336
+ # Find first treatment among observed periods
1337
+ treated_idx = np.where(observed_d == 1)[0]
1338
+ if len(treated_idx) > 0:
1339
+ first_treat_by_unit.append(observed_periods[treated_idx[0]])
1340
+
1341
+ unique_starts = sorted(set(first_treat_by_unit))
1342
+ if len(unique_starts) > 1:
1343
+ raise ValueError(
1344
+ f"method='joint' requires simultaneous treatment adoption, but your data "
1345
+ f"shows staggered adoption (units first treated at periods {unique_starts}). "
1346
+ f"Use method='twostep' which properly handles staggered adoption designs."
1347
+ )
1348
+
1349
+ # LOOCV grid search for tuning parameters
1350
+ # Use Rust backend when available for parallel LOOCV (5-10x speedup)
1351
+ best_lambda = None
1352
+ best_score = np.inf
1353
+ control_mask = D == 0
1354
+
1355
+ if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None:
1356
+ try:
1357
+ # Prepare inputs for Rust function
1358
+ control_mask_u8 = control_mask.astype(np.uint8)
1359
+
1360
+ lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
1361
+ lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
1362
+ lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
1363
+
1364
+ result = _rust_loocv_grid_search_joint(
1365
+ Y, D.astype(np.float64), control_mask_u8,
1366
+ lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
1367
+ self.max_loocv_samples, self.max_iter, self.tol,
1368
+ self.seed if self.seed is not None else 0
1369
+ )
1370
+ # Unpack result - 7 values including optional first_failed_obs
1371
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
1372
+ # Only accept finite scores - infinite means all fits failed
1373
+ if np.isfinite(best_score):
1374
+ best_lambda = (best_lt, best_lu, best_ln)
1375
+ # Emit warnings consistent with Python implementation
1376
+ if n_valid == 0:
1377
+ obs_info = ""
1378
+ if first_failed_obs is not None:
1379
+ t_idx, i_idx = first_failed_obs
1380
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1381
+ warnings.warn(
1382
+ f"LOOCV: All {n_attempted} fits failed for "
1383
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1384
+ f"Returning infinite score.{obs_info}",
1385
+ UserWarning
1386
+ )
1387
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
1388
+ n_failed = n_attempted - n_valid
1389
+ obs_info = ""
1390
+ if first_failed_obs is not None:
1391
+ t_idx, i_idx = first_failed_obs
1392
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1393
+ warnings.warn(
1394
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
1395
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1396
+ f"This may indicate numerical instability.{obs_info}",
1397
+ UserWarning
1398
+ )
1399
+ except Exception as e:
1400
+ # Fall back to Python implementation on error
1401
+ logger.debug(
1402
+ "Rust LOOCV grid search (joint) failed, falling back to Python: %s", e
1403
+ )
1404
+ best_lambda = None
1405
+ best_score = np.inf
1406
+
1407
+ # Fall back to Python implementation if Rust unavailable or failed
1408
+ if best_lambda is None:
1409
+ # Get control observations for LOOCV
1410
+ control_obs = [
1411
+ (t, i) for t in range(n_periods) for i in range(n_units)
1412
+ if control_mask[t, i] and not np.isnan(Y[t, i])
1413
+ ]
1414
+
1415
+ # Subsample if needed (sample indices to avoid ValueError on list of tuples)
1416
+ rng = np.random.default_rng(self.seed)
1417
+ max_loocv = min(self.max_loocv_samples, len(control_obs))
1418
+ if len(control_obs) > max_loocv:
1419
+ indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
1420
+ control_obs = [control_obs[idx] for idx in indices]
1421
+
1422
+ # Grid search with true LOOCV
1423
+ for lambda_time_val in self.lambda_time_grid:
1424
+ for lambda_unit_val in self.lambda_unit_grid:
1425
+ for lambda_nn_val in self.lambda_nn_grid:
1426
+ # Convert infinity values
1427
+ lt = 0.0 if np.isinf(lambda_time_val) else lambda_time_val
1428
+ lu = 0.0 if np.isinf(lambda_unit_val) else lambda_unit_val
1429
+ ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val
1430
+
1431
+ try:
1432
+ score = self._loocv_score_joint(
1433
+ Y, D, control_obs, lt, lu, ln,
1434
+ treated_periods, n_units, n_periods
1435
+ )
1436
+
1437
+ if score < best_score:
1438
+ best_score = score
1439
+ best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val)
1440
+
1441
+ except (np.linalg.LinAlgError, ValueError):
1442
+ continue
1443
+
1444
+ if best_lambda is None:
1445
+ warnings.warn(
1446
+ "All tuning parameter combinations failed. Using defaults.",
1447
+ UserWarning
1448
+ )
1449
+ best_lambda = (1.0, 1.0, 0.1)
1450
+ best_score = np.nan
1451
+
1452
+ # Final estimation with best parameters
1453
+ lambda_time, lambda_unit, lambda_nn = best_lambda
1454
+ original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda
1455
+
1456
+ # Convert infinity values for computation
1457
+ if np.isinf(lambda_time):
1458
+ lambda_time = 0.0
1459
+ if np.isinf(lambda_unit):
1460
+ lambda_unit = 0.0
1461
+ if np.isinf(lambda_nn):
1462
+ lambda_nn = 1e10
1463
+
1464
+ # Compute final weights and fit
1465
+ delta = self._compute_joint_weights(
1466
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1467
+ )
1468
+
1469
+ if lambda_nn >= 1e10:
1470
+ mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta)
1471
+ L = np.zeros((n_periods, n_units))
1472
+ else:
1473
+ mu, alpha, beta, L, tau = self._solve_joint_with_lowrank(
1474
+ Y, D, delta, lambda_nn, self.max_iter, self.tol
1475
+ )
1476
+
1477
+ # ATT is the scalar treatment effect
1478
+ att = tau
1479
+
1480
+ # Compute individual treatment effects for reporting (same τ for all)
1481
+ treatment_effects = {}
1482
+ for t in range(n_periods):
1483
+ for i in range(n_units):
1484
+ if D[t, i] == 1:
1485
+ unit_id = idx_to_unit[i]
1486
+ time_id = idx_to_period[t]
1487
+ treatment_effects[(unit_id, time_id)] = tau
1488
+
1489
+ # Compute effective rank of L
1490
+ _, s, _ = np.linalg.svd(L, full_matrices=False)
1491
+ if s[0] > 0:
1492
+ effective_rank = np.sum(s) / s[0]
1493
+ else:
1494
+ effective_rank = 0.0
1495
+
1496
+ # Bootstrap variance estimation
1497
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
1498
+
1499
+ if self.variance_method == "bootstrap":
1500
+ se, bootstrap_dist = self._bootstrap_variance_joint(
1501
+ data, outcome, treatment, unit, time,
1502
+ effective_lambda, treated_periods
1503
+ )
1504
+ else:
1505
+ # Jackknife for joint method
1506
+ se, bootstrap_dist = self._jackknife_variance_joint(
1507
+ Y, D, effective_lambda, treated_periods,
1508
+ n_units, n_periods
1509
+ )
1510
+
1511
+ # Compute test statistics
1512
+ if se > 0:
1513
+ t_stat = att / se
1514
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
1515
+ conf_int = compute_confidence_interval(att, se, self.alpha)
1516
+ else:
1517
+ t_stat = np.nan
1518
+ p_value = np.nan
1519
+ conf_int = (np.nan, np.nan)
1520
+
1521
+ # Create results dictionaries
1522
+ unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)}
1523
+ time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)}
1524
+
1525
+ self.results_ = TROPResults(
1526
+ att=float(att),
1527
+ se=float(se),
1528
+ t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat,
1529
+ p_value=float(p_value) if np.isfinite(p_value) else p_value,
1530
+ conf_int=conf_int,
1531
+ n_obs=len(data),
1532
+ n_treated=len(treated_unit_idx),
1533
+ n_control=len(control_unit_idx),
1534
+ n_treated_obs=int(n_treated_obs),
1535
+ unit_effects=unit_effects_dict,
1536
+ time_effects=time_effects_dict,
1537
+ treatment_effects=treatment_effects,
1538
+ lambda_time=original_lambda_time,
1539
+ lambda_unit=original_lambda_unit,
1540
+ lambda_nn=original_lambda_nn,
1541
+ factor_matrix=L,
1542
+ effective_rank=effective_rank,
1543
+ loocv_score=best_score,
1544
+ variance_method=self.variance_method,
1545
+ alpha=self.alpha,
1546
+ n_pre_periods=n_pre_periods,
1547
+ n_post_periods=n_post_periods,
1548
+ n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None,
1549
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
1550
+ )
1551
+
1552
+ self.is_fitted_ = True
1553
+ return self.results_
1554
+
1555
+ def _bootstrap_variance_joint(
1556
+ self,
1557
+ data: pd.DataFrame,
1558
+ outcome: str,
1559
+ treatment: str,
1560
+ unit: str,
1561
+ time: str,
1562
+ optimal_lambda: Tuple[float, float, float],
1563
+ treated_periods: int,
1564
+ ) -> Tuple[float, np.ndarray]:
1565
+ """
1566
+ Compute bootstrap standard error for joint method.
1567
+
1568
+ Uses Rust backend when available for parallel bootstrap (5-15x speedup).
1569
+
1570
+ Parameters
1571
+ ----------
1572
+ data : pd.DataFrame
1573
+ Original data.
1574
+ outcome : str
1575
+ Outcome column name.
1576
+ treatment : str
1577
+ Treatment column name.
1578
+ unit : str
1579
+ Unit column name.
1580
+ time : str
1581
+ Time column name.
1582
+ optimal_lambda : tuple
1583
+ Optimal tuning parameters.
1584
+ treated_periods : int
1585
+ Number of post-treatment periods.
1586
+
1587
+ Returns
1588
+ -------
1589
+ Tuple[float, np.ndarray]
1590
+ (se, bootstrap_estimates).
1591
+ """
1592
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
1593
+
1594
+ # Try Rust backend for parallel bootstrap (5-15x speedup)
1595
+ if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None:
1596
+ try:
1597
+ # Create matrices for Rust function
1598
+ all_units = sorted(data[unit].unique())
1599
+ all_periods = sorted(data[time].unique())
1600
+
1601
+ Y = (
1602
+ data.pivot(index=time, columns=unit, values=outcome)
1603
+ .reindex(index=all_periods, columns=all_units)
1604
+ .values
1605
+ )
1606
+ D = (
1607
+ data.pivot(index=time, columns=unit, values=treatment)
1608
+ .reindex(index=all_periods, columns=all_units)
1609
+ .fillna(0)
1610
+ .astype(np.float64)
1611
+ .values
1612
+ )
1613
+
1614
+ bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint(
1615
+ Y, D,
1616
+ lambda_time, lambda_unit, lambda_nn,
1617
+ self.n_bootstrap, self.max_iter, self.tol,
1618
+ self.seed if self.seed is not None else 0
1619
+ )
1620
+
1621
+ if len(bootstrap_estimates) < 10:
1622
+ warnings.warn(
1623
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
1624
+ UserWarning
1625
+ )
1626
+ if len(bootstrap_estimates) == 0:
1627
+ return 0.0, np.array([])
1628
+
1629
+ return float(se), np.array(bootstrap_estimates)
1630
+
1631
+ except Exception as e:
1632
+ logger.debug(
1633
+ "Rust bootstrap (joint) failed, falling back to Python: %s", e
1634
+ )
1635
+
1636
+ # Python fallback implementation
1637
+ rng = np.random.default_rng(self.seed)
1638
+
1639
+ # Stratified bootstrap sampling
1640
+ unit_ever_treated = data.groupby(unit)[treatment].max()
1641
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist())
1642
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist())
1643
+
1644
+ n_treated_units = len(treated_units)
1645
+ n_control_units = len(control_units)
1646
+
1647
+ bootstrap_estimates_list: List[float] = []
1648
+
1649
+ for _ in range(self.n_bootstrap):
1650
+ # Stratified sampling
1651
+ if n_control_units > 0:
1652
+ sampled_control = rng.choice(
1653
+ control_units, size=n_control_units, replace=True
1654
+ )
1655
+ else:
1656
+ sampled_control = np.array([], dtype=object)
1657
+
1658
+ if n_treated_units > 0:
1659
+ sampled_treated = rng.choice(
1660
+ treated_units, size=n_treated_units, replace=True
1661
+ )
1662
+ else:
1663
+ sampled_treated = np.array([], dtype=object)
1664
+
1665
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
1666
+
1667
+ # Create bootstrap sample
1668
+ boot_data = pd.concat([
1669
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
1670
+ for idx, u in enumerate(sampled_units)
1671
+ ], ignore_index=True)
1672
+
1673
+ try:
1674
+ tau = self._fit_joint_with_fixed_lambda(
1675
+ boot_data, outcome, treatment, unit, time,
1676
+ optimal_lambda, treated_periods
1677
+ )
1678
+ bootstrap_estimates_list.append(tau)
1679
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1680
+ continue
1681
+
1682
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
1683
+
1684
+ if len(bootstrap_estimates) < 10:
1685
+ warnings.warn(
1686
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
1687
+ UserWarning
1688
+ )
1689
+ if len(bootstrap_estimates) == 0:
1690
+ return 0.0, np.array([])
1691
+
1692
+ se = np.std(bootstrap_estimates, ddof=1)
1693
+ return float(se), bootstrap_estimates
1694
+
1695
+ def _fit_joint_with_fixed_lambda(
1696
+ self,
1697
+ data: pd.DataFrame,
1698
+ outcome: str,
1699
+ treatment: str,
1700
+ unit: str,
1701
+ time: str,
1702
+ fixed_lambda: Tuple[float, float, float],
1703
+ treated_periods: int,
1704
+ ) -> float:
1705
+ """
1706
+ Fit joint model with fixed tuning parameters.
1707
+
1708
+ Returns only the treatment effect τ.
1709
+ """
1710
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
1711
+
1712
+ all_units = sorted(data[unit].unique())
1713
+ all_periods = sorted(data[time].unique())
1714
+
1715
+ n_units = len(all_units)
1716
+ n_periods = len(all_periods)
1717
+
1718
+ Y = (
1719
+ data.pivot(index=time, columns=unit, values=outcome)
1720
+ .reindex(index=all_periods, columns=all_units)
1721
+ .values
1722
+ )
1723
+ D = (
1724
+ data.pivot(index=time, columns=unit, values=treatment)
1725
+ .reindex(index=all_periods, columns=all_units)
1726
+ .fillna(0)
1727
+ .astype(int)
1728
+ .values
1729
+ )
1730
+
1731
+ # Compute weights
1732
+ delta = self._compute_joint_weights(
1733
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1734
+ )
1735
+
1736
+ # Fit model
1737
+ if lambda_nn >= 1e10:
1738
+ _, _, _, tau = self._solve_joint_no_lowrank(Y, D, delta)
1739
+ else:
1740
+ _, _, _, _, tau = self._solve_joint_with_lowrank(
1741
+ Y, D, delta, lambda_nn, self.max_iter, self.tol
1742
+ )
1743
+
1744
+ return tau
1745
+
1746
+ def _jackknife_variance_joint(
1747
+ self,
1748
+ Y: np.ndarray,
1749
+ D: np.ndarray,
1750
+ optimal_lambda: Tuple[float, float, float],
1751
+ treated_periods: int,
1752
+ n_units: int,
1753
+ n_periods: int,
1754
+ ) -> Tuple[float, np.ndarray]:
1755
+ """
1756
+ Compute jackknife standard error for joint method.
1757
+
1758
+ Parameters
1759
+ ----------
1760
+ Y : np.ndarray
1761
+ Outcome matrix.
1762
+ D : np.ndarray
1763
+ Treatment matrix.
1764
+ optimal_lambda : tuple
1765
+ Optimal tuning parameters.
1766
+ treated_periods : int
1767
+ Number of post-treatment periods.
1768
+ n_units : int
1769
+ Number of units.
1770
+ n_periods : int
1771
+ Number of periods.
1772
+
1773
+ Returns
1774
+ -------
1775
+ Tuple[float, np.ndarray]
1776
+ (se, jackknife_estimates).
1777
+ """
1778
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
1779
+ jackknife_estimates = []
1780
+
1781
+ # Get treated unit indices
1782
+ treated_unit_idx = np.where(np.any(D == 1, axis=0))[0]
1783
+
1784
+ for leave_out in treated_unit_idx:
1785
+ # True leave-one-out: zero the delta weight for the left-out unit
1786
+ # This excludes the unit from estimation without imputation
1787
+ Y_jack = Y.copy()
1788
+ D_jack = D.copy()
1789
+ D_jack[:, leave_out] = 0 # Mark as not treated for weight computation
1790
+
1791
+ try:
1792
+ # Compute weights (left-out unit is still in calculation)
1793
+ delta = self._compute_joint_weights(
1794
+ Y_jack, D_jack, lambda_time, lambda_unit,
1795
+ treated_periods, n_units, n_periods
1796
+ )
1797
+
1798
+ # Zero the delta weight for the left-out unit
1799
+ # This ensures the unit doesn't contribute to estimation
1800
+ delta[:, leave_out] = 0.0
1801
+
1802
+ # Fit model (left-out unit has zero weight, truly excluded)
1803
+ if lambda_nn >= 1e10:
1804
+ _, _, _, tau = self._solve_joint_no_lowrank(Y_jack, D_jack, delta)
1805
+ else:
1806
+ _, _, _, _, tau = self._solve_joint_with_lowrank(
1807
+ Y_jack, D_jack, delta, lambda_nn, self.max_iter, self.tol
1808
+ )
1809
+
1810
+ jackknife_estimates.append(tau)
1811
+
1812
+ except (np.linalg.LinAlgError, ValueError):
1813
+ continue
1814
+
1815
+ jackknife_estimates = np.array(jackknife_estimates)
1816
+
1817
+ if len(jackknife_estimates) < 2:
1818
+ return 0.0, jackknife_estimates
1819
+
1820
+ # Jackknife SE formula
1821
+ n = len(jackknife_estimates)
1822
+ mean_est = np.mean(jackknife_estimates)
1823
+ se = np.sqrt((n - 1) / n * np.sum((jackknife_estimates - mean_est) ** 2))
1824
+
1825
+ return float(se), jackknife_estimates
1826
+
1827
+ def fit(
1828
+ self,
1829
+ data: pd.DataFrame,
1830
+ outcome: str,
1831
+ treatment: str,
1832
+ unit: str,
1833
+ time: str,
1834
+ ) -> TROPResults:
1835
+ """
1836
+ Fit the TROP model.
1837
+
1838
+ Parameters
1839
+ ----------
1840
+ data : pd.DataFrame
1841
+ Panel data with observations for multiple units over multiple
1842
+ time periods.
1843
+ outcome : str
1844
+ Name of the outcome variable column.
1845
+ treatment : str
1846
+ Name of the treatment indicator column (0/1).
1847
+
1848
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a
1849
+ treatment timing indicator. For each unit, D=1 for ALL periods
1850
+ during and after treatment:
1851
+
1852
+ - D[t, i] = 0 for all t < g_i (pre-treatment periods)
1853
+ - D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
1854
+
1855
+ where g_i is the treatment start time for unit i.
1856
+
1857
+ For staggered adoption, different units can have different g_i.
1858
+ The ATT averages over ALL D=1 cells per Equation 1 of the paper.
1859
+ unit : str
1860
+ Name of the unit identifier column.
1861
+ time : str
1862
+ Name of the time period column.
1863
+
1864
+ Returns
1865
+ -------
1866
+ TROPResults
1867
+ Object containing the ATT estimate, standard error,
1868
+ factor estimates, and tuning parameters. The lambda_*
1869
+ attributes show the selected grid values. Infinity values
1870
+ (∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform
1871
+ weights), λ_nn=∞ → 1e10 (factor model disabled).
1872
+ """
1873
+ # Validate inputs
1874
+ required_cols = [outcome, treatment, unit, time]
1875
+ missing = [c for c in required_cols if c not in data.columns]
1876
+ if missing:
1877
+ raise ValueError(f"Missing columns: {missing}")
1878
+
1879
+ # Dispatch based on estimation method
1880
+ if self.method == "joint":
1881
+ return self._fit_joint(data, outcome, treatment, unit, time)
1882
+
1883
+ # Below is the twostep method (default)
1884
+ # Get unique units and periods
1885
+ all_units = sorted(data[unit].unique())
1886
+ all_periods = sorted(data[time].unique())
1887
+
1888
+ n_units = len(all_units)
1889
+ n_periods = len(all_periods)
1890
+
1891
+ # Create mappings
1892
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1893
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
1894
+ idx_to_unit = {i: u for u, i in unit_to_idx.items()}
1895
+ idx_to_period = {i: p for p, i in period_to_idx.items()}
1896
+
1897
+ # Create outcome matrix Y (n_periods x n_units) and treatment matrix D
1898
+ # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
1899
+ Y = (
1900
+ data.pivot(index=time, columns=unit, values=outcome)
1901
+ .reindex(index=all_periods, columns=all_units)
1902
+ .values
1903
+ )
1904
+
1905
+ # For D matrix, track missing values BEFORE fillna to support unbalanced panels
1906
+ # Issue 3 fix: Missing observations should not trigger spurious violations
1907
+ D_raw = (
1908
+ data.pivot(index=time, columns=unit, values=treatment)
1909
+ .reindex(index=all_periods, columns=all_units)
1910
+ )
1911
+ missing_mask = pd.isna(D_raw).values # True where originally missing
1912
+ D = D_raw.fillna(0).astype(int).values
1913
+
1914
+ # Validate D is monotonic non-decreasing per unit (absorbing state)
1915
+ # D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
1916
+ # Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
1917
+ # This catches 1→0 violations that span missing period gaps
1918
+ # Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
1919
+ # adjacent period transitions don't show it (the gap hides the transition)
1920
+ violating_units = []
1921
+ for unit_idx in range(n_units):
1922
+ # Get observed D values for this unit (where not missing)
1923
+ observed_mask = ~missing_mask[:, unit_idx]
1924
+ observed_d = D[observed_mask, unit_idx]
1925
+
1926
+ # Check if observed sequence is monotonically non-decreasing
1927
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
1928
+ violating_units.append(all_units[unit_idx])
1929
+
1930
+ if violating_units:
1931
+ raise ValueError(
1932
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
1933
+ f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
1934
+ f"If this is event-study style data, convert to absorbing state: "
1935
+ f"D[t, i] = 1 for all t >= first treatment period."
1936
+ )
1937
+
1938
+ # Identify treated observations
1939
+ treated_mask = D == 1
1940
+ n_treated_obs = np.sum(treated_mask)
1941
+
1942
+ if n_treated_obs == 0:
1943
+ raise ValueError("No treated observations found")
1944
+
1945
+ # Identify treated and control units
1946
+ unit_ever_treated = np.any(D == 1, axis=0)
1947
+ treated_unit_idx = np.where(unit_ever_treated)[0]
1948
+ control_unit_idx = np.where(~unit_ever_treated)[0]
1949
+
1950
+ if len(control_unit_idx) == 0:
1951
+ raise ValueError("No control units found")
1952
+
1953
+ # Determine pre/post periods from treatment indicator D
1954
+ # D matrix is the sole input for treatment timing per the paper
1955
+ first_treat_period = None
1956
+ for t in range(n_periods):
1957
+ if np.any(D[t, :] == 1):
1958
+ first_treat_period = t
1959
+ break
1960
+ if first_treat_period is None:
1961
+ raise ValueError("Could not infer post-treatment periods from D matrix")
1962
+
1963
+ n_pre_periods = first_treat_period
1964
+ # Count periods where D=1 is actually observed (matches docstring)
1965
+ # Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
1966
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
1967
+
1968
+ if n_pre_periods < 2:
1969
+ raise ValueError("Need at least 2 pre-treatment periods")
1970
+
1971
+ # Step 1: Grid search with LOOCV for tuning parameters
1972
+ best_lambda = None
1973
+ best_score = np.inf
1974
+
1975
+ # Control observations mask (for LOOCV)
1976
+ control_mask = D == 0
1977
+
1978
+ # Pre-compute structures that are reused across LOOCV iterations
1979
+ self._precomputed = self._precompute_structures(
1980
+ Y, D, control_unit_idx, n_units, n_periods
1981
+ )
1982
+
1983
+ # Use Rust backend for parallel LOOCV grid search (10-50x speedup)
1984
+ if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None:
1985
+ try:
1986
+ # Prepare inputs for Rust function
1987
+ control_mask_u8 = control_mask.astype(np.uint8)
1988
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
1989
+
1990
+ lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
1991
+ lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
1992
+ lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
1993
+
1994
+ result = _rust_loocv_grid_search(
1995
+ Y, D.astype(np.float64), control_mask_u8,
1996
+ time_dist_matrix,
1997
+ lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
1998
+ self.max_loocv_samples, self.max_iter, self.tol,
1999
+ self.seed if self.seed is not None else 0
2000
+ )
2001
+ # Unpack result - 7 values including optional first_failed_obs
2002
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
2003
+ # Only accept finite scores - infinite means all fits failed
2004
+ if np.isfinite(best_score):
2005
+ best_lambda = (best_lt, best_lu, best_ln)
2006
+ # else: best_lambda stays None, triggering defaults fallback
2007
+ # Emit warnings consistent with Python implementation
2008
+ if n_valid == 0:
2009
+ # Include failed observation coordinates if available (Issue 2 fix)
2010
+ obs_info = ""
2011
+ if first_failed_obs is not None:
2012
+ t_idx, i_idx = first_failed_obs
2013
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
2014
+ warnings.warn(
2015
+ f"LOOCV: All {n_attempted} fits failed for "
2016
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
2017
+ f"Returning infinite score.{obs_info}",
2018
+ UserWarning
2019
+ )
2020
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
2021
+ n_failed = n_attempted - n_valid
2022
+ # Include failed observation coordinates if available
2023
+ obs_info = ""
2024
+ if first_failed_obs is not None:
2025
+ t_idx, i_idx = first_failed_obs
2026
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
2027
+ warnings.warn(
2028
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
2029
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
2030
+ f"This may indicate numerical instability.{obs_info}",
2031
+ UserWarning
2032
+ )
2033
+ except Exception as e:
2034
+ # Fall back to Python implementation on error
2035
+ logger.debug(
2036
+ "Rust LOOCV grid search failed, falling back to Python: %s", e
2037
+ )
2038
+ best_lambda = None
2039
+ best_score = np.inf
2040
+
2041
+ # Fall back to Python implementation if Rust unavailable or failed
2042
+ # Uses two-stage approach per paper's footnote 2:
2043
+ # Stage 1: Univariate searches for initial values
2044
+ # Stage 2: Cycling (coordinate descent) until convergence
2045
+ if best_lambda is None:
2046
+ # Stage 1: Univariate searches with extreme fixed values
2047
+ # Following paper's footnote 2 for initial bounds
2048
+
2049
+ # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
2050
+ lambda_time_init, _ = self._univariate_loocv_search(
2051
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
2052
+ 'lambda_time', self.lambda_time_grid,
2053
+ {'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF}
2054
+ )
2055
+
2056
+ # λ_nn search: fix λ_time=∞ (uniform time weights), λ_unit=0
2057
+ lambda_nn_init, _ = self._univariate_loocv_search(
2058
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
2059
+ 'lambda_nn', self.lambda_nn_grid,
2060
+ {'lambda_time': _LAMBDA_INF, 'lambda_unit': 0.0}
2061
+ )
2062
+
2063
+ # λ_unit search: fix λ_nn=∞, λ_time=0
2064
+ lambda_unit_init, _ = self._univariate_loocv_search(
2065
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
2066
+ 'lambda_unit', self.lambda_unit_grid,
2067
+ {'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0}
2068
+ )
2069
+
2070
+ # Stage 2: Cycling refinement (coordinate descent)
2071
+ lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
2072
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
2073
+ (lambda_time_init, lambda_unit_init, lambda_nn_init)
2074
+ )
2075
+
2076
+ # Compute final score for the optimized parameters
2077
+ try:
2078
+ best_score = self._loocv_score_obs_specific(
2079
+ Y, D, control_mask, control_unit_idx,
2080
+ lambda_time, lambda_unit, lambda_nn,
2081
+ n_units, n_periods
2082
+ )
2083
+ # Only accept finite scores - infinite means all fits failed
2084
+ if np.isfinite(best_score):
2085
+ best_lambda = (lambda_time, lambda_unit, lambda_nn)
2086
+ # else: best_lambda stays None, triggering defaults fallback
2087
+ except (np.linalg.LinAlgError, ValueError):
2088
+ # If even the optimized parameters fail, best_lambda stays None
2089
+ pass
2090
+
2091
+ if best_lambda is None:
2092
+ warnings.warn(
2093
+ "All tuning parameter combinations failed. Using defaults.",
2094
+ UserWarning
2095
+ )
2096
+ best_lambda = (1.0, 1.0, 0.1)
2097
+ best_score = np.nan
2098
+
2099
+ self._optimal_lambda = best_lambda
2100
+ lambda_time, lambda_unit, lambda_nn = best_lambda
2101
+
2102
+ # Convert infinity values for final estimation (matching LOOCV conversion)
2103
+ # This ensures final estimation uses the same effective parameters that LOOCV evaluated.
2104
+ # See REGISTRY.md "λ=∞ implementation" for rationale.
2105
+ #
2106
+ # IMPORTANT: Store original grid values for results, use converted for computation.
2107
+ # This lets users see what was selected from their grid, while ensuring consistent
2108
+ # behavior between point estimation and variance estimation.
2109
+ original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda
2110
+
2111
+ if np.isinf(lambda_time):
2112
+ lambda_time = 0.0 # Uniform time weights
2113
+ if np.isinf(lambda_unit):
2114
+ lambda_unit = 0.0 # Uniform unit weights
2115
+ if np.isinf(lambda_nn):
2116
+ lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
2117
+
2118
+ # Create effective_lambda with converted values for ALL downstream computation
2119
+ # This ensures variance estimation uses the same parameters as point estimation
2120
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
2121
+
2122
+ # Step 2: Final estimation - per-observation model fitting following Algorithm 2
2123
+ # For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
2124
+ treatment_effects = {}
2125
+ tau_values = []
2126
+ alpha_estimates = []
2127
+ beta_estimates = []
2128
+ L_estimates = []
2129
+
2130
+ # Use pre-computed treated observations
2131
+ treated_observations = self._precomputed["treated_observations"]
2132
+
2133
+ for t, i in treated_observations:
2134
+ # Compute observation-specific weights for this (i, t)
2135
+ weight_matrix = self._compute_observation_weights(
2136
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
2137
+ n_units, n_periods
2138
+ )
2139
+
2140
+ # Fit model with these weights
2141
+ alpha_hat, beta_hat, L_hat = self._estimate_model(
2142
+ Y, control_mask, weight_matrix, lambda_nn,
2143
+ n_units, n_periods
2144
+ )
2145
+
2146
+ # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
2147
+ tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
2148
+
2149
+ unit_id = idx_to_unit[i]
2150
+ time_id = idx_to_period[t]
2151
+ treatment_effects[(unit_id, time_id)] = tau_it
2152
+ tau_values.append(tau_it)
2153
+
2154
+ # Store for averaging
2155
+ alpha_estimates.append(alpha_hat)
2156
+ beta_estimates.append(beta_hat)
2157
+ L_estimates.append(L_hat)
2158
+
2159
+ # Average ATT
2160
+ att = np.mean(tau_values)
2161
+
2162
+ # Average parameter estimates for output (representative)
2163
+ alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
2164
+ beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
2165
+ L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
2166
+
2167
+ # Compute effective rank
2168
+ _, s, _ = np.linalg.svd(L_hat, full_matrices=False)
2169
+ if s[0] > 0:
2170
+ effective_rank = np.sum(s) / s[0]
2171
+ else:
2172
+ effective_rank = 0.0
2173
+
2174
+ # Step 4: Variance estimation
2175
+ # Use effective_lambda (converted values) to ensure SE is computed with same
2176
+ # parameters as point estimation. This fixes the variance inconsistency issue.
2177
+ if self.variance_method == "bootstrap":
2178
+ se, bootstrap_dist = self._bootstrap_variance(
2179
+ data, outcome, treatment, unit, time,
2180
+ effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
2181
+ )
2182
+ else:
2183
+ se, bootstrap_dist = self._jackknife_variance(
2184
+ Y, D, control_mask, control_unit_idx, effective_lambda,
2185
+ n_units, n_periods
2186
+ )
2187
+
2188
+ # Compute test statistics
2189
+ if se > 0:
2190
+ t_stat = att / se
2191
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
2192
+ conf_int = compute_confidence_interval(att, se, self.alpha)
2193
+ else:
2194
+ # When SE is undefined/zero, ALL inference fields should be NaN
2195
+ t_stat = np.nan
2196
+ p_value = np.nan
2197
+ conf_int = (np.nan, np.nan)
2198
+
2199
+ # Create results dictionaries
2200
+ unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
2201
+ time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
2202
+
2203
+ # Store results
2204
+ self.results_ = TROPResults(
2205
+ att=att,
2206
+ se=se,
2207
+ t_stat=t_stat,
2208
+ p_value=p_value,
2209
+ conf_int=conf_int,
2210
+ n_obs=len(data),
2211
+ n_treated=len(treated_unit_idx),
2212
+ n_control=len(control_unit_idx),
2213
+ n_treated_obs=n_treated_obs,
2214
+ unit_effects=unit_effects_dict,
2215
+ time_effects=time_effects_dict,
2216
+ treatment_effects=treatment_effects,
2217
+ # Store ORIGINAL grid values (possibly inf) so users see what was selected.
2218
+ # Internally, infinity values are converted for computation (see effective_lambda).
2219
+ lambda_time=original_lambda_time,
2220
+ lambda_unit=original_lambda_unit,
2221
+ lambda_nn=original_lambda_nn,
2222
+ factor_matrix=L_hat,
2223
+ effective_rank=effective_rank,
2224
+ loocv_score=best_score,
2225
+ variance_method=self.variance_method,
2226
+ alpha=self.alpha,
2227
+ n_pre_periods=n_pre_periods,
2228
+ n_post_periods=n_post_periods,
2229
+ n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None,
2230
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
2231
+ )
2232
+
2233
+ self.is_fitted_ = True
2234
+ return self.results_
2235
+
2236
+ def _compute_observation_weights(
2237
+ self,
2238
+ Y: np.ndarray,
2239
+ D: np.ndarray,
2240
+ i: int,
2241
+ t: int,
2242
+ lambda_time: float,
2243
+ lambda_unit: float,
2244
+ control_unit_idx: np.ndarray,
2245
+ n_units: int,
2246
+ n_periods: int,
2247
+ ) -> np.ndarray:
2248
+ """
2249
+ Compute observation-specific weight matrix for treated observation (i, t).
2250
+
2251
+ Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
2252
+ - Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
2253
+ - Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
2254
+
2255
+ IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
2256
+ where (1 - W_js) is non-zero, which includes pre-treatment observations of
2257
+ eventually-treated units since W_js = 0 for those. This method computes
2258
+ weights for ALL units where D[t, j] = 0 at the target period, not just
2259
+ never-treated units.
2260
+
2261
+ Uses pre-computed structures when available for efficiency.
2262
+
2263
+ Parameters
2264
+ ----------
2265
+ Y : np.ndarray
2266
+ Outcome matrix (n_periods x n_units).
2267
+ D : np.ndarray
2268
+ Treatment indicator matrix (n_periods x n_units).
2269
+ i : int
2270
+ Treated unit index.
2271
+ t : int
2272
+ Treatment period index.
2273
+ lambda_time : float
2274
+ Time weight decay parameter.
2275
+ lambda_unit : float
2276
+ Unit weight decay parameter.
2277
+ control_unit_idx : np.ndarray
2278
+ Indices of never-treated units (for backward compatibility, but not
2279
+ used for weight computation - we use D matrix directly).
2280
+ n_units : int
2281
+ Number of units.
2282
+ n_periods : int
2283
+ Number of periods.
2284
+
2285
+ Returns
2286
+ -------
2287
+ np.ndarray
2288
+ Weight matrix (n_periods x n_units) for observation (i, t).
2289
+ """
2290
+ # Use pre-computed structures when available
2291
+ if self._precomputed is not None:
2292
+ # Time weights from pre-computed time distance matrix
2293
+ # time_dist_matrix[t, s] = |t - s|
2294
+ time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
2295
+
2296
+ # Unit weights - computed for ALL units where D[t, j] = 0
2297
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
2298
+ unit_weights = np.zeros(n_units)
2299
+ D_stored = self._precomputed["D"]
2300
+ Y_stored = self._precomputed["Y"]
2301
+
2302
+ # Valid control units at time t: D[t, j] == 0
2303
+ valid_control_at_t = D_stored[t, :] == 0
2304
+
2305
+ if lambda_unit == 0:
2306
+ # Uniform weights when lambda_unit = 0
2307
+ # All units not treated at time t get weight 1
2308
+ unit_weights[valid_control_at_t] = 1.0
2309
+ else:
2310
+ # Use observation-specific distances with target period excluded
2311
+ # (Issue B fix: compute exact per-observation distance)
2312
+ for j in range(n_units):
2313
+ if valid_control_at_t[j] and j != i:
2314
+ # Compute distance excluding target period t
2315
+ dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
2316
+ if np.isinf(dist):
2317
+ unit_weights[j] = 0.0
2318
+ else:
2319
+ unit_weights[j] = np.exp(-lambda_unit * dist)
2320
+
2321
+ # Treated unit i gets weight 1
2322
+ unit_weights[i] = 1.0
2323
+
2324
+ # Weight matrix: outer product (n_periods x n_units)
2325
+ return np.outer(time_weights, unit_weights)
2326
+
2327
+ # Fallback: compute from scratch (used in bootstrap/jackknife)
2328
+ # Time distance: |t - s| following paper's Equation 3 (page 7)
2329
+ dist_time = np.abs(np.arange(n_periods) - t)
2330
+ time_weights = np.exp(-lambda_time * dist_time)
2331
+
2332
+ # Unit weights - computed for ALL units where D[t, j] = 0
2333
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
2334
+ unit_weights = np.zeros(n_units)
2335
+
2336
+ # Valid control units at time t: D[t, j] == 0
2337
+ valid_control_at_t = D[t, :] == 0
2338
+
2339
+ if lambda_unit == 0:
2340
+ # Uniform weights when lambda_unit = 0
2341
+ unit_weights[valid_control_at_t] = 1.0
2342
+ else:
2343
+ for j in range(n_units):
2344
+ if valid_control_at_t[j] and j != i:
2345
+ # Compute distance excluding target period t (Issue B fix)
2346
+ dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
2347
+ if np.isinf(dist):
2348
+ unit_weights[j] = 0.0
2349
+ else:
2350
+ unit_weights[j] = np.exp(-lambda_unit * dist)
2351
+
2352
+ # Treated unit i gets weight 1 (or could be omitted since we fit on controls)
2353
+ # We include treated unit's own observation for model fitting
2354
+ unit_weights[i] = 1.0
2355
+
2356
+ # Weight matrix: outer product (n_periods x n_units)
2357
+ W = np.outer(time_weights, unit_weights)
2358
+
2359
+ return W
2360
+
2361
+ def _soft_threshold_svd(
2362
+ self,
2363
+ M: np.ndarray,
2364
+ threshold: float,
2365
+ ) -> np.ndarray:
2366
+ """
2367
+ Apply soft-thresholding to singular values (proximal operator for nuclear norm).
2368
+
2369
+ Parameters
2370
+ ----------
2371
+ M : np.ndarray
2372
+ Input matrix.
2373
+ threshold : float
2374
+ Soft-thresholding parameter.
2375
+
2376
+ Returns
2377
+ -------
2378
+ np.ndarray
2379
+ Matrix with soft-thresholded singular values.
2380
+ """
2381
+ if threshold <= 0:
2382
+ return M
2383
+
2384
+ # Handle NaN/Inf values in input
2385
+ if not np.isfinite(M).all():
2386
+ M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
2387
+
2388
+ try:
2389
+ U, s, Vt = np.linalg.svd(M, full_matrices=False)
2390
+ except np.linalg.LinAlgError:
2391
+ # SVD failed, return zero matrix
2392
+ return np.zeros_like(M)
2393
+
2394
+ # Check for numerical issues in SVD output
2395
+ if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
2396
+ # SVD produced non-finite values, return zero matrix
2397
+ return np.zeros_like(M)
2398
+
2399
+ s_thresh = np.maximum(s - threshold, 0)
2400
+
2401
+ # Use truncated reconstruction with only non-zero singular values
2402
+ nonzero_mask = s_thresh > self.CONVERGENCE_TOL_SVD
2403
+ if not np.any(nonzero_mask):
2404
+ return np.zeros_like(M)
2405
+
2406
+ # Truncate to non-zero components for numerical stability
2407
+ U_trunc = U[:, nonzero_mask]
2408
+ s_trunc = s_thresh[nonzero_mask]
2409
+ Vt_trunc = Vt[nonzero_mask, :]
2410
+
2411
+ # Compute result, suppressing expected numerical warnings from
2412
+ # ill-conditioned matrices during alternating minimization
2413
+ with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
2414
+ result = (U_trunc * s_trunc) @ Vt_trunc
2415
+
2416
+ # Replace any NaN/Inf in result with zeros
2417
+ if not np.isfinite(result).all():
2418
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
2419
+
2420
+ return result
2421
+
2422
+ def _weighted_nuclear_norm_solve(
2423
+ self,
2424
+ Y: np.ndarray,
2425
+ W: np.ndarray,
2426
+ L_init: np.ndarray,
2427
+ alpha: np.ndarray,
2428
+ beta: np.ndarray,
2429
+ lambda_nn: float,
2430
+ max_inner_iter: int = 20,
2431
+ ) -> np.ndarray:
2432
+ """
2433
+ Solve weighted nuclear norm problem using iterative weighted soft-impute.
2434
+
2435
+ Issue C fix: Implements the weighted nuclear norm optimization from the
2436
+ paper's Equation 2 (page 7). The full objective is:
2437
+ min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
2438
+
2439
+ This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
2440
+ L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
2441
+
2442
+ where W ⊙ denotes element-wise multiplication with normalized weights.
2443
+
2444
+ IMPORTANT: For observations with W=0 (treated observations), we keep
2445
+ L values from the previous iteration rather than setting L = R, which
2446
+ would absorb the treatment effect.
2447
+
2448
+ Parameters
2449
+ ----------
2450
+ Y : np.ndarray
2451
+ Outcome matrix (n_periods x n_units).
2452
+ W : np.ndarray
2453
+ Weight matrix (n_periods x n_units), non-negative. W=0 indicates
2454
+ observations that should not be used for fitting (treated obs).
2455
+ L_init : np.ndarray
2456
+ Initial estimate of L matrix.
2457
+ alpha : np.ndarray
2458
+ Current unit fixed effects estimate.
2459
+ beta : np.ndarray
2460
+ Current time fixed effects estimate.
2461
+ lambda_nn : float
2462
+ Nuclear norm regularization parameter.
2463
+ max_inner_iter : int, default=20
2464
+ Maximum inner iterations for the proximal algorithm.
2465
+
2466
+ Returns
2467
+ -------
2468
+ np.ndarray
2469
+ Updated L matrix estimate.
2470
+ """
2471
+ # Compute target residual R = Y - α - β
2472
+ R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
2473
+
2474
+ # Handle invalid values
2475
+ R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
2476
+
2477
+ # For observations with W=0 (treated obs), keep L_init instead of R
2478
+ # This prevents L from absorbing the treatment effect
2479
+ valid_obs_mask = W > 0
2480
+ R_masked = np.where(valid_obs_mask, R, L_init)
2481
+
2482
+ if lambda_nn <= 0:
2483
+ # No regularization - just return masked residual
2484
+ # Use soft-thresholding with threshold=0 which returns the input
2485
+ return R_masked
2486
+
2487
+ # Normalize weights so max is 1 (for step size stability)
2488
+ W_max = np.max(W)
2489
+ if W_max > 0:
2490
+ W_norm = W / W_max
2491
+ else:
2492
+ W_norm = W
2493
+
2494
+ # Initialize L
2495
+ L = L_init.copy()
2496
+
2497
+ # Proximal gradient iteration with weighted soft-impute
2498
+ # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
2499
+ # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
2500
+ # where η is the step size (we use η = 1 with normalized weights)
2501
+ for _ in range(max_inner_iter):
2502
+ L_old = L.copy()
2503
+
2504
+ # Gradient step: L_k + W ⊙ (R - L_k)
2505
+ # For W=0 observations, this keeps L_k unchanged
2506
+ gradient_step = L + W_norm * (R_masked - L)
2507
+
2508
+ # Proximal step: soft-threshold singular values
2509
+ L = self._soft_threshold_svd(gradient_step, lambda_nn)
2510
+
2511
+ # Check convergence
2512
+ if np.max(np.abs(L - L_old)) < self.tol:
2513
+ break
2514
+
2515
+ return L
2516
+
2517
+ def _estimate_model(
2518
+ self,
2519
+ Y: np.ndarray,
2520
+ control_mask: np.ndarray,
2521
+ weight_matrix: np.ndarray,
2522
+ lambda_nn: float,
2523
+ n_units: int,
2524
+ n_periods: int,
2525
+ exclude_obs: Optional[Tuple[int, int]] = None,
2526
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
2527
+ """
2528
+ Estimate the model: Y = α + β + L + τD + ε with nuclear norm penalty on L.
2529
+
2530
+ Uses alternating minimization with vectorized operations:
2531
+ 1. Fix L, solve for α, β via weighted means
2532
+ 2. Fix α, β, solve for L via soft-thresholding
2533
+
2534
+ Parameters
2535
+ ----------
2536
+ Y : np.ndarray
2537
+ Outcome matrix (n_periods x n_units).
2538
+ control_mask : np.ndarray
2539
+ Boolean mask for control observations.
2540
+ weight_matrix : np.ndarray
2541
+ Pre-computed global weight matrix (n_periods x n_units).
2542
+ lambda_nn : float
2543
+ Nuclear norm regularization parameter.
2544
+ n_units : int
2545
+ Number of units.
2546
+ n_periods : int
2547
+ Number of periods.
2548
+ exclude_obs : tuple, optional
2549
+ (t, i) observation to exclude (for LOOCV).
2550
+
2551
+ Returns
2552
+ -------
2553
+ tuple
2554
+ (alpha, beta, L) estimated parameters.
2555
+ """
2556
+ W = weight_matrix
2557
+
2558
+ # Mask for estimation (control obs only, excluding LOOCV obs if specified)
2559
+ est_mask = control_mask.copy()
2560
+ if exclude_obs is not None:
2561
+ t_ex, i_ex = exclude_obs
2562
+ est_mask[t_ex, i_ex] = False
2563
+
2564
+ # Handle missing values
2565
+ valid_mask = ~np.isnan(Y) & est_mask
2566
+
2567
+ # Initialize
2568
+ alpha = np.zeros(n_units)
2569
+ beta = np.zeros(n_periods)
2570
+ L = np.zeros((n_periods, n_units))
2571
+
2572
+ # Pre-compute masked weights for vectorized operations
2573
+ # Set weights to 0 where not valid
2574
+ W_masked = W * valid_mask
2575
+
2576
+ # Pre-compute weight sums per unit and per time (for denominator)
2577
+ # shape: (n_units,) and (n_periods,)
2578
+ weight_sum_per_unit = np.sum(W_masked, axis=0) # sum over periods
2579
+ weight_sum_per_time = np.sum(W_masked, axis=1) # sum over units
2580
+
2581
+ # Handle units/periods with zero weight sum
2582
+ unit_has_obs = weight_sum_per_unit > 0
2583
+ time_has_obs = weight_sum_per_time > 0
2584
+
2585
+ # Create safe denominators (avoid division by zero)
2586
+ safe_unit_denom = np.where(unit_has_obs, weight_sum_per_unit, 1.0)
2587
+ safe_time_denom = np.where(time_has_obs, weight_sum_per_time, 1.0)
2588
+
2589
+ # Replace NaN in Y with 0 for computation (mask handles exclusion)
2590
+ Y_safe = np.where(np.isnan(Y), 0.0, Y)
2591
+
2592
+ # Alternating minimization following Algorithm 1 (page 9)
2593
+ # Minimize: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
2594
+ for _ in range(self.max_iter):
2595
+ alpha_old = alpha.copy()
2596
+ beta_old = beta.copy()
2597
+ L_old = L.copy()
2598
+
2599
+ # Step 1: Update α and β (weighted least squares)
2600
+ # Following Equation 2 (page 7), fix L and solve for α, β
2601
+ # R = Y - L (residual without fixed effects)
2602
+ R = Y_safe - L
2603
+
2604
+ # Alpha update (unit fixed effects):
2605
+ # α_i = argmin_α Σ_t W_{ti}(R_{ti} - α - β_t)²
2606
+ # Solution: α_i = Σ_t W_{ti}(R_{ti} - β_t) / Σ_t W_{ti}
2607
+ R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units)
2608
+ weighted_R_minus_beta = W_masked * R_minus_beta
2609
+ alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,)
2610
+ alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0)
2611
+
2612
+ # Beta update (time fixed effects):
2613
+ # β_t = argmin_β Σ_i W_{ti}(R_{ti} - α_i - β)²
2614
+ # Solution: β_t = Σ_i W_{ti}(R_{ti} - α_i) / Σ_i W_{ti}
2615
+ R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units)
2616
+ weighted_R_minus_alpha = W_masked * R_minus_alpha
2617
+ beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
2618
+ beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
2619
+
2620
+ # Step 2: Update L with weighted nuclear norm penalty
2621
+ # Issue C fix: Use weighted soft-impute to properly account for
2622
+ # observation weights in the nuclear norm optimization.
2623
+ # Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
2624
+ L = self._weighted_nuclear_norm_solve(
2625
+ Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
2626
+ )
2627
+
2628
+ # Check convergence
2629
+ alpha_diff = np.max(np.abs(alpha - alpha_old))
2630
+ beta_diff = np.max(np.abs(beta - beta_old))
2631
+ L_diff = np.max(np.abs(L - L_old))
2632
+
2633
+ if max(alpha_diff, beta_diff, L_diff) < self.tol:
2634
+ break
2635
+
2636
+ return alpha, beta, L
2637
+
2638
+ def _loocv_score_obs_specific(
2639
+ self,
2640
+ Y: np.ndarray,
2641
+ D: np.ndarray,
2642
+ control_mask: np.ndarray,
2643
+ control_unit_idx: np.ndarray,
2644
+ lambda_time: float,
2645
+ lambda_unit: float,
2646
+ lambda_nn: float,
2647
+ n_units: int,
2648
+ n_periods: int,
2649
+ ) -> float:
2650
+ """
2651
+ Compute leave-one-out cross-validation score with observation-specific weights.
2652
+
2653
+ Following the paper's Equation 5 (page 8):
2654
+ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
2655
+
2656
+ For each control observation (j, s), treat it as pseudo-treated,
2657
+ compute observation-specific weights, fit model excluding (j, s),
2658
+ and sum squared pseudo-treatment effects.
2659
+
2660
+ Uses pre-computed structures when available for efficiency.
2661
+
2662
+ Parameters
2663
+ ----------
2664
+ Y : np.ndarray
2665
+ Outcome matrix (n_periods x n_units).
2666
+ D : np.ndarray
2667
+ Treatment indicator matrix (n_periods x n_units).
2668
+ control_mask : np.ndarray
2669
+ Boolean mask for control observations.
2670
+ control_unit_idx : np.ndarray
2671
+ Indices of control units.
2672
+ lambda_time : float
2673
+ Time weight decay parameter.
2674
+ lambda_unit : float
2675
+ Unit weight decay parameter.
2676
+ lambda_nn : float
2677
+ Nuclear norm regularization parameter.
2678
+ n_units : int
2679
+ Number of units.
2680
+ n_periods : int
2681
+ Number of periods.
2682
+
2683
+ Returns
2684
+ -------
2685
+ float
2686
+ LOOCV score (lower is better).
2687
+ """
2688
+ # Use pre-computed control observations if available
2689
+ if self._precomputed is not None:
2690
+ control_obs = self._precomputed["control_obs"]
2691
+ else:
2692
+ # Get all control observations
2693
+ control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
2694
+ if control_mask[t, i] and not np.isnan(Y[t, i])]
2695
+
2696
+ # Subsample for computational tractability (as noted in paper's footnote)
2697
+ rng = np.random.default_rng(self.seed)
2698
+ max_loocv = min(self.max_loocv_samples, len(control_obs))
2699
+ if len(control_obs) > max_loocv:
2700
+ indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
2701
+ control_obs = [control_obs[idx] for idx in indices]
2702
+
2703
+ # Empty control set check: if no control observations, return infinity
2704
+ # A score of 0.0 would incorrectly "win" over legitimate parameters
2705
+ if len(control_obs) == 0:
2706
+ warnings.warn(
2707
+ f"LOOCV: No valid control observations for "
2708
+ f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
2709
+ "Returning infinite score.",
2710
+ UserWarning
2711
+ )
2712
+ return np.inf
2713
+
2714
+ tau_squared_sum = 0.0
2715
+ n_valid = 0
2716
+
2717
+ for t, i in control_obs:
2718
+ try:
2719
+ # Compute observation-specific weights for pseudo-treated (i, t)
2720
+ # Uses pre-computed distance matrices when available
2721
+ weight_matrix = self._compute_observation_weights(
2722
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
2723
+ n_units, n_periods
2724
+ )
2725
+
2726
+ # Estimate model excluding observation (t, i)
2727
+ alpha, beta, L = self._estimate_model(
2728
+ Y, control_mask, weight_matrix, lambda_nn,
2729
+ n_units, n_periods, exclude_obs=(t, i)
2730
+ )
2731
+
2732
+ # Pseudo treatment effect
2733
+ tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
2734
+ tau_squared_sum += tau_ti ** 2
2735
+ n_valid += 1
2736
+
2737
+ except (np.linalg.LinAlgError, ValueError):
2738
+ # Per Equation 5: Q(λ) must sum over ALL D==0 cells
2739
+ # Any failure means this λ cannot produce valid estimates for all cells
2740
+ warnings.warn(
2741
+ f"LOOCV: Fit failed for observation ({t}, {i}) with "
2742
+ f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
2743
+ "Returning infinite score per Equation 5.",
2744
+ UserWarning
2745
+ )
2746
+ return np.inf
2747
+
2748
+ # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
2749
+ # Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
2750
+ return tau_squared_sum
2751
+
2752
+ def _bootstrap_variance(
2753
+ self,
2754
+ data: pd.DataFrame,
2755
+ outcome: str,
2756
+ treatment: str,
2757
+ unit: str,
2758
+ time: str,
2759
+ optimal_lambda: Tuple[float, float, float],
2760
+ Y: Optional[np.ndarray] = None,
2761
+ D: Optional[np.ndarray] = None,
2762
+ control_unit_idx: Optional[np.ndarray] = None,
2763
+ ) -> Tuple[float, np.ndarray]:
2764
+ """
2765
+ Compute bootstrap standard error using unit-level block bootstrap.
2766
+
2767
+ When the optional Rust backend is available and the matrix parameters
2768
+ (Y, D, control_unit_idx) are provided, uses parallelized Rust
2769
+ implementation for 5-15x speedup. Falls back to Python implementation
2770
+ if Rust is unavailable or if matrix parameters are not provided.
2771
+
2772
+ Parameters
2773
+ ----------
2774
+ data : pd.DataFrame
2775
+ Original data in long format with unit, time, outcome, and treatment.
2776
+ outcome : str
2777
+ Name of the outcome column in data.
2778
+ treatment : str
2779
+ Name of the treatment indicator column in data.
2780
+ unit : str
2781
+ Name of the unit identifier column in data.
2782
+ time : str
2783
+ Name of the time period column in data.
2784
+ optimal_lambda : tuple of float
2785
+ Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
2786
+ from cross-validation. Used for model estimation in each bootstrap.
2787
+ Y : np.ndarray, optional
2788
+ Outcome matrix of shape (n_periods, n_units). Required for Rust
2789
+ backend acceleration. If None, falls back to Python implementation.
2790
+ D : np.ndarray, optional
2791
+ Treatment indicator matrix of shape (n_periods, n_units) where
2792
+ D[t,i]=1 indicates unit i is treated at time t. Required for Rust
2793
+ backend acceleration.
2794
+ control_unit_idx : np.ndarray, optional
2795
+ Array of indices for control units (never-treated). Required for
2796
+ Rust backend acceleration.
2797
+
2798
+ Returns
2799
+ -------
2800
+ se : float
2801
+ Bootstrap standard error of the ATT estimate.
2802
+ bootstrap_estimates : np.ndarray
2803
+ Array of ATT estimates from each bootstrap iteration. Length may
2804
+ be less than n_bootstrap if some iterations failed.
2805
+
2806
+ Notes
2807
+ -----
2808
+ Uses unit-level block bootstrap where entire unit time series are
2809
+ resampled with replacement. This preserves within-unit correlation
2810
+ structure and is appropriate for panel data.
2811
+ """
2812
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
2813
+
2814
+ # Try Rust backend for parallel bootstrap (5-15x speedup)
2815
+ if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
2816
+ and self._precomputed is not None and Y is not None
2817
+ and D is not None):
2818
+ try:
2819
+ control_mask = self._precomputed["control_mask"]
2820
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
2821
+
2822
+ bootstrap_estimates, se = _rust_bootstrap_trop_variance(
2823
+ Y, D.astype(np.float64),
2824
+ control_mask.astype(np.uint8),
2825
+ time_dist_matrix,
2826
+ lambda_time, lambda_unit, lambda_nn,
2827
+ self.n_bootstrap, self.max_iter, self.tol,
2828
+ self.seed if self.seed is not None else 0
2829
+ )
2830
+
2831
+ if len(bootstrap_estimates) >= 10:
2832
+ return float(se), bootstrap_estimates
2833
+ # Fall through to Python if too few bootstrap samples
2834
+ logger.debug(
2835
+ "Rust bootstrap returned only %d samples, falling back to Python",
2836
+ len(bootstrap_estimates)
2837
+ )
2838
+ except Exception as e:
2839
+ logger.debug(
2840
+ "Rust bootstrap variance failed, falling back to Python: %s", e
2841
+ )
2842
+
2843
+ # Python implementation (fallback)
2844
+ rng = np.random.default_rng(self.seed)
2845
+
2846
+ # Issue D fix: Stratified bootstrap sampling
2847
+ # Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
2848
+ # and N_1 treated rows separately to preserve treatment ratio
2849
+ unit_ever_treated = data.groupby(unit)[treatment].max()
2850
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
2851
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
2852
+
2853
+ n_treated_units = len(treated_units)
2854
+ n_control_units = len(control_units)
2855
+
2856
+ bootstrap_estimates_list = []
2857
+
2858
+ for _ in range(self.n_bootstrap):
2859
+ # Stratified sampling: sample control and treated units separately
2860
+ # This preserves the treatment ratio in each bootstrap sample
2861
+ if n_control_units > 0:
2862
+ sampled_control = rng.choice(
2863
+ control_units, size=n_control_units, replace=True
2864
+ )
2865
+ else:
2866
+ sampled_control = np.array([], dtype=control_units.dtype)
2867
+
2868
+ if n_treated_units > 0:
2869
+ sampled_treated = rng.choice(
2870
+ treated_units, size=n_treated_units, replace=True
2871
+ )
2872
+ else:
2873
+ sampled_treated = np.array([], dtype=treated_units.dtype)
2874
+
2875
+ # Combine stratified samples
2876
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
2877
+
2878
+ # Create bootstrap sample with unique unit IDs
2879
+ boot_data = pd.concat([
2880
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
2881
+ for idx, u in enumerate(sampled_units)
2882
+ ], ignore_index=True)
2883
+
2884
+ try:
2885
+ # Fit with fixed lambda (skip LOOCV for speed)
2886
+ att = self._fit_with_fixed_lambda(
2887
+ boot_data, outcome, treatment, unit, time,
2888
+ optimal_lambda
2889
+ )
2890
+ bootstrap_estimates_list.append(att)
2891
+ except (ValueError, np.linalg.LinAlgError, KeyError):
2892
+ continue
2893
+
2894
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
2895
+
2896
+ if len(bootstrap_estimates) < 10:
2897
+ warnings.warn(
2898
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
2899
+ "Standard errors may be unreliable.",
2900
+ UserWarning
2901
+ )
2902
+ if len(bootstrap_estimates) == 0:
2903
+ return 0.0, np.array([])
2904
+
2905
+ se = np.std(bootstrap_estimates, ddof=1)
2906
+ return float(se), bootstrap_estimates
2907
+
2908
+ def _jackknife_variance(
2909
+ self,
2910
+ Y: np.ndarray,
2911
+ D: np.ndarray,
2912
+ control_mask: np.ndarray,
2913
+ control_unit_idx: np.ndarray,
2914
+ optimal_lambda: Tuple[float, float, float],
2915
+ n_units: int,
2916
+ n_periods: int,
2917
+ ) -> Tuple[float, np.ndarray]:
2918
+ """
2919
+ Compute jackknife standard error (leave-one-unit-out).
2920
+
2921
+ Uses observation-specific weights following Algorithm 2.
2922
+
2923
+ Parameters
2924
+ ----------
2925
+ Y : np.ndarray
2926
+ Outcome matrix.
2927
+ D : np.ndarray
2928
+ Treatment matrix.
2929
+ control_mask : np.ndarray
2930
+ Control observation mask.
2931
+ control_unit_idx : np.ndarray
2932
+ Indices of control units.
2933
+ optimal_lambda : tuple
2934
+ Optimal tuning parameters.
2935
+ n_units : int
2936
+ Number of units.
2937
+ n_periods : int
2938
+ Number of periods.
2939
+
2940
+ Returns
2941
+ -------
2942
+ tuple
2943
+ (se, jackknife_estimates).
2944
+ """
2945
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
2946
+ jackknife_estimates = []
2947
+
2948
+ # Get treated unit indices
2949
+ treated_unit_idx = np.where(np.any(D == 1, axis=0))[0]
2950
+
2951
+ for leave_out in treated_unit_idx:
2952
+ # Create mask excluding this unit
2953
+ Y_jack = Y.copy()
2954
+ D_jack = D.copy()
2955
+ Y_jack[:, leave_out] = np.nan
2956
+ D_jack[:, leave_out] = 0
2957
+
2958
+ control_mask_jack = D_jack == 0
2959
+
2960
+ # Get remaining treated observations
2961
+ treated_obs_jack = [(t, i) for t in range(n_periods) for i in range(n_units)
2962
+ if D_jack[t, i] == 1]
2963
+
2964
+ if not treated_obs_jack:
2965
+ continue
2966
+
2967
+ try:
2968
+ # Compute ATT using observation-specific weights (Algorithm 2)
2969
+ tau_values = []
2970
+ for t, i in treated_obs_jack:
2971
+ # Compute observation-specific weights for this (i, t)
2972
+ weight_matrix = self._compute_observation_weights(
2973
+ Y_jack, D_jack, i, t, lambda_time, lambda_unit,
2974
+ control_unit_idx, n_units, n_periods
2975
+ )
2976
+
2977
+ # Fit model with these weights
2978
+ alpha, beta, L = self._estimate_model(
2979
+ Y_jack, control_mask_jack, weight_matrix, lambda_nn,
2980
+ n_units, n_periods
2981
+ )
2982
+
2983
+ # Compute treatment effect
2984
+ tau = Y_jack[t, i] - alpha[i] - beta[t] - L[t, i]
2985
+ tau_values.append(tau)
2986
+
2987
+ if tau_values:
2988
+ jackknife_estimates.append(np.mean(tau_values))
2989
+
2990
+ except (np.linalg.LinAlgError, ValueError):
2991
+ continue
2992
+
2993
+ jackknife_estimates = np.array(jackknife_estimates)
2994
+
2995
+ if len(jackknife_estimates) < 2:
2996
+ return 0.0, jackknife_estimates
2997
+
2998
+ # Jackknife SE formula
2999
+ n = len(jackknife_estimates)
3000
+ mean_est = np.mean(jackknife_estimates)
3001
+ se = np.sqrt((n - 1) / n * np.sum((jackknife_estimates - mean_est) ** 2))
3002
+
3003
+ return se, jackknife_estimates
3004
+
3005
+ def _fit_with_fixed_lambda(
3006
+ self,
3007
+ data: pd.DataFrame,
3008
+ outcome: str,
3009
+ treatment: str,
3010
+ unit: str,
3011
+ time: str,
3012
+ fixed_lambda: Tuple[float, float, float],
3013
+ ) -> float:
3014
+ """
3015
+ Fit model with fixed tuning parameters (for bootstrap).
3016
+
3017
+ Uses observation-specific weights following Algorithm 2.
3018
+ Returns only the ATT estimate.
3019
+ """
3020
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
3021
+
3022
+ # Setup matrices
3023
+ all_units = sorted(data[unit].unique())
3024
+ all_periods = sorted(data[time].unique())
3025
+
3026
+ n_units = len(all_units)
3027
+ n_periods = len(all_periods)
3028
+
3029
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
3030
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
3031
+
3032
+ # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
3033
+ Y = (
3034
+ data.pivot(index=time, columns=unit, values=outcome)
3035
+ .reindex(index=all_periods, columns=all_units)
3036
+ .values
3037
+ )
3038
+ D = (
3039
+ data.pivot(index=time, columns=unit, values=treatment)
3040
+ .reindex(index=all_periods, columns=all_units)
3041
+ .fillna(0)
3042
+ .astype(int)
3043
+ .values
3044
+ )
3045
+
3046
+ control_mask = D == 0
3047
+
3048
+ # Get control unit indices
3049
+ unit_ever_treated = np.any(D == 1, axis=0)
3050
+ control_unit_idx = np.where(~unit_ever_treated)[0]
3051
+
3052
+ # Get list of treated observations
3053
+ treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
3054
+ if D[t, i] == 1]
3055
+
3056
+ if not treated_observations:
3057
+ raise ValueError("No treated observations")
3058
+
3059
+ # Compute ATT using observation-specific weights (Algorithm 2)
3060
+ tau_values = []
3061
+ for t, i in treated_observations:
3062
+ # Compute observation-specific weights for this (i, t)
3063
+ weight_matrix = self._compute_observation_weights(
3064
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
3065
+ n_units, n_periods
3066
+ )
3067
+
3068
+ # Fit model with these weights
3069
+ alpha, beta, L = self._estimate_model(
3070
+ Y, control_mask, weight_matrix, lambda_nn,
3071
+ n_units, n_periods
3072
+ )
3073
+
3074
+ # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
3075
+ tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
3076
+ tau_values.append(tau)
3077
+
3078
+ return np.mean(tau_values)
3079
+
3080
+ def get_params(self) -> Dict[str, Any]:
3081
+ """Get estimator parameters."""
3082
+ return {
3083
+ "method": self.method,
3084
+ "lambda_time_grid": self.lambda_time_grid,
3085
+ "lambda_unit_grid": self.lambda_unit_grid,
3086
+ "lambda_nn_grid": self.lambda_nn_grid,
3087
+ "max_iter": self.max_iter,
3088
+ "tol": self.tol,
3089
+ "alpha": self.alpha,
3090
+ "variance_method": self.variance_method,
3091
+ "n_bootstrap": self.n_bootstrap,
3092
+ "max_loocv_samples": self.max_loocv_samples,
3093
+ "seed": self.seed,
3094
+ }
3095
+
3096
+ def set_params(self, **params) -> "TROP":
3097
+ """Set estimator parameters."""
3098
+ for key, value in params.items():
3099
+ if hasattr(self, key):
3100
+ setattr(self, key, value)
3101
+ else:
3102
+ raise ValueError(f"Unknown parameter: {key}")
3103
+ return self
3104
+
3105
+
3106
+ def trop(
3107
+ data: pd.DataFrame,
3108
+ outcome: str,
3109
+ treatment: str,
3110
+ unit: str,
3111
+ time: str,
3112
+ **kwargs,
3113
+ ) -> TROPResults:
3114
+ """
3115
+ Convenience function for TROP estimation.
3116
+
3117
+ Parameters
3118
+ ----------
3119
+ data : pd.DataFrame
3120
+ Panel data.
3121
+ outcome : str
3122
+ Outcome variable column name.
3123
+ treatment : str
3124
+ Treatment indicator column name (0/1).
3125
+
3126
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
3127
+ timing indicator. For each unit, D=1 for ALL periods during and after
3128
+ treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
3129
+ the treatment start time for unit i).
3130
+ unit : str
3131
+ Unit identifier column name.
3132
+ time : str
3133
+ Time period column name.
3134
+ **kwargs
3135
+ Additional arguments passed to TROP constructor.
3136
+
3137
+ Returns
3138
+ -------
3139
+ TROPResults
3140
+ Estimation results.
3141
+
3142
+ Examples
3143
+ --------
3144
+ >>> from diff_diff import trop
3145
+ >>> results = trop(data, 'y', 'treated', 'unit', 'time')
3146
+ >>> print(f"ATT: {results.att:.3f}")
3147
+ """
3148
+ estimator = TROP(**kwargs)
3149
+ return estimator.fit(data, outcome, treatment, unit, time)