diff-diff 2.3.2__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +254 -0
- diff_diff/_backend.py +112 -0
- diff_diff/_rust_backend.cp313-win_amd64.pyd +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1161 -0
- diff_diff/honest_did.py +1511 -0
- diff_diff/imputation.py +2480 -0
- diff_diff/linalg.py +1537 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1241 -0
- diff_diff/prep_dgp.py +777 -0
- diff_diff/pretrends.py +1104 -0
- diff_diff/results.py +794 -0
- diff_diff/staggered.py +1120 -0
- diff_diff/staggered_aggregation.py +492 -0
- diff_diff/staggered_bootstrap.py +753 -0
- diff_diff/staggered_results.py +296 -0
- diff_diff/sun_abraham.py +1227 -0
- diff_diff/synthetic_did.py +858 -0
- diff_diff/triple_diff.py +1322 -0
- diff_diff/trop.py +2904 -0
- diff_diff/twfe.py +428 -0
- diff_diff/utils.py +1845 -0
- diff_diff/visualization.py +1676 -0
- diff_diff-2.3.2.dist-info/METADATA +2646 -0
- diff_diff-2.3.2.dist-info/RECORD +30 -0
- diff_diff-2.3.2.dist-info/WHEEL +4 -0
- diff_diff-2.3.2.dist-info/sboms/diff_diff_rust.cyclonedx.json +5952 -0
diff_diff/trop.py
ADDED
|
@@ -0,0 +1,2904 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Triply Robust Panel (TROP) estimator.
|
|
3
|
+
|
|
4
|
+
Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
|
|
5
|
+
TROP combines three robustness components:
|
|
6
|
+
1. Nuclear norm regularized factor model (interactive fixed effects)
|
|
7
|
+
2. Exponential distance-based unit weights
|
|
8
|
+
3. Exponential time decay weights
|
|
9
|
+
|
|
10
|
+
The estimator uses leave-one-out cross-validation for tuning parameter
|
|
11
|
+
selection and provides robust treatment effect estimates under factor
|
|
12
|
+
confounding.
|
|
13
|
+
|
|
14
|
+
References
|
|
15
|
+
----------
|
|
16
|
+
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
|
|
17
|
+
Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
import warnings
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import pandas as pd
|
|
27
|
+
from scipy import stats
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from typing import TypedDict
|
|
33
|
+
except ImportError:
|
|
34
|
+
from typing_extensions import TypedDict
|
|
35
|
+
|
|
36
|
+
from diff_diff._backend import (
|
|
37
|
+
HAS_RUST_BACKEND,
|
|
38
|
+
_rust_unit_distance_matrix,
|
|
39
|
+
_rust_loocv_grid_search,
|
|
40
|
+
_rust_bootstrap_trop_variance,
|
|
41
|
+
_rust_loocv_grid_search_joint,
|
|
42
|
+
_rust_bootstrap_trop_variance_joint,
|
|
43
|
+
)
|
|
44
|
+
from diff_diff.results import _get_significance_stars
|
|
45
|
+
from diff_diff.utils import compute_confidence_interval, compute_p_value
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# Sentinel value for "disabled" λ_nn in LOOCV parameter search.
|
|
49
|
+
# Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0).
|
|
50
|
+
# For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3:
|
|
51
|
+
# exp(-0 × dist) = 1 for all distances.
|
|
52
|
+
_LAMBDA_INF: float = float('inf')
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _PrecomputedStructures(TypedDict):
|
|
56
|
+
"""Type definition for pre-computed structures used across LOOCV iterations.
|
|
57
|
+
|
|
58
|
+
These structures are computed once in `_precompute_structures()` and reused
|
|
59
|
+
to avoid redundant computation during LOOCV and final estimation.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
unit_dist_matrix: np.ndarray
|
|
63
|
+
"""Pairwise unit distance matrix (n_units x n_units)."""
|
|
64
|
+
time_dist_matrix: np.ndarray
|
|
65
|
+
"""Time distance matrix where [t, s] = |t - s| (n_periods x n_periods)."""
|
|
66
|
+
control_mask: np.ndarray
|
|
67
|
+
"""Boolean mask for control observations (D == 0)."""
|
|
68
|
+
treated_mask: np.ndarray
|
|
69
|
+
"""Boolean mask for treated observations (D == 1)."""
|
|
70
|
+
treated_observations: List[Tuple[int, int]]
|
|
71
|
+
"""List of (t, i) tuples for treated observations."""
|
|
72
|
+
control_obs: List[Tuple[int, int]]
|
|
73
|
+
"""List of (t, i) tuples for valid control observations."""
|
|
74
|
+
control_unit_idx: np.ndarray
|
|
75
|
+
"""Array of never-treated unit indices (for backward compatibility)."""
|
|
76
|
+
D: np.ndarray
|
|
77
|
+
"""Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
|
|
78
|
+
Y: np.ndarray
|
|
79
|
+
"""Outcome matrix (n_periods x n_units)."""
|
|
80
|
+
n_units: int
|
|
81
|
+
"""Number of units."""
|
|
82
|
+
n_periods: int
|
|
83
|
+
"""Number of time periods."""
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class TROPResults:
|
|
88
|
+
"""
|
|
89
|
+
Results from a Triply Robust Panel (TROP) estimation.
|
|
90
|
+
|
|
91
|
+
TROP combines nuclear norm regularized factor estimation with
|
|
92
|
+
exponential distance-based unit weights and time decay weights.
|
|
93
|
+
|
|
94
|
+
Attributes
|
|
95
|
+
----------
|
|
96
|
+
att : float
|
|
97
|
+
Average Treatment effect on the Treated (ATT).
|
|
98
|
+
se : float
|
|
99
|
+
Standard error of the ATT estimate.
|
|
100
|
+
t_stat : float
|
|
101
|
+
T-statistic for the ATT estimate.
|
|
102
|
+
p_value : float
|
|
103
|
+
P-value for the null hypothesis that ATT = 0.
|
|
104
|
+
conf_int : tuple[float, float]
|
|
105
|
+
Confidence interval for the ATT.
|
|
106
|
+
n_obs : int
|
|
107
|
+
Number of observations used in estimation.
|
|
108
|
+
n_treated : int
|
|
109
|
+
Number of treated units.
|
|
110
|
+
n_control : int
|
|
111
|
+
Number of control units.
|
|
112
|
+
n_treated_obs : int
|
|
113
|
+
Number of treated unit-time observations.
|
|
114
|
+
unit_effects : dict
|
|
115
|
+
Estimated unit fixed effects (alpha_i).
|
|
116
|
+
time_effects : dict
|
|
117
|
+
Estimated time fixed effects (beta_t).
|
|
118
|
+
treatment_effects : dict
|
|
119
|
+
Individual treatment effects for each treated (unit, time) pair.
|
|
120
|
+
lambda_time : float
|
|
121
|
+
Selected time weight decay parameter from grid. 0.0 = uniform time
|
|
122
|
+
weights (disabled) per Eq. 3.
|
|
123
|
+
lambda_unit : float
|
|
124
|
+
Selected unit weight decay parameter from grid. 0.0 = uniform unit
|
|
125
|
+
weights (disabled) per Eq. 3.
|
|
126
|
+
lambda_nn : float
|
|
127
|
+
Selected nuclear norm regularization parameter from grid. inf = factor
|
|
128
|
+
model disabled (L=0); converted to 1e10 internally for computation.
|
|
129
|
+
factor_matrix : np.ndarray
|
|
130
|
+
Estimated low-rank factor matrix L (n_periods x n_units).
|
|
131
|
+
effective_rank : float
|
|
132
|
+
Effective rank of the factor matrix (sum of singular values / max).
|
|
133
|
+
loocv_score : float
|
|
134
|
+
Leave-one-out cross-validation score for selected parameters.
|
|
135
|
+
alpha : float
|
|
136
|
+
Significance level for confidence interval.
|
|
137
|
+
n_pre_periods : int
|
|
138
|
+
Number of pre-treatment periods.
|
|
139
|
+
n_post_periods : int
|
|
140
|
+
Number of post-treatment periods (periods with D=1 observations).
|
|
141
|
+
n_bootstrap : int, optional
|
|
142
|
+
Number of bootstrap replications (if bootstrap variance).
|
|
143
|
+
bootstrap_distribution : np.ndarray, optional
|
|
144
|
+
Bootstrap distribution of estimates.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
att: float
|
|
148
|
+
se: float
|
|
149
|
+
t_stat: float
|
|
150
|
+
p_value: float
|
|
151
|
+
conf_int: Tuple[float, float]
|
|
152
|
+
n_obs: int
|
|
153
|
+
n_treated: int
|
|
154
|
+
n_control: int
|
|
155
|
+
n_treated_obs: int
|
|
156
|
+
unit_effects: Dict[Any, float]
|
|
157
|
+
time_effects: Dict[Any, float]
|
|
158
|
+
treatment_effects: Dict[Tuple[Any, Any], float]
|
|
159
|
+
lambda_time: float
|
|
160
|
+
lambda_unit: float
|
|
161
|
+
lambda_nn: float
|
|
162
|
+
factor_matrix: np.ndarray
|
|
163
|
+
effective_rank: float
|
|
164
|
+
loocv_score: float
|
|
165
|
+
alpha: float = 0.05
|
|
166
|
+
n_pre_periods: int = 0
|
|
167
|
+
n_post_periods: int = 0
|
|
168
|
+
n_bootstrap: Optional[int] = field(default=None)
|
|
169
|
+
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
|
|
170
|
+
|
|
171
|
+
def __repr__(self) -> str:
|
|
172
|
+
"""Concise string representation."""
|
|
173
|
+
sig = _get_significance_stars(self.p_value)
|
|
174
|
+
return (
|
|
175
|
+
f"TROPResults(ATT={self.att:.4f}{sig}, "
|
|
176
|
+
f"SE={self.se:.4f}, "
|
|
177
|
+
f"eff_rank={self.effective_rank:.1f}, "
|
|
178
|
+
f"p={self.p_value:.4f})"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def summary(self, alpha: Optional[float] = None) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Generate a formatted summary of the estimation results.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
alpha : float, optional
|
|
188
|
+
Significance level for confidence intervals. Defaults to the
|
|
189
|
+
alpha used during estimation.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
str
|
|
194
|
+
Formatted summary table.
|
|
195
|
+
"""
|
|
196
|
+
alpha = alpha or self.alpha
|
|
197
|
+
conf_level = int((1 - alpha) * 100)
|
|
198
|
+
|
|
199
|
+
lines = [
|
|
200
|
+
"=" * 75,
|
|
201
|
+
"Triply Robust Panel (TROP) Estimation Results".center(75),
|
|
202
|
+
"Athey, Imbens, Qu & Viviano (2025)".center(75),
|
|
203
|
+
"=" * 75,
|
|
204
|
+
"",
|
|
205
|
+
f"{'Observations:':<25} {self.n_obs:>10}",
|
|
206
|
+
f"{'Treated units:':<25} {self.n_treated:>10}",
|
|
207
|
+
f"{'Control units:':<25} {self.n_control:>10}",
|
|
208
|
+
f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
|
|
209
|
+
f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}",
|
|
210
|
+
f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}",
|
|
211
|
+
"",
|
|
212
|
+
"-" * 75,
|
|
213
|
+
"Tuning Parameters (selected via LOOCV)".center(75),
|
|
214
|
+
"-" * 75,
|
|
215
|
+
f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}",
|
|
216
|
+
f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}",
|
|
217
|
+
f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}",
|
|
218
|
+
f"{'Effective rank:':<25} {self.effective_rank:>10.2f}",
|
|
219
|
+
f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}",
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
# Variance info
|
|
223
|
+
if self.n_bootstrap is not None:
|
|
224
|
+
lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
|
|
225
|
+
|
|
226
|
+
lines.extend([
|
|
227
|
+
"",
|
|
228
|
+
"-" * 75,
|
|
229
|
+
f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
|
|
230
|
+
f"{'t-stat':>10} {'P>|t|':>10} {'':>5}",
|
|
231
|
+
"-" * 75,
|
|
232
|
+
f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} "
|
|
233
|
+
f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
|
|
234
|
+
"-" * 75,
|
|
235
|
+
"",
|
|
236
|
+
f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
|
|
237
|
+
])
|
|
238
|
+
|
|
239
|
+
# Add significance codes
|
|
240
|
+
lines.extend([
|
|
241
|
+
"",
|
|
242
|
+
"Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
|
|
243
|
+
"=" * 75,
|
|
244
|
+
])
|
|
245
|
+
|
|
246
|
+
return "\n".join(lines)
|
|
247
|
+
|
|
248
|
+
def print_summary(self, alpha: Optional[float] = None) -> None:
|
|
249
|
+
"""Print the summary to stdout."""
|
|
250
|
+
print(self.summary(alpha))
|
|
251
|
+
|
|
252
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
253
|
+
"""
|
|
254
|
+
Convert results to a dictionary.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
Dict[str, Any]
|
|
259
|
+
Dictionary containing all estimation results.
|
|
260
|
+
"""
|
|
261
|
+
return {
|
|
262
|
+
"att": self.att,
|
|
263
|
+
"se": self.se,
|
|
264
|
+
"t_stat": self.t_stat,
|
|
265
|
+
"p_value": self.p_value,
|
|
266
|
+
"conf_int_lower": self.conf_int[0],
|
|
267
|
+
"conf_int_upper": self.conf_int[1],
|
|
268
|
+
"n_obs": self.n_obs,
|
|
269
|
+
"n_treated": self.n_treated,
|
|
270
|
+
"n_control": self.n_control,
|
|
271
|
+
"n_treated_obs": self.n_treated_obs,
|
|
272
|
+
"n_pre_periods": self.n_pre_periods,
|
|
273
|
+
"n_post_periods": self.n_post_periods,
|
|
274
|
+
"lambda_time": self.lambda_time,
|
|
275
|
+
"lambda_unit": self.lambda_unit,
|
|
276
|
+
"lambda_nn": self.lambda_nn,
|
|
277
|
+
"effective_rank": self.effective_rank,
|
|
278
|
+
"loocv_score": self.loocv_score,
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
282
|
+
"""
|
|
283
|
+
Convert results to a pandas DataFrame.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
pd.DataFrame
|
|
288
|
+
DataFrame with estimation results.
|
|
289
|
+
"""
|
|
290
|
+
return pd.DataFrame([self.to_dict()])
|
|
291
|
+
|
|
292
|
+
def get_treatment_effects_df(self) -> pd.DataFrame:
|
|
293
|
+
"""
|
|
294
|
+
Get individual treatment effects as a DataFrame.
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
pd.DataFrame
|
|
299
|
+
DataFrame with unit, time, and treatment effect columns.
|
|
300
|
+
"""
|
|
301
|
+
return pd.DataFrame([
|
|
302
|
+
{"unit": unit, "time": time, "effect": effect}
|
|
303
|
+
for (unit, time), effect in self.treatment_effects.items()
|
|
304
|
+
])
|
|
305
|
+
|
|
306
|
+
def get_unit_effects_df(self) -> pd.DataFrame:
|
|
307
|
+
"""
|
|
308
|
+
Get unit fixed effects as a DataFrame.
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
pd.DataFrame
|
|
313
|
+
DataFrame with unit and effect columns.
|
|
314
|
+
"""
|
|
315
|
+
return pd.DataFrame([
|
|
316
|
+
{"unit": unit, "effect": effect}
|
|
317
|
+
for unit, effect in self.unit_effects.items()
|
|
318
|
+
])
|
|
319
|
+
|
|
320
|
+
def get_time_effects_df(self) -> pd.DataFrame:
|
|
321
|
+
"""
|
|
322
|
+
Get time fixed effects as a DataFrame.
|
|
323
|
+
|
|
324
|
+
Returns
|
|
325
|
+
-------
|
|
326
|
+
pd.DataFrame
|
|
327
|
+
DataFrame with time and effect columns.
|
|
328
|
+
"""
|
|
329
|
+
return pd.DataFrame([
|
|
330
|
+
{"time": time, "effect": effect}
|
|
331
|
+
for time, effect in self.time_effects.items()
|
|
332
|
+
])
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def is_significant(self) -> bool:
|
|
336
|
+
"""Check if the ATT is statistically significant at the alpha level."""
|
|
337
|
+
return bool(self.p_value < self.alpha)
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def significance_stars(self) -> str:
|
|
341
|
+
"""Return significance stars based on p-value."""
|
|
342
|
+
return _get_significance_stars(self.p_value)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class TROP:
|
|
346
|
+
"""
|
|
347
|
+
Triply Robust Panel (TROP) estimator.
|
|
348
|
+
|
|
349
|
+
Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
|
|
350
|
+
TROP combines three robustness components:
|
|
351
|
+
|
|
352
|
+
1. **Nuclear norm regularized factor model**: Estimates interactive fixed
|
|
353
|
+
effects L_it via matrix completion with nuclear norm penalty ||L||_*
|
|
354
|
+
|
|
355
|
+
2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
|
|
356
|
+
where d(j,i) is the RMSE of outcome differences between units
|
|
357
|
+
|
|
358
|
+
3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
|
|
359
|
+
weighting pre-treatment periods by proximity to treatment
|
|
360
|
+
|
|
361
|
+
Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
|
|
362
|
+
cross-validation on control observations.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
method : str, default='twostep'
|
|
367
|
+
Estimation method to use:
|
|
368
|
+
|
|
369
|
+
- 'twostep': Per-observation model fitting following Algorithm 2 of
|
|
370
|
+
Athey et al. (2025). Computes observation-specific weights and fits
|
|
371
|
+
a model for each treated observation, averaging the individual
|
|
372
|
+
treatment effects. More flexible but computationally intensive.
|
|
373
|
+
|
|
374
|
+
- 'joint': Joint weighted least squares optimization. Estimates a
|
|
375
|
+
single scalar treatment effect τ along with fixed effects and
|
|
376
|
+
optional low-rank factor adjustment. Faster but assumes homogeneous
|
|
377
|
+
treatment effects. Uses alternating minimization when nuclear norm
|
|
378
|
+
penalty is finite.
|
|
379
|
+
|
|
380
|
+
lambda_time_grid : list, optional
|
|
381
|
+
Grid of time weight decay parameters. 0.0 = uniform weights (disabled).
|
|
382
|
+
Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
|
|
383
|
+
lambda_unit_grid : list, optional
|
|
384
|
+
Grid of unit weight decay parameters. 0.0 = uniform weights (disabled).
|
|
385
|
+
Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
|
|
386
|
+
lambda_nn_grid : list, optional
|
|
387
|
+
Grid of nuclear norm regularization parameters. inf = factor model
|
|
388
|
+
disabled (L=0). Default: [0, 0.01, 0.1, 1].
|
|
389
|
+
max_iter : int, default=100
|
|
390
|
+
Maximum iterations for nuclear norm optimization.
|
|
391
|
+
tol : float, default=1e-6
|
|
392
|
+
Convergence tolerance for optimization.
|
|
393
|
+
alpha : float, default=0.05
|
|
394
|
+
Significance level for confidence intervals.
|
|
395
|
+
n_bootstrap : int, default=200
|
|
396
|
+
Number of bootstrap replications for variance estimation.
|
|
397
|
+
seed : int, optional
|
|
398
|
+
Random seed for reproducibility.
|
|
399
|
+
|
|
400
|
+
Attributes
|
|
401
|
+
----------
|
|
402
|
+
results_ : TROPResults
|
|
403
|
+
Estimation results after calling fit().
|
|
404
|
+
is_fitted_ : bool
|
|
405
|
+
Whether the model has been fitted.
|
|
406
|
+
|
|
407
|
+
Examples
|
|
408
|
+
--------
|
|
409
|
+
>>> from diff_diff import TROP
|
|
410
|
+
>>> trop = TROP()
|
|
411
|
+
>>> results = trop.fit(
|
|
412
|
+
... data,
|
|
413
|
+
... outcome='outcome',
|
|
414
|
+
... treatment='treated',
|
|
415
|
+
... unit='unit',
|
|
416
|
+
... time='period',
|
|
417
|
+
... )
|
|
418
|
+
>>> results.print_summary()
|
|
419
|
+
|
|
420
|
+
References
|
|
421
|
+
----------
|
|
422
|
+
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
|
|
423
|
+
Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
# Class constants
|
|
427
|
+
CONVERGENCE_TOL_SVD: float = 1e-10
|
|
428
|
+
"""Tolerance for singular value truncation in soft-thresholding.
|
|
429
|
+
|
|
430
|
+
Singular values below this threshold after soft-thresholding are treated
|
|
431
|
+
as zero to improve numerical stability.
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
method: str = "twostep",
|
|
437
|
+
lambda_time_grid: Optional[List[float]] = None,
|
|
438
|
+
lambda_unit_grid: Optional[List[float]] = None,
|
|
439
|
+
lambda_nn_grid: Optional[List[float]] = None,
|
|
440
|
+
max_iter: int = 100,
|
|
441
|
+
tol: float = 1e-6,
|
|
442
|
+
alpha: float = 0.05,
|
|
443
|
+
n_bootstrap: int = 200,
|
|
444
|
+
seed: Optional[int] = None,
|
|
445
|
+
):
|
|
446
|
+
# Validate method parameter
|
|
447
|
+
valid_methods = ("twostep", "joint")
|
|
448
|
+
if method not in valid_methods:
|
|
449
|
+
raise ValueError(
|
|
450
|
+
f"method must be one of {valid_methods}, got '{method}'"
|
|
451
|
+
)
|
|
452
|
+
self.method = method
|
|
453
|
+
|
|
454
|
+
# Default grids from paper
|
|
455
|
+
self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
|
|
456
|
+
self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
|
|
457
|
+
self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
|
|
458
|
+
|
|
459
|
+
self.max_iter = max_iter
|
|
460
|
+
self.tol = tol
|
|
461
|
+
self.alpha = alpha
|
|
462
|
+
self.n_bootstrap = n_bootstrap
|
|
463
|
+
self.seed = seed
|
|
464
|
+
|
|
465
|
+
# Validate that time/unit grids do not contain inf.
|
|
466
|
+
# Per Athey et al. (2025) Eq. 3, λ_time=0 and λ_unit=0 give uniform
|
|
467
|
+
# weights (exp(-0 × dist) = 1). Using inf is a misunderstanding of
|
|
468
|
+
# the paper's convention. Only λ_nn=∞ is valid (disables factor model).
|
|
469
|
+
for grid_name, grid_vals in [
|
|
470
|
+
("lambda_time_grid", self.lambda_time_grid),
|
|
471
|
+
("lambda_unit_grid", self.lambda_unit_grid),
|
|
472
|
+
]:
|
|
473
|
+
if any(np.isinf(v) for v in grid_vals):
|
|
474
|
+
raise ValueError(
|
|
475
|
+
f"{grid_name} must not contain inf. Use 0.0 for uniform "
|
|
476
|
+
f"weights (disabled) per Athey et al. (2025) Eq. 3: "
|
|
477
|
+
f"exp(-0 × dist) = 1 for all distances."
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Internal state
|
|
481
|
+
self.results_: Optional[TROPResults] = None
|
|
482
|
+
self.is_fitted_: bool = False
|
|
483
|
+
self._optimal_lambda: Optional[Tuple[float, float, float]] = None
|
|
484
|
+
|
|
485
|
+
# Pre-computed structures (set during fit)
|
|
486
|
+
self._precomputed: Optional[_PrecomputedStructures] = None
|
|
487
|
+
|
|
488
|
+
def _precompute_structures(
|
|
489
|
+
self,
|
|
490
|
+
Y: np.ndarray,
|
|
491
|
+
D: np.ndarray,
|
|
492
|
+
control_unit_idx: np.ndarray,
|
|
493
|
+
n_units: int,
|
|
494
|
+
n_periods: int,
|
|
495
|
+
) -> _PrecomputedStructures:
|
|
496
|
+
"""
|
|
497
|
+
Pre-compute data structures that are reused across LOOCV and estimation.
|
|
498
|
+
|
|
499
|
+
This method computes once what would otherwise be computed repeatedly:
|
|
500
|
+
- Pairwise unit distance matrix
|
|
501
|
+
- Time distance vectors
|
|
502
|
+
- Masks and indices
|
|
503
|
+
|
|
504
|
+
Parameters
|
|
505
|
+
----------
|
|
506
|
+
Y : np.ndarray
|
|
507
|
+
Outcome matrix (n_periods x n_units).
|
|
508
|
+
D : np.ndarray
|
|
509
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
510
|
+
control_unit_idx : np.ndarray
|
|
511
|
+
Indices of control units.
|
|
512
|
+
n_units : int
|
|
513
|
+
Number of units.
|
|
514
|
+
n_periods : int
|
|
515
|
+
Number of periods.
|
|
516
|
+
|
|
517
|
+
Returns
|
|
518
|
+
-------
|
|
519
|
+
_PrecomputedStructures
|
|
520
|
+
Pre-computed structures for efficient reuse.
|
|
521
|
+
"""
|
|
522
|
+
# Compute pairwise unit distances (for all observation-specific weights)
|
|
523
|
+
# Following Equation 3 (page 7): RMSE between units over pre-treatment
|
|
524
|
+
if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None:
|
|
525
|
+
# Use Rust backend for parallel distance computation (4-8x speedup)
|
|
526
|
+
unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64))
|
|
527
|
+
else:
|
|
528
|
+
unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)
|
|
529
|
+
|
|
530
|
+
# Pre-compute time distance vectors for each target period
|
|
531
|
+
# Time distance: |t - s| for all s and each target t
|
|
532
|
+
time_dist_matrix = np.abs(
|
|
533
|
+
np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
|
|
534
|
+
) # (n_periods, n_periods) where [t, s] = |t - s|
|
|
535
|
+
|
|
536
|
+
# Control and treatment masks
|
|
537
|
+
control_mask = D == 0
|
|
538
|
+
treated_mask = D == 1
|
|
539
|
+
|
|
540
|
+
# Identify treated observations
|
|
541
|
+
treated_observations = list(zip(*np.where(treated_mask)))
|
|
542
|
+
|
|
543
|
+
# Control observations for LOOCV
|
|
544
|
+
control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
545
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])]
|
|
546
|
+
|
|
547
|
+
return {
|
|
548
|
+
"unit_dist_matrix": unit_dist_matrix,
|
|
549
|
+
"time_dist_matrix": time_dist_matrix,
|
|
550
|
+
"control_mask": control_mask,
|
|
551
|
+
"treated_mask": treated_mask,
|
|
552
|
+
"treated_observations": treated_observations,
|
|
553
|
+
"control_obs": control_obs,
|
|
554
|
+
"control_unit_idx": control_unit_idx,
|
|
555
|
+
"D": D,
|
|
556
|
+
"Y": Y,
|
|
557
|
+
"n_units": n_units,
|
|
558
|
+
"n_periods": n_periods,
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
def _compute_all_unit_distances(
|
|
562
|
+
self,
|
|
563
|
+
Y: np.ndarray,
|
|
564
|
+
D: np.ndarray,
|
|
565
|
+
n_units: int,
|
|
566
|
+
n_periods: int,
|
|
567
|
+
) -> np.ndarray:
|
|
568
|
+
"""
|
|
569
|
+
Compute pairwise unit distance matrix using vectorized operations.
|
|
570
|
+
|
|
571
|
+
Following Equation 3 (page 7):
|
|
572
|
+
dist_unit_{-t}(j, i) = sqrt(Σ_u (Y_{iu} - Y_{ju})² / n_valid)
|
|
573
|
+
|
|
574
|
+
For efficiency, we compute a base distance matrix excluding all treated
|
|
575
|
+
observations, which provides a good approximation. The exact per-observation
|
|
576
|
+
distances are refined when needed.
|
|
577
|
+
|
|
578
|
+
Uses vectorized numpy operations with masked arrays for O(n²) complexity
|
|
579
|
+
but with highly optimized inner loops via numpy/BLAS.
|
|
580
|
+
|
|
581
|
+
Parameters
|
|
582
|
+
----------
|
|
583
|
+
Y : np.ndarray
|
|
584
|
+
Outcome matrix (n_periods x n_units).
|
|
585
|
+
D : np.ndarray
|
|
586
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
587
|
+
n_units : int
|
|
588
|
+
Number of units.
|
|
589
|
+
n_periods : int
|
|
590
|
+
Number of periods.
|
|
591
|
+
|
|
592
|
+
Returns
|
|
593
|
+
-------
|
|
594
|
+
np.ndarray
|
|
595
|
+
Pairwise distance matrix (n_units x n_units).
|
|
596
|
+
"""
|
|
597
|
+
# Mask for valid observations: control periods only (D=0), non-NaN
|
|
598
|
+
valid_mask = (D == 0) & ~np.isnan(Y)
|
|
599
|
+
|
|
600
|
+
# Replace invalid values with NaN for masked computation
|
|
601
|
+
Y_masked = np.where(valid_mask, Y, np.nan)
|
|
602
|
+
|
|
603
|
+
# Transpose to (n_units, n_periods) for easier broadcasting
|
|
604
|
+
Y_T = Y_masked.T # (n_units, n_periods)
|
|
605
|
+
|
|
606
|
+
# Compute pairwise squared differences using broadcasting
|
|
607
|
+
# Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
|
|
608
|
+
# Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
|
|
609
|
+
# diff has shape (n_units, n_units, n_periods)
|
|
610
|
+
diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :]
|
|
611
|
+
sq_diff = diff ** 2
|
|
612
|
+
|
|
613
|
+
# Count valid (non-NaN) observations per pair
|
|
614
|
+
# A difference is valid only if both units have valid observations
|
|
615
|
+
valid_diff = ~np.isnan(sq_diff)
|
|
616
|
+
n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units)
|
|
617
|
+
|
|
618
|
+
# Compute sum of squared differences (treating NaN as 0)
|
|
619
|
+
sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units)
|
|
620
|
+
|
|
621
|
+
# Compute RMSE distance: sqrt(sum / n_valid)
|
|
622
|
+
# Avoid division by zero
|
|
623
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
624
|
+
dist_matrix = np.sqrt(sq_diff_sum / n_valid)
|
|
625
|
+
|
|
626
|
+
# Set pairs with no valid observations to inf
|
|
627
|
+
dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf)
|
|
628
|
+
|
|
629
|
+
# Ensure diagonal is 0 (same unit distance)
|
|
630
|
+
np.fill_diagonal(dist_matrix, 0.0)
|
|
631
|
+
|
|
632
|
+
return dist_matrix
|
|
633
|
+
|
|
634
|
+
def _compute_unit_distance_for_obs(
|
|
635
|
+
self,
|
|
636
|
+
Y: np.ndarray,
|
|
637
|
+
D: np.ndarray,
|
|
638
|
+
j: int,
|
|
639
|
+
i: int,
|
|
640
|
+
target_period: int,
|
|
641
|
+
) -> float:
|
|
642
|
+
"""
|
|
643
|
+
Compute observation-specific pairwise distance from unit j to unit i.
|
|
644
|
+
|
|
645
|
+
This is the exact computation from Equation 3, excluding the target period.
|
|
646
|
+
Used when the base distance matrix approximation is insufficient.
|
|
647
|
+
|
|
648
|
+
Parameters
|
|
649
|
+
----------
|
|
650
|
+
Y : np.ndarray
|
|
651
|
+
Outcome matrix (n_periods x n_units).
|
|
652
|
+
D : np.ndarray
|
|
653
|
+
Treatment indicator matrix.
|
|
654
|
+
j : int
|
|
655
|
+
Control unit index.
|
|
656
|
+
i : int
|
|
657
|
+
Treated unit index.
|
|
658
|
+
target_period : int
|
|
659
|
+
Target period to exclude.
|
|
660
|
+
|
|
661
|
+
Returns
|
|
662
|
+
-------
|
|
663
|
+
float
|
|
664
|
+
Pairwise RMSE distance.
|
|
665
|
+
"""
|
|
666
|
+
n_periods = Y.shape[0]
|
|
667
|
+
|
|
668
|
+
# Mask: exclude target period, both units must be untreated, non-NaN
|
|
669
|
+
valid = np.ones(n_periods, dtype=bool)
|
|
670
|
+
valid[target_period] = False
|
|
671
|
+
valid &= (D[:, i] == 0) & (D[:, j] == 0)
|
|
672
|
+
valid &= ~np.isnan(Y[:, i]) & ~np.isnan(Y[:, j])
|
|
673
|
+
|
|
674
|
+
if np.any(valid):
|
|
675
|
+
sq_diffs = (Y[valid, i] - Y[valid, j]) ** 2
|
|
676
|
+
return np.sqrt(np.mean(sq_diffs))
|
|
677
|
+
else:
|
|
678
|
+
return np.inf
|
|
679
|
+
|
|
680
|
+
def _univariate_loocv_search(
|
|
681
|
+
self,
|
|
682
|
+
Y: np.ndarray,
|
|
683
|
+
D: np.ndarray,
|
|
684
|
+
control_mask: np.ndarray,
|
|
685
|
+
control_unit_idx: np.ndarray,
|
|
686
|
+
n_units: int,
|
|
687
|
+
n_periods: int,
|
|
688
|
+
param_name: str,
|
|
689
|
+
grid: List[float],
|
|
690
|
+
fixed_params: Dict[str, float],
|
|
691
|
+
) -> Tuple[float, float]:
|
|
692
|
+
"""
|
|
693
|
+
Search over one parameter with others fixed.
|
|
694
|
+
|
|
695
|
+
Following paper's footnote 2, this performs a univariate grid search
|
|
696
|
+
for one tuning parameter while holding others fixed. The fixed_params
|
|
697
|
+
use 0.0 for disabled time/unit weights and _LAMBDA_INF for disabled
|
|
698
|
+
factor model:
|
|
699
|
+
- lambda_nn = inf: Skip nuclear norm regularization (L=0)
|
|
700
|
+
- lambda_time = 0.0: Uniform time weights (exp(-0×dist)=1)
|
|
701
|
+
- lambda_unit = 0.0: Uniform unit weights (exp(-0×dist)=1)
|
|
702
|
+
|
|
703
|
+
Parameters
|
|
704
|
+
----------
|
|
705
|
+
Y : np.ndarray
|
|
706
|
+
Outcome matrix (n_periods x n_units).
|
|
707
|
+
D : np.ndarray
|
|
708
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
709
|
+
control_mask : np.ndarray
|
|
710
|
+
Boolean mask for control observations.
|
|
711
|
+
control_unit_idx : np.ndarray
|
|
712
|
+
Indices of control units.
|
|
713
|
+
n_units : int
|
|
714
|
+
Number of units.
|
|
715
|
+
n_periods : int
|
|
716
|
+
Number of periods.
|
|
717
|
+
param_name : str
|
|
718
|
+
Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
|
|
719
|
+
grid : List[float]
|
|
720
|
+
Grid of values to search over.
|
|
721
|
+
fixed_params : Dict[str, float]
|
|
722
|
+
Fixed values for other parameters. May include _LAMBDA_INF for lambda_nn.
|
|
723
|
+
|
|
724
|
+
Returns
|
|
725
|
+
-------
|
|
726
|
+
Tuple[float, float]
|
|
727
|
+
(best_value, best_score) for the searched parameter.
|
|
728
|
+
"""
|
|
729
|
+
best_score = np.inf
|
|
730
|
+
best_value = grid[0] if grid else 0.0
|
|
731
|
+
|
|
732
|
+
for value in grid:
|
|
733
|
+
params = {**fixed_params, param_name: value}
|
|
734
|
+
|
|
735
|
+
lambda_time = params.get('lambda_time', 0.0)
|
|
736
|
+
lambda_unit = params.get('lambda_unit', 0.0)
|
|
737
|
+
lambda_nn = params.get('lambda_nn', 0.0)
|
|
738
|
+
|
|
739
|
+
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
|
|
740
|
+
# λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed)
|
|
741
|
+
if np.isinf(lambda_nn):
|
|
742
|
+
lambda_nn = 1e10
|
|
743
|
+
|
|
744
|
+
try:
|
|
745
|
+
score = self._loocv_score_obs_specific(
|
|
746
|
+
Y, D, control_mask, control_unit_idx,
|
|
747
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
748
|
+
n_units, n_periods
|
|
749
|
+
)
|
|
750
|
+
if score < best_score:
|
|
751
|
+
best_score = score
|
|
752
|
+
best_value = value
|
|
753
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
754
|
+
continue
|
|
755
|
+
|
|
756
|
+
return best_value, best_score
|
|
757
|
+
|
|
758
|
+
def _cycling_parameter_search(
|
|
759
|
+
self,
|
|
760
|
+
Y: np.ndarray,
|
|
761
|
+
D: np.ndarray,
|
|
762
|
+
control_mask: np.ndarray,
|
|
763
|
+
control_unit_idx: np.ndarray,
|
|
764
|
+
n_units: int,
|
|
765
|
+
n_periods: int,
|
|
766
|
+
initial_lambda: Tuple[float, float, float],
|
|
767
|
+
max_cycles: int = 10,
|
|
768
|
+
) -> Tuple[float, float, float]:
|
|
769
|
+
"""
|
|
770
|
+
Cycle through parameters until convergence (coordinate descent).
|
|
771
|
+
|
|
772
|
+
Following paper's footnote 2 (Stage 2), this iteratively optimizes
|
|
773
|
+
each tuning parameter while holding the others fixed, until convergence.
|
|
774
|
+
|
|
775
|
+
Parameters
|
|
776
|
+
----------
|
|
777
|
+
Y : np.ndarray
|
|
778
|
+
Outcome matrix (n_periods x n_units).
|
|
779
|
+
D : np.ndarray
|
|
780
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
781
|
+
control_mask : np.ndarray
|
|
782
|
+
Boolean mask for control observations.
|
|
783
|
+
control_unit_idx : np.ndarray
|
|
784
|
+
Indices of control units.
|
|
785
|
+
n_units : int
|
|
786
|
+
Number of units.
|
|
787
|
+
n_periods : int
|
|
788
|
+
Number of periods.
|
|
789
|
+
initial_lambda : Tuple[float, float, float]
|
|
790
|
+
Initial values (lambda_time, lambda_unit, lambda_nn).
|
|
791
|
+
max_cycles : int, default=10
|
|
792
|
+
Maximum number of coordinate descent cycles.
|
|
793
|
+
|
|
794
|
+
Returns
|
|
795
|
+
-------
|
|
796
|
+
Tuple[float, float, float]
|
|
797
|
+
Optimized (lambda_time, lambda_unit, lambda_nn).
|
|
798
|
+
"""
|
|
799
|
+
lambda_time, lambda_unit, lambda_nn = initial_lambda
|
|
800
|
+
prev_score = np.inf
|
|
801
|
+
|
|
802
|
+
for cycle in range(max_cycles):
|
|
803
|
+
# Optimize λ_unit (fix λ_time, λ_nn)
|
|
804
|
+
lambda_unit, _ = self._univariate_loocv_search(
|
|
805
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
806
|
+
'lambda_unit', self.lambda_unit_grid,
|
|
807
|
+
{'lambda_time': lambda_time, 'lambda_nn': lambda_nn}
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
# Optimize λ_time (fix λ_unit, λ_nn)
|
|
811
|
+
lambda_time, _ = self._univariate_loocv_search(
|
|
812
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
813
|
+
'lambda_time', self.lambda_time_grid,
|
|
814
|
+
{'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn}
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
# Optimize λ_nn (fix λ_unit, λ_time)
|
|
818
|
+
lambda_nn, score = self._univariate_loocv_search(
|
|
819
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
820
|
+
'lambda_nn', self.lambda_nn_grid,
|
|
821
|
+
{'lambda_unit': lambda_unit, 'lambda_time': lambda_time}
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
# Check convergence
|
|
825
|
+
if abs(score - prev_score) < 1e-6:
|
|
826
|
+
logger.debug(
|
|
827
|
+
"Cycling search converged after %d cycles with score %.6f",
|
|
828
|
+
cycle + 1, score
|
|
829
|
+
)
|
|
830
|
+
break
|
|
831
|
+
prev_score = score
|
|
832
|
+
|
|
833
|
+
return lambda_time, lambda_unit, lambda_nn
|
|
834
|
+
|
|
835
|
+
# =========================================================================
|
|
836
|
+
# Joint estimation method
|
|
837
|
+
# =========================================================================
|
|
838
|
+
|
|
839
|
+
def _compute_joint_weights(
|
|
840
|
+
self,
|
|
841
|
+
Y: np.ndarray,
|
|
842
|
+
D: np.ndarray,
|
|
843
|
+
lambda_time: float,
|
|
844
|
+
lambda_unit: float,
|
|
845
|
+
treated_periods: int,
|
|
846
|
+
n_units: int,
|
|
847
|
+
n_periods: int,
|
|
848
|
+
) -> np.ndarray:
|
|
849
|
+
"""
|
|
850
|
+
Compute distance-based weights for joint estimation.
|
|
851
|
+
|
|
852
|
+
Following the reference implementation, weights are computed based on:
|
|
853
|
+
- Time distance: distance to center of treated block
|
|
854
|
+
- Unit distance: RMSE to average treated trajectory over pre-periods
|
|
855
|
+
|
|
856
|
+
Parameters
|
|
857
|
+
----------
|
|
858
|
+
Y : np.ndarray
|
|
859
|
+
Outcome matrix (n_periods x n_units).
|
|
860
|
+
D : np.ndarray
|
|
861
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
862
|
+
lambda_time : float
|
|
863
|
+
Time weight decay parameter.
|
|
864
|
+
lambda_unit : float
|
|
865
|
+
Unit weight decay parameter.
|
|
866
|
+
treated_periods : int
|
|
867
|
+
Number of post-treatment periods.
|
|
868
|
+
n_units : int
|
|
869
|
+
Number of units.
|
|
870
|
+
n_periods : int
|
|
871
|
+
Number of periods.
|
|
872
|
+
|
|
873
|
+
Returns
|
|
874
|
+
-------
|
|
875
|
+
np.ndarray
|
|
876
|
+
Weight matrix (n_periods x n_units).
|
|
877
|
+
"""
|
|
878
|
+
# Identify treated units (ever treated)
|
|
879
|
+
treated_mask = np.any(D == 1, axis=0)
|
|
880
|
+
treated_unit_idx = np.where(treated_mask)[0]
|
|
881
|
+
|
|
882
|
+
if len(treated_unit_idx) == 0:
|
|
883
|
+
raise ValueError("No treated units found")
|
|
884
|
+
|
|
885
|
+
# Time weights: distance to center of treated block
|
|
886
|
+
# Following reference: center = T - treated_periods/2
|
|
887
|
+
center = n_periods - treated_periods / 2.0
|
|
888
|
+
dist_time = np.abs(np.arange(n_periods, dtype=float) - center)
|
|
889
|
+
delta_time = np.exp(-lambda_time * dist_time)
|
|
890
|
+
|
|
891
|
+
# Unit weights: RMSE to average treated trajectory over pre-periods
|
|
892
|
+
# Compute average treated trajectory (use nanmean to handle NaN)
|
|
893
|
+
average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1)
|
|
894
|
+
|
|
895
|
+
# Pre-period mask: 1 in pre, 0 in post
|
|
896
|
+
pre_mask = np.ones(n_periods, dtype=float)
|
|
897
|
+
pre_mask[-treated_periods:] = 0.0
|
|
898
|
+
|
|
899
|
+
# Compute RMS distance for each unit
|
|
900
|
+
# dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre)
|
|
901
|
+
# Use NaN-safe operations: treat NaN differences as 0 (excluded)
|
|
902
|
+
diff = average_treated[:, np.newaxis] - Y
|
|
903
|
+
diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis]
|
|
904
|
+
|
|
905
|
+
# Count valid observations per unit in pre-period
|
|
906
|
+
# Must check diff is finite (both Y and average_treated finite)
|
|
907
|
+
# to match the periods contributing to diff_sq
|
|
908
|
+
valid_count = np.sum(
|
|
909
|
+
np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0
|
|
910
|
+
)
|
|
911
|
+
sum_sq = np.sum(diff_sq, axis=0)
|
|
912
|
+
n_pre = np.sum(pre_mask)
|
|
913
|
+
|
|
914
|
+
if n_pre == 0:
|
|
915
|
+
raise ValueError("No pre-treatment periods")
|
|
916
|
+
|
|
917
|
+
# Track units with no valid pre-period data
|
|
918
|
+
no_valid_pre = valid_count == 0
|
|
919
|
+
|
|
920
|
+
# Use valid count per unit (avoid division by zero for calculation)
|
|
921
|
+
valid_count_safe = np.maximum(valid_count, 1)
|
|
922
|
+
dist_unit = np.sqrt(sum_sq / valid_count_safe)
|
|
923
|
+
|
|
924
|
+
# Units with no valid pre-period data get zero weight
|
|
925
|
+
# (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0)
|
|
926
|
+
delta_unit = np.exp(-lambda_unit * dist_unit)
|
|
927
|
+
delta_unit[no_valid_pre] = 0.0
|
|
928
|
+
|
|
929
|
+
# Outer product: (n_periods x n_units)
|
|
930
|
+
delta = np.outer(delta_time, delta_unit)
|
|
931
|
+
|
|
932
|
+
return delta
|
|
933
|
+
|
|
934
|
+
def _loocv_score_joint(
|
|
935
|
+
self,
|
|
936
|
+
Y: np.ndarray,
|
|
937
|
+
D: np.ndarray,
|
|
938
|
+
control_obs: List[Tuple[int, int]],
|
|
939
|
+
lambda_time: float,
|
|
940
|
+
lambda_unit: float,
|
|
941
|
+
lambda_nn: float,
|
|
942
|
+
treated_periods: int,
|
|
943
|
+
n_units: int,
|
|
944
|
+
n_periods: int,
|
|
945
|
+
) -> float:
|
|
946
|
+
"""
|
|
947
|
+
Compute LOOCV score for joint method with specific parameter combination.
|
|
948
|
+
|
|
949
|
+
Following paper's Equation 5:
|
|
950
|
+
Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
951
|
+
|
|
952
|
+
For joint method, we exclude each control observation, fit the joint model
|
|
953
|
+
on remaining data, and compute the pseudo-treatment effect at the excluded obs.
|
|
954
|
+
|
|
955
|
+
Parameters
|
|
956
|
+
----------
|
|
957
|
+
Y : np.ndarray
|
|
958
|
+
Outcome matrix (n_periods x n_units).
|
|
959
|
+
D : np.ndarray
|
|
960
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
961
|
+
control_obs : List[Tuple[int, int]]
|
|
962
|
+
List of (t, i) control observations for LOOCV.
|
|
963
|
+
lambda_time : float
|
|
964
|
+
Time weight decay parameter.
|
|
965
|
+
lambda_unit : float
|
|
966
|
+
Unit weight decay parameter.
|
|
967
|
+
lambda_nn : float
|
|
968
|
+
Nuclear norm regularization parameter.
|
|
969
|
+
treated_periods : int
|
|
970
|
+
Number of post-treatment periods.
|
|
971
|
+
n_units : int
|
|
972
|
+
Number of units.
|
|
973
|
+
n_periods : int
|
|
974
|
+
Number of periods.
|
|
975
|
+
|
|
976
|
+
Returns
|
|
977
|
+
-------
|
|
978
|
+
float
|
|
979
|
+
LOOCV score (sum of squared pseudo-treatment effects).
|
|
980
|
+
"""
|
|
981
|
+
# Compute global weights (same for all LOOCV iterations)
|
|
982
|
+
delta = self._compute_joint_weights(
|
|
983
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
tau_sq_sum = 0.0
|
|
987
|
+
n_valid = 0
|
|
988
|
+
|
|
989
|
+
for t_ex, i_ex in control_obs:
|
|
990
|
+
# Create modified delta with excluded observation zeroed out
|
|
991
|
+
delta_ex = delta.copy()
|
|
992
|
+
delta_ex[t_ex, i_ex] = 0.0
|
|
993
|
+
|
|
994
|
+
try:
|
|
995
|
+
# Fit joint model excluding this observation
|
|
996
|
+
if lambda_nn >= 1e10:
|
|
997
|
+
mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta_ex)
|
|
998
|
+
L = np.zeros((n_periods, n_units))
|
|
999
|
+
else:
|
|
1000
|
+
mu, alpha, beta, L, tau = self._solve_joint_with_lowrank(
|
|
1001
|
+
Y, D, delta_ex, lambda_nn, self.max_iter, self.tol
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
# Pseudo treatment effect: τ = Y - μ - α - β - L
|
|
1005
|
+
if np.isfinite(Y[t_ex, i_ex]):
|
|
1006
|
+
tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex]
|
|
1007
|
+
tau_sq_sum += tau_loocv ** 2
|
|
1008
|
+
n_valid += 1
|
|
1009
|
+
|
|
1010
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1011
|
+
# Any failure means this λ combination is invalid per Equation 5
|
|
1012
|
+
return np.inf
|
|
1013
|
+
|
|
1014
|
+
if n_valid == 0:
|
|
1015
|
+
return np.inf
|
|
1016
|
+
|
|
1017
|
+
return tau_sq_sum
|
|
1018
|
+
|
|
1019
|
+
def _solve_joint_no_lowrank(
|
|
1020
|
+
self,
|
|
1021
|
+
Y: np.ndarray,
|
|
1022
|
+
D: np.ndarray,
|
|
1023
|
+
delta: np.ndarray,
|
|
1024
|
+
) -> Tuple[float, np.ndarray, np.ndarray, float]:
|
|
1025
|
+
"""
|
|
1026
|
+
Solve joint TWFE + treatment via weighted least squares (no low-rank).
|
|
1027
|
+
|
|
1028
|
+
Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})²
|
|
1029
|
+
|
|
1030
|
+
Parameters
|
|
1031
|
+
----------
|
|
1032
|
+
Y : np.ndarray
|
|
1033
|
+
Outcome matrix (n_periods x n_units).
|
|
1034
|
+
D : np.ndarray
|
|
1035
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
1036
|
+
delta : np.ndarray
|
|
1037
|
+
Weight matrix (n_periods x n_units).
|
|
1038
|
+
|
|
1039
|
+
Returns
|
|
1040
|
+
-------
|
|
1041
|
+
Tuple[float, np.ndarray, np.ndarray, float]
|
|
1042
|
+
(mu, alpha, beta, tau) estimated parameters.
|
|
1043
|
+
"""
|
|
1044
|
+
n_periods, n_units = Y.shape
|
|
1045
|
+
|
|
1046
|
+
# Flatten matrices for regression
|
|
1047
|
+
y = Y.flatten() # length n_periods * n_units
|
|
1048
|
+
w = D.flatten()
|
|
1049
|
+
weights = delta.flatten()
|
|
1050
|
+
|
|
1051
|
+
# Handle NaN values: zero weight for NaN outcomes/weights, impute with 0
|
|
1052
|
+
# This ensures NaN observations don't contribute to estimation
|
|
1053
|
+
valid_y = np.isfinite(y)
|
|
1054
|
+
valid_w = np.isfinite(weights)
|
|
1055
|
+
valid_mask = valid_y & valid_w
|
|
1056
|
+
weights = np.where(valid_mask, weights, 0.0)
|
|
1057
|
+
y = np.where(valid_mask, y, 0.0)
|
|
1058
|
+
|
|
1059
|
+
sqrt_weights = np.sqrt(np.maximum(weights, 0))
|
|
1060
|
+
|
|
1061
|
+
# Check for all-zero weights (matches Rust's sum_w < 1e-10 check)
|
|
1062
|
+
sum_w = np.sum(weights)
|
|
1063
|
+
if sum_w < 1e-10:
|
|
1064
|
+
raise ValueError("All weights are zero - cannot estimate")
|
|
1065
|
+
|
|
1066
|
+
# Build design matrix: [intercept, unit_dummies, time_dummies, treatment]
|
|
1067
|
+
# Total columns: 1 + n_units + n_periods + 1
|
|
1068
|
+
# But we need to drop one unit and one time dummy for identification
|
|
1069
|
+
# Drop first unit (unit 0) and first time (time 0)
|
|
1070
|
+
n_obs = n_periods * n_units
|
|
1071
|
+
n_params = 1 + (n_units - 1) + (n_periods - 1) + 1
|
|
1072
|
+
|
|
1073
|
+
X = np.zeros((n_obs, n_params))
|
|
1074
|
+
X[:, 0] = 1.0 # intercept
|
|
1075
|
+
|
|
1076
|
+
# Unit dummies (skip unit 0)
|
|
1077
|
+
for i in range(1, n_units):
|
|
1078
|
+
for t in range(n_periods):
|
|
1079
|
+
X[t * n_units + i, i] = 1.0
|
|
1080
|
+
|
|
1081
|
+
# Time dummies (skip time 0)
|
|
1082
|
+
for t in range(1, n_periods):
|
|
1083
|
+
for i in range(n_units):
|
|
1084
|
+
X[t * n_units + i, (n_units - 1) + t] = 1.0
|
|
1085
|
+
|
|
1086
|
+
# Treatment indicator
|
|
1087
|
+
X[:, -1] = w
|
|
1088
|
+
|
|
1089
|
+
# Apply weights
|
|
1090
|
+
X_weighted = X * sqrt_weights[:, np.newaxis]
|
|
1091
|
+
y_weighted = y * sqrt_weights
|
|
1092
|
+
|
|
1093
|
+
# Solve weighted least squares
|
|
1094
|
+
try:
|
|
1095
|
+
coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None)
|
|
1096
|
+
except np.linalg.LinAlgError:
|
|
1097
|
+
# Fallback: use pseudo-inverse
|
|
1098
|
+
coeffs = np.linalg.pinv(X_weighted) @ y_weighted
|
|
1099
|
+
|
|
1100
|
+
# Extract parameters
|
|
1101
|
+
mu = coeffs[0]
|
|
1102
|
+
alpha = np.zeros(n_units)
|
|
1103
|
+
alpha[1:] = coeffs[1:n_units]
|
|
1104
|
+
beta = np.zeros(n_periods)
|
|
1105
|
+
beta[1:] = coeffs[n_units:(n_units + n_periods - 1)]
|
|
1106
|
+
tau = coeffs[-1]
|
|
1107
|
+
|
|
1108
|
+
return float(mu), alpha, beta, float(tau)
|
|
1109
|
+
|
|
1110
|
+
def _solve_joint_with_lowrank(
|
|
1111
|
+
self,
|
|
1112
|
+
Y: np.ndarray,
|
|
1113
|
+
D: np.ndarray,
|
|
1114
|
+
delta: np.ndarray,
|
|
1115
|
+
lambda_nn: float,
|
|
1116
|
+
max_iter: int = 100,
|
|
1117
|
+
tol: float = 1e-6,
|
|
1118
|
+
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]:
|
|
1119
|
+
"""
|
|
1120
|
+
Solve joint TWFE + treatment + low-rank via alternating minimization.
|
|
1121
|
+
|
|
1122
|
+
Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_*
|
|
1123
|
+
|
|
1124
|
+
Parameters
|
|
1125
|
+
----------
|
|
1126
|
+
Y : np.ndarray
|
|
1127
|
+
Outcome matrix (n_periods x n_units).
|
|
1128
|
+
D : np.ndarray
|
|
1129
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
1130
|
+
delta : np.ndarray
|
|
1131
|
+
Weight matrix (n_periods x n_units).
|
|
1132
|
+
lambda_nn : float
|
|
1133
|
+
Nuclear norm regularization parameter.
|
|
1134
|
+
max_iter : int, default=100
|
|
1135
|
+
Maximum iterations for alternating minimization.
|
|
1136
|
+
tol : float, default=1e-6
|
|
1137
|
+
Convergence tolerance.
|
|
1138
|
+
|
|
1139
|
+
Returns
|
|
1140
|
+
-------
|
|
1141
|
+
Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]
|
|
1142
|
+
(mu, alpha, beta, L, tau) estimated parameters.
|
|
1143
|
+
"""
|
|
1144
|
+
n_periods, n_units = Y.shape
|
|
1145
|
+
|
|
1146
|
+
# Handle NaN values: impute with 0 for computations
|
|
1147
|
+
# The solver will also zero weights for NaN observations
|
|
1148
|
+
Y_safe = np.where(np.isfinite(Y), Y, 0.0)
|
|
1149
|
+
|
|
1150
|
+
# Mask delta to exclude NaN outcomes from estimation
|
|
1151
|
+
# This ensures NaN observations don't contribute to the gradient step
|
|
1152
|
+
nan_mask = ~np.isfinite(Y)
|
|
1153
|
+
delta_masked = delta.copy()
|
|
1154
|
+
delta_masked[nan_mask] = 0.0
|
|
1155
|
+
|
|
1156
|
+
# Initialize L = 0
|
|
1157
|
+
L = np.zeros((n_periods, n_units))
|
|
1158
|
+
|
|
1159
|
+
for iteration in range(max_iter):
|
|
1160
|
+
L_old = L.copy()
|
|
1161
|
+
|
|
1162
|
+
# Step 1: Fix L, solve for (mu, alpha, beta, tau)
|
|
1163
|
+
# Adjusted outcome: Y - L (using NaN-safe Y)
|
|
1164
|
+
# Pass masked delta to exclude NaN observations from WLS
|
|
1165
|
+
Y_adj = Y_safe - L
|
|
1166
|
+
mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta_masked)
|
|
1167
|
+
|
|
1168
|
+
# Step 2: Fix (mu, alpha, beta, tau), update L
|
|
1169
|
+
# Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y)
|
|
1170
|
+
R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D
|
|
1171
|
+
|
|
1172
|
+
# Weighted proximal step for L (soft-threshold SVD)
|
|
1173
|
+
# Normalize weights (using masked delta to exclude NaN observations)
|
|
1174
|
+
delta_max = np.max(delta_masked)
|
|
1175
|
+
if delta_max > 0:
|
|
1176
|
+
delta_norm = delta_masked / delta_max
|
|
1177
|
+
else:
|
|
1178
|
+
delta_norm = delta_masked
|
|
1179
|
+
|
|
1180
|
+
# Weighted average between current L and target R
|
|
1181
|
+
# L_next = L + delta_norm * (R - L), then soft-threshold
|
|
1182
|
+
# NaN observations have delta_norm=0, so they don't influence L update
|
|
1183
|
+
gradient_step = L + delta_norm * (R - L)
|
|
1184
|
+
|
|
1185
|
+
# Soft-threshold singular values
|
|
1186
|
+
# Use eta * lambda_nn for proper proximal step size (matches Rust)
|
|
1187
|
+
eta = 1.0 / delta_max if delta_max > 0 else 1.0
|
|
1188
|
+
L = self._soft_threshold_svd(gradient_step, eta * lambda_nn)
|
|
1189
|
+
|
|
1190
|
+
# Check convergence
|
|
1191
|
+
if np.max(np.abs(L - L_old)) < tol:
|
|
1192
|
+
break
|
|
1193
|
+
|
|
1194
|
+
return mu, alpha, beta, L, tau
|
|
1195
|
+
|
|
1196
|
+
def _fit_joint(
|
|
1197
|
+
self,
|
|
1198
|
+
data: pd.DataFrame,
|
|
1199
|
+
outcome: str,
|
|
1200
|
+
treatment: str,
|
|
1201
|
+
unit: str,
|
|
1202
|
+
time: str,
|
|
1203
|
+
) -> TROPResults:
|
|
1204
|
+
"""
|
|
1205
|
+
Fit TROP using joint weighted least squares method.
|
|
1206
|
+
|
|
1207
|
+
This method estimates a single scalar treatment effect τ along with
|
|
1208
|
+
fixed effects and optional low-rank factor adjustment.
|
|
1209
|
+
|
|
1210
|
+
Parameters
|
|
1211
|
+
----------
|
|
1212
|
+
data : pd.DataFrame
|
|
1213
|
+
Panel data.
|
|
1214
|
+
outcome : str
|
|
1215
|
+
Outcome variable column name.
|
|
1216
|
+
treatment : str
|
|
1217
|
+
Treatment indicator column name.
|
|
1218
|
+
unit : str
|
|
1219
|
+
Unit identifier column name.
|
|
1220
|
+
time : str
|
|
1221
|
+
Time period column name.
|
|
1222
|
+
|
|
1223
|
+
Returns
|
|
1224
|
+
-------
|
|
1225
|
+
TROPResults
|
|
1226
|
+
Estimation results.
|
|
1227
|
+
|
|
1228
|
+
Notes
|
|
1229
|
+
-----
|
|
1230
|
+
Bootstrap variance estimation assumes simultaneous treatment adoption
|
|
1231
|
+
(fixed `treated_periods` across resamples). The treatment timing is
|
|
1232
|
+
inferred from the data once and held constant for all bootstrap
|
|
1233
|
+
iterations. For staggered adoption designs where treatment timing varies
|
|
1234
|
+
across units, use `method="twostep"` which computes observation-specific
|
|
1235
|
+
weights that naturally handle heterogeneous timing.
|
|
1236
|
+
"""
|
|
1237
|
+
# Data setup (same as twostep method)
|
|
1238
|
+
all_units = sorted(data[unit].unique())
|
|
1239
|
+
all_periods = sorted(data[time].unique())
|
|
1240
|
+
|
|
1241
|
+
n_units = len(all_units)
|
|
1242
|
+
n_periods = len(all_periods)
|
|
1243
|
+
|
|
1244
|
+
idx_to_unit = {i: u for i, u in enumerate(all_units)}
|
|
1245
|
+
idx_to_period = {i: p for i, p in enumerate(all_periods)}
|
|
1246
|
+
|
|
1247
|
+
# Create matrices
|
|
1248
|
+
Y = (
|
|
1249
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1250
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1251
|
+
.values
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
D_raw = (
|
|
1255
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1256
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1257
|
+
)
|
|
1258
|
+
missing_mask = pd.isna(D_raw).values
|
|
1259
|
+
D = D_raw.fillna(0).astype(int).values
|
|
1260
|
+
|
|
1261
|
+
# Validate absorbing state
|
|
1262
|
+
violating_units = []
|
|
1263
|
+
for unit_idx in range(n_units):
|
|
1264
|
+
observed_mask = ~missing_mask[:, unit_idx]
|
|
1265
|
+
observed_d = D[observed_mask, unit_idx]
|
|
1266
|
+
if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
|
|
1267
|
+
violating_units.append(all_units[unit_idx])
|
|
1268
|
+
|
|
1269
|
+
if violating_units:
|
|
1270
|
+
raise ValueError(
|
|
1271
|
+
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
|
|
1272
|
+
f"D[t, unit] must be monotonic non-decreasing."
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
# Identify treated observations
|
|
1276
|
+
treated_mask = D == 1
|
|
1277
|
+
n_treated_obs = np.sum(treated_mask)
|
|
1278
|
+
|
|
1279
|
+
if n_treated_obs == 0:
|
|
1280
|
+
raise ValueError("No treated observations found")
|
|
1281
|
+
|
|
1282
|
+
# Identify treated and control units
|
|
1283
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
1284
|
+
treated_unit_idx = np.where(unit_ever_treated)[0]
|
|
1285
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
1286
|
+
|
|
1287
|
+
if len(control_unit_idx) == 0:
|
|
1288
|
+
raise ValueError("No control units found")
|
|
1289
|
+
|
|
1290
|
+
# Determine pre/post periods
|
|
1291
|
+
first_treat_period = None
|
|
1292
|
+
for t in range(n_periods):
|
|
1293
|
+
if np.any(D[t, :] == 1):
|
|
1294
|
+
first_treat_period = t
|
|
1295
|
+
break
|
|
1296
|
+
|
|
1297
|
+
if first_treat_period is None:
|
|
1298
|
+
raise ValueError("Could not infer post-treatment periods from D matrix")
|
|
1299
|
+
|
|
1300
|
+
n_pre_periods = first_treat_period
|
|
1301
|
+
treated_periods = n_periods - first_treat_period
|
|
1302
|
+
n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
|
|
1303
|
+
|
|
1304
|
+
if n_pre_periods < 2:
|
|
1305
|
+
raise ValueError("Need at least 2 pre-treatment periods")
|
|
1306
|
+
|
|
1307
|
+
# Check for staggered adoption (joint method requires simultaneous treatment)
|
|
1308
|
+
# Use only observed periods (skip missing) to avoid false positives on unbalanced panels
|
|
1309
|
+
first_treat_by_unit = []
|
|
1310
|
+
for i in treated_unit_idx:
|
|
1311
|
+
observed_mask = ~missing_mask[:, i]
|
|
1312
|
+
# Get D values for observed periods only
|
|
1313
|
+
observed_d = D[observed_mask, i]
|
|
1314
|
+
observed_periods = np.where(observed_mask)[0]
|
|
1315
|
+
# Find first treatment among observed periods
|
|
1316
|
+
treated_idx = np.where(observed_d == 1)[0]
|
|
1317
|
+
if len(treated_idx) > 0:
|
|
1318
|
+
first_treat_by_unit.append(observed_periods[treated_idx[0]])
|
|
1319
|
+
|
|
1320
|
+
unique_starts = sorted(set(first_treat_by_unit))
|
|
1321
|
+
if len(unique_starts) > 1:
|
|
1322
|
+
raise ValueError(
|
|
1323
|
+
f"method='joint' requires simultaneous treatment adoption, but your data "
|
|
1324
|
+
f"shows staggered adoption (units first treated at periods {unique_starts}). "
|
|
1325
|
+
f"Use method='twostep' which properly handles staggered adoption designs."
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
# LOOCV grid search for tuning parameters
|
|
1329
|
+
# Use Rust backend when available for parallel LOOCV (5-10x speedup)
|
|
1330
|
+
best_lambda = None
|
|
1331
|
+
best_score = np.inf
|
|
1332
|
+
control_mask = D == 0
|
|
1333
|
+
|
|
1334
|
+
if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None:
|
|
1335
|
+
try:
|
|
1336
|
+
# Prepare inputs for Rust function
|
|
1337
|
+
control_mask_u8 = control_mask.astype(np.uint8)
|
|
1338
|
+
|
|
1339
|
+
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
|
|
1340
|
+
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
|
|
1341
|
+
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
|
|
1342
|
+
|
|
1343
|
+
result = _rust_loocv_grid_search_joint(
|
|
1344
|
+
Y, D.astype(np.float64), control_mask_u8,
|
|
1345
|
+
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
|
|
1346
|
+
self.max_iter, self.tol,
|
|
1347
|
+
)
|
|
1348
|
+
# Unpack result - 7 values including optional first_failed_obs
|
|
1349
|
+
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
|
|
1350
|
+
# Only accept finite scores - infinite means all fits failed
|
|
1351
|
+
if np.isfinite(best_score):
|
|
1352
|
+
best_lambda = (best_lt, best_lu, best_ln)
|
|
1353
|
+
# Emit warnings consistent with Python implementation
|
|
1354
|
+
if n_valid == 0:
|
|
1355
|
+
obs_info = ""
|
|
1356
|
+
if first_failed_obs is not None:
|
|
1357
|
+
t_idx, i_idx = first_failed_obs
|
|
1358
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
1359
|
+
warnings.warn(
|
|
1360
|
+
f"LOOCV: All {n_attempted} fits failed for "
|
|
1361
|
+
f"λ=({best_lt}, {best_lu}, {best_ln}). "
|
|
1362
|
+
f"Returning infinite score.{obs_info}",
|
|
1363
|
+
UserWarning
|
|
1364
|
+
)
|
|
1365
|
+
elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
|
|
1366
|
+
n_failed = n_attempted - n_valid
|
|
1367
|
+
obs_info = ""
|
|
1368
|
+
if first_failed_obs is not None:
|
|
1369
|
+
t_idx, i_idx = first_failed_obs
|
|
1370
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
1371
|
+
warnings.warn(
|
|
1372
|
+
f"LOOCV: {n_failed}/{n_attempted} fits failed for "
|
|
1373
|
+
f"λ=({best_lt}, {best_lu}, {best_ln}). "
|
|
1374
|
+
f"This may indicate numerical instability.{obs_info}",
|
|
1375
|
+
UserWarning
|
|
1376
|
+
)
|
|
1377
|
+
except Exception as e:
|
|
1378
|
+
# Fall back to Python implementation on error
|
|
1379
|
+
logger.debug(
|
|
1380
|
+
"Rust LOOCV grid search (joint) failed, falling back to Python: %s", e
|
|
1381
|
+
)
|
|
1382
|
+
best_lambda = None
|
|
1383
|
+
best_score = np.inf
|
|
1384
|
+
|
|
1385
|
+
# Fall back to Python implementation if Rust unavailable or failed
|
|
1386
|
+
if best_lambda is None:
|
|
1387
|
+
# Get control observations for LOOCV
|
|
1388
|
+
control_obs = [
|
|
1389
|
+
(t, i) for t in range(n_periods) for i in range(n_units)
|
|
1390
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])
|
|
1391
|
+
]
|
|
1392
|
+
|
|
1393
|
+
# Grid search with true LOOCV
|
|
1394
|
+
for lambda_time_val in self.lambda_time_grid:
|
|
1395
|
+
for lambda_unit_val in self.lambda_unit_grid:
|
|
1396
|
+
for lambda_nn_val in self.lambda_nn_grid:
|
|
1397
|
+
# Convert λ_nn=∞ → large finite value (factor model disabled)
|
|
1398
|
+
lt = lambda_time_val
|
|
1399
|
+
lu = lambda_unit_val
|
|
1400
|
+
ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val
|
|
1401
|
+
|
|
1402
|
+
try:
|
|
1403
|
+
score = self._loocv_score_joint(
|
|
1404
|
+
Y, D, control_obs, lt, lu, ln,
|
|
1405
|
+
treated_periods, n_units, n_periods
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
if score < best_score:
|
|
1409
|
+
best_score = score
|
|
1410
|
+
best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val)
|
|
1411
|
+
|
|
1412
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1413
|
+
continue
|
|
1414
|
+
|
|
1415
|
+
if best_lambda is None:
|
|
1416
|
+
warnings.warn(
|
|
1417
|
+
"All tuning parameter combinations failed. Using defaults.",
|
|
1418
|
+
UserWarning
|
|
1419
|
+
)
|
|
1420
|
+
best_lambda = (1.0, 1.0, 0.1)
|
|
1421
|
+
best_score = np.nan
|
|
1422
|
+
|
|
1423
|
+
# Final estimation with best parameters
|
|
1424
|
+
lambda_time, lambda_unit, lambda_nn = best_lambda
|
|
1425
|
+
original_lambda_nn = lambda_nn
|
|
1426
|
+
|
|
1427
|
+
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
|
|
1428
|
+
# λ_time and λ_unit use 0.0 for uniform weights directly (no conversion needed)
|
|
1429
|
+
if np.isinf(lambda_nn):
|
|
1430
|
+
lambda_nn = 1e10
|
|
1431
|
+
|
|
1432
|
+
# Compute final weights and fit
|
|
1433
|
+
delta = self._compute_joint_weights(
|
|
1434
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
1435
|
+
)
|
|
1436
|
+
|
|
1437
|
+
if lambda_nn >= 1e10:
|
|
1438
|
+
mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta)
|
|
1439
|
+
L = np.zeros((n_periods, n_units))
|
|
1440
|
+
else:
|
|
1441
|
+
mu, alpha, beta, L, tau = self._solve_joint_with_lowrank(
|
|
1442
|
+
Y, D, delta, lambda_nn, self.max_iter, self.tol
|
|
1443
|
+
)
|
|
1444
|
+
|
|
1445
|
+
# ATT is the scalar treatment effect
|
|
1446
|
+
att = tau
|
|
1447
|
+
|
|
1448
|
+
# Compute individual treatment effects for reporting (same τ for all)
|
|
1449
|
+
treatment_effects = {}
|
|
1450
|
+
for t in range(n_periods):
|
|
1451
|
+
for i in range(n_units):
|
|
1452
|
+
if D[t, i] == 1:
|
|
1453
|
+
unit_id = idx_to_unit[i]
|
|
1454
|
+
time_id = idx_to_period[t]
|
|
1455
|
+
treatment_effects[(unit_id, time_id)] = tau
|
|
1456
|
+
|
|
1457
|
+
# Compute effective rank of L
|
|
1458
|
+
_, s, _ = np.linalg.svd(L, full_matrices=False)
|
|
1459
|
+
if s[0] > 0:
|
|
1460
|
+
effective_rank = np.sum(s) / s[0]
|
|
1461
|
+
else:
|
|
1462
|
+
effective_rank = 0.0
|
|
1463
|
+
|
|
1464
|
+
# Bootstrap variance estimation
|
|
1465
|
+
effective_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
1466
|
+
|
|
1467
|
+
se, bootstrap_dist = self._bootstrap_variance_joint(
|
|
1468
|
+
data, outcome, treatment, unit, time,
|
|
1469
|
+
effective_lambda, treated_periods
|
|
1470
|
+
)
|
|
1471
|
+
|
|
1472
|
+
# Compute test statistics
|
|
1473
|
+
if se > 0:
|
|
1474
|
+
t_stat = att / se
|
|
1475
|
+
p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
|
|
1476
|
+
conf_int = compute_confidence_interval(att, se, self.alpha)
|
|
1477
|
+
else:
|
|
1478
|
+
t_stat = np.nan
|
|
1479
|
+
p_value = np.nan
|
|
1480
|
+
conf_int = (np.nan, np.nan)
|
|
1481
|
+
|
|
1482
|
+
# Create results dictionaries
|
|
1483
|
+
unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)}
|
|
1484
|
+
time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)}
|
|
1485
|
+
|
|
1486
|
+
self.results_ = TROPResults(
|
|
1487
|
+
att=float(att),
|
|
1488
|
+
se=float(se),
|
|
1489
|
+
t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat,
|
|
1490
|
+
p_value=float(p_value) if np.isfinite(p_value) else p_value,
|
|
1491
|
+
conf_int=conf_int,
|
|
1492
|
+
n_obs=len(data),
|
|
1493
|
+
n_treated=len(treated_unit_idx),
|
|
1494
|
+
n_control=len(control_unit_idx),
|
|
1495
|
+
n_treated_obs=int(n_treated_obs),
|
|
1496
|
+
unit_effects=unit_effects_dict,
|
|
1497
|
+
time_effects=time_effects_dict,
|
|
1498
|
+
treatment_effects=treatment_effects,
|
|
1499
|
+
lambda_time=lambda_time,
|
|
1500
|
+
lambda_unit=lambda_unit,
|
|
1501
|
+
lambda_nn=original_lambda_nn,
|
|
1502
|
+
factor_matrix=L,
|
|
1503
|
+
effective_rank=effective_rank,
|
|
1504
|
+
loocv_score=best_score,
|
|
1505
|
+
alpha=self.alpha,
|
|
1506
|
+
n_pre_periods=n_pre_periods,
|
|
1507
|
+
n_post_periods=n_post_periods,
|
|
1508
|
+
n_bootstrap=self.n_bootstrap,
|
|
1509
|
+
bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
self.is_fitted_ = True
|
|
1513
|
+
return self.results_
|
|
1514
|
+
|
|
1515
|
+
def _bootstrap_variance_joint(
|
|
1516
|
+
self,
|
|
1517
|
+
data: pd.DataFrame,
|
|
1518
|
+
outcome: str,
|
|
1519
|
+
treatment: str,
|
|
1520
|
+
unit: str,
|
|
1521
|
+
time: str,
|
|
1522
|
+
optimal_lambda: Tuple[float, float, float],
|
|
1523
|
+
treated_periods: int,
|
|
1524
|
+
) -> Tuple[float, np.ndarray]:
|
|
1525
|
+
"""
|
|
1526
|
+
Compute bootstrap standard error for joint method.
|
|
1527
|
+
|
|
1528
|
+
Uses Rust backend when available for parallel bootstrap (5-15x speedup).
|
|
1529
|
+
|
|
1530
|
+
Parameters
|
|
1531
|
+
----------
|
|
1532
|
+
data : pd.DataFrame
|
|
1533
|
+
Original data.
|
|
1534
|
+
outcome : str
|
|
1535
|
+
Outcome column name.
|
|
1536
|
+
treatment : str
|
|
1537
|
+
Treatment column name.
|
|
1538
|
+
unit : str
|
|
1539
|
+
Unit column name.
|
|
1540
|
+
time : str
|
|
1541
|
+
Time column name.
|
|
1542
|
+
optimal_lambda : tuple
|
|
1543
|
+
Optimal tuning parameters.
|
|
1544
|
+
treated_periods : int
|
|
1545
|
+
Number of post-treatment periods.
|
|
1546
|
+
|
|
1547
|
+
Returns
|
|
1548
|
+
-------
|
|
1549
|
+
Tuple[float, np.ndarray]
|
|
1550
|
+
(se, bootstrap_estimates).
|
|
1551
|
+
"""
|
|
1552
|
+
lambda_time, lambda_unit, lambda_nn = optimal_lambda
|
|
1553
|
+
|
|
1554
|
+
# Try Rust backend for parallel bootstrap (5-15x speedup)
|
|
1555
|
+
if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None:
|
|
1556
|
+
try:
|
|
1557
|
+
# Create matrices for Rust function
|
|
1558
|
+
all_units = sorted(data[unit].unique())
|
|
1559
|
+
all_periods = sorted(data[time].unique())
|
|
1560
|
+
|
|
1561
|
+
Y = (
|
|
1562
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1563
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1564
|
+
.values
|
|
1565
|
+
)
|
|
1566
|
+
D = (
|
|
1567
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1568
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1569
|
+
.fillna(0)
|
|
1570
|
+
.astype(np.float64)
|
|
1571
|
+
.values
|
|
1572
|
+
)
|
|
1573
|
+
|
|
1574
|
+
bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint(
|
|
1575
|
+
Y, D,
|
|
1576
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
1577
|
+
self.n_bootstrap, self.max_iter, self.tol,
|
|
1578
|
+
self.seed if self.seed is not None else 0
|
|
1579
|
+
)
|
|
1580
|
+
|
|
1581
|
+
if len(bootstrap_estimates) < 10:
|
|
1582
|
+
warnings.warn(
|
|
1583
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
|
|
1584
|
+
UserWarning
|
|
1585
|
+
)
|
|
1586
|
+
if len(bootstrap_estimates) == 0:
|
|
1587
|
+
return 0.0, np.array([])
|
|
1588
|
+
|
|
1589
|
+
return float(se), np.array(bootstrap_estimates)
|
|
1590
|
+
|
|
1591
|
+
except Exception as e:
|
|
1592
|
+
logger.debug(
|
|
1593
|
+
"Rust bootstrap (joint) failed, falling back to Python: %s", e
|
|
1594
|
+
)
|
|
1595
|
+
|
|
1596
|
+
# Python fallback implementation
|
|
1597
|
+
rng = np.random.default_rng(self.seed)
|
|
1598
|
+
|
|
1599
|
+
# Stratified bootstrap sampling
|
|
1600
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
1601
|
+
treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist())
|
|
1602
|
+
control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist())
|
|
1603
|
+
|
|
1604
|
+
n_treated_units = len(treated_units)
|
|
1605
|
+
n_control_units = len(control_units)
|
|
1606
|
+
|
|
1607
|
+
bootstrap_estimates_list: List[float] = []
|
|
1608
|
+
|
|
1609
|
+
for _ in range(self.n_bootstrap):
|
|
1610
|
+
# Stratified sampling
|
|
1611
|
+
if n_control_units > 0:
|
|
1612
|
+
sampled_control = rng.choice(
|
|
1613
|
+
control_units, size=n_control_units, replace=True
|
|
1614
|
+
)
|
|
1615
|
+
else:
|
|
1616
|
+
sampled_control = np.array([], dtype=object)
|
|
1617
|
+
|
|
1618
|
+
if n_treated_units > 0:
|
|
1619
|
+
sampled_treated = rng.choice(
|
|
1620
|
+
treated_units, size=n_treated_units, replace=True
|
|
1621
|
+
)
|
|
1622
|
+
else:
|
|
1623
|
+
sampled_treated = np.array([], dtype=object)
|
|
1624
|
+
|
|
1625
|
+
sampled_units = np.concatenate([sampled_control, sampled_treated])
|
|
1626
|
+
|
|
1627
|
+
# Create bootstrap sample
|
|
1628
|
+
boot_data = pd.concat([
|
|
1629
|
+
data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
|
|
1630
|
+
for idx, u in enumerate(sampled_units)
|
|
1631
|
+
], ignore_index=True)
|
|
1632
|
+
|
|
1633
|
+
try:
|
|
1634
|
+
tau = self._fit_joint_with_fixed_lambda(
|
|
1635
|
+
boot_data, outcome, treatment, unit, time,
|
|
1636
|
+
optimal_lambda, treated_periods
|
|
1637
|
+
)
|
|
1638
|
+
bootstrap_estimates_list.append(tau)
|
|
1639
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
1640
|
+
continue
|
|
1641
|
+
|
|
1642
|
+
bootstrap_estimates = np.array(bootstrap_estimates_list)
|
|
1643
|
+
|
|
1644
|
+
if len(bootstrap_estimates) < 10:
|
|
1645
|
+
warnings.warn(
|
|
1646
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
|
|
1647
|
+
UserWarning
|
|
1648
|
+
)
|
|
1649
|
+
if len(bootstrap_estimates) == 0:
|
|
1650
|
+
return 0.0, np.array([])
|
|
1651
|
+
|
|
1652
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
1653
|
+
return float(se), bootstrap_estimates
|
|
1654
|
+
|
|
1655
|
+
def _fit_joint_with_fixed_lambda(
|
|
1656
|
+
self,
|
|
1657
|
+
data: pd.DataFrame,
|
|
1658
|
+
outcome: str,
|
|
1659
|
+
treatment: str,
|
|
1660
|
+
unit: str,
|
|
1661
|
+
time: str,
|
|
1662
|
+
fixed_lambda: Tuple[float, float, float],
|
|
1663
|
+
treated_periods: int,
|
|
1664
|
+
) -> float:
|
|
1665
|
+
"""
|
|
1666
|
+
Fit joint model with fixed tuning parameters.
|
|
1667
|
+
|
|
1668
|
+
Returns only the treatment effect τ.
|
|
1669
|
+
"""
|
|
1670
|
+
lambda_time, lambda_unit, lambda_nn = fixed_lambda
|
|
1671
|
+
|
|
1672
|
+
all_units = sorted(data[unit].unique())
|
|
1673
|
+
all_periods = sorted(data[time].unique())
|
|
1674
|
+
|
|
1675
|
+
n_units = len(all_units)
|
|
1676
|
+
n_periods = len(all_periods)
|
|
1677
|
+
|
|
1678
|
+
Y = (
|
|
1679
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1680
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1681
|
+
.values
|
|
1682
|
+
)
|
|
1683
|
+
D = (
|
|
1684
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1685
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1686
|
+
.fillna(0)
|
|
1687
|
+
.astype(int)
|
|
1688
|
+
.values
|
|
1689
|
+
)
|
|
1690
|
+
|
|
1691
|
+
# Compute weights
|
|
1692
|
+
delta = self._compute_joint_weights(
|
|
1693
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
1694
|
+
)
|
|
1695
|
+
|
|
1696
|
+
# Fit model
|
|
1697
|
+
if lambda_nn >= 1e10:
|
|
1698
|
+
_, _, _, tau = self._solve_joint_no_lowrank(Y, D, delta)
|
|
1699
|
+
else:
|
|
1700
|
+
_, _, _, _, tau = self._solve_joint_with_lowrank(
|
|
1701
|
+
Y, D, delta, lambda_nn, self.max_iter, self.tol
|
|
1702
|
+
)
|
|
1703
|
+
|
|
1704
|
+
return tau
|
|
1705
|
+
|
|
1706
|
+
def fit(
|
|
1707
|
+
self,
|
|
1708
|
+
data: pd.DataFrame,
|
|
1709
|
+
outcome: str,
|
|
1710
|
+
treatment: str,
|
|
1711
|
+
unit: str,
|
|
1712
|
+
time: str,
|
|
1713
|
+
) -> TROPResults:
|
|
1714
|
+
"""
|
|
1715
|
+
Fit the TROP model.
|
|
1716
|
+
|
|
1717
|
+
Parameters
|
|
1718
|
+
----------
|
|
1719
|
+
data : pd.DataFrame
|
|
1720
|
+
Panel data with observations for multiple units over multiple
|
|
1721
|
+
time periods.
|
|
1722
|
+
outcome : str
|
|
1723
|
+
Name of the outcome variable column.
|
|
1724
|
+
treatment : str
|
|
1725
|
+
Name of the treatment indicator column (0/1).
|
|
1726
|
+
|
|
1727
|
+
IMPORTANT: This should be an ABSORBING STATE indicator, not a
|
|
1728
|
+
treatment timing indicator. For each unit, D=1 for ALL periods
|
|
1729
|
+
during and after treatment:
|
|
1730
|
+
|
|
1731
|
+
- D[t, i] = 0 for all t < g_i (pre-treatment periods)
|
|
1732
|
+
- D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
|
|
1733
|
+
|
|
1734
|
+
where g_i is the treatment start time for unit i.
|
|
1735
|
+
|
|
1736
|
+
For staggered adoption, different units can have different g_i.
|
|
1737
|
+
The ATT averages over ALL D=1 cells per Equation 1 of the paper.
|
|
1738
|
+
unit : str
|
|
1739
|
+
Name of the unit identifier column.
|
|
1740
|
+
time : str
|
|
1741
|
+
Name of the time period column.
|
|
1742
|
+
|
|
1743
|
+
Returns
|
|
1744
|
+
-------
|
|
1745
|
+
TROPResults
|
|
1746
|
+
Object containing the ATT estimate, standard error,
|
|
1747
|
+
factor estimates, and tuning parameters. The lambda_*
|
|
1748
|
+
attributes show the selected grid values. For λ_time and
|
|
1749
|
+
λ_unit, 0.0 means uniform weights; inf is not accepted.
|
|
1750
|
+
For λ_nn, ∞ is converted to 1e10 (factor model disabled).
|
|
1751
|
+
"""
|
|
1752
|
+
# Validate inputs
|
|
1753
|
+
required_cols = [outcome, treatment, unit, time]
|
|
1754
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
1755
|
+
if missing:
|
|
1756
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
1757
|
+
|
|
1758
|
+
# Dispatch based on estimation method
|
|
1759
|
+
if self.method == "joint":
|
|
1760
|
+
return self._fit_joint(data, outcome, treatment, unit, time)
|
|
1761
|
+
|
|
1762
|
+
# Below is the twostep method (default)
|
|
1763
|
+
# Get unique units and periods
|
|
1764
|
+
all_units = sorted(data[unit].unique())
|
|
1765
|
+
all_periods = sorted(data[time].unique())
|
|
1766
|
+
|
|
1767
|
+
n_units = len(all_units)
|
|
1768
|
+
n_periods = len(all_periods)
|
|
1769
|
+
|
|
1770
|
+
# Create mappings
|
|
1771
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
1772
|
+
period_to_idx = {p: i for i, p in enumerate(all_periods)}
|
|
1773
|
+
idx_to_unit = {i: u for u, i in unit_to_idx.items()}
|
|
1774
|
+
idx_to_period = {i: p for p, i in period_to_idx.items()}
|
|
1775
|
+
|
|
1776
|
+
# Create outcome matrix Y (n_periods x n_units) and treatment matrix D
|
|
1777
|
+
# Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
|
|
1778
|
+
Y = (
|
|
1779
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1780
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1781
|
+
.values
|
|
1782
|
+
)
|
|
1783
|
+
|
|
1784
|
+
# For D matrix, track missing values BEFORE fillna to support unbalanced panels
|
|
1785
|
+
# Issue 3 fix: Missing observations should not trigger spurious violations
|
|
1786
|
+
D_raw = (
|
|
1787
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1788
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1789
|
+
)
|
|
1790
|
+
missing_mask = pd.isna(D_raw).values # True where originally missing
|
|
1791
|
+
D = D_raw.fillna(0).astype(int).values
|
|
1792
|
+
|
|
1793
|
+
# Validate D is monotonic non-decreasing per unit (absorbing state)
|
|
1794
|
+
# D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
|
|
1795
|
+
# Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
|
|
1796
|
+
# This catches 1→0 violations that span missing period gaps
|
|
1797
|
+
# Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
|
|
1798
|
+
# adjacent period transitions don't show it (the gap hides the transition)
|
|
1799
|
+
violating_units = []
|
|
1800
|
+
for unit_idx in range(n_units):
|
|
1801
|
+
# Get observed D values for this unit (where not missing)
|
|
1802
|
+
observed_mask = ~missing_mask[:, unit_idx]
|
|
1803
|
+
observed_d = D[observed_mask, unit_idx]
|
|
1804
|
+
|
|
1805
|
+
# Check if observed sequence is monotonically non-decreasing
|
|
1806
|
+
if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
|
|
1807
|
+
violating_units.append(all_units[unit_idx])
|
|
1808
|
+
|
|
1809
|
+
if violating_units:
|
|
1810
|
+
raise ValueError(
|
|
1811
|
+
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
|
|
1812
|
+
f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
|
|
1813
|
+
f"If this is event-study style data, convert to absorbing state: "
|
|
1814
|
+
f"D[t, i] = 1 for all t >= first treatment period."
|
|
1815
|
+
)
|
|
1816
|
+
|
|
1817
|
+
# Identify treated observations
|
|
1818
|
+
treated_mask = D == 1
|
|
1819
|
+
n_treated_obs = np.sum(treated_mask)
|
|
1820
|
+
|
|
1821
|
+
if n_treated_obs == 0:
|
|
1822
|
+
raise ValueError("No treated observations found")
|
|
1823
|
+
|
|
1824
|
+
# Identify treated and control units
|
|
1825
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
1826
|
+
treated_unit_idx = np.where(unit_ever_treated)[0]
|
|
1827
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
1828
|
+
|
|
1829
|
+
if len(control_unit_idx) == 0:
|
|
1830
|
+
raise ValueError("No control units found")
|
|
1831
|
+
|
|
1832
|
+
# Determine pre/post periods from treatment indicator D
|
|
1833
|
+
# D matrix is the sole input for treatment timing per the paper
|
|
1834
|
+
first_treat_period = None
|
|
1835
|
+
for t in range(n_periods):
|
|
1836
|
+
if np.any(D[t, :] == 1):
|
|
1837
|
+
first_treat_period = t
|
|
1838
|
+
break
|
|
1839
|
+
if first_treat_period is None:
|
|
1840
|
+
raise ValueError("Could not infer post-treatment periods from D matrix")
|
|
1841
|
+
|
|
1842
|
+
n_pre_periods = first_treat_period
|
|
1843
|
+
# Count periods where D=1 is actually observed (matches docstring)
|
|
1844
|
+
# Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
|
|
1845
|
+
n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
|
|
1846
|
+
|
|
1847
|
+
if n_pre_periods < 2:
|
|
1848
|
+
raise ValueError("Need at least 2 pre-treatment periods")
|
|
1849
|
+
|
|
1850
|
+
# Step 1: Grid search with LOOCV for tuning parameters
|
|
1851
|
+
best_lambda = None
|
|
1852
|
+
best_score = np.inf
|
|
1853
|
+
|
|
1854
|
+
# Control observations mask (for LOOCV)
|
|
1855
|
+
control_mask = D == 0
|
|
1856
|
+
|
|
1857
|
+
# Pre-compute structures that are reused across LOOCV iterations
|
|
1858
|
+
self._precomputed = self._precompute_structures(
|
|
1859
|
+
Y, D, control_unit_idx, n_units, n_periods
|
|
1860
|
+
)
|
|
1861
|
+
|
|
1862
|
+
# Use Rust backend for parallel LOOCV grid search (10-50x speedup)
|
|
1863
|
+
if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None:
|
|
1864
|
+
try:
|
|
1865
|
+
# Prepare inputs for Rust function
|
|
1866
|
+
control_mask_u8 = control_mask.astype(np.uint8)
|
|
1867
|
+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
|
|
1868
|
+
|
|
1869
|
+
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
|
|
1870
|
+
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
|
|
1871
|
+
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
|
|
1872
|
+
|
|
1873
|
+
result = _rust_loocv_grid_search(
|
|
1874
|
+
Y, D.astype(np.float64), control_mask_u8,
|
|
1875
|
+
time_dist_matrix,
|
|
1876
|
+
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
|
|
1877
|
+
self.max_iter, self.tol,
|
|
1878
|
+
)
|
|
1879
|
+
# Unpack result - 7 values including optional first_failed_obs
|
|
1880
|
+
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
|
|
1881
|
+
# Only accept finite scores - infinite means all fits failed
|
|
1882
|
+
if np.isfinite(best_score):
|
|
1883
|
+
best_lambda = (best_lt, best_lu, best_ln)
|
|
1884
|
+
# else: best_lambda stays None, triggering defaults fallback
|
|
1885
|
+
# Emit warnings consistent with Python implementation
|
|
1886
|
+
if n_valid == 0:
|
|
1887
|
+
# Include failed observation coordinates if available (Issue 2 fix)
|
|
1888
|
+
obs_info = ""
|
|
1889
|
+
if first_failed_obs is not None:
|
|
1890
|
+
t_idx, i_idx = first_failed_obs
|
|
1891
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
1892
|
+
warnings.warn(
|
|
1893
|
+
f"LOOCV: All {n_attempted} fits failed for "
|
|
1894
|
+
f"λ=({best_lt}, {best_lu}, {best_ln}). "
|
|
1895
|
+
f"Returning infinite score.{obs_info}",
|
|
1896
|
+
UserWarning
|
|
1897
|
+
)
|
|
1898
|
+
elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
|
|
1899
|
+
n_failed = n_attempted - n_valid
|
|
1900
|
+
# Include failed observation coordinates if available
|
|
1901
|
+
obs_info = ""
|
|
1902
|
+
if first_failed_obs is not None:
|
|
1903
|
+
t_idx, i_idx = first_failed_obs
|
|
1904
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
1905
|
+
warnings.warn(
|
|
1906
|
+
f"LOOCV: {n_failed}/{n_attempted} fits failed for "
|
|
1907
|
+
f"λ=({best_lt}, {best_lu}, {best_ln}). "
|
|
1908
|
+
f"This may indicate numerical instability.{obs_info}",
|
|
1909
|
+
UserWarning
|
|
1910
|
+
)
|
|
1911
|
+
except Exception as e:
|
|
1912
|
+
# Fall back to Python implementation on error
|
|
1913
|
+
logger.debug(
|
|
1914
|
+
"Rust LOOCV grid search failed, falling back to Python: %s", e
|
|
1915
|
+
)
|
|
1916
|
+
best_lambda = None
|
|
1917
|
+
best_score = np.inf
|
|
1918
|
+
|
|
1919
|
+
# Fall back to Python implementation if Rust unavailable or failed
|
|
1920
|
+
# Uses two-stage approach per paper's footnote 2:
|
|
1921
|
+
# Stage 1: Univariate searches for initial values
|
|
1922
|
+
# Stage 2: Cycling (coordinate descent) until convergence
|
|
1923
|
+
if best_lambda is None:
|
|
1924
|
+
# Stage 1: Univariate searches with extreme fixed values
|
|
1925
|
+
# Following paper's footnote 2 for initial bounds
|
|
1926
|
+
|
|
1927
|
+
# λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
|
|
1928
|
+
lambda_time_init, _ = self._univariate_loocv_search(
|
|
1929
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1930
|
+
'lambda_time', self.lambda_time_grid,
|
|
1931
|
+
{'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF}
|
|
1932
|
+
)
|
|
1933
|
+
|
|
1934
|
+
# λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0
|
|
1935
|
+
lambda_nn_init, _ = self._univariate_loocv_search(
|
|
1936
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1937
|
+
'lambda_nn', self.lambda_nn_grid,
|
|
1938
|
+
{'lambda_time': 0.0, 'lambda_unit': 0.0}
|
|
1939
|
+
)
|
|
1940
|
+
|
|
1941
|
+
# λ_unit search: fix λ_nn=∞, λ_time=0
|
|
1942
|
+
lambda_unit_init, _ = self._univariate_loocv_search(
|
|
1943
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1944
|
+
'lambda_unit', self.lambda_unit_grid,
|
|
1945
|
+
{'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0}
|
|
1946
|
+
)
|
|
1947
|
+
|
|
1948
|
+
# Stage 2: Cycling refinement (coordinate descent)
|
|
1949
|
+
lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
|
|
1950
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1951
|
+
(lambda_time_init, lambda_unit_init, lambda_nn_init)
|
|
1952
|
+
)
|
|
1953
|
+
|
|
1954
|
+
# Compute final score for the optimized parameters
|
|
1955
|
+
try:
|
|
1956
|
+
best_score = self._loocv_score_obs_specific(
|
|
1957
|
+
Y, D, control_mask, control_unit_idx,
|
|
1958
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
1959
|
+
n_units, n_periods
|
|
1960
|
+
)
|
|
1961
|
+
# Only accept finite scores - infinite means all fits failed
|
|
1962
|
+
if np.isfinite(best_score):
|
|
1963
|
+
best_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
1964
|
+
# else: best_lambda stays None, triggering defaults fallback
|
|
1965
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1966
|
+
# If even the optimized parameters fail, best_lambda stays None
|
|
1967
|
+
pass
|
|
1968
|
+
|
|
1969
|
+
if best_lambda is None:
|
|
1970
|
+
warnings.warn(
|
|
1971
|
+
"All tuning parameter combinations failed. Using defaults.",
|
|
1972
|
+
UserWarning
|
|
1973
|
+
)
|
|
1974
|
+
best_lambda = (1.0, 1.0, 0.1)
|
|
1975
|
+
best_score = np.nan
|
|
1976
|
+
|
|
1977
|
+
self._optimal_lambda = best_lambda
|
|
1978
|
+
lambda_time, lambda_unit, lambda_nn = best_lambda
|
|
1979
|
+
|
|
1980
|
+
# Store original λ_nn for results (only λ_nn needs original→effective conversion).
|
|
1981
|
+
# λ_time and λ_unit use 0.0 for uniform weights directly per Eq. 3.
|
|
1982
|
+
original_lambda_nn = lambda_nn
|
|
1983
|
+
|
|
1984
|
+
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
|
|
1985
|
+
if np.isinf(lambda_nn):
|
|
1986
|
+
lambda_nn = 1e10
|
|
1987
|
+
|
|
1988
|
+
# effective_lambda with converted λ_nn for ALL downstream computation
|
|
1989
|
+
# (variance estimation uses the same parameters as point estimation)
|
|
1990
|
+
effective_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
1991
|
+
|
|
1992
|
+
# Step 2: Final estimation - per-observation model fitting following Algorithm 2
|
|
1993
|
+
# For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
|
|
1994
|
+
treatment_effects = {}
|
|
1995
|
+
tau_values = []
|
|
1996
|
+
alpha_estimates = []
|
|
1997
|
+
beta_estimates = []
|
|
1998
|
+
L_estimates = []
|
|
1999
|
+
|
|
2000
|
+
# Use pre-computed treated observations
|
|
2001
|
+
treated_observations = self._precomputed["treated_observations"]
|
|
2002
|
+
|
|
2003
|
+
for t, i in treated_observations:
|
|
2004
|
+
# Compute observation-specific weights for this (i, t)
|
|
2005
|
+
weight_matrix = self._compute_observation_weights(
|
|
2006
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
|
|
2007
|
+
n_units, n_periods
|
|
2008
|
+
)
|
|
2009
|
+
|
|
2010
|
+
# Fit model with these weights
|
|
2011
|
+
alpha_hat, beta_hat, L_hat = self._estimate_model(
|
|
2012
|
+
Y, control_mask, weight_matrix, lambda_nn,
|
|
2013
|
+
n_units, n_periods
|
|
2014
|
+
)
|
|
2015
|
+
|
|
2016
|
+
# Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
|
|
2017
|
+
tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
|
|
2018
|
+
|
|
2019
|
+
unit_id = idx_to_unit[i]
|
|
2020
|
+
time_id = idx_to_period[t]
|
|
2021
|
+
treatment_effects[(unit_id, time_id)] = tau_it
|
|
2022
|
+
tau_values.append(tau_it)
|
|
2023
|
+
|
|
2024
|
+
# Store for averaging
|
|
2025
|
+
alpha_estimates.append(alpha_hat)
|
|
2026
|
+
beta_estimates.append(beta_hat)
|
|
2027
|
+
L_estimates.append(L_hat)
|
|
2028
|
+
|
|
2029
|
+
# Average ATT
|
|
2030
|
+
att = np.mean(tau_values)
|
|
2031
|
+
|
|
2032
|
+
# Average parameter estimates for output (representative)
|
|
2033
|
+
alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
|
|
2034
|
+
beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
|
|
2035
|
+
L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
|
|
2036
|
+
|
|
2037
|
+
# Compute effective rank
|
|
2038
|
+
_, s, _ = np.linalg.svd(L_hat, full_matrices=False)
|
|
2039
|
+
if s[0] > 0:
|
|
2040
|
+
effective_rank = np.sum(s) / s[0]
|
|
2041
|
+
else:
|
|
2042
|
+
effective_rank = 0.0
|
|
2043
|
+
|
|
2044
|
+
# Step 4: Variance estimation
|
|
2045
|
+
# Use effective_lambda (converted values) to ensure SE is computed with same
|
|
2046
|
+
# parameters as point estimation. This fixes the variance inconsistency issue.
|
|
2047
|
+
se, bootstrap_dist = self._bootstrap_variance(
|
|
2048
|
+
data, outcome, treatment, unit, time,
|
|
2049
|
+
effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
|
|
2050
|
+
)
|
|
2051
|
+
|
|
2052
|
+
# Compute test statistics
|
|
2053
|
+
if se > 0:
|
|
2054
|
+
t_stat = att / se
|
|
2055
|
+
p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
|
|
2056
|
+
conf_int = compute_confidence_interval(att, se, self.alpha)
|
|
2057
|
+
else:
|
|
2058
|
+
# When SE is undefined/zero, ALL inference fields should be NaN
|
|
2059
|
+
t_stat = np.nan
|
|
2060
|
+
p_value = np.nan
|
|
2061
|
+
conf_int = (np.nan, np.nan)
|
|
2062
|
+
|
|
2063
|
+
# Create results dictionaries
|
|
2064
|
+
unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
|
|
2065
|
+
time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
|
|
2066
|
+
|
|
2067
|
+
# Store results
|
|
2068
|
+
self.results_ = TROPResults(
|
|
2069
|
+
att=att,
|
|
2070
|
+
se=se,
|
|
2071
|
+
t_stat=t_stat,
|
|
2072
|
+
p_value=p_value,
|
|
2073
|
+
conf_int=conf_int,
|
|
2074
|
+
n_obs=len(data),
|
|
2075
|
+
n_treated=len(treated_unit_idx),
|
|
2076
|
+
n_control=len(control_unit_idx),
|
|
2077
|
+
n_treated_obs=n_treated_obs,
|
|
2078
|
+
unit_effects=unit_effects_dict,
|
|
2079
|
+
time_effects=time_effects_dict,
|
|
2080
|
+
treatment_effects=treatment_effects,
|
|
2081
|
+
lambda_time=lambda_time,
|
|
2082
|
+
lambda_unit=lambda_unit,
|
|
2083
|
+
lambda_nn=original_lambda_nn,
|
|
2084
|
+
factor_matrix=L_hat,
|
|
2085
|
+
effective_rank=effective_rank,
|
|
2086
|
+
loocv_score=best_score,
|
|
2087
|
+
alpha=self.alpha,
|
|
2088
|
+
n_pre_periods=n_pre_periods,
|
|
2089
|
+
n_post_periods=n_post_periods,
|
|
2090
|
+
n_bootstrap=self.n_bootstrap,
|
|
2091
|
+
bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
|
|
2092
|
+
)
|
|
2093
|
+
|
|
2094
|
+
self.is_fitted_ = True
|
|
2095
|
+
return self.results_
|
|
2096
|
+
|
|
2097
|
+
def _compute_observation_weights(
|
|
2098
|
+
self,
|
|
2099
|
+
Y: np.ndarray,
|
|
2100
|
+
D: np.ndarray,
|
|
2101
|
+
i: int,
|
|
2102
|
+
t: int,
|
|
2103
|
+
lambda_time: float,
|
|
2104
|
+
lambda_unit: float,
|
|
2105
|
+
control_unit_idx: np.ndarray,
|
|
2106
|
+
n_units: int,
|
|
2107
|
+
n_periods: int,
|
|
2108
|
+
) -> np.ndarray:
|
|
2109
|
+
"""
|
|
2110
|
+
Compute observation-specific weight matrix for treated observation (i, t).
|
|
2111
|
+
|
|
2112
|
+
Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
|
|
2113
|
+
- Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
|
|
2114
|
+
- Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
|
|
2115
|
+
|
|
2116
|
+
IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
|
|
2117
|
+
where (1 - W_js) is non-zero, which includes pre-treatment observations of
|
|
2118
|
+
eventually-treated units since W_js = 0 for those. This method computes
|
|
2119
|
+
weights for ALL units where D[t, j] = 0 at the target period, not just
|
|
2120
|
+
never-treated units.
|
|
2121
|
+
|
|
2122
|
+
Uses pre-computed structures when available for efficiency.
|
|
2123
|
+
|
|
2124
|
+
Parameters
|
|
2125
|
+
----------
|
|
2126
|
+
Y : np.ndarray
|
|
2127
|
+
Outcome matrix (n_periods x n_units).
|
|
2128
|
+
D : np.ndarray
|
|
2129
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
2130
|
+
i : int
|
|
2131
|
+
Treated unit index.
|
|
2132
|
+
t : int
|
|
2133
|
+
Treatment period index.
|
|
2134
|
+
lambda_time : float
|
|
2135
|
+
Time weight decay parameter.
|
|
2136
|
+
lambda_unit : float
|
|
2137
|
+
Unit weight decay parameter.
|
|
2138
|
+
control_unit_idx : np.ndarray
|
|
2139
|
+
Indices of never-treated units (for backward compatibility, but not
|
|
2140
|
+
used for weight computation - we use D matrix directly).
|
|
2141
|
+
n_units : int
|
|
2142
|
+
Number of units.
|
|
2143
|
+
n_periods : int
|
|
2144
|
+
Number of periods.
|
|
2145
|
+
|
|
2146
|
+
Returns
|
|
2147
|
+
-------
|
|
2148
|
+
np.ndarray
|
|
2149
|
+
Weight matrix (n_periods x n_units) for observation (i, t).
|
|
2150
|
+
"""
|
|
2151
|
+
# Use pre-computed structures when available
|
|
2152
|
+
if self._precomputed is not None:
|
|
2153
|
+
# Time weights from pre-computed time distance matrix
|
|
2154
|
+
# time_dist_matrix[t, s] = |t - s|
|
|
2155
|
+
time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
|
|
2156
|
+
|
|
2157
|
+
# Unit weights - computed for ALL units where D[t, j] = 0
|
|
2158
|
+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
|
|
2159
|
+
unit_weights = np.zeros(n_units)
|
|
2160
|
+
D_stored = self._precomputed["D"]
|
|
2161
|
+
Y_stored = self._precomputed["Y"]
|
|
2162
|
+
|
|
2163
|
+
# Valid control units at time t: D[t, j] == 0
|
|
2164
|
+
valid_control_at_t = D_stored[t, :] == 0
|
|
2165
|
+
|
|
2166
|
+
if lambda_unit == 0:
|
|
2167
|
+
# Uniform weights when lambda_unit = 0
|
|
2168
|
+
# All units not treated at time t get weight 1
|
|
2169
|
+
unit_weights[valid_control_at_t] = 1.0
|
|
2170
|
+
else:
|
|
2171
|
+
# Use observation-specific distances with target period excluded
|
|
2172
|
+
# (Issue B fix: compute exact per-observation distance)
|
|
2173
|
+
for j in range(n_units):
|
|
2174
|
+
if valid_control_at_t[j] and j != i:
|
|
2175
|
+
# Compute distance excluding target period t
|
|
2176
|
+
dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
|
|
2177
|
+
if np.isinf(dist):
|
|
2178
|
+
unit_weights[j] = 0.0
|
|
2179
|
+
else:
|
|
2180
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
2181
|
+
|
|
2182
|
+
# Treated unit i gets weight 1
|
|
2183
|
+
unit_weights[i] = 1.0
|
|
2184
|
+
|
|
2185
|
+
# Weight matrix: outer product (n_periods x n_units)
|
|
2186
|
+
return np.outer(time_weights, unit_weights)
|
|
2187
|
+
|
|
2188
|
+
# Fallback: compute from scratch (used in bootstrap)
|
|
2189
|
+
# Time distance: |t - s| following paper's Equation 3 (page 7)
|
|
2190
|
+
dist_time = np.abs(np.arange(n_periods) - t)
|
|
2191
|
+
time_weights = np.exp(-lambda_time * dist_time)
|
|
2192
|
+
|
|
2193
|
+
# Unit weights - computed for ALL units where D[t, j] = 0
|
|
2194
|
+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
|
|
2195
|
+
unit_weights = np.zeros(n_units)
|
|
2196
|
+
|
|
2197
|
+
# Valid control units at time t: D[t, j] == 0
|
|
2198
|
+
valid_control_at_t = D[t, :] == 0
|
|
2199
|
+
|
|
2200
|
+
if lambda_unit == 0:
|
|
2201
|
+
# Uniform weights when lambda_unit = 0
|
|
2202
|
+
unit_weights[valid_control_at_t] = 1.0
|
|
2203
|
+
else:
|
|
2204
|
+
for j in range(n_units):
|
|
2205
|
+
if valid_control_at_t[j] and j != i:
|
|
2206
|
+
# Compute distance excluding target period t (Issue B fix)
|
|
2207
|
+
dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
|
|
2208
|
+
if np.isinf(dist):
|
|
2209
|
+
unit_weights[j] = 0.0
|
|
2210
|
+
else:
|
|
2211
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
2212
|
+
|
|
2213
|
+
# Treated unit i gets weight 1 (or could be omitted since we fit on controls)
|
|
2214
|
+
# We include treated unit's own observation for model fitting
|
|
2215
|
+
unit_weights[i] = 1.0
|
|
2216
|
+
|
|
2217
|
+
# Weight matrix: outer product (n_periods x n_units)
|
|
2218
|
+
W = np.outer(time_weights, unit_weights)
|
|
2219
|
+
|
|
2220
|
+
return W
|
|
2221
|
+
|
|
2222
|
+
def _soft_threshold_svd(
|
|
2223
|
+
self,
|
|
2224
|
+
M: np.ndarray,
|
|
2225
|
+
threshold: float,
|
|
2226
|
+
) -> np.ndarray:
|
|
2227
|
+
"""
|
|
2228
|
+
Apply soft-thresholding to singular values (proximal operator for nuclear norm).
|
|
2229
|
+
|
|
2230
|
+
Parameters
|
|
2231
|
+
----------
|
|
2232
|
+
M : np.ndarray
|
|
2233
|
+
Input matrix.
|
|
2234
|
+
threshold : float
|
|
2235
|
+
Soft-thresholding parameter.
|
|
2236
|
+
|
|
2237
|
+
Returns
|
|
2238
|
+
-------
|
|
2239
|
+
np.ndarray
|
|
2240
|
+
Matrix with soft-thresholded singular values.
|
|
2241
|
+
"""
|
|
2242
|
+
if threshold <= 0:
|
|
2243
|
+
return M
|
|
2244
|
+
|
|
2245
|
+
# Handle NaN/Inf values in input
|
|
2246
|
+
if not np.isfinite(M).all():
|
|
2247
|
+
M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
|
|
2248
|
+
|
|
2249
|
+
try:
|
|
2250
|
+
U, s, Vt = np.linalg.svd(M, full_matrices=False)
|
|
2251
|
+
except np.linalg.LinAlgError:
|
|
2252
|
+
# SVD failed, return zero matrix
|
|
2253
|
+
return np.zeros_like(M)
|
|
2254
|
+
|
|
2255
|
+
# Check for numerical issues in SVD output
|
|
2256
|
+
if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
|
|
2257
|
+
# SVD produced non-finite values, return zero matrix
|
|
2258
|
+
return np.zeros_like(M)
|
|
2259
|
+
|
|
2260
|
+
s_thresh = np.maximum(s - threshold, 0)
|
|
2261
|
+
|
|
2262
|
+
# Use truncated reconstruction with only non-zero singular values
|
|
2263
|
+
nonzero_mask = s_thresh > self.CONVERGENCE_TOL_SVD
|
|
2264
|
+
if not np.any(nonzero_mask):
|
|
2265
|
+
return np.zeros_like(M)
|
|
2266
|
+
|
|
2267
|
+
# Truncate to non-zero components for numerical stability
|
|
2268
|
+
U_trunc = U[:, nonzero_mask]
|
|
2269
|
+
s_trunc = s_thresh[nonzero_mask]
|
|
2270
|
+
Vt_trunc = Vt[nonzero_mask, :]
|
|
2271
|
+
|
|
2272
|
+
# Compute result, suppressing expected numerical warnings from
|
|
2273
|
+
# ill-conditioned matrices during alternating minimization
|
|
2274
|
+
with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
|
|
2275
|
+
result = (U_trunc * s_trunc) @ Vt_trunc
|
|
2276
|
+
|
|
2277
|
+
# Replace any NaN/Inf in result with zeros
|
|
2278
|
+
if not np.isfinite(result).all():
|
|
2279
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
2280
|
+
|
|
2281
|
+
return result
|
|
2282
|
+
|
|
2283
|
+
def _weighted_nuclear_norm_solve(
|
|
2284
|
+
self,
|
|
2285
|
+
Y: np.ndarray,
|
|
2286
|
+
W: np.ndarray,
|
|
2287
|
+
L_init: np.ndarray,
|
|
2288
|
+
alpha: np.ndarray,
|
|
2289
|
+
beta: np.ndarray,
|
|
2290
|
+
lambda_nn: float,
|
|
2291
|
+
max_inner_iter: int = 20,
|
|
2292
|
+
) -> np.ndarray:
|
|
2293
|
+
"""
|
|
2294
|
+
Solve weighted nuclear norm problem using iterative weighted soft-impute.
|
|
2295
|
+
|
|
2296
|
+
Issue C fix: Implements the weighted nuclear norm optimization from the
|
|
2297
|
+
paper's Equation 2 (page 7). The full objective is:
|
|
2298
|
+
min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
|
|
2299
|
+
|
|
2300
|
+
This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
|
|
2301
|
+
L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
|
|
2302
|
+
|
|
2303
|
+
where W ⊙ denotes element-wise multiplication with normalized weights.
|
|
2304
|
+
|
|
2305
|
+
IMPORTANT: For observations with W=0 (treated observations), we keep
|
|
2306
|
+
L values from the previous iteration rather than setting L = R, which
|
|
2307
|
+
would absorb the treatment effect.
|
|
2308
|
+
|
|
2309
|
+
Parameters
|
|
2310
|
+
----------
|
|
2311
|
+
Y : np.ndarray
|
|
2312
|
+
Outcome matrix (n_periods x n_units).
|
|
2313
|
+
W : np.ndarray
|
|
2314
|
+
Weight matrix (n_periods x n_units), non-negative. W=0 indicates
|
|
2315
|
+
observations that should not be used for fitting (treated obs).
|
|
2316
|
+
L_init : np.ndarray
|
|
2317
|
+
Initial estimate of L matrix.
|
|
2318
|
+
alpha : np.ndarray
|
|
2319
|
+
Current unit fixed effects estimate.
|
|
2320
|
+
beta : np.ndarray
|
|
2321
|
+
Current time fixed effects estimate.
|
|
2322
|
+
lambda_nn : float
|
|
2323
|
+
Nuclear norm regularization parameter.
|
|
2324
|
+
max_inner_iter : int, default=20
|
|
2325
|
+
Maximum inner iterations for the proximal algorithm.
|
|
2326
|
+
|
|
2327
|
+
Returns
|
|
2328
|
+
-------
|
|
2329
|
+
np.ndarray
|
|
2330
|
+
Updated L matrix estimate.
|
|
2331
|
+
"""
|
|
2332
|
+
# Compute target residual R = Y - α - β
|
|
2333
|
+
R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
|
|
2334
|
+
|
|
2335
|
+
# Handle invalid values
|
|
2336
|
+
R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
|
|
2337
|
+
|
|
2338
|
+
# For observations with W=0 (treated obs), keep L_init instead of R
|
|
2339
|
+
# This prevents L from absorbing the treatment effect
|
|
2340
|
+
valid_obs_mask = W > 0
|
|
2341
|
+
R_masked = np.where(valid_obs_mask, R, L_init)
|
|
2342
|
+
|
|
2343
|
+
if lambda_nn <= 0:
|
|
2344
|
+
# No regularization - just return masked residual
|
|
2345
|
+
# Use soft-thresholding with threshold=0 which returns the input
|
|
2346
|
+
return R_masked
|
|
2347
|
+
|
|
2348
|
+
# Normalize weights so max is 1 (for step size stability)
|
|
2349
|
+
W_max = np.max(W)
|
|
2350
|
+
if W_max > 0:
|
|
2351
|
+
W_norm = W / W_max
|
|
2352
|
+
else:
|
|
2353
|
+
W_norm = W
|
|
2354
|
+
|
|
2355
|
+
# Initialize L
|
|
2356
|
+
L = L_init.copy()
|
|
2357
|
+
|
|
2358
|
+
# Proximal gradient iteration with weighted soft-impute
|
|
2359
|
+
# This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
|
|
2360
|
+
# Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
|
|
2361
|
+
# where η is the step size (we use η = 1 with normalized weights)
|
|
2362
|
+
for _ in range(max_inner_iter):
|
|
2363
|
+
L_old = L.copy()
|
|
2364
|
+
|
|
2365
|
+
# Gradient step: L_k + W ⊙ (R - L_k)
|
|
2366
|
+
# For W=0 observations, this keeps L_k unchanged
|
|
2367
|
+
gradient_step = L + W_norm * (R_masked - L)
|
|
2368
|
+
|
|
2369
|
+
# Proximal step: soft-threshold singular values
|
|
2370
|
+
L = self._soft_threshold_svd(gradient_step, lambda_nn)
|
|
2371
|
+
|
|
2372
|
+
# Check convergence
|
|
2373
|
+
if np.max(np.abs(L - L_old)) < self.tol:
|
|
2374
|
+
break
|
|
2375
|
+
|
|
2376
|
+
return L
|
|
2377
|
+
|
|
2378
|
+
def _estimate_model(
|
|
2379
|
+
self,
|
|
2380
|
+
Y: np.ndarray,
|
|
2381
|
+
control_mask: np.ndarray,
|
|
2382
|
+
weight_matrix: np.ndarray,
|
|
2383
|
+
lambda_nn: float,
|
|
2384
|
+
n_units: int,
|
|
2385
|
+
n_periods: int,
|
|
2386
|
+
exclude_obs: Optional[Tuple[int, int]] = None,
|
|
2387
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
2388
|
+
"""
|
|
2389
|
+
Estimate the model: Y = α + β + L + τD + ε with nuclear norm penalty on L.
|
|
2390
|
+
|
|
2391
|
+
Uses alternating minimization with vectorized operations:
|
|
2392
|
+
1. Fix L, solve for α, β via weighted means
|
|
2393
|
+
2. Fix α, β, solve for L via soft-thresholding
|
|
2394
|
+
|
|
2395
|
+
Parameters
|
|
2396
|
+
----------
|
|
2397
|
+
Y : np.ndarray
|
|
2398
|
+
Outcome matrix (n_periods x n_units).
|
|
2399
|
+
control_mask : np.ndarray
|
|
2400
|
+
Boolean mask for control observations.
|
|
2401
|
+
weight_matrix : np.ndarray
|
|
2402
|
+
Pre-computed global weight matrix (n_periods x n_units).
|
|
2403
|
+
lambda_nn : float
|
|
2404
|
+
Nuclear norm regularization parameter.
|
|
2405
|
+
n_units : int
|
|
2406
|
+
Number of units.
|
|
2407
|
+
n_periods : int
|
|
2408
|
+
Number of periods.
|
|
2409
|
+
exclude_obs : tuple, optional
|
|
2410
|
+
(t, i) observation to exclude (for LOOCV).
|
|
2411
|
+
|
|
2412
|
+
Returns
|
|
2413
|
+
-------
|
|
2414
|
+
tuple
|
|
2415
|
+
(alpha, beta, L) estimated parameters.
|
|
2416
|
+
"""
|
|
2417
|
+
W = weight_matrix
|
|
2418
|
+
|
|
2419
|
+
# Mask for estimation (control obs only, excluding LOOCV obs if specified)
|
|
2420
|
+
est_mask = control_mask.copy()
|
|
2421
|
+
if exclude_obs is not None:
|
|
2422
|
+
t_ex, i_ex = exclude_obs
|
|
2423
|
+
est_mask[t_ex, i_ex] = False
|
|
2424
|
+
|
|
2425
|
+
# Handle missing values
|
|
2426
|
+
valid_mask = ~np.isnan(Y) & est_mask
|
|
2427
|
+
|
|
2428
|
+
# Initialize
|
|
2429
|
+
alpha = np.zeros(n_units)
|
|
2430
|
+
beta = np.zeros(n_periods)
|
|
2431
|
+
L = np.zeros((n_periods, n_units))
|
|
2432
|
+
|
|
2433
|
+
# Pre-compute masked weights for vectorized operations
|
|
2434
|
+
# Set weights to 0 where not valid
|
|
2435
|
+
W_masked = W * valid_mask
|
|
2436
|
+
|
|
2437
|
+
# Pre-compute weight sums per unit and per time (for denominator)
|
|
2438
|
+
# shape: (n_units,) and (n_periods,)
|
|
2439
|
+
weight_sum_per_unit = np.sum(W_masked, axis=0) # sum over periods
|
|
2440
|
+
weight_sum_per_time = np.sum(W_masked, axis=1) # sum over units
|
|
2441
|
+
|
|
2442
|
+
# Handle units/periods with zero weight sum
|
|
2443
|
+
unit_has_obs = weight_sum_per_unit > 0
|
|
2444
|
+
time_has_obs = weight_sum_per_time > 0
|
|
2445
|
+
|
|
2446
|
+
# Create safe denominators (avoid division by zero)
|
|
2447
|
+
safe_unit_denom = np.where(unit_has_obs, weight_sum_per_unit, 1.0)
|
|
2448
|
+
safe_time_denom = np.where(time_has_obs, weight_sum_per_time, 1.0)
|
|
2449
|
+
|
|
2450
|
+
# Replace NaN in Y with 0 for computation (mask handles exclusion)
|
|
2451
|
+
Y_safe = np.where(np.isnan(Y), 0.0, Y)
|
|
2452
|
+
|
|
2453
|
+
# Alternating minimization following Algorithm 1 (page 9)
|
|
2454
|
+
# Minimize: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
|
|
2455
|
+
for _ in range(self.max_iter):
|
|
2456
|
+
alpha_old = alpha.copy()
|
|
2457
|
+
beta_old = beta.copy()
|
|
2458
|
+
L_old = L.copy()
|
|
2459
|
+
|
|
2460
|
+
# Step 1: Update α and β (weighted least squares)
|
|
2461
|
+
# Following Equation 2 (page 7), fix L and solve for α, β
|
|
2462
|
+
# R = Y - L (residual without fixed effects)
|
|
2463
|
+
R = Y_safe - L
|
|
2464
|
+
|
|
2465
|
+
# Alpha update (unit fixed effects):
|
|
2466
|
+
# α_i = argmin_α Σ_t W_{ti}(R_{ti} - α - β_t)²
|
|
2467
|
+
# Solution: α_i = Σ_t W_{ti}(R_{ti} - β_t) / Σ_t W_{ti}
|
|
2468
|
+
R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units)
|
|
2469
|
+
weighted_R_minus_beta = W_masked * R_minus_beta
|
|
2470
|
+
alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,)
|
|
2471
|
+
alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0)
|
|
2472
|
+
|
|
2473
|
+
# Beta update (time fixed effects):
|
|
2474
|
+
# β_t = argmin_β Σ_i W_{ti}(R_{ti} - α_i - β)²
|
|
2475
|
+
# Solution: β_t = Σ_i W_{ti}(R_{ti} - α_i) / Σ_i W_{ti}
|
|
2476
|
+
R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units)
|
|
2477
|
+
weighted_R_minus_alpha = W_masked * R_minus_alpha
|
|
2478
|
+
beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
|
|
2479
|
+
beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
|
|
2480
|
+
|
|
2481
|
+
# Step 2: Update L with weighted nuclear norm penalty
|
|
2482
|
+
# Issue C fix: Use weighted soft-impute to properly account for
|
|
2483
|
+
# observation weights in the nuclear norm optimization.
|
|
2484
|
+
# Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
|
|
2485
|
+
L = self._weighted_nuclear_norm_solve(
|
|
2486
|
+
Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
|
|
2487
|
+
)
|
|
2488
|
+
|
|
2489
|
+
# Check convergence
|
|
2490
|
+
alpha_diff = np.max(np.abs(alpha - alpha_old))
|
|
2491
|
+
beta_diff = np.max(np.abs(beta - beta_old))
|
|
2492
|
+
L_diff = np.max(np.abs(L - L_old))
|
|
2493
|
+
|
|
2494
|
+
if max(alpha_diff, beta_diff, L_diff) < self.tol:
|
|
2495
|
+
break
|
|
2496
|
+
|
|
2497
|
+
return alpha, beta, L
|
|
2498
|
+
|
|
2499
|
+
def _loocv_score_obs_specific(
|
|
2500
|
+
self,
|
|
2501
|
+
Y: np.ndarray,
|
|
2502
|
+
D: np.ndarray,
|
|
2503
|
+
control_mask: np.ndarray,
|
|
2504
|
+
control_unit_idx: np.ndarray,
|
|
2505
|
+
lambda_time: float,
|
|
2506
|
+
lambda_unit: float,
|
|
2507
|
+
lambda_nn: float,
|
|
2508
|
+
n_units: int,
|
|
2509
|
+
n_periods: int,
|
|
2510
|
+
) -> float:
|
|
2511
|
+
"""
|
|
2512
|
+
Compute leave-one-out cross-validation score with observation-specific weights.
|
|
2513
|
+
|
|
2514
|
+
Following the paper's Equation 5 (page 8):
|
|
2515
|
+
Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
2516
|
+
|
|
2517
|
+
For each control observation (j, s), treat it as pseudo-treated,
|
|
2518
|
+
compute observation-specific weights, fit model excluding (j, s),
|
|
2519
|
+
and sum squared pseudo-treatment effects.
|
|
2520
|
+
|
|
2521
|
+
Uses pre-computed structures when available for efficiency.
|
|
2522
|
+
|
|
2523
|
+
Parameters
|
|
2524
|
+
----------
|
|
2525
|
+
Y : np.ndarray
|
|
2526
|
+
Outcome matrix (n_periods x n_units).
|
|
2527
|
+
D : np.ndarray
|
|
2528
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
2529
|
+
control_mask : np.ndarray
|
|
2530
|
+
Boolean mask for control observations.
|
|
2531
|
+
control_unit_idx : np.ndarray
|
|
2532
|
+
Indices of control units.
|
|
2533
|
+
lambda_time : float
|
|
2534
|
+
Time weight decay parameter.
|
|
2535
|
+
lambda_unit : float
|
|
2536
|
+
Unit weight decay parameter.
|
|
2537
|
+
lambda_nn : float
|
|
2538
|
+
Nuclear norm regularization parameter.
|
|
2539
|
+
n_units : int
|
|
2540
|
+
Number of units.
|
|
2541
|
+
n_periods : int
|
|
2542
|
+
Number of periods.
|
|
2543
|
+
|
|
2544
|
+
Returns
|
|
2545
|
+
-------
|
|
2546
|
+
float
|
|
2547
|
+
LOOCV score (lower is better).
|
|
2548
|
+
"""
|
|
2549
|
+
# Use pre-computed control observations if available
|
|
2550
|
+
if self._precomputed is not None:
|
|
2551
|
+
control_obs = self._precomputed["control_obs"]
|
|
2552
|
+
else:
|
|
2553
|
+
# Get all control observations
|
|
2554
|
+
control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
2555
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])]
|
|
2556
|
+
|
|
2557
|
+
# Empty control set check: if no control observations, return infinity
|
|
2558
|
+
# A score of 0.0 would incorrectly "win" over legitimate parameters
|
|
2559
|
+
if len(control_obs) == 0:
|
|
2560
|
+
warnings.warn(
|
|
2561
|
+
f"LOOCV: No valid control observations for "
|
|
2562
|
+
f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
|
|
2563
|
+
"Returning infinite score.",
|
|
2564
|
+
UserWarning
|
|
2565
|
+
)
|
|
2566
|
+
return np.inf
|
|
2567
|
+
|
|
2568
|
+
tau_squared_sum = 0.0
|
|
2569
|
+
n_valid = 0
|
|
2570
|
+
|
|
2571
|
+
for t, i in control_obs:
|
|
2572
|
+
try:
|
|
2573
|
+
# Compute observation-specific weights for pseudo-treated (i, t)
|
|
2574
|
+
# Uses pre-computed distance matrices when available
|
|
2575
|
+
weight_matrix = self._compute_observation_weights(
|
|
2576
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
|
|
2577
|
+
n_units, n_periods
|
|
2578
|
+
)
|
|
2579
|
+
|
|
2580
|
+
# Estimate model excluding observation (t, i)
|
|
2581
|
+
alpha, beta, L = self._estimate_model(
|
|
2582
|
+
Y, control_mask, weight_matrix, lambda_nn,
|
|
2583
|
+
n_units, n_periods, exclude_obs=(t, i)
|
|
2584
|
+
)
|
|
2585
|
+
|
|
2586
|
+
# Pseudo treatment effect
|
|
2587
|
+
tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
2588
|
+
tau_squared_sum += tau_ti ** 2
|
|
2589
|
+
n_valid += 1
|
|
2590
|
+
|
|
2591
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
2592
|
+
# Per Equation 5: Q(λ) must sum over ALL D==0 cells
|
|
2593
|
+
# Any failure means this λ cannot produce valid estimates for all cells
|
|
2594
|
+
warnings.warn(
|
|
2595
|
+
f"LOOCV: Fit failed for observation ({t}, {i}) with "
|
|
2596
|
+
f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
|
|
2597
|
+
"Returning infinite score per Equation 5.",
|
|
2598
|
+
UserWarning
|
|
2599
|
+
)
|
|
2600
|
+
return np.inf
|
|
2601
|
+
|
|
2602
|
+
# Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
|
|
2603
|
+
# Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
2604
|
+
return tau_squared_sum
|
|
2605
|
+
|
|
2606
|
+
def _bootstrap_variance(
|
|
2607
|
+
self,
|
|
2608
|
+
data: pd.DataFrame,
|
|
2609
|
+
outcome: str,
|
|
2610
|
+
treatment: str,
|
|
2611
|
+
unit: str,
|
|
2612
|
+
time: str,
|
|
2613
|
+
optimal_lambda: Tuple[float, float, float],
|
|
2614
|
+
Y: Optional[np.ndarray] = None,
|
|
2615
|
+
D: Optional[np.ndarray] = None,
|
|
2616
|
+
control_unit_idx: Optional[np.ndarray] = None,
|
|
2617
|
+
) -> Tuple[float, np.ndarray]:
|
|
2618
|
+
"""
|
|
2619
|
+
Compute bootstrap standard error using unit-level block bootstrap.
|
|
2620
|
+
|
|
2621
|
+
When the optional Rust backend is available and the matrix parameters
|
|
2622
|
+
(Y, D, control_unit_idx) are provided, uses parallelized Rust
|
|
2623
|
+
implementation for 5-15x speedup. Falls back to Python implementation
|
|
2624
|
+
if Rust is unavailable or if matrix parameters are not provided.
|
|
2625
|
+
|
|
2626
|
+
Parameters
|
|
2627
|
+
----------
|
|
2628
|
+
data : pd.DataFrame
|
|
2629
|
+
Original data in long format with unit, time, outcome, and treatment.
|
|
2630
|
+
outcome : str
|
|
2631
|
+
Name of the outcome column in data.
|
|
2632
|
+
treatment : str
|
|
2633
|
+
Name of the treatment indicator column in data.
|
|
2634
|
+
unit : str
|
|
2635
|
+
Name of the unit identifier column in data.
|
|
2636
|
+
time : str
|
|
2637
|
+
Name of the time period column in data.
|
|
2638
|
+
optimal_lambda : tuple of float
|
|
2639
|
+
Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
|
|
2640
|
+
from cross-validation. Used for model estimation in each bootstrap.
|
|
2641
|
+
Y : np.ndarray, optional
|
|
2642
|
+
Outcome matrix of shape (n_periods, n_units). Required for Rust
|
|
2643
|
+
backend acceleration. If None, falls back to Python implementation.
|
|
2644
|
+
D : np.ndarray, optional
|
|
2645
|
+
Treatment indicator matrix of shape (n_periods, n_units) where
|
|
2646
|
+
D[t,i]=1 indicates unit i is treated at time t. Required for Rust
|
|
2647
|
+
backend acceleration.
|
|
2648
|
+
control_unit_idx : np.ndarray, optional
|
|
2649
|
+
Array of indices for control units (never-treated). Required for
|
|
2650
|
+
Rust backend acceleration.
|
|
2651
|
+
|
|
2652
|
+
Returns
|
|
2653
|
+
-------
|
|
2654
|
+
se : float
|
|
2655
|
+
Bootstrap standard error of the ATT estimate.
|
|
2656
|
+
bootstrap_estimates : np.ndarray
|
|
2657
|
+
Array of ATT estimates from each bootstrap iteration. Length may
|
|
2658
|
+
be less than n_bootstrap if some iterations failed.
|
|
2659
|
+
|
|
2660
|
+
Notes
|
|
2661
|
+
-----
|
|
2662
|
+
Uses unit-level block bootstrap where entire unit time series are
|
|
2663
|
+
resampled with replacement. This preserves within-unit correlation
|
|
2664
|
+
structure and is appropriate for panel data.
|
|
2665
|
+
"""
|
|
2666
|
+
lambda_time, lambda_unit, lambda_nn = optimal_lambda
|
|
2667
|
+
|
|
2668
|
+
# Try Rust backend for parallel bootstrap (5-15x speedup)
|
|
2669
|
+
if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
|
|
2670
|
+
and self._precomputed is not None and Y is not None
|
|
2671
|
+
and D is not None):
|
|
2672
|
+
try:
|
|
2673
|
+
control_mask = self._precomputed["control_mask"]
|
|
2674
|
+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
|
|
2675
|
+
|
|
2676
|
+
bootstrap_estimates, se = _rust_bootstrap_trop_variance(
|
|
2677
|
+
Y, D.astype(np.float64),
|
|
2678
|
+
control_mask.astype(np.uint8),
|
|
2679
|
+
time_dist_matrix,
|
|
2680
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
2681
|
+
self.n_bootstrap, self.max_iter, self.tol,
|
|
2682
|
+
self.seed if self.seed is not None else 0
|
|
2683
|
+
)
|
|
2684
|
+
|
|
2685
|
+
if len(bootstrap_estimates) >= 10:
|
|
2686
|
+
return float(se), bootstrap_estimates
|
|
2687
|
+
# Fall through to Python if too few bootstrap samples
|
|
2688
|
+
logger.debug(
|
|
2689
|
+
"Rust bootstrap returned only %d samples, falling back to Python",
|
|
2690
|
+
len(bootstrap_estimates)
|
|
2691
|
+
)
|
|
2692
|
+
except Exception as e:
|
|
2693
|
+
logger.debug(
|
|
2694
|
+
"Rust bootstrap variance failed, falling back to Python: %s", e
|
|
2695
|
+
)
|
|
2696
|
+
|
|
2697
|
+
# Python implementation (fallback)
|
|
2698
|
+
rng = np.random.default_rng(self.seed)
|
|
2699
|
+
|
|
2700
|
+
# Issue D fix: Stratified bootstrap sampling
|
|
2701
|
+
# Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
|
|
2702
|
+
# and N_1 treated rows separately to preserve treatment ratio
|
|
2703
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
2704
|
+
treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
|
|
2705
|
+
control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
|
|
2706
|
+
|
|
2707
|
+
n_treated_units = len(treated_units)
|
|
2708
|
+
n_control_units = len(control_units)
|
|
2709
|
+
|
|
2710
|
+
bootstrap_estimates_list = []
|
|
2711
|
+
|
|
2712
|
+
for _ in range(self.n_bootstrap):
|
|
2713
|
+
# Stratified sampling: sample control and treated units separately
|
|
2714
|
+
# This preserves the treatment ratio in each bootstrap sample
|
|
2715
|
+
if n_control_units > 0:
|
|
2716
|
+
sampled_control = rng.choice(
|
|
2717
|
+
control_units, size=n_control_units, replace=True
|
|
2718
|
+
)
|
|
2719
|
+
else:
|
|
2720
|
+
sampled_control = np.array([], dtype=control_units.dtype)
|
|
2721
|
+
|
|
2722
|
+
if n_treated_units > 0:
|
|
2723
|
+
sampled_treated = rng.choice(
|
|
2724
|
+
treated_units, size=n_treated_units, replace=True
|
|
2725
|
+
)
|
|
2726
|
+
else:
|
|
2727
|
+
sampled_treated = np.array([], dtype=treated_units.dtype)
|
|
2728
|
+
|
|
2729
|
+
# Combine stratified samples
|
|
2730
|
+
sampled_units = np.concatenate([sampled_control, sampled_treated])
|
|
2731
|
+
|
|
2732
|
+
# Create bootstrap sample with unique unit IDs
|
|
2733
|
+
boot_data = pd.concat([
|
|
2734
|
+
data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
|
|
2735
|
+
for idx, u in enumerate(sampled_units)
|
|
2736
|
+
], ignore_index=True)
|
|
2737
|
+
|
|
2738
|
+
try:
|
|
2739
|
+
# Fit with fixed lambda (skip LOOCV for speed)
|
|
2740
|
+
att = self._fit_with_fixed_lambda(
|
|
2741
|
+
boot_data, outcome, treatment, unit, time,
|
|
2742
|
+
optimal_lambda
|
|
2743
|
+
)
|
|
2744
|
+
bootstrap_estimates_list.append(att)
|
|
2745
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
2746
|
+
continue
|
|
2747
|
+
|
|
2748
|
+
bootstrap_estimates = np.array(bootstrap_estimates_list)
|
|
2749
|
+
|
|
2750
|
+
if len(bootstrap_estimates) < 10:
|
|
2751
|
+
warnings.warn(
|
|
2752
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
|
|
2753
|
+
"Standard errors may be unreliable.",
|
|
2754
|
+
UserWarning
|
|
2755
|
+
)
|
|
2756
|
+
if len(bootstrap_estimates) == 0:
|
|
2757
|
+
return 0.0, np.array([])
|
|
2758
|
+
|
|
2759
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
2760
|
+
return float(se), bootstrap_estimates
|
|
2761
|
+
|
|
2762
|
+
def _fit_with_fixed_lambda(
|
|
2763
|
+
self,
|
|
2764
|
+
data: pd.DataFrame,
|
|
2765
|
+
outcome: str,
|
|
2766
|
+
treatment: str,
|
|
2767
|
+
unit: str,
|
|
2768
|
+
time: str,
|
|
2769
|
+
fixed_lambda: Tuple[float, float, float],
|
|
2770
|
+
) -> float:
|
|
2771
|
+
"""
|
|
2772
|
+
Fit model with fixed tuning parameters (for bootstrap).
|
|
2773
|
+
|
|
2774
|
+
Uses observation-specific weights following Algorithm 2.
|
|
2775
|
+
Returns only the ATT estimate.
|
|
2776
|
+
"""
|
|
2777
|
+
lambda_time, lambda_unit, lambda_nn = fixed_lambda
|
|
2778
|
+
|
|
2779
|
+
# Setup matrices
|
|
2780
|
+
all_units = sorted(data[unit].unique())
|
|
2781
|
+
all_periods = sorted(data[time].unique())
|
|
2782
|
+
|
|
2783
|
+
n_units = len(all_units)
|
|
2784
|
+
n_periods = len(all_periods)
|
|
2785
|
+
|
|
2786
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
2787
|
+
period_to_idx = {p: i for i, p in enumerate(all_periods)}
|
|
2788
|
+
|
|
2789
|
+
# Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
|
|
2790
|
+
Y = (
|
|
2791
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
2792
|
+
.reindex(index=all_periods, columns=all_units)
|
|
2793
|
+
.values
|
|
2794
|
+
)
|
|
2795
|
+
D = (
|
|
2796
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
2797
|
+
.reindex(index=all_periods, columns=all_units)
|
|
2798
|
+
.fillna(0)
|
|
2799
|
+
.astype(int)
|
|
2800
|
+
.values
|
|
2801
|
+
)
|
|
2802
|
+
|
|
2803
|
+
control_mask = D == 0
|
|
2804
|
+
|
|
2805
|
+
# Get control unit indices
|
|
2806
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
2807
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
2808
|
+
|
|
2809
|
+
# Get list of treated observations
|
|
2810
|
+
treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
2811
|
+
if D[t, i] == 1]
|
|
2812
|
+
|
|
2813
|
+
if not treated_observations:
|
|
2814
|
+
raise ValueError("No treated observations")
|
|
2815
|
+
|
|
2816
|
+
# Compute ATT using observation-specific weights (Algorithm 2)
|
|
2817
|
+
tau_values = []
|
|
2818
|
+
for t, i in treated_observations:
|
|
2819
|
+
# Compute observation-specific weights for this (i, t)
|
|
2820
|
+
weight_matrix = self._compute_observation_weights(
|
|
2821
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
|
|
2822
|
+
n_units, n_periods
|
|
2823
|
+
)
|
|
2824
|
+
|
|
2825
|
+
# Fit model with these weights
|
|
2826
|
+
alpha, beta, L = self._estimate_model(
|
|
2827
|
+
Y, control_mask, weight_matrix, lambda_nn,
|
|
2828
|
+
n_units, n_periods
|
|
2829
|
+
)
|
|
2830
|
+
|
|
2831
|
+
# Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
|
|
2832
|
+
tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
2833
|
+
tau_values.append(tau)
|
|
2834
|
+
|
|
2835
|
+
return np.mean(tau_values)
|
|
2836
|
+
|
|
2837
|
+
def get_params(self) -> Dict[str, Any]:
|
|
2838
|
+
"""Get estimator parameters."""
|
|
2839
|
+
return {
|
|
2840
|
+
"method": self.method,
|
|
2841
|
+
"lambda_time_grid": self.lambda_time_grid,
|
|
2842
|
+
"lambda_unit_grid": self.lambda_unit_grid,
|
|
2843
|
+
"lambda_nn_grid": self.lambda_nn_grid,
|
|
2844
|
+
"max_iter": self.max_iter,
|
|
2845
|
+
"tol": self.tol,
|
|
2846
|
+
"alpha": self.alpha,
|
|
2847
|
+
"n_bootstrap": self.n_bootstrap,
|
|
2848
|
+
"seed": self.seed,
|
|
2849
|
+
}
|
|
2850
|
+
|
|
2851
|
+
def set_params(self, **params) -> "TROP":
|
|
2852
|
+
"""Set estimator parameters."""
|
|
2853
|
+
for key, value in params.items():
|
|
2854
|
+
if hasattr(self, key):
|
|
2855
|
+
setattr(self, key, value)
|
|
2856
|
+
else:
|
|
2857
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
2858
|
+
return self
|
|
2859
|
+
|
|
2860
|
+
|
|
2861
|
+
def trop(
|
|
2862
|
+
data: pd.DataFrame,
|
|
2863
|
+
outcome: str,
|
|
2864
|
+
treatment: str,
|
|
2865
|
+
unit: str,
|
|
2866
|
+
time: str,
|
|
2867
|
+
**kwargs,
|
|
2868
|
+
) -> TROPResults:
|
|
2869
|
+
"""
|
|
2870
|
+
Convenience function for TROP estimation.
|
|
2871
|
+
|
|
2872
|
+
Parameters
|
|
2873
|
+
----------
|
|
2874
|
+
data : pd.DataFrame
|
|
2875
|
+
Panel data.
|
|
2876
|
+
outcome : str
|
|
2877
|
+
Outcome variable column name.
|
|
2878
|
+
treatment : str
|
|
2879
|
+
Treatment indicator column name (0/1).
|
|
2880
|
+
|
|
2881
|
+
IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
|
|
2882
|
+
timing indicator. For each unit, D=1 for ALL periods during and after
|
|
2883
|
+
treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
|
|
2884
|
+
the treatment start time for unit i).
|
|
2885
|
+
unit : str
|
|
2886
|
+
Unit identifier column name.
|
|
2887
|
+
time : str
|
|
2888
|
+
Time period column name.
|
|
2889
|
+
**kwargs
|
|
2890
|
+
Additional arguments passed to TROP constructor.
|
|
2891
|
+
|
|
2892
|
+
Returns
|
|
2893
|
+
-------
|
|
2894
|
+
TROPResults
|
|
2895
|
+
Estimation results.
|
|
2896
|
+
|
|
2897
|
+
Examples
|
|
2898
|
+
--------
|
|
2899
|
+
>>> from diff_diff import trop
|
|
2900
|
+
>>> results = trop(data, 'y', 'treated', 'unit', 'time')
|
|
2901
|
+
>>> print(f"ATT: {results.att:.3f}")
|
|
2902
|
+
"""
|
|
2903
|
+
estimator = TROP(**kwargs)
|
|
2904
|
+
return estimator.fit(data, outcome, treatment, unit, time)
|