diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
|
@@ -0,0 +1,1770 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Efficient Difference-in-Differences estimator.
|
|
3
|
+
|
|
4
|
+
Implements the ATT estimator from Chen, Sant'Anna & Xie (2025).
|
|
5
|
+
Without covariates, achieves the semiparametric efficiency bound via
|
|
6
|
+
closed-form within-group covariances. With covariates, uses a doubly
|
|
7
|
+
robust path with OLS outcome regression, sieve propensity ratios, and
|
|
8
|
+
kernel-smoothed conditional Omega*(X) (see class docstring for caveats).
|
|
9
|
+
|
|
10
|
+
Under PT-All the model is overidentified and EDiD exploits this for
|
|
11
|
+
tighter inference; under PT-Post it reduces to the standard
|
|
12
|
+
single-baseline estimator (Callaway-Sant'Anna).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import warnings
|
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
from diff_diff.efficient_did_bootstrap import (
|
|
22
|
+
EDiDBootstrapResults,
|
|
23
|
+
EfficientDiDBootstrapMixin,
|
|
24
|
+
)
|
|
25
|
+
from diff_diff.efficient_did_covariates import (
|
|
26
|
+
compute_eif_cov,
|
|
27
|
+
compute_generated_outcomes_cov,
|
|
28
|
+
compute_omega_star_conditional,
|
|
29
|
+
compute_per_unit_weights,
|
|
30
|
+
estimate_inverse_propensity_sieve,
|
|
31
|
+
estimate_outcome_regression,
|
|
32
|
+
estimate_propensity_ratio_sieve,
|
|
33
|
+
)
|
|
34
|
+
from diff_diff.efficient_did_results import EfficientDiDResults, HausmanPretestResult
|
|
35
|
+
from diff_diff.efficient_did_weights import (
|
|
36
|
+
compute_efficient_weights,
|
|
37
|
+
compute_eif_nocov,
|
|
38
|
+
compute_generated_outcomes_nocov,
|
|
39
|
+
compute_omega_star_nocov,
|
|
40
|
+
enumerate_valid_triples,
|
|
41
|
+
)
|
|
42
|
+
from diff_diff.utils import safe_inference
|
|
43
|
+
|
|
44
|
+
# Re-export for convenience
|
|
45
|
+
__all__ = ["EfficientDiD", "EfficientDiDResults", "EDiDBootstrapResults"]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _validate_and_build_cluster_mapping(
|
|
49
|
+
df: pd.DataFrame,
|
|
50
|
+
unit: str,
|
|
51
|
+
cluster: str,
|
|
52
|
+
all_units: list,
|
|
53
|
+
) -> Tuple[np.ndarray, int]:
|
|
54
|
+
"""Validate cluster column and build unit-to-cluster-index mapping.
|
|
55
|
+
|
|
56
|
+
Checks: column exists, no NaN, per-unit constancy, >= 2 clusters.
|
|
57
|
+
Returns (cluster_indices, n_clusters).
|
|
58
|
+
"""
|
|
59
|
+
if cluster not in df.columns:
|
|
60
|
+
raise ValueError(f"Cluster column '{cluster}' not found in data.")
|
|
61
|
+
if df[cluster].isna().any():
|
|
62
|
+
raise ValueError(f"Cluster column '{cluster}' contains missing values.")
|
|
63
|
+
cluster_by_unit = df.groupby(unit)[cluster]
|
|
64
|
+
if (cluster_by_unit.nunique() > 1).any():
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Cluster column '{cluster}' varies within unit. "
|
|
67
|
+
"Cluster assignment must be constant per unit."
|
|
68
|
+
)
|
|
69
|
+
cluster_col = cluster_by_unit.first().reindex(all_units).values
|
|
70
|
+
unique_clusters = np.unique(cluster_col)
|
|
71
|
+
n_clusters = len(unique_clusters)
|
|
72
|
+
if n_clusters < 2:
|
|
73
|
+
raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}.")
|
|
74
|
+
cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)}
|
|
75
|
+
indices = np.array([cluster_to_idx[c] for c in cluster_col])
|
|
76
|
+
return indices, n_clusters
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _cluster_aggregate(
|
|
80
|
+
eif_mat: np.ndarray,
|
|
81
|
+
cluster_indices: np.ndarray,
|
|
82
|
+
n_clusters: int,
|
|
83
|
+
) -> np.ndarray:
|
|
84
|
+
"""Sum EIF values within clusters and center.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
eif_mat : ndarray, shape (n_units,) or (n_units, k)
|
|
89
|
+
EIF values — 1-D for a single estimand, 2-D for multiple.
|
|
90
|
+
cluster_indices : ndarray, shape (n_units,)
|
|
91
|
+
Integer cluster assignment per unit.
|
|
92
|
+
n_clusters : int
|
|
93
|
+
Number of unique clusters.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
ndarray, shape (n_clusters,) or (n_clusters, k)
|
|
98
|
+
Centered cluster-level sums.
|
|
99
|
+
"""
|
|
100
|
+
if eif_mat.ndim == 1:
|
|
101
|
+
sums = np.bincount(cluster_indices, weights=eif_mat, minlength=n_clusters).astype(float)
|
|
102
|
+
else:
|
|
103
|
+
sums = np.column_stack(
|
|
104
|
+
[
|
|
105
|
+
np.bincount(cluster_indices, weights=eif_mat[:, j], minlength=n_clusters)
|
|
106
|
+
for j in range(eif_mat.shape[1])
|
|
107
|
+
]
|
|
108
|
+
).astype(float)
|
|
109
|
+
return sums - sums.mean(axis=0)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _compute_se_from_eif(
|
|
113
|
+
eif: np.ndarray,
|
|
114
|
+
n_units: int,
|
|
115
|
+
cluster_indices: Optional[np.ndarray] = None,
|
|
116
|
+
n_clusters: Optional[int] = None,
|
|
117
|
+
) -> float:
|
|
118
|
+
"""SE from EIF values, optionally with cluster-robust correction.
|
|
119
|
+
|
|
120
|
+
Without clusters: ``sqrt(mean(EIF^2) / n)``.
|
|
121
|
+
With clusters: Liang-Zeger sandwich — aggregate EIF within clusters,
|
|
122
|
+
center, and apply G/(G-1) small-sample correction.
|
|
123
|
+
"""
|
|
124
|
+
if cluster_indices is not None and n_clusters is not None:
|
|
125
|
+
centered = _cluster_aggregate(eif, cluster_indices, n_clusters)
|
|
126
|
+
correction = n_clusters / (n_clusters - 1) if n_clusters > 1 else 1.0
|
|
127
|
+
var = correction * np.sum(centered**2) / (n_units**2)
|
|
128
|
+
return float(np.sqrt(max(var, 0.0)))
|
|
129
|
+
return float(np.sqrt(np.mean(eif**2) / n_units))
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class EfficientDiD(EfficientDiDBootstrapMixin):
|
|
133
|
+
"""Efficient DiD estimator (Chen, Sant'Anna & Xie 2025).
|
|
134
|
+
|
|
135
|
+
Without covariates, achieves the semiparametric efficiency bound for
|
|
136
|
+
ATT(g,t) using a closed-form estimator based on within-group sample
|
|
137
|
+
means and covariances.
|
|
138
|
+
|
|
139
|
+
With covariates, uses a doubly robust path: sieve-based propensity
|
|
140
|
+
score ratios (Eq 4.1-4.2), OLS outcome regression, sieve-estimated
|
|
141
|
+
inverse propensities (algorithm step 4), and kernel-smoothed
|
|
142
|
+
conditional Omega*(X) with per-unit efficient weights (Eq 3.12).
|
|
143
|
+
The DR property ensures consistency if either the OLS outcome model
|
|
144
|
+
or the sieve propensity ratio is correctly specified. The OLS
|
|
145
|
+
working model for outcome regressions does not generically guarantee
|
|
146
|
+
the semiparametric efficiency bound (see REGISTRY.md).
|
|
147
|
+
|
|
148
|
+
Parameters
|
|
149
|
+
----------
|
|
150
|
+
pt_assumption : str, default ``"all"``
|
|
151
|
+
Parallel trends variant: ``"all"`` (overidentified, uses all
|
|
152
|
+
pre-treatment periods and comparison groups) or ``"post"``
|
|
153
|
+
(just-identified, single baseline, equivalent to CS).
|
|
154
|
+
alpha : float, default 0.05
|
|
155
|
+
Significance level.
|
|
156
|
+
cluster : str or None
|
|
157
|
+
Column name for cluster-robust SEs. When set, analytical SEs
|
|
158
|
+
use the Liang-Zeger clustered sandwich estimator on EIF values.
|
|
159
|
+
With ``n_bootstrap > 0``, bootstrap weights are generated at the
|
|
160
|
+
cluster level (all units in a cluster share the same weight).
|
|
161
|
+
control_group : str, default ``"never_treated"``
|
|
162
|
+
Which units serve as the comparison group:
|
|
163
|
+
``"never_treated"`` requires a never-treated cohort (raises if
|
|
164
|
+
none exist); ``"last_cohort"`` reclassifies the latest treatment
|
|
165
|
+
cohort as pseudo-never-treated and drops post-treatment periods
|
|
166
|
+
for that cohort. Distinct from CallawaySantAnna's
|
|
167
|
+
``"not_yet_treated"`` — see REGISTRY.md for details.
|
|
168
|
+
n_bootstrap : int, default 0
|
|
169
|
+
Number of multiplier bootstrap iterations (0 = analytical only).
|
|
170
|
+
bootstrap_weights : str, default ``"rademacher"``
|
|
171
|
+
Bootstrap weight distribution.
|
|
172
|
+
seed : int or None
|
|
173
|
+
Random seed for reproducibility.
|
|
174
|
+
anticipation : int, default 0
|
|
175
|
+
Number of anticipation periods (shifts the effective treatment
|
|
176
|
+
boundary forward by this amount).
|
|
177
|
+
sieve_k_max : int or None
|
|
178
|
+
Maximum polynomial degree for sieve ratio estimation. None = auto
|
|
179
|
+
(``min(floor(n_gp^{1/5}), 5)``). Only used with covariates.
|
|
180
|
+
sieve_criterion : str, default ``"bic"``
|
|
181
|
+
Information criterion for sieve degree selection: ``"aic"`` or ``"bic"``.
|
|
182
|
+
ratio_clip : float, default 20.0
|
|
183
|
+
Clip sieve propensity ratios to ``[1/ratio_clip, ratio_clip]``.
|
|
184
|
+
kernel_bandwidth : float or None
|
|
185
|
+
Bandwidth for Gaussian kernel in conditional Omega* estimation.
|
|
186
|
+
None = Silverman's rule-of-thumb (automatic).
|
|
187
|
+
|
|
188
|
+
Examples
|
|
189
|
+
--------
|
|
190
|
+
>>> from diff_diff import EfficientDiD
|
|
191
|
+
>>> edid = EfficientDiD(pt_assumption="all")
|
|
192
|
+
>>> results = edid.fit(data, outcome="y", unit="id", time="t",
|
|
193
|
+
... first_treat="first_treat", aggregate="all")
|
|
194
|
+
>>> results.print_summary()
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
pt_assumption: str = "all",
|
|
200
|
+
alpha: float = 0.05,
|
|
201
|
+
cluster: Optional[str] = None,
|
|
202
|
+
control_group: str = "never_treated",
|
|
203
|
+
n_bootstrap: int = 0,
|
|
204
|
+
bootstrap_weights: str = "rademacher",
|
|
205
|
+
seed: Optional[int] = None,
|
|
206
|
+
anticipation: int = 0,
|
|
207
|
+
sieve_k_max: Optional[int] = None,
|
|
208
|
+
sieve_criterion: str = "bic",
|
|
209
|
+
ratio_clip: float = 20.0,
|
|
210
|
+
kernel_bandwidth: Optional[float] = None,
|
|
211
|
+
):
|
|
212
|
+
self.pt_assumption = pt_assumption
|
|
213
|
+
self.alpha = alpha
|
|
214
|
+
self.cluster = cluster
|
|
215
|
+
self.control_group = control_group
|
|
216
|
+
self.n_bootstrap = n_bootstrap
|
|
217
|
+
self.bootstrap_weights = bootstrap_weights
|
|
218
|
+
self.seed = seed
|
|
219
|
+
self.anticipation = anticipation
|
|
220
|
+
self.sieve_k_max = sieve_k_max
|
|
221
|
+
self.sieve_criterion = sieve_criterion
|
|
222
|
+
self.ratio_clip = ratio_clip
|
|
223
|
+
self.kernel_bandwidth = kernel_bandwidth
|
|
224
|
+
self.is_fitted_ = False
|
|
225
|
+
self.results_: Optional[EfficientDiDResults] = None
|
|
226
|
+
self._unit_resolved_survey = None
|
|
227
|
+
self._validate_params()
|
|
228
|
+
|
|
229
|
+
def _validate_params(self) -> None:
|
|
230
|
+
"""Validate constrained parameters."""
|
|
231
|
+
if self.pt_assumption not in ("all", "post"):
|
|
232
|
+
raise ValueError(f"pt_assumption must be 'all' or 'post', got '{self.pt_assumption}'")
|
|
233
|
+
if self.control_group not in ("never_treated", "last_cohort"):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"control_group must be 'never_treated' or 'last_cohort', "
|
|
236
|
+
f"got '{self.control_group}'"
|
|
237
|
+
)
|
|
238
|
+
valid_weights = ("rademacher", "mammen", "webb")
|
|
239
|
+
if self.bootstrap_weights not in valid_weights:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"bootstrap_weights must be one of {valid_weights}, "
|
|
242
|
+
f"got '{self.bootstrap_weights}'"
|
|
243
|
+
)
|
|
244
|
+
if self.sieve_criterion not in ("aic", "bic"):
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"sieve_criterion must be 'aic' or 'bic', got '{self.sieve_criterion}'"
|
|
247
|
+
)
|
|
248
|
+
if not (np.isfinite(self.ratio_clip) and self.ratio_clip > 1.0):
|
|
249
|
+
raise ValueError(f"ratio_clip must be finite and > 1.0, got {self.ratio_clip}")
|
|
250
|
+
if self.kernel_bandwidth is not None:
|
|
251
|
+
if not (np.isfinite(self.kernel_bandwidth) and self.kernel_bandwidth > 0):
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"kernel_bandwidth must be finite and > 0 (or None for auto), "
|
|
254
|
+
f"got {self.kernel_bandwidth}"
|
|
255
|
+
)
|
|
256
|
+
if self.sieve_k_max is not None:
|
|
257
|
+
if not (isinstance(self.sieve_k_max, (int, np.integer)) and self.sieve_k_max > 0):
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"sieve_k_max must be a positive integer (or None for auto), "
|
|
260
|
+
f"got {self.sieve_k_max}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# -- sklearn compatibility ------------------------------------------------
|
|
264
|
+
|
|
265
|
+
def get_params(self) -> Dict[str, Any]:
|
|
266
|
+
"""Get estimator parameters (sklearn-compatible)."""
|
|
267
|
+
return {
|
|
268
|
+
"pt_assumption": self.pt_assumption,
|
|
269
|
+
"anticipation": self.anticipation,
|
|
270
|
+
"alpha": self.alpha,
|
|
271
|
+
"cluster": self.cluster,
|
|
272
|
+
"control_group": self.control_group,
|
|
273
|
+
"n_bootstrap": self.n_bootstrap,
|
|
274
|
+
"bootstrap_weights": self.bootstrap_weights,
|
|
275
|
+
"seed": self.seed,
|
|
276
|
+
"sieve_k_max": self.sieve_k_max,
|
|
277
|
+
"sieve_criterion": self.sieve_criterion,
|
|
278
|
+
"ratio_clip": self.ratio_clip,
|
|
279
|
+
"kernel_bandwidth": self.kernel_bandwidth,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
def set_params(self, **params: Any) -> "EfficientDiD":
|
|
283
|
+
"""Set estimator parameters (sklearn-compatible)."""
|
|
284
|
+
for key, value in params.items():
|
|
285
|
+
if hasattr(self, key):
|
|
286
|
+
setattr(self, key, value)
|
|
287
|
+
else:
|
|
288
|
+
raise ValueError(f"Unknown parameter: {key}")
|
|
289
|
+
self._validate_params()
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
# -- Main estimation ------------------------------------------------------
|
|
293
|
+
|
|
294
|
+
def fit(
|
|
295
|
+
self,
|
|
296
|
+
data: pd.DataFrame,
|
|
297
|
+
outcome: str,
|
|
298
|
+
unit: str,
|
|
299
|
+
time: str,
|
|
300
|
+
first_treat: str,
|
|
301
|
+
covariates: Optional[List[str]] = None,
|
|
302
|
+
aggregate: Optional[str] = None,
|
|
303
|
+
balance_e: Optional[int] = None,
|
|
304
|
+
survey_design: Optional[Any] = None,
|
|
305
|
+
store_eif: bool = False,
|
|
306
|
+
) -> EfficientDiDResults:
|
|
307
|
+
"""Fit the Efficient DiD estimator.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
data : DataFrame
|
|
312
|
+
Balanced panel data.
|
|
313
|
+
outcome : str
|
|
314
|
+
Outcome variable column name.
|
|
315
|
+
unit : str
|
|
316
|
+
Unit identifier column name.
|
|
317
|
+
time : str
|
|
318
|
+
Time period column name.
|
|
319
|
+
first_treat : str
|
|
320
|
+
Column indicating first treatment period.
|
|
321
|
+
Use 0 or ``np.inf`` for never-treated units.
|
|
322
|
+
covariates : list of str, optional
|
|
323
|
+
Column names for time-invariant unit-level covariates.
|
|
324
|
+
When provided, uses the doubly robust path (outcome regression
|
|
325
|
+
+ propensity score ratios).
|
|
326
|
+
aggregate : str, optional
|
|
327
|
+
``None``, ``"simple"``, ``"event_study"``, ``"group"``, or
|
|
328
|
+
``"all"``.
|
|
329
|
+
balance_e : int, optional
|
|
330
|
+
Balance event study at this relative period.
|
|
331
|
+
survey_design : SurveyDesign, optional
|
|
332
|
+
Survey design specification for design-based inference.
|
|
333
|
+
Applies survey weights to all means, covariances, and cohort
|
|
334
|
+
fractions, and uses Taylor Series Linearization for SE
|
|
335
|
+
estimation. Cannot be combined with ``cluster``.
|
|
336
|
+
store_eif : bool, default False
|
|
337
|
+
Store per-(g,t) EIF vectors in the results object. Used
|
|
338
|
+
internally by :meth:`hausman_pretest`; not needed for
|
|
339
|
+
normal usage.
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
EfficientDiDResults
|
|
344
|
+
|
|
345
|
+
Raises
|
|
346
|
+
------
|
|
347
|
+
ValueError
|
|
348
|
+
Missing columns, unbalanced panel, non-absorbing treatment,
|
|
349
|
+
or PT-Post without a never-treated group.
|
|
350
|
+
"""
|
|
351
|
+
self._validate_params()
|
|
352
|
+
|
|
353
|
+
if self.cluster is not None and survey_design is not None:
|
|
354
|
+
raise NotImplementedError(
|
|
355
|
+
"cluster and survey_design cannot both be set. "
|
|
356
|
+
"Use survey_design with PSU/strata for cluster-robust inference."
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Resolve survey design if provided
|
|
360
|
+
from diff_diff.survey import _resolve_survey_for_fit
|
|
361
|
+
|
|
362
|
+
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
|
|
363
|
+
_resolve_survey_for_fit(survey_design, data, "analytical")
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Validate within-unit constancy for panel survey designs
|
|
367
|
+
if resolved_survey is not None:
|
|
368
|
+
from diff_diff.survey import _validate_unit_constant_survey
|
|
369
|
+
|
|
370
|
+
_validate_unit_constant_survey(data, unit, survey_design)
|
|
371
|
+
|
|
372
|
+
# Store survey df for safe_inference calls (t-distribution with survey df)
|
|
373
|
+
self._survey_df = survey_metadata.df_survey if survey_metadata is not None else None
|
|
374
|
+
# Guard: replicate design with undefined df → NaN inference
|
|
375
|
+
if (self._survey_df is None and resolved_survey is not None
|
|
376
|
+
and hasattr(resolved_survey, 'uses_replicate_variance')
|
|
377
|
+
and resolved_survey.uses_replicate_variance):
|
|
378
|
+
self._survey_df = 0
|
|
379
|
+
|
|
380
|
+
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
|
|
381
|
+
|
|
382
|
+
# Normalize empty covariates list to None (use nocov path)
|
|
383
|
+
if covariates is not None and len(covariates) == 0:
|
|
384
|
+
covariates = None
|
|
385
|
+
use_covariates = covariates is not None
|
|
386
|
+
|
|
387
|
+
# ----- Validate inputs -----
|
|
388
|
+
required_cols = [outcome, unit, time, first_treat]
|
|
389
|
+
missing = [c for c in required_cols if c not in data.columns]
|
|
390
|
+
if missing:
|
|
391
|
+
raise ValueError(f"Missing columns: {missing}")
|
|
392
|
+
|
|
393
|
+
df = data.copy()
|
|
394
|
+
df[time] = pd.to_numeric(df[time])
|
|
395
|
+
df[first_treat] = pd.to_numeric(df[first_treat])
|
|
396
|
+
|
|
397
|
+
# Normalize never-treated: inf -> 0 internally, keep track
|
|
398
|
+
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
|
|
399
|
+
df.loc[df[first_treat] == np.inf, first_treat] = 0
|
|
400
|
+
|
|
401
|
+
time_periods = sorted(df[time].unique())
|
|
402
|
+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
|
|
403
|
+
|
|
404
|
+
# Validate balanced panel
|
|
405
|
+
unit_period_counts = df.groupby(unit)[time].nunique()
|
|
406
|
+
n_periods = len(time_periods)
|
|
407
|
+
if (unit_period_counts != n_periods).any():
|
|
408
|
+
raise ValueError(
|
|
409
|
+
"Unbalanced panel detected. EfficientDiD requires a balanced "
|
|
410
|
+
"panel where every unit is observed in every time period."
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Reject non-finite outcomes (NaN/Inf corrupt Omega*/EIF calculations)
|
|
414
|
+
non_finite_mask = ~np.isfinite(df[outcome])
|
|
415
|
+
if non_finite_mask.any():
|
|
416
|
+
n_bad = int(non_finite_mask.sum())
|
|
417
|
+
raise ValueError(
|
|
418
|
+
f"Found {n_bad} non-finite value(s) in outcome column '{outcome}'. "
|
|
419
|
+
"EfficientDiD requires finite outcomes for all unit-period observations."
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Reject duplicate (unit, time) rows
|
|
423
|
+
dup_mask = df.duplicated(subset=[unit, time], keep=False)
|
|
424
|
+
if dup_mask.any():
|
|
425
|
+
n_dups = int(dup_mask.sum())
|
|
426
|
+
raise ValueError(
|
|
427
|
+
f"Found {n_dups} duplicate ({unit}, {time}) rows. "
|
|
428
|
+
"EfficientDiD requires exactly one observation per unit-period."
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# Validate absorbing treatment (vectorized)
|
|
432
|
+
ft_nunique = df.groupby(unit)[first_treat].nunique()
|
|
433
|
+
bad_units = ft_nunique[ft_nunique > 1]
|
|
434
|
+
if len(bad_units) > 0:
|
|
435
|
+
uid = bad_units.index[0]
|
|
436
|
+
raise ValueError(
|
|
437
|
+
f"Non-absorbing treatment detected for unit {uid}: "
|
|
438
|
+
"first_treat value changes over time."
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Unit info
|
|
442
|
+
unit_info = (
|
|
443
|
+
df.groupby(unit)
|
|
444
|
+
.agg(
|
|
445
|
+
{
|
|
446
|
+
first_treat: "first",
|
|
447
|
+
"_never_treated": "first",
|
|
448
|
+
}
|
|
449
|
+
)
|
|
450
|
+
.reset_index()
|
|
451
|
+
)
|
|
452
|
+
n_treated_units = int((unit_info[first_treat] > 0).sum())
|
|
453
|
+
n_control_units = int(unit_info["_never_treated"].sum())
|
|
454
|
+
|
|
455
|
+
# Control group logic
|
|
456
|
+
if self.control_group == "last_cohort":
|
|
457
|
+
# Always reclassify last cohort as pseudo-control when requested
|
|
458
|
+
if not treatment_groups:
|
|
459
|
+
raise ValueError(
|
|
460
|
+
"No treated cohorts found. control_group='last_cohort' requires "
|
|
461
|
+
"at least 2 treatment cohorts."
|
|
462
|
+
)
|
|
463
|
+
last_g = max(treatment_groups)
|
|
464
|
+
treatment_groups = [g for g in treatment_groups if g != last_g]
|
|
465
|
+
if not treatment_groups:
|
|
466
|
+
raise ValueError("Only one treatment cohort; cannot use last_cohort control.")
|
|
467
|
+
effective_last = last_g - self.anticipation
|
|
468
|
+
time_periods = [t for t in time_periods if t < effective_last]
|
|
469
|
+
if len(time_periods) < 2:
|
|
470
|
+
raise ValueError(
|
|
471
|
+
"Fewer than 2 time periods remain after trimming for last_cohort control."
|
|
472
|
+
)
|
|
473
|
+
unit_info.loc[unit_info[first_treat] == last_g, first_treat] = 0
|
|
474
|
+
unit_info.loc[unit_info[first_treat] == 0, "_never_treated"] = True
|
|
475
|
+
n_treated_units = int((unit_info[first_treat] > 0).sum())
|
|
476
|
+
n_control_units = int(unit_info["_never_treated"].sum())
|
|
477
|
+
elif n_control_units == 0:
|
|
478
|
+
raise ValueError(
|
|
479
|
+
"No never-treated units found. Use control_group='last_cohort' "
|
|
480
|
+
"to use the last treatment cohort as a pseudo-control."
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# ----- Prepare data -----
|
|
484
|
+
all_units = sorted(df[unit].unique())
|
|
485
|
+
n_units = len(all_units)
|
|
486
|
+
|
|
487
|
+
# Build unit-to-first-panel-row index aligned to all_units (sorted)
|
|
488
|
+
# order. The previous approach (groupby cumcount == 0) yielded
|
|
489
|
+
# first-appearance order which can differ from sorted order when the
|
|
490
|
+
# input DataFrame is not pre-sorted by unit.
|
|
491
|
+
first_pos: Dict[Any, int] = {}
|
|
492
|
+
for i, u in enumerate(df[unit].values):
|
|
493
|
+
if u not in first_pos:
|
|
494
|
+
first_pos[u] = i
|
|
495
|
+
self._unit_first_panel_row = np.array([first_pos[u] for u in all_units])
|
|
496
|
+
|
|
497
|
+
# Build unit-level ResolvedSurveyDesign once (avoids repeated
|
|
498
|
+
# construction in _compute_survey_eif_se and ensures consistent
|
|
499
|
+
# unit-level df for safe_inference t-distribution).
|
|
500
|
+
if resolved_survey is not None:
|
|
501
|
+
from diff_diff.survey import ResolvedSurveyDesign
|
|
502
|
+
|
|
503
|
+
row_idx = self._unit_first_panel_row
|
|
504
|
+
unit_weights_s = resolved_survey.weights[row_idx]
|
|
505
|
+
unit_strata = (
|
|
506
|
+
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
|
|
507
|
+
)
|
|
508
|
+
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
|
|
509
|
+
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
|
|
510
|
+
n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
|
|
511
|
+
n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
|
|
512
|
+
self._unit_resolved_survey = resolved_survey.subset_to_units(
|
|
513
|
+
row_idx, unit_weights_s, unit_strata, unit_psu, unit_fpc,
|
|
514
|
+
n_strata_u, n_psu_u,
|
|
515
|
+
)
|
|
516
|
+
# Use unit-level df (not panel-level) for t-distribution
|
|
517
|
+
self._survey_df = self._unit_resolved_survey.df_survey
|
|
518
|
+
# Re-apply replicate guard: undefined df → NaN inference
|
|
519
|
+
if (self._survey_df is None
|
|
520
|
+
and self._unit_resolved_survey.uses_replicate_variance):
|
|
521
|
+
self._survey_df = 0
|
|
522
|
+
else:
|
|
523
|
+
self._unit_resolved_survey = None
|
|
524
|
+
|
|
525
|
+
# Build cluster mapping if cluster-robust SEs requested
|
|
526
|
+
if self.cluster is not None:
|
|
527
|
+
unit_cluster_indices, n_clusters = _validate_and_build_cluster_mapping(
|
|
528
|
+
df, unit, self.cluster, all_units
|
|
529
|
+
)
|
|
530
|
+
if n_clusters < 50:
|
|
531
|
+
warnings.warn(
|
|
532
|
+
f"Only {n_clusters} clusters. Analytical clustered SEs may "
|
|
533
|
+
"be unreliable. Consider n_bootstrap > 0 for cluster "
|
|
534
|
+
"bootstrap inference.",
|
|
535
|
+
UserWarning,
|
|
536
|
+
stacklevel=2,
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
unit_cluster_indices = None
|
|
540
|
+
n_clusters = None
|
|
541
|
+
|
|
542
|
+
period_to_col = {p: i for i, p in enumerate(time_periods)}
|
|
543
|
+
period_1 = time_periods[0]
|
|
544
|
+
period_1_col = period_to_col[period_1]
|
|
545
|
+
|
|
546
|
+
# Pivot outcome to wide matrix (n_units, n_periods)
|
|
547
|
+
pivot = df.pivot(index=unit, columns=time, values=outcome)
|
|
548
|
+
# Reindex to match all_units ordering and time_periods column order
|
|
549
|
+
pivot = pivot.reindex(index=all_units, columns=time_periods)
|
|
550
|
+
outcome_wide = pivot.values.astype(float)
|
|
551
|
+
|
|
552
|
+
# Build cohort masks and fractions
|
|
553
|
+
unit_info_indexed = unit_info.set_index(unit)
|
|
554
|
+
unit_cohorts = unit_info_indexed.reindex(all_units)[first_treat].values.astype(
|
|
555
|
+
float
|
|
556
|
+
) # 0 = never-treated
|
|
557
|
+
|
|
558
|
+
cohort_masks: Dict[float, np.ndarray] = {}
|
|
559
|
+
for g in treatment_groups:
|
|
560
|
+
cohort_masks[g] = unit_cohorts == g
|
|
561
|
+
never_treated_mask = unit_cohorts == 0
|
|
562
|
+
cohort_masks[np.inf] = never_treated_mask # also keyed by inf sentinel
|
|
563
|
+
|
|
564
|
+
# ----- Unit-level survey weights -----
|
|
565
|
+
# Survey weights in the panel are at obs level (unit x time).
|
|
566
|
+
# EfficientDiD works at unit level. Extract one weight per unit
|
|
567
|
+
# by taking the first observation per unit (balanced panel, so
|
|
568
|
+
# weights should be constant within unit).
|
|
569
|
+
unit_level_weights: Optional[np.ndarray] = None
|
|
570
|
+
if resolved_survey is not None:
|
|
571
|
+
# Use the resolved survey's weights (already normalized per weight_type)
|
|
572
|
+
# subset to unit level via _unit_first_panel_row (aligned to all_units)
|
|
573
|
+
unit_level_weights = self._unit_resolved_survey.weights
|
|
574
|
+
self._unit_level_weights = unit_level_weights
|
|
575
|
+
|
|
576
|
+
cohort_fractions: Dict[float, float] = {}
|
|
577
|
+
if unit_level_weights is not None:
|
|
578
|
+
# Survey-weighted cohort fractions: sum(w_i for i in cohort) / sum(w_i)
|
|
579
|
+
total_w = float(np.sum(unit_level_weights))
|
|
580
|
+
for g in treatment_groups:
|
|
581
|
+
cohort_fractions[g] = float(np.sum(unit_level_weights[cohort_masks[g]])) / total_w
|
|
582
|
+
cohort_fractions[np.inf] = (
|
|
583
|
+
float(np.sum(unit_level_weights[never_treated_mask])) / total_w
|
|
584
|
+
)
|
|
585
|
+
else:
|
|
586
|
+
for g in treatment_groups:
|
|
587
|
+
cohort_fractions[g] = float(np.sum(cohort_masks[g])) / n_units
|
|
588
|
+
cohort_fractions[np.inf] = float(np.sum(never_treated_mask)) / n_units
|
|
589
|
+
|
|
590
|
+
# ----- Small cohort warnings -----
|
|
591
|
+
for g in treatment_groups:
|
|
592
|
+
n_g = int(np.sum(cohort_masks[g]))
|
|
593
|
+
frac_g = cohort_fractions[g]
|
|
594
|
+
if n_g < 2:
|
|
595
|
+
warnings.warn(
|
|
596
|
+
f"Cohort {g} has only {n_g} unit. Omega* inversion and "
|
|
597
|
+
"EIF computation may be numerically unstable.",
|
|
598
|
+
UserWarning,
|
|
599
|
+
stacklevel=2,
|
|
600
|
+
)
|
|
601
|
+
elif frac_g < 0.01:
|
|
602
|
+
warnings.warn(
|
|
603
|
+
f"Cohort {g} represents {frac_g:.1%} of the sample (< 1%). "
|
|
604
|
+
"Efficient weights may be imprecise.",
|
|
605
|
+
UserWarning,
|
|
606
|
+
stacklevel=2,
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
# Guard: never-treated with zero survey weight → no valid comparisons
|
|
610
|
+
# Applies to both covariates (DR nuisance) and nocov (weighted means) paths
|
|
611
|
+
if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
|
|
612
|
+
raise ValueError(
|
|
613
|
+
"Never-treated group has zero survey weight. EfficientDiD "
|
|
614
|
+
"requires a never-treated control group with positive "
|
|
615
|
+
"survey weight for estimation."
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# ----- Covariate preparation (if provided) -----
|
|
619
|
+
covariate_matrix: Optional[np.ndarray] = None
|
|
620
|
+
m_hat_cache: Dict[Tuple, np.ndarray] = {}
|
|
621
|
+
r_hat_cache: Dict[Tuple[float, float], np.ndarray] = {}
|
|
622
|
+
s_hat_cache: Dict[float, np.ndarray] = {} # inverse propensities per group
|
|
623
|
+
|
|
624
|
+
if use_covariates:
|
|
625
|
+
assert covariates is not None # for type narrowing
|
|
626
|
+
|
|
627
|
+
# Validate covariate columns exist
|
|
628
|
+
missing_cov = [c for c in covariates if c not in data.columns]
|
|
629
|
+
if missing_cov:
|
|
630
|
+
raise ValueError(f"Missing covariate columns: {missing_cov}")
|
|
631
|
+
|
|
632
|
+
# Validate no NaN/Inf in covariates
|
|
633
|
+
for col_name in covariates:
|
|
634
|
+
non_finite_cov = ~np.isfinite(pd.to_numeric(df[col_name], errors="coerce"))
|
|
635
|
+
if non_finite_cov.any():
|
|
636
|
+
n_bad = int(non_finite_cov.sum())
|
|
637
|
+
raise ValueError(
|
|
638
|
+
f"Found {n_bad} non-finite value(s) in covariate column "
|
|
639
|
+
f"'{col_name}'. Covariates must be finite."
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# Validate time-invariance: covariates must be constant within each unit
|
|
643
|
+
for col_name in covariates:
|
|
644
|
+
cov_nunique = df.groupby(unit)[col_name].nunique()
|
|
645
|
+
varying = cov_nunique[cov_nunique > 1]
|
|
646
|
+
if len(varying) > 0:
|
|
647
|
+
uid = varying.index[0]
|
|
648
|
+
raise ValueError(
|
|
649
|
+
f"Covariate '{col_name}' varies over time for unit {uid}. "
|
|
650
|
+
"EfficientDiD requires time-invariant covariates. "
|
|
651
|
+
"Extract base-period values before calling fit()."
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Extract unit-level covariate matrix from period_1 observations
|
|
655
|
+
base_df = df[df[time] == period_1].set_index(unit).reindex(all_units)
|
|
656
|
+
covariate_matrix = base_df[list(covariates)].values.astype(float)
|
|
657
|
+
|
|
658
|
+
# ----- Core estimation: ATT(g, t) for each target -----
|
|
659
|
+
# Precompute per-group unit counts (avoid repeated np.sum in loop)
|
|
660
|
+
n_treated_per_g = {g: int(np.sum(cohort_masks[g])) for g in treatment_groups}
|
|
661
|
+
n_control_count = int(np.sum(never_treated_mask))
|
|
662
|
+
|
|
663
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] = {}
|
|
664
|
+
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray] = {}
|
|
665
|
+
stored_weights: Dict[Tuple[Any, Any], np.ndarray] = {}
|
|
666
|
+
stored_cond: Dict[Tuple[Any, Any], float] = {}
|
|
667
|
+
|
|
668
|
+
for g in treatment_groups:
|
|
669
|
+
# Under PT-Post, use per-group baseline Y_{g-1-anticipation}
|
|
670
|
+
# instead of the universal Y_1. This implements the weaker
|
|
671
|
+
# PT-Post assumption (parallel trends only from g-1 onward),
|
|
672
|
+
# matching the Callaway-Sant'Anna estimator exactly.
|
|
673
|
+
if self.pt_assumption == "post":
|
|
674
|
+
effective_base = g - 1 - self.anticipation
|
|
675
|
+
if effective_base not in period_to_col:
|
|
676
|
+
warnings.warn(
|
|
677
|
+
f"Cohort g={g} dropped: baseline period {effective_base} "
|
|
678
|
+
f"(g-1-anticipation) is not in the data.",
|
|
679
|
+
UserWarning,
|
|
680
|
+
stacklevel=2,
|
|
681
|
+
)
|
|
682
|
+
continue
|
|
683
|
+
effective_p1_col = period_to_col[effective_base]
|
|
684
|
+
else:
|
|
685
|
+
effective_p1_col = period_1_col
|
|
686
|
+
|
|
687
|
+
# Guard: skip cohorts with zero survey weight (all units zero-weighted)
|
|
688
|
+
if cohort_fractions[g] <= 0:
|
|
689
|
+
warnings.warn(
|
|
690
|
+
f"Cohort {g} has zero survey weight; skipping.",
|
|
691
|
+
UserWarning,
|
|
692
|
+
stacklevel=2,
|
|
693
|
+
)
|
|
694
|
+
continue
|
|
695
|
+
|
|
696
|
+
# Estimate all (g, t) cells including pre-treatment. Under PT-Post,
|
|
697
|
+
# pre-treatment cells serve as placebo/pre-trend diagnostics, matching
|
|
698
|
+
# the CallawaySantAnna implementation. Users filter to t >= g for
|
|
699
|
+
# post-treatment effects; pre-treatment cells are clearly labeled by
|
|
700
|
+
# their (g, t) coordinates in the results object.
|
|
701
|
+
for t in time_periods:
|
|
702
|
+
# Skip period_1 — it's the universal reference baseline,
|
|
703
|
+
# not a target period
|
|
704
|
+
if t == period_1:
|
|
705
|
+
continue
|
|
706
|
+
|
|
707
|
+
# Enumerate valid comparison pairs
|
|
708
|
+
pairs = enumerate_valid_triples(
|
|
709
|
+
target_g=g,
|
|
710
|
+
treatment_groups=treatment_groups,
|
|
711
|
+
time_periods=time_periods,
|
|
712
|
+
period_1=period_1,
|
|
713
|
+
pt_assumption=self.pt_assumption,
|
|
714
|
+
anticipation=self.anticipation,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
# Filter out comparison pairs with zero survey weight
|
|
718
|
+
if unit_level_weights is not None and pairs:
|
|
719
|
+
pairs = [
|
|
720
|
+
(gp, tpre) for gp, tpre in pairs
|
|
721
|
+
if np.sum(unit_level_weights[
|
|
722
|
+
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
|
|
723
|
+
]) > 0
|
|
724
|
+
]
|
|
725
|
+
|
|
726
|
+
if not pairs:
|
|
727
|
+
warnings.warn(
|
|
728
|
+
f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
|
|
729
|
+
UserWarning,
|
|
730
|
+
stacklevel=2,
|
|
731
|
+
)
|
|
732
|
+
t_stat, p_val, ci = np.nan, np.nan, (np.nan, np.nan)
|
|
733
|
+
group_time_effects[(g, t)] = {
|
|
734
|
+
"effect": np.nan,
|
|
735
|
+
"se": np.nan,
|
|
736
|
+
"t_stat": t_stat,
|
|
737
|
+
"p_value": p_val,
|
|
738
|
+
"conf_int": ci,
|
|
739
|
+
"n_treated": n_treated_per_g[g],
|
|
740
|
+
"n_control": n_control_count,
|
|
741
|
+
}
|
|
742
|
+
eif_by_gt[(g, t)] = np.zeros(n_units)
|
|
743
|
+
continue
|
|
744
|
+
|
|
745
|
+
if use_covariates:
|
|
746
|
+
assert covariate_matrix is not None
|
|
747
|
+
t_col_val = period_to_col[t]
|
|
748
|
+
|
|
749
|
+
# Lazily populate nuisance caches for this (g, t)
|
|
750
|
+
for gp, tpre in pairs:
|
|
751
|
+
tpre_col_val = period_to_col[tpre]
|
|
752
|
+
# m_{inf, t, tpre}(X)
|
|
753
|
+
key_inf_t = (np.inf, t_col_val, tpre_col_val)
|
|
754
|
+
if key_inf_t not in m_hat_cache:
|
|
755
|
+
m_hat_cache[key_inf_t] = estimate_outcome_regression(
|
|
756
|
+
outcome_wide,
|
|
757
|
+
covariate_matrix,
|
|
758
|
+
never_treated_mask,
|
|
759
|
+
t_col_val,
|
|
760
|
+
tpre_col_val,
|
|
761
|
+
unit_weights=unit_level_weights,
|
|
762
|
+
)
|
|
763
|
+
# m_{g', tpre, 1}(X)
|
|
764
|
+
key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
|
|
765
|
+
if key_gp_tpre not in m_hat_cache:
|
|
766
|
+
gp_mask_for_reg = (
|
|
767
|
+
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
|
|
768
|
+
)
|
|
769
|
+
m_hat_cache[key_gp_tpre] = estimate_outcome_regression(
|
|
770
|
+
outcome_wide,
|
|
771
|
+
covariate_matrix,
|
|
772
|
+
gp_mask_for_reg,
|
|
773
|
+
tpre_col_val,
|
|
774
|
+
effective_p1_col,
|
|
775
|
+
unit_weights=unit_level_weights,
|
|
776
|
+
)
|
|
777
|
+
# r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
|
|
778
|
+
for comp in {np.inf, gp}:
|
|
779
|
+
rkey = (g, comp)
|
|
780
|
+
if rkey not in r_hat_cache:
|
|
781
|
+
comp_mask = (
|
|
782
|
+
never_treated_mask if np.isinf(comp) else cohort_masks[comp]
|
|
783
|
+
)
|
|
784
|
+
r_hat_cache[rkey] = estimate_propensity_ratio_sieve(
|
|
785
|
+
covariate_matrix,
|
|
786
|
+
cohort_masks[g],
|
|
787
|
+
comp_mask,
|
|
788
|
+
k_max=self.sieve_k_max,
|
|
789
|
+
criterion=self.sieve_criterion,
|
|
790
|
+
ratio_clip=self.ratio_clip,
|
|
791
|
+
unit_weights=unit_level_weights,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# Per-unit DR generated outcomes: shape (n_units, H)
|
|
795
|
+
gen_out = compute_generated_outcomes_cov(
|
|
796
|
+
target_g=g,
|
|
797
|
+
target_t=t,
|
|
798
|
+
valid_pairs=pairs,
|
|
799
|
+
outcome_wide=outcome_wide,
|
|
800
|
+
cohort_masks=cohort_masks,
|
|
801
|
+
never_treated_mask=never_treated_mask,
|
|
802
|
+
period_to_col=period_to_col,
|
|
803
|
+
period_1_col=effective_p1_col,
|
|
804
|
+
cohort_fractions=cohort_fractions,
|
|
805
|
+
m_hat_cache=m_hat_cache,
|
|
806
|
+
r_hat_cache=r_hat_cache,
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
y_hat = np.mean(gen_out, axis=0) # shape (H,)
|
|
810
|
+
|
|
811
|
+
# Inverse propensity estimation (algorithm step 4)
|
|
812
|
+
# s_hat_{g'}(X) = 1/p_{g'}(X) for Eq 3.12 scaling
|
|
813
|
+
for group_id in {g, np.inf} | {gp for gp, _ in pairs}:
|
|
814
|
+
if group_id not in s_hat_cache:
|
|
815
|
+
group_mask_s = (
|
|
816
|
+
never_treated_mask if np.isinf(group_id) else cohort_masks[group_id]
|
|
817
|
+
)
|
|
818
|
+
s_hat_cache[group_id] = estimate_inverse_propensity_sieve(
|
|
819
|
+
covariate_matrix,
|
|
820
|
+
group_mask_s,
|
|
821
|
+
k_max=self.sieve_k_max,
|
|
822
|
+
criterion=self.sieve_criterion,
|
|
823
|
+
unit_weights=unit_level_weights,
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Conditional Omega*(X) with per-unit propensities (Eq 3.12)
|
|
827
|
+
omega_cond = compute_omega_star_conditional(
|
|
828
|
+
target_g=g,
|
|
829
|
+
target_t=t,
|
|
830
|
+
valid_pairs=pairs,
|
|
831
|
+
outcome_wide=outcome_wide,
|
|
832
|
+
cohort_masks=cohort_masks,
|
|
833
|
+
never_treated_mask=never_treated_mask,
|
|
834
|
+
period_to_col=period_to_col,
|
|
835
|
+
period_1_col=effective_p1_col,
|
|
836
|
+
cohort_fractions=cohort_fractions,
|
|
837
|
+
covariate_matrix=covariate_matrix,
|
|
838
|
+
s_hat_cache=s_hat_cache,
|
|
839
|
+
bandwidth=self.kernel_bandwidth,
|
|
840
|
+
unit_weights=unit_level_weights,
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
# Per-unit weights: (n_units, H)
|
|
844
|
+
per_unit_w = compute_per_unit_weights(omega_cond)
|
|
845
|
+
|
|
846
|
+
# ATT = (survey-)weighted mean of per-unit DR scores
|
|
847
|
+
if per_unit_w.shape[1] > 0:
|
|
848
|
+
per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
|
|
849
|
+
if unit_level_weights is not None:
|
|
850
|
+
att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
|
|
851
|
+
else:
|
|
852
|
+
att_gt = float(np.mean(per_unit_scores))
|
|
853
|
+
else:
|
|
854
|
+
att_gt = np.nan
|
|
855
|
+
|
|
856
|
+
# EIF with per-unit weights (Remark 4.2: plug-in valid)
|
|
857
|
+
# Center on scalar ATT, not per-pair means (ensures mean(EIF) ≈ 0)
|
|
858
|
+
eif_vals = compute_eif_cov(per_unit_w, gen_out, att_gt, n_units)
|
|
859
|
+
eif_by_gt[(g, t)] = eif_vals
|
|
860
|
+
else:
|
|
861
|
+
# No-covariates path (closed-form)
|
|
862
|
+
omega = compute_omega_star_nocov(
|
|
863
|
+
target_g=g,
|
|
864
|
+
target_t=t,
|
|
865
|
+
valid_pairs=pairs,
|
|
866
|
+
outcome_wide=outcome_wide,
|
|
867
|
+
cohort_masks=cohort_masks,
|
|
868
|
+
never_treated_mask=never_treated_mask,
|
|
869
|
+
period_to_col=period_to_col,
|
|
870
|
+
period_1_col=effective_p1_col,
|
|
871
|
+
cohort_fractions=cohort_fractions,
|
|
872
|
+
unit_weights=unit_level_weights,
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
weights, _, cond_num = compute_efficient_weights(omega)
|
|
876
|
+
stored_weights[(g, t)] = weights
|
|
877
|
+
if omega.size > 0:
|
|
878
|
+
stored_cond[(g, t)] = cond_num
|
|
879
|
+
|
|
880
|
+
y_hat = compute_generated_outcomes_nocov(
|
|
881
|
+
target_g=g,
|
|
882
|
+
target_t=t,
|
|
883
|
+
valid_pairs=pairs,
|
|
884
|
+
outcome_wide=outcome_wide,
|
|
885
|
+
cohort_masks=cohort_masks,
|
|
886
|
+
never_treated_mask=never_treated_mask,
|
|
887
|
+
period_to_col=period_to_col,
|
|
888
|
+
period_1_col=effective_p1_col,
|
|
889
|
+
unit_weights=unit_level_weights,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
att_gt = float(weights @ y_hat) if len(weights) > 0 else np.nan
|
|
893
|
+
|
|
894
|
+
eif_vals = compute_eif_nocov(
|
|
895
|
+
target_g=g,
|
|
896
|
+
target_t=t,
|
|
897
|
+
weights=weights,
|
|
898
|
+
valid_pairs=pairs,
|
|
899
|
+
outcome_wide=outcome_wide,
|
|
900
|
+
cohort_masks=cohort_masks,
|
|
901
|
+
never_treated_mask=never_treated_mask,
|
|
902
|
+
period_to_col=period_to_col,
|
|
903
|
+
period_1_col=effective_p1_col,
|
|
904
|
+
cohort_fractions=cohort_fractions,
|
|
905
|
+
n_units=n_units,
|
|
906
|
+
unit_weights=unit_level_weights,
|
|
907
|
+
)
|
|
908
|
+
eif_by_gt[(g, t)] = eif_vals
|
|
909
|
+
|
|
910
|
+
# Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
|
|
911
|
+
# With survey: use TSL variance via compute_survey_vcov
|
|
912
|
+
if self._unit_resolved_survey is not None:
|
|
913
|
+
se_gt = self._compute_survey_eif_se(eif_vals)
|
|
914
|
+
else:
|
|
915
|
+
se_gt = _compute_se_from_eif(
|
|
916
|
+
eif_vals, n_units, unit_cluster_indices, n_clusters
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
t_stat, p_val, ci = safe_inference(
|
|
920
|
+
att_gt, se_gt, alpha=self.alpha, df=self._survey_df
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
group_time_effects[(g, t)] = {
|
|
924
|
+
"effect": att_gt,
|
|
925
|
+
"se": se_gt,
|
|
926
|
+
"t_stat": t_stat,
|
|
927
|
+
"p_value": p_val,
|
|
928
|
+
"conf_int": ci,
|
|
929
|
+
"n_treated": int(np.sum(cohort_masks[g])),
|
|
930
|
+
"n_control": int(np.sum(never_treated_mask)),
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
if not group_time_effects:
|
|
934
|
+
raise ValueError(
|
|
935
|
+
"Could not estimate any group-time effects. "
|
|
936
|
+
"Check data has sufficient observations."
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# ----- Aggregation -----
|
|
940
|
+
overall_att, overall_se = self._aggregate_overall(
|
|
941
|
+
group_time_effects,
|
|
942
|
+
eif_by_gt,
|
|
943
|
+
n_units,
|
|
944
|
+
cohort_fractions,
|
|
945
|
+
unit_cohorts,
|
|
946
|
+
cluster_indices=unit_cluster_indices,
|
|
947
|
+
n_clusters=n_clusters,
|
|
948
|
+
)
|
|
949
|
+
overall_t, overall_p, overall_ci = safe_inference(
|
|
950
|
+
overall_att, overall_se, alpha=self.alpha, df=self._survey_df
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
event_study_effects = None
|
|
954
|
+
group_effects = None
|
|
955
|
+
|
|
956
|
+
if aggregate in ("event_study", "all"):
|
|
957
|
+
event_study_effects = self._aggregate_event_study(
|
|
958
|
+
group_time_effects,
|
|
959
|
+
eif_by_gt,
|
|
960
|
+
n_units,
|
|
961
|
+
cohort_fractions,
|
|
962
|
+
treatment_groups,
|
|
963
|
+
time_periods,
|
|
964
|
+
balance_e,
|
|
965
|
+
unit_cohorts=unit_cohorts,
|
|
966
|
+
cluster_indices=unit_cluster_indices,
|
|
967
|
+
n_clusters=n_clusters,
|
|
968
|
+
)
|
|
969
|
+
if aggregate in ("group", "all"):
|
|
970
|
+
group_effects = self._aggregate_by_group(
|
|
971
|
+
group_time_effects,
|
|
972
|
+
eif_by_gt,
|
|
973
|
+
n_units,
|
|
974
|
+
cohort_fractions,
|
|
975
|
+
treatment_groups,
|
|
976
|
+
unit_cohorts=unit_cohorts,
|
|
977
|
+
cluster_indices=unit_cluster_indices,
|
|
978
|
+
n_clusters=n_clusters,
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# ----- Bootstrap -----
|
|
982
|
+
# Reject replicate-weight designs for bootstrap — replicate variance
|
|
983
|
+
# is an analytical alternative, not compatible with bootstrap
|
|
984
|
+
if (
|
|
985
|
+
self.n_bootstrap > 0
|
|
986
|
+
and self._unit_resolved_survey is not None
|
|
987
|
+
and self._unit_resolved_survey.uses_replicate_variance
|
|
988
|
+
):
|
|
989
|
+
raise NotImplementedError(
|
|
990
|
+
"EfficientDiD bootstrap (n_bootstrap > 0) is not supported "
|
|
991
|
+
"with replicate-weight survey designs. Replicate weights provide "
|
|
992
|
+
"analytical variance; use n_bootstrap=0 instead."
|
|
993
|
+
)
|
|
994
|
+
bootstrap_results = None
|
|
995
|
+
if self.n_bootstrap > 0 and eif_by_gt:
|
|
996
|
+
bootstrap_results = self._run_multiplier_bootstrap(
|
|
997
|
+
group_time_effects=group_time_effects,
|
|
998
|
+
eif_by_gt=eif_by_gt,
|
|
999
|
+
n_units=n_units,
|
|
1000
|
+
aggregate=aggregate,
|
|
1001
|
+
balance_e=balance_e,
|
|
1002
|
+
treatment_groups=treatment_groups,
|
|
1003
|
+
cohort_fractions=cohort_fractions,
|
|
1004
|
+
cluster_indices=unit_cluster_indices,
|
|
1005
|
+
n_clusters=n_clusters,
|
|
1006
|
+
resolved_survey=self._unit_resolved_survey,
|
|
1007
|
+
unit_level_weights=self._unit_level_weights,
|
|
1008
|
+
)
|
|
1009
|
+
# Update estimates with bootstrap inference
|
|
1010
|
+
overall_se = bootstrap_results.overall_att_se
|
|
1011
|
+
overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
|
|
1012
|
+
overall_p = bootstrap_results.overall_att_p_value
|
|
1013
|
+
overall_ci = bootstrap_results.overall_att_ci
|
|
1014
|
+
|
|
1015
|
+
for gt in group_time_effects:
|
|
1016
|
+
if gt in bootstrap_results.group_time_ses:
|
|
1017
|
+
group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt]
|
|
1018
|
+
group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt]
|
|
1019
|
+
group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt]
|
|
1020
|
+
eff = float(group_time_effects[gt]["effect"])
|
|
1021
|
+
se = float(group_time_effects[gt]["se"])
|
|
1022
|
+
group_time_effects[gt]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
|
|
1023
|
+
|
|
1024
|
+
es_cis = bootstrap_results.event_study_cis
|
|
1025
|
+
es_pvs = bootstrap_results.event_study_p_values
|
|
1026
|
+
if (
|
|
1027
|
+
event_study_effects is not None
|
|
1028
|
+
and bootstrap_results.event_study_ses is not None
|
|
1029
|
+
and es_cis is not None
|
|
1030
|
+
and es_pvs is not None
|
|
1031
|
+
):
|
|
1032
|
+
for e in event_study_effects:
|
|
1033
|
+
if e in bootstrap_results.event_study_ses:
|
|
1034
|
+
event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
|
|
1035
|
+
event_study_effects[e]["conf_int"] = es_cis[e]
|
|
1036
|
+
event_study_effects[e]["p_value"] = es_pvs[e]
|
|
1037
|
+
eff = float(event_study_effects[e]["effect"])
|
|
1038
|
+
se = float(event_study_effects[e]["se"])
|
|
1039
|
+
event_study_effects[e]["t_stat"] = safe_inference(
|
|
1040
|
+
eff, se, alpha=self.alpha
|
|
1041
|
+
)[0]
|
|
1042
|
+
|
|
1043
|
+
g_cis = bootstrap_results.group_effect_cis
|
|
1044
|
+
g_pvs = bootstrap_results.group_effect_p_values
|
|
1045
|
+
if (
|
|
1046
|
+
group_effects is not None
|
|
1047
|
+
and bootstrap_results.group_effect_ses is not None
|
|
1048
|
+
and g_cis is not None
|
|
1049
|
+
and g_pvs is not None
|
|
1050
|
+
):
|
|
1051
|
+
for g in group_effects:
|
|
1052
|
+
if g in bootstrap_results.group_effect_ses:
|
|
1053
|
+
group_effects[g]["se"] = bootstrap_results.group_effect_ses[g]
|
|
1054
|
+
group_effects[g]["conf_int"] = g_cis[g]
|
|
1055
|
+
group_effects[g]["p_value"] = g_pvs[g]
|
|
1056
|
+
eff = float(group_effects[g]["effect"])
|
|
1057
|
+
se = float(group_effects[g]["se"])
|
|
1058
|
+
group_effects[g]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
|
|
1059
|
+
|
|
1060
|
+
# ----- Build results -----
|
|
1061
|
+
self.results_ = EfficientDiDResults(
|
|
1062
|
+
group_time_effects=group_time_effects,
|
|
1063
|
+
overall_att=overall_att,
|
|
1064
|
+
overall_se=overall_se,
|
|
1065
|
+
overall_t_stat=overall_t,
|
|
1066
|
+
overall_p_value=overall_p,
|
|
1067
|
+
overall_conf_int=overall_ci,
|
|
1068
|
+
groups=treatment_groups,
|
|
1069
|
+
time_periods=time_periods,
|
|
1070
|
+
n_obs=n_units * len(time_periods),
|
|
1071
|
+
n_treated_units=n_treated_units,
|
|
1072
|
+
n_control_units=n_control_units,
|
|
1073
|
+
alpha=self.alpha,
|
|
1074
|
+
pt_assumption=self.pt_assumption,
|
|
1075
|
+
anticipation=self.anticipation,
|
|
1076
|
+
n_bootstrap=self.n_bootstrap,
|
|
1077
|
+
bootstrap_weights=self.bootstrap_weights,
|
|
1078
|
+
seed=self.seed,
|
|
1079
|
+
event_study_effects=event_study_effects,
|
|
1080
|
+
group_effects=group_effects,
|
|
1081
|
+
efficient_weights=stored_weights if stored_weights else None,
|
|
1082
|
+
omega_condition_numbers=stored_cond if stored_cond else None,
|
|
1083
|
+
control_group=self.control_group,
|
|
1084
|
+
influence_functions=eif_by_gt if store_eif else None,
|
|
1085
|
+
bootstrap_results=bootstrap_results,
|
|
1086
|
+
estimation_path="dr" if use_covariates else "nocov",
|
|
1087
|
+
sieve_k_max=self.sieve_k_max,
|
|
1088
|
+
sieve_criterion=self.sieve_criterion,
|
|
1089
|
+
ratio_clip=self.ratio_clip,
|
|
1090
|
+
kernel_bandwidth=self.kernel_bandwidth,
|
|
1091
|
+
survey_metadata=(
|
|
1092
|
+
self._recompute_unit_survey_metadata(survey_metadata)
|
|
1093
|
+
if survey_metadata is not None
|
|
1094
|
+
else None
|
|
1095
|
+
),
|
|
1096
|
+
)
|
|
1097
|
+
self.is_fitted_ = True
|
|
1098
|
+
return self.results_
|
|
1099
|
+
|
|
1100
|
+
def _recompute_unit_survey_metadata(self, panel_metadata):
|
|
1101
|
+
"""Recompute survey metadata from unit-level design if available."""
|
|
1102
|
+
if self._unit_resolved_survey is not None:
|
|
1103
|
+
from diff_diff.survey import compute_survey_metadata
|
|
1104
|
+
|
|
1105
|
+
meta = compute_survey_metadata(
|
|
1106
|
+
self._unit_resolved_survey,
|
|
1107
|
+
self._unit_resolved_survey.weights,
|
|
1108
|
+
)
|
|
1109
|
+
# Propagate effective replicate df if available
|
|
1110
|
+
# (but not the df=0 sentinel — keep metadata as None for undefined df)
|
|
1111
|
+
if (self._survey_df is not None and self._survey_df != 0
|
|
1112
|
+
and meta.df_survey != self._survey_df):
|
|
1113
|
+
meta.df_survey = self._survey_df
|
|
1114
|
+
return meta
|
|
1115
|
+
return panel_metadata
|
|
1116
|
+
|
|
1117
|
+
# -- Survey SE helpers ----------------------------------------------------
|
|
1118
|
+
|
|
1119
|
+
def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
|
|
1120
|
+
"""Compute SE from EIF scores using Taylor Series Linearization.
|
|
1121
|
+
|
|
1122
|
+
Uses the pre-built unit-level ``_unit_resolved_survey`` constructed
|
|
1123
|
+
once in ``fit()``, ensuring consistent unit-level arrays and
|
|
1124
|
+
avoiding repeated subsetting of panel-level survey data.
|
|
1125
|
+
"""
|
|
1126
|
+
if self._unit_resolved_survey.uses_replicate_variance:
|
|
1127
|
+
from diff_diff.survey import compute_replicate_if_variance
|
|
1128
|
+
|
|
1129
|
+
# Score-scale IFs to match TSL bread: psi = w * eif / sum(w)
|
|
1130
|
+
w = self._unit_resolved_survey.weights
|
|
1131
|
+
psi_scaled = w * eif_vals / w.sum()
|
|
1132
|
+
variance, n_valid = compute_replicate_if_variance(psi_scaled, self._unit_resolved_survey)
|
|
1133
|
+
# Update survey df to reflect effective replicate count
|
|
1134
|
+
if n_valid < self._unit_resolved_survey.n_replicates:
|
|
1135
|
+
self._survey_df = n_valid - 1 if n_valid > 1 else None
|
|
1136
|
+
return float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
|
|
1137
|
+
|
|
1138
|
+
from diff_diff.survey import compute_survey_vcov
|
|
1139
|
+
|
|
1140
|
+
X_ones = np.ones((len(eif_vals), 1))
|
|
1141
|
+
vcov = compute_survey_vcov(X_ones, eif_vals, self._unit_resolved_survey)
|
|
1142
|
+
return float(np.sqrt(np.abs(vcov[0, 0])))
|
|
1143
|
+
|
|
1144
|
+
def _eif_se(
|
|
1145
|
+
self,
|
|
1146
|
+
eif_vals: np.ndarray,
|
|
1147
|
+
n_units: int,
|
|
1148
|
+
cluster_indices: Optional[np.ndarray] = None,
|
|
1149
|
+
n_clusters: Optional[int] = None,
|
|
1150
|
+
) -> float:
|
|
1151
|
+
"""Compute SE from aggregated EIF scores.
|
|
1152
|
+
|
|
1153
|
+
Dispatches to survey TSL when ``_unit_resolved_survey`` is set
|
|
1154
|
+
(during fit), otherwise uses cluster-robust or standard formula.
|
|
1155
|
+
"""
|
|
1156
|
+
if self._unit_resolved_survey is not None:
|
|
1157
|
+
return self._compute_survey_eif_se(eif_vals)
|
|
1158
|
+
return _compute_se_from_eif(eif_vals, n_units, cluster_indices, n_clusters)
|
|
1159
|
+
|
|
1160
|
+
# -- Aggregation helpers --------------------------------------------------
|
|
1161
|
+
|
|
1162
|
+
def _compute_wif_contribution(
|
|
1163
|
+
self,
|
|
1164
|
+
keepers: List[Tuple],
|
|
1165
|
+
effects: np.ndarray,
|
|
1166
|
+
unit_cohorts: np.ndarray,
|
|
1167
|
+
cohort_fractions: Dict[float, float],
|
|
1168
|
+
n_units: int,
|
|
1169
|
+
unit_weights: Optional[np.ndarray] = None,
|
|
1170
|
+
) -> np.ndarray:
|
|
1171
|
+
"""Compute weight influence function correction (O(1) scale, matching EIF).
|
|
1172
|
+
|
|
1173
|
+
This accounts for uncertainty in cohort-size aggregation weights.
|
|
1174
|
+
Matches R's ``did`` package WIF formula (staggered_aggregation.py:282-309),
|
|
1175
|
+
adapted to EDiD's EIF scale.
|
|
1176
|
+
|
|
1177
|
+
Parameters
|
|
1178
|
+
----------
|
|
1179
|
+
keepers : list of (g, t) tuples
|
|
1180
|
+
Post-treatment group-time pairs included in aggregation.
|
|
1181
|
+
effects : ndarray, shape (n_keepers,)
|
|
1182
|
+
ATT estimates for each keeper.
|
|
1183
|
+
unit_cohorts : ndarray, shape (n_units,)
|
|
1184
|
+
Cohort assignment for each unit (0 = never-treated).
|
|
1185
|
+
cohort_fractions : dict
|
|
1186
|
+
``{cohort: n_cohort / n}`` for each cohort.
|
|
1187
|
+
n_units : int
|
|
1188
|
+
Total number of units.
|
|
1189
|
+
unit_weights : ndarray, shape (n_units,), optional
|
|
1190
|
+
Survey weights at the unit level. When provided, uses the
|
|
1191
|
+
survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
|
|
1192
|
+
|
|
1193
|
+
Returns
|
|
1194
|
+
-------
|
|
1195
|
+
ndarray, shape (n_units,)
|
|
1196
|
+
WIF contribution at O(1) scale, additive with ``agg_eif``.
|
|
1197
|
+
"""
|
|
1198
|
+
groups_for_keepers = np.array([g for (g, t) in keepers])
|
|
1199
|
+
pg_keepers = np.array([cohort_fractions.get(g, 0.0) for g, t in keepers])
|
|
1200
|
+
sum_pg = pg_keepers.sum()
|
|
1201
|
+
if sum_pg == 0:
|
|
1202
|
+
return np.zeros(n_units)
|
|
1203
|
+
|
|
1204
|
+
indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
|
|
1205
|
+
|
|
1206
|
+
if unit_weights is not None:
|
|
1207
|
+
# Survey-weighted WIF (matches staggered_aggregation.py:392-401):
|
|
1208
|
+
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
|
|
1209
|
+
weighted_indicator = indicator * unit_weights[:, None]
|
|
1210
|
+
indicator_diff = weighted_indicator - pg_keepers
|
|
1211
|
+
indicator_sum = np.sum(indicator_diff, axis=1)
|
|
1212
|
+
else:
|
|
1213
|
+
indicator_diff = indicator - pg_keepers
|
|
1214
|
+
indicator_sum = np.sum(indicator_diff, axis=1)
|
|
1215
|
+
|
|
1216
|
+
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
|
|
1217
|
+
if1 = indicator_diff / sum_pg
|
|
1218
|
+
if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
|
|
1219
|
+
wif_matrix = if1 - if2
|
|
1220
|
+
wif_contrib = wif_matrix @ effects
|
|
1221
|
+
return wif_contrib # O(1) scale, same as agg_eif
|
|
1222
|
+
|
|
1223
|
+
def _aggregate_overall(
|
|
1224
|
+
self,
|
|
1225
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
1226
|
+
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
|
|
1227
|
+
n_units: int,
|
|
1228
|
+
cohort_fractions: Dict[float, float],
|
|
1229
|
+
unit_cohorts: np.ndarray,
|
|
1230
|
+
cluster_indices: Optional[np.ndarray] = None,
|
|
1231
|
+
n_clusters: Optional[int] = None,
|
|
1232
|
+
) -> Tuple[float, float]:
|
|
1233
|
+
"""Compute overall ATT with WIF-adjusted SE.
|
|
1234
|
+
|
|
1235
|
+
Parameters
|
|
1236
|
+
----------
|
|
1237
|
+
group_time_effects : dict
|
|
1238
|
+
Group-time ATT estimates.
|
|
1239
|
+
eif_by_gt : dict
|
|
1240
|
+
Per-unit EIF values for each (g, t).
|
|
1241
|
+
n_units : int
|
|
1242
|
+
Total number of units.
|
|
1243
|
+
cohort_fractions : dict
|
|
1244
|
+
Cohort size fractions.
|
|
1245
|
+
unit_cohorts : ndarray, shape (n_units,)
|
|
1246
|
+
Cohort assignment for each unit.
|
|
1247
|
+
"""
|
|
1248
|
+
# Filter to post-treatment effects
|
|
1249
|
+
keepers = [
|
|
1250
|
+
(g, t)
|
|
1251
|
+
for (g, t) in group_time_effects
|
|
1252
|
+
if t >= g - self.anticipation and np.isfinite(group_time_effects[(g, t)]["effect"])
|
|
1253
|
+
]
|
|
1254
|
+
if not keepers:
|
|
1255
|
+
return np.nan, np.nan
|
|
1256
|
+
|
|
1257
|
+
# Cohort-size weights
|
|
1258
|
+
pg = np.array([cohort_fractions.get(g, 0.0) for (g, _) in keepers])
|
|
1259
|
+
total_pg = pg.sum()
|
|
1260
|
+
if total_pg == 0:
|
|
1261
|
+
return np.nan, np.nan
|
|
1262
|
+
w = pg / total_pg
|
|
1263
|
+
|
|
1264
|
+
effects = np.array([group_time_effects[gt]["effect"] for gt in keepers])
|
|
1265
|
+
overall_att = float(np.sum(w * effects))
|
|
1266
|
+
|
|
1267
|
+
# Aggregate EIF
|
|
1268
|
+
agg_eif = np.zeros(n_units)
|
|
1269
|
+
for k, gt in enumerate(keepers):
|
|
1270
|
+
agg_eif += w[k] * eif_by_gt[gt]
|
|
1271
|
+
|
|
1272
|
+
# WIF correction: accounts for uncertainty in cohort-size weights
|
|
1273
|
+
wif = self._compute_wif_contribution(
|
|
1274
|
+
keepers, effects, unit_cohorts, cohort_fractions, n_units,
|
|
1275
|
+
unit_weights=self._unit_level_weights,
|
|
1276
|
+
)
|
|
1277
|
+
# Compute SE: survey path uses score-level psi to avoid double-weighting
|
|
1278
|
+
# (compute_survey_vcov applies w_i internally, which would double-weight
|
|
1279
|
+
# the survey-weighted WIF term). Dispatch replicate vs TSL.
|
|
1280
|
+
if self._unit_resolved_survey is not None:
|
|
1281
|
+
uw = self._unit_level_weights
|
|
1282
|
+
total_w = float(np.sum(uw))
|
|
1283
|
+
psi_total = uw * agg_eif / total_w + wif / total_w
|
|
1284
|
+
|
|
1285
|
+
if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
|
|
1286
|
+
and self._unit_resolved_survey.uses_replicate_variance):
|
|
1287
|
+
from diff_diff.survey import compute_replicate_if_variance
|
|
1288
|
+
|
|
1289
|
+
variance, _ = compute_replicate_if_variance(
|
|
1290
|
+
psi_total, self._unit_resolved_survey
|
|
1291
|
+
)
|
|
1292
|
+
else:
|
|
1293
|
+
from diff_diff.survey import compute_survey_if_variance
|
|
1294
|
+
|
|
1295
|
+
variance = compute_survey_if_variance(
|
|
1296
|
+
psi_total, self._unit_resolved_survey
|
|
1297
|
+
)
|
|
1298
|
+
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
|
|
1299
|
+
else:
|
|
1300
|
+
agg_eif_total = agg_eif + wif
|
|
1301
|
+
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
|
|
1302
|
+
|
|
1303
|
+
return overall_att, se
|
|
1304
|
+
|
|
1305
|
+
def _aggregate_event_study(
|
|
1306
|
+
self,
|
|
1307
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
1308
|
+
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
|
|
1309
|
+
n_units: int,
|
|
1310
|
+
cohort_fractions: Dict[float, float],
|
|
1311
|
+
treatment_groups: List[Any],
|
|
1312
|
+
time_periods: List[Any],
|
|
1313
|
+
balance_e: Optional[int] = None,
|
|
1314
|
+
unit_cohorts: Optional[np.ndarray] = None,
|
|
1315
|
+
cluster_indices: Optional[np.ndarray] = None,
|
|
1316
|
+
n_clusters: Optional[int] = None,
|
|
1317
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
1318
|
+
"""Aggregate ATT(g,t) by relative time e = t - g.
|
|
1319
|
+
|
|
1320
|
+
Parameters
|
|
1321
|
+
----------
|
|
1322
|
+
group_time_effects : dict
|
|
1323
|
+
Group-time ATT estimates.
|
|
1324
|
+
eif_by_gt : dict
|
|
1325
|
+
Per-unit EIF values for each (g, t).
|
|
1326
|
+
n_units : int
|
|
1327
|
+
Total number of units.
|
|
1328
|
+
cohort_fractions : dict
|
|
1329
|
+
Cohort size fractions.
|
|
1330
|
+
treatment_groups : list
|
|
1331
|
+
Treatment cohort identifiers.
|
|
1332
|
+
time_periods : list
|
|
1333
|
+
All time periods.
|
|
1334
|
+
balance_e : int, optional
|
|
1335
|
+
Balance event study at this relative period.
|
|
1336
|
+
unit_cohorts : ndarray, optional
|
|
1337
|
+
Cohort assignment for each unit (for WIF correction).
|
|
1338
|
+
"""
|
|
1339
|
+
# Organize by relative time
|
|
1340
|
+
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
|
|
1341
|
+
for (g, t), data in group_time_effects.items():
|
|
1342
|
+
if not np.isfinite(data["effect"]):
|
|
1343
|
+
continue
|
|
1344
|
+
e = int(t - g)
|
|
1345
|
+
if e not in effects_by_e:
|
|
1346
|
+
effects_by_e[e] = []
|
|
1347
|
+
effects_by_e[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
|
|
1348
|
+
|
|
1349
|
+
# Balance if requested
|
|
1350
|
+
if balance_e is not None:
|
|
1351
|
+
groups_at_e = {gt[0] for gt, _, _ in effects_by_e.get(balance_e, [])}
|
|
1352
|
+
balanced: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
|
|
1353
|
+
for (g, t), data in group_time_effects.items():
|
|
1354
|
+
if not np.isfinite(data["effect"]):
|
|
1355
|
+
continue
|
|
1356
|
+
if g in groups_at_e:
|
|
1357
|
+
e = int(t - g)
|
|
1358
|
+
if e not in balanced:
|
|
1359
|
+
balanced[e] = []
|
|
1360
|
+
balanced[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
|
|
1361
|
+
effects_by_e = balanced
|
|
1362
|
+
|
|
1363
|
+
if balance_e is not None and not effects_by_e:
|
|
1364
|
+
warnings.warn(
|
|
1365
|
+
f"balance_e={balance_e}: no cohort has a finite effect at the "
|
|
1366
|
+
"anchor horizon. Event study will be empty.",
|
|
1367
|
+
UserWarning,
|
|
1368
|
+
stacklevel=2,
|
|
1369
|
+
)
|
|
1370
|
+
|
|
1371
|
+
result: Dict[int, Dict[str, Any]] = {}
|
|
1372
|
+
for e, elist in sorted(effects_by_e.items()):
|
|
1373
|
+
gt_pairs = [x[0] for x in elist]
|
|
1374
|
+
effs = np.array([x[1] for x in elist])
|
|
1375
|
+
pgs = np.array([x[2] for x in elist])
|
|
1376
|
+
total_pg = pgs.sum()
|
|
1377
|
+
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
|
|
1378
|
+
|
|
1379
|
+
agg_eff = float(np.sum(w * effs))
|
|
1380
|
+
|
|
1381
|
+
# Aggregate EIF
|
|
1382
|
+
agg_eif = np.zeros(n_units)
|
|
1383
|
+
for k, gt in enumerate(gt_pairs):
|
|
1384
|
+
agg_eif += w[k] * eif_by_gt[gt]
|
|
1385
|
+
|
|
1386
|
+
# WIF correction for event-study aggregation
|
|
1387
|
+
wif_e = np.zeros(n_units)
|
|
1388
|
+
if unit_cohorts is not None:
|
|
1389
|
+
es_keepers = [(g, t) for (g, t) in gt_pairs]
|
|
1390
|
+
es_effects = effs
|
|
1391
|
+
wif_e = self._compute_wif_contribution(
|
|
1392
|
+
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
|
|
1393
|
+
unit_weights=self._unit_level_weights,
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
if self._unit_resolved_survey is not None:
|
|
1397
|
+
uw = self._unit_level_weights
|
|
1398
|
+
total_w = float(np.sum(uw))
|
|
1399
|
+
psi_total = uw * agg_eif / total_w + wif_e / total_w
|
|
1400
|
+
|
|
1401
|
+
if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
|
|
1402
|
+
and self._unit_resolved_survey.uses_replicate_variance):
|
|
1403
|
+
from diff_diff.survey import compute_replicate_if_variance
|
|
1404
|
+
|
|
1405
|
+
variance, _ = compute_replicate_if_variance(
|
|
1406
|
+
psi_total, self._unit_resolved_survey
|
|
1407
|
+
)
|
|
1408
|
+
else:
|
|
1409
|
+
from diff_diff.survey import compute_survey_if_variance
|
|
1410
|
+
|
|
1411
|
+
variance = compute_survey_if_variance(
|
|
1412
|
+
psi_total, self._unit_resolved_survey
|
|
1413
|
+
)
|
|
1414
|
+
agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
|
|
1415
|
+
else:
|
|
1416
|
+
agg_eif = agg_eif + wif_e
|
|
1417
|
+
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
|
|
1418
|
+
|
|
1419
|
+
t_stat, p_val, ci = safe_inference(
|
|
1420
|
+
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
|
|
1421
|
+
)
|
|
1422
|
+
result[e] = {
|
|
1423
|
+
"effect": agg_eff,
|
|
1424
|
+
"se": agg_se,
|
|
1425
|
+
"t_stat": t_stat,
|
|
1426
|
+
"p_value": p_val,
|
|
1427
|
+
"conf_int": ci,
|
|
1428
|
+
"n_groups": len(elist),
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
return result
|
|
1432
|
+
|
|
1433
|
+
def _aggregate_by_group(
|
|
1434
|
+
self,
|
|
1435
|
+
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
|
|
1436
|
+
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
|
|
1437
|
+
n_units: int,
|
|
1438
|
+
cohort_fractions: Dict[float, float],
|
|
1439
|
+
treatment_groups: List[Any],
|
|
1440
|
+
unit_cohorts: Optional[np.ndarray] = None,
|
|
1441
|
+
cluster_indices: Optional[np.ndarray] = None,
|
|
1442
|
+
n_clusters: Optional[int] = None,
|
|
1443
|
+
) -> Dict[Any, Dict[str, Any]]:
|
|
1444
|
+
"""Aggregate ATT(g,t) by treatment cohort.
|
|
1445
|
+
|
|
1446
|
+
Parameters
|
|
1447
|
+
----------
|
|
1448
|
+
group_time_effects : dict
|
|
1449
|
+
Group-time ATT estimates.
|
|
1450
|
+
eif_by_gt : dict
|
|
1451
|
+
Per-unit EIF values for each (g, t).
|
|
1452
|
+
n_units : int
|
|
1453
|
+
Total number of units.
|
|
1454
|
+
cohort_fractions : dict
|
|
1455
|
+
Cohort size fractions.
|
|
1456
|
+
treatment_groups : list
|
|
1457
|
+
Treatment cohort identifiers.
|
|
1458
|
+
unit_cohorts : ndarray, optional
|
|
1459
|
+
Cohort assignment for each unit (unused — group aggregation
|
|
1460
|
+
uses equal weights, not cohort-size weights).
|
|
1461
|
+
"""
|
|
1462
|
+
result: Dict[Any, Dict[str, Any]] = {}
|
|
1463
|
+
for g in treatment_groups:
|
|
1464
|
+
g_gts = [
|
|
1465
|
+
(gg, t)
|
|
1466
|
+
for (gg, t) in group_time_effects
|
|
1467
|
+
if gg == g
|
|
1468
|
+
and t >= g - self.anticipation
|
|
1469
|
+
and np.isfinite(group_time_effects[(gg, t)]["effect"])
|
|
1470
|
+
]
|
|
1471
|
+
if not g_gts:
|
|
1472
|
+
continue
|
|
1473
|
+
|
|
1474
|
+
effs = np.array([group_time_effects[gt]["effect"] for gt in g_gts])
|
|
1475
|
+
w = np.ones(len(effs)) / len(effs)
|
|
1476
|
+
agg_eff = float(np.sum(w * effs))
|
|
1477
|
+
|
|
1478
|
+
agg_eif = np.zeros(n_units)
|
|
1479
|
+
for k, gt in enumerate(g_gts):
|
|
1480
|
+
agg_eif += w[k] * eif_by_gt[gt]
|
|
1481
|
+
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
|
|
1482
|
+
|
|
1483
|
+
t_stat, p_val, ci = safe_inference(
|
|
1484
|
+
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
|
|
1485
|
+
)
|
|
1486
|
+
result[g] = {
|
|
1487
|
+
"effect": agg_eff,
|
|
1488
|
+
"se": agg_se,
|
|
1489
|
+
"t_stat": t_stat,
|
|
1490
|
+
"p_value": p_val,
|
|
1491
|
+
"conf_int": ci,
|
|
1492
|
+
"n_periods": len(g_gts),
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
return result
|
|
1496
|
+
|
|
1497
|
+
def summary(self) -> str:
|
|
1498
|
+
"""Get summary of estimation results."""
|
|
1499
|
+
if not self.is_fitted_:
|
|
1500
|
+
raise RuntimeError("Model must be fitted before calling summary()")
|
|
1501
|
+
assert self.results_ is not None
|
|
1502
|
+
return self.results_.summary()
|
|
1503
|
+
|
|
1504
|
+
def print_summary(self) -> None:
|
|
1505
|
+
"""Print summary to stdout."""
|
|
1506
|
+
print(self.summary())
|
|
1507
|
+
|
|
1508
|
+
# -- Hausman pretest -------------------------------------------------------
|
|
1509
|
+
|
|
1510
|
+
@classmethod
|
|
1511
|
+
def hausman_pretest(
|
|
1512
|
+
cls,
|
|
1513
|
+
data: pd.DataFrame,
|
|
1514
|
+
outcome: str,
|
|
1515
|
+
unit: str,
|
|
1516
|
+
time: str,
|
|
1517
|
+
first_treat: str,
|
|
1518
|
+
covariates: Optional[List[str]] = None,
|
|
1519
|
+
cluster: Optional[str] = None,
|
|
1520
|
+
anticipation: int = 0,
|
|
1521
|
+
control_group: str = "never_treated",
|
|
1522
|
+
alpha: float = 0.05,
|
|
1523
|
+
**nuisance_kwargs: Any,
|
|
1524
|
+
) -> HausmanPretestResult:
|
|
1525
|
+
"""Hausman pretest for PT-All vs PT-Post (Theorem A.1).
|
|
1526
|
+
|
|
1527
|
+
Fits the estimator under both parallel trends assumptions and
|
|
1528
|
+
compares the results. Under H0 (PT-All holds), both are consistent
|
|
1529
|
+
but PT-All is more efficient. Rejection suggests PT-All is too
|
|
1530
|
+
strong; use PT-Post instead.
|
|
1531
|
+
|
|
1532
|
+
Parameters
|
|
1533
|
+
----------
|
|
1534
|
+
data, outcome, unit, time, first_treat, covariates
|
|
1535
|
+
Same as :meth:`fit`.
|
|
1536
|
+
cluster : str, optional
|
|
1537
|
+
Cluster column for cluster-robust covariance.
|
|
1538
|
+
anticipation : int
|
|
1539
|
+
Anticipation periods.
|
|
1540
|
+
control_group : str
|
|
1541
|
+
``"never_treated"`` or ``"last_cohort"``.
|
|
1542
|
+
alpha : float
|
|
1543
|
+
Significance level for the test.
|
|
1544
|
+
**nuisance_kwargs
|
|
1545
|
+
Passed to both fits (e.g. ``sieve_k_max``, ``ratio_clip``).
|
|
1546
|
+
|
|
1547
|
+
Returns
|
|
1548
|
+
-------
|
|
1549
|
+
HausmanPretestResult
|
|
1550
|
+
"""
|
|
1551
|
+
from scipy.stats import chi2
|
|
1552
|
+
|
|
1553
|
+
# Fit under both assumptions (analytical SEs only, no bootstrap)
|
|
1554
|
+
common_kwargs = dict(
|
|
1555
|
+
cluster=cluster,
|
|
1556
|
+
control_group=control_group,
|
|
1557
|
+
anticipation=anticipation,
|
|
1558
|
+
n_bootstrap=0,
|
|
1559
|
+
**nuisance_kwargs,
|
|
1560
|
+
)
|
|
1561
|
+
fit_kwargs = dict(
|
|
1562
|
+
data=data,
|
|
1563
|
+
outcome=outcome,
|
|
1564
|
+
unit=unit,
|
|
1565
|
+
time=time,
|
|
1566
|
+
first_treat=first_treat,
|
|
1567
|
+
covariates=covariates,
|
|
1568
|
+
aggregate=None,
|
|
1569
|
+
)
|
|
1570
|
+
|
|
1571
|
+
edid_all = cls(pt_assumption="all", alpha=alpha, **common_kwargs)
|
|
1572
|
+
result_all = edid_all.fit(**fit_kwargs, store_eif=True)
|
|
1573
|
+
|
|
1574
|
+
edid_post = cls(pt_assumption="post", alpha=alpha, **common_kwargs)
|
|
1575
|
+
result_post = edid_post.fit(**fit_kwargs, store_eif=True)
|
|
1576
|
+
|
|
1577
|
+
# Find common (g,t) pairs — PT-Post pairs are a subset of PT-All
|
|
1578
|
+
common_gts = sorted(
|
|
1579
|
+
set(result_all.group_time_effects.keys()) & set(result_post.group_time_effects.keys())
|
|
1580
|
+
)
|
|
1581
|
+
|
|
1582
|
+
def _nan_result() -> HausmanPretestResult:
|
|
1583
|
+
return HausmanPretestResult(
|
|
1584
|
+
statistic=np.nan,
|
|
1585
|
+
p_value=np.nan,
|
|
1586
|
+
df=0,
|
|
1587
|
+
reject=False,
|
|
1588
|
+
alpha=alpha,
|
|
1589
|
+
att_all=result_all.overall_att,
|
|
1590
|
+
att_post=result_post.overall_att,
|
|
1591
|
+
recommendation="inconclusive",
|
|
1592
|
+
gt_details=None,
|
|
1593
|
+
)
|
|
1594
|
+
|
|
1595
|
+
if not common_gts:
|
|
1596
|
+
return _nan_result()
|
|
1597
|
+
|
|
1598
|
+
eif_all = result_all.influence_functions
|
|
1599
|
+
eif_post = result_post.influence_functions
|
|
1600
|
+
assert eif_all is not None and eif_post is not None
|
|
1601
|
+
n_units = len(next(iter(eif_all.values())))
|
|
1602
|
+
|
|
1603
|
+
# --- Aggregate to post-treatment ES(e) per Theorem A.1 ---
|
|
1604
|
+
# Derive cohort fractions from data for proper weights
|
|
1605
|
+
all_units_list = sorted(data[unit].unique())
|
|
1606
|
+
unit_cohorts = (
|
|
1607
|
+
data.groupby(unit)[first_treat].first().reindex(all_units_list).values.astype(float)
|
|
1608
|
+
)
|
|
1609
|
+
cohort_fractions: Dict[float, float] = {}
|
|
1610
|
+
for g in set(result_all.groups) | set(result_post.groups):
|
|
1611
|
+
cohort_fractions[g] = float(np.sum(unit_cohorts == g)) / n_units
|
|
1612
|
+
|
|
1613
|
+
def _aggregate_es(
|
|
1614
|
+
gt_effects: Dict, eif_dict: Dict, groups: List, ant: int
|
|
1615
|
+
) -> Dict[int, Tuple[float, np.ndarray]]:
|
|
1616
|
+
"""Aggregate (g,t) effects to post-treatment ES(e) with WIF-corrected EIF."""
|
|
1617
|
+
by_e: Dict[int, List[Tuple[Tuple, float, float, np.ndarray]]] = {}
|
|
1618
|
+
for (g, t), d in gt_effects.items():
|
|
1619
|
+
e = int(t - g)
|
|
1620
|
+
if e < -ant:
|
|
1621
|
+
continue
|
|
1622
|
+
if not np.isfinite(d["effect"]):
|
|
1623
|
+
continue
|
|
1624
|
+
if (g, t) not in eif_dict:
|
|
1625
|
+
continue
|
|
1626
|
+
eif_vec = eif_dict[(g, t)]
|
|
1627
|
+
if not np.all(np.isfinite(eif_vec)):
|
|
1628
|
+
continue
|
|
1629
|
+
pg = cohort_fractions.get(g, 0.0)
|
|
1630
|
+
if e not in by_e:
|
|
1631
|
+
by_e[e] = []
|
|
1632
|
+
by_e[e].append(((g, t), d["effect"], pg, eif_vec))
|
|
1633
|
+
|
|
1634
|
+
result: Dict[int, Tuple[float, np.ndarray]] = {}
|
|
1635
|
+
for e, items in by_e.items():
|
|
1636
|
+
if e < 0:
|
|
1637
|
+
continue
|
|
1638
|
+
effs = np.array([x[1] for x in items])
|
|
1639
|
+
pgs = np.array([x[2] for x in items])
|
|
1640
|
+
eifs = [x[3] for x in items]
|
|
1641
|
+
gt_pairs_e = [x[0] for x in items]
|
|
1642
|
+
total_pg = pgs.sum()
|
|
1643
|
+
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
|
|
1644
|
+
es_eff = float(np.sum(w * effs))
|
|
1645
|
+
es_eif = np.zeros(n_units)
|
|
1646
|
+
for k_idx in range(len(eifs)):
|
|
1647
|
+
es_eif += w[k_idx] * eifs[k_idx]
|
|
1648
|
+
# WIF correction for estimated cohort-size weights
|
|
1649
|
+
groups_e = np.array([g for (g, t) in gt_pairs_e])
|
|
1650
|
+
pg_e = np.array([cohort_fractions.get(g, 0.0) for g, t in gt_pairs_e])
|
|
1651
|
+
sum_pg = pg_e.sum()
|
|
1652
|
+
if sum_pg > 0:
|
|
1653
|
+
indicator = (unit_cohorts[:, None] == groups_e[None, :]).astype(float)
|
|
1654
|
+
indicator_sum = np.sum(indicator - pg_e, axis=1)
|
|
1655
|
+
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
|
|
1656
|
+
if1 = (indicator - pg_e) / sum_pg
|
|
1657
|
+
if2 = np.outer(indicator_sum, pg_e) / sum_pg**2
|
|
1658
|
+
wif = (if1 - if2) @ effs
|
|
1659
|
+
es_eif = es_eif + wif
|
|
1660
|
+
result[e] = (es_eff, es_eif)
|
|
1661
|
+
return result
|
|
1662
|
+
|
|
1663
|
+
es_all = _aggregate_es(
|
|
1664
|
+
result_all.group_time_effects, eif_all, result_all.groups, anticipation
|
|
1665
|
+
)
|
|
1666
|
+
es_post = _aggregate_es(
|
|
1667
|
+
result_post.group_time_effects, eif_post, result_post.groups, anticipation
|
|
1668
|
+
)
|
|
1669
|
+
|
|
1670
|
+
# Find common post-treatment horizons
|
|
1671
|
+
common_e = sorted(set(es_all.keys()) & set(es_post.keys()))
|
|
1672
|
+
if not common_e:
|
|
1673
|
+
return _nan_result()
|
|
1674
|
+
|
|
1675
|
+
delta = np.array([es_post[e][0] - es_all[e][0] for e in common_e])
|
|
1676
|
+
|
|
1677
|
+
# Build ES(e)-level EIF matrices
|
|
1678
|
+
eif_all_mat = np.column_stack([es_all[e][1] for e in common_e])
|
|
1679
|
+
eif_post_mat = np.column_stack([es_post[e][1] for e in common_e])
|
|
1680
|
+
|
|
1681
|
+
# Filter units with non-finite EIF values
|
|
1682
|
+
row_finite = np.all(np.isfinite(eif_all_mat), axis=1) & np.all(
|
|
1683
|
+
np.isfinite(eif_post_mat), axis=1
|
|
1684
|
+
)
|
|
1685
|
+
cl_idx: Optional[np.ndarray] = None
|
|
1686
|
+
n_cl: Optional[int] = None
|
|
1687
|
+
if cluster is not None:
|
|
1688
|
+
cl_idx, n_cl = _validate_and_build_cluster_mapping(data, unit, cluster, all_units_list)
|
|
1689
|
+
if not np.all(row_finite):
|
|
1690
|
+
eif_all_mat = eif_all_mat[row_finite]
|
|
1691
|
+
eif_post_mat = eif_post_mat[row_finite]
|
|
1692
|
+
n_units = int(np.sum(row_finite))
|
|
1693
|
+
if cl_idx is not None:
|
|
1694
|
+
cl_idx = cl_idx[row_finite]
|
|
1695
|
+
# Recompute effective cluster count and remap to contiguous
|
|
1696
|
+
# indices — entire clusters may have been dropped by filtering
|
|
1697
|
+
unique_cl, cl_idx = np.unique(cl_idx, return_inverse=True)
|
|
1698
|
+
n_cl = len(unique_cl)
|
|
1699
|
+
|
|
1700
|
+
# Compute full covariance matrices
|
|
1701
|
+
if cl_idx is not None and n_cl is not None:
|
|
1702
|
+
|
|
1703
|
+
def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
|
|
1704
|
+
centered = _cluster_aggregate(eif_mat, cl_idx, n_cl)
|
|
1705
|
+
correction = n_cl / (n_cl - 1) if n_cl > 1 else 1.0
|
|
1706
|
+
return correction * (centered.T @ centered) / (n_units**2)
|
|
1707
|
+
|
|
1708
|
+
cov_all = _eif_cov(eif_all_mat)
|
|
1709
|
+
cov_post = _eif_cov(eif_post_mat)
|
|
1710
|
+
else:
|
|
1711
|
+
with np.errstate(over="ignore", invalid="ignore"):
|
|
1712
|
+
cov_all = (eif_all_mat.T @ eif_all_mat) / (n_units**2)
|
|
1713
|
+
cov_post = (eif_post_mat.T @ eif_post_mat) / (n_units**2)
|
|
1714
|
+
|
|
1715
|
+
V = cov_post - cov_all
|
|
1716
|
+
|
|
1717
|
+
if not np.all(np.isfinite(V)):
|
|
1718
|
+
warnings.warn(
|
|
1719
|
+
"Hausman covariance matrix contains non-finite values. " "The test is unreliable.",
|
|
1720
|
+
UserWarning,
|
|
1721
|
+
stacklevel=2,
|
|
1722
|
+
)
|
|
1723
|
+
return _nan_result()
|
|
1724
|
+
|
|
1725
|
+
# Eigendecompose V — check for non-PSD
|
|
1726
|
+
eigvals = np.linalg.eigvalsh(V)
|
|
1727
|
+
max_eigval = np.max(np.abs(eigvals)) if len(eigvals) > 0 else 0.0
|
|
1728
|
+
tol = max(1e-10 * max_eigval, 1e-15)
|
|
1729
|
+
|
|
1730
|
+
n_negative = int(np.sum(eigvals < -tol))
|
|
1731
|
+
if n_negative > 0:
|
|
1732
|
+
warnings.warn(
|
|
1733
|
+
f"Hausman variance-difference matrix V has {n_negative} "
|
|
1734
|
+
"substantially negative eigenvalue(s). The test may be "
|
|
1735
|
+
"unreliable (finite-sample efficiency reversal).",
|
|
1736
|
+
UserWarning,
|
|
1737
|
+
stacklevel=2,
|
|
1738
|
+
)
|
|
1739
|
+
|
|
1740
|
+
effective_rank = int(np.sum(eigvals > tol))
|
|
1741
|
+
if effective_rank == 0:
|
|
1742
|
+
return _nan_result()
|
|
1743
|
+
|
|
1744
|
+
V_pinv = np.linalg.pinv(V, rcond=tol / max_eigval if max_eigval > 0 else 1e-10)
|
|
1745
|
+
H = float(delta @ V_pinv @ delta)
|
|
1746
|
+
H = max(H, 0.0)
|
|
1747
|
+
|
|
1748
|
+
p_value = float(chi2.sf(H, df=effective_rank))
|
|
1749
|
+
reject = p_value < alpha
|
|
1750
|
+
|
|
1751
|
+
es_details = pd.DataFrame(
|
|
1752
|
+
{
|
|
1753
|
+
"relative_period": common_e,
|
|
1754
|
+
"es_all": [es_all[e][0] for e in common_e],
|
|
1755
|
+
"es_post": [es_post[e][0] for e in common_e],
|
|
1756
|
+
"delta": delta,
|
|
1757
|
+
}
|
|
1758
|
+
)
|
|
1759
|
+
|
|
1760
|
+
return HausmanPretestResult(
|
|
1761
|
+
statistic=H,
|
|
1762
|
+
p_value=p_value,
|
|
1763
|
+
df=effective_rank,
|
|
1764
|
+
reject=reject,
|
|
1765
|
+
alpha=alpha,
|
|
1766
|
+
att_all=result_all.overall_att,
|
|
1767
|
+
att_post=result_post.overall_att,
|
|
1768
|
+
recommendation="pt_post" if reject else "pt_all",
|
|
1769
|
+
gt_details=es_details,
|
|
1770
|
+
)
|