diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1270 @@
1
+ """
2
+ Global estimation method for the TROP estimator.
3
+
4
+ Contains the TROPGlobalMixin class with all methods for the global
5
+ (joint) estimation pathway. The global method fits a single weighted
6
+ model on control observations and extracts per-observation treatment
7
+ effects as post-hoc residuals.
8
+
9
+ This module is used via mixin inheritance — see trop.py for the
10
+ main TROP class definition.
11
+ """
12
+
13
+ import logging
14
+ import warnings
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ from diff_diff._backend import (
23
+ HAS_RUST_BACKEND,
24
+ _rust_bootstrap_trop_variance_global,
25
+ _rust_loocv_grid_search_global,
26
+ )
27
+ from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment
28
+ from diff_diff.trop_results import TROPResults
29
+ from diff_diff.utils import safe_inference
30
+
31
+
32
+ class TROPGlobalMixin:
33
+ """Mixin providing global estimation method for TROP.
34
+
35
+ Methods in this mixin access the following attributes from the main
36
+ TROP class via ``self``:
37
+
38
+ - Tuning grids: ``lambda_time_grid``, ``lambda_unit_grid``, ``lambda_nn_grid``
39
+ - Solver params: ``max_iter``, ``tol``
40
+ - Inference params: ``alpha``, ``n_bootstrap``, ``seed``
41
+ - State: ``results_``, ``is_fitted_``
42
+
43
+ """
44
+
45
+ # Type hints for attributes accessed from the main TROP class
46
+ lambda_time_grid: List[float]
47
+ lambda_unit_grid: List[float]
48
+ lambda_nn_grid: List[float]
49
+ max_iter: int
50
+ tol: float
51
+ alpha: float
52
+ n_bootstrap: int
53
+ seed: Optional[int]
54
+ results_: Any
55
+ is_fitted_: bool
56
+
57
+ def _compute_global_weights(
58
+ self,
59
+ Y: np.ndarray,
60
+ D: np.ndarray,
61
+ lambda_time: float,
62
+ lambda_unit: float,
63
+ treated_periods: int,
64
+ n_units: int,
65
+ n_periods: int,
66
+ ) -> np.ndarray:
67
+ """
68
+ Compute distance-based weights for global estimation.
69
+
70
+ Following the reference implementation, weights are computed based on:
71
+ - Time distance: distance to center of treated block
72
+ - Unit distance: RMSE to average treated trajectory over pre-periods
73
+
74
+ Parameters
75
+ ----------
76
+ Y : np.ndarray
77
+ Outcome matrix (n_periods x n_units).
78
+ D : np.ndarray
79
+ Treatment indicator matrix (n_periods x n_units).
80
+ lambda_time : float
81
+ Time weight decay parameter.
82
+ lambda_unit : float
83
+ Unit weight decay parameter.
84
+ treated_periods : int
85
+ Number of post-treatment periods.
86
+ n_units : int
87
+ Number of units.
88
+ n_periods : int
89
+ Number of periods.
90
+
91
+ Returns
92
+ -------
93
+ np.ndarray
94
+ Weight matrix (n_periods x n_units).
95
+ """
96
+ # Identify treated units (ever treated)
97
+ treated_mask = np.any(D == 1, axis=0)
98
+ treated_unit_idx = np.where(treated_mask)[0]
99
+
100
+ if len(treated_unit_idx) == 0:
101
+ raise ValueError("No treated units found")
102
+
103
+ # Time weights: distance to center of treated block
104
+ # Following reference: center = T - treated_periods/2
105
+ center = n_periods - treated_periods / 2.0
106
+ dist_time = np.abs(np.arange(n_periods, dtype=float) - center)
107
+ delta_time = np.exp(-lambda_time * dist_time)
108
+
109
+ # Unit weights: RMSE to average treated trajectory over pre-periods
110
+ # Compute average treated trajectory (use nanmean to handle NaN)
111
+ average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1)
112
+
113
+ # Pre-period mask: 1 in pre, 0 in post
114
+ pre_mask = np.ones(n_periods, dtype=float)
115
+ pre_mask[-treated_periods:] = 0.0
116
+
117
+ # Compute RMS distance for each unit
118
+ # dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre)
119
+ # Use NaN-safe operations: treat NaN differences as 0 (excluded)
120
+ diff = average_treated[:, np.newaxis] - Y
121
+ diff_sq = np.where(np.isfinite(diff), diff**2, 0.0) * pre_mask[:, np.newaxis]
122
+
123
+ # Count valid observations per unit in pre-period
124
+ # Must check diff is finite (both Y and average_treated finite)
125
+ # to match the periods contributing to diff_sq
126
+ valid_count = np.sum(np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0)
127
+ sum_sq = np.sum(diff_sq, axis=0)
128
+ n_pre = np.sum(pre_mask)
129
+
130
+ if n_pre == 0:
131
+ raise ValueError("No pre-treatment periods")
132
+
133
+ # Track units with no valid pre-period data
134
+ no_valid_pre = valid_count == 0
135
+
136
+ # Use valid count per unit (avoid division by zero for calculation)
137
+ valid_count_safe = np.maximum(valid_count, 1)
138
+ dist_unit = np.sqrt(sum_sq / valid_count_safe)
139
+
140
+ # Units with no valid pre-period data get zero weight
141
+ # (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0)
142
+ delta_unit = np.exp(-lambda_unit * dist_unit)
143
+ delta_unit[no_valid_pre] = 0.0
144
+
145
+ # Outer product: (n_periods x n_units)
146
+ delta = np.outer(delta_time, delta_unit)
147
+
148
+ # (1-W) masking: zero out treated observations per paper Eq. 2
149
+ # Model is fit on control data only; tau extracted post-hoc
150
+ delta = delta * (1 - D)
151
+
152
+ return delta
153
+
154
+ def _solve_global_model(
155
+ self,
156
+ Y: np.ndarray,
157
+ delta: np.ndarray,
158
+ lambda_nn: float,
159
+ ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
160
+ """
161
+ Dispatch to no-lowrank or with-lowrank solver based on lambda_nn.
162
+
163
+ Returns (mu, alpha, beta, L) in all cases.
164
+ """
165
+ n_periods, n_units = Y.shape
166
+ if lambda_nn >= 1e10:
167
+ mu, alpha, beta = self._solve_global_no_lowrank(Y, delta)
168
+ L = np.zeros((n_periods, n_units))
169
+ else:
170
+ mu, alpha, beta, L = self._solve_global_with_lowrank(
171
+ Y, delta, lambda_nn, self.max_iter, self.tol
172
+ )
173
+ return mu, alpha, beta, L
174
+
175
+ @staticmethod
176
+ def _extract_posthoc_tau(
177
+ Y: np.ndarray,
178
+ D: np.ndarray,
179
+ mu: float,
180
+ alpha: np.ndarray,
181
+ beta: np.ndarray,
182
+ L: np.ndarray,
183
+ idx_to_unit: Optional[Dict] = None,
184
+ idx_to_period: Optional[Dict] = None,
185
+ unit_weights: Optional[np.ndarray] = None,
186
+ ) -> Tuple[float, Dict, List[float]]:
187
+ """
188
+ Extract post-hoc treatment effects: tau_it = Y - mu - alpha - beta - L.
189
+
190
+ Returns (att, treatment_effects_dict, tau_values_list).
191
+ When idx_to_unit/idx_to_period are None, treatment_effects uses raw indices.
192
+ """
193
+ counterfactual = mu + alpha[np.newaxis, :] + beta[:, np.newaxis] + L
194
+ tau_matrix = Y - counterfactual
195
+
196
+ treated_mask = D == 1
197
+ finite_mask = np.isfinite(Y)
198
+ valid_treated = treated_mask & finite_mask
199
+
200
+ tau_values = tau_matrix[valid_treated].tolist()
201
+ if unit_weights is not None and tau_values:
202
+ obs_weights = unit_weights[np.where(valid_treated)[1]]
203
+ att = float(np.average(tau_values, weights=obs_weights))
204
+ else:
205
+ att = float(np.mean(tau_values)) if tau_values else np.nan
206
+
207
+ # Build treatment effects dict
208
+ treatment_effects: Dict = {}
209
+ n_periods, n_units = D.shape
210
+ for t in range(n_periods):
211
+ for i in range(n_units):
212
+ if D[t, i] == 1:
213
+ uid = idx_to_unit[i] if idx_to_unit is not None else i
214
+ tid = idx_to_period[t] if idx_to_period is not None else t
215
+ if finite_mask[t, i]:
216
+ treatment_effects[(uid, tid)] = tau_matrix[t, i]
217
+ else:
218
+ treatment_effects[(uid, tid)] = np.nan
219
+
220
+ return att, treatment_effects, tau_values
221
+
222
+ def _loocv_score_global(
223
+ self,
224
+ Y: np.ndarray,
225
+ D: np.ndarray,
226
+ control_obs: List[Tuple[int, int]],
227
+ lambda_time: float,
228
+ lambda_unit: float,
229
+ lambda_nn: float,
230
+ treated_periods: int,
231
+ n_units: int,
232
+ n_periods: int,
233
+ ) -> float:
234
+ """
235
+ Compute LOOCV score for global method with specific parameter combination.
236
+
237
+ Following paper's Equation 5:
238
+ Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2
239
+
240
+ For global method, we exclude each control observation, fit the global model
241
+ on remaining data, and compute the pseudo-treatment effect at the excluded obs.
242
+
243
+ Parameters
244
+ ----------
245
+ Y : np.ndarray
246
+ Outcome matrix (n_periods x n_units).
247
+ D : np.ndarray
248
+ Treatment indicator matrix (n_periods x n_units).
249
+ control_obs : List[Tuple[int, int]]
250
+ List of (t, i) control observations for LOOCV.
251
+ lambda_time : float
252
+ Time weight decay parameter.
253
+ lambda_unit : float
254
+ Unit weight decay parameter.
255
+ lambda_nn : float
256
+ Nuclear norm regularization parameter.
257
+ treated_periods : int
258
+ Number of post-treatment periods.
259
+ n_units : int
260
+ Number of units.
261
+ n_periods : int
262
+ Number of periods.
263
+
264
+ Returns
265
+ -------
266
+ float
267
+ LOOCV score (sum of squared pseudo-treatment effects).
268
+ """
269
+ # Compute global weights (same for all LOOCV iterations)
270
+ delta = self._compute_global_weights(
271
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
272
+ )
273
+
274
+ tau_sq_sum = 0.0
275
+ n_valid = 0
276
+
277
+ for t_ex, i_ex in control_obs:
278
+ # Create modified delta with excluded observation zeroed out
279
+ delta_ex = delta.copy()
280
+ delta_ex[t_ex, i_ex] = 0.0
281
+
282
+ try:
283
+ mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn)
284
+
285
+ # Pseudo treatment effect: tau = Y - mu - alpha - beta - L
286
+ if np.isfinite(Y[t_ex, i_ex]):
287
+ tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex]
288
+ tau_sq_sum += tau_loocv**2
289
+ n_valid += 1
290
+
291
+ except (np.linalg.LinAlgError, ValueError):
292
+ # Any failure means this lambda combination is invalid per Equation 5
293
+ return np.inf
294
+
295
+ if n_valid == 0:
296
+ return np.inf
297
+
298
+ return tau_sq_sum
299
+
300
+ def _solve_global_no_lowrank(
301
+ self,
302
+ Y: np.ndarray,
303
+ delta: np.ndarray,
304
+ ) -> Tuple[float, np.ndarray, np.ndarray]:
305
+ """
306
+ Solve TWFE via weighted least squares on control data (no low-rank).
307
+
308
+ Solves: min sum (1-W)*delta_{it}(Y_{it} - mu - alpha_i - beta_t)^2
309
+
310
+ The (1-W) masking is already applied to delta by _compute_global_weights,
311
+ so treated observations have zero weight and do not affect the fit.
312
+
313
+ Parameters
314
+ ----------
315
+ Y : np.ndarray
316
+ Outcome matrix (n_periods x n_units).
317
+ delta : np.ndarray
318
+ Weight matrix (n_periods x n_units), already (1-W) masked.
319
+
320
+ Returns
321
+ -------
322
+ Tuple[float, np.ndarray, np.ndarray]
323
+ (mu, alpha, beta) estimated parameters.
324
+ """
325
+ n_periods, n_units = Y.shape
326
+
327
+ # Flatten matrices for regression
328
+ y = Y.flatten() # length n_periods * n_units
329
+ weights = delta.flatten()
330
+
331
+ # Handle NaN values: zero weight for NaN outcomes/weights, impute with 0
332
+ # This ensures NaN observations don't contribute to estimation
333
+ valid_y = np.isfinite(y)
334
+ valid_w = np.isfinite(weights)
335
+ valid_mask = valid_y & valid_w
336
+ weights = np.where(valid_mask, weights, 0.0)
337
+ y = np.where(valid_mask, y, 0.0)
338
+
339
+ sqrt_weights = np.sqrt(np.maximum(weights, 0))
340
+
341
+ # Check for all-zero weights (matches Rust's sum_w < 1e-10 check)
342
+ sum_w = np.sum(weights)
343
+ if sum_w < 1e-10:
344
+ raise ValueError("All weights are zero - cannot estimate")
345
+
346
+ # Build design matrix: [intercept, unit_dummies, time_dummies]
347
+ # Drop first unit (unit 0) and first time (time 0) for identification
348
+ n_obs = n_periods * n_units
349
+ n_params = 1 + (n_units - 1) + (n_periods - 1)
350
+
351
+ X = np.zeros((n_obs, n_params))
352
+ X[:, 0] = 1.0 # intercept
353
+
354
+ # Unit dummies (skip unit 0)
355
+ for i in range(1, n_units):
356
+ for t in range(n_periods):
357
+ X[t * n_units + i, i] = 1.0
358
+
359
+ # Time dummies (skip time 0)
360
+ for t in range(1, n_periods):
361
+ for i in range(n_units):
362
+ X[t * n_units + i, (n_units - 1) + t] = 1.0
363
+
364
+ # Apply weights
365
+ X_weighted = X * sqrt_weights[:, np.newaxis]
366
+ y_weighted = y * sqrt_weights
367
+
368
+ # Solve weighted least squares
369
+ try:
370
+ coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None)
371
+ except np.linalg.LinAlgError:
372
+ # Fallback: use pseudo-inverse
373
+ warnings.warn(
374
+ "Least-squares solver failed in TROP global estimation; "
375
+ "falling back to pseudo-inverse. Results may be less "
376
+ "numerically stable.",
377
+ UserWarning,
378
+ stacklevel=2,
379
+ )
380
+ coeffs = np.dot(np.linalg.pinv(X_weighted), y_weighted)
381
+
382
+ # Extract parameters
383
+ mu = coeffs[0]
384
+ alpha = np.zeros(n_units)
385
+ alpha[1:] = coeffs[1:n_units]
386
+ beta = np.zeros(n_periods)
387
+ beta[1:] = coeffs[n_units : (n_units + n_periods - 1)]
388
+
389
+ return float(mu), alpha, beta
390
+
391
+ def _solve_global_with_lowrank(
392
+ self,
393
+ Y: np.ndarray,
394
+ delta: np.ndarray,
395
+ lambda_nn: float,
396
+ max_iter: int = 100,
397
+ tol: float = 1e-6,
398
+ ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
399
+ """
400
+ Solve TWFE + low-rank on control data via alternating minimization.
401
+
402
+ Solves: min sum (1-W)*delta_{it}(Y_{it} - mu - alpha_i - beta_t - L_{it})^2 + lambda_nn||L||_*
403
+
404
+ The (1-W) masking is already applied to delta by _compute_global_weights,
405
+ so treated observations have zero weight and do not affect the fit.
406
+
407
+ Parameters
408
+ ----------
409
+ Y : np.ndarray
410
+ Outcome matrix (n_periods x n_units).
411
+ delta : np.ndarray
412
+ Weight matrix (n_periods x n_units), already (1-W) masked.
413
+ lambda_nn : float
414
+ Nuclear norm regularization parameter.
415
+ max_iter : int, default=100
416
+ Maximum iterations for alternating minimization.
417
+ tol : float, default=1e-6
418
+ Convergence tolerance.
419
+
420
+ Returns
421
+ -------
422
+ Tuple[float, np.ndarray, np.ndarray, np.ndarray]
423
+ (mu, alpha, beta, L) estimated parameters.
424
+ """
425
+ n_periods, n_units = Y.shape
426
+
427
+ # Handle NaN values: impute with 0 for computations
428
+ # The solver will also zero weights for NaN observations
429
+ Y_safe = np.where(np.isfinite(Y), Y, 0.0)
430
+
431
+ # Mask delta to exclude NaN outcomes from estimation
432
+ # This ensures NaN observations don't contribute to the gradient step
433
+ nan_mask = ~np.isfinite(Y)
434
+ delta_masked = delta.copy()
435
+ delta_masked[nan_mask] = 0.0
436
+
437
+ # Precompute normalized weights and threshold (constant across iterations)
438
+ delta_max = np.max(delta_masked)
439
+ if delta_max > 0:
440
+ delta_norm = delta_masked / delta_max
441
+ else:
442
+ delta_norm = delta_masked
443
+ threshold = lambda_nn / (2.0 * delta_max) if delta_max > 0 else lambda_nn / 2.0
444
+
445
+ # Initialize L = 0
446
+ L = np.zeros((n_periods, n_units))
447
+
448
+ for iteration in range(max_iter):
449
+ L_old = L.copy()
450
+
451
+ # Step 1: Fix L, solve for (mu, alpha, beta)
452
+ Y_adj = Y_safe - L
453
+ mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)
454
+
455
+ # Step 2: Fix (mu, alpha, beta), update L with FISTA acceleration
456
+ R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis]
457
+
458
+ # For delta=0 observations (treated/NaN), keep L rather than R
459
+ R_masked = np.where(delta_masked > 0, R, L)
460
+
461
+ # Inner FISTA loop for L update
462
+ L_inner = L.copy()
463
+ L_inner_prev = L_inner # share reference initially (no copy needed)
464
+ t_fista = 1.0
465
+
466
+ for _ in range(20):
467
+ # FISTA momentum
468
+ t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
469
+ momentum = (t_fista - 1.0) / t_fista_new
470
+ L_momentum = L_inner + momentum * (L_inner - L_inner_prev)
471
+
472
+ # Gradient step from momentum point
473
+ gradient_step = L_momentum + delta_norm * (R_masked - L_momentum)
474
+
475
+ # Proximal step: soft-threshold singular values
476
+ L_inner_prev = L_inner
477
+ L_inner = _soft_threshold_svd(gradient_step, threshold)
478
+ t_fista = t_fista_new
479
+
480
+ # Convergence check (L_inner_prev holds the pre-SVD value)
481
+ if np.max(np.abs(L_inner - L_inner_prev)) < tol:
482
+ break
483
+
484
+ L = L_inner
485
+
486
+ # Outer convergence check
487
+ if np.max(np.abs(L - L_old)) < tol:
488
+ break
489
+
490
+ # Final re-solve with converged L (match Rust behavior)
491
+ Y_adj = Y_safe - L
492
+ mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)
493
+
494
+ return mu, alpha, beta, L
495
+
496
+ def _fit_global(
497
+ self,
498
+ data: pd.DataFrame,
499
+ outcome: str,
500
+ treatment: str,
501
+ unit: str,
502
+ time: str,
503
+ resolved_survey=None,
504
+ survey_metadata=None,
505
+ survey_design=None,
506
+ ) -> TROPResults:
507
+ """
508
+ Fit TROP using global weighted least squares method.
509
+
510
+ Fits a single model on control observations using (1-W) masked weights,
511
+ then extracts per-observation treatment effects as post-hoc residuals.
512
+ ATT is the mean of these heterogeneous effects.
513
+
514
+ Parameters
515
+ ----------
516
+ data : pd.DataFrame
517
+ Panel data.
518
+ outcome : str
519
+ Outcome variable column name.
520
+ treatment : str
521
+ Treatment indicator column name.
522
+ unit : str
523
+ Unit identifier column name.
524
+ time : str
525
+ Time period column name.
526
+
527
+ Returns
528
+ -------
529
+ TROPResults
530
+ Estimation results.
531
+
532
+ Notes
533
+ -----
534
+ Bootstrap variance estimation assumes simultaneous treatment adoption
535
+ (fixed `treated_periods` across resamples). The treatment timing is
536
+ inferred from the data once and held constant for all bootstrap
537
+ iterations. For staggered adoption designs where treatment timing varies
538
+ across units, use `method="local"` which computes observation-specific
539
+ weights that naturally handle heterogeneous timing.
540
+ """
541
+ # Data setup (same as local method)
542
+ all_units = sorted(data[unit].unique())
543
+ all_periods = sorted(data[time].unique())
544
+
545
+ # Extract per-unit survey weights for weighted ATT aggregation
546
+ if resolved_survey is not None:
547
+ from diff_diff.survey import _extract_unit_survey_weights
548
+
549
+ unit_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units)
550
+ else:
551
+ unit_weight_arr = None
552
+
553
+ n_units = len(all_units)
554
+ n_periods = len(all_periods)
555
+
556
+ idx_to_unit = {i: u for i, u in enumerate(all_units)}
557
+ idx_to_period = {i: p for i, p in enumerate(all_periods)}
558
+
559
+ # Create matrices
560
+ Y = (
561
+ data.pivot(index=time, columns=unit, values=outcome)
562
+ .reindex(index=all_periods, columns=all_units)
563
+ .values
564
+ )
565
+
566
+ D, missing_mask = _validate_and_pivot_treatment(
567
+ data, time, unit, treatment, all_periods, all_units
568
+ )
569
+
570
+ # Validate absorbing state
571
+ violating_units = []
572
+ for unit_idx in range(n_units):
573
+ observed_mask = ~missing_mask[:, unit_idx]
574
+ observed_d = D[observed_mask, unit_idx]
575
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
576
+ violating_units.append(all_units[unit_idx])
577
+
578
+ if violating_units:
579
+ raise ValueError(
580
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
581
+ f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
582
+ f"If this is event-study style data, convert to absorbing state: "
583
+ f"D[t, i] = 1 for all t >= first treatment period."
584
+ )
585
+
586
+ # Identify treated observations
587
+ treated_mask = D == 1
588
+ n_treated_obs = np.sum(treated_mask)
589
+
590
+ if n_treated_obs == 0:
591
+ raise ValueError("No treated observations found")
592
+
593
+ # Identify treated and control units
594
+ unit_ever_treated = np.any(D == 1, axis=0)
595
+ treated_unit_idx = np.where(unit_ever_treated)[0]
596
+ control_unit_idx = np.where(~unit_ever_treated)[0]
597
+
598
+ if len(control_unit_idx) == 0:
599
+ raise ValueError("No control units found")
600
+
601
+ # Determine pre/post periods
602
+ first_treat_period = None
603
+ for t in range(n_periods):
604
+ if np.any(D[t, :] == 1):
605
+ first_treat_period = t
606
+ break
607
+
608
+ if first_treat_period is None:
609
+ raise ValueError("Could not infer post-treatment periods from D matrix")
610
+
611
+ n_pre_periods = first_treat_period
612
+ treated_periods = n_periods - first_treat_period
613
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
614
+
615
+ if n_pre_periods < 2:
616
+ raise ValueError("Need at least 2 pre-treatment periods")
617
+
618
+ # Check for staggered adoption (global method requires simultaneous treatment)
619
+ # Use only observed periods (skip missing) to avoid false positives on unbalanced panels
620
+ first_treat_by_unit = []
621
+ for i in treated_unit_idx:
622
+ observed_mask = ~missing_mask[:, i]
623
+ # Get D values for observed periods only
624
+ observed_d = D[observed_mask, i]
625
+ observed_periods = np.where(observed_mask)[0]
626
+ # Find first treatment among observed periods
627
+ treated_idx = np.where(observed_d == 1)[0]
628
+ if len(treated_idx) > 0:
629
+ first_treat_by_unit.append(observed_periods[treated_idx[0]])
630
+
631
+ unique_starts = sorted(set(first_treat_by_unit))
632
+ if len(unique_starts) > 1:
633
+ raise ValueError(
634
+ f"method='global' requires simultaneous treatment adoption, but your data "
635
+ f"shows staggered adoption (units first treated at periods {unique_starts}). "
636
+ f"Use method='local' which properly handles staggered adoption designs."
637
+ )
638
+
639
+ # LOOCV grid search for tuning parameters
640
+ # Use Rust backend when available for parallel LOOCV (5-10x speedup)
641
+ best_lambda = None
642
+ best_score = np.inf
643
+ control_mask = D == 0
644
+
645
+ if HAS_RUST_BACKEND and _rust_loocv_grid_search_global is not None:
646
+ try:
647
+ # Prepare inputs for Rust function
648
+ control_mask_u8 = control_mask.astype(np.uint8)
649
+
650
+ lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
651
+ lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
652
+ lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
653
+
654
+ result = _rust_loocv_grid_search_global(
655
+ Y,
656
+ D.astype(np.float64),
657
+ control_mask_u8,
658
+ lambda_time_arr,
659
+ lambda_unit_arr,
660
+ lambda_nn_arr,
661
+ self.max_iter,
662
+ self.tol,
663
+ )
664
+ # Unpack result - 7 values including optional first_failed_obs
665
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = (
666
+ result
667
+ )
668
+ # Only accept finite scores - infinite means all fits failed
669
+ if np.isfinite(best_score):
670
+ best_lambda = (best_lt, best_lu, best_ln)
671
+ # Emit warnings consistent with Python implementation
672
+ if n_valid == 0:
673
+ obs_info = ""
674
+ if first_failed_obs is not None:
675
+ t_idx, i_idx = first_failed_obs
676
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
677
+ warnings.warn(
678
+ f"LOOCV: All {n_attempted} fits failed for "
679
+ f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
680
+ f"Returning infinite score.{obs_info}",
681
+ UserWarning,
682
+ )
683
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
684
+ n_failed = n_attempted - n_valid
685
+ obs_info = ""
686
+ if first_failed_obs is not None:
687
+ t_idx, i_idx = first_failed_obs
688
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
689
+ warnings.warn(
690
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
691
+ f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
692
+ f"This may indicate numerical instability.{obs_info}",
693
+ UserWarning,
694
+ )
695
+ except Exception as e:
696
+ # Fall back to Python implementation on error
697
+ logger.debug(
698
+ "Rust LOOCV grid search (global) failed, falling back to Python: %s", e
699
+ )
700
+ warnings.warn(
701
+ f"Rust backend failed for LOOCV grid search (global); "
702
+ f"falling back to Python. Performance may be reduced. "
703
+ f"Error: {e}",
704
+ UserWarning,
705
+ stacklevel=2,
706
+ )
707
+ best_lambda = None
708
+ best_score = np.inf
709
+
710
+ # Fall back to Python implementation if Rust unavailable or failed
711
+ if best_lambda is None:
712
+ # Get control observations for LOOCV
713
+ control_obs = [
714
+ (t, i)
715
+ for t in range(n_periods)
716
+ for i in range(n_units)
717
+ if control_mask[t, i] and not np.isnan(Y[t, i])
718
+ ]
719
+
720
+ # Grid search with true LOOCV
721
+ for lambda_time_val in self.lambda_time_grid:
722
+ for lambda_unit_val in self.lambda_unit_grid:
723
+ for lambda_nn_val in self.lambda_nn_grid:
724
+ # Convert lambda_nn=inf -> large finite value (factor model disabled)
725
+ lt = lambda_time_val
726
+ lu = lambda_unit_val
727
+ ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val
728
+
729
+ try:
730
+ score = self._loocv_score_global(
731
+ Y, D, control_obs, lt, lu, ln, treated_periods, n_units, n_periods
732
+ )
733
+
734
+ if score < best_score:
735
+ best_score = score
736
+ best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val)
737
+
738
+ except (np.linalg.LinAlgError, ValueError):
739
+ continue
740
+
741
+ if best_lambda is None:
742
+ warnings.warn("All tuning parameter combinations failed. Using defaults.", UserWarning)
743
+ best_lambda = (1.0, 1.0, 0.1)
744
+ best_score = np.nan
745
+
746
+ # Final estimation with best parameters
747
+ lambda_time, lambda_unit, lambda_nn = best_lambda
748
+ original_lambda_nn = lambda_nn
749
+
750
+ # Convert lambda_nn=inf -> large finite value (factor model disabled, L~0)
751
+ # lambda_time and lambda_unit use 0.0 for uniform weights directly (no conversion needed)
752
+ if np.isinf(lambda_nn):
753
+ lambda_nn = 1e10
754
+
755
+ # Compute final weights and fit
756
+ delta = self._compute_global_weights(
757
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
758
+ )
759
+
760
+ mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
761
+
762
+ # Post-hoc tau extraction (per paper Eq. 2)
763
+ att, treatment_effects, tau_values = self._extract_posthoc_tau(
764
+ Y,
765
+ D,
766
+ mu,
767
+ alpha,
768
+ beta,
769
+ L,
770
+ idx_to_unit,
771
+ idx_to_period,
772
+ unit_weights=unit_weight_arr,
773
+ )
774
+
775
+ # Use count of valid (finite) treated outcomes for df and metadata
776
+ n_valid_treated = len(tau_values)
777
+ if n_valid_treated == 0:
778
+ warnings.warn(
779
+ "All treated outcomes are NaN/missing. Cannot estimate ATT.",
780
+ UserWarning,
781
+ )
782
+ elif n_valid_treated < n_treated_obs:
783
+ warnings.warn(
784
+ f"Only {n_valid_treated} of {n_treated_obs} treated outcomes are finite. "
785
+ "df and n_treated_obs reflect valid observations only.",
786
+ UserWarning,
787
+ )
788
+
789
+ # Compute effective rank of L
790
+ _, s, _ = np.linalg.svd(L, full_matrices=False)
791
+ if s[0] > 0:
792
+ effective_rank = np.sum(s) / s[0]
793
+ else:
794
+ effective_rank = 0.0
795
+
796
+ # Bootstrap variance estimation
797
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
798
+
799
+ se, bootstrap_dist = self._bootstrap_variance_global(
800
+ data,
801
+ outcome,
802
+ treatment,
803
+ unit,
804
+ time,
805
+ effective_lambda,
806
+ treated_periods,
807
+ survey_design=survey_design,
808
+ unit_weight_arr=unit_weight_arr,
809
+ resolved_survey=resolved_survey,
810
+ )
811
+
812
+ # Compute test statistics
813
+ df_trop = max(1, n_valid_treated - 1)
814
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop)
815
+
816
+ # Create results dictionaries
817
+ unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)}
818
+ time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)}
819
+
820
+ self.results_ = TROPResults(
821
+ att=float(att),
822
+ se=float(se),
823
+ t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat,
824
+ p_value=float(p_value) if np.isfinite(p_value) else p_value,
825
+ conf_int=conf_int,
826
+ n_obs=len(data),
827
+ n_treated=len(treated_unit_idx),
828
+ n_control=len(control_unit_idx),
829
+ n_treated_obs=int(n_valid_treated),
830
+ unit_effects=unit_effects_dict,
831
+ time_effects=time_effects_dict,
832
+ treatment_effects=treatment_effects,
833
+ lambda_time=lambda_time,
834
+ lambda_unit=lambda_unit,
835
+ lambda_nn=original_lambda_nn,
836
+ factor_matrix=L,
837
+ effective_rank=effective_rank,
838
+ loocv_score=best_score,
839
+ alpha=self.alpha,
840
+ n_pre_periods=n_pre_periods,
841
+ n_post_periods=n_post_periods,
842
+ n_bootstrap=self.n_bootstrap,
843
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
844
+ survey_metadata=survey_metadata,
845
+ )
846
+
847
+ self.is_fitted_ = True
848
+ return self.results_
849
+
850
+ def _bootstrap_variance_global(
851
+ self,
852
+ data: pd.DataFrame,
853
+ outcome: str,
854
+ treatment: str,
855
+ unit: str,
856
+ time: str,
857
+ optimal_lambda: Tuple[float, float, float],
858
+ treated_periods: int,
859
+ survey_design=None,
860
+ unit_weight_arr: Optional[np.ndarray] = None,
861
+ resolved_survey=None,
862
+ ) -> Tuple[float, np.ndarray]:
863
+ """
864
+ Compute bootstrap standard error for global method.
865
+
866
+ Uses Rust backend when available for parallel bootstrap (5-15x speedup).
867
+ When a full survey design (strata/PSU/FPC) is present, uses Rao-Wu
868
+ rescaled bootstrap instead, which skips the Rust path.
869
+
870
+ Parameters
871
+ ----------
872
+ data : pd.DataFrame
873
+ Original data.
874
+ outcome : str
875
+ Outcome column name.
876
+ treatment : str
877
+ Treatment column name.
878
+ unit : str
879
+ Unit column name.
880
+ time : str
881
+ Time column name.
882
+ optimal_lambda : tuple
883
+ Optimal tuning parameters.
884
+ treated_periods : int
885
+ Number of post-treatment periods.
886
+ survey_design : SurveyDesign, optional
887
+ Survey design specification.
888
+ unit_weight_arr : np.ndarray, optional
889
+ Unit-level survey weights.
890
+ resolved_survey : ResolvedSurveyDesign, optional
891
+ Resolved survey design (observation-level).
892
+
893
+ Returns
894
+ -------
895
+ Tuple[float, np.ndarray]
896
+ (se, bootstrap_estimates).
897
+ """
898
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
899
+
900
+ # Check for full survey design (strata/PSU/FPC present)
901
+ _has_full_design = resolved_survey is not None and (
902
+ resolved_survey.strata is not None
903
+ or resolved_survey.psu is not None
904
+ or resolved_survey.fpc is not None
905
+ )
906
+
907
+ # Full survey design: use Python Rao-Wu rescaled bootstrap
908
+ if _has_full_design:
909
+ return self._bootstrap_rao_wu_global(
910
+ data,
911
+ outcome,
912
+ treatment,
913
+ unit,
914
+ time,
915
+ optimal_lambda,
916
+ treated_periods,
917
+ resolved_survey,
918
+ survey_design,
919
+ )
920
+
921
+ # Try Rust backend for parallel bootstrap (5-15x speedup)
922
+ # Only used for pweight-only designs (no strata/PSU/FPC)
923
+ if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_global is not None:
924
+ try:
925
+ # Create matrices for Rust function
926
+ all_units = sorted(data[unit].unique())
927
+ all_periods = sorted(data[time].unique())
928
+
929
+ Y = (
930
+ data.pivot(index=time, columns=unit, values=outcome)
931
+ .reindex(index=all_periods, columns=all_units)
932
+ .values
933
+ )
934
+ D = (
935
+ data.pivot(index=time, columns=unit, values=treatment)
936
+ .reindex(index=all_periods, columns=all_units)
937
+ .fillna(0)
938
+ .astype(np.float64)
939
+ .values
940
+ )
941
+
942
+ bootstrap_estimates, se = _rust_bootstrap_trop_variance_global(
943
+ Y,
944
+ D,
945
+ lambda_time,
946
+ lambda_unit,
947
+ lambda_nn,
948
+ self.n_bootstrap,
949
+ self.max_iter,
950
+ self.tol,
951
+ self.seed if self.seed is not None else 0,
952
+ unit_weight_arr,
953
+ )
954
+
955
+ if len(bootstrap_estimates) < 10:
956
+ warnings.warn(
957
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
958
+ UserWarning,
959
+ )
960
+ if len(bootstrap_estimates) == 0:
961
+ return np.nan, np.array([])
962
+
963
+ return float(se), np.array(bootstrap_estimates)
964
+
965
+ except Exception as e:
966
+ logger.debug("Rust bootstrap (global) failed, falling back to Python: %s", e)
967
+ warnings.warn(
968
+ f"Rust backend failed for bootstrap variance (global); "
969
+ f"falling back to Python. Performance may be reduced. "
970
+ f"Error: {e}",
971
+ UserWarning,
972
+ stacklevel=2,
973
+ )
974
+
975
+ # Python fallback implementation
976
+ rng = np.random.default_rng(self.seed)
977
+
978
+ # Stratified bootstrap sampling
979
+ unit_ever_treated = data.groupby(unit)[treatment].max()
980
+ treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist())
981
+ control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist())
982
+
983
+ n_treated_units = len(treated_units)
984
+ n_control_units = len(control_units)
985
+
986
+ bootstrap_estimates_list: List[float] = []
987
+
988
+ for _ in range(self.n_bootstrap):
989
+ # Stratified sampling
990
+ if n_control_units > 0:
991
+ sampled_control = rng.choice(control_units, size=n_control_units, replace=True)
992
+ else:
993
+ sampled_control = np.array([], dtype=object)
994
+
995
+ if n_treated_units > 0:
996
+ sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True)
997
+ else:
998
+ sampled_treated = np.array([], dtype=object)
999
+
1000
+ sampled_units = np.concatenate([sampled_control, sampled_treated])
1001
+
1002
+ # Create bootstrap sample
1003
+ boot_data = pd.concat(
1004
+ [
1005
+ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"})
1006
+ for idx, u in enumerate(sampled_units)
1007
+ ],
1008
+ ignore_index=True,
1009
+ )
1010
+
1011
+ try:
1012
+ tau = self._fit_global_with_fixed_lambda(
1013
+ boot_data,
1014
+ outcome,
1015
+ treatment,
1016
+ unit,
1017
+ time,
1018
+ optimal_lambda,
1019
+ treated_periods,
1020
+ survey_design=survey_design,
1021
+ )
1022
+ if np.isfinite(tau):
1023
+ bootstrap_estimates_list.append(tau)
1024
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1025
+ continue
1026
+
1027
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
1028
+
1029
+ if len(bootstrap_estimates) < 10:
1030
+ warnings.warn(
1031
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning
1032
+ )
1033
+ if len(bootstrap_estimates) == 0:
1034
+ return np.nan, np.array([])
1035
+
1036
+ se = np.std(bootstrap_estimates, ddof=1)
1037
+ return float(se), bootstrap_estimates
1038
+
1039
+ def _bootstrap_rao_wu_global(
1040
+ self,
1041
+ data: pd.DataFrame,
1042
+ outcome: str,
1043
+ treatment: str,
1044
+ unit: str,
1045
+ time: str,
1046
+ optimal_lambda: Tuple[float, float, float],
1047
+ treated_periods: int,
1048
+ resolved_survey,
1049
+ survey_design,
1050
+ ) -> Tuple[float, np.ndarray]:
1051
+ """
1052
+ Rao-Wu rescaled bootstrap for global method with full survey design.
1053
+
1054
+ Instead of physically resampling units, each iteration generates
1055
+ rescaled observation weights via Rao-Wu (1988) weight perturbation.
1056
+ Cross-classifies survey strata with treatment group to preserve
1057
+ the stratified resampling structure.
1058
+
1059
+ Parameters
1060
+ ----------
1061
+ data : pd.DataFrame
1062
+ Original data.
1063
+ outcome, treatment, unit, time : str
1064
+ Column names.
1065
+ optimal_lambda : tuple
1066
+ Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn).
1067
+ treated_periods : int
1068
+ Number of post-treatment periods.
1069
+ resolved_survey : ResolvedSurveyDesign
1070
+ Resolved survey design (observation-level).
1071
+ survey_design : SurveyDesign
1072
+ Original survey design specification.
1073
+
1074
+ Returns
1075
+ -------
1076
+ Tuple[float, np.ndarray]
1077
+ (se, bootstrap_estimates).
1078
+ """
1079
+ from diff_diff.bootstrap_utils import generate_rao_wu_weights
1080
+ from diff_diff.survey import ResolvedSurveyDesign
1081
+
1082
+ lambda_time, lambda_unit, lambda_nn = optimal_lambda
1083
+ rng = np.random.default_rng(self.seed)
1084
+
1085
+ # Build unit-level resolved survey with cross-classified strata
1086
+ all_units = sorted(data[unit].unique())
1087
+ n_units = len(all_units)
1088
+
1089
+ # Determine treatment status per unit
1090
+ unit_ever_treated = data.groupby(unit)[treatment].max()
1091
+ treatment_group = np.array([int(unit_ever_treated[u]) for u in all_units], dtype=np.int64)
1092
+
1093
+ # Extract unit-level survey design fields
1094
+ first_rows = data.groupby(unit).first().loc[all_units]
1095
+
1096
+ # Weights (unit-level)
1097
+ if survey_design.weights is not None:
1098
+ unit_weights = first_rows[survey_design.weights].values.astype(np.float64)
1099
+ else:
1100
+ unit_weights = np.ones(n_units, dtype=np.float64)
1101
+
1102
+ # Strata: cross-classify survey strata x treatment group
1103
+ from diff_diff.linalg import _factorize_cluster_ids
1104
+
1105
+ if survey_design.strata is not None:
1106
+ survey_strata = first_rows[survey_design.strata].values
1107
+ cross_labels = np.array([f"{s}_{g}" for s, g in zip(survey_strata, treatment_group)])
1108
+ cross_strata = _factorize_cluster_ids(cross_labels)
1109
+ else:
1110
+ # No survey strata: use treatment group as strata
1111
+ cross_strata = treatment_group.copy()
1112
+ n_strata = len(np.unique(cross_strata))
1113
+
1114
+ # PSU (unit-level)
1115
+ psu_arr = None
1116
+ n_psu = 0
1117
+ if survey_design.psu is not None:
1118
+ psu_raw = first_rows[survey_design.psu].values
1119
+ if survey_design.nest and survey_design.strata is not None:
1120
+ combined = np.array([f"{s}_{p}" for s, p in zip(cross_strata, psu_raw)])
1121
+ psu_arr = _factorize_cluster_ids(combined)
1122
+ else:
1123
+ psu_arr = _factorize_cluster_ids(psu_raw)
1124
+ n_psu = len(np.unique(psu_arr))
1125
+ else:
1126
+ # Implicit PSU: each unit is its own PSU
1127
+ psu_arr = np.arange(n_units, dtype=np.int64)
1128
+ n_psu = n_units
1129
+
1130
+ # FPC (unit-level)
1131
+ fpc_arr = None
1132
+ if survey_design.fpc is not None:
1133
+ fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64)
1134
+
1135
+ unit_resolved = ResolvedSurveyDesign(
1136
+ weights=unit_weights,
1137
+ weight_type=resolved_survey.weight_type,
1138
+ strata=cross_strata,
1139
+ psu=psu_arr,
1140
+ fpc=fpc_arr,
1141
+ n_strata=n_strata,
1142
+ n_psu=n_psu,
1143
+ lonely_psu=resolved_survey.lonely_psu,
1144
+ )
1145
+
1146
+ # Check for unidentified variance (single unstratified PSU)
1147
+ if (
1148
+ survey_design.psu is not None
1149
+ and unit_resolved.n_psu < 2
1150
+ and survey_design.strata is None
1151
+ ):
1152
+ return np.nan, np.array([])
1153
+
1154
+ # Bootstrap loop with Rao-Wu rescaled weights
1155
+ all_periods = sorted(data[time].unique())
1156
+ n_periods = len(all_periods)
1157
+
1158
+ Y = (
1159
+ data.pivot(index=time, columns=unit, values=outcome)
1160
+ .reindex(index=all_periods, columns=all_units)
1161
+ .values
1162
+ )
1163
+ D = (
1164
+ data.pivot(index=time, columns=unit, values=treatment)
1165
+ .reindex(index=all_periods, columns=all_units)
1166
+ .fillna(0)
1167
+ .astype(int)
1168
+ .values
1169
+ )
1170
+
1171
+ bootstrap_estimates_list: List[float] = []
1172
+
1173
+ for _ in range(self.n_bootstrap):
1174
+ try:
1175
+ # Generate Rao-Wu rescaled weights (unit-level)
1176
+ boot_weights = generate_rao_wu_weights(unit_resolved, rng)
1177
+
1178
+ # Skip if all control or all treated weights are zero
1179
+ control_mask_units = treatment_group == 0
1180
+ treated_mask_units = treatment_group == 1
1181
+ if boot_weights[control_mask_units].sum() == 0:
1182
+ continue
1183
+ if boot_weights[treated_mask_units].sum() == 0:
1184
+ continue
1185
+
1186
+ # Compute global weights and fit model
1187
+ delta = self._compute_global_weights(
1188
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1189
+ )
1190
+ mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
1191
+
1192
+ # Extract weighted ATT using Rao-Wu rescaled weights
1193
+ att, _, _ = self._extract_posthoc_tau(
1194
+ Y, D, mu, alpha, beta, L, unit_weights=boot_weights
1195
+ )
1196
+
1197
+ if np.isfinite(att):
1198
+ bootstrap_estimates_list.append(att)
1199
+ except (ValueError, np.linalg.LinAlgError, KeyError):
1200
+ continue
1201
+
1202
+ bootstrap_estimates = np.array(bootstrap_estimates_list)
1203
+
1204
+ if len(bootstrap_estimates) < 10:
1205
+ warnings.warn(
1206
+ f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
1207
+ UserWarning,
1208
+ )
1209
+ if len(bootstrap_estimates) == 0:
1210
+ return np.nan, np.array([])
1211
+
1212
+ se = np.std(bootstrap_estimates, ddof=1)
1213
+ return float(se), bootstrap_estimates
1214
+
1215
+ def _fit_global_with_fixed_lambda(
1216
+ self,
1217
+ data: pd.DataFrame,
1218
+ outcome: str,
1219
+ treatment: str,
1220
+ unit: str,
1221
+ time: str,
1222
+ fixed_lambda: Tuple[float, float, float],
1223
+ treated_periods: int,
1224
+ survey_design=None,
1225
+ ) -> float:
1226
+ """
1227
+ Fit global model with fixed tuning parameters.
1228
+
1229
+ Returns the ATT (mean of post-hoc per-observation treatment effects).
1230
+ """
1231
+ lambda_time, lambda_unit, lambda_nn = fixed_lambda
1232
+
1233
+ all_units = sorted(data[unit].unique())
1234
+ all_periods = sorted(data[time].unique())
1235
+
1236
+ # Extract per-unit survey weights for weighted ATT in bootstrap
1237
+ if survey_design is not None and survey_design.weights is not None:
1238
+ from diff_diff.survey import _extract_unit_survey_weights
1239
+
1240
+ local_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units)
1241
+ else:
1242
+ local_weight_arr = None
1243
+
1244
+ n_units = len(all_units)
1245
+ n_periods = len(all_periods)
1246
+
1247
+ Y = (
1248
+ data.pivot(index=time, columns=unit, values=outcome)
1249
+ .reindex(index=all_periods, columns=all_units)
1250
+ .values
1251
+ )
1252
+ D = (
1253
+ data.pivot(index=time, columns=unit, values=treatment)
1254
+ .reindex(index=all_periods, columns=all_units)
1255
+ .fillna(0)
1256
+ .astype(int)
1257
+ .values
1258
+ )
1259
+
1260
+ # Compute weights (includes (1-W) masking)
1261
+ delta = self._compute_global_weights(
1262
+ Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
1263
+ )
1264
+
1265
+ # Fit model on control data and extract post-hoc tau
1266
+ mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
1267
+ att, _, _ = self._extract_posthoc_tau(
1268
+ Y, D, mu, alpha, beta, L, unit_weights=local_weight_arr
1269
+ )
1270
+ return att