diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1307 @@
1
+ """
2
+ Local (observation-specific) estimation method for the TROP estimator.
3
+
4
+ Contains the TROPLocalMixin class with all methods for the local
5
+ estimation pathway, including preprocessing, distance computation,
6
+ per-observation weight computation, model fitting, LOOCV scoring,
7
+ and bootstrap variance estimation.
8
+
9
+ This module is used via mixin inheritance — see trop.py for the
10
+ main TROP class definition.
11
+ """
12
+
13
+ import logging
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ from diff_diff._backend import (
23
+ HAS_RUST_BACKEND,
24
+ _rust_bootstrap_trop_variance,
25
+ _rust_unit_distance_matrix,
26
+ )
27
+ from diff_diff.trop_results import _PrecomputedStructures
28
+
29
+
30
+ def _validate_and_pivot_treatment(data, time, unit, treatment, all_periods, all_units):
31
+ """Validate treatment column and create D matrix with missing mask.
32
+
33
+ Rejects observed rows with missing treatment values (data quality error),
34
+ then pivots to (time x unit) matrix. Structural gaps from unbalanced panels
35
+ are filled with 0 (assumed untreated) and flagged with a warning.
36
+
37
+ Returns
38
+ -------
39
+ D : ndarray
40
+ Treatment matrix (n_periods x n_units), int.
41
+ missing_mask : ndarray
42
+ Boolean mask of structurally absent cells (n_periods x n_units).
43
+ """
44
+ n_nan_observed = int(data[treatment].isna().sum())
45
+ if n_nan_observed > 0:
46
+ raise ValueError(
47
+ f"{n_nan_observed} observation(s) have missing treatment values. "
48
+ f"TROP requires non-missing treatment indicators for all observed "
49
+ f"rows. Remove or impute missing values before fitting."
50
+ )
51
+
52
+ D_raw = data.pivot(index=time, columns=unit, values=treatment).reindex(
53
+ index=all_periods, columns=all_units
54
+ )
55
+ missing_mask = pd.isna(D_raw).values
56
+ n_missing_structural = int(missing_mask.sum())
57
+ if n_missing_structural > 0:
58
+ warnings.warn(
59
+ f"{n_missing_structural} missing treatment indicator(s) in the "
60
+ f"(time x unit) panel matrix filled with 0 (assumed "
61
+ f"untreated). This typically occurs in unbalanced panels.",
62
+ UserWarning,
63
+ stacklevel=3,
64
+ )
65
+ D = D_raw.fillna(0).astype(int).values
66
+ return D, missing_mask
67
+
68
+
69
+ # Module-level convergence tolerance for SVD singular value truncation.
70
+ # Singular values below this threshold after soft-thresholding are treated
71
+ # as zero to improve numerical stability.
72
+ _CONVERGENCE_TOL_SVD: float = 1e-10
73
+
74
+
75
+ def _soft_threshold_svd(
76
+ M: np.ndarray,
77
+ threshold: float,
78
+ convergence_tol: float = _CONVERGENCE_TOL_SVD,
79
+ ) -> np.ndarray:
80
+ """
81
+ Apply soft-thresholding to singular values (proximal operator for nuclear norm).
82
+
83
+ Parameters
84
+ ----------
85
+ M : np.ndarray
86
+ Input matrix.
87
+ threshold : float
88
+ Soft-thresholding parameter.
89
+ convergence_tol : float, default=1e-10
90
+ Singular values below this after thresholding are treated as zero.
91
+
92
+ Returns
93
+ -------
94
+ np.ndarray
95
+ Matrix with soft-thresholded singular values.
96
+ """
97
+ if threshold <= 0:
98
+ return M
99
+
100
+ # Handle NaN/Inf values in input
101
+ if not np.isfinite(M).all():
102
+ M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
103
+
104
+ try:
105
+ U, s, Vt = np.linalg.svd(M, full_matrices=False)
106
+ except np.linalg.LinAlgError:
107
+ # SVD failed, return zero matrix
108
+ return np.zeros_like(M)
109
+
110
+ # Check for numerical issues in SVD output
111
+ if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
112
+ # SVD produced non-finite values, return zero matrix
113
+ return np.zeros_like(M)
114
+
115
+ s_thresh = np.maximum(s - threshold, 0)
116
+
117
+ # Use truncated reconstruction with only non-zero singular values
118
+ nonzero_mask = s_thresh > convergence_tol
119
+ if not np.any(nonzero_mask):
120
+ return np.zeros_like(M)
121
+
122
+ # Truncate to non-zero components for numerical stability
123
+ U_trunc = U[:, nonzero_mask]
124
+ s_trunc = s_thresh[nonzero_mask]
125
+ Vt_trunc = Vt[nonzero_mask, :]
126
+
127
+ # Compute result, suppressing expected numerical warnings from
128
+ # ill-conditioned matrices during alternating minimization
129
+ with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
130
+ result = (U_trunc * s_trunc) @ Vt_trunc
131
+
132
+ # Replace any NaN/Inf in result with zeros
133
+ if not np.isfinite(result).all():
134
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
135
+
136
+ return result
137
+
138
+
139
+ class TROPLocalMixin:
140
+ """Mixin providing local (observation-specific) estimation for TROP.
141
+
142
+ Methods in this mixin access the following attributes from the main
143
+ TROP class via ``self``:
144
+
145
+ - Solver params: ``max_iter``, ``tol``
146
+ - Inference params: ``n_bootstrap``, ``seed``
147
+ - State: ``_precomputed``
148
+ """
149
+
150
+ # Type hints for attributes accessed from the main TROP class
151
+ max_iter: int
152
+ tol: float
153
+ n_bootstrap: int
154
+ seed: Optional[int]
155
+ _precomputed: Optional[_PrecomputedStructures]
156
+
157
+ # Convergence tolerance for SVD singular value truncation
158
+ CONVERGENCE_TOL_SVD: float = 1e-10
159
+
160
+ # =========================================================================
161
+ # Preprocessing and distance computation
162
+ # =========================================================================
163
+
164
+ def _precompute_structures(
165
+ self,
166
+ Y: np.ndarray,
167
+ D: np.ndarray,
168
+ control_unit_idx: np.ndarray,
169
+ n_units: int,
170
+ n_periods: int,
171
+ ) -> _PrecomputedStructures:
172
+ """
173
+ Pre-compute data structures that are reused across LOOCV and estimation.
174
+
175
+ This method computes once what would otherwise be computed repeatedly:
176
+ - Pairwise unit distance matrix
177
+ - Time distance vectors
178
+ - Masks and indices
179
+
180
+ Parameters
181
+ ----------
182
+ Y : np.ndarray
183
+ Outcome matrix (n_periods x n_units).
184
+ D : np.ndarray
185
+ Treatment indicator matrix (n_periods x n_units).
186
+ control_unit_idx : np.ndarray
187
+ Indices of control units.
188
+ n_units : int
189
+ Number of units.
190
+ n_periods : int
191
+ Number of periods.
192
+
193
+ Returns
194
+ -------
195
+ _PrecomputedStructures
196
+ Pre-computed structures for efficient reuse.
197
+ """
198
+ # Compute pairwise unit distances (for all observation-specific weights)
199
+ # Following Equation 3 (page 7): RMSE between units over pre-treatment
200
+ if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None:
201
+ # Use Rust backend for parallel distance computation (4-8x speedup)
202
+ unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64))
203
+ else:
204
+ unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)
205
+
206
+ # Pre-compute time distance vectors for each target period
207
+ # Time distance: |t - s| for all s and each target t
208
+ time_dist_matrix = np.abs(
209
+ np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
210
+ ) # (n_periods, n_periods) where [t, s] = |t - s|
211
+
212
+ # Control and treatment masks
213
+ control_mask = D == 0
214
+ treated_mask = D == 1
215
+
216
+ # Identify treated observations
217
+ treated_observations = list(zip(*np.where(treated_mask)))
218
+
219
+ # Control observations for LOOCV
220
+ control_obs = [
221
+ (t, i)
222
+ for t in range(n_periods)
223
+ for i in range(n_units)
224
+ if control_mask[t, i] and not np.isnan(Y[t, i])
225
+ ]
226
+
227
+ return {
228
+ "unit_dist_matrix": unit_dist_matrix,
229
+ "time_dist_matrix": time_dist_matrix,
230
+ "control_mask": control_mask,
231
+ "treated_mask": treated_mask,
232
+ "treated_observations": treated_observations,
233
+ "control_obs": control_obs,
234
+ "control_unit_idx": control_unit_idx,
235
+ "D": D,
236
+ "Y": Y,
237
+ "n_units": n_units,
238
+ "n_periods": n_periods,
239
+ }
240
+
241
+ def _compute_all_unit_distances(
242
+ self,
243
+ Y: np.ndarray,
244
+ D: np.ndarray,
245
+ n_units: int,
246
+ n_periods: int,
247
+ ) -> np.ndarray:
248
+ """
249
+ Compute pairwise unit distance matrix using vectorized operations.
250
+
251
+ Following Equation 3 (page 7):
252
+ dist_unit_{-t}(j, i) = sqrt(sum_u (Y_{iu} - Y_{ju})^2 / n_valid)
253
+
254
+ For efficiency, we compute a base distance matrix excluding all treated
255
+ observations, which provides a good approximation. The exact per-observation
256
+ distances are refined when needed.
257
+
258
+ Uses vectorized numpy operations with masked arrays for O(n^2) complexity
259
+ but with highly optimized inner loops via numpy/BLAS.
260
+
261
+ Parameters
262
+ ----------
263
+ Y : np.ndarray
264
+ Outcome matrix (n_periods x n_units).
265
+ D : np.ndarray
266
+ Treatment indicator matrix (n_periods x n_units).
267
+ n_units : int
268
+ Number of units.
269
+ n_periods : int
270
+ Number of periods.
271
+
272
+ Returns
273
+ -------
274
+ np.ndarray
275
+ Pairwise distance matrix (n_units x n_units).
276
+ """
277
+ # Mask for valid observations: control periods only (D=0), non-NaN
278
+ valid_mask = (D == 0) & ~np.isnan(Y)
279
+
280
+ # Replace invalid values with NaN for masked computation
281
+ Y_masked = np.where(valid_mask, Y, np.nan)
282
+
283
+ # Transpose to (n_units, n_periods) for easier broadcasting
284
+ Y_T = Y_masked.T # (n_units, n_periods)
285
+
286
+ # Compute pairwise squared differences using broadcasting
287
+ # Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
288
+ # Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
289
+ # diff has shape (n_units, n_units, n_periods)
290
+ diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :]
291
+ sq_diff = diff**2
292
+
293
+ # Count valid (non-NaN) observations per pair
294
+ # A difference is valid only if both units have valid observations
295
+ valid_diff = ~np.isnan(sq_diff)
296
+ n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units)
297
+
298
+ # Compute sum of squared differences (treating NaN as 0)
299
+ sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units)
300
+
301
+ # Compute RMSE distance: sqrt(sum / n_valid)
302
+ # Avoid division by zero
303
+ with np.errstate(divide="ignore", invalid="ignore"):
304
+ dist_matrix = np.sqrt(sq_diff_sum / n_valid)
305
+
306
+ # Set pairs with no valid observations to inf
307
+ dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf)
308
+
309
+ # Ensure diagonal is 0 (same unit distance)
310
+ np.fill_diagonal(dist_matrix, 0.0)
311
+
312
+ return dist_matrix
313
+
314
+ def _compute_unit_distance_for_obs(
315
+ self,
316
+ Y: np.ndarray,
317
+ D: np.ndarray,
318
+ j: int,
319
+ i: int,
320
+ target_period: int,
321
+ ) -> float:
322
+ """
323
+ Compute observation-specific pairwise distance from unit j to unit i.
324
+
325
+ This is the exact computation from Equation 3, excluding the target period.
326
+ Used when the base distance matrix approximation is insufficient.
327
+
328
+ Parameters
329
+ ----------
330
+ Y : np.ndarray
331
+ Outcome matrix (n_periods x n_units).
332
+ D : np.ndarray
333
+ Treatment indicator matrix.
334
+ j : int
335
+ Control unit index.
336
+ i : int
337
+ Treated unit index.
338
+ target_period : int
339
+ Target period to exclude.
340
+
341
+ Returns
342
+ -------
343
+ float
344
+ Pairwise RMSE distance.
345
+ """
346
+ n_periods = Y.shape[0]
347
+
348
+ # Mask: exclude target period, both units must be untreated, non-NaN
349
+ valid = np.ones(n_periods, dtype=bool)
350
+ valid[target_period] = False
351
+ valid &= (D[:, i] == 0) & (D[:, j] == 0)
352
+ valid &= ~np.isnan(Y[:, i]) & ~np.isnan(Y[:, j])
353
+
354
+ if np.any(valid):
355
+ sq_diffs = (Y[valid, i] - Y[valid, j]) ** 2
356
+ return np.sqrt(np.mean(sq_diffs))
357
+ else:
358
+ return np.inf
359
+
360
+ # =========================================================================
361
+ # Observation-specific estimation
362
+ # =========================================================================
363
+
364
+ def _compute_observation_weights(
365
+ self,
366
+ Y: np.ndarray,
367
+ D: np.ndarray,
368
+ i: int,
369
+ t: int,
370
+ lambda_time: float,
371
+ lambda_unit: float,
372
+ control_unit_idx: np.ndarray,
373
+ n_units: int,
374
+ n_periods: int,
375
+ ) -> np.ndarray:
376
+ """
377
+ Compute observation-specific weight matrix for treated observation (i, t).
378
+
379
+ Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
380
+ - Time weights theta_s^{i,t} = exp(-lambda_time * |t - s|)
381
+ - Unit weights omega_j^{i,t} = exp(-lambda_unit * dist_unit_{-t}(j, i))
382
+
383
+ IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
384
+ where (1 - W_js) is non-zero, which includes pre-treatment observations of
385
+ eventually-treated units since W_js = 0 for those. This method computes
386
+ weights for ALL units where D[t, j] = 0 at the target period, not just
387
+ never-treated units.
388
+
389
+ Uses pre-computed structures when available for efficiency.
390
+
391
+ Parameters
392
+ ----------
393
+ Y : np.ndarray
394
+ Outcome matrix (n_periods x n_units).
395
+ D : np.ndarray
396
+ Treatment indicator matrix (n_periods x n_units).
397
+ i : int
398
+ Treated unit index.
399
+ t : int
400
+ Treatment period index.
401
+ lambda_time : float
402
+ Time weight decay parameter.
403
+ lambda_unit : float
404
+ Unit weight decay parameter.
405
+ control_unit_idx : np.ndarray
406
+ Indices of never-treated units (for backward compatibility, but not
407
+ used for weight computation - we use D matrix directly).
408
+ n_units : int
409
+ Number of units.
410
+ n_periods : int
411
+ Number of periods.
412
+
413
+ Returns
414
+ -------
415
+ np.ndarray
416
+ Weight matrix (n_periods x n_units) for observation (i, t).
417
+ """
418
+ # Use pre-computed structures when available
419
+ if self._precomputed is not None:
420
+ # Time weights from pre-computed time distance matrix
421
+ # time_dist_matrix[t, s] = |t - s|
422
+ time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
423
+
424
+ # Unit weights - computed for ALL units where D[t, j] = 0
425
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
426
+ unit_weights = np.zeros(n_units)
427
+ D_stored = self._precomputed["D"]
428
+ Y_stored = self._precomputed["Y"]
429
+
430
+ # Valid control units at time t: D[t, j] == 0
431
+ valid_control_at_t = D_stored[t, :] == 0
432
+
433
+ if lambda_unit == 0:
434
+ # Uniform weights when lambda_unit = 0
435
+ # All units not treated at time t get weight 1
436
+ unit_weights[valid_control_at_t] = 1.0
437
+ else:
438
+ # Use observation-specific distances with target period excluded
439
+ # (Issue B fix: compute exact per-observation distance)
440
+ for j in range(n_units):
441
+ if valid_control_at_t[j] and j != i:
442
+ # Compute distance excluding target period t
443
+ dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
444
+ if np.isinf(dist):
445
+ unit_weights[j] = 0.0
446
+ else:
447
+ unit_weights[j] = np.exp(-lambda_unit * dist)
448
+
449
+ # Treated unit i gets weight 1
450
+ unit_weights[i] = 1.0
451
+
452
+ # Weight matrix: outer product (n_periods x n_units)
453
+ return np.outer(time_weights, unit_weights)
454
+
455
+ # Fallback: compute from scratch (used in bootstrap)
456
+ # Time distance: |t - s| following paper's Equation 3 (page 7)
457
+ dist_time = np.abs(np.arange(n_periods) - t)
458
+ time_weights = np.exp(-lambda_time * dist_time)
459
+
460
+ # Unit weights - computed for ALL units where D[t, j] = 0
461
+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
462
+ unit_weights = np.zeros(n_units)
463
+
464
+ # Valid control units at time t: D[t, j] == 0
465
+ valid_control_at_t = D[t, :] == 0
466
+
467
+ if lambda_unit == 0:
468
+ # Uniform weights when lambda_unit = 0
469
+ unit_weights[valid_control_at_t] = 1.0
470
+ else:
471
+ for j in range(n_units):
472
+ if valid_control_at_t[j] and j != i:
473
+ # Compute distance excluding target period t (Issue B fix)
474
+ dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
475
+ if np.isinf(dist):
476
+ unit_weights[j] = 0.0
477
+ else:
478
+ unit_weights[j] = np.exp(-lambda_unit * dist)
479
+
480
+ # Treated unit i gets weight 1 (or could be omitted since we fit on controls)
481
+ # We include treated unit's own observation for model fitting
482
+ unit_weights[i] = 1.0
483
+
484
+ # Weight matrix: outer product (n_periods x n_units)
485
+ W = np.outer(time_weights, unit_weights)
486
+
487
+ return W
488
+
489
+ def _soft_threshold_svd(
490
+ self,
491
+ M: np.ndarray,
492
+ threshold: float,
493
+ ) -> np.ndarray:
494
+ """Delegate to module-level ``_soft_threshold_svd``."""
495
+ return _soft_threshold_svd(M, threshold, self.CONVERGENCE_TOL_SVD)
496
+
497
+ def _weighted_nuclear_norm_solve(
498
+ self,
499
+ Y: np.ndarray,
500
+ W: np.ndarray,
501
+ L_init: np.ndarray,
502
+ alpha: np.ndarray,
503
+ beta: np.ndarray,
504
+ lambda_nn: float,
505
+ max_inner_iter: int = 20,
506
+ ) -> np.ndarray:
507
+ """
508
+ Solve weighted nuclear norm problem using iterative weighted soft-impute.
509
+
510
+ Issue C fix: Implements the weighted nuclear norm optimization from the
511
+ paper's Equation 2 (page 7). The full objective is:
512
+ min_L sum W_{ti}(R_{ti} - L_{ti})^2 + lambda_nn||L||_*
513
+
514
+ This uses proximal gradient descent (Mazumder et al. 2010) with
515
+ FISTA/Nesterov acceleration. Lipschitz constant L_f = 2*max(W),
516
+ step size eta = 1/(2*max(W)), proximal threshold eta*lambda_nn:
517
+ G_k = L_k + (W/max(W)) * (R - L_k)
518
+ L_{k+1} = prox_{eta*lambda_nn*||*||_*}(G_k)
519
+
520
+ IMPORTANT: For observations with W=0 (treated observations), we keep
521
+ L values from the previous iteration rather than setting L = R, which
522
+ would absorb the treatment effect.
523
+
524
+ Parameters
525
+ ----------
526
+ Y : np.ndarray
527
+ Outcome matrix (n_periods x n_units).
528
+ W : np.ndarray
529
+ Weight matrix (n_periods x n_units), non-negative. W=0 indicates
530
+ observations that should not be used for fitting (treated obs).
531
+ L_init : np.ndarray
532
+ Initial estimate of L matrix.
533
+ alpha : np.ndarray
534
+ Current unit fixed effects estimate.
535
+ beta : np.ndarray
536
+ Current time fixed effects estimate.
537
+ lambda_nn : float
538
+ Nuclear norm regularization parameter.
539
+ max_inner_iter : int, default=20
540
+ Maximum inner iterations for the proximal algorithm.
541
+
542
+ Returns
543
+ -------
544
+ np.ndarray
545
+ Updated L matrix estimate.
546
+ """
547
+ # Compute target residual R = Y - alpha - beta
548
+ R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
549
+
550
+ # Handle invalid values
551
+ R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
552
+
553
+ # For observations with W=0 (treated obs), keep L_init instead of R
554
+ # This prevents L from absorbing the treatment effect
555
+ valid_obs_mask = W > 0
556
+ R_masked = np.where(valid_obs_mask, R, L_init)
557
+
558
+ if lambda_nn <= 0:
559
+ # No regularization - just return masked residual
560
+ # Use soft-thresholding with threshold=0 which returns the input
561
+ return R_masked
562
+
563
+ # Normalize weights so max is 1 (for step size stability)
564
+ W_max = np.max(W)
565
+ if W_max > 0:
566
+ W_norm = W / W_max
567
+ else:
568
+ W_norm = W
569
+
570
+ # Initialize L
571
+ L = L_init.copy()
572
+ L_prev = L.copy()
573
+ t_fista = 1.0
574
+
575
+ # Proximal gradient iteration with FISTA/Nesterov acceleration
576
+ # This solves: min_L ||W^{1/2} * (R - L)||_F^2 + lambda||L||_*
577
+ # Lipschitz constant L_f = 2*max(W), so eta = 1/(2*max(W))
578
+ # Threshold = eta*lambda_nn = lambda_nn/(2*max(W))
579
+ for _ in range(max_inner_iter):
580
+ L_old = L.copy()
581
+
582
+ # FISTA momentum
583
+ t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
584
+ momentum = (t_fista - 1.0) / t_fista_new
585
+ L_momentum = L + momentum * (L - L_prev)
586
+
587
+ # Gradient step from momentum point: L_m + W * (R - L_m)
588
+ # For W=0 observations, this keeps L_m unchanged
589
+ gradient_step = L_momentum + W_norm * (R_masked - L_momentum)
590
+
591
+ # Proximal step: soft-threshold singular values
592
+ L_prev = L.copy()
593
+ threshold = lambda_nn / (2.0 * W_max) if W_max > 0 else lambda_nn / 2.0
594
+ L = self._soft_threshold_svd(gradient_step, threshold)
595
+ t_fista = t_fista_new
596
+
597
+ # Check convergence
598
+ if np.max(np.abs(L - L_old)) < self.tol:
599
+ break
600
+
601
+ return L
602
+
603
+ def _estimate_model(
604
+ self,
605
+ Y: np.ndarray,
606
+ control_mask: np.ndarray,
607
+ weight_matrix: np.ndarray,
608
+ lambda_nn: float,
609
+ n_units: int,
610
+ n_periods: int,
611
+ exclude_obs: Optional[Tuple[int, int]] = None,
612
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
613
+ """
614
+ Estimate the model: Y = alpha + beta + L + tau*D + eps with nuclear norm penalty on L.
615
+
616
+ Uses alternating minimization with vectorized operations:
617
+ 1. Fix L, solve for alpha, beta via weighted means
618
+ 2. Fix alpha, beta, solve for L via soft-thresholding
619
+
620
+ Parameters
621
+ ----------
622
+ Y : np.ndarray
623
+ Outcome matrix (n_periods x n_units).
624
+ control_mask : np.ndarray
625
+ Boolean mask for control observations.
626
+ weight_matrix : np.ndarray
627
+ Pre-computed global weight matrix (n_periods x n_units).
628
+ lambda_nn : float
629
+ Nuclear norm regularization parameter.
630
+ n_units : int
631
+ Number of units.
632
+ n_periods : int
633
+ Number of periods.
634
+ exclude_obs : tuple, optional
635
+ (t, i) observation to exclude (for LOOCV).
636
+
637
+ Returns
638
+ -------
639
+ tuple
640
+ (alpha, beta, L) estimated parameters.
641
+ """
642
+ W = weight_matrix
643
+
644
+ # Mask for estimation (control obs only, excluding LOOCV obs if specified)
645
+ est_mask = control_mask.copy()
646
+ if exclude_obs is not None:
647
+ t_ex, i_ex = exclude_obs
648
+ est_mask[t_ex, i_ex] = False
649
+
650
+ # Handle missing values
651
+ valid_mask = ~np.isnan(Y) & est_mask
652
+
653
+ # Initialize
654
+ alpha = np.zeros(n_units)
655
+ beta = np.zeros(n_periods)
656
+ L = np.zeros((n_periods, n_units))
657
+
658
+ # Pre-compute masked weights for vectorized operations
659
+ # Set weights to 0 where not valid
660
+ W_masked = W * valid_mask
661
+
662
+ # Pre-compute weight sums per unit and per time (for denominator)
663
+ # shape: (n_units,) and (n_periods,)
664
+ weight_sum_per_unit = np.sum(W_masked, axis=0) # sum over periods
665
+ weight_sum_per_time = np.sum(W_masked, axis=1) # sum over units
666
+
667
+ # Handle units/periods with zero weight sum
668
+ unit_has_obs = weight_sum_per_unit > 0
669
+ time_has_obs = weight_sum_per_time > 0
670
+
671
+ # Create safe denominators (avoid division by zero)
672
+ safe_unit_denom = np.where(unit_has_obs, weight_sum_per_unit, 1.0)
673
+ safe_time_denom = np.where(time_has_obs, weight_sum_per_time, 1.0)
674
+
675
+ # Replace NaN in Y with 0 for computation (mask handles exclusion)
676
+ Y_safe = np.where(np.isnan(Y), 0.0, Y)
677
+
678
+ # Alternating minimization following Algorithm 1 (page 9)
679
+ # Minimize: sum W_{ti}(Y_{ti} - alpha_i - beta_t - L_{ti})^2 + lambda_nn||L||_*
680
+ for _ in range(self.max_iter):
681
+ alpha_old = alpha.copy()
682
+ beta_old = beta.copy()
683
+ L_old = L.copy()
684
+
685
+ # Step 1: Update alpha and beta (weighted least squares)
686
+ # Following Equation 2 (page 7), fix L and solve for alpha, beta
687
+ # R = Y - L (residual without fixed effects)
688
+ R = Y_safe - L
689
+
690
+ # Alpha update (unit fixed effects):
691
+ # alpha_i = argmin_alpha sum_t W_{ti}(R_{ti} - alpha - beta_t)^2
692
+ # Solution: alpha_i = sum_t W_{ti}(R_{ti} - beta_t) / sum_t W_{ti}
693
+ R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units)
694
+ weighted_R_minus_beta = W_masked * R_minus_beta
695
+ alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,)
696
+ alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0)
697
+
698
+ # Beta update (time fixed effects):
699
+ # beta_t = argmin_beta sum_i W_{ti}(R_{ti} - alpha_i - beta)^2
700
+ # Solution: beta_t = sum_i W_{ti}(R_{ti} - alpha_i) / sum_i W_{ti}
701
+ R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units)
702
+ weighted_R_minus_alpha = W_masked * R_minus_alpha
703
+ beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
704
+ beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
705
+
706
+ # Step 2: Update L with weighted nuclear norm penalty
707
+ # Issue C fix: Use weighted soft-impute to properly account for
708
+ # observation weights in the nuclear norm optimization.
709
+ # Following Equation 2 (page 7): min_L sum W_{ti}(Y - alpha - beta - L)^2 + lambda||L||_*
710
+ L = self._weighted_nuclear_norm_solve(
711
+ Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
712
+ )
713
+
714
+ # Check convergence
715
+ alpha_diff = np.max(np.abs(alpha - alpha_old))
716
+ beta_diff = np.max(np.abs(beta - beta_old))
717
+ L_diff = np.max(np.abs(L - L_old))
718
+
719
+ if max(alpha_diff, beta_diff, L_diff) < self.tol:
720
+ break
721
+
722
+ return alpha, beta, L
723
+
724
+ def _loocv_score_obs_specific(
725
+ self,
726
+ Y: np.ndarray,
727
+ D: np.ndarray,
728
+ control_mask: np.ndarray,
729
+ control_unit_idx: np.ndarray,
730
+ lambda_time: float,
731
+ lambda_unit: float,
732
+ lambda_nn: float,
733
+ n_units: int,
734
+ n_periods: int,
735
+ ) -> float:
736
+ """
737
+ Compute leave-one-out cross-validation score with observation-specific weights.
738
+
739
+ Following the paper's Equation 5 (page 8):
740
+ Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2
741
+
742
+ For each control observation (j, s), treat it as pseudo-treated,
743
+ compute observation-specific weights, fit model excluding (j, s),
744
+ and sum squared pseudo-treatment effects.
745
+
746
+ Uses pre-computed structures when available for efficiency.
747
+
748
+ Parameters
749
+ ----------
750
+ Y : np.ndarray
751
+ Outcome matrix (n_periods x n_units).
752
+ D : np.ndarray
753
+ Treatment indicator matrix (n_periods x n_units).
754
+ control_mask : np.ndarray
755
+ Boolean mask for control observations.
756
+ control_unit_idx : np.ndarray
757
+ Indices of control units.
758
+ lambda_time : float
759
+ Time weight decay parameter.
760
+ lambda_unit : float
761
+ Unit weight decay parameter.
762
+ lambda_nn : float
763
+ Nuclear norm regularization parameter.
764
+ n_units : int
765
+ Number of units.
766
+ n_periods : int
767
+ Number of periods.
768
+
769
+ Returns
770
+ -------
771
+ float
772
+ LOOCV score (lower is better).
773
+ """
774
+ # Use pre-computed control observations if available
775
+ if self._precomputed is not None:
776
+ control_obs = self._precomputed["control_obs"]
777
+ else:
778
+ # Get all control observations
779
+ control_obs = [
780
+ (t, i)
781
+ for t in range(n_periods)
782
+ for i in range(n_units)
783
+ if control_mask[t, i] and not np.isnan(Y[t, i])
784
+ ]
785
+
786
+ # Empty control set check: if no control observations, return infinity
787
+ # A score of 0.0 would incorrectly "win" over legitimate parameters
788
+ if len(control_obs) == 0:
789
+ warnings.warn(
790
+ f"LOOCV: No valid control observations for "
791
+ f"\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}). "
792
+ "Returning infinite score.",
793
+ UserWarning,
794
+ )
795
+ return np.inf
796
+
797
+ tau_squared_sum = 0.0
798
+ n_valid = 0
799
+
800
+ for t, i in control_obs:
801
+ try:
802
+ # Compute observation-specific weights for pseudo-treated (i, t)
803
+ # Uses pre-computed distance matrices when available
804
+ weight_matrix = self._compute_observation_weights(
805
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
806
+ )
807
+
808
+ # Estimate model excluding observation (t, i)
809
+ alpha, beta, L = self._estimate_model(
810
+ Y,
811
+ control_mask,
812
+ weight_matrix,
813
+ lambda_nn,
814
+ n_units,
815
+ n_periods,
816
+ exclude_obs=(t, i),
817
+ )
818
+
819
+ # Pseudo treatment effect
820
+ tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i]
821
+ tau_squared_sum += tau_ti**2
822
+ n_valid += 1
823
+
824
+ except (np.linalg.LinAlgError, ValueError):
825
+ # Per Equation 5: Q(lambda) must sum over ALL D==0 cells
826
+ # Any failure means this lambda cannot produce valid estimates for all cells
827
+ warnings.warn(
828
+ f"LOOCV: Fit failed for observation ({t}, {i}) with "
829
+ f"\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}). "
830
+ "Returning infinite score per Equation 5.",
831
+ UserWarning,
832
+ )
833
+ return np.inf
834
+
835
+ # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
836
+ # Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2
837
+ return tau_squared_sum
838
+
839
+ def _bootstrap_variance(
840
+ self,
841
+ data: pd.DataFrame,
842
+ outcome: str,
843
+ treatment: str,
844
+ unit: str,
845
+ time: str,
846
+ optimal_lambda: Tuple[float, float, float],
847
+ Y: Optional[np.ndarray] = None,
848
+ D: Optional[np.ndarray] = None,
849
+ control_unit_idx: Optional[np.ndarray] = None,
850
+ survey_design=None,
851
+ unit_weight_arr: Optional[np.ndarray] = None,
852
+ resolved_survey=None,
853
+ ) -> Tuple[float, np.ndarray]:
854
+ """
855
+ Compute bootstrap standard error using unit-level block bootstrap.
856
+
857
+ When the optional Rust backend is available and the matrix parameters
858
+ (Y, D, control_unit_idx) are provided, uses parallelized Rust
859
+ implementation for 5-15x speedup. Falls back to Python implementation
860
+ if Rust is unavailable or if matrix parameters are not provided.
861
+
862
+ When a full survey design (strata/PSU/FPC) is present, uses Rao-Wu
863
+ rescaled bootstrap instead, which skips the Rust path.
864
+
865
+ Parameters
866
+ ----------
867
+ data : pd.DataFrame
868
+ Original data in long format with unit, time, outcome, and treatment.
869
+ outcome : str
870
+ Name of the outcome column in data.
871
+ treatment : str
872
+ Name of the treatment indicator column in data.
873
+ unit : str
874
+ Name of the unit identifier column in data.
875
+ time : str
876
+ Name of the time period column in data.
877
+ optimal_lambda : tuple of float
878
+ Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
879
+ from cross-validation. Used for model estimation in each bootstrap.
880
+ Y : np.ndarray, optional
881
+ Outcome matrix of shape (n_periods, n_units). Required for Rust
882
+ backend acceleration. If None, falls back to Python implementation.
883
+ D : np.ndarray, optional
884
+ Treatment indicator matrix of shape (n_periods, n_units) where
885
+ D[t,i]=1 indicates unit i is treated at time t. Required for Rust
886
+ backend acceleration.
887
+ control_unit_idx : np.ndarray, optional
888
+ Array of indices for control units (never-treated). Required for
889
+ Rust backend acceleration.
890
+ survey_design : SurveyDesign, optional
891
+ Survey design specification.
892
+ unit_weight_arr : np.ndarray, optional
893
+ Unit-level survey weights.
894
+ resolved_survey : ResolvedSurveyDesign, optional
895
+ Resolved survey design (observation-level).
896
+
897
+ Returns
898
+ -------
899
+ se : float
900
+ Bootstrap standard error of the ATT estimate.
901
+ bootstrap_estimates : np.ndarray
902
+ Array of ATT estimates from each bootstrap iteration. Length may
903
+ be less than n_bootstrap if some iterations failed.
904
+
905
+ Notes
906
+ -----
907
+ Uses unit-level block bootstrap where entire unit time series are
908
+ resampled with replacement. This preserves within-unit correlation
909
+ structure and is appropriate for panel data.
910
+ """
911
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
912
+
913
+ # Check for full survey design (strata/PSU/FPC present)
914
+ _has_full_design = resolved_survey is not None and (
915
+ resolved_survey.strata is not None
916
+ or resolved_survey.psu is not None
917
+ or resolved_survey.fpc is not None
918
+ )
919
+
920
+ # Full survey design: use Python Rao-Wu rescaled bootstrap
921
+ if _has_full_design:
922
+ return self._bootstrap_rao_wu_local(
923
+ data,
924
+ outcome,
925
+ treatment,
926
+ unit,
927
+ time,
928
+ optimal_lambda,
929
+ resolved_survey,
930
+ survey_design,
931
+ )
932
+
933
+ # Try Rust backend for parallel bootstrap (5-15x speedup)
934
+ # Only used for pweight-only designs (no strata/PSU/FPC)
935
+ if (
936
+ HAS_RUST_BACKEND
937
+ and _rust_bootstrap_trop_variance is not None
938
+ and self._precomputed is not None
939
+ and Y is not None
940
+ and D is not None
941
+ ):
942
+ try:
943
+ control_mask = self._precomputed["control_mask"]
944
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
945
+
946
+ bootstrap_estimates, se = _rust_bootstrap_trop_variance(
947
+ Y,
948
+ D.astype(np.float64),
949
+ control_mask.astype(np.uint8),
950
+ time_dist_matrix,
951
+ lambda_time,
952
+ lambda_unit,
953
+ lambda_nn,
954
+ self.n_bootstrap,
955
+ self.max_iter,
956
+ self.tol,
957
+ self.seed if self.seed is not None else 0,
958
+ unit_weight_arr,
959
+ )
960
+
961
+ if len(bootstrap_estimates) >= 10:
962
+ return float(se), bootstrap_estimates
963
+ # Fall through to Python if too few bootstrap samples
964
+ logger.debug(
965
+ "Rust bootstrap returned only %d samples, falling back to Python",
966
+ len(bootstrap_estimates),
967
+ )
968
+ except Exception as e:
969
+ logger.debug("Rust bootstrap variance failed, falling back to Python: %s", e)
970
+ warnings.warn(
971
+ f"Rust backend failed for bootstrap variance; "
972
+ f"falling back to Python. Performance may be reduced. "
973
+ f"Error: {e}",
974
+ UserWarning,
975
+ stacklevel=2,
976
+ )
977
+
978
+ # Python implementation (fallback)
979
+ rng = np.random.default_rng(self.seed)
980
+
981
+ # Issue D fix: Stratified bootstrap sampling
982
+ # Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
983
+ # and N_1 treated rows separately to preserve treatment ratio
984
+ unit_ever_treated = data.groupby(unit)[treatment].max()
985
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
986
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
987
+
988
+ n_treated_units = len(treated_units)
989
+ n_control_units = len(control_units)
990
+
991
+ bootstrap_estimates_list = []
992
+
993
+ for _ in range(self.n_bootstrap):
994
+ # Stratified sampling: sample control and treated units separately
995
+ # This preserves the treatment ratio in each bootstrap sample
996
+ if n_control_units > 0:
997
+ sampled_control = rng.choice(control_units, size=n_control_units, replace=True)
998
+ else:
999
+ sampled_control = np.array([], dtype=control_units.dtype)
1000
+
1001
+ if n_treated_units > 0:
1002
+ sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True)
1003
+ else:
1004
+ sampled_treated = np.array([], dtype=treated_units.dtype)
1005
+
1006
+ # Combine stratified samples
1007
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
1008
+
1009
+ # Create bootstrap sample with unique unit IDs
1010
+ boot_data = pd.concat(
1011
+ [
1012
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
1013
+ for idx, u in enumerate(sampled_units)
1014
+ ],
1015
+ ignore_index=True,
1016
+ )
1017
+
1018
+ try:
1019
+ # Fit with fixed lambda (skip LOOCV for speed)
1020
+ att = self._fit_with_fixed_lambda(
1021
+ boot_data,
1022
+ outcome,
1023
+ treatment,
1024
+ unit,
1025
+ time,
1026
+ optimal_lambda,
1027
+ survey_design=survey_design,
1028
+ )
1029
+ if np.isfinite(att):
1030
+ bootstrap_estimates_list.append(att)
1031
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1032
+ continue
1033
+
1034
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
1035
+
1036
+ if len(bootstrap_estimates) < 10:
1037
+ warnings.warn(
1038
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
1039
+ "Standard errors may be unreliable.",
1040
+ UserWarning,
1041
+ )
1042
+ if len(bootstrap_estimates) == 0:
1043
+ return np.nan, np.array([])
1044
+
1045
+ se = np.std(bootstrap_estimates, ddof=1)
1046
+ return float(se), bootstrap_estimates
1047
+
1048
+ def _bootstrap_rao_wu_local(
1049
+ self,
1050
+ data: pd.DataFrame,
1051
+ outcome: str,
1052
+ treatment: str,
1053
+ unit: str,
1054
+ time: str,
1055
+ optimal_lambda: Tuple[float, float, float],
1056
+ resolved_survey,
1057
+ survey_design,
1058
+ ) -> Tuple[float, np.ndarray]:
1059
+ """
1060
+ Rao-Wu rescaled bootstrap for local method with full survey design.
1061
+
1062
+ Instead of physically resampling units, each iteration generates
1063
+ rescaled observation weights via Rao-Wu (1988) weight perturbation.
1064
+ Cross-classifies survey strata with treatment group to preserve
1065
+ the stratified resampling structure.
1066
+
1067
+ Parameters
1068
+ ----------
1069
+ data : pd.DataFrame
1070
+ Original data.
1071
+ outcome, treatment, unit, time : str
1072
+ Column names.
1073
+ optimal_lambda : tuple
1074
+ Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn).
1075
+ resolved_survey : ResolvedSurveyDesign
1076
+ Resolved survey design (observation-level).
1077
+ survey_design : SurveyDesign
1078
+ Original survey design specification.
1079
+
1080
+ Returns
1081
+ -------
1082
+ Tuple[float, np.ndarray]
1083
+ (se, bootstrap_estimates).
1084
+ """
1085
+ import warnings
1086
+
1087
+ from diff_diff.bootstrap_utils import generate_rao_wu_weights
1088
+ from diff_diff.linalg import _factorize_cluster_ids
1089
+ from diff_diff.survey import ResolvedSurveyDesign
1090
+
1091
+ rng = np.random.default_rng(self.seed)
1092
+
1093
+ # Build unit-level resolved survey with cross-classified strata
1094
+ all_units = sorted(data[unit].unique())
1095
+ n_units = len(all_units)
1096
+
1097
+ # Determine treatment status per unit
1098
+ unit_ever_treated = data.groupby(unit)[treatment].max()
1099
+ treatment_group = np.array([int(unit_ever_treated[u]) for u in all_units], dtype=np.int64)
1100
+
1101
+ # Extract unit-level survey design fields
1102
+ first_rows = data.groupby(unit).first().loc[all_units]
1103
+
1104
+ # Weights (unit-level)
1105
+ if survey_design.weights is not None:
1106
+ unit_weights = first_rows[survey_design.weights].values.astype(np.float64)
1107
+ else:
1108
+ unit_weights = np.ones(n_units, dtype=np.float64)
1109
+
1110
+ # Strata: cross-classify survey strata x treatment group
1111
+ if survey_design.strata is not None:
1112
+ survey_strata = first_rows[survey_design.strata].values
1113
+ cross_labels = np.array([f"{s}_{g}" for s, g in zip(survey_strata, treatment_group)])
1114
+ cross_strata = _factorize_cluster_ids(cross_labels)
1115
+ else:
1116
+ # No survey strata: use treatment group as strata
1117
+ cross_strata = treatment_group.copy()
1118
+ n_strata = len(np.unique(cross_strata))
1119
+
1120
+ # PSU (unit-level)
1121
+ psu_arr = None
1122
+ n_psu = 0
1123
+ if survey_design.psu is not None:
1124
+ psu_raw = first_rows[survey_design.psu].values
1125
+ if survey_design.nest and survey_design.strata is not None:
1126
+ combined = np.array([f"{s}_{p}" for s, p in zip(cross_strata, psu_raw)])
1127
+ psu_arr = _factorize_cluster_ids(combined)
1128
+ else:
1129
+ psu_arr = _factorize_cluster_ids(psu_raw)
1130
+ n_psu = len(np.unique(psu_arr))
1131
+ else:
1132
+ # Implicit PSU: each unit is its own PSU
1133
+ psu_arr = np.arange(n_units, dtype=np.int64)
1134
+ n_psu = n_units
1135
+
1136
+ # FPC (unit-level)
1137
+ fpc_arr = None
1138
+ if survey_design.fpc is not None:
1139
+ fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64)
1140
+
1141
+ unit_resolved = ResolvedSurveyDesign(
1142
+ weights=unit_weights,
1143
+ weight_type=resolved_survey.weight_type,
1144
+ strata=cross_strata,
1145
+ psu=psu_arr,
1146
+ fpc=fpc_arr,
1147
+ n_strata=n_strata,
1148
+ n_psu=n_psu,
1149
+ lonely_psu=resolved_survey.lonely_psu,
1150
+ )
1151
+
1152
+ # Check for unidentified variance (single unstratified PSU)
1153
+ if (
1154
+ survey_design.psu is not None
1155
+ and unit_resolved.n_psu < 2
1156
+ and survey_design.strata is None
1157
+ ):
1158
+ return np.nan, np.array([])
1159
+
1160
+ # Bootstrap loop: refit the full model per draw with Rao-Wu rescaled
1161
+ # weights, mirroring the physical-resampling bootstrap but using weight
1162
+ # perturbation instead of unit resampling.
1163
+ bootstrap_estimates_list = []
1164
+
1165
+ for _ in range(self.n_bootstrap):
1166
+ try:
1167
+ # Generate Rao-Wu rescaled unit weights
1168
+ boot_weights = generate_rao_wu_weights(unit_resolved, rng)
1169
+
1170
+ # Skip if all weights are zero
1171
+ if boot_weights.sum() == 0:
1172
+ continue
1173
+
1174
+ # Refit the full local model with rescaled weights
1175
+ att = self._fit_with_fixed_lambda(
1176
+ data,
1177
+ outcome,
1178
+ treatment,
1179
+ unit,
1180
+ time,
1181
+ optimal_lambda,
1182
+ survey_design=survey_design,
1183
+ unit_weight_arr=boot_weights,
1184
+ )
1185
+
1186
+ if np.isfinite(att):
1187
+ bootstrap_estimates_list.append(att)
1188
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1189
+ continue
1190
+
1191
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
1192
+
1193
+ if len(bootstrap_estimates) < 10:
1194
+ warnings.warn(
1195
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
1196
+ "Standard errors may be unreliable.",
1197
+ UserWarning,
1198
+ )
1199
+ if len(bootstrap_estimates) == 0:
1200
+ return np.nan, np.array([])
1201
+
1202
+ se = np.std(bootstrap_estimates, ddof=1)
1203
+ return float(se), bootstrap_estimates
1204
+
1205
+ def _fit_with_fixed_lambda(
1206
+ self,
1207
+ data: pd.DataFrame,
1208
+ outcome: str,
1209
+ treatment: str,
1210
+ unit: str,
1211
+ time: str,
1212
+ fixed_lambda: Tuple[float, float, float],
1213
+ survey_design=None,
1214
+ unit_weight_arr: Optional[np.ndarray] = None,
1215
+ ) -> float:
1216
+ """
1217
+ Fit model with fixed tuning parameters (for bootstrap).
1218
+
1219
+ Uses observation-specific weights following Algorithm 2.
1220
+ Returns only the ATT estimate.
1221
+
1222
+ Parameters
1223
+ ----------
1224
+ unit_weight_arr : np.ndarray, optional
1225
+ Pre-computed unit-level weights (e.g. Rao-Wu rescaled weights).
1226
+ When provided, overrides weights extracted from survey_design.
1227
+ """
1228
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
1229
+
1230
+ # Use pre-computed weights if provided (e.g. Rao-Wu bootstrap),
1231
+ # otherwise extract from survey_design.
1232
+ if unit_weight_arr is not None:
1233
+ local_weight_arr = unit_weight_arr
1234
+ elif survey_design is not None and survey_design.weights is not None:
1235
+ from diff_diff.survey import _extract_unit_survey_weights
1236
+
1237
+ local_all_units = sorted(data[unit].unique())
1238
+ local_weight_arr = _extract_unit_survey_weights(
1239
+ data, unit, survey_design, local_all_units
1240
+ )
1241
+ else:
1242
+ local_weight_arr = None
1243
+
1244
+ # Setup matrices
1245
+ all_units = sorted(data[unit].unique())
1246
+ all_periods = sorted(data[time].unique())
1247
+
1248
+ n_units = len(all_units)
1249
+ n_periods = len(all_periods)
1250
+
1251
+ # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
1252
+ Y = (
1253
+ data.pivot(index=time, columns=unit, values=outcome)
1254
+ .reindex(index=all_periods, columns=all_units)
1255
+ .values
1256
+ )
1257
+ D = (
1258
+ data.pivot(index=time, columns=unit, values=treatment)
1259
+ .reindex(index=all_periods, columns=all_units)
1260
+ .fillna(0)
1261
+ .astype(int)
1262
+ .values
1263
+ )
1264
+
1265
+ control_mask = D == 0
1266
+
1267
+ # Get control unit indices
1268
+ unit_ever_treated = np.any(D == 1, axis=0)
1269
+ control_unit_idx = np.where(~unit_ever_treated)[0]
1270
+
1271
+ # Get list of treated observations
1272
+ treated_observations = [
1273
+ (t, i) for t in range(n_periods) for i in range(n_units) if D[t, i] == 1
1274
+ ]
1275
+
1276
+ if not treated_observations:
1277
+ raise ValueError("No treated observations")
1278
+
1279
+ # Compute ATT using observation-specific weights (Algorithm 2)
1280
+ tau_values = []
1281
+ tau_weights = []
1282
+ for t, i in treated_observations:
1283
+ # Skip non-finite outcomes (match main fit NaN contract)
1284
+ if not np.isfinite(Y[t, i]):
1285
+ continue
1286
+
1287
+ # Compute observation-specific weights for this (i, t)
1288
+ weight_matrix = self._compute_observation_weights(
1289
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
1290
+ )
1291
+
1292
+ # Fit model with these weights
1293
+ alpha, beta, L = self._estimate_model(
1294
+ Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods
1295
+ )
1296
+
1297
+ # Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it}
1298
+ tau = Y[t, i] - alpha[i] - beta[t] - L[t, i]
1299
+ tau_values.append(tau)
1300
+ if local_weight_arr is not None:
1301
+ tau_weights.append(local_weight_arr[i])
1302
+
1303
+ if not tau_values:
1304
+ return float("nan")
1305
+ if local_weight_arr is not None:
1306
+ return float(np.average(tau_values, weights=tau_weights))
1307
+ return float(np.mean(tau_values))