diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/trop_local.py
ADDED
|
@@ -0,0 +1,1307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local (observation-specific) estimation method for the TROP estimator.
|
|
3
|
+
|
|
4
|
+
Contains the TROPLocalMixin class with all methods for the local
|
|
5
|
+
estimation pathway, including preprocessing, distance computation,
|
|
6
|
+
per-observation weight computation, model fitting, LOOCV scoring,
|
|
7
|
+
and bootstrap variance estimation.
|
|
8
|
+
|
|
9
|
+
This module is used via mixin inheritance — see trop.py for the
|
|
10
|
+
main TROP class definition.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import warnings
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
from diff_diff._backend import (
|
|
23
|
+
HAS_RUST_BACKEND,
|
|
24
|
+
_rust_bootstrap_trop_variance,
|
|
25
|
+
_rust_unit_distance_matrix,
|
|
26
|
+
)
|
|
27
|
+
from diff_diff.trop_results import _PrecomputedStructures
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _validate_and_pivot_treatment(data, time, unit, treatment, all_periods, all_units):
|
|
31
|
+
"""Validate treatment column and create D matrix with missing mask.
|
|
32
|
+
|
|
33
|
+
Rejects observed rows with missing treatment values (data quality error),
|
|
34
|
+
then pivots to (time x unit) matrix. Structural gaps from unbalanced panels
|
|
35
|
+
are filled with 0 (assumed untreated) and flagged with a warning.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
D : ndarray
|
|
40
|
+
Treatment matrix (n_periods x n_units), int.
|
|
41
|
+
missing_mask : ndarray
|
|
42
|
+
Boolean mask of structurally absent cells (n_periods x n_units).
|
|
43
|
+
"""
|
|
44
|
+
n_nan_observed = int(data[treatment].isna().sum())
|
|
45
|
+
if n_nan_observed > 0:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"{n_nan_observed} observation(s) have missing treatment values. "
|
|
48
|
+
f"TROP requires non-missing treatment indicators for all observed "
|
|
49
|
+
f"rows. Remove or impute missing values before fitting."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
D_raw = data.pivot(index=time, columns=unit, values=treatment).reindex(
|
|
53
|
+
index=all_periods, columns=all_units
|
|
54
|
+
)
|
|
55
|
+
missing_mask = pd.isna(D_raw).values
|
|
56
|
+
n_missing_structural = int(missing_mask.sum())
|
|
57
|
+
if n_missing_structural > 0:
|
|
58
|
+
warnings.warn(
|
|
59
|
+
f"{n_missing_structural} missing treatment indicator(s) in the "
|
|
60
|
+
f"(time x unit) panel matrix filled with 0 (assumed "
|
|
61
|
+
f"untreated). This typically occurs in unbalanced panels.",
|
|
62
|
+
UserWarning,
|
|
63
|
+
stacklevel=3,
|
|
64
|
+
)
|
|
65
|
+
D = D_raw.fillna(0).astype(int).values
|
|
66
|
+
return D, missing_mask
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Module-level convergence tolerance for SVD singular value truncation.
|
|
70
|
+
# Singular values below this threshold after soft-thresholding are treated
|
|
71
|
+
# as zero to improve numerical stability.
|
|
72
|
+
_CONVERGENCE_TOL_SVD: float = 1e-10
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _soft_threshold_svd(
|
|
76
|
+
M: np.ndarray,
|
|
77
|
+
threshold: float,
|
|
78
|
+
convergence_tol: float = _CONVERGENCE_TOL_SVD,
|
|
79
|
+
) -> np.ndarray:
|
|
80
|
+
"""
|
|
81
|
+
Apply soft-thresholding to singular values (proximal operator for nuclear norm).
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
M : np.ndarray
|
|
86
|
+
Input matrix.
|
|
87
|
+
threshold : float
|
|
88
|
+
Soft-thresholding parameter.
|
|
89
|
+
convergence_tol : float, default=1e-10
|
|
90
|
+
Singular values below this after thresholding are treated as zero.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
np.ndarray
|
|
95
|
+
Matrix with soft-thresholded singular values.
|
|
96
|
+
"""
|
|
97
|
+
if threshold <= 0:
|
|
98
|
+
return M
|
|
99
|
+
|
|
100
|
+
# Handle NaN/Inf values in input
|
|
101
|
+
if not np.isfinite(M).all():
|
|
102
|
+
M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
U, s, Vt = np.linalg.svd(M, full_matrices=False)
|
|
106
|
+
except np.linalg.LinAlgError:
|
|
107
|
+
# SVD failed, return zero matrix
|
|
108
|
+
return np.zeros_like(M)
|
|
109
|
+
|
|
110
|
+
# Check for numerical issues in SVD output
|
|
111
|
+
if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
|
|
112
|
+
# SVD produced non-finite values, return zero matrix
|
|
113
|
+
return np.zeros_like(M)
|
|
114
|
+
|
|
115
|
+
s_thresh = np.maximum(s - threshold, 0)
|
|
116
|
+
|
|
117
|
+
# Use truncated reconstruction with only non-zero singular values
|
|
118
|
+
nonzero_mask = s_thresh > convergence_tol
|
|
119
|
+
if not np.any(nonzero_mask):
|
|
120
|
+
return np.zeros_like(M)
|
|
121
|
+
|
|
122
|
+
# Truncate to non-zero components for numerical stability
|
|
123
|
+
U_trunc = U[:, nonzero_mask]
|
|
124
|
+
s_trunc = s_thresh[nonzero_mask]
|
|
125
|
+
Vt_trunc = Vt[nonzero_mask, :]
|
|
126
|
+
|
|
127
|
+
# Compute result, suppressing expected numerical warnings from
|
|
128
|
+
# ill-conditioned matrices during alternating minimization
|
|
129
|
+
with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
|
|
130
|
+
result = (U_trunc * s_trunc) @ Vt_trunc
|
|
131
|
+
|
|
132
|
+
# Replace any NaN/Inf in result with zeros
|
|
133
|
+
if not np.isfinite(result).all():
|
|
134
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
135
|
+
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class TROPLocalMixin:
|
|
140
|
+
"""Mixin providing local (observation-specific) estimation for TROP.
|
|
141
|
+
|
|
142
|
+
Methods in this mixin access the following attributes from the main
|
|
143
|
+
TROP class via ``self``:
|
|
144
|
+
|
|
145
|
+
- Solver params: ``max_iter``, ``tol``
|
|
146
|
+
- Inference params: ``n_bootstrap``, ``seed``
|
|
147
|
+
- State: ``_precomputed``
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
# Type hints for attributes accessed from the main TROP class
|
|
151
|
+
max_iter: int
|
|
152
|
+
tol: float
|
|
153
|
+
n_bootstrap: int
|
|
154
|
+
seed: Optional[int]
|
|
155
|
+
_precomputed: Optional[_PrecomputedStructures]
|
|
156
|
+
|
|
157
|
+
# Convergence tolerance for SVD singular value truncation
|
|
158
|
+
CONVERGENCE_TOL_SVD: float = 1e-10
|
|
159
|
+
|
|
160
|
+
# =========================================================================
|
|
161
|
+
# Preprocessing and distance computation
|
|
162
|
+
# =========================================================================
|
|
163
|
+
|
|
164
|
+
def _precompute_structures(
|
|
165
|
+
self,
|
|
166
|
+
Y: np.ndarray,
|
|
167
|
+
D: np.ndarray,
|
|
168
|
+
control_unit_idx: np.ndarray,
|
|
169
|
+
n_units: int,
|
|
170
|
+
n_periods: int,
|
|
171
|
+
) -> _PrecomputedStructures:
|
|
172
|
+
"""
|
|
173
|
+
Pre-compute data structures that are reused across LOOCV and estimation.
|
|
174
|
+
|
|
175
|
+
This method computes once what would otherwise be computed repeatedly:
|
|
176
|
+
- Pairwise unit distance matrix
|
|
177
|
+
- Time distance vectors
|
|
178
|
+
- Masks and indices
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
Y : np.ndarray
|
|
183
|
+
Outcome matrix (n_periods x n_units).
|
|
184
|
+
D : np.ndarray
|
|
185
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
186
|
+
control_unit_idx : np.ndarray
|
|
187
|
+
Indices of control units.
|
|
188
|
+
n_units : int
|
|
189
|
+
Number of units.
|
|
190
|
+
n_periods : int
|
|
191
|
+
Number of periods.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
_PrecomputedStructures
|
|
196
|
+
Pre-computed structures for efficient reuse.
|
|
197
|
+
"""
|
|
198
|
+
# Compute pairwise unit distances (for all observation-specific weights)
|
|
199
|
+
# Following Equation 3 (page 7): RMSE between units over pre-treatment
|
|
200
|
+
if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None:
|
|
201
|
+
# Use Rust backend for parallel distance computation (4-8x speedup)
|
|
202
|
+
unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64))
|
|
203
|
+
else:
|
|
204
|
+
unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)
|
|
205
|
+
|
|
206
|
+
# Pre-compute time distance vectors for each target period
|
|
207
|
+
# Time distance: |t - s| for all s and each target t
|
|
208
|
+
time_dist_matrix = np.abs(
|
|
209
|
+
np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
|
|
210
|
+
) # (n_periods, n_periods) where [t, s] = |t - s|
|
|
211
|
+
|
|
212
|
+
# Control and treatment masks
|
|
213
|
+
control_mask = D == 0
|
|
214
|
+
treated_mask = D == 1
|
|
215
|
+
|
|
216
|
+
# Identify treated observations
|
|
217
|
+
treated_observations = list(zip(*np.where(treated_mask)))
|
|
218
|
+
|
|
219
|
+
# Control observations for LOOCV
|
|
220
|
+
control_obs = [
|
|
221
|
+
(t, i)
|
|
222
|
+
for t in range(n_periods)
|
|
223
|
+
for i in range(n_units)
|
|
224
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
return {
|
|
228
|
+
"unit_dist_matrix": unit_dist_matrix,
|
|
229
|
+
"time_dist_matrix": time_dist_matrix,
|
|
230
|
+
"control_mask": control_mask,
|
|
231
|
+
"treated_mask": treated_mask,
|
|
232
|
+
"treated_observations": treated_observations,
|
|
233
|
+
"control_obs": control_obs,
|
|
234
|
+
"control_unit_idx": control_unit_idx,
|
|
235
|
+
"D": D,
|
|
236
|
+
"Y": Y,
|
|
237
|
+
"n_units": n_units,
|
|
238
|
+
"n_periods": n_periods,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
def _compute_all_unit_distances(
|
|
242
|
+
self,
|
|
243
|
+
Y: np.ndarray,
|
|
244
|
+
D: np.ndarray,
|
|
245
|
+
n_units: int,
|
|
246
|
+
n_periods: int,
|
|
247
|
+
) -> np.ndarray:
|
|
248
|
+
"""
|
|
249
|
+
Compute pairwise unit distance matrix using vectorized operations.
|
|
250
|
+
|
|
251
|
+
Following Equation 3 (page 7):
|
|
252
|
+
dist_unit_{-t}(j, i) = sqrt(sum_u (Y_{iu} - Y_{ju})^2 / n_valid)
|
|
253
|
+
|
|
254
|
+
For efficiency, we compute a base distance matrix excluding all treated
|
|
255
|
+
observations, which provides a good approximation. The exact per-observation
|
|
256
|
+
distances are refined when needed.
|
|
257
|
+
|
|
258
|
+
Uses vectorized numpy operations with masked arrays for O(n^2) complexity
|
|
259
|
+
but with highly optimized inner loops via numpy/BLAS.
|
|
260
|
+
|
|
261
|
+
Parameters
|
|
262
|
+
----------
|
|
263
|
+
Y : np.ndarray
|
|
264
|
+
Outcome matrix (n_periods x n_units).
|
|
265
|
+
D : np.ndarray
|
|
266
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
267
|
+
n_units : int
|
|
268
|
+
Number of units.
|
|
269
|
+
n_periods : int
|
|
270
|
+
Number of periods.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
np.ndarray
|
|
275
|
+
Pairwise distance matrix (n_units x n_units).
|
|
276
|
+
"""
|
|
277
|
+
# Mask for valid observations: control periods only (D=0), non-NaN
|
|
278
|
+
valid_mask = (D == 0) & ~np.isnan(Y)
|
|
279
|
+
|
|
280
|
+
# Replace invalid values with NaN for masked computation
|
|
281
|
+
Y_masked = np.where(valid_mask, Y, np.nan)
|
|
282
|
+
|
|
283
|
+
# Transpose to (n_units, n_periods) for easier broadcasting
|
|
284
|
+
Y_T = Y_masked.T # (n_units, n_periods)
|
|
285
|
+
|
|
286
|
+
# Compute pairwise squared differences using broadcasting
|
|
287
|
+
# Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
|
|
288
|
+
# Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
|
|
289
|
+
# diff has shape (n_units, n_units, n_periods)
|
|
290
|
+
diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :]
|
|
291
|
+
sq_diff = diff**2
|
|
292
|
+
|
|
293
|
+
# Count valid (non-NaN) observations per pair
|
|
294
|
+
# A difference is valid only if both units have valid observations
|
|
295
|
+
valid_diff = ~np.isnan(sq_diff)
|
|
296
|
+
n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units)
|
|
297
|
+
|
|
298
|
+
# Compute sum of squared differences (treating NaN as 0)
|
|
299
|
+
sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units)
|
|
300
|
+
|
|
301
|
+
# Compute RMSE distance: sqrt(sum / n_valid)
|
|
302
|
+
# Avoid division by zero
|
|
303
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
304
|
+
dist_matrix = np.sqrt(sq_diff_sum / n_valid)
|
|
305
|
+
|
|
306
|
+
# Set pairs with no valid observations to inf
|
|
307
|
+
dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf)
|
|
308
|
+
|
|
309
|
+
# Ensure diagonal is 0 (same unit distance)
|
|
310
|
+
np.fill_diagonal(dist_matrix, 0.0)
|
|
311
|
+
|
|
312
|
+
return dist_matrix
|
|
313
|
+
|
|
314
|
+
def _compute_unit_distance_for_obs(
|
|
315
|
+
self,
|
|
316
|
+
Y: np.ndarray,
|
|
317
|
+
D: np.ndarray,
|
|
318
|
+
j: int,
|
|
319
|
+
i: int,
|
|
320
|
+
target_period: int,
|
|
321
|
+
) -> float:
|
|
322
|
+
"""
|
|
323
|
+
Compute observation-specific pairwise distance from unit j to unit i.
|
|
324
|
+
|
|
325
|
+
This is the exact computation from Equation 3, excluding the target period.
|
|
326
|
+
Used when the base distance matrix approximation is insufficient.
|
|
327
|
+
|
|
328
|
+
Parameters
|
|
329
|
+
----------
|
|
330
|
+
Y : np.ndarray
|
|
331
|
+
Outcome matrix (n_periods x n_units).
|
|
332
|
+
D : np.ndarray
|
|
333
|
+
Treatment indicator matrix.
|
|
334
|
+
j : int
|
|
335
|
+
Control unit index.
|
|
336
|
+
i : int
|
|
337
|
+
Treated unit index.
|
|
338
|
+
target_period : int
|
|
339
|
+
Target period to exclude.
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
float
|
|
344
|
+
Pairwise RMSE distance.
|
|
345
|
+
"""
|
|
346
|
+
n_periods = Y.shape[0]
|
|
347
|
+
|
|
348
|
+
# Mask: exclude target period, both units must be untreated, non-NaN
|
|
349
|
+
valid = np.ones(n_periods, dtype=bool)
|
|
350
|
+
valid[target_period] = False
|
|
351
|
+
valid &= (D[:, i] == 0) & (D[:, j] == 0)
|
|
352
|
+
valid &= ~np.isnan(Y[:, i]) & ~np.isnan(Y[:, j])
|
|
353
|
+
|
|
354
|
+
if np.any(valid):
|
|
355
|
+
sq_diffs = (Y[valid, i] - Y[valid, j]) ** 2
|
|
356
|
+
return np.sqrt(np.mean(sq_diffs))
|
|
357
|
+
else:
|
|
358
|
+
return np.inf
|
|
359
|
+
|
|
360
|
+
# =========================================================================
|
|
361
|
+
# Observation-specific estimation
|
|
362
|
+
# =========================================================================
|
|
363
|
+
|
|
364
|
+
def _compute_observation_weights(
|
|
365
|
+
self,
|
|
366
|
+
Y: np.ndarray,
|
|
367
|
+
D: np.ndarray,
|
|
368
|
+
i: int,
|
|
369
|
+
t: int,
|
|
370
|
+
lambda_time: float,
|
|
371
|
+
lambda_unit: float,
|
|
372
|
+
control_unit_idx: np.ndarray,
|
|
373
|
+
n_units: int,
|
|
374
|
+
n_periods: int,
|
|
375
|
+
) -> np.ndarray:
|
|
376
|
+
"""
|
|
377
|
+
Compute observation-specific weight matrix for treated observation (i, t).
|
|
378
|
+
|
|
379
|
+
Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
|
|
380
|
+
- Time weights theta_s^{i,t} = exp(-lambda_time * |t - s|)
|
|
381
|
+
- Unit weights omega_j^{i,t} = exp(-lambda_unit * dist_unit_{-t}(j, i))
|
|
382
|
+
|
|
383
|
+
IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
|
|
384
|
+
where (1 - W_js) is non-zero, which includes pre-treatment observations of
|
|
385
|
+
eventually-treated units since W_js = 0 for those. This method computes
|
|
386
|
+
weights for ALL units where D[t, j] = 0 at the target period, not just
|
|
387
|
+
never-treated units.
|
|
388
|
+
|
|
389
|
+
Uses pre-computed structures when available for efficiency.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
Y : np.ndarray
|
|
394
|
+
Outcome matrix (n_periods x n_units).
|
|
395
|
+
D : np.ndarray
|
|
396
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
397
|
+
i : int
|
|
398
|
+
Treated unit index.
|
|
399
|
+
t : int
|
|
400
|
+
Treatment period index.
|
|
401
|
+
lambda_time : float
|
|
402
|
+
Time weight decay parameter.
|
|
403
|
+
lambda_unit : float
|
|
404
|
+
Unit weight decay parameter.
|
|
405
|
+
control_unit_idx : np.ndarray
|
|
406
|
+
Indices of never-treated units (for backward compatibility, but not
|
|
407
|
+
used for weight computation - we use D matrix directly).
|
|
408
|
+
n_units : int
|
|
409
|
+
Number of units.
|
|
410
|
+
n_periods : int
|
|
411
|
+
Number of periods.
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
np.ndarray
|
|
416
|
+
Weight matrix (n_periods x n_units) for observation (i, t).
|
|
417
|
+
"""
|
|
418
|
+
# Use pre-computed structures when available
|
|
419
|
+
if self._precomputed is not None:
|
|
420
|
+
# Time weights from pre-computed time distance matrix
|
|
421
|
+
# time_dist_matrix[t, s] = |t - s|
|
|
422
|
+
time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
|
|
423
|
+
|
|
424
|
+
# Unit weights - computed for ALL units where D[t, j] = 0
|
|
425
|
+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
|
|
426
|
+
unit_weights = np.zeros(n_units)
|
|
427
|
+
D_stored = self._precomputed["D"]
|
|
428
|
+
Y_stored = self._precomputed["Y"]
|
|
429
|
+
|
|
430
|
+
# Valid control units at time t: D[t, j] == 0
|
|
431
|
+
valid_control_at_t = D_stored[t, :] == 0
|
|
432
|
+
|
|
433
|
+
if lambda_unit == 0:
|
|
434
|
+
# Uniform weights when lambda_unit = 0
|
|
435
|
+
# All units not treated at time t get weight 1
|
|
436
|
+
unit_weights[valid_control_at_t] = 1.0
|
|
437
|
+
else:
|
|
438
|
+
# Use observation-specific distances with target period excluded
|
|
439
|
+
# (Issue B fix: compute exact per-observation distance)
|
|
440
|
+
for j in range(n_units):
|
|
441
|
+
if valid_control_at_t[j] and j != i:
|
|
442
|
+
# Compute distance excluding target period t
|
|
443
|
+
dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
|
|
444
|
+
if np.isinf(dist):
|
|
445
|
+
unit_weights[j] = 0.0
|
|
446
|
+
else:
|
|
447
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
448
|
+
|
|
449
|
+
# Treated unit i gets weight 1
|
|
450
|
+
unit_weights[i] = 1.0
|
|
451
|
+
|
|
452
|
+
# Weight matrix: outer product (n_periods x n_units)
|
|
453
|
+
return np.outer(time_weights, unit_weights)
|
|
454
|
+
|
|
455
|
+
# Fallback: compute from scratch (used in bootstrap)
|
|
456
|
+
# Time distance: |t - s| following paper's Equation 3 (page 7)
|
|
457
|
+
dist_time = np.abs(np.arange(n_periods) - t)
|
|
458
|
+
time_weights = np.exp(-lambda_time * dist_time)
|
|
459
|
+
|
|
460
|
+
# Unit weights - computed for ALL units where D[t, j] = 0
|
|
461
|
+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
|
|
462
|
+
unit_weights = np.zeros(n_units)
|
|
463
|
+
|
|
464
|
+
# Valid control units at time t: D[t, j] == 0
|
|
465
|
+
valid_control_at_t = D[t, :] == 0
|
|
466
|
+
|
|
467
|
+
if lambda_unit == 0:
|
|
468
|
+
# Uniform weights when lambda_unit = 0
|
|
469
|
+
unit_weights[valid_control_at_t] = 1.0
|
|
470
|
+
else:
|
|
471
|
+
for j in range(n_units):
|
|
472
|
+
if valid_control_at_t[j] and j != i:
|
|
473
|
+
# Compute distance excluding target period t (Issue B fix)
|
|
474
|
+
dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
|
|
475
|
+
if np.isinf(dist):
|
|
476
|
+
unit_weights[j] = 0.0
|
|
477
|
+
else:
|
|
478
|
+
unit_weights[j] = np.exp(-lambda_unit * dist)
|
|
479
|
+
|
|
480
|
+
# Treated unit i gets weight 1 (or could be omitted since we fit on controls)
|
|
481
|
+
# We include treated unit's own observation for model fitting
|
|
482
|
+
unit_weights[i] = 1.0
|
|
483
|
+
|
|
484
|
+
# Weight matrix: outer product (n_periods x n_units)
|
|
485
|
+
W = np.outer(time_weights, unit_weights)
|
|
486
|
+
|
|
487
|
+
return W
|
|
488
|
+
|
|
489
|
+
def _soft_threshold_svd(
|
|
490
|
+
self,
|
|
491
|
+
M: np.ndarray,
|
|
492
|
+
threshold: float,
|
|
493
|
+
) -> np.ndarray:
|
|
494
|
+
"""Delegate to module-level ``_soft_threshold_svd``."""
|
|
495
|
+
return _soft_threshold_svd(M, threshold, self.CONVERGENCE_TOL_SVD)
|
|
496
|
+
|
|
497
|
+
def _weighted_nuclear_norm_solve(
|
|
498
|
+
self,
|
|
499
|
+
Y: np.ndarray,
|
|
500
|
+
W: np.ndarray,
|
|
501
|
+
L_init: np.ndarray,
|
|
502
|
+
alpha: np.ndarray,
|
|
503
|
+
beta: np.ndarray,
|
|
504
|
+
lambda_nn: float,
|
|
505
|
+
max_inner_iter: int = 20,
|
|
506
|
+
) -> np.ndarray:
|
|
507
|
+
"""
|
|
508
|
+
Solve weighted nuclear norm problem using iterative weighted soft-impute.
|
|
509
|
+
|
|
510
|
+
Issue C fix: Implements the weighted nuclear norm optimization from the
|
|
511
|
+
paper's Equation 2 (page 7). The full objective is:
|
|
512
|
+
min_L sum W_{ti}(R_{ti} - L_{ti})^2 + lambda_nn||L||_*
|
|
513
|
+
|
|
514
|
+
This uses proximal gradient descent (Mazumder et al. 2010) with
|
|
515
|
+
FISTA/Nesterov acceleration. Lipschitz constant L_f = 2*max(W),
|
|
516
|
+
step size eta = 1/(2*max(W)), proximal threshold eta*lambda_nn:
|
|
517
|
+
G_k = L_k + (W/max(W)) * (R - L_k)
|
|
518
|
+
L_{k+1} = prox_{eta*lambda_nn*||*||_*}(G_k)
|
|
519
|
+
|
|
520
|
+
IMPORTANT: For observations with W=0 (treated observations), we keep
|
|
521
|
+
L values from the previous iteration rather than setting L = R, which
|
|
522
|
+
would absorb the treatment effect.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
Y : np.ndarray
|
|
527
|
+
Outcome matrix (n_periods x n_units).
|
|
528
|
+
W : np.ndarray
|
|
529
|
+
Weight matrix (n_periods x n_units), non-negative. W=0 indicates
|
|
530
|
+
observations that should not be used for fitting (treated obs).
|
|
531
|
+
L_init : np.ndarray
|
|
532
|
+
Initial estimate of L matrix.
|
|
533
|
+
alpha : np.ndarray
|
|
534
|
+
Current unit fixed effects estimate.
|
|
535
|
+
beta : np.ndarray
|
|
536
|
+
Current time fixed effects estimate.
|
|
537
|
+
lambda_nn : float
|
|
538
|
+
Nuclear norm regularization parameter.
|
|
539
|
+
max_inner_iter : int, default=20
|
|
540
|
+
Maximum inner iterations for the proximal algorithm.
|
|
541
|
+
|
|
542
|
+
Returns
|
|
543
|
+
-------
|
|
544
|
+
np.ndarray
|
|
545
|
+
Updated L matrix estimate.
|
|
546
|
+
"""
|
|
547
|
+
# Compute target residual R = Y - alpha - beta
|
|
548
|
+
R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
|
|
549
|
+
|
|
550
|
+
# Handle invalid values
|
|
551
|
+
R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
|
|
552
|
+
|
|
553
|
+
# For observations with W=0 (treated obs), keep L_init instead of R
|
|
554
|
+
# This prevents L from absorbing the treatment effect
|
|
555
|
+
valid_obs_mask = W > 0
|
|
556
|
+
R_masked = np.where(valid_obs_mask, R, L_init)
|
|
557
|
+
|
|
558
|
+
if lambda_nn <= 0:
|
|
559
|
+
# No regularization - just return masked residual
|
|
560
|
+
# Use soft-thresholding with threshold=0 which returns the input
|
|
561
|
+
return R_masked
|
|
562
|
+
|
|
563
|
+
# Normalize weights so max is 1 (for step size stability)
|
|
564
|
+
W_max = np.max(W)
|
|
565
|
+
if W_max > 0:
|
|
566
|
+
W_norm = W / W_max
|
|
567
|
+
else:
|
|
568
|
+
W_norm = W
|
|
569
|
+
|
|
570
|
+
# Initialize L
|
|
571
|
+
L = L_init.copy()
|
|
572
|
+
L_prev = L.copy()
|
|
573
|
+
t_fista = 1.0
|
|
574
|
+
|
|
575
|
+
# Proximal gradient iteration with FISTA/Nesterov acceleration
|
|
576
|
+
# This solves: min_L ||W^{1/2} * (R - L)||_F^2 + lambda||L||_*
|
|
577
|
+
# Lipschitz constant L_f = 2*max(W), so eta = 1/(2*max(W))
|
|
578
|
+
# Threshold = eta*lambda_nn = lambda_nn/(2*max(W))
|
|
579
|
+
for _ in range(max_inner_iter):
|
|
580
|
+
L_old = L.copy()
|
|
581
|
+
|
|
582
|
+
# FISTA momentum
|
|
583
|
+
t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
|
|
584
|
+
momentum = (t_fista - 1.0) / t_fista_new
|
|
585
|
+
L_momentum = L + momentum * (L - L_prev)
|
|
586
|
+
|
|
587
|
+
# Gradient step from momentum point: L_m + W * (R - L_m)
|
|
588
|
+
# For W=0 observations, this keeps L_m unchanged
|
|
589
|
+
gradient_step = L_momentum + W_norm * (R_masked - L_momentum)
|
|
590
|
+
|
|
591
|
+
# Proximal step: soft-threshold singular values
|
|
592
|
+
L_prev = L.copy()
|
|
593
|
+
threshold = lambda_nn / (2.0 * W_max) if W_max > 0 else lambda_nn / 2.0
|
|
594
|
+
L = self._soft_threshold_svd(gradient_step, threshold)
|
|
595
|
+
t_fista = t_fista_new
|
|
596
|
+
|
|
597
|
+
# Check convergence
|
|
598
|
+
if np.max(np.abs(L - L_old)) < self.tol:
|
|
599
|
+
break
|
|
600
|
+
|
|
601
|
+
return L
|
|
602
|
+
|
|
603
|
+
def _estimate_model(
|
|
604
|
+
self,
|
|
605
|
+
Y: np.ndarray,
|
|
606
|
+
control_mask: np.ndarray,
|
|
607
|
+
weight_matrix: np.ndarray,
|
|
608
|
+
lambda_nn: float,
|
|
609
|
+
n_units: int,
|
|
610
|
+
n_periods: int,
|
|
611
|
+
exclude_obs: Optional[Tuple[int, int]] = None,
|
|
612
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
613
|
+
"""
|
|
614
|
+
Estimate the model: Y = alpha + beta + L + tau*D + eps with nuclear norm penalty on L.
|
|
615
|
+
|
|
616
|
+
Uses alternating minimization with vectorized operations:
|
|
617
|
+
1. Fix L, solve for alpha, beta via weighted means
|
|
618
|
+
2. Fix alpha, beta, solve for L via soft-thresholding
|
|
619
|
+
|
|
620
|
+
Parameters
|
|
621
|
+
----------
|
|
622
|
+
Y : np.ndarray
|
|
623
|
+
Outcome matrix (n_periods x n_units).
|
|
624
|
+
control_mask : np.ndarray
|
|
625
|
+
Boolean mask for control observations.
|
|
626
|
+
weight_matrix : np.ndarray
|
|
627
|
+
Pre-computed global weight matrix (n_periods x n_units).
|
|
628
|
+
lambda_nn : float
|
|
629
|
+
Nuclear norm regularization parameter.
|
|
630
|
+
n_units : int
|
|
631
|
+
Number of units.
|
|
632
|
+
n_periods : int
|
|
633
|
+
Number of periods.
|
|
634
|
+
exclude_obs : tuple, optional
|
|
635
|
+
(t, i) observation to exclude (for LOOCV).
|
|
636
|
+
|
|
637
|
+
Returns
|
|
638
|
+
-------
|
|
639
|
+
tuple
|
|
640
|
+
(alpha, beta, L) estimated parameters.
|
|
641
|
+
"""
|
|
642
|
+
W = weight_matrix
|
|
643
|
+
|
|
644
|
+
# Mask for estimation (control obs only, excluding LOOCV obs if specified)
|
|
645
|
+
est_mask = control_mask.copy()
|
|
646
|
+
if exclude_obs is not None:
|
|
647
|
+
t_ex, i_ex = exclude_obs
|
|
648
|
+
est_mask[t_ex, i_ex] = False
|
|
649
|
+
|
|
650
|
+
# Handle missing values
|
|
651
|
+
valid_mask = ~np.isnan(Y) & est_mask
|
|
652
|
+
|
|
653
|
+
# Initialize
|
|
654
|
+
alpha = np.zeros(n_units)
|
|
655
|
+
beta = np.zeros(n_periods)
|
|
656
|
+
L = np.zeros((n_periods, n_units))
|
|
657
|
+
|
|
658
|
+
# Pre-compute masked weights for vectorized operations
|
|
659
|
+
# Set weights to 0 where not valid
|
|
660
|
+
W_masked = W * valid_mask
|
|
661
|
+
|
|
662
|
+
# Pre-compute weight sums per unit and per time (for denominator)
|
|
663
|
+
# shape: (n_units,) and (n_periods,)
|
|
664
|
+
weight_sum_per_unit = np.sum(W_masked, axis=0) # sum over periods
|
|
665
|
+
weight_sum_per_time = np.sum(W_masked, axis=1) # sum over units
|
|
666
|
+
|
|
667
|
+
# Handle units/periods with zero weight sum
|
|
668
|
+
unit_has_obs = weight_sum_per_unit > 0
|
|
669
|
+
time_has_obs = weight_sum_per_time > 0
|
|
670
|
+
|
|
671
|
+
# Create safe denominators (avoid division by zero)
|
|
672
|
+
safe_unit_denom = np.where(unit_has_obs, weight_sum_per_unit, 1.0)
|
|
673
|
+
safe_time_denom = np.where(time_has_obs, weight_sum_per_time, 1.0)
|
|
674
|
+
|
|
675
|
+
# Replace NaN in Y with 0 for computation (mask handles exclusion)
|
|
676
|
+
Y_safe = np.where(np.isnan(Y), 0.0, Y)
|
|
677
|
+
|
|
678
|
+
# Alternating minimization following Algorithm 1 (page 9)
|
|
679
|
+
# Minimize: sum W_{ti}(Y_{ti} - alpha_i - beta_t - L_{ti})^2 + lambda_nn||L||_*
|
|
680
|
+
for _ in range(self.max_iter):
|
|
681
|
+
alpha_old = alpha.copy()
|
|
682
|
+
beta_old = beta.copy()
|
|
683
|
+
L_old = L.copy()
|
|
684
|
+
|
|
685
|
+
# Step 1: Update alpha and beta (weighted least squares)
|
|
686
|
+
# Following Equation 2 (page 7), fix L and solve for alpha, beta
|
|
687
|
+
# R = Y - L (residual without fixed effects)
|
|
688
|
+
R = Y_safe - L
|
|
689
|
+
|
|
690
|
+
# Alpha update (unit fixed effects):
|
|
691
|
+
# alpha_i = argmin_alpha sum_t W_{ti}(R_{ti} - alpha - beta_t)^2
|
|
692
|
+
# Solution: alpha_i = sum_t W_{ti}(R_{ti} - beta_t) / sum_t W_{ti}
|
|
693
|
+
R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units)
|
|
694
|
+
weighted_R_minus_beta = W_masked * R_minus_beta
|
|
695
|
+
alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,)
|
|
696
|
+
alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0)
|
|
697
|
+
|
|
698
|
+
# Beta update (time fixed effects):
|
|
699
|
+
# beta_t = argmin_beta sum_i W_{ti}(R_{ti} - alpha_i - beta)^2
|
|
700
|
+
# Solution: beta_t = sum_i W_{ti}(R_{ti} - alpha_i) / sum_i W_{ti}
|
|
701
|
+
R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units)
|
|
702
|
+
weighted_R_minus_alpha = W_masked * R_minus_alpha
|
|
703
|
+
beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
|
|
704
|
+
beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
|
|
705
|
+
|
|
706
|
+
# Step 2: Update L with weighted nuclear norm penalty
|
|
707
|
+
# Issue C fix: Use weighted soft-impute to properly account for
|
|
708
|
+
# observation weights in the nuclear norm optimization.
|
|
709
|
+
# Following Equation 2 (page 7): min_L sum W_{ti}(Y - alpha - beta - L)^2 + lambda||L||_*
|
|
710
|
+
L = self._weighted_nuclear_norm_solve(
|
|
711
|
+
Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Check convergence
|
|
715
|
+
alpha_diff = np.max(np.abs(alpha - alpha_old))
|
|
716
|
+
beta_diff = np.max(np.abs(beta - beta_old))
|
|
717
|
+
L_diff = np.max(np.abs(L - L_old))
|
|
718
|
+
|
|
719
|
+
if max(alpha_diff, beta_diff, L_diff) < self.tol:
|
|
720
|
+
break
|
|
721
|
+
|
|
722
|
+
return alpha, beta, L
|
|
723
|
+
|
|
724
|
+
def _loocv_score_obs_specific(
|
|
725
|
+
self,
|
|
726
|
+
Y: np.ndarray,
|
|
727
|
+
D: np.ndarray,
|
|
728
|
+
control_mask: np.ndarray,
|
|
729
|
+
control_unit_idx: np.ndarray,
|
|
730
|
+
lambda_time: float,
|
|
731
|
+
lambda_unit: float,
|
|
732
|
+
lambda_nn: float,
|
|
733
|
+
n_units: int,
|
|
734
|
+
n_periods: int,
|
|
735
|
+
) -> float:
|
|
736
|
+
"""
|
|
737
|
+
Compute leave-one-out cross-validation score with observation-specific weights.
|
|
738
|
+
|
|
739
|
+
Following the paper's Equation 5 (page 8):
|
|
740
|
+
Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2
|
|
741
|
+
|
|
742
|
+
For each control observation (j, s), treat it as pseudo-treated,
|
|
743
|
+
compute observation-specific weights, fit model excluding (j, s),
|
|
744
|
+
and sum squared pseudo-treatment effects.
|
|
745
|
+
|
|
746
|
+
Uses pre-computed structures when available for efficiency.
|
|
747
|
+
|
|
748
|
+
Parameters
|
|
749
|
+
----------
|
|
750
|
+
Y : np.ndarray
|
|
751
|
+
Outcome matrix (n_periods x n_units).
|
|
752
|
+
D : np.ndarray
|
|
753
|
+
Treatment indicator matrix (n_periods x n_units).
|
|
754
|
+
control_mask : np.ndarray
|
|
755
|
+
Boolean mask for control observations.
|
|
756
|
+
control_unit_idx : np.ndarray
|
|
757
|
+
Indices of control units.
|
|
758
|
+
lambda_time : float
|
|
759
|
+
Time weight decay parameter.
|
|
760
|
+
lambda_unit : float
|
|
761
|
+
Unit weight decay parameter.
|
|
762
|
+
lambda_nn : float
|
|
763
|
+
Nuclear norm regularization parameter.
|
|
764
|
+
n_units : int
|
|
765
|
+
Number of units.
|
|
766
|
+
n_periods : int
|
|
767
|
+
Number of periods.
|
|
768
|
+
|
|
769
|
+
Returns
|
|
770
|
+
-------
|
|
771
|
+
float
|
|
772
|
+
LOOCV score (lower is better).
|
|
773
|
+
"""
|
|
774
|
+
# Use pre-computed control observations if available
|
|
775
|
+
if self._precomputed is not None:
|
|
776
|
+
control_obs = self._precomputed["control_obs"]
|
|
777
|
+
else:
|
|
778
|
+
# Get all control observations
|
|
779
|
+
control_obs = [
|
|
780
|
+
(t, i)
|
|
781
|
+
for t in range(n_periods)
|
|
782
|
+
for i in range(n_units)
|
|
783
|
+
if control_mask[t, i] and not np.isnan(Y[t, i])
|
|
784
|
+
]
|
|
785
|
+
|
|
786
|
+
# Empty control set check: if no control observations, return infinity
|
|
787
|
+
# A score of 0.0 would incorrectly "win" over legitimate parameters
|
|
788
|
+
if len(control_obs) == 0:
|
|
789
|
+
warnings.warn(
|
|
790
|
+
f"LOOCV: No valid control observations for "
|
|
791
|
+
f"\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}). "
|
|
792
|
+
"Returning infinite score.",
|
|
793
|
+
UserWarning,
|
|
794
|
+
)
|
|
795
|
+
return np.inf
|
|
796
|
+
|
|
797
|
+
tau_squared_sum = 0.0
|
|
798
|
+
n_valid = 0
|
|
799
|
+
|
|
800
|
+
for t, i in control_obs:
|
|
801
|
+
try:
|
|
802
|
+
# Compute observation-specific weights for pseudo-treated (i, t)
|
|
803
|
+
# Uses pre-computed distance matrices when available
|
|
804
|
+
weight_matrix = self._compute_observation_weights(
|
|
805
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Estimate model excluding observation (t, i)
|
|
809
|
+
alpha, beta, L = self._estimate_model(
|
|
810
|
+
Y,
|
|
811
|
+
control_mask,
|
|
812
|
+
weight_matrix,
|
|
813
|
+
lambda_nn,
|
|
814
|
+
n_units,
|
|
815
|
+
n_periods,
|
|
816
|
+
exclude_obs=(t, i),
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
# Pseudo treatment effect
|
|
820
|
+
tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
821
|
+
tau_squared_sum += tau_ti**2
|
|
822
|
+
n_valid += 1
|
|
823
|
+
|
|
824
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
825
|
+
# Per Equation 5: Q(lambda) must sum over ALL D==0 cells
|
|
826
|
+
# Any failure means this lambda cannot produce valid estimates for all cells
|
|
827
|
+
warnings.warn(
|
|
828
|
+
f"LOOCV: Fit failed for observation ({t}, {i}) with "
|
|
829
|
+
f"\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}). "
|
|
830
|
+
"Returning infinite score per Equation 5.",
|
|
831
|
+
UserWarning,
|
|
832
|
+
)
|
|
833
|
+
return np.inf
|
|
834
|
+
|
|
835
|
+
# Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
|
|
836
|
+
# Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2
|
|
837
|
+
return tau_squared_sum
|
|
838
|
+
|
|
839
|
+
def _bootstrap_variance(
|
|
840
|
+
self,
|
|
841
|
+
data: pd.DataFrame,
|
|
842
|
+
outcome: str,
|
|
843
|
+
treatment: str,
|
|
844
|
+
unit: str,
|
|
845
|
+
time: str,
|
|
846
|
+
optimal_lambda: Tuple[float, float, float],
|
|
847
|
+
Y: Optional[np.ndarray] = None,
|
|
848
|
+
D: Optional[np.ndarray] = None,
|
|
849
|
+
control_unit_idx: Optional[np.ndarray] = None,
|
|
850
|
+
survey_design=None,
|
|
851
|
+
unit_weight_arr: Optional[np.ndarray] = None,
|
|
852
|
+
resolved_survey=None,
|
|
853
|
+
) -> Tuple[float, np.ndarray]:
|
|
854
|
+
"""
|
|
855
|
+
Compute bootstrap standard error using unit-level block bootstrap.
|
|
856
|
+
|
|
857
|
+
When the optional Rust backend is available and the matrix parameters
|
|
858
|
+
(Y, D, control_unit_idx) are provided, uses parallelized Rust
|
|
859
|
+
implementation for 5-15x speedup. Falls back to Python implementation
|
|
860
|
+
if Rust is unavailable or if matrix parameters are not provided.
|
|
861
|
+
|
|
862
|
+
When a full survey design (strata/PSU/FPC) is present, uses Rao-Wu
|
|
863
|
+
rescaled bootstrap instead, which skips the Rust path.
|
|
864
|
+
|
|
865
|
+
Parameters
|
|
866
|
+
----------
|
|
867
|
+
data : pd.DataFrame
|
|
868
|
+
Original data in long format with unit, time, outcome, and treatment.
|
|
869
|
+
outcome : str
|
|
870
|
+
Name of the outcome column in data.
|
|
871
|
+
treatment : str
|
|
872
|
+
Name of the treatment indicator column in data.
|
|
873
|
+
unit : str
|
|
874
|
+
Name of the unit identifier column in data.
|
|
875
|
+
time : str
|
|
876
|
+
Name of the time period column in data.
|
|
877
|
+
optimal_lambda : tuple of float
|
|
878
|
+
Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
|
|
879
|
+
from cross-validation. Used for model estimation in each bootstrap.
|
|
880
|
+
Y : np.ndarray, optional
|
|
881
|
+
Outcome matrix of shape (n_periods, n_units). Required for Rust
|
|
882
|
+
backend acceleration. If None, falls back to Python implementation.
|
|
883
|
+
D : np.ndarray, optional
|
|
884
|
+
Treatment indicator matrix of shape (n_periods, n_units) where
|
|
885
|
+
D[t,i]=1 indicates unit i is treated at time t. Required for Rust
|
|
886
|
+
backend acceleration.
|
|
887
|
+
control_unit_idx : np.ndarray, optional
|
|
888
|
+
Array of indices for control units (never-treated). Required for
|
|
889
|
+
Rust backend acceleration.
|
|
890
|
+
survey_design : SurveyDesign, optional
|
|
891
|
+
Survey design specification.
|
|
892
|
+
unit_weight_arr : np.ndarray, optional
|
|
893
|
+
Unit-level survey weights.
|
|
894
|
+
resolved_survey : ResolvedSurveyDesign, optional
|
|
895
|
+
Resolved survey design (observation-level).
|
|
896
|
+
|
|
897
|
+
Returns
|
|
898
|
+
-------
|
|
899
|
+
se : float
|
|
900
|
+
Bootstrap standard error of the ATT estimate.
|
|
901
|
+
bootstrap_estimates : np.ndarray
|
|
902
|
+
Array of ATT estimates from each bootstrap iteration. Length may
|
|
903
|
+
be less than n_bootstrap if some iterations failed.
|
|
904
|
+
|
|
905
|
+
Notes
|
|
906
|
+
-----
|
|
907
|
+
Uses unit-level block bootstrap where entire unit time series are
|
|
908
|
+
resampled with replacement. This preserves within-unit correlation
|
|
909
|
+
structure and is appropriate for panel data.
|
|
910
|
+
"""
|
|
911
|
+
lambda_time, lambda_unit, lambda_nn = optimal_lambda
|
|
912
|
+
|
|
913
|
+
# Check for full survey design (strata/PSU/FPC present)
|
|
914
|
+
_has_full_design = resolved_survey is not None and (
|
|
915
|
+
resolved_survey.strata is not None
|
|
916
|
+
or resolved_survey.psu is not None
|
|
917
|
+
or resolved_survey.fpc is not None
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
# Full survey design: use Python Rao-Wu rescaled bootstrap
|
|
921
|
+
if _has_full_design:
|
|
922
|
+
return self._bootstrap_rao_wu_local(
|
|
923
|
+
data,
|
|
924
|
+
outcome,
|
|
925
|
+
treatment,
|
|
926
|
+
unit,
|
|
927
|
+
time,
|
|
928
|
+
optimal_lambda,
|
|
929
|
+
resolved_survey,
|
|
930
|
+
survey_design,
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
# Try Rust backend for parallel bootstrap (5-15x speedup)
|
|
934
|
+
# Only used for pweight-only designs (no strata/PSU/FPC)
|
|
935
|
+
if (
|
|
936
|
+
HAS_RUST_BACKEND
|
|
937
|
+
and _rust_bootstrap_trop_variance is not None
|
|
938
|
+
and self._precomputed is not None
|
|
939
|
+
and Y is not None
|
|
940
|
+
and D is not None
|
|
941
|
+
):
|
|
942
|
+
try:
|
|
943
|
+
control_mask = self._precomputed["control_mask"]
|
|
944
|
+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
|
|
945
|
+
|
|
946
|
+
bootstrap_estimates, se = _rust_bootstrap_trop_variance(
|
|
947
|
+
Y,
|
|
948
|
+
D.astype(np.float64),
|
|
949
|
+
control_mask.astype(np.uint8),
|
|
950
|
+
time_dist_matrix,
|
|
951
|
+
lambda_time,
|
|
952
|
+
lambda_unit,
|
|
953
|
+
lambda_nn,
|
|
954
|
+
self.n_bootstrap,
|
|
955
|
+
self.max_iter,
|
|
956
|
+
self.tol,
|
|
957
|
+
self.seed if self.seed is not None else 0,
|
|
958
|
+
unit_weight_arr,
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
if len(bootstrap_estimates) >= 10:
|
|
962
|
+
return float(se), bootstrap_estimates
|
|
963
|
+
# Fall through to Python if too few bootstrap samples
|
|
964
|
+
logger.debug(
|
|
965
|
+
"Rust bootstrap returned only %d samples, falling back to Python",
|
|
966
|
+
len(bootstrap_estimates),
|
|
967
|
+
)
|
|
968
|
+
except Exception as e:
|
|
969
|
+
logger.debug("Rust bootstrap variance failed, falling back to Python: %s", e)
|
|
970
|
+
warnings.warn(
|
|
971
|
+
f"Rust backend failed for bootstrap variance; "
|
|
972
|
+
f"falling back to Python. Performance may be reduced. "
|
|
973
|
+
f"Error: {e}",
|
|
974
|
+
UserWarning,
|
|
975
|
+
stacklevel=2,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
# Python implementation (fallback)
|
|
979
|
+
rng = np.random.default_rng(self.seed)
|
|
980
|
+
|
|
981
|
+
# Issue D fix: Stratified bootstrap sampling
|
|
982
|
+
# Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
|
|
983
|
+
# and N_1 treated rows separately to preserve treatment ratio
|
|
984
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
985
|
+
treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
|
|
986
|
+
control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
|
|
987
|
+
|
|
988
|
+
n_treated_units = len(treated_units)
|
|
989
|
+
n_control_units = len(control_units)
|
|
990
|
+
|
|
991
|
+
bootstrap_estimates_list = []
|
|
992
|
+
|
|
993
|
+
for _ in range(self.n_bootstrap):
|
|
994
|
+
# Stratified sampling: sample control and treated units separately
|
|
995
|
+
# This preserves the treatment ratio in each bootstrap sample
|
|
996
|
+
if n_control_units > 0:
|
|
997
|
+
sampled_control = rng.choice(control_units, size=n_control_units, replace=True)
|
|
998
|
+
else:
|
|
999
|
+
sampled_control = np.array([], dtype=control_units.dtype)
|
|
1000
|
+
|
|
1001
|
+
if n_treated_units > 0:
|
|
1002
|
+
sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True)
|
|
1003
|
+
else:
|
|
1004
|
+
sampled_treated = np.array([], dtype=treated_units.dtype)
|
|
1005
|
+
|
|
1006
|
+
# Combine stratified samples
|
|
1007
|
+
sampled_units = np.concatenate([sampled_control, sampled_treated])
|
|
1008
|
+
|
|
1009
|
+
# Create bootstrap sample with unique unit IDs
|
|
1010
|
+
boot_data = pd.concat(
|
|
1011
|
+
[
|
|
1012
|
+
data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
|
|
1013
|
+
for idx, u in enumerate(sampled_units)
|
|
1014
|
+
],
|
|
1015
|
+
ignore_index=True,
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
try:
|
|
1019
|
+
# Fit with fixed lambda (skip LOOCV for speed)
|
|
1020
|
+
att = self._fit_with_fixed_lambda(
|
|
1021
|
+
boot_data,
|
|
1022
|
+
outcome,
|
|
1023
|
+
treatment,
|
|
1024
|
+
unit,
|
|
1025
|
+
time,
|
|
1026
|
+
optimal_lambda,
|
|
1027
|
+
survey_design=survey_design,
|
|
1028
|
+
)
|
|
1029
|
+
if np.isfinite(att):
|
|
1030
|
+
bootstrap_estimates_list.append(att)
|
|
1031
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
1032
|
+
continue
|
|
1033
|
+
|
|
1034
|
+
bootstrap_estimates = np.array(bootstrap_estimates_list)
|
|
1035
|
+
|
|
1036
|
+
if len(bootstrap_estimates) < 10:
|
|
1037
|
+
warnings.warn(
|
|
1038
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
|
|
1039
|
+
"Standard errors may be unreliable.",
|
|
1040
|
+
UserWarning,
|
|
1041
|
+
)
|
|
1042
|
+
if len(bootstrap_estimates) == 0:
|
|
1043
|
+
return np.nan, np.array([])
|
|
1044
|
+
|
|
1045
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
1046
|
+
return float(se), bootstrap_estimates
|
|
1047
|
+
|
|
1048
|
+
def _bootstrap_rao_wu_local(
|
|
1049
|
+
self,
|
|
1050
|
+
data: pd.DataFrame,
|
|
1051
|
+
outcome: str,
|
|
1052
|
+
treatment: str,
|
|
1053
|
+
unit: str,
|
|
1054
|
+
time: str,
|
|
1055
|
+
optimal_lambda: Tuple[float, float, float],
|
|
1056
|
+
resolved_survey,
|
|
1057
|
+
survey_design,
|
|
1058
|
+
) -> Tuple[float, np.ndarray]:
|
|
1059
|
+
"""
|
|
1060
|
+
Rao-Wu rescaled bootstrap for local method with full survey design.
|
|
1061
|
+
|
|
1062
|
+
Instead of physically resampling units, each iteration generates
|
|
1063
|
+
rescaled observation weights via Rao-Wu (1988) weight perturbation.
|
|
1064
|
+
Cross-classifies survey strata with treatment group to preserve
|
|
1065
|
+
the stratified resampling structure.
|
|
1066
|
+
|
|
1067
|
+
Parameters
|
|
1068
|
+
----------
|
|
1069
|
+
data : pd.DataFrame
|
|
1070
|
+
Original data.
|
|
1071
|
+
outcome, treatment, unit, time : str
|
|
1072
|
+
Column names.
|
|
1073
|
+
optimal_lambda : tuple
|
|
1074
|
+
Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn).
|
|
1075
|
+
resolved_survey : ResolvedSurveyDesign
|
|
1076
|
+
Resolved survey design (observation-level).
|
|
1077
|
+
survey_design : SurveyDesign
|
|
1078
|
+
Original survey design specification.
|
|
1079
|
+
|
|
1080
|
+
Returns
|
|
1081
|
+
-------
|
|
1082
|
+
Tuple[float, np.ndarray]
|
|
1083
|
+
(se, bootstrap_estimates).
|
|
1084
|
+
"""
|
|
1085
|
+
import warnings
|
|
1086
|
+
|
|
1087
|
+
from diff_diff.bootstrap_utils import generate_rao_wu_weights
|
|
1088
|
+
from diff_diff.linalg import _factorize_cluster_ids
|
|
1089
|
+
from diff_diff.survey import ResolvedSurveyDesign
|
|
1090
|
+
|
|
1091
|
+
rng = np.random.default_rng(self.seed)
|
|
1092
|
+
|
|
1093
|
+
# Build unit-level resolved survey with cross-classified strata
|
|
1094
|
+
all_units = sorted(data[unit].unique())
|
|
1095
|
+
n_units = len(all_units)
|
|
1096
|
+
|
|
1097
|
+
# Determine treatment status per unit
|
|
1098
|
+
unit_ever_treated = data.groupby(unit)[treatment].max()
|
|
1099
|
+
treatment_group = np.array([int(unit_ever_treated[u]) for u in all_units], dtype=np.int64)
|
|
1100
|
+
|
|
1101
|
+
# Extract unit-level survey design fields
|
|
1102
|
+
first_rows = data.groupby(unit).first().loc[all_units]
|
|
1103
|
+
|
|
1104
|
+
# Weights (unit-level)
|
|
1105
|
+
if survey_design.weights is not None:
|
|
1106
|
+
unit_weights = first_rows[survey_design.weights].values.astype(np.float64)
|
|
1107
|
+
else:
|
|
1108
|
+
unit_weights = np.ones(n_units, dtype=np.float64)
|
|
1109
|
+
|
|
1110
|
+
# Strata: cross-classify survey strata x treatment group
|
|
1111
|
+
if survey_design.strata is not None:
|
|
1112
|
+
survey_strata = first_rows[survey_design.strata].values
|
|
1113
|
+
cross_labels = np.array([f"{s}_{g}" for s, g in zip(survey_strata, treatment_group)])
|
|
1114
|
+
cross_strata = _factorize_cluster_ids(cross_labels)
|
|
1115
|
+
else:
|
|
1116
|
+
# No survey strata: use treatment group as strata
|
|
1117
|
+
cross_strata = treatment_group.copy()
|
|
1118
|
+
n_strata = len(np.unique(cross_strata))
|
|
1119
|
+
|
|
1120
|
+
# PSU (unit-level)
|
|
1121
|
+
psu_arr = None
|
|
1122
|
+
n_psu = 0
|
|
1123
|
+
if survey_design.psu is not None:
|
|
1124
|
+
psu_raw = first_rows[survey_design.psu].values
|
|
1125
|
+
if survey_design.nest and survey_design.strata is not None:
|
|
1126
|
+
combined = np.array([f"{s}_{p}" for s, p in zip(cross_strata, psu_raw)])
|
|
1127
|
+
psu_arr = _factorize_cluster_ids(combined)
|
|
1128
|
+
else:
|
|
1129
|
+
psu_arr = _factorize_cluster_ids(psu_raw)
|
|
1130
|
+
n_psu = len(np.unique(psu_arr))
|
|
1131
|
+
else:
|
|
1132
|
+
# Implicit PSU: each unit is its own PSU
|
|
1133
|
+
psu_arr = np.arange(n_units, dtype=np.int64)
|
|
1134
|
+
n_psu = n_units
|
|
1135
|
+
|
|
1136
|
+
# FPC (unit-level)
|
|
1137
|
+
fpc_arr = None
|
|
1138
|
+
if survey_design.fpc is not None:
|
|
1139
|
+
fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64)
|
|
1140
|
+
|
|
1141
|
+
unit_resolved = ResolvedSurveyDesign(
|
|
1142
|
+
weights=unit_weights,
|
|
1143
|
+
weight_type=resolved_survey.weight_type,
|
|
1144
|
+
strata=cross_strata,
|
|
1145
|
+
psu=psu_arr,
|
|
1146
|
+
fpc=fpc_arr,
|
|
1147
|
+
n_strata=n_strata,
|
|
1148
|
+
n_psu=n_psu,
|
|
1149
|
+
lonely_psu=resolved_survey.lonely_psu,
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
# Check for unidentified variance (single unstratified PSU)
|
|
1153
|
+
if (
|
|
1154
|
+
survey_design.psu is not None
|
|
1155
|
+
and unit_resolved.n_psu < 2
|
|
1156
|
+
and survey_design.strata is None
|
|
1157
|
+
):
|
|
1158
|
+
return np.nan, np.array([])
|
|
1159
|
+
|
|
1160
|
+
# Bootstrap loop: refit the full model per draw with Rao-Wu rescaled
|
|
1161
|
+
# weights, mirroring the physical-resampling bootstrap but using weight
|
|
1162
|
+
# perturbation instead of unit resampling.
|
|
1163
|
+
bootstrap_estimates_list = []
|
|
1164
|
+
|
|
1165
|
+
for _ in range(self.n_bootstrap):
|
|
1166
|
+
try:
|
|
1167
|
+
# Generate Rao-Wu rescaled unit weights
|
|
1168
|
+
boot_weights = generate_rao_wu_weights(unit_resolved, rng)
|
|
1169
|
+
|
|
1170
|
+
# Skip if all weights are zero
|
|
1171
|
+
if boot_weights.sum() == 0:
|
|
1172
|
+
continue
|
|
1173
|
+
|
|
1174
|
+
# Refit the full local model with rescaled weights
|
|
1175
|
+
att = self._fit_with_fixed_lambda(
|
|
1176
|
+
data,
|
|
1177
|
+
outcome,
|
|
1178
|
+
treatment,
|
|
1179
|
+
unit,
|
|
1180
|
+
time,
|
|
1181
|
+
optimal_lambda,
|
|
1182
|
+
survey_design=survey_design,
|
|
1183
|
+
unit_weight_arr=boot_weights,
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
if np.isfinite(att):
|
|
1187
|
+
bootstrap_estimates_list.append(att)
|
|
1188
|
+
except (ValueError, np.linalg.LinAlgError, KeyError):
|
|
1189
|
+
continue
|
|
1190
|
+
|
|
1191
|
+
bootstrap_estimates = np.array(bootstrap_estimates_list)
|
|
1192
|
+
|
|
1193
|
+
if len(bootstrap_estimates) < 10:
|
|
1194
|
+
warnings.warn(
|
|
1195
|
+
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
|
|
1196
|
+
"Standard errors may be unreliable.",
|
|
1197
|
+
UserWarning,
|
|
1198
|
+
)
|
|
1199
|
+
if len(bootstrap_estimates) == 0:
|
|
1200
|
+
return np.nan, np.array([])
|
|
1201
|
+
|
|
1202
|
+
se = np.std(bootstrap_estimates, ddof=1)
|
|
1203
|
+
return float(se), bootstrap_estimates
|
|
1204
|
+
|
|
1205
|
+
def _fit_with_fixed_lambda(
|
|
1206
|
+
self,
|
|
1207
|
+
data: pd.DataFrame,
|
|
1208
|
+
outcome: str,
|
|
1209
|
+
treatment: str,
|
|
1210
|
+
unit: str,
|
|
1211
|
+
time: str,
|
|
1212
|
+
fixed_lambda: Tuple[float, float, float],
|
|
1213
|
+
survey_design=None,
|
|
1214
|
+
unit_weight_arr: Optional[np.ndarray] = None,
|
|
1215
|
+
) -> float:
|
|
1216
|
+
"""
|
|
1217
|
+
Fit model with fixed tuning parameters (for bootstrap).
|
|
1218
|
+
|
|
1219
|
+
Uses observation-specific weights following Algorithm 2.
|
|
1220
|
+
Returns only the ATT estimate.
|
|
1221
|
+
|
|
1222
|
+
Parameters
|
|
1223
|
+
----------
|
|
1224
|
+
unit_weight_arr : np.ndarray, optional
|
|
1225
|
+
Pre-computed unit-level weights (e.g. Rao-Wu rescaled weights).
|
|
1226
|
+
When provided, overrides weights extracted from survey_design.
|
|
1227
|
+
"""
|
|
1228
|
+
lambda_time, lambda_unit, lambda_nn = fixed_lambda
|
|
1229
|
+
|
|
1230
|
+
# Use pre-computed weights if provided (e.g. Rao-Wu bootstrap),
|
|
1231
|
+
# otherwise extract from survey_design.
|
|
1232
|
+
if unit_weight_arr is not None:
|
|
1233
|
+
local_weight_arr = unit_weight_arr
|
|
1234
|
+
elif survey_design is not None and survey_design.weights is not None:
|
|
1235
|
+
from diff_diff.survey import _extract_unit_survey_weights
|
|
1236
|
+
|
|
1237
|
+
local_all_units = sorted(data[unit].unique())
|
|
1238
|
+
local_weight_arr = _extract_unit_survey_weights(
|
|
1239
|
+
data, unit, survey_design, local_all_units
|
|
1240
|
+
)
|
|
1241
|
+
else:
|
|
1242
|
+
local_weight_arr = None
|
|
1243
|
+
|
|
1244
|
+
# Setup matrices
|
|
1245
|
+
all_units = sorted(data[unit].unique())
|
|
1246
|
+
all_periods = sorted(data[time].unique())
|
|
1247
|
+
|
|
1248
|
+
n_units = len(all_units)
|
|
1249
|
+
n_periods = len(all_periods)
|
|
1250
|
+
|
|
1251
|
+
# Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
|
|
1252
|
+
Y = (
|
|
1253
|
+
data.pivot(index=time, columns=unit, values=outcome)
|
|
1254
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1255
|
+
.values
|
|
1256
|
+
)
|
|
1257
|
+
D = (
|
|
1258
|
+
data.pivot(index=time, columns=unit, values=treatment)
|
|
1259
|
+
.reindex(index=all_periods, columns=all_units)
|
|
1260
|
+
.fillna(0)
|
|
1261
|
+
.astype(int)
|
|
1262
|
+
.values
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
control_mask = D == 0
|
|
1266
|
+
|
|
1267
|
+
# Get control unit indices
|
|
1268
|
+
unit_ever_treated = np.any(D == 1, axis=0)
|
|
1269
|
+
control_unit_idx = np.where(~unit_ever_treated)[0]
|
|
1270
|
+
|
|
1271
|
+
# Get list of treated observations
|
|
1272
|
+
treated_observations = [
|
|
1273
|
+
(t, i) for t in range(n_periods) for i in range(n_units) if D[t, i] == 1
|
|
1274
|
+
]
|
|
1275
|
+
|
|
1276
|
+
if not treated_observations:
|
|
1277
|
+
raise ValueError("No treated observations")
|
|
1278
|
+
|
|
1279
|
+
# Compute ATT using observation-specific weights (Algorithm 2)
|
|
1280
|
+
tau_values = []
|
|
1281
|
+
tau_weights = []
|
|
1282
|
+
for t, i in treated_observations:
|
|
1283
|
+
# Skip non-finite outcomes (match main fit NaN contract)
|
|
1284
|
+
if not np.isfinite(Y[t, i]):
|
|
1285
|
+
continue
|
|
1286
|
+
|
|
1287
|
+
# Compute observation-specific weights for this (i, t)
|
|
1288
|
+
weight_matrix = self._compute_observation_weights(
|
|
1289
|
+
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
# Fit model with these weights
|
|
1293
|
+
alpha, beta, L = self._estimate_model(
|
|
1294
|
+
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
# Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it}
|
|
1298
|
+
tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
|
|
1299
|
+
tau_values.append(tau)
|
|
1300
|
+
if local_weight_arr is not None:
|
|
1301
|
+
tau_weights.append(local_weight_arr[i])
|
|
1302
|
+
|
|
1303
|
+
if not tau_values:
|
|
1304
|
+
return float("nan")
|
|
1305
|
+
if local_weight_arr is not None:
|
|
1306
|
+
return float(np.average(tau_values, weights=tau_weights))
|
|
1307
|
+
return float(np.mean(tau_values))
|