diff-diff 2.1.7__tar.gz → 2.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diff_diff-2.1.7 → diff_diff-2.1.8}/PKG-INFO +14 -9
- {diff_diff-2.1.7 → diff_diff-2.1.8}/README.md +13 -8
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/__init__.py +1 -1
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/trop.py +394 -88
- {diff_diff-2.1.7 → diff_diff-2.1.8}/pyproject.toml +1 -1
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/Cargo.lock +3 -3
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/Cargo.toml +1 -1
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/trop.rs +217 -45
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/_backend.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/bacon.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/datasets.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/diagnostics.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/estimators.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/honest_did.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/linalg.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/power.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/prep.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/prep_dgp.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/pretrends.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/results.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered_aggregation.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered_bootstrap.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered_results.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/sun_abraham.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/synthetic_did.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/triple_diff.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/twfe.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/utils.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/visualization.py +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/bootstrap.rs +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/lib.rs +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/linalg.rs +0 -0
- {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/weights.rs +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diff-diff
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.8
|
|
4
4
|
Classifier: Development Status :: 5 - Production/Stable
|
|
5
5
|
Classifier: Intended Audience :: Science/Research
|
|
6
6
|
Classifier: Operating System :: OS Independent
|
|
@@ -1173,13 +1173,15 @@ trop_est = TROP(
|
|
|
1173
1173
|
lambda_nn_grid=[0.0, 0.1, 1.0], # Nuclear norm grid
|
|
1174
1174
|
n_bootstrap=200
|
|
1175
1175
|
)
|
|
1176
|
+
# Note: TROP infers treatment periods from the treatment indicator column.
|
|
1177
|
+
# The 'treated' column must be an absorbing state (D=1 for all periods
|
|
1178
|
+
# during and after treatment starts for each unit).
|
|
1176
1179
|
results = trop_est.fit(
|
|
1177
1180
|
panel_data,
|
|
1178
1181
|
outcome='gdp_growth',
|
|
1179
1182
|
treatment='treated',
|
|
1180
1183
|
unit='state',
|
|
1181
|
-
time='year'
|
|
1182
|
-
post_periods=[2015, 2016, 2017, 2018]
|
|
1184
|
+
time='year'
|
|
1183
1185
|
)
|
|
1184
1186
|
|
|
1185
1187
|
# View results
|
|
@@ -1267,9 +1269,11 @@ sdid_results = sdid.fit(data, outcome='y', treatment='treated',
|
|
|
1267
1269
|
unit='unit', time='time', post_periods=[5,6,7])
|
|
1268
1270
|
|
|
1269
1271
|
# TROP (accounts for factors)
|
|
1272
|
+
# Note: TROP infers treatment periods from the treatment indicator column
|
|
1273
|
+
# (D=1 for treated observations, D=0 for control)
|
|
1270
1274
|
trop_est = TROP() # Uses default grids with LOOCV selection
|
|
1271
1275
|
trop_results = trop_est.fit(data, outcome='y', treatment='treated',
|
|
1272
|
-
unit='unit', time='time'
|
|
1276
|
+
unit='unit', time='time')
|
|
1273
1277
|
|
|
1274
1278
|
print(f"SDID estimate: {sdid_results.att:.3f}")
|
|
1275
1279
|
print(f"TROP estimate: {trop_results.att:.3f}")
|
|
@@ -1314,13 +1318,13 @@ TROP(
|
|
|
1314
1318
|
|
|
1315
1319
|
```python
|
|
1316
1320
|
# One-liner estimation with default tuning grids
|
|
1321
|
+
# Note: TROP infers treatment periods from the treatment indicator
|
|
1317
1322
|
results = trop(
|
|
1318
1323
|
data,
|
|
1319
1324
|
outcome='y',
|
|
1320
1325
|
treatment='treated',
|
|
1321
1326
|
unit='unit',
|
|
1322
1327
|
time='time',
|
|
1323
|
-
post_periods=[5, 6, 7],
|
|
1324
1328
|
n_bootstrap=200
|
|
1325
1329
|
)
|
|
1326
1330
|
```
|
|
@@ -1912,10 +1916,11 @@ TROP(
|
|
|
1912
1916
|
|-----------|------|-------------|
|
|
1913
1917
|
| `data` | DataFrame | Panel data |
|
|
1914
1918
|
| `outcome` | str | Outcome variable column name |
|
|
1915
|
-
| `treatment` | str | Treatment indicator column (0/1) |
|
|
1919
|
+
| `treatment` | str | Treatment indicator column (0/1 absorbing state) |
|
|
1916
1920
|
| `unit` | str | Unit identifier column |
|
|
1917
1921
|
| `time` | str | Time period column |
|
|
1918
|
-
|
|
1922
|
+
|
|
1923
|
+
Note: TROP infers treatment periods from the treatment indicator column. The treatment column should be an absorbing state indicator where D=1 for all periods during and after treatment starts.
|
|
1919
1924
|
|
|
1920
1925
|
### TROPResults
|
|
1921
1926
|
|
|
@@ -1941,8 +1946,8 @@ TROP(
|
|
|
1941
1946
|
| `factor_matrix` | Low-rank factor matrix L (n_periods x n_units) |
|
|
1942
1947
|
| `effective_rank` | Effective rank of factor matrix |
|
|
1943
1948
|
| `loocv_score` | LOOCV score for selected parameters |
|
|
1944
|
-
| `
|
|
1945
|
-
| `
|
|
1949
|
+
| `n_pre_periods` | Number of pre-treatment periods |
|
|
1950
|
+
| `n_post_periods` | Number of post-treatment periods |
|
|
1946
1951
|
| `variance_method` | Variance estimation method |
|
|
1947
1952
|
| `bootstrap_distribution` | Bootstrap distribution (if bootstrap) |
|
|
1948
1953
|
|
|
@@ -1138,13 +1138,15 @@ trop_est = TROP(
|
|
|
1138
1138
|
lambda_nn_grid=[0.0, 0.1, 1.0], # Nuclear norm grid
|
|
1139
1139
|
n_bootstrap=200
|
|
1140
1140
|
)
|
|
1141
|
+
# Note: TROP infers treatment periods from the treatment indicator column.
|
|
1142
|
+
# The 'treated' column must be an absorbing state (D=1 for all periods
|
|
1143
|
+
# during and after treatment starts for each unit).
|
|
1141
1144
|
results = trop_est.fit(
|
|
1142
1145
|
panel_data,
|
|
1143
1146
|
outcome='gdp_growth',
|
|
1144
1147
|
treatment='treated',
|
|
1145
1148
|
unit='state',
|
|
1146
|
-
time='year'
|
|
1147
|
-
post_periods=[2015, 2016, 2017, 2018]
|
|
1149
|
+
time='year'
|
|
1148
1150
|
)
|
|
1149
1151
|
|
|
1150
1152
|
# View results
|
|
@@ -1232,9 +1234,11 @@ sdid_results = sdid.fit(data, outcome='y', treatment='treated',
|
|
|
1232
1234
|
unit='unit', time='time', post_periods=[5,6,7])
|
|
1233
1235
|
|
|
1234
1236
|
# TROP (accounts for factors)
|
|
1237
|
+
# Note: TROP infers treatment periods from the treatment indicator column
|
|
1238
|
+
# (D=1 for treated observations, D=0 for control)
|
|
1235
1239
|
trop_est = TROP() # Uses default grids with LOOCV selection
|
|
1236
1240
|
trop_results = trop_est.fit(data, outcome='y', treatment='treated',
|
|
1237
|
-
unit='unit', time='time'
|
|
1241
|
+
unit='unit', time='time')
|
|
1238
1242
|
|
|
1239
1243
|
print(f"SDID estimate: {sdid_results.att:.3f}")
|
|
1240
1244
|
print(f"TROP estimate: {trop_results.att:.3f}")
|
|
@@ -1279,13 +1283,13 @@ TROP(
|
|
|
1279
1283
|
|
|
1280
1284
|
```python
|
|
1281
1285
|
# One-liner estimation with default tuning grids
|
|
1286
|
+
# Note: TROP infers treatment periods from the treatment indicator
|
|
1282
1287
|
results = trop(
|
|
1283
1288
|
data,
|
|
1284
1289
|
outcome='y',
|
|
1285
1290
|
treatment='treated',
|
|
1286
1291
|
unit='unit',
|
|
1287
1292
|
time='time',
|
|
1288
|
-
post_periods=[5, 6, 7],
|
|
1289
1293
|
n_bootstrap=200
|
|
1290
1294
|
)
|
|
1291
1295
|
```
|
|
@@ -1877,10 +1881,11 @@ TROP(
|
|
|
1877
1881
|
|-----------|------|-------------|
|
|
1878
1882
|
| `data` | DataFrame | Panel data |
|
|
1879
1883
|
| `outcome` | str | Outcome variable column name |
|
|
1880
|
-
| `treatment` | str | Treatment indicator column (0/1) |
|
|
1884
|
+
| `treatment` | str | Treatment indicator column (0/1 absorbing state) |
|
|
1881
1885
|
| `unit` | str | Unit identifier column |
|
|
1882
1886
|
| `time` | str | Time period column |
|
|
1883
|
-
|
|
1887
|
+
|
|
1888
|
+
Note: TROP infers treatment periods from the treatment indicator column. The treatment column should be an absorbing state indicator where D=1 for all periods during and after treatment starts.
|
|
1884
1889
|
|
|
1885
1890
|
### TROPResults
|
|
1886
1891
|
|
|
@@ -1906,8 +1911,8 @@ TROP(
|
|
|
1906
1911
|
| `factor_matrix` | Low-rank factor matrix L (n_periods x n_units) |
|
|
1907
1912
|
| `effective_rank` | Effective rank of factor matrix |
|
|
1908
1913
|
| `loocv_score` | LOOCV score for selected parameters |
|
|
1909
|
-
| `
|
|
1910
|
-
| `
|
|
1914
|
+
| `n_pre_periods` | Number of pre-treatment periods |
|
|
1915
|
+
| `n_post_periods` | Number of post-treatment periods |
|
|
1911
1916
|
| `variance_method` | Variance estimation method |
|
|
1912
1917
|
| `bootstrap_distribution` | Bootstrap distribution (if bootstrap) |
|
|
1913
1918
|
|
|
@@ -43,6 +43,11 @@ from diff_diff.results import _get_significance_stars
|
|
|
43
43
|
from diff_diff.utils import compute_confidence_interval, compute_p_value
|
|
44
44
|
|
|
45
45
|
|
|
46
|
+
# Sentinel value for "disabled" mode in LOOCV parameter search
|
|
47
|
+
# Following paper's footnote 2: λ=∞ disables the corresponding component
|
|
48
|
+
_LAMBDA_INF: float = float('inf')
|
|
49
|
+
|
|
50
|
+
|
|
46
51
|
class _PrecomputedStructures(TypedDict):
|
|
47
52
|
"""Type definition for pre-computed structures used across LOOCV iterations.
|
|
48
53
|
|
|
@@ -109,11 +114,15 @@ class TROPResults:
|
|
|
109
114
|
treatment_effects : dict
|
|
110
115
|
Individual treatment effects for each treated (unit, time) pair.
|
|
111
116
|
lambda_time : float
|
|
112
|
-
Selected time weight decay parameter.
|
|
117
|
+
Selected time weight decay parameter from grid. Note: infinity values
|
|
118
|
+
are converted internally (∞ → 0.0 for uniform weights) for computation.
|
|
113
119
|
lambda_unit : float
|
|
114
|
-
Selected unit weight decay parameter.
|
|
120
|
+
Selected unit weight decay parameter from grid. Note: infinity values
|
|
121
|
+
are converted internally (∞ → 0.0 for uniform weights) for computation.
|
|
115
122
|
lambda_nn : float
|
|
116
|
-
Selected nuclear norm regularization parameter.
|
|
123
|
+
Selected nuclear norm regularization parameter from grid. Note: infinity
|
|
124
|
+
values are converted internally (∞ → 1e10, factor model disabled) for
|
|
125
|
+
computation.
|
|
117
126
|
factor_matrix : np.ndarray
|
|
118
127
|
Estimated low-rank factor matrix L (n_periods x n_units).
|
|
119
128
|
effective_rank : float
|
|
@@ -124,10 +133,10 @@ class TROPResults:
|
|
|
124
133
|
Method used for variance estimation.
|
|
125
134
|
alpha : float
|
|
126
135
|
Significance level for confidence interval.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
136
|
+
n_pre_periods : int
|
|
137
|
+
Number of pre-treatment periods.
|
|
138
|
+
n_post_periods : int
|
|
139
|
+
Number of post-treatment periods (periods with D=1 observations).
|
|
131
140
|
n_bootstrap : int, optional
|
|
132
141
|
Number of bootstrap replications (if bootstrap variance).
|
|
133
142
|
bootstrap_distribution : np.ndarray, optional
|
|
@@ -154,8 +163,8 @@ class TROPResults:
|
|
|
154
163
|
loocv_score: float
|
|
155
164
|
variance_method: str
|
|
156
165
|
alpha: float = 0.05
|
|
157
|
-
|
|
158
|
-
|
|
166
|
+
n_pre_periods: int = 0
|
|
167
|
+
n_post_periods: int = 0
|
|
159
168
|
n_bootstrap: Optional[int] = field(default=None)
|
|
160
169
|
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
|
|
161
170
|
|
|
@@ -197,8 +206,8 @@ class TROPResults:
|
|
|
197
206
|
f"{'Treated units:':<25} {self.n_treated:>10}",
|
|
198
207
|
f"{'Control units:':<25} {self.n_control:>10}",
|
|
199
208
|
f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
|
|
200
|
-
f"{'Pre-treatment periods:':<25} {
|
|
201
|
-
f"{'Post-treatment periods:':<25} {
|
|
209
|
+
f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}",
|
|
210
|
+
f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}",
|
|
202
211
|
"",
|
|
203
212
|
"-" * 75,
|
|
204
213
|
"Tuning Parameters (selected via LOOCV)".center(75),
|
|
@@ -261,8 +270,8 @@ class TROPResults:
|
|
|
261
270
|
"n_treated": self.n_treated,
|
|
262
271
|
"n_control": self.n_control,
|
|
263
272
|
"n_treated_obs": self.n_treated_obs,
|
|
264
|
-
"n_pre_periods":
|
|
265
|
-
"n_post_periods":
|
|
273
|
+
"n_pre_periods": self.n_pre_periods,
|
|
274
|
+
"n_post_periods": self.n_post_periods,
|
|
266
275
|
"lambda_time": self.lambda_time,
|
|
267
276
|
"lambda_unit": self.lambda_unit,
|
|
268
277
|
"lambda_nn": self.lambda_nn,
|
|
@@ -397,7 +406,6 @@ class TROP:
|
|
|
397
406
|
... treatment='treated',
|
|
398
407
|
... unit='unit',
|
|
399
408
|
... time='period',
|
|
400
|
-
... post_periods=[5, 6, 7, 8]
|
|
401
409
|
... )
|
|
402
410
|
>>> results.print_summary()
|
|
403
411
|
|
|
@@ -658,6 +666,168 @@ class TROP:
|
|
|
658
666
|
else:
|
|
659
667
|
return np.inf
|
|
660
668
|
|
|
669
|
+
def _univariate_loocv_search(
|
|
670
|
+
self,
|
|
671
|
+
Y: np.ndarray,
|
|
672
|
+
D: np.ndarray,
|
|
673
|
+
control_mask: np.ndarray,
|
|
674
|
+
control_unit_idx: np.ndarray,
|
|
675
|
+
n_units: int,
|
|
676
|
+
n_periods: int,
|
|
677
|
+
param_name: str,
|
|
678
|
+
grid: List[float],
|
|
679
|
+
fixed_params: Dict[str, float],
|
|
680
|
+
) -> Tuple[float, float]:
|
|
681
|
+
"""
|
|
682
|
+
Search over one parameter with others fixed.
|
|
683
|
+
|
|
684
|
+
Following paper's footnote 2, this performs a univariate grid search
|
|
685
|
+
for one tuning parameter while holding others fixed. The fixed_params
|
|
686
|
+
can include _LAMBDA_INF values to disable specific components:
|
|
687
|
+
- lambda_nn = inf: Skip nuclear norm regularization (L=0)
|
|
688
|
+
- lambda_time = inf: Uniform time weights (treated as 0)
|
|
689
|
+
- lambda_unit = inf: Uniform unit weights (treated as 0)
|
|
690
|
+
|
|
691
|
+
Parameters
|
|
692
|
+
----------
|
|
693
|
+
Y : np.ndarray
|
|
694
|
+
Outcome matrix (n_periods x n_units).
|
|
695
|
+
D : np.ndarray
|
|
696
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
697
|
+
control_mask : np.ndarray
|
|
698
|
+
Boolean mask for control observations.
|
|
699
|
+
control_unit_idx : np.ndarray
|
|
700
|
+
Indices of control units.
|
|
701
|
+
n_units : int
|
|
702
|
+
Number of units.
|
|
703
|
+
n_periods : int
|
|
704
|
+
Number of periods.
|
|
705
|
+
param_name : str
|
|
706
|
+
Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
|
|
707
|
+
grid : List[float]
|
|
708
|
+
Grid of values to search over.
|
|
709
|
+
fixed_params : Dict[str, float]
|
|
710
|
+
Fixed values for other parameters. May include _LAMBDA_INF.
|
|
711
|
+
|
|
712
|
+
Returns
|
|
713
|
+
-------
|
|
714
|
+
Tuple[float, float]
|
|
715
|
+
(best_value, best_score) for the searched parameter.
|
|
716
|
+
"""
|
|
717
|
+
best_score = np.inf
|
|
718
|
+
best_value = grid[0] if grid else 0.0
|
|
719
|
+
|
|
720
|
+
for value in grid:
|
|
721
|
+
params = {**fixed_params, param_name: value}
|
|
722
|
+
|
|
723
|
+
# Convert inf values to 0 for computation (inf means "disabled" = uniform weights)
|
|
724
|
+
lambda_time = params.get('lambda_time', 0.0)
|
|
725
|
+
lambda_unit = params.get('lambda_unit', 0.0)
|
|
726
|
+
lambda_nn = params.get('lambda_nn', 0.0)
|
|
727
|
+
|
|
728
|
+
# Handle infinity as "disabled" mode
|
|
729
|
+
# Per paper Equations 2-3:
|
|
730
|
+
# - λ_time/λ_unit=∞ → exp(-∞×dist)→0 for dist>0, uniform weights → use 0.0
|
|
731
|
+
# - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
|
|
732
|
+
# Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled"
|
|
733
|
+
if np.isinf(lambda_time):
|
|
734
|
+
lambda_time = 0.0 # Uniform time weights
|
|
735
|
+
if np.isinf(lambda_unit):
|
|
736
|
+
lambda_unit = 0.0 # Uniform unit weights
|
|
737
|
+
if np.isinf(lambda_nn):
|
|
738
|
+
lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
|
|
739
|
+
|
|
740
|
+
try:
|
|
741
|
+
score = self._loocv_score_obs_specific(
|
|
742
|
+
Y, D, control_mask, control_unit_idx,
|
|
743
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
744
|
+
n_units, n_periods
|
|
745
|
+
)
|
|
746
|
+
if score < best_score:
|
|
747
|
+
best_score = score
|
|
748
|
+
best_value = value
|
|
749
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
750
|
+
continue
|
|
751
|
+
|
|
752
|
+
return best_value, best_score
|
|
753
|
+
|
|
754
|
+
def _cycling_parameter_search(
|
|
755
|
+
self,
|
|
756
|
+
Y: np.ndarray,
|
|
757
|
+
D: np.ndarray,
|
|
758
|
+
control_mask: np.ndarray,
|
|
759
|
+
control_unit_idx: np.ndarray,
|
|
760
|
+
n_units: int,
|
|
761
|
+
n_periods: int,
|
|
762
|
+
initial_lambda: Tuple[float, float, float],
|
|
763
|
+
max_cycles: int = 10,
|
|
764
|
+
) -> Tuple[float, float, float]:
|
|
765
|
+
"""
|
|
766
|
+
Cycle through parameters until convergence (coordinate descent).
|
|
767
|
+
|
|
768
|
+
Following paper's footnote 2 (Stage 2), this iteratively optimizes
|
|
769
|
+
each tuning parameter while holding the others fixed, until convergence.
|
|
770
|
+
|
|
771
|
+
Parameters
|
|
772
|
+
----------
|
|
773
|
+
Y : np.ndarray
|
|
774
|
+
Outcome matrix (n_periods x n_units).
|
|
775
|
+
D : np.ndarray
|
|
776
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
777
|
+
control_mask : np.ndarray
|
|
778
|
+
Boolean mask for control observations.
|
|
779
|
+
control_unit_idx : np.ndarray
|
|
780
|
+
Indices of control units.
|
|
781
|
+
n_units : int
|
|
782
|
+
Number of units.
|
|
783
|
+
n_periods : int
|
|
784
|
+
Number of periods.
|
|
785
|
+
initial_lambda : Tuple[float, float, float]
|
|
786
|
+
Initial values (lambda_time, lambda_unit, lambda_nn).
|
|
787
|
+
max_cycles : int, default=10
|
|
788
|
+
Maximum number of coordinate descent cycles.
|
|
789
|
+
|
|
790
|
+
Returns
|
|
791
|
+
-------
|
|
792
|
+
Tuple[float, float, float]
|
|
793
|
+
Optimized (lambda_time, lambda_unit, lambda_nn).
|
|
794
|
+
"""
|
|
795
|
+
lambda_time, lambda_unit, lambda_nn = initial_lambda
|
|
796
|
+
prev_score = np.inf
|
|
797
|
+
|
|
798
|
+
for cycle in range(max_cycles):
|
|
799
|
+
# Optimize λ_unit (fix λ_time, λ_nn)
|
|
800
|
+
lambda_unit, _ = self._univariate_loocv_search(
|
|
801
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
802
|
+
'lambda_unit', self.lambda_unit_grid,
|
|
803
|
+
{'lambda_time': lambda_time, 'lambda_nn': lambda_nn}
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# Optimize λ_time (fix λ_unit, λ_nn)
|
|
807
|
+
lambda_time, _ = self._univariate_loocv_search(
|
|
808
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
809
|
+
'lambda_time', self.lambda_time_grid,
|
|
810
|
+
{'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn}
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Optimize λ_nn (fix λ_unit, λ_time)
|
|
814
|
+
lambda_nn, score = self._univariate_loocv_search(
|
|
815
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
816
|
+
'lambda_nn', self.lambda_nn_grid,
|
|
817
|
+
{'lambda_unit': lambda_unit, 'lambda_time': lambda_time}
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
# Check convergence
|
|
821
|
+
if abs(score - prev_score) < 1e-6:
|
|
822
|
+
logger.debug(
|
|
823
|
+
"Cycling search converged after %d cycles with score %.6f",
|
|
824
|
+
cycle + 1, score
|
|
825
|
+
)
|
|
826
|
+
break
|
|
827
|
+
prev_score = score
|
|
828
|
+
|
|
829
|
+
return lambda_time, lambda_unit, lambda_nn
|
|
830
|
+
|
|
661
831
|
def fit(
|
|
662
832
|
self,
|
|
663
833
|
data: pd.DataFrame,
|
|
@@ -665,7 +835,6 @@ class TROP:
|
|
|
665
835
|
treatment: str,
|
|
666
836
|
unit: str,
|
|
667
837
|
time: str,
|
|
668
|
-
post_periods: Optional[List[Any]] = None,
|
|
669
838
|
) -> TROPResults:
|
|
670
839
|
"""
|
|
671
840
|
Fit the TROP model.
|
|
@@ -679,20 +848,31 @@ class TROP:
|
|
|
679
848
|
Name of the outcome variable column.
|
|
680
849
|
treatment : str
|
|
681
850
|
Name of the treatment indicator column (0/1).
|
|
682
|
-
|
|
851
|
+
|
|
852
|
+
IMPORTANT: This should be an ABSORBING STATE indicator, not a
|
|
853
|
+
treatment timing indicator. For each unit, D=1 for ALL periods
|
|
854
|
+
during and after treatment:
|
|
855
|
+
|
|
856
|
+
- D[t, i] = 0 for all t < g_i (pre-treatment periods)
|
|
857
|
+
- D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
|
|
858
|
+
|
|
859
|
+
where g_i is the treatment start time for unit i.
|
|
860
|
+
|
|
861
|
+
For staggered adoption, different units can have different g_i.
|
|
862
|
+
The ATT averages over ALL D=1 cells per Equation 1 of the paper.
|
|
683
863
|
unit : str
|
|
684
864
|
Name of the unit identifier column.
|
|
685
865
|
time : str
|
|
686
866
|
Name of the time period column.
|
|
687
|
-
post_periods : list, optional
|
|
688
|
-
List of time period values that are post-treatment.
|
|
689
|
-
If None, infers from treatment indicator.
|
|
690
867
|
|
|
691
868
|
Returns
|
|
692
869
|
-------
|
|
693
870
|
TROPResults
|
|
694
871
|
Object containing the ATT estimate, standard error,
|
|
695
|
-
factor estimates, and tuning parameters.
|
|
872
|
+
factor estimates, and tuning parameters. The lambda_*
|
|
873
|
+
attributes show the selected grid values. Infinity values
|
|
874
|
+
(∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform
|
|
875
|
+
weights), λ_nn=∞ → 1e10 (factor model disabled).
|
|
696
876
|
"""
|
|
697
877
|
# Validate inputs
|
|
698
878
|
required_cols = [outcome, treatment, unit, time]
|
|
@@ -720,13 +900,39 @@ class TROP:
|
|
|
720
900
|
.reindex(index=all_periods, columns=all_units)
|
|
721
901
|
.values
|
|
722
902
|
)
|
|
723
|
-
|
|
903
|
+
|
|
904
|
+
# For D matrix, track missing values BEFORE fillna to support unbalanced panels
|
|
905
|
+
# Issue 3 fix: Missing observations should not trigger spurious violations
|
|
906
|
+
D_raw = (
|
|
724
907
|
data.pivot(index=time, columns=unit, values=treatment)
|
|
725
908
|
.reindex(index=all_periods, columns=all_units)
|
|
726
|
-
.fillna(0)
|
|
727
|
-
.astype(int)
|
|
728
|
-
.values
|
|
729
909
|
)
|
|
910
|
+
missing_mask = pd.isna(D_raw).values # True where originally missing
|
|
911
|
+
D = D_raw.fillna(0).astype(int).values
|
|
912
|
+
|
|
913
|
+
# Validate D is monotonic non-decreasing per unit (absorbing state)
|
|
914
|
+
# D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
|
|
915
|
+
# Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
|
|
916
|
+
# This catches 1→0 violations that span missing period gaps
|
|
917
|
+
# Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
|
|
918
|
+
# adjacent period transitions don't show it (the gap hides the transition)
|
|
919
|
+
violating_units = []
|
|
920
|
+
for unit_idx in range(n_units):
|
|
921
|
+
# Get observed D values for this unit (where not missing)
|
|
922
|
+
observed_mask = ~missing_mask[:, unit_idx]
|
|
923
|
+
observed_d = D[observed_mask, unit_idx]
|
|
924
|
+
|
|
925
|
+
# Check if observed sequence is monotonically non-decreasing
|
|
926
|
+
if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
|
|
927
|
+
violating_units.append(all_units[unit_idx])
|
|
928
|
+
|
|
929
|
+
if violating_units:
|
|
930
|
+
raise ValueError(
|
|
931
|
+
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
|
|
932
|
+
f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
|
|
933
|
+
f"If this is event-study style data, convert to absorbing state: "
|
|
934
|
+
f"D[t, i] = 1 for all t >= first treatment period."
|
|
935
|
+
)
|
|
730
936
|
|
|
731
937
|
# Identify treated observations
|
|
732
938
|
treated_mask = D == 1
|
|
@@ -743,28 +949,23 @@ class TROP:
|
|
|
743
949
|
if len(control_unit_idx) == 0:
|
|
744
950
|
raise ValueError("No control units found")
|
|
745
951
|
|
|
746
|
-
# Determine pre/post periods
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
raise ValueError("Could not infer post-treatment periods")
|
|
756
|
-
pre_period_idx = list(range(first_treat_period))
|
|
757
|
-
post_period_idx = list(range(first_treat_period, n_periods))
|
|
758
|
-
else:
|
|
759
|
-
post_period_idx = [period_to_idx[p] for p in post_periods if p in period_to_idx]
|
|
760
|
-
pre_period_idx = [i for i in range(n_periods) if i not in post_period_idx]
|
|
952
|
+
# Determine pre/post periods from treatment indicator D
|
|
953
|
+
# D matrix is the sole input for treatment timing per the paper
|
|
954
|
+
first_treat_period = None
|
|
955
|
+
for t in range(n_periods):
|
|
956
|
+
if np.any(D[t, :] == 1):
|
|
957
|
+
first_treat_period = t
|
|
958
|
+
break
|
|
959
|
+
if first_treat_period is None:
|
|
960
|
+
raise ValueError("Could not infer post-treatment periods from D matrix")
|
|
761
961
|
|
|
762
|
-
|
|
763
|
-
|
|
962
|
+
n_pre_periods = first_treat_period
|
|
963
|
+
# Count periods where D=1 is actually observed (matches docstring)
|
|
964
|
+
# Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
|
|
965
|
+
n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
|
|
764
966
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
n_treated_periods = len(post_period_idx)
|
|
967
|
+
if n_pre_periods < 2:
|
|
968
|
+
raise ValueError("Need at least 2 pre-treatment periods")
|
|
768
969
|
|
|
769
970
|
# Step 1: Grid search with LOOCV for tuning parameters
|
|
770
971
|
best_lambda = None
|
|
@@ -789,14 +990,45 @@ class TROP:
|
|
|
789
990
|
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
|
|
790
991
|
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
|
|
791
992
|
|
|
792
|
-
|
|
993
|
+
result = _rust_loocv_grid_search(
|
|
793
994
|
Y, D.astype(np.float64), control_mask_u8,
|
|
794
995
|
time_dist_matrix,
|
|
795
996
|
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
|
|
796
997
|
self.max_loocv_samples, self.max_iter, self.tol,
|
|
797
998
|
self.seed if self.seed is not None else 0
|
|
798
999
|
)
|
|
799
|
-
|
|
1000
|
+
# Unpack result - 7 values including optional first_failed_obs
|
|
1001
|
+
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
|
|
1002
|
+
# Only accept finite scores - infinite means all fits failed
|
|
1003
|
+
if np.isfinite(best_score):
|
|
1004
|
+
best_lambda = (best_lt, best_lu, best_ln)
|
|
1005
|
+
# else: best_lambda stays None, triggering defaults fallback
|
|
1006
|
+
# Emit warnings consistent with Python implementation
|
|
1007
|
+
if n_valid == 0:
|
|
1008
|
+
# Include failed observation coordinates if available (Issue 2 fix)
|
|
1009
|
+
obs_info = ""
|
|
1010
|
+
if first_failed_obs is not None:
|
|
1011
|
+
t_idx, i_idx = first_failed_obs
|
|
1012
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
1013
|
+
warnings.warn(
|
|
1014
|
+
f"LOOCV: All {n_attempted} fits failed for "
|
|
1015
|
+
f"λ=({best_lt}, {best_lu}, {best_ln}). "
|
|
1016
|
+
f"Returning infinite score.{obs_info}",
|
|
1017
|
+
UserWarning
|
|
1018
|
+
)
|
|
1019
|
+
elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
|
|
1020
|
+
n_failed = n_attempted - n_valid
|
|
1021
|
+
# Include failed observation coordinates if available
|
|
1022
|
+
obs_info = ""
|
|
1023
|
+
if first_failed_obs is not None:
|
|
1024
|
+
t_idx, i_idx = first_failed_obs
|
|
1025
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
1026
|
+
warnings.warn(
|
|
1027
|
+
f"LOOCV: {n_failed}/{n_attempted} fits failed for "
|
|
1028
|
+
f"λ=({best_lt}, {best_lu}, {best_ln}). "
|
|
1029
|
+
f"This may indicate numerical instability.{obs_info}",
|
|
1030
|
+
UserWarning
|
|
1031
|
+
)
|
|
800
1032
|
except Exception as e:
|
|
801
1033
|
# Fall back to Python implementation on error
|
|
802
1034
|
logger.debug(
|
|
@@ -806,21 +1038,54 @@ class TROP:
|
|
|
806
1038
|
best_score = np.inf
|
|
807
1039
|
|
|
808
1040
|
# Fall back to Python implementation if Rust unavailable or failed
|
|
1041
|
+
# Uses two-stage approach per paper's footnote 2:
|
|
1042
|
+
# Stage 1: Univariate searches for initial values
|
|
1043
|
+
# Stage 2: Cycling (coordinate descent) until convergence
|
|
809
1044
|
if best_lambda is None:
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
1045
|
+
# Stage 1: Univariate searches with extreme fixed values
|
|
1046
|
+
# Following paper's footnote 2 for initial bounds
|
|
1047
|
+
|
|
1048
|
+
# λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
|
|
1049
|
+
lambda_time_init, _ = self._univariate_loocv_search(
|
|
1050
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1051
|
+
'lambda_time', self.lambda_time_grid,
|
|
1052
|
+
{'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF}
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# λ_nn search: fix λ_time=∞ (uniform time weights), λ_unit=0
|
|
1056
|
+
lambda_nn_init, _ = self._univariate_loocv_search(
|
|
1057
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1058
|
+
'lambda_nn', self.lambda_nn_grid,
|
|
1059
|
+
{'lambda_time': _LAMBDA_INF, 'lambda_unit': 0.0}
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# λ_unit search: fix λ_nn=∞, λ_time=0
|
|
1063
|
+
lambda_unit_init, _ = self._univariate_loocv_search(
|
|
1064
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1065
|
+
'lambda_unit', self.lambda_unit_grid,
|
|
1066
|
+
{'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0}
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
# Stage 2: Cycling refinement (coordinate descent)
|
|
1070
|
+
lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
|
|
1071
|
+
Y, D, control_mask, control_unit_idx, n_units, n_periods,
|
|
1072
|
+
(lambda_time_init, lambda_unit_init, lambda_nn_init)
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
# Compute final score for the optimized parameters
|
|
1076
|
+
try:
|
|
1077
|
+
best_score = self._loocv_score_obs_specific(
|
|
1078
|
+
Y, D, control_mask, control_unit_idx,
|
|
1079
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
1080
|
+
n_units, n_periods
|
|
1081
|
+
)
|
|
1082
|
+
# Only accept finite scores - infinite means all fits failed
|
|
1083
|
+
if np.isfinite(best_score):
|
|
1084
|
+
best_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
1085
|
+
# else: best_lambda stays None, triggering defaults fallback
|
|
1086
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1087
|
+
# If even the optimized parameters fail, best_lambda stays None
|
|
1088
|
+
pass
|
|
824
1089
|
|
|
825
1090
|
if best_lambda is None:
|
|
826
1091
|
warnings.warn(
|
|
@@ -833,6 +1098,26 @@ class TROP:
|
|
|
833
1098
|
self._optimal_lambda = best_lambda
|
|
834
1099
|
lambda_time, lambda_unit, lambda_nn = best_lambda
|
|
835
1100
|
|
|
1101
|
+
# Convert infinity values for final estimation (matching LOOCV conversion)
|
|
1102
|
+
# This ensures final estimation uses the same effective parameters that LOOCV evaluated.
|
|
1103
|
+
# See REGISTRY.md "λ=∞ implementation" for rationale.
|
|
1104
|
+
#
|
|
1105
|
+
# IMPORTANT: Store original grid values for results, use converted for computation.
|
|
1106
|
+
# This lets users see what was selected from their grid, while ensuring consistent
|
|
1107
|
+
# behavior between point estimation and variance estimation.
|
|
1108
|
+
original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda
|
|
1109
|
+
|
|
1110
|
+
if np.isinf(lambda_time):
|
|
1111
|
+
lambda_time = 0.0 # Uniform time weights
|
|
1112
|
+
if np.isinf(lambda_unit):
|
|
1113
|
+
lambda_unit = 0.0 # Uniform unit weights
|
|
1114
|
+
if np.isinf(lambda_nn):
|
|
1115
|
+
lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
|
|
1116
|
+
|
|
1117
|
+
# Create effective_lambda with converted values for ALL downstream computation
|
|
1118
|
+
# This ensures variance estimation uses the same parameters as point estimation
|
|
1119
|
+
effective_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
1120
|
+
|
|
836
1121
|
# Step 2: Final estimation - per-observation model fitting following Algorithm 2
|
|
837
1122
|
# For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
|
|
838
1123
|
treatment_effects = {}
|
|
@@ -886,14 +1171,16 @@ class TROP:
|
|
|
886
1171
|
effective_rank = 0.0
|
|
887
1172
|
|
|
888
1173
|
# Step 4: Variance estimation
|
|
1174
|
+
# Use effective_lambda (converted values) to ensure SE is computed with same
|
|
1175
|
+
# parameters as point estimation. This fixes the variance inconsistency issue.
|
|
889
1176
|
if self.variance_method == "bootstrap":
|
|
890
1177
|
se, bootstrap_dist = self._bootstrap_variance(
|
|
891
|
-
data, outcome, treatment, unit, time,
|
|
892
|
-
|
|
1178
|
+
data, outcome, treatment, unit, time,
|
|
1179
|
+
effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
|
|
893
1180
|
)
|
|
894
1181
|
else:
|
|
895
1182
|
se, bootstrap_dist = self._jackknife_variance(
|
|
896
|
-
Y, D, control_mask, control_unit_idx,
|
|
1183
|
+
Y, D, control_mask, control_unit_idx, effective_lambda,
|
|
897
1184
|
n_units, n_periods
|
|
898
1185
|
)
|
|
899
1186
|
|
|
@@ -901,11 +1188,12 @@ class TROP:
|
|
|
901
1188
|
if se > 0:
|
|
902
1189
|
t_stat = att / se
|
|
903
1190
|
p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
|
|
1191
|
+
conf_int = compute_confidence_interval(att, se, self.alpha)
|
|
904
1192
|
else:
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
1193
|
+
# When SE is undefined/zero, ALL inference fields should be NaN
|
|
1194
|
+
t_stat = np.nan
|
|
1195
|
+
p_value = np.nan
|
|
1196
|
+
conf_int = (np.nan, np.nan)
|
|
909
1197
|
|
|
910
1198
|
# Create results dictionaries
|
|
911
1199
|
unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
|
|
@@ -925,16 +1213,18 @@ class TROP:
|
|
|
925
1213
|
unit_effects=unit_effects_dict,
|
|
926
1214
|
time_effects=time_effects_dict,
|
|
927
1215
|
treatment_effects=treatment_effects,
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
1216
|
+
# Store ORIGINAL grid values (possibly inf) so users see what was selected.
|
|
1217
|
+
# Internally, infinity values are converted for computation (see effective_lambda).
|
|
1218
|
+
lambda_time=original_lambda_time,
|
|
1219
|
+
lambda_unit=original_lambda_unit,
|
|
1220
|
+
lambda_nn=original_lambda_nn,
|
|
931
1221
|
factor_matrix=L_hat,
|
|
932
1222
|
effective_rank=effective_rank,
|
|
933
1223
|
loocv_score=best_score,
|
|
934
1224
|
variance_method=self.variance_method,
|
|
935
1225
|
alpha=self.alpha,
|
|
936
|
-
|
|
937
|
-
|
|
1226
|
+
n_pre_periods=n_pre_periods,
|
|
1227
|
+
n_post_periods=n_post_periods,
|
|
938
1228
|
n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None,
|
|
939
1229
|
bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
|
|
940
1230
|
)
|
|
@@ -1409,6 +1699,17 @@ class TROP:
|
|
|
1409
1699
|
indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
|
|
1410
1700
|
control_obs = [control_obs[idx] for idx in indices]
|
|
1411
1701
|
|
|
1702
|
+
# Empty control set check: if no control observations, return infinity
|
|
1703
|
+
# A score of 0.0 would incorrectly "win" over legitimate parameters
|
|
1704
|
+
if len(control_obs) == 0:
|
|
1705
|
+
warnings.warn(
|
|
1706
|
+
f"LOOCV: No valid control observations for "
|
|
1707
|
+
f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
|
|
1708
|
+
"Returning infinite score.",
|
|
1709
|
+
UserWarning
|
|
1710
|
+
)
|
|
1711
|
+
return np.inf
|
|
1712
|
+
|
|
1412
1713
|
tau_squared_sum = 0.0
|
|
1413
1714
|
n_valid = 0
|
|
1414
1715
|
|
|
@@ -1433,12 +1734,19 @@ class TROP:
|
|
|
1433
1734
|
n_valid += 1
|
|
1434
1735
|
|
|
1435
1736
|
except (np.linalg.LinAlgError, ValueError):
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1737
|
+
# Per Equation 5: Q(λ) must sum over ALL D==0 cells
|
|
1738
|
+
# Any failure means this λ cannot produce valid estimates for all cells
|
|
1739
|
+
warnings.warn(
|
|
1740
|
+
f"LOOCV: Fit failed for observation ({t}, {i}) with "
|
|
1741
|
+
f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
|
|
1742
|
+
"Returning infinite score per Equation 5.",
|
|
1743
|
+
UserWarning
|
|
1744
|
+
)
|
|
1745
|
+
return np.inf
|
|
1440
1746
|
|
|
1441
|
-
|
|
1747
|
+
# Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
|
|
1748
|
+
# Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
1749
|
+
return tau_squared_sum
|
|
1442
1750
|
|
|
1443
1751
|
def _bootstrap_variance(
|
|
1444
1752
|
self,
|
|
@@ -1447,7 +1755,6 @@ class TROP:
|
|
|
1447
1755
|
treatment: str,
|
|
1448
1756
|
unit: str,
|
|
1449
1757
|
time: str,
|
|
1450
|
-
post_periods: List[Any],
|
|
1451
1758
|
optimal_lambda: Tuple[float, float, float],
|
|
1452
1759
|
Y: Optional[np.ndarray] = None,
|
|
1453
1760
|
D: Optional[np.ndarray] = None,
|
|
@@ -1473,8 +1780,6 @@ class TROP:
|
|
|
1473
1780
|
Name of the unit identifier column in data.
|
|
1474
1781
|
time : str
|
|
1475
1782
|
Name of the time period column in data.
|
|
1476
|
-
post_periods : list
|
|
1477
|
-
List of post-treatment time periods.
|
|
1478
1783
|
optimal_lambda : tuple of float
|
|
1479
1784
|
Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
|
|
1480
1785
|
from cross-validation. Used for model estimation in each bootstrap.
|
|
@@ -1579,7 +1884,7 @@ class TROP:
|
|
|
1579
1884
|
# Fit with fixed lambda (skip LOOCV for speed)
|
|
1580
1885
|
att = self._fit_with_fixed_lambda(
|
|
1581
1886
|
boot_data, outcome, treatment, unit, time,
|
|
1582
|
-
|
|
1887
|
+
optimal_lambda
|
|
1583
1888
|
)
|
|
1584
1889
|
bootstrap_estimates_list.append(att)
|
|
1585
1890
|
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
@@ -1703,7 +2008,6 @@ class TROP:
|
|
|
1703
2008
|
treatment: str,
|
|
1704
2009
|
unit: str,
|
|
1705
2010
|
time: str,
|
|
1706
|
-
post_periods: List[Any],
|
|
1707
2011
|
fixed_lambda: Tuple[float, float, float],
|
|
1708
2012
|
) -> float:
|
|
1709
2013
|
"""
|
|
@@ -1803,7 +2107,6 @@ def trop(
|
|
|
1803
2107
|
treatment: str,
|
|
1804
2108
|
unit: str,
|
|
1805
2109
|
time: str,
|
|
1806
|
-
post_periods: Optional[List[Any]] = None,
|
|
1807
2110
|
**kwargs,
|
|
1808
2111
|
) -> TROPResults:
|
|
1809
2112
|
"""
|
|
@@ -1816,13 +2119,16 @@ def trop(
|
|
|
1816
2119
|
outcome : str
|
|
1817
2120
|
Outcome variable column name.
|
|
1818
2121
|
treatment : str
|
|
1819
|
-
Treatment indicator column name.
|
|
2122
|
+
Treatment indicator column name (0/1).
|
|
2123
|
+
|
|
2124
|
+
IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
|
|
2125
|
+
timing indicator. For each unit, D=1 for ALL periods during and after
|
|
2126
|
+
treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
|
|
2127
|
+
the treatment start time for unit i).
|
|
1820
2128
|
unit : str
|
|
1821
2129
|
Unit identifier column name.
|
|
1822
2130
|
time : str
|
|
1823
2131
|
Time period column name.
|
|
1824
|
-
post_periods : list, optional
|
|
1825
|
-
Post-treatment periods.
|
|
1826
2132
|
**kwargs
|
|
1827
2133
|
Additional arguments passed to TROP constructor.
|
|
1828
2134
|
|
|
@@ -1834,8 +2140,8 @@ def trop(
|
|
|
1834
2140
|
Examples
|
|
1835
2141
|
--------
|
|
1836
2142
|
>>> from diff_diff import trop
|
|
1837
|
-
>>> results = trop(data, 'y', 'treated', 'unit', 'time'
|
|
2143
|
+
>>> results = trop(data, 'y', 'treated', 'unit', 'time')
|
|
1838
2144
|
>>> print(f"ATT: {results.att:.3f}")
|
|
1839
2145
|
"""
|
|
1840
2146
|
estimator = TROP(**kwargs)
|
|
1841
|
-
return estimator.fit(data, outcome, treatment, unit, time
|
|
2147
|
+
return estimator.fit(data, outcome, treatment, unit, time)
|
|
@@ -289,7 +289,7 @@ dependencies = [
|
|
|
289
289
|
|
|
290
290
|
[[package]]
|
|
291
291
|
name = "diff_diff_rust"
|
|
292
|
-
version = "2.1.
|
|
292
|
+
version = "2.1.8"
|
|
293
293
|
dependencies = [
|
|
294
294
|
"ndarray",
|
|
295
295
|
"ndarray-linalg",
|
|
@@ -2316,6 +2316,6 @@ dependencies = [
|
|
|
2316
2316
|
|
|
2317
2317
|
[[package]]
|
|
2318
2318
|
name = "zmij"
|
|
2319
|
-
version = "1.0.
|
|
2319
|
+
version = "1.0.17"
|
|
2320
2320
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
2321
|
-
checksum = "
|
|
2321
|
+
checksum = "02aae0f83f69aafc94776e879363e9771d7ecbffe2c7fbb6c14c5e00dfe88439"
|
|
@@ -146,10 +146,160 @@ fn compute_pair_distance(
|
|
|
146
146
|
}
|
|
147
147
|
}
|
|
148
148
|
|
|
149
|
-
/// Perform LOOCV
|
|
149
|
+
/// Perform univariate LOOCV search over a single parameter.
|
|
150
150
|
///
|
|
151
|
-
///
|
|
152
|
-
///
|
|
151
|
+
/// Following paper's footnote 2, this performs a grid search for one parameter
|
|
152
|
+
/// while holding others fixed. Used in the two-stage LOOCV approach.
|
|
153
|
+
///
|
|
154
|
+
/// # Arguments
|
|
155
|
+
/// * `y` - Outcome matrix (n_periods x n_units)
|
|
156
|
+
/// * `d` - Treatment indicator matrix
|
|
157
|
+
/// * `control_mask` - Boolean mask for control observations
|
|
158
|
+
/// * `time_dist` - Time distance matrix
|
|
159
|
+
/// * `control_obs` - List of control observations for LOOCV
|
|
160
|
+
/// * `grid` - Grid of values to search
|
|
161
|
+
/// * `fixed_time` - Fixed lambda_time (inf for disabled)
|
|
162
|
+
/// * `fixed_unit` - Fixed lambda_unit (inf for disabled)
|
|
163
|
+
/// * `fixed_nn` - Fixed lambda_nn (inf for disabled)
|
|
164
|
+
/// * `param_type` - Which parameter to search: 0=time, 1=unit, 2=nn
|
|
165
|
+
/// * `max_iter` - Maximum iterations
|
|
166
|
+
/// * `tol` - Convergence tolerance
|
|
167
|
+
///
|
|
168
|
+
/// # Returns
|
|
169
|
+
/// (best_value, best_score)
|
|
170
|
+
fn univariate_loocv_search(
|
|
171
|
+
y: &ArrayView2<f64>,
|
|
172
|
+
d: &ArrayView2<f64>,
|
|
173
|
+
control_mask: &ArrayView2<u8>,
|
|
174
|
+
time_dist: &ArrayView2<i64>,
|
|
175
|
+
control_obs: &[(usize, usize)],
|
|
176
|
+
grid: &[f64],
|
|
177
|
+
fixed_time: f64,
|
|
178
|
+
fixed_unit: f64,
|
|
179
|
+
fixed_nn: f64,
|
|
180
|
+
param_type: usize, // 0=time, 1=unit, 2=nn
|
|
181
|
+
max_iter: usize,
|
|
182
|
+
tol: f64,
|
|
183
|
+
) -> (f64, f64) {
|
|
184
|
+
let mut best_score = f64::INFINITY;
|
|
185
|
+
let mut best_value = grid.first().copied().unwrap_or(0.0);
|
|
186
|
+
|
|
187
|
+
// Parallelize over grid values
|
|
188
|
+
let results: Vec<(f64, f64)> = grid
|
|
189
|
+
.par_iter()
|
|
190
|
+
.map(|&value| {
|
|
191
|
+
// Set parameters, converting inf for "disabled" mode
|
|
192
|
+
// Per paper Equations 2-3:
|
|
193
|
+
// - λ_time/λ_unit=∞ → uniform weights → use 0.0
|
|
194
|
+
// - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
|
|
195
|
+
// Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled"
|
|
196
|
+
//
|
|
197
|
+
// IMPORTANT: Convert the grid value BEFORE using it, matching Python behavior.
|
|
198
|
+
// This ensures Rust and Python evaluate the same objective for infinity grids.
|
|
199
|
+
let (lambda_time, lambda_unit, lambda_nn) = match param_type {
|
|
200
|
+
0 => {
|
|
201
|
+
// Searching λ_time: convert grid value if infinite
|
|
202
|
+
let value_converted = if value.is_infinite() { 0.0 } else { value };
|
|
203
|
+
(value_converted,
|
|
204
|
+
if fixed_unit.is_infinite() { 0.0 } else { fixed_unit },
|
|
205
|
+
if fixed_nn.is_infinite() { 1e10 } else { fixed_nn })
|
|
206
|
+
},
|
|
207
|
+
1 => {
|
|
208
|
+
// Searching λ_unit: convert grid value if infinite
|
|
209
|
+
let value_converted = if value.is_infinite() { 0.0 } else { value };
|
|
210
|
+
(if fixed_time.is_infinite() { 0.0 } else { fixed_time },
|
|
211
|
+
value_converted,
|
|
212
|
+
if fixed_nn.is_infinite() { 1e10 } else { fixed_nn })
|
|
213
|
+
},
|
|
214
|
+
_ => {
|
|
215
|
+
// Searching λ_nn: convert grid value if infinite
|
|
216
|
+
let value_converted = if value.is_infinite() { 1e10 } else { value };
|
|
217
|
+
(if fixed_time.is_infinite() { 0.0 } else { fixed_time },
|
|
218
|
+
if fixed_unit.is_infinite() { 0.0 } else { fixed_unit },
|
|
219
|
+
value_converted)
|
|
220
|
+
},
|
|
221
|
+
};
|
|
222
|
+
|
|
223
|
+
let (score, _, _) = loocv_score_for_params(
|
|
224
|
+
y, d, control_mask, time_dist, control_obs,
|
|
225
|
+
lambda_time, lambda_unit, lambda_nn,
|
|
226
|
+
max_iter, tol,
|
|
227
|
+
);
|
|
228
|
+
(value, score)
|
|
229
|
+
})
|
|
230
|
+
.collect();
|
|
231
|
+
|
|
232
|
+
for (value, score) in results {
|
|
233
|
+
if score < best_score {
|
|
234
|
+
best_score = score;
|
|
235
|
+
best_value = value;
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
(best_value, best_score)
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/// Cycle through parameters until convergence (coordinate descent).
|
|
243
|
+
///
|
|
244
|
+
/// Following paper's footnote 2 (Stage 2), iteratively optimize each parameter.
|
|
245
|
+
fn cycling_parameter_search(
|
|
246
|
+
y: &ArrayView2<f64>,
|
|
247
|
+
d: &ArrayView2<f64>,
|
|
248
|
+
control_mask: &ArrayView2<u8>,
|
|
249
|
+
time_dist: &ArrayView2<i64>,
|
|
250
|
+
control_obs: &[(usize, usize)],
|
|
251
|
+
lambda_time_grid: &[f64],
|
|
252
|
+
lambda_unit_grid: &[f64],
|
|
253
|
+
lambda_nn_grid: &[f64],
|
|
254
|
+
initial_time: f64,
|
|
255
|
+
initial_unit: f64,
|
|
256
|
+
initial_nn: f64,
|
|
257
|
+
max_iter: usize,
|
|
258
|
+
tol: f64,
|
|
259
|
+
max_cycles: usize,
|
|
260
|
+
) -> (f64, f64, f64) {
|
|
261
|
+
let mut lambda_time = initial_time;
|
|
262
|
+
let mut lambda_unit = initial_unit;
|
|
263
|
+
let mut lambda_nn = initial_nn;
|
|
264
|
+
let mut prev_score = f64::INFINITY;
|
|
265
|
+
|
|
266
|
+
for _cycle in 0..max_cycles {
|
|
267
|
+
// Optimize λ_unit (fix λ_time, λ_nn)
|
|
268
|
+
let (new_unit, _) = univariate_loocv_search(
|
|
269
|
+
y, d, control_mask, time_dist, control_obs,
|
|
270
|
+
lambda_unit_grid, lambda_time, 0.0, lambda_nn, 1, max_iter, tol,
|
|
271
|
+
);
|
|
272
|
+
lambda_unit = new_unit;
|
|
273
|
+
|
|
274
|
+
// Optimize λ_time (fix λ_unit, λ_nn)
|
|
275
|
+
let (new_time, _) = univariate_loocv_search(
|
|
276
|
+
y, d, control_mask, time_dist, control_obs,
|
|
277
|
+
lambda_time_grid, 0.0, lambda_unit, lambda_nn, 0, max_iter, tol,
|
|
278
|
+
);
|
|
279
|
+
lambda_time = new_time;
|
|
280
|
+
|
|
281
|
+
// Optimize λ_nn (fix λ_unit, λ_time)
|
|
282
|
+
let (new_nn, score) = univariate_loocv_search(
|
|
283
|
+
y, d, control_mask, time_dist, control_obs,
|
|
284
|
+
lambda_nn_grid, lambda_time, lambda_unit, 0.0, 2, max_iter, tol,
|
|
285
|
+
);
|
|
286
|
+
lambda_nn = new_nn;
|
|
287
|
+
|
|
288
|
+
// Check convergence
|
|
289
|
+
if (score - prev_score).abs() < 1e-6 {
|
|
290
|
+
break;
|
|
291
|
+
}
|
|
292
|
+
prev_score = score;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
(lambda_time, lambda_unit, lambda_nn)
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
/// Perform LOOCV grid search over tuning parameters using two-stage approach.
|
|
299
|
+
///
|
|
300
|
+
/// Following paper's footnote 2:
|
|
301
|
+
/// - Stage 1: Univariate searches for initial values with extreme fixed parameters
|
|
302
|
+
/// - Stage 2: Cycling (coordinate descent) until convergence
|
|
153
303
|
///
|
|
154
304
|
/// Following TROP Equation 5 (page 8):
|
|
155
305
|
/// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
@@ -158,8 +308,6 @@ fn compute_pair_distance(
|
|
|
158
308
|
/// * `y` - Outcome matrix (n_periods x n_units)
|
|
159
309
|
/// * `d` - Treatment indicator matrix (n_periods x n_units)
|
|
160
310
|
/// * `control_mask` - Boolean mask (n_periods x n_units) for control observations
|
|
161
|
-
/// * `control_unit_idx` - Array of control unit indices
|
|
162
|
-
/// * `unit_dist_matrix` - Pre-computed unit distance matrix (n_units x n_units)
|
|
163
311
|
/// * `time_dist_matrix` - Pre-computed time distance matrix (n_periods x n_periods)
|
|
164
312
|
/// * `lambda_time_grid` - Grid of time decay parameters
|
|
165
313
|
/// * `lambda_unit_grid` - Grid of unit distance parameters
|
|
@@ -170,7 +318,10 @@ fn compute_pair_distance(
|
|
|
170
318
|
/// * `seed` - Random seed for subsampling
|
|
171
319
|
///
|
|
172
320
|
/// # Returns
|
|
173
|
-
/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score)
|
|
321
|
+
/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs)
|
|
322
|
+
/// where n_valid and n_attempted are the counts for the best parameter combination,
|
|
323
|
+
/// allowing Python to emit warnings when >10% of fits fail.
|
|
324
|
+
/// first_failed_obs is Some((t, i)) if a fit failed during final score computation, None otherwise.
|
|
174
325
|
#[pyfunction]
|
|
175
326
|
#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
|
|
176
327
|
#[allow(clippy::too_many_arguments)]
|
|
@@ -187,7 +338,7 @@ pub fn loocv_grid_search<'py>(
|
|
|
187
338
|
max_iter: usize,
|
|
188
339
|
tol: f64,
|
|
189
340
|
seed: u64,
|
|
190
|
-
) -> PyResult<(f64, f64, f64, f64)> {
|
|
341
|
+
) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> {
|
|
191
342
|
let y_arr = y.as_array();
|
|
192
343
|
let d_arr = d.as_array();
|
|
193
344
|
let control_mask_arr = control_mask.as_array();
|
|
@@ -204,43 +355,53 @@ pub fn loocv_grid_search<'py>(
|
|
|
204
355
|
seed,
|
|
205
356
|
);
|
|
206
357
|
|
|
207
|
-
|
|
208
|
-
let mut param_combos: Vec<(f64, f64, f64)> = Vec::new();
|
|
209
|
-
for < in &lambda_time_vec {
|
|
210
|
-
for &lu in &lambda_unit_vec {
|
|
211
|
-
for &ln in &lambda_nn_vec {
|
|
212
|
-
param_combos.push((lt, lu, ln));
|
|
213
|
-
}
|
|
214
|
-
}
|
|
215
|
-
}
|
|
358
|
+
let n_attempted = control_obs.len();
|
|
216
359
|
|
|
217
|
-
//
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
&d_arr,
|
|
224
|
-
&control_mask_arr,
|
|
225
|
-
&time_dist_arr,
|
|
226
|
-
&control_obs,
|
|
227
|
-
lambda_time,
|
|
228
|
-
lambda_unit,
|
|
229
|
-
lambda_nn,
|
|
230
|
-
max_iter,
|
|
231
|
-
tol,
|
|
232
|
-
);
|
|
233
|
-
(lambda_time, lambda_unit, lambda_nn, score)
|
|
234
|
-
})
|
|
235
|
-
.collect();
|
|
360
|
+
// Stage 1: Univariate searches for initial values (paper footnote 2)
|
|
361
|
+
// λ_time search: fix λ_unit=0, λ_nn=∞ (disabled)
|
|
362
|
+
let (lambda_time_init, _) = univariate_loocv_search(
|
|
363
|
+
&y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
|
|
364
|
+
&lambda_time_vec, 0.0, 0.0, f64::INFINITY, 0, max_iter, tol,
|
|
365
|
+
);
|
|
236
366
|
|
|
237
|
-
//
|
|
238
|
-
let
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
367
|
+
// λ_nn search: fix λ_time=∞ (disabled), λ_unit=0
|
|
368
|
+
let (lambda_nn_init, _) = univariate_loocv_search(
|
|
369
|
+
&y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
|
|
370
|
+
&lambda_nn_vec, f64::INFINITY, 0.0, 0.0, 2, max_iter, tol,
|
|
371
|
+
);
|
|
372
|
+
|
|
373
|
+
// λ_unit search: fix λ_nn=∞, λ_time=0
|
|
374
|
+
let (lambda_unit_init, _) = univariate_loocv_search(
|
|
375
|
+
&y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
|
|
376
|
+
&lambda_unit_vec, 0.0, 0.0, f64::INFINITY, 1, max_iter, tol,
|
|
377
|
+
);
|
|
242
378
|
|
|
243
|
-
|
|
379
|
+
// Stage 2: Cycling refinement
|
|
380
|
+
let (best_time, best_unit, best_nn) = cycling_parameter_search(
|
|
381
|
+
&y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
|
|
382
|
+
&lambda_time_vec, &lambda_unit_vec, &lambda_nn_vec,
|
|
383
|
+
lambda_time_init, lambda_unit_init, lambda_nn_init,
|
|
384
|
+
max_iter, tol, 10,
|
|
385
|
+
);
|
|
386
|
+
|
|
387
|
+
// Convert infinity values BEFORE computing final score (Issue 1 fix)
|
|
388
|
+
// Per paper Equations 2-3:
|
|
389
|
+
// - λ_time/λ_unit=∞ → uniform weights → use 0.0
|
|
390
|
+
// - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
|
|
391
|
+
// This ensures final score computation matches what LOOCV evaluated.
|
|
392
|
+
let best_time_eff = if best_time.is_infinite() { 0.0 } else { best_time };
|
|
393
|
+
let best_unit_eff = if best_unit.is_infinite() { 0.0 } else { best_unit };
|
|
394
|
+
let best_nn_eff = if best_nn.is_infinite() { 1e10 } else { best_nn };
|
|
395
|
+
|
|
396
|
+
// Compute final score with converted values
|
|
397
|
+
let (best_score, n_valid, first_failed) = loocv_score_for_params(
|
|
398
|
+
&y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
|
|
399
|
+
best_time_eff, best_unit_eff, best_nn_eff,
|
|
400
|
+
max_iter, tol,
|
|
401
|
+
);
|
|
402
|
+
|
|
403
|
+
// Return ORIGINAL grid values (for user visibility) but score computed with converted
|
|
404
|
+
Ok((best_time, best_unit, best_nn, best_score, n_valid, n_attempted, first_failed))
|
|
244
405
|
}
|
|
245
406
|
|
|
246
407
|
/// Get sampled control observations for LOOCV.
|
|
@@ -277,6 +438,10 @@ fn get_control_observations(
|
|
|
277
438
|
}
|
|
278
439
|
|
|
279
440
|
/// Compute LOOCV score for a specific parameter combination.
|
|
441
|
+
///
|
|
442
|
+
/// # Returns
|
|
443
|
+
/// (score, n_valid, first_failed_obs) - the LOOCV score, number of successful fits,
|
|
444
|
+
/// and the first failed observation (t, i) if any fit failed, None otherwise.
|
|
280
445
|
#[allow(clippy::too_many_arguments)]
|
|
281
446
|
fn loocv_score_for_params(
|
|
282
447
|
y: &ArrayView2<f64>,
|
|
@@ -289,7 +454,7 @@ fn loocv_score_for_params(
|
|
|
289
454
|
lambda_nn: f64,
|
|
290
455
|
max_iter: usize,
|
|
291
456
|
tol: f64,
|
|
292
|
-
) -> f64 {
|
|
457
|
+
) -> (f64, usize, Option<(usize, usize)>) {
|
|
293
458
|
let n_periods = y.nrows();
|
|
294
459
|
let n_units = y.ncols();
|
|
295
460
|
|
|
@@ -328,14 +493,21 @@ fn loocv_score_for_params(
|
|
|
328
493
|
tau_sq_sum += tau * tau;
|
|
329
494
|
n_valid += 1;
|
|
330
495
|
}
|
|
331
|
-
None =>
|
|
496
|
+
None => {
|
|
497
|
+
// Per Equation 5: Q(λ) must sum over ALL D==0 cells
|
|
498
|
+
// Any failure means this λ cannot produce valid estimates for all cells
|
|
499
|
+
// Return the failed observation (t, i) for warning metadata
|
|
500
|
+
return (f64::INFINITY, n_valid, Some((t, i)));
|
|
501
|
+
}
|
|
332
502
|
}
|
|
333
503
|
}
|
|
334
504
|
|
|
335
505
|
if n_valid == 0 {
|
|
336
|
-
f64::INFINITY
|
|
506
|
+
(f64::INFINITY, 0, None)
|
|
337
507
|
} else {
|
|
338
|
-
|
|
508
|
+
// Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
|
|
509
|
+
// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
|
|
510
|
+
(tau_sq_sum, n_valid, None)
|
|
339
511
|
}
|
|
340
512
|
}
|
|
341
513
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|