diff-diff 2.1.7__tar.gz → 2.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {diff_diff-2.1.7 → diff_diff-2.1.8}/PKG-INFO +14 -9
  2. {diff_diff-2.1.7 → diff_diff-2.1.8}/README.md +13 -8
  3. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/__init__.py +1 -1
  4. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/trop.py +394 -88
  5. {diff_diff-2.1.7 → diff_diff-2.1.8}/pyproject.toml +1 -1
  6. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/Cargo.lock +3 -3
  7. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/Cargo.toml +1 -1
  8. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/trop.rs +217 -45
  9. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/_backend.py +0 -0
  10. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/bacon.py +0 -0
  11. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/datasets.py +0 -0
  12. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/diagnostics.py +0 -0
  13. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/estimators.py +0 -0
  14. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/honest_did.py +0 -0
  15. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/linalg.py +0 -0
  16. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/power.py +0 -0
  17. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/prep.py +0 -0
  18. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/prep_dgp.py +0 -0
  19. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/pretrends.py +0 -0
  20. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/results.py +0 -0
  21. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered.py +0 -0
  22. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered_aggregation.py +0 -0
  23. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered_bootstrap.py +0 -0
  24. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/staggered_results.py +0 -0
  25. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/sun_abraham.py +0 -0
  26. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/synthetic_did.py +0 -0
  27. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/triple_diff.py +0 -0
  28. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/twfe.py +0 -0
  29. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/utils.py +0 -0
  30. {diff_diff-2.1.7 → diff_diff-2.1.8}/diff_diff/visualization.py +0 -0
  31. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/bootstrap.rs +0 -0
  32. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/lib.rs +0 -0
  33. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/linalg.rs +0 -0
  34. {diff_diff-2.1.7 → diff_diff-2.1.8}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.1.7
3
+ Version: 2.1.8
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -1173,13 +1173,15 @@ trop_est = TROP(
1173
1173
  lambda_nn_grid=[0.0, 0.1, 1.0], # Nuclear norm grid
1174
1174
  n_bootstrap=200
1175
1175
  )
1176
+ # Note: TROP infers treatment periods from the treatment indicator column.
1177
+ # The 'treated' column must be an absorbing state (D=1 for all periods
1178
+ # during and after treatment starts for each unit).
1176
1179
  results = trop_est.fit(
1177
1180
  panel_data,
1178
1181
  outcome='gdp_growth',
1179
1182
  treatment='treated',
1180
1183
  unit='state',
1181
- time='year',
1182
- post_periods=[2015, 2016, 2017, 2018]
1184
+ time='year'
1183
1185
  )
1184
1186
 
1185
1187
  # View results
@@ -1267,9 +1269,11 @@ sdid_results = sdid.fit(data, outcome='y', treatment='treated',
1267
1269
  unit='unit', time='time', post_periods=[5,6,7])
1268
1270
 
1269
1271
  # TROP (accounts for factors)
1272
+ # Note: TROP infers treatment periods from the treatment indicator column
1273
+ # (D=1 for treated observations, D=0 for control)
1270
1274
  trop_est = TROP() # Uses default grids with LOOCV selection
1271
1275
  trop_results = trop_est.fit(data, outcome='y', treatment='treated',
1272
- unit='unit', time='time', post_periods=[5,6,7])
1276
+ unit='unit', time='time')
1273
1277
 
1274
1278
  print(f"SDID estimate: {sdid_results.att:.3f}")
1275
1279
  print(f"TROP estimate: {trop_results.att:.3f}")
@@ -1314,13 +1318,13 @@ TROP(
1314
1318
 
1315
1319
  ```python
1316
1320
  # One-liner estimation with default tuning grids
1321
+ # Note: TROP infers treatment periods from the treatment indicator
1317
1322
  results = trop(
1318
1323
  data,
1319
1324
  outcome='y',
1320
1325
  treatment='treated',
1321
1326
  unit='unit',
1322
1327
  time='time',
1323
- post_periods=[5, 6, 7],
1324
1328
  n_bootstrap=200
1325
1329
  )
1326
1330
  ```
@@ -1912,10 +1916,11 @@ TROP(
1912
1916
  |-----------|------|-------------|
1913
1917
  | `data` | DataFrame | Panel data |
1914
1918
  | `outcome` | str | Outcome variable column name |
1915
- | `treatment` | str | Treatment indicator column (0/1) |
1919
+ | `treatment` | str | Treatment indicator column (0/1 absorbing state) |
1916
1920
  | `unit` | str | Unit identifier column |
1917
1921
  | `time` | str | Time period column |
1918
- | `post_periods` | list | List of post-treatment period values |
1922
+
1923
+ Note: TROP infers treatment periods from the treatment indicator column. The treatment column should be an absorbing state indicator where D=1 for all periods during and after treatment starts.
1919
1924
 
1920
1925
  ### TROPResults
1921
1926
 
@@ -1941,8 +1946,8 @@ TROP(
1941
1946
  | `factor_matrix` | Low-rank factor matrix L (n_periods x n_units) |
1942
1947
  | `effective_rank` | Effective rank of factor matrix |
1943
1948
  | `loocv_score` | LOOCV score for selected parameters |
1944
- | `pre_periods` | List of pre-treatment periods |
1945
- | `post_periods` | List of post-treatment periods |
1949
+ | `n_pre_periods` | Number of pre-treatment periods |
1950
+ | `n_post_periods` | Number of post-treatment periods |
1946
1951
  | `variance_method` | Variance estimation method |
1947
1952
  | `bootstrap_distribution` | Bootstrap distribution (if bootstrap) |
1948
1953
 
@@ -1138,13 +1138,15 @@ trop_est = TROP(
1138
1138
  lambda_nn_grid=[0.0, 0.1, 1.0], # Nuclear norm grid
1139
1139
  n_bootstrap=200
1140
1140
  )
1141
+ # Note: TROP infers treatment periods from the treatment indicator column.
1142
+ # The 'treated' column must be an absorbing state (D=1 for all periods
1143
+ # during and after treatment starts for each unit).
1141
1144
  results = trop_est.fit(
1142
1145
  panel_data,
1143
1146
  outcome='gdp_growth',
1144
1147
  treatment='treated',
1145
1148
  unit='state',
1146
- time='year',
1147
- post_periods=[2015, 2016, 2017, 2018]
1149
+ time='year'
1148
1150
  )
1149
1151
 
1150
1152
  # View results
@@ -1232,9 +1234,11 @@ sdid_results = sdid.fit(data, outcome='y', treatment='treated',
1232
1234
  unit='unit', time='time', post_periods=[5,6,7])
1233
1235
 
1234
1236
  # TROP (accounts for factors)
1237
+ # Note: TROP infers treatment periods from the treatment indicator column
1238
+ # (D=1 for treated observations, D=0 for control)
1235
1239
  trop_est = TROP() # Uses default grids with LOOCV selection
1236
1240
  trop_results = trop_est.fit(data, outcome='y', treatment='treated',
1237
- unit='unit', time='time', post_periods=[5,6,7])
1241
+ unit='unit', time='time')
1238
1242
 
1239
1243
  print(f"SDID estimate: {sdid_results.att:.3f}")
1240
1244
  print(f"TROP estimate: {trop_results.att:.3f}")
@@ -1279,13 +1283,13 @@ TROP(
1279
1283
 
1280
1284
  ```python
1281
1285
  # One-liner estimation with default tuning grids
1286
+ # Note: TROP infers treatment periods from the treatment indicator
1282
1287
  results = trop(
1283
1288
  data,
1284
1289
  outcome='y',
1285
1290
  treatment='treated',
1286
1291
  unit='unit',
1287
1292
  time='time',
1288
- post_periods=[5, 6, 7],
1289
1293
  n_bootstrap=200
1290
1294
  )
1291
1295
  ```
@@ -1877,10 +1881,11 @@ TROP(
1877
1881
  |-----------|------|-------------|
1878
1882
  | `data` | DataFrame | Panel data |
1879
1883
  | `outcome` | str | Outcome variable column name |
1880
- | `treatment` | str | Treatment indicator column (0/1) |
1884
+ | `treatment` | str | Treatment indicator column (0/1 absorbing state) |
1881
1885
  | `unit` | str | Unit identifier column |
1882
1886
  | `time` | str | Time period column |
1883
- | `post_periods` | list | List of post-treatment period values |
1887
+
1888
+ Note: TROP infers treatment periods from the treatment indicator column. The treatment column should be an absorbing state indicator where D=1 for all periods during and after treatment starts.
1884
1889
 
1885
1890
  ### TROPResults
1886
1891
 
@@ -1906,8 +1911,8 @@ TROP(
1906
1911
  | `factor_matrix` | Low-rank factor matrix L (n_periods x n_units) |
1907
1912
  | `effective_rank` | Effective rank of factor matrix |
1908
1913
  | `loocv_score` | LOOCV score for selected parameters |
1909
- | `pre_periods` | List of pre-treatment periods |
1910
- | `post_periods` | List of post-treatment periods |
1914
+ | `n_pre_periods` | Number of pre-treatment periods |
1915
+ | `n_post_periods` | Number of post-treatment periods |
1911
1916
  | `variance_method` | Variance estimation method |
1912
1917
  | `bootstrap_distribution` | Bootstrap distribution (if bootstrap) |
1913
1918
 
@@ -136,7 +136,7 @@ from diff_diff.datasets import (
136
136
  load_mpdta,
137
137
  )
138
138
 
139
- __version__ = "2.1.7"
139
+ __version__ = "2.1.8"
140
140
  __all__ = [
141
141
  # Estimators
142
142
  "DifferenceInDifferences",
@@ -43,6 +43,11 @@ from diff_diff.results import _get_significance_stars
43
43
  from diff_diff.utils import compute_confidence_interval, compute_p_value
44
44
 
45
45
 
46
+ # Sentinel value for "disabled" mode in LOOCV parameter search
47
+ # Following paper's footnote 2: λ=∞ disables the corresponding component
48
+ _LAMBDA_INF: float = float('inf')
49
+
50
+
46
51
  class _PrecomputedStructures(TypedDict):
47
52
  """Type definition for pre-computed structures used across LOOCV iterations.
48
53
 
@@ -109,11 +114,15 @@ class TROPResults:
109
114
  treatment_effects : dict
110
115
  Individual treatment effects for each treated (unit, time) pair.
111
116
  lambda_time : float
112
- Selected time weight decay parameter.
117
+ Selected time weight decay parameter from grid. Note: infinity values
118
+ are converted internally (∞ → 0.0 for uniform weights) for computation.
113
119
  lambda_unit : float
114
- Selected unit weight decay parameter.
120
+ Selected unit weight decay parameter from grid. Note: infinity values
121
+ are converted internally (∞ → 0.0 for uniform weights) for computation.
115
122
  lambda_nn : float
116
- Selected nuclear norm regularization parameter.
123
+ Selected nuclear norm regularization parameter from grid. Note: infinity
124
+ values are converted internally (∞ → 1e10, factor model disabled) for
125
+ computation.
117
126
  factor_matrix : np.ndarray
118
127
  Estimated low-rank factor matrix L (n_periods x n_units).
119
128
  effective_rank : float
@@ -124,10 +133,10 @@ class TROPResults:
124
133
  Method used for variance estimation.
125
134
  alpha : float
126
135
  Significance level for confidence interval.
127
- pre_periods : list
128
- List of pre-treatment period identifiers.
129
- post_periods : list
130
- List of post-treatment period identifiers.
136
+ n_pre_periods : int
137
+ Number of pre-treatment periods.
138
+ n_post_periods : int
139
+ Number of post-treatment periods (periods with D=1 observations).
131
140
  n_bootstrap : int, optional
132
141
  Number of bootstrap replications (if bootstrap variance).
133
142
  bootstrap_distribution : np.ndarray, optional
@@ -154,8 +163,8 @@ class TROPResults:
154
163
  loocv_score: float
155
164
  variance_method: str
156
165
  alpha: float = 0.05
157
- pre_periods: List[Any] = field(default_factory=list)
158
- post_periods: List[Any] = field(default_factory=list)
166
+ n_pre_periods: int = 0
167
+ n_post_periods: int = 0
159
168
  n_bootstrap: Optional[int] = field(default=None)
160
169
  bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
161
170
 
@@ -197,8 +206,8 @@ class TROPResults:
197
206
  f"{'Treated units:':<25} {self.n_treated:>10}",
198
207
  f"{'Control units:':<25} {self.n_control:>10}",
199
208
  f"{'Treated observations:':<25} {self.n_treated_obs:>10}",
200
- f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
201
- f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
209
+ f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}",
210
+ f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}",
202
211
  "",
203
212
  "-" * 75,
204
213
  "Tuning Parameters (selected via LOOCV)".center(75),
@@ -261,8 +270,8 @@ class TROPResults:
261
270
  "n_treated": self.n_treated,
262
271
  "n_control": self.n_control,
263
272
  "n_treated_obs": self.n_treated_obs,
264
- "n_pre_periods": len(self.pre_periods),
265
- "n_post_periods": len(self.post_periods),
273
+ "n_pre_periods": self.n_pre_periods,
274
+ "n_post_periods": self.n_post_periods,
266
275
  "lambda_time": self.lambda_time,
267
276
  "lambda_unit": self.lambda_unit,
268
277
  "lambda_nn": self.lambda_nn,
@@ -397,7 +406,6 @@ class TROP:
397
406
  ... treatment='treated',
398
407
  ... unit='unit',
399
408
  ... time='period',
400
- ... post_periods=[5, 6, 7, 8]
401
409
  ... )
402
410
  >>> results.print_summary()
403
411
 
@@ -658,6 +666,168 @@ class TROP:
658
666
  else:
659
667
  return np.inf
660
668
 
669
+ def _univariate_loocv_search(
670
+ self,
671
+ Y: np.ndarray,
672
+ D: np.ndarray,
673
+ control_mask: np.ndarray,
674
+ control_unit_idx: np.ndarray,
675
+ n_units: int,
676
+ n_periods: int,
677
+ param_name: str,
678
+ grid: List[float],
679
+ fixed_params: Dict[str, float],
680
+ ) -> Tuple[float, float]:
681
+ """
682
+ Search over one parameter with others fixed.
683
+
684
+ Following paper's footnote 2, this performs a univariate grid search
685
+ for one tuning parameter while holding others fixed. The fixed_params
686
+ can include _LAMBDA_INF values to disable specific components:
687
+ - lambda_nn = inf: Skip nuclear norm regularization (L=0)
688
+ - lambda_time = inf: Uniform time weights (treated as 0)
689
+ - lambda_unit = inf: Uniform unit weights (treated as 0)
690
+
691
+ Parameters
692
+ ----------
693
+ Y : np.ndarray
694
+ Outcome matrix (n_periods x n_units).
695
+ D : np.ndarray
696
+ Treatment indicator matrix (n_periods x n_units).
697
+ control_mask : np.ndarray
698
+ Boolean mask for control observations.
699
+ control_unit_idx : np.ndarray
700
+ Indices of control units.
701
+ n_units : int
702
+ Number of units.
703
+ n_periods : int
704
+ Number of periods.
705
+ param_name : str
706
+ Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
707
+ grid : List[float]
708
+ Grid of values to search over.
709
+ fixed_params : Dict[str, float]
710
+ Fixed values for other parameters. May include _LAMBDA_INF.
711
+
712
+ Returns
713
+ -------
714
+ Tuple[float, float]
715
+ (best_value, best_score) for the searched parameter.
716
+ """
717
+ best_score = np.inf
718
+ best_value = grid[0] if grid else 0.0
719
+
720
+ for value in grid:
721
+ params = {**fixed_params, param_name: value}
722
+
723
+ # Convert inf values to 0 for computation (inf means "disabled" = uniform weights)
724
+ lambda_time = params.get('lambda_time', 0.0)
725
+ lambda_unit = params.get('lambda_unit', 0.0)
726
+ lambda_nn = params.get('lambda_nn', 0.0)
727
+
728
+ # Handle infinity as "disabled" mode
729
+ # Per paper Equations 2-3:
730
+ # - λ_time/λ_unit=∞ → exp(-∞×dist)→0 for dist>0, uniform weights → use 0.0
731
+ # - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
732
+ # Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled"
733
+ if np.isinf(lambda_time):
734
+ lambda_time = 0.0 # Uniform time weights
735
+ if np.isinf(lambda_unit):
736
+ lambda_unit = 0.0 # Uniform unit weights
737
+ if np.isinf(lambda_nn):
738
+ lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
739
+
740
+ try:
741
+ score = self._loocv_score_obs_specific(
742
+ Y, D, control_mask, control_unit_idx,
743
+ lambda_time, lambda_unit, lambda_nn,
744
+ n_units, n_periods
745
+ )
746
+ if score < best_score:
747
+ best_score = score
748
+ best_value = value
749
+ except (np.linalg.LinAlgError, ValueError):
750
+ continue
751
+
752
+ return best_value, best_score
753
+
754
+ def _cycling_parameter_search(
755
+ self,
756
+ Y: np.ndarray,
757
+ D: np.ndarray,
758
+ control_mask: np.ndarray,
759
+ control_unit_idx: np.ndarray,
760
+ n_units: int,
761
+ n_periods: int,
762
+ initial_lambda: Tuple[float, float, float],
763
+ max_cycles: int = 10,
764
+ ) -> Tuple[float, float, float]:
765
+ """
766
+ Cycle through parameters until convergence (coordinate descent).
767
+
768
+ Following paper's footnote 2 (Stage 2), this iteratively optimizes
769
+ each tuning parameter while holding the others fixed, until convergence.
770
+
771
+ Parameters
772
+ ----------
773
+ Y : np.ndarray
774
+ Outcome matrix (n_periods x n_units).
775
+ D : np.ndarray
776
+ Treatment indicator matrix (n_periods x n_units).
777
+ control_mask : np.ndarray
778
+ Boolean mask for control observations.
779
+ control_unit_idx : np.ndarray
780
+ Indices of control units.
781
+ n_units : int
782
+ Number of units.
783
+ n_periods : int
784
+ Number of periods.
785
+ initial_lambda : Tuple[float, float, float]
786
+ Initial values (lambda_time, lambda_unit, lambda_nn).
787
+ max_cycles : int, default=10
788
+ Maximum number of coordinate descent cycles.
789
+
790
+ Returns
791
+ -------
792
+ Tuple[float, float, float]
793
+ Optimized (lambda_time, lambda_unit, lambda_nn).
794
+ """
795
+ lambda_time, lambda_unit, lambda_nn = initial_lambda
796
+ prev_score = np.inf
797
+
798
+ for cycle in range(max_cycles):
799
+ # Optimize λ_unit (fix λ_time, λ_nn)
800
+ lambda_unit, _ = self._univariate_loocv_search(
801
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
802
+ 'lambda_unit', self.lambda_unit_grid,
803
+ {'lambda_time': lambda_time, 'lambda_nn': lambda_nn}
804
+ )
805
+
806
+ # Optimize λ_time (fix λ_unit, λ_nn)
807
+ lambda_time, _ = self._univariate_loocv_search(
808
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
809
+ 'lambda_time', self.lambda_time_grid,
810
+ {'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn}
811
+ )
812
+
813
+ # Optimize λ_nn (fix λ_unit, λ_time)
814
+ lambda_nn, score = self._univariate_loocv_search(
815
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
816
+ 'lambda_nn', self.lambda_nn_grid,
817
+ {'lambda_unit': lambda_unit, 'lambda_time': lambda_time}
818
+ )
819
+
820
+ # Check convergence
821
+ if abs(score - prev_score) < 1e-6:
822
+ logger.debug(
823
+ "Cycling search converged after %d cycles with score %.6f",
824
+ cycle + 1, score
825
+ )
826
+ break
827
+ prev_score = score
828
+
829
+ return lambda_time, lambda_unit, lambda_nn
830
+
661
831
  def fit(
662
832
  self,
663
833
  data: pd.DataFrame,
@@ -665,7 +835,6 @@ class TROP:
665
835
  treatment: str,
666
836
  unit: str,
667
837
  time: str,
668
- post_periods: Optional[List[Any]] = None,
669
838
  ) -> TROPResults:
670
839
  """
671
840
  Fit the TROP model.
@@ -679,20 +848,31 @@ class TROP:
679
848
  Name of the outcome variable column.
680
849
  treatment : str
681
850
  Name of the treatment indicator column (0/1).
682
- Should be 1 for treated unit-time observations.
851
+
852
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a
853
+ treatment timing indicator. For each unit, D=1 for ALL periods
854
+ during and after treatment:
855
+
856
+ - D[t, i] = 0 for all t < g_i (pre-treatment periods)
857
+ - D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
858
+
859
+ where g_i is the treatment start time for unit i.
860
+
861
+ For staggered adoption, different units can have different g_i.
862
+ The ATT averages over ALL D=1 cells per Equation 1 of the paper.
683
863
  unit : str
684
864
  Name of the unit identifier column.
685
865
  time : str
686
866
  Name of the time period column.
687
- post_periods : list, optional
688
- List of time period values that are post-treatment.
689
- If None, infers from treatment indicator.
690
867
 
691
868
  Returns
692
869
  -------
693
870
  TROPResults
694
871
  Object containing the ATT estimate, standard error,
695
- factor estimates, and tuning parameters.
872
+ factor estimates, and tuning parameters. The lambda_*
873
+ attributes show the selected grid values. Infinity values
874
+ (∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform
875
+ weights), λ_nn=∞ → 1e10 (factor model disabled).
696
876
  """
697
877
  # Validate inputs
698
878
  required_cols = [outcome, treatment, unit, time]
@@ -720,13 +900,39 @@ class TROP:
720
900
  .reindex(index=all_periods, columns=all_units)
721
901
  .values
722
902
  )
723
- D = (
903
+
904
+ # For D matrix, track missing values BEFORE fillna to support unbalanced panels
905
+ # Issue 3 fix: Missing observations should not trigger spurious violations
906
+ D_raw = (
724
907
  data.pivot(index=time, columns=unit, values=treatment)
725
908
  .reindex(index=all_periods, columns=all_units)
726
- .fillna(0)
727
- .astype(int)
728
- .values
729
909
  )
910
+ missing_mask = pd.isna(D_raw).values # True where originally missing
911
+ D = D_raw.fillna(0).astype(int).values
912
+
913
+ # Validate D is monotonic non-decreasing per unit (absorbing state)
914
+ # D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
915
+ # Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
916
+ # This catches 1→0 violations that span missing period gaps
917
+ # Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
918
+ # adjacent period transitions don't show it (the gap hides the transition)
919
+ violating_units = []
920
+ for unit_idx in range(n_units):
921
+ # Get observed D values for this unit (where not missing)
922
+ observed_mask = ~missing_mask[:, unit_idx]
923
+ observed_d = D[observed_mask, unit_idx]
924
+
925
+ # Check if observed sequence is monotonically non-decreasing
926
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
927
+ violating_units.append(all_units[unit_idx])
928
+
929
+ if violating_units:
930
+ raise ValueError(
931
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
932
+ f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
933
+ f"If this is event-study style data, convert to absorbing state: "
934
+ f"D[t, i] = 1 for all t >= first treatment period."
935
+ )
730
936
 
731
937
  # Identify treated observations
732
938
  treated_mask = D == 1
@@ -743,28 +949,23 @@ class TROP:
743
949
  if len(control_unit_idx) == 0:
744
950
  raise ValueError("No control units found")
745
951
 
746
- # Determine pre/post periods
747
- if post_periods is None:
748
- # Infer from first treatment time
749
- first_treat_period = None
750
- for t in range(n_periods):
751
- if np.any(D[t, :] == 1):
752
- first_treat_period = t
753
- break
754
- if first_treat_period is None:
755
- raise ValueError("Could not infer post-treatment periods")
756
- pre_period_idx = list(range(first_treat_period))
757
- post_period_idx = list(range(first_treat_period, n_periods))
758
- else:
759
- post_period_idx = [period_to_idx[p] for p in post_periods if p in period_to_idx]
760
- pre_period_idx = [i for i in range(n_periods) if i not in post_period_idx]
952
+ # Determine pre/post periods from treatment indicator D
953
+ # D matrix is the sole input for treatment timing per the paper
954
+ first_treat_period = None
955
+ for t in range(n_periods):
956
+ if np.any(D[t, :] == 1):
957
+ first_treat_period = t
958
+ break
959
+ if first_treat_period is None:
960
+ raise ValueError("Could not infer post-treatment periods from D matrix")
761
961
 
762
- if len(pre_period_idx) < 2:
763
- raise ValueError("Need at least 2 pre-treatment periods")
962
+ n_pre_periods = first_treat_period
963
+ # Count periods where D=1 is actually observed (matches docstring)
964
+ # Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
965
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
764
966
 
765
- pre_periods_list = [idx_to_period[i] for i in pre_period_idx]
766
- post_periods_list = [idx_to_period[i] for i in post_period_idx]
767
- n_treated_periods = len(post_period_idx)
967
+ if n_pre_periods < 2:
968
+ raise ValueError("Need at least 2 pre-treatment periods")
768
969
 
769
970
  # Step 1: Grid search with LOOCV for tuning parameters
770
971
  best_lambda = None
@@ -789,14 +990,45 @@ class TROP:
789
990
  lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
790
991
  lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
791
992
 
792
- best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search(
993
+ result = _rust_loocv_grid_search(
793
994
  Y, D.astype(np.float64), control_mask_u8,
794
995
  time_dist_matrix,
795
996
  lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
796
997
  self.max_loocv_samples, self.max_iter, self.tol,
797
998
  self.seed if self.seed is not None else 0
798
999
  )
799
- best_lambda = (best_lt, best_lu, best_ln)
1000
+ # Unpack result - 7 values including optional first_failed_obs
1001
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
1002
+ # Only accept finite scores - infinite means all fits failed
1003
+ if np.isfinite(best_score):
1004
+ best_lambda = (best_lt, best_lu, best_ln)
1005
+ # else: best_lambda stays None, triggering defaults fallback
1006
+ # Emit warnings consistent with Python implementation
1007
+ if n_valid == 0:
1008
+ # Include failed observation coordinates if available (Issue 2 fix)
1009
+ obs_info = ""
1010
+ if first_failed_obs is not None:
1011
+ t_idx, i_idx = first_failed_obs
1012
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1013
+ warnings.warn(
1014
+ f"LOOCV: All {n_attempted} fits failed for "
1015
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1016
+ f"Returning infinite score.{obs_info}",
1017
+ UserWarning
1018
+ )
1019
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
1020
+ n_failed = n_attempted - n_valid
1021
+ # Include failed observation coordinates if available
1022
+ obs_info = ""
1023
+ if first_failed_obs is not None:
1024
+ t_idx, i_idx = first_failed_obs
1025
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
1026
+ warnings.warn(
1027
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
1028
+ f"λ=({best_lt}, {best_lu}, {best_ln}). "
1029
+ f"This may indicate numerical instability.{obs_info}",
1030
+ UserWarning
1031
+ )
800
1032
  except Exception as e:
801
1033
  # Fall back to Python implementation on error
802
1034
  logger.debug(
@@ -806,21 +1038,54 @@ class TROP:
806
1038
  best_score = np.inf
807
1039
 
808
1040
  # Fall back to Python implementation if Rust unavailable or failed
1041
+ # Uses two-stage approach per paper's footnote 2:
1042
+ # Stage 1: Univariate searches for initial values
1043
+ # Stage 2: Cycling (coordinate descent) until convergence
809
1044
  if best_lambda is None:
810
- for lambda_time in self.lambda_time_grid:
811
- for lambda_unit in self.lambda_unit_grid:
812
- for lambda_nn in self.lambda_nn_grid:
813
- try:
814
- score = self._loocv_score_obs_specific(
815
- Y, D, control_mask, control_unit_idx,
816
- lambda_time, lambda_unit, lambda_nn,
817
- n_units, n_periods
818
- )
819
- if score < best_score:
820
- best_score = score
821
- best_lambda = (lambda_time, lambda_unit, lambda_nn)
822
- except (np.linalg.LinAlgError, ValueError):
823
- continue
1045
+ # Stage 1: Univariate searches with extreme fixed values
1046
+ # Following paper's footnote 2 for initial bounds
1047
+
1048
+ # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
1049
+ lambda_time_init, _ = self._univariate_loocv_search(
1050
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1051
+ 'lambda_time', self.lambda_time_grid,
1052
+ {'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF}
1053
+ )
1054
+
1055
+ # λ_nn search: fix λ_time=∞ (uniform time weights), λ_unit=0
1056
+ lambda_nn_init, _ = self._univariate_loocv_search(
1057
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1058
+ 'lambda_nn', self.lambda_nn_grid,
1059
+ {'lambda_time': _LAMBDA_INF, 'lambda_unit': 0.0}
1060
+ )
1061
+
1062
+ # λ_unit search: fix λ_nn=∞, λ_time=0
1063
+ lambda_unit_init, _ = self._univariate_loocv_search(
1064
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1065
+ 'lambda_unit', self.lambda_unit_grid,
1066
+ {'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0}
1067
+ )
1068
+
1069
+ # Stage 2: Cycling refinement (coordinate descent)
1070
+ lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
1071
+ Y, D, control_mask, control_unit_idx, n_units, n_periods,
1072
+ (lambda_time_init, lambda_unit_init, lambda_nn_init)
1073
+ )
1074
+
1075
+ # Compute final score for the optimized parameters
1076
+ try:
1077
+ best_score = self._loocv_score_obs_specific(
1078
+ Y, D, control_mask, control_unit_idx,
1079
+ lambda_time, lambda_unit, lambda_nn,
1080
+ n_units, n_periods
1081
+ )
1082
+ # Only accept finite scores - infinite means all fits failed
1083
+ if np.isfinite(best_score):
1084
+ best_lambda = (lambda_time, lambda_unit, lambda_nn)
1085
+ # else: best_lambda stays None, triggering defaults fallback
1086
+ except (np.linalg.LinAlgError, ValueError):
1087
+ # If even the optimized parameters fail, best_lambda stays None
1088
+ pass
824
1089
 
825
1090
  if best_lambda is None:
826
1091
  warnings.warn(
@@ -833,6 +1098,26 @@ class TROP:
833
1098
  self._optimal_lambda = best_lambda
834
1099
  lambda_time, lambda_unit, lambda_nn = best_lambda
835
1100
 
1101
+ # Convert infinity values for final estimation (matching LOOCV conversion)
1102
+ # This ensures final estimation uses the same effective parameters that LOOCV evaluated.
1103
+ # See REGISTRY.md "λ=∞ implementation" for rationale.
1104
+ #
1105
+ # IMPORTANT: Store original grid values for results, use converted for computation.
1106
+ # This lets users see what was selected from their grid, while ensuring consistent
1107
+ # behavior between point estimation and variance estimation.
1108
+ original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda
1109
+
1110
+ if np.isinf(lambda_time):
1111
+ lambda_time = 0.0 # Uniform time weights
1112
+ if np.isinf(lambda_unit):
1113
+ lambda_unit = 0.0 # Uniform unit weights
1114
+ if np.isinf(lambda_nn):
1115
+ lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
1116
+
1117
+ # Create effective_lambda with converted values for ALL downstream computation
1118
+ # This ensures variance estimation uses the same parameters as point estimation
1119
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
1120
+
836
1121
  # Step 2: Final estimation - per-observation model fitting following Algorithm 2
837
1122
  # For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it}
838
1123
  treatment_effects = {}
@@ -886,14 +1171,16 @@ class TROP:
886
1171
  effective_rank = 0.0
887
1172
 
888
1173
  # Step 4: Variance estimation
1174
+ # Use effective_lambda (converted values) to ensure SE is computed with same
1175
+ # parameters as point estimation. This fixes the variance inconsistency issue.
889
1176
  if self.variance_method == "bootstrap":
890
1177
  se, bootstrap_dist = self._bootstrap_variance(
891
- data, outcome, treatment, unit, time, post_periods_list,
892
- best_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
1178
+ data, outcome, treatment, unit, time,
1179
+ effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
893
1180
  )
894
1181
  else:
895
1182
  se, bootstrap_dist = self._jackknife_variance(
896
- Y, D, control_mask, control_unit_idx, best_lambda,
1183
+ Y, D, control_mask, control_unit_idx, effective_lambda,
897
1184
  n_units, n_periods
898
1185
  )
899
1186
 
@@ -901,11 +1188,12 @@ class TROP:
901
1188
  if se > 0:
902
1189
  t_stat = att / se
903
1190
  p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1)))
1191
+ conf_int = compute_confidence_interval(att, se, self.alpha)
904
1192
  else:
905
- t_stat = 0.0
906
- p_value = 1.0
907
-
908
- conf_int = compute_confidence_interval(att, se, self.alpha)
1193
+ # When SE is undefined/zero, ALL inference fields should be NaN
1194
+ t_stat = np.nan
1195
+ p_value = np.nan
1196
+ conf_int = (np.nan, np.nan)
909
1197
 
910
1198
  # Create results dictionaries
911
1199
  unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
@@ -925,16 +1213,18 @@ class TROP:
925
1213
  unit_effects=unit_effects_dict,
926
1214
  time_effects=time_effects_dict,
927
1215
  treatment_effects=treatment_effects,
928
- lambda_time=lambda_time,
929
- lambda_unit=lambda_unit,
930
- lambda_nn=lambda_nn,
1216
+ # Store ORIGINAL grid values (possibly inf) so users see what was selected.
1217
+ # Internally, infinity values are converted for computation (see effective_lambda).
1218
+ lambda_time=original_lambda_time,
1219
+ lambda_unit=original_lambda_unit,
1220
+ lambda_nn=original_lambda_nn,
931
1221
  factor_matrix=L_hat,
932
1222
  effective_rank=effective_rank,
933
1223
  loocv_score=best_score,
934
1224
  variance_method=self.variance_method,
935
1225
  alpha=self.alpha,
936
- pre_periods=pre_periods_list,
937
- post_periods=post_periods_list,
1226
+ n_pre_periods=n_pre_periods,
1227
+ n_post_periods=n_post_periods,
938
1228
  n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None,
939
1229
  bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
940
1230
  )
@@ -1409,6 +1699,17 @@ class TROP:
1409
1699
  indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
1410
1700
  control_obs = [control_obs[idx] for idx in indices]
1411
1701
 
1702
+ # Empty control set check: if no control observations, return infinity
1703
+ # A score of 0.0 would incorrectly "win" over legitimate parameters
1704
+ if len(control_obs) == 0:
1705
+ warnings.warn(
1706
+ f"LOOCV: No valid control observations for "
1707
+ f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
1708
+ "Returning infinite score.",
1709
+ UserWarning
1710
+ )
1711
+ return np.inf
1712
+
1412
1713
  tau_squared_sum = 0.0
1413
1714
  n_valid = 0
1414
1715
 
@@ -1433,12 +1734,19 @@ class TROP:
1433
1734
  n_valid += 1
1434
1735
 
1435
1736
  except (np.linalg.LinAlgError, ValueError):
1436
- continue
1437
-
1438
- if n_valid == 0:
1439
- return np.inf
1737
+ # Per Equation 5: Q(λ) must sum over ALL D==0 cells
1738
+ # Any failure means this λ cannot produce valid estimates for all cells
1739
+ warnings.warn(
1740
+ f"LOOCV: Fit failed for observation ({t}, {i}) with "
1741
+ f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). "
1742
+ "Returning infinite score per Equation 5.",
1743
+ UserWarning
1744
+ )
1745
+ return np.inf
1440
1746
 
1441
- return tau_squared_sum / n_valid
1747
+ # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
1748
+ # Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
1749
+ return tau_squared_sum
1442
1750
 
1443
1751
  def _bootstrap_variance(
1444
1752
  self,
@@ -1447,7 +1755,6 @@ class TROP:
1447
1755
  treatment: str,
1448
1756
  unit: str,
1449
1757
  time: str,
1450
- post_periods: List[Any],
1451
1758
  optimal_lambda: Tuple[float, float, float],
1452
1759
  Y: Optional[np.ndarray] = None,
1453
1760
  D: Optional[np.ndarray] = None,
@@ -1473,8 +1780,6 @@ class TROP:
1473
1780
  Name of the unit identifier column in data.
1474
1781
  time : str
1475
1782
  Name of the time period column in data.
1476
- post_periods : list
1477
- List of post-treatment time periods.
1478
1783
  optimal_lambda : tuple of float
1479
1784
  Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
1480
1785
  from cross-validation. Used for model estimation in each bootstrap.
@@ -1579,7 +1884,7 @@ class TROP:
1579
1884
  # Fit with fixed lambda (skip LOOCV for speed)
1580
1885
  att = self._fit_with_fixed_lambda(
1581
1886
  boot_data, outcome, treatment, unit, time,
1582
- post_periods, optimal_lambda
1887
+ optimal_lambda
1583
1888
  )
1584
1889
  bootstrap_estimates_list.append(att)
1585
1890
  except (ValueError, np.linalg.LinAlgError, KeyError):
@@ -1703,7 +2008,6 @@ class TROP:
1703
2008
  treatment: str,
1704
2009
  unit: str,
1705
2010
  time: str,
1706
- post_periods: List[Any],
1707
2011
  fixed_lambda: Tuple[float, float, float],
1708
2012
  ) -> float:
1709
2013
  """
@@ -1803,7 +2107,6 @@ def trop(
1803
2107
  treatment: str,
1804
2108
  unit: str,
1805
2109
  time: str,
1806
- post_periods: Optional[List[Any]] = None,
1807
2110
  **kwargs,
1808
2111
  ) -> TROPResults:
1809
2112
  """
@@ -1816,13 +2119,16 @@ def trop(
1816
2119
  outcome : str
1817
2120
  Outcome variable column name.
1818
2121
  treatment : str
1819
- Treatment indicator column name.
2122
+ Treatment indicator column name (0/1).
2123
+
2124
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
2125
+ timing indicator. For each unit, D=1 for ALL periods during and after
2126
+ treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
2127
+ the treatment start time for unit i).
1820
2128
  unit : str
1821
2129
  Unit identifier column name.
1822
2130
  time : str
1823
2131
  Time period column name.
1824
- post_periods : list, optional
1825
- Post-treatment periods.
1826
2132
  **kwargs
1827
2133
  Additional arguments passed to TROP constructor.
1828
2134
 
@@ -1834,8 +2140,8 @@ def trop(
1834
2140
  Examples
1835
2141
  --------
1836
2142
  >>> from diff_diff import trop
1837
- >>> results = trop(data, 'y', 'treated', 'unit', 'time', post_periods=[5,6,7])
2143
+ >>> results = trop(data, 'y', 'treated', 'unit', 'time')
1838
2144
  >>> print(f"ATT: {results.att:.3f}")
1839
2145
  """
1840
2146
  estimator = TROP(**kwargs)
1841
- return estimator.fit(data, outcome, treatment, unit, time, post_periods)
2147
+ return estimator.fit(data, outcome, treatment, unit, time)
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "diff-diff"
7
- version = "2.1.7"
7
+ version = "2.1.8"
8
8
  description = "A library for Difference-in-Differences causal inference analysis"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -289,7 +289,7 @@ dependencies = [
289
289
 
290
290
  [[package]]
291
291
  name = "diff_diff_rust"
292
- version = "2.1.7"
292
+ version = "2.1.8"
293
293
  dependencies = [
294
294
  "ndarray",
295
295
  "ndarray-linalg",
@@ -2316,6 +2316,6 @@ dependencies = [
2316
2316
 
2317
2317
  [[package]]
2318
2318
  name = "zmij"
2319
- version = "1.0.16"
2319
+ version = "1.0.17"
2320
2320
  source = "registry+https://github.com/rust-lang/crates.io-index"
2321
- checksum = "dfcd145825aace48cff44a8844de64bf75feec3080e0aa5cdbde72961ae51a65"
2321
+ checksum = "02aae0f83f69aafc94776e879363e9771d7ecbffe2c7fbb6c14c5e00dfe88439"
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "diff_diff_rust"
3
- version = "2.1.7"
3
+ version = "2.1.8"
4
4
  edition = "2021"
5
5
  description = "Rust backend for diff-diff DiD library"
6
6
  license = "MIT"
@@ -146,10 +146,160 @@ fn compute_pair_distance(
146
146
  }
147
147
  }
148
148
 
149
- /// Perform LOOCV grid search over tuning parameters in parallel.
149
+ /// Perform univariate LOOCV search over a single parameter.
150
150
  ///
151
- /// Evaluates all combinations of (lambda_time, lambda_unit, lambda_nn) in parallel
152
- /// and returns the combination with the lowest LOOCV score.
151
+ /// Following paper's footnote 2, this performs a grid search for one parameter
152
+ /// while holding others fixed. Used in the two-stage LOOCV approach.
153
+ ///
154
+ /// # Arguments
155
+ /// * `y` - Outcome matrix (n_periods x n_units)
156
+ /// * `d` - Treatment indicator matrix
157
+ /// * `control_mask` - Boolean mask for control observations
158
+ /// * `time_dist` - Time distance matrix
159
+ /// * `control_obs` - List of control observations for LOOCV
160
+ /// * `grid` - Grid of values to search
161
+ /// * `fixed_time` - Fixed lambda_time (inf for disabled)
162
+ /// * `fixed_unit` - Fixed lambda_unit (inf for disabled)
163
+ /// * `fixed_nn` - Fixed lambda_nn (inf for disabled)
164
+ /// * `param_type` - Which parameter to search: 0=time, 1=unit, 2=nn
165
+ /// * `max_iter` - Maximum iterations
166
+ /// * `tol` - Convergence tolerance
167
+ ///
168
+ /// # Returns
169
+ /// (best_value, best_score)
170
+ fn univariate_loocv_search(
171
+ y: &ArrayView2<f64>,
172
+ d: &ArrayView2<f64>,
173
+ control_mask: &ArrayView2<u8>,
174
+ time_dist: &ArrayView2<i64>,
175
+ control_obs: &[(usize, usize)],
176
+ grid: &[f64],
177
+ fixed_time: f64,
178
+ fixed_unit: f64,
179
+ fixed_nn: f64,
180
+ param_type: usize, // 0=time, 1=unit, 2=nn
181
+ max_iter: usize,
182
+ tol: f64,
183
+ ) -> (f64, f64) {
184
+ let mut best_score = f64::INFINITY;
185
+ let mut best_value = grid.first().copied().unwrap_or(0.0);
186
+
187
+ // Parallelize over grid values
188
+ let results: Vec<(f64, f64)> = grid
189
+ .par_iter()
190
+ .map(|&value| {
191
+ // Set parameters, converting inf for "disabled" mode
192
+ // Per paper Equations 2-3:
193
+ // - λ_time/λ_unit=∞ → uniform weights → use 0.0
194
+ // - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
195
+ // Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled"
196
+ //
197
+ // IMPORTANT: Convert the grid value BEFORE using it, matching Python behavior.
198
+ // This ensures Rust and Python evaluate the same objective for infinity grids.
199
+ let (lambda_time, lambda_unit, lambda_nn) = match param_type {
200
+ 0 => {
201
+ // Searching λ_time: convert grid value if infinite
202
+ let value_converted = if value.is_infinite() { 0.0 } else { value };
203
+ (value_converted,
204
+ if fixed_unit.is_infinite() { 0.0 } else { fixed_unit },
205
+ if fixed_nn.is_infinite() { 1e10 } else { fixed_nn })
206
+ },
207
+ 1 => {
208
+ // Searching λ_unit: convert grid value if infinite
209
+ let value_converted = if value.is_infinite() { 0.0 } else { value };
210
+ (if fixed_time.is_infinite() { 0.0 } else { fixed_time },
211
+ value_converted,
212
+ if fixed_nn.is_infinite() { 1e10 } else { fixed_nn })
213
+ },
214
+ _ => {
215
+ // Searching λ_nn: convert grid value if infinite
216
+ let value_converted = if value.is_infinite() { 1e10 } else { value };
217
+ (if fixed_time.is_infinite() { 0.0 } else { fixed_time },
218
+ if fixed_unit.is_infinite() { 0.0 } else { fixed_unit },
219
+ value_converted)
220
+ },
221
+ };
222
+
223
+ let (score, _, _) = loocv_score_for_params(
224
+ y, d, control_mask, time_dist, control_obs,
225
+ lambda_time, lambda_unit, lambda_nn,
226
+ max_iter, tol,
227
+ );
228
+ (value, score)
229
+ })
230
+ .collect();
231
+
232
+ for (value, score) in results {
233
+ if score < best_score {
234
+ best_score = score;
235
+ best_value = value;
236
+ }
237
+ }
238
+
239
+ (best_value, best_score)
240
+ }
241
+
242
+ /// Cycle through parameters until convergence (coordinate descent).
243
+ ///
244
+ /// Following paper's footnote 2 (Stage 2), iteratively optimize each parameter.
245
+ fn cycling_parameter_search(
246
+ y: &ArrayView2<f64>,
247
+ d: &ArrayView2<f64>,
248
+ control_mask: &ArrayView2<u8>,
249
+ time_dist: &ArrayView2<i64>,
250
+ control_obs: &[(usize, usize)],
251
+ lambda_time_grid: &[f64],
252
+ lambda_unit_grid: &[f64],
253
+ lambda_nn_grid: &[f64],
254
+ initial_time: f64,
255
+ initial_unit: f64,
256
+ initial_nn: f64,
257
+ max_iter: usize,
258
+ tol: f64,
259
+ max_cycles: usize,
260
+ ) -> (f64, f64, f64) {
261
+ let mut lambda_time = initial_time;
262
+ let mut lambda_unit = initial_unit;
263
+ let mut lambda_nn = initial_nn;
264
+ let mut prev_score = f64::INFINITY;
265
+
266
+ for _cycle in 0..max_cycles {
267
+ // Optimize λ_unit (fix λ_time, λ_nn)
268
+ let (new_unit, _) = univariate_loocv_search(
269
+ y, d, control_mask, time_dist, control_obs,
270
+ lambda_unit_grid, lambda_time, 0.0, lambda_nn, 1, max_iter, tol,
271
+ );
272
+ lambda_unit = new_unit;
273
+
274
+ // Optimize λ_time (fix λ_unit, λ_nn)
275
+ let (new_time, _) = univariate_loocv_search(
276
+ y, d, control_mask, time_dist, control_obs,
277
+ lambda_time_grid, 0.0, lambda_unit, lambda_nn, 0, max_iter, tol,
278
+ );
279
+ lambda_time = new_time;
280
+
281
+ // Optimize λ_nn (fix λ_unit, λ_time)
282
+ let (new_nn, score) = univariate_loocv_search(
283
+ y, d, control_mask, time_dist, control_obs,
284
+ lambda_nn_grid, lambda_time, lambda_unit, 0.0, 2, max_iter, tol,
285
+ );
286
+ lambda_nn = new_nn;
287
+
288
+ // Check convergence
289
+ if (score - prev_score).abs() < 1e-6 {
290
+ break;
291
+ }
292
+ prev_score = score;
293
+ }
294
+
295
+ (lambda_time, lambda_unit, lambda_nn)
296
+ }
297
+
298
+ /// Perform LOOCV grid search over tuning parameters using two-stage approach.
299
+ ///
300
+ /// Following paper's footnote 2:
301
+ /// - Stage 1: Univariate searches for initial values with extreme fixed parameters
302
+ /// - Stage 2: Cycling (coordinate descent) until convergence
153
303
  ///
154
304
  /// Following TROP Equation 5 (page 8):
155
305
  /// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
@@ -158,8 +308,6 @@ fn compute_pair_distance(
158
308
  /// * `y` - Outcome matrix (n_periods x n_units)
159
309
  /// * `d` - Treatment indicator matrix (n_periods x n_units)
160
310
  /// * `control_mask` - Boolean mask (n_periods x n_units) for control observations
161
- /// * `control_unit_idx` - Array of control unit indices
162
- /// * `unit_dist_matrix` - Pre-computed unit distance matrix (n_units x n_units)
163
311
  /// * `time_dist_matrix` - Pre-computed time distance matrix (n_periods x n_periods)
164
312
  /// * `lambda_time_grid` - Grid of time decay parameters
165
313
  /// * `lambda_unit_grid` - Grid of unit distance parameters
@@ -170,7 +318,10 @@ fn compute_pair_distance(
170
318
  /// * `seed` - Random seed for subsampling
171
319
  ///
172
320
  /// # Returns
173
- /// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score)
321
+ /// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs)
322
+ /// where n_valid and n_attempted are the counts for the best parameter combination,
323
+ /// allowing Python to emit warnings when >10% of fits fail.
324
+ /// first_failed_obs is Some((t, i)) if a fit failed during final score computation, None otherwise.
174
325
  #[pyfunction]
175
326
  #[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
176
327
  #[allow(clippy::too_many_arguments)]
@@ -187,7 +338,7 @@ pub fn loocv_grid_search<'py>(
187
338
  max_iter: usize,
188
339
  tol: f64,
189
340
  seed: u64,
190
- ) -> PyResult<(f64, f64, f64, f64)> {
341
+ ) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> {
191
342
  let y_arr = y.as_array();
192
343
  let d_arr = d.as_array();
193
344
  let control_mask_arr = control_mask.as_array();
@@ -204,43 +355,53 @@ pub fn loocv_grid_search<'py>(
204
355
  seed,
205
356
  );
206
357
 
207
- // Generate all parameter combinations
208
- let mut param_combos: Vec<(f64, f64, f64)> = Vec::new();
209
- for &lt in &lambda_time_vec {
210
- for &lu in &lambda_unit_vec {
211
- for &ln in &lambda_nn_vec {
212
- param_combos.push((lt, lu, ln));
213
- }
214
- }
215
- }
358
+ let n_attempted = control_obs.len();
216
359
 
217
- // Evaluate all combinations in parallel
218
- let results: Vec<(f64, f64, f64, f64)> = param_combos
219
- .par_iter()
220
- .map(|&(lambda_time, lambda_unit, lambda_nn)| {
221
- let score = loocv_score_for_params(
222
- &y_arr,
223
- &d_arr,
224
- &control_mask_arr,
225
- &time_dist_arr,
226
- &control_obs,
227
- lambda_time,
228
- lambda_unit,
229
- lambda_nn,
230
- max_iter,
231
- tol,
232
- );
233
- (lambda_time, lambda_unit, lambda_nn, score)
234
- })
235
- .collect();
360
+ // Stage 1: Univariate searches for initial values (paper footnote 2)
361
+ // λ_time search: fix λ_unit=0, λ_nn=∞ (disabled)
362
+ let (lambda_time_init, _) = univariate_loocv_search(
363
+ &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
364
+ &lambda_time_vec, 0.0, 0.0, f64::INFINITY, 0, max_iter, tol,
365
+ );
236
366
 
237
- // Find best (minimum score)
238
- let best = results
239
- .into_iter()
240
- .min_by(|a, b| a.3.partial_cmp(&b.3).unwrap_or(std::cmp::Ordering::Equal))
241
- .unwrap_or((1.0, 1.0, 0.1, f64::INFINITY));
367
+ // λ_nn search: fix λ_time=∞ (disabled), λ_unit=0
368
+ let (lambda_nn_init, _) = univariate_loocv_search(
369
+ &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
370
+ &lambda_nn_vec, f64::INFINITY, 0.0, 0.0, 2, max_iter, tol,
371
+ );
372
+
373
+ // λ_unit search: fix λ_nn=∞, λ_time=0
374
+ let (lambda_unit_init, _) = univariate_loocv_search(
375
+ &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
376
+ &lambda_unit_vec, 0.0, 0.0, f64::INFINITY, 1, max_iter, tol,
377
+ );
242
378
 
243
- Ok(best)
379
+ // Stage 2: Cycling refinement
380
+ let (best_time, best_unit, best_nn) = cycling_parameter_search(
381
+ &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
382
+ &lambda_time_vec, &lambda_unit_vec, &lambda_nn_vec,
383
+ lambda_time_init, lambda_unit_init, lambda_nn_init,
384
+ max_iter, tol, 10,
385
+ );
386
+
387
+ // Convert infinity values BEFORE computing final score (Issue 1 fix)
388
+ // Per paper Equations 2-3:
389
+ // - λ_time/λ_unit=∞ → uniform weights → use 0.0
390
+ // - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
391
+ // This ensures final score computation matches what LOOCV evaluated.
392
+ let best_time_eff = if best_time.is_infinite() { 0.0 } else { best_time };
393
+ let best_unit_eff = if best_unit.is_infinite() { 0.0 } else { best_unit };
394
+ let best_nn_eff = if best_nn.is_infinite() { 1e10 } else { best_nn };
395
+
396
+ // Compute final score with converted values
397
+ let (best_score, n_valid, first_failed) = loocv_score_for_params(
398
+ &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
399
+ best_time_eff, best_unit_eff, best_nn_eff,
400
+ max_iter, tol,
401
+ );
402
+
403
+ // Return ORIGINAL grid values (for user visibility) but score computed with converted
404
+ Ok((best_time, best_unit, best_nn, best_score, n_valid, n_attempted, first_failed))
244
405
  }
245
406
 
246
407
  /// Get sampled control observations for LOOCV.
@@ -277,6 +438,10 @@ fn get_control_observations(
277
438
  }
278
439
 
279
440
  /// Compute LOOCV score for a specific parameter combination.
441
+ ///
442
+ /// # Returns
443
+ /// (score, n_valid, first_failed_obs) - the LOOCV score, number of successful fits,
444
+ /// and the first failed observation (t, i) if any fit failed, None otherwise.
280
445
  #[allow(clippy::too_many_arguments)]
281
446
  fn loocv_score_for_params(
282
447
  y: &ArrayView2<f64>,
@@ -289,7 +454,7 @@ fn loocv_score_for_params(
289
454
  lambda_nn: f64,
290
455
  max_iter: usize,
291
456
  tol: f64,
292
- ) -> f64 {
457
+ ) -> (f64, usize, Option<(usize, usize)>) {
293
458
  let n_periods = y.nrows();
294
459
  let n_units = y.ncols();
295
460
 
@@ -328,14 +493,21 @@ fn loocv_score_for_params(
328
493
  tau_sq_sum += tau * tau;
329
494
  n_valid += 1;
330
495
  }
331
- None => continue, // Skip if estimation failed
496
+ None => {
497
+ // Per Equation 5: Q(λ) must sum over ALL D==0 cells
498
+ // Any failure means this λ cannot produce valid estimates for all cells
499
+ // Return the failed observation (t, i) for warning metadata
500
+ return (f64::INFINITY, n_valid, Some((t, i)));
501
+ }
332
502
  }
333
503
  }
334
504
 
335
505
  if n_valid == 0 {
336
- f64::INFINITY
506
+ (f64::INFINITY, 0, None)
337
507
  } else {
338
- tau_sq_sum / n_valid as f64
508
+ // Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
509
+ // Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
510
+ (tau_sq_sum, n_valid, None)
339
511
  }
340
512
  }
341
513
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes