diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
|
@@ -0,0 +1,1545 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Staggered Triple Difference (DDD) estimator.
|
|
3
|
+
|
|
4
|
+
Implements Ortiz-Villavicencio & Sant'Anna (2025) for staggered adoption
|
|
5
|
+
settings with an eligibility dimension, combining group-time DDD effects
|
|
6
|
+
via GMM-optimal weighting.
|
|
7
|
+
|
|
8
|
+
Core pairwise DiD computation matches R's triplediff::compute_did() exactly
|
|
9
|
+
(Riesz/Hajek normalization, separate M1/M3 OR corrections, hessian = (X'WX)^{-1}*n).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import warnings
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from diff_diff.linalg import (
|
|
19
|
+
_check_propensity_diagnostics,
|
|
20
|
+
solve_logit,
|
|
21
|
+
)
|
|
22
|
+
from diff_diff.staggered_aggregation import (
|
|
23
|
+
CallawaySantAnnaAggregationMixin,
|
|
24
|
+
)
|
|
25
|
+
from diff_diff.staggered_bootstrap import (
|
|
26
|
+
CallawaySantAnnaBootstrapMixin,
|
|
27
|
+
)
|
|
28
|
+
from diff_diff.staggered_triple_diff_results import StaggeredTripleDiffResults
|
|
29
|
+
from diff_diff.utils import safe_inference
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"StaggeredTripleDifference",
|
|
33
|
+
"StaggeredTripleDiffResults",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
# Type alias for pre-computed structures
|
|
37
|
+
PrecomputedData = Dict[str, Any]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class StaggeredTripleDifference(
|
|
41
|
+
CallawaySantAnnaBootstrapMixin,
|
|
42
|
+
CallawaySantAnnaAggregationMixin,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Staggered Triple Difference (DDD) estimator.
|
|
46
|
+
|
|
47
|
+
Computes group-time average treatment effects ATT(g,t) for settings
|
|
48
|
+
with staggered adoption and a binary eligibility dimension, using the
|
|
49
|
+
three-DiD decomposition of Ortiz-Villavicencio & Sant'Anna (2025).
|
|
50
|
+
|
|
51
|
+
Multiple comparison groups are combined via GMM-optimal (inverse-variance)
|
|
52
|
+
weighting. Event study, group, and overall aggregations are supported.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
estimation_method : str, default="dr"
|
|
57
|
+
Estimation method: "dr" (doubly robust), "ipw" (inverse probability
|
|
58
|
+
weighting), or "reg" (regression adjustment).
|
|
59
|
+
alpha : float, default=0.05
|
|
60
|
+
Significance level.
|
|
61
|
+
anticipation : int, default=0
|
|
62
|
+
Number of anticipation periods.
|
|
63
|
+
base_period : str, default="varying"
|
|
64
|
+
Base period selection: "varying" (consecutive comparisons) or
|
|
65
|
+
"universal" (always vs g-1-anticipation).
|
|
66
|
+
n_bootstrap : int, default=0
|
|
67
|
+
Number of multiplier bootstrap repetitions. 0 disables bootstrap.
|
|
68
|
+
bootstrap_weights : str, default="rademacher"
|
|
69
|
+
Bootstrap weight distribution: "rademacher", "mammen", or "webb".
|
|
70
|
+
seed : int or None, default=None
|
|
71
|
+
Random seed for reproducibility.
|
|
72
|
+
cband : bool, default=True
|
|
73
|
+
Whether to compute simultaneous confidence bands.
|
|
74
|
+
pscore_trim : float, default=0.01
|
|
75
|
+
Propensity score trimming bound.
|
|
76
|
+
cluster : str or None, default=None
|
|
77
|
+
Column name for cluster-robust standard errors.
|
|
78
|
+
rank_deficient_action : str, default="warn"
|
|
79
|
+
Action for rank-deficient design matrices: "warn", "error", "silent".
|
|
80
|
+
epv_threshold : float, default=10
|
|
81
|
+
Minimum events per variable for propensity score logistic regression.
|
|
82
|
+
A warning is emitted when EPV falls below this threshold.
|
|
83
|
+
pscore_fallback : str, default="error"
|
|
84
|
+
Action when propensity score estimation fails: "error" (raise) or
|
|
85
|
+
"unconditional" (fall back to unconditional propensity).
|
|
86
|
+
|
|
87
|
+
References
|
|
88
|
+
----------
|
|
89
|
+
Ortiz-Villavicencio, M. & Sant'Anna, P.H.C. (2025). "Better Understanding
|
|
90
|
+
Triple Differences Estimators." arXiv:2505.09942.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
estimation_method: str = "dr",
|
|
96
|
+
control_group: str = "notyettreated",
|
|
97
|
+
alpha: float = 0.05,
|
|
98
|
+
anticipation: int = 0,
|
|
99
|
+
base_period: str = "varying",
|
|
100
|
+
n_bootstrap: int = 0,
|
|
101
|
+
bootstrap_weights: str = "rademacher",
|
|
102
|
+
seed: Optional[int] = None,
|
|
103
|
+
cband: bool = True,
|
|
104
|
+
pscore_trim: float = 0.01,
|
|
105
|
+
cluster: Optional[str] = None,
|
|
106
|
+
rank_deficient_action: str = "warn",
|
|
107
|
+
epv_threshold: float = 10,
|
|
108
|
+
pscore_fallback: str = "error",
|
|
109
|
+
):
|
|
110
|
+
if estimation_method not in ["dr", "ipw", "reg"]:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'"
|
|
113
|
+
)
|
|
114
|
+
if control_group not in ["nevertreated", "notyettreated"]:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"control_group must be 'nevertreated' or 'notyettreated', "
|
|
117
|
+
f"got '{control_group}'"
|
|
118
|
+
)
|
|
119
|
+
if not (0 < pscore_trim < 0.5):
|
|
120
|
+
raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}")
|
|
121
|
+
if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
|
|
124
|
+
f"got '{bootstrap_weights}'"
|
|
125
|
+
)
|
|
126
|
+
if rank_deficient_action not in ["warn", "error", "silent"]:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
|
|
129
|
+
f"got '{rank_deficient_action}'"
|
|
130
|
+
)
|
|
131
|
+
if base_period not in ["varying", "universal"]:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"base_period must be 'varying' or 'universal', " f"got '{base_period}'"
|
|
134
|
+
)
|
|
135
|
+
if epv_threshold <= 0:
|
|
136
|
+
raise ValueError(f"epv_threshold must be > 0, got {epv_threshold}")
|
|
137
|
+
if pscore_fallback not in ["error", "unconditional"]:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"pscore_fallback must be 'error' or 'unconditional', "
|
|
140
|
+
f"got '{pscore_fallback}'"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
self.estimation_method = estimation_method
|
|
144
|
+
self.control_group = control_group
|
|
145
|
+
self.alpha = alpha
|
|
146
|
+
self.anticipation = anticipation
|
|
147
|
+
self.base_period = base_period
|
|
148
|
+
self.n_bootstrap = n_bootstrap
|
|
149
|
+
self.bootstrap_weights = bootstrap_weights
|
|
150
|
+
self.seed = seed
|
|
151
|
+
self.cband = cband
|
|
152
|
+
self.pscore_trim = pscore_trim
|
|
153
|
+
self.cluster = cluster
|
|
154
|
+
self.rank_deficient_action = rank_deficient_action
|
|
155
|
+
self.epv_threshold = epv_threshold
|
|
156
|
+
self.pscore_fallback = pscore_fallback
|
|
157
|
+
|
|
158
|
+
self.is_fitted_ = False
|
|
159
|
+
self.results_: Optional[StaggeredTripleDiffResults] = None
|
|
160
|
+
|
|
161
|
+
def get_params(self) -> Dict[str, Any]:
|
|
162
|
+
"""Get estimator parameters (sklearn-compatible)."""
|
|
163
|
+
return {
|
|
164
|
+
"estimation_method": self.estimation_method,
|
|
165
|
+
"control_group": self.control_group,
|
|
166
|
+
"alpha": self.alpha,
|
|
167
|
+
"anticipation": self.anticipation,
|
|
168
|
+
"base_period": self.base_period,
|
|
169
|
+
"n_bootstrap": self.n_bootstrap,
|
|
170
|
+
"bootstrap_weights": self.bootstrap_weights,
|
|
171
|
+
"seed": self.seed,
|
|
172
|
+
"cband": self.cband,
|
|
173
|
+
"pscore_trim": self.pscore_trim,
|
|
174
|
+
"cluster": self.cluster,
|
|
175
|
+
"rank_deficient_action": self.rank_deficient_action,
|
|
176
|
+
"epv_threshold": self.epv_threshold,
|
|
177
|
+
"pscore_fallback": self.pscore_fallback,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
def set_params(self, **params) -> "StaggeredTripleDifference":
|
|
181
|
+
"""Set estimator parameters (sklearn-compatible)."""
|
|
182
|
+
valid_params = self.get_params()
|
|
183
|
+
for key, value in params.items():
|
|
184
|
+
if key not in valid_params:
|
|
185
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
186
|
+
setattr(self, key, value)
|
|
187
|
+
if "bootstrap_weights" in params:
|
|
188
|
+
self.bootstrap_weights = params["bootstrap_weights"]
|
|
189
|
+
return self
|
|
190
|
+
|
|
191
|
+
# ------------------------------------------------------------------
|
|
192
|
+
# fit()
|
|
193
|
+
# ------------------------------------------------------------------
|
|
194
|
+
|
|
195
|
+
def fit(
|
|
196
|
+
self,
|
|
197
|
+
data: pd.DataFrame,
|
|
198
|
+
outcome: str,
|
|
199
|
+
unit: str,
|
|
200
|
+
time: str,
|
|
201
|
+
first_treat: str,
|
|
202
|
+
eligibility: str,
|
|
203
|
+
covariates: Optional[List[str]] = None,
|
|
204
|
+
aggregate: Optional[str] = None,
|
|
205
|
+
balance_e: Optional[int] = None,
|
|
206
|
+
survey_design: object = None,
|
|
207
|
+
) -> StaggeredTripleDiffResults:
|
|
208
|
+
"""
|
|
209
|
+
Fit the staggered triple difference estimator.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
data : pd.DataFrame
|
|
214
|
+
Panel data.
|
|
215
|
+
outcome : str
|
|
216
|
+
Outcome variable column name.
|
|
217
|
+
unit : str
|
|
218
|
+
Unit identifier column name.
|
|
219
|
+
time : str
|
|
220
|
+
Time period column name.
|
|
221
|
+
first_treat : str
|
|
222
|
+
Column with the enabling period for each unit's group.
|
|
223
|
+
Use 0 or np.inf for never-enabled units.
|
|
224
|
+
eligibility : str
|
|
225
|
+
Binary eligibility indicator column (0/1, time-invariant).
|
|
226
|
+
covariates : list of str, optional
|
|
227
|
+
Covariate column names.
|
|
228
|
+
aggregate : str, optional
|
|
229
|
+
Aggregation method: "event_study", "group", "simple", or "all".
|
|
230
|
+
balance_e : int, optional
|
|
231
|
+
Event time to balance on for event study.
|
|
232
|
+
survey_design : SurveyDesign, optional
|
|
233
|
+
Survey design specification for complex survey data. When
|
|
234
|
+
provided, uses survey weights for estimation (weighted Riesz
|
|
235
|
+
representers, weighted logit, weighted OLS) and design-based
|
|
236
|
+
variance for aggregated SEs (overall, event study, group) via
|
|
237
|
+
Taylor Series Linearization or replicate weights. Requires
|
|
238
|
+
``weight_type='pweight'``.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
StaggeredTripleDiffResults
|
|
243
|
+
"""
|
|
244
|
+
from diff_diff.survey import (
|
|
245
|
+
_resolve_survey_for_fit,
|
|
246
|
+
_validate_unit_constant_survey,
|
|
247
|
+
compute_survey_metadata,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
|
|
251
|
+
_resolve_survey_for_fit(survey_design, data, "analytical")
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if resolved_survey is not None:
|
|
255
|
+
_validate_unit_constant_survey(data, unit, survey_design)
|
|
256
|
+
if resolved_survey.weight_type != "pweight":
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"StaggeredTripleDifference survey support requires "
|
|
259
|
+
f"weight_type='pweight', got '{resolved_survey.weight_type}'. "
|
|
260
|
+
f"The survey variance math assumes probability weights."
|
|
261
|
+
)
|
|
262
|
+
if aggregate is not None and aggregate not in [
|
|
263
|
+
"event_study",
|
|
264
|
+
"group",
|
|
265
|
+
"simple",
|
|
266
|
+
"all",
|
|
267
|
+
]:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"aggregate must be 'event_study', 'group', 'simple', or 'all', "
|
|
270
|
+
f"got '{aggregate}'"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
df = data.copy()
|
|
274
|
+
self._validate_inputs(df, outcome, unit, time, first_treat, eligibility, covariates)
|
|
275
|
+
|
|
276
|
+
if self.cluster is not None:
|
|
277
|
+
warnings.warn(
|
|
278
|
+
"cluster parameter is accepted but cluster-robust analytical SEs "
|
|
279
|
+
"are not yet implemented for staggered DDD. Use n_bootstrap > 0 "
|
|
280
|
+
"for unit-level clustered inference via multiplier bootstrap.",
|
|
281
|
+
UserWarning,
|
|
282
|
+
stacklevel=2,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
if first_treat != "first_treat":
|
|
286
|
+
df["first_treat"] = df[first_treat]
|
|
287
|
+
df["first_treat"] = df["first_treat"].replace([np.inf, float("inf")], 0)
|
|
288
|
+
|
|
289
|
+
precomputed = self._precompute_structures(
|
|
290
|
+
df,
|
|
291
|
+
outcome,
|
|
292
|
+
unit,
|
|
293
|
+
time,
|
|
294
|
+
eligibility,
|
|
295
|
+
covariates,
|
|
296
|
+
resolved_survey=resolved_survey,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Recompute survey metadata from unit-level resolved survey
|
|
300
|
+
if resolved_survey is not None and survey_metadata is not None:
|
|
301
|
+
resolved_survey_unit = precomputed.get("resolved_survey_unit")
|
|
302
|
+
if resolved_survey_unit is not None:
|
|
303
|
+
unit_w = resolved_survey_unit.weights
|
|
304
|
+
survey_metadata = compute_survey_metadata(resolved_survey_unit, unit_w)
|
|
305
|
+
|
|
306
|
+
# Survey df for t-distribution critical values
|
|
307
|
+
df_survey = precomputed.get("df_survey")
|
|
308
|
+
if (
|
|
309
|
+
df_survey is None
|
|
310
|
+
and resolved_survey is not None
|
|
311
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
312
|
+
and resolved_survey.uses_replicate_variance
|
|
313
|
+
):
|
|
314
|
+
df_survey = 0 # Forces NaN inference for undefined replicate df
|
|
315
|
+
|
|
316
|
+
has_survey = resolved_survey is not None
|
|
317
|
+
|
|
318
|
+
treatment_groups = precomputed["treatment_groups"]
|
|
319
|
+
time_periods = precomputed["time_periods"]
|
|
320
|
+
all_units = precomputed["all_units"]
|
|
321
|
+
time_to_col = precomputed["time_to_col"]
|
|
322
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
323
|
+
eligibility_per_unit = precomputed["eligibility_per_unit"]
|
|
324
|
+
n_units = len(all_units)
|
|
325
|
+
|
|
326
|
+
pscore_cache: Dict = {}
|
|
327
|
+
# Skip Cholesky OR cache when survey weights present (X'WX != X'X)
|
|
328
|
+
cho_cache: Dict = {} if not has_survey else None
|
|
329
|
+
|
|
330
|
+
group_time_effects: Dict[Tuple, Dict[str, Any]] = {}
|
|
331
|
+
influence_func_info: Dict[Tuple, Dict[str, Any]] = {}
|
|
332
|
+
comparison_group_counts: Dict[Tuple, int] = {}
|
|
333
|
+
gmm_weights_store: Dict[Tuple, Dict] = {}
|
|
334
|
+
epv_diagnostics: Optional[Dict[Tuple, Dict[str, Any]]] = (
|
|
335
|
+
{} if (covariates and self.estimation_method in ("ipw", "dr")) else None
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
for g in treatment_groups:
|
|
339
|
+
# In universal mode, skip the reference period (t == g-1-anticipation)
|
|
340
|
+
# so it's omitted from GT estimation. The event-study mixin injects
|
|
341
|
+
# a synthetic reference row with effect=0, matching CS behavior.
|
|
342
|
+
if self.base_period == "universal":
|
|
343
|
+
universal_base = g - 1 - self.anticipation
|
|
344
|
+
valid_periods = [t for t in time_periods if t != universal_base]
|
|
345
|
+
else:
|
|
346
|
+
valid_periods = time_periods
|
|
347
|
+
|
|
348
|
+
for t in valid_periods:
|
|
349
|
+
base_period_val = self._get_base_period(g, t)
|
|
350
|
+
if base_period_val is None:
|
|
351
|
+
continue
|
|
352
|
+
if base_period_val not in time_to_col:
|
|
353
|
+
warnings.warn(
|
|
354
|
+
f"Base period {base_period_val} for (g={g}, t={t}) is "
|
|
355
|
+
"outside the observed panel. Skipping this cell.",
|
|
356
|
+
UserWarning,
|
|
357
|
+
stacklevel=2,
|
|
358
|
+
)
|
|
359
|
+
continue
|
|
360
|
+
if t not in time_to_col:
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
has_never_enabled = bool(np.any(unit_cohorts == 0))
|
|
364
|
+
|
|
365
|
+
if self.control_group == "nevertreated":
|
|
366
|
+
# Only use never-enabled cohort as comparison
|
|
367
|
+
valid_gc = [0] if has_never_enabled else []
|
|
368
|
+
else:
|
|
369
|
+
# Use all valid comparison cohorts (not-yet-treated + never)
|
|
370
|
+
# Threshold accounts for anticipation: cohorts that start
|
|
371
|
+
# treatment within the anticipation window are contaminated.
|
|
372
|
+
nyt_threshold = max(t, base_period_val) + self.anticipation
|
|
373
|
+
valid_gc = [gc for gc in treatment_groups if gc > nyt_threshold and gc != g]
|
|
374
|
+
if has_never_enabled:
|
|
375
|
+
valid_gc = [0] + valid_gc
|
|
376
|
+
|
|
377
|
+
if not valid_gc:
|
|
378
|
+
warnings.warn(
|
|
379
|
+
f"No valid comparison groups for (g={g}, t={t}), skipping.",
|
|
380
|
+
UserWarning,
|
|
381
|
+
stacklevel=2,
|
|
382
|
+
)
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
treated_mask = (unit_cohorts == g) & (eligibility_per_unit == 1)
|
|
386
|
+
n_treated = int(np.sum(treated_mask))
|
|
387
|
+
if n_treated == 0:
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
att_vec = []
|
|
391
|
+
inf_raw = [] # unrescaled IFs
|
|
392
|
+
gc_labels = []
|
|
393
|
+
gc_cell_sizes = [] # size_gt_ctrl per surviving gc
|
|
394
|
+
|
|
395
|
+
for gc in valid_gc:
|
|
396
|
+
result = self._compute_ddd_gt_gc(
|
|
397
|
+
precomputed,
|
|
398
|
+
g,
|
|
399
|
+
gc,
|
|
400
|
+
t,
|
|
401
|
+
base_period_val,
|
|
402
|
+
covariates,
|
|
403
|
+
pscore_cache,
|
|
404
|
+
cho_cache,
|
|
405
|
+
epv_diagnostics=epv_diagnostics,
|
|
406
|
+
)
|
|
407
|
+
if result is None:
|
|
408
|
+
continue
|
|
409
|
+
att_gc, inf_gc, size_gt_ctrl = result
|
|
410
|
+
if not np.isfinite(att_gc):
|
|
411
|
+
continue
|
|
412
|
+
|
|
413
|
+
att_vec.append(att_gc)
|
|
414
|
+
inf_raw.append(inf_gc)
|
|
415
|
+
gc_labels.append(gc)
|
|
416
|
+
gc_cell_sizes.append(size_gt_ctrl)
|
|
417
|
+
|
|
418
|
+
if not att_vec:
|
|
419
|
+
continue
|
|
420
|
+
|
|
421
|
+
# Compute size_gt from SURVIVING comparison cohorts only
|
|
422
|
+
# (not from all initially valid gc's)
|
|
423
|
+
surviving_units = treated_mask.copy()
|
|
424
|
+
for gc in gc_labels:
|
|
425
|
+
surviving_units |= (unit_cohorts == gc) | (unit_cohorts == g)
|
|
426
|
+
survey_w = precomputed.get("survey_weights")
|
|
427
|
+
if survey_w is not None:
|
|
428
|
+
size_gt = float(np.sum(survey_w[surviving_units]))
|
|
429
|
+
else:
|
|
430
|
+
size_gt = float(np.sum(surviving_units))
|
|
431
|
+
|
|
432
|
+
# Apply IF rescaling now that size_gt is known
|
|
433
|
+
inf_matrix = []
|
|
434
|
+
for inf_gc, size_gt_ctrl in zip(inf_raw, gc_cell_sizes):
|
|
435
|
+
if size_gt_ctrl > 0:
|
|
436
|
+
inf_gc = inf_gc * (size_gt / size_gt_ctrl)
|
|
437
|
+
inf_matrix.append(inf_gc)
|
|
438
|
+
|
|
439
|
+
att_gmm, inf_gmm, gmm_w, se_gt = self._combine_gmm(
|
|
440
|
+
np.array(att_vec),
|
|
441
|
+
np.array(inf_matrix),
|
|
442
|
+
n_units,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
if not np.isfinite(att_gmm):
|
|
446
|
+
continue
|
|
447
|
+
|
|
448
|
+
# R's single-gc SE uses size_gt in denominator, not n_total.
|
|
449
|
+
# For multi-gc (GMM), the size_gt factor is already in Omega
|
|
450
|
+
# via the per-gc rescaling, so n_total is correct.
|
|
451
|
+
if len(gc_labels) == 1:
|
|
452
|
+
se_gt = float(np.sqrt(np.sum(inf_gmm**2) / size_gt**2))
|
|
453
|
+
|
|
454
|
+
if not np.isfinite(se_gt) or se_gt <= 0:
|
|
455
|
+
se_gt = np.nan
|
|
456
|
+
|
|
457
|
+
t_stat, p_value, conf_int = safe_inference(
|
|
458
|
+
att_gmm, se_gt, alpha=self.alpha, df=df_survey
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Rescale IF for mixin compatibility.
|
|
462
|
+
# R stores IF * (n/size_gt) in inf_func_mat, then uses
|
|
463
|
+
# SE = sqrt(sum(IF^2)/n^2) = sqrt(sum(psi^2)) with psi = IF/n.
|
|
464
|
+
# We need psi = IF_rescaled / n so mixin's sqrt(sum(psi^2)) works.
|
|
465
|
+
# IF is already at size_gt/size_gt_ctrl scale from above.
|
|
466
|
+
# Apply the final n/size_gt factor, then divide by n for mixin.
|
|
467
|
+
inf_gmm_rescaled = inf_gmm * (n_units / size_gt)
|
|
468
|
+
inf_gmm_scaled = inf_gmm_rescaled / n_units
|
|
469
|
+
|
|
470
|
+
treated_idx = np.where(treated_mask)[0]
|
|
471
|
+
treated_inf = inf_gmm_scaled[treated_idx]
|
|
472
|
+
nonzero_mask = (inf_gmm_scaled != 0) & ~treated_mask
|
|
473
|
+
control_idx = np.where(nonzero_mask)[0]
|
|
474
|
+
control_inf = inf_gmm_scaled[control_idx]
|
|
475
|
+
n_control = int(np.sum(nonzero_mask))
|
|
476
|
+
|
|
477
|
+
group_time_effects[(g, t)] = {
|
|
478
|
+
"effect": att_gmm,
|
|
479
|
+
"se": se_gt,
|
|
480
|
+
"t_stat": t_stat,
|
|
481
|
+
"p_value": p_value,
|
|
482
|
+
"conf_int": conf_int,
|
|
483
|
+
"n_treated": n_treated,
|
|
484
|
+
"n_control": n_control,
|
|
485
|
+
}
|
|
486
|
+
influence_func_info[(g, t)] = {
|
|
487
|
+
"treated_idx": treated_idx,
|
|
488
|
+
"control_idx": control_idx,
|
|
489
|
+
"treated_units": all_units[treated_idx],
|
|
490
|
+
"control_units": all_units[control_idx],
|
|
491
|
+
"treated_inf": treated_inf,
|
|
492
|
+
"control_inf": control_inf,
|
|
493
|
+
}
|
|
494
|
+
comparison_group_counts[(g, t)] = len(gc_labels)
|
|
495
|
+
gmm_weights_store[(g, t)] = dict(zip(gc_labels, gmm_w.tolist()))
|
|
496
|
+
|
|
497
|
+
# Consolidated EPV summary warning
|
|
498
|
+
if epv_diagnostics:
|
|
499
|
+
low_epv = {k: v for k, v in epv_diagnostics.items() if v.get("is_low")}
|
|
500
|
+
if low_epv:
|
|
501
|
+
n_affected = len(low_epv)
|
|
502
|
+
n_total = len(epv_diagnostics)
|
|
503
|
+
min_entry = min(low_epv.values(), key=lambda v: v["epv"])
|
|
504
|
+
min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
|
|
505
|
+
warnings.warn(
|
|
506
|
+
f"Low Events Per Variable (EPV) detected in "
|
|
507
|
+
f"{n_affected} of {n_total} cohort-time cell(s). "
|
|
508
|
+
f"Minimum EPV: {min_entry['epv']:.1f} (cohort g={min_g[0]}). "
|
|
509
|
+
f"Consider estimation_method='reg' or fewer covariates. "
|
|
510
|
+
f"Call results.epv_summary() for per-cohort details.",
|
|
511
|
+
UserWarning,
|
|
512
|
+
stacklevel=2,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
if not group_time_effects:
|
|
516
|
+
raise ValueError(
|
|
517
|
+
"No valid group-time effects could be computed. "
|
|
518
|
+
"Check that the data has sufficient variation in treatment "
|
|
519
|
+
"timing and eligibility."
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# For aggregation: use eligible-treated-only cohort assignments so
|
|
523
|
+
# WIF weights match the point estimate weights (n_treated per cohort,
|
|
524
|
+
# i.e. P(S=g, Q=1)). This matches the paper's Eq 4.13 which defines
|
|
525
|
+
# aggregation weights over the treated population (G_i defined only
|
|
526
|
+
# for Q=1 units). Ineligible units get cohort=0 so they don't
|
|
527
|
+
# contribute to pg for any treatment group.
|
|
528
|
+
# Both precomputed["unit_cohorts"] AND df["first_treat"] must be
|
|
529
|
+
# zeroed for ineligible units because the WIF code reads both.
|
|
530
|
+
precomputed_agg = dict(precomputed)
|
|
531
|
+
cohorts_for_agg = precomputed["unit_cohorts"].copy()
|
|
532
|
+
cohorts_for_agg[eligibility_per_unit == 0] = 0
|
|
533
|
+
precomputed_agg["unit_cohorts"] = cohorts_for_agg
|
|
534
|
+
|
|
535
|
+
df_agg = df.copy()
|
|
536
|
+
df_agg.loc[df_agg[eligibility] == 0, "first_treat"] = 0
|
|
537
|
+
|
|
538
|
+
# Overall ATT via aggregation mixin
|
|
539
|
+
overall_att, overall_se, overall_effective_df = self._aggregate_simple(
|
|
540
|
+
group_time_effects, influence_func_info, df_agg, unit, precomputed_agg
|
|
541
|
+
)
|
|
542
|
+
# Use per-statistic effective df from replicate aggregation if available;
|
|
543
|
+
# otherwise fall back to the original df from the survey design.
|
|
544
|
+
if overall_effective_df is not None:
|
|
545
|
+
df_survey = overall_effective_df
|
|
546
|
+
if survey_metadata is not None:
|
|
547
|
+
survey_metadata.df_survey = df_survey
|
|
548
|
+
overall_t_stat, overall_p_value, overall_conf_int = safe_inference(
|
|
549
|
+
overall_att, overall_se, alpha=self.alpha, df=df_survey
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# Aggregations
|
|
553
|
+
event_study_effects = None
|
|
554
|
+
group_effects = None
|
|
555
|
+
if aggregate in ("event_study", "all"):
|
|
556
|
+
event_study_effects = self._aggregate_event_study(
|
|
557
|
+
group_time_effects,
|
|
558
|
+
influence_func_info,
|
|
559
|
+
treatment_groups,
|
|
560
|
+
time_periods,
|
|
561
|
+
balance_e,
|
|
562
|
+
df_agg,
|
|
563
|
+
unit,
|
|
564
|
+
precomputed_agg,
|
|
565
|
+
)
|
|
566
|
+
if aggregate in ("group", "all"):
|
|
567
|
+
group_effects = self._aggregate_by_group(
|
|
568
|
+
group_time_effects,
|
|
569
|
+
influence_func_info,
|
|
570
|
+
treatment_groups,
|
|
571
|
+
precomputed_agg,
|
|
572
|
+
df_agg,
|
|
573
|
+
unit,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
# Reject replicate-weight designs for bootstrap — replicate variance
|
|
577
|
+
# is an analytical alternative, not compatible with bootstrap
|
|
578
|
+
if (
|
|
579
|
+
self.n_bootstrap > 0
|
|
580
|
+
and resolved_survey is not None
|
|
581
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
582
|
+
and resolved_survey.uses_replicate_variance
|
|
583
|
+
):
|
|
584
|
+
raise NotImplementedError(
|
|
585
|
+
"StaggeredTripleDifference bootstrap (n_bootstrap > 0) is not "
|
|
586
|
+
"supported with replicate-weight survey designs. Replicate "
|
|
587
|
+
"weights provide analytical variance; use n_bootstrap=0 instead."
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Bootstrap
|
|
591
|
+
bootstrap_results = None
|
|
592
|
+
cband_crit_value = None
|
|
593
|
+
if self.n_bootstrap > 0:
|
|
594
|
+
bootstrap_results = self._run_multiplier_bootstrap(
|
|
595
|
+
group_time_effects,
|
|
596
|
+
influence_func_info,
|
|
597
|
+
aggregate,
|
|
598
|
+
balance_e,
|
|
599
|
+
treatment_groups,
|
|
600
|
+
time_periods,
|
|
601
|
+
df_agg,
|
|
602
|
+
unit,
|
|
603
|
+
precomputed_agg,
|
|
604
|
+
self.cband,
|
|
605
|
+
)
|
|
606
|
+
if bootstrap_results is not None:
|
|
607
|
+
overall_se = bootstrap_results.overall_att_se
|
|
608
|
+
overall_t_stat, overall_p_value, overall_conf_int = safe_inference(
|
|
609
|
+
overall_att, overall_se, alpha=self.alpha, df=df_survey
|
|
610
|
+
)
|
|
611
|
+
overall_conf_int = bootstrap_results.overall_att_ci
|
|
612
|
+
overall_p_value = bootstrap_results.overall_att_p_value
|
|
613
|
+
if bootstrap_results.cband_crit_value is not None:
|
|
614
|
+
cband_crit_value = bootstrap_results.cband_crit_value
|
|
615
|
+
|
|
616
|
+
# Update group-time effects with bootstrap SEs
|
|
617
|
+
if bootstrap_results.group_time_ses:
|
|
618
|
+
for gt_key in group_time_effects:
|
|
619
|
+
if gt_key in bootstrap_results.group_time_ses:
|
|
620
|
+
group_time_effects[gt_key]["se"] = bootstrap_results.group_time_ses[
|
|
621
|
+
gt_key
|
|
622
|
+
]
|
|
623
|
+
group_time_effects[gt_key]["conf_int"] = (
|
|
624
|
+
bootstrap_results.group_time_cis[gt_key]
|
|
625
|
+
)
|
|
626
|
+
group_time_effects[gt_key]["p_value"] = (
|
|
627
|
+
bootstrap_results.group_time_p_values[gt_key]
|
|
628
|
+
)
|
|
629
|
+
t_val, _, _ = safe_inference(
|
|
630
|
+
group_time_effects[gt_key]["effect"],
|
|
631
|
+
bootstrap_results.group_time_ses[gt_key],
|
|
632
|
+
alpha=self.alpha,
|
|
633
|
+
df=df_survey,
|
|
634
|
+
)
|
|
635
|
+
group_time_effects[gt_key]["t_stat"] = t_val
|
|
636
|
+
|
|
637
|
+
if event_study_effects and bootstrap_results.event_study_ses:
|
|
638
|
+
for e_key in event_study_effects:
|
|
639
|
+
if e_key in bootstrap_results.event_study_ses:
|
|
640
|
+
event_study_effects[e_key]["se"] = bootstrap_results.event_study_ses[
|
|
641
|
+
e_key
|
|
642
|
+
]
|
|
643
|
+
event_study_effects[e_key]["conf_int"] = (
|
|
644
|
+
bootstrap_results.event_study_cis[e_key]
|
|
645
|
+
)
|
|
646
|
+
event_study_effects[e_key]["p_value"] = (
|
|
647
|
+
bootstrap_results.event_study_p_values[e_key]
|
|
648
|
+
)
|
|
649
|
+
t_val, _, _ = safe_inference(
|
|
650
|
+
event_study_effects[e_key]["effect"],
|
|
651
|
+
bootstrap_results.event_study_ses[e_key],
|
|
652
|
+
alpha=self.alpha,
|
|
653
|
+
df=df_survey,
|
|
654
|
+
)
|
|
655
|
+
event_study_effects[e_key]["t_stat"] = t_val
|
|
656
|
+
if cband_crit_value is not None:
|
|
657
|
+
bs_se = bootstrap_results.event_study_ses[e_key]
|
|
658
|
+
eff = event_study_effects[e_key]["effect"]
|
|
659
|
+
event_study_effects[e_key]["cband_conf_int"] = (
|
|
660
|
+
eff - cband_crit_value * bs_se,
|
|
661
|
+
eff + cband_crit_value * bs_se,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Update group effects with bootstrap SEs
|
|
665
|
+
if (
|
|
666
|
+
group_effects
|
|
667
|
+
and bootstrap_results.group_effect_ses is not None
|
|
668
|
+
and bootstrap_results.group_effect_cis is not None
|
|
669
|
+
and bootstrap_results.group_effect_p_values is not None
|
|
670
|
+
):
|
|
671
|
+
grp_keys = [g for g in group_effects if g in bootstrap_results.group_effect_ses]
|
|
672
|
+
for g_key in grp_keys:
|
|
673
|
+
group_effects[g_key]["se"] = bootstrap_results.group_effect_ses[g_key]
|
|
674
|
+
group_effects[g_key]["conf_int"] = bootstrap_results.group_effect_cis[g_key]
|
|
675
|
+
group_effects[g_key]["p_value"] = bootstrap_results.group_effect_p_values[
|
|
676
|
+
g_key
|
|
677
|
+
]
|
|
678
|
+
t_val, _, _ = safe_inference(
|
|
679
|
+
group_effects[g_key]["effect"],
|
|
680
|
+
bootstrap_results.group_effect_ses[g_key],
|
|
681
|
+
alpha=self.alpha,
|
|
682
|
+
df=df_survey,
|
|
683
|
+
)
|
|
684
|
+
group_effects[g_key]["t_stat"] = t_val
|
|
685
|
+
|
|
686
|
+
n_treated_units = int(np.sum((unit_cohorts > 0) & (eligibility_per_unit == 1)))
|
|
687
|
+
n_control_units = n_units - n_treated_units
|
|
688
|
+
n_never_enabled = int(np.sum(unit_cohorts == 0))
|
|
689
|
+
n_eligible = int(np.sum(eligibility_per_unit == 1))
|
|
690
|
+
n_ineligible = int(np.sum(eligibility_per_unit == 0))
|
|
691
|
+
|
|
692
|
+
self.results_ = StaggeredTripleDiffResults(
|
|
693
|
+
group_time_effects=group_time_effects,
|
|
694
|
+
overall_att=overall_att,
|
|
695
|
+
overall_se=overall_se,
|
|
696
|
+
overall_t_stat=overall_t_stat,
|
|
697
|
+
overall_p_value=overall_p_value,
|
|
698
|
+
overall_conf_int=overall_conf_int,
|
|
699
|
+
groups=treatment_groups,
|
|
700
|
+
time_periods=time_periods,
|
|
701
|
+
n_obs=len(df),
|
|
702
|
+
n_treated_units=n_treated_units,
|
|
703
|
+
n_control_units=n_control_units,
|
|
704
|
+
n_never_enabled=n_never_enabled,
|
|
705
|
+
n_eligible=n_eligible,
|
|
706
|
+
n_ineligible=n_ineligible,
|
|
707
|
+
alpha=self.alpha,
|
|
708
|
+
control_group=self.control_group,
|
|
709
|
+
base_period=self.base_period,
|
|
710
|
+
estimation_method=self.estimation_method,
|
|
711
|
+
event_study_effects=event_study_effects,
|
|
712
|
+
group_effects=group_effects,
|
|
713
|
+
bootstrap_results=bootstrap_results,
|
|
714
|
+
cband_crit_value=cband_crit_value,
|
|
715
|
+
pscore_trim=self.pscore_trim,
|
|
716
|
+
survey_metadata=survey_metadata,
|
|
717
|
+
comparison_group_counts=comparison_group_counts,
|
|
718
|
+
gmm_weights=gmm_weights_store,
|
|
719
|
+
epv_diagnostics=epv_diagnostics if epv_diagnostics else None,
|
|
720
|
+
epv_threshold=self.epv_threshold,
|
|
721
|
+
pscore_fallback=self.pscore_fallback,
|
|
722
|
+
)
|
|
723
|
+
self.is_fitted_ = True
|
|
724
|
+
return self.results_
|
|
725
|
+
|
|
726
|
+
# ------------------------------------------------------------------
|
|
727
|
+
# Validation
|
|
728
|
+
# ------------------------------------------------------------------
|
|
729
|
+
|
|
730
|
+
def _validate_inputs(
|
|
731
|
+
self,
|
|
732
|
+
df: pd.DataFrame,
|
|
733
|
+
outcome: str,
|
|
734
|
+
unit: str,
|
|
735
|
+
time: str,
|
|
736
|
+
first_treat: str,
|
|
737
|
+
eligibility: str,
|
|
738
|
+
covariates: Optional[List[str]],
|
|
739
|
+
) -> None:
|
|
740
|
+
"""Validate input data."""
|
|
741
|
+
required_cols = [outcome, unit, time, first_treat, eligibility]
|
|
742
|
+
if covariates:
|
|
743
|
+
required_cols.extend(covariates)
|
|
744
|
+
missing = [c for c in required_cols if c not in df.columns]
|
|
745
|
+
if missing:
|
|
746
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
747
|
+
|
|
748
|
+
elig_vals = df[eligibility].dropna().unique()
|
|
749
|
+
if not set(elig_vals).issubset({0, 1, 0.0, 1.0}):
|
|
750
|
+
raise ValueError(
|
|
751
|
+
f"Eligibility column '{eligibility}' must be binary (0/1). "
|
|
752
|
+
f"Found values: {sorted(elig_vals)}"
|
|
753
|
+
)
|
|
754
|
+
elig_by_unit = df.groupby(unit)[eligibility].nunique()
|
|
755
|
+
varying = elig_by_unit[elig_by_unit > 1]
|
|
756
|
+
if len(varying) > 0:
|
|
757
|
+
raise ValueError(
|
|
758
|
+
f"Eligibility must be time-invariant within units. "
|
|
759
|
+
f"Found {len(varying)} units with varying eligibility."
|
|
760
|
+
)
|
|
761
|
+
for col in [outcome, first_treat, eligibility]:
|
|
762
|
+
if df[col].isna().any():
|
|
763
|
+
raise ValueError(f"Column '{col}' contains missing values.")
|
|
764
|
+
|
|
765
|
+
# Reject non-finite outcomes (Inf/-Inf)
|
|
766
|
+
if not np.all(np.isfinite(df[outcome])):
|
|
767
|
+
raise ValueError(
|
|
768
|
+
f"Column '{outcome}' contains non-finite values (Inf/-Inf). "
|
|
769
|
+
"All outcome values must be finite."
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# Reject non-finite covariates
|
|
773
|
+
if covariates:
|
|
774
|
+
for cov in covariates:
|
|
775
|
+
if df[cov].isna().any():
|
|
776
|
+
raise ValueError(f"Covariate '{cov}' contains missing values.")
|
|
777
|
+
if not np.all(np.isfinite(df[cov])):
|
|
778
|
+
raise ValueError(f"Covariate '{cov}' contains non-finite values.")
|
|
779
|
+
if df[eligibility].nunique() < 2:
|
|
780
|
+
raise ValueError(
|
|
781
|
+
"Need both eligible (Q=1) and ineligible (Q=0) units. "
|
|
782
|
+
f"Only found Q={df[eligibility].unique()[0]}."
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
# Check unique (unit, time) pairs — no duplicate rows
|
|
786
|
+
dup = df.duplicated(subset=[unit, time], keep=False)
|
|
787
|
+
if dup.any():
|
|
788
|
+
raise ValueError(
|
|
789
|
+
f"Duplicate (unit, time) rows found. "
|
|
790
|
+
f"{int(dup.sum())} duplicates detected. Panel must have unique rows."
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Check balanced panel — every unit observed in exactly the global period set
|
|
794
|
+
global_periods = set(df[time].unique())
|
|
795
|
+
n_global_periods = len(global_periods)
|
|
796
|
+
unit_period_sets = df.groupby(unit)[time].apply(set)
|
|
797
|
+
mismatched = unit_period_sets[unit_period_sets != global_periods]
|
|
798
|
+
if len(mismatched) > 0:
|
|
799
|
+
raise ValueError(
|
|
800
|
+
"Unbalanced panel detected. All units must be observed in "
|
|
801
|
+
f"all {n_global_periods} periods. "
|
|
802
|
+
f"Found {len(mismatched)} units with different period sets."
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
# Check time-invariant first_treat
|
|
806
|
+
ft_by_unit = df.groupby(unit)[first_treat].nunique()
|
|
807
|
+
varying_ft = ft_by_unit[ft_by_unit > 1]
|
|
808
|
+
if len(varying_ft) > 0:
|
|
809
|
+
raise ValueError(
|
|
810
|
+
f"first_treat must be time-invariant within units. "
|
|
811
|
+
f"Found {len(varying_ft)} units with varying first_treat."
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
# Check time-invariant covariates
|
|
815
|
+
if covariates:
|
|
816
|
+
for cov in covariates:
|
|
817
|
+
cov_nunique = df.groupby(unit)[cov].nunique()
|
|
818
|
+
varying_cov = cov_nunique[cov_nunique > 1]
|
|
819
|
+
if len(varying_cov) > 0:
|
|
820
|
+
raise ValueError(
|
|
821
|
+
f"Covariate '{cov}' must be time-invariant within units. "
|
|
822
|
+
f"Found {len(varying_cov)} units with varying values."
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
# ------------------------------------------------------------------
|
|
826
|
+
# Precomputation
|
|
827
|
+
# ------------------------------------------------------------------
|
|
828
|
+
|
|
829
|
+
def _precompute_structures(
|
|
830
|
+
self,
|
|
831
|
+
df: pd.DataFrame,
|
|
832
|
+
outcome: str,
|
|
833
|
+
unit: str,
|
|
834
|
+
time: str,
|
|
835
|
+
eligibility: str,
|
|
836
|
+
covariates: Optional[List[str]],
|
|
837
|
+
resolved_survey=None,
|
|
838
|
+
) -> PrecomputedData:
|
|
839
|
+
"""Build precomputed structures for efficient computation."""
|
|
840
|
+
all_units = np.array(sorted(df[unit].unique()))
|
|
841
|
+
time_periods = sorted(df[time].unique())
|
|
842
|
+
n_units = len(all_units)
|
|
843
|
+
n_periods = len(time_periods)
|
|
844
|
+
|
|
845
|
+
unit_to_idx = {u: i for i, u in enumerate(all_units)}
|
|
846
|
+
time_to_col = {t: j for j, t in enumerate(time_periods)}
|
|
847
|
+
|
|
848
|
+
outcome_matrix = np.full((n_units, n_periods), np.nan)
|
|
849
|
+
for _, row in df.iterrows():
|
|
850
|
+
u_idx = unit_to_idx[row[unit]]
|
|
851
|
+
t_idx = time_to_col[row[time]]
|
|
852
|
+
outcome_matrix[u_idx, t_idx] = row[outcome]
|
|
853
|
+
|
|
854
|
+
unit_df = df.groupby(unit).first().reindex(all_units)
|
|
855
|
+
unit_cohorts = unit_df["first_treat"].values.astype(float)
|
|
856
|
+
eligibility_per_unit = unit_df[eligibility].values.astype(int)
|
|
857
|
+
|
|
858
|
+
treatment_groups = sorted([g for g in np.unique(unit_cohorts) if g > 0])
|
|
859
|
+
|
|
860
|
+
covariate_matrix = None
|
|
861
|
+
if covariates:
|
|
862
|
+
cov_wide = {}
|
|
863
|
+
for cov in covariates:
|
|
864
|
+
cov_vals = np.full(n_units, np.nan)
|
|
865
|
+
for u_id, idx in unit_to_idx.items():
|
|
866
|
+
u_data = df.loc[df[unit] == u_id, cov]
|
|
867
|
+
if len(u_data) > 0:
|
|
868
|
+
cov_vals[idx] = u_data.iloc[0]
|
|
869
|
+
cov_wide[cov] = cov_vals
|
|
870
|
+
covariate_matrix = np.column_stack(list(cov_wide.values()))
|
|
871
|
+
|
|
872
|
+
# Extract per-unit survey weights and collapse design to unit level
|
|
873
|
+
survey_weights_arr = None
|
|
874
|
+
resolved_survey_unit = None
|
|
875
|
+
if resolved_survey is not None:
|
|
876
|
+
from diff_diff.survey import collapse_survey_to_unit_level
|
|
877
|
+
|
|
878
|
+
survey_weights_arr = (
|
|
879
|
+
pd.Series(resolved_survey.weights, index=df.index)
|
|
880
|
+
.groupby(df[unit])
|
|
881
|
+
.first()
|
|
882
|
+
.reindex(all_units)
|
|
883
|
+
.values.astype(np.float64)
|
|
884
|
+
)
|
|
885
|
+
# Normalize to sum=n for aggregation/rescaling (matches pweight
|
|
886
|
+
# convention). Raw weights preserved in resolved_survey_unit for
|
|
887
|
+
# replicate w_r/w_full ratios — those are inherently scale-invariant.
|
|
888
|
+
sw_sum = np.sum(survey_weights_arr)
|
|
889
|
+
if sw_sum > 0:
|
|
890
|
+
survey_weights_arr = survey_weights_arr * (len(survey_weights_arr) / sw_sum)
|
|
891
|
+
resolved_survey_unit = collapse_survey_to_unit_level(
|
|
892
|
+
resolved_survey, df, unit, all_units
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
return {
|
|
896
|
+
"all_units": all_units,
|
|
897
|
+
"unit_to_idx": unit_to_idx,
|
|
898
|
+
"time_periods": time_periods,
|
|
899
|
+
"time_to_col": time_to_col,
|
|
900
|
+
"outcome_matrix": outcome_matrix,
|
|
901
|
+
"unit_cohorts": unit_cohorts,
|
|
902
|
+
"eligibility_per_unit": eligibility_per_unit,
|
|
903
|
+
"treatment_groups": treatment_groups,
|
|
904
|
+
"covariate_matrix": covariate_matrix,
|
|
905
|
+
"n_units": n_units,
|
|
906
|
+
"n_periods": n_periods,
|
|
907
|
+
"survey_weights": survey_weights_arr,
|
|
908
|
+
"resolved_survey_unit": resolved_survey_unit,
|
|
909
|
+
"df_survey": (
|
|
910
|
+
resolved_survey_unit.df_survey if resolved_survey_unit is not None else None
|
|
911
|
+
),
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
# ------------------------------------------------------------------
|
|
915
|
+
# Base period
|
|
916
|
+
# ------------------------------------------------------------------
|
|
917
|
+
|
|
918
|
+
def _get_base_period(self, g: Any, t: Any) -> Optional[Any]:
|
|
919
|
+
"""Determine base period for a (g, t) pair."""
|
|
920
|
+
if self.base_period == "universal":
|
|
921
|
+
return g - 1 - self.anticipation
|
|
922
|
+
else:
|
|
923
|
+
if t < g - self.anticipation:
|
|
924
|
+
return t - 1
|
|
925
|
+
else:
|
|
926
|
+
return g - 1 - self.anticipation
|
|
927
|
+
|
|
928
|
+
# ------------------------------------------------------------------
|
|
929
|
+
# Three-DiD DDD for one (g, g_c, t) triple
|
|
930
|
+
# ------------------------------------------------------------------
|
|
931
|
+
|
|
932
|
+
def _compute_ddd_gt_gc(
|
|
933
|
+
self,
|
|
934
|
+
precomputed: PrecomputedData,
|
|
935
|
+
g: Any,
|
|
936
|
+
g_c: Any,
|
|
937
|
+
t: Any,
|
|
938
|
+
base_period_val: Any,
|
|
939
|
+
covariates: Optional[List[str]],
|
|
940
|
+
pscore_cache: Dict,
|
|
941
|
+
cho_cache: Optional[Dict],
|
|
942
|
+
epv_diagnostics: Optional[Dict] = None,
|
|
943
|
+
) -> Optional[Tuple[float, np.ndarray, float]]:
|
|
944
|
+
"""
|
|
945
|
+
Compute DDD ATT for one (g, g_c, t) triple.
|
|
946
|
+
|
|
947
|
+
Returns (att_ddd, inf_full_n_units, size_gt_ctrl) or None.
|
|
948
|
+
"""
|
|
949
|
+
outcome_matrix = precomputed["outcome_matrix"]
|
|
950
|
+
time_to_col = precomputed["time_to_col"]
|
|
951
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
952
|
+
eligibility_per_unit = precomputed["eligibility_per_unit"]
|
|
953
|
+
covariate_matrix = precomputed["covariate_matrix"]
|
|
954
|
+
n_units = precomputed["n_units"]
|
|
955
|
+
survey_weights = precomputed.get("survey_weights")
|
|
956
|
+
|
|
957
|
+
t_col = time_to_col[t]
|
|
958
|
+
b_col = time_to_col[base_period_val]
|
|
959
|
+
|
|
960
|
+
# Four sub-groups within this (g, g_c) cell
|
|
961
|
+
treated_mask = (unit_cohorts == g) & (eligibility_per_unit == 1) # subgroup 4
|
|
962
|
+
sub_a_mask = (unit_cohorts == g) & (eligibility_per_unit == 0) # subgroup 3
|
|
963
|
+
sub_b_mask = (unit_cohorts == g_c) & (eligibility_per_unit == 1) # subgroup 2
|
|
964
|
+
sub_c_mask = (unit_cohorts == g_c) & (eligibility_per_unit == 0) # subgroup 1
|
|
965
|
+
|
|
966
|
+
n_treated = int(np.sum(treated_mask))
|
|
967
|
+
n_a = int(np.sum(sub_a_mask))
|
|
968
|
+
n_b = int(np.sum(sub_b_mask))
|
|
969
|
+
n_c = int(np.sum(sub_c_mask))
|
|
970
|
+
|
|
971
|
+
# Check for empty subgroups (by count or by survey weight mass)
|
|
972
|
+
empty = []
|
|
973
|
+
if n_treated == 0:
|
|
974
|
+
empty.append(f"(S={g},Q=1)")
|
|
975
|
+
if n_a == 0:
|
|
976
|
+
empty.append(f"(S={g},Q=0)")
|
|
977
|
+
if n_b == 0:
|
|
978
|
+
empty.append(f"(S={g_c},Q=1)")
|
|
979
|
+
if n_c == 0:
|
|
980
|
+
empty.append(f"(S={g_c},Q=0)")
|
|
981
|
+
# Zero survey-weight mass after subpopulation filtering = effectively empty
|
|
982
|
+
if not empty and survey_weights is not None:
|
|
983
|
+
if np.sum(survey_weights[treated_mask]) <= 0:
|
|
984
|
+
empty.append(f"(S={g},Q=1,mass=0)")
|
|
985
|
+
if np.sum(survey_weights[sub_a_mask]) <= 0:
|
|
986
|
+
empty.append(f"(S={g},Q=0,mass=0)")
|
|
987
|
+
if np.sum(survey_weights[sub_b_mask]) <= 0:
|
|
988
|
+
empty.append(f"(S={g_c},Q=1,mass=0)")
|
|
989
|
+
if np.sum(survey_weights[sub_c_mask]) <= 0:
|
|
990
|
+
empty.append(f"(S={g_c},Q=0,mass=0)")
|
|
991
|
+
if empty:
|
|
992
|
+
warnings.warn(
|
|
993
|
+
f"Empty subgroup(s) {', '.join(empty)} for "
|
|
994
|
+
f"(g={g}, g_c={g_c}, t={t}). "
|
|
995
|
+
"Comparison unidentified, skipping.",
|
|
996
|
+
UserWarning,
|
|
997
|
+
stacklevel=3,
|
|
998
|
+
)
|
|
999
|
+
return None
|
|
1000
|
+
|
|
1001
|
+
if min(n_treated, n_a, n_b, n_c) < 5:
|
|
1002
|
+
warnings.warn(
|
|
1003
|
+
f"Small cell size for (g={g}, g_c={g_c}, t={t}). " "Estimates may be unreliable.",
|
|
1004
|
+
UserWarning,
|
|
1005
|
+
stacklevel=3,
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
# Outcome changes
|
|
1009
|
+
delta_y_all = outcome_matrix[:, t_col] - outcome_matrix[:, b_col]
|
|
1010
|
+
valid = np.isfinite(delta_y_all)
|
|
1011
|
+
for m in [treated_mask, sub_a_mask, sub_b_mask, sub_c_mask]:
|
|
1012
|
+
if not np.all(valid[m]):
|
|
1013
|
+
return None
|
|
1014
|
+
|
|
1015
|
+
# Three pairwise DiDs, each on a 2-cell subset
|
|
1016
|
+
# Collect per-DiD EPV diagnostics; merge worst into (g,t) key later
|
|
1017
|
+
epv_diag_a = {} if epv_diagnostics is not None else None
|
|
1018
|
+
epv_diag_b = {} if epv_diagnostics is not None else None
|
|
1019
|
+
epv_diag_c = {} if epv_diagnostics is not None else None
|
|
1020
|
+
|
|
1021
|
+
# DiD_A: subgroup 4 vs 3 (treated-eligible vs treated-ineligible)
|
|
1022
|
+
pair_a_mask = treated_mask | sub_a_mask
|
|
1023
|
+
did_a = self._run_pairwise_did(
|
|
1024
|
+
delta_y_all,
|
|
1025
|
+
pair_a_mask,
|
|
1026
|
+
treated_mask,
|
|
1027
|
+
sub_a_mask,
|
|
1028
|
+
covariate_matrix,
|
|
1029
|
+
pscore_cache,
|
|
1030
|
+
(g, g, 0, base_period_val),
|
|
1031
|
+
cho_cache,
|
|
1032
|
+
("a", g, g, base_period_val),
|
|
1033
|
+
survey_weights=survey_weights,
|
|
1034
|
+
context_label=f"cohort g={g}, DiD_A (g_c={g_c})",
|
|
1035
|
+
epv_diagnostics_out=epv_diag_a,
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
# DiD_B: subgroup 4 vs 2 (treated-eligible vs control-eligible)
|
|
1039
|
+
pair_b_mask = treated_mask | sub_b_mask
|
|
1040
|
+
did_b = self._run_pairwise_did(
|
|
1041
|
+
delta_y_all,
|
|
1042
|
+
pair_b_mask,
|
|
1043
|
+
treated_mask,
|
|
1044
|
+
sub_b_mask,
|
|
1045
|
+
covariate_matrix,
|
|
1046
|
+
pscore_cache,
|
|
1047
|
+
(g, g_c, 1, base_period_val),
|
|
1048
|
+
cho_cache,
|
|
1049
|
+
("b", g, g_c, base_period_val),
|
|
1050
|
+
survey_weights=survey_weights,
|
|
1051
|
+
context_label=f"cohort g={g}, DiD_B (g_c={g_c})",
|
|
1052
|
+
epv_diagnostics_out=epv_diag_b,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# DiD_C: subgroup 4 vs 1 (treated-eligible vs control-ineligible)
|
|
1056
|
+
pair_c_mask = treated_mask | sub_c_mask
|
|
1057
|
+
did_c = self._run_pairwise_did(
|
|
1058
|
+
delta_y_all,
|
|
1059
|
+
pair_c_mask,
|
|
1060
|
+
treated_mask,
|
|
1061
|
+
sub_c_mask,
|
|
1062
|
+
covariate_matrix,
|
|
1063
|
+
pscore_cache,
|
|
1064
|
+
(g, g_c, 0, base_period_val),
|
|
1065
|
+
cho_cache,
|
|
1066
|
+
("c", g, g_c, base_period_val),
|
|
1067
|
+
survey_weights=survey_weights,
|
|
1068
|
+
context_label=f"cohort g={g}, DiD_C (g_c={g_c})",
|
|
1069
|
+
epv_diagnostics_out=epv_diag_c,
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
# Merge per-DiD EPV diagnostics: keep the worst (lowest EPV) entry
|
|
1073
|
+
# across all three DiDs for this g_c. If multiple g_c contribute to the
|
|
1074
|
+
# same (g, t) cell, retain the overall minimum EPV across all g_c calls.
|
|
1075
|
+
if epv_diagnostics is not None:
|
|
1076
|
+
candidates = [d for d in [epv_diag_a, epv_diag_b, epv_diag_c] if d]
|
|
1077
|
+
if candidates:
|
|
1078
|
+
worst = min(candidates, key=lambda d: d.get("epv", float("inf")))
|
|
1079
|
+
existing = epv_diagnostics.get((g, t))
|
|
1080
|
+
if existing is None or worst.get("epv", float("inf")) < existing.get(
|
|
1081
|
+
"epv", float("inf")
|
|
1082
|
+
):
|
|
1083
|
+
epv_diagnostics[(g, t)] = worst
|
|
1084
|
+
|
|
1085
|
+
if did_a is None or did_b is None or did_c is None:
|
|
1086
|
+
return None
|
|
1087
|
+
|
|
1088
|
+
att_a, inf_a = did_a
|
|
1089
|
+
att_b, inf_b = did_b
|
|
1090
|
+
att_c, inf_c = did_c
|
|
1091
|
+
|
|
1092
|
+
att_ddd = att_a + att_b - att_c
|
|
1093
|
+
|
|
1094
|
+
# Three-DiD IF combination: w_j = n_cell / n_pair_j (R's att_dr convention)
|
|
1095
|
+
# With survey weights, use survey-weighted cell sizes
|
|
1096
|
+
if survey_weights is not None:
|
|
1097
|
+
sw_4 = float(np.sum(survey_weights[treated_mask]))
|
|
1098
|
+
sw_3 = float(np.sum(survey_weights[sub_a_mask]))
|
|
1099
|
+
sw_2 = float(np.sum(survey_weights[sub_b_mask]))
|
|
1100
|
+
sw_1 = float(np.sum(survey_weights[sub_c_mask]))
|
|
1101
|
+
n_cell_w = sw_4 + sw_3 + sw_2 + sw_1
|
|
1102
|
+
n_pair_a_w = sw_4 + sw_3
|
|
1103
|
+
n_pair_b_w = sw_4 + sw_2
|
|
1104
|
+
n_pair_c_w = sw_4 + sw_1
|
|
1105
|
+
w_3 = n_cell_w / n_pair_a_w if n_pair_a_w > 0 else 1.0
|
|
1106
|
+
w_2 = n_cell_w / n_pair_b_w if n_pair_b_w > 0 else 1.0
|
|
1107
|
+
w_1 = n_cell_w / n_pair_c_w if n_pair_c_w > 0 else 1.0
|
|
1108
|
+
size_gt_ctrl = n_cell_w
|
|
1109
|
+
else:
|
|
1110
|
+
n_cell = n_treated + n_a + n_b + n_c
|
|
1111
|
+
n_pair_a = n_treated + n_a
|
|
1112
|
+
n_pair_b = n_treated + n_b
|
|
1113
|
+
n_pair_c = n_treated + n_c
|
|
1114
|
+
w_3 = n_cell / n_pair_a if n_pair_a > 0 else 1.0
|
|
1115
|
+
w_2 = n_cell / n_pair_b if n_pair_b > 0 else 1.0
|
|
1116
|
+
w_1 = n_cell / n_pair_c if n_pair_c > 0 else 1.0
|
|
1117
|
+
size_gt_ctrl = float(n_cell)
|
|
1118
|
+
|
|
1119
|
+
# Scatter pair-level IFs into n_units-length vector
|
|
1120
|
+
inf_full = np.zeros(n_units)
|
|
1121
|
+
pair_a_idx = np.where(pair_a_mask)[0]
|
|
1122
|
+
pair_b_idx = np.where(pair_b_mask)[0]
|
|
1123
|
+
pair_c_idx = np.where(pair_c_mask)[0]
|
|
1124
|
+
|
|
1125
|
+
inf_full[pair_a_idx] += w_3 * inf_a
|
|
1126
|
+
inf_full[pair_b_idx] += w_2 * inf_b
|
|
1127
|
+
inf_full[pair_c_idx] -= w_1 * inf_c
|
|
1128
|
+
|
|
1129
|
+
return att_ddd, inf_full, size_gt_ctrl
|
|
1130
|
+
|
|
1131
|
+
# ------------------------------------------------------------------
|
|
1132
|
+
# Pairwise DiD (matches R's compute_did)
|
|
1133
|
+
# ------------------------------------------------------------------
|
|
1134
|
+
|
|
1135
|
+
def _run_pairwise_did(
|
|
1136
|
+
self,
|
|
1137
|
+
delta_y_all: np.ndarray,
|
|
1138
|
+
pair_mask: np.ndarray,
|
|
1139
|
+
treated_mask: np.ndarray,
|
|
1140
|
+
control_mask: np.ndarray,
|
|
1141
|
+
covariate_matrix: Optional[np.ndarray],
|
|
1142
|
+
pscore_cache: Dict,
|
|
1143
|
+
pscore_key: Any,
|
|
1144
|
+
cho_cache: Optional[Dict],
|
|
1145
|
+
cho_key: Any,
|
|
1146
|
+
survey_weights: Optional[np.ndarray] = None,
|
|
1147
|
+
context_label: str = "",
|
|
1148
|
+
epv_diagnostics_out: Optional[dict] = None,
|
|
1149
|
+
) -> Optional[Tuple[float, np.ndarray]]:
|
|
1150
|
+
"""
|
|
1151
|
+
Compute a single pairwise DiD ATT and IF on a 2-cell subset.
|
|
1152
|
+
|
|
1153
|
+
Matches R's triplediff::compute_did() formulation exactly:
|
|
1154
|
+
Riesz/Hajek normalization, PS + OR IF corrections.
|
|
1155
|
+
|
|
1156
|
+
Returns (att, inf_func) where inf_func has length n_pair,
|
|
1157
|
+
ordered by pair_mask indices. Returns None if insufficient data.
|
|
1158
|
+
"""
|
|
1159
|
+
pair_idx = np.where(pair_mask)[0]
|
|
1160
|
+
n_pair = len(pair_idx)
|
|
1161
|
+
if n_pair == 0:
|
|
1162
|
+
return None
|
|
1163
|
+
|
|
1164
|
+
delta_y = delta_y_all[pair_idx]
|
|
1165
|
+
PA4 = treated_mask[pair_idx].astype(float)
|
|
1166
|
+
PAa = control_mask[pair_idx].astype(float)
|
|
1167
|
+
sw_pair = survey_weights[pair_idx] if survey_weights is not None else None
|
|
1168
|
+
|
|
1169
|
+
n_t = int(np.sum(PA4))
|
|
1170
|
+
n_c = int(np.sum(PAa))
|
|
1171
|
+
if n_t == 0 or n_c == 0:
|
|
1172
|
+
return None
|
|
1173
|
+
|
|
1174
|
+
has_covariates = covariate_matrix is not None and self.estimation_method != "none"
|
|
1175
|
+
|
|
1176
|
+
# Build covariate matrix with intercept for the pair
|
|
1177
|
+
covX = None
|
|
1178
|
+
if has_covariates:
|
|
1179
|
+
X_pair = covariate_matrix[pair_idx]
|
|
1180
|
+
covX = np.column_stack([np.ones(n_pair), X_pair])
|
|
1181
|
+
|
|
1182
|
+
# Compute nuisance parameters based on estimation method
|
|
1183
|
+
pscore = None
|
|
1184
|
+
hessian = None
|
|
1185
|
+
or_delta = np.zeros(n_pair)
|
|
1186
|
+
|
|
1187
|
+
if self.estimation_method in ("ipw", "dr") and covX is not None:
|
|
1188
|
+
pscore, hessian = self._compute_pscore(
|
|
1189
|
+
PA4,
|
|
1190
|
+
covX,
|
|
1191
|
+
pscore_cache,
|
|
1192
|
+
pscore_key,
|
|
1193
|
+
survey_weights=sw_pair,
|
|
1194
|
+
context_label=context_label,
|
|
1195
|
+
epv_diagnostics_out=epv_diagnostics_out,
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
if self.estimation_method in ("reg", "dr") and covX is not None:
|
|
1199
|
+
# Skip Cholesky cache when survey weights present (cho_cache=None)
|
|
1200
|
+
or_delta = self._compute_or(
|
|
1201
|
+
delta_y,
|
|
1202
|
+
PAa,
|
|
1203
|
+
covX,
|
|
1204
|
+
cho_cache,
|
|
1205
|
+
cho_key,
|
|
1206
|
+
survey_weights=sw_pair,
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
# Compute ATT and IF (R's compute_did formulation)
|
|
1210
|
+
return self._compute_did_panel(
|
|
1211
|
+
delta_y,
|
|
1212
|
+
PA4,
|
|
1213
|
+
PAa,
|
|
1214
|
+
covX,
|
|
1215
|
+
pscore,
|
|
1216
|
+
hessian,
|
|
1217
|
+
or_delta,
|
|
1218
|
+
survey_weights=sw_pair,
|
|
1219
|
+
)
|
|
1220
|
+
|
|
1221
|
+
# ------------------------------------------------------------------
|
|
1222
|
+
# Core DR/IPW/RA computation (matches R's compute_did exactly)
|
|
1223
|
+
# ------------------------------------------------------------------
|
|
1224
|
+
|
|
1225
|
+
def _compute_did_panel(
|
|
1226
|
+
self,
|
|
1227
|
+
delta_y: np.ndarray,
|
|
1228
|
+
PA4: np.ndarray,
|
|
1229
|
+
PAa: np.ndarray,
|
|
1230
|
+
covX: Optional[np.ndarray],
|
|
1231
|
+
pscore: Optional[np.ndarray],
|
|
1232
|
+
hessian: Optional[np.ndarray],
|
|
1233
|
+
or_delta: np.ndarray,
|
|
1234
|
+
survey_weights: Optional[np.ndarray] = None,
|
|
1235
|
+
) -> Tuple[float, np.ndarray]:
|
|
1236
|
+
"""
|
|
1237
|
+
Pairwise DiD ATT and influence function.
|
|
1238
|
+
Matches R's triplediff::compute_did() line-by-line.
|
|
1239
|
+
|
|
1240
|
+
Parameters
|
|
1241
|
+
----------
|
|
1242
|
+
delta_y : outcome changes for 2-cell subset (n_pair,)
|
|
1243
|
+
PA4 : treated indicator (n_pair,)
|
|
1244
|
+
PAa : control indicator (n_pair,)
|
|
1245
|
+
covX : covariate matrix with intercept (n_pair, p) or None
|
|
1246
|
+
pscore : propensity scores (n_pair,) or None
|
|
1247
|
+
hessian : (X'WX)^{-1} * n_pair or None
|
|
1248
|
+
or_delta : OR predictions (n_pair,), zeros if no covariates
|
|
1249
|
+
survey_weights : per-observation survey weights (n_pair,) or None
|
|
1250
|
+
|
|
1251
|
+
Returns
|
|
1252
|
+
-------
|
|
1253
|
+
(att, inf_func) where inf_func has length n_pair.
|
|
1254
|
+
"""
|
|
1255
|
+
n_pair = len(delta_y)
|
|
1256
|
+
est = self.estimation_method
|
|
1257
|
+
|
|
1258
|
+
# Riesz representers (R lines 243-250)
|
|
1259
|
+
if est == "reg" or pscore is None:
|
|
1260
|
+
w_treat = PA4.copy()
|
|
1261
|
+
w_control = PAa.copy()
|
|
1262
|
+
else:
|
|
1263
|
+
w_treat = PA4.copy()
|
|
1264
|
+
pscore_safe = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
|
|
1265
|
+
w_control = pscore_safe * PAa / (1 - pscore_safe)
|
|
1266
|
+
|
|
1267
|
+
# Incorporate survey weights into Riesz representers
|
|
1268
|
+
if survey_weights is not None:
|
|
1269
|
+
w_treat = w_treat * survey_weights
|
|
1270
|
+
w_control = w_control * survey_weights
|
|
1271
|
+
|
|
1272
|
+
# DR ATT via Hajek normalization (R lines 251-256)
|
|
1273
|
+
resid = delta_y - or_delta
|
|
1274
|
+
riesz_treat = w_treat * resid
|
|
1275
|
+
riesz_control = w_control * resid
|
|
1276
|
+
|
|
1277
|
+
mean_w_treat = np.mean(w_treat)
|
|
1278
|
+
mean_w_control = np.mean(w_control)
|
|
1279
|
+
|
|
1280
|
+
if mean_w_treat <= 0 or mean_w_control <= 0:
|
|
1281
|
+
return float("nan"), np.zeros(n_pair)
|
|
1282
|
+
|
|
1283
|
+
att_treat = np.mean(riesz_treat) / mean_w_treat
|
|
1284
|
+
att_control = np.mean(riesz_control) / mean_w_control
|
|
1285
|
+
dr_att = att_treat - att_control
|
|
1286
|
+
|
|
1287
|
+
# Base IF (R lines 302-304)
|
|
1288
|
+
inf_treat_did = riesz_treat - w_treat * att_treat
|
|
1289
|
+
inf_control_did = riesz_control - w_control * att_control
|
|
1290
|
+
|
|
1291
|
+
# PS correction (R lines 262-273) — IPW and DR only
|
|
1292
|
+
inf_control_pscore = 0.0
|
|
1293
|
+
if est != "reg" and hessian is not None and covX is not None:
|
|
1294
|
+
M2 = np.mean((w_control * (resid - att_control))[:, None] * covX, axis=0)
|
|
1295
|
+
if survey_weights is not None:
|
|
1296
|
+
score_ps = survey_weights[:, None] * (PA4 - pscore_safe)[:, None] * covX
|
|
1297
|
+
else:
|
|
1298
|
+
score_ps = (PA4 - pscore_safe)[:, None] * covX
|
|
1299
|
+
asy_lin_rep_ps = score_ps @ hessian
|
|
1300
|
+
inf_control_pscore = asy_lin_rep_ps @ M2
|
|
1301
|
+
|
|
1302
|
+
# OR correction (R lines 278-300) — reg and DR only
|
|
1303
|
+
inf_treat_or = 0.0
|
|
1304
|
+
inf_cont_or = 0.0
|
|
1305
|
+
if est != "ipw" and covX is not None:
|
|
1306
|
+
M1 = np.mean(w_treat[:, None] * covX, axis=0)
|
|
1307
|
+
M3 = np.mean(w_control[:, None] * covX, axis=0)
|
|
1308
|
+
|
|
1309
|
+
if survey_weights is not None:
|
|
1310
|
+
or_x = (PAa * survey_weights)[:, None] * covX
|
|
1311
|
+
or_ex = (PAa * survey_weights * resid)[:, None] * covX
|
|
1312
|
+
else:
|
|
1313
|
+
or_x = PAa[:, None] * covX
|
|
1314
|
+
or_ex = (PAa * resid)[:, None] * covX
|
|
1315
|
+
XpX = or_x.T @ covX / n_pair
|
|
1316
|
+
|
|
1317
|
+
try:
|
|
1318
|
+
asy_linear_or = (np.linalg.solve(XpX, or_ex.T)).T
|
|
1319
|
+
except np.linalg.LinAlgError:
|
|
1320
|
+
asy_linear_or = (np.linalg.lstsq(XpX, or_ex.T, rcond=None)[0]).T
|
|
1321
|
+
|
|
1322
|
+
inf_treat_or = -(asy_linear_or @ M1)
|
|
1323
|
+
inf_cont_or = -(asy_linear_or @ M3)
|
|
1324
|
+
|
|
1325
|
+
# Final IF assembly (R lines 307-310)
|
|
1326
|
+
inf_control = (inf_control_did + inf_control_pscore + inf_cont_or) / mean_w_control
|
|
1327
|
+
inf_treat = (inf_treat_did + inf_treat_or) / mean_w_treat
|
|
1328
|
+
inf_func = inf_treat - inf_control
|
|
1329
|
+
|
|
1330
|
+
return float(dr_att), inf_func
|
|
1331
|
+
|
|
1332
|
+
# ------------------------------------------------------------------
|
|
1333
|
+
# Nuisance parameter computation
|
|
1334
|
+
# ------------------------------------------------------------------
|
|
1335
|
+
|
|
1336
|
+
def _compute_pscore(
|
|
1337
|
+
self,
|
|
1338
|
+
PA4: np.ndarray,
|
|
1339
|
+
covX: np.ndarray,
|
|
1340
|
+
pscore_cache: Dict,
|
|
1341
|
+
pscore_key: Any,
|
|
1342
|
+
survey_weights: Optional[np.ndarray] = None,
|
|
1343
|
+
context_label: str = "",
|
|
1344
|
+
epv_diagnostics_out: Optional[dict] = None,
|
|
1345
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1346
|
+
"""Fit logistic P(PA4=1|X). Returns (pscore, hessian).
|
|
1347
|
+
|
|
1348
|
+
hessian = (X'WX)^{-1} * n_pair, matching R's convention.
|
|
1349
|
+
When survey_weights is provided, IRLS uses survey-weighted
|
|
1350
|
+
working weights and the hessian accounts for survey weights.
|
|
1351
|
+
"""
|
|
1352
|
+
cached = pscore_cache.get(pscore_key)
|
|
1353
|
+
n_pair = len(PA4)
|
|
1354
|
+
|
|
1355
|
+
if cached is not None:
|
|
1356
|
+
beta_logistic, cached_diag = cached
|
|
1357
|
+
z = np.dot(covX, beta_logistic)
|
|
1358
|
+
z = np.clip(z, -500, 500)
|
|
1359
|
+
pscore = 1 / (1 + np.exp(-z))
|
|
1360
|
+
if epv_diagnostics_out is not None and cached_diag:
|
|
1361
|
+
epv_diagnostics_out.update(cached_diag)
|
|
1362
|
+
else:
|
|
1363
|
+
X_no_intercept = covX[:, 1:] # solve_logit adds its own intercept
|
|
1364
|
+
diag = {}
|
|
1365
|
+
try:
|
|
1366
|
+
beta_logistic, pscore = solve_logit(
|
|
1367
|
+
X_no_intercept,
|
|
1368
|
+
PA4,
|
|
1369
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
1370
|
+
weights=survey_weights,
|
|
1371
|
+
epv_threshold=self.epv_threshold,
|
|
1372
|
+
context_label=context_label,
|
|
1373
|
+
diagnostics_out=diag,
|
|
1374
|
+
)
|
|
1375
|
+
_check_propensity_diagnostics(pscore, self.pscore_trim)
|
|
1376
|
+
# Zero-fill NaN coefficients (from rank-deficient columns)
|
|
1377
|
+
# before caching, so cache reuse doesn't propagate NaN.
|
|
1378
|
+
# Cache alongside EPV diagnostics for replay on cache hits.
|
|
1379
|
+
beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0)
|
|
1380
|
+
pscore_cache[pscore_key] = (beta_clean, diag)
|
|
1381
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
1382
|
+
if (
|
|
1383
|
+
self.pscore_fallback == "error"
|
|
1384
|
+
or self.rank_deficient_action == "error"
|
|
1385
|
+
):
|
|
1386
|
+
raise
|
|
1387
|
+
ctx = f" for {context_label}" if context_label else ""
|
|
1388
|
+
warnings.warn(
|
|
1389
|
+
f"Propensity score estimation failed{ctx}. "
|
|
1390
|
+
f"Falling back to unconditional propensity "
|
|
1391
|
+
f"(propensity model ignores covariates; outcome "
|
|
1392
|
+
f"regression still uses them for DR). "
|
|
1393
|
+
f"Consider estimation_method='reg' to avoid "
|
|
1394
|
+
f"propensity scores entirely.",
|
|
1395
|
+
UserWarning,
|
|
1396
|
+
stacklevel=5,
|
|
1397
|
+
)
|
|
1398
|
+
# Use survey-weighted treated share when weights available
|
|
1399
|
+
if survey_weights is not None:
|
|
1400
|
+
pos = survey_weights > 0
|
|
1401
|
+
if np.any(pos):
|
|
1402
|
+
p_uc = np.average(PA4[pos], weights=survey_weights[pos])
|
|
1403
|
+
else:
|
|
1404
|
+
p_uc = np.mean(PA4)
|
|
1405
|
+
else:
|
|
1406
|
+
p_uc = np.mean(PA4)
|
|
1407
|
+
pscore = np.full(n_pair, p_uc)
|
|
1408
|
+
pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
|
|
1409
|
+
# No hessian for unconditional fallback
|
|
1410
|
+
return pscore, None
|
|
1411
|
+
if epv_diagnostics_out is not None and diag:
|
|
1412
|
+
epv_diagnostics_out.update(diag)
|
|
1413
|
+
|
|
1414
|
+
pscore = np.clip(pscore, 1e-6, 1 - 1e-6)
|
|
1415
|
+
|
|
1416
|
+
# Hessian: (X'WX)^{-1} * n (matching R's compute_pscore)
|
|
1417
|
+
W = pscore * (1 - pscore)
|
|
1418
|
+
if survey_weights is not None:
|
|
1419
|
+
W = W * survey_weights
|
|
1420
|
+
XWX = covX.T @ (W[:, None] * covX)
|
|
1421
|
+
try:
|
|
1422
|
+
hessian = np.linalg.inv(XWX) * n_pair
|
|
1423
|
+
except np.linalg.LinAlgError:
|
|
1424
|
+
hessian = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0] * n_pair
|
|
1425
|
+
|
|
1426
|
+
return pscore, hessian
|
|
1427
|
+
|
|
1428
|
+
def _compute_or(
|
|
1429
|
+
self,
|
|
1430
|
+
delta_y: np.ndarray,
|
|
1431
|
+
PAa: np.ndarray,
|
|
1432
|
+
covX: np.ndarray,
|
|
1433
|
+
cho_cache: Optional[Dict],
|
|
1434
|
+
cho_key: Any,
|
|
1435
|
+
survey_weights: Optional[np.ndarray] = None,
|
|
1436
|
+
) -> np.ndarray:
|
|
1437
|
+
"""Fit OLS on control outcome changes. Returns or_delta for all pair units.
|
|
1438
|
+
|
|
1439
|
+
Honors self.rank_deficient_action for collinear covariates.
|
|
1440
|
+
When survey_weights is provided, uses WLS via solve_ols(weights=...).
|
|
1441
|
+
Cholesky cache is disabled for the survey path (cho_cache=None).
|
|
1442
|
+
"""
|
|
1443
|
+
from diff_diff.linalg import solve_ols as _solve_ols
|
|
1444
|
+
|
|
1445
|
+
control_mask = PAa > 0
|
|
1446
|
+
n_c = int(np.sum(control_mask))
|
|
1447
|
+
if n_c == 0:
|
|
1448
|
+
return np.zeros(len(delta_y))
|
|
1449
|
+
|
|
1450
|
+
X_control = covX[control_mask]
|
|
1451
|
+
y_control = delta_y[control_mask]
|
|
1452
|
+
sw_control = survey_weights[control_mask] if survey_weights is not None else None
|
|
1453
|
+
|
|
1454
|
+
# Try Cholesky cache for fast path (full-rank only)
|
|
1455
|
+
# Skipped when cho_cache is None (survey weights present)
|
|
1456
|
+
beta = None
|
|
1457
|
+
if cho_cache is not None:
|
|
1458
|
+
cached_cho = cho_cache.get(cho_key)
|
|
1459
|
+
if cached_cho is False:
|
|
1460
|
+
pass # Previously detected rank-deficient; skip Cholesky
|
|
1461
|
+
elif cached_cho is not None:
|
|
1462
|
+
from scipy import linalg as sp_linalg
|
|
1463
|
+
|
|
1464
|
+
Xty = X_control.T @ y_control
|
|
1465
|
+
beta = sp_linalg.cho_solve(cached_cho, Xty)
|
|
1466
|
+
if np.any(~np.isfinite(beta)):
|
|
1467
|
+
beta = None
|
|
1468
|
+
elif cho_key not in cho_cache:
|
|
1469
|
+
XtX = X_control.T @ X_control
|
|
1470
|
+
try:
|
|
1471
|
+
from scipy import linalg as sp_linalg
|
|
1472
|
+
|
|
1473
|
+
cho_factor = sp_linalg.cho_factor(XtX)
|
|
1474
|
+
cho_cache[cho_key] = cho_factor
|
|
1475
|
+
Xty = X_control.T @ y_control
|
|
1476
|
+
beta = sp_linalg.cho_solve(cho_factor, Xty)
|
|
1477
|
+
if np.any(~np.isfinite(beta)):
|
|
1478
|
+
beta = None
|
|
1479
|
+
except np.linalg.LinAlgError:
|
|
1480
|
+
cho_cache[cho_key] = False
|
|
1481
|
+
|
|
1482
|
+
if beta is None:
|
|
1483
|
+
# Fallback (or survey path): use solve_ols with optional weights
|
|
1484
|
+
beta, _, _ = _solve_ols(
|
|
1485
|
+
X_control,
|
|
1486
|
+
y_control,
|
|
1487
|
+
rank_deficient_action=self.rank_deficient_action,
|
|
1488
|
+
weights=sw_control,
|
|
1489
|
+
)
|
|
1490
|
+
beta = np.where(np.isfinite(beta), beta, 0.0)
|
|
1491
|
+
|
|
1492
|
+
return covX @ beta
|
|
1493
|
+
|
|
1494
|
+
# ------------------------------------------------------------------
|
|
1495
|
+
# GMM-optimal combination (matches R's att_gt GMM procedure)
|
|
1496
|
+
# ------------------------------------------------------------------
|
|
1497
|
+
|
|
1498
|
+
def _combine_gmm(
|
|
1499
|
+
self,
|
|
1500
|
+
att_vec: np.ndarray,
|
|
1501
|
+
inf_func_matrix: np.ndarray,
|
|
1502
|
+
n_units: int,
|
|
1503
|
+
) -> Tuple[float, np.ndarray, np.ndarray, float]:
|
|
1504
|
+
"""
|
|
1505
|
+
Combine comparison-group-specific estimates via GMM-optimal weights.
|
|
1506
|
+
|
|
1507
|
+
Returns (att_gmm, inf_gmm, weights, se_gmm).
|
|
1508
|
+
"""
|
|
1509
|
+
k = len(att_vec)
|
|
1510
|
+
|
|
1511
|
+
if k == 1:
|
|
1512
|
+
att_gmm = float(att_vec[0])
|
|
1513
|
+
inf_gmm = inf_func_matrix[0].copy()
|
|
1514
|
+
# R's SE: sqrt(sum(IF^2) / n^2)
|
|
1515
|
+
se_gmm = float(np.sqrt(np.sum(inf_gmm**2) / n_units**2))
|
|
1516
|
+
return att_gmm, inf_gmm, np.array([1.0]), se_gmm
|
|
1517
|
+
|
|
1518
|
+
# R: OMEGA <- cov(inf_mat_local) — sample covariance, ddof=1
|
|
1519
|
+
Omega = np.cov(inf_func_matrix)
|
|
1520
|
+
|
|
1521
|
+
ones = np.ones(k)
|
|
1522
|
+
try:
|
|
1523
|
+
Omega_inv = np.linalg.inv(Omega)
|
|
1524
|
+
except np.linalg.LinAlgError:
|
|
1525
|
+
warnings.warn(
|
|
1526
|
+
"Singular covariance matrix in GMM combination. " "Using pseudoinverse.",
|
|
1527
|
+
UserWarning,
|
|
1528
|
+
stacklevel=3,
|
|
1529
|
+
)
|
|
1530
|
+
Omega_inv = np.linalg.pinv(Omega)
|
|
1531
|
+
|
|
1532
|
+
denom = float(ones @ Omega_inv @ ones)
|
|
1533
|
+
if denom <= 0 or not np.isfinite(denom):
|
|
1534
|
+
weights = np.full(k, 1.0 / k)
|
|
1535
|
+
att_gmm = float(weights @ att_vec)
|
|
1536
|
+
inf_gmm = weights @ inf_func_matrix
|
|
1537
|
+
se_gmm = float(np.sqrt(np.sum(inf_gmm**2) / n_units**2))
|
|
1538
|
+
else:
|
|
1539
|
+
weights = (Omega_inv @ ones) / denom
|
|
1540
|
+
att_gmm = float(weights @ att_vec)
|
|
1541
|
+
inf_gmm = weights @ inf_func_matrix
|
|
1542
|
+
# R: gmm_se <- sqrt(1 / (n * sum(inv_OMEGA)))
|
|
1543
|
+
se_gmm = float(np.sqrt(1.0 / (n_units * denom)))
|
|
1544
|
+
|
|
1545
|
+
return att_gmm, inf_gmm, weights, se_gmm
|