diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1770 @@
1
+ """
2
+ Efficient Difference-in-Differences estimator.
3
+
4
+ Implements the ATT estimator from Chen, Sant'Anna & Xie (2025).
5
+ Without covariates, achieves the semiparametric efficiency bound via
6
+ closed-form within-group covariances. With covariates, uses a doubly
7
+ robust path with OLS outcome regression, sieve propensity ratios, and
8
+ kernel-smoothed conditional Omega*(X) (see class docstring for caveats).
9
+
10
+ Under PT-All the model is overidentified and EDiD exploits this for
11
+ tighter inference; under PT-Post it reduces to the standard
12
+ single-baseline estimator (Callaway-Sant'Anna).
13
+ """
14
+
15
+ import warnings
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ from diff_diff.efficient_did_bootstrap import (
22
+ EDiDBootstrapResults,
23
+ EfficientDiDBootstrapMixin,
24
+ )
25
+ from diff_diff.efficient_did_covariates import (
26
+ compute_eif_cov,
27
+ compute_generated_outcomes_cov,
28
+ compute_omega_star_conditional,
29
+ compute_per_unit_weights,
30
+ estimate_inverse_propensity_sieve,
31
+ estimate_outcome_regression,
32
+ estimate_propensity_ratio_sieve,
33
+ )
34
+ from diff_diff.efficient_did_results import EfficientDiDResults, HausmanPretestResult
35
+ from diff_diff.efficient_did_weights import (
36
+ compute_efficient_weights,
37
+ compute_eif_nocov,
38
+ compute_generated_outcomes_nocov,
39
+ compute_omega_star_nocov,
40
+ enumerate_valid_triples,
41
+ )
42
+ from diff_diff.utils import safe_inference
43
+
44
+ # Re-export for convenience
45
+ __all__ = ["EfficientDiD", "EfficientDiDResults", "EDiDBootstrapResults"]
46
+
47
+
48
+ def _validate_and_build_cluster_mapping(
49
+ df: pd.DataFrame,
50
+ unit: str,
51
+ cluster: str,
52
+ all_units: list,
53
+ ) -> Tuple[np.ndarray, int]:
54
+ """Validate cluster column and build unit-to-cluster-index mapping.
55
+
56
+ Checks: column exists, no NaN, per-unit constancy, >= 2 clusters.
57
+ Returns (cluster_indices, n_clusters).
58
+ """
59
+ if cluster not in df.columns:
60
+ raise ValueError(f"Cluster column '{cluster}' not found in data.")
61
+ if df[cluster].isna().any():
62
+ raise ValueError(f"Cluster column '{cluster}' contains missing values.")
63
+ cluster_by_unit = df.groupby(unit)[cluster]
64
+ if (cluster_by_unit.nunique() > 1).any():
65
+ raise ValueError(
66
+ f"Cluster column '{cluster}' varies within unit. "
67
+ "Cluster assignment must be constant per unit."
68
+ )
69
+ cluster_col = cluster_by_unit.first().reindex(all_units).values
70
+ unique_clusters = np.unique(cluster_col)
71
+ n_clusters = len(unique_clusters)
72
+ if n_clusters < 2:
73
+ raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}.")
74
+ cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)}
75
+ indices = np.array([cluster_to_idx[c] for c in cluster_col])
76
+ return indices, n_clusters
77
+
78
+
79
+ def _cluster_aggregate(
80
+ eif_mat: np.ndarray,
81
+ cluster_indices: np.ndarray,
82
+ n_clusters: int,
83
+ ) -> np.ndarray:
84
+ """Sum EIF values within clusters and center.
85
+
86
+ Parameters
87
+ ----------
88
+ eif_mat : ndarray, shape (n_units,) or (n_units, k)
89
+ EIF values — 1-D for a single estimand, 2-D for multiple.
90
+ cluster_indices : ndarray, shape (n_units,)
91
+ Integer cluster assignment per unit.
92
+ n_clusters : int
93
+ Number of unique clusters.
94
+
95
+ Returns
96
+ -------
97
+ ndarray, shape (n_clusters,) or (n_clusters, k)
98
+ Centered cluster-level sums.
99
+ """
100
+ if eif_mat.ndim == 1:
101
+ sums = np.bincount(cluster_indices, weights=eif_mat, minlength=n_clusters).astype(float)
102
+ else:
103
+ sums = np.column_stack(
104
+ [
105
+ np.bincount(cluster_indices, weights=eif_mat[:, j], minlength=n_clusters)
106
+ for j in range(eif_mat.shape[1])
107
+ ]
108
+ ).astype(float)
109
+ return sums - sums.mean(axis=0)
110
+
111
+
112
+ def _compute_se_from_eif(
113
+ eif: np.ndarray,
114
+ n_units: int,
115
+ cluster_indices: Optional[np.ndarray] = None,
116
+ n_clusters: Optional[int] = None,
117
+ ) -> float:
118
+ """SE from EIF values, optionally with cluster-robust correction.
119
+
120
+ Without clusters: ``sqrt(mean(EIF^2) / n)``.
121
+ With clusters: Liang-Zeger sandwich — aggregate EIF within clusters,
122
+ center, and apply G/(G-1) small-sample correction.
123
+ """
124
+ if cluster_indices is not None and n_clusters is not None:
125
+ centered = _cluster_aggregate(eif, cluster_indices, n_clusters)
126
+ correction = n_clusters / (n_clusters - 1) if n_clusters > 1 else 1.0
127
+ var = correction * np.sum(centered**2) / (n_units**2)
128
+ return float(np.sqrt(max(var, 0.0)))
129
+ return float(np.sqrt(np.mean(eif**2) / n_units))
130
+
131
+
132
+ class EfficientDiD(EfficientDiDBootstrapMixin):
133
+ """Efficient DiD estimator (Chen, Sant'Anna & Xie 2025).
134
+
135
+ Without covariates, achieves the semiparametric efficiency bound for
136
+ ATT(g,t) using a closed-form estimator based on within-group sample
137
+ means and covariances.
138
+
139
+ With covariates, uses a doubly robust path: sieve-based propensity
140
+ score ratios (Eq 4.1-4.2), OLS outcome regression, sieve-estimated
141
+ inverse propensities (algorithm step 4), and kernel-smoothed
142
+ conditional Omega*(X) with per-unit efficient weights (Eq 3.12).
143
+ The DR property ensures consistency if either the OLS outcome model
144
+ or the sieve propensity ratio is correctly specified. The OLS
145
+ working model for outcome regressions does not generically guarantee
146
+ the semiparametric efficiency bound (see REGISTRY.md).
147
+
148
+ Parameters
149
+ ----------
150
+ pt_assumption : str, default ``"all"``
151
+ Parallel trends variant: ``"all"`` (overidentified, uses all
152
+ pre-treatment periods and comparison groups) or ``"post"``
153
+ (just-identified, single baseline, equivalent to CS).
154
+ alpha : float, default 0.05
155
+ Significance level.
156
+ cluster : str or None
157
+ Column name for cluster-robust SEs. When set, analytical SEs
158
+ use the Liang-Zeger clustered sandwich estimator on EIF values.
159
+ With ``n_bootstrap > 0``, bootstrap weights are generated at the
160
+ cluster level (all units in a cluster share the same weight).
161
+ control_group : str, default ``"never_treated"``
162
+ Which units serve as the comparison group:
163
+ ``"never_treated"`` requires a never-treated cohort (raises if
164
+ none exist); ``"last_cohort"`` reclassifies the latest treatment
165
+ cohort as pseudo-never-treated and drops post-treatment periods
166
+ for that cohort. Distinct from CallawaySantAnna's
167
+ ``"not_yet_treated"`` — see REGISTRY.md for details.
168
+ n_bootstrap : int, default 0
169
+ Number of multiplier bootstrap iterations (0 = analytical only).
170
+ bootstrap_weights : str, default ``"rademacher"``
171
+ Bootstrap weight distribution.
172
+ seed : int or None
173
+ Random seed for reproducibility.
174
+ anticipation : int, default 0
175
+ Number of anticipation periods (shifts the effective treatment
176
+ boundary forward by this amount).
177
+ sieve_k_max : int or None
178
+ Maximum polynomial degree for sieve ratio estimation. None = auto
179
+ (``min(floor(n_gp^{1/5}), 5)``). Only used with covariates.
180
+ sieve_criterion : str, default ``"bic"``
181
+ Information criterion for sieve degree selection: ``"aic"`` or ``"bic"``.
182
+ ratio_clip : float, default 20.0
183
+ Clip sieve propensity ratios to ``[1/ratio_clip, ratio_clip]``.
184
+ kernel_bandwidth : float or None
185
+ Bandwidth for Gaussian kernel in conditional Omega* estimation.
186
+ None = Silverman's rule-of-thumb (automatic).
187
+
188
+ Examples
189
+ --------
190
+ >>> from diff_diff import EfficientDiD
191
+ >>> edid = EfficientDiD(pt_assumption="all")
192
+ >>> results = edid.fit(data, outcome="y", unit="id", time="t",
193
+ ... first_treat="first_treat", aggregate="all")
194
+ >>> results.print_summary()
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ pt_assumption: str = "all",
200
+ alpha: float = 0.05,
201
+ cluster: Optional[str] = None,
202
+ control_group: str = "never_treated",
203
+ n_bootstrap: int = 0,
204
+ bootstrap_weights: str = "rademacher",
205
+ seed: Optional[int] = None,
206
+ anticipation: int = 0,
207
+ sieve_k_max: Optional[int] = None,
208
+ sieve_criterion: str = "bic",
209
+ ratio_clip: float = 20.0,
210
+ kernel_bandwidth: Optional[float] = None,
211
+ ):
212
+ self.pt_assumption = pt_assumption
213
+ self.alpha = alpha
214
+ self.cluster = cluster
215
+ self.control_group = control_group
216
+ self.n_bootstrap = n_bootstrap
217
+ self.bootstrap_weights = bootstrap_weights
218
+ self.seed = seed
219
+ self.anticipation = anticipation
220
+ self.sieve_k_max = sieve_k_max
221
+ self.sieve_criterion = sieve_criterion
222
+ self.ratio_clip = ratio_clip
223
+ self.kernel_bandwidth = kernel_bandwidth
224
+ self.is_fitted_ = False
225
+ self.results_: Optional[EfficientDiDResults] = None
226
+ self._unit_resolved_survey = None
227
+ self._validate_params()
228
+
229
+ def _validate_params(self) -> None:
230
+ """Validate constrained parameters."""
231
+ if self.pt_assumption not in ("all", "post"):
232
+ raise ValueError(f"pt_assumption must be 'all' or 'post', got '{self.pt_assumption}'")
233
+ if self.control_group not in ("never_treated", "last_cohort"):
234
+ raise ValueError(
235
+ f"control_group must be 'never_treated' or 'last_cohort', "
236
+ f"got '{self.control_group}'"
237
+ )
238
+ valid_weights = ("rademacher", "mammen", "webb")
239
+ if self.bootstrap_weights not in valid_weights:
240
+ raise ValueError(
241
+ f"bootstrap_weights must be one of {valid_weights}, "
242
+ f"got '{self.bootstrap_weights}'"
243
+ )
244
+ if self.sieve_criterion not in ("aic", "bic"):
245
+ raise ValueError(
246
+ f"sieve_criterion must be 'aic' or 'bic', got '{self.sieve_criterion}'"
247
+ )
248
+ if not (np.isfinite(self.ratio_clip) and self.ratio_clip > 1.0):
249
+ raise ValueError(f"ratio_clip must be finite and > 1.0, got {self.ratio_clip}")
250
+ if self.kernel_bandwidth is not None:
251
+ if not (np.isfinite(self.kernel_bandwidth) and self.kernel_bandwidth > 0):
252
+ raise ValueError(
253
+ f"kernel_bandwidth must be finite and > 0 (or None for auto), "
254
+ f"got {self.kernel_bandwidth}"
255
+ )
256
+ if self.sieve_k_max is not None:
257
+ if not (isinstance(self.sieve_k_max, (int, np.integer)) and self.sieve_k_max > 0):
258
+ raise ValueError(
259
+ f"sieve_k_max must be a positive integer (or None for auto), "
260
+ f"got {self.sieve_k_max}"
261
+ )
262
+
263
+ # -- sklearn compatibility ------------------------------------------------
264
+
265
+ def get_params(self) -> Dict[str, Any]:
266
+ """Get estimator parameters (sklearn-compatible)."""
267
+ return {
268
+ "pt_assumption": self.pt_assumption,
269
+ "anticipation": self.anticipation,
270
+ "alpha": self.alpha,
271
+ "cluster": self.cluster,
272
+ "control_group": self.control_group,
273
+ "n_bootstrap": self.n_bootstrap,
274
+ "bootstrap_weights": self.bootstrap_weights,
275
+ "seed": self.seed,
276
+ "sieve_k_max": self.sieve_k_max,
277
+ "sieve_criterion": self.sieve_criterion,
278
+ "ratio_clip": self.ratio_clip,
279
+ "kernel_bandwidth": self.kernel_bandwidth,
280
+ }
281
+
282
+ def set_params(self, **params: Any) -> "EfficientDiD":
283
+ """Set estimator parameters (sklearn-compatible)."""
284
+ for key, value in params.items():
285
+ if hasattr(self, key):
286
+ setattr(self, key, value)
287
+ else:
288
+ raise ValueError(f"Unknown parameter: {key}")
289
+ self._validate_params()
290
+ return self
291
+
292
+ # -- Main estimation ------------------------------------------------------
293
+
294
+ def fit(
295
+ self,
296
+ data: pd.DataFrame,
297
+ outcome: str,
298
+ unit: str,
299
+ time: str,
300
+ first_treat: str,
301
+ covariates: Optional[List[str]] = None,
302
+ aggregate: Optional[str] = None,
303
+ balance_e: Optional[int] = None,
304
+ survey_design: Optional[Any] = None,
305
+ store_eif: bool = False,
306
+ ) -> EfficientDiDResults:
307
+ """Fit the Efficient DiD estimator.
308
+
309
+ Parameters
310
+ ----------
311
+ data : DataFrame
312
+ Balanced panel data.
313
+ outcome : str
314
+ Outcome variable column name.
315
+ unit : str
316
+ Unit identifier column name.
317
+ time : str
318
+ Time period column name.
319
+ first_treat : str
320
+ Column indicating first treatment period.
321
+ Use 0 or ``np.inf`` for never-treated units.
322
+ covariates : list of str, optional
323
+ Column names for time-invariant unit-level covariates.
324
+ When provided, uses the doubly robust path (outcome regression
325
+ + propensity score ratios).
326
+ aggregate : str, optional
327
+ ``None``, ``"simple"``, ``"event_study"``, ``"group"``, or
328
+ ``"all"``.
329
+ balance_e : int, optional
330
+ Balance event study at this relative period.
331
+ survey_design : SurveyDesign, optional
332
+ Survey design specification for design-based inference.
333
+ Applies survey weights to all means, covariances, and cohort
334
+ fractions, and uses Taylor Series Linearization for SE
335
+ estimation. Cannot be combined with ``cluster``.
336
+ store_eif : bool, default False
337
+ Store per-(g,t) EIF vectors in the results object. Used
338
+ internally by :meth:`hausman_pretest`; not needed for
339
+ normal usage.
340
+
341
+ Returns
342
+ -------
343
+ EfficientDiDResults
344
+
345
+ Raises
346
+ ------
347
+ ValueError
348
+ Missing columns, unbalanced panel, non-absorbing treatment,
349
+ or PT-Post without a never-treated group.
350
+ """
351
+ self._validate_params()
352
+
353
+ if self.cluster is not None and survey_design is not None:
354
+ raise NotImplementedError(
355
+ "cluster and survey_design cannot both be set. "
356
+ "Use survey_design with PSU/strata for cluster-robust inference."
357
+ )
358
+
359
+ # Resolve survey design if provided
360
+ from diff_diff.survey import _resolve_survey_for_fit
361
+
362
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
363
+ _resolve_survey_for_fit(survey_design, data, "analytical")
364
+ )
365
+
366
+ # Validate within-unit constancy for panel survey designs
367
+ if resolved_survey is not None:
368
+ from diff_diff.survey import _validate_unit_constant_survey
369
+
370
+ _validate_unit_constant_survey(data, unit, survey_design)
371
+
372
+ # Store survey df for safe_inference calls (t-distribution with survey df)
373
+ self._survey_df = survey_metadata.df_survey if survey_metadata is not None else None
374
+ # Guard: replicate design with undefined df → NaN inference
375
+ if (self._survey_df is None and resolved_survey is not None
376
+ and hasattr(resolved_survey, 'uses_replicate_variance')
377
+ and resolved_survey.uses_replicate_variance):
378
+ self._survey_df = 0
379
+
380
+ # Bootstrap + survey supported via PSU-level multiplier bootstrap.
381
+
382
+ # Normalize empty covariates list to None (use nocov path)
383
+ if covariates is not None and len(covariates) == 0:
384
+ covariates = None
385
+ use_covariates = covariates is not None
386
+
387
+ # ----- Validate inputs -----
388
+ required_cols = [outcome, unit, time, first_treat]
389
+ missing = [c for c in required_cols if c not in data.columns]
390
+ if missing:
391
+ raise ValueError(f"Missing columns: {missing}")
392
+
393
+ df = data.copy()
394
+ df[time] = pd.to_numeric(df[time])
395
+ df[first_treat] = pd.to_numeric(df[first_treat])
396
+
397
+ # Normalize never-treated: inf -> 0 internally, keep track
398
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
399
+ df.loc[df[first_treat] == np.inf, first_treat] = 0
400
+
401
+ time_periods = sorted(df[time].unique())
402
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
403
+
404
+ # Validate balanced panel
405
+ unit_period_counts = df.groupby(unit)[time].nunique()
406
+ n_periods = len(time_periods)
407
+ if (unit_period_counts != n_periods).any():
408
+ raise ValueError(
409
+ "Unbalanced panel detected. EfficientDiD requires a balanced "
410
+ "panel where every unit is observed in every time period."
411
+ )
412
+
413
+ # Reject non-finite outcomes (NaN/Inf corrupt Omega*/EIF calculations)
414
+ non_finite_mask = ~np.isfinite(df[outcome])
415
+ if non_finite_mask.any():
416
+ n_bad = int(non_finite_mask.sum())
417
+ raise ValueError(
418
+ f"Found {n_bad} non-finite value(s) in outcome column '{outcome}'. "
419
+ "EfficientDiD requires finite outcomes for all unit-period observations."
420
+ )
421
+
422
+ # Reject duplicate (unit, time) rows
423
+ dup_mask = df.duplicated(subset=[unit, time], keep=False)
424
+ if dup_mask.any():
425
+ n_dups = int(dup_mask.sum())
426
+ raise ValueError(
427
+ f"Found {n_dups} duplicate ({unit}, {time}) rows. "
428
+ "EfficientDiD requires exactly one observation per unit-period."
429
+ )
430
+
431
+ # Validate absorbing treatment (vectorized)
432
+ ft_nunique = df.groupby(unit)[first_treat].nunique()
433
+ bad_units = ft_nunique[ft_nunique > 1]
434
+ if len(bad_units) > 0:
435
+ uid = bad_units.index[0]
436
+ raise ValueError(
437
+ f"Non-absorbing treatment detected for unit {uid}: "
438
+ "first_treat value changes over time."
439
+ )
440
+
441
+ # Unit info
442
+ unit_info = (
443
+ df.groupby(unit)
444
+ .agg(
445
+ {
446
+ first_treat: "first",
447
+ "_never_treated": "first",
448
+ }
449
+ )
450
+ .reset_index()
451
+ )
452
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
453
+ n_control_units = int(unit_info["_never_treated"].sum())
454
+
455
+ # Control group logic
456
+ if self.control_group == "last_cohort":
457
+ # Always reclassify last cohort as pseudo-control when requested
458
+ if not treatment_groups:
459
+ raise ValueError(
460
+ "No treated cohorts found. control_group='last_cohort' requires "
461
+ "at least 2 treatment cohorts."
462
+ )
463
+ last_g = max(treatment_groups)
464
+ treatment_groups = [g for g in treatment_groups if g != last_g]
465
+ if not treatment_groups:
466
+ raise ValueError("Only one treatment cohort; cannot use last_cohort control.")
467
+ effective_last = last_g - self.anticipation
468
+ time_periods = [t for t in time_periods if t < effective_last]
469
+ if len(time_periods) < 2:
470
+ raise ValueError(
471
+ "Fewer than 2 time periods remain after trimming for last_cohort control."
472
+ )
473
+ unit_info.loc[unit_info[first_treat] == last_g, first_treat] = 0
474
+ unit_info.loc[unit_info[first_treat] == 0, "_never_treated"] = True
475
+ n_treated_units = int((unit_info[first_treat] > 0).sum())
476
+ n_control_units = int(unit_info["_never_treated"].sum())
477
+ elif n_control_units == 0:
478
+ raise ValueError(
479
+ "No never-treated units found. Use control_group='last_cohort' "
480
+ "to use the last treatment cohort as a pseudo-control."
481
+ )
482
+
483
+ # ----- Prepare data -----
484
+ all_units = sorted(df[unit].unique())
485
+ n_units = len(all_units)
486
+
487
+ # Build unit-to-first-panel-row index aligned to all_units (sorted)
488
+ # order. The previous approach (groupby cumcount == 0) yielded
489
+ # first-appearance order which can differ from sorted order when the
490
+ # input DataFrame is not pre-sorted by unit.
491
+ first_pos: Dict[Any, int] = {}
492
+ for i, u in enumerate(df[unit].values):
493
+ if u not in first_pos:
494
+ first_pos[u] = i
495
+ self._unit_first_panel_row = np.array([first_pos[u] for u in all_units])
496
+
497
+ # Build unit-level ResolvedSurveyDesign once (avoids repeated
498
+ # construction in _compute_survey_eif_se and ensures consistent
499
+ # unit-level df for safe_inference t-distribution).
500
+ if resolved_survey is not None:
501
+ from diff_diff.survey import ResolvedSurveyDesign
502
+
503
+ row_idx = self._unit_first_panel_row
504
+ unit_weights_s = resolved_survey.weights[row_idx]
505
+ unit_strata = (
506
+ resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
507
+ )
508
+ unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
509
+ unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
510
+ n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
511
+ n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
512
+ self._unit_resolved_survey = resolved_survey.subset_to_units(
513
+ row_idx, unit_weights_s, unit_strata, unit_psu, unit_fpc,
514
+ n_strata_u, n_psu_u,
515
+ )
516
+ # Use unit-level df (not panel-level) for t-distribution
517
+ self._survey_df = self._unit_resolved_survey.df_survey
518
+ # Re-apply replicate guard: undefined df → NaN inference
519
+ if (self._survey_df is None
520
+ and self._unit_resolved_survey.uses_replicate_variance):
521
+ self._survey_df = 0
522
+ else:
523
+ self._unit_resolved_survey = None
524
+
525
+ # Build cluster mapping if cluster-robust SEs requested
526
+ if self.cluster is not None:
527
+ unit_cluster_indices, n_clusters = _validate_and_build_cluster_mapping(
528
+ df, unit, self.cluster, all_units
529
+ )
530
+ if n_clusters < 50:
531
+ warnings.warn(
532
+ f"Only {n_clusters} clusters. Analytical clustered SEs may "
533
+ "be unreliable. Consider n_bootstrap > 0 for cluster "
534
+ "bootstrap inference.",
535
+ UserWarning,
536
+ stacklevel=2,
537
+ )
538
+ else:
539
+ unit_cluster_indices = None
540
+ n_clusters = None
541
+
542
+ period_to_col = {p: i for i, p in enumerate(time_periods)}
543
+ period_1 = time_periods[0]
544
+ period_1_col = period_to_col[period_1]
545
+
546
+ # Pivot outcome to wide matrix (n_units, n_periods)
547
+ pivot = df.pivot(index=unit, columns=time, values=outcome)
548
+ # Reindex to match all_units ordering and time_periods column order
549
+ pivot = pivot.reindex(index=all_units, columns=time_periods)
550
+ outcome_wide = pivot.values.astype(float)
551
+
552
+ # Build cohort masks and fractions
553
+ unit_info_indexed = unit_info.set_index(unit)
554
+ unit_cohorts = unit_info_indexed.reindex(all_units)[first_treat].values.astype(
555
+ float
556
+ ) # 0 = never-treated
557
+
558
+ cohort_masks: Dict[float, np.ndarray] = {}
559
+ for g in treatment_groups:
560
+ cohort_masks[g] = unit_cohorts == g
561
+ never_treated_mask = unit_cohorts == 0
562
+ cohort_masks[np.inf] = never_treated_mask # also keyed by inf sentinel
563
+
564
+ # ----- Unit-level survey weights -----
565
+ # Survey weights in the panel are at obs level (unit x time).
566
+ # EfficientDiD works at unit level. Extract one weight per unit
567
+ # by taking the first observation per unit (balanced panel, so
568
+ # weights should be constant within unit).
569
+ unit_level_weights: Optional[np.ndarray] = None
570
+ if resolved_survey is not None:
571
+ # Use the resolved survey's weights (already normalized per weight_type)
572
+ # subset to unit level via _unit_first_panel_row (aligned to all_units)
573
+ unit_level_weights = self._unit_resolved_survey.weights
574
+ self._unit_level_weights = unit_level_weights
575
+
576
+ cohort_fractions: Dict[float, float] = {}
577
+ if unit_level_weights is not None:
578
+ # Survey-weighted cohort fractions: sum(w_i for i in cohort) / sum(w_i)
579
+ total_w = float(np.sum(unit_level_weights))
580
+ for g in treatment_groups:
581
+ cohort_fractions[g] = float(np.sum(unit_level_weights[cohort_masks[g]])) / total_w
582
+ cohort_fractions[np.inf] = (
583
+ float(np.sum(unit_level_weights[never_treated_mask])) / total_w
584
+ )
585
+ else:
586
+ for g in treatment_groups:
587
+ cohort_fractions[g] = float(np.sum(cohort_masks[g])) / n_units
588
+ cohort_fractions[np.inf] = float(np.sum(never_treated_mask)) / n_units
589
+
590
+ # ----- Small cohort warnings -----
591
+ for g in treatment_groups:
592
+ n_g = int(np.sum(cohort_masks[g]))
593
+ frac_g = cohort_fractions[g]
594
+ if n_g < 2:
595
+ warnings.warn(
596
+ f"Cohort {g} has only {n_g} unit. Omega* inversion and "
597
+ "EIF computation may be numerically unstable.",
598
+ UserWarning,
599
+ stacklevel=2,
600
+ )
601
+ elif frac_g < 0.01:
602
+ warnings.warn(
603
+ f"Cohort {g} represents {frac_g:.1%} of the sample (< 1%). "
604
+ "Efficient weights may be imprecise.",
605
+ UserWarning,
606
+ stacklevel=2,
607
+ )
608
+
609
+ # Guard: never-treated with zero survey weight → no valid comparisons
610
+ # Applies to both covariates (DR nuisance) and nocov (weighted means) paths
611
+ if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
612
+ raise ValueError(
613
+ "Never-treated group has zero survey weight. EfficientDiD "
614
+ "requires a never-treated control group with positive "
615
+ "survey weight for estimation."
616
+ )
617
+
618
+ # ----- Covariate preparation (if provided) -----
619
+ covariate_matrix: Optional[np.ndarray] = None
620
+ m_hat_cache: Dict[Tuple, np.ndarray] = {}
621
+ r_hat_cache: Dict[Tuple[float, float], np.ndarray] = {}
622
+ s_hat_cache: Dict[float, np.ndarray] = {} # inverse propensities per group
623
+
624
+ if use_covariates:
625
+ assert covariates is not None # for type narrowing
626
+
627
+ # Validate covariate columns exist
628
+ missing_cov = [c for c in covariates if c not in data.columns]
629
+ if missing_cov:
630
+ raise ValueError(f"Missing covariate columns: {missing_cov}")
631
+
632
+ # Validate no NaN/Inf in covariates
633
+ for col_name in covariates:
634
+ non_finite_cov = ~np.isfinite(pd.to_numeric(df[col_name], errors="coerce"))
635
+ if non_finite_cov.any():
636
+ n_bad = int(non_finite_cov.sum())
637
+ raise ValueError(
638
+ f"Found {n_bad} non-finite value(s) in covariate column "
639
+ f"'{col_name}'. Covariates must be finite."
640
+ )
641
+
642
+ # Validate time-invariance: covariates must be constant within each unit
643
+ for col_name in covariates:
644
+ cov_nunique = df.groupby(unit)[col_name].nunique()
645
+ varying = cov_nunique[cov_nunique > 1]
646
+ if len(varying) > 0:
647
+ uid = varying.index[0]
648
+ raise ValueError(
649
+ f"Covariate '{col_name}' varies over time for unit {uid}. "
650
+ "EfficientDiD requires time-invariant covariates. "
651
+ "Extract base-period values before calling fit()."
652
+ )
653
+
654
+ # Extract unit-level covariate matrix from period_1 observations
655
+ base_df = df[df[time] == period_1].set_index(unit).reindex(all_units)
656
+ covariate_matrix = base_df[list(covariates)].values.astype(float)
657
+
658
+ # ----- Core estimation: ATT(g, t) for each target -----
659
+ # Precompute per-group unit counts (avoid repeated np.sum in loop)
660
+ n_treated_per_g = {g: int(np.sum(cohort_masks[g])) for g in treatment_groups}
661
+ n_control_count = int(np.sum(never_treated_mask))
662
+
663
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] = {}
664
+ eif_by_gt: Dict[Tuple[Any, Any], np.ndarray] = {}
665
+ stored_weights: Dict[Tuple[Any, Any], np.ndarray] = {}
666
+ stored_cond: Dict[Tuple[Any, Any], float] = {}
667
+
668
+ for g in treatment_groups:
669
+ # Under PT-Post, use per-group baseline Y_{g-1-anticipation}
670
+ # instead of the universal Y_1. This implements the weaker
671
+ # PT-Post assumption (parallel trends only from g-1 onward),
672
+ # matching the Callaway-Sant'Anna estimator exactly.
673
+ if self.pt_assumption == "post":
674
+ effective_base = g - 1 - self.anticipation
675
+ if effective_base not in period_to_col:
676
+ warnings.warn(
677
+ f"Cohort g={g} dropped: baseline period {effective_base} "
678
+ f"(g-1-anticipation) is not in the data.",
679
+ UserWarning,
680
+ stacklevel=2,
681
+ )
682
+ continue
683
+ effective_p1_col = period_to_col[effective_base]
684
+ else:
685
+ effective_p1_col = period_1_col
686
+
687
+ # Guard: skip cohorts with zero survey weight (all units zero-weighted)
688
+ if cohort_fractions[g] <= 0:
689
+ warnings.warn(
690
+ f"Cohort {g} has zero survey weight; skipping.",
691
+ UserWarning,
692
+ stacklevel=2,
693
+ )
694
+ continue
695
+
696
+ # Estimate all (g, t) cells including pre-treatment. Under PT-Post,
697
+ # pre-treatment cells serve as placebo/pre-trend diagnostics, matching
698
+ # the CallawaySantAnna implementation. Users filter to t >= g for
699
+ # post-treatment effects; pre-treatment cells are clearly labeled by
700
+ # their (g, t) coordinates in the results object.
701
+ for t in time_periods:
702
+ # Skip period_1 — it's the universal reference baseline,
703
+ # not a target period
704
+ if t == period_1:
705
+ continue
706
+
707
+ # Enumerate valid comparison pairs
708
+ pairs = enumerate_valid_triples(
709
+ target_g=g,
710
+ treatment_groups=treatment_groups,
711
+ time_periods=time_periods,
712
+ period_1=period_1,
713
+ pt_assumption=self.pt_assumption,
714
+ anticipation=self.anticipation,
715
+ )
716
+
717
+ # Filter out comparison pairs with zero survey weight
718
+ if unit_level_weights is not None and pairs:
719
+ pairs = [
720
+ (gp, tpre) for gp, tpre in pairs
721
+ if np.sum(unit_level_weights[
722
+ never_treated_mask if np.isinf(gp) else cohort_masks[gp]
723
+ ]) > 0
724
+ ]
725
+
726
+ if not pairs:
727
+ warnings.warn(
728
+ f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
729
+ UserWarning,
730
+ stacklevel=2,
731
+ )
732
+ t_stat, p_val, ci = np.nan, np.nan, (np.nan, np.nan)
733
+ group_time_effects[(g, t)] = {
734
+ "effect": np.nan,
735
+ "se": np.nan,
736
+ "t_stat": t_stat,
737
+ "p_value": p_val,
738
+ "conf_int": ci,
739
+ "n_treated": n_treated_per_g[g],
740
+ "n_control": n_control_count,
741
+ }
742
+ eif_by_gt[(g, t)] = np.zeros(n_units)
743
+ continue
744
+
745
+ if use_covariates:
746
+ assert covariate_matrix is not None
747
+ t_col_val = period_to_col[t]
748
+
749
+ # Lazily populate nuisance caches for this (g, t)
750
+ for gp, tpre in pairs:
751
+ tpre_col_val = period_to_col[tpre]
752
+ # m_{inf, t, tpre}(X)
753
+ key_inf_t = (np.inf, t_col_val, tpre_col_val)
754
+ if key_inf_t not in m_hat_cache:
755
+ m_hat_cache[key_inf_t] = estimate_outcome_regression(
756
+ outcome_wide,
757
+ covariate_matrix,
758
+ never_treated_mask,
759
+ t_col_val,
760
+ tpre_col_val,
761
+ unit_weights=unit_level_weights,
762
+ )
763
+ # m_{g', tpre, 1}(X)
764
+ key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
765
+ if key_gp_tpre not in m_hat_cache:
766
+ gp_mask_for_reg = (
767
+ never_treated_mask if np.isinf(gp) else cohort_masks[gp]
768
+ )
769
+ m_hat_cache[key_gp_tpre] = estimate_outcome_regression(
770
+ outcome_wide,
771
+ covariate_matrix,
772
+ gp_mask_for_reg,
773
+ tpre_col_val,
774
+ effective_p1_col,
775
+ unit_weights=unit_level_weights,
776
+ )
777
+ # r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
778
+ for comp in {np.inf, gp}:
779
+ rkey = (g, comp)
780
+ if rkey not in r_hat_cache:
781
+ comp_mask = (
782
+ never_treated_mask if np.isinf(comp) else cohort_masks[comp]
783
+ )
784
+ r_hat_cache[rkey] = estimate_propensity_ratio_sieve(
785
+ covariate_matrix,
786
+ cohort_masks[g],
787
+ comp_mask,
788
+ k_max=self.sieve_k_max,
789
+ criterion=self.sieve_criterion,
790
+ ratio_clip=self.ratio_clip,
791
+ unit_weights=unit_level_weights,
792
+ )
793
+
794
+ # Per-unit DR generated outcomes: shape (n_units, H)
795
+ gen_out = compute_generated_outcomes_cov(
796
+ target_g=g,
797
+ target_t=t,
798
+ valid_pairs=pairs,
799
+ outcome_wide=outcome_wide,
800
+ cohort_masks=cohort_masks,
801
+ never_treated_mask=never_treated_mask,
802
+ period_to_col=period_to_col,
803
+ period_1_col=effective_p1_col,
804
+ cohort_fractions=cohort_fractions,
805
+ m_hat_cache=m_hat_cache,
806
+ r_hat_cache=r_hat_cache,
807
+ )
808
+
809
+ y_hat = np.mean(gen_out, axis=0) # shape (H,)
810
+
811
+ # Inverse propensity estimation (algorithm step 4)
812
+ # s_hat_{g'}(X) = 1/p_{g'}(X) for Eq 3.12 scaling
813
+ for group_id in {g, np.inf} | {gp for gp, _ in pairs}:
814
+ if group_id not in s_hat_cache:
815
+ group_mask_s = (
816
+ never_treated_mask if np.isinf(group_id) else cohort_masks[group_id]
817
+ )
818
+ s_hat_cache[group_id] = estimate_inverse_propensity_sieve(
819
+ covariate_matrix,
820
+ group_mask_s,
821
+ k_max=self.sieve_k_max,
822
+ criterion=self.sieve_criterion,
823
+ unit_weights=unit_level_weights,
824
+ )
825
+
826
+ # Conditional Omega*(X) with per-unit propensities (Eq 3.12)
827
+ omega_cond = compute_omega_star_conditional(
828
+ target_g=g,
829
+ target_t=t,
830
+ valid_pairs=pairs,
831
+ outcome_wide=outcome_wide,
832
+ cohort_masks=cohort_masks,
833
+ never_treated_mask=never_treated_mask,
834
+ period_to_col=period_to_col,
835
+ period_1_col=effective_p1_col,
836
+ cohort_fractions=cohort_fractions,
837
+ covariate_matrix=covariate_matrix,
838
+ s_hat_cache=s_hat_cache,
839
+ bandwidth=self.kernel_bandwidth,
840
+ unit_weights=unit_level_weights,
841
+ )
842
+
843
+ # Per-unit weights: (n_units, H)
844
+ per_unit_w = compute_per_unit_weights(omega_cond)
845
+
846
+ # ATT = (survey-)weighted mean of per-unit DR scores
847
+ if per_unit_w.shape[1] > 0:
848
+ per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
849
+ if unit_level_weights is not None:
850
+ att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
851
+ else:
852
+ att_gt = float(np.mean(per_unit_scores))
853
+ else:
854
+ att_gt = np.nan
855
+
856
+ # EIF with per-unit weights (Remark 4.2: plug-in valid)
857
+ # Center on scalar ATT, not per-pair means (ensures mean(EIF) ≈ 0)
858
+ eif_vals = compute_eif_cov(per_unit_w, gen_out, att_gt, n_units)
859
+ eif_by_gt[(g, t)] = eif_vals
860
+ else:
861
+ # No-covariates path (closed-form)
862
+ omega = compute_omega_star_nocov(
863
+ target_g=g,
864
+ target_t=t,
865
+ valid_pairs=pairs,
866
+ outcome_wide=outcome_wide,
867
+ cohort_masks=cohort_masks,
868
+ never_treated_mask=never_treated_mask,
869
+ period_to_col=period_to_col,
870
+ period_1_col=effective_p1_col,
871
+ cohort_fractions=cohort_fractions,
872
+ unit_weights=unit_level_weights,
873
+ )
874
+
875
+ weights, _, cond_num = compute_efficient_weights(omega)
876
+ stored_weights[(g, t)] = weights
877
+ if omega.size > 0:
878
+ stored_cond[(g, t)] = cond_num
879
+
880
+ y_hat = compute_generated_outcomes_nocov(
881
+ target_g=g,
882
+ target_t=t,
883
+ valid_pairs=pairs,
884
+ outcome_wide=outcome_wide,
885
+ cohort_masks=cohort_masks,
886
+ never_treated_mask=never_treated_mask,
887
+ period_to_col=period_to_col,
888
+ period_1_col=effective_p1_col,
889
+ unit_weights=unit_level_weights,
890
+ )
891
+
892
+ att_gt = float(weights @ y_hat) if len(weights) > 0 else np.nan
893
+
894
+ eif_vals = compute_eif_nocov(
895
+ target_g=g,
896
+ target_t=t,
897
+ weights=weights,
898
+ valid_pairs=pairs,
899
+ outcome_wide=outcome_wide,
900
+ cohort_masks=cohort_masks,
901
+ never_treated_mask=never_treated_mask,
902
+ period_to_col=period_to_col,
903
+ period_1_col=effective_p1_col,
904
+ cohort_fractions=cohort_fractions,
905
+ n_units=n_units,
906
+ unit_weights=unit_level_weights,
907
+ )
908
+ eif_by_gt[(g, t)] = eif_vals
909
+
910
+ # Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
911
+ # With survey: use TSL variance via compute_survey_vcov
912
+ if self._unit_resolved_survey is not None:
913
+ se_gt = self._compute_survey_eif_se(eif_vals)
914
+ else:
915
+ se_gt = _compute_se_from_eif(
916
+ eif_vals, n_units, unit_cluster_indices, n_clusters
917
+ )
918
+
919
+ t_stat, p_val, ci = safe_inference(
920
+ att_gt, se_gt, alpha=self.alpha, df=self._survey_df
921
+ )
922
+
923
+ group_time_effects[(g, t)] = {
924
+ "effect": att_gt,
925
+ "se": se_gt,
926
+ "t_stat": t_stat,
927
+ "p_value": p_val,
928
+ "conf_int": ci,
929
+ "n_treated": int(np.sum(cohort_masks[g])),
930
+ "n_control": int(np.sum(never_treated_mask)),
931
+ }
932
+
933
+ if not group_time_effects:
934
+ raise ValueError(
935
+ "Could not estimate any group-time effects. "
936
+ "Check data has sufficient observations."
937
+ )
938
+
939
+ # ----- Aggregation -----
940
+ overall_att, overall_se = self._aggregate_overall(
941
+ group_time_effects,
942
+ eif_by_gt,
943
+ n_units,
944
+ cohort_fractions,
945
+ unit_cohorts,
946
+ cluster_indices=unit_cluster_indices,
947
+ n_clusters=n_clusters,
948
+ )
949
+ overall_t, overall_p, overall_ci = safe_inference(
950
+ overall_att, overall_se, alpha=self.alpha, df=self._survey_df
951
+ )
952
+
953
+ event_study_effects = None
954
+ group_effects = None
955
+
956
+ if aggregate in ("event_study", "all"):
957
+ event_study_effects = self._aggregate_event_study(
958
+ group_time_effects,
959
+ eif_by_gt,
960
+ n_units,
961
+ cohort_fractions,
962
+ treatment_groups,
963
+ time_periods,
964
+ balance_e,
965
+ unit_cohorts=unit_cohorts,
966
+ cluster_indices=unit_cluster_indices,
967
+ n_clusters=n_clusters,
968
+ )
969
+ if aggregate in ("group", "all"):
970
+ group_effects = self._aggregate_by_group(
971
+ group_time_effects,
972
+ eif_by_gt,
973
+ n_units,
974
+ cohort_fractions,
975
+ treatment_groups,
976
+ unit_cohorts=unit_cohorts,
977
+ cluster_indices=unit_cluster_indices,
978
+ n_clusters=n_clusters,
979
+ )
980
+
981
+ # ----- Bootstrap -----
982
+ # Reject replicate-weight designs for bootstrap — replicate variance
983
+ # is an analytical alternative, not compatible with bootstrap
984
+ if (
985
+ self.n_bootstrap > 0
986
+ and self._unit_resolved_survey is not None
987
+ and self._unit_resolved_survey.uses_replicate_variance
988
+ ):
989
+ raise NotImplementedError(
990
+ "EfficientDiD bootstrap (n_bootstrap > 0) is not supported "
991
+ "with replicate-weight survey designs. Replicate weights provide "
992
+ "analytical variance; use n_bootstrap=0 instead."
993
+ )
994
+ bootstrap_results = None
995
+ if self.n_bootstrap > 0 and eif_by_gt:
996
+ bootstrap_results = self._run_multiplier_bootstrap(
997
+ group_time_effects=group_time_effects,
998
+ eif_by_gt=eif_by_gt,
999
+ n_units=n_units,
1000
+ aggregate=aggregate,
1001
+ balance_e=balance_e,
1002
+ treatment_groups=treatment_groups,
1003
+ cohort_fractions=cohort_fractions,
1004
+ cluster_indices=unit_cluster_indices,
1005
+ n_clusters=n_clusters,
1006
+ resolved_survey=self._unit_resolved_survey,
1007
+ unit_level_weights=self._unit_level_weights,
1008
+ )
1009
+ # Update estimates with bootstrap inference
1010
+ overall_se = bootstrap_results.overall_att_se
1011
+ overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
1012
+ overall_p = bootstrap_results.overall_att_p_value
1013
+ overall_ci = bootstrap_results.overall_att_ci
1014
+
1015
+ for gt in group_time_effects:
1016
+ if gt in bootstrap_results.group_time_ses:
1017
+ group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt]
1018
+ group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt]
1019
+ group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt]
1020
+ eff = float(group_time_effects[gt]["effect"])
1021
+ se = float(group_time_effects[gt]["se"])
1022
+ group_time_effects[gt]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
1023
+
1024
+ es_cis = bootstrap_results.event_study_cis
1025
+ es_pvs = bootstrap_results.event_study_p_values
1026
+ if (
1027
+ event_study_effects is not None
1028
+ and bootstrap_results.event_study_ses is not None
1029
+ and es_cis is not None
1030
+ and es_pvs is not None
1031
+ ):
1032
+ for e in event_study_effects:
1033
+ if e in bootstrap_results.event_study_ses:
1034
+ event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
1035
+ event_study_effects[e]["conf_int"] = es_cis[e]
1036
+ event_study_effects[e]["p_value"] = es_pvs[e]
1037
+ eff = float(event_study_effects[e]["effect"])
1038
+ se = float(event_study_effects[e]["se"])
1039
+ event_study_effects[e]["t_stat"] = safe_inference(
1040
+ eff, se, alpha=self.alpha
1041
+ )[0]
1042
+
1043
+ g_cis = bootstrap_results.group_effect_cis
1044
+ g_pvs = bootstrap_results.group_effect_p_values
1045
+ if (
1046
+ group_effects is not None
1047
+ and bootstrap_results.group_effect_ses is not None
1048
+ and g_cis is not None
1049
+ and g_pvs is not None
1050
+ ):
1051
+ for g in group_effects:
1052
+ if g in bootstrap_results.group_effect_ses:
1053
+ group_effects[g]["se"] = bootstrap_results.group_effect_ses[g]
1054
+ group_effects[g]["conf_int"] = g_cis[g]
1055
+ group_effects[g]["p_value"] = g_pvs[g]
1056
+ eff = float(group_effects[g]["effect"])
1057
+ se = float(group_effects[g]["se"])
1058
+ group_effects[g]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
1059
+
1060
+ # ----- Build results -----
1061
+ self.results_ = EfficientDiDResults(
1062
+ group_time_effects=group_time_effects,
1063
+ overall_att=overall_att,
1064
+ overall_se=overall_se,
1065
+ overall_t_stat=overall_t,
1066
+ overall_p_value=overall_p,
1067
+ overall_conf_int=overall_ci,
1068
+ groups=treatment_groups,
1069
+ time_periods=time_periods,
1070
+ n_obs=n_units * len(time_periods),
1071
+ n_treated_units=n_treated_units,
1072
+ n_control_units=n_control_units,
1073
+ alpha=self.alpha,
1074
+ pt_assumption=self.pt_assumption,
1075
+ anticipation=self.anticipation,
1076
+ n_bootstrap=self.n_bootstrap,
1077
+ bootstrap_weights=self.bootstrap_weights,
1078
+ seed=self.seed,
1079
+ event_study_effects=event_study_effects,
1080
+ group_effects=group_effects,
1081
+ efficient_weights=stored_weights if stored_weights else None,
1082
+ omega_condition_numbers=stored_cond if stored_cond else None,
1083
+ control_group=self.control_group,
1084
+ influence_functions=eif_by_gt if store_eif else None,
1085
+ bootstrap_results=bootstrap_results,
1086
+ estimation_path="dr" if use_covariates else "nocov",
1087
+ sieve_k_max=self.sieve_k_max,
1088
+ sieve_criterion=self.sieve_criterion,
1089
+ ratio_clip=self.ratio_clip,
1090
+ kernel_bandwidth=self.kernel_bandwidth,
1091
+ survey_metadata=(
1092
+ self._recompute_unit_survey_metadata(survey_metadata)
1093
+ if survey_metadata is not None
1094
+ else None
1095
+ ),
1096
+ )
1097
+ self.is_fitted_ = True
1098
+ return self.results_
1099
+
1100
+ def _recompute_unit_survey_metadata(self, panel_metadata):
1101
+ """Recompute survey metadata from unit-level design if available."""
1102
+ if self._unit_resolved_survey is not None:
1103
+ from diff_diff.survey import compute_survey_metadata
1104
+
1105
+ meta = compute_survey_metadata(
1106
+ self._unit_resolved_survey,
1107
+ self._unit_resolved_survey.weights,
1108
+ )
1109
+ # Propagate effective replicate df if available
1110
+ # (but not the df=0 sentinel — keep metadata as None for undefined df)
1111
+ if (self._survey_df is not None and self._survey_df != 0
1112
+ and meta.df_survey != self._survey_df):
1113
+ meta.df_survey = self._survey_df
1114
+ return meta
1115
+ return panel_metadata
1116
+
1117
+ # -- Survey SE helpers ----------------------------------------------------
1118
+
1119
+ def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
1120
+ """Compute SE from EIF scores using Taylor Series Linearization.
1121
+
1122
+ Uses the pre-built unit-level ``_unit_resolved_survey`` constructed
1123
+ once in ``fit()``, ensuring consistent unit-level arrays and
1124
+ avoiding repeated subsetting of panel-level survey data.
1125
+ """
1126
+ if self._unit_resolved_survey.uses_replicate_variance:
1127
+ from diff_diff.survey import compute_replicate_if_variance
1128
+
1129
+ # Score-scale IFs to match TSL bread: psi = w * eif / sum(w)
1130
+ w = self._unit_resolved_survey.weights
1131
+ psi_scaled = w * eif_vals / w.sum()
1132
+ variance, n_valid = compute_replicate_if_variance(psi_scaled, self._unit_resolved_survey)
1133
+ # Update survey df to reflect effective replicate count
1134
+ if n_valid < self._unit_resolved_survey.n_replicates:
1135
+ self._survey_df = n_valid - 1 if n_valid > 1 else None
1136
+ return float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1137
+
1138
+ from diff_diff.survey import compute_survey_vcov
1139
+
1140
+ X_ones = np.ones((len(eif_vals), 1))
1141
+ vcov = compute_survey_vcov(X_ones, eif_vals, self._unit_resolved_survey)
1142
+ return float(np.sqrt(np.abs(vcov[0, 0])))
1143
+
1144
+ def _eif_se(
1145
+ self,
1146
+ eif_vals: np.ndarray,
1147
+ n_units: int,
1148
+ cluster_indices: Optional[np.ndarray] = None,
1149
+ n_clusters: Optional[int] = None,
1150
+ ) -> float:
1151
+ """Compute SE from aggregated EIF scores.
1152
+
1153
+ Dispatches to survey TSL when ``_unit_resolved_survey`` is set
1154
+ (during fit), otherwise uses cluster-robust or standard formula.
1155
+ """
1156
+ if self._unit_resolved_survey is not None:
1157
+ return self._compute_survey_eif_se(eif_vals)
1158
+ return _compute_se_from_eif(eif_vals, n_units, cluster_indices, n_clusters)
1159
+
1160
+ # -- Aggregation helpers --------------------------------------------------
1161
+
1162
+ def _compute_wif_contribution(
1163
+ self,
1164
+ keepers: List[Tuple],
1165
+ effects: np.ndarray,
1166
+ unit_cohorts: np.ndarray,
1167
+ cohort_fractions: Dict[float, float],
1168
+ n_units: int,
1169
+ unit_weights: Optional[np.ndarray] = None,
1170
+ ) -> np.ndarray:
1171
+ """Compute weight influence function correction (O(1) scale, matching EIF).
1172
+
1173
+ This accounts for uncertainty in cohort-size aggregation weights.
1174
+ Matches R's ``did`` package WIF formula (staggered_aggregation.py:282-309),
1175
+ adapted to EDiD's EIF scale.
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ keepers : list of (g, t) tuples
1180
+ Post-treatment group-time pairs included in aggregation.
1181
+ effects : ndarray, shape (n_keepers,)
1182
+ ATT estimates for each keeper.
1183
+ unit_cohorts : ndarray, shape (n_units,)
1184
+ Cohort assignment for each unit (0 = never-treated).
1185
+ cohort_fractions : dict
1186
+ ``{cohort: n_cohort / n}`` for each cohort.
1187
+ n_units : int
1188
+ Total number of units.
1189
+ unit_weights : ndarray, shape (n_units,), optional
1190
+ Survey weights at the unit level. When provided, uses the
1191
+ survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
1192
+
1193
+ Returns
1194
+ -------
1195
+ ndarray, shape (n_units,)
1196
+ WIF contribution at O(1) scale, additive with ``agg_eif``.
1197
+ """
1198
+ groups_for_keepers = np.array([g for (g, t) in keepers])
1199
+ pg_keepers = np.array([cohort_fractions.get(g, 0.0) for g, t in keepers])
1200
+ sum_pg = pg_keepers.sum()
1201
+ if sum_pg == 0:
1202
+ return np.zeros(n_units)
1203
+
1204
+ indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
1205
+
1206
+ if unit_weights is not None:
1207
+ # Survey-weighted WIF (matches staggered_aggregation.py:392-401):
1208
+ # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
1209
+ weighted_indicator = indicator * unit_weights[:, None]
1210
+ indicator_diff = weighted_indicator - pg_keepers
1211
+ indicator_sum = np.sum(indicator_diff, axis=1)
1212
+ else:
1213
+ indicator_diff = indicator - pg_keepers
1214
+ indicator_sum = np.sum(indicator_diff, axis=1)
1215
+
1216
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1217
+ if1 = indicator_diff / sum_pg
1218
+ if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
1219
+ wif_matrix = if1 - if2
1220
+ wif_contrib = wif_matrix @ effects
1221
+ return wif_contrib # O(1) scale, same as agg_eif
1222
+
1223
+ def _aggregate_overall(
1224
+ self,
1225
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
1226
+ eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
1227
+ n_units: int,
1228
+ cohort_fractions: Dict[float, float],
1229
+ unit_cohorts: np.ndarray,
1230
+ cluster_indices: Optional[np.ndarray] = None,
1231
+ n_clusters: Optional[int] = None,
1232
+ ) -> Tuple[float, float]:
1233
+ """Compute overall ATT with WIF-adjusted SE.
1234
+
1235
+ Parameters
1236
+ ----------
1237
+ group_time_effects : dict
1238
+ Group-time ATT estimates.
1239
+ eif_by_gt : dict
1240
+ Per-unit EIF values for each (g, t).
1241
+ n_units : int
1242
+ Total number of units.
1243
+ cohort_fractions : dict
1244
+ Cohort size fractions.
1245
+ unit_cohorts : ndarray, shape (n_units,)
1246
+ Cohort assignment for each unit.
1247
+ """
1248
+ # Filter to post-treatment effects
1249
+ keepers = [
1250
+ (g, t)
1251
+ for (g, t) in group_time_effects
1252
+ if t >= g - self.anticipation and np.isfinite(group_time_effects[(g, t)]["effect"])
1253
+ ]
1254
+ if not keepers:
1255
+ return np.nan, np.nan
1256
+
1257
+ # Cohort-size weights
1258
+ pg = np.array([cohort_fractions.get(g, 0.0) for (g, _) in keepers])
1259
+ total_pg = pg.sum()
1260
+ if total_pg == 0:
1261
+ return np.nan, np.nan
1262
+ w = pg / total_pg
1263
+
1264
+ effects = np.array([group_time_effects[gt]["effect"] for gt in keepers])
1265
+ overall_att = float(np.sum(w * effects))
1266
+
1267
+ # Aggregate EIF
1268
+ agg_eif = np.zeros(n_units)
1269
+ for k, gt in enumerate(keepers):
1270
+ agg_eif += w[k] * eif_by_gt[gt]
1271
+
1272
+ # WIF correction: accounts for uncertainty in cohort-size weights
1273
+ wif = self._compute_wif_contribution(
1274
+ keepers, effects, unit_cohorts, cohort_fractions, n_units,
1275
+ unit_weights=self._unit_level_weights,
1276
+ )
1277
+ # Compute SE: survey path uses score-level psi to avoid double-weighting
1278
+ # (compute_survey_vcov applies w_i internally, which would double-weight
1279
+ # the survey-weighted WIF term). Dispatch replicate vs TSL.
1280
+ if self._unit_resolved_survey is not None:
1281
+ uw = self._unit_level_weights
1282
+ total_w = float(np.sum(uw))
1283
+ psi_total = uw * agg_eif / total_w + wif / total_w
1284
+
1285
+ if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1286
+ and self._unit_resolved_survey.uses_replicate_variance):
1287
+ from diff_diff.survey import compute_replicate_if_variance
1288
+
1289
+ variance, _ = compute_replicate_if_variance(
1290
+ psi_total, self._unit_resolved_survey
1291
+ )
1292
+ else:
1293
+ from diff_diff.survey import compute_survey_if_variance
1294
+
1295
+ variance = compute_survey_if_variance(
1296
+ psi_total, self._unit_resolved_survey
1297
+ )
1298
+ se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1299
+ else:
1300
+ agg_eif_total = agg_eif + wif
1301
+ se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
1302
+
1303
+ return overall_att, se
1304
+
1305
+ def _aggregate_event_study(
1306
+ self,
1307
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
1308
+ eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
1309
+ n_units: int,
1310
+ cohort_fractions: Dict[float, float],
1311
+ treatment_groups: List[Any],
1312
+ time_periods: List[Any],
1313
+ balance_e: Optional[int] = None,
1314
+ unit_cohorts: Optional[np.ndarray] = None,
1315
+ cluster_indices: Optional[np.ndarray] = None,
1316
+ n_clusters: Optional[int] = None,
1317
+ ) -> Dict[int, Dict[str, Any]]:
1318
+ """Aggregate ATT(g,t) by relative time e = t - g.
1319
+
1320
+ Parameters
1321
+ ----------
1322
+ group_time_effects : dict
1323
+ Group-time ATT estimates.
1324
+ eif_by_gt : dict
1325
+ Per-unit EIF values for each (g, t).
1326
+ n_units : int
1327
+ Total number of units.
1328
+ cohort_fractions : dict
1329
+ Cohort size fractions.
1330
+ treatment_groups : list
1331
+ Treatment cohort identifiers.
1332
+ time_periods : list
1333
+ All time periods.
1334
+ balance_e : int, optional
1335
+ Balance event study at this relative period.
1336
+ unit_cohorts : ndarray, optional
1337
+ Cohort assignment for each unit (for WIF correction).
1338
+ """
1339
+ # Organize by relative time
1340
+ effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
1341
+ for (g, t), data in group_time_effects.items():
1342
+ if not np.isfinite(data["effect"]):
1343
+ continue
1344
+ e = int(t - g)
1345
+ if e not in effects_by_e:
1346
+ effects_by_e[e] = []
1347
+ effects_by_e[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
1348
+
1349
+ # Balance if requested
1350
+ if balance_e is not None:
1351
+ groups_at_e = {gt[0] for gt, _, _ in effects_by_e.get(balance_e, [])}
1352
+ balanced: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
1353
+ for (g, t), data in group_time_effects.items():
1354
+ if not np.isfinite(data["effect"]):
1355
+ continue
1356
+ if g in groups_at_e:
1357
+ e = int(t - g)
1358
+ if e not in balanced:
1359
+ balanced[e] = []
1360
+ balanced[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
1361
+ effects_by_e = balanced
1362
+
1363
+ if balance_e is not None and not effects_by_e:
1364
+ warnings.warn(
1365
+ f"balance_e={balance_e}: no cohort has a finite effect at the "
1366
+ "anchor horizon. Event study will be empty.",
1367
+ UserWarning,
1368
+ stacklevel=2,
1369
+ )
1370
+
1371
+ result: Dict[int, Dict[str, Any]] = {}
1372
+ for e, elist in sorted(effects_by_e.items()):
1373
+ gt_pairs = [x[0] for x in elist]
1374
+ effs = np.array([x[1] for x in elist])
1375
+ pgs = np.array([x[2] for x in elist])
1376
+ total_pg = pgs.sum()
1377
+ w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
1378
+
1379
+ agg_eff = float(np.sum(w * effs))
1380
+
1381
+ # Aggregate EIF
1382
+ agg_eif = np.zeros(n_units)
1383
+ for k, gt in enumerate(gt_pairs):
1384
+ agg_eif += w[k] * eif_by_gt[gt]
1385
+
1386
+ # WIF correction for event-study aggregation
1387
+ wif_e = np.zeros(n_units)
1388
+ if unit_cohorts is not None:
1389
+ es_keepers = [(g, t) for (g, t) in gt_pairs]
1390
+ es_effects = effs
1391
+ wif_e = self._compute_wif_contribution(
1392
+ es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
1393
+ unit_weights=self._unit_level_weights,
1394
+ )
1395
+
1396
+ if self._unit_resolved_survey is not None:
1397
+ uw = self._unit_level_weights
1398
+ total_w = float(np.sum(uw))
1399
+ psi_total = uw * agg_eif / total_w + wif_e / total_w
1400
+
1401
+ if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1402
+ and self._unit_resolved_survey.uses_replicate_variance):
1403
+ from diff_diff.survey import compute_replicate_if_variance
1404
+
1405
+ variance, _ = compute_replicate_if_variance(
1406
+ psi_total, self._unit_resolved_survey
1407
+ )
1408
+ else:
1409
+ from diff_diff.survey import compute_survey_if_variance
1410
+
1411
+ variance = compute_survey_if_variance(
1412
+ psi_total, self._unit_resolved_survey
1413
+ )
1414
+ agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1415
+ else:
1416
+ agg_eif = agg_eif + wif_e
1417
+ agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1418
+
1419
+ t_stat, p_val, ci = safe_inference(
1420
+ agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
1421
+ )
1422
+ result[e] = {
1423
+ "effect": agg_eff,
1424
+ "se": agg_se,
1425
+ "t_stat": t_stat,
1426
+ "p_value": p_val,
1427
+ "conf_int": ci,
1428
+ "n_groups": len(elist),
1429
+ }
1430
+
1431
+ return result
1432
+
1433
+ def _aggregate_by_group(
1434
+ self,
1435
+ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
1436
+ eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
1437
+ n_units: int,
1438
+ cohort_fractions: Dict[float, float],
1439
+ treatment_groups: List[Any],
1440
+ unit_cohorts: Optional[np.ndarray] = None,
1441
+ cluster_indices: Optional[np.ndarray] = None,
1442
+ n_clusters: Optional[int] = None,
1443
+ ) -> Dict[Any, Dict[str, Any]]:
1444
+ """Aggregate ATT(g,t) by treatment cohort.
1445
+
1446
+ Parameters
1447
+ ----------
1448
+ group_time_effects : dict
1449
+ Group-time ATT estimates.
1450
+ eif_by_gt : dict
1451
+ Per-unit EIF values for each (g, t).
1452
+ n_units : int
1453
+ Total number of units.
1454
+ cohort_fractions : dict
1455
+ Cohort size fractions.
1456
+ treatment_groups : list
1457
+ Treatment cohort identifiers.
1458
+ unit_cohorts : ndarray, optional
1459
+ Cohort assignment for each unit (unused — group aggregation
1460
+ uses equal weights, not cohort-size weights).
1461
+ """
1462
+ result: Dict[Any, Dict[str, Any]] = {}
1463
+ for g in treatment_groups:
1464
+ g_gts = [
1465
+ (gg, t)
1466
+ for (gg, t) in group_time_effects
1467
+ if gg == g
1468
+ and t >= g - self.anticipation
1469
+ and np.isfinite(group_time_effects[(gg, t)]["effect"])
1470
+ ]
1471
+ if not g_gts:
1472
+ continue
1473
+
1474
+ effs = np.array([group_time_effects[gt]["effect"] for gt in g_gts])
1475
+ w = np.ones(len(effs)) / len(effs)
1476
+ agg_eff = float(np.sum(w * effs))
1477
+
1478
+ agg_eif = np.zeros(n_units)
1479
+ for k, gt in enumerate(g_gts):
1480
+ agg_eif += w[k] * eif_by_gt[gt]
1481
+ agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1482
+
1483
+ t_stat, p_val, ci = safe_inference(
1484
+ agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
1485
+ )
1486
+ result[g] = {
1487
+ "effect": agg_eff,
1488
+ "se": agg_se,
1489
+ "t_stat": t_stat,
1490
+ "p_value": p_val,
1491
+ "conf_int": ci,
1492
+ "n_periods": len(g_gts),
1493
+ }
1494
+
1495
+ return result
1496
+
1497
+ def summary(self) -> str:
1498
+ """Get summary of estimation results."""
1499
+ if not self.is_fitted_:
1500
+ raise RuntimeError("Model must be fitted before calling summary()")
1501
+ assert self.results_ is not None
1502
+ return self.results_.summary()
1503
+
1504
+ def print_summary(self) -> None:
1505
+ """Print summary to stdout."""
1506
+ print(self.summary())
1507
+
1508
+ # -- Hausman pretest -------------------------------------------------------
1509
+
1510
+ @classmethod
1511
+ def hausman_pretest(
1512
+ cls,
1513
+ data: pd.DataFrame,
1514
+ outcome: str,
1515
+ unit: str,
1516
+ time: str,
1517
+ first_treat: str,
1518
+ covariates: Optional[List[str]] = None,
1519
+ cluster: Optional[str] = None,
1520
+ anticipation: int = 0,
1521
+ control_group: str = "never_treated",
1522
+ alpha: float = 0.05,
1523
+ **nuisance_kwargs: Any,
1524
+ ) -> HausmanPretestResult:
1525
+ """Hausman pretest for PT-All vs PT-Post (Theorem A.1).
1526
+
1527
+ Fits the estimator under both parallel trends assumptions and
1528
+ compares the results. Under H0 (PT-All holds), both are consistent
1529
+ but PT-All is more efficient. Rejection suggests PT-All is too
1530
+ strong; use PT-Post instead.
1531
+
1532
+ Parameters
1533
+ ----------
1534
+ data, outcome, unit, time, first_treat, covariates
1535
+ Same as :meth:`fit`.
1536
+ cluster : str, optional
1537
+ Cluster column for cluster-robust covariance.
1538
+ anticipation : int
1539
+ Anticipation periods.
1540
+ control_group : str
1541
+ ``"never_treated"`` or ``"last_cohort"``.
1542
+ alpha : float
1543
+ Significance level for the test.
1544
+ **nuisance_kwargs
1545
+ Passed to both fits (e.g. ``sieve_k_max``, ``ratio_clip``).
1546
+
1547
+ Returns
1548
+ -------
1549
+ HausmanPretestResult
1550
+ """
1551
+ from scipy.stats import chi2
1552
+
1553
+ # Fit under both assumptions (analytical SEs only, no bootstrap)
1554
+ common_kwargs = dict(
1555
+ cluster=cluster,
1556
+ control_group=control_group,
1557
+ anticipation=anticipation,
1558
+ n_bootstrap=0,
1559
+ **nuisance_kwargs,
1560
+ )
1561
+ fit_kwargs = dict(
1562
+ data=data,
1563
+ outcome=outcome,
1564
+ unit=unit,
1565
+ time=time,
1566
+ first_treat=first_treat,
1567
+ covariates=covariates,
1568
+ aggregate=None,
1569
+ )
1570
+
1571
+ edid_all = cls(pt_assumption="all", alpha=alpha, **common_kwargs)
1572
+ result_all = edid_all.fit(**fit_kwargs, store_eif=True)
1573
+
1574
+ edid_post = cls(pt_assumption="post", alpha=alpha, **common_kwargs)
1575
+ result_post = edid_post.fit(**fit_kwargs, store_eif=True)
1576
+
1577
+ # Find common (g,t) pairs — PT-Post pairs are a subset of PT-All
1578
+ common_gts = sorted(
1579
+ set(result_all.group_time_effects.keys()) & set(result_post.group_time_effects.keys())
1580
+ )
1581
+
1582
+ def _nan_result() -> HausmanPretestResult:
1583
+ return HausmanPretestResult(
1584
+ statistic=np.nan,
1585
+ p_value=np.nan,
1586
+ df=0,
1587
+ reject=False,
1588
+ alpha=alpha,
1589
+ att_all=result_all.overall_att,
1590
+ att_post=result_post.overall_att,
1591
+ recommendation="inconclusive",
1592
+ gt_details=None,
1593
+ )
1594
+
1595
+ if not common_gts:
1596
+ return _nan_result()
1597
+
1598
+ eif_all = result_all.influence_functions
1599
+ eif_post = result_post.influence_functions
1600
+ assert eif_all is not None and eif_post is not None
1601
+ n_units = len(next(iter(eif_all.values())))
1602
+
1603
+ # --- Aggregate to post-treatment ES(e) per Theorem A.1 ---
1604
+ # Derive cohort fractions from data for proper weights
1605
+ all_units_list = sorted(data[unit].unique())
1606
+ unit_cohorts = (
1607
+ data.groupby(unit)[first_treat].first().reindex(all_units_list).values.astype(float)
1608
+ )
1609
+ cohort_fractions: Dict[float, float] = {}
1610
+ for g in set(result_all.groups) | set(result_post.groups):
1611
+ cohort_fractions[g] = float(np.sum(unit_cohorts == g)) / n_units
1612
+
1613
+ def _aggregate_es(
1614
+ gt_effects: Dict, eif_dict: Dict, groups: List, ant: int
1615
+ ) -> Dict[int, Tuple[float, np.ndarray]]:
1616
+ """Aggregate (g,t) effects to post-treatment ES(e) with WIF-corrected EIF."""
1617
+ by_e: Dict[int, List[Tuple[Tuple, float, float, np.ndarray]]] = {}
1618
+ for (g, t), d in gt_effects.items():
1619
+ e = int(t - g)
1620
+ if e < -ant:
1621
+ continue
1622
+ if not np.isfinite(d["effect"]):
1623
+ continue
1624
+ if (g, t) not in eif_dict:
1625
+ continue
1626
+ eif_vec = eif_dict[(g, t)]
1627
+ if not np.all(np.isfinite(eif_vec)):
1628
+ continue
1629
+ pg = cohort_fractions.get(g, 0.0)
1630
+ if e not in by_e:
1631
+ by_e[e] = []
1632
+ by_e[e].append(((g, t), d["effect"], pg, eif_vec))
1633
+
1634
+ result: Dict[int, Tuple[float, np.ndarray]] = {}
1635
+ for e, items in by_e.items():
1636
+ if e < 0:
1637
+ continue
1638
+ effs = np.array([x[1] for x in items])
1639
+ pgs = np.array([x[2] for x in items])
1640
+ eifs = [x[3] for x in items]
1641
+ gt_pairs_e = [x[0] for x in items]
1642
+ total_pg = pgs.sum()
1643
+ w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
1644
+ es_eff = float(np.sum(w * effs))
1645
+ es_eif = np.zeros(n_units)
1646
+ for k_idx in range(len(eifs)):
1647
+ es_eif += w[k_idx] * eifs[k_idx]
1648
+ # WIF correction for estimated cohort-size weights
1649
+ groups_e = np.array([g for (g, t) in gt_pairs_e])
1650
+ pg_e = np.array([cohort_fractions.get(g, 0.0) for g, t in gt_pairs_e])
1651
+ sum_pg = pg_e.sum()
1652
+ if sum_pg > 0:
1653
+ indicator = (unit_cohorts[:, None] == groups_e[None, :]).astype(float)
1654
+ indicator_sum = np.sum(indicator - pg_e, axis=1)
1655
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1656
+ if1 = (indicator - pg_e) / sum_pg
1657
+ if2 = np.outer(indicator_sum, pg_e) / sum_pg**2
1658
+ wif = (if1 - if2) @ effs
1659
+ es_eif = es_eif + wif
1660
+ result[e] = (es_eff, es_eif)
1661
+ return result
1662
+
1663
+ es_all = _aggregate_es(
1664
+ result_all.group_time_effects, eif_all, result_all.groups, anticipation
1665
+ )
1666
+ es_post = _aggregate_es(
1667
+ result_post.group_time_effects, eif_post, result_post.groups, anticipation
1668
+ )
1669
+
1670
+ # Find common post-treatment horizons
1671
+ common_e = sorted(set(es_all.keys()) & set(es_post.keys()))
1672
+ if not common_e:
1673
+ return _nan_result()
1674
+
1675
+ delta = np.array([es_post[e][0] - es_all[e][0] for e in common_e])
1676
+
1677
+ # Build ES(e)-level EIF matrices
1678
+ eif_all_mat = np.column_stack([es_all[e][1] for e in common_e])
1679
+ eif_post_mat = np.column_stack([es_post[e][1] for e in common_e])
1680
+
1681
+ # Filter units with non-finite EIF values
1682
+ row_finite = np.all(np.isfinite(eif_all_mat), axis=1) & np.all(
1683
+ np.isfinite(eif_post_mat), axis=1
1684
+ )
1685
+ cl_idx: Optional[np.ndarray] = None
1686
+ n_cl: Optional[int] = None
1687
+ if cluster is not None:
1688
+ cl_idx, n_cl = _validate_and_build_cluster_mapping(data, unit, cluster, all_units_list)
1689
+ if not np.all(row_finite):
1690
+ eif_all_mat = eif_all_mat[row_finite]
1691
+ eif_post_mat = eif_post_mat[row_finite]
1692
+ n_units = int(np.sum(row_finite))
1693
+ if cl_idx is not None:
1694
+ cl_idx = cl_idx[row_finite]
1695
+ # Recompute effective cluster count and remap to contiguous
1696
+ # indices — entire clusters may have been dropped by filtering
1697
+ unique_cl, cl_idx = np.unique(cl_idx, return_inverse=True)
1698
+ n_cl = len(unique_cl)
1699
+
1700
+ # Compute full covariance matrices
1701
+ if cl_idx is not None and n_cl is not None:
1702
+
1703
+ def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
1704
+ centered = _cluster_aggregate(eif_mat, cl_idx, n_cl)
1705
+ correction = n_cl / (n_cl - 1) if n_cl > 1 else 1.0
1706
+ return correction * (centered.T @ centered) / (n_units**2)
1707
+
1708
+ cov_all = _eif_cov(eif_all_mat)
1709
+ cov_post = _eif_cov(eif_post_mat)
1710
+ else:
1711
+ with np.errstate(over="ignore", invalid="ignore"):
1712
+ cov_all = (eif_all_mat.T @ eif_all_mat) / (n_units**2)
1713
+ cov_post = (eif_post_mat.T @ eif_post_mat) / (n_units**2)
1714
+
1715
+ V = cov_post - cov_all
1716
+
1717
+ if not np.all(np.isfinite(V)):
1718
+ warnings.warn(
1719
+ "Hausman covariance matrix contains non-finite values. " "The test is unreliable.",
1720
+ UserWarning,
1721
+ stacklevel=2,
1722
+ )
1723
+ return _nan_result()
1724
+
1725
+ # Eigendecompose V — check for non-PSD
1726
+ eigvals = np.linalg.eigvalsh(V)
1727
+ max_eigval = np.max(np.abs(eigvals)) if len(eigvals) > 0 else 0.0
1728
+ tol = max(1e-10 * max_eigval, 1e-15)
1729
+
1730
+ n_negative = int(np.sum(eigvals < -tol))
1731
+ if n_negative > 0:
1732
+ warnings.warn(
1733
+ f"Hausman variance-difference matrix V has {n_negative} "
1734
+ "substantially negative eigenvalue(s). The test may be "
1735
+ "unreliable (finite-sample efficiency reversal).",
1736
+ UserWarning,
1737
+ stacklevel=2,
1738
+ )
1739
+
1740
+ effective_rank = int(np.sum(eigvals > tol))
1741
+ if effective_rank == 0:
1742
+ return _nan_result()
1743
+
1744
+ V_pinv = np.linalg.pinv(V, rcond=tol / max_eigval if max_eigval > 0 else 1e-10)
1745
+ H = float(delta @ V_pinv @ delta)
1746
+ H = max(H, 0.0)
1747
+
1748
+ p_value = float(chi2.sf(H, df=effective_rank))
1749
+ reject = p_value < alpha
1750
+
1751
+ es_details = pd.DataFrame(
1752
+ {
1753
+ "relative_period": common_e,
1754
+ "es_all": [es_all[e][0] for e in common_e],
1755
+ "es_post": [es_post[e][0] for e in common_e],
1756
+ "delta": delta,
1757
+ }
1758
+ )
1759
+
1760
+ return HausmanPretestResult(
1761
+ statistic=H,
1762
+ p_value=p_value,
1763
+ df=effective_rank,
1764
+ reject=reject,
1765
+ alpha=alpha,
1766
+ att_all=result_all.overall_att,
1767
+ att_post=result_post.overall_att,
1768
+ recommendation="pt_post" if reject else "pt_all",
1769
+ gt_details=es_details,
1770
+ )