diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/staggered.py ADDED
@@ -0,0 +1,3895 @@
1
+ """
2
+ Staggered Difference-in-Differences estimators.
3
+
4
+ Implements modern methods for DiD with variation in treatment timing,
5
+ including the Callaway-Sant'Anna (2021) estimator.
6
+ """
7
+
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from scipy import linalg as scipy_linalg
14
+
15
+ from diff_diff.linalg import (
16
+ _check_propensity_diagnostics,
17
+ _detect_rank_deficiency,
18
+ _format_dropped_columns,
19
+ solve_logit,
20
+ solve_ols,
21
+ )
22
+ from diff_diff.staggered_aggregation import (
23
+ CallawaySantAnnaAggregationMixin,
24
+ )
25
+ from diff_diff.staggered_bootstrap import (
26
+ CallawaySantAnnaBootstrapMixin,
27
+ CSBootstrapResults,
28
+ )
29
+
30
+ # Import from split modules
31
+ from diff_diff.staggered_results import (
32
+ CallawaySantAnnaResults,
33
+ GroupTimeEffect,
34
+ )
35
+ from diff_diff.utils import safe_inference, safe_inference_batch
36
+
37
+ # Re-export for backward compatibility
38
+ __all__ = [
39
+ "CallawaySantAnna",
40
+ "CallawaySantAnnaResults",
41
+ "CSBootstrapResults",
42
+ "GroupTimeEffect",
43
+ ]
44
+
45
+ # Type alias for pre-computed structures
46
+ PrecomputedData = Dict[str, Any]
47
+
48
+
49
+ def _linear_regression(
50
+ X: np.ndarray,
51
+ y: np.ndarray,
52
+ rank_deficient_action: str = "warn",
53
+ weights: Optional[np.ndarray] = None,
54
+ ) -> Tuple[np.ndarray, np.ndarray]:
55
+ """
56
+ Fit OLS regression.
57
+
58
+ Parameters
59
+ ----------
60
+ X : np.ndarray
61
+ Feature matrix (n_samples, n_features). Intercept added automatically.
62
+ y : np.ndarray
63
+ Outcome variable.
64
+ rank_deficient_action : str, default "warn"
65
+ Action when design matrix is rank-deficient:
66
+ - "warn": Issue warning and drop linearly dependent columns (default)
67
+ - "error": Raise ValueError
68
+ - "silent": Drop columns silently without warning
69
+ weights : np.ndarray, optional
70
+ Observation weights for WLS. When None, OLS is used.
71
+
72
+ Returns
73
+ -------
74
+ beta : np.ndarray
75
+ Fitted coefficients (including intercept).
76
+ residuals : np.ndarray
77
+ Residuals from the fit.
78
+ """
79
+ n = X.shape[0]
80
+ # Add intercept
81
+ X_with_intercept = np.column_stack([np.ones(n), X])
82
+
83
+ # Use unified OLS backend (no vcov needed)
84
+ beta, residuals, _ = solve_ols(
85
+ X_with_intercept,
86
+ y,
87
+ return_vcov=False,
88
+ rank_deficient_action=rank_deficient_action,
89
+ weights=weights,
90
+ )
91
+
92
+ return beta, residuals
93
+
94
+
95
+ def _safe_inv(A: np.ndarray) -> np.ndarray:
96
+ """Invert a square matrix with lstsq fallback for near-singular cases."""
97
+ try:
98
+ return np.linalg.solve(A, np.eye(A.shape[0]))
99
+ except np.linalg.LinAlgError:
100
+ return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0]
101
+
102
+
103
+ class CallawaySantAnna(
104
+ CallawaySantAnnaBootstrapMixin,
105
+ CallawaySantAnnaAggregationMixin,
106
+ ):
107
+ """
108
+ Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
109
+
110
+ This estimator handles DiD designs with variation in treatment timing
111
+ (staggered adoption) and heterogeneous treatment effects. It avoids the
112
+ bias of traditional two-way fixed effects (TWFE) estimators by:
113
+
114
+ 1. Computing group-time average treatment effects ATT(g,t) for each
115
+ cohort g (units first treated in period g) and time t.
116
+ 2. Aggregating these to summary measures (overall ATT, event study, etc.)
117
+ using appropriate weights.
118
+
119
+ Parameters
120
+ ----------
121
+ control_group : str, default="never_treated"
122
+ Which units to use as controls:
123
+ - "never_treated": Use only never-treated units (recommended)
124
+ - "not_yet_treated": Use never-treated and not-yet-treated units
125
+ anticipation : int, default=0
126
+ Number of periods before treatment where effects may occur.
127
+ Set to > 0 if treatment effects can begin before the official
128
+ treatment date.
129
+ estimation_method : str, default="dr"
130
+ Estimation method:
131
+ - "dr": Doubly robust (recommended)
132
+ - "ipw": Inverse probability weighting
133
+ - "reg": Outcome regression
134
+ alpha : float, default=0.05
135
+ Significance level for confidence intervals.
136
+ cluster : str, optional
137
+ Column name for cluster-robust standard errors.
138
+ Defaults to unit-level clustering.
139
+ n_bootstrap : int, default=0
140
+ Number of bootstrap iterations for inference.
141
+ If 0, uses analytical standard errors.
142
+ Recommended: 999 or more for reliable inference.
143
+
144
+ .. note:: Memory Usage
145
+ The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
146
+ float64 array. For large datasets, this can be significant:
147
+ - 1K bootstrap × 10K units = ~80 MB
148
+ - 10K bootstrap × 100K units = ~8 GB
149
+ Consider reducing n_bootstrap if memory is constrained.
150
+
151
+ bootstrap_weights : str, default="rademacher"
152
+ Type of weights for multiplier bootstrap:
153
+ - "rademacher": +1/-1 with equal probability (standard choice)
154
+ - "mammen": Two-point distribution (asymptotically valid, matches skewness)
155
+ - "webb": Six-point distribution (recommended when n_clusters < 20)
156
+ seed : int, optional
157
+ Random seed for reproducibility.
158
+ rank_deficient_action : str, default="warn"
159
+ Action when design matrix is rank-deficient (linearly dependent columns):
160
+ - "warn": Issue warning and drop linearly dependent columns (default)
161
+ - "error": Raise ValueError
162
+ - "silent": Drop columns silently without warning
163
+ base_period : str, default="varying"
164
+ Method for selecting the base (reference) period for computing
165
+ ATT(g,t). Options:
166
+ - "varying": For pre-treatment periods (t < g - anticipation), use
167
+ t-1 as base (consecutive comparisons). For post-treatment, use
168
+ g-1-anticipation. Requires t-1 to exist in data.
169
+ - "universal": Always use g-1-anticipation as base period.
170
+ Both produce identical post-treatment effects. Matches R's
171
+ did::att_gt() base_period parameter.
172
+ cband : bool, default=True
173
+ Whether to compute simultaneous confidence bands (sup-t) for
174
+ event study aggregation. Requires ``n_bootstrap > 0``.
175
+ When True, results include ``cband_crit_value`` and per-event-time
176
+ ``cband_conf_int`` entries controlling family-wise error rate.
177
+ pscore_trim : float, default=0.01
178
+ Trimming bound for propensity scores. Scores are clipped to
179
+ ``[pscore_trim, 1 - pscore_trim]`` before weight computation
180
+ in IPW and DR estimation. Must be in ``(0, 0.5)``.
181
+ panel : bool, default=True
182
+ Whether the data is a balanced/unbalanced panel (units observed
183
+ across multiple time periods). Set to ``False`` for stationary
184
+ repeated cross-sections where each observation has a unique unit
185
+ ID and units do not repeat across periods. Requires that the
186
+ cross-sectional samples are drawn from the same population in
187
+ each period (stationarity). Uses cross-sectional DRDID
188
+ (Sant'Anna & Zhao 2020, Section 4) with per-observation influence
189
+ functions.
190
+ epv_threshold : float, default=10
191
+ Events Per Variable threshold for propensity score logit.
192
+ When the ratio of minority-class observations to predictor
193
+ variables (excluding intercept) falls below this value, a
194
+ warning is emitted (or ``ValueError`` raised if
195
+ ``rank_deficient_action="error"``). Based on Peduzzi et al.
196
+ (1996). Only applies to IPW and DR estimation methods.
197
+ Use ``diagnose_propensity()`` for a pre-estimation check across
198
+ all cohorts.
199
+ pscore_fallback : str, default="error"
200
+ Action when propensity score estimation fails entirely
201
+ (``LinAlgError`` or ``ValueError`` from IRLS):
202
+ - "error": Raise the exception (default). Ensures the user is
203
+ aware of estimation failures.
204
+ - "unconditional": Fall back to unconditional propensity
205
+ with a warning. For IPW, this drops all covariates. For DR,
206
+ the propensity model becomes unconditional but outcome
207
+ regression still uses covariates.
208
+ When ``rank_deficient_action="error"``, errors are always
209
+ re-raised regardless of this setting.
210
+
211
+ Attributes
212
+ ----------
213
+ results_ : CallawaySantAnnaResults
214
+ Estimation results after calling fit().
215
+ is_fitted_ : bool
216
+ Whether the model has been fitted.
217
+
218
+ Examples
219
+ --------
220
+ Basic usage:
221
+
222
+ >>> import pandas as pd
223
+ >>> from diff_diff import CallawaySantAnna
224
+ >>>
225
+ >>> # Panel data with staggered treatment
226
+ >>> # 'first_treat' = period when unit was first treated (0 if never treated)
227
+ >>> data = pd.DataFrame({
228
+ ... 'unit': [...],
229
+ ... 'time': [...],
230
+ ... 'outcome': [...],
231
+ ... 'first_treat': [...] # 0 for never-treated, else first treatment period
232
+ ... })
233
+ >>>
234
+ >>> cs = CallawaySantAnna()
235
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
236
+ ... time='time', first_treat='first_treat')
237
+ >>>
238
+ >>> results.print_summary()
239
+
240
+ With event study aggregation:
241
+
242
+ >>> cs = CallawaySantAnna()
243
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
244
+ ... time='time', first_treat='first_treat',
245
+ ... aggregate='event_study')
246
+ >>>
247
+ >>> # Plot event study
248
+ >>> from diff_diff import plot_event_study
249
+ >>> plot_event_study(results)
250
+
251
+ With covariate adjustment (conditional parallel trends):
252
+
253
+ >>> # When parallel trends only holds conditional on covariates
254
+ >>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
255
+ >>> results = cs.fit(data, outcome='outcome', unit='unit',
256
+ ... time='time', first_treat='first_treat',
257
+ ... covariates=['age', 'income'])
258
+ >>>
259
+ >>> # DR is recommended: consistent if either outcome model
260
+ >>> # or propensity model is correctly specified
261
+
262
+ Notes
263
+ -----
264
+ The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
265
+ approach: instead of estimating a single treatment effect, they estimate
266
+ ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
267
+ problem where already-treated units act as controls.
268
+
269
+ The ATT(g,t) is identified under parallel trends conditional on covariates:
270
+
271
+ E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
272
+
273
+ where G=g indicates treatment cohort g and C=1 indicates control units.
274
+ This uses g-1 as the base period, which applies to post-treatment (t >= g).
275
+ With base_period="varying" (default), pre-treatment uses t-1 as base for
276
+ consecutive comparisons useful in parallel trends diagnostics.
277
+
278
+ References
279
+ ----------
280
+ Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
281
+ multiple time periods. Journal of Econometrics, 225(2), 200-230.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ control_group: str = "never_treated",
287
+ anticipation: int = 0,
288
+ estimation_method: str = "dr",
289
+ alpha: float = 0.05,
290
+ cluster: Optional[str] = None,
291
+ n_bootstrap: int = 0,
292
+ bootstrap_weights: Optional[str] = None,
293
+ seed: Optional[int] = None,
294
+ rank_deficient_action: str = "warn",
295
+ base_period: str = "varying",
296
+ cband: bool = True,
297
+ pscore_trim: float = 0.01,
298
+ panel: bool = True,
299
+ epv_threshold: float = 10,
300
+ pscore_fallback: str = "error",
301
+ ):
302
+ import warnings
303
+
304
+ if control_group not in ["never_treated", "not_yet_treated"]:
305
+ raise ValueError(
306
+ f"control_group must be 'never_treated' or 'not_yet_treated', "
307
+ f"got '{control_group}'"
308
+ )
309
+ if estimation_method not in ["dr", "ipw", "reg"]:
310
+ raise ValueError(
311
+ f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'"
312
+ )
313
+ if not (0 < pscore_trim < 0.5):
314
+ raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}")
315
+ if epv_threshold <= 0:
316
+ raise ValueError(f"epv_threshold must be > 0, got {epv_threshold}")
317
+ if pscore_fallback not in ["error", "unconditional"]:
318
+ raise ValueError(
319
+ f"pscore_fallback must be 'error' or 'unconditional', " f"got '{pscore_fallback}'"
320
+ )
321
+
322
+ # Default to rademacher if not specified
323
+ if bootstrap_weights is None:
324
+ bootstrap_weights = "rademacher"
325
+
326
+ if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
327
+ raise ValueError(
328
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
329
+ f"got '{bootstrap_weights}'"
330
+ )
331
+
332
+ if rank_deficient_action not in ["warn", "error", "silent"]:
333
+ raise ValueError(
334
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
335
+ f"got '{rank_deficient_action}'"
336
+ )
337
+
338
+ if base_period not in ["varying", "universal"]:
339
+ raise ValueError(
340
+ f"base_period must be 'varying' or 'universal', " f"got '{base_period}'"
341
+ )
342
+
343
+ self.control_group = control_group
344
+ self.anticipation = anticipation
345
+ self.estimation_method = estimation_method
346
+ self.alpha = alpha
347
+ self.cluster = cluster
348
+ self.n_bootstrap = n_bootstrap
349
+ self.bootstrap_weights = bootstrap_weights
350
+ self.seed = seed
351
+ self.rank_deficient_action = rank_deficient_action
352
+ self.base_period = base_period
353
+
354
+ self.cband = cband
355
+ self.pscore_trim = pscore_trim
356
+ self.panel = panel
357
+ self.epv_threshold = epv_threshold
358
+ self.pscore_fallback = pscore_fallback
359
+
360
+ self.is_fitted_ = False
361
+ self.results_: Optional[CallawaySantAnnaResults] = None
362
+
363
+ def diagnose_propensity(
364
+ self,
365
+ df: pd.DataFrame,
366
+ outcome: str,
367
+ unit: str,
368
+ time: str,
369
+ first_treat: str,
370
+ covariates: Optional[List[str]] = None,
371
+ ) -> pd.DataFrame:
372
+ """
373
+ Check Events Per Variable (EPV) across all cohorts without estimation.
374
+
375
+ Examines the data to identify cohorts where propensity score logit may
376
+ be unreliable due to too few events per covariate. Based on Peduzzi
377
+ et al. (1996).
378
+
379
+ This is a raw-count heuristic: it uses total cohort/control unit
380
+ counts without filtering for missing outcomes, zero survey weights,
381
+ or period-specific validity. The actual fit-time EPV (stored in
382
+ ``results.epv_diagnostics``) may be lower because ``fit()`` operates
383
+ on the valid base/post outcome pair and the positive-weight effective
384
+ sample. Use this method as a quick pre-check; rely on
385
+ ``results.epv_diagnostics`` for authoritative per-cell EPV.
386
+
387
+ Parameters
388
+ ----------
389
+ df, outcome, unit, time, first_treat, covariates
390
+ Same arguments as ``fit()``.
391
+
392
+ Returns
393
+ -------
394
+ pd.DataFrame
395
+ Per-cohort EPV diagnostics with columns: group, n_treated,
396
+ n_control, n_covariates, n_params, epv, status.
397
+ """
398
+ if not self.panel:
399
+ raise NotImplementedError(
400
+ "diagnose_propensity() is not yet supported for repeated "
401
+ "cross-section data (panel=False). Use fit() with covariates "
402
+ "and check results.epv_diagnostics instead."
403
+ )
404
+ if self.control_group == "not_yet_treated":
405
+ raise NotImplementedError(
406
+ "diagnose_propensity() is not yet supported for "
407
+ "control_group='not_yet_treated' because the control set "
408
+ "varies per (g, t) cell. Use fit() with covariates and "
409
+ "check results.epv_diagnostics instead."
410
+ )
411
+ if self.estimation_method == "reg":
412
+ return pd.DataFrame(
413
+ columns=[
414
+ "group",
415
+ "n_treated",
416
+ "n_control",
417
+ "n_covariates",
418
+ "n_params",
419
+ "epv",
420
+ "status",
421
+ ]
422
+ )
423
+ if not covariates:
424
+ return pd.DataFrame(
425
+ columns=[
426
+ "group",
427
+ "n_treated",
428
+ "n_control",
429
+ "n_covariates",
430
+ "n_params",
431
+ "epv",
432
+ "status",
433
+ ]
434
+ )
435
+
436
+ # Normalize np.inf → 0 for never-treated encoding (same as fit())
437
+ df = df.copy()
438
+ _inf_mask_diag = df[first_treat].isin([np.inf, float("inf")])
439
+ if _inf_mask_diag.any():
440
+ n_inf_units = df.loc[_inf_mask_diag, unit].nunique()
441
+ warnings.warn(
442
+ f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 "
443
+ f"(never-treated). Use first_treat=0 to suppress this warning.",
444
+ UserWarning,
445
+ stacklevel=2,
446
+ )
447
+ df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
448
+
449
+ # Compute time_periods and treatment_groups (same logic as fit())
450
+ time_periods = sorted(df[time].unique())
451
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
452
+ precomputed = self._precompute_structures(
453
+ df,
454
+ outcome,
455
+ unit,
456
+ time,
457
+ first_treat,
458
+ covariates,
459
+ time_periods=time_periods,
460
+ treatment_groups=treatment_groups,
461
+ )
462
+ cohort_masks = precomputed["cohort_masks"]
463
+ never_treated_mask = precomputed["never_treated_mask"]
464
+ unit_cohorts = precomputed["unit_cohorts"]
465
+ n_covariates = len(covariates)
466
+ n_params = n_covariates # predictor count, excluding intercept (Peduzzi convention)
467
+
468
+ rows = []
469
+ for g in sorted(cohort_masks.keys()):
470
+ treated_mask = cohort_masks[g]
471
+ if self.control_group == "never_treated":
472
+ control_mask = never_treated_mask
473
+ else:
474
+ base_period_val = g - 1 - self.anticipation
475
+ nyt_threshold = base_period_val + self.anticipation
476
+ control_mask = never_treated_mask | (
477
+ (unit_cohorts > nyt_threshold) & (unit_cohorts != g)
478
+ )
479
+
480
+ n_treated = int(np.sum(treated_mask))
481
+ n_control = int(np.sum(control_mask))
482
+ n_events = min(n_treated, n_control)
483
+ epv = n_events / n_params if n_params > 0 else float("inf")
484
+
485
+ if epv >= self.epv_threshold:
486
+ status = "ok"
487
+ elif epv >= 2:
488
+ status = "low"
489
+ else:
490
+ status = "critical"
491
+
492
+ rows.append(
493
+ {
494
+ "group": g,
495
+ "n_treated": n_treated,
496
+ "n_control": n_control,
497
+ "n_covariates": n_covariates,
498
+ "n_params": n_params,
499
+ "epv": round(epv, 1),
500
+ "status": status,
501
+ }
502
+ )
503
+
504
+ return pd.DataFrame(rows)
505
+
506
+ @staticmethod
507
+ def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
508
+ """Create unit-level ResolvedSurveyDesign for panel IF-based variance.
509
+
510
+ Survey design columns are constant within units (validated upstream).
511
+ This extracts one row per unit, aligned to ``all_units`` ordering.
512
+ """
513
+ from diff_diff.survey import collapse_survey_to_unit_level
514
+
515
+ return collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units)
516
+
517
+ def _precompute_structures(
518
+ self,
519
+ df: pd.DataFrame,
520
+ outcome: str,
521
+ unit: str,
522
+ time: str,
523
+ first_treat: str,
524
+ covariates: Optional[List[str]],
525
+ time_periods: List[Any],
526
+ treatment_groups: List[Any],
527
+ resolved_survey=None,
528
+ ) -> PrecomputedData:
529
+ """
530
+ Pre-compute data structures for efficient ATT(g,t) computation.
531
+
532
+ This pivots data to wide format and pre-computes:
533
+ - Outcome matrix (units x time periods)
534
+ - Covariate matrix (units x covariates) from base period
535
+ - Unit cohort membership masks
536
+ - Control unit masks
537
+
538
+ Returns
539
+ -------
540
+ PrecomputedData
541
+ Dictionary with pre-computed structures.
542
+ """
543
+ # Get unique units and their cohort assignments
544
+ unit_info = df.groupby(unit)[first_treat].first()
545
+ all_units = unit_info.index.values
546
+ unit_cohorts = unit_info.values
547
+
548
+ # Create unit index mapping for fast lookups
549
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
550
+
551
+ # Pivot outcome to wide format: rows = units, columns = time periods
552
+ outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
553
+ # Reindex to ensure all units are present (handles unbalanced panels)
554
+ outcome_wide = outcome_wide.reindex(all_units)
555
+ outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
556
+ period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
557
+
558
+ # Pre-compute cohort masks (boolean arrays)
559
+ cohort_masks = {}
560
+ for g in treatment_groups:
561
+ cohort_masks[g] = unit_cohorts == g
562
+
563
+ # Never-treated mask
564
+ # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
565
+ never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
566
+
567
+ # Pre-compute covariate matrices by time period if needed
568
+ # (covariates are retrieved from the base period of each comparison)
569
+ covariate_by_period = None
570
+ if covariates:
571
+ covariate_by_period = {}
572
+ for t in time_periods:
573
+ period_data = df[df[time] == t].set_index(unit)
574
+ period_cov = period_data.reindex(all_units)[covariates]
575
+ covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
576
+
577
+ is_balanced = not np.any(np.isnan(outcome_matrix))
578
+
579
+ # Extract per-unit survey weights (one weight per unit)
580
+ if resolved_survey is not None:
581
+ sw_by_unit = (
582
+ pd.Series(resolved_survey.weights, index=df.index).groupby(df[unit]).first()
583
+ )
584
+ survey_weights_arr = sw_by_unit.reindex(all_units).values
585
+ else:
586
+ survey_weights_arr = None
587
+
588
+ resolved_survey_unit = (
589
+ self._collapse_survey_to_unit_level(resolved_survey, df, unit, all_units)
590
+ if resolved_survey is not None
591
+ else None
592
+ )
593
+
594
+ return {
595
+ "all_units": all_units,
596
+ "unit_to_idx": unit_to_idx,
597
+ "unit_cohorts": unit_cohorts,
598
+ "outcome_matrix": outcome_matrix,
599
+ "period_to_col": period_to_col,
600
+ "cohort_masks": cohort_masks,
601
+ "never_treated_mask": never_treated_mask,
602
+ "covariate_by_period": covariate_by_period,
603
+ "time_periods": time_periods,
604
+ "is_balanced": is_balanced,
605
+ "is_panel": True,
606
+ "canonical_size": len(all_units),
607
+ "survey_weights": survey_weights_arr,
608
+ "resolved_survey": resolved_survey,
609
+ "resolved_survey_unit": resolved_survey_unit,
610
+ "df_survey": (
611
+ resolved_survey_unit.df_survey if resolved_survey_unit is not None else None
612
+ ),
613
+ }
614
+
615
+ def _compute_att_gt_fast(
616
+ self,
617
+ precomputed: PrecomputedData,
618
+ g: Any,
619
+ t: Any,
620
+ covariates: Optional[List[str]],
621
+ pscore_cache: Optional[Dict] = None,
622
+ cho_cache: Optional[Dict] = None,
623
+ epv_diagnostics: Optional[Dict] = None,
624
+ ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]:
625
+ """
626
+ Compute ATT(g,t) using pre-computed data structures (fast version).
627
+
628
+ Uses vectorized numpy operations on pre-pivoted outcome matrix
629
+ instead of repeated pandas filtering.
630
+
631
+ Returns
632
+ -------
633
+ att_gt : float or None
634
+ se_gt : float
635
+ n_treated : int
636
+ n_control : int
637
+ inf_func_info : dict or None
638
+ survey_weight_sum : float or None
639
+ Sum of survey weights for treated units (for aggregation weighting).
640
+ """
641
+ period_to_col = precomputed["period_to_col"]
642
+ outcome_matrix = precomputed["outcome_matrix"]
643
+ cohort_masks = precomputed["cohort_masks"]
644
+ never_treated_mask = precomputed["never_treated_mask"]
645
+ unit_cohorts = precomputed["unit_cohorts"]
646
+ covariate_by_period = precomputed["covariate_by_period"]
647
+
648
+ # Base period selection based on mode
649
+ if self.base_period == "universal":
650
+ # Universal: always use g - 1 - anticipation
651
+ base_period_val = g - 1 - self.anticipation
652
+ else: # varying
653
+ if t < g - self.anticipation:
654
+ # Pre-treatment: use t - 1 (consecutive comparison)
655
+ base_period_val = t - 1
656
+ else:
657
+ # Post-treatment: use g - 1 - anticipation
658
+ base_period_val = g - 1 - self.anticipation
659
+
660
+ if base_period_val not in period_to_col:
661
+ # Base period must exist; no fallback to maintain methodological consistency
662
+ return None, 0.0, 0, 0, None, None
663
+
664
+ # Check if periods exist in the data
665
+ if base_period_val not in period_to_col or t not in period_to_col:
666
+ return None, 0.0, 0, 0, None, None
667
+
668
+ base_col = period_to_col[base_period_val]
669
+ post_col = period_to_col[t]
670
+
671
+ # Get treated units mask (cohort g)
672
+ treated_mask = cohort_masks[g]
673
+
674
+ # Get control units mask
675
+ if self.control_group == "never_treated":
676
+ control_mask = never_treated_mask
677
+ else: # not_yet_treated
678
+ # Not yet treated at BOTH time t and the base period:
679
+ # Controls must be untreated at whichever is later, otherwise
680
+ # their outcome at the base period is contaminated by treatment.
681
+ nyt_threshold = max(t, base_period_val) + self.anticipation
682
+ control_mask = never_treated_mask | (
683
+ (unit_cohorts > nyt_threshold) & (unit_cohorts != g)
684
+ )
685
+
686
+ # Extract outcomes for base and post periods
687
+ y_base = outcome_matrix[:, base_col]
688
+ y_post = outcome_matrix[:, post_col]
689
+
690
+ # Compute outcome changes (vectorized)
691
+ outcome_change = y_post - y_base
692
+
693
+ # Filter to units with valid data (no NaN in either period)
694
+ valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
695
+
696
+ # Get treated and control with valid data
697
+ treated_valid = treated_mask & valid_mask
698
+ control_valid = control_mask & valid_mask
699
+
700
+ n_treated = np.sum(treated_valid)
701
+ n_control = np.sum(control_valid)
702
+
703
+ if n_treated == 0 or n_control == 0:
704
+ return None, 0.0, 0, 0, None, None
705
+
706
+ # Extract outcome changes for treated and control
707
+ treated_change = outcome_change[treated_valid]
708
+ control_change = outcome_change[control_valid]
709
+
710
+ # Extract survey weights for treated and control
711
+ survey_w = precomputed.get("survey_weights")
712
+ sw_treated = survey_w[treated_valid] if survey_w is not None else None
713
+ sw_control = survey_w[control_valid] if survey_w is not None else None
714
+
715
+ # Guard against zero effective mass after subpopulation filtering
716
+ if sw_treated is not None and np.sum(sw_treated) <= 0:
717
+ return None, 0.0, 0, 0, None, None
718
+ if sw_control is not None and np.sum(sw_control) <= 0:
719
+ return None, 0.0, 0, 0, None, None
720
+
721
+ # Get covariates if specified (from the base period)
722
+ X_treated = None
723
+ X_control = None
724
+ if covariates and covariate_by_period is not None:
725
+ cov_matrix = covariate_by_period[base_period_val]
726
+ X_treated = cov_matrix[treated_valid]
727
+ X_control = cov_matrix[control_valid]
728
+
729
+ # Check for missing values
730
+ if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
731
+ warnings.warn(
732
+ f"Missing values in covariates for group {g}, time {t}. "
733
+ "Falling back to unconditional estimation.",
734
+ UserWarning,
735
+ stacklevel=3,
736
+ )
737
+ X_treated = None
738
+ X_control = None
739
+
740
+ # Compute cache key for propensity score reuse
741
+ pscore_key = None
742
+ if pscore_cache is not None and X_treated is not None:
743
+ is_balanced = precomputed.get("is_balanced", False)
744
+ if is_balanced and self.control_group == "never_treated":
745
+ pscore_key = (g, base_period_val)
746
+ else:
747
+ pscore_key = (g, base_period_val, t)
748
+
749
+ # Compute cache key for Cholesky reuse (DR outcome regression)
750
+ cho_key = None
751
+ if cho_cache is not None and X_control is not None:
752
+ is_balanced = precomputed.get("is_balanced", False)
753
+ if is_balanced and self.control_group == "never_treated":
754
+ cho_key = base_period_val
755
+ else:
756
+ cho_key = (g, base_period_val, t)
757
+
758
+ # Estimation method
759
+ if self.estimation_method == "reg":
760
+ att_gt, se_gt, inf_func = self._outcome_regression(
761
+ treated_change,
762
+ control_change,
763
+ X_treated,
764
+ X_control,
765
+ sw_treated=sw_treated,
766
+ sw_control=sw_control,
767
+ )
768
+ elif self.estimation_method == "ipw":
769
+ sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None
770
+ epv_diag: dict = {}
771
+ att_gt, se_gt, inf_func = self._ipw_estimation(
772
+ treated_change,
773
+ control_change,
774
+ int(n_treated),
775
+ int(n_control),
776
+ X_treated,
777
+ X_control,
778
+ pscore_cache=pscore_cache,
779
+ pscore_key=pscore_key,
780
+ sw_treated=sw_treated,
781
+ sw_control=sw_control,
782
+ sw_all=sw_all,
783
+ context_label=f"cohort g={g}",
784
+ epv_diagnostics_out=epv_diag,
785
+ )
786
+ if epv_diagnostics is not None and epv_diag:
787
+ epv_diagnostics[(g, t)] = epv_diag
788
+ else: # doubly robust
789
+ sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None
790
+ epv_diag = {}
791
+ att_gt, se_gt, inf_func = self._doubly_robust(
792
+ treated_change,
793
+ control_change,
794
+ X_treated,
795
+ X_control,
796
+ pscore_cache=pscore_cache,
797
+ pscore_key=pscore_key,
798
+ cho_cache=cho_cache,
799
+ cho_key=cho_key,
800
+ sw_treated=sw_treated,
801
+ sw_control=sw_control,
802
+ sw_all=sw_all,
803
+ context_label=f"cohort g={g}",
804
+ epv_diagnostics_out=epv_diag,
805
+ )
806
+ if epv_diagnostics is not None and epv_diag:
807
+ epv_diagnostics[(g, t)] = epv_diag
808
+
809
+ # Package influence function info with index arrays (positions into
810
+ # precomputed['all_units']) for O(1) downstream lookups instead of
811
+ # O(n) Python dict lookups.
812
+ n_t = int(n_treated)
813
+ all_units = precomputed["all_units"]
814
+ treated_positions = np.where(treated_valid)[0]
815
+ control_positions = np.where(control_valid)[0]
816
+ inf_func_info = {
817
+ "treated_idx": treated_positions,
818
+ "control_idx": control_positions,
819
+ "treated_units": all_units[treated_positions],
820
+ "control_units": all_units[control_positions],
821
+ "treated_inf": inf_func[:n_t],
822
+ "control_inf": inf_func[n_t:],
823
+ }
824
+
825
+ sw_sum = float(np.sum(sw_treated)) if sw_treated is not None else None
826
+ return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info, sw_sum
827
+
828
+ def _compute_all_att_gt_vectorized(
829
+ self,
830
+ precomputed: PrecomputedData,
831
+ treatment_groups: List[Any],
832
+ time_periods: List[Any],
833
+ min_period: Any,
834
+ ) -> Tuple[Dict, Dict, Dict]:
835
+ """
836
+ Vectorized computation of all ATT(g,t) for the no-covariates regression case.
837
+
838
+ This inlines the simple difference-in-means path from _outcome_regression()
839
+ and eliminates per-(g,t) Python function call overhead.
840
+
841
+ Returns
842
+ -------
843
+ group_time_effects : dict
844
+ Mapping (g, t) -> effect dict.
845
+ influence_func_info : dict
846
+ Mapping (g, t) -> influence function info dict.
847
+ """
848
+ period_to_col = precomputed["period_to_col"]
849
+ outcome_matrix = precomputed["outcome_matrix"]
850
+ cohort_masks = precomputed["cohort_masks"]
851
+ never_treated_mask = precomputed["never_treated_mask"]
852
+ unit_cohorts = precomputed["unit_cohorts"]
853
+ survey_w = precomputed.get("survey_weights")
854
+
855
+ group_time_effects = {}
856
+ influence_func_info = {}
857
+ skipped_missing_period: List[Tuple] = []
858
+ skipped_empty_cell: List[Tuple] = []
859
+
860
+ # Collect all valid (g, t, base_col, post_col) tuples
861
+ tasks = []
862
+ for g in treatment_groups:
863
+ if self.base_period == "universal":
864
+ universal_base = g - 1 - self.anticipation
865
+ valid_periods = [t for t in time_periods if t != universal_base]
866
+ else:
867
+ valid_periods = [
868
+ t for t in time_periods if t >= g - self.anticipation or t > min_period
869
+ ]
870
+
871
+ for t in valid_periods:
872
+ # Base period selection
873
+ if self.base_period == "universal":
874
+ base_period_val = g - 1 - self.anticipation
875
+ else:
876
+ if t < g - self.anticipation:
877
+ base_period_val = t - 1
878
+ else:
879
+ base_period_val = g - 1 - self.anticipation
880
+
881
+ if base_period_val not in period_to_col or t not in period_to_col:
882
+ skipped_missing_period.append((g, t))
883
+ continue
884
+
885
+ tasks.append(
886
+ (g, t, period_to_col[base_period_val], period_to_col[t], base_period_val)
887
+ )
888
+
889
+ # Process all tasks
890
+ atts = []
891
+ ses = []
892
+ task_keys = []
893
+
894
+ for g, t, base_col, post_col, base_period_val in tasks:
895
+ treated_mask = cohort_masks[g]
896
+
897
+ if self.control_group == "never_treated":
898
+ control_mask = never_treated_mask
899
+ else:
900
+ # Controls must be untreated at both t and base_period_val
901
+ nyt_threshold = max(t, base_period_val) + self.anticipation
902
+ control_mask = never_treated_mask | (
903
+ (unit_cohorts > nyt_threshold) & (unit_cohorts != g)
904
+ )
905
+
906
+ y_base = outcome_matrix[:, base_col]
907
+ y_post = outcome_matrix[:, post_col]
908
+ outcome_change = y_post - y_base
909
+ valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
910
+
911
+ treated_valid = treated_mask & valid_mask
912
+ control_valid = control_mask & valid_mask
913
+
914
+ n_treated = np.sum(treated_valid)
915
+ n_control = np.sum(control_valid)
916
+
917
+ if n_treated == 0 or n_control == 0:
918
+ skipped_empty_cell.append((g, t))
919
+ continue
920
+
921
+ treated_change = outcome_change[treated_valid]
922
+ control_change = outcome_change[control_valid]
923
+
924
+ n_t = int(n_treated)
925
+ n_c = int(n_control)
926
+
927
+ # Inline no-covariates regression (difference in means)
928
+ if survey_w is not None:
929
+ sw_t = survey_w[treated_valid]
930
+ sw_c = survey_w[control_valid]
931
+ # Guard against zero effective mass
932
+ if np.sum(sw_t) <= 0 or np.sum(sw_c) <= 0:
933
+ skipped_empty_cell.append((g, t))
934
+ continue
935
+ sw_t_norm = sw_t / np.sum(sw_t)
936
+ sw_c_norm = sw_c / np.sum(sw_c)
937
+ mu_t = float(np.sum(sw_t_norm * treated_change))
938
+ mu_c = float(np.sum(sw_c_norm * control_change))
939
+ att = mu_t - mu_c
940
+
941
+ # Influence function (survey-weighted)
942
+ inf_treated = sw_t_norm * (treated_change - mu_t)
943
+ inf_control = -sw_c_norm * (control_change - mu_c)
944
+ # SE derived from IF: sum(IF_i^2)
945
+ se = (
946
+ float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
947
+ if (n_t > 0 and n_c > 0)
948
+ else 0.0
949
+ )
950
+ sw_sum = float(np.sum(sw_t))
951
+ else:
952
+ att = float(np.mean(treated_change) - np.mean(control_change))
953
+
954
+ var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0
955
+ var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0
956
+ se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
957
+
958
+ # Influence function
959
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
960
+ inf_control = -(control_change - np.mean(control_change)) / n_c
961
+ sw_sum = None
962
+
963
+ gte_entry = {
964
+ "effect": att,
965
+ "se": se,
966
+ # t_stat, p_value, conf_int filled by batch inference below
967
+ "t_stat": np.nan,
968
+ "p_value": np.nan,
969
+ "conf_int": (np.nan, np.nan),
970
+ "n_treated": n_t,
971
+ "n_control": n_c,
972
+ }
973
+ if sw_sum is not None:
974
+ gte_entry["survey_weight_sum"] = sw_sum
975
+ group_time_effects[(g, t)] = gte_entry
976
+
977
+ all_units = precomputed["all_units"]
978
+ treated_positions = np.where(treated_valid)[0]
979
+ control_positions = np.where(control_valid)[0]
980
+ influence_func_info[(g, t)] = {
981
+ "treated_idx": treated_positions,
982
+ "control_idx": control_positions,
983
+ "treated_units": all_units[treated_positions],
984
+ "control_units": all_units[control_positions],
985
+ "treated_inf": inf_treated,
986
+ "control_inf": inf_control,
987
+ }
988
+
989
+ atts.append(att)
990
+ ses.append(se)
991
+ task_keys.append((g, t))
992
+
993
+ # Batch inference for all (g,t) pairs at once
994
+ if task_keys:
995
+ df_survey_val = precomputed.get("df_survey")
996
+ # Guard: replicate design with undefined df → NaN inference
997
+ if (
998
+ df_survey_val is None
999
+ and precomputed.get("resolved_survey_unit") is not None
1000
+ and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
1001
+ and precomputed["resolved_survey_unit"].uses_replicate_variance
1002
+ ):
1003
+ df_survey_val = 0
1004
+ t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
1005
+ np.array(atts),
1006
+ np.array(ses),
1007
+ alpha=self.alpha,
1008
+ df=df_survey_val,
1009
+ )
1010
+ for idx, key in enumerate(task_keys):
1011
+ group_time_effects[key]["t_stat"] = float(t_stats[idx])
1012
+ group_time_effects[key]["p_value"] = float(p_values[idx])
1013
+ group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx]))
1014
+
1015
+ skip_info = {
1016
+ "missing_period": skipped_missing_period,
1017
+ "empty_cell": skipped_empty_cell,
1018
+ }
1019
+ return group_time_effects, influence_func_info, skip_info
1020
+
1021
+ def _compute_all_att_gt_covariate_reg(
1022
+ self,
1023
+ precomputed: PrecomputedData,
1024
+ treatment_groups: List[Any],
1025
+ time_periods: List[Any],
1026
+ min_period: Any,
1027
+ ) -> Tuple[Dict, Dict, Dict]:
1028
+ """
1029
+ Optimized computation of all ATT(g,t) for the covariate regression case.
1030
+
1031
+ Groups (g,t) pairs by their control regression key to reuse Cholesky
1032
+ factorizations of X^T X across pairs that share the same control design
1033
+ matrix.
1034
+
1035
+ Returns
1036
+ -------
1037
+ group_time_effects : dict
1038
+ Mapping (g, t) -> effect dict.
1039
+ influence_func_info : dict
1040
+ Mapping (g, t) -> influence function info dict.
1041
+ """
1042
+ period_to_col = precomputed["period_to_col"]
1043
+ outcome_matrix = precomputed["outcome_matrix"]
1044
+ cohort_masks = precomputed["cohort_masks"]
1045
+ never_treated_mask = precomputed["never_treated_mask"]
1046
+ unit_cohorts = precomputed["unit_cohorts"]
1047
+ covariate_by_period = precomputed["covariate_by_period"]
1048
+ is_balanced = precomputed["is_balanced"]
1049
+
1050
+ group_time_effects = {}
1051
+ influence_func_info = {}
1052
+ atts = []
1053
+ ses = []
1054
+ task_keys = []
1055
+ n_nan_cells = 0
1056
+ skipped_missing_period: List[Tuple] = []
1057
+ skipped_empty_cell: List[Tuple] = []
1058
+
1059
+ # Collect all valid (g, t) tasks with their base periods
1060
+ tasks_by_group = {} # control_key -> list of (g, t, base_period_val, base_col, post_col)
1061
+ for g in treatment_groups:
1062
+ if self.base_period == "universal":
1063
+ universal_base = g - 1 - self.anticipation
1064
+ valid_periods = [t for t in time_periods if t != universal_base]
1065
+ else:
1066
+ valid_periods = [
1067
+ t for t in time_periods if t >= g - self.anticipation or t > min_period
1068
+ ]
1069
+
1070
+ for t in valid_periods:
1071
+ if self.base_period == "universal":
1072
+ base_period_val = g - 1 - self.anticipation
1073
+ else:
1074
+ if t < g - self.anticipation:
1075
+ base_period_val = t - 1
1076
+ else:
1077
+ base_period_val = g - 1 - self.anticipation
1078
+
1079
+ if base_period_val not in period_to_col or t not in period_to_col:
1080
+ skipped_missing_period.append((g, t))
1081
+ continue
1082
+
1083
+ # Determine control regression grouping key.
1084
+ # For balanced panels with never_treated control, X_control depends
1085
+ # only on base_period_val (control mask is time-invariant).
1086
+ # For not_yet_treated, the control mask excludes cohort g, so include g.
1087
+ if is_balanced and self.control_group == "never_treated":
1088
+ control_key = base_period_val
1089
+ else:
1090
+ control_key = (g, base_period_val, t)
1091
+
1092
+ tasks_by_group.setdefault(control_key, []).append(
1093
+ (g, t, base_period_val, period_to_col[base_period_val], period_to_col[t])
1094
+ )
1095
+
1096
+ # Process each group of tasks sharing the same control regression
1097
+ for control_key, tasks in tasks_by_group.items():
1098
+ # Use the first task to build X_control (same for all in the group)
1099
+ first_g, first_t, base_period_val, first_base_col, first_post_col = tasks[0]
1100
+
1101
+ cov_matrix = covariate_by_period[base_period_val]
1102
+
1103
+ # Build control mask (same for all tasks in this group)
1104
+ if self.control_group == "never_treated":
1105
+ control_mask = never_treated_mask
1106
+ else:
1107
+ # Controls must be untreated at both t and base_period_val
1108
+ nyt_threshold = max(first_t, base_period_val) + self.anticipation
1109
+ control_mask = never_treated_mask | (
1110
+ (unit_cohorts > nyt_threshold) & (unit_cohorts != first_g)
1111
+ )
1112
+
1113
+ # For balanced panels, valid_mask is all True so control_valid = control_mask
1114
+ if is_balanced:
1115
+ control_valid_base = control_mask
1116
+ else:
1117
+ y_base_first = outcome_matrix[:, first_base_col]
1118
+ y_post_first = outcome_matrix[:, first_post_col]
1119
+ valid_first = ~(np.isnan(y_base_first) | np.isnan(y_post_first))
1120
+ control_valid_base = control_mask & valid_first
1121
+
1122
+ X_ctrl_raw = cov_matrix[control_valid_base]
1123
+
1124
+ # Check for NaN in control covariates
1125
+ ctrl_has_nan = bool(np.any(np.isnan(X_ctrl_raw)))
1126
+
1127
+ # Build X_ctrl with intercept
1128
+ n_c_base = int(np.sum(control_valid_base))
1129
+ if n_c_base == 0:
1130
+ skipped_empty_cell.extend((g, t) for g, t, *_ in tasks)
1131
+ continue
1132
+
1133
+ X_ctrl = None
1134
+ cho = None
1135
+ kept_cols = None
1136
+ if not ctrl_has_nan:
1137
+ X_ctrl = np.column_stack([np.ones(n_c_base), X_ctrl_raw])
1138
+
1139
+ # One-time rank check for this control group
1140
+ rank, dropped_cols, _ = _detect_rank_deficiency(X_ctrl)
1141
+
1142
+ if len(dropped_cols) > 0:
1143
+ # Rank-deficient: force lstsq for both "warn" and "silent".
1144
+ # Cholesky on near-singular XtX could yield unstable coefficients.
1145
+ if self.rank_deficient_action == "warn":
1146
+ col_info = _format_dropped_columns(dropped_cols)
1147
+ warnings.warn(
1148
+ f"Rank-deficient covariate design (control_key={control_key}): "
1149
+ f"dropped columns {col_info}. Rank {rank} < {X_ctrl.shape[1]}. "
1150
+ "Using minimum-norm least-squares solution.",
1151
+ UserWarning,
1152
+ stacklevel=2,
1153
+ )
1154
+ cho = None # Force lstsq path for ALL rank-deficient cases
1155
+ kept_cols = np.array(
1156
+ [i for i in range(X_ctrl.shape[1]) if i not in dropped_cols]
1157
+ )
1158
+ else:
1159
+ kept_cols = None # Full rank — use all columns
1160
+ with np.errstate(all="ignore"):
1161
+ XtX = X_ctrl.T @ X_ctrl
1162
+ try:
1163
+ cho = scipy_linalg.cho_factor(XtX)
1164
+ except np.linalg.LinAlgError:
1165
+ cho = None
1166
+
1167
+ # Process each (g, t) pair in this group
1168
+ for g, t, bp_val, base_col, post_col in tasks:
1169
+ treated_mask = cohort_masks[g]
1170
+
1171
+ # Recompute control mask for not_yet_treated (varies by g, t)
1172
+ if self.control_group == "not_yet_treated":
1173
+ # Controls must be untreated at both t and base period
1174
+ nyt_threshold = max(t, bp_val) + self.anticipation
1175
+ control_mask = never_treated_mask | (
1176
+ (unit_cohorts > nyt_threshold) & (unit_cohorts != g)
1177
+ )
1178
+
1179
+ y_base = outcome_matrix[:, base_col]
1180
+ y_post = outcome_matrix[:, post_col]
1181
+ outcome_change = y_post - y_base
1182
+
1183
+ if is_balanced:
1184
+ valid_mask_pair = np.ones(len(y_base), dtype=bool)
1185
+ else:
1186
+ valid_mask_pair = ~(np.isnan(y_base) | np.isnan(y_post))
1187
+
1188
+ treated_valid = treated_mask & valid_mask_pair
1189
+ # For balanced + never_treated, control_valid is same as control_valid_base
1190
+ if is_balanced and self.control_group == "never_treated":
1191
+ control_valid = control_valid_base
1192
+ else:
1193
+ control_valid = control_mask & valid_mask_pair
1194
+
1195
+ n_t = int(np.sum(treated_valid))
1196
+ n_c = int(np.sum(control_valid))
1197
+
1198
+ if n_t == 0 or n_c == 0:
1199
+ skipped_empty_cell.append((g, t))
1200
+ continue
1201
+
1202
+ treated_change = outcome_change[treated_valid]
1203
+ control_change = outcome_change[control_valid]
1204
+
1205
+ X_treated_pair = cov_matrix[treated_valid]
1206
+ X_control_pair = cov_matrix[control_valid]
1207
+
1208
+ # Check for NaN in this pair's covariates
1209
+ if np.any(np.isnan(X_treated_pair)) or np.any(np.isnan(X_control_pair)):
1210
+ # Fall back to unconditional (difference in means)
1211
+ warnings.warn(
1212
+ f"Missing values in covariates for group {g}, time {t}. "
1213
+ "Falling back to unconditional estimation.",
1214
+ UserWarning,
1215
+ stacklevel=3,
1216
+ )
1217
+ att = float(np.mean(treated_change) - np.mean(control_change))
1218
+ var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0
1219
+ var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0
1220
+ se = float(np.sqrt(var_t / n_t + var_c / n_c))
1221
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
1222
+ inf_control = -(control_change - np.mean(control_change)) / n_c
1223
+ else:
1224
+ # Build per-pair X_ctrl if control_valid differs from base
1225
+ if is_balanced and self.control_group == "never_treated" and X_ctrl is not None:
1226
+ pair_X_ctrl = X_ctrl
1227
+ pair_n_c = n_c_base
1228
+ else:
1229
+ pair_X_ctrl = np.column_stack([np.ones(n_c), X_control_pair])
1230
+ pair_n_c = n_c
1231
+
1232
+ # Solve for beta
1233
+ beta = None
1234
+ with np.errstate(all="ignore"):
1235
+ if (
1236
+ cho is not None
1237
+ and is_balanced
1238
+ and self.control_group == "never_treated"
1239
+ ):
1240
+ # Use cached Cholesky
1241
+ Xty = pair_X_ctrl.T @ control_change
1242
+ beta = scipy_linalg.cho_solve(cho, Xty)
1243
+ else:
1244
+ # Compute per-pair Cholesky or lstsq fallback
1245
+ if kept_cols is not None:
1246
+ # Rank-deficient: skip Cholesky, use reduced lstsq
1247
+ pass
1248
+ else:
1249
+ pair_XtX = pair_X_ctrl.T @ pair_X_ctrl
1250
+ try:
1251
+ pair_cho = scipy_linalg.cho_factor(pair_XtX)
1252
+ Xty = pair_X_ctrl.T @ control_change
1253
+ beta = scipy_linalg.cho_solve(pair_cho, Xty)
1254
+ except np.linalg.LinAlgError:
1255
+ pass
1256
+
1257
+ if beta is None or np.any(~np.isfinite(beta)):
1258
+ if kept_cols is not None:
1259
+ # Reduced solve for rank-deficient design
1260
+ result = scipy_linalg.lstsq(
1261
+ pair_X_ctrl[:, kept_cols],
1262
+ control_change,
1263
+ cond=1e-07,
1264
+ )
1265
+ beta = np.zeros(pair_X_ctrl.shape[1])
1266
+ beta[kept_cols] = result[0]
1267
+ else:
1268
+ # Full-rank lstsq fallback (Cholesky numerical failure)
1269
+ result = scipy_linalg.lstsq(
1270
+ pair_X_ctrl,
1271
+ control_change,
1272
+ cond=1e-07,
1273
+ )
1274
+ beta = result[0]
1275
+
1276
+ nan_cell = False
1277
+
1278
+ if beta is None or np.any(~np.isfinite(beta)):
1279
+ nan_cell = True
1280
+ n_nan_cells += 1
1281
+
1282
+ if not nan_cell:
1283
+ X_treated_w_intercept = np.column_stack([np.ones(n_t), X_treated_pair])
1284
+ with np.errstate(all="ignore"):
1285
+ predicted_control = X_treated_w_intercept @ beta
1286
+ treated_residuals = treated_change - predicted_control
1287
+ if np.any(~np.isfinite(predicted_control)):
1288
+ nan_cell = True
1289
+ n_nan_cells += 1
1290
+
1291
+ if not nan_cell:
1292
+ att = float(np.mean(treated_residuals))
1293
+ with np.errstate(all="ignore"):
1294
+ residuals = control_change - pair_X_ctrl @ beta
1295
+ if np.any(~np.isfinite(residuals)):
1296
+ nan_cell = True
1297
+ n_nan_cells += 1
1298
+
1299
+ if nan_cell:
1300
+ att = np.nan
1301
+ se = np.nan
1302
+ inf_treated = np.zeros(n_t)
1303
+ inf_control = np.zeros(n_c)
1304
+ else:
1305
+ var_t = float(np.var(treated_residuals, ddof=1)) if n_t > 1 else 0.0
1306
+ var_c = float(np.var(residuals, ddof=1)) if pair_n_c > 1 else 0.0
1307
+ se = float(np.sqrt(var_t / n_t + var_c / pair_n_c))
1308
+ inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
1309
+ inf_control = -residuals / pair_n_c
1310
+
1311
+ group_time_effects[(g, t)] = {
1312
+ "effect": att,
1313
+ "se": se,
1314
+ "t_stat": np.nan,
1315
+ "p_value": np.nan,
1316
+ "conf_int": (np.nan, np.nan),
1317
+ "n_treated": n_t,
1318
+ "n_control": n_c,
1319
+ }
1320
+
1321
+ all_units = precomputed["all_units"]
1322
+ treated_positions = np.where(treated_valid)[0]
1323
+ control_positions = np.where(control_valid)[0]
1324
+ influence_func_info[(g, t)] = {
1325
+ "treated_idx": treated_positions,
1326
+ "control_idx": control_positions,
1327
+ "treated_units": all_units[treated_positions],
1328
+ "control_units": all_units[control_positions],
1329
+ "treated_inf": inf_treated,
1330
+ "control_inf": inf_control,
1331
+ }
1332
+
1333
+ atts.append(att)
1334
+ ses.append(se)
1335
+ task_keys.append((g, t))
1336
+
1337
+ if n_nan_cells > 0:
1338
+ warnings.warn(
1339
+ f"{n_nan_cells} group-time cell(s) have non-finite regression results "
1340
+ "(near-singular covariates). These cells are preserved with NaN inference.",
1341
+ UserWarning,
1342
+ stacklevel=2,
1343
+ )
1344
+
1345
+ # Batch inference
1346
+ if task_keys:
1347
+ # Use survey df for replicate designs (propagated from precomputed)
1348
+ _ipw_dr_df = precomputed.get("df_survey") if precomputed is not None else None
1349
+ # Guard: replicate design with undefined df → NaN inference
1350
+ if (
1351
+ _ipw_dr_df is None
1352
+ and precomputed is not None
1353
+ and precomputed.get("resolved_survey_unit") is not None
1354
+ and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
1355
+ and precomputed["resolved_survey_unit"].uses_replicate_variance
1356
+ ):
1357
+ _ipw_dr_df = 0
1358
+ t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
1359
+ np.array(atts), np.array(ses), alpha=self.alpha, df=_ipw_dr_df
1360
+ )
1361
+ for idx, key in enumerate(task_keys):
1362
+ group_time_effects[key]["t_stat"] = float(t_stats[idx])
1363
+ group_time_effects[key]["p_value"] = float(p_values[idx])
1364
+ group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx]))
1365
+
1366
+ skip_info = {
1367
+ "missing_period": skipped_missing_period,
1368
+ "empty_cell": skipped_empty_cell,
1369
+ }
1370
+ return group_time_effects, influence_func_info, skip_info
1371
+
1372
+ def fit(
1373
+ self,
1374
+ data: pd.DataFrame,
1375
+ outcome: str,
1376
+ unit: str,
1377
+ time: str,
1378
+ first_treat: str,
1379
+ covariates: Optional[List[str]] = None,
1380
+ aggregate: Optional[str] = None,
1381
+ balance_e: Optional[int] = None,
1382
+ survey_design: object = None,
1383
+ ) -> CallawaySantAnnaResults:
1384
+ """
1385
+ Fit the Callaway-Sant'Anna estimator.
1386
+
1387
+ Parameters
1388
+ ----------
1389
+ data : pd.DataFrame
1390
+ Panel data with unit and time identifiers. For repeated
1391
+ cross-sections (``panel=False``), each observation should
1392
+ have a unique unit ID — units do not repeat across periods.
1393
+ outcome : str
1394
+ Name of outcome variable column.
1395
+ unit : str
1396
+ Name of unit identifier column.
1397
+ time : str
1398
+ Name of time period column.
1399
+ first_treat : str
1400
+ Name of column indicating when unit was first treated.
1401
+ Use 0 (or np.inf) for never-treated units.
1402
+ covariates : list, optional
1403
+ List of covariate column names for conditional parallel trends.
1404
+ aggregate : str, optional
1405
+ How to aggregate group-time effects:
1406
+ - None: Only compute ATT(g,t) (default)
1407
+ - "simple": Simple weighted average (overall ATT)
1408
+ - "event_study": Aggregate by relative time (event study)
1409
+ - "group": Aggregate by treatment cohort
1410
+ - "all": Compute all aggregations
1411
+ balance_e : int, optional
1412
+ For event study, balance the panel at relative time e.
1413
+ Ensures all groups contribute to each relative period.
1414
+ survey_design : SurveyDesign, optional
1415
+ Survey design specification. Supports pweight with strata/PSU/FPC.
1416
+ Aggregated SEs (overall, event study, group) use design-based
1417
+ variance via compute_survey_if_variance(). All estimation methods
1418
+ (reg, ipw, dr) support covariates + survey. For repeated
1419
+ cross-sections (``panel=False``), survey weights are
1420
+ per-observation (no unit-level collapse).
1421
+
1422
+ Returns
1423
+ -------
1424
+ CallawaySantAnnaResults
1425
+ Object containing all estimation results.
1426
+
1427
+ Raises
1428
+ ------
1429
+ ValueError
1430
+ If required columns are missing or data validation fails.
1431
+ """
1432
+ # Validate pscore_trim (may have been changed via set_params)
1433
+ if not (0 < self.pscore_trim < 0.5):
1434
+ raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}")
1435
+
1436
+ # Reset stale state from prior fit (prevents leaking event-study VCV)
1437
+ self._event_study_vcov = None
1438
+
1439
+ if not self.panel:
1440
+ warnings.warn(
1441
+ "panel=False uses repeated cross-section DRDID estimators "
1442
+ "(Sant'Anna & Zhao 2020, Section 4) which assume stationary "
1443
+ "cross-sectional sampling: the population distribution of "
1444
+ "(Y, X, G) must be stable across periods. This assumption "
1445
+ "is not data-checkable.",
1446
+ UserWarning,
1447
+ stacklevel=2,
1448
+ )
1449
+
1450
+ # Validate unique unit IDs for panel=False
1451
+ if not self.panel:
1452
+ if data[unit].duplicated().any():
1453
+ raise ValueError(
1454
+ "panel=False requires unique unit IDs (one observation per unit). "
1455
+ "Found duplicate unit IDs. If your data is a panel, use panel=True."
1456
+ )
1457
+
1458
+ # Normalize empty covariates list to None
1459
+ if covariates is not None and len(covariates) == 0:
1460
+ covariates = None
1461
+
1462
+ # Resolve survey design if provided
1463
+ from diff_diff.survey import (
1464
+ _resolve_survey_for_fit,
1465
+ _validate_unit_constant_survey,
1466
+ )
1467
+
1468
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
1469
+ _resolve_survey_for_fit(survey_design, data, "analytical")
1470
+ )
1471
+
1472
+ # Validate within-unit constancy for panel survey designs
1473
+ if resolved_survey is not None:
1474
+ if self.panel:
1475
+ _validate_unit_constant_survey(data, unit, survey_design)
1476
+ if resolved_survey.weight_type != "pweight":
1477
+ raise ValueError(
1478
+ f"CallawaySantAnna survey support requires weight_type='pweight', "
1479
+ f"got '{resolved_survey.weight_type}'. The survey variance math "
1480
+ f"assumes probability weights (pweight)."
1481
+ )
1482
+ # Note: strata/PSU/FPC are now supported — aggregated SEs use
1483
+ # compute_survey_if_variance() for design-based inference.
1484
+
1485
+ # Bootstrap + survey is now supported via PSU-level multiplier bootstrap.
1486
+
1487
+ # Validate inputs
1488
+ required_cols = [outcome, unit, time, first_treat]
1489
+ if covariates:
1490
+ required_cols.extend(covariates)
1491
+
1492
+ missing = [c for c in required_cols if c not in data.columns]
1493
+ if missing:
1494
+ raise ValueError(f"Missing columns: {missing}")
1495
+
1496
+ # Create working copy
1497
+ df = data.copy()
1498
+
1499
+ # Ensure numeric types
1500
+ df[time] = pd.to_numeric(df[time])
1501
+ df[first_treat] = pd.to_numeric(df[first_treat])
1502
+
1503
+ # Standardize the first_treat column name for internal use
1504
+ # This avoids hardcoding column names in internal methods
1505
+ df["first_treat"] = df[first_treat]
1506
+
1507
+ # Never-treated indicator (must precede treatment_groups to exclude np.inf)
1508
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
1509
+ # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
1510
+ _inf_mask = df[first_treat] == np.inf
1511
+ if _inf_mask.any():
1512
+ n_inf_units = df.loc[_inf_mask, unit].nunique()
1513
+ warnings.warn(
1514
+ f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 "
1515
+ f"(never-treated). Use first_treat=0 to suppress this warning.",
1516
+ UserWarning,
1517
+ stacklevel=2,
1518
+ )
1519
+ df.loc[_inf_mask, first_treat] = 0
1520
+
1521
+ # Identify groups and time periods
1522
+ time_periods = sorted(df[time].unique())
1523
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
1524
+
1525
+ if self.panel:
1526
+ # Panel: count unique units
1527
+ unit_info = (
1528
+ df.groupby(unit)
1529
+ .agg({first_treat: "first", "_never_treated": "first"})
1530
+ .reset_index()
1531
+ )
1532
+ n_treated_units = (unit_info[first_treat] > 0).sum()
1533
+ n_control_units = (unit_info["_never_treated"]).sum()
1534
+ else:
1535
+ # RCS: count observations per cohort (no unit tracking)
1536
+ n_treated_units = int((df[first_treat] > 0).sum())
1537
+ n_control_units = int(df["_never_treated"].sum())
1538
+
1539
+ if n_control_units == 0 and self.control_group == "never_treated":
1540
+ raise ValueError(
1541
+ "No never-treated units found. Check 'first_treat' column. "
1542
+ "Use control_group='not_yet_treated' if all units are eventually treated."
1543
+ )
1544
+ if n_control_units == 0 and self.control_group == "not_yet_treated":
1545
+ # With not_yet_treated, controls are units not yet treated at each
1546
+ # (g, t) pair — never-treated units are not required.
1547
+ if len(treatment_groups) < 2:
1548
+ raise ValueError(
1549
+ "not_yet_treated control group requires at least 2 treatment "
1550
+ "cohorts when there are no never-treated units."
1551
+ )
1552
+
1553
+ # Note: CallawaySantAnna supports survey weights, strata, PSU, and FPC.
1554
+ # Per-cell SEs use IF-based variance; aggregated SEs use design-based
1555
+ # variance via compute_survey_if_variance() or PSU-level bootstrap.
1556
+ # Pre-compute data structures for efficient ATT(g,t) computation
1557
+ if self.panel:
1558
+ precomputed = self._precompute_structures(
1559
+ df,
1560
+ outcome,
1561
+ unit,
1562
+ time,
1563
+ first_treat,
1564
+ covariates,
1565
+ time_periods,
1566
+ treatment_groups,
1567
+ resolved_survey=resolved_survey,
1568
+ )
1569
+ else:
1570
+ precomputed = self._precompute_structures_rc(
1571
+ df,
1572
+ outcome,
1573
+ unit,
1574
+ time,
1575
+ first_treat,
1576
+ covariates,
1577
+ time_periods,
1578
+ treatment_groups,
1579
+ resolved_survey=resolved_survey,
1580
+ )
1581
+
1582
+ # Recompute survey metadata from the unit-level resolved survey so
1583
+ # that n_psu and df_survey reflect the actual survey design (explicit
1584
+ # PSU/strata) rather than hard-coding n_units.
1585
+ if resolved_survey is not None and survey_metadata is not None:
1586
+ resolved_survey_unit = precomputed.get("resolved_survey_unit")
1587
+ if resolved_survey_unit is not None:
1588
+ from diff_diff.survey import compute_survey_metadata
1589
+
1590
+ unit_w = resolved_survey_unit.weights
1591
+ survey_metadata = compute_survey_metadata(resolved_survey_unit, unit_w)
1592
+
1593
+ # Survey df for safe_inference calls — use the unit-level resolved
1594
+ # survey df computed in _precompute_structures for consistency.
1595
+ df_survey = precomputed.get("df_survey")
1596
+ # Guard: replicate design with undefined df (rank <= 1) → NaN inference
1597
+ if (
1598
+ df_survey is None
1599
+ and resolved_survey is not None
1600
+ and hasattr(resolved_survey, "uses_replicate_variance")
1601
+ and resolved_survey.uses_replicate_variance
1602
+ ):
1603
+ df_survey = 0
1604
+
1605
+ # Compute ATT(g,t) for each group-time combination
1606
+ min_period = min(time_periods)
1607
+ has_survey = resolved_survey is not None
1608
+
1609
+ _skip_info = {"missing_period": [], "empty_cell": []}
1610
+ _n_skipped_other = 0
1611
+
1612
+ if not self.panel:
1613
+ # --- Repeated cross-section path ---
1614
+ # No vectorized/Cholesky fast paths (panel-only optimizations).
1615
+ # Loop using _compute_att_gt_rc() for each (g,t).
1616
+ group_time_effects = {}
1617
+ influence_func_info = {}
1618
+ epv_diagnostics = (
1619
+ {} if (covariates and self.estimation_method in ("ipw", "dr")) else None
1620
+ )
1621
+
1622
+ for g in treatment_groups:
1623
+ if self.base_period == "universal":
1624
+ universal_base = g - 1 - self.anticipation
1625
+ valid_periods = [t for t in time_periods if t != universal_base]
1626
+ else:
1627
+ valid_periods = [
1628
+ t for t in time_periods if t >= g - self.anticipation or t > min_period
1629
+ ]
1630
+
1631
+ for t in valid_periods:
1632
+ rc_result = self._compute_att_gt_rc(
1633
+ precomputed,
1634
+ g,
1635
+ t,
1636
+ covariates,
1637
+ epv_diagnostics=epv_diagnostics,
1638
+ )
1639
+ att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = rc_result[:6]
1640
+ agg_w = rc_result[6] if len(rc_result) > 6 else n_treat
1641
+
1642
+ if att_gt is not None:
1643
+ t_stat, p_val, ci = safe_inference(
1644
+ att_gt,
1645
+ se_gt,
1646
+ alpha=self.alpha,
1647
+ df=df_survey,
1648
+ )
1649
+
1650
+ gte_entry = {
1651
+ "effect": att_gt,
1652
+ "se": se_gt,
1653
+ "t_stat": t_stat,
1654
+ "p_value": p_val,
1655
+ "conf_int": ci,
1656
+ "n_treated": n_treat,
1657
+ "n_control": n_ctrl,
1658
+ "agg_weight": agg_w,
1659
+ }
1660
+ if sw_sum is not None:
1661
+ gte_entry["survey_weight_sum"] = sw_sum
1662
+ group_time_effects[(g, t)] = gte_entry
1663
+
1664
+ if inf_info is not None:
1665
+ influence_func_info[(g, t)] = inf_info
1666
+ else:
1667
+ _n_skipped_other += 1
1668
+
1669
+ elif covariates is None and self.estimation_method == "reg":
1670
+ # Fast vectorized path for the common no-covariates regression case
1671
+ group_time_effects, influence_func_info, _skip_info = (
1672
+ self._compute_all_att_gt_vectorized(
1673
+ precomputed, treatment_groups, time_periods, min_period
1674
+ )
1675
+ )
1676
+ epv_diagnostics = None # No logit in this path
1677
+ elif (
1678
+ covariates is not None
1679
+ and self.estimation_method == "reg"
1680
+ and self.rank_deficient_action != "error"
1681
+ and not has_survey # Cholesky cache uses X'X; survey needs X'WX
1682
+ ):
1683
+ # Optimized covariate regression path with Cholesky caching
1684
+ group_time_effects, influence_func_info, _skip_info = (
1685
+ self._compute_all_att_gt_covariate_reg(
1686
+ precomputed, treatment_groups, time_periods, min_period
1687
+ )
1688
+ )
1689
+ epv_diagnostics = None # No logit in this path
1690
+ else:
1691
+ # General path: IPW, DR, rank_deficient_action="error", or edge cases
1692
+ group_time_effects = {}
1693
+ influence_func_info = {}
1694
+
1695
+ # Propensity score cache for IPW/DR with covariates
1696
+ pscore_cache = {} if (covariates and self.estimation_method in ("ipw", "dr")) else None
1697
+ # Cholesky cache for DR outcome regression component
1698
+ # Skip cache when survey weights present (X'WX differs from X'X)
1699
+ cho_cache = (
1700
+ {}
1701
+ if (
1702
+ covariates
1703
+ and self.estimation_method == "dr"
1704
+ and self.rank_deficient_action != "error"
1705
+ and not has_survey
1706
+ )
1707
+ else None
1708
+ )
1709
+
1710
+ epv_diagnostics = (
1711
+ {} if (covariates and self.estimation_method in ("ipw", "dr")) else None
1712
+ )
1713
+
1714
+ for g in treatment_groups:
1715
+ if self.base_period == "universal":
1716
+ universal_base = g - 1 - self.anticipation
1717
+ valid_periods = [t for t in time_periods if t != universal_base]
1718
+ else:
1719
+ valid_periods = [
1720
+ t for t in time_periods if t >= g - self.anticipation or t > min_period
1721
+ ]
1722
+
1723
+ for t in valid_periods:
1724
+ att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_fast(
1725
+ precomputed,
1726
+ g,
1727
+ t,
1728
+ covariates,
1729
+ pscore_cache=pscore_cache,
1730
+ cho_cache=cho_cache,
1731
+ epv_diagnostics=epv_diagnostics,
1732
+ )
1733
+
1734
+ if att_gt is not None:
1735
+ t_stat, p_val, ci = safe_inference(
1736
+ att_gt,
1737
+ se_gt,
1738
+ alpha=self.alpha,
1739
+ df=df_survey,
1740
+ )
1741
+
1742
+ gte_entry = {
1743
+ "effect": att_gt,
1744
+ "se": se_gt,
1745
+ "t_stat": t_stat,
1746
+ "p_value": p_val,
1747
+ "conf_int": ci,
1748
+ "n_treated": n_treat,
1749
+ "n_control": n_ctrl,
1750
+ }
1751
+ if sw_sum is not None:
1752
+ gte_entry["survey_weight_sum"] = sw_sum
1753
+ group_time_effects[(g, t)] = gte_entry
1754
+
1755
+ if inf_info is not None:
1756
+ influence_func_info[(g, t)] = inf_info
1757
+ else:
1758
+ _n_skipped_other += 1
1759
+
1760
+ if not group_time_effects:
1761
+ raise ValueError(
1762
+ "Could not estimate any group-time effects. "
1763
+ "Check that data has sufficient observations."
1764
+ )
1765
+
1766
+ # Consolidated EPV summary warning
1767
+ if epv_diagnostics:
1768
+ low_epv = {k: v for k, v in epv_diagnostics.items() if v.get("is_low")}
1769
+ if low_epv:
1770
+ n_affected = len(low_epv)
1771
+ n_total = len(epv_diagnostics)
1772
+ min_entry = min(low_epv.values(), key=lambda v: v["epv"])
1773
+ min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
1774
+ warnings.warn(
1775
+ f"Low Events Per Variable (EPV) detected in propensity "
1776
+ f"score estimation for {n_affected} of {n_total} cell(s). "
1777
+ f"Minimum EPV = {min_entry['epv']:.1f} "
1778
+ f"(cohort g={min_g[0]}). "
1779
+ f"Consider estimation_method='reg' (avoids propensity "
1780
+ f"scores) or reducing the number of covariates. "
1781
+ f"See results.epv_summary() for details.",
1782
+ UserWarning,
1783
+ stacklevel=2,
1784
+ )
1785
+
1786
+ # Consolidated (g,t) cell skip warning (all paths)
1787
+ _n_missing = len(_skip_info.get("missing_period", []))
1788
+ _n_empty = len(_skip_info.get("empty_cell", []))
1789
+ _n_total_skipped = _n_missing + _n_empty + _n_skipped_other
1790
+ if _n_total_skipped > 0:
1791
+ _parts = []
1792
+ if _n_missing:
1793
+ _parts.append(
1794
+ f"{_n_missing} due to missing base/post period " f"in panel structure"
1795
+ )
1796
+ if _n_empty:
1797
+ _parts.append(f"{_n_empty} due to zero treated or control " f"observations")
1798
+ if _n_skipped_other:
1799
+ _parts.append(
1800
+ f"{_n_skipped_other} due to insufficient data or " f"non-estimable cells"
1801
+ )
1802
+ warnings.warn(
1803
+ f"{_n_total_skipped} (group, time) cell(s) could not be "
1804
+ f"estimated: {'; '.join(_parts)}.",
1805
+ UserWarning,
1806
+ stacklevel=2,
1807
+ )
1808
+
1809
+ # Compute overall ATT (simple aggregation)
1810
+ overall_att, overall_se, overall_effective_df = self._aggregate_simple(
1811
+ group_time_effects, influence_func_info, df, unit, precomputed
1812
+ )
1813
+ # Use per-statistic effective df from replicate aggregation if available;
1814
+ # otherwise fall back to the original df from the survey design.
1815
+ if overall_effective_df is not None:
1816
+ df_survey = overall_effective_df
1817
+ # Propagate to survey_metadata for display consistency
1818
+ if survey_metadata is not None:
1819
+ survey_metadata.df_survey = df_survey
1820
+ # Guard: replicate design with undefined df (rank <= 1) → NaN inference
1821
+ if (
1822
+ df_survey is None
1823
+ and resolved_survey is not None
1824
+ and hasattr(resolved_survey, "uses_replicate_variance")
1825
+ and resolved_survey.uses_replicate_variance
1826
+ ):
1827
+ df_survey = 0
1828
+ overall_t, overall_p, overall_ci = safe_inference(
1829
+ overall_att,
1830
+ overall_se,
1831
+ alpha=self.alpha,
1832
+ df=df_survey,
1833
+ )
1834
+
1835
+ # Compute additional aggregations if requested
1836
+ event_study_effects = None
1837
+ group_effects = None
1838
+
1839
+ if aggregate in ["event_study", "all"]:
1840
+ event_study_effects = self._aggregate_event_study(
1841
+ group_time_effects,
1842
+ influence_func_info,
1843
+ treatment_groups,
1844
+ time_periods,
1845
+ balance_e,
1846
+ df,
1847
+ unit,
1848
+ precomputed,
1849
+ )
1850
+
1851
+ if aggregate in ["group", "all"]:
1852
+ group_effects = self._aggregate_by_group(
1853
+ group_time_effects,
1854
+ influence_func_info,
1855
+ treatment_groups,
1856
+ precomputed=precomputed,
1857
+ df=df,
1858
+ unit=unit,
1859
+ )
1860
+
1861
+ # Reject replicate-weight designs for bootstrap — replicate variance
1862
+ # is an analytical alternative, not compatible with bootstrap
1863
+ if (
1864
+ self.n_bootstrap > 0
1865
+ and resolved_survey is not None
1866
+ and hasattr(resolved_survey, "uses_replicate_variance")
1867
+ and resolved_survey.uses_replicate_variance
1868
+ ):
1869
+ raise NotImplementedError(
1870
+ "CallawaySantAnna bootstrap (n_bootstrap > 0) is not supported "
1871
+ "with replicate-weight survey designs. Replicate weights provide "
1872
+ "analytical variance; use n_bootstrap=0 instead."
1873
+ )
1874
+
1875
+ # Run bootstrap inference if requested
1876
+ bootstrap_results = None
1877
+ if self.n_bootstrap > 0 and influence_func_info:
1878
+ bootstrap_results = self._run_multiplier_bootstrap(
1879
+ group_time_effects=group_time_effects,
1880
+ influence_func_info=influence_func_info,
1881
+ aggregate=aggregate,
1882
+ balance_e=balance_e,
1883
+ treatment_groups=treatment_groups,
1884
+ time_periods=time_periods,
1885
+ df=df,
1886
+ unit=unit,
1887
+ precomputed=precomputed,
1888
+ cband=self.cband,
1889
+ )
1890
+
1891
+ # Update estimates with bootstrap inference
1892
+ overall_se = bootstrap_results.overall_att_se
1893
+ overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
1894
+ overall_p = bootstrap_results.overall_att_p_value
1895
+ overall_ci = bootstrap_results.overall_att_ci
1896
+
1897
+ # Update group-time effects with bootstrap SEs (batched)
1898
+ gt_keys = [gt for gt in group_time_effects if gt in bootstrap_results.group_time_ses]
1899
+ if gt_keys:
1900
+ gt_effects_arr = np.array(
1901
+ [float(group_time_effects[gt]["effect"]) for gt in gt_keys]
1902
+ )
1903
+ gt_ses_arr = np.array(
1904
+ [float(bootstrap_results.group_time_ses[gt]) for gt in gt_keys]
1905
+ )
1906
+ gt_t_stats, _, _, _ = safe_inference_batch(
1907
+ gt_effects_arr, gt_ses_arr, alpha=self.alpha
1908
+ )
1909
+ for idx, gt in enumerate(gt_keys):
1910
+ group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt]
1911
+ group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt]
1912
+ group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt]
1913
+ group_time_effects[gt]["t_stat"] = float(gt_t_stats[idx])
1914
+
1915
+ # Update event study effects with bootstrap SEs (batched)
1916
+ if (
1917
+ event_study_effects is not None
1918
+ and bootstrap_results.event_study_ses is not None
1919
+ and bootstrap_results.event_study_cis is not None
1920
+ and bootstrap_results.event_study_p_values is not None
1921
+ ):
1922
+ es_keys = [e for e in event_study_effects if e in bootstrap_results.event_study_ses]
1923
+ if es_keys:
1924
+ es_effects_arr = np.array(
1925
+ [float(event_study_effects[e]["effect"]) for e in es_keys]
1926
+ )
1927
+ es_ses_arr = np.array(
1928
+ [float(bootstrap_results.event_study_ses[e]) for e in es_keys]
1929
+ )
1930
+ es_t_stats, _, _, _ = safe_inference_batch(
1931
+ es_effects_arr, es_ses_arr, alpha=self.alpha
1932
+ )
1933
+ for idx, e in enumerate(es_keys):
1934
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
1935
+ event_study_effects[e]["conf_int"] = bootstrap_results.event_study_cis[e]
1936
+ event_study_effects[e]["p_value"] = bootstrap_results.event_study_p_values[
1937
+ e
1938
+ ]
1939
+ event_study_effects[e]["t_stat"] = float(es_t_stats[idx])
1940
+
1941
+ # Update group effects with bootstrap SEs (batched)
1942
+ if (
1943
+ group_effects is not None
1944
+ and bootstrap_results.group_effect_ses is not None
1945
+ and bootstrap_results.group_effect_cis is not None
1946
+ and bootstrap_results.group_effect_p_values is not None
1947
+ ):
1948
+ grp_keys = [g for g in group_effects if g in bootstrap_results.group_effect_ses]
1949
+ if grp_keys:
1950
+ grp_effects_arr = np.array(
1951
+ [float(group_effects[g]["effect"]) for g in grp_keys]
1952
+ )
1953
+ grp_ses_arr = np.array(
1954
+ [float(bootstrap_results.group_effect_ses[g]) for g in grp_keys]
1955
+ )
1956
+ grp_t_stats, _, _, _ = safe_inference_batch(
1957
+ grp_effects_arr, grp_ses_arr, alpha=self.alpha
1958
+ )
1959
+ for idx, g in enumerate(grp_keys):
1960
+ group_effects[g]["se"] = bootstrap_results.group_effect_ses[g]
1961
+ group_effects[g]["conf_int"] = bootstrap_results.group_effect_cis[g]
1962
+ group_effects[g]["p_value"] = bootstrap_results.group_effect_p_values[g]
1963
+ group_effects[g]["t_stat"] = float(grp_t_stats[idx])
1964
+
1965
+ # Compute simultaneous confidence band CIs if cband is available
1966
+ cband_crit_value = None
1967
+ if bootstrap_results is not None:
1968
+ cband_crit_value = bootstrap_results.cband_crit_value
1969
+
1970
+ if cband_crit_value is not None and event_study_effects is not None:
1971
+ for e, eff_data in event_study_effects.items():
1972
+ se_val = eff_data["se"]
1973
+ if np.isfinite(se_val) and se_val > 0:
1974
+ eff_data["cband_conf_int"] = (
1975
+ eff_data["effect"] - cband_crit_value * se_val,
1976
+ eff_data["effect"] + cband_crit_value * se_val,
1977
+ )
1978
+
1979
+ # Store results
1980
+ # Retrieve event-study VCV from aggregation mixin (Phase 7d).
1981
+ # Clear it when bootstrap overwrites event-study SEs to prevent
1982
+ # HonestDiD from mixing analytical VCV with bootstrap SEs.
1983
+ event_study_vcov = getattr(self, "_event_study_vcov", None)
1984
+ event_study_vcov_index = getattr(self, "_event_study_vcov_index", None)
1985
+ if bootstrap_results is not None and event_study_vcov is not None:
1986
+ event_study_vcov = None
1987
+ event_study_vcov_index = None
1988
+
1989
+ self.results_ = CallawaySantAnnaResults(
1990
+ group_time_effects=group_time_effects,
1991
+ overall_att=overall_att,
1992
+ overall_se=overall_se,
1993
+ overall_t_stat=overall_t,
1994
+ overall_p_value=overall_p,
1995
+ overall_conf_int=overall_ci,
1996
+ groups=treatment_groups,
1997
+ time_periods=time_periods,
1998
+ n_obs=len(df),
1999
+ n_treated_units=n_treated_units,
2000
+ n_control_units=n_control_units,
2001
+ alpha=self.alpha,
2002
+ control_group=self.control_group,
2003
+ base_period=self.base_period,
2004
+ event_study_effects=event_study_effects,
2005
+ group_effects=group_effects,
2006
+ bootstrap_results=bootstrap_results,
2007
+ cband_crit_value=cband_crit_value,
2008
+ pscore_trim=self.pscore_trim,
2009
+ survey_metadata=survey_metadata,
2010
+ event_study_vcov=event_study_vcov,
2011
+ event_study_vcov_index=event_study_vcov_index,
2012
+ panel=self.panel,
2013
+ epv_diagnostics=epv_diagnostics if epv_diagnostics else None,
2014
+ epv_threshold=self.epv_threshold,
2015
+ pscore_fallback=self.pscore_fallback,
2016
+ )
2017
+
2018
+ self.is_fitted_ = True
2019
+ return self.results_
2020
+
2021
+ def _outcome_regression(
2022
+ self,
2023
+ treated_change: np.ndarray,
2024
+ control_change: np.ndarray,
2025
+ X_treated: Optional[np.ndarray] = None,
2026
+ X_control: Optional[np.ndarray] = None,
2027
+ sw_treated: Optional[np.ndarray] = None,
2028
+ sw_control: Optional[np.ndarray] = None,
2029
+ ) -> Tuple[float, float, np.ndarray]:
2030
+ """
2031
+ Estimate ATT using outcome regression.
2032
+
2033
+ With covariates:
2034
+ 1. Regress outcome changes on covariates for control group
2035
+ 2. Predict counterfactual for treated using their covariates
2036
+ 3. ATT = mean(treated_change) - mean(predicted_counterfactual)
2037
+
2038
+ Without covariates:
2039
+ Simple difference in means.
2040
+
2041
+ Parameters
2042
+ ----------
2043
+ sw_treated, sw_control : np.ndarray, optional
2044
+ Survey weights for treated and control units.
2045
+ """
2046
+ n_t = len(treated_change)
2047
+ n_c = len(control_change)
2048
+
2049
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
2050
+ # Covariate-adjusted outcome regression
2051
+ # Fit regression on control units: E[Delta Y | X, D=0]
2052
+ beta, residuals = _linear_regression(
2053
+ X_control,
2054
+ control_change,
2055
+ rank_deficient_action=self.rank_deficient_action,
2056
+ weights=sw_control,
2057
+ )
2058
+
2059
+ # Zero NaN coefficients for prediction (dropped rank-deficient columns
2060
+ # contribute 0 to the column space projection, matching DR path convention)
2061
+ beta = np.where(np.isfinite(beta), beta, 0.0)
2062
+
2063
+ # Predict counterfactual for treated units
2064
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
2065
+ predicted_control = np.dot(X_treated_with_intercept, beta)
2066
+
2067
+ # ATT: survey-weighted mean of treated residuals
2068
+ treated_residuals = treated_change - predicted_control
2069
+
2070
+ if sw_treated is not None:
2071
+ sw_t_sum = float(np.sum(sw_treated))
2072
+ sw_c_sum = float(np.sum(sw_control))
2073
+ sw_t_norm = sw_treated / sw_t_sum
2074
+ sw_c_norm = sw_control / sw_c_sum
2075
+ att = float(np.sum(sw_t_norm * treated_residuals))
2076
+
2077
+ # Survey-weighted OR influence function.
2078
+ # Mirrors unweighted: inf_treated = (resid-ATT)/n_t,
2079
+ # inf_control = -resid/n_c. Survey: w_i/sum(w_group).
2080
+ # WLS residuals are orthogonal to W*X by construction.
2081
+ X_c_int = np.column_stack([np.ones(n_c), X_control])
2082
+ resid_c = control_change - np.dot(X_c_int, beta)
2083
+
2084
+ inf_treated = sw_t_norm * (treated_residuals - att)
2085
+ inf_control = -sw_c_norm * resid_c
2086
+ inf_func = np.concatenate([inf_treated, inf_control])
2087
+
2088
+ # SE: survey-weighted variance matching unweighted var_t/n_t + var_c/n_c
2089
+ var_t = float(np.sum(sw_t_norm * (treated_residuals - att) ** 2))
2090
+ var_c = float(np.sum(sw_c_norm * resid_c**2))
2091
+ se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0
2092
+ else:
2093
+ att = float(np.mean(treated_residuals))
2094
+
2095
+ # Standard error using sandwich estimator
2096
+ var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0
2097
+ var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0
2098
+ se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
2099
+
2100
+ # Influence function
2101
+ inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t
2102
+ inf_control = -residuals / n_c
2103
+ inf_func = np.concatenate([inf_treated, inf_control])
2104
+ else:
2105
+ # Simple difference in means (no covariates)
2106
+ if sw_treated is not None:
2107
+ sw_t_norm = sw_treated / np.sum(sw_treated)
2108
+ sw_c_norm = sw_control / np.sum(sw_control)
2109
+ mu_t = float(np.sum(sw_t_norm * treated_change))
2110
+ mu_c = float(np.sum(sw_c_norm * control_change))
2111
+ att = mu_t - mu_c
2112
+
2113
+ # Influence function (survey-weighted)
2114
+ inf_treated = sw_t_norm * (treated_change - mu_t)
2115
+ inf_control = -sw_c_norm * (control_change - mu_c)
2116
+ inf_func = np.concatenate([inf_treated, inf_control])
2117
+
2118
+ # SE from influence function variance
2119
+ se = (
2120
+ float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
2121
+ if (n_t > 0 and n_c > 0)
2122
+ else 0.0
2123
+ )
2124
+ else:
2125
+ att = float(np.mean(treated_change) - np.mean(control_change))
2126
+
2127
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
2128
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
2129
+ se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
2130
+
2131
+ # Influence function (for aggregation)
2132
+ inf_treated = treated_change - np.mean(treated_change)
2133
+ inf_control = control_change - np.mean(control_change)
2134
+ inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c])
2135
+
2136
+ return att, se, inf_func
2137
+
2138
+ def _ipw_estimation(
2139
+ self,
2140
+ treated_change: np.ndarray,
2141
+ control_change: np.ndarray,
2142
+ n_treated: int,
2143
+ n_control: int,
2144
+ X_treated: Optional[np.ndarray] = None,
2145
+ X_control: Optional[np.ndarray] = None,
2146
+ pscore_cache: Optional[Dict] = None,
2147
+ pscore_key: Optional[Any] = None,
2148
+ sw_treated: Optional[np.ndarray] = None,
2149
+ sw_control: Optional[np.ndarray] = None,
2150
+ sw_all: Optional[np.ndarray] = None,
2151
+ context_label: str = "",
2152
+ epv_diagnostics_out: Optional[dict] = None,
2153
+ ) -> Tuple[float, float, np.ndarray]:
2154
+ """
2155
+ Estimate ATT using inverse probability weighting.
2156
+
2157
+ With covariates:
2158
+ 1. Estimate propensity score P(D=1|X) using logistic regression
2159
+ 2. Reweight control units to match treated covariate distribution
2160
+ 3. ATT = mean(treated) - weighted_mean(control)
2161
+
2162
+ Without covariates:
2163
+ Simple difference in means with unconditional propensity weighting.
2164
+
2165
+ Parameters
2166
+ ----------
2167
+ sw_treated, sw_control, sw_all : np.ndarray, optional
2168
+ Survey weights for treated, control, and all units.
2169
+ """
2170
+ n_t = len(treated_change)
2171
+ n_c = len(control_change)
2172
+ n_total = n_treated + n_control
2173
+
2174
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
2175
+ # Covariate-adjusted IPW estimation
2176
+ ps_fallback_used = False
2177
+ # Check propensity score cache
2178
+ cached_pscore = None
2179
+ if pscore_cache is not None and pscore_key is not None:
2180
+ cached_pscore = pscore_cache.get(pscore_key)
2181
+
2182
+ if cached_pscore is not None:
2183
+ # Use cached propensity scores (beta coefficients + EPV diag)
2184
+ beta_logistic, cached_diag = cached_pscore
2185
+ X_all = np.vstack([X_treated, X_control])
2186
+ X_all_with_intercept = np.column_stack([np.ones(n_t + n_c), X_all])
2187
+ z = np.dot(X_all_with_intercept, beta_logistic)
2188
+ z = np.clip(z, -500, 500)
2189
+ pscore = 1 / (1 + np.exp(-z))
2190
+ if epv_diagnostics_out is not None and cached_diag:
2191
+ epv_diagnostics_out.update(cached_diag)
2192
+ else:
2193
+ # Stack covariates and create treatment indicator
2194
+ X_all = np.vstack([X_treated, X_control])
2195
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
2196
+
2197
+ # Estimate propensity scores using IRLS logistic regression
2198
+ diag = {}
2199
+ try:
2200
+ beta_logistic, pscore = solve_logit(
2201
+ X_all,
2202
+ D,
2203
+ rank_deficient_action=self.rank_deficient_action,
2204
+ weights=sw_all,
2205
+ epv_threshold=self.epv_threshold,
2206
+ context_label=context_label,
2207
+ diagnostics_out=diag,
2208
+ )
2209
+ _check_propensity_diagnostics(pscore, self.pscore_trim)
2210
+ # Cache the fitted coefficients (zero-fill NaN from
2211
+ # dropped rank-deficient columns to prevent NaN
2212
+ # propagation on cache reuse) alongside EPV diagnostics
2213
+ if pscore_cache is not None and pscore_key is not None:
2214
+ beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0)
2215
+ pscore_cache[pscore_key] = (beta_clean, diag)
2216
+ except (np.linalg.LinAlgError, ValueError):
2217
+ if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
2218
+ raise
2219
+ # Fallback to unconditional if logistic regression fails
2220
+ ctx = f" for {context_label}" if context_label else ""
2221
+ warnings.warn(
2222
+ f"Propensity score estimation failed{ctx}. "
2223
+ f"Falling back to unconditional propensity "
2224
+ f"(all covariates dropped for this cell). "
2225
+ f"Consider estimation_method='reg' to avoid "
2226
+ f"propensity scores entirely.",
2227
+ UserWarning,
2228
+ stacklevel=4,
2229
+ )
2230
+ if sw_all is not None:
2231
+ pos = sw_all > 0
2232
+ p_uc = float(np.average(D[pos], weights=sw_all[pos]))
2233
+ else:
2234
+ p_uc = n_t / (n_t + n_c)
2235
+ pscore = np.full(len(D), p_uc)
2236
+ ps_fallback_used = True
2237
+ if epv_diagnostics_out is not None and diag:
2238
+ epv_diagnostics_out.update(diag)
2239
+
2240
+ # Propensity scores for treated and control
2241
+ pscore_treated = pscore[:n_t]
2242
+ pscore_control = pscore[n_t:]
2243
+
2244
+ # Clip propensity scores to avoid extreme weights
2245
+ pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim)
2246
+ pscore_treated = np.clip(pscore_treated, self.pscore_trim, 1 - self.pscore_trim)
2247
+
2248
+ if sw_treated is not None:
2249
+ # IPW weights compose with survey weights:
2250
+ # w_i = sw_i * p(X_i) / (1 - p(X_i))
2251
+ weights_control = sw_control * pscore_control / (1 - pscore_control)
2252
+ weights_control_norm = weights_control / np.sum(weights_control)
2253
+
2254
+ # ATT: survey-weighted treated mean minus composite-weighted control mean
2255
+ sw_t_norm = sw_treated / np.sum(sw_treated)
2256
+ mu_t = float(np.sum(sw_t_norm * treated_change))
2257
+ att = mu_t - float(np.sum(weights_control_norm * control_change))
2258
+
2259
+ # Influence function (survey-weighted)
2260
+ inf_treated = sw_t_norm * (treated_change - mu_t)
2261
+ inf_control = -weights_control_norm * (
2262
+ control_change - np.sum(weights_control_norm * control_change)
2263
+ )
2264
+ inf_func = np.concatenate([inf_treated, inf_control])
2265
+
2266
+ if not ps_fallback_used:
2267
+ # Propensity score IF correction
2268
+ # Accounts for estimation uncertainty in logistic regression coefficients
2269
+ X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
2270
+ pscore_all = np.concatenate([pscore_treated, pscore_control])
2271
+
2272
+ # PS IF correction — compute in R's psi convention, convert to phi
2273
+ n_all_panel = n_t + n_c
2274
+ W_ps = pscore_all * (1 - pscore_all)
2275
+ if sw_all is not None:
2276
+ W_ps = W_ps * sw_all
2277
+ # R: Hessian.ps = crossprod(X * sqrt(W)) / n
2278
+ H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2279
+ H_psi_inv = _safe_inv(H_psi)
2280
+
2281
+ D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
2282
+ score_ps = (D_all - pscore_all)[:, None] * X_all_int
2283
+ if sw_all is not None:
2284
+ score_ps = score_ps * sw_all[:, None]
2285
+ # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs)
2286
+ asy_lin_rep_psi = score_ps @ H_psi_inv
2287
+
2288
+ att_control_weighted = np.sum(weights_control_norm * control_change)
2289
+ # R: M2 = colMeans(w.cont * (y - att) * X) / mean(w.cont)
2290
+ # np.sum (not mean): subset sum with normalized weights matches
2291
+ # R's full-sample colMeans/mean(w) after cancellation
2292
+ M2 = np.sum(
2293
+ (weights_control_norm * (control_change - att_control_weighted))[:, None]
2294
+ * X_all_int[n_t:],
2295
+ axis=0,
2296
+ )
2297
+
2298
+ # psi-scale correction, convert to phi for storage
2299
+ # Subtract: R adds PS correction to inf.control, then att = treat - control
2300
+ inf_func = inf_func - (asy_lin_rep_psi @ M2) / n_all_panel
2301
+
2302
+ # SE from influence function variance
2303
+ var_psi = np.sum(inf_func**2)
2304
+ se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
2305
+ else:
2306
+ # IPW weights for control units: p(X) / (1 - p(X))
2307
+ # This reweights controls to have same covariate distribution as treated
2308
+ weights_control = pscore_control / (1 - pscore_control)
2309
+ weights_control = weights_control / np.sum(weights_control) # normalize
2310
+
2311
+ # ATT = mean(treated) - weighted_mean(control)
2312
+ att = float(np.mean(treated_change) - np.sum(weights_control * control_change))
2313
+
2314
+ # Compute standard error
2315
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
2316
+
2317
+ weighted_var_c = np.sum(
2318
+ weights_control
2319
+ * (control_change - np.sum(weights_control * control_change)) ** 2
2320
+ )
2321
+
2322
+ se = float(np.sqrt(var_t / n_t + weighted_var_c)) if (n_t > 0 and n_c > 0) else 0.0
2323
+
2324
+ # Influence function
2325
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
2326
+ inf_control = -weights_control * (
2327
+ control_change - np.sum(weights_control * control_change)
2328
+ )
2329
+ inf_func = np.concatenate([inf_treated, inf_control])
2330
+ else:
2331
+ # Unconditional IPW (reduces to difference in means)
2332
+ if sw_treated is not None:
2333
+ # Survey-weighted difference in means
2334
+ sw_t_norm = sw_treated / np.sum(sw_treated)
2335
+ sw_c_norm = sw_control / np.sum(sw_control)
2336
+ mu_t = float(np.sum(sw_t_norm * treated_change))
2337
+ mu_c = float(np.sum(sw_c_norm * control_change))
2338
+ att = mu_t - mu_c
2339
+
2340
+ inf_treated = sw_t_norm * (treated_change - mu_t)
2341
+ inf_control = -sw_c_norm * (control_change - mu_c)
2342
+ inf_func = np.concatenate([inf_treated, inf_control])
2343
+
2344
+ se = (
2345
+ float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
2346
+ if (n_t > 0 and n_c > 0)
2347
+ else 0.0
2348
+ )
2349
+ else:
2350
+ p_treat = n_treated / n_total # unconditional propensity score
2351
+
2352
+ att = float(np.mean(treated_change) - np.mean(control_change))
2353
+
2354
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
2355
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
2356
+
2357
+ # Adjusted variance for IPW
2358
+ se = float(
2359
+ np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat))
2360
+ if (n_t > 0 and n_c > 0 and p_treat > 0)
2361
+ else 0.0
2362
+ )
2363
+
2364
+ # Influence function (for aggregation)
2365
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
2366
+ inf_control = (control_change - np.mean(control_change)) / n_c
2367
+ inf_func = np.concatenate([inf_treated, -inf_control])
2368
+
2369
+ return att, se, inf_func
2370
+
2371
+ def _doubly_robust(
2372
+ self,
2373
+ treated_change: np.ndarray,
2374
+ control_change: np.ndarray,
2375
+ X_treated: Optional[np.ndarray] = None,
2376
+ X_control: Optional[np.ndarray] = None,
2377
+ pscore_cache: Optional[Dict] = None,
2378
+ pscore_key: Optional[Any] = None,
2379
+ cho_cache: Optional[Dict] = None,
2380
+ cho_key: Optional[Any] = None,
2381
+ sw_treated: Optional[np.ndarray] = None,
2382
+ sw_control: Optional[np.ndarray] = None,
2383
+ sw_all: Optional[np.ndarray] = None,
2384
+ context_label: str = "",
2385
+ epv_diagnostics_out: Optional[dict] = None,
2386
+ ) -> Tuple[float, float, np.ndarray]:
2387
+ """
2388
+ Estimate ATT using doubly robust estimation.
2389
+
2390
+ With covariates:
2391
+ Combines outcome regression and IPW for double robustness.
2392
+ The estimator is consistent if either the outcome model OR
2393
+ the propensity model is correctly specified.
2394
+
2395
+ ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))]
2396
+ + (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)]
2397
+
2398
+ where m(X) is the outcome model and w_i are IPW weights.
2399
+
2400
+ Without covariates:
2401
+ Reduces to simple difference in means.
2402
+
2403
+ Parameters
2404
+ ----------
2405
+ sw_treated, sw_control, sw_all : np.ndarray, optional
2406
+ Survey weights for treated, control, and all units.
2407
+ """
2408
+ n_t = len(treated_change)
2409
+ n_c = len(control_change)
2410
+
2411
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
2412
+ # Doubly robust estimation with covariates
2413
+ ps_fallback_used = False
2414
+ # Step 1: Outcome regression - fit E[Delta Y | X] on control
2415
+ # Try Cholesky cache for outcome regression (disabled when survey weights present)
2416
+ beta = None
2417
+ X_control_with_intercept = np.column_stack([np.ones(n_c), X_control])
2418
+ if cho_cache is not None and cho_key is not None:
2419
+ cached_cho = cho_cache.get(cho_key)
2420
+
2421
+ if cached_cho is False:
2422
+ # Rank-deficient sentinel: skip Cholesky, fall through
2423
+ pass
2424
+ elif cached_cho is not None:
2425
+ Xty = X_control_with_intercept.T @ control_change
2426
+ beta = scipy_linalg.cho_solve(cached_cho, Xty)
2427
+ if np.any(~np.isfinite(beta)):
2428
+ beta = None
2429
+ else:
2430
+ # First time for this cho_key: check rank before Cholesky
2431
+ rank_info = _detect_rank_deficiency(X_control_with_intercept)
2432
+ if len(rank_info[1]) > 0:
2433
+ cho_cache[cho_key] = False # Sentinel
2434
+ else:
2435
+ XtX = X_control_with_intercept.T @ X_control_with_intercept
2436
+ try:
2437
+ cho_factor = scipy_linalg.cho_factor(XtX)
2438
+ cho_cache[cho_key] = cho_factor
2439
+ Xty = X_control_with_intercept.T @ control_change
2440
+ beta = scipy_linalg.cho_solve(cho_factor, Xty)
2441
+ if np.any(~np.isfinite(beta)):
2442
+ beta = None
2443
+ except np.linalg.LinAlgError:
2444
+ pass
2445
+
2446
+ if beta is None:
2447
+ beta, _ = _linear_regression(
2448
+ X_control,
2449
+ control_change,
2450
+ rank_deficient_action=self.rank_deficient_action,
2451
+ weights=sw_control,
2452
+ )
2453
+ # Zero NaN coefficients for prediction only — dropped columns
2454
+ # contribute 0 to the column space projection. Note: solve_ols
2455
+ # deliberately uses NaN (R's lm() convention) for inference, but
2456
+ # here we only need beta for prediction (m_treated, m_control).
2457
+ beta = np.where(np.isfinite(beta), beta, 0.0)
2458
+
2459
+ # Predict counterfactual for both treated and control
2460
+ X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated])
2461
+ m_treated = np.dot(X_treated_with_intercept, beta)
2462
+ m_control = np.dot(X_control_with_intercept, beta)
2463
+
2464
+ # Step 2: Propensity score estimation
2465
+ # Check propensity score cache
2466
+ cached_pscore = None
2467
+ if pscore_cache is not None and pscore_key is not None:
2468
+ cached_pscore = pscore_cache.get(pscore_key)
2469
+
2470
+ if cached_pscore is not None:
2471
+ beta_logistic, cached_diag = cached_pscore
2472
+ X_all = np.vstack([X_treated, X_control])
2473
+ X_all_with_intercept = np.column_stack([np.ones(n_t + n_c), X_all])
2474
+ z = np.dot(X_all_with_intercept, beta_logistic)
2475
+ z = np.clip(z, -500, 500)
2476
+ pscore = 1 / (1 + np.exp(-z))
2477
+ if epv_diagnostics_out is not None and cached_diag:
2478
+ epv_diagnostics_out.update(cached_diag)
2479
+ else:
2480
+ X_all = np.vstack([X_treated, X_control])
2481
+ D = np.concatenate([np.ones(n_t), np.zeros(n_c)])
2482
+
2483
+ diag = {}
2484
+ try:
2485
+ beta_logistic, pscore = solve_logit(
2486
+ X_all,
2487
+ D,
2488
+ rank_deficient_action=self.rank_deficient_action,
2489
+ weights=sw_all,
2490
+ epv_threshold=self.epv_threshold,
2491
+ context_label=context_label,
2492
+ diagnostics_out=diag,
2493
+ )
2494
+ _check_propensity_diagnostics(pscore, self.pscore_trim)
2495
+ if pscore_cache is not None and pscore_key is not None:
2496
+ beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0)
2497
+ pscore_cache[pscore_key] = (beta_clean, diag)
2498
+ except (np.linalg.LinAlgError, ValueError):
2499
+ if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
2500
+ raise
2501
+ # Fallback to unconditional if logistic regression fails
2502
+ ctx = f" for {context_label}" if context_label else ""
2503
+ warnings.warn(
2504
+ f"Propensity score estimation failed{ctx}. "
2505
+ f"Falling back to unconditional propensity "
2506
+ f"(propensity model ignores covariates; outcome "
2507
+ f"regression still uses them). "
2508
+ f"Consider estimation_method='reg' to avoid "
2509
+ f"propensity scores entirely.",
2510
+ UserWarning,
2511
+ stacklevel=4,
2512
+ )
2513
+ if sw_all is not None:
2514
+ pos = sw_all > 0
2515
+ p_uc = float(np.average(D[pos], weights=sw_all[pos]))
2516
+ else:
2517
+ p_uc = n_t / (n_t + n_c)
2518
+ pscore = np.full(len(D), p_uc)
2519
+ ps_fallback_used = True
2520
+ if epv_diagnostics_out is not None and diag:
2521
+ epv_diagnostics_out.update(diag)
2522
+
2523
+ pscore_control = pscore[n_t:]
2524
+
2525
+ # Clip propensity scores
2526
+ pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim)
2527
+
2528
+ if sw_treated is not None:
2529
+ # IPW weights compose with survey weights
2530
+ weights_control = sw_control * pscore_control / (1 - pscore_control)
2531
+
2532
+ # Step 3: DR ATT (survey-weighted)
2533
+ sw_t_sum = np.sum(sw_treated)
2534
+ att_treated_part = float(
2535
+ np.sum(sw_treated * (treated_change - m_treated)) / sw_t_sum
2536
+ )
2537
+ augmentation = float(
2538
+ np.sum(weights_control * (m_control - control_change)) / sw_t_sum
2539
+ )
2540
+ att = att_treated_part + augmentation
2541
+
2542
+ # Step 4: Influence function (survey-weighted DR)
2543
+ # Start with plug-in IF, then add nuisance parameter corrections
2544
+ # (Sant'Anna & Zhao 2020, Theorem 3.1)
2545
+ psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att)
2546
+ psi_control = (weights_control / sw_t_sum) * (m_control - control_change)
2547
+ inf_func = np.concatenate([psi_treated, psi_control])
2548
+
2549
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
2550
+ if not ps_fallback_used:
2551
+ # --- PS IF correction (mirrors IPW L1929-1961) ---
2552
+ # Accounts for propensity score estimation uncertainty
2553
+ X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
2554
+ pscore_treated_clipped = np.clip(
2555
+ pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim
2556
+ )
2557
+ pscore_all = np.concatenate([pscore_treated_clipped, pscore_control])
2558
+
2559
+ # PS IF correction — psi convention, convert to phi
2560
+ n_all_panel = n_t + n_c
2561
+ W_ps = pscore_all * (1 - pscore_all)
2562
+ if sw_all is not None:
2563
+ W_ps = W_ps * sw_all
2564
+ H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2565
+ H_psi_inv = _safe_inv(H_psi)
2566
+
2567
+ D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
2568
+ score_ps = (D_all - pscore_all)[:, None] * X_all_int
2569
+ if sw_all is not None:
2570
+ score_ps = score_ps * sw_all[:, None]
2571
+ # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale)
2572
+ asy_lin_rep_psi = score_ps @ H_psi_inv
2573
+
2574
+ dr_resid_control = m_control - control_change
2575
+ M2_dr = np.sum(
2576
+ ((weights_control / sw_t_sum) * dr_resid_control)[:, None]
2577
+ * X_all_int[n_t:],
2578
+ axis=0,
2579
+ )
2580
+ inf_func = inf_func + (asy_lin_rep_psi @ M2_dr) / n_all_panel
2581
+
2582
+ # --- OR IF correction ---
2583
+ # Accounts for outcome regression estimation uncertainty
2584
+ X_c_int = X_control_with_intercept
2585
+ W_diag = sw_control if sw_control is not None else np.ones(n_c)
2586
+ XtWX = X_c_int.T @ (W_diag[:, None] * X_c_int)
2587
+ bread = _safe_inv(XtWX)
2588
+
2589
+ # M1: dATT/dbeta — gradient of DR ATT w.r.t. OR parameters
2590
+ X_t_int = X_treated_with_intercept
2591
+ M1 = (
2592
+ -np.sum(sw_treated[:, None] * X_t_int, axis=0)
2593
+ + np.sum(weights_control[:, None] * X_c_int, axis=0)
2594
+ ) / sw_t_sum
2595
+
2596
+ # OR asymptotic linear representation (control-only)
2597
+ resid_c = control_change - m_control
2598
+ asy_lin_rep_or = (W_diag * resid_c)[:, None] * X_c_int @ bread
2599
+ # Apply to control portion only (treated contribute zero)
2600
+ inf_func[n_t:] += asy_lin_rep_or @ M1
2601
+
2602
+ # Recompute SE from corrected IF
2603
+ var_psi = np.sum(inf_func**2)
2604
+ se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
2605
+ else:
2606
+ # IPW weights for control: p(X) / (1 - p(X))
2607
+ weights_control = pscore_control / (1 - pscore_control)
2608
+
2609
+ # Step 3: Doubly robust ATT
2610
+ att_treated_part = float(np.mean(treated_change - m_treated))
2611
+ augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t)
2612
+ att = att_treated_part + augmentation
2613
+
2614
+ # Step 4: Influence function with nuisance IF corrections
2615
+ psi_treated = (treated_change - m_treated - att) / n_t
2616
+ psi_control = (weights_control * (m_control - control_change)) / n_t
2617
+ inf_func = np.concatenate([psi_treated, psi_control])
2618
+
2619
+ if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
2620
+ if not ps_fallback_used:
2621
+ # --- PS IF correction — psi convention, convert to phi ---
2622
+ n_all_panel = n_t + n_c
2623
+ X_all_int = np.column_stack([np.ones(n_all_panel), X_all])
2624
+ pscore_treated_clipped = np.clip(
2625
+ pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim
2626
+ )
2627
+ pscore_all = np.concatenate([pscore_treated_clipped, pscore_control])
2628
+
2629
+ W_ps = pscore_all * (1 - pscore_all)
2630
+ H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2631
+ H_psi_inv = _safe_inv(H_psi)
2632
+
2633
+ D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
2634
+ score_ps = (D_all - pscore_all)[:, None] * X_all_int
2635
+ # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale)
2636
+ asy_lin_rep_psi = score_ps @ H_psi_inv
2637
+
2638
+ dr_resid_control = m_control - control_change
2639
+ M2_dr = np.sum(
2640
+ ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:],
2641
+ axis=0,
2642
+ )
2643
+ inf_func = inf_func + (asy_lin_rep_psi @ M2_dr) / n_all_panel
2644
+
2645
+ # --- OR IF correction ---
2646
+ X_c_int = X_control_with_intercept
2647
+ XtX = X_c_int.T @ X_c_int
2648
+ bread = _safe_inv(XtX)
2649
+
2650
+ X_t_int = X_treated_with_intercept
2651
+ M1 = (
2652
+ -np.sum(X_t_int, axis=0)
2653
+ + np.sum(weights_control[:, None] * X_c_int, axis=0)
2654
+ ) / n_t
2655
+
2656
+ resid_c = control_change - m_control
2657
+ asy_lin_rep_or = resid_c[:, None] * X_c_int @ bread
2658
+ inf_func[n_t:] += asy_lin_rep_or @ M1
2659
+
2660
+ # Recompute SE from corrected IF
2661
+ var_psi = np.sum(inf_func**2)
2662
+ se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
2663
+ else:
2664
+ # Without covariates, DR simplifies to difference in means
2665
+ if sw_treated is not None:
2666
+ sw_t_norm = sw_treated / np.sum(sw_treated)
2667
+ sw_c_norm = sw_control / np.sum(sw_control)
2668
+ mu_t = float(np.sum(sw_t_norm * treated_change))
2669
+ mu_c = float(np.sum(sw_c_norm * control_change))
2670
+ att = mu_t - mu_c
2671
+
2672
+ inf_treated = sw_t_norm * (treated_change - mu_t)
2673
+ inf_control = -sw_c_norm * (control_change - mu_c)
2674
+ inf_func = np.concatenate([inf_treated, inf_control])
2675
+
2676
+ se = (
2677
+ float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
2678
+ if (n_t > 0 and n_c > 0)
2679
+ else 0.0
2680
+ )
2681
+ else:
2682
+ att = float(np.mean(treated_change) - np.mean(control_change))
2683
+
2684
+ var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0
2685
+ var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0
2686
+
2687
+ se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
2688
+
2689
+ # Influence function for DR estimator
2690
+ inf_treated = (treated_change - np.mean(treated_change)) / n_t
2691
+ inf_control = (control_change - np.mean(control_change)) / n_c
2692
+ inf_func = np.concatenate([inf_treated, -inf_control])
2693
+
2694
+ return att, se, inf_func
2695
+
2696
+ # =========================================================================
2697
+ # Repeated Cross-Section (RCS) methods
2698
+ # =========================================================================
2699
+
2700
+ def _precompute_structures_rc(
2701
+ self,
2702
+ df: pd.DataFrame,
2703
+ outcome: str,
2704
+ unit: str,
2705
+ time: str,
2706
+ first_treat: str,
2707
+ covariates: Optional[List[str]],
2708
+ time_periods: List[Any],
2709
+ treatment_groups: List[Any],
2710
+ resolved_survey=None,
2711
+ ) -> PrecomputedData:
2712
+ """
2713
+ Pre-compute observation-level structures for repeated cross-section.
2714
+
2715
+ Unlike the panel path, RCS does not pivot to wide format. Each
2716
+ observation is treated independently (no within-unit differencing).
2717
+
2718
+ Returns
2719
+ -------
2720
+ PrecomputedData
2721
+ Dictionary with pre-computed structures (observation-level).
2722
+ """
2723
+ n_obs = len(df)
2724
+
2725
+ # Observation-level arrays (no pivot)
2726
+ obs_time = df[time].values
2727
+ obs_outcome = df[outcome].values
2728
+ unit_cohorts = df[first_treat].values
2729
+
2730
+ # "all_units" key holds integer observation indices for backward
2731
+ # compatibility with aggregation code
2732
+ all_units = np.arange(n_obs)
2733
+
2734
+ # Pre-compute cohort masks (boolean arrays, observation-level)
2735
+ cohort_masks = {}
2736
+ for g in treatment_groups:
2737
+ cohort_masks[g] = unit_cohorts == g
2738
+
2739
+ # Never-treated mask
2740
+ never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
2741
+
2742
+ # Period-to-column mapping (identity for RCS — used for base period checks)
2743
+ period_to_col = {t: i for i, t in enumerate(sorted(time_periods))}
2744
+
2745
+ # Covariates (observation-level, not per-period)
2746
+ obs_covariates = None
2747
+ if covariates:
2748
+ obs_covariates = df[covariates].values
2749
+
2750
+ # Survey weights (already per-observation for RCS)
2751
+ if resolved_survey is not None:
2752
+ survey_weights_arr = resolved_survey.weights.copy()
2753
+ else:
2754
+ survey_weights_arr = None
2755
+
2756
+ # For RCS, the resolved survey is already per-observation
2757
+ resolved_survey_rc = resolved_survey
2758
+
2759
+ # Fixed cohort masses: total observations per cohort across all periods.
2760
+ # Used as aggregation weights so that n_treated is consistent with WIF.
2761
+ rcs_cohort_masses = {}
2762
+ for g in treatment_groups:
2763
+ rcs_cohort_masses[g] = int(np.sum(unit_cohorts == g))
2764
+
2765
+ return {
2766
+ "all_units": all_units,
2767
+ "unit_to_idx": None, # RCS: obs indices are positions
2768
+ "unit_cohorts": unit_cohorts,
2769
+ "canonical_size": n_obs,
2770
+ "is_panel": False,
2771
+ "obs_time": obs_time,
2772
+ "obs_outcome": obs_outcome,
2773
+ "obs_covariates": obs_covariates,
2774
+ "cohort_masks": cohort_masks,
2775
+ "never_treated_mask": never_treated_mask,
2776
+ "time_periods": time_periods,
2777
+ "period_to_col": period_to_col,
2778
+ "is_balanced": False,
2779
+ "survey_weights": survey_weights_arr,
2780
+ "resolved_survey": resolved_survey,
2781
+ "resolved_survey_unit": resolved_survey_rc,
2782
+ "df_survey": (
2783
+ resolved_survey_rc.df_survey
2784
+ if resolved_survey_rc is not None and hasattr(resolved_survey_rc, "df_survey")
2785
+ else None
2786
+ ),
2787
+ "rcs_cohort_masses": rcs_cohort_masses,
2788
+ }
2789
+
2790
+ def _compute_att_gt_rc(
2791
+ self,
2792
+ precomputed: PrecomputedData,
2793
+ g: Any,
2794
+ t: Any,
2795
+ covariates: Optional[List[str]],
2796
+ epv_diagnostics: Optional[Dict] = None,
2797
+ ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]:
2798
+ """
2799
+ Compute ATT(g,t) for repeated cross-section data.
2800
+
2801
+ For RCS, the 2x2 DiD compares outcomes across two independent
2802
+ cross-sections (periods t and base period s) rather than
2803
+ within-unit changes.
2804
+
2805
+ Returns
2806
+ -------
2807
+ att_gt : float or None
2808
+ se_gt : float
2809
+ n_treated : int (treated obs at period t)
2810
+ n_control : int (control obs at period t)
2811
+ inf_func_info : dict or None
2812
+ survey_weight_sum : float or None
2813
+ """
2814
+ cohort_masks = precomputed["cohort_masks"]
2815
+ never_treated_mask = precomputed["never_treated_mask"]
2816
+ unit_cohorts = precomputed["unit_cohorts"]
2817
+ obs_time = precomputed["obs_time"]
2818
+ obs_outcome = precomputed["obs_outcome"]
2819
+ period_to_col = precomputed["period_to_col"]
2820
+
2821
+ # Base period selection (same logic as panel)
2822
+ if self.base_period == "universal":
2823
+ base_period_val = g - 1 - self.anticipation
2824
+ else: # varying
2825
+ if t < g - self.anticipation:
2826
+ base_period_val = t - 1
2827
+ else:
2828
+ base_period_val = g - 1 - self.anticipation
2829
+
2830
+ if base_period_val not in period_to_col or t not in period_to_col:
2831
+ return None, 0.0, 0, 0, None, None
2832
+
2833
+ # Treated mask = cohort g
2834
+ treated_mask = cohort_masks[g]
2835
+
2836
+ # Control mask (same logic as panel)
2837
+ if self.control_group == "never_treated":
2838
+ control_mask = never_treated_mask
2839
+ else: # not_yet_treated
2840
+ nyt_threshold = max(t, base_period_val) + self.anticipation
2841
+ control_mask = never_treated_mask | (
2842
+ (unit_cohorts > nyt_threshold) & (unit_cohorts != g)
2843
+ )
2844
+
2845
+ # Period masks
2846
+ at_t = obs_time == t
2847
+ at_s = obs_time == base_period_val
2848
+
2849
+ # 4 groups of observations
2850
+ treated_t = treated_mask & at_t
2851
+ treated_s = treated_mask & at_s
2852
+ control_t = control_mask & at_t
2853
+ control_s = control_mask & at_s
2854
+
2855
+ n_gt = int(np.sum(treated_t))
2856
+ n_gs = int(np.sum(treated_s))
2857
+ n_ct = int(np.sum(control_t))
2858
+ n_cs = int(np.sum(control_s))
2859
+
2860
+ if n_gt == 0 or n_ct == 0 or n_gs == 0 or n_cs == 0:
2861
+ return None, 0.0, 0, 0, None, None
2862
+
2863
+ # Extract outcomes for each group
2864
+ y_gt = obs_outcome[treated_t]
2865
+ y_gs = obs_outcome[treated_s]
2866
+ y_ct = obs_outcome[control_t]
2867
+ y_cs = obs_outcome[control_s]
2868
+
2869
+ # Survey weights
2870
+ survey_w = precomputed.get("survey_weights")
2871
+ sw_gt = survey_w[treated_t] if survey_w is not None else None
2872
+ sw_gs = survey_w[treated_s] if survey_w is not None else None
2873
+ sw_ct = survey_w[control_t] if survey_w is not None else None
2874
+ sw_cs = survey_w[control_s] if survey_w is not None else None
2875
+
2876
+ # Guard against zero effective mass
2877
+ if sw_gt is not None:
2878
+ if np.sum(sw_gt) <= 0 or np.sum(sw_gs) <= 0:
2879
+ return None, 0.0, 0, 0, None, None
2880
+ if np.sum(sw_ct) <= 0 or np.sum(sw_cs) <= 0:
2881
+ return None, 0.0, 0, 0, None, None
2882
+
2883
+ # Get covariates if specified
2884
+ obs_covariates = precomputed.get("obs_covariates")
2885
+ has_covariates = covariates is not None and obs_covariates is not None
2886
+
2887
+ if has_covariates:
2888
+ X_gt = obs_covariates[treated_t]
2889
+ X_gs = obs_covariates[treated_s]
2890
+ X_ct = obs_covariates[control_t]
2891
+ X_cs = obs_covariates[control_s]
2892
+
2893
+ # Check for NaN in covariates
2894
+ if (
2895
+ np.any(np.isnan(X_gt))
2896
+ or np.any(np.isnan(X_gs))
2897
+ or np.any(np.isnan(X_ct))
2898
+ or np.any(np.isnan(X_cs))
2899
+ ):
2900
+ warnings.warn(
2901
+ f"Missing values in covariates for group {g}, time {t} (RCS). "
2902
+ "Falling back to unconditional estimation.",
2903
+ UserWarning,
2904
+ stacklevel=3,
2905
+ )
2906
+ has_covariates = False
2907
+
2908
+ if has_covariates and self.estimation_method == "reg":
2909
+ att, se, inf_func_all, idx_all = self._outcome_regression_rc(
2910
+ y_gt,
2911
+ y_gs,
2912
+ y_ct,
2913
+ y_cs,
2914
+ X_gt,
2915
+ X_gs,
2916
+ X_ct,
2917
+ X_cs,
2918
+ sw_gt=sw_gt,
2919
+ sw_gs=sw_gs,
2920
+ sw_ct=sw_ct,
2921
+ sw_cs=sw_cs,
2922
+ )
2923
+ elif has_covariates and self.estimation_method == "ipw":
2924
+ epv_diag: dict = {}
2925
+ att, se, inf_func_all, idx_all = self._ipw_estimation_rc(
2926
+ y_gt,
2927
+ y_gs,
2928
+ y_ct,
2929
+ y_cs,
2930
+ X_gt,
2931
+ X_gs,
2932
+ X_ct,
2933
+ X_cs,
2934
+ sw_gt=sw_gt,
2935
+ sw_gs=sw_gs,
2936
+ sw_ct=sw_ct,
2937
+ sw_cs=sw_cs,
2938
+ context_label=f"cohort g={g}",
2939
+ epv_diagnostics_out=epv_diag,
2940
+ )
2941
+ if epv_diagnostics is not None and epv_diag:
2942
+ epv_diagnostics[(g, t)] = epv_diag
2943
+ elif has_covariates and self.estimation_method == "dr":
2944
+ epv_diag = {}
2945
+ att, se, inf_func_all, idx_all = self._doubly_robust_rc(
2946
+ y_gt,
2947
+ y_gs,
2948
+ y_ct,
2949
+ y_cs,
2950
+ X_gt,
2951
+ X_gs,
2952
+ X_ct,
2953
+ X_cs,
2954
+ sw_gt=sw_gt,
2955
+ sw_gs=sw_gs,
2956
+ sw_ct=sw_ct,
2957
+ sw_cs=sw_cs,
2958
+ context_label=f"cohort g={g}",
2959
+ epv_diagnostics_out=epv_diag,
2960
+ )
2961
+ if epv_diagnostics is not None and epv_diag:
2962
+ epv_diagnostics[(g, t)] = epv_diag
2963
+ else:
2964
+ # No-covariates 2x2 DiD (all methods reduce to same)
2965
+ att, se, inf_func_all, idx_all = self._rc_2x2_did(
2966
+ y_gt,
2967
+ y_gs,
2968
+ y_ct,
2969
+ y_cs,
2970
+ treated_t,
2971
+ treated_s,
2972
+ control_t,
2973
+ control_s,
2974
+ sw_gt=sw_gt,
2975
+ sw_gs=sw_gs,
2976
+ sw_ct=sw_ct,
2977
+ sw_cs=sw_cs,
2978
+ )
2979
+
2980
+ # Build influence function info
2981
+ # For RCS, treated_idx/control_idx combine obs from BOTH periods
2982
+ treated_idx = np.concatenate([np.where(treated_t)[0], np.where(treated_s)[0]])
2983
+ control_idx = np.concatenate([np.where(control_t)[0], np.where(control_s)[0]])
2984
+
2985
+ n_treated_combined = len(treated_idx)
2986
+ inf_func_info = {
2987
+ "treated_idx": treated_idx,
2988
+ "control_idx": control_idx,
2989
+ "treated_units": treated_idx, # For RCS, obs indices = "units"
2990
+ "control_units": control_idx,
2991
+ "treated_inf": inf_func_all[:n_treated_combined],
2992
+ "control_inf": inf_func_all[n_treated_combined:],
2993
+ }
2994
+
2995
+ sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None
2996
+ # n_treated = per-cell treated count at period t (for display).
2997
+ # cohort_mass = total treated across all periods (for aggregation weights).
2998
+ cohort_mass = precomputed.get("rcs_cohort_masses", {}).get(g, n_gt)
2999
+ return att, se, n_gt, n_ct, inf_func_info, sw_sum, cohort_mass
3000
+
3001
+ def _rc_2x2_did(
3002
+ self,
3003
+ y_gt,
3004
+ y_gs,
3005
+ y_ct,
3006
+ y_cs,
3007
+ mask_gt,
3008
+ mask_gs,
3009
+ mask_ct,
3010
+ mask_cs,
3011
+ sw_gt=None,
3012
+ sw_gs=None,
3013
+ sw_ct=None,
3014
+ sw_cs=None,
3015
+ ):
3016
+ """
3017
+ Compute the basic 2x2 DiD for RCS (no covariates).
3018
+
3019
+ ATT = (mean(Y_treated_t) - mean(Y_control_t))
3020
+ - (mean(Y_treated_s) - mean(Y_control_s))
3021
+
3022
+ Returns (att, se, inf_func_concat, idx_concat) where inf_func_concat
3023
+ has treated obs (both periods) first, then control obs (both periods).
3024
+ """
3025
+ n_gt = len(y_gt)
3026
+ n_gs = len(y_gs)
3027
+ n_ct = len(y_ct)
3028
+ n_cs = len(y_cs)
3029
+
3030
+ if sw_gt is not None:
3031
+ sw_gt_norm = sw_gt / np.sum(sw_gt)
3032
+ sw_gs_norm = sw_gs / np.sum(sw_gs)
3033
+ sw_ct_norm = sw_ct / np.sum(sw_ct)
3034
+ sw_cs_norm = sw_cs / np.sum(sw_cs)
3035
+
3036
+ mu_gt = float(np.sum(sw_gt_norm * y_gt))
3037
+ mu_gs = float(np.sum(sw_gs_norm * y_gs))
3038
+ mu_ct = float(np.sum(sw_ct_norm * y_ct))
3039
+ mu_cs = float(np.sum(sw_cs_norm * y_cs))
3040
+
3041
+ att = (mu_gt - mu_ct) - (mu_gs - mu_cs)
3042
+
3043
+ # Influence function for 4 groups (survey-weighted)
3044
+ inf_gt = sw_gt_norm * (y_gt - mu_gt)
3045
+ inf_ct = -sw_ct_norm * (y_ct - mu_ct)
3046
+ inf_gs = -sw_gs_norm * (y_gs - mu_gs)
3047
+ inf_cs = sw_cs_norm * (y_cs - mu_cs)
3048
+ else:
3049
+ mu_gt = float(np.mean(y_gt))
3050
+ mu_gs = float(np.mean(y_gs))
3051
+ mu_ct = float(np.mean(y_ct))
3052
+ mu_cs = float(np.mean(y_cs))
3053
+
3054
+ att = (mu_gt - mu_ct) - (mu_gs - mu_cs)
3055
+
3056
+ # Influence function for 4 groups
3057
+ inf_gt = (y_gt - mu_gt) / n_gt
3058
+ inf_ct = -(y_ct - mu_ct) / n_ct
3059
+ inf_gs = -(y_gs - mu_gs) / n_gs
3060
+ inf_cs = (y_cs - mu_cs) / n_cs
3061
+
3062
+ # Concatenate: treated (t then s), control (t then s)
3063
+ inf_treated = np.concatenate([inf_gt, inf_gs])
3064
+ inf_control = np.concatenate([inf_ct, inf_cs])
3065
+ inf_all = np.concatenate([inf_treated, inf_control])
3066
+
3067
+ # SE from influence function
3068
+ se = float(np.sqrt(np.sum(inf_all**2)))
3069
+
3070
+ idx_all = np.concatenate(
3071
+ [
3072
+ np.where(mask_gt)[0],
3073
+ np.where(mask_gs)[0],
3074
+ np.where(mask_ct)[0],
3075
+ np.where(mask_cs)[0],
3076
+ ]
3077
+ )
3078
+
3079
+ return att, se, inf_all, idx_all
3080
+
3081
+ def _outcome_regression_rc(
3082
+ self,
3083
+ y_gt,
3084
+ y_gs,
3085
+ y_ct,
3086
+ y_cs,
3087
+ X_gt,
3088
+ X_gs,
3089
+ X_ct,
3090
+ X_cs,
3091
+ sw_gt=None,
3092
+ sw_gs=None,
3093
+ sw_ct=None,
3094
+ sw_cs=None,
3095
+ ):
3096
+ """
3097
+ Cross-sectional outcome regression for ATT(g,t).
3098
+
3099
+ Matches R DRDID::reg_did_rc (Sant'Anna & Zhao 2020, Eq 2.2).
3100
+
3101
+ Two OLS models fit on controls (period t and base period s).
3102
+ Predictions made for ALL treated (both periods).
3103
+ OR correction pools ALL treated observations across both periods.
3104
+
3105
+ IF convention
3106
+ -------------
3107
+ Intermediate terms use R's unnormalized psi_i convention throughout.
3108
+ R computes SE as ``sd(psi) / sqrt(n)``; with mean(psi) approx 0 this
3109
+ equals ``sqrt(sum(psi^2)) / n``. At the end we convert to the
3110
+ library's pre-scaled phi_i = psi_i / n convention where
3111
+ ``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer.
3112
+
3113
+ Returns (att, se, inf_func_concat, idx_concat).
3114
+ """
3115
+ n_gt = len(y_gt)
3116
+ n_gs = len(y_gs)
3117
+ n_ct = len(y_ct)
3118
+ n_cs = len(y_cs)
3119
+ n_all = n_gt + n_gs + n_ct + n_cs
3120
+
3121
+ # --- Fit 2 OLS on control groups (period t and s separately) ---
3122
+ beta_t, resid_ct = _linear_regression(
3123
+ X_ct,
3124
+ y_ct,
3125
+ rank_deficient_action=self.rank_deficient_action,
3126
+ weights=sw_ct,
3127
+ )
3128
+ beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0)
3129
+
3130
+ beta_s, resid_cs = _linear_regression(
3131
+ X_cs,
3132
+ y_cs,
3133
+ rank_deficient_action=self.rank_deficient_action,
3134
+ weights=sw_cs,
3135
+ )
3136
+ beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0)
3137
+
3138
+ # --- Predict counterfactual for ALL treated (both periods) ---
3139
+ X_gt_int = np.column_stack([np.ones(n_gt), X_gt])
3140
+ X_gs_int = np.column_stack([np.ones(n_gs), X_gs])
3141
+ X_ct_int = np.column_stack([np.ones(n_ct), X_ct])
3142
+ X_cs_int = np.column_stack([np.ones(n_cs), X_cs])
3143
+
3144
+ # mu_hat_{0,t}(X) and mu_hat_{0,s}(X) for each treated obs
3145
+ mu_post_gt = X_gt_int @ beta_t # treated-post predicted at post model
3146
+ mu_pre_gt = X_gt_int @ beta_s # treated-post predicted at pre model
3147
+ mu_post_gs = X_gs_int @ beta_t # treated-pre predicted at post model
3148
+ mu_pre_gs = X_gs_int @ beta_s # treated-pre predicted at pre model
3149
+
3150
+ # --- Group weights (R: w.treat.pre, w.treat.post, w.cont = w.D) ---
3151
+ if sw_gt is not None:
3152
+ w_treat_post = sw_gt # treated at t
3153
+ w_treat_pre = sw_gs # treated at s
3154
+ w_D_gt = sw_gt # ALL treated: t portion
3155
+ w_D_gs = sw_gs # ALL treated: s portion
3156
+ else:
3157
+ w_treat_post = np.ones(n_gt)
3158
+ w_treat_pre = np.ones(n_gs)
3159
+ w_D_gt = np.ones(n_gt)
3160
+ w_D_gs = np.ones(n_gs)
3161
+
3162
+ sum_w_treat_post = np.sum(w_treat_post)
3163
+ sum_w_treat_pre = np.sum(w_treat_pre)
3164
+ sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # pool ALL treated
3165
+
3166
+ # R: mean(w.treat.post), mean(w.treat.pre), mean(w.cont)
3167
+ mean_w_treat_post = sum_w_treat_post / n_all
3168
+ mean_w_treat_pre = sum_w_treat_pre / n_all
3169
+ mean_w_D = sum_w_D / n_all
3170
+
3171
+ # --- Treated means (period-specific Hajek means) ---
3172
+ eta_treat_post = np.sum(w_treat_post * y_gt) / sum_w_treat_post
3173
+ eta_treat_pre = np.sum(w_treat_pre * y_gs) / sum_w_treat_pre
3174
+
3175
+ # --- OR correction: pools ALL treated ---
3176
+ # R: out.y.post - out.y.pre for each treated obs
3177
+ or_diff_gt = mu_post_gt - mu_pre_gt # treated at t
3178
+ or_diff_gs = mu_post_gs - mu_pre_gs # treated at s
3179
+ eta_cont = (np.sum(w_D_gt * or_diff_gt) + np.sum(w_D_gs * or_diff_gs)) / sum_w_D
3180
+
3181
+ # --- Point estimate ---
3182
+ att = float(eta_treat_post - eta_treat_pre - eta_cont)
3183
+
3184
+ # =================================================================
3185
+ # Influence function in R's unnormalized psi convention
3186
+ # (R: reg_did_rc.R, psi = n * phi)
3187
+ # =================================================================
3188
+
3189
+ # --- Treated psi (R: eta.treat.post, eta.treat.pre) ---
3190
+ # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post)
3191
+ psi_treat_post = w_treat_post * (y_gt - eta_treat_post) / mean_w_treat_post
3192
+ # R: w.treat.pre * (y - eta.treat.pre) / mean(w.treat.pre)
3193
+ psi_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / mean_w_treat_pre
3194
+
3195
+ # --- Control psi: leading term (R: inf.cont.1) ---
3196
+ # R: w.cont * (or_diff - eta.cont) [before /mean(w.cont)]
3197
+ psi_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont)
3198
+ psi_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont)
3199
+
3200
+ # --- Control psi: estimation effect (R: inf.cont.2) ---
3201
+ # R: bread = solve(crossprod(X_ctrl, W * X_ctrl) / n)
3202
+ # Here bread is (X'WX)^{-1} (without /n), so asy_lin_rep already
3203
+ # absorbs the 1/n that R puts in its bread. We compensate by using
3204
+ # R's colMeans (= sum/n_all) for M1, matching the product exactly.
3205
+ W_ct = sw_ct if sw_ct is not None else np.ones(n_ct)
3206
+ W_cs = sw_cs if sw_cs is not None else np.ones(n_cs)
3207
+ bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int))
3208
+ bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int))
3209
+
3210
+ # R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all
3211
+ M1 = (
3212
+ np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
3213
+ ) / n_all
3214
+
3215
+ # R: asy.lin.rep.ols (per-obs OLS score * bread)
3216
+ asy_lin_rep_ols_t = (W_ct * resid_ct)[:, None] * X_ct_int @ bread_t
3217
+ asy_lin_rep_ols_s = (W_cs * resid_cs)[:, None] * X_cs_int @ bread_s
3218
+
3219
+ # R: inf.cont.2.post = asy.lin.rep.ols_t %*% M1
3220
+ psi_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,)
3221
+ # R: inf.cont.2.pre = asy.lin.rep.ols_s %*% M1
3222
+ psi_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,)
3223
+
3224
+ # --- Assemble per-group psi ---
3225
+ # R: inf.treat = inf.treat.post - inf.treat.pre (across groups)
3226
+ # R: inf.cont = (inf.cont.1 + inf.cont.2.post - inf.cont.2.pre) / mean(w.cont)
3227
+ # R: att.inf.func = inf.treat - inf.cont
3228
+ psi_gt = psi_treat_post - psi_cont_1_gt / mean_w_D
3229
+ psi_gs = -psi_treat_pre - psi_cont_1_gs / mean_w_D
3230
+ psi_ct = -psi_cont_2_ct / mean_w_D
3231
+ psi_cs = psi_cont_2_cs / mean_w_D
3232
+
3233
+ psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs])
3234
+
3235
+ # =================================================================
3236
+ # Convert to library convention: phi = psi / n_all
3237
+ # se = sqrt(sum(phi^2)) == sqrt(sum(psi^2)) / n_all
3238
+ # =================================================================
3239
+ inf_all = psi_all / n_all
3240
+ se = float(np.sqrt(np.sum(inf_all**2)))
3241
+
3242
+ idx_all = None # caller builds idx from masks
3243
+ return att, se, inf_all, idx_all
3244
+
3245
+ def _ipw_estimation_rc(
3246
+ self,
3247
+ y_gt,
3248
+ y_gs,
3249
+ y_ct,
3250
+ y_cs,
3251
+ X_gt,
3252
+ X_gs,
3253
+ X_ct,
3254
+ X_cs,
3255
+ sw_gt=None,
3256
+ sw_gs=None,
3257
+ sw_ct=None,
3258
+ sw_cs=None,
3259
+ context_label: str = "",
3260
+ epv_diagnostics_out: Optional[dict] = None,
3261
+ ):
3262
+ """
3263
+ Cross-sectional IPW estimation for ATT(g,t).
3264
+
3265
+ Propensity score P(G=g | X) estimated on pooled treated+control
3266
+ observations from both periods. Reweight controls in each period.
3267
+
3268
+ IF convention
3269
+ -------------
3270
+ Intermediate terms use R's unnormalized psi_i convention throughout
3271
+ (R: ``ipw_did_rc``). R computes SE as ``sd(psi) / sqrt(n)``.
3272
+ At the end we convert to the library's pre-scaled phi_i = psi_i / n
3273
+ convention where ``se = sqrt(sum(phi^2))``, used by the
3274
+ aggregation/bootstrap layer.
3275
+
3276
+ Returns (att, se, inf_func_concat, idx_concat).
3277
+ """
3278
+ n_gt = len(y_gt)
3279
+ n_gs = len(y_gs)
3280
+ n_ct = len(y_ct)
3281
+ n_cs = len(y_cs)
3282
+ n_all = n_gt + n_gs + n_ct + n_cs
3283
+
3284
+ # Pool treated and control for propensity score
3285
+ X_all = np.vstack([X_gt, X_gs, X_ct, X_cs])
3286
+ D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)])
3287
+
3288
+ sw_all = None
3289
+ if sw_gt is not None:
3290
+ sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs])
3291
+
3292
+ ps_fallback_used = False
3293
+ diag = {}
3294
+ try:
3295
+ beta_logistic, pscore = solve_logit(
3296
+ X_all,
3297
+ D_all,
3298
+ rank_deficient_action=self.rank_deficient_action,
3299
+ weights=sw_all,
3300
+ epv_threshold=self.epv_threshold,
3301
+ context_label=context_label,
3302
+ diagnostics_out=diag,
3303
+ )
3304
+ _check_propensity_diagnostics(pscore, self.pscore_trim)
3305
+ except (np.linalg.LinAlgError, ValueError):
3306
+ if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
3307
+ raise
3308
+ ctx = f" for {context_label}" if context_label else ""
3309
+ warnings.warn(
3310
+ f"Propensity score estimation failed{ctx} (RCS IPW). "
3311
+ f"Falling back to unconditional propensity "
3312
+ f"(all covariates dropped for this cell). "
3313
+ f"Consider estimation_method='reg' to avoid "
3314
+ f"propensity scores entirely.",
3315
+ UserWarning,
3316
+ stacklevel=4,
3317
+ )
3318
+ if sw_all is not None:
3319
+ pos = sw_all > 0
3320
+ p_treat = float(np.average(D_all[pos], weights=sw_all[pos]))
3321
+ else:
3322
+ p_treat = (n_gt + n_gs) / len(D_all)
3323
+ pscore = np.full(len(D_all), p_treat)
3324
+ ps_fallback_used = True
3325
+ if epv_diagnostics_out is not None and diag:
3326
+ epv_diagnostics_out.update(diag)
3327
+
3328
+ # Clip propensity scores
3329
+ pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
3330
+
3331
+ # Split propensity scores (treated ps not used -- only control IPW weights)
3332
+ ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct]
3333
+ ps_cs = pscore[n_gt + n_gs + n_ct :]
3334
+
3335
+ # IPW weights for controls (R: w1.x = ps / (1 - ps))
3336
+ w_ct = ps_ct / (1 - ps_ct)
3337
+ w_cs = ps_cs / (1 - ps_cs)
3338
+
3339
+ if sw_gt is not None:
3340
+ w_ct = sw_ct * w_ct
3341
+ w_cs = sw_cs * w_cs
3342
+
3343
+ # R: mean(w.treat.post), mean(w.treat.pre), mean(w.ipw.ct), mean(w.ipw.cs)
3344
+ if sw_gt is not None:
3345
+ sum_w_treat_post = np.sum(sw_gt)
3346
+ sum_w_treat_pre = np.sum(sw_gs)
3347
+ else:
3348
+ sum_w_treat_post = float(n_gt)
3349
+ sum_w_treat_pre = float(n_gs)
3350
+
3351
+ mean_w_treat_post = sum_w_treat_post / n_all
3352
+ mean_w_treat_pre = sum_w_treat_pre / n_all
3353
+
3354
+ sum_w_ct = np.sum(w_ct)
3355
+ sum_w_cs = np.sum(w_cs)
3356
+ mean_w_ct = sum_w_ct / n_all
3357
+ mean_w_cs = sum_w_cs / n_all
3358
+
3359
+ # Hajek-normalized weights (R normalizes by sum for point estimate)
3360
+ w_ct_norm = w_ct / sum_w_ct if sum_w_ct > 0 else w_ct
3361
+ w_cs_norm = w_cs / sum_w_cs if sum_w_cs > 0 else w_cs
3362
+
3363
+ if sw_gt is not None:
3364
+ sw_gt_norm = sw_gt / sum_w_treat_post
3365
+ sw_gs_norm = sw_gs / sum_w_treat_pre
3366
+ mu_gt = float(np.sum(sw_gt_norm * y_gt))
3367
+ mu_gs = float(np.sum(sw_gs_norm * y_gs))
3368
+ else:
3369
+ mu_gt = float(np.mean(y_gt))
3370
+ mu_gs = float(np.mean(y_gs))
3371
+
3372
+ mu_ct_ipw = float(np.sum(w_ct_norm * y_ct))
3373
+ mu_cs_ipw = float(np.sum(w_cs_norm * y_cs))
3374
+
3375
+ att = (mu_gt - mu_ct_ipw) - (mu_gs - mu_cs_ipw)
3376
+
3377
+ # =================================================================
3378
+ # Influence function in R's unnormalized psi convention
3379
+ # (R: ipw_did_rc.R, psi = n * phi)
3380
+ # =================================================================
3381
+
3382
+ # --- Treated psi (R: eta.treat.post, eta.treat.pre) ---
3383
+ # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post)
3384
+ if sw_gt is not None:
3385
+ psi_gt = sw_gt * (y_gt - mu_gt) / mean_w_treat_post
3386
+ psi_gs = -sw_gs * (y_gs - mu_gs) / mean_w_treat_pre
3387
+ else:
3388
+ psi_gt = (y_gt - mu_gt) / mean_w_treat_post
3389
+ psi_gs = -(y_gs - mu_gs) / mean_w_treat_pre
3390
+
3391
+ # --- Control psi (R: eta.cont.post, eta.cont.pre) ---
3392
+ # R: w.ipw * (y - eta.cont) / mean(w.ipw)
3393
+ psi_ct = -w_ct * (y_ct - mu_ct_ipw) / mean_w_ct if mean_w_ct > 0 else np.zeros(n_ct)
3394
+ psi_cs = w_cs * (y_cs - mu_cs_ipw) / mean_w_cs if mean_w_cs > 0 else np.zeros(n_cs)
3395
+
3396
+ psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs])
3397
+
3398
+ # Convert leading psi to phi: phi = psi / n_all
3399
+ inf_all = psi_all / n_all
3400
+
3401
+ if not ps_fallback_used:
3402
+ # --- PS IF correction — psi convention, convert to phi ---
3403
+ X_all_int = np.column_stack([np.ones(n_all), X_all])
3404
+
3405
+ W_ps = pscore * (1 - pscore)
3406
+ if sw_all is not None:
3407
+ W_ps = W_ps * sw_all
3408
+ # R: Hessian.ps = crossprod(X * sqrt(W)) / n
3409
+ H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all
3410
+ H_psi_inv = _safe_inv(H_psi)
3411
+
3412
+ score_ps = (D_all - pscore)[:, None] * X_all_int
3413
+ if sw_all is not None:
3414
+ score_ps = score_ps * sw_all[:, None]
3415
+ # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs)
3416
+ asy_lin_rep_psi = score_ps @ H_psi_inv
3417
+
3418
+ # PS nuisance correction in psi convention
3419
+ # R: M2 = colMeans(w_ipw * (y-mu) * X)
3420
+ ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw)
3421
+ ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw)
3422
+
3423
+ ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct)
3424
+ cs_slice = slice(n_gt + n_gs + n_ct, None)
3425
+
3426
+ M2 = np.zeros(X_all_int.shape[1])
3427
+ M2 += np.sum(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0)
3428
+ M2 -= np.sum(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0)
3429
+
3430
+ # psi-scale correction, convert to phi
3431
+ # Subtract: R adds PS correction to inf.control, then att = treat - control
3432
+ inf_all = inf_all - (asy_lin_rep_psi @ M2) / n_all
3433
+
3434
+ # =================================================================
3435
+ # SE from phi: se = sqrt(sum(phi^2))
3436
+ # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0.
3437
+ # =================================================================
3438
+ se = float(np.sqrt(np.sum(inf_all**2)))
3439
+
3440
+ idx_all = None
3441
+ return att, se, inf_all, idx_all
3442
+
3443
+ def _doubly_robust_rc(
3444
+ self,
3445
+ y_gt,
3446
+ y_gs,
3447
+ y_ct,
3448
+ y_cs,
3449
+ X_gt,
3450
+ X_gs,
3451
+ X_ct,
3452
+ X_cs,
3453
+ sw_gt=None,
3454
+ sw_gs=None,
3455
+ sw_ct=None,
3456
+ sw_cs=None,
3457
+ context_label: str = "",
3458
+ epv_diagnostics_out: Optional[dict] = None,
3459
+ ):
3460
+ """
3461
+ Cross-sectional doubly robust estimation for ATT(g,t).
3462
+
3463
+ Matches R DRDID::drdid_rc (Sant'Anna & Zhao 2020, Eq 3.1).
3464
+ Locally efficient DR estimator with 4 OLS fits (control pre/post,
3465
+ treated pre/post) plus propensity score.
3466
+
3467
+ IF convention
3468
+ -------------
3469
+ Intermediate terms use R's unnormalized psi_i convention throughout
3470
+ (R: ``drdid_rc``). R computes SE as ``sd(psi) / sqrt(n)``.
3471
+ At the end we convert to the library's pre-scaled phi_i = psi_i / n
3472
+ convention where ``se = sqrt(sum(phi^2))``, used by the
3473
+ aggregation/bootstrap layer.
3474
+
3475
+ Returns (att, se, inf_func_concat, idx_concat).
3476
+ """
3477
+ n_gt = len(y_gt)
3478
+ n_gs = len(y_gs)
3479
+ n_ct = len(y_ct)
3480
+ n_cs = len(y_cs)
3481
+ n_all = n_gt + n_gs + n_ct + n_cs
3482
+
3483
+ # =====================================================================
3484
+ # 1. Outcome regression: 4 OLS fits
3485
+ # =====================================================================
3486
+ # Control OLS: E[Y|X, D=0, T=t] and E[Y|X, D=0, T=s]
3487
+ beta_ct, resid_ct = _linear_regression(
3488
+ X_ct,
3489
+ y_ct,
3490
+ rank_deficient_action=self.rank_deficient_action,
3491
+ weights=sw_ct,
3492
+ )
3493
+ beta_ct = np.where(np.isfinite(beta_ct), beta_ct, 0.0)
3494
+
3495
+ beta_cs, resid_cs = _linear_regression(
3496
+ X_cs,
3497
+ y_cs,
3498
+ rank_deficient_action=self.rank_deficient_action,
3499
+ weights=sw_cs,
3500
+ )
3501
+ beta_cs = np.where(np.isfinite(beta_cs), beta_cs, 0.0)
3502
+
3503
+ # Treated OLS: E[Y|X, D=1, T=t] and E[Y|X, D=1, T=s]
3504
+ beta_gt, resid_gt = _linear_regression(
3505
+ X_gt,
3506
+ y_gt,
3507
+ rank_deficient_action=self.rank_deficient_action,
3508
+ weights=sw_gt,
3509
+ )
3510
+ beta_gt = np.where(np.isfinite(beta_gt), beta_gt, 0.0)
3511
+
3512
+ beta_gs, resid_gs = _linear_regression(
3513
+ X_gs,
3514
+ y_gs,
3515
+ rank_deficient_action=self.rank_deficient_action,
3516
+ weights=sw_gs,
3517
+ )
3518
+ beta_gs = np.where(np.isfinite(beta_gs), beta_gs, 0.0)
3519
+
3520
+ # Intercept-augmented design matrices
3521
+ X_gt_int = np.column_stack([np.ones(n_gt), X_gt])
3522
+ X_gs_int = np.column_stack([np.ones(n_gs), X_gs])
3523
+ X_ct_int = np.column_stack([np.ones(n_ct), X_ct])
3524
+ X_cs_int = np.column_stack([np.ones(n_cs), X_cs])
3525
+
3526
+ # Control OR predictions for all groups
3527
+ mu0_post_gt = X_gt_int @ beta_ct # mu_{0,1}(X) for treated-post
3528
+ mu0_pre_gt = X_gt_int @ beta_cs # mu_{0,0}(X) for treated-post
3529
+ mu0_post_gs = X_gs_int @ beta_ct # mu_{0,1}(X) for treated-pre
3530
+ mu0_pre_gs = X_gs_int @ beta_cs # mu_{0,0}(X) for treated-pre
3531
+ mu0_post_ct = X_ct_int @ beta_ct # mu_{0,1}(X) for control-post
3532
+ mu0_pre_ct = X_ct_int @ beta_cs # mu_{0,0}(X) for control-post
3533
+ mu0_post_cs = X_cs_int @ beta_ct # mu_{0,1}(X) for control-pre
3534
+ mu0_pre_cs = X_cs_int @ beta_cs # mu_{0,0}(X) for control-pre
3535
+
3536
+ # Treated OR predictions for all groups (for local efficiency adjustment)
3537
+ mu1_post_gt = X_gt_int @ beta_gt # mu_{1,1}(X) for treated-post
3538
+ mu1_pre_gt = X_gt_int @ beta_gs # mu_{1,0}(X) for treated-post
3539
+ mu1_post_gs = X_gs_int @ beta_gt # mu_{1,1}(X) for treated-pre
3540
+ mu1_pre_gs = X_gs_int @ beta_gs # mu_{1,0}(X) for treated-pre
3541
+
3542
+ # mu_{0,Y}(T_i, X_i): control OR evaluated at own period
3543
+ mu0Y_gt = mu0_post_gt # treated-post: use post control model
3544
+ mu0Y_gs = mu0_pre_gs # treated-pre: use pre control model
3545
+ mu0Y_ct = mu0_post_ct # control-post: use post control model
3546
+ mu0Y_cs = mu0_pre_cs # control-pre: use pre control model
3547
+
3548
+ # =====================================================================
3549
+ # 2. Propensity score
3550
+ # =====================================================================
3551
+ X_all = np.vstack([X_gt, X_gs, X_ct, X_cs])
3552
+ D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)])
3553
+ sw_all = None
3554
+ if sw_gt is not None:
3555
+ sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs])
3556
+
3557
+ ps_fallback_used = False
3558
+ diag = {}
3559
+ try:
3560
+ beta_logistic, pscore = solve_logit(
3561
+ X_all,
3562
+ D_all,
3563
+ rank_deficient_action=self.rank_deficient_action,
3564
+ weights=sw_all,
3565
+ epv_threshold=self.epv_threshold,
3566
+ context_label=context_label,
3567
+ diagnostics_out=diag,
3568
+ )
3569
+ _check_propensity_diagnostics(pscore, self.pscore_trim)
3570
+ except (np.linalg.LinAlgError, ValueError):
3571
+ if self.pscore_fallback == "error" or self.rank_deficient_action == "error":
3572
+ raise
3573
+ ctx = f" for {context_label}" if context_label else ""
3574
+ warnings.warn(
3575
+ f"Propensity score estimation failed{ctx} (RCS DR). "
3576
+ f"Falling back to unconditional propensity "
3577
+ f"(propensity model ignores covariates; outcome "
3578
+ f"regression still uses them). "
3579
+ f"Consider estimation_method='reg' to avoid "
3580
+ f"propensity scores entirely.",
3581
+ UserWarning,
3582
+ stacklevel=4,
3583
+ )
3584
+ if sw_all is not None:
3585
+ pos = sw_all > 0
3586
+ p_treat = float(np.average(D_all[pos], weights=sw_all[pos]))
3587
+ else:
3588
+ p_treat = (n_gt + n_gs) / len(D_all)
3589
+ pscore = np.full(len(D_all), p_treat)
3590
+ ps_fallback_used = True
3591
+ if epv_diagnostics_out is not None and diag:
3592
+ epv_diagnostics_out.update(diag)
3593
+
3594
+ pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
3595
+
3596
+ # Split propensity scores per group
3597
+ ps_gt = pscore[:n_gt]
3598
+ ps_gs = pscore[n_gt : n_gt + n_gs]
3599
+ ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct]
3600
+ ps_cs = pscore[n_gt + n_gs + n_ct :]
3601
+
3602
+ # =====================================================================
3603
+ # 3. Group weights and R-convention means
3604
+ # =====================================================================
3605
+ if sw_gt is not None:
3606
+ w_treat_post = sw_gt
3607
+ w_treat_pre = sw_gs
3608
+ w_D_gt = sw_gt
3609
+ w_D_gs = sw_gs
3610
+ else:
3611
+ w_treat_post = np.ones(n_gt)
3612
+ w_treat_pre = np.ones(n_gs)
3613
+ w_D_gt = np.ones(n_gt)
3614
+ w_D_gs = np.ones(n_gs)
3615
+
3616
+ sum_w_treat_post = np.sum(w_treat_post)
3617
+ sum_w_treat_pre = np.sum(w_treat_pre)
3618
+ sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs)
3619
+
3620
+ # R: mean(w) = sum(w) / n -- used in psi normalizers
3621
+ mean_w_treat_post = sum_w_treat_post / n_all
3622
+ mean_w_treat_pre = sum_w_treat_pre / n_all
3623
+ mean_w_D = sum_w_D / n_all
3624
+
3625
+ # IPW control weights: sw * ps/(1-ps) for controls
3626
+ w_ipw_ct = ps_ct / (1 - ps_ct)
3627
+ w_ipw_cs = ps_cs / (1 - ps_cs)
3628
+ if sw_ct is not None:
3629
+ w_ipw_ct = sw_ct * w_ipw_ct
3630
+ w_ipw_cs = sw_cs * w_ipw_cs
3631
+
3632
+ sum_w_ipw_ct = np.sum(w_ipw_ct)
3633
+ sum_w_ipw_cs = np.sum(w_ipw_cs)
3634
+ mean_w_ipw_ct = sum_w_ipw_ct / n_all
3635
+ mean_w_ipw_cs = sum_w_ipw_cs / n_all
3636
+
3637
+ # =====================================================================
3638
+ # 4. Point estimate: tau_1 (AIPW using control ORs)
3639
+ # =====================================================================
3640
+ # Hajek-normalized means of (y - mu0Y) per group
3641
+ eta_treat_post = np.sum(w_treat_post * (y_gt - mu0Y_gt)) / sum_w_treat_post
3642
+ eta_treat_pre = np.sum(w_treat_pre * (y_gs - mu0Y_gs)) / sum_w_treat_pre
3643
+
3644
+ eta_cont_post = (
3645
+ np.sum(w_ipw_ct * (y_ct - mu0Y_ct)) / sum_w_ipw_ct if sum_w_ipw_ct > 0 else 0.0
3646
+ )
3647
+ eta_cont_pre = (
3648
+ np.sum(w_ipw_cs * (y_cs - mu0Y_cs)) / sum_w_ipw_cs if sum_w_ipw_cs > 0 else 0.0
3649
+ )
3650
+
3651
+ tau_1 = (eta_treat_post - eta_cont_post) - (eta_treat_pre - eta_cont_pre)
3652
+
3653
+ # =====================================================================
3654
+ # 5. Point estimate: local efficiency adjustment (tau_2)
3655
+ # =====================================================================
3656
+ # Differences mu_{1,t}(X) - mu_{0,t}(X) for treated obs
3657
+ or_diff_post_gt = mu1_post_gt - mu0_post_gt # at treated-post
3658
+ or_diff_post_gs = mu1_post_gs - mu0_post_gs # at treated-pre
3659
+ or_diff_pre_gt = mu1_pre_gt - mu0_pre_gt # at treated-post
3660
+ or_diff_pre_gs = mu1_pre_gs - mu0_pre_gs # at treated-pre
3661
+
3662
+ # att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) -- all treated
3663
+ att_d_post = (np.sum(w_D_gt * or_diff_post_gt) + np.sum(w_D_gs * or_diff_post_gs)) / sum_w_D
3664
+ # att_dt1_post -- treated-post only
3665
+ att_dt1_post = np.sum(w_treat_post * or_diff_post_gt) / sum_w_treat_post
3666
+ # att_d_pre -- all treated
3667
+ att_d_pre = (np.sum(w_D_gt * or_diff_pre_gt) + np.sum(w_D_gs * or_diff_pre_gs)) / sum_w_D
3668
+ # att_dt0_pre -- treated-pre only
3669
+ att_dt0_pre = np.sum(w_treat_pre * or_diff_pre_gs) / sum_w_treat_pre
3670
+
3671
+ tau_2 = (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre)
3672
+
3673
+ att = float(tau_1 + tau_2)
3674
+
3675
+ # =====================================================================
3676
+ # 6. Influence function in R's unnormalized psi convention
3677
+ # (R: drdid_rc.R, psi = n * phi)
3678
+ # =====================================================================
3679
+
3680
+ # --- tau_1: treated psi (R: eta.treat.post / mean(w.treat.post)) ---
3681
+ # R: w.treat.post * (y - mu0Y - eta.treat.post) / mean(w.treat.post)
3682
+ psi_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / mean_w_treat_post
3683
+ psi_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / mean_w_treat_pre
3684
+
3685
+ # --- tau_1: control psi (R: eta.cont.post / mean(w.ipw)) ---
3686
+ # R: w.ipw * (y - mu0Y - eta.cont) / mean(w.ipw)
3687
+ psi_cont_post_ct = (
3688
+ w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / mean_w_ipw_ct
3689
+ if mean_w_ipw_ct > 0
3690
+ else np.zeros(n_ct)
3691
+ )
3692
+ psi_cont_pre_cs = (
3693
+ w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / mean_w_ipw_cs
3694
+ if mean_w_ipw_cs > 0
3695
+ else np.zeros(n_cs)
3696
+ )
3697
+
3698
+ # tau_1 psi per group
3699
+ psi_gt_tau1 = psi_treat_post
3700
+ psi_gs_tau1 = -psi_treat_pre
3701
+ psi_ct_tau1 = -psi_cont_post_ct
3702
+ psi_cs_tau1 = psi_cont_pre_cs
3703
+
3704
+ # =====================================================================
3705
+ # 7. tau_2 leading terms (R: att.d.post, att.dt1.post, etc.)
3706
+ # =====================================================================
3707
+ # R: w.D * (or_diff - att.d.post) / mean(w.D)
3708
+ psi_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / mean_w_D
3709
+ psi_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / mean_w_D
3710
+ # R: w.treat.post * (or_diff - att.dt1.post) / mean(w.treat.post)
3711
+ psi_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / mean_w_treat_post
3712
+ # R: w.D * (or_diff_pre - att.d.pre) / mean(w.D)
3713
+ psi_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / mean_w_D
3714
+ psi_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / mean_w_D
3715
+ # R: w.treat.pre * (or_diff_pre - att.dt0.pre) / mean(w.treat.pre)
3716
+ psi_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / mean_w_treat_pre
3717
+
3718
+ # tau_2 psi per group (controls contribute zero)
3719
+ psi_gt_tau2 = (psi_d_post_gt - psi_dt1_post) - psi_d_pre_gt
3720
+ psi_gs_tau2 = psi_d_post_gs - (-psi_dt0_pre + psi_d_pre_gs)
3721
+
3722
+ # =====================================================================
3723
+ # 8. Combined plug-in psi (before nuisance corrections)
3724
+ # =====================================================================
3725
+ psi_gt = psi_gt_tau1 + psi_gt_tau2
3726
+ psi_gs = psi_gs_tau1 + psi_gs_tau2
3727
+ psi_ct = psi_ct_tau1
3728
+ psi_cs = psi_cs_tau1
3729
+
3730
+ psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs])
3731
+
3732
+ # =================================================================
3733
+ # Convert leading psi to library phi convention: phi = psi / n_all
3734
+ # =================================================================
3735
+ inf_all = psi_all / n_all
3736
+
3737
+ # =====================================================================
3738
+ # 9. PS nuisance correction — psi convention, convert to phi
3739
+ # =====================================================================
3740
+ X_all_int = np.column_stack([np.ones(n_all), X_all])
3741
+ if not ps_fallback_used:
3742
+ W_ps = pscore * (1 - pscore)
3743
+ if sw_all is not None:
3744
+ W_ps = W_ps * sw_all
3745
+ # R: Hessian.ps = crossprod(X * sqrt(W)) / n
3746
+ H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all
3747
+ H_psi_inv = _safe_inv(H_psi)
3748
+
3749
+ score_ps = (D_all - pscore)[:, None] * X_all_int
3750
+ if sw_all is not None:
3751
+ score_ps = score_ps * sw_all[:, None]
3752
+ # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs)
3753
+ asy_lin_rep_psi = score_ps @ H_psi_inv
3754
+
3755
+ # R: M2 = colMeans(w_ipw * dr_resid / mean(w_ipw) * X)
3756
+ ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct)
3757
+ cs_slice = slice(n_gt + n_gs + n_ct, None)
3758
+
3759
+ dr_resid_ct = y_ct - mu0Y_ct - eta_cont_post
3760
+ dr_resid_cs = y_cs - mu0Y_cs - eta_cont_pre
3761
+
3762
+ M2 = np.zeros(X_all_int.shape[1])
3763
+ if sum_w_ipw_ct > 0:
3764
+ M2 -= np.sum(
3765
+ ((w_ipw_ct * dr_resid_ct / sum_w_ipw_ct)[:, None] * X_all_int[ct_slice]),
3766
+ axis=0,
3767
+ )
3768
+ if sum_w_ipw_cs > 0:
3769
+ M2 += np.sum(
3770
+ ((w_ipw_cs * dr_resid_cs / sum_w_ipw_cs)[:, None] * X_all_int[cs_slice]),
3771
+ axis=0,
3772
+ )
3773
+
3774
+ # psi-scale correction, convert to phi
3775
+ inf_all = inf_all + (asy_lin_rep_psi @ M2) / n_all
3776
+
3777
+ # =====================================================================
3778
+ # 10. Control OR nuisance corrections (phi-scale)
3779
+ # =====================================================================
3780
+ W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct)
3781
+ W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs)
3782
+ bread_ct = _safe_inv(X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int))
3783
+ bread_cs = _safe_inv(X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int))
3784
+
3785
+ # R: asy.lin.rep.ols (per-obs OLS score * bread)
3786
+ asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct
3787
+ asy_lin_rep_cs = (W_cs_vals * resid_cs)[:, None] * X_cs_int @ bread_cs
3788
+
3789
+ # M1 for control-post model (beta_ct): gradient from tau_1 + tau_2
3790
+ # tau_1: -w_treat_post*X/sum_w_treat_post (eta_treat_post via mu0Y_gt)
3791
+ # +w_ipw_ct*X/sum_w_ipw_ct (eta_cont_post via mu0Y_ct)
3792
+ # tau_2: -w_D*X/sum_w_D (att_d_post via mu0_post at all treated)
3793
+ # +w_treat_post*X/sum_w_treat_post (att_dt1_post via mu0_post)
3794
+ M1_ct = np.zeros(X_all_int.shape[1])
3795
+ M1_ct -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post
3796
+ if sum_w_ipw_ct > 0:
3797
+ M1_ct += np.sum(w_ipw_ct[:, None] * X_ct_int, axis=0) / sum_w_ipw_ct
3798
+ M1_ct -= (
3799
+ np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
3800
+ ) / sum_w_D
3801
+ M1_ct += np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post
3802
+
3803
+ # M1 for control-pre model (beta_cs)
3804
+ M1_cs = np.zeros(X_all_int.shape[1])
3805
+ M1_cs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre
3806
+ if sum_w_ipw_cs > 0:
3807
+ M1_cs -= np.sum(w_ipw_cs[:, None] * X_cs_int, axis=0) / sum_w_ipw_cs
3808
+ M1_cs += (
3809
+ np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
3810
+ ) / sum_w_D
3811
+ M1_cs -= np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre
3812
+
3813
+ inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_ct @ M1_ct
3814
+ inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_cs @ M1_cs
3815
+
3816
+ # =====================================================================
3817
+ # 11. Treated OR nuisance corrections (phi-scale)
3818
+ # =====================================================================
3819
+ W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt)
3820
+ W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs)
3821
+ bread_gt = _safe_inv(X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int))
3822
+ bread_gs = _safe_inv(X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int))
3823
+
3824
+ asy_lin_rep_gt = (W_gt_vals * resid_gt)[:, None] * X_gt_int @ bread_gt
3825
+ asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs
3826
+
3827
+ # M1 for treated-post model (beta_gt): mu_{1,1}(X)
3828
+ # From att_d_post: +w_D*X/sum_w_D (all treated)
3829
+ # From att_dt1_post: -w_treat_post*X/sum_w_treat_post (treated-post)
3830
+ M1_gt = np.zeros(X_all_int.shape[1])
3831
+ M1_gt += (
3832
+ np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
3833
+ ) / sum_w_D
3834
+ M1_gt -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post
3835
+
3836
+ # M1 for treated-pre model (beta_gs): mu_{1,0}(X)
3837
+ # From att_d_pre: -w_D*X/sum_w_D
3838
+ # From att_dt0_pre: +w_treat_pre*X/sum_w_treat_pre
3839
+ M1_gs = np.zeros(X_all_int.shape[1])
3840
+ M1_gs -= (
3841
+ np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0)
3842
+ ) / sum_w_D
3843
+ M1_gs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre
3844
+
3845
+ inf_all[:n_gt] += asy_lin_rep_gt @ M1_gt
3846
+ inf_all[n_gt : n_gt + n_gs] += asy_lin_rep_gs @ M1_gs
3847
+
3848
+ # =================================================================
3849
+ # SE from phi: se = sqrt(sum(phi^2))
3850
+ # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0.
3851
+ # =================================================================
3852
+ se = float(np.sqrt(np.sum(inf_all**2)))
3853
+
3854
+ idx_all = None
3855
+ return att, se, inf_all, idx_all
3856
+
3857
+ def get_params(self) -> Dict[str, Any]:
3858
+ """Get estimator parameters (sklearn-compatible)."""
3859
+ return {
3860
+ "control_group": self.control_group,
3861
+ "anticipation": self.anticipation,
3862
+ "estimation_method": self.estimation_method,
3863
+ "alpha": self.alpha,
3864
+ "cluster": self.cluster,
3865
+ "n_bootstrap": self.n_bootstrap,
3866
+ "bootstrap_weights": self.bootstrap_weights,
3867
+ "seed": self.seed,
3868
+ "rank_deficient_action": self.rank_deficient_action,
3869
+ "base_period": self.base_period,
3870
+ "cband": self.cband,
3871
+ "pscore_trim": self.pscore_trim,
3872
+ "panel": self.panel,
3873
+ "epv_threshold": self.epv_threshold,
3874
+ "pscore_fallback": self.pscore_fallback,
3875
+ }
3876
+
3877
+ def set_params(self, **params) -> "CallawaySantAnna":
3878
+ """Set estimator parameters (sklearn-compatible)."""
3879
+ for key, value in params.items():
3880
+ if hasattr(self, key):
3881
+ setattr(self, key, value)
3882
+ else:
3883
+ raise ValueError(f"Unknown parameter: {key}")
3884
+ return self
3885
+
3886
+ def summary(self) -> str:
3887
+ """Get summary of estimation results."""
3888
+ if not self.is_fitted_:
3889
+ raise RuntimeError("Model must be fitted before calling summary()")
3890
+ assert self.results_ is not None
3891
+ return self.results_.summary()
3892
+
3893
+ def print_summary(self) -> None:
3894
+ """Print summary to stdout."""
3895
+ print(self.summary())