diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/trop.py ADDED
@@ -0,0 +1,952 @@
1
+ """
2
+ Triply Robust Panel (TROP) estimator.
3
+
4
+ Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
5
+ TROP combines three robustness components:
6
+ 1. Nuclear norm regularized factor model (interactive fixed effects)
7
+ 2. Exponential distance-based unit weights
8
+ 3. Exponential time decay weights
9
+
10
+ The estimator uses leave-one-out cross-validation for tuning parameter
11
+ selection and provides robust treatment effect estimates under factor
12
+ confounding.
13
+
14
+ References
15
+ ----------
16
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
17
+ Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
18
+ """
19
+
20
+ import logging
21
+ import warnings
22
+ from typing import Any, Dict, List, Optional, Tuple
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ from diff_diff._backend import (
30
+ HAS_RUST_BACKEND,
31
+ _rust_loocv_grid_search,
32
+ )
33
+ from diff_diff.trop_global import TROPGlobalMixin
34
+ from diff_diff.trop_local import TROPLocalMixin, _validate_and_pivot_treatment
35
+ from diff_diff.trop_results import (
36
+ _LAMBDA_INF,
37
+ _PrecomputedStructures,
38
+ TROPResults,
39
+ )
40
+ from diff_diff.utils import safe_inference
41
+
42
+
43
+ class TROP(TROPLocalMixin, TROPGlobalMixin):
44
+ """
45
+ Triply Robust Panel (TROP) estimator.
46
+
47
+ Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025).
48
+ TROP combines three robustness components:
49
+
50
+ 1. **Nuclear norm regularized factor model**: Estimates interactive fixed
51
+ effects L_it via matrix completion with nuclear norm penalty ||L||_*
52
+
53
+ 2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i))
54
+ where d(j,i) is the RMSE of outcome differences between units
55
+
56
+ 3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|)
57
+ weighting pre-treatment periods by proximity to treatment
58
+
59
+ Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out
60
+ cross-validation on control observations.
61
+
62
+ Parameters
63
+ ----------
64
+ method : str, default='local'
65
+ Estimation method to use:
66
+
67
+ - 'local': Per-observation model fitting following Algorithm 2 of
68
+ Athey et al. (2025). Computes observation-specific weights and fits
69
+ a model for each treated observation, averaging the individual
70
+ treatment effects. More flexible but computationally intensive.
71
+
72
+ - 'global': Computationally efficient adaptation using the (1-W)
73
+ masking principle from Eq. 2. Fits a single model on control
74
+ observations with global weights, then computes per-observation
75
+ treatment effects as residuals:
76
+ tau_it = Y_it - mu - alpha_i - beta_t - L_it for treated cells.
77
+ ATT is the mean of these effects. For the paper's full
78
+ per-treated-cell estimator, use ``method='local'``.
79
+
80
+ lambda_time_grid : list, optional
81
+ Grid of time weight decay parameters. 0.0 = uniform weights (disabled).
82
+ Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
83
+ lambda_unit_grid : list, optional
84
+ Grid of unit weight decay parameters. 0.0 = uniform weights (disabled).
85
+ Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
86
+ lambda_nn_grid : list, optional
87
+ Grid of nuclear norm regularization parameters. inf = factor model
88
+ disabled (L=0). Default: [0, 0.01, 0.1, 1].
89
+ max_iter : int, default=100
90
+ Maximum iterations for nuclear norm optimization.
91
+ tol : float, default=1e-6
92
+ Convergence tolerance for optimization.
93
+ alpha : float, default=0.05
94
+ Significance level for confidence intervals.
95
+ n_bootstrap : int, default=200
96
+ Number of bootstrap replications for variance estimation. Must be >= 2.
97
+ seed : int, optional
98
+ Random seed for reproducibility.
99
+
100
+ Attributes
101
+ ----------
102
+ results_ : TROPResults
103
+ Estimation results after calling fit().
104
+ is_fitted_ : bool
105
+ Whether the model has been fitted.
106
+
107
+ Examples
108
+ --------
109
+ >>> from diff_diff import TROP
110
+ >>> trop = TROP()
111
+ >>> results = trop.fit(
112
+ ... data,
113
+ ... outcome='outcome',
114
+ ... treatment='treated',
115
+ ... unit='unit',
116
+ ... time='period',
117
+ ... )
118
+ >>> results.print_summary()
119
+
120
+ References
121
+ ----------
122
+ Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust
123
+ Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ method: str = "local",
129
+ lambda_time_grid: Optional[List[float]] = None,
130
+ lambda_unit_grid: Optional[List[float]] = None,
131
+ lambda_nn_grid: Optional[List[float]] = None,
132
+ max_iter: int = 100,
133
+ tol: float = 1e-6,
134
+ alpha: float = 0.05,
135
+ n_bootstrap: int = 200,
136
+ seed: Optional[int] = None,
137
+ ):
138
+ # Validate method parameter
139
+ valid_methods = ("local", "global")
140
+ if method not in valid_methods:
141
+ raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
142
+ self.method = method
143
+
144
+ # Default grids from paper
145
+ self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
146
+ self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
147
+ self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
148
+
149
+ if n_bootstrap < 2:
150
+ raise ValueError(
151
+ "n_bootstrap must be >= 2 for TROP (bootstrap variance "
152
+ "estimation is always used)"
153
+ )
154
+
155
+ self.max_iter = max_iter
156
+ self.tol = tol
157
+ self.alpha = alpha
158
+ self.n_bootstrap = n_bootstrap
159
+ self.seed = seed
160
+
161
+ # Validate that time/unit grids do not contain inf.
162
+ # Per Athey et al. (2025) Eq. 3, λ_time=0 and λ_unit=0 give uniform
163
+ # weights (exp(-0 × dist) = 1). Using inf is a misunderstanding of
164
+ # the paper's convention. Only λ_nn=∞ is valid (disables factor model).
165
+ for grid_name, grid_vals in [
166
+ ("lambda_time_grid", self.lambda_time_grid),
167
+ ("lambda_unit_grid", self.lambda_unit_grid),
168
+ ]:
169
+ if any(np.isinf(v) for v in grid_vals):
170
+ raise ValueError(
171
+ f"{grid_name} must not contain inf. Use 0.0 for uniform "
172
+ f"weights (disabled) per Athey et al. (2025) Eq. 3: "
173
+ f"exp(-0 × dist) = 1 for all distances."
174
+ )
175
+
176
+ # Internal state
177
+ self.results_: Optional[TROPResults] = None
178
+ self.is_fitted_: bool = False
179
+ self._optimal_lambda: Optional[Tuple[float, float, float]] = None
180
+
181
+ # Pre-computed structures (set during fit)
182
+ self._precomputed: Optional[_PrecomputedStructures] = None
183
+
184
+ # =========================================================================
185
+ # Parameter search (used by local method's fit() path)
186
+ # =========================================================================
187
+
188
+ def _univariate_loocv_search(
189
+ self,
190
+ Y: np.ndarray,
191
+ D: np.ndarray,
192
+ control_mask: np.ndarray,
193
+ control_unit_idx: np.ndarray,
194
+ n_units: int,
195
+ n_periods: int,
196
+ param_name: str,
197
+ grid: List[float],
198
+ fixed_params: Dict[str, float],
199
+ ) -> Tuple[float, float]:
200
+ """
201
+ Search over one parameter with others fixed.
202
+
203
+ Following paper's footnote 2, this performs a univariate grid search
204
+ for one tuning parameter while holding others fixed. The fixed_params
205
+ use 0.0 for disabled time/unit weights and _LAMBDA_INF for disabled
206
+ factor model:
207
+ - lambda_nn = inf: Skip nuclear norm regularization (L=0)
208
+ - lambda_time = 0.0: Uniform time weights (exp(-0×dist)=1)
209
+ - lambda_unit = 0.0: Uniform unit weights (exp(-0×dist)=1)
210
+
211
+ Parameters
212
+ ----------
213
+ Y : np.ndarray
214
+ Outcome matrix (n_periods x n_units).
215
+ D : np.ndarray
216
+ Treatment indicator matrix (n_periods x n_units).
217
+ control_mask : np.ndarray
218
+ Boolean mask for control observations.
219
+ control_unit_idx : np.ndarray
220
+ Indices of control units.
221
+ n_units : int
222
+ Number of units.
223
+ n_periods : int
224
+ Number of periods.
225
+ param_name : str
226
+ Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'.
227
+ grid : List[float]
228
+ Grid of values to search over.
229
+ fixed_params : Dict[str, float]
230
+ Fixed values for other parameters. May include _LAMBDA_INF for lambda_nn.
231
+
232
+ Returns
233
+ -------
234
+ Tuple[float, float]
235
+ (best_value, best_score) for the searched parameter.
236
+ """
237
+ best_score = np.inf
238
+ best_value = grid[0] if grid else 0.0
239
+
240
+ for value in grid:
241
+ params = {**fixed_params, param_name: value}
242
+
243
+ lambda_time = params.get("lambda_time", 0.0)
244
+ lambda_unit = params.get("lambda_unit", 0.0)
245
+ lambda_nn = params.get("lambda_nn", 0.0)
246
+
247
+ # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
248
+ # λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed)
249
+ if np.isinf(lambda_nn):
250
+ lambda_nn = 1e10
251
+
252
+ try:
253
+ score = self._loocv_score_obs_specific(
254
+ Y,
255
+ D,
256
+ control_mask,
257
+ control_unit_idx,
258
+ lambda_time,
259
+ lambda_unit,
260
+ lambda_nn,
261
+ n_units,
262
+ n_periods,
263
+ )
264
+ if score < best_score:
265
+ best_score = score
266
+ best_value = value
267
+ except (np.linalg.LinAlgError, ValueError):
268
+ continue
269
+
270
+ return best_value, best_score
271
+
272
+ def _cycling_parameter_search(
273
+ self,
274
+ Y: np.ndarray,
275
+ D: np.ndarray,
276
+ control_mask: np.ndarray,
277
+ control_unit_idx: np.ndarray,
278
+ n_units: int,
279
+ n_periods: int,
280
+ initial_lambda: Tuple[float, float, float],
281
+ max_cycles: int = 10,
282
+ ) -> Tuple[float, float, float]:
283
+ """
284
+ Cycle through parameters until convergence (coordinate descent).
285
+
286
+ Following paper's footnote 2 (Stage 2), this iteratively optimizes
287
+ each tuning parameter while holding the others fixed, until convergence.
288
+
289
+ Parameters
290
+ ----------
291
+ Y : np.ndarray
292
+ Outcome matrix (n_periods x n_units).
293
+ D : np.ndarray
294
+ Treatment indicator matrix (n_periods x n_units).
295
+ control_mask : np.ndarray
296
+ Boolean mask for control observations.
297
+ control_unit_idx : np.ndarray
298
+ Indices of control units.
299
+ n_units : int
300
+ Number of units.
301
+ n_periods : int
302
+ Number of periods.
303
+ initial_lambda : Tuple[float, float, float]
304
+ Initial values (lambda_time, lambda_unit, lambda_nn).
305
+ max_cycles : int, default=10
306
+ Maximum number of coordinate descent cycles.
307
+
308
+ Returns
309
+ -------
310
+ Tuple[float, float, float]
311
+ Optimized (lambda_time, lambda_unit, lambda_nn).
312
+ """
313
+ lambda_time, lambda_unit, lambda_nn = initial_lambda
314
+ prev_score = np.inf
315
+
316
+ for cycle in range(max_cycles):
317
+ # Optimize λ_unit (fix λ_time, λ_nn)
318
+ lambda_unit, _ = self._univariate_loocv_search(
319
+ Y,
320
+ D,
321
+ control_mask,
322
+ control_unit_idx,
323
+ n_units,
324
+ n_periods,
325
+ "lambda_unit",
326
+ self.lambda_unit_grid,
327
+ {"lambda_time": lambda_time, "lambda_nn": lambda_nn},
328
+ )
329
+
330
+ # Optimize λ_time (fix λ_unit, λ_nn)
331
+ lambda_time, _ = self._univariate_loocv_search(
332
+ Y,
333
+ D,
334
+ control_mask,
335
+ control_unit_idx,
336
+ n_units,
337
+ n_periods,
338
+ "lambda_time",
339
+ self.lambda_time_grid,
340
+ {"lambda_unit": lambda_unit, "lambda_nn": lambda_nn},
341
+ )
342
+
343
+ # Optimize λ_nn (fix λ_unit, λ_time)
344
+ lambda_nn, score = self._univariate_loocv_search(
345
+ Y,
346
+ D,
347
+ control_mask,
348
+ control_unit_idx,
349
+ n_units,
350
+ n_periods,
351
+ "lambda_nn",
352
+ self.lambda_nn_grid,
353
+ {"lambda_unit": lambda_unit, "lambda_time": lambda_time},
354
+ )
355
+
356
+ # Check convergence
357
+ if abs(score - prev_score) < 1e-6:
358
+ logger.debug(
359
+ "Cycling search converged after %d cycles with score %.6f", cycle + 1, score
360
+ )
361
+ break
362
+ prev_score = score
363
+
364
+ return lambda_time, lambda_unit, lambda_nn
365
+
366
+ # =========================================================================
367
+ # Main fit method
368
+ # =========================================================================
369
+
370
+ def fit(
371
+ self,
372
+ data: pd.DataFrame,
373
+ outcome: str,
374
+ treatment: str,
375
+ unit: str,
376
+ time: str,
377
+ survey_design=None,
378
+ ) -> TROPResults:
379
+ """
380
+ Fit the TROP model.
381
+
382
+ Parameters
383
+ ----------
384
+ data : pd.DataFrame
385
+ Panel data with observations for multiple units over multiple
386
+ time periods.
387
+ outcome : str
388
+ Name of the outcome variable column.
389
+ treatment : str
390
+ Name of the treatment indicator column (0/1).
391
+
392
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a
393
+ treatment timing indicator. For each unit, D=1 for ALL periods
394
+ during and after treatment:
395
+
396
+ - D[t, i] = 0 for all t < g_i (pre-treatment periods)
397
+ - D[t, i] = 1 for all t >= g_i (treatment and post-treatment)
398
+
399
+ where g_i is the treatment start time for unit i.
400
+
401
+ For staggered adoption, different units can have different g_i.
402
+ The ATT averages over ALL D=1 cells per Equation 1 of the paper.
403
+ unit : str
404
+ Name of the unit identifier column.
405
+ time : str
406
+ Name of the time period column.
407
+ survey_design : SurveyDesign, optional
408
+ Survey design specification. Supports pweight, strata, PSU, and
409
+ FPC. Full-design surveys (strata/PSU/FPC) use Rao-Wu rescaled
410
+ bootstrap; Rust backend is pweight-only (Python fallback for
411
+ full design). Survey weights enter ATT aggregation only.
412
+
413
+ Returns
414
+ -------
415
+ TROPResults
416
+ Object containing the ATT estimate, standard error,
417
+ factor estimates, and tuning parameters. The lambda_*
418
+ attributes show the selected grid values. For lambda_time and
419
+ lambda_unit, 0.0 means uniform weights; inf is not accepted.
420
+ For lambda_nn, inf is converted to 1e10 (factor model disabled).
421
+
422
+ Raises
423
+ ------
424
+ ValueError
425
+ If required columns are missing or non-pweight survey design.
426
+ """
427
+ # Validate inputs
428
+ required_cols = [outcome, treatment, unit, time]
429
+ missing = [c for c in required_cols if c not in data.columns]
430
+ if missing:
431
+ raise ValueError(f"Missing columns: {missing}")
432
+
433
+ # Resolve survey design
434
+ from diff_diff.survey import (
435
+ _extract_unit_survey_weights,
436
+ _resolve_survey_for_fit,
437
+ _validate_unit_constant_survey,
438
+ )
439
+
440
+ resolved_survey, _survey_weights, _survey_wt, survey_metadata = _resolve_survey_for_fit(
441
+ survey_design, data, "analytical"
442
+ )
443
+ # Reject replicate-weight designs — TROP uses Rao-Wu bootstrap
444
+ if resolved_survey is not None and resolved_survey.uses_replicate_variance:
445
+ raise NotImplementedError(
446
+ "TROP does not yet support replicate-weight survey designs. "
447
+ "Use a TSL-based survey design (strata/psu/fpc)."
448
+ )
449
+ # Validate weight_type is pweight (keep restriction), but allow
450
+ # strata/PSU/FPC — those are handled via Rao-Wu rescaled bootstrap.
451
+ if resolved_survey is not None and resolved_survey.weight_type != "pweight":
452
+ raise ValueError(
453
+ "TROP requires pweight survey weights. "
454
+ f"Got weight_type='{resolved_survey.weight_type}'."
455
+ )
456
+ if resolved_survey is not None:
457
+ _validate_unit_constant_survey(data, unit, survey_design)
458
+
459
+ # Dispatch based on estimation method
460
+ if self.method == "global":
461
+ return self._fit_global(
462
+ data,
463
+ outcome,
464
+ treatment,
465
+ unit,
466
+ time,
467
+ resolved_survey=resolved_survey,
468
+ survey_metadata=survey_metadata,
469
+ survey_design=survey_design,
470
+ )
471
+
472
+ # Below is the local method (default)
473
+ # Get unique units and periods
474
+ all_units = sorted(data[unit].unique())
475
+
476
+ # Extract unit-level survey weights
477
+ if resolved_survey is not None:
478
+ unit_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units)
479
+ else:
480
+ unit_weight_arr = None
481
+ all_periods = sorted(data[time].unique())
482
+
483
+ n_units = len(all_units)
484
+ n_periods = len(all_periods)
485
+
486
+ # Create mappings
487
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
488
+ period_to_idx = {p: i for i, p in enumerate(all_periods)}
489
+ idx_to_unit = {i: u for u, i in unit_to_idx.items()}
490
+ idx_to_period = {i: p for p, i in period_to_idx.items()}
491
+
492
+ # Create outcome matrix Y (n_periods x n_units) and treatment matrix D
493
+ # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
494
+ Y = (
495
+ data.pivot(index=time, columns=unit, values=outcome)
496
+ .reindex(index=all_periods, columns=all_units)
497
+ .values
498
+ )
499
+
500
+ # For D matrix, validate observed treatment and handle unbalanced panels
501
+ D, missing_mask = _validate_and_pivot_treatment(
502
+ data, time, unit, treatment, all_periods, all_units
503
+ )
504
+
505
+ # Validate D is monotonic non-decreasing per unit (absorbing state)
506
+ # D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
507
+ # Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
508
+ # This catches 1->0 violations that span missing period gaps
509
+ # Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
510
+ # adjacent period transitions don't show it (the gap hides the transition)
511
+ violating_units = []
512
+ for unit_idx in range(n_units):
513
+ # Get observed D values for this unit (where not missing)
514
+ observed_mask = ~missing_mask[:, unit_idx]
515
+ observed_d = D[observed_mask, unit_idx]
516
+
517
+ # Check if observed sequence is monotonically non-decreasing
518
+ if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
519
+ violating_units.append(all_units[unit_idx])
520
+
521
+ if violating_units:
522
+ raise ValueError(
523
+ f"Treatment indicator is not an absorbing state for units: {violating_units}. "
524
+ f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
525
+ f"If this is event-study style data, convert to absorbing state: "
526
+ f"D[t, i] = 1 for all t >= first treatment period."
527
+ )
528
+
529
+ # Identify treated observations
530
+ treated_mask = D == 1
531
+ n_treated_obs = np.sum(treated_mask)
532
+
533
+ if n_treated_obs == 0:
534
+ raise ValueError("No treated observations found")
535
+
536
+ # Identify treated and control units
537
+ unit_ever_treated = np.any(D == 1, axis=0)
538
+ treated_unit_idx = np.where(unit_ever_treated)[0]
539
+ control_unit_idx = np.where(~unit_ever_treated)[0]
540
+
541
+ if len(control_unit_idx) == 0:
542
+ raise ValueError("No control units found")
543
+
544
+ # Determine pre/post periods from treatment indicator D
545
+ # D matrix is the sole input for treatment timing per the paper
546
+ first_treat_period = None
547
+ for t in range(n_periods):
548
+ if np.any(D[t, :] == 1):
549
+ first_treat_period = t
550
+ break
551
+ if first_treat_period is None:
552
+ raise ValueError("Could not infer post-treatment periods from D matrix")
553
+
554
+ n_pre_periods = first_treat_period
555
+ # Count periods where D=1 is actually observed (matches docstring)
556
+ # Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
557
+ n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
558
+
559
+ if n_pre_periods < 2:
560
+ raise ValueError("Need at least 2 pre-treatment periods")
561
+
562
+ # Step 1: Grid search with LOOCV for tuning parameters
563
+ best_lambda = None
564
+ best_score = np.inf
565
+
566
+ # Control observations mask (for LOOCV)
567
+ control_mask = D == 0
568
+
569
+ # Pre-compute structures that are reused across LOOCV iterations
570
+ self._precomputed = self._precompute_structures(Y, D, control_unit_idx, n_units, n_periods)
571
+
572
+ # Use Rust backend for parallel LOOCV grid search (10-50x speedup)
573
+ if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None:
574
+ try:
575
+ # Prepare inputs for Rust function
576
+ control_mask_u8 = control_mask.astype(np.uint8)
577
+ time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
578
+
579
+ lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
580
+ lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
581
+ lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
582
+
583
+ result = _rust_loocv_grid_search(
584
+ Y,
585
+ D.astype(np.float64),
586
+ control_mask_u8,
587
+ time_dist_matrix,
588
+ lambda_time_arr,
589
+ lambda_unit_arr,
590
+ lambda_nn_arr,
591
+ self.max_iter,
592
+ self.tol,
593
+ )
594
+ # Unpack result - 7 values including optional first_failed_obs
595
+ best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = (
596
+ result
597
+ )
598
+ # Only accept finite scores - infinite means all fits failed
599
+ if np.isfinite(best_score):
600
+ best_lambda = (best_lt, best_lu, best_ln)
601
+ # else: best_lambda stays None, triggering defaults fallback
602
+ # Emit warnings consistent with Python implementation
603
+ if n_valid == 0:
604
+ # Include failed observation coordinates if available (Issue 2 fix)
605
+ obs_info = ""
606
+ if first_failed_obs is not None:
607
+ t_idx, i_idx = first_failed_obs
608
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
609
+ warnings.warn(
610
+ f"LOOCV: All {n_attempted} fits failed for "
611
+ f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
612
+ f"Returning infinite score.{obs_info}",
613
+ UserWarning,
614
+ )
615
+ elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
616
+ n_failed = n_attempted - n_valid
617
+ # Include failed observation coordinates if available
618
+ obs_info = ""
619
+ if first_failed_obs is not None:
620
+ t_idx, i_idx = first_failed_obs
621
+ obs_info = f" First failure at observation ({t_idx}, {i_idx})."
622
+ warnings.warn(
623
+ f"LOOCV: {n_failed}/{n_attempted} fits failed for "
624
+ f"\u03bb=({best_lt}, {best_lu}, {best_ln}). "
625
+ f"This may indicate numerical instability.{obs_info}",
626
+ UserWarning,
627
+ )
628
+ except Exception as e:
629
+ # Fall back to Python implementation on error
630
+ logger.debug("Rust LOOCV grid search failed, falling back to Python: %s", e)
631
+ warnings.warn(
632
+ f"Rust backend failed for LOOCV grid search; "
633
+ f"falling back to Python. Performance may be reduced. "
634
+ f"Error: {e}",
635
+ UserWarning,
636
+ stacklevel=2,
637
+ )
638
+ best_lambda = None
639
+ best_score = np.inf
640
+
641
+ # Fall back to Python implementation if Rust unavailable or failed
642
+ # Uses two-stage approach per paper's footnote 2:
643
+ # Stage 1: Univariate searches for initial values
644
+ # Stage 2: Cycling (coordinate descent) until convergence
645
+ if best_lambda is None:
646
+ # Stage 1: Univariate searches with extreme fixed values
647
+ # Following paper's footnote 2 for initial bounds
648
+
649
+ # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment)
650
+ lambda_time_init, _ = self._univariate_loocv_search(
651
+ Y,
652
+ D,
653
+ control_mask,
654
+ control_unit_idx,
655
+ n_units,
656
+ n_periods,
657
+ "lambda_time",
658
+ self.lambda_time_grid,
659
+ {"lambda_unit": 0.0, "lambda_nn": _LAMBDA_INF},
660
+ )
661
+
662
+ # λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0
663
+ lambda_nn_init, _ = self._univariate_loocv_search(
664
+ Y,
665
+ D,
666
+ control_mask,
667
+ control_unit_idx,
668
+ n_units,
669
+ n_periods,
670
+ "lambda_nn",
671
+ self.lambda_nn_grid,
672
+ {"lambda_time": 0.0, "lambda_unit": 0.0},
673
+ )
674
+
675
+ # λ_unit search: fix λ_nn=∞, λ_time=0
676
+ lambda_unit_init, _ = self._univariate_loocv_search(
677
+ Y,
678
+ D,
679
+ control_mask,
680
+ control_unit_idx,
681
+ n_units,
682
+ n_periods,
683
+ "lambda_unit",
684
+ self.lambda_unit_grid,
685
+ {"lambda_nn": _LAMBDA_INF, "lambda_time": 0.0},
686
+ )
687
+
688
+ # Stage 2: Cycling refinement (coordinate descent)
689
+ lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search(
690
+ Y,
691
+ D,
692
+ control_mask,
693
+ control_unit_idx,
694
+ n_units,
695
+ n_periods,
696
+ (lambda_time_init, lambda_unit_init, lambda_nn_init),
697
+ )
698
+
699
+ # Compute final score for the optimized parameters
700
+ try:
701
+ best_score = self._loocv_score_obs_specific(
702
+ Y,
703
+ D,
704
+ control_mask,
705
+ control_unit_idx,
706
+ lambda_time,
707
+ lambda_unit,
708
+ lambda_nn,
709
+ n_units,
710
+ n_periods,
711
+ )
712
+ # Only accept finite scores - infinite means all fits failed
713
+ if np.isfinite(best_score):
714
+ best_lambda = (lambda_time, lambda_unit, lambda_nn)
715
+ # else: best_lambda stays None, triggering defaults fallback
716
+ except (np.linalg.LinAlgError, ValueError):
717
+ # If even the optimized parameters fail, best_lambda stays None
718
+ pass
719
+
720
+ if best_lambda is None:
721
+ warnings.warn("All tuning parameter combinations failed. Using defaults.", UserWarning)
722
+ best_lambda = (1.0, 1.0, 0.1)
723
+ best_score = np.nan
724
+
725
+ self._optimal_lambda = best_lambda
726
+ lambda_time, lambda_unit, lambda_nn = best_lambda
727
+
728
+ # Store original λ_nn for results (only λ_nn needs original→effective conversion).
729
+ # λ_time and λ_unit use 0.0 for uniform weights directly per Eq. 3.
730
+ original_lambda_nn = lambda_nn
731
+
732
+ # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
733
+ if np.isinf(lambda_nn):
734
+ lambda_nn = 1e10
735
+
736
+ # effective_lambda with converted λ_nn for ALL downstream computation
737
+ # (variance estimation uses the same parameters as point estimation)
738
+ effective_lambda = (lambda_time, lambda_unit, lambda_nn)
739
+
740
+ # Step 2: Final estimation - per-observation model fitting following Algorithm 2
741
+ # For each treated (i,t): compute observation-specific weights, fit model, compute tau_{it}
742
+ treatment_effects = {}
743
+ tau_values = []
744
+ tau_weights = [] # parallel to tau_values for survey-weighted ATT
745
+ alpha_estimates = []
746
+ beta_estimates = []
747
+ L_estimates = []
748
+
749
+ # Use pre-computed treated observations
750
+ treated_observations = self._precomputed["treated_observations"]
751
+
752
+ for t, i in treated_observations:
753
+ unit_id = idx_to_unit[i]
754
+ time_id = idx_to_period[t]
755
+
756
+ # Skip observations where outcome is missing -- record NaN but
757
+ # don't fit the model or include in tau_values (avoids NaN poisoning)
758
+ if not np.isfinite(Y[t, i]):
759
+ treatment_effects[(unit_id, time_id)] = np.nan
760
+ continue
761
+
762
+ # Compute observation-specific weights for this (i, t)
763
+ weight_matrix = self._compute_observation_weights(
764
+ Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
765
+ )
766
+
767
+ # Fit model with these weights
768
+ alpha_hat, beta_hat, L_hat = self._estimate_model(
769
+ Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods
770
+ )
771
+
772
+ # Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it}
773
+ tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i]
774
+
775
+ treatment_effects[(unit_id, time_id)] = tau_it
776
+ tau_values.append(tau_it)
777
+ if unit_weight_arr is not None:
778
+ tau_weights.append(unit_weight_arr[i])
779
+
780
+ # Store for averaging
781
+ alpha_estimates.append(alpha_hat)
782
+ beta_estimates.append(beta_hat)
783
+ L_estimates.append(L_hat)
784
+
785
+ # Count valid treated observations
786
+ n_valid_treated = len(tau_values)
787
+ if n_valid_treated == 0:
788
+ warnings.warn(
789
+ "All treated outcomes are NaN/missing. Cannot estimate ATT.",
790
+ UserWarning,
791
+ )
792
+ elif n_valid_treated < n_treated_obs:
793
+ warnings.warn(
794
+ f"Only {n_valid_treated} of {n_treated_obs} treated outcomes are finite. "
795
+ "df and n_treated_obs reflect valid observations only.",
796
+ UserWarning,
797
+ )
798
+
799
+ # Average ATT (survey-weighted when applicable)
800
+ if unit_weight_arr is not None and tau_values:
801
+ att = float(np.average(tau_values, weights=tau_weights))
802
+ else:
803
+ att = np.mean(tau_values) if tau_values else np.nan
804
+
805
+ # Average parameter estimates for output (representative)
806
+ alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units)
807
+ beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods)
808
+ L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units))
809
+
810
+ # Compute effective rank
811
+ _, s, _ = np.linalg.svd(L_hat, full_matrices=False)
812
+ if s[0] > 0:
813
+ effective_rank = np.sum(s) / s[0]
814
+ else:
815
+ effective_rank = 0.0
816
+
817
+ # Step 4: Variance estimation
818
+ # Use effective_lambda (converted values) to ensure SE is computed with same
819
+ # parameters as point estimation. This fixes the variance inconsistency issue.
820
+ se, bootstrap_dist = self._bootstrap_variance(
821
+ data,
822
+ outcome,
823
+ treatment,
824
+ unit,
825
+ time,
826
+ effective_lambda,
827
+ Y=Y,
828
+ D=D,
829
+ control_unit_idx=control_unit_idx,
830
+ survey_design=survey_design,
831
+ unit_weight_arr=unit_weight_arr,
832
+ resolved_survey=resolved_survey,
833
+ )
834
+
835
+ # Compute test statistics
836
+ df_trop = max(1, n_valid_treated - 1)
837
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop)
838
+
839
+ # Create results dictionaries
840
+ unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)}
841
+ time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)}
842
+
843
+ # Store results
844
+ self.results_ = TROPResults(
845
+ att=att,
846
+ se=se,
847
+ t_stat=t_stat,
848
+ p_value=p_value,
849
+ conf_int=conf_int,
850
+ n_obs=len(data),
851
+ n_treated=len(treated_unit_idx),
852
+ n_control=len(control_unit_idx),
853
+ n_treated_obs=int(n_valid_treated),
854
+ unit_effects=unit_effects_dict,
855
+ time_effects=time_effects_dict,
856
+ treatment_effects=treatment_effects,
857
+ lambda_time=lambda_time,
858
+ lambda_unit=lambda_unit,
859
+ lambda_nn=original_lambda_nn,
860
+ factor_matrix=L_hat,
861
+ effective_rank=effective_rank,
862
+ loocv_score=best_score,
863
+ alpha=self.alpha,
864
+ n_pre_periods=n_pre_periods,
865
+ n_post_periods=n_post_periods,
866
+ n_bootstrap=self.n_bootstrap,
867
+ bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None,
868
+ survey_metadata=survey_metadata,
869
+ )
870
+
871
+ self.is_fitted_ = True
872
+ return self.results_
873
+
874
+ # =========================================================================
875
+ # sklearn-like API
876
+ # =========================================================================
877
+
878
+ def get_params(self) -> Dict[str, Any]:
879
+ """Get estimator parameters."""
880
+ return {
881
+ "method": self.method,
882
+ "lambda_time_grid": self.lambda_time_grid,
883
+ "lambda_unit_grid": self.lambda_unit_grid,
884
+ "lambda_nn_grid": self.lambda_nn_grid,
885
+ "max_iter": self.max_iter,
886
+ "tol": self.tol,
887
+ "alpha": self.alpha,
888
+ "n_bootstrap": self.n_bootstrap,
889
+ "seed": self.seed,
890
+ }
891
+
892
+ def set_params(self, **params) -> "TROP":
893
+ """Set estimator parameters."""
894
+ for key, value in params.items():
895
+ if key == "method" and value not in ("local", "global"):
896
+ raise ValueError(
897
+ f"method must be one of ('local', 'global'), got '{value}'"
898
+ )
899
+ if hasattr(self, key):
900
+ setattr(self, key, value)
901
+ else:
902
+ raise ValueError(f"Unknown parameter: {key}")
903
+ return self
904
+
905
+
906
+ def trop(
907
+ data: pd.DataFrame,
908
+ outcome: str,
909
+ treatment: str,
910
+ unit: str,
911
+ time: str,
912
+ survey_design=None,
913
+ **kwargs,
914
+ ) -> TROPResults:
915
+ """
916
+ Convenience function for TROP estimation.
917
+
918
+ Parameters
919
+ ----------
920
+ data : pd.DataFrame
921
+ Panel data.
922
+ outcome : str
923
+ Outcome variable column name.
924
+ treatment : str
925
+ Treatment indicator column name (0/1).
926
+
927
+ IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment
928
+ timing indicator. For each unit, D=1 for ALL periods during and after
929
+ treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is
930
+ the treatment start time for unit i).
931
+ unit : str
932
+ Unit identifier column name.
933
+ time : str
934
+ Time period column name.
935
+ survey_design : SurveyDesign, optional
936
+ Survey design specification. Supports pweight, strata, PSU, and FPC.
937
+ **kwargs
938
+ Additional arguments passed to TROP constructor.
939
+
940
+ Returns
941
+ -------
942
+ TROPResults
943
+ Estimation results.
944
+
945
+ Examples
946
+ --------
947
+ >>> from diff_diff import trop
948
+ >>> results = trop(data, 'y', 'treated', 'unit', 'time')
949
+ >>> print(f"ATT: {results.att:.3f}")
950
+ """
951
+ estimator = TROP(**kwargs)
952
+ return estimator.fit(data, outcome, treatment, unit, time, survey_design=survey_design)