diff-diff 2.1.2__tar.gz → 2.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diff_diff-2.1.2 → diff_diff-2.1.3}/PKG-INFO +1 -1
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/__init__.py +1 -1
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/trop.py +184 -46
- {diff_diff-2.1.2 → diff_diff-2.1.3}/pyproject.toml +1 -1
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/Cargo.lock +1 -1
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/Cargo.toml +1 -1
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/trop.rs +131 -60
- {diff_diff-2.1.2 → diff_diff-2.1.3}/README.md +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/_backend.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/bacon.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/datasets.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/diagnostics.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/estimators.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/honest_did.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/linalg.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/power.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/prep.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/pretrends.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/results.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/staggered.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/sun_abraham.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/synthetic_did.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/triple_diff.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/twfe.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/utils.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/visualization.py +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/bootstrap.rs +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/lib.rs +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/linalg.rs +0 -0
- {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/weights.rs +0 -0
|
@@ -63,7 +63,11 @@ class _PrecomputedStructures(TypedDict):
|
|
|
63
63
|
control_obs: List[Tuple[int, int]]
|
|
64
64
|
"""List of (t, i) tuples for valid control observations."""
|
|
65
65
|
control_unit_idx: np.ndarray
|
|
66
|
-
"""Array of
|
|
66
|
+
"""Array of never-treated unit indices (for backward compatibility)."""
|
|
67
|
+
D: np.ndarray
|
|
68
|
+
"""Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
|
|
69
|
+
Y: np.ndarray
|
|
70
|
+
"""Outcome matrix (n_periods x n_units)."""
|
|
67
71
|
n_units: int
|
|
68
72
|
"""Number of units."""
|
|
69
73
|
n_periods: int
|
|
@@ -529,6 +533,8 @@ class TROP:
|
|
|
529
533
|
"treated_observations": treated_observations,
|
|
530
534
|
"control_obs": control_obs,
|
|
531
535
|
"control_unit_idx": control_unit_idx,
|
|
536
|
+
"D": D,
|
|
537
|
+
"Y": Y,
|
|
532
538
|
"n_units": n_units,
|
|
533
539
|
"n_periods": n_periods,
|
|
534
540
|
}
|
|
@@ -778,16 +784,14 @@ class TROP:
|
|
|
778
784
|
# Prepare inputs for Rust function
|
|
779
785
|
control_mask_u8 = control_mask.astype(np.uint8)
|
|
780
786
|
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
|
|
781
|
-
unit_dist_matrix = self._precomputed["unit_dist_matrix"]
|
|
782
|
-
control_unit_idx_i64 = control_unit_idx.astype(np.int64)
|
|
783
787
|
|
|
784
788
|
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
|
|
785
789
|
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
|
|
786
790
|
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
|
|
787
791
|
|
|
788
792
|
best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search(
|
|
789
|
-
Y, D.astype(np.float64), control_mask_u8,
|
|
790
|
-
|
|
793
|
+
Y, D.astype(np.float64), control_mask_u8,
|
|
794
|
+
time_dist_matrix,
|
|
791
795
|
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
|
|
792
796
|
self.max_loocv_samples, self.max_iter, self.tol,
|
|
793
797
|
self.seed if self.seed is not None else 0
|
|
@@ -953,10 +957,16 @@ class TROP:
|
|
|
953
957
|
"""
|
|
954
958
|
Compute observation-specific weight matrix for treated observation (i, t).
|
|
955
959
|
|
|
956
|
-
Following the paper's Algorithm 2 (page 27):
|
|
960
|
+
Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
|
|
957
961
|
- Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
|
|
958
962
|
- Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
|
|
959
963
|
|
|
964
|
+
IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
|
|
965
|
+
where (1 - W_js) is non-zero, which includes pre-treatment observations of
|
|
966
|
+
eventually-treated units since W_js = 0 for those. This method computes
|
|
967
|
+
weights for ALL units where D[t, j] = 0 at the target period, not just
|
|
968
|
+
never-treated units.
|
|
969
|
+
|
|
960
970
|
Uses pre-computed structures when available for efficiency.
|
|
961
971
|
|
|
962
972
|
Parameters
|
|
@@ -974,7 +984,8 @@ class TROP:
|
|
|
974
984
|
lambda_unit : float
|
|
975
985
|
Unit weight decay parameter.
|
|
976
986
|
control_unit_idx : np.ndarray
|
|
977
|
-
Indices of
|
|
987
|
+
Indices of never-treated units (for backward compatibility, but not
|
|
988
|
+
used for weight computation - we use D matrix directly).
|
|
978
989
|
n_units : int
|
|
979
990
|
Number of units.
|
|
980
991
|
n_periods : int
|
|
@@ -991,21 +1002,30 @@ class TROP:
|
|
|
991
1002
|
# time_dist_matrix[t, s] = |t - s|
|
|
992
1003
|
time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
|
|
993
1004
|
|
|
994
|
-
# Unit weights
|
|
1005
|
+
# Unit weights - computed for ALL units where D[t, j] = 0
|
|
1006
|
+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
|
|
995
1007
|
unit_weights = np.zeros(n_units)
|
|
1008
|
+
D_stored = self._precomputed["D"]
|
|
1009
|
+
Y_stored = self._precomputed["Y"]
|
|
1010
|
+
|
|
1011
|
+
# Valid control units at time t: D[t, j] == 0
|
|
1012
|
+
valid_control_at_t = D_stored[t, :] == 0
|
|
996
1013
|
|
|
997
1014
|
if lambda_unit == 0:
|
|
998
1015
|
# Uniform weights when lambda_unit = 0
|
|
999
|
-
|
|
1016
|
+
# All units not treated at time t get weight 1
|
|
1017
|
+
unit_weights[valid_control_at_t] = 1.0
|
|
1000
1018
|
else:
|
|
1001
|
-
# Use
|
|
1002
|
-
|
|
1003
|
-
for j in
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1019
|
+
# Use observation-specific distances with target period excluded
|
|
1020
|
+
# (Issue B fix: compute exact per-observation distance)
|
|
1021
|
+
for j in range(n_units):
|
|
1022
|
+
if valid_control_at_t[j] and j != i:
|
|
1023
|
+
# Compute distance excluding target period t
|
|
1024
|
+
dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
|
|
1025
|
+
if np.isinf(dist):
|
|
1026
|
+
unit_weights[j] = 0.0
|
|
1027
|
+
else:
|
|
1028
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
1009
1029
|
|
|
1010
1030
|
# Treated unit i gets weight 1
|
|
1011
1031
|
unit_weights[i] = 1.0
|
|
@@ -1018,19 +1038,25 @@ class TROP:
|
|
|
1018
1038
|
dist_time = np.abs(np.arange(n_periods) - t)
|
|
1019
1039
|
time_weights = np.exp(-lambda_time * dist_time)
|
|
1020
1040
|
|
|
1021
|
-
# Unit
|
|
1041
|
+
# Unit weights - computed for ALL units where D[t, j] = 0
|
|
1042
|
+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
|
|
1022
1043
|
unit_weights = np.zeros(n_units)
|
|
1023
1044
|
|
|
1045
|
+
# Valid control units at time t: D[t, j] == 0
|
|
1046
|
+
valid_control_at_t = D[t, :] == 0
|
|
1047
|
+
|
|
1024
1048
|
if lambda_unit == 0:
|
|
1025
1049
|
# Uniform weights when lambda_unit = 0
|
|
1026
|
-
unit_weights[
|
|
1050
|
+
unit_weights[valid_control_at_t] = 1.0
|
|
1027
1051
|
else:
|
|
1028
|
-
for j in
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1052
|
+
for j in range(n_units):
|
|
1053
|
+
if valid_control_at_t[j] and j != i:
|
|
1054
|
+
# Compute distance excluding target period t (Issue B fix)
|
|
1055
|
+
dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
|
|
1056
|
+
if np.isinf(dist):
|
|
1057
|
+
unit_weights[j] = 0.0
|
|
1058
|
+
else:
|
|
1059
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
1034
1060
|
|
|
1035
1061
|
# Treated unit i gets weight 1 (or could be omitted since we fit on controls)
|
|
1036
1062
|
# We include treated unit's own observation for model fitting
|
|
@@ -1102,6 +1128,101 @@ class TROP:
|
|
|
1102
1128
|
|
|
1103
1129
|
return result
|
|
1104
1130
|
|
|
1131
|
+
def _weighted_nuclear_norm_solve(
|
|
1132
|
+
self,
|
|
1133
|
+
Y: np.ndarray,
|
|
1134
|
+
W: np.ndarray,
|
|
1135
|
+
L_init: np.ndarray,
|
|
1136
|
+
alpha: np.ndarray,
|
|
1137
|
+
beta: np.ndarray,
|
|
1138
|
+
lambda_nn: float,
|
|
1139
|
+
max_inner_iter: int = 20,
|
|
1140
|
+
) -> np.ndarray:
|
|
1141
|
+
"""
|
|
1142
|
+
Solve weighted nuclear norm problem using iterative weighted soft-impute.
|
|
1143
|
+
|
|
1144
|
+
Issue C fix: Implements the weighted nuclear norm optimization from the
|
|
1145
|
+
paper's Equation 2 (page 7). The full objective is:
|
|
1146
|
+
min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
|
|
1147
|
+
|
|
1148
|
+
This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
|
|
1149
|
+
L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
|
|
1150
|
+
|
|
1151
|
+
where W ⊙ denotes element-wise multiplication with normalized weights.
|
|
1152
|
+
|
|
1153
|
+
IMPORTANT: For observations with W=0 (treated observations), we keep
|
|
1154
|
+
L values from the previous iteration rather than setting L = R, which
|
|
1155
|
+
would absorb the treatment effect.
|
|
1156
|
+
|
|
1157
|
+
Parameters
|
|
1158
|
+
----------
|
|
1159
|
+
Y : np.ndarray
|
|
1160
|
+
Outcome matrix (n_periods x n_units).
|
|
1161
|
+
W : np.ndarray
|
|
1162
|
+
Weight matrix (n_periods x n_units), non-negative. W=0 indicates
|
|
1163
|
+
observations that should not be used for fitting (treated obs).
|
|
1164
|
+
L_init : np.ndarray
|
|
1165
|
+
Initial estimate of L matrix.
|
|
1166
|
+
alpha : np.ndarray
|
|
1167
|
+
Current unit fixed effects estimate.
|
|
1168
|
+
beta : np.ndarray
|
|
1169
|
+
Current time fixed effects estimate.
|
|
1170
|
+
lambda_nn : float
|
|
1171
|
+
Nuclear norm regularization parameter.
|
|
1172
|
+
max_inner_iter : int, default=20
|
|
1173
|
+
Maximum inner iterations for the proximal algorithm.
|
|
1174
|
+
|
|
1175
|
+
Returns
|
|
1176
|
+
-------
|
|
1177
|
+
np.ndarray
|
|
1178
|
+
Updated L matrix estimate.
|
|
1179
|
+
"""
|
|
1180
|
+
# Compute target residual R = Y - α - β
|
|
1181
|
+
R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
|
|
1182
|
+
|
|
1183
|
+
# Handle invalid values
|
|
1184
|
+
R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
|
|
1185
|
+
|
|
1186
|
+
# For observations with W=0 (treated obs), keep L_init instead of R
|
|
1187
|
+
# This prevents L from absorbing the treatment effect
|
|
1188
|
+
valid_obs_mask = W > 0
|
|
1189
|
+
R_masked = np.where(valid_obs_mask, R, L_init)
|
|
1190
|
+
|
|
1191
|
+
if lambda_nn <= 0:
|
|
1192
|
+
# No regularization - just return masked residual
|
|
1193
|
+
# Use soft-thresholding with threshold=0 which returns the input
|
|
1194
|
+
return R_masked
|
|
1195
|
+
|
|
1196
|
+
# Normalize weights so max is 1 (for step size stability)
|
|
1197
|
+
W_max = np.max(W)
|
|
1198
|
+
if W_max > 0:
|
|
1199
|
+
W_norm = W / W_max
|
|
1200
|
+
else:
|
|
1201
|
+
W_norm = W
|
|
1202
|
+
|
|
1203
|
+
# Initialize L
|
|
1204
|
+
L = L_init.copy()
|
|
1205
|
+
|
|
1206
|
+
# Proximal gradient iteration with weighted soft-impute
|
|
1207
|
+
# This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
|
|
1208
|
+
# Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
|
|
1209
|
+
# where η is the step size (we use η = 1 with normalized weights)
|
|
1210
|
+
for _ in range(max_inner_iter):
|
|
1211
|
+
L_old = L.copy()
|
|
1212
|
+
|
|
1213
|
+
# Gradient step: L_k + W ⊙ (R - L_k)
|
|
1214
|
+
# For W=0 observations, this keeps L_k unchanged
|
|
1215
|
+
gradient_step = L + W_norm * (R_masked - L)
|
|
1216
|
+
|
|
1217
|
+
# Proximal step: soft-threshold singular values
|
|
1218
|
+
L = self._soft_threshold_svd(gradient_step, lambda_nn)
|
|
1219
|
+
|
|
1220
|
+
# Check convergence
|
|
1221
|
+
if np.max(np.abs(L - L_old)) < self.tol:
|
|
1222
|
+
break
|
|
1223
|
+
|
|
1224
|
+
return L
|
|
1225
|
+
|
|
1105
1226
|
def _estimate_model(
|
|
1106
1227
|
self,
|
|
1107
1228
|
Y: np.ndarray,
|
|
@@ -1205,14 +1326,13 @@ class TROP:
|
|
|
1205
1326
|
beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
|
|
1206
1327
|
beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
|
|
1207
1328
|
|
|
1208
|
-
# Step 2: Update L with nuclear norm penalty
|
|
1209
|
-
#
|
|
1210
|
-
#
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
L = self._soft_threshold_svd(R_for_L, lambda_nn)
|
|
1329
|
+
# Step 2: Update L with weighted nuclear norm penalty
|
|
1330
|
+
# Issue C fix: Use weighted soft-impute to properly account for
|
|
1331
|
+
# observation weights in the nuclear norm optimization.
|
|
1332
|
+
# Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
|
|
1333
|
+
L = self._weighted_nuclear_norm_solve(
|
|
1334
|
+
Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
|
|
1335
|
+
)
|
|
1216
1336
|
|
|
1217
1337
|
# Check convergence
|
|
1218
1338
|
alpha_diff = np.max(np.abs(alpha - alpha_old))
|
|
@@ -1388,21 +1508,15 @@ class TROP:
|
|
|
1388
1508
|
# Try Rust backend for parallel bootstrap (5-15x speedup)
|
|
1389
1509
|
if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
|
|
1390
1510
|
and self._precomputed is not None and Y is not None
|
|
1391
|
-
and D is not None
|
|
1511
|
+
and D is not None):
|
|
1392
1512
|
try:
|
|
1393
|
-
# Prepare inputs
|
|
1394
|
-
treated_observations = self._precomputed["treated_observations"]
|
|
1395
|
-
treated_t = np.array([t for t, i in treated_observations], dtype=np.int64)
|
|
1396
|
-
treated_i = np.array([i for t, i in treated_observations], dtype=np.int64)
|
|
1397
1513
|
control_mask = self._precomputed["control_mask"]
|
|
1514
|
+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
|
|
1398
1515
|
|
|
1399
1516
|
bootstrap_estimates, se = _rust_bootstrap_trop_variance(
|
|
1400
1517
|
Y, D.astype(np.float64),
|
|
1401
1518
|
control_mask.astype(np.uint8),
|
|
1402
|
-
|
|
1403
|
-
treated_t, treated_i,
|
|
1404
|
-
self._precomputed["unit_dist_matrix"],
|
|
1405
|
-
self._precomputed["time_dist_matrix"].astype(np.int64),
|
|
1519
|
+
time_dist_matrix,
|
|
1406
1520
|
lambda_time, lambda_unit, lambda_nn,
|
|
1407
1521
|
self.n_bootstrap, self.max_iter, self.tol,
|
|
1408
1522
|
self.seed if self.seed is not None else 0
|
|
@@ -1422,14 +1536,38 @@ class TROP:
|
|
|
1422
1536
|
|
|
1423
1537
|
# Python implementation (fallback)
|
|
1424
1538
|
rng = np.random.default_rng(self.seed)
|
|
1425
|
-
|
|
1426
|
-
|
|
1539
|
+
|
|
1540
|
+
# Issue D fix: Stratified bootstrap sampling
|
|
1541
|
+
# Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
|
|
1542
|
+
# and N_1 treated rows separately to preserve treatment ratio
|
|
1543
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
1544
|
+
treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
|
|
1545
|
+
control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
|
|
1546
|
+
|
|
1547
|
+
n_treated_units = len(treated_units)
|
|
1548
|
+
n_control_units = len(control_units)
|
|
1427
1549
|
|
|
1428
1550
|
bootstrap_estimates_list = []
|
|
1429
1551
|
|
|
1430
1552
|
for _ in range(self.n_bootstrap):
|
|
1431
|
-
#
|
|
1432
|
-
|
|
1553
|
+
# Stratified sampling: sample control and treated units separately
|
|
1554
|
+
# This preserves the treatment ratio in each bootstrap sample
|
|
1555
|
+
if n_control_units > 0:
|
|
1556
|
+
sampled_control = rng.choice(
|
|
1557
|
+
control_units, size=n_control_units, replace=True
|
|
1558
|
+
)
|
|
1559
|
+
else:
|
|
1560
|
+
sampled_control = np.array([], dtype=control_units.dtype)
|
|
1561
|
+
|
|
1562
|
+
if n_treated_units > 0:
|
|
1563
|
+
sampled_treated = rng.choice(
|
|
1564
|
+
treated_units, size=n_treated_units, replace=True
|
|
1565
|
+
)
|
|
1566
|
+
else:
|
|
1567
|
+
sampled_treated = np.array([], dtype=treated_units.dtype)
|
|
1568
|
+
|
|
1569
|
+
# Combine stratified samples
|
|
1570
|
+
sampled_units = np.concatenate([sampled_control, sampled_treated])
|
|
1433
1571
|
|
|
1434
1572
|
# Create bootstrap sample with unique unit IDs
|
|
1435
1573
|
boot_data = pd.concat([
|
|
@@ -172,15 +172,13 @@ fn compute_pair_distance(
|
|
|
172
172
|
/// # Returns
|
|
173
173
|
/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score)
|
|
174
174
|
#[pyfunction]
|
|
175
|
-
#[pyo3(signature = (y, d, control_mask,
|
|
175
|
+
#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
|
|
176
176
|
#[allow(clippy::too_many_arguments)]
|
|
177
177
|
pub fn loocv_grid_search<'py>(
|
|
178
178
|
_py: Python<'py>,
|
|
179
179
|
y: PyReadonlyArray2<'py, f64>,
|
|
180
180
|
d: PyReadonlyArray2<'py, f64>,
|
|
181
181
|
control_mask: PyReadonlyArray2<'py, u8>,
|
|
182
|
-
control_unit_idx: PyReadonlyArray1<'py, i64>,
|
|
183
|
-
unit_dist_matrix: PyReadonlyArray2<'py, f64>,
|
|
184
182
|
time_dist_matrix: PyReadonlyArray2<'py, i64>,
|
|
185
183
|
lambda_time_grid: PyReadonlyArray1<'py, f64>,
|
|
186
184
|
lambda_unit_grid: PyReadonlyArray1<'py, f64>,
|
|
@@ -193,19 +191,11 @@ pub fn loocv_grid_search<'py>(
|
|
|
193
191
|
let y_arr = y.as_array();
|
|
194
192
|
let d_arr = d.as_array();
|
|
195
193
|
let control_mask_arr = control_mask.as_array();
|
|
196
|
-
let control_unit_idx_arr = control_unit_idx.as_array();
|
|
197
|
-
let unit_dist_arr = unit_dist_matrix.as_array();
|
|
198
194
|
let time_dist_arr = time_dist_matrix.as_array();
|
|
199
195
|
let lambda_time_vec: Vec<f64> = lambda_time_grid.as_array().to_vec();
|
|
200
196
|
let lambda_unit_vec: Vec<f64> = lambda_unit_grid.as_array().to_vec();
|
|
201
197
|
let lambda_nn_vec: Vec<f64> = lambda_nn_grid.as_array().to_vec();
|
|
202
198
|
|
|
203
|
-
// Convert control_unit_idx to Vec<usize>
|
|
204
|
-
let control_units: Vec<usize> = control_unit_idx_arr
|
|
205
|
-
.iter()
|
|
206
|
-
.map(|&idx| idx as usize)
|
|
207
|
-
.collect();
|
|
208
|
-
|
|
209
199
|
// Get control observations for LOOCV
|
|
210
200
|
let control_obs = get_control_observations(
|
|
211
201
|
&y_arr,
|
|
@@ -232,8 +222,6 @@ pub fn loocv_grid_search<'py>(
|
|
|
232
222
|
&y_arr,
|
|
233
223
|
&d_arr,
|
|
234
224
|
&control_mask_arr,
|
|
235
|
-
&control_units,
|
|
236
|
-
&unit_dist_arr,
|
|
237
225
|
&time_dist_arr,
|
|
238
226
|
&control_obs,
|
|
239
227
|
lambda_time,
|
|
@@ -291,10 +279,8 @@ fn get_control_observations(
|
|
|
291
279
|
/// Compute LOOCV score for a specific parameter combination.
|
|
292
280
|
fn loocv_score_for_params(
|
|
293
281
|
y: &ArrayView2<f64>,
|
|
294
|
-
|
|
282
|
+
d: &ArrayView2<f64>,
|
|
295
283
|
control_mask: &ArrayView2<u8>,
|
|
296
|
-
control_units: &[usize],
|
|
297
|
-
unit_dist: &ArrayView2<f64>,
|
|
298
284
|
time_dist: &ArrayView2<i64>,
|
|
299
285
|
control_obs: &[(usize, usize)],
|
|
300
286
|
lambda_time: f64,
|
|
@@ -312,14 +298,14 @@ fn loocv_score_for_params(
|
|
|
312
298
|
for &(t, i) in control_obs {
|
|
313
299
|
// Compute observation-specific weight matrix
|
|
314
300
|
let weight_matrix = compute_weight_matrix(
|
|
301
|
+
y,
|
|
302
|
+
d,
|
|
315
303
|
n_periods,
|
|
316
304
|
n_units,
|
|
317
305
|
i,
|
|
318
306
|
t,
|
|
319
307
|
lambda_time,
|
|
320
308
|
lambda_unit,
|
|
321
|
-
control_units,
|
|
322
|
-
unit_dist,
|
|
323
309
|
time_dist,
|
|
324
310
|
);
|
|
325
311
|
|
|
@@ -352,46 +338,107 @@ fn loocv_score_for_params(
|
|
|
352
338
|
}
|
|
353
339
|
}
|
|
354
340
|
|
|
341
|
+
/// Compute observation-specific distance from unit j to unit i, excluding target period.
|
|
342
|
+
///
|
|
343
|
+
/// Issue B fix: Follows Equation 3 (page 7) which specifies 1{u ≠ t} to exclude target period.
|
|
344
|
+
fn compute_unit_distance_for_obs(
|
|
345
|
+
y: &ArrayView2<f64>,
|
|
346
|
+
d: &ArrayView2<f64>,
|
|
347
|
+
j: usize,
|
|
348
|
+
i: usize,
|
|
349
|
+
target_period: usize,
|
|
350
|
+
) -> f64 {
|
|
351
|
+
let n_periods = y.nrows();
|
|
352
|
+
let mut sum_sq = 0.0;
|
|
353
|
+
let mut n_valid = 0usize;
|
|
354
|
+
|
|
355
|
+
for t in 0..n_periods {
|
|
356
|
+
// Exclude target period (Issue B fix)
|
|
357
|
+
if t == target_period {
|
|
358
|
+
continue;
|
|
359
|
+
}
|
|
360
|
+
// Both units must be control at this period and have valid values
|
|
361
|
+
if d[[t, i]] == 0.0 && d[[t, j]] == 0.0
|
|
362
|
+
&& y[[t, i]].is_finite() && y[[t, j]].is_finite()
|
|
363
|
+
{
|
|
364
|
+
let diff = y[[t, i]] - y[[t, j]];
|
|
365
|
+
sum_sq += diff * diff;
|
|
366
|
+
n_valid += 1;
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
if n_valid > 0 {
|
|
371
|
+
(sum_sq / n_valid as f64).sqrt()
|
|
372
|
+
} else {
|
|
373
|
+
f64::INFINITY
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
355
377
|
/// Compute observation-specific weight matrix for TROP.
|
|
356
378
|
///
|
|
357
379
|
/// Time weights: θ_s = exp(-λ_time × |t - s|)
|
|
358
380
|
/// Unit weights: ω_j = exp(-λ_unit × dist(j, i))
|
|
381
|
+
///
|
|
382
|
+
/// Paper alignment notes:
|
|
383
|
+
/// - ALL units get weights (not just those untreated at target period)
|
|
384
|
+
/// - The (1 - D_js) masking in the loss naturally excludes treated cells
|
|
385
|
+
/// - Weights are normalized to sum to 1 (probability weights)
|
|
386
|
+
/// - Distance excludes target period t per Equation 3
|
|
359
387
|
fn compute_weight_matrix(
|
|
388
|
+
y: &ArrayView2<f64>,
|
|
389
|
+
d: &ArrayView2<f64>,
|
|
360
390
|
n_periods: usize,
|
|
361
391
|
n_units: usize,
|
|
362
392
|
target_unit: usize,
|
|
363
393
|
target_period: usize,
|
|
364
394
|
lambda_time: f64,
|
|
365
395
|
lambda_unit: f64,
|
|
366
|
-
control_units: &[usize],
|
|
367
|
-
unit_dist: &ArrayView2<f64>,
|
|
368
396
|
time_dist: &ArrayView2<i64>,
|
|
369
397
|
) -> Array2<f64> {
|
|
370
|
-
// Time weights for this target period
|
|
371
|
-
let time_weights: Array1<f64> = Array1::from_shape_fn(n_periods, |s| {
|
|
398
|
+
// Time weights for this target period: θ_s = exp(-λ_time × |t - s|)
|
|
399
|
+
let mut time_weights: Array1<f64> = Array1::from_shape_fn(n_periods, |s| {
|
|
372
400
|
let dist = time_dist[[target_period, s]] as f64;
|
|
373
401
|
(-lambda_time * dist).exp()
|
|
374
402
|
});
|
|
375
403
|
|
|
376
|
-
//
|
|
404
|
+
// Normalize time weights to sum to 1
|
|
405
|
+
let time_sum: f64 = time_weights.sum();
|
|
406
|
+
if time_sum > 0.0 {
|
|
407
|
+
time_weights /= time_sum;
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
// Unit weights: ω_j = exp(-λ_unit × dist(j, i))
|
|
411
|
+
// Paper alignment: compute for ALL units, let control masking handle exclusion
|
|
377
412
|
let mut unit_weights = Array1::<f64>::zeros(n_units);
|
|
378
413
|
|
|
379
414
|
if lambda_unit == 0.0 {
|
|
380
415
|
// Uniform weights when lambda_unit = 0
|
|
416
|
+
// All units get weight 1 (control masking will handle exclusion)
|
|
381
417
|
unit_weights.fill(1.0);
|
|
382
418
|
} else {
|
|
383
|
-
for
|
|
384
|
-
|
|
385
|
-
if
|
|
386
|
-
|
|
419
|
+
// Compute per-observation distance for all units (excluding target unit itself)
|
|
420
|
+
for j in 0..n_units {
|
|
421
|
+
if j != target_unit {
|
|
422
|
+
let dist = compute_unit_distance_for_obs(y, d, j, target_unit, target_period);
|
|
423
|
+
if dist.is_finite() {
|
|
424
|
+
unit_weights[j] = (-lambda_unit * dist).exp();
|
|
425
|
+
}
|
|
426
|
+
// Units with infinite distance (no valid comparison periods) get weight 0
|
|
387
427
|
}
|
|
388
428
|
}
|
|
389
429
|
}
|
|
390
430
|
|
|
391
|
-
// Target unit gets weight 1
|
|
431
|
+
// Target unit gets weight 1 (will be masked out in estimation anyway)
|
|
392
432
|
unit_weights[target_unit] = 1.0;
|
|
393
433
|
|
|
434
|
+
// Normalize unit weights to sum to 1
|
|
435
|
+
let unit_sum: f64 = unit_weights.sum();
|
|
436
|
+
if unit_sum > 0.0 {
|
|
437
|
+
unit_weights /= unit_sum;
|
|
438
|
+
}
|
|
439
|
+
|
|
394
440
|
// Outer product: W[t, i] = time_weights[t] * unit_weights[i]
|
|
441
|
+
// Result is normalized since both components sum to 1
|
|
395
442
|
let mut weight_matrix = Array2::<f64>::zeros((n_periods, n_units));
|
|
396
443
|
for t in 0..n_periods {
|
|
397
444
|
for i in 0..n_units {
|
|
@@ -406,6 +453,10 @@ fn compute_weight_matrix(
|
|
|
406
453
|
///
|
|
407
454
|
/// Minimizes: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
|
|
408
455
|
///
|
|
456
|
+
/// Paper alignment: Uses weighted proximal gradient for L update:
|
|
457
|
+
/// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
|
|
458
|
+
/// where η ≤ 1/max(W) for convergence.
|
|
459
|
+
///
|
|
409
460
|
/// Returns None if estimation fails due to numerical issues.
|
|
410
461
|
fn estimate_model(
|
|
411
462
|
y: &ArrayView2<f64>,
|
|
@@ -432,7 +483,7 @@ fn estimate_model(
|
|
|
432
483
|
y[[t, i]].is_finite() && est_mask[[t, i]]
|
|
433
484
|
});
|
|
434
485
|
|
|
435
|
-
// Masked weights
|
|
486
|
+
// Masked weights: W=0 for invalid/treated observations
|
|
436
487
|
let w_masked = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
|
|
437
488
|
if valid_mask[[t, i]] {
|
|
438
489
|
weight_matrix[[t, i]]
|
|
@@ -441,6 +492,10 @@ fn estimate_model(
|
|
|
441
492
|
}
|
|
442
493
|
});
|
|
443
494
|
|
|
495
|
+
// Compute step size for proximal gradient: η ≤ 1/max(W)
|
|
496
|
+
let w_max = w_masked.iter().cloned().fold(0.0_f64, f64::max);
|
|
497
|
+
let eta = if w_max > 0.0 { 1.0 / w_max } else { 1.0 };
|
|
498
|
+
|
|
444
499
|
// Weight sums per unit and time
|
|
445
500
|
let weight_sum_per_unit: Array1<f64> = w_masked.sum_axis(Axis(0));
|
|
446
501
|
let weight_sum_per_time: Array1<f64> = w_masked.sum_axis(Axis(1));
|
|
@@ -472,7 +527,7 @@ fn estimate_model(
|
|
|
472
527
|
let beta_old = beta.clone();
|
|
473
528
|
let l_old = l.clone();
|
|
474
529
|
|
|
475
|
-
// Step 1: Update α and β
|
|
530
|
+
// Step 1: Update α and β (weighted least squares)
|
|
476
531
|
// R = Y - L
|
|
477
532
|
let r = &y_safe - &l;
|
|
478
533
|
|
|
@@ -498,25 +553,31 @@ fn estimate_model(
|
|
|
498
553
|
}
|
|
499
554
|
}
|
|
500
555
|
|
|
501
|
-
// Step 2: Update L with nuclear norm penalty
|
|
502
|
-
//
|
|
503
|
-
|
|
556
|
+
// Step 2: Update L with WEIGHTED nuclear norm penalty
|
|
557
|
+
// Paper alignment: Use proximal gradient instead of direct soft-thresholding
|
|
558
|
+
// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
|
|
559
|
+
// where R = Y - α - β
|
|
560
|
+
|
|
561
|
+
// Compute target residual R = Y - α - β
|
|
562
|
+
let mut r_target = Array2::<f64>::zeros((n_periods, n_units));
|
|
504
563
|
for t in 0..n_periods {
|
|
505
564
|
for i in 0..n_units {
|
|
506
|
-
|
|
565
|
+
r_target[[t, i]] = y_safe[[t, i]] - alpha[i] - beta[t];
|
|
507
566
|
}
|
|
508
567
|
}
|
|
509
568
|
|
|
510
|
-
//
|
|
569
|
+
// Weighted proximal gradient step:
|
|
570
|
+
// gradient_step = L + η * W ⊙ (R - L)
|
|
571
|
+
// For W=0 cells (treated obs), this keeps L unchanged
|
|
572
|
+
let mut gradient_step = Array2::<f64>::zeros((n_periods, n_units));
|
|
511
573
|
for t in 0..n_periods {
|
|
512
574
|
for i in 0..n_units {
|
|
513
|
-
|
|
514
|
-
r_for_l[[t, i]] = l[[t, i]];
|
|
515
|
-
}
|
|
575
|
+
gradient_step[[t, i]] = l[[t, i]] + eta * w_masked[[t, i]] * (r_target[[t, i]] - l[[t, i]]);
|
|
516
576
|
}
|
|
517
577
|
}
|
|
518
578
|
|
|
519
|
-
|
|
579
|
+
// Proximal step: soft-threshold singular values with scaled lambda
|
|
580
|
+
l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?;
|
|
520
581
|
|
|
521
582
|
// Check convergence
|
|
522
583
|
let alpha_diff = max_abs_diff(&alpha, &alpha_old);
|
|
@@ -627,17 +688,13 @@ fn max_abs_diff_2d(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
|
|
|
627
688
|
/// # Returns
|
|
628
689
|
/// (bootstrap_estimates, standard_error)
|
|
629
690
|
#[pyfunction]
|
|
630
|
-
#[pyo3(signature = (y, d, control_mask,
|
|
691
|
+
#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))]
|
|
631
692
|
#[allow(clippy::too_many_arguments)]
|
|
632
693
|
pub fn bootstrap_trop_variance<'py>(
|
|
633
694
|
py: Python<'py>,
|
|
634
695
|
y: PyReadonlyArray2<'py, f64>,
|
|
635
696
|
d: PyReadonlyArray2<'py, f64>,
|
|
636
697
|
control_mask: PyReadonlyArray2<'py, u8>,
|
|
637
|
-
control_unit_idx: PyReadonlyArray1<'py, i64>,
|
|
638
|
-
treated_obs_t: PyReadonlyArray1<'py, i64>,
|
|
639
|
-
treated_obs_i: PyReadonlyArray1<'py, i64>,
|
|
640
|
-
unit_dist_matrix: PyReadonlyArray2<'py, f64>,
|
|
641
698
|
time_dist_matrix: PyReadonlyArray2<'py, i64>,
|
|
642
699
|
lambda_time: f64,
|
|
643
700
|
lambda_unit: f64,
|
|
@@ -650,16 +707,25 @@ pub fn bootstrap_trop_variance<'py>(
|
|
|
650
707
|
let y_arr = y.as_array().to_owned();
|
|
651
708
|
let d_arr = d.as_array().to_owned();
|
|
652
709
|
let control_mask_arr = control_mask.as_array().to_owned();
|
|
653
|
-
let unit_dist_arr = unit_dist_matrix.as_array().to_owned();
|
|
654
710
|
let time_dist_arr = time_dist_matrix.as_array().to_owned();
|
|
655
711
|
|
|
656
712
|
let n_units = y_arr.ncols();
|
|
657
713
|
let n_periods = y_arr.nrows();
|
|
658
714
|
|
|
659
|
-
//
|
|
660
|
-
//
|
|
661
|
-
|
|
662
|
-
let
|
|
715
|
+
// Issue D fix: Identify treated and control units for stratified sampling
|
|
716
|
+
// Following paper's Algorithm 3 (page 27): sample N_0 control and N_1 treated separately
|
|
717
|
+
let mut original_treated_units: Vec<usize> = Vec::new();
|
|
718
|
+
let mut original_control_units: Vec<usize> = Vec::new();
|
|
719
|
+
for i in 0..n_units {
|
|
720
|
+
let is_ever_treated = (0..n_periods).any(|t| d_arr[[t, i]] == 1.0);
|
|
721
|
+
if is_ever_treated {
|
|
722
|
+
original_treated_units.push(i);
|
|
723
|
+
} else {
|
|
724
|
+
original_control_units.push(i);
|
|
725
|
+
}
|
|
726
|
+
}
|
|
727
|
+
let n_treated_units = original_treated_units.len();
|
|
728
|
+
let n_control_units = original_control_units.len();
|
|
663
729
|
|
|
664
730
|
// Run bootstrap iterations in parallel
|
|
665
731
|
let bootstrap_estimates: Vec<f64> = (0..n_bootstrap)
|
|
@@ -670,16 +736,25 @@ pub fn bootstrap_trop_variance<'py>(
|
|
|
670
736
|
|
|
671
737
|
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(b as u64));
|
|
672
738
|
|
|
673
|
-
//
|
|
674
|
-
let sampled_units: Vec<usize> = (
|
|
675
|
-
|
|
676
|
-
|
|
739
|
+
// Issue D fix: Stratified sampling - sample control and treated units separately
|
|
740
|
+
let mut sampled_units: Vec<usize> = Vec::with_capacity(n_units);
|
|
741
|
+
|
|
742
|
+
// Sample control units with replacement
|
|
743
|
+
for _ in 0..n_control_units {
|
|
744
|
+
let idx = rng.gen_range(0..n_control_units);
|
|
745
|
+
sampled_units.push(original_control_units[idx]);
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
// Sample treated units with replacement
|
|
749
|
+
for _ in 0..n_treated_units {
|
|
750
|
+
let idx = rng.gen_range(0..n_treated_units);
|
|
751
|
+
sampled_units.push(original_treated_units[idx]);
|
|
752
|
+
}
|
|
677
753
|
|
|
678
754
|
// Create bootstrap matrices by selecting columns
|
|
679
755
|
let mut y_boot = Array2::<f64>::zeros((n_periods, n_units));
|
|
680
756
|
let mut d_boot = Array2::<f64>::zeros((n_periods, n_units));
|
|
681
757
|
let mut control_mask_boot = Array2::<u8>::zeros((n_periods, n_units));
|
|
682
|
-
let mut unit_dist_boot = Array2::<f64>::zeros((n_units, n_units));
|
|
683
758
|
|
|
684
759
|
for (new_idx, &old_idx) in sampled_units.iter().enumerate() {
|
|
685
760
|
for t in 0..n_periods {
|
|
@@ -687,10 +762,6 @@ pub fn bootstrap_trop_variance<'py>(
|
|
|
687
762
|
d_boot[[t, new_idx]] = d_arr[[t, old_idx]];
|
|
688
763
|
control_mask_boot[[t, new_idx]] = control_mask_arr[[t, old_idx]];
|
|
689
764
|
}
|
|
690
|
-
|
|
691
|
-
for (new_j, &old_j) in sampled_units.iter().enumerate() {
|
|
692
|
-
unit_dist_boot[[new_idx, new_j]] = unit_dist_arr[[old_idx, old_j]];
|
|
693
|
-
}
|
|
694
765
|
}
|
|
695
766
|
|
|
696
767
|
// Get treated observations in bootstrap sample
|
|
@@ -725,14 +796,14 @@ pub fn bootstrap_trop_variance<'py>(
|
|
|
725
796
|
|
|
726
797
|
for (t, i) in boot_treated {
|
|
727
798
|
let weight_matrix = compute_weight_matrix(
|
|
799
|
+
&y_boot.view(),
|
|
800
|
+
&d_boot.view(),
|
|
728
801
|
n_periods,
|
|
729
802
|
n_units,
|
|
730
803
|
i,
|
|
731
804
|
t,
|
|
732
805
|
lambda_time,
|
|
733
806
|
lambda_unit,
|
|
734
|
-
&boot_control_units,
|
|
735
|
-
&unit_dist_boot.view(),
|
|
736
807
|
&time_dist_arr.view(),
|
|
737
808
|
);
|
|
738
809
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|