diff-diff 2.1.0__tar.gz → 2.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diff_diff-2.1.0 → diff_diff-2.1.1}/PKG-INFO +1 -1
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/__init__.py +1 -1
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/_backend.py +16 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/pretrends.py +104 -11
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/staggered.py +4 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/trop.py +516 -161
- {diff_diff-2.1.0 → diff_diff-2.1.1}/pyproject.toml +1 -1
- {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/Cargo.lock +6 -81
- {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/Cargo.toml +1 -1
- {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/lib.rs +6 -0
- diff_diff-2.1.1/rust/src/trop.rs +861 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/README.md +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/bacon.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/datasets.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/diagnostics.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/estimators.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/honest_did.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/linalg.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/power.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/prep.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/results.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/sun_abraham.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/synthetic_did.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/triple_diff.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/twfe.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/utils.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/diff_diff/visualization.py +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/bootstrap.rs +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/linalg.rs +0 -0
- {diff_diff-2.1.0 → diff_diff-2.1.1}/rust/src/weights.rs +0 -0
|
@@ -23,6 +23,10 @@ try:
|
|
|
23
23
|
project_simplex as _rust_project_simplex,
|
|
24
24
|
solve_ols as _rust_solve_ols,
|
|
25
25
|
compute_robust_vcov as _rust_compute_robust_vcov,
|
|
26
|
+
# TROP estimator acceleration
|
|
27
|
+
compute_unit_distance_matrix as _rust_unit_distance_matrix,
|
|
28
|
+
loocv_grid_search as _rust_loocv_grid_search,
|
|
29
|
+
bootstrap_trop_variance as _rust_bootstrap_trop_variance,
|
|
26
30
|
)
|
|
27
31
|
_rust_available = True
|
|
28
32
|
except ImportError:
|
|
@@ -32,6 +36,10 @@ except ImportError:
|
|
|
32
36
|
_rust_project_simplex = None
|
|
33
37
|
_rust_solve_ols = None
|
|
34
38
|
_rust_compute_robust_vcov = None
|
|
39
|
+
# TROP estimator acceleration
|
|
40
|
+
_rust_unit_distance_matrix = None
|
|
41
|
+
_rust_loocv_grid_search = None
|
|
42
|
+
_rust_bootstrap_trop_variance = None
|
|
35
43
|
|
|
36
44
|
# Determine final backend based on environment variable and availability
|
|
37
45
|
if _backend_env == 'python':
|
|
@@ -42,6 +50,10 @@ if _backend_env == 'python':
|
|
|
42
50
|
_rust_project_simplex = None
|
|
43
51
|
_rust_solve_ols = None
|
|
44
52
|
_rust_compute_robust_vcov = None
|
|
53
|
+
# TROP estimator acceleration
|
|
54
|
+
_rust_unit_distance_matrix = None
|
|
55
|
+
_rust_loocv_grid_search = None
|
|
56
|
+
_rust_bootstrap_trop_variance = None
|
|
45
57
|
elif _backend_env == 'rust':
|
|
46
58
|
# Force Rust mode - fail if not available
|
|
47
59
|
if not _rust_available:
|
|
@@ -61,4 +73,8 @@ __all__ = [
|
|
|
61
73
|
'_rust_project_simplex',
|
|
62
74
|
'_rust_solve_ols',
|
|
63
75
|
'_rust_compute_robust_vcov',
|
|
76
|
+
# TROP estimator acceleration
|
|
77
|
+
'_rust_unit_distance_matrix',
|
|
78
|
+
'_rust_loocv_grid_search',
|
|
79
|
+
'_rust_bootstrap_trop_variance',
|
|
64
80
|
]
|
|
@@ -202,6 +202,63 @@ class PreTrendsPowerResults:
|
|
|
202
202
|
"""Convert results to DataFrame."""
|
|
203
203
|
return pd.DataFrame([self.to_dict()])
|
|
204
204
|
|
|
205
|
+
def power_at(self, M: float) -> float:
|
|
206
|
+
"""
|
|
207
|
+
Compute power to detect a specific violation magnitude.
|
|
208
|
+
|
|
209
|
+
This method allows computing power at different M values without
|
|
210
|
+
re-fitting the model, using the stored variance-covariance matrix.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
M : float
|
|
215
|
+
Violation magnitude to evaluate.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
float
|
|
220
|
+
Power to detect violation of magnitude M.
|
|
221
|
+
"""
|
|
222
|
+
from scipy import stats
|
|
223
|
+
|
|
224
|
+
n_pre = self.n_pre_periods
|
|
225
|
+
|
|
226
|
+
# Reconstruct violation weights based on violation type
|
|
227
|
+
# Must match PreTrendsPower._get_violation_weights() exactly
|
|
228
|
+
if self.violation_type == "linear":
|
|
229
|
+
# Linear trend: weights decrease toward treatment
|
|
230
|
+
# [n-1, n-2, ..., 1, 0] for n pre-periods
|
|
231
|
+
weights = np.arange(-n_pre + 1, 1, dtype=float)
|
|
232
|
+
weights = -weights # Now [n-1, n-2, ..., 1, 0]
|
|
233
|
+
elif self.violation_type == "constant":
|
|
234
|
+
weights = np.ones(n_pre)
|
|
235
|
+
elif self.violation_type == "last_period":
|
|
236
|
+
weights = np.zeros(n_pre)
|
|
237
|
+
weights[-1] = 1.0
|
|
238
|
+
else:
|
|
239
|
+
# For custom, we can't reconstruct - use equal weights as fallback
|
|
240
|
+
weights = np.ones(n_pre)
|
|
241
|
+
|
|
242
|
+
# Normalize weights to unit L2 norm
|
|
243
|
+
norm = np.linalg.norm(weights)
|
|
244
|
+
if norm > 0:
|
|
245
|
+
weights = weights / norm
|
|
246
|
+
|
|
247
|
+
# Compute non-centrality parameter
|
|
248
|
+
try:
|
|
249
|
+
vcov_inv = np.linalg.inv(self.vcov)
|
|
250
|
+
except np.linalg.LinAlgError:
|
|
251
|
+
vcov_inv = np.linalg.pinv(self.vcov)
|
|
252
|
+
|
|
253
|
+
# delta = M * weights
|
|
254
|
+
# nc = delta' * V^{-1} * delta
|
|
255
|
+
noncentrality = M**2 * (weights @ vcov_inv @ weights)
|
|
256
|
+
|
|
257
|
+
# Compute power using non-central chi-squared
|
|
258
|
+
power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality)
|
|
259
|
+
|
|
260
|
+
return float(power)
|
|
261
|
+
|
|
205
262
|
|
|
206
263
|
@dataclass
|
|
207
264
|
class PreTrendsPowerCurve:
|
|
@@ -471,10 +528,18 @@ class PreTrendsPower:
|
|
|
471
528
|
def _extract_pre_period_params(
|
|
472
529
|
self,
|
|
473
530
|
results: Union[MultiPeriodDiDResults, Any],
|
|
531
|
+
pre_periods: Optional[List[int]] = None,
|
|
474
532
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
|
|
475
533
|
"""
|
|
476
534
|
Extract pre-period parameters from results.
|
|
477
535
|
|
|
536
|
+
Parameters
|
|
537
|
+
----------
|
|
538
|
+
results : MultiPeriodDiDResults or similar
|
|
539
|
+
Results object from event study estimation.
|
|
540
|
+
pre_periods : list of int, optional
|
|
541
|
+
Explicit list of pre-treatment periods. If None, uses results.pre_periods.
|
|
542
|
+
|
|
478
543
|
Returns
|
|
479
544
|
-------
|
|
480
545
|
effects : np.ndarray
|
|
@@ -487,13 +552,18 @@ class PreTrendsPower:
|
|
|
487
552
|
Number of pre-periods.
|
|
488
553
|
"""
|
|
489
554
|
if isinstance(results, MultiPeriodDiDResults):
|
|
490
|
-
# Get pre-period information
|
|
491
|
-
|
|
555
|
+
# Get pre-period information - use explicit pre_periods if provided
|
|
556
|
+
if pre_periods is not None:
|
|
557
|
+
all_pre_periods = list(pre_periods)
|
|
558
|
+
else:
|
|
559
|
+
all_pre_periods = results.pre_periods
|
|
492
560
|
|
|
493
561
|
if len(all_pre_periods) == 0:
|
|
494
562
|
raise ValueError(
|
|
495
563
|
"No pre-treatment periods found in results. "
|
|
496
|
-
"Pre-trends power analysis requires pre-period coefficients."
|
|
564
|
+
"Pre-trends power analysis requires pre-period coefficients. "
|
|
565
|
+
"If you estimated all periods as post_periods, use the pre_periods "
|
|
566
|
+
"parameter to specify which are actually pre-treatment."
|
|
497
567
|
)
|
|
498
568
|
|
|
499
569
|
# Only include periods with actual estimated coefficients
|
|
@@ -775,6 +845,7 @@ class PreTrendsPower:
|
|
|
775
845
|
self,
|
|
776
846
|
results: Union[MultiPeriodDiDResults, Any],
|
|
777
847
|
M: Optional[float] = None,
|
|
848
|
+
pre_periods: Optional[List[int]] = None,
|
|
778
849
|
) -> PreTrendsPowerResults:
|
|
779
850
|
"""
|
|
780
851
|
Compute pre-trends power analysis.
|
|
@@ -786,6 +857,11 @@ class PreTrendsPower:
|
|
|
786
857
|
M : float, optional
|
|
787
858
|
Specific violation magnitude to evaluate. If None, evaluates at
|
|
788
859
|
a default magnitude based on the data.
|
|
860
|
+
pre_periods : list of int, optional
|
|
861
|
+
Explicit list of pre-treatment periods to use for power analysis.
|
|
862
|
+
If None, attempts to infer from results.pre_periods. Use this when
|
|
863
|
+
you've estimated an event study with all periods in post_periods
|
|
864
|
+
and need to specify which are actually pre-treatment.
|
|
789
865
|
|
|
790
866
|
Returns
|
|
791
867
|
-------
|
|
@@ -793,7 +869,7 @@ class PreTrendsPower:
|
|
|
793
869
|
Power analysis results including power and MDV.
|
|
794
870
|
"""
|
|
795
871
|
# Extract pre-period parameters
|
|
796
|
-
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
|
|
872
|
+
effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
|
|
797
873
|
|
|
798
874
|
# Get violation weights
|
|
799
875
|
weights = self._get_violation_weights(n_pre)
|
|
@@ -831,6 +907,7 @@ class PreTrendsPower:
|
|
|
831
907
|
self,
|
|
832
908
|
results: Union[MultiPeriodDiDResults, Any],
|
|
833
909
|
M: float,
|
|
910
|
+
pre_periods: Optional[List[int]] = None,
|
|
834
911
|
) -> float:
|
|
835
912
|
"""
|
|
836
913
|
Compute power to detect a specific violation magnitude.
|
|
@@ -841,13 +918,15 @@ class PreTrendsPower:
|
|
|
841
918
|
Event study results.
|
|
842
919
|
M : float
|
|
843
920
|
Violation magnitude.
|
|
921
|
+
pre_periods : list of int, optional
|
|
922
|
+
Explicit list of pre-treatment periods. See fit() for details.
|
|
844
923
|
|
|
845
924
|
Returns
|
|
846
925
|
-------
|
|
847
926
|
float
|
|
848
927
|
Power to detect violation of magnitude M.
|
|
849
928
|
"""
|
|
850
|
-
result = self.fit(results, M=M)
|
|
929
|
+
result = self.fit(results, M=M, pre_periods=pre_periods)
|
|
851
930
|
return result.power
|
|
852
931
|
|
|
853
932
|
def power_curve(
|
|
@@ -855,6 +934,7 @@ class PreTrendsPower:
|
|
|
855
934
|
results: Union[MultiPeriodDiDResults, Any],
|
|
856
935
|
M_grid: Optional[List[float]] = None,
|
|
857
936
|
n_points: int = 50,
|
|
937
|
+
pre_periods: Optional[List[int]] = None,
|
|
858
938
|
) -> PreTrendsPowerCurve:
|
|
859
939
|
"""
|
|
860
940
|
Compute power across a range of violation magnitudes.
|
|
@@ -868,6 +948,8 @@ class PreTrendsPower:
|
|
|
868
948
|
automatic grid from 0 to 2.5 * MDV.
|
|
869
949
|
n_points : int, default=50
|
|
870
950
|
Number of points in automatic grid.
|
|
951
|
+
pre_periods : list of int, optional
|
|
952
|
+
Explicit list of pre-treatment periods. See fit() for details.
|
|
871
953
|
|
|
872
954
|
Returns
|
|
873
955
|
-------
|
|
@@ -875,7 +957,7 @@ class PreTrendsPower:
|
|
|
875
957
|
Power curve data with plot method.
|
|
876
958
|
"""
|
|
877
959
|
# Extract parameters
|
|
878
|
-
|
|
960
|
+
_, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
|
|
879
961
|
weights = self._get_violation_weights(n_pre)
|
|
880
962
|
|
|
881
963
|
# Compute MDV
|
|
@@ -906,6 +988,7 @@ class PreTrendsPower:
|
|
|
906
988
|
def sensitivity_to_honest_did(
|
|
907
989
|
self,
|
|
908
990
|
results: Union[MultiPeriodDiDResults, Any],
|
|
991
|
+
pre_periods: Optional[List[int]] = None,
|
|
909
992
|
) -> Dict[str, Any]:
|
|
910
993
|
"""
|
|
911
994
|
Compare pre-trends power analysis with HonestDiD sensitivity.
|
|
@@ -917,6 +1000,8 @@ class PreTrendsPower:
|
|
|
917
1000
|
----------
|
|
918
1001
|
results : results object
|
|
919
1002
|
Event study results.
|
|
1003
|
+
pre_periods : list of int, optional
|
|
1004
|
+
Explicit list of pre-treatment periods. See fit() for details.
|
|
920
1005
|
|
|
921
1006
|
Returns
|
|
922
1007
|
-------
|
|
@@ -926,7 +1011,7 @@ class PreTrendsPower:
|
|
|
926
1011
|
- honest_M_at_mdv: Corresponding M value for HonestDiD
|
|
927
1012
|
- interpretation: Text explaining the relationship
|
|
928
1013
|
"""
|
|
929
|
-
pt_results = self.fit(results)
|
|
1014
|
+
pt_results = self.fit(results, pre_periods=pre_periods)
|
|
930
1015
|
mdv = pt_results.mdv
|
|
931
1016
|
|
|
932
1017
|
# The MDV represents the size of violation the test could detect
|
|
@@ -993,6 +1078,7 @@ def compute_pretrends_power(
|
|
|
993
1078
|
alpha: float = 0.05,
|
|
994
1079
|
target_power: float = 0.80,
|
|
995
1080
|
violation_type: str = "linear",
|
|
1081
|
+
pre_periods: Optional[List[int]] = None,
|
|
996
1082
|
) -> PreTrendsPowerResults:
|
|
997
1083
|
"""
|
|
998
1084
|
Convenience function for pre-trends power analysis.
|
|
@@ -1009,6 +1095,9 @@ def compute_pretrends_power(
|
|
|
1009
1095
|
Target power for MDV calculation.
|
|
1010
1096
|
violation_type : str, default='linear'
|
|
1011
1097
|
Type of violation pattern.
|
|
1098
|
+
pre_periods : list of int, optional
|
|
1099
|
+
Explicit list of pre-treatment periods. If None, attempts to infer
|
|
1100
|
+
from results. Use when you've estimated all periods as post_periods.
|
|
1012
1101
|
|
|
1013
1102
|
Returns
|
|
1014
1103
|
-------
|
|
@@ -1021,7 +1110,7 @@ def compute_pretrends_power(
|
|
|
1021
1110
|
>>> from diff_diff.pretrends import compute_pretrends_power
|
|
1022
1111
|
>>>
|
|
1023
1112
|
>>> results = MultiPeriodDiD().fit(data, ...)
|
|
1024
|
-
>>> power_results = compute_pretrends_power(results)
|
|
1113
|
+
>>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3])
|
|
1025
1114
|
>>> print(f"MDV: {power_results.mdv:.3f}")
|
|
1026
1115
|
>>> print(f"Power: {power_results.power:.1%}")
|
|
1027
1116
|
"""
|
|
@@ -1030,7 +1119,7 @@ def compute_pretrends_power(
|
|
|
1030
1119
|
power=target_power,
|
|
1031
1120
|
violation_type=violation_type,
|
|
1032
1121
|
)
|
|
1033
|
-
return pt.fit(results, M=M)
|
|
1122
|
+
return pt.fit(results, M=M, pre_periods=pre_periods)
|
|
1034
1123
|
|
|
1035
1124
|
|
|
1036
1125
|
def compute_mdv(
|
|
@@ -1038,6 +1127,7 @@ def compute_mdv(
|
|
|
1038
1127
|
alpha: float = 0.05,
|
|
1039
1128
|
target_power: float = 0.80,
|
|
1040
1129
|
violation_type: str = "linear",
|
|
1130
|
+
pre_periods: Optional[List[int]] = None,
|
|
1041
1131
|
) -> float:
|
|
1042
1132
|
"""
|
|
1043
1133
|
Compute minimum detectable violation.
|
|
@@ -1049,9 +1139,12 @@ def compute_mdv(
|
|
|
1049
1139
|
alpha : float, default=0.05
|
|
1050
1140
|
Significance level.
|
|
1051
1141
|
target_power : float, default=0.80
|
|
1052
|
-
Target power.
|
|
1142
|
+
Target power for MDV calculation.
|
|
1053
1143
|
violation_type : str, default='linear'
|
|
1054
1144
|
Type of violation pattern.
|
|
1145
|
+
pre_periods : list of int, optional
|
|
1146
|
+
Explicit list of pre-treatment periods. If None, attempts to infer
|
|
1147
|
+
from results. Use when you've estimated all periods as post_periods.
|
|
1055
1148
|
|
|
1056
1149
|
Returns
|
|
1057
1150
|
-------
|
|
@@ -1063,5 +1156,5 @@ def compute_mdv(
|
|
|
1063
1156
|
power=target_power,
|
|
1064
1157
|
violation_type=violation_type,
|
|
1065
1158
|
)
|
|
1066
|
-
result = pt.fit(results)
|
|
1159
|
+
result = pt.fit(results, pre_periods=pre_periods)
|
|
1067
1160
|
return result.mdv
|
|
@@ -1053,6 +1053,10 @@ class CallawaySantAnna:
|
|
|
1053
1053
|
df[time] = pd.to_numeric(df[time])
|
|
1054
1054
|
df[first_treat] = pd.to_numeric(df[first_treat])
|
|
1055
1055
|
|
|
1056
|
+
# Standardize the first_treat column name for internal use
|
|
1057
|
+
# This avoids hardcoding column names in internal methods
|
|
1058
|
+
df['first_treat'] = df[first_treat]
|
|
1059
|
+
|
|
1056
1060
|
# Identify groups and time periods
|
|
1057
1061
|
time_periods = sorted(df[time].unique())
|
|
1058
1062
|
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
|