diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/trop.py
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Triply Robust Panel (TROP) estimator.
|
|
3
|
+
|
|
4
|
+
Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
|
|
5
|
+
TROP combines three robustness components:
|
|
6
|
+
1. Nuclear norm regularized factor model (interactive fixed effects)
|
|
7
|
+
2. Exponential distance-based unit weights
|
|
8
|
+
3. Exponential time decay weights
|
|
9
|
+
|
|
10
|
+
The estimator uses leave-one-out cross-validation for tuning parameter
|
|
11
|
+
selection and provides robust treatment effect estimates under factor
|
|
12
|
+
confounding.
|
|
13
|
+
|
|
14
|
+
References
|
|
15
|
+
----------
|
|
16
|
+
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
|
|
17
|
+
Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
import warnings
|
|
22
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
from diff_diff._backend import (
|
|
30
|
+
HAS_RUST_BACKEND,
|
|
31
|
+
_rust_loocv_grid_search,
|
|
32
|
+
)
|
|
33
|
+
from diff_diff.trop_global import TROPGlobalMixin
|
|
34
|
+
from diff_diff.trop_local import TROPLocalMixin, _validate_and_pivot_treatment
|
|
35
|
+
from diff_diff.trop_results import (
|
|
36
|
+
_LAMBDA_INF,
|
|
37
|
+
_PrecomputedStructures,
|
|
38
|
+
TROPResults,
|
|
39
|
+
)
|
|
40
|
+
from diff_diff.utils import safe_inference
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TROP(TROPLocalMixin, TROPGlobalMixin):
|
|
44
|
+
"""
|
|
45
|
+
Triply Robust Panel (TROP) estimator.
|
|
46
|
+
|
|
47
|
+
Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
|
|
48
|
+
TROP combines three robustness components:
|
|
49
|
+
|
|
50
|
+
1. **Nuclear norm regularized factor model**: Estimates interactive fixed
|
|
51
|
+
effects L_it via matrix completion with nuclear norm penalty ||L||_*
|
|
52
|
+
|
|
53
|
+
2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
|
|
54
|
+
where d(j,i) is the RMSE of outcome differences between units
|
|
55
|
+
|
|
56
|
+
3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
|
|
57
|
+
weighting pre-treatment periods by proximity to treatment
|
|
58
|
+
|
|
59
|
+
Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
|
|
60
|
+
cross-validation on control observations.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
method : str, default='local'
|
|
65
|
+
Estimation method to use:
|
|
66
|
+
|
|
67
|
+
- 'local': Per-observation model fitting following Algorithm 2 of
|
|
68
|
+
Athey et al. (2025). Computes observation-specific weights and fits
|
|
69
|
+
a model for each treated observation, averaging the individual
|
|
70
|
+
treatment effects. More flexible but computationally intensive.
|
|
71
|
+
|
|
72
|
+
- 'global': Computationally efficient adaptation using the (1-W)
|
|
73
|
+
masking principle from Eq. 2. Fits a single model on control
|
|
74
|
+
observations with global weights, then computes per-observation
|
|
75
|
+
treatment effects as residuals:
|
|
76
|
+
tau_it = Y_it - mu - alpha_i - beta_t - L_it for treated cells.
|
|
77
|
+
ATT is the mean of these effects. For the paper's full
|
|
78
|
+
per-treated-cell estimator, use ``method='local'``.
|
|
79
|
+
|
|
80
|
+
lambda_time_grid : list, optional
|
|
81
|
+
Grid of time weight decay parameters. 0.0 = uniform weights (disabled).
|
|
82
|
+
Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
|
|
83
|
+
lambda_unit_grid : list, optional
|
|
84
|
+
Grid of unit weight decay parameters. 0.0 = uniform weights (disabled).
|
|
85
|
+
Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
|
|
86
|
+
lambda_nn_grid : list, optional
|
|
87
|
+
Grid of nuclear norm regularization parameters. inf = factor model
|
|
88
|
+
disabled (L=0). Default: [0, 0.01, 0.1, 1].
|
|
89
|
+
max_iter : int, default=100
|
|
90
|
+
Maximum iterations for nuclear norm optimization.
|
|
91
|
+
tol : float, default=1e-6
|
|
92
|
+
Convergence tolerance for optimization.
|
|
93
|
+
alpha : float, default=0.05
|
|
94
|
+
Significance level for confidence intervals.
|
|
95
|
+
n_bootstrap : int, default=200
|
|
96
|
+
Number of bootstrap replications for variance estimation. Must be >= 2.
|
|
97
|
+
seed : int, optional
|
|
98
|
+
Random seed for reproducibility.
|
|
99
|
+
|
|
100
|
+
Attributes
|
|
101
|
+
----------
|
|
102
|
+
results_ : TROPResults
|
|
103
|
+
Estimation results after calling fit().
|
|
104
|
+
is_fitted_ : bool
|
|
105
|
+
Whether the model has been fitted.
|
|
106
|
+
|
|
107
|
+
Examples
|
|
108
|
+
--------
|
|
109
|
+
>>> from diff_diff import TROP
|
|
110
|
+
>>> trop = TROP()
|
|
111
|
+
>>> results = trop.fit(
|
|
112
|
+
... data,
|
|
113
|
+
... outcome='outcome',
|
|
114
|
+
... treatment='treated',
|
|
115
|
+
... unit='unit',
|
|
116
|
+
... time='period',
|
|
117
|
+
... )
|
|
118
|
+
>>> results.print_summary()
|
|
119
|
+
|
|
120
|
+
References
|
|
121
|
+
----------
|
|
122
|
+
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
|
|
123
|
+
Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
method: str = "local",
|
|
129
|
+
lambda_time_grid: Optional[List[float]] = None,
|
|
130
|
+
lambda_unit_grid: Optional[List[float]] = None,
|
|
131
|
+
lambda_nn_grid: Optional[List[float]] = None,
|
|
132
|
+
max_iter: int = 100,
|
|
133
|
+
tol: float = 1e-6,
|
|
134
|
+
alpha: float = 0.05,
|
|
135
|
+
n_bootstrap: int = 200,
|
|
136
|
+
seed: Optional[int] = None,
|
|
137
|
+
):
|
|
138
|
+
# Validate method parameter
|
|
139
|
+
valid_methods = ("local", "global")
|
|
140
|
+
if method not in valid_methods:
|
|
141
|
+
raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
|
|
142
|
+
self.method = method
|
|
143
|
+
|
|
144
|
+
# Default grids from paper
|
|
145
|
+
self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
|
|
146
|
+
self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
|
|
147
|
+
self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
|
|
148
|
+
|
|
149
|
+
if n_bootstrap < 2:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"n_bootstrap must be >= 2 for TROP (bootstrap variance "
|
|
152
|
+
"estimation is always used)"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
self.max_iter = max_iter
|
|
156
|
+
self.tol = tol
|
|
157
|
+
self.alpha = alpha
|
|
158
|
+
self.n_bootstrap = n_bootstrap
|
|
159
|
+
self.seed = seed
|
|
160
|
+
|
|
161
|
+
# Validate that time/unit grids do not contain inf.
|
|
162
|
+
# Per Athey et al. (2025) Eq. 3, λ_time=0 and λ_unit=0 give uniform
|
|
163
|
+
# weights (exp(-0 × dist) = 1). Using inf is a misunderstanding of
|
|
164
|
+
# the paper's convention. Only λ_nn=∞ is valid (disables factor model).
|
|
165
|
+
for grid_name, grid_vals in [
|
|
166
|
+
("lambda_time_grid", self.lambda_time_grid),
|
|
167
|
+
("lambda_unit_grid", self.lambda_unit_grid),
|
|
168
|
+
]:
|
|
169
|
+
if any(np.isinf(v) for v in grid_vals):
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"{grid_name} must not contain inf. Use 0.0 for uniform "
|
|
172
|
+
f"weights (disabled) per Athey et al. (2025) Eq. 3: "
|
|
173
|
+
f"exp(-0 × dist) = 1 for all distances."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Internal state
|
|
177
|
+
self.results_: Optional[TROPResults] = None
|
|
178
|
+
self.is_fitted_: bool = False
|
|
179
|
+
self._optimal_lambda: Optional[Tuple[float, float, float]] = None
|
|
180
|
+
|
|
181
|
+
# Pre-computed structures (set during fit)
|
|
182
|
+
self._precomputed: Optional[_PrecomputedStructures] = None
|
|
183
|
+
|
|
184
|
+
# =========================================================================
|
|
185
|
+
# Parameter search (used by local method's fit() path)
|
|
186
|
+
# =========================================================================
|
|
187
|
+
|
|
188
|
+
def _univariate_loocv_search(
|
|
189
|
+
self,
|
|
190
|
+
Y: np.ndarray,
|
|
191
|
+
D: np.ndarray,
|
|
192
|
+
control_mask: np.ndarray,
|
|
193
|
+
control_unit_idx: np.ndarray,
|
|
194
|
+
n_units: int,
|
|
195
|
+
n_periods: int,
|
|
196
|
+
param_name: str,
|
|
197
|
+
grid: List[float],
|
|
198
|
+
fixed_params: Dict[str, float],
|
|
199
|
+
) -> Tuple[float, float]:
|
|
200
|
+
"""
|
|
201
|
+
Search over one parameter with others fixed.
|
|
202
|
+
|
|
203
|
+
Following paper's footnote 2, this performs a univariate grid search
|
|
204
|
+
for one tuning parameter while holding others fixed. The fixed_params
|
|
205
|
+
use 0.0 for disabled time/unit weights and _LAMBDA_INF for disabled
|
|
206
|
+
factor model:
|
|
207
|
+
- lambda_nn = inf: Skip nuclear norm regularization (L=0)
|
|
208
|
+
- lambda_time = 0.0: Uniform time weights (exp(-0×dist)=1)
|
|
209
|
+
- lambda_unit = 0.0: Uniform unit weights (exp(-0×dist)=1)
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
Y : np.ndarray
|
|
214
|
+
Outcome matrix (n_periods x n_units).
|
|
215
|
+
D : np.ndarray
|
|
216
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
217
|
+
control_mask : np.ndarray
|
|
218
|
+
Boolean mask for control observations.
|
|
219
|
+
control_unit_idx : np.ndarray
|
|
220
|
+
Indices of control units.
|
|
221
|
+
n_units : int
|
|
222
|
+
Number of units.
|
|
223
|
+
n_periods : int
|
|
224
|
+
Number of periods.
|
|
225
|
+
param_name : str
|
|
226
|
+
Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
|
|
227
|
+
grid : List[float]
|
|
228
|
+
Grid of values to search over.
|
|
229
|
+
fixed_params : Dict[str, float]
|
|
230
|
+
Fixed values for other parameters. May include _LAMBDA_INF for lambda_nn.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
Tuple[float, float]
|
|
235
|
+
(best_value, best_score) for the searched parameter.
|
|
236
|
+
"""
|
|
237
|
+
best_score = np.inf
|
|
238
|
+
best_value = grid[0] if grid else 0.0
|
|
239
|
+
|
|
240
|
+
for value in grid:
|
|
241
|
+
params = {**fixed_params, param_name: value}
|
|
242
|
+
|
|
243
|
+
lambda_time = params.get("lambda_time", 0.0)
|
|
244
|
+
lambda_unit = params.get("lambda_unit", 0.0)
|
|
245
|
+
lambda_nn = params.get("lambda_nn", 0.0)
|
|
246
|
+
|
|
247
|
+
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
|
|
248
|
+
# λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed)
|
|
249
|
+
if np.isinf(lambda_nn):
|
|
250
|
+
lambda_nn = 1e10
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
score = self._loocv_score_obs_specific(
|
|
254
|
+
Y,
|
|
255
|
+
D,
|
|
256
|
+
control_mask,
|
|
257
|
+
control_unit_idx,
|
|
258
|
+
lambda_time,
|
|
259
|
+
lambda_unit,
|
|
260
|
+
lambda_nn,
|
|
261
|
+
n_units,
|
|
262
|
+
n_periods,
|
|
263
|
+
)
|
|
264
|
+
if score < best_score:
|
|
265
|
+
best_score = score
|
|
266
|
+
best_value = value
|
|
267
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
return best_value, best_score
|
|
271
|
+
|
|
272
|
+
def _cycling_parameter_search(
|
|
273
|
+
self,
|
|
274
|
+
Y: np.ndarray,
|
|
275
|
+
D: np.ndarray,
|
|
276
|
+
control_mask: np.ndarray,
|
|
277
|
+
control_unit_idx: np.ndarray,
|
|
278
|
+
n_units: int,
|
|
279
|
+
n_periods: int,
|
|
280
|
+
initial_lambda: Tuple[float, float, float],
|
|
281
|
+
max_cycles: int = 10,
|
|
282
|
+
) -> Tuple[float, float, float]:
|
|
283
|
+
"""
|
|
284
|
+
Cycle through parameters until convergence (coordinate descent).
|
|
285
|
+
|
|
286
|
+
Following paper's footnote 2 (Stage 2), this iteratively optimizes
|
|
287
|
+
each tuning parameter while holding the others fixed, until convergence.
|
|
288
|
+
|
|
289
|
+
Parameters
|
|
290
|
+
----------
|
|
291
|
+
Y : np.ndarray
|
|
292
|
+
Outcome matrix (n_periods x n_units).
|
|
293
|
+
D : np.ndarray
|
|
294
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
295
|
+
control_mask : np.ndarray
|
|
296
|
+
Boolean mask for control observations.
|
|
297
|
+
control_unit_idx : np.ndarray
|
|
298
|
+
Indices of control units.
|
|
299
|
+
n_units : int
|
|
300
|
+
Number of units.
|
|
301
|
+
n_periods : int
|
|
302
|
+
Number of periods.
|
|
303
|
+
initial_lambda : Tuple[float, float, float]
|
|
304
|
+
Initial values (lambda_time, lambda_unit, lambda_nn).
|
|
305
|
+
max_cycles : int, default=10
|
|
306
|
+
Maximum number of coordinate descent cycles.
|
|
307
|
+
|
|
308
|
+
Returns
|
|
309
|
+
-------
|
|
310
|
+
Tuple[float, float, float]
|
|
311
|
+
Optimized (lambda_time, lambda_unit, lambda_nn).
|
|
312
|
+
"""
|
|
313
|
+
lambda_time, lambda_unit, lambda_nn = initial_lambda
|
|
314
|
+
prev_score = np.inf
|
|
315
|
+
|
|
316
|
+
for cycle in range(max_cycles):
|
|
317
|
+
# Optimize λ_unit (fix λ_time, λ_nn)
|
|
318
|
+
lambda_unit, _ = self._univariate_loocv_search(
|
|
319
|
+
Y,
|
|
320
|
+
D,
|
|
321
|
+
control_mask,
|
|
322
|
+
control_unit_idx,
|
|
323
|
+
n_units,
|
|
324
|
+
n_periods,
|
|
325
|
+
"lambda_unit",
|
|
326
|
+
self.lambda_unit_grid,
|
|
327
|
+
{"lambda_time": lambda_time, "lambda_nn": lambda_nn},
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Optimize λ_time (fix λ_unit, λ_nn)
|
|
331
|
+
lambda_time, _ = self._univariate_loocv_search(
|
|
332
|
+
Y,
|
|
333
|
+
D,
|
|
334
|
+
control_mask,
|
|
335
|
+
control_unit_idx,
|
|
336
|
+
n_units,
|
|
337
|
+
n_periods,
|
|
338
|
+
"lambda_time",
|
|
339
|
+
self.lambda_time_grid,
|
|
340
|
+
{"lambda_unit": lambda_unit, "lambda_nn": lambda_nn},
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Optimize λ_nn (fix λ_unit, λ_time)
|
|
344
|
+
lambda_nn, score = self._univariate_loocv_search(
|
|
345
|
+
Y,
|
|
346
|
+
D,
|
|
347
|
+
control_mask,
|
|
348
|
+
control_unit_idx,
|
|
349
|
+
n_units,
|
|
350
|
+
n_periods,
|
|
351
|
+
"lambda_nn",
|
|
352
|
+
self.lambda_nn_grid,
|
|
353
|
+
{"lambda_unit": lambda_unit, "lambda_time": lambda_time},
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Check convergence
|
|
357
|
+
if abs(score - prev_score) < 1e-6:
|
|
358
|
+
logger.debug(
|
|
359
|
+
"Cycling search converged after %d cycles with score %.6f", cycle + 1, score
|
|
360
|
+
)
|
|
361
|
+
break
|
|
362
|
+
prev_score = score
|
|
363
|
+
|
|
364
|
+
return lambda_time, lambda_unit, lambda_nn
|
|
365
|
+
|
|
366
|
+
# =========================================================================
|
|
367
|
+
# Main fit method
|
|
368
|
+
# =========================================================================
|
|
369
|
+
|
|
370
|
+
def fit(
|
|
371
|
+
self,
|
|
372
|
+
data: pd.DataFrame,
|
|
373
|
+
outcome: str,
|
|
374
|
+
treatment: str,
|
|
375
|
+
unit: str,
|
|
376
|
+
time: str,
|
|
377
|
+
survey_design=None,
|
|
378
|
+
) -> TROPResults:
|
|
379
|
+
"""
|
|
380
|
+
Fit the TROP model.
|
|
381
|
+
|
|
382
|
+
Parameters
|
|
383
|
+
----------
|
|
384
|
+
data : pd.DataFrame
|
|
385
|
+
Panel data with observations for multiple units over multiple
|
|
386
|
+
time periods.
|
|
387
|
+
outcome : str
|
|
388
|
+
Name of the outcome variable column.
|
|
389
|
+
treatment : str
|
|
390
|
+
Name of the treatment indicator column (0/1).
|
|
391
|
+
|
|
392
|
+
IMPORTANT: This should be an ABSORBING STATE indicator, not a
|
|
393
|
+
treatment timing indicator. For each unit, D=1 for ALL periods
|
|
394
|
+
during and after treatment:
|
|
395
|
+
|
|
396
|
+
- D[t, i] = 0 for all t < g_i (pre-treatment periods)
|
|
397
|
+
- D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
|
|
398
|
+
|
|
399
|
+
where g_i is the treatment start time for unit i.
|
|
400
|
+
|
|
401
|
+
For staggered adoption, different units can have different g_i.
|
|
402
|
+
The ATT averages over ALL D=1 cells per Equation 1 of the paper.
|
|
403
|
+
unit : str
|
|
404
|
+
Name of the unit identifier column.
|
|
405
|
+
time : str
|
|
406
|
+
Name of the time period column.
|
|
407
|
+
survey_design : SurveyDesign, optional
|
|
408
|
+
Survey design specification. Supports pweight, strata, PSU, and
|
|
409
|
+
FPC. Full-design surveys (strata/PSU/FPC) use Rao-Wu rescaled
|
|
410
|
+
bootstrap; Rust backend is pweight-only (Python fallback for
|
|
411
|
+
full design). Survey weights enter ATT aggregation only.
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
TROPResults
|
|
416
|
+
Object containing the ATT estimate, standard error,
|
|
417
|
+
factor estimates, and tuning parameters. The lambda_*
|
|
418
|
+
attributes show the selected grid values. For lambda_time and
|
|
419
|
+
lambda_unit, 0.0 means uniform weights; inf is not accepted.
|
|
420
|
+
For lambda_nn, inf is converted to 1e10 (factor model disabled).
|
|
421
|
+
|
|
422
|
+
Raises
|
|
423
|
+
------
|
|
424
|
+
ValueError
|
|
425
|
+
If required columns are missing or non-pweight survey design.
|
|
426
|
+
"""
|
|
427
|
+
# Validate inputs
|
|
428
|
+
required_cols = [outcome, treatment, unit, time]
|
|
429
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
430
|
+
if missing:
|
|
431
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
432
|
+
|
|
433
|
+
# Resolve survey design
|
|
434
|
+
from diff_diff.survey import (
|
|
435
|
+
_extract_unit_survey_weights,
|
|
436
|
+
_resolve_survey_for_fit,
|
|
437
|
+
_validate_unit_constant_survey,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
resolved_survey, _survey_weights, _survey_wt, survey_metadata = _resolve_survey_for_fit(
|
|
441
|
+
survey_design, data, "analytical"
|
|
442
|
+
)
|
|
443
|
+
# Reject replicate-weight designs — TROP uses Rao-Wu bootstrap
|
|
444
|
+
if resolved_survey is not None and resolved_survey.uses_replicate_variance:
|
|
445
|
+
raise NotImplementedError(
|
|
446
|
+
"TROP does not yet support replicate-weight survey designs. "
|
|
447
|
+
"Use a TSL-based survey design (strata/psu/fpc)."
|
|
448
|
+
)
|
|
449
|
+
# Validate weight_type is pweight (keep restriction), but allow
|
|
450
|
+
# strata/PSU/FPC — those are handled via Rao-Wu rescaled bootstrap.
|
|
451
|
+
if resolved_survey is not None and resolved_survey.weight_type != "pweight":
|
|
452
|
+
raise ValueError(
|
|
453
|
+
"TROP requires pweight survey weights. "
|
|
454
|
+
f"Got weight_type='{resolved_survey.weight_type}'."
|
|
455
|
+
)
|
|
456
|
+
if resolved_survey is not None:
|
|
457
|
+
_validate_unit_constant_survey(data, unit, survey_design)
|
|
458
|
+
|
|
459
|
+
# Dispatch based on estimation method
|
|
460
|
+
if self.method == "global":
|
|
461
|
+
return self._fit_global(
|
|
462
|
+
data,
|
|
463
|
+
outcome,
|
|
464
|
+
treatment,
|
|
465
|
+
unit,
|
|
466
|
+
time,
|
|
467
|
+
resolved_survey=resolved_survey,
|
|
468
|
+
survey_metadata=survey_metadata,
|
|
469
|
+
survey_design=survey_design,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Below is the local method (default)
|
|
473
|
+
# Get unique units and periods
|
|
474
|
+
all_units = sorted(data[unit].unique())
|
|
475
|
+
|
|
476
|
+
# Extract unit-level survey weights
|
|
477
|
+
if resolved_survey is not None:
|
|
478
|
+
unit_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units)
|
|
479
|
+
else:
|
|
480
|
+
unit_weight_arr = None
|
|
481
|
+
all_periods = sorted(data[time].unique())
|
|
482
|
+
|
|
483
|
+
n_units = len(all_units)
|
|
484
|
+
n_periods = len(all_periods)
|
|
485
|
+
|
|
486
|
+
# Create mappings
|
|
487
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
488
|
+
period_to_idx = {p: i for i, p in enumerate(all_periods)}
|
|
489
|
+
idx_to_unit = {i: u for u, i in unit_to_idx.items()}
|
|
490
|
+
idx_to_period = {i: p for p, i in period_to_idx.items()}
|
|
491
|
+
|
|
492
|
+
# Create outcome matrix Y (n_periods x n_units) and treatment matrix D
|
|
493
|
+
# Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
|
|
494
|
+
Y = (
|
|
495
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
496
|
+
.reindex(index=all_periods, columns=all_units)
|
|
497
|
+
.values
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# For D matrix, validate observed treatment and handle unbalanced panels
|
|
501
|
+
D, missing_mask = _validate_and_pivot_treatment(
|
|
502
|
+
data, time, unit, treatment, all_periods, all_units
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Validate D is monotonic non-decreasing per unit (absorbing state)
|
|
506
|
+
# D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
|
|
507
|
+
# Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
|
|
508
|
+
# This catches 1->0 violations that span missing period gaps
|
|
509
|
+
# Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
|
|
510
|
+
# adjacent period transitions don't show it (the gap hides the transition)
|
|
511
|
+
violating_units = []
|
|
512
|
+
for unit_idx in range(n_units):
|
|
513
|
+
# Get observed D values for this unit (where not missing)
|
|
514
|
+
observed_mask = ~missing_mask[:, unit_idx]
|
|
515
|
+
observed_d = D[observed_mask, unit_idx]
|
|
516
|
+
|
|
517
|
+
# Check if observed sequence is monotonically non-decreasing
|
|
518
|
+
if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
|
|
519
|
+
violating_units.append(all_units[unit_idx])
|
|
520
|
+
|
|
521
|
+
if violating_units:
|
|
522
|
+
raise ValueError(
|
|
523
|
+
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
|
|
524
|
+
f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
|
|
525
|
+
f"If this is event-study style data, convert to absorbing state: "
|
|
526
|
+
f"D[t, i] = 1 for all t >= first treatment period."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Identify treated observations
|
|
530
|
+
treated_mask = D == 1
|
|
531
|
+
n_treated_obs = np.sum(treated_mask)
|
|
532
|
+
|
|
533
|
+
if n_treated_obs == 0:
|
|
534
|
+
raise ValueError("No treated observations found")
|
|
535
|
+
|
|
536
|
+
# Identify treated and control units
|
|
537
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
538
|
+
treated_unit_idx = np.where(unit_ever_treated)[0]
|
|
539
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
540
|
+
|
|
541
|
+
if len(control_unit_idx) == 0:
|
|
542
|
+
raise ValueError("No control units found")
|
|
543
|
+
|
|
544
|
+
# Determine pre/post periods from treatment indicator D
|
|
545
|
+
# D matrix is the sole input for treatment timing per the paper
|
|
546
|
+
first_treat_period = None
|
|
547
|
+
for t in range(n_periods):
|
|
548
|
+
if np.any(D[t, :] == 1):
|
|
549
|
+
first_treat_period = t
|
|
550
|
+
break
|
|
551
|
+
if first_treat_period is None:
|
|
552
|
+
raise ValueError("Could not infer post-treatment periods from D matrix")
|
|
553
|
+
|
|
554
|
+
n_pre_periods = first_treat_period
|
|
555
|
+
# Count periods where D=1 is actually observed (matches docstring)
|
|
556
|
+
# Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
|
|
557
|
+
n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
|
|
558
|
+
|
|
559
|
+
if n_pre_periods < 2:
|
|
560
|
+
raise ValueError("Need at least 2 pre-treatment periods")
|
|
561
|
+
|
|
562
|
+
# Step 1: Grid search with LOOCV for tuning parameters
|
|
563
|
+
best_lambda = None
|
|
564
|
+
best_score = np.inf
|
|
565
|
+
|
|
566
|
+
# Control observations mask (for LOOCV)
|
|
567
|
+
control_mask = D == 0
|
|
568
|
+
|
|
569
|
+
# Pre-compute structures that are reused across LOOCV iterations
|
|
570
|
+
self._precomputed = self._precompute_structures(Y, D, control_unit_idx, n_units, n_periods)
|
|
571
|
+
|
|
572
|
+
# Use Rust backend for parallel LOOCV grid search (10-50x speedup)
|
|
573
|
+
if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None:
|
|
574
|
+
try:
|
|
575
|
+
# Prepare inputs for Rust function
|
|
576
|
+
control_mask_u8 = control_mask.astype(np.uint8)
|
|
577
|
+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
|
|
578
|
+
|
|
579
|
+
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
|
|
580
|
+
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
|
|
581
|
+
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
|
|
582
|
+
|
|
583
|
+
result = _rust_loocv_grid_search(
|
|
584
|
+
Y,
|
|
585
|
+
D.astype(np.float64),
|
|
586
|
+
control_mask_u8,
|
|
587
|
+
time_dist_matrix,
|
|
588
|
+
lambda_time_arr,
|
|
589
|
+
lambda_unit_arr,
|
|
590
|
+
lambda_nn_arr,
|
|
591
|
+
self.max_iter,
|
|
592
|
+
self.tol,
|
|
593
|
+
)
|
|
594
|
+
# Unpack result - 7 values including optional first_failed_obs
|
|
595
|
+
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = (
|
|
596
|
+
result
|
|
597
|
+
)
|
|
598
|
+
# Only accept finite scores - infinite means all fits failed
|
|
599
|
+
if np.isfinite(best_score):
|
|
600
|
+
best_lambda = (best_lt, best_lu, best_ln)
|
|
601
|
+
# else: best_lambda stays None, triggering defaults fallback
|
|
602
|
+
# Emit warnings consistent with Python implementation
|
|
603
|
+
if n_valid == 0:
|
|
604
|
+
# Include failed observation coordinates if available (Issue 2 fix)
|
|
605
|
+
obs_info = ""
|
|
606
|
+
if first_failed_obs is not None:
|
|
607
|
+
t_idx, i_idx = first_failed_obs
|
|
608
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
609
|
+
warnings.warn(
|
|
610
|
+
f"LOOCV: All {n_attempted} fits failed for "
|
|
611
|
+
f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
|
|
612
|
+
f"Returning infinite score.{obs_info}",
|
|
613
|
+
UserWarning,
|
|
614
|
+
)
|
|
615
|
+
elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
|
|
616
|
+
n_failed = n_attempted - n_valid
|
|
617
|
+
# Include failed observation coordinates if available
|
|
618
|
+
obs_info = ""
|
|
619
|
+
if first_failed_obs is not None:
|
|
620
|
+
t_idx, i_idx = first_failed_obs
|
|
621
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
622
|
+
warnings.warn(
|
|
623
|
+
f"LOOCV: {n_failed}/{n_attempted} fits failed for "
|
|
624
|
+
f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
|
|
625
|
+
f"This may indicate numerical instability.{obs_info}",
|
|
626
|
+
UserWarning,
|
|
627
|
+
)
|
|
628
|
+
except Exception as e:
|
|
629
|
+
# Fall back to Python implementation on error
|
|
630
|
+
logger.debug("Rust LOOCV grid search failed, falling back to Python: %s", e)
|
|
631
|
+
warnings.warn(
|
|
632
|
+
f"Rust backend failed for LOOCV grid search; "
|
|
633
|
+
f"falling back to Python. Performance may be reduced. "
|
|
634
|
+
f"Error: {e}",
|
|
635
|
+
UserWarning,
|
|
636
|
+
stacklevel=2,
|
|
637
|
+
)
|
|
638
|
+
best_lambda = None
|
|
639
|
+
best_score = np.inf
|
|
640
|
+
|
|
641
|
+
# Fall back to Python implementation if Rust unavailable or failed
|
|
642
|
+
# Uses two-stage approach per paper's footnote 2:
|
|
643
|
+
# Stage 1: Univariate searches for initial values
|
|
644
|
+
# Stage 2: Cycling (coordinate descent) until convergence
|
|
645
|
+
if best_lambda is None:
|
|
646
|
+
# Stage 1: Univariate searches with extreme fixed values
|
|
647
|
+
# Following paper's footnote 2 for initial bounds
|
|
648
|
+
|
|
649
|
+
# λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
|
|
650
|
+
lambda_time_init, _ = self._univariate_loocv_search(
|
|
651
|
+
Y,
|
|
652
|
+
D,
|
|
653
|
+
control_mask,
|
|
654
|
+
control_unit_idx,
|
|
655
|
+
n_units,
|
|
656
|
+
n_periods,
|
|
657
|
+
"lambda_time",
|
|
658
|
+
self.lambda_time_grid,
|
|
659
|
+
{"lambda_unit": 0.0, "lambda_nn": _LAMBDA_INF},
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0
|
|
663
|
+
lambda_nn_init, _ = self._univariate_loocv_search(
|
|
664
|
+
Y,
|
|
665
|
+
D,
|
|
666
|
+
control_mask,
|
|
667
|
+
control_unit_idx,
|
|
668
|
+
n_units,
|
|
669
|
+
n_periods,
|
|
670
|
+
"lambda_nn",
|
|
671
|
+
self.lambda_nn_grid,
|
|
672
|
+
{"lambda_time": 0.0, "lambda_unit": 0.0},
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# λ_unit search: fix λ_nn=∞, λ_time=0
|
|
676
|
+
lambda_unit_init, _ = self._univariate_loocv_search(
|
|
677
|
+
Y,
|
|
678
|
+
D,
|
|
679
|
+
control_mask,
|
|
680
|
+
control_unit_idx,
|
|
681
|
+
n_units,
|
|
682
|
+
n_periods,
|
|
683
|
+
"lambda_unit",
|
|
684
|
+
self.lambda_unit_grid,
|
|
685
|
+
{"lambda_nn": _LAMBDA_INF, "lambda_time": 0.0},
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Stage 2: Cycling refinement (coordinate descent)
|
|
689
|
+
lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
|
|
690
|
+
Y,
|
|
691
|
+
D,
|
|
692
|
+
control_mask,
|
|
693
|
+
control_unit_idx,
|
|
694
|
+
n_units,
|
|
695
|
+
n_periods,
|
|
696
|
+
(lambda_time_init, lambda_unit_init, lambda_nn_init),
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
# Compute final score for the optimized parameters
|
|
700
|
+
try:
|
|
701
|
+
best_score = self._loocv_score_obs_specific(
|
|
702
|
+
Y,
|
|
703
|
+
D,
|
|
704
|
+
control_mask,
|
|
705
|
+
control_unit_idx,
|
|
706
|
+
lambda_time,
|
|
707
|
+
lambda_unit,
|
|
708
|
+
lambda_nn,
|
|
709
|
+
n_units,
|
|
710
|
+
n_periods,
|
|
711
|
+
)
|
|
712
|
+
# Only accept finite scores - infinite means all fits failed
|
|
713
|
+
if np.isfinite(best_score):
|
|
714
|
+
best_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
715
|
+
# else: best_lambda stays None, triggering defaults fallback
|
|
716
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
717
|
+
# If even the optimized parameters fail, best_lambda stays None
|
|
718
|
+
pass
|
|
719
|
+
|
|
720
|
+
if best_lambda is None:
|
|
721
|
+
warnings.warn("All tuning parameter combinations failed. Using defaults.", UserWarning)
|
|
722
|
+
best_lambda = (1.0, 1.0, 0.1)
|
|
723
|
+
best_score = np.nan
|
|
724
|
+
|
|
725
|
+
self._optimal_lambda = best_lambda
|
|
726
|
+
lambda_time, lambda_unit, lambda_nn = best_lambda
|
|
727
|
+
|
|
728
|
+
# Store original λ_nn for results (only λ_nn needs original→effective conversion).
|
|
729
|
+
# λ_time and λ_unit use 0.0 for uniform weights directly per Eq. 3.
|
|
730
|
+
original_lambda_nn = lambda_nn
|
|
731
|
+
|
|
732
|
+
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
|
|
733
|
+
if np.isinf(lambda_nn):
|
|
734
|
+
lambda_nn = 1e10
|
|
735
|
+
|
|
736
|
+
# effective_lambda with converted λ_nn for ALL downstream computation
|
|
737
|
+
# (variance estimation uses the same parameters as point estimation)
|
|
738
|
+
effective_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
739
|
+
|
|
740
|
+
# Step 2: Final estimation - per-observation model fitting following Algorithm 2
|
|
741
|
+
# For each treated (i,t): compute observation-specific weights, fit model, compute tau_{it}
|
|
742
|
+
treatment_effects = {}
|
|
743
|
+
tau_values = []
|
|
744
|
+
tau_weights = [] # parallel to tau_values for survey-weighted ATT
|
|
745
|
+
alpha_estimates = []
|
|
746
|
+
beta_estimates = []
|
|
747
|
+
L_estimates = []
|
|
748
|
+
|
|
749
|
+
# Use pre-computed treated observations
|
|
750
|
+
treated_observations = self._precomputed["treated_observations"]
|
|
751
|
+
|
|
752
|
+
for t, i in treated_observations:
|
|
753
|
+
unit_id = idx_to_unit[i]
|
|
754
|
+
time_id = idx_to_period[t]
|
|
755
|
+
|
|
756
|
+
# Skip observations where outcome is missing -- record NaN but
|
|
757
|
+
# don't fit the model or include in tau_values (avoids NaN poisoning)
|
|
758
|
+
if not np.isfinite(Y[t, i]):
|
|
759
|
+
treatment_effects[(unit_id, time_id)] = np.nan
|
|
760
|
+
continue
|
|
761
|
+
|
|
762
|
+
# Compute observation-specific weights for this (i, t)
|
|
763
|
+
weight_matrix = self._compute_observation_weights(
|
|
764
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# Fit model with these weights
|
|
768
|
+
alpha_hat, beta_hat, L_hat = self._estimate_model(
|
|
769
|
+
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it}
|
|
773
|
+
tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
|
|
774
|
+
|
|
775
|
+
treatment_effects[(unit_id, time_id)] = tau_it
|
|
776
|
+
tau_values.append(tau_it)
|
|
777
|
+
if unit_weight_arr is not None:
|
|
778
|
+
tau_weights.append(unit_weight_arr[i])
|
|
779
|
+
|
|
780
|
+
# Store for averaging
|
|
781
|
+
alpha_estimates.append(alpha_hat)
|
|
782
|
+
beta_estimates.append(beta_hat)
|
|
783
|
+
L_estimates.append(L_hat)
|
|
784
|
+
|
|
785
|
+
# Count valid treated observations
|
|
786
|
+
n_valid_treated = len(tau_values)
|
|
787
|
+
if n_valid_treated == 0:
|
|
788
|
+
warnings.warn(
|
|
789
|
+
"All treated outcomes are NaN/missing. Cannot estimate ATT.",
|
|
790
|
+
UserWarning,
|
|
791
|
+
)
|
|
792
|
+
elif n_valid_treated < n_treated_obs:
|
|
793
|
+
warnings.warn(
|
|
794
|
+
f"Only {n_valid_treated} of {n_treated_obs} treated outcomes are finite. "
|
|
795
|
+
"df and n_treated_obs reflect valid observations only.",
|
|
796
|
+
UserWarning,
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# Average ATT (survey-weighted when applicable)
|
|
800
|
+
if unit_weight_arr is not None and tau_values:
|
|
801
|
+
att = float(np.average(tau_values, weights=tau_weights))
|
|
802
|
+
else:
|
|
803
|
+
att = np.mean(tau_values) if tau_values else np.nan
|
|
804
|
+
|
|
805
|
+
# Average parameter estimates for output (representative)
|
|
806
|
+
alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
|
|
807
|
+
beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
|
|
808
|
+
L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
|
|
809
|
+
|
|
810
|
+
# Compute effective rank
|
|
811
|
+
_, s, _ = np.linalg.svd(L_hat, full_matrices=False)
|
|
812
|
+
if s[0] > 0:
|
|
813
|
+
effective_rank = np.sum(s) / s[0]
|
|
814
|
+
else:
|
|
815
|
+
effective_rank = 0.0
|
|
816
|
+
|
|
817
|
+
# Step 4: Variance estimation
|
|
818
|
+
# Use effective_lambda (converted values) to ensure SE is computed with same
|
|
819
|
+
# parameters as point estimation. This fixes the variance inconsistency issue.
|
|
820
|
+
se, bootstrap_dist = self._bootstrap_variance(
|
|
821
|
+
data,
|
|
822
|
+
outcome,
|
|
823
|
+
treatment,
|
|
824
|
+
unit,
|
|
825
|
+
time,
|
|
826
|
+
effective_lambda,
|
|
827
|
+
Y=Y,
|
|
828
|
+
D=D,
|
|
829
|
+
control_unit_idx=control_unit_idx,
|
|
830
|
+
survey_design=survey_design,
|
|
831
|
+
unit_weight_arr=unit_weight_arr,
|
|
832
|
+
resolved_survey=resolved_survey,
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
# Compute test statistics
|
|
836
|
+
df_trop = max(1, n_valid_treated - 1)
|
|
837
|
+
t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop)
|
|
838
|
+
|
|
839
|
+
# Create results dictionaries
|
|
840
|
+
unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
|
|
841
|
+
time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
|
|
842
|
+
|
|
843
|
+
# Store results
|
|
844
|
+
self.results_ = TROPResults(
|
|
845
|
+
att=att,
|
|
846
|
+
se=se,
|
|
847
|
+
t_stat=t_stat,
|
|
848
|
+
p_value=p_value,
|
|
849
|
+
conf_int=conf_int,
|
|
850
|
+
n_obs=len(data),
|
|
851
|
+
n_treated=len(treated_unit_idx),
|
|
852
|
+
n_control=len(control_unit_idx),
|
|
853
|
+
n_treated_obs=int(n_valid_treated),
|
|
854
|
+
unit_effects=unit_effects_dict,
|
|
855
|
+
time_effects=time_effects_dict,
|
|
856
|
+
treatment_effects=treatment_effects,
|
|
857
|
+
lambda_time=lambda_time,
|
|
858
|
+
lambda_unit=lambda_unit,
|
|
859
|
+
lambda_nn=original_lambda_nn,
|
|
860
|
+
factor_matrix=L_hat,
|
|
861
|
+
effective_rank=effective_rank,
|
|
862
|
+
loocv_score=best_score,
|
|
863
|
+
alpha=self.alpha,
|
|
864
|
+
n_pre_periods=n_pre_periods,
|
|
865
|
+
n_post_periods=n_post_periods,
|
|
866
|
+
n_bootstrap=self.n_bootstrap,
|
|
867
|
+
bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
|
|
868
|
+
survey_metadata=survey_metadata,
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
self.is_fitted_ = True
|
|
872
|
+
return self.results_
|
|
873
|
+
|
|
874
|
+
# =========================================================================
|
|
875
|
+
# sklearn-like API
|
|
876
|
+
# =========================================================================
|
|
877
|
+
|
|
878
|
+
def get_params(self) -> Dict[str, Any]:
|
|
879
|
+
"""Get estimator parameters."""
|
|
880
|
+
return {
|
|
881
|
+
"method": self.method,
|
|
882
|
+
"lambda_time_grid": self.lambda_time_grid,
|
|
883
|
+
"lambda_unit_grid": self.lambda_unit_grid,
|
|
884
|
+
"lambda_nn_grid": self.lambda_nn_grid,
|
|
885
|
+
"max_iter": self.max_iter,
|
|
886
|
+
"tol": self.tol,
|
|
887
|
+
"alpha": self.alpha,
|
|
888
|
+
"n_bootstrap": self.n_bootstrap,
|
|
889
|
+
"seed": self.seed,
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
def set_params(self, **params) -> "TROP":
|
|
893
|
+
"""Set estimator parameters."""
|
|
894
|
+
for key, value in params.items():
|
|
895
|
+
if key == "method" and value not in ("local", "global"):
|
|
896
|
+
raise ValueError(
|
|
897
|
+
f"method must be one of ('local', 'global'), got '{value}'"
|
|
898
|
+
)
|
|
899
|
+
if hasattr(self, key):
|
|
900
|
+
setattr(self, key, value)
|
|
901
|
+
else:
|
|
902
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
903
|
+
return self
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def trop(
|
|
907
|
+
data: pd.DataFrame,
|
|
908
|
+
outcome: str,
|
|
909
|
+
treatment: str,
|
|
910
|
+
unit: str,
|
|
911
|
+
time: str,
|
|
912
|
+
survey_design=None,
|
|
913
|
+
**kwargs,
|
|
914
|
+
) -> TROPResults:
|
|
915
|
+
"""
|
|
916
|
+
Convenience function for TROP estimation.
|
|
917
|
+
|
|
918
|
+
Parameters
|
|
919
|
+
----------
|
|
920
|
+
data : pd.DataFrame
|
|
921
|
+
Panel data.
|
|
922
|
+
outcome : str
|
|
923
|
+
Outcome variable column name.
|
|
924
|
+
treatment : str
|
|
925
|
+
Treatment indicator column name (0/1).
|
|
926
|
+
|
|
927
|
+
IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
|
|
928
|
+
timing indicator. For each unit, D=1 for ALL periods during and after
|
|
929
|
+
treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
|
|
930
|
+
the treatment start time for unit i).
|
|
931
|
+
unit : str
|
|
932
|
+
Unit identifier column name.
|
|
933
|
+
time : str
|
|
934
|
+
Time period column name.
|
|
935
|
+
survey_design : SurveyDesign, optional
|
|
936
|
+
Survey design specification. Supports pweight, strata, PSU, and FPC.
|
|
937
|
+
**kwargs
|
|
938
|
+
Additional arguments passed to TROP constructor.
|
|
939
|
+
|
|
940
|
+
Returns
|
|
941
|
+
-------
|
|
942
|
+
TROPResults
|
|
943
|
+
Estimation results.
|
|
944
|
+
|
|
945
|
+
Examples
|
|
946
|
+
--------
|
|
947
|
+
>>> from diff_diff import trop
|
|
948
|
+
>>> results = trop(data, 'y', 'treated', 'unit', 'time')
|
|
949
|
+
>>> print(f"ATT: {results.att:.3f}")
|
|
950
|
+
"""
|
|
951
|
+
estimator = TROP(**kwargs)
|
|
952
|
+
return estimator.fit(data, outcome, treatment, unit, time, survey_design=survey_design)
|