diff-diff 2.1.2__tar.gz → 2.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {diff_diff-2.1.2 → diff_diff-2.1.3}/PKG-INFO +1 -1
  2. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/__init__.py +1 -1
  3. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/trop.py +184 -46
  4. {diff_diff-2.1.2 → diff_diff-2.1.3}/pyproject.toml +1 -1
  5. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/Cargo.lock +1 -1
  6. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/Cargo.toml +1 -1
  7. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/trop.rs +131 -60
  8. {diff_diff-2.1.2 → diff_diff-2.1.3}/README.md +0 -0
  9. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/_backend.py +0 -0
  10. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/bacon.py +0 -0
  11. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/datasets.py +0 -0
  12. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/diagnostics.py +0 -0
  13. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/estimators.py +0 -0
  14. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/honest_did.py +0 -0
  15. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/linalg.py +0 -0
  16. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/power.py +0 -0
  17. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/prep.py +0 -0
  18. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/pretrends.py +0 -0
  19. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/results.py +0 -0
  20. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/staggered.py +0 -0
  21. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/sun_abraham.py +0 -0
  22. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/synthetic_did.py +0 -0
  23. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/triple_diff.py +0 -0
  24. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/twfe.py +0 -0
  25. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/utils.py +0 -0
  26. {diff_diff-2.1.2 → diff_diff-2.1.3}/diff_diff/visualization.py +0 -0
  27. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/bootstrap.rs +0 -0
  28. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/lib.rs +0 -0
  29. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/linalg.rs +0 -0
  30. {diff_diff-2.1.2 → diff_diff-2.1.3}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.1.2
3
+ Version: 2.1.3
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -136,7 +136,7 @@ from diff_diff.datasets import (
136
136
  load_mpdta,
137
137
  )
138
138
 
139
- __version__ = "2.1.2"
139
+ __version__ = "2.1.3"
140
140
  __all__ = [
141
141
  # Estimators
142
142
  "DifferenceInDifferences",
@@ -63,7 +63,11 @@ class _PrecomputedStructures(TypedDict):
63
63
  control_obs: List[Tuple[int, int]]
64
64
  """List of (t, i) tuples for valid control observations."""
65
65
  control_unit_idx: np.ndarray
66
- """Array of control unit indices."""
66
+ """Array of never-treated unit indices (for backward compatibility)."""
67
+ D: np.ndarray
68
+ """Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
69
+ Y: np.ndarray
70
+ """Outcome matrix (n_periods x n_units)."""
67
71
  n_units: int
68
72
  """Number of units."""
69
73
  n_periods: int
@@ -529,6 +533,8 @@ class TROP:
529
533
  "treated_observations": treated_observations,
530
534
  "control_obs": control_obs,
531
535
  "control_unit_idx": control_unit_idx,
536
+ "D": D,
537
+ "Y": Y,
532
538
  "n_units": n_units,
533
539
  "n_periods": n_periods,
534
540
  }
@@ -778,16 +784,14 @@ class TROP:
778
784
  # Prepare inputs for Rust function
779
785
  control_mask_u8 = control_mask.astype(np.uint8)
780
786
  time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
781
- unit_dist_matrix = self._precomputed["unit_dist_matrix"]
782
- control_unit_idx_i64 = control_unit_idx.astype(np.int64)
783
787
 
784
788
  lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
785
789
  lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
786
790
  lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
787
791
 
788
792
  best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search(
789
- Y, D.astype(np.float64), control_mask_u8, control_unit_idx_i64,
790
- unit_dist_matrix, time_dist_matrix,
793
+ Y, D.astype(np.float64), control_mask_u8,
794
+ time_dist_matrix,
791
795
  lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
792
796
  self.max_loocv_samples, self.max_iter, self.tol,
793
797
  self.seed if self.seed is not None else 0
@@ -953,10 +957,16 @@ class TROP:
953
957
  """
954
958
  Compute observation-specific weight matrix for treated observation (i, t).
955
959
 
956
- Following the paper's Algorithm 2 (page 27):
960
+ Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
957
961
  - Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
958
962
  - Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
959
963
 
964
+ IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
965
+ where (1 - W_js) is non-zero, which includes pre-treatment observations of
966
+ eventually-treated units since W_js = 0 for those. This method computes
967
+ weights for ALL units where D[t, j] = 0 at the target period, not just
968
+ never-treated units.
969
+
960
970
  Uses pre-computed structures when available for efficiency.
961
971
 
962
972
  Parameters
@@ -974,7 +984,8 @@ class TROP:
974
984
  lambda_unit : float
975
985
  Unit weight decay parameter.
976
986
  control_unit_idx : np.ndarray
977
- Indices of control units.
987
+ Indices of never-treated units (for backward compatibility, but not
988
+ used for weight computation - we use D matrix directly).
978
989
  n_units : int
979
990
  Number of units.
980
991
  n_periods : int
@@ -991,21 +1002,30 @@ class TROP:
991
1002
  # time_dist_matrix[t, s] = |t - s|
992
1003
  time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
993
1004
 
994
- # Unit weights from pre-computed unit distance matrix
1005
+ # Unit weights - computed for ALL units where D[t, j] = 0
1006
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
995
1007
  unit_weights = np.zeros(n_units)
1008
+ D_stored = self._precomputed["D"]
1009
+ Y_stored = self._precomputed["Y"]
1010
+
1011
+ # Valid control units at time t: D[t, j] == 0
1012
+ valid_control_at_t = D_stored[t, :] == 0
996
1013
 
997
1014
  if lambda_unit == 0:
998
1015
  # Uniform weights when lambda_unit = 0
999
- unit_weights[:] = 1.0
1016
+ # All units not treated at time t get weight 1
1017
+ unit_weights[valid_control_at_t] = 1.0
1000
1018
  else:
1001
- # Use pre-computed distances: unit_dist_matrix[j, i] = dist(j, i)
1002
- dist_matrix = self._precomputed["unit_dist_matrix"]
1003
- for j in control_unit_idx:
1004
- dist = dist_matrix[j, i]
1005
- if np.isinf(dist):
1006
- unit_weights[j] = 0.0
1007
- else:
1008
- unit_weights[j] = np.exp(-lambda_unit * dist)
1019
+ # Use observation-specific distances with target period excluded
1020
+ # (Issue B fix: compute exact per-observation distance)
1021
+ for j in range(n_units):
1022
+ if valid_control_at_t[j] and j != i:
1023
+ # Compute distance excluding target period t
1024
+ dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
1025
+ if np.isinf(dist):
1026
+ unit_weights[j] = 0.0
1027
+ else:
1028
+ unit_weights[j] = np.exp(-lambda_unit * dist)
1009
1029
 
1010
1030
  # Treated unit i gets weight 1
1011
1031
  unit_weights[i] = 1.0
@@ -1018,19 +1038,25 @@ class TROP:
1018
1038
  dist_time = np.abs(np.arange(n_periods) - t)
1019
1039
  time_weights = np.exp(-lambda_time * dist_time)
1020
1040
 
1021
- # Unit distance: pairwise RMSE from each control j to treated i
1041
+ # Unit weights - computed for ALL units where D[t, j] = 0
1042
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
1022
1043
  unit_weights = np.zeros(n_units)
1023
1044
 
1045
+ # Valid control units at time t: D[t, j] == 0
1046
+ valid_control_at_t = D[t, :] == 0
1047
+
1024
1048
  if lambda_unit == 0:
1025
1049
  # Uniform weights when lambda_unit = 0
1026
- unit_weights[:] = 1.0
1050
+ unit_weights[valid_control_at_t] = 1.0
1027
1051
  else:
1028
- for j in control_unit_idx:
1029
- dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
1030
- if np.isinf(dist):
1031
- unit_weights[j] = 0.0
1032
- else:
1033
- unit_weights[j] = np.exp(-lambda_unit * dist)
1052
+ for j in range(n_units):
1053
+ if valid_control_at_t[j] and j != i:
1054
+ # Compute distance excluding target period t (Issue B fix)
1055
+ dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
1056
+ if np.isinf(dist):
1057
+ unit_weights[j] = 0.0
1058
+ else:
1059
+ unit_weights[j] = np.exp(-lambda_unit * dist)
1034
1060
 
1035
1061
  # Treated unit i gets weight 1 (or could be omitted since we fit on controls)
1036
1062
  # We include treated unit's own observation for model fitting
@@ -1102,6 +1128,101 @@ class TROP:
1102
1128
 
1103
1129
  return result
1104
1130
 
1131
+ def _weighted_nuclear_norm_solve(
1132
+ self,
1133
+ Y: np.ndarray,
1134
+ W: np.ndarray,
1135
+ L_init: np.ndarray,
1136
+ alpha: np.ndarray,
1137
+ beta: np.ndarray,
1138
+ lambda_nn: float,
1139
+ max_inner_iter: int = 20,
1140
+ ) -> np.ndarray:
1141
+ """
1142
+ Solve weighted nuclear norm problem using iterative weighted soft-impute.
1143
+
1144
+ Issue C fix: Implements the weighted nuclear norm optimization from the
1145
+ paper's Equation 2 (page 7). The full objective is:
1146
+ min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
1147
+
1148
+ This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
1149
+ L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
1150
+
1151
+ where W ⊙ denotes element-wise multiplication with normalized weights.
1152
+
1153
+ IMPORTANT: For observations with W=0 (treated observations), we keep
1154
+ L values from the previous iteration rather than setting L = R, which
1155
+ would absorb the treatment effect.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ Y : np.ndarray
1160
+ Outcome matrix (n_periods x n_units).
1161
+ W : np.ndarray
1162
+ Weight matrix (n_periods x n_units), non-negative. W=0 indicates
1163
+ observations that should not be used for fitting (treated obs).
1164
+ L_init : np.ndarray
1165
+ Initial estimate of L matrix.
1166
+ alpha : np.ndarray
1167
+ Current unit fixed effects estimate.
1168
+ beta : np.ndarray
1169
+ Current time fixed effects estimate.
1170
+ lambda_nn : float
1171
+ Nuclear norm regularization parameter.
1172
+ max_inner_iter : int, default=20
1173
+ Maximum inner iterations for the proximal algorithm.
1174
+
1175
+ Returns
1176
+ -------
1177
+ np.ndarray
1178
+ Updated L matrix estimate.
1179
+ """
1180
+ # Compute target residual R = Y - α - β
1181
+ R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
1182
+
1183
+ # Handle invalid values
1184
+ R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
1185
+
1186
+ # For observations with W=0 (treated obs), keep L_init instead of R
1187
+ # This prevents L from absorbing the treatment effect
1188
+ valid_obs_mask = W > 0
1189
+ R_masked = np.where(valid_obs_mask, R, L_init)
1190
+
1191
+ if lambda_nn <= 0:
1192
+ # No regularization - just return masked residual
1193
+ # Use soft-thresholding with threshold=0 which returns the input
1194
+ return R_masked
1195
+
1196
+ # Normalize weights so max is 1 (for step size stability)
1197
+ W_max = np.max(W)
1198
+ if W_max > 0:
1199
+ W_norm = W / W_max
1200
+ else:
1201
+ W_norm = W
1202
+
1203
+ # Initialize L
1204
+ L = L_init.copy()
1205
+
1206
+ # Proximal gradient iteration with weighted soft-impute
1207
+ # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
1208
+ # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
1209
+ # where η is the step size (we use η = 1 with normalized weights)
1210
+ for _ in range(max_inner_iter):
1211
+ L_old = L.copy()
1212
+
1213
+ # Gradient step: L_k + W ⊙ (R - L_k)
1214
+ # For W=0 observations, this keeps L_k unchanged
1215
+ gradient_step = L + W_norm * (R_masked - L)
1216
+
1217
+ # Proximal step: soft-threshold singular values
1218
+ L = self._soft_threshold_svd(gradient_step, lambda_nn)
1219
+
1220
+ # Check convergence
1221
+ if np.max(np.abs(L - L_old)) < self.tol:
1222
+ break
1223
+
1224
+ return L
1225
+
1105
1226
  def _estimate_model(
1106
1227
  self,
1107
1228
  Y: np.ndarray,
@@ -1205,14 +1326,13 @@ class TROP:
1205
1326
  beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
1206
1327
  beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
1207
1328
 
1208
- # Step 2: Update L with nuclear norm penalty
1209
- # Following Equation 2 (page 7): L = prox_{λ_nn||·||_*}(Y - α - β)
1210
- # The proximal operator for nuclear norm is soft-thresholding of SVD
1211
- R_for_L = Y_safe - alpha[np.newaxis, :] - beta[:, np.newaxis]
1212
- # Impute invalid observations with current L for stable SVD
1213
- R_for_L = np.where(valid_mask, R_for_L, L)
1214
-
1215
- L = self._soft_threshold_svd(R_for_L, lambda_nn)
1329
+ # Step 2: Update L with weighted nuclear norm penalty
1330
+ # Issue C fix: Use weighted soft-impute to properly account for
1331
+ # observation weights in the nuclear norm optimization.
1332
+ # Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
1333
+ L = self._weighted_nuclear_norm_solve(
1334
+ Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
1335
+ )
1216
1336
 
1217
1337
  # Check convergence
1218
1338
  alpha_diff = np.max(np.abs(alpha - alpha_old))
@@ -1388,21 +1508,15 @@ class TROP:
1388
1508
  # Try Rust backend for parallel bootstrap (5-15x speedup)
1389
1509
  if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
1390
1510
  and self._precomputed is not None and Y is not None
1391
- and D is not None and control_unit_idx is not None):
1511
+ and D is not None):
1392
1512
  try:
1393
- # Prepare inputs
1394
- treated_observations = self._precomputed["treated_observations"]
1395
- treated_t = np.array([t for t, i in treated_observations], dtype=np.int64)
1396
- treated_i = np.array([i for t, i in treated_observations], dtype=np.int64)
1397
1513
  control_mask = self._precomputed["control_mask"]
1514
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
1398
1515
 
1399
1516
  bootstrap_estimates, se = _rust_bootstrap_trop_variance(
1400
1517
  Y, D.astype(np.float64),
1401
1518
  control_mask.astype(np.uint8),
1402
- control_unit_idx.astype(np.int64),
1403
- treated_t, treated_i,
1404
- self._precomputed["unit_dist_matrix"],
1405
- self._precomputed["time_dist_matrix"].astype(np.int64),
1519
+ time_dist_matrix,
1406
1520
  lambda_time, lambda_unit, lambda_nn,
1407
1521
  self.n_bootstrap, self.max_iter, self.tol,
1408
1522
  self.seed if self.seed is not None else 0
@@ -1422,14 +1536,38 @@ class TROP:
1422
1536
 
1423
1537
  # Python implementation (fallback)
1424
1538
  rng = np.random.default_rng(self.seed)
1425
- all_units = data[unit].unique()
1426
- n_units_data = len(all_units)
1539
+
1540
+ # Issue D fix: Stratified bootstrap sampling
1541
+ # Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
1542
+ # and N_1 treated rows separately to preserve treatment ratio
1543
+ unit_ever_treated = data.groupby(unit)[treatment].max()
1544
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
1545
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
1546
+
1547
+ n_treated_units = len(treated_units)
1548
+ n_control_units = len(control_units)
1427
1549
 
1428
1550
  bootstrap_estimates_list = []
1429
1551
 
1430
1552
  for _ in range(self.n_bootstrap):
1431
- # Sample units with replacement
1432
- sampled_units = rng.choice(all_units, size=n_units_data, replace=True)
1553
+ # Stratified sampling: sample control and treated units separately
1554
+ # This preserves the treatment ratio in each bootstrap sample
1555
+ if n_control_units > 0:
1556
+ sampled_control = rng.choice(
1557
+ control_units, size=n_control_units, replace=True
1558
+ )
1559
+ else:
1560
+ sampled_control = np.array([], dtype=control_units.dtype)
1561
+
1562
+ if n_treated_units > 0:
1563
+ sampled_treated = rng.choice(
1564
+ treated_units, size=n_treated_units, replace=True
1565
+ )
1566
+ else:
1567
+ sampled_treated = np.array([], dtype=treated_units.dtype)
1568
+
1569
+ # Combine stratified samples
1570
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
1433
1571
 
1434
1572
  # Create bootstrap sample with unique unit IDs
1435
1573
  boot_data = pd.concat([
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "diff-diff"
7
- version = "2.1.2"
7
+ version = "2.1.3"
8
8
  description = "A library for Difference-in-Differences causal inference analysis"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -289,7 +289,7 @@ dependencies = [
289
289
 
290
290
  [[package]]
291
291
  name = "diff_diff_rust"
292
- version = "2.1.2"
292
+ version = "2.1.3"
293
293
  dependencies = [
294
294
  "ndarray",
295
295
  "ndarray-linalg",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "diff_diff_rust"
3
- version = "2.1.2"
3
+ version = "2.1.3"
4
4
  edition = "2021"
5
5
  description = "Rust backend for diff-diff DiD library"
6
6
  license = "MIT"
@@ -172,15 +172,13 @@ fn compute_pair_distance(
172
172
  /// # Returns
173
173
  /// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score)
174
174
  #[pyfunction]
175
- #[pyo3(signature = (y, d, control_mask, control_unit_idx, unit_dist_matrix, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
175
+ #[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
176
176
  #[allow(clippy::too_many_arguments)]
177
177
  pub fn loocv_grid_search<'py>(
178
178
  _py: Python<'py>,
179
179
  y: PyReadonlyArray2<'py, f64>,
180
180
  d: PyReadonlyArray2<'py, f64>,
181
181
  control_mask: PyReadonlyArray2<'py, u8>,
182
- control_unit_idx: PyReadonlyArray1<'py, i64>,
183
- unit_dist_matrix: PyReadonlyArray2<'py, f64>,
184
182
  time_dist_matrix: PyReadonlyArray2<'py, i64>,
185
183
  lambda_time_grid: PyReadonlyArray1<'py, f64>,
186
184
  lambda_unit_grid: PyReadonlyArray1<'py, f64>,
@@ -193,19 +191,11 @@ pub fn loocv_grid_search<'py>(
193
191
  let y_arr = y.as_array();
194
192
  let d_arr = d.as_array();
195
193
  let control_mask_arr = control_mask.as_array();
196
- let control_unit_idx_arr = control_unit_idx.as_array();
197
- let unit_dist_arr = unit_dist_matrix.as_array();
198
194
  let time_dist_arr = time_dist_matrix.as_array();
199
195
  let lambda_time_vec: Vec<f64> = lambda_time_grid.as_array().to_vec();
200
196
  let lambda_unit_vec: Vec<f64> = lambda_unit_grid.as_array().to_vec();
201
197
  let lambda_nn_vec: Vec<f64> = lambda_nn_grid.as_array().to_vec();
202
198
 
203
- // Convert control_unit_idx to Vec<usize>
204
- let control_units: Vec<usize> = control_unit_idx_arr
205
- .iter()
206
- .map(|&idx| idx as usize)
207
- .collect();
208
-
209
199
  // Get control observations for LOOCV
210
200
  let control_obs = get_control_observations(
211
201
  &y_arr,
@@ -232,8 +222,6 @@ pub fn loocv_grid_search<'py>(
232
222
  &y_arr,
233
223
  &d_arr,
234
224
  &control_mask_arr,
235
- &control_units,
236
- &unit_dist_arr,
237
225
  &time_dist_arr,
238
226
  &control_obs,
239
227
  lambda_time,
@@ -291,10 +279,8 @@ fn get_control_observations(
291
279
  /// Compute LOOCV score for a specific parameter combination.
292
280
  fn loocv_score_for_params(
293
281
  y: &ArrayView2<f64>,
294
- _d: &ArrayView2<f64>,
282
+ d: &ArrayView2<f64>,
295
283
  control_mask: &ArrayView2<u8>,
296
- control_units: &[usize],
297
- unit_dist: &ArrayView2<f64>,
298
284
  time_dist: &ArrayView2<i64>,
299
285
  control_obs: &[(usize, usize)],
300
286
  lambda_time: f64,
@@ -312,14 +298,14 @@ fn loocv_score_for_params(
312
298
  for &(t, i) in control_obs {
313
299
  // Compute observation-specific weight matrix
314
300
  let weight_matrix = compute_weight_matrix(
301
+ y,
302
+ d,
315
303
  n_periods,
316
304
  n_units,
317
305
  i,
318
306
  t,
319
307
  lambda_time,
320
308
  lambda_unit,
321
- control_units,
322
- unit_dist,
323
309
  time_dist,
324
310
  );
325
311
 
@@ -352,46 +338,107 @@ fn loocv_score_for_params(
352
338
  }
353
339
  }
354
340
 
341
+ /// Compute observation-specific distance from unit j to unit i, excluding target period.
342
+ ///
343
+ /// Issue B fix: Follows Equation 3 (page 7) which specifies 1{u ≠ t} to exclude target period.
344
+ fn compute_unit_distance_for_obs(
345
+ y: &ArrayView2<f64>,
346
+ d: &ArrayView2<f64>,
347
+ j: usize,
348
+ i: usize,
349
+ target_period: usize,
350
+ ) -> f64 {
351
+ let n_periods = y.nrows();
352
+ let mut sum_sq = 0.0;
353
+ let mut n_valid = 0usize;
354
+
355
+ for t in 0..n_periods {
356
+ // Exclude target period (Issue B fix)
357
+ if t == target_period {
358
+ continue;
359
+ }
360
+ // Both units must be control at this period and have valid values
361
+ if d[[t, i]] == 0.0 && d[[t, j]] == 0.0
362
+ && y[[t, i]].is_finite() && y[[t, j]].is_finite()
363
+ {
364
+ let diff = y[[t, i]] - y[[t, j]];
365
+ sum_sq += diff * diff;
366
+ n_valid += 1;
367
+ }
368
+ }
369
+
370
+ if n_valid > 0 {
371
+ (sum_sq / n_valid as f64).sqrt()
372
+ } else {
373
+ f64::INFINITY
374
+ }
375
+ }
376
+
355
377
  /// Compute observation-specific weight matrix for TROP.
356
378
  ///
357
379
  /// Time weights: θ_s = exp(-λ_time × |t - s|)
358
380
  /// Unit weights: ω_j = exp(-λ_unit × dist(j, i))
381
+ ///
382
+ /// Paper alignment notes:
383
+ /// - ALL units get weights (not just those untreated at target period)
384
+ /// - The (1 - D_js) masking in the loss naturally excludes treated cells
385
+ /// - Weights are normalized to sum to 1 (probability weights)
386
+ /// - Distance excludes target period t per Equation 3
359
387
  fn compute_weight_matrix(
388
+ y: &ArrayView2<f64>,
389
+ d: &ArrayView2<f64>,
360
390
  n_periods: usize,
361
391
  n_units: usize,
362
392
  target_unit: usize,
363
393
  target_period: usize,
364
394
  lambda_time: f64,
365
395
  lambda_unit: f64,
366
- control_units: &[usize],
367
- unit_dist: &ArrayView2<f64>,
368
396
  time_dist: &ArrayView2<i64>,
369
397
  ) -> Array2<f64> {
370
- // Time weights for this target period
371
- let time_weights: Array1<f64> = Array1::from_shape_fn(n_periods, |s| {
398
+ // Time weights for this target period: θ_s = exp(-λ_time × |t - s|)
399
+ let mut time_weights: Array1<f64> = Array1::from_shape_fn(n_periods, |s| {
372
400
  let dist = time_dist[[target_period, s]] as f64;
373
401
  (-lambda_time * dist).exp()
374
402
  });
375
403
 
376
- // Unit weights
404
+ // Normalize time weights to sum to 1
405
+ let time_sum: f64 = time_weights.sum();
406
+ if time_sum > 0.0 {
407
+ time_weights /= time_sum;
408
+ }
409
+
410
+ // Unit weights: ω_j = exp(-λ_unit × dist(j, i))
411
+ // Paper alignment: compute for ALL units, let control masking handle exclusion
377
412
  let mut unit_weights = Array1::<f64>::zeros(n_units);
378
413
 
379
414
  if lambda_unit == 0.0 {
380
415
  // Uniform weights when lambda_unit = 0
416
+ // All units get weight 1 (control masking will handle exclusion)
381
417
  unit_weights.fill(1.0);
382
418
  } else {
383
- for &j in control_units {
384
- let dist = unit_dist[[j, target_unit]];
385
- if dist.is_finite() {
386
- unit_weights[j] = (-lambda_unit * dist).exp();
419
+ // Compute per-observation distance for all units (excluding target unit itself)
420
+ for j in 0..n_units {
421
+ if j != target_unit {
422
+ let dist = compute_unit_distance_for_obs(y, d, j, target_unit, target_period);
423
+ if dist.is_finite() {
424
+ unit_weights[j] = (-lambda_unit * dist).exp();
425
+ }
426
+ // Units with infinite distance (no valid comparison periods) get weight 0
387
427
  }
388
428
  }
389
429
  }
390
430
 
391
- // Target unit gets weight 1
431
+ // Target unit gets weight 1 (will be masked out in estimation anyway)
392
432
  unit_weights[target_unit] = 1.0;
393
433
 
434
+ // Normalize unit weights to sum to 1
435
+ let unit_sum: f64 = unit_weights.sum();
436
+ if unit_sum > 0.0 {
437
+ unit_weights /= unit_sum;
438
+ }
439
+
394
440
  // Outer product: W[t, i] = time_weights[t] * unit_weights[i]
441
+ // Result is normalized since both components sum to 1
395
442
  let mut weight_matrix = Array2::<f64>::zeros((n_periods, n_units));
396
443
  for t in 0..n_periods {
397
444
  for i in 0..n_units {
@@ -406,6 +453,10 @@ fn compute_weight_matrix(
406
453
  ///
407
454
  /// Minimizes: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
408
455
  ///
456
+ /// Paper alignment: Uses weighted proximal gradient for L update:
457
+ /// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
458
+ /// where η ≤ 1/max(W) for convergence.
459
+ ///
409
460
  /// Returns None if estimation fails due to numerical issues.
410
461
  fn estimate_model(
411
462
  y: &ArrayView2<f64>,
@@ -432,7 +483,7 @@ fn estimate_model(
432
483
  y[[t, i]].is_finite() && est_mask[[t, i]]
433
484
  });
434
485
 
435
- // Masked weights
486
+ // Masked weights: W=0 for invalid/treated observations
436
487
  let w_masked = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
437
488
  if valid_mask[[t, i]] {
438
489
  weight_matrix[[t, i]]
@@ -441,6 +492,10 @@ fn estimate_model(
441
492
  }
442
493
  });
443
494
 
495
+ // Compute step size for proximal gradient: η ≤ 1/max(W)
496
+ let w_max = w_masked.iter().cloned().fold(0.0_f64, f64::max);
497
+ let eta = if w_max > 0.0 { 1.0 / w_max } else { 1.0 };
498
+
444
499
  // Weight sums per unit and time
445
500
  let weight_sum_per_unit: Array1<f64> = w_masked.sum_axis(Axis(0));
446
501
  let weight_sum_per_time: Array1<f64> = w_masked.sum_axis(Axis(1));
@@ -472,7 +527,7 @@ fn estimate_model(
472
527
  let beta_old = beta.clone();
473
528
  let l_old = l.clone();
474
529
 
475
- // Step 1: Update α and β
530
+ // Step 1: Update α and β (weighted least squares)
476
531
  // R = Y - L
477
532
  let r = &y_safe - &l;
478
533
 
@@ -498,25 +553,31 @@ fn estimate_model(
498
553
  }
499
554
  }
500
555
 
501
- // Step 2: Update L with nuclear norm penalty
502
- // R_for_L = Y - α - β
503
- let mut r_for_l = Array2::<f64>::zeros((n_periods, n_units));
556
+ // Step 2: Update L with WEIGHTED nuclear norm penalty
557
+ // Paper alignment: Use proximal gradient instead of direct soft-thresholding
558
+ // L prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
559
+ // where R = Y - α - β
560
+
561
+ // Compute target residual R = Y - α - β
562
+ let mut r_target = Array2::<f64>::zeros((n_periods, n_units));
504
563
  for t in 0..n_periods {
505
564
  for i in 0..n_units {
506
- r_for_l[[t, i]] = y_safe[[t, i]] - alpha[i] - beta[t];
565
+ r_target[[t, i]] = y_safe[[t, i]] - alpha[i] - beta[t];
507
566
  }
508
567
  }
509
568
 
510
- // Impute invalid observations with current L
569
+ // Weighted proximal gradient step:
570
+ // gradient_step = L + η * W ⊙ (R - L)
571
+ // For W=0 cells (treated obs), this keeps L unchanged
572
+ let mut gradient_step = Array2::<f64>::zeros((n_periods, n_units));
511
573
  for t in 0..n_periods {
512
574
  for i in 0..n_units {
513
- if !valid_mask[[t, i]] {
514
- r_for_l[[t, i]] = l[[t, i]];
515
- }
575
+ gradient_step[[t, i]] = l[[t, i]] + eta * w_masked[[t, i]] * (r_target[[t, i]] - l[[t, i]]);
516
576
  }
517
577
  }
518
578
 
519
- l = soft_threshold_svd(&r_for_l, lambda_nn)?;
579
+ // Proximal step: soft-threshold singular values with scaled lambda
580
+ l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?;
520
581
 
521
582
  // Check convergence
522
583
  let alpha_diff = max_abs_diff(&alpha, &alpha_old);
@@ -627,17 +688,13 @@ fn max_abs_diff_2d(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
627
688
  /// # Returns
628
689
  /// (bootstrap_estimates, standard_error)
629
690
  #[pyfunction]
630
- #[pyo3(signature = (y, d, control_mask, control_unit_idx, treated_obs_t, treated_obs_i, unit_dist_matrix, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))]
691
+ #[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))]
631
692
  #[allow(clippy::too_many_arguments)]
632
693
  pub fn bootstrap_trop_variance<'py>(
633
694
  py: Python<'py>,
634
695
  y: PyReadonlyArray2<'py, f64>,
635
696
  d: PyReadonlyArray2<'py, f64>,
636
697
  control_mask: PyReadonlyArray2<'py, u8>,
637
- control_unit_idx: PyReadonlyArray1<'py, i64>,
638
- treated_obs_t: PyReadonlyArray1<'py, i64>,
639
- treated_obs_i: PyReadonlyArray1<'py, i64>,
640
- unit_dist_matrix: PyReadonlyArray2<'py, f64>,
641
698
  time_dist_matrix: PyReadonlyArray2<'py, i64>,
642
699
  lambda_time: f64,
643
700
  lambda_unit: f64,
@@ -650,16 +707,25 @@ pub fn bootstrap_trop_variance<'py>(
650
707
  let y_arr = y.as_array().to_owned();
651
708
  let d_arr = d.as_array().to_owned();
652
709
  let control_mask_arr = control_mask.as_array().to_owned();
653
- let unit_dist_arr = unit_dist_matrix.as_array().to_owned();
654
710
  let time_dist_arr = time_dist_matrix.as_array().to_owned();
655
711
 
656
712
  let n_units = y_arr.ncols();
657
713
  let n_periods = y_arr.nrows();
658
714
 
659
- // Note: control_unit_idx, treated_obs_t, treated_obs_i are passed for API
660
- // compatibility but not used directly - each bootstrap iteration recomputes
661
- // control units and treated observations from the resampled data.
662
- let _ = (control_unit_idx, treated_obs_t, treated_obs_i);
715
+ // Issue D fix: Identify treated and control units for stratified sampling
716
+ // Following paper's Algorithm 3 (page 27): sample N_0 control and N_1 treated separately
717
+ let mut original_treated_units: Vec<usize> = Vec::new();
718
+ let mut original_control_units: Vec<usize> = Vec::new();
719
+ for i in 0..n_units {
720
+ let is_ever_treated = (0..n_periods).any(|t| d_arr[[t, i]] == 1.0);
721
+ if is_ever_treated {
722
+ original_treated_units.push(i);
723
+ } else {
724
+ original_control_units.push(i);
725
+ }
726
+ }
727
+ let n_treated_units = original_treated_units.len();
728
+ let n_control_units = original_control_units.len();
663
729
 
664
730
  // Run bootstrap iterations in parallel
665
731
  let bootstrap_estimates: Vec<f64> = (0..n_bootstrap)
@@ -670,16 +736,25 @@ pub fn bootstrap_trop_variance<'py>(
670
736
 
671
737
  let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(b as u64));
672
738
 
673
- // Sample units with replacement
674
- let sampled_units: Vec<usize> = (0..n_units)
675
- .map(|_| rng.gen_range(0..n_units))
676
- .collect();
739
+ // Issue D fix: Stratified sampling - sample control and treated units separately
740
+ let mut sampled_units: Vec<usize> = Vec::with_capacity(n_units);
741
+
742
+ // Sample control units with replacement
743
+ for _ in 0..n_control_units {
744
+ let idx = rng.gen_range(0..n_control_units);
745
+ sampled_units.push(original_control_units[idx]);
746
+ }
747
+
748
+ // Sample treated units with replacement
749
+ for _ in 0..n_treated_units {
750
+ let idx = rng.gen_range(0..n_treated_units);
751
+ sampled_units.push(original_treated_units[idx]);
752
+ }
677
753
 
678
754
  // Create bootstrap matrices by selecting columns
679
755
  let mut y_boot = Array2::<f64>::zeros((n_periods, n_units));
680
756
  let mut d_boot = Array2::<f64>::zeros((n_periods, n_units));
681
757
  let mut control_mask_boot = Array2::<u8>::zeros((n_periods, n_units));
682
- let mut unit_dist_boot = Array2::<f64>::zeros((n_units, n_units));
683
758
 
684
759
  for (new_idx, &old_idx) in sampled_units.iter().enumerate() {
685
760
  for t in 0..n_periods {
@@ -687,10 +762,6 @@ pub fn bootstrap_trop_variance<'py>(
687
762
  d_boot[[t, new_idx]] = d_arr[[t, old_idx]];
688
763
  control_mask_boot[[t, new_idx]] = control_mask_arr[[t, old_idx]];
689
764
  }
690
-
691
- for (new_j, &old_j) in sampled_units.iter().enumerate() {
692
- unit_dist_boot[[new_idx, new_j]] = unit_dist_arr[[old_idx, old_j]];
693
- }
694
765
  }
695
766
 
696
767
  // Get treated observations in bootstrap sample
@@ -725,14 +796,14 @@ pub fn bootstrap_trop_variance<'py>(
725
796
 
726
797
  for (t, i) in boot_treated {
727
798
  let weight_matrix = compute_weight_matrix(
799
+ &y_boot.view(),
800
+ &d_boot.view(),
728
801
  n_periods,
729
802
  n_units,
730
803
  i,
731
804
  t,
732
805
  lambda_time,
733
806
  lambda_unit,
734
- &boot_control_units,
735
- &unit_dist_boot.view(),
736
807
  &time_dist_arr.view(),
737
808
  );
738
809
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes