diff-diff 2.1.0__tar.gz → 2.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {diff_diff-2.1.0 → diff_diff-2.1.1}/PKG-INFO +1 -1
  2. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/__init__.py +1 -1
  3. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/_backend.py +16 -0
  4. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/pretrends.py +104 -11
  5. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/staggered.py +4 -0
  6. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/trop.py +516 -161
  7. {diff_diff-2.1.0 → diff_diff-2.1.1}/pyproject.toml +1 -1
  8. {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/Cargo.lock +6 -81
  9. {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/Cargo.toml +1 -1
  10. {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/lib.rs +6 -0
  11. diff_diff-2.1.1/rust/src/trop.rs +861 -0
  12. {diff_diff-2.1.0 → diff_diff-2.1.1}/README.md +0 -0
  13. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/bacon.py +0 -0
  14. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/datasets.py +0 -0
  15. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/diagnostics.py +0 -0
  16. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/estimators.py +0 -0
  17. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/honest_did.py +0 -0
  18. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/linalg.py +0 -0
  19. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/power.py +0 -0
  20. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/prep.py +0 -0
  21. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/results.py +0 -0
  22. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/sun_abraham.py +0 -0
  23. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/synthetic_did.py +0 -0
  24. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/triple_diff.py +0 -0
  25. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/twfe.py +0 -0
  26. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/utils.py +0 -0
  27. {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/visualization.py +0 -0
  28. {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/bootstrap.rs +0 -0
  29. {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/linalg.rs +0 -0
  30. {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.1.0
3
+ Version: 2.1.1
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -131,7 +131,7 @@ from diff_diff.datasets import (
131
131
  load_mpdta,
132
132
  )
133
133
 
134
- __version__ = "2.1.0"
134
+ __version__ = "2.1.1"
135
135
  __all__ = [
136
136
  # Estimators
137
137
  "DifferenceInDifferences",
@@ -23,6 +23,10 @@ try:
23
23
  project_simplex as _rust_project_simplex,
24
24
  solve_ols as _rust_solve_ols,
25
25
  compute_robust_vcov as _rust_compute_robust_vcov,
26
+ # TROP estimator acceleration
27
+ compute_unit_distance_matrix as _rust_unit_distance_matrix,
28
+ loocv_grid_search as _rust_loocv_grid_search,
29
+ bootstrap_trop_variance as _rust_bootstrap_trop_variance,
26
30
  )
27
31
  _rust_available = True
28
32
  except ImportError:
@@ -32,6 +36,10 @@ except ImportError:
32
36
  _rust_project_simplex = None
33
37
  _rust_solve_ols = None
34
38
  _rust_compute_robust_vcov = None
39
+ # TROP estimator acceleration
40
+ _rust_unit_distance_matrix = None
41
+ _rust_loocv_grid_search = None
42
+ _rust_bootstrap_trop_variance = None
35
43
 
36
44
  # Determine final backend based on environment variable and availability
37
45
  if _backend_env == 'python':
@@ -42,6 +50,10 @@ if _backend_env == 'python':
42
50
  _rust_project_simplex = None
43
51
  _rust_solve_ols = None
44
52
  _rust_compute_robust_vcov = None
53
+ # TROP estimator acceleration
54
+ _rust_unit_distance_matrix = None
55
+ _rust_loocv_grid_search = None
56
+ _rust_bootstrap_trop_variance = None
45
57
  elif _backend_env == 'rust':
46
58
  # Force Rust mode - fail if not available
47
59
  if not _rust_available:
@@ -61,4 +73,8 @@ __all__ = [
61
73
  '_rust_project_simplex',
62
74
  '_rust_solve_ols',
63
75
  '_rust_compute_robust_vcov',
76
+ # TROP estimator acceleration
77
+ '_rust_unit_distance_matrix',
78
+ '_rust_loocv_grid_search',
79
+ '_rust_bootstrap_trop_variance',
64
80
  ]
@@ -202,6 +202,63 @@ class PreTrendsPowerResults:
202
202
  """Convert results to DataFrame."""
203
203
  return pd.DataFrame([self.to_dict()])
204
204
 
205
+ def power_at(self, M: float) -> float:
206
+ """
207
+ Compute power to detect a specific violation magnitude.
208
+
209
+ This method allows computing power at different M values without
210
+ re-fitting the model, using the stored variance-covariance matrix.
211
+
212
+ Parameters
213
+ ----------
214
+ M : float
215
+ Violation magnitude to evaluate.
216
+
217
+ Returns
218
+ -------
219
+ float
220
+ Power to detect violation of magnitude M.
221
+ """
222
+ from scipy import stats
223
+
224
+ n_pre = self.n_pre_periods
225
+
226
+ # Reconstruct violation weights based on violation type
227
+ # Must match PreTrendsPower._get_violation_weights() exactly
228
+ if self.violation_type == "linear":
229
+ # Linear trend: weights decrease toward treatment
230
+ # [n-1, n-2, ..., 1, 0] for n pre-periods
231
+ weights = np.arange(-n_pre + 1, 1, dtype=float)
232
+ weights = -weights # Now [n-1, n-2, ..., 1, 0]
233
+ elif self.violation_type == "constant":
234
+ weights = np.ones(n_pre)
235
+ elif self.violation_type == "last_period":
236
+ weights = np.zeros(n_pre)
237
+ weights[-1] = 1.0
238
+ else:
239
+ # For custom, we can't reconstruct - use equal weights as fallback
240
+ weights = np.ones(n_pre)
241
+
242
+ # Normalize weights to unit L2 norm
243
+ norm = np.linalg.norm(weights)
244
+ if norm > 0:
245
+ weights = weights / norm
246
+
247
+ # Compute non-centrality parameter
248
+ try:
249
+ vcov_inv = np.linalg.inv(self.vcov)
250
+ except np.linalg.LinAlgError:
251
+ vcov_inv = np.linalg.pinv(self.vcov)
252
+
253
+ # delta = M * weights
254
+ # nc = delta' * V^{-1} * delta
255
+ noncentrality = M**2 * (weights @ vcov_inv @ weights)
256
+
257
+ # Compute power using non-central chi-squared
258
+ power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality)
259
+
260
+ return float(power)
261
+
205
262
 
206
263
  @dataclass
207
264
  class PreTrendsPowerCurve:
@@ -471,10 +528,18 @@ class PreTrendsPower:
471
528
  def _extract_pre_period_params(
472
529
  self,
473
530
  results: Union[MultiPeriodDiDResults, Any],
531
+ pre_periods: Optional[List[int]] = None,
474
532
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
475
533
  """
476
534
  Extract pre-period parameters from results.
477
535
 
536
+ Parameters
537
+ ----------
538
+ results : MultiPeriodDiDResults or similar
539
+ Results object from event study estimation.
540
+ pre_periods : list of int, optional
541
+ Explicit list of pre-treatment periods. If None, uses results.pre_periods.
542
+
478
543
  Returns
479
544
  -------
480
545
  effects : np.ndarray
@@ -487,13 +552,18 @@ class PreTrendsPower:
487
552
  Number of pre-periods.
488
553
  """
489
554
  if isinstance(results, MultiPeriodDiDResults):
490
- # Get pre-period information
491
- all_pre_periods = results.pre_periods
555
+ # Get pre-period information - use explicit pre_periods if provided
556
+ if pre_periods is not None:
557
+ all_pre_periods = list(pre_periods)
558
+ else:
559
+ all_pre_periods = results.pre_periods
492
560
 
493
561
  if len(all_pre_periods) == 0:
494
562
  raise ValueError(
495
563
  "No pre-treatment periods found in results. "
496
- "Pre-trends power analysis requires pre-period coefficients."
564
+ "Pre-trends power analysis requires pre-period coefficients. "
565
+ "If you estimated all periods as post_periods, use the pre_periods "
566
+ "parameter to specify which are actually pre-treatment."
497
567
  )
498
568
 
499
569
  # Only include periods with actual estimated coefficients
@@ -775,6 +845,7 @@ class PreTrendsPower:
775
845
  self,
776
846
  results: Union[MultiPeriodDiDResults, Any],
777
847
  M: Optional[float] = None,
848
+ pre_periods: Optional[List[int]] = None,
778
849
  ) -> PreTrendsPowerResults:
779
850
  """
780
851
  Compute pre-trends power analysis.
@@ -786,6 +857,11 @@ class PreTrendsPower:
786
857
  M : float, optional
787
858
  Specific violation magnitude to evaluate. If None, evaluates at
788
859
  a default magnitude based on the data.
860
+ pre_periods : list of int, optional
861
+ Explicit list of pre-treatment periods to use for power analysis.
862
+ If None, attempts to infer from results.pre_periods. Use this when
863
+ you've estimated an event study with all periods in post_periods
864
+ and need to specify which are actually pre-treatment.
789
865
 
790
866
  Returns
791
867
  -------
@@ -793,7 +869,7 @@ class PreTrendsPower:
793
869
  Power analysis results including power and MDV.
794
870
  """
795
871
  # Extract pre-period parameters
796
- effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
872
+ effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
797
873
 
798
874
  # Get violation weights
799
875
  weights = self._get_violation_weights(n_pre)
@@ -831,6 +907,7 @@ class PreTrendsPower:
831
907
  self,
832
908
  results: Union[MultiPeriodDiDResults, Any],
833
909
  M: float,
910
+ pre_periods: Optional[List[int]] = None,
834
911
  ) -> float:
835
912
  """
836
913
  Compute power to detect a specific violation magnitude.
@@ -841,13 +918,15 @@ class PreTrendsPower:
841
918
  Event study results.
842
919
  M : float
843
920
  Violation magnitude.
921
+ pre_periods : list of int, optional
922
+ Explicit list of pre-treatment periods. See fit() for details.
844
923
 
845
924
  Returns
846
925
  -------
847
926
  float
848
927
  Power to detect violation of magnitude M.
849
928
  """
850
- result = self.fit(results, M=M)
929
+ result = self.fit(results, M=M, pre_periods=pre_periods)
851
930
  return result.power
852
931
 
853
932
  def power_curve(
@@ -855,6 +934,7 @@ class PreTrendsPower:
855
934
  results: Union[MultiPeriodDiDResults, Any],
856
935
  M_grid: Optional[List[float]] = None,
857
936
  n_points: int = 50,
937
+ pre_periods: Optional[List[int]] = None,
858
938
  ) -> PreTrendsPowerCurve:
859
939
  """
860
940
  Compute power across a range of violation magnitudes.
@@ -868,6 +948,8 @@ class PreTrendsPower:
868
948
  automatic grid from 0 to 2.5 * MDV.
869
949
  n_points : int, default=50
870
950
  Number of points in automatic grid.
951
+ pre_periods : list of int, optional
952
+ Explicit list of pre-treatment periods. See fit() for details.
871
953
 
872
954
  Returns
873
955
  -------
@@ -875,7 +957,7 @@ class PreTrendsPower:
875
957
  Power curve data with plot method.
876
958
  """
877
959
  # Extract parameters
878
- effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
960
+ _, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
879
961
  weights = self._get_violation_weights(n_pre)
880
962
 
881
963
  # Compute MDV
@@ -906,6 +988,7 @@ class PreTrendsPower:
906
988
  def sensitivity_to_honest_did(
907
989
  self,
908
990
  results: Union[MultiPeriodDiDResults, Any],
991
+ pre_periods: Optional[List[int]] = None,
909
992
  ) -> Dict[str, Any]:
910
993
  """
911
994
  Compare pre-trends power analysis with HonestDiD sensitivity.
@@ -917,6 +1000,8 @@ class PreTrendsPower:
917
1000
  ----------
918
1001
  results : results object
919
1002
  Event study results.
1003
+ pre_periods : list of int, optional
1004
+ Explicit list of pre-treatment periods. See fit() for details.
920
1005
 
921
1006
  Returns
922
1007
  -------
@@ -926,7 +1011,7 @@ class PreTrendsPower:
926
1011
  - honest_M_at_mdv: Corresponding M value for HonestDiD
927
1012
  - interpretation: Text explaining the relationship
928
1013
  """
929
- pt_results = self.fit(results)
1014
+ pt_results = self.fit(results, pre_periods=pre_periods)
930
1015
  mdv = pt_results.mdv
931
1016
 
932
1017
  # The MDV represents the size of violation the test could detect
@@ -993,6 +1078,7 @@ def compute_pretrends_power(
993
1078
  alpha: float = 0.05,
994
1079
  target_power: float = 0.80,
995
1080
  violation_type: str = "linear",
1081
+ pre_periods: Optional[List[int]] = None,
996
1082
  ) -> PreTrendsPowerResults:
997
1083
  """
998
1084
  Convenience function for pre-trends power analysis.
@@ -1009,6 +1095,9 @@ def compute_pretrends_power(
1009
1095
  Target power for MDV calculation.
1010
1096
  violation_type : str, default='linear'
1011
1097
  Type of violation pattern.
1098
+ pre_periods : list of int, optional
1099
+ Explicit list of pre-treatment periods. If None, attempts to infer
1100
+ from results. Use when you've estimated all periods as post_periods.
1012
1101
 
1013
1102
  Returns
1014
1103
  -------
@@ -1021,7 +1110,7 @@ def compute_pretrends_power(
1021
1110
  >>> from diff_diff.pretrends import compute_pretrends_power
1022
1111
  >>>
1023
1112
  >>> results = MultiPeriodDiD().fit(data, ...)
1024
- >>> power_results = compute_pretrends_power(results)
1113
+ >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3])
1025
1114
  >>> print(f"MDV: {power_results.mdv:.3f}")
1026
1115
  >>> print(f"Power: {power_results.power:.1%}")
1027
1116
  """
@@ -1030,7 +1119,7 @@ def compute_pretrends_power(
1030
1119
  power=target_power,
1031
1120
  violation_type=violation_type,
1032
1121
  )
1033
- return pt.fit(results, M=M)
1122
+ return pt.fit(results, M=M, pre_periods=pre_periods)
1034
1123
 
1035
1124
 
1036
1125
  def compute_mdv(
@@ -1038,6 +1127,7 @@ def compute_mdv(
1038
1127
  alpha: float = 0.05,
1039
1128
  target_power: float = 0.80,
1040
1129
  violation_type: str = "linear",
1130
+ pre_periods: Optional[List[int]] = None,
1041
1131
  ) -> float:
1042
1132
  """
1043
1133
  Compute minimum detectable violation.
@@ -1049,9 +1139,12 @@ def compute_mdv(
1049
1139
  alpha : float, default=0.05
1050
1140
  Significance level.
1051
1141
  target_power : float, default=0.80
1052
- Target power.
1142
+ Target power for MDV calculation.
1053
1143
  violation_type : str, default='linear'
1054
1144
  Type of violation pattern.
1145
+ pre_periods : list of int, optional
1146
+ Explicit list of pre-treatment periods. If None, attempts to infer
1147
+ from results. Use when you've estimated all periods as post_periods.
1055
1148
 
1056
1149
  Returns
1057
1150
  -------
@@ -1063,5 +1156,5 @@ def compute_mdv(
1063
1156
  power=target_power,
1064
1157
  violation_type=violation_type,
1065
1158
  )
1066
- result = pt.fit(results)
1159
+ result = pt.fit(results, pre_periods=pre_periods)
1067
1160
  return result.mdv
@@ -1053,6 +1053,10 @@ class CallawaySantAnna:
1053
1053
  df[time] = pd.to_numeric(df[time])
1054
1054
  df[first_treat] = pd.to_numeric(df[first_treat])
1055
1055
 
1056
+ # Standardize the first_treat column name for internal use
1057
+ # This avoids hardcoding column names in internal methods
1058
+ df['first_treat'] = df[first_treat]
1059
+
1056
1060
  # Identify groups and time periods
1057
1061
  time_periods = sorted(df[time].unique())
1058
1062
  treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])