diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/trop_global.py
ADDED
|
@@ -0,0 +1,1270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Global estimation method for the TROP estimator.
|
|
3
|
+
|
|
4
|
+
Contains the TROPGlobalMixin class with all methods for the global
|
|
5
|
+
(joint) estimation pathway. The global method fits a single weighted
|
|
6
|
+
model on control observations and extracts per-observation treatment
|
|
7
|
+
effects as post-hoc residuals.
|
|
8
|
+
|
|
9
|
+
This module is used via mixin inheritance — see trop.py for the
|
|
10
|
+
main TROP class definition.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import warnings
|
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
from diff_diff._backend import (
|
|
23
|
+
HAS_RUST_BACKEND,
|
|
24
|
+
_rust_bootstrap_trop_variance_global,
|
|
25
|
+
_rust_loocv_grid_search_global,
|
|
26
|
+
)
|
|
27
|
+
from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment
|
|
28
|
+
from diff_diff.trop_results import TROPResults
|
|
29
|
+
from diff_diff.utils import safe_inference
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TROPGlobalMixin:
|
|
33
|
+
"""Mixin providing global estimation method for TROP.
|
|
34
|
+
|
|
35
|
+
Methods in this mixin access the following attributes from the main
|
|
36
|
+
TROP class via ``self``:
|
|
37
|
+
|
|
38
|
+
- Tuning grids: ``lambda_time_grid``, ``lambda_unit_grid``, ``lambda_nn_grid``
|
|
39
|
+
- Solver params: ``max_iter``, ``tol``
|
|
40
|
+
- Inference params: ``alpha``, ``n_bootstrap``, ``seed``
|
|
41
|
+
- State: ``results_``, ``is_fitted_``
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Type hints for attributes accessed from the main TROP class
|
|
46
|
+
lambda_time_grid: List[float]
|
|
47
|
+
lambda_unit_grid: List[float]
|
|
48
|
+
lambda_nn_grid: List[float]
|
|
49
|
+
max_iter: int
|
|
50
|
+
tol: float
|
|
51
|
+
alpha: float
|
|
52
|
+
n_bootstrap: int
|
|
53
|
+
seed: Optional[int]
|
|
54
|
+
results_: Any
|
|
55
|
+
is_fitted_: bool
|
|
56
|
+
|
|
57
|
+
def _compute_global_weights(
|
|
58
|
+
self,
|
|
59
|
+
Y: np.ndarray,
|
|
60
|
+
D: np.ndarray,
|
|
61
|
+
lambda_time: float,
|
|
62
|
+
lambda_unit: float,
|
|
63
|
+
treated_periods: int,
|
|
64
|
+
n_units: int,
|
|
65
|
+
n_periods: int,
|
|
66
|
+
) -> np.ndarray:
|
|
67
|
+
"""
|
|
68
|
+
Compute distance-based weights for global estimation.
|
|
69
|
+
|
|
70
|
+
Following the reference implementation, weights are computed based on:
|
|
71
|
+
- Time distance: distance to center of treated block
|
|
72
|
+
- Unit distance: RMSE to average treated trajectory over pre-periods
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
Y : np.ndarray
|
|
77
|
+
Outcome matrix (n_periods x n_units).
|
|
78
|
+
D : np.ndarray
|
|
79
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
80
|
+
lambda_time : float
|
|
81
|
+
Time weight decay parameter.
|
|
82
|
+
lambda_unit : float
|
|
83
|
+
Unit weight decay parameter.
|
|
84
|
+
treated_periods : int
|
|
85
|
+
Number of post-treatment periods.
|
|
86
|
+
n_units : int
|
|
87
|
+
Number of units.
|
|
88
|
+
n_periods : int
|
|
89
|
+
Number of periods.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
np.ndarray
|
|
94
|
+
Weight matrix (n_periods x n_units).
|
|
95
|
+
"""
|
|
96
|
+
# Identify treated units (ever treated)
|
|
97
|
+
treated_mask = np.any(D == 1, axis=0)
|
|
98
|
+
treated_unit_idx = np.where(treated_mask)[0]
|
|
99
|
+
|
|
100
|
+
if len(treated_unit_idx) == 0:
|
|
101
|
+
raise ValueError("No treated units found")
|
|
102
|
+
|
|
103
|
+
# Time weights: distance to center of treated block
|
|
104
|
+
# Following reference: center = T - treated_periods/2
|
|
105
|
+
center = n_periods - treated_periods / 2.0
|
|
106
|
+
dist_time = np.abs(np.arange(n_periods, dtype=float) - center)
|
|
107
|
+
delta_time = np.exp(-lambda_time * dist_time)
|
|
108
|
+
|
|
109
|
+
# Unit weights: RMSE to average treated trajectory over pre-periods
|
|
110
|
+
# Compute average treated trajectory (use nanmean to handle NaN)
|
|
111
|
+
average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1)
|
|
112
|
+
|
|
113
|
+
# Pre-period mask: 1 in pre, 0 in post
|
|
114
|
+
pre_mask = np.ones(n_periods, dtype=float)
|
|
115
|
+
pre_mask[-treated_periods:] = 0.0
|
|
116
|
+
|
|
117
|
+
# Compute RMS distance for each unit
|
|
118
|
+
# dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre)
|
|
119
|
+
# Use NaN-safe operations: treat NaN differences as 0 (excluded)
|
|
120
|
+
diff = average_treated[:, np.newaxis] - Y
|
|
121
|
+
diff_sq = np.where(np.isfinite(diff), diff**2, 0.0) * pre_mask[:, np.newaxis]
|
|
122
|
+
|
|
123
|
+
# Count valid observations per unit in pre-period
|
|
124
|
+
# Must check diff is finite (both Y and average_treated finite)
|
|
125
|
+
# to match the periods contributing to diff_sq
|
|
126
|
+
valid_count = np.sum(np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0)
|
|
127
|
+
sum_sq = np.sum(diff_sq, axis=0)
|
|
128
|
+
n_pre = np.sum(pre_mask)
|
|
129
|
+
|
|
130
|
+
if n_pre == 0:
|
|
131
|
+
raise ValueError("No pre-treatment periods")
|
|
132
|
+
|
|
133
|
+
# Track units with no valid pre-period data
|
|
134
|
+
no_valid_pre = valid_count == 0
|
|
135
|
+
|
|
136
|
+
# Use valid count per unit (avoid division by zero for calculation)
|
|
137
|
+
valid_count_safe = np.maximum(valid_count, 1)
|
|
138
|
+
dist_unit = np.sqrt(sum_sq / valid_count_safe)
|
|
139
|
+
|
|
140
|
+
# Units with no valid pre-period data get zero weight
|
|
141
|
+
# (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0)
|
|
142
|
+
delta_unit = np.exp(-lambda_unit * dist_unit)
|
|
143
|
+
delta_unit[no_valid_pre] = 0.0
|
|
144
|
+
|
|
145
|
+
# Outer product: (n_periods x n_units)
|
|
146
|
+
delta = np.outer(delta_time, delta_unit)
|
|
147
|
+
|
|
148
|
+
# (1-W) masking: zero out treated observations per paper Eq. 2
|
|
149
|
+
# Model is fit on control data only; tau extracted post-hoc
|
|
150
|
+
delta = delta * (1 - D)
|
|
151
|
+
|
|
152
|
+
return delta
|
|
153
|
+
|
|
154
|
+
def _solve_global_model(
|
|
155
|
+
self,
|
|
156
|
+
Y: np.ndarray,
|
|
157
|
+
delta: np.ndarray,
|
|
158
|
+
lambda_nn: float,
|
|
159
|
+
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
|
|
160
|
+
"""
|
|
161
|
+
Dispatch to no-lowrank or with-lowrank solver based on lambda_nn.
|
|
162
|
+
|
|
163
|
+
Returns (mu, alpha, beta, L) in all cases.
|
|
164
|
+
"""
|
|
165
|
+
n_periods, n_units = Y.shape
|
|
166
|
+
if lambda_nn >= 1e10:
|
|
167
|
+
mu, alpha, beta = self._solve_global_no_lowrank(Y, delta)
|
|
168
|
+
L = np.zeros((n_periods, n_units))
|
|
169
|
+
else:
|
|
170
|
+
mu, alpha, beta, L = self._solve_global_with_lowrank(
|
|
171
|
+
Y, delta, lambda_nn, self.max_iter, self.tol
|
|
172
|
+
)
|
|
173
|
+
return mu, alpha, beta, L
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def _extract_posthoc_tau(
|
|
177
|
+
Y: np.ndarray,
|
|
178
|
+
D: np.ndarray,
|
|
179
|
+
mu: float,
|
|
180
|
+
alpha: np.ndarray,
|
|
181
|
+
beta: np.ndarray,
|
|
182
|
+
L: np.ndarray,
|
|
183
|
+
idx_to_unit: Optional[Dict] = None,
|
|
184
|
+
idx_to_period: Optional[Dict] = None,
|
|
185
|
+
unit_weights: Optional[np.ndarray] = None,
|
|
186
|
+
) -> Tuple[float, Dict, List[float]]:
|
|
187
|
+
"""
|
|
188
|
+
Extract post-hoc treatment effects: tau_it = Y - mu - alpha - beta - L.
|
|
189
|
+
|
|
190
|
+
Returns (att, treatment_effects_dict, tau_values_list).
|
|
191
|
+
When idx_to_unit/idx_to_period are None, treatment_effects uses raw indices.
|
|
192
|
+
"""
|
|
193
|
+
counterfactual = mu + alpha[np.newaxis, :] + beta[:, np.newaxis] + L
|
|
194
|
+
tau_matrix = Y - counterfactual
|
|
195
|
+
|
|
196
|
+
treated_mask = D == 1
|
|
197
|
+
finite_mask = np.isfinite(Y)
|
|
198
|
+
valid_treated = treated_mask & finite_mask
|
|
199
|
+
|
|
200
|
+
tau_values = tau_matrix[valid_treated].tolist()
|
|
201
|
+
if unit_weights is not None and tau_values:
|
|
202
|
+
obs_weights = unit_weights[np.where(valid_treated)[1]]
|
|
203
|
+
att = float(np.average(tau_values, weights=obs_weights))
|
|
204
|
+
else:
|
|
205
|
+
att = float(np.mean(tau_values)) if tau_values else np.nan
|
|
206
|
+
|
|
207
|
+
# Build treatment effects dict
|
|
208
|
+
treatment_effects: Dict = {}
|
|
209
|
+
n_periods, n_units = D.shape
|
|
210
|
+
for t in range(n_periods):
|
|
211
|
+
for i in range(n_units):
|
|
212
|
+
if D[t, i] == 1:
|
|
213
|
+
uid = idx_to_unit[i] if idx_to_unit is not None else i
|
|
214
|
+
tid = idx_to_period[t] if idx_to_period is not None else t
|
|
215
|
+
if finite_mask[t, i]:
|
|
216
|
+
treatment_effects[(uid, tid)] = tau_matrix[t, i]
|
|
217
|
+
else:
|
|
218
|
+
treatment_effects[(uid, tid)] = np.nan
|
|
219
|
+
|
|
220
|
+
return att, treatment_effects, tau_values
|
|
221
|
+
|
|
222
|
+
def _loocv_score_global(
|
|
223
|
+
self,
|
|
224
|
+
Y: np.ndarray,
|
|
225
|
+
D: np.ndarray,
|
|
226
|
+
control_obs: List[Tuple[int, int]],
|
|
227
|
+
lambda_time: float,
|
|
228
|
+
lambda_unit: float,
|
|
229
|
+
lambda_nn: float,
|
|
230
|
+
treated_periods: int,
|
|
231
|
+
n_units: int,
|
|
232
|
+
n_periods: int,
|
|
233
|
+
) -> float:
|
|
234
|
+
"""
|
|
235
|
+
Compute LOOCV score for global method with specific parameter combination.
|
|
236
|
+
|
|
237
|
+
Following paper's Equation 5:
|
|
238
|
+
Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2
|
|
239
|
+
|
|
240
|
+
For global method, we exclude each control observation, fit the global model
|
|
241
|
+
on remaining data, and compute the pseudo-treatment effect at the excluded obs.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
Y : np.ndarray
|
|
246
|
+
Outcome matrix (n_periods x n_units).
|
|
247
|
+
D : np.ndarray
|
|
248
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
249
|
+
control_obs : List[Tuple[int, int]]
|
|
250
|
+
List of (t, i) control observations for LOOCV.
|
|
251
|
+
lambda_time : float
|
|
252
|
+
Time weight decay parameter.
|
|
253
|
+
lambda_unit : float
|
|
254
|
+
Unit weight decay parameter.
|
|
255
|
+
lambda_nn : float
|
|
256
|
+
Nuclear norm regularization parameter.
|
|
257
|
+
treated_periods : int
|
|
258
|
+
Number of post-treatment periods.
|
|
259
|
+
n_units : int
|
|
260
|
+
Number of units.
|
|
261
|
+
n_periods : int
|
|
262
|
+
Number of periods.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
float
|
|
267
|
+
LOOCV score (sum of squared pseudo-treatment effects).
|
|
268
|
+
"""
|
|
269
|
+
# Compute global weights (same for all LOOCV iterations)
|
|
270
|
+
delta = self._compute_global_weights(
|
|
271
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
tau_sq_sum = 0.0
|
|
275
|
+
n_valid = 0
|
|
276
|
+
|
|
277
|
+
for t_ex, i_ex in control_obs:
|
|
278
|
+
# Create modified delta with excluded observation zeroed out
|
|
279
|
+
delta_ex = delta.copy()
|
|
280
|
+
delta_ex[t_ex, i_ex] = 0.0
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn)
|
|
284
|
+
|
|
285
|
+
# Pseudo treatment effect: tau = Y - mu - alpha - beta - L
|
|
286
|
+
if np.isfinite(Y[t_ex, i_ex]):
|
|
287
|
+
tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex]
|
|
288
|
+
tau_sq_sum += tau_loocv**2
|
|
289
|
+
n_valid += 1
|
|
290
|
+
|
|
291
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
292
|
+
# Any failure means this lambda combination is invalid per Equation 5
|
|
293
|
+
return np.inf
|
|
294
|
+
|
|
295
|
+
if n_valid == 0:
|
|
296
|
+
return np.inf
|
|
297
|
+
|
|
298
|
+
return tau_sq_sum
|
|
299
|
+
|
|
300
|
+
def _solve_global_no_lowrank(
|
|
301
|
+
self,
|
|
302
|
+
Y: np.ndarray,
|
|
303
|
+
delta: np.ndarray,
|
|
304
|
+
) -> Tuple[float, np.ndarray, np.ndarray]:
|
|
305
|
+
"""
|
|
306
|
+
Solve TWFE via weighted least squares on control data (no low-rank).
|
|
307
|
+
|
|
308
|
+
Solves: min sum (1-W)*delta_{it}(Y_{it} - mu - alpha_i - beta_t)^2
|
|
309
|
+
|
|
310
|
+
The (1-W) masking is already applied to delta by _compute_global_weights,
|
|
311
|
+
so treated observations have zero weight and do not affect the fit.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
Y : np.ndarray
|
|
316
|
+
Outcome matrix (n_periods x n_units).
|
|
317
|
+
delta : np.ndarray
|
|
318
|
+
Weight matrix (n_periods x n_units), already (1-W) masked.
|
|
319
|
+
|
|
320
|
+
Returns
|
|
321
|
+
-------
|
|
322
|
+
Tuple[float, np.ndarray, np.ndarray]
|
|
323
|
+
(mu, alpha, beta) estimated parameters.
|
|
324
|
+
"""
|
|
325
|
+
n_periods, n_units = Y.shape
|
|
326
|
+
|
|
327
|
+
# Flatten matrices for regression
|
|
328
|
+
y = Y.flatten() # length n_periods * n_units
|
|
329
|
+
weights = delta.flatten()
|
|
330
|
+
|
|
331
|
+
# Handle NaN values: zero weight for NaN outcomes/weights, impute with 0
|
|
332
|
+
# This ensures NaN observations don't contribute to estimation
|
|
333
|
+
valid_y = np.isfinite(y)
|
|
334
|
+
valid_w = np.isfinite(weights)
|
|
335
|
+
valid_mask = valid_y & valid_w
|
|
336
|
+
weights = np.where(valid_mask, weights, 0.0)
|
|
337
|
+
y = np.where(valid_mask, y, 0.0)
|
|
338
|
+
|
|
339
|
+
sqrt_weights = np.sqrt(np.maximum(weights, 0))
|
|
340
|
+
|
|
341
|
+
# Check for all-zero weights (matches Rust's sum_w < 1e-10 check)
|
|
342
|
+
sum_w = np.sum(weights)
|
|
343
|
+
if sum_w < 1e-10:
|
|
344
|
+
raise ValueError("All weights are zero - cannot estimate")
|
|
345
|
+
|
|
346
|
+
# Build design matrix: [intercept, unit_dummies, time_dummies]
|
|
347
|
+
# Drop first unit (unit 0) and first time (time 0) for identification
|
|
348
|
+
n_obs = n_periods * n_units
|
|
349
|
+
n_params = 1 + (n_units - 1) + (n_periods - 1)
|
|
350
|
+
|
|
351
|
+
X = np.zeros((n_obs, n_params))
|
|
352
|
+
X[:, 0] = 1.0 # intercept
|
|
353
|
+
|
|
354
|
+
# Unit dummies (skip unit 0)
|
|
355
|
+
for i in range(1, n_units):
|
|
356
|
+
for t in range(n_periods):
|
|
357
|
+
X[t * n_units + i, i] = 1.0
|
|
358
|
+
|
|
359
|
+
# Time dummies (skip time 0)
|
|
360
|
+
for t in range(1, n_periods):
|
|
361
|
+
for i in range(n_units):
|
|
362
|
+
X[t * n_units + i, (n_units - 1) + t] = 1.0
|
|
363
|
+
|
|
364
|
+
# Apply weights
|
|
365
|
+
X_weighted = X * sqrt_weights[:, np.newaxis]
|
|
366
|
+
y_weighted = y * sqrt_weights
|
|
367
|
+
|
|
368
|
+
# Solve weighted least squares
|
|
369
|
+
try:
|
|
370
|
+
coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None)
|
|
371
|
+
except np.linalg.LinAlgError:
|
|
372
|
+
# Fallback: use pseudo-inverse
|
|
373
|
+
warnings.warn(
|
|
374
|
+
"Least-squares solver failed in TROP global estimation; "
|
|
375
|
+
"falling back to pseudo-inverse. Results may be less "
|
|
376
|
+
"numerically stable.",
|
|
377
|
+
UserWarning,
|
|
378
|
+
stacklevel=2,
|
|
379
|
+
)
|
|
380
|
+
coeffs = np.dot(np.linalg.pinv(X_weighted), y_weighted)
|
|
381
|
+
|
|
382
|
+
# Extract parameters
|
|
383
|
+
mu = coeffs[0]
|
|
384
|
+
alpha = np.zeros(n_units)
|
|
385
|
+
alpha[1:] = coeffs[1:n_units]
|
|
386
|
+
beta = np.zeros(n_periods)
|
|
387
|
+
beta[1:] = coeffs[n_units : (n_units + n_periods - 1)]
|
|
388
|
+
|
|
389
|
+
return float(mu), alpha, beta
|
|
390
|
+
|
|
391
|
+
def _solve_global_with_lowrank(
|
|
392
|
+
self,
|
|
393
|
+
Y: np.ndarray,
|
|
394
|
+
delta: np.ndarray,
|
|
395
|
+
lambda_nn: float,
|
|
396
|
+
max_iter: int = 100,
|
|
397
|
+
tol: float = 1e-6,
|
|
398
|
+
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
|
|
399
|
+
"""
|
|
400
|
+
Solve TWFE + low-rank on control data via alternating minimization.
|
|
401
|
+
|
|
402
|
+
Solves: min sum (1-W)*delta_{it}(Y_{it} - mu - alpha_i - beta_t - L_{it})^2 + lambda_nn||L||_*
|
|
403
|
+
|
|
404
|
+
The (1-W) masking is already applied to delta by _compute_global_weights,
|
|
405
|
+
so treated observations have zero weight and do not affect the fit.
|
|
406
|
+
|
|
407
|
+
Parameters
|
|
408
|
+
----------
|
|
409
|
+
Y : np.ndarray
|
|
410
|
+
Outcome matrix (n_periods x n_units).
|
|
411
|
+
delta : np.ndarray
|
|
412
|
+
Weight matrix (n_periods x n_units), already (1-W) masked.
|
|
413
|
+
lambda_nn : float
|
|
414
|
+
Nuclear norm regularization parameter.
|
|
415
|
+
max_iter : int, default=100
|
|
416
|
+
Maximum iterations for alternating minimization.
|
|
417
|
+
tol : float, default=1e-6
|
|
418
|
+
Convergence tolerance.
|
|
419
|
+
|
|
420
|
+
Returns
|
|
421
|
+
-------
|
|
422
|
+
Tuple[float, np.ndarray, np.ndarray, np.ndarray]
|
|
423
|
+
(mu, alpha, beta, L) estimated parameters.
|
|
424
|
+
"""
|
|
425
|
+
n_periods, n_units = Y.shape
|
|
426
|
+
|
|
427
|
+
# Handle NaN values: impute with 0 for computations
|
|
428
|
+
# The solver will also zero weights for NaN observations
|
|
429
|
+
Y_safe = np.where(np.isfinite(Y), Y, 0.0)
|
|
430
|
+
|
|
431
|
+
# Mask delta to exclude NaN outcomes from estimation
|
|
432
|
+
# This ensures NaN observations don't contribute to the gradient step
|
|
433
|
+
nan_mask = ~np.isfinite(Y)
|
|
434
|
+
delta_masked = delta.copy()
|
|
435
|
+
delta_masked[nan_mask] = 0.0
|
|
436
|
+
|
|
437
|
+
# Precompute normalized weights and threshold (constant across iterations)
|
|
438
|
+
delta_max = np.max(delta_masked)
|
|
439
|
+
if delta_max > 0:
|
|
440
|
+
delta_norm = delta_masked / delta_max
|
|
441
|
+
else:
|
|
442
|
+
delta_norm = delta_masked
|
|
443
|
+
threshold = lambda_nn / (2.0 * delta_max) if delta_max > 0 else lambda_nn / 2.0
|
|
444
|
+
|
|
445
|
+
# Initialize L = 0
|
|
446
|
+
L = np.zeros((n_periods, n_units))
|
|
447
|
+
|
|
448
|
+
for iteration in range(max_iter):
|
|
449
|
+
L_old = L.copy()
|
|
450
|
+
|
|
451
|
+
# Step 1: Fix L, solve for (mu, alpha, beta)
|
|
452
|
+
Y_adj = Y_safe - L
|
|
453
|
+
mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)
|
|
454
|
+
|
|
455
|
+
# Step 2: Fix (mu, alpha, beta), update L with FISTA acceleration
|
|
456
|
+
R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis]
|
|
457
|
+
|
|
458
|
+
# For delta=0 observations (treated/NaN), keep L rather than R
|
|
459
|
+
R_masked = np.where(delta_masked > 0, R, L)
|
|
460
|
+
|
|
461
|
+
# Inner FISTA loop for L update
|
|
462
|
+
L_inner = L.copy()
|
|
463
|
+
L_inner_prev = L_inner # share reference initially (no copy needed)
|
|
464
|
+
t_fista = 1.0
|
|
465
|
+
|
|
466
|
+
for _ in range(20):
|
|
467
|
+
# FISTA momentum
|
|
468
|
+
t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
|
|
469
|
+
momentum = (t_fista - 1.0) / t_fista_new
|
|
470
|
+
L_momentum = L_inner + momentum * (L_inner - L_inner_prev)
|
|
471
|
+
|
|
472
|
+
# Gradient step from momentum point
|
|
473
|
+
gradient_step = L_momentum + delta_norm * (R_masked - L_momentum)
|
|
474
|
+
|
|
475
|
+
# Proximal step: soft-threshold singular values
|
|
476
|
+
L_inner_prev = L_inner
|
|
477
|
+
L_inner = _soft_threshold_svd(gradient_step, threshold)
|
|
478
|
+
t_fista = t_fista_new
|
|
479
|
+
|
|
480
|
+
# Convergence check (L_inner_prev holds the pre-SVD value)
|
|
481
|
+
if np.max(np.abs(L_inner - L_inner_prev)) < tol:
|
|
482
|
+
break
|
|
483
|
+
|
|
484
|
+
L = L_inner
|
|
485
|
+
|
|
486
|
+
# Outer convergence check
|
|
487
|
+
if np.max(np.abs(L - L_old)) < tol:
|
|
488
|
+
break
|
|
489
|
+
|
|
490
|
+
# Final re-solve with converged L (match Rust behavior)
|
|
491
|
+
Y_adj = Y_safe - L
|
|
492
|
+
mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)
|
|
493
|
+
|
|
494
|
+
return mu, alpha, beta, L
|
|
495
|
+
|
|
496
|
+
def _fit_global(
|
|
497
|
+
self,
|
|
498
|
+
data: pd.DataFrame,
|
|
499
|
+
outcome: str,
|
|
500
|
+
treatment: str,
|
|
501
|
+
unit: str,
|
|
502
|
+
time: str,
|
|
503
|
+
resolved_survey=None,
|
|
504
|
+
survey_metadata=None,
|
|
505
|
+
survey_design=None,
|
|
506
|
+
) -> TROPResults:
|
|
507
|
+
"""
|
|
508
|
+
Fit TROP using global weighted least squares method.
|
|
509
|
+
|
|
510
|
+
Fits a single model on control observations using (1-W) masked weights,
|
|
511
|
+
then extracts per-observation treatment effects as post-hoc residuals.
|
|
512
|
+
ATT is the mean of these heterogeneous effects.
|
|
513
|
+
|
|
514
|
+
Parameters
|
|
515
|
+
----------
|
|
516
|
+
data : pd.DataFrame
|
|
517
|
+
Panel data.
|
|
518
|
+
outcome : str
|
|
519
|
+
Outcome variable column name.
|
|
520
|
+
treatment : str
|
|
521
|
+
Treatment indicator column name.
|
|
522
|
+
unit : str
|
|
523
|
+
Unit identifier column name.
|
|
524
|
+
time : str
|
|
525
|
+
Time period column name.
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
TROPResults
|
|
530
|
+
Estimation results.
|
|
531
|
+
|
|
532
|
+
Notes
|
|
533
|
+
-----
|
|
534
|
+
Bootstrap variance estimation assumes simultaneous treatment adoption
|
|
535
|
+
(fixed `treated_periods` across resamples). The treatment timing is
|
|
536
|
+
inferred from the data once and held constant for all bootstrap
|
|
537
|
+
iterations. For staggered adoption designs where treatment timing varies
|
|
538
|
+
across units, use `method="local"` which computes observation-specific
|
|
539
|
+
weights that naturally handle heterogeneous timing.
|
|
540
|
+
"""
|
|
541
|
+
# Data setup (same as local method)
|
|
542
|
+
all_units = sorted(data[unit].unique())
|
|
543
|
+
all_periods = sorted(data[time].unique())
|
|
544
|
+
|
|
545
|
+
# Extract per-unit survey weights for weighted ATT aggregation
|
|
546
|
+
if resolved_survey is not None:
|
|
547
|
+
from diff_diff.survey import _extract_unit_survey_weights
|
|
548
|
+
|
|
549
|
+
unit_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units)
|
|
550
|
+
else:
|
|
551
|
+
unit_weight_arr = None
|
|
552
|
+
|
|
553
|
+
n_units = len(all_units)
|
|
554
|
+
n_periods = len(all_periods)
|
|
555
|
+
|
|
556
|
+
idx_to_unit = {i: u for i, u in enumerate(all_units)}
|
|
557
|
+
idx_to_period = {i: p for i, p in enumerate(all_periods)}
|
|
558
|
+
|
|
559
|
+
# Create matrices
|
|
560
|
+
Y = (
|
|
561
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
562
|
+
.reindex(index=all_periods, columns=all_units)
|
|
563
|
+
.values
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
D, missing_mask = _validate_and_pivot_treatment(
|
|
567
|
+
data, time, unit, treatment, all_periods, all_units
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# Validate absorbing state
|
|
571
|
+
violating_units = []
|
|
572
|
+
for unit_idx in range(n_units):
|
|
573
|
+
observed_mask = ~missing_mask[:, unit_idx]
|
|
574
|
+
observed_d = D[observed_mask, unit_idx]
|
|
575
|
+
if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
|
|
576
|
+
violating_units.append(all_units[unit_idx])
|
|
577
|
+
|
|
578
|
+
if violating_units:
|
|
579
|
+
raise ValueError(
|
|
580
|
+
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
|
|
581
|
+
f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
|
|
582
|
+
f"If this is event-study style data, convert to absorbing state: "
|
|
583
|
+
f"D[t, i] = 1 for all t >= first treatment period."
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Identify treated observations
|
|
587
|
+
treated_mask = D == 1
|
|
588
|
+
n_treated_obs = np.sum(treated_mask)
|
|
589
|
+
|
|
590
|
+
if n_treated_obs == 0:
|
|
591
|
+
raise ValueError("No treated observations found")
|
|
592
|
+
|
|
593
|
+
# Identify treated and control units
|
|
594
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
595
|
+
treated_unit_idx = np.where(unit_ever_treated)[0]
|
|
596
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
597
|
+
|
|
598
|
+
if len(control_unit_idx) == 0:
|
|
599
|
+
raise ValueError("No control units found")
|
|
600
|
+
|
|
601
|
+
# Determine pre/post periods
|
|
602
|
+
first_treat_period = None
|
|
603
|
+
for t in range(n_periods):
|
|
604
|
+
if np.any(D[t, :] == 1):
|
|
605
|
+
first_treat_period = t
|
|
606
|
+
break
|
|
607
|
+
|
|
608
|
+
if first_treat_period is None:
|
|
609
|
+
raise ValueError("Could not infer post-treatment periods from D matrix")
|
|
610
|
+
|
|
611
|
+
n_pre_periods = first_treat_period
|
|
612
|
+
treated_periods = n_periods - first_treat_period
|
|
613
|
+
n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
|
|
614
|
+
|
|
615
|
+
if n_pre_periods < 2:
|
|
616
|
+
raise ValueError("Need at least 2 pre-treatment periods")
|
|
617
|
+
|
|
618
|
+
# Check for staggered adoption (global method requires simultaneous treatment)
|
|
619
|
+
# Use only observed periods (skip missing) to avoid false positives on unbalanced panels
|
|
620
|
+
first_treat_by_unit = []
|
|
621
|
+
for i in treated_unit_idx:
|
|
622
|
+
observed_mask = ~missing_mask[:, i]
|
|
623
|
+
# Get D values for observed periods only
|
|
624
|
+
observed_d = D[observed_mask, i]
|
|
625
|
+
observed_periods = np.where(observed_mask)[0]
|
|
626
|
+
# Find first treatment among observed periods
|
|
627
|
+
treated_idx = np.where(observed_d == 1)[0]
|
|
628
|
+
if len(treated_idx) > 0:
|
|
629
|
+
first_treat_by_unit.append(observed_periods[treated_idx[0]])
|
|
630
|
+
|
|
631
|
+
unique_starts = sorted(set(first_treat_by_unit))
|
|
632
|
+
if len(unique_starts) > 1:
|
|
633
|
+
raise ValueError(
|
|
634
|
+
f"method='global' requires simultaneous treatment adoption, but your data "
|
|
635
|
+
f"shows staggered adoption (units first treated at periods {unique_starts}). "
|
|
636
|
+
f"Use method='local' which properly handles staggered adoption designs."
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# LOOCV grid search for tuning parameters
|
|
640
|
+
# Use Rust backend when available for parallel LOOCV (5-10x speedup)
|
|
641
|
+
best_lambda = None
|
|
642
|
+
best_score = np.inf
|
|
643
|
+
control_mask = D == 0
|
|
644
|
+
|
|
645
|
+
if HAS_RUST_BACKEND and _rust_loocv_grid_search_global is not None:
|
|
646
|
+
try:
|
|
647
|
+
# Prepare inputs for Rust function
|
|
648
|
+
control_mask_u8 = control_mask.astype(np.uint8)
|
|
649
|
+
|
|
650
|
+
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
|
|
651
|
+
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
|
|
652
|
+
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
|
|
653
|
+
|
|
654
|
+
result = _rust_loocv_grid_search_global(
|
|
655
|
+
Y,
|
|
656
|
+
D.astype(np.float64),
|
|
657
|
+
control_mask_u8,
|
|
658
|
+
lambda_time_arr,
|
|
659
|
+
lambda_unit_arr,
|
|
660
|
+
lambda_nn_arr,
|
|
661
|
+
self.max_iter,
|
|
662
|
+
self.tol,
|
|
663
|
+
)
|
|
664
|
+
# Unpack result - 7 values including optional first_failed_obs
|
|
665
|
+
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = (
|
|
666
|
+
result
|
|
667
|
+
)
|
|
668
|
+
# Only accept finite scores - infinite means all fits failed
|
|
669
|
+
if np.isfinite(best_score):
|
|
670
|
+
best_lambda = (best_lt, best_lu, best_ln)
|
|
671
|
+
# Emit warnings consistent with Python implementation
|
|
672
|
+
if n_valid == 0:
|
|
673
|
+
obs_info = ""
|
|
674
|
+
if first_failed_obs is not None:
|
|
675
|
+
t_idx, i_idx = first_failed_obs
|
|
676
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
677
|
+
warnings.warn(
|
|
678
|
+
f"LOOCV: All {n_attempted} fits failed for "
|
|
679
|
+
f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
|
|
680
|
+
f"Returning infinite score.{obs_info}",
|
|
681
|
+
UserWarning,
|
|
682
|
+
)
|
|
683
|
+
elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
|
|
684
|
+
n_failed = n_attempted - n_valid
|
|
685
|
+
obs_info = ""
|
|
686
|
+
if first_failed_obs is not None:
|
|
687
|
+
t_idx, i_idx = first_failed_obs
|
|
688
|
+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
|
|
689
|
+
warnings.warn(
|
|
690
|
+
f"LOOCV: {n_failed}/{n_attempted} fits failed for "
|
|
691
|
+
f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
|
|
692
|
+
f"This may indicate numerical instability.{obs_info}",
|
|
693
|
+
UserWarning,
|
|
694
|
+
)
|
|
695
|
+
except Exception as e:
|
|
696
|
+
# Fall back to Python implementation on error
|
|
697
|
+
logger.debug(
|
|
698
|
+
"Rust LOOCV grid search (global) failed, falling back to Python: %s", e
|
|
699
|
+
)
|
|
700
|
+
warnings.warn(
|
|
701
|
+
f"Rust backend failed for LOOCV grid search (global); "
|
|
702
|
+
f"falling back to Python. Performance may be reduced. "
|
|
703
|
+
f"Error: {e}",
|
|
704
|
+
UserWarning,
|
|
705
|
+
stacklevel=2,
|
|
706
|
+
)
|
|
707
|
+
best_lambda = None
|
|
708
|
+
best_score = np.inf
|
|
709
|
+
|
|
710
|
+
# Fall back to Python implementation if Rust unavailable or failed
|
|
711
|
+
if best_lambda is None:
|
|
712
|
+
# Get control observations for LOOCV
|
|
713
|
+
control_obs = [
|
|
714
|
+
(t, i)
|
|
715
|
+
for t in range(n_periods)
|
|
716
|
+
for i in range(n_units)
|
|
717
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])
|
|
718
|
+
]
|
|
719
|
+
|
|
720
|
+
# Grid search with true LOOCV
|
|
721
|
+
for lambda_time_val in self.lambda_time_grid:
|
|
722
|
+
for lambda_unit_val in self.lambda_unit_grid:
|
|
723
|
+
for lambda_nn_val in self.lambda_nn_grid:
|
|
724
|
+
# Convert lambda_nn=inf -> large finite value (factor model disabled)
|
|
725
|
+
lt = lambda_time_val
|
|
726
|
+
lu = lambda_unit_val
|
|
727
|
+
ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
score = self._loocv_score_global(
|
|
731
|
+
Y, D, control_obs, lt, lu, ln, treated_periods, n_units, n_periods
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
if score < best_score:
|
|
735
|
+
best_score = score
|
|
736
|
+
best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val)
|
|
737
|
+
|
|
738
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
739
|
+
continue
|
|
740
|
+
|
|
741
|
+
if best_lambda is None:
|
|
742
|
+
warnings.warn("All tuning parameter combinations failed. Using defaults.", UserWarning)
|
|
743
|
+
best_lambda = (1.0, 1.0, 0.1)
|
|
744
|
+
best_score = np.nan
|
|
745
|
+
|
|
746
|
+
# Final estimation with best parameters
|
|
747
|
+
lambda_time, lambda_unit, lambda_nn = best_lambda
|
|
748
|
+
original_lambda_nn = lambda_nn
|
|
749
|
+
|
|
750
|
+
# Convert lambda_nn=inf -> large finite value (factor model disabled, L~0)
|
|
751
|
+
# lambda_time and lambda_unit use 0.0 for uniform weights directly (no conversion needed)
|
|
752
|
+
if np.isinf(lambda_nn):
|
|
753
|
+
lambda_nn = 1e10
|
|
754
|
+
|
|
755
|
+
# Compute final weights and fit
|
|
756
|
+
delta = self._compute_global_weights(
|
|
757
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
|
|
761
|
+
|
|
762
|
+
# Post-hoc tau extraction (per paper Eq. 2)
|
|
763
|
+
att, treatment_effects, tau_values = self._extract_posthoc_tau(
|
|
764
|
+
Y,
|
|
765
|
+
D,
|
|
766
|
+
mu,
|
|
767
|
+
alpha,
|
|
768
|
+
beta,
|
|
769
|
+
L,
|
|
770
|
+
idx_to_unit,
|
|
771
|
+
idx_to_period,
|
|
772
|
+
unit_weights=unit_weight_arr,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
# Use count of valid (finite) treated outcomes for df and metadata
|
|
776
|
+
n_valid_treated = len(tau_values)
|
|
777
|
+
if n_valid_treated == 0:
|
|
778
|
+
warnings.warn(
|
|
779
|
+
"All treated outcomes are NaN/missing. Cannot estimate ATT.",
|
|
780
|
+
UserWarning,
|
|
781
|
+
)
|
|
782
|
+
elif n_valid_treated < n_treated_obs:
|
|
783
|
+
warnings.warn(
|
|
784
|
+
f"Only {n_valid_treated} of {n_treated_obs} treated outcomes are finite. "
|
|
785
|
+
"df and n_treated_obs reflect valid observations only.",
|
|
786
|
+
UserWarning,
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# Compute effective rank of L
|
|
790
|
+
_, s, _ = np.linalg.svd(L, full_matrices=False)
|
|
791
|
+
if s[0] > 0:
|
|
792
|
+
effective_rank = np.sum(s) / s[0]
|
|
793
|
+
else:
|
|
794
|
+
effective_rank = 0.0
|
|
795
|
+
|
|
796
|
+
# Bootstrap variance estimation
|
|
797
|
+
effective_lambda = (lambda_time, lambda_unit, lambda_nn)
|
|
798
|
+
|
|
799
|
+
se, bootstrap_dist = self._bootstrap_variance_global(
|
|
800
|
+
data,
|
|
801
|
+
outcome,
|
|
802
|
+
treatment,
|
|
803
|
+
unit,
|
|
804
|
+
time,
|
|
805
|
+
effective_lambda,
|
|
806
|
+
treated_periods,
|
|
807
|
+
survey_design=survey_design,
|
|
808
|
+
unit_weight_arr=unit_weight_arr,
|
|
809
|
+
resolved_survey=resolved_survey,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Compute test statistics
|
|
813
|
+
df_trop = max(1, n_valid_treated - 1)
|
|
814
|
+
t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop)
|
|
815
|
+
|
|
816
|
+
# Create results dictionaries
|
|
817
|
+
unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)}
|
|
818
|
+
time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)}
|
|
819
|
+
|
|
820
|
+
self.results_ = TROPResults(
|
|
821
|
+
att=float(att),
|
|
822
|
+
se=float(se),
|
|
823
|
+
t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat,
|
|
824
|
+
p_value=float(p_value) if np.isfinite(p_value) else p_value,
|
|
825
|
+
conf_int=conf_int,
|
|
826
|
+
n_obs=len(data),
|
|
827
|
+
n_treated=len(treated_unit_idx),
|
|
828
|
+
n_control=len(control_unit_idx),
|
|
829
|
+
n_treated_obs=int(n_valid_treated),
|
|
830
|
+
unit_effects=unit_effects_dict,
|
|
831
|
+
time_effects=time_effects_dict,
|
|
832
|
+
treatment_effects=treatment_effects,
|
|
833
|
+
lambda_time=lambda_time,
|
|
834
|
+
lambda_unit=lambda_unit,
|
|
835
|
+
lambda_nn=original_lambda_nn,
|
|
836
|
+
factor_matrix=L,
|
|
837
|
+
effective_rank=effective_rank,
|
|
838
|
+
loocv_score=best_score,
|
|
839
|
+
alpha=self.alpha,
|
|
840
|
+
n_pre_periods=n_pre_periods,
|
|
841
|
+
n_post_periods=n_post_periods,
|
|
842
|
+
n_bootstrap=self.n_bootstrap,
|
|
843
|
+
bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
|
|
844
|
+
survey_metadata=survey_metadata,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
self.is_fitted_ = True
|
|
848
|
+
return self.results_
|
|
849
|
+
|
|
850
|
+
def _bootstrap_variance_global(
|
|
851
|
+
self,
|
|
852
|
+
data: pd.DataFrame,
|
|
853
|
+
outcome: str,
|
|
854
|
+
treatment: str,
|
|
855
|
+
unit: str,
|
|
856
|
+
time: str,
|
|
857
|
+
optimal_lambda: Tuple[float, float, float],
|
|
858
|
+
treated_periods: int,
|
|
859
|
+
survey_design=None,
|
|
860
|
+
unit_weight_arr: Optional[np.ndarray] = None,
|
|
861
|
+
resolved_survey=None,
|
|
862
|
+
) -> Tuple[float, np.ndarray]:
|
|
863
|
+
"""
|
|
864
|
+
Compute bootstrap standard error for global method.
|
|
865
|
+
|
|
866
|
+
Uses Rust backend when available for parallel bootstrap (5-15x speedup).
|
|
867
|
+
When a full survey design (strata/PSU/FPC) is present, uses Rao-Wu
|
|
868
|
+
rescaled bootstrap instead, which skips the Rust path.
|
|
869
|
+
|
|
870
|
+
Parameters
|
|
871
|
+
----------
|
|
872
|
+
data : pd.DataFrame
|
|
873
|
+
Original data.
|
|
874
|
+
outcome : str
|
|
875
|
+
Outcome column name.
|
|
876
|
+
treatment : str
|
|
877
|
+
Treatment column name.
|
|
878
|
+
unit : str
|
|
879
|
+
Unit column name.
|
|
880
|
+
time : str
|
|
881
|
+
Time column name.
|
|
882
|
+
optimal_lambda : tuple
|
|
883
|
+
Optimal tuning parameters.
|
|
884
|
+
treated_periods : int
|
|
885
|
+
Number of post-treatment periods.
|
|
886
|
+
survey_design : SurveyDesign, optional
|
|
887
|
+
Survey design specification.
|
|
888
|
+
unit_weight_arr : np.ndarray, optional
|
|
889
|
+
Unit-level survey weights.
|
|
890
|
+
resolved_survey : ResolvedSurveyDesign, optional
|
|
891
|
+
Resolved survey design (observation-level).
|
|
892
|
+
|
|
893
|
+
Returns
|
|
894
|
+
-------
|
|
895
|
+
Tuple[float, np.ndarray]
|
|
896
|
+
(se, bootstrap_estimates).
|
|
897
|
+
"""
|
|
898
|
+
lambda_time, lambda_unit, lambda_nn = optimal_lambda
|
|
899
|
+
|
|
900
|
+
# Check for full survey design (strata/PSU/FPC present)
|
|
901
|
+
_has_full_design = resolved_survey is not None and (
|
|
902
|
+
resolved_survey.strata is not None
|
|
903
|
+
or resolved_survey.psu is not None
|
|
904
|
+
or resolved_survey.fpc is not None
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# Full survey design: use Python Rao-Wu rescaled bootstrap
|
|
908
|
+
if _has_full_design:
|
|
909
|
+
return self._bootstrap_rao_wu_global(
|
|
910
|
+
data,
|
|
911
|
+
outcome,
|
|
912
|
+
treatment,
|
|
913
|
+
unit,
|
|
914
|
+
time,
|
|
915
|
+
optimal_lambda,
|
|
916
|
+
treated_periods,
|
|
917
|
+
resolved_survey,
|
|
918
|
+
survey_design,
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
# Try Rust backend for parallel bootstrap (5-15x speedup)
|
|
922
|
+
# Only used for pweight-only designs (no strata/PSU/FPC)
|
|
923
|
+
if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_global is not None:
|
|
924
|
+
try:
|
|
925
|
+
# Create matrices for Rust function
|
|
926
|
+
all_units = sorted(data[unit].unique())
|
|
927
|
+
all_periods = sorted(data[time].unique())
|
|
928
|
+
|
|
929
|
+
Y = (
|
|
930
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
931
|
+
.reindex(index=all_periods, columns=all_units)
|
|
932
|
+
.values
|
|
933
|
+
)
|
|
934
|
+
D = (
|
|
935
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
936
|
+
.reindex(index=all_periods, columns=all_units)
|
|
937
|
+
.fillna(0)
|
|
938
|
+
.astype(np.float64)
|
|
939
|
+
.values
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
bootstrap_estimates, se = _rust_bootstrap_trop_variance_global(
|
|
943
|
+
Y,
|
|
944
|
+
D,
|
|
945
|
+
lambda_time,
|
|
946
|
+
lambda_unit,
|
|
947
|
+
lambda_nn,
|
|
948
|
+
self.n_bootstrap,
|
|
949
|
+
self.max_iter,
|
|
950
|
+
self.tol,
|
|
951
|
+
self.seed if self.seed is not None else 0,
|
|
952
|
+
unit_weight_arr,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
if len(bootstrap_estimates) < 10:
|
|
956
|
+
warnings.warn(
|
|
957
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
|
|
958
|
+
UserWarning,
|
|
959
|
+
)
|
|
960
|
+
if len(bootstrap_estimates) == 0:
|
|
961
|
+
return np.nan, np.array([])
|
|
962
|
+
|
|
963
|
+
return float(se), np.array(bootstrap_estimates)
|
|
964
|
+
|
|
965
|
+
except Exception as e:
|
|
966
|
+
logger.debug("Rust bootstrap (global) failed, falling back to Python: %s", e)
|
|
967
|
+
warnings.warn(
|
|
968
|
+
f"Rust backend failed for bootstrap variance (global); "
|
|
969
|
+
f"falling back to Python. Performance may be reduced. "
|
|
970
|
+
f"Error: {e}",
|
|
971
|
+
UserWarning,
|
|
972
|
+
stacklevel=2,
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# Python fallback implementation
|
|
976
|
+
rng = np.random.default_rng(self.seed)
|
|
977
|
+
|
|
978
|
+
# Stratified bootstrap sampling
|
|
979
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
980
|
+
treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist())
|
|
981
|
+
control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist())
|
|
982
|
+
|
|
983
|
+
n_treated_units = len(treated_units)
|
|
984
|
+
n_control_units = len(control_units)
|
|
985
|
+
|
|
986
|
+
bootstrap_estimates_list: List[float] = []
|
|
987
|
+
|
|
988
|
+
for _ in range(self.n_bootstrap):
|
|
989
|
+
# Stratified sampling
|
|
990
|
+
if n_control_units > 0:
|
|
991
|
+
sampled_control = rng.choice(control_units, size=n_control_units, replace=True)
|
|
992
|
+
else:
|
|
993
|
+
sampled_control = np.array([], dtype=object)
|
|
994
|
+
|
|
995
|
+
if n_treated_units > 0:
|
|
996
|
+
sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True)
|
|
997
|
+
else:
|
|
998
|
+
sampled_treated = np.array([], dtype=object)
|
|
999
|
+
|
|
1000
|
+
sampled_units = np.concatenate([sampled_control, sampled_treated])
|
|
1001
|
+
|
|
1002
|
+
# Create bootstrap sample
|
|
1003
|
+
boot_data = pd.concat(
|
|
1004
|
+
[
|
|
1005
|
+
data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
|
|
1006
|
+
for idx, u in enumerate(sampled_units)
|
|
1007
|
+
],
|
|
1008
|
+
ignore_index=True,
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
try:
|
|
1012
|
+
tau = self._fit_global_with_fixed_lambda(
|
|
1013
|
+
boot_data,
|
|
1014
|
+
outcome,
|
|
1015
|
+
treatment,
|
|
1016
|
+
unit,
|
|
1017
|
+
time,
|
|
1018
|
+
optimal_lambda,
|
|
1019
|
+
treated_periods,
|
|
1020
|
+
survey_design=survey_design,
|
|
1021
|
+
)
|
|
1022
|
+
if np.isfinite(tau):
|
|
1023
|
+
bootstrap_estimates_list.append(tau)
|
|
1024
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
1025
|
+
continue
|
|
1026
|
+
|
|
1027
|
+
bootstrap_estimates = np.array(bootstrap_estimates_list)
|
|
1028
|
+
|
|
1029
|
+
if len(bootstrap_estimates) < 10:
|
|
1030
|
+
warnings.warn(
|
|
1031
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning
|
|
1032
|
+
)
|
|
1033
|
+
if len(bootstrap_estimates) == 0:
|
|
1034
|
+
return np.nan, np.array([])
|
|
1035
|
+
|
|
1036
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
1037
|
+
return float(se), bootstrap_estimates
|
|
1038
|
+
|
|
1039
|
+
def _bootstrap_rao_wu_global(
|
|
1040
|
+
self,
|
|
1041
|
+
data: pd.DataFrame,
|
|
1042
|
+
outcome: str,
|
|
1043
|
+
treatment: str,
|
|
1044
|
+
unit: str,
|
|
1045
|
+
time: str,
|
|
1046
|
+
optimal_lambda: Tuple[float, float, float],
|
|
1047
|
+
treated_periods: int,
|
|
1048
|
+
resolved_survey,
|
|
1049
|
+
survey_design,
|
|
1050
|
+
) -> Tuple[float, np.ndarray]:
|
|
1051
|
+
"""
|
|
1052
|
+
Rao-Wu rescaled bootstrap for global method with full survey design.
|
|
1053
|
+
|
|
1054
|
+
Instead of physically resampling units, each iteration generates
|
|
1055
|
+
rescaled observation weights via Rao-Wu (1988) weight perturbation.
|
|
1056
|
+
Cross-classifies survey strata with treatment group to preserve
|
|
1057
|
+
the stratified resampling structure.
|
|
1058
|
+
|
|
1059
|
+
Parameters
|
|
1060
|
+
----------
|
|
1061
|
+
data : pd.DataFrame
|
|
1062
|
+
Original data.
|
|
1063
|
+
outcome, treatment, unit, time : str
|
|
1064
|
+
Column names.
|
|
1065
|
+
optimal_lambda : tuple
|
|
1066
|
+
Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn).
|
|
1067
|
+
treated_periods : int
|
|
1068
|
+
Number of post-treatment periods.
|
|
1069
|
+
resolved_survey : ResolvedSurveyDesign
|
|
1070
|
+
Resolved survey design (observation-level).
|
|
1071
|
+
survey_design : SurveyDesign
|
|
1072
|
+
Original survey design specification.
|
|
1073
|
+
|
|
1074
|
+
Returns
|
|
1075
|
+
-------
|
|
1076
|
+
Tuple[float, np.ndarray]
|
|
1077
|
+
(se, bootstrap_estimates).
|
|
1078
|
+
"""
|
|
1079
|
+
from diff_diff.bootstrap_utils import generate_rao_wu_weights
|
|
1080
|
+
from diff_diff.survey import ResolvedSurveyDesign
|
|
1081
|
+
|
|
1082
|
+
lambda_time, lambda_unit, lambda_nn = optimal_lambda
|
|
1083
|
+
rng = np.random.default_rng(self.seed)
|
|
1084
|
+
|
|
1085
|
+
# Build unit-level resolved survey with cross-classified strata
|
|
1086
|
+
all_units = sorted(data[unit].unique())
|
|
1087
|
+
n_units = len(all_units)
|
|
1088
|
+
|
|
1089
|
+
# Determine treatment status per unit
|
|
1090
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
1091
|
+
treatment_group = np.array([int(unit_ever_treated[u]) for u in all_units], dtype=np.int64)
|
|
1092
|
+
|
|
1093
|
+
# Extract unit-level survey design fields
|
|
1094
|
+
first_rows = data.groupby(unit).first().loc[all_units]
|
|
1095
|
+
|
|
1096
|
+
# Weights (unit-level)
|
|
1097
|
+
if survey_design.weights is not None:
|
|
1098
|
+
unit_weights = first_rows[survey_design.weights].values.astype(np.float64)
|
|
1099
|
+
else:
|
|
1100
|
+
unit_weights = np.ones(n_units, dtype=np.float64)
|
|
1101
|
+
|
|
1102
|
+
# Strata: cross-classify survey strata x treatment group
|
|
1103
|
+
from diff_diff.linalg import _factorize_cluster_ids
|
|
1104
|
+
|
|
1105
|
+
if survey_design.strata is not None:
|
|
1106
|
+
survey_strata = first_rows[survey_design.strata].values
|
|
1107
|
+
cross_labels = np.array([f"{s}_{g}" for s, g in zip(survey_strata, treatment_group)])
|
|
1108
|
+
cross_strata = _factorize_cluster_ids(cross_labels)
|
|
1109
|
+
else:
|
|
1110
|
+
# No survey strata: use treatment group as strata
|
|
1111
|
+
cross_strata = treatment_group.copy()
|
|
1112
|
+
n_strata = len(np.unique(cross_strata))
|
|
1113
|
+
|
|
1114
|
+
# PSU (unit-level)
|
|
1115
|
+
psu_arr = None
|
|
1116
|
+
n_psu = 0
|
|
1117
|
+
if survey_design.psu is not None:
|
|
1118
|
+
psu_raw = first_rows[survey_design.psu].values
|
|
1119
|
+
if survey_design.nest and survey_design.strata is not None:
|
|
1120
|
+
combined = np.array([f"{s}_{p}" for s, p in zip(cross_strata, psu_raw)])
|
|
1121
|
+
psu_arr = _factorize_cluster_ids(combined)
|
|
1122
|
+
else:
|
|
1123
|
+
psu_arr = _factorize_cluster_ids(psu_raw)
|
|
1124
|
+
n_psu = len(np.unique(psu_arr))
|
|
1125
|
+
else:
|
|
1126
|
+
# Implicit PSU: each unit is its own PSU
|
|
1127
|
+
psu_arr = np.arange(n_units, dtype=np.int64)
|
|
1128
|
+
n_psu = n_units
|
|
1129
|
+
|
|
1130
|
+
# FPC (unit-level)
|
|
1131
|
+
fpc_arr = None
|
|
1132
|
+
if survey_design.fpc is not None:
|
|
1133
|
+
fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64)
|
|
1134
|
+
|
|
1135
|
+
unit_resolved = ResolvedSurveyDesign(
|
|
1136
|
+
weights=unit_weights,
|
|
1137
|
+
weight_type=resolved_survey.weight_type,
|
|
1138
|
+
strata=cross_strata,
|
|
1139
|
+
psu=psu_arr,
|
|
1140
|
+
fpc=fpc_arr,
|
|
1141
|
+
n_strata=n_strata,
|
|
1142
|
+
n_psu=n_psu,
|
|
1143
|
+
lonely_psu=resolved_survey.lonely_psu,
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
# Check for unidentified variance (single unstratified PSU)
|
|
1147
|
+
if (
|
|
1148
|
+
survey_design.psu is not None
|
|
1149
|
+
and unit_resolved.n_psu < 2
|
|
1150
|
+
and survey_design.strata is None
|
|
1151
|
+
):
|
|
1152
|
+
return np.nan, np.array([])
|
|
1153
|
+
|
|
1154
|
+
# Bootstrap loop with Rao-Wu rescaled weights
|
|
1155
|
+
all_periods = sorted(data[time].unique())
|
|
1156
|
+
n_periods = len(all_periods)
|
|
1157
|
+
|
|
1158
|
+
Y = (
|
|
1159
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1160
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1161
|
+
.values
|
|
1162
|
+
)
|
|
1163
|
+
D = (
|
|
1164
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1165
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1166
|
+
.fillna(0)
|
|
1167
|
+
.astype(int)
|
|
1168
|
+
.values
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
bootstrap_estimates_list: List[float] = []
|
|
1172
|
+
|
|
1173
|
+
for _ in range(self.n_bootstrap):
|
|
1174
|
+
try:
|
|
1175
|
+
# Generate Rao-Wu rescaled weights (unit-level)
|
|
1176
|
+
boot_weights = generate_rao_wu_weights(unit_resolved, rng)
|
|
1177
|
+
|
|
1178
|
+
# Skip if all control or all treated weights are zero
|
|
1179
|
+
control_mask_units = treatment_group == 0
|
|
1180
|
+
treated_mask_units = treatment_group == 1
|
|
1181
|
+
if boot_weights[control_mask_units].sum() == 0:
|
|
1182
|
+
continue
|
|
1183
|
+
if boot_weights[treated_mask_units].sum() == 0:
|
|
1184
|
+
continue
|
|
1185
|
+
|
|
1186
|
+
# Compute global weights and fit model
|
|
1187
|
+
delta = self._compute_global_weights(
|
|
1188
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
1189
|
+
)
|
|
1190
|
+
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
|
|
1191
|
+
|
|
1192
|
+
# Extract weighted ATT using Rao-Wu rescaled weights
|
|
1193
|
+
att, _, _ = self._extract_posthoc_tau(
|
|
1194
|
+
Y, D, mu, alpha, beta, L, unit_weights=boot_weights
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
if np.isfinite(att):
|
|
1198
|
+
bootstrap_estimates_list.append(att)
|
|
1199
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
1200
|
+
continue
|
|
1201
|
+
|
|
1202
|
+
bootstrap_estimates = np.array(bootstrap_estimates_list)
|
|
1203
|
+
|
|
1204
|
+
if len(bootstrap_estimates) < 10:
|
|
1205
|
+
warnings.warn(
|
|
1206
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
|
|
1207
|
+
UserWarning,
|
|
1208
|
+
)
|
|
1209
|
+
if len(bootstrap_estimates) == 0:
|
|
1210
|
+
return np.nan, np.array([])
|
|
1211
|
+
|
|
1212
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
1213
|
+
return float(se), bootstrap_estimates
|
|
1214
|
+
|
|
1215
|
+
def _fit_global_with_fixed_lambda(
|
|
1216
|
+
self,
|
|
1217
|
+
data: pd.DataFrame,
|
|
1218
|
+
outcome: str,
|
|
1219
|
+
treatment: str,
|
|
1220
|
+
unit: str,
|
|
1221
|
+
time: str,
|
|
1222
|
+
fixed_lambda: Tuple[float, float, float],
|
|
1223
|
+
treated_periods: int,
|
|
1224
|
+
survey_design=None,
|
|
1225
|
+
) -> float:
|
|
1226
|
+
"""
|
|
1227
|
+
Fit global model with fixed tuning parameters.
|
|
1228
|
+
|
|
1229
|
+
Returns the ATT (mean of post-hoc per-observation treatment effects).
|
|
1230
|
+
"""
|
|
1231
|
+
lambda_time, lambda_unit, lambda_nn = fixed_lambda
|
|
1232
|
+
|
|
1233
|
+
all_units = sorted(data[unit].unique())
|
|
1234
|
+
all_periods = sorted(data[time].unique())
|
|
1235
|
+
|
|
1236
|
+
# Extract per-unit survey weights for weighted ATT in bootstrap
|
|
1237
|
+
if survey_design is not None and survey_design.weights is not None:
|
|
1238
|
+
from diff_diff.survey import _extract_unit_survey_weights
|
|
1239
|
+
|
|
1240
|
+
local_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units)
|
|
1241
|
+
else:
|
|
1242
|
+
local_weight_arr = None
|
|
1243
|
+
|
|
1244
|
+
n_units = len(all_units)
|
|
1245
|
+
n_periods = len(all_periods)
|
|
1246
|
+
|
|
1247
|
+
Y = (
|
|
1248
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1249
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1250
|
+
.values
|
|
1251
|
+
)
|
|
1252
|
+
D = (
|
|
1253
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1254
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1255
|
+
.fillna(0)
|
|
1256
|
+
.astype(int)
|
|
1257
|
+
.values
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
# Compute weights (includes (1-W) masking)
|
|
1261
|
+
delta = self._compute_global_weights(
|
|
1262
|
+
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
# Fit model on control data and extract post-hoc tau
|
|
1266
|
+
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
|
|
1267
|
+
att, _, _ = self._extract_posthoc_tau(
|
|
1268
|
+
Y, D, mu, alpha, beta, L, unit_weights=local_weight_arr
|
|
1269
|
+
)
|
|
1270
|
+
return att
|