diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +234 -0
- diff_diff/_backend.py +64 -0
- diff_diff/_rust_backend.cpython-39-darwin.so +0 -0
- diff_diff/bacon.py +979 -0
- diff_diff/datasets.py +708 -0
- diff_diff/diagnostics.py +927 -0
- diff_diff/estimators.py +1000 -0
- diff_diff/honest_did.py +1493 -0
- diff_diff/linalg.py +980 -0
- diff_diff/power.py +1350 -0
- diff_diff/prep.py +1338 -0
- diff_diff/pretrends.py +1067 -0
- diff_diff/results.py +703 -0
- diff_diff/staggered.py +2297 -0
- diff_diff/sun_abraham.py +1176 -0
- diff_diff/synthetic_did.py +738 -0
- diff_diff/triple_diff.py +1291 -0
- diff_diff/trop.py +1348 -0
- diff_diff/twfe.py +344 -0
- diff_diff/utils.py +1481 -0
- diff_diff/visualization.py +1627 -0
- diff_diff-2.1.0.dist-info/METADATA +2511 -0
- diff_diff-2.1.0.dist-info/RECORD +24 -0
- diff_diff-2.1.0.dist-info/WHEEL +4 -0
diff_diff/trop.py
ADDED
|
@@ -0,0 +1,1348 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Triply Robust Panel (TROP) estimator.
|
|
3
|
+
|
|
4
|
+
Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
|
|
5
|
+
TROP combines three robustness components:
|
|
6
|
+
1. Nuclear norm regularized factor model (interactive fixed effects)
|
|
7
|
+
2. Exponential distance-based unit weights
|
|
8
|
+
3. Exponential time decay weights
|
|
9
|
+
|
|
10
|
+
The estimator uses leave-one-out cross-validation for tuning parameter
|
|
11
|
+
selection and provides robust treatment effect estimates under factor
|
|
12
|
+
confounding.
|
|
13
|
+
|
|
14
|
+
References
|
|
15
|
+
----------
|
|
16
|
+
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
|
|
17
|
+
Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import warnings
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from scipy import stats
|
|
27
|
+
|
|
28
|
+
from diff_diff.results import _get_significance_stars
|
|
29
|
+
from diff_diff.utils import compute_confidence_interval, compute_p_value
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class TROPResults:
|
|
34
|
+
"""
|
|
35
|
+
Results from a Triply Robust Panel (TROP) estimation.
|
|
36
|
+
|
|
37
|
+
TROP combines nuclear norm regularized factor estimation with
|
|
38
|
+
exponential distance-based unit weights and time decay weights.
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
att : float
|
|
43
|
+
Average Treatment effect on the Treated (ATT).
|
|
44
|
+
se : float
|
|
45
|
+
Standard error of the ATT estimate.
|
|
46
|
+
t_stat : float
|
|
47
|
+
T-statistic for the ATT estimate.
|
|
48
|
+
p_value : float
|
|
49
|
+
P-value for the null hypothesis that ATT = 0.
|
|
50
|
+
conf_int : tuple[float, float]
|
|
51
|
+
Confidence interval for the ATT.
|
|
52
|
+
n_obs : int
|
|
53
|
+
Number of observations used in estimation.
|
|
54
|
+
n_treated : int
|
|
55
|
+
Number of treated units.
|
|
56
|
+
n_control : int
|
|
57
|
+
Number of control units.
|
|
58
|
+
n_treated_obs : int
|
|
59
|
+
Number of treated unit-time observations.
|
|
60
|
+
unit_effects : dict
|
|
61
|
+
Estimated unit fixed effects (alpha_i).
|
|
62
|
+
time_effects : dict
|
|
63
|
+
Estimated time fixed effects (beta_t).
|
|
64
|
+
treatment_effects : dict
|
|
65
|
+
Individual treatment effects for each treated (unit, time) pair.
|
|
66
|
+
lambda_time : float
|
|
67
|
+
Selected time weight decay parameter.
|
|
68
|
+
lambda_unit : float
|
|
69
|
+
Selected unit weight decay parameter.
|
|
70
|
+
lambda_nn : float
|
|
71
|
+
Selected nuclear norm regularization parameter.
|
|
72
|
+
factor_matrix : np.ndarray
|
|
73
|
+
Estimated low-rank factor matrix L (n_periods x n_units).
|
|
74
|
+
effective_rank : float
|
|
75
|
+
Effective rank of the factor matrix (sum of singular values / max).
|
|
76
|
+
loocv_score : float
|
|
77
|
+
Leave-one-out cross-validation score for selected parameters.
|
|
78
|
+
variance_method : str
|
|
79
|
+
Method used for variance estimation.
|
|
80
|
+
alpha : float
|
|
81
|
+
Significance level for confidence interval.
|
|
82
|
+
pre_periods : list
|
|
83
|
+
List of pre-treatment period identifiers.
|
|
84
|
+
post_periods : list
|
|
85
|
+
List of post-treatment period identifiers.
|
|
86
|
+
n_bootstrap : int, optional
|
|
87
|
+
Number of bootstrap replications (if bootstrap variance).
|
|
88
|
+
bootstrap_distribution : np.ndarray, optional
|
|
89
|
+
Bootstrap distribution of estimates.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
att: float
|
|
93
|
+
se: float
|
|
94
|
+
t_stat: float
|
|
95
|
+
p_value: float
|
|
96
|
+
conf_int: Tuple[float, float]
|
|
97
|
+
n_obs: int
|
|
98
|
+
n_treated: int
|
|
99
|
+
n_control: int
|
|
100
|
+
n_treated_obs: int
|
|
101
|
+
unit_effects: Dict[Any, float]
|
|
102
|
+
time_effects: Dict[Any, float]
|
|
103
|
+
treatment_effects: Dict[Tuple[Any, Any], float]
|
|
104
|
+
lambda_time: float
|
|
105
|
+
lambda_unit: float
|
|
106
|
+
lambda_nn: float
|
|
107
|
+
factor_matrix: np.ndarray
|
|
108
|
+
effective_rank: float
|
|
109
|
+
loocv_score: float
|
|
110
|
+
variance_method: str
|
|
111
|
+
alpha: float = 0.05
|
|
112
|
+
pre_periods: List[Any] = field(default_factory=list)
|
|
113
|
+
post_periods: List[Any] = field(default_factory=list)
|
|
114
|
+
n_bootstrap: Optional[int] = field(default=None)
|
|
115
|
+
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
|
|
116
|
+
|
|
117
|
+
def __repr__(self) -> str:
|
|
118
|
+
"""Concise string representation."""
|
|
119
|
+
sig = _get_significance_stars(self.p_value)
|
|
120
|
+
return (
|
|
121
|
+
f"TROPResults(ATT={self.att:.4f}{sig}, "
|
|
122
|
+
f"SE={self.se:.4f}, "
|
|
123
|
+
f"eff_rank={self.effective_rank:.1f}, "
|
|
124
|
+
f"p={self.p_value:.4f})"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def summary(self, alpha: Optional[float] = None) -> str:
|
|
128
|
+
"""
|
|
129
|
+
Generate a formatted summary of the estimation results.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
alpha : float, optional
|
|
134
|
+
Significance level for confidence intervals. Defaults to the
|
|
135
|
+
alpha used during estimation.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
str
|
|
140
|
+
Formatted summary table.
|
|
141
|
+
"""
|
|
142
|
+
alpha = alpha or self.alpha
|
|
143
|
+
conf_level = int((1 - alpha) * 100)
|
|
144
|
+
|
|
145
|
+
lines = [
|
|
146
|
+
"=" * 75,
|
|
147
|
+
"Triply Robust Panel (TROP) Estimation Results".center(75),
|
|
148
|
+
"Athey, Imbens, Qu & Viviano (2025)".center(75),
|
|
149
|
+
"=" * 75,
|
|
150
|
+
"",
|
|
151
|
+
f"{'Observations:':<25} {self.n_obs:>10}",
|
|
152
|
+
f"{'Treated units:':<25} {self.n_treated:>10}",
|
|
153
|
+
f"{'Control units:':<25} {self.n_control:>10}",
|
|
154
|
+
f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
|
|
155
|
+
f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
|
|
156
|
+
f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
|
|
157
|
+
"",
|
|
158
|
+
"-" * 75,
|
|
159
|
+
"Tuning Parameters (selected via LOOCV)".center(75),
|
|
160
|
+
"-" * 75,
|
|
161
|
+
f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}",
|
|
162
|
+
f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}",
|
|
163
|
+
f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}",
|
|
164
|
+
f"{'Effective rank:':<25} {self.effective_rank:>10.2f}",
|
|
165
|
+
f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}",
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
# Variance method info
|
|
169
|
+
lines.append(f"{'Variance method:':<25} {self.variance_method:>10}")
|
|
170
|
+
if self.variance_method == "bootstrap" and self.n_bootstrap is not None:
|
|
171
|
+
lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
|
|
172
|
+
|
|
173
|
+
lines.extend([
|
|
174
|
+
"",
|
|
175
|
+
"-" * 75,
|
|
176
|
+
f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
|
|
177
|
+
f"{'t-stat':>10} {'P>|t|':>10} {'':>5}",
|
|
178
|
+
"-" * 75,
|
|
179
|
+
f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} "
|
|
180
|
+
f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
|
|
181
|
+
"-" * 75,
|
|
182
|
+
"",
|
|
183
|
+
f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
|
|
184
|
+
])
|
|
185
|
+
|
|
186
|
+
# Add significance codes
|
|
187
|
+
lines.extend([
|
|
188
|
+
"",
|
|
189
|
+
"Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
|
|
190
|
+
"=" * 75,
|
|
191
|
+
])
|
|
192
|
+
|
|
193
|
+
return "\n".join(lines)
|
|
194
|
+
|
|
195
|
+
def print_summary(self, alpha: Optional[float] = None) -> None:
|
|
196
|
+
"""Print the summary to stdout."""
|
|
197
|
+
print(self.summary(alpha))
|
|
198
|
+
|
|
199
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
200
|
+
"""
|
|
201
|
+
Convert results to a dictionary.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
Dict[str, Any]
|
|
206
|
+
Dictionary containing all estimation results.
|
|
207
|
+
"""
|
|
208
|
+
return {
|
|
209
|
+
"att": self.att,
|
|
210
|
+
"se": self.se,
|
|
211
|
+
"t_stat": self.t_stat,
|
|
212
|
+
"p_value": self.p_value,
|
|
213
|
+
"conf_int_lower": self.conf_int[0],
|
|
214
|
+
"conf_int_upper": self.conf_int[1],
|
|
215
|
+
"n_obs": self.n_obs,
|
|
216
|
+
"n_treated": self.n_treated,
|
|
217
|
+
"n_control": self.n_control,
|
|
218
|
+
"n_treated_obs": self.n_treated_obs,
|
|
219
|
+
"n_pre_periods": len(self.pre_periods),
|
|
220
|
+
"n_post_periods": len(self.post_periods),
|
|
221
|
+
"lambda_time": self.lambda_time,
|
|
222
|
+
"lambda_unit": self.lambda_unit,
|
|
223
|
+
"lambda_nn": self.lambda_nn,
|
|
224
|
+
"effective_rank": self.effective_rank,
|
|
225
|
+
"loocv_score": self.loocv_score,
|
|
226
|
+
"variance_method": self.variance_method,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
230
|
+
"""
|
|
231
|
+
Convert results to a pandas DataFrame.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
pd.DataFrame
|
|
236
|
+
DataFrame with estimation results.
|
|
237
|
+
"""
|
|
238
|
+
return pd.DataFrame([self.to_dict()])
|
|
239
|
+
|
|
240
|
+
def get_treatment_effects_df(self) -> pd.DataFrame:
|
|
241
|
+
"""
|
|
242
|
+
Get individual treatment effects as a DataFrame.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
pd.DataFrame
|
|
247
|
+
DataFrame with unit, time, and treatment effect columns.
|
|
248
|
+
"""
|
|
249
|
+
return pd.DataFrame([
|
|
250
|
+
{"unit": unit, "time": time, "effect": effect}
|
|
251
|
+
for (unit, time), effect in self.treatment_effects.items()
|
|
252
|
+
])
|
|
253
|
+
|
|
254
|
+
def get_unit_effects_df(self) -> pd.DataFrame:
|
|
255
|
+
"""
|
|
256
|
+
Get unit fixed effects as a DataFrame.
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
pd.DataFrame
|
|
261
|
+
DataFrame with unit and effect columns.
|
|
262
|
+
"""
|
|
263
|
+
return pd.DataFrame([
|
|
264
|
+
{"unit": unit, "effect": effect}
|
|
265
|
+
for unit, effect in self.unit_effects.items()
|
|
266
|
+
])
|
|
267
|
+
|
|
268
|
+
def get_time_effects_df(self) -> pd.DataFrame:
|
|
269
|
+
"""
|
|
270
|
+
Get time fixed effects as a DataFrame.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
pd.DataFrame
|
|
275
|
+
DataFrame with time and effect columns.
|
|
276
|
+
"""
|
|
277
|
+
return pd.DataFrame([
|
|
278
|
+
{"time": time, "effect": effect}
|
|
279
|
+
for time, effect in self.time_effects.items()
|
|
280
|
+
])
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def is_significant(self) -> bool:
|
|
284
|
+
"""Check if the ATT is statistically significant at the alpha level."""
|
|
285
|
+
return bool(self.p_value < self.alpha)
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def significance_stars(self) -> str:
|
|
289
|
+
"""Return significance stars based on p-value."""
|
|
290
|
+
return _get_significance_stars(self.p_value)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class TROP:
|
|
294
|
+
"""
|
|
295
|
+
Triply Robust Panel (TROP) estimator.
|
|
296
|
+
|
|
297
|
+
Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
|
|
298
|
+
TROP combines three robustness components:
|
|
299
|
+
|
|
300
|
+
1. **Nuclear norm regularized factor model**: Estimates interactive fixed
|
|
301
|
+
effects L_it via matrix completion with nuclear norm penalty ||L||_*
|
|
302
|
+
|
|
303
|
+
2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
|
|
304
|
+
where d(j,i) is the RMSE of outcome differences between units
|
|
305
|
+
|
|
306
|
+
3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
|
|
307
|
+
weighting pre-treatment periods by proximity to treatment
|
|
308
|
+
|
|
309
|
+
Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
|
|
310
|
+
cross-validation on control observations.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
lambda_time_grid : list, optional
|
|
315
|
+
Grid of time weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
|
|
316
|
+
lambda_unit_grid : list, optional
|
|
317
|
+
Grid of unit weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
|
|
318
|
+
lambda_nn_grid : list, optional
|
|
319
|
+
Grid of nuclear norm regularization parameters. Default: [0, 0.01, 0.1, 1].
|
|
320
|
+
max_iter : int, default=100
|
|
321
|
+
Maximum iterations for nuclear norm optimization.
|
|
322
|
+
tol : float, default=1e-6
|
|
323
|
+
Convergence tolerance for optimization.
|
|
324
|
+
alpha : float, default=0.05
|
|
325
|
+
Significance level for confidence intervals.
|
|
326
|
+
variance_method : str, default='bootstrap'
|
|
327
|
+
Method for variance estimation: 'bootstrap' or 'jackknife'.
|
|
328
|
+
n_bootstrap : int, default=200
|
|
329
|
+
Number of replications for variance estimation.
|
|
330
|
+
seed : int, optional
|
|
331
|
+
Random seed for reproducibility.
|
|
332
|
+
|
|
333
|
+
Attributes
|
|
334
|
+
----------
|
|
335
|
+
results_ : TROPResults
|
|
336
|
+
Estimation results after calling fit().
|
|
337
|
+
is_fitted_ : bool
|
|
338
|
+
Whether the model has been fitted.
|
|
339
|
+
|
|
340
|
+
Examples
|
|
341
|
+
--------
|
|
342
|
+
>>> from diff_diff import TROP
|
|
343
|
+
>>> trop = TROP()
|
|
344
|
+
>>> results = trop.fit(
|
|
345
|
+
... data,
|
|
346
|
+
... outcome='outcome',
|
|
347
|
+
... treatment='treated',
|
|
348
|
+
... unit='unit',
|
|
349
|
+
... time='period',
|
|
350
|
+
... post_periods=[5, 6, 7, 8]
|
|
351
|
+
... )
|
|
352
|
+
>>> results.print_summary()
|
|
353
|
+
|
|
354
|
+
References
|
|
355
|
+
----------
|
|
356
|
+
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
|
|
357
|
+
Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def __init__(
|
|
361
|
+
self,
|
|
362
|
+
lambda_time_grid: Optional[List[float]] = None,
|
|
363
|
+
lambda_unit_grid: Optional[List[float]] = None,
|
|
364
|
+
lambda_nn_grid: Optional[List[float]] = None,
|
|
365
|
+
max_iter: int = 100,
|
|
366
|
+
tol: float = 1e-6,
|
|
367
|
+
alpha: float = 0.05,
|
|
368
|
+
variance_method: str = 'bootstrap',
|
|
369
|
+
n_bootstrap: int = 200,
|
|
370
|
+
seed: Optional[int] = None,
|
|
371
|
+
):
|
|
372
|
+
# Default grids from paper
|
|
373
|
+
self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
|
|
374
|
+
self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
|
|
375
|
+
self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
|
|
376
|
+
|
|
377
|
+
self.max_iter = max_iter
|
|
378
|
+
self.tol = tol
|
|
379
|
+
self.alpha = alpha
|
|
380
|
+
self.variance_method = variance_method
|
|
381
|
+
self.n_bootstrap = n_bootstrap
|
|
382
|
+
self.seed = seed
|
|
383
|
+
|
|
384
|
+
# Validate parameters
|
|
385
|
+
valid_variance_methods = ("bootstrap", "jackknife")
|
|
386
|
+
if variance_method not in valid_variance_methods:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"variance_method must be one of {valid_variance_methods}, "
|
|
389
|
+
f"got '{variance_method}'"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Internal state
|
|
393
|
+
self.results_: Optional[TROPResults] = None
|
|
394
|
+
self.is_fitted_: bool = False
|
|
395
|
+
self._optimal_lambda: Optional[Tuple[float, float, float]] = None
|
|
396
|
+
|
|
397
|
+
def fit(
|
|
398
|
+
self,
|
|
399
|
+
data: pd.DataFrame,
|
|
400
|
+
outcome: str,
|
|
401
|
+
treatment: str,
|
|
402
|
+
unit: str,
|
|
403
|
+
time: str,
|
|
404
|
+
post_periods: Optional[List[Any]] = None,
|
|
405
|
+
) -> TROPResults:
|
|
406
|
+
"""
|
|
407
|
+
Fit the TROP model.
|
|
408
|
+
|
|
409
|
+
Parameters
|
|
410
|
+
----------
|
|
411
|
+
data : pd.DataFrame
|
|
412
|
+
Panel data with observations for multiple units over multiple
|
|
413
|
+
time periods.
|
|
414
|
+
outcome : str
|
|
415
|
+
Name of the outcome variable column.
|
|
416
|
+
treatment : str
|
|
417
|
+
Name of the treatment indicator column (0/1).
|
|
418
|
+
Should be 1 for treated unit-time observations.
|
|
419
|
+
unit : str
|
|
420
|
+
Name of the unit identifier column.
|
|
421
|
+
time : str
|
|
422
|
+
Name of the time period column.
|
|
423
|
+
post_periods : list, optional
|
|
424
|
+
List of time period values that are post-treatment.
|
|
425
|
+
If None, infers from treatment indicator.
|
|
426
|
+
|
|
427
|
+
Returns
|
|
428
|
+
-------
|
|
429
|
+
TROPResults
|
|
430
|
+
Object containing the ATT estimate, standard error,
|
|
431
|
+
factor estimates, and tuning parameters.
|
|
432
|
+
"""
|
|
433
|
+
# Validate inputs
|
|
434
|
+
required_cols = [outcome, treatment, unit, time]
|
|
435
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
436
|
+
if missing:
|
|
437
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
438
|
+
|
|
439
|
+
# Get unique units and periods
|
|
440
|
+
all_units = sorted(data[unit].unique())
|
|
441
|
+
all_periods = sorted(data[time].unique())
|
|
442
|
+
|
|
443
|
+
n_units = len(all_units)
|
|
444
|
+
n_periods = len(all_periods)
|
|
445
|
+
|
|
446
|
+
# Create mappings
|
|
447
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
448
|
+
period_to_idx = {p: i for i, p in enumerate(all_periods)}
|
|
449
|
+
idx_to_unit = {i: u for u, i in unit_to_idx.items()}
|
|
450
|
+
idx_to_period = {i: p for p, i in period_to_idx.items()}
|
|
451
|
+
|
|
452
|
+
# Create outcome matrix Y (n_periods x n_units) and treatment matrix D
|
|
453
|
+
Y = np.full((n_periods, n_units), np.nan)
|
|
454
|
+
D = np.zeros((n_periods, n_units), dtype=int)
|
|
455
|
+
|
|
456
|
+
for _, row in data.iterrows():
|
|
457
|
+
i = unit_to_idx[row[unit]]
|
|
458
|
+
t = period_to_idx[row[time]]
|
|
459
|
+
Y[t, i] = row[outcome]
|
|
460
|
+
D[t, i] = int(row[treatment])
|
|
461
|
+
|
|
462
|
+
# Identify treated observations
|
|
463
|
+
treated_mask = D == 1
|
|
464
|
+
n_treated_obs = np.sum(treated_mask)
|
|
465
|
+
|
|
466
|
+
if n_treated_obs == 0:
|
|
467
|
+
raise ValueError("No treated observations found")
|
|
468
|
+
|
|
469
|
+
# Identify treated and control units
|
|
470
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
471
|
+
treated_unit_idx = np.where(unit_ever_treated)[0]
|
|
472
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
473
|
+
|
|
474
|
+
if len(control_unit_idx) == 0:
|
|
475
|
+
raise ValueError("No control units found")
|
|
476
|
+
|
|
477
|
+
# Determine pre/post periods
|
|
478
|
+
if post_periods is None:
|
|
479
|
+
# Infer from first treatment time
|
|
480
|
+
first_treat_period = None
|
|
481
|
+
for t in range(n_periods):
|
|
482
|
+
if np.any(D[t, :] == 1):
|
|
483
|
+
first_treat_period = t
|
|
484
|
+
break
|
|
485
|
+
if first_treat_period is None:
|
|
486
|
+
raise ValueError("Could not infer post-treatment periods")
|
|
487
|
+
pre_period_idx = list(range(first_treat_period))
|
|
488
|
+
post_period_idx = list(range(first_treat_period, n_periods))
|
|
489
|
+
else:
|
|
490
|
+
post_period_idx = [period_to_idx[p] for p in post_periods if p in period_to_idx]
|
|
491
|
+
pre_period_idx = [i for i in range(n_periods) if i not in post_period_idx]
|
|
492
|
+
|
|
493
|
+
if len(pre_period_idx) < 2:
|
|
494
|
+
raise ValueError("Need at least 2 pre-treatment periods")
|
|
495
|
+
|
|
496
|
+
pre_periods_list = [idx_to_period[i] for i in pre_period_idx]
|
|
497
|
+
post_periods_list = [idx_to_period[i] for i in post_period_idx]
|
|
498
|
+
n_treated_periods = len(post_period_idx)
|
|
499
|
+
|
|
500
|
+
# Step 1: Grid search with LOOCV for tuning parameters
|
|
501
|
+
best_lambda = None
|
|
502
|
+
best_score = np.inf
|
|
503
|
+
|
|
504
|
+
# Control observations mask (for LOOCV)
|
|
505
|
+
control_mask = D == 0
|
|
506
|
+
|
|
507
|
+
for lambda_time in self.lambda_time_grid:
|
|
508
|
+
for lambda_unit in self.lambda_unit_grid:
|
|
509
|
+
for lambda_nn in self.lambda_nn_grid:
|
|
510
|
+
try:
|
|
511
|
+
score = self._loocv_score_obs_specific(
|
|
512
|
+
Y, D, control_mask, control_unit_idx,
|
|
513
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
514
|
+
n_units, n_periods
|
|
515
|
+
)
|
|
516
|
+
if score < best_score:
|
|
517
|
+
best_score = score
|
|
518
|
+
best_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
519
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
520
|
+
continue
|
|
521
|
+
|
|
522
|
+
if best_lambda is None:
|
|
523
|
+
warnings.warn(
|
|
524
|
+
"All tuning parameter combinations failed. Using defaults.",
|
|
525
|
+
UserWarning
|
|
526
|
+
)
|
|
527
|
+
best_lambda = (1.0, 1.0, 0.1)
|
|
528
|
+
best_score = np.nan
|
|
529
|
+
|
|
530
|
+
self._optimal_lambda = best_lambda
|
|
531
|
+
lambda_time, lambda_unit, lambda_nn = best_lambda
|
|
532
|
+
|
|
533
|
+
# Step 2: Final estimation - per-observation model fitting following Algorithm 2
|
|
534
|
+
# For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
|
|
535
|
+
treatment_effects = {}
|
|
536
|
+
tau_values = []
|
|
537
|
+
alpha_estimates = []
|
|
538
|
+
beta_estimates = []
|
|
539
|
+
L_estimates = []
|
|
540
|
+
|
|
541
|
+
# Get list of treated observations
|
|
542
|
+
treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
543
|
+
if D[t, i] == 1]
|
|
544
|
+
|
|
545
|
+
for t, i in treated_observations:
|
|
546
|
+
# Compute observation-specific weights for this (i, t)
|
|
547
|
+
weight_matrix = self._compute_observation_weights(
|
|
548
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
|
|
549
|
+
n_units, n_periods
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# Fit model with these weights
|
|
553
|
+
alpha_hat, beta_hat, L_hat = self._estimate_model(
|
|
554
|
+
Y, control_mask, weight_matrix, lambda_nn,
|
|
555
|
+
n_units, n_periods
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
|
|
559
|
+
tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
|
|
560
|
+
|
|
561
|
+
unit_id = idx_to_unit[i]
|
|
562
|
+
time_id = idx_to_period[t]
|
|
563
|
+
treatment_effects[(unit_id, time_id)] = tau_it
|
|
564
|
+
tau_values.append(tau_it)
|
|
565
|
+
|
|
566
|
+
# Store for averaging
|
|
567
|
+
alpha_estimates.append(alpha_hat)
|
|
568
|
+
beta_estimates.append(beta_hat)
|
|
569
|
+
L_estimates.append(L_hat)
|
|
570
|
+
|
|
571
|
+
# Average ATT
|
|
572
|
+
att = np.mean(tau_values)
|
|
573
|
+
|
|
574
|
+
# Average parameter estimates for output (representative)
|
|
575
|
+
alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
|
|
576
|
+
beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
|
|
577
|
+
L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
|
|
578
|
+
|
|
579
|
+
# Compute effective rank
|
|
580
|
+
_, s, _ = np.linalg.svd(L_hat, full_matrices=False)
|
|
581
|
+
if s[0] > 0:
|
|
582
|
+
effective_rank = np.sum(s) / s[0]
|
|
583
|
+
else:
|
|
584
|
+
effective_rank = 0.0
|
|
585
|
+
|
|
586
|
+
# Step 4: Variance estimation
|
|
587
|
+
if self.variance_method == "bootstrap":
|
|
588
|
+
se, bootstrap_dist = self._bootstrap_variance(
|
|
589
|
+
data, outcome, treatment, unit, time, post_periods_list,
|
|
590
|
+
best_lambda
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
se, bootstrap_dist = self._jackknife_variance(
|
|
594
|
+
Y, D, control_mask, control_unit_idx, best_lambda,
|
|
595
|
+
n_units, n_periods
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Compute test statistics
|
|
599
|
+
if se > 0:
|
|
600
|
+
t_stat = att / se
|
|
601
|
+
p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
|
|
602
|
+
else:
|
|
603
|
+
t_stat = 0.0
|
|
604
|
+
p_value = 1.0
|
|
605
|
+
|
|
606
|
+
conf_int = compute_confidence_interval(att, se, self.alpha)
|
|
607
|
+
|
|
608
|
+
# Create results dictionaries
|
|
609
|
+
unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
|
|
610
|
+
time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
|
|
611
|
+
|
|
612
|
+
# Store results
|
|
613
|
+
self.results_ = TROPResults(
|
|
614
|
+
att=att,
|
|
615
|
+
se=se,
|
|
616
|
+
t_stat=t_stat,
|
|
617
|
+
p_value=p_value,
|
|
618
|
+
conf_int=conf_int,
|
|
619
|
+
n_obs=len(data),
|
|
620
|
+
n_treated=len(treated_unit_idx),
|
|
621
|
+
n_control=len(control_unit_idx),
|
|
622
|
+
n_treated_obs=n_treated_obs,
|
|
623
|
+
unit_effects=unit_effects_dict,
|
|
624
|
+
time_effects=time_effects_dict,
|
|
625
|
+
treatment_effects=treatment_effects,
|
|
626
|
+
lambda_time=lambda_time,
|
|
627
|
+
lambda_unit=lambda_unit,
|
|
628
|
+
lambda_nn=lambda_nn,
|
|
629
|
+
factor_matrix=L_hat,
|
|
630
|
+
effective_rank=effective_rank,
|
|
631
|
+
loocv_score=best_score,
|
|
632
|
+
variance_method=self.variance_method,
|
|
633
|
+
alpha=self.alpha,
|
|
634
|
+
pre_periods=pre_periods_list,
|
|
635
|
+
post_periods=post_periods_list,
|
|
636
|
+
n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None,
|
|
637
|
+
bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
self.is_fitted_ = True
|
|
641
|
+
return self.results_
|
|
642
|
+
|
|
643
|
+
def _compute_unit_distance_pairwise(
|
|
644
|
+
self,
|
|
645
|
+
Y: np.ndarray,
|
|
646
|
+
D: np.ndarray,
|
|
647
|
+
j: int,
|
|
648
|
+
i: int,
|
|
649
|
+
target_period: int,
|
|
650
|
+
) -> float:
|
|
651
|
+
"""
|
|
652
|
+
Compute pairwise distance from control unit j to treated unit i.
|
|
653
|
+
|
|
654
|
+
Following the paper's Equation 3 (page 7):
|
|
655
|
+
dist_unit_{-t}(j, i) = sqrt(
|
|
656
|
+
Σ_u 1{u≠t}(1-W_{iu})(1-W_{ju})(Y_{iu} - Y_{ju})²
|
|
657
|
+
/ Σ_u 1{u≠t}(1-W_{iu})(1-W_{ju})
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
This computes the RMSE between units j and i over periods where
|
|
661
|
+
both are untreated, excluding the target period t.
|
|
662
|
+
|
|
663
|
+
Parameters
|
|
664
|
+
----------
|
|
665
|
+
Y : np.ndarray
|
|
666
|
+
Outcome matrix (n_periods x n_units).
|
|
667
|
+
D : np.ndarray
|
|
668
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
669
|
+
j : int
|
|
670
|
+
Index of control unit.
|
|
671
|
+
i : int
|
|
672
|
+
Index of treated unit.
|
|
673
|
+
target_period : int
|
|
674
|
+
Target treatment period t (excluded from distance computation).
|
|
675
|
+
|
|
676
|
+
Returns
|
|
677
|
+
-------
|
|
678
|
+
float
|
|
679
|
+
Pairwise RMSE distance between units j and i.
|
|
680
|
+
"""
|
|
681
|
+
n_periods = Y.shape[0]
|
|
682
|
+
|
|
683
|
+
sq_diffs = []
|
|
684
|
+
for u in range(n_periods):
|
|
685
|
+
# Exclude target period and periods where either unit is treated
|
|
686
|
+
if u == target_period:
|
|
687
|
+
continue
|
|
688
|
+
# (1 - W_{iu})(1 - W_{ju}) means both must be untreated
|
|
689
|
+
if D[u, i] == 1 or D[u, j] == 1:
|
|
690
|
+
continue
|
|
691
|
+
if np.isnan(Y[u, i]) or np.isnan(Y[u, j]):
|
|
692
|
+
continue
|
|
693
|
+
|
|
694
|
+
sq_diffs.append((Y[u, i] - Y[u, j]) ** 2)
|
|
695
|
+
|
|
696
|
+
if len(sq_diffs) > 0:
|
|
697
|
+
return np.sqrt(np.mean(sq_diffs))
|
|
698
|
+
else:
|
|
699
|
+
return np.inf
|
|
700
|
+
|
|
701
|
+
def _compute_observation_weights(
|
|
702
|
+
self,
|
|
703
|
+
Y: np.ndarray,
|
|
704
|
+
D: np.ndarray,
|
|
705
|
+
i: int,
|
|
706
|
+
t: int,
|
|
707
|
+
lambda_time: float,
|
|
708
|
+
lambda_unit: float,
|
|
709
|
+
control_unit_idx: np.ndarray,
|
|
710
|
+
n_units: int,
|
|
711
|
+
n_periods: int,
|
|
712
|
+
) -> np.ndarray:
|
|
713
|
+
"""
|
|
714
|
+
Compute observation-specific weight matrix for treated observation (i, t).
|
|
715
|
+
|
|
716
|
+
Following the paper's Algorithm 2 (page 27):
|
|
717
|
+
- Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
|
|
718
|
+
- Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
|
|
719
|
+
|
|
720
|
+
Parameters
|
|
721
|
+
----------
|
|
722
|
+
Y : np.ndarray
|
|
723
|
+
Outcome matrix (n_periods x n_units).
|
|
724
|
+
D : np.ndarray
|
|
725
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
726
|
+
i : int
|
|
727
|
+
Treated unit index.
|
|
728
|
+
t : int
|
|
729
|
+
Treatment period index.
|
|
730
|
+
lambda_time : float
|
|
731
|
+
Time weight decay parameter.
|
|
732
|
+
lambda_unit : float
|
|
733
|
+
Unit weight decay parameter.
|
|
734
|
+
control_unit_idx : np.ndarray
|
|
735
|
+
Indices of control units.
|
|
736
|
+
n_units : int
|
|
737
|
+
Number of units.
|
|
738
|
+
n_periods : int
|
|
739
|
+
Number of periods.
|
|
740
|
+
|
|
741
|
+
Returns
|
|
742
|
+
-------
|
|
743
|
+
np.ndarray
|
|
744
|
+
Weight matrix (n_periods x n_units) for observation (i, t).
|
|
745
|
+
"""
|
|
746
|
+
# Time distance: |t - s| following paper's Equation 3 (page 7)
|
|
747
|
+
dist_time = np.array([abs(t - s) for s in range(n_periods)])
|
|
748
|
+
time_weights = np.exp(-lambda_time * dist_time)
|
|
749
|
+
|
|
750
|
+
# Unit distance: pairwise RMSE from each control j to treated i
|
|
751
|
+
unit_weights = np.zeros(n_units)
|
|
752
|
+
|
|
753
|
+
if lambda_unit == 0:
|
|
754
|
+
# Uniform weights when lambda_unit = 0
|
|
755
|
+
unit_weights[:] = 1.0
|
|
756
|
+
else:
|
|
757
|
+
for j in control_unit_idx:
|
|
758
|
+
dist = self._compute_unit_distance_pairwise(Y, D, j, i, t)
|
|
759
|
+
if np.isinf(dist):
|
|
760
|
+
unit_weights[j] = 0.0
|
|
761
|
+
else:
|
|
762
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
763
|
+
|
|
764
|
+
# Treated unit i gets weight 1 (or could be omitted since we fit on controls)
|
|
765
|
+
# We include treated unit's own observation for model fitting
|
|
766
|
+
unit_weights[i] = 1.0
|
|
767
|
+
|
|
768
|
+
# Weight matrix: outer product (n_periods x n_units)
|
|
769
|
+
W = np.outer(time_weights, unit_weights)
|
|
770
|
+
|
|
771
|
+
return W
|
|
772
|
+
|
|
773
|
+
def _soft_threshold_svd(
|
|
774
|
+
self,
|
|
775
|
+
M: np.ndarray,
|
|
776
|
+
threshold: float,
|
|
777
|
+
) -> np.ndarray:
|
|
778
|
+
"""
|
|
779
|
+
Apply soft-thresholding to singular values (proximal operator for nuclear norm).
|
|
780
|
+
|
|
781
|
+
Parameters
|
|
782
|
+
----------
|
|
783
|
+
M : np.ndarray
|
|
784
|
+
Input matrix.
|
|
785
|
+
threshold : float
|
|
786
|
+
Soft-thresholding parameter.
|
|
787
|
+
|
|
788
|
+
Returns
|
|
789
|
+
-------
|
|
790
|
+
np.ndarray
|
|
791
|
+
Matrix with soft-thresholded singular values.
|
|
792
|
+
"""
|
|
793
|
+
if threshold <= 0:
|
|
794
|
+
return M
|
|
795
|
+
|
|
796
|
+
# Handle NaN/Inf values in input
|
|
797
|
+
if not np.isfinite(M).all():
|
|
798
|
+
M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
|
|
799
|
+
|
|
800
|
+
try:
|
|
801
|
+
U, s, Vt = np.linalg.svd(M, full_matrices=False)
|
|
802
|
+
except np.linalg.LinAlgError:
|
|
803
|
+
# SVD failed, return zero matrix
|
|
804
|
+
return np.zeros_like(M)
|
|
805
|
+
|
|
806
|
+
# Check for numerical issues in SVD output
|
|
807
|
+
if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
|
|
808
|
+
# SVD produced non-finite values, return zero matrix
|
|
809
|
+
return np.zeros_like(M)
|
|
810
|
+
|
|
811
|
+
s_thresh = np.maximum(s - threshold, 0)
|
|
812
|
+
|
|
813
|
+
# Use truncated reconstruction with only non-zero singular values
|
|
814
|
+
nonzero_mask = s_thresh > 1e-10
|
|
815
|
+
if not np.any(nonzero_mask):
|
|
816
|
+
return np.zeros_like(M)
|
|
817
|
+
|
|
818
|
+
# Truncate to non-zero components for numerical stability
|
|
819
|
+
U_trunc = U[:, nonzero_mask]
|
|
820
|
+
s_trunc = s_thresh[nonzero_mask]
|
|
821
|
+
Vt_trunc = Vt[nonzero_mask, :]
|
|
822
|
+
|
|
823
|
+
# Compute result, suppressing expected numerical warnings from
|
|
824
|
+
# ill-conditioned matrices during alternating minimization
|
|
825
|
+
with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
|
|
826
|
+
result = (U_trunc * s_trunc) @ Vt_trunc
|
|
827
|
+
|
|
828
|
+
# Replace any NaN/Inf in result with zeros
|
|
829
|
+
if not np.isfinite(result).all():
|
|
830
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
831
|
+
|
|
832
|
+
return result
|
|
833
|
+
|
|
834
|
+
def _estimate_model(
|
|
835
|
+
self,
|
|
836
|
+
Y: np.ndarray,
|
|
837
|
+
control_mask: np.ndarray,
|
|
838
|
+
weight_matrix: np.ndarray,
|
|
839
|
+
lambda_nn: float,
|
|
840
|
+
n_units: int,
|
|
841
|
+
n_periods: int,
|
|
842
|
+
exclude_obs: Optional[Tuple[int, int]] = None,
|
|
843
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
844
|
+
"""
|
|
845
|
+
Estimate the model: Y = α + β + L + τD + ε with nuclear norm penalty on L.
|
|
846
|
+
|
|
847
|
+
Uses alternating minimization:
|
|
848
|
+
1. Fix L, solve for α, β
|
|
849
|
+
2. Fix α, β, solve for L via soft-thresholding
|
|
850
|
+
|
|
851
|
+
Parameters
|
|
852
|
+
----------
|
|
853
|
+
Y : np.ndarray
|
|
854
|
+
Outcome matrix (n_periods x n_units).
|
|
855
|
+
control_mask : np.ndarray
|
|
856
|
+
Boolean mask for control observations.
|
|
857
|
+
weight_matrix : np.ndarray
|
|
858
|
+
Pre-computed global weight matrix (n_periods x n_units).
|
|
859
|
+
lambda_nn : float
|
|
860
|
+
Nuclear norm regularization parameter.
|
|
861
|
+
n_units : int
|
|
862
|
+
Number of units.
|
|
863
|
+
n_periods : int
|
|
864
|
+
Number of periods.
|
|
865
|
+
exclude_obs : tuple, optional
|
|
866
|
+
(t, i) observation to exclude (for LOOCV).
|
|
867
|
+
|
|
868
|
+
Returns
|
|
869
|
+
-------
|
|
870
|
+
tuple
|
|
871
|
+
(alpha, beta, L) estimated parameters.
|
|
872
|
+
"""
|
|
873
|
+
W = weight_matrix
|
|
874
|
+
|
|
875
|
+
# Mask for estimation (control obs only, excluding LOOCV obs if specified)
|
|
876
|
+
est_mask = control_mask.copy()
|
|
877
|
+
if exclude_obs is not None:
|
|
878
|
+
t_ex, i_ex = exclude_obs
|
|
879
|
+
est_mask[t_ex, i_ex] = False
|
|
880
|
+
|
|
881
|
+
# Handle missing values
|
|
882
|
+
valid_mask = ~np.isnan(Y) & est_mask
|
|
883
|
+
|
|
884
|
+
# Initialize
|
|
885
|
+
alpha = np.zeros(n_units)
|
|
886
|
+
beta = np.zeros(n_periods)
|
|
887
|
+
L = np.zeros((n_periods, n_units))
|
|
888
|
+
|
|
889
|
+
# Alternating minimization
|
|
890
|
+
for iteration in range(self.max_iter):
|
|
891
|
+
alpha_old = alpha.copy()
|
|
892
|
+
beta_old = beta.copy()
|
|
893
|
+
L_old = L.copy()
|
|
894
|
+
|
|
895
|
+
# Step 1: Update α and β (weighted means)
|
|
896
|
+
R = Y - L # Residual without fixed effects
|
|
897
|
+
|
|
898
|
+
# Weighted mean for alpha (unit effects)
|
|
899
|
+
for i in range(n_units):
|
|
900
|
+
mask_i = valid_mask[:, i]
|
|
901
|
+
if np.any(mask_i):
|
|
902
|
+
weights_i = W[mask_i, i]
|
|
903
|
+
# Handle case where weights sum to zero (unit not in weight computation)
|
|
904
|
+
weight_sum = np.sum(weights_i)
|
|
905
|
+
if weight_sum > 0:
|
|
906
|
+
alpha[i] = np.average(R[mask_i, i] - beta[mask_i], weights=weights_i)
|
|
907
|
+
else:
|
|
908
|
+
# Use unweighted mean for units with zero total weight
|
|
909
|
+
alpha[i] = np.mean(R[mask_i, i] - beta[mask_i])
|
|
910
|
+
else:
|
|
911
|
+
alpha[i] = 0.0
|
|
912
|
+
|
|
913
|
+
# Weighted mean for beta (time effects)
|
|
914
|
+
for t in range(n_periods):
|
|
915
|
+
mask_t = valid_mask[t, :]
|
|
916
|
+
if np.any(mask_t):
|
|
917
|
+
weights_t = W[t, mask_t]
|
|
918
|
+
# Handle case where weights sum to zero
|
|
919
|
+
weight_sum = np.sum(weights_t)
|
|
920
|
+
if weight_sum > 0:
|
|
921
|
+
beta[t] = np.average(R[t, mask_t] - alpha[mask_t], weights=weights_t)
|
|
922
|
+
else:
|
|
923
|
+
# Use unweighted mean for periods with zero total weight
|
|
924
|
+
beta[t] = np.mean(R[t, mask_t] - alpha[mask_t])
|
|
925
|
+
else:
|
|
926
|
+
beta[t] = 0.0
|
|
927
|
+
|
|
928
|
+
# Step 2: Update L with nuclear norm penalty
|
|
929
|
+
# L = soft_threshold(Y - α - β, λ_nn)
|
|
930
|
+
R_for_L = np.zeros((n_periods, n_units))
|
|
931
|
+
for t in range(n_periods):
|
|
932
|
+
for i in range(n_units):
|
|
933
|
+
if valid_mask[t, i]:
|
|
934
|
+
R_for_L[t, i] = Y[t, i] - alpha[i] - beta[t]
|
|
935
|
+
else:
|
|
936
|
+
# Impute with current L
|
|
937
|
+
R_for_L[t, i] = L[t, i]
|
|
938
|
+
|
|
939
|
+
L = self._soft_threshold_svd(R_for_L, lambda_nn)
|
|
940
|
+
|
|
941
|
+
# Check convergence
|
|
942
|
+
alpha_diff = np.max(np.abs(alpha - alpha_old))
|
|
943
|
+
beta_diff = np.max(np.abs(beta - beta_old))
|
|
944
|
+
L_diff = np.max(np.abs(L - L_old))
|
|
945
|
+
|
|
946
|
+
if max(alpha_diff, beta_diff, L_diff) < self.tol:
|
|
947
|
+
break
|
|
948
|
+
|
|
949
|
+
return alpha, beta, L
|
|
950
|
+
|
|
951
|
+
def _loocv_score_obs_specific(
|
|
952
|
+
self,
|
|
953
|
+
Y: np.ndarray,
|
|
954
|
+
D: np.ndarray,
|
|
955
|
+
control_mask: np.ndarray,
|
|
956
|
+
control_unit_idx: np.ndarray,
|
|
957
|
+
lambda_time: float,
|
|
958
|
+
lambda_unit: float,
|
|
959
|
+
lambda_nn: float,
|
|
960
|
+
n_units: int,
|
|
961
|
+
n_periods: int,
|
|
962
|
+
) -> float:
|
|
963
|
+
"""
|
|
964
|
+
Compute leave-one-out cross-validation score with observation-specific weights.
|
|
965
|
+
|
|
966
|
+
Following the paper's Equation 5 (page 8):
|
|
967
|
+
Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
968
|
+
|
|
969
|
+
For each control observation (j, s), treat it as pseudo-treated,
|
|
970
|
+
compute observation-specific weights, fit model excluding (j, s),
|
|
971
|
+
and sum squared pseudo-treatment effects.
|
|
972
|
+
|
|
973
|
+
Parameters
|
|
974
|
+
----------
|
|
975
|
+
Y : np.ndarray
|
|
976
|
+
Outcome matrix (n_periods x n_units).
|
|
977
|
+
D : np.ndarray
|
|
978
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
979
|
+
control_mask : np.ndarray
|
|
980
|
+
Boolean mask for control observations.
|
|
981
|
+
control_unit_idx : np.ndarray
|
|
982
|
+
Indices of control units.
|
|
983
|
+
lambda_time : float
|
|
984
|
+
Time weight decay parameter.
|
|
985
|
+
lambda_unit : float
|
|
986
|
+
Unit weight decay parameter.
|
|
987
|
+
lambda_nn : float
|
|
988
|
+
Nuclear norm regularization parameter.
|
|
989
|
+
n_units : int
|
|
990
|
+
Number of units.
|
|
991
|
+
n_periods : int
|
|
992
|
+
Number of periods.
|
|
993
|
+
|
|
994
|
+
Returns
|
|
995
|
+
-------
|
|
996
|
+
float
|
|
997
|
+
LOOCV score (lower is better).
|
|
998
|
+
"""
|
|
999
|
+
# Get all control observations
|
|
1000
|
+
control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
1001
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])]
|
|
1002
|
+
|
|
1003
|
+
# Subsample for computational tractability (as noted in paper's footnote)
|
|
1004
|
+
rng = np.random.default_rng(self.seed)
|
|
1005
|
+
max_loocv = min(100, len(control_obs))
|
|
1006
|
+
if len(control_obs) > max_loocv:
|
|
1007
|
+
indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
|
|
1008
|
+
control_obs = [control_obs[idx] for idx in indices]
|
|
1009
|
+
|
|
1010
|
+
tau_squared_sum = 0.0
|
|
1011
|
+
n_valid = 0
|
|
1012
|
+
|
|
1013
|
+
for t, i in control_obs:
|
|
1014
|
+
try:
|
|
1015
|
+
# Compute observation-specific weights for pseudo-treated (i, t)
|
|
1016
|
+
weight_matrix = self._compute_observation_weights(
|
|
1017
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
|
|
1018
|
+
n_units, n_periods
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
# Estimate model excluding observation (t, i)
|
|
1022
|
+
alpha, beta, L = self._estimate_model(
|
|
1023
|
+
Y, control_mask, weight_matrix, lambda_nn,
|
|
1024
|
+
n_units, n_periods, exclude_obs=(t, i)
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
# Pseudo treatment effect
|
|
1028
|
+
tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
1029
|
+
tau_squared_sum += tau_ti ** 2
|
|
1030
|
+
n_valid += 1
|
|
1031
|
+
|
|
1032
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1033
|
+
continue
|
|
1034
|
+
|
|
1035
|
+
if n_valid == 0:
|
|
1036
|
+
return np.inf
|
|
1037
|
+
|
|
1038
|
+
return tau_squared_sum / n_valid
|
|
1039
|
+
|
|
1040
|
+
def _bootstrap_variance(
|
|
1041
|
+
self,
|
|
1042
|
+
data: pd.DataFrame,
|
|
1043
|
+
outcome: str,
|
|
1044
|
+
treatment: str,
|
|
1045
|
+
unit: str,
|
|
1046
|
+
time: str,
|
|
1047
|
+
post_periods: List[Any],
|
|
1048
|
+
optimal_lambda: Tuple[float, float, float],
|
|
1049
|
+
) -> Tuple[float, np.ndarray]:
|
|
1050
|
+
"""
|
|
1051
|
+
Compute bootstrap standard error using unit-level block bootstrap.
|
|
1052
|
+
|
|
1053
|
+
Parameters
|
|
1054
|
+
----------
|
|
1055
|
+
data : pd.DataFrame
|
|
1056
|
+
Original data.
|
|
1057
|
+
outcome : str
|
|
1058
|
+
Outcome column name.
|
|
1059
|
+
treatment : str
|
|
1060
|
+
Treatment column name.
|
|
1061
|
+
unit : str
|
|
1062
|
+
Unit column name.
|
|
1063
|
+
time : str
|
|
1064
|
+
Time column name.
|
|
1065
|
+
post_periods : list
|
|
1066
|
+
Post-treatment periods.
|
|
1067
|
+
optimal_lambda : tuple
|
|
1068
|
+
Optimal (lambda_time, lambda_unit, lambda_nn).
|
|
1069
|
+
|
|
1070
|
+
Returns
|
|
1071
|
+
-------
|
|
1072
|
+
tuple
|
|
1073
|
+
(se, bootstrap_estimates).
|
|
1074
|
+
"""
|
|
1075
|
+
rng = np.random.default_rng(self.seed)
|
|
1076
|
+
all_units = data[unit].unique()
|
|
1077
|
+
n_units = len(all_units)
|
|
1078
|
+
|
|
1079
|
+
bootstrap_estimates = []
|
|
1080
|
+
|
|
1081
|
+
for b in range(self.n_bootstrap):
|
|
1082
|
+
# Sample units with replacement
|
|
1083
|
+
sampled_units = rng.choice(all_units, size=n_units, replace=True)
|
|
1084
|
+
|
|
1085
|
+
# Create bootstrap sample with unique unit IDs
|
|
1086
|
+
boot_data = pd.concat([
|
|
1087
|
+
data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
|
|
1088
|
+
for idx, u in enumerate(sampled_units)
|
|
1089
|
+
], ignore_index=True)
|
|
1090
|
+
|
|
1091
|
+
try:
|
|
1092
|
+
# Fit with fixed lambda (skip LOOCV for speed)
|
|
1093
|
+
att = self._fit_with_fixed_lambda(
|
|
1094
|
+
boot_data, outcome, treatment, unit, time,
|
|
1095
|
+
post_periods, optimal_lambda
|
|
1096
|
+
)
|
|
1097
|
+
bootstrap_estimates.append(att)
|
|
1098
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
1099
|
+
continue
|
|
1100
|
+
|
|
1101
|
+
bootstrap_estimates = np.array(bootstrap_estimates)
|
|
1102
|
+
|
|
1103
|
+
if len(bootstrap_estimates) < 10:
|
|
1104
|
+
warnings.warn(
|
|
1105
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
|
|
1106
|
+
"Standard errors may be unreliable.",
|
|
1107
|
+
UserWarning
|
|
1108
|
+
)
|
|
1109
|
+
if len(bootstrap_estimates) == 0:
|
|
1110
|
+
return 0.0, np.array([])
|
|
1111
|
+
|
|
1112
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
1113
|
+
return se, bootstrap_estimates
|
|
1114
|
+
|
|
1115
|
+
def _jackknife_variance(
|
|
1116
|
+
self,
|
|
1117
|
+
Y: np.ndarray,
|
|
1118
|
+
D: np.ndarray,
|
|
1119
|
+
control_mask: np.ndarray,
|
|
1120
|
+
control_unit_idx: np.ndarray,
|
|
1121
|
+
optimal_lambda: Tuple[float, float, float],
|
|
1122
|
+
n_units: int,
|
|
1123
|
+
n_periods: int,
|
|
1124
|
+
) -> Tuple[float, np.ndarray]:
|
|
1125
|
+
"""
|
|
1126
|
+
Compute jackknife standard error (leave-one-unit-out).
|
|
1127
|
+
|
|
1128
|
+
Uses observation-specific weights following Algorithm 2.
|
|
1129
|
+
|
|
1130
|
+
Parameters
|
|
1131
|
+
----------
|
|
1132
|
+
Y : np.ndarray
|
|
1133
|
+
Outcome matrix.
|
|
1134
|
+
D : np.ndarray
|
|
1135
|
+
Treatment matrix.
|
|
1136
|
+
control_mask : np.ndarray
|
|
1137
|
+
Control observation mask.
|
|
1138
|
+
control_unit_idx : np.ndarray
|
|
1139
|
+
Indices of control units.
|
|
1140
|
+
optimal_lambda : tuple
|
|
1141
|
+
Optimal tuning parameters.
|
|
1142
|
+
n_units : int
|
|
1143
|
+
Number of units.
|
|
1144
|
+
n_periods : int
|
|
1145
|
+
Number of periods.
|
|
1146
|
+
|
|
1147
|
+
Returns
|
|
1148
|
+
-------
|
|
1149
|
+
tuple
|
|
1150
|
+
(se, jackknife_estimates).
|
|
1151
|
+
"""
|
|
1152
|
+
lambda_time, lambda_unit, lambda_nn = optimal_lambda
|
|
1153
|
+
jackknife_estimates = []
|
|
1154
|
+
|
|
1155
|
+
# Get treated unit indices
|
|
1156
|
+
treated_unit_idx = np.where(np.any(D == 1, axis=0))[0]
|
|
1157
|
+
|
|
1158
|
+
for leave_out in treated_unit_idx:
|
|
1159
|
+
# Create mask excluding this unit
|
|
1160
|
+
Y_jack = Y.copy()
|
|
1161
|
+
D_jack = D.copy()
|
|
1162
|
+
Y_jack[:, leave_out] = np.nan
|
|
1163
|
+
D_jack[:, leave_out] = 0
|
|
1164
|
+
|
|
1165
|
+
control_mask_jack = D_jack == 0
|
|
1166
|
+
|
|
1167
|
+
# Get remaining treated observations
|
|
1168
|
+
treated_obs_jack = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
1169
|
+
if D_jack[t, i] == 1]
|
|
1170
|
+
|
|
1171
|
+
if not treated_obs_jack:
|
|
1172
|
+
continue
|
|
1173
|
+
|
|
1174
|
+
try:
|
|
1175
|
+
# Compute ATT using observation-specific weights (Algorithm 2)
|
|
1176
|
+
tau_values = []
|
|
1177
|
+
for t, i in treated_obs_jack:
|
|
1178
|
+
# Compute observation-specific weights for this (i, t)
|
|
1179
|
+
weight_matrix = self._compute_observation_weights(
|
|
1180
|
+
Y_jack, D_jack, i, t, lambda_time, lambda_unit,
|
|
1181
|
+
control_unit_idx, n_units, n_periods
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
# Fit model with these weights
|
|
1185
|
+
alpha, beta, L = self._estimate_model(
|
|
1186
|
+
Y_jack, control_mask_jack, weight_matrix, lambda_nn,
|
|
1187
|
+
n_units, n_periods
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# Compute treatment effect
|
|
1191
|
+
tau = Y_jack[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
1192
|
+
tau_values.append(tau)
|
|
1193
|
+
|
|
1194
|
+
if tau_values:
|
|
1195
|
+
jackknife_estimates.append(np.mean(tau_values))
|
|
1196
|
+
|
|
1197
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1198
|
+
continue
|
|
1199
|
+
|
|
1200
|
+
jackknife_estimates = np.array(jackknife_estimates)
|
|
1201
|
+
|
|
1202
|
+
if len(jackknife_estimates) < 2:
|
|
1203
|
+
return 0.0, jackknife_estimates
|
|
1204
|
+
|
|
1205
|
+
# Jackknife SE formula
|
|
1206
|
+
n = len(jackknife_estimates)
|
|
1207
|
+
mean_est = np.mean(jackknife_estimates)
|
|
1208
|
+
se = np.sqrt((n - 1) / n * np.sum((jackknife_estimates - mean_est) ** 2))
|
|
1209
|
+
|
|
1210
|
+
return se, jackknife_estimates
|
|
1211
|
+
|
|
1212
|
+
def _fit_with_fixed_lambda(
|
|
1213
|
+
self,
|
|
1214
|
+
data: pd.DataFrame,
|
|
1215
|
+
outcome: str,
|
|
1216
|
+
treatment: str,
|
|
1217
|
+
unit: str,
|
|
1218
|
+
time: str,
|
|
1219
|
+
post_periods: List[Any],
|
|
1220
|
+
fixed_lambda: Tuple[float, float, float],
|
|
1221
|
+
) -> float:
|
|
1222
|
+
"""
|
|
1223
|
+
Fit model with fixed tuning parameters (for bootstrap).
|
|
1224
|
+
|
|
1225
|
+
Uses observation-specific weights following Algorithm 2.
|
|
1226
|
+
Returns only the ATT estimate.
|
|
1227
|
+
"""
|
|
1228
|
+
lambda_time, lambda_unit, lambda_nn = fixed_lambda
|
|
1229
|
+
|
|
1230
|
+
# Setup matrices
|
|
1231
|
+
all_units = sorted(data[unit].unique())
|
|
1232
|
+
all_periods = sorted(data[time].unique())
|
|
1233
|
+
|
|
1234
|
+
n_units = len(all_units)
|
|
1235
|
+
n_periods = len(all_periods)
|
|
1236
|
+
|
|
1237
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
1238
|
+
period_to_idx = {p: i for i, p in enumerate(all_periods)}
|
|
1239
|
+
|
|
1240
|
+
Y = np.full((n_periods, n_units), np.nan)
|
|
1241
|
+
D = np.zeros((n_periods, n_units), dtype=int)
|
|
1242
|
+
|
|
1243
|
+
for _, row in data.iterrows():
|
|
1244
|
+
i = unit_to_idx[row[unit]]
|
|
1245
|
+
t = period_to_idx[row[time]]
|
|
1246
|
+
Y[t, i] = row[outcome]
|
|
1247
|
+
D[t, i] = int(row[treatment])
|
|
1248
|
+
|
|
1249
|
+
control_mask = D == 0
|
|
1250
|
+
|
|
1251
|
+
# Get control unit indices
|
|
1252
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
1253
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
1254
|
+
|
|
1255
|
+
# Get list of treated observations
|
|
1256
|
+
treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units)
|
|
1257
|
+
if D[t, i] == 1]
|
|
1258
|
+
|
|
1259
|
+
if not treated_observations:
|
|
1260
|
+
raise ValueError("No treated observations")
|
|
1261
|
+
|
|
1262
|
+
# Compute ATT using observation-specific weights (Algorithm 2)
|
|
1263
|
+
tau_values = []
|
|
1264
|
+
for t, i in treated_observations:
|
|
1265
|
+
# Compute observation-specific weights for this (i, t)
|
|
1266
|
+
weight_matrix = self._compute_observation_weights(
|
|
1267
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx,
|
|
1268
|
+
n_units, n_periods
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
# Fit model with these weights
|
|
1272
|
+
alpha, beta, L = self._estimate_model(
|
|
1273
|
+
Y, control_mask, weight_matrix, lambda_nn,
|
|
1274
|
+
n_units, n_periods
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
# Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it}
|
|
1278
|
+
tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
1279
|
+
tau_values.append(tau)
|
|
1280
|
+
|
|
1281
|
+
return np.mean(tau_values)
|
|
1282
|
+
|
|
1283
|
+
def get_params(self) -> Dict[str, Any]:
|
|
1284
|
+
"""Get estimator parameters."""
|
|
1285
|
+
return {
|
|
1286
|
+
"lambda_time_grid": self.lambda_time_grid,
|
|
1287
|
+
"lambda_unit_grid": self.lambda_unit_grid,
|
|
1288
|
+
"lambda_nn_grid": self.lambda_nn_grid,
|
|
1289
|
+
"max_iter": self.max_iter,
|
|
1290
|
+
"tol": self.tol,
|
|
1291
|
+
"alpha": self.alpha,
|
|
1292
|
+
"variance_method": self.variance_method,
|
|
1293
|
+
"n_bootstrap": self.n_bootstrap,
|
|
1294
|
+
"seed": self.seed,
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
def set_params(self, **params) -> "TROP":
|
|
1298
|
+
"""Set estimator parameters."""
|
|
1299
|
+
for key, value in params.items():
|
|
1300
|
+
if hasattr(self, key):
|
|
1301
|
+
setattr(self, key, value)
|
|
1302
|
+
else:
|
|
1303
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
1304
|
+
return self
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
def trop(
|
|
1308
|
+
data: pd.DataFrame,
|
|
1309
|
+
outcome: str,
|
|
1310
|
+
treatment: str,
|
|
1311
|
+
unit: str,
|
|
1312
|
+
time: str,
|
|
1313
|
+
post_periods: Optional[List[Any]] = None,
|
|
1314
|
+
**kwargs,
|
|
1315
|
+
) -> TROPResults:
|
|
1316
|
+
"""
|
|
1317
|
+
Convenience function for TROP estimation.
|
|
1318
|
+
|
|
1319
|
+
Parameters
|
|
1320
|
+
----------
|
|
1321
|
+
data : pd.DataFrame
|
|
1322
|
+
Panel data.
|
|
1323
|
+
outcome : str
|
|
1324
|
+
Outcome variable column name.
|
|
1325
|
+
treatment : str
|
|
1326
|
+
Treatment indicator column name.
|
|
1327
|
+
unit : str
|
|
1328
|
+
Unit identifier column name.
|
|
1329
|
+
time : str
|
|
1330
|
+
Time period column name.
|
|
1331
|
+
post_periods : list, optional
|
|
1332
|
+
Post-treatment periods.
|
|
1333
|
+
**kwargs
|
|
1334
|
+
Additional arguments passed to TROP constructor.
|
|
1335
|
+
|
|
1336
|
+
Returns
|
|
1337
|
+
-------
|
|
1338
|
+
TROPResults
|
|
1339
|
+
Estimation results.
|
|
1340
|
+
|
|
1341
|
+
Examples
|
|
1342
|
+
--------
|
|
1343
|
+
>>> from diff_diff import trop
|
|
1344
|
+
>>> results = trop(data, 'y', 'treated', 'unit', 'time', post_periods=[5,6,7])
|
|
1345
|
+
>>> print(f"ATT: {results.att:.3f}")
|
|
1346
|
+
"""
|
|
1347
|
+
estimator = TROP(**kwargs)
|
|
1348
|
+
return estimator.fit(data, outcome, treatment, unit, time, post_periods)
|