diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,2458 @@
1
+ """
2
+ Borusyak-Jaravel-Spiess (2024) Imputation DiD Estimator.
3
+
4
+ Implements the efficient imputation estimator for staggered
5
+ Difference-in-Differences from Borusyak, Jaravel & Spiess (2024),
6
+ "Revisiting Event-Study Designs: Robust and Efficient Estimation",
7
+ Review of Economic Studies.
8
+
9
+ The estimator:
10
+ 1. Runs OLS on untreated observations to estimate unit + time fixed effects
11
+ 2. Imputes counterfactual Y(0) for treated observations
12
+ 3. Aggregates imputed treatment effects with researcher-chosen weights
13
+
14
+ Inference uses the conservative clustered variance estimator (Theorem 3).
15
+ """
16
+
17
+ import warnings
18
+ from typing import Any, Dict, List, Optional, Set, Tuple
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ from scipy import sparse, stats
23
+ from scipy.sparse.linalg import spsolve
24
+
25
+ from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin, _compute_target_weights
26
+ from diff_diff.imputation_results import ( # noqa: F401 (re-export)
27
+ ImputationBootstrapResults,
28
+ ImputationDiDResults,
29
+ )
30
+ from diff_diff.linalg import solve_ols
31
+ from diff_diff.utils import safe_inference
32
+
33
+ # =============================================================================
34
+ # Main Estimator
35
+ # =============================================================================
36
+
37
+
38
+ class ImputationDiD(ImputationDiDBootstrapMixin):
39
+ """
40
+ Borusyak-Jaravel-Spiess (2024) imputation DiD estimator.
41
+
42
+ This is the efficient estimator for staggered Difference-in-Differences
43
+ under parallel trends. It produces shorter confidence intervals than
44
+ Callaway-Sant'Anna (~50% shorter) and Sun-Abraham (2-3.5x shorter)
45
+ under homogeneous treatment effects.
46
+
47
+ The estimation procedure:
48
+ 1. Run OLS on untreated observations to estimate unit + time fixed effects
49
+ 2. Impute counterfactual Y(0) for treated observations
50
+ 3. Aggregate imputed treatment effects with researcher-chosen weights
51
+
52
+ Inference uses the conservative clustered variance estimator from Theorem 3
53
+ of the paper.
54
+
55
+ Parameters
56
+ ----------
57
+ anticipation : int, default=0
58
+ Number of periods before treatment where effects may occur.
59
+ alpha : float, default=0.05
60
+ Significance level for confidence intervals.
61
+ cluster : str, optional
62
+ Column name for cluster-robust standard errors.
63
+ If None, clusters at the unit level by default.
64
+ n_bootstrap : int, default=0
65
+ Number of bootstrap iterations. If 0, uses analytical inference
66
+ (conservative variance from Theorem 3).
67
+ bootstrap_weights : str, default="rademacher"
68
+ Type of bootstrap weights: "rademacher", "mammen", or "webb".
69
+ seed : int, optional
70
+ Random seed for reproducibility.
71
+ rank_deficient_action : str, default="warn"
72
+ Action when design matrix is rank-deficient:
73
+ - "warn": Issue warning and drop linearly dependent columns
74
+ - "error": Raise ValueError
75
+ - "silent": Drop columns silently
76
+ horizon_max : int, optional
77
+ Maximum event-study horizon. If set, event study effects are only
78
+ computed for |h| <= horizon_max.
79
+ aux_partition : str, default="cohort_horizon"
80
+ Controls the auxiliary model partition for Theorem 3 variance:
81
+ - "cohort_horizon": Groups by cohort x relative time (tightest SEs)
82
+ - "cohort": Groups by cohort only (more conservative)
83
+ - "horizon": Groups by relative time only (more conservative)
84
+ pretrends : bool, default=False
85
+ If True, event study includes pre-treatment horizons for visual
86
+ pre-trends assessment. Pre-period effects should be ~0 under
87
+ parallel trends. Only affects event_study aggregation; overall
88
+ ATT and group aggregation are unchanged.
89
+
90
+ Attributes
91
+ ----------
92
+ results_ : ImputationDiDResults
93
+ Estimation results after calling fit().
94
+ is_fitted_ : bool
95
+ Whether the model has been fitted.
96
+
97
+ Examples
98
+ --------
99
+ Basic usage:
100
+
101
+ >>> from diff_diff import ImputationDiD, generate_staggered_data
102
+ >>> data = generate_staggered_data(n_units=200, seed=42)
103
+ >>> est = ImputationDiD()
104
+ >>> results = est.fit(data, outcome='outcome', unit='unit',
105
+ ... time='time', first_treat='first_treat')
106
+ >>> results.print_summary()
107
+
108
+ With event study:
109
+
110
+ >>> est = ImputationDiD()
111
+ >>> results = est.fit(data, outcome='outcome', unit='unit',
112
+ ... time='time', first_treat='first_treat',
113
+ ... aggregate='event_study')
114
+ >>> from diff_diff import plot_event_study
115
+ >>> plot_event_study(results)
116
+
117
+ Notes
118
+ -----
119
+ The imputation estimator uses ALL untreated observations (never-treated +
120
+ not-yet-treated periods of eventually-treated units) to estimate the
121
+ counterfactual model. There is no ``control_group`` parameter because this
122
+ is fundamental to the method's efficiency.
123
+
124
+ References
125
+ ----------
126
+ Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event-Study
127
+ Designs: Robust and Efficient Estimation. Review of Economic Studies,
128
+ 91(6), 3253-3285.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ anticipation: int = 0,
134
+ alpha: float = 0.05,
135
+ cluster: Optional[str] = None,
136
+ n_bootstrap: int = 0,
137
+ bootstrap_weights: str = "rademacher",
138
+ seed: Optional[int] = None,
139
+ rank_deficient_action: str = "warn",
140
+ horizon_max: Optional[int] = None,
141
+ aux_partition: str = "cohort_horizon",
142
+ pretrends: bool = False,
143
+ ):
144
+ if rank_deficient_action not in ("warn", "error", "silent"):
145
+ raise ValueError(
146
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
147
+ f"got '{rank_deficient_action}'"
148
+ )
149
+ if bootstrap_weights not in ("rademacher", "mammen", "webb"):
150
+ raise ValueError(
151
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
152
+ f"got '{bootstrap_weights}'"
153
+ )
154
+ if aux_partition not in ("cohort_horizon", "cohort", "horizon"):
155
+ raise ValueError(
156
+ f"aux_partition must be 'cohort_horizon', 'cohort', or 'horizon', "
157
+ f"got '{aux_partition}'"
158
+ )
159
+
160
+ self.anticipation = anticipation
161
+ self.alpha = alpha
162
+ self.cluster = cluster
163
+ self.n_bootstrap = n_bootstrap
164
+ self.bootstrap_weights = bootstrap_weights
165
+ self.seed = seed
166
+ self.rank_deficient_action = rank_deficient_action
167
+ self.horizon_max = horizon_max
168
+ self.aux_partition = aux_partition
169
+ self.pretrends = pretrends
170
+
171
+ self.is_fitted_ = False
172
+ self.results_: Optional[ImputationDiDResults] = None
173
+
174
+ # Internal state preserved for pretrend_test()
175
+ self._fit_data: Optional[Dict[str, Any]] = None
176
+
177
+ def fit(
178
+ self,
179
+ data: pd.DataFrame,
180
+ outcome: str,
181
+ unit: str,
182
+ time: str,
183
+ first_treat: str,
184
+ covariates: Optional[List[str]] = None,
185
+ aggregate: Optional[str] = None,
186
+ balance_e: Optional[int] = None,
187
+ survey_design: object = None,
188
+ ) -> ImputationDiDResults:
189
+ """
190
+ Fit the imputation DiD estimator.
191
+
192
+ Parameters
193
+ ----------
194
+ data : pd.DataFrame
195
+ Panel data with unit and time identifiers.
196
+ outcome : str
197
+ Name of outcome variable column.
198
+ unit : str
199
+ Name of unit identifier column.
200
+ time : str
201
+ Name of time period column.
202
+ first_treat : str
203
+ Name of column indicating when unit was first treated.
204
+ Use 0 (or np.inf) for never-treated units.
205
+ covariates : list of str, optional
206
+ List of covariate column names.
207
+ aggregate : str, optional
208
+ Aggregation mode: None/"simple" (overall ATT only),
209
+ "event_study", "group", or "all".
210
+ balance_e : int, optional
211
+ When computing event study, restrict to cohorts observed at all
212
+ relative times in [-balance_e, max_h].
213
+ survey_design : SurveyDesign, optional
214
+ Survey design specification for design-based inference. Supports
215
+ pweight only (aweight/fweight raise ValueError). Supports strata,
216
+ PSU, and FPC for design-based variance via compute_survey_if_variance().
217
+ Strata enters survey df for t-distribution inference.
218
+ Both analytical (n_bootstrap=0) and bootstrap inference are supported.
219
+
220
+ Returns
221
+ -------
222
+ ImputationDiDResults
223
+ Object containing all estimation results.
224
+
225
+ Raises
226
+ ------
227
+ ValueError
228
+ If required columns are missing or data validation fails.
229
+ """
230
+ # Validate inputs
231
+ required_cols = [outcome, unit, time, first_treat]
232
+ if covariates:
233
+ required_cols.extend(covariates)
234
+
235
+ missing = [c for c in required_cols if c not in data.columns]
236
+ if missing:
237
+ raise ValueError(f"Missing columns: {missing}")
238
+
239
+ # pretrends + analytical survey is supported (Phase 8e-iii).
240
+ # Replicate-weight surveys need per-replicate lead regression refits
241
+ # which are not yet implemented — reject that combination.
242
+ if (
243
+ self.pretrends
244
+ and survey_design is not None
245
+ and survey_design.replicate_method is not None
246
+ and aggregate in ("event_study", "all")
247
+ ):
248
+ raise NotImplementedError(
249
+ "pretrends=True is not yet compatible with replicate-weight "
250
+ "survey designs. Analytical survey designs (strata/PSU/FPC) "
251
+ "are supported. Use pretrends=False with replicate weights."
252
+ )
253
+
254
+ # Create working copy
255
+ df = data.copy()
256
+
257
+ # Resolve survey design if provided
258
+ from diff_diff.survey import (
259
+ _inject_cluster_as_psu,
260
+ _resolve_effective_cluster,
261
+ _resolve_survey_for_fit,
262
+ _validate_unit_constant_survey,
263
+ )
264
+
265
+ resolved_survey, survey_weights, _, survey_metadata = _resolve_survey_for_fit(
266
+ survey_design, data, "analytical"
267
+ )
268
+
269
+ _uses_replicate_imp = (
270
+ resolved_survey is not None and resolved_survey.uses_replicate_variance
271
+ )
272
+ if _uses_replicate_imp and self.n_bootstrap > 0:
273
+ raise ValueError(
274
+ "Cannot use n_bootstrap > 0 with replicate-weight survey designs. "
275
+ "Replicate weights provide their own variance estimation."
276
+ )
277
+ # Validate within-unit constancy for panel survey designs
278
+ if resolved_survey is not None:
279
+ _validate_unit_constant_survey(data, unit, survey_design)
280
+ if resolved_survey.weight_type != "pweight":
281
+ raise ValueError(
282
+ f"ImputationDiD survey support requires weight_type='pweight', "
283
+ f"got '{resolved_survey.weight_type}'. The survey variance math "
284
+ f"assumes probability weights (pweight)."
285
+ )
286
+ # FPC is supported — threaded through compute_survey_if_variance()
287
+ # in _compute_conservative_variance().
288
+
289
+ # Bootstrap + survey supported via PSU-level multiplier bootstrap.
290
+
291
+ # Ensure numeric types
292
+ df[time] = pd.to_numeric(df[time])
293
+ df[first_treat] = pd.to_numeric(df[first_treat])
294
+
295
+ # Validate absorbing treatment: first_treat must be constant within each unit
296
+ ft_nunique = df.groupby(unit)[first_treat].nunique()
297
+ non_constant = ft_nunique[ft_nunique > 1]
298
+ if len(non_constant) > 0:
299
+ example_unit = non_constant.index[0]
300
+ example_vals = sorted(df.loc[df[unit] == example_unit, first_treat].unique())
301
+ warnings.warn(
302
+ f"{len(non_constant)} unit(s) have non-constant '{first_treat}' "
303
+ f"values (e.g., unit '{example_unit}' has values {example_vals}). "
304
+ f"ImputationDiD assumes treatment is an absorbing state "
305
+ f"(once treated, always treated) with a single treatment onset "
306
+ f"time per unit. Non-constant first_treat violates this assumption "
307
+ f"and may produce unreliable estimates.",
308
+ UserWarning,
309
+ stacklevel=2,
310
+ )
311
+
312
+ # Coerce to per-unit value so downstream code
313
+ # (_never_treated, _treated, _rel_time) uses a single
314
+ # consistent first_treat per unit.
315
+ df[first_treat] = df.groupby(unit)[first_treat].transform("first")
316
+
317
+ # Identify treatment status
318
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
319
+
320
+ # Check for always-treated units (treated in all observed periods)
321
+ min_time = df[time].min()
322
+ always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time)
323
+ n_always_treated = df.loc[always_treated_mask, unit].nunique()
324
+ if n_always_treated > 0:
325
+ warnings.warn(
326
+ f"{n_always_treated} unit(s) are treated in all observed periods "
327
+ f"(first_treat <= {min_time}). These units have no untreated "
328
+ "observations and cannot contribute to the counterfactual model. "
329
+ "Their treatment effects will be imputed but may be unreliable.",
330
+ UserWarning,
331
+ stacklevel=2,
332
+ )
333
+
334
+ # Create treatment indicator D_it
335
+ # D_it = 1 if t >= first_treat and first_treat > 0
336
+ # With anticipation: D_it = 1 if t >= first_treat - anticipation
337
+ effective_treat = df[first_treat] - self.anticipation
338
+ df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat)
339
+
340
+ # Identify Omega_0 (untreated) and Omega_1 (treated)
341
+ omega_0_mask = ~df["_treated"]
342
+ omega_1_mask = df["_treated"]
343
+
344
+ n_omega_0 = int(omega_0_mask.sum())
345
+ n_omega_1 = int(omega_1_mask.sum())
346
+
347
+ if n_omega_0 == 0:
348
+ raise ValueError(
349
+ "No untreated observations found. Cannot estimate counterfactual model."
350
+ )
351
+ if n_omega_1 == 0:
352
+ raise ValueError("No treated observations found. Nothing to estimate.")
353
+
354
+ # Identify groups and time periods
355
+ time_periods = sorted(df[time].unique())
356
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf])
357
+
358
+ if len(treatment_groups) == 0:
359
+ raise ValueError("No treated units found. Check 'first_treat' column.")
360
+
361
+ # Unit info
362
+ unit_info = (
363
+ df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index()
364
+ )
365
+ n_treated_units = int((~unit_info["_never_treated"]).sum())
366
+ # Control units = units with at least one untreated observation
367
+ units_in_omega_0 = df.loc[omega_0_mask, unit].unique()
368
+ n_control_units = len(units_in_omega_0)
369
+
370
+ # Cluster variable
371
+ cluster_var = self.cluster if self.cluster is not None else unit
372
+ if self.cluster is not None and self.cluster not in df.columns:
373
+ raise ValueError(
374
+ f"Cluster column '{self.cluster}' not found in data. "
375
+ f"Available columns: {list(df.columns)}"
376
+ )
377
+
378
+ # Resolve effective cluster and inject cluster-as-PSU for survey variance
379
+ if resolved_survey is not None:
380
+ cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None
381
+ effective_cluster_ids = _resolve_effective_cluster(
382
+ resolved_survey,
383
+ cluster_ids_raw,
384
+ cluster_var if self.cluster is not None else None,
385
+ )
386
+ resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
387
+ # When survey PSU is present, use it as the effective cluster for
388
+ # Theorem 3 variance (PSU overrides unit-level clustering)
389
+ if resolved_survey.psu is not None:
390
+ # Create a temporary column with PSU IDs for cluster_var
391
+ df["_survey_cluster"] = resolved_survey.psu
392
+ cluster_var = "_survey_cluster"
393
+ # Recompute metadata after PSU injection
394
+ if resolved_survey.psu is not None and survey_metadata is not None:
395
+ from diff_diff.survey import compute_survey_metadata
396
+
397
+ raw_w = (
398
+ data[survey_design.weights].values.astype(np.float64)
399
+ if survey_design.weights
400
+ else np.ones(len(data), dtype=np.float64)
401
+ )
402
+ survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
403
+
404
+ # Compute relative time
405
+ df["_rel_time"] = np.where(
406
+ ~df["_never_treated"],
407
+ df[time] - df[first_treat],
408
+ np.nan,
409
+ )
410
+
411
+ # ---- Step 1: OLS on untreated observations ----
412
+ unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model(
413
+ df, outcome, unit, time, covariates, omega_0_mask, weights=survey_weights
414
+ )
415
+
416
+ # ---- Rank condition checks ----
417
+ # Check: every treated unit should have >= 1 untreated period (for unit FE)
418
+ treated_unit_ids = df.loc[omega_1_mask, unit].unique()
419
+ units_with_fe = set(unit_fe.keys())
420
+ units_missing_fe = set(treated_unit_ids) - units_with_fe
421
+
422
+ # Check: every post-treatment period should have >= 1 untreated unit (for time FE)
423
+ post_period_ids = df.loc[omega_1_mask, time].unique()
424
+ periods_with_fe = set(time_fe.keys())
425
+ periods_missing_fe = set(post_period_ids) - periods_with_fe
426
+
427
+ if units_missing_fe or periods_missing_fe:
428
+ parts = []
429
+ if units_missing_fe:
430
+ sorted_missing = sorted(units_missing_fe)
431
+ parts.append(
432
+ f"{len(units_missing_fe)} treated unit(s) have no untreated "
433
+ f"periods (units: {sorted_missing[:5]}"
434
+ f"{'...' if len(units_missing_fe) > 5 else ''})"
435
+ )
436
+ if periods_missing_fe:
437
+ sorted_missing = sorted(periods_missing_fe)
438
+ parts.append(
439
+ f"{len(periods_missing_fe)} post-treatment period(s) have no "
440
+ f"untreated units (periods: {sorted_missing[:5]}"
441
+ f"{'...' if len(periods_missing_fe) > 5 else ''})"
442
+ )
443
+ msg = (
444
+ "Rank condition violated: "
445
+ + "; ".join(parts)
446
+ + ". Affected treatment effects will be NaN."
447
+ )
448
+ if self.rank_deficient_action == "error":
449
+ raise ValueError(msg)
450
+ elif self.rank_deficient_action == "warn":
451
+ warnings.warn(msg, UserWarning, stacklevel=2)
452
+ # "silent": continue without warning
453
+
454
+ # ---- Step 2: Impute treatment effects ----
455
+ tau_hat, y_hat_0 = self._impute_treatment_effects(
456
+ df,
457
+ outcome,
458
+ unit,
459
+ time,
460
+ covariates,
461
+ omega_1_mask,
462
+ unit_fe,
463
+ time_fe,
464
+ grand_mean,
465
+ delta_hat,
466
+ )
467
+
468
+ # Store tau_hat in dataframe
469
+ df["_tau_hat"] = np.nan
470
+ df.loc[omega_1_mask, "_tau_hat"] = tau_hat
471
+
472
+ # ---- Step 3: Aggregate ----
473
+ # Always compute overall ATT (simple aggregation)
474
+ finite_mask = np.isfinite(tau_hat)
475
+ valid_tau = tau_hat[finite_mask]
476
+
477
+ if len(valid_tau) == 0:
478
+ overall_att = np.nan
479
+ elif survey_weights is not None:
480
+ # Survey-weighted ATT: use treated obs' survey weights
481
+ treated_survey_w = survey_weights[omega_1_mask.values]
482
+ w_finite = treated_survey_w[finite_mask]
483
+ overall_att = float(np.average(valid_tau, weights=w_finite))
484
+ else:
485
+ overall_att = float(np.mean(valid_tau))
486
+
487
+ # ---- Variance ----
488
+ _n_valid_rep_imp = None
489
+ _vcov_rep_imp = None
490
+ overall_se = np.nan # placeholder; overridden by replicate or conservative path
491
+
492
+ if not _uses_replicate_imp:
493
+ # Conservative variance (Theorem 3)
494
+ overall_weights = np.zeros(n_omega_1)
495
+ n_valid = int(finite_mask.sum())
496
+ if n_valid > 0:
497
+ if survey_weights is not None:
498
+ treated_sw = survey_weights[omega_1_mask.values]
499
+ sw_finite = treated_sw[finite_mask]
500
+ overall_weights[finite_mask] = sw_finite / sw_finite.sum()
501
+ else:
502
+ overall_weights[finite_mask] = 1.0 / n_valid
503
+
504
+ if n_valid == 0:
505
+ overall_se = np.nan
506
+ else:
507
+ overall_se = self._compute_conservative_variance(
508
+ df=df,
509
+ outcome=outcome,
510
+ unit=unit,
511
+ time=time,
512
+ first_treat=first_treat,
513
+ covariates=covariates,
514
+ omega_0_mask=omega_0_mask,
515
+ omega_1_mask=omega_1_mask,
516
+ unit_fe=unit_fe,
517
+ time_fe=time_fe,
518
+ grand_mean=grand_mean,
519
+ delta_hat=delta_hat,
520
+ weights=overall_weights,
521
+ cluster_var=cluster_var,
522
+ kept_cov_mask=kept_cov_mask,
523
+ survey_weights=survey_weights,
524
+ resolved_survey=(resolved_survey if not _uses_replicate_imp else None),
525
+ )
526
+
527
+ # Survey degrees of freedom for t-distribution inference
528
+ _survey_df = resolved_survey.df_survey if resolved_survey is not None else None
529
+ # Replicate df: rank-deficient → NaN inference; dropped replicates → n_valid-1
530
+ if _uses_replicate_imp and _survey_df is None:
531
+ _survey_df = 0 # rank-deficient replicate → NaN inference
532
+
533
+ # Compute overall inference (may be overridden by replicate below)
534
+ overall_t, overall_p, overall_ci = safe_inference(
535
+ overall_att, overall_se, alpha=self.alpha, df=_survey_df
536
+ )
537
+
538
+ # Event study and group aggregation (full-sample, for point estimates)
539
+ event_study_effects = None
540
+ group_effects = None
541
+
542
+ if aggregate in ("event_study", "all"):
543
+ event_study_effects = self._aggregate_event_study(
544
+ df=df,
545
+ outcome=outcome,
546
+ unit=unit,
547
+ time=time,
548
+ first_treat=first_treat,
549
+ covariates=covariates,
550
+ omega_0_mask=omega_0_mask,
551
+ omega_1_mask=omega_1_mask,
552
+ unit_fe=unit_fe,
553
+ time_fe=time_fe,
554
+ grand_mean=grand_mean,
555
+ delta_hat=delta_hat,
556
+ cluster_var=cluster_var,
557
+ treatment_groups=treatment_groups,
558
+ balance_e=balance_e,
559
+ kept_cov_mask=kept_cov_mask,
560
+ survey_weights=survey_weights,
561
+ survey_df=_survey_df,
562
+ resolved_survey=(resolved_survey if not _uses_replicate_imp else None),
563
+ )
564
+
565
+ if aggregate in ("group", "all"):
566
+ group_effects = self._aggregate_group(
567
+ df=df,
568
+ outcome=outcome,
569
+ unit=unit,
570
+ time=time,
571
+ first_treat=first_treat,
572
+ covariates=covariates,
573
+ omega_0_mask=omega_0_mask,
574
+ omega_1_mask=omega_1_mask,
575
+ unit_fe=unit_fe,
576
+ time_fe=time_fe,
577
+ grand_mean=grand_mean,
578
+ delta_hat=delta_hat,
579
+ cluster_var=cluster_var,
580
+ treatment_groups=treatment_groups,
581
+ kept_cov_mask=kept_cov_mask,
582
+ survey_weights=survey_weights,
583
+ survey_df=_survey_df,
584
+ resolved_survey=(resolved_survey if not _uses_replicate_imp else None),
585
+ )
586
+
587
+ # Replicate variance: derive keys from actual outputs (after filtering)
588
+ if _uses_replicate_imp:
589
+ from diff_diff.survey import compute_replicate_refit_variance
590
+
591
+ _rel_times_treated = df.loc[omega_1_mask, "_rel_time"].values
592
+ _cohorts_treated = df.loc[omega_1_mask, first_treat].values
593
+
594
+ # Derive keys from actual outputs (excludes filtered/Prop5/ref)
595
+ _sorted_rel_times = sorted(
596
+ e
597
+ for e in (event_study_effects or {}).keys()
598
+ if np.isfinite(event_study_effects[e]["effect"])
599
+ and event_study_effects[e].get("n_obs", 1) > 0
600
+ )
601
+ _sorted_groups = sorted(
602
+ g for g in (group_effects or {}).keys() if np.isfinite(group_effects[g]["effect"])
603
+ )
604
+ _n_es = len(_sorted_rel_times)
605
+
606
+ # Pre-compute balanced cohort mask for balance_e
607
+ _balanced_mask_treated = None
608
+ if balance_e is not None and _sorted_rel_times:
609
+ df_1 = df.loc[omega_1_mask]
610
+ rel_times_all = df_1["_rel_time"].values
611
+ all_horizons_full = sorted(set(int(h) for h in rel_times_all if np.isfinite(h)))
612
+ if self.horizon_max is not None:
613
+ all_horizons_full = [h for h in all_horizons_full if abs(h) <= self.horizon_max]
614
+ cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
615
+ _balanced_mask_treated = self._compute_balanced_cohort_mask(
616
+ df_1, first_treat, all_horizons_full, balance_e, cohort_rel_times
617
+ )
618
+
619
+ # Single vectorized refit: [overall, es_e0..., grp_g0...]
620
+ def _refit_imp(w_r):
621
+ ufe_r, tfe_r, gm_r, delta_r, _ = self._fit_untreated_model(
622
+ df,
623
+ outcome,
624
+ unit,
625
+ time,
626
+ covariates,
627
+ omega_0_mask,
628
+ weights=w_r,
629
+ )
630
+ tau_r, _ = self._impute_treatment_effects(
631
+ df,
632
+ outcome,
633
+ unit,
634
+ time,
635
+ covariates,
636
+ omega_1_mask,
637
+ ufe_r,
638
+ tfe_r,
639
+ gm_r,
640
+ delta_r,
641
+ )
642
+ fin = np.isfinite(tau_r)
643
+ treated_w = w_r[omega_1_mask.values]
644
+ results = []
645
+ # [0] Overall ATT
646
+ tw_fin = treated_w[fin]
647
+ tw_sum = np.sum(tw_fin)
648
+ results.append(
649
+ float(np.sum(tau_r[fin] * tw_fin) / tw_sum) if tw_sum > 0 else np.nan
650
+ )
651
+ # [1..n_es] Event-study (identified only)
652
+ for e in _sorted_rel_times:
653
+ mask_e = fin & (_rel_times_treated == e)
654
+ if _balanced_mask_treated is not None:
655
+ mask_e = mask_e & _balanced_mask_treated
656
+ tw_e = treated_w[mask_e]
657
+ s = np.sum(tw_e)
658
+ results.append(float(np.sum(tau_r[mask_e] * tw_e) / s) if s > 0 else np.nan)
659
+ # [n_es+1..] Group (identified only)
660
+ for g in _sorted_groups:
661
+ mask_g = fin & (_cohorts_treated == g)
662
+ tw_g = treated_w[mask_g]
663
+ s = np.sum(tw_g)
664
+ results.append(float(np.sum(tau_r[mask_g] * tw_g) / s) if s > 0 else np.nan)
665
+ return np.array(results)
666
+
667
+ # Build full-sample estimate from actual effects
668
+ _full_est = [overall_att]
669
+ _full_est.extend([event_study_effects[e]["effect"] for e in _sorted_rel_times])
670
+ _full_est.extend([group_effects[g]["effect"] for g in _sorted_groups])
671
+
672
+ _vcov_rep_imp, _n_valid_rep_imp = compute_replicate_refit_variance(
673
+ _refit_imp, np.array(_full_est), resolved_survey
674
+ )
675
+ overall_se = float(np.sqrt(max(_vcov_rep_imp[0, 0], 0.0)))
676
+
677
+ # Override df if replicates were dropped
678
+ if _n_valid_rep_imp < resolved_survey.n_replicates:
679
+ _survey_df = _n_valid_rep_imp - 1 if _n_valid_rep_imp > 1 else 0
680
+ if survey_metadata is not None:
681
+ survey_metadata.df_survey = _survey_df if _survey_df and _survey_df > 0 else None
682
+
683
+ overall_t, overall_p, overall_ci = safe_inference(
684
+ overall_att, overall_se, alpha=self.alpha, df=_survey_df
685
+ )
686
+
687
+ # Override event-study SEs from vcov diagonal
688
+ for i, e in enumerate(_sorted_rel_times):
689
+ if event_study_effects is not None and e in event_study_effects:
690
+ se_e = float(np.sqrt(max(_vcov_rep_imp[1 + i, 1 + i], 0.0)))
691
+ eff_e = event_study_effects[e]["effect"]
692
+ t_e, p_e, ci_e = safe_inference(eff_e, se_e, alpha=self.alpha, df=_survey_df)
693
+ event_study_effects[e]["se"] = se_e
694
+ event_study_effects[e]["t_stat"] = t_e
695
+ event_study_effects[e]["p_value"] = p_e
696
+ event_study_effects[e]["conf_int"] = ci_e
697
+
698
+ # Override group SEs from vcov diagonal
699
+ for j, g in enumerate(_sorted_groups):
700
+ if group_effects is not None and g in group_effects:
701
+ se_g = float(np.sqrt(max(_vcov_rep_imp[1 + _n_es + j, 1 + _n_es + j], 0.0)))
702
+ eff_g = group_effects[g]["effect"]
703
+ t_g, p_g, ci_g = safe_inference(eff_g, se_g, alpha=self.alpha, df=_survey_df)
704
+ group_effects[g]["se"] = se_g
705
+ group_effects[g]["t_stat"] = t_g
706
+ group_effects[g]["p_value"] = p_g
707
+ group_effects[g]["conf_int"] = ci_g
708
+
709
+ # Build treatment effects dataframe
710
+ treated_df = df.loc[omega_1_mask, [unit, time, "_tau_hat", "_rel_time"]].copy()
711
+ treated_df = treated_df.rename(columns={"_tau_hat": "tau_hat", "_rel_time": "rel_time"})
712
+ # Weights consistent with actual ATT: zero for NaN tau_hat
713
+ tau_finite = treated_df["tau_hat"].notna()
714
+ n_valid_te = int(tau_finite.sum())
715
+ if n_valid_te > 0:
716
+ if survey_weights is not None:
717
+ # Survey-weighted: use normalized survey weights for treated obs
718
+ treated_sw = survey_weights[omega_1_mask.values]
719
+ sw_finite = np.where(tau_finite, treated_sw, 0.0)
720
+ sw_sum = sw_finite.sum()
721
+ treated_df["weight"] = sw_finite / sw_sum if sw_sum > 0 else 0.0
722
+ else:
723
+ treated_df["weight"] = np.where(tau_finite, 1.0 / n_valid_te, 0.0)
724
+ else:
725
+ treated_df["weight"] = 0.0
726
+
727
+ # Store fit data for pretrend_test
728
+ self._fit_data = {
729
+ "df": df,
730
+ "outcome": outcome,
731
+ "unit": unit,
732
+ "time": time,
733
+ "first_treat": first_treat,
734
+ "covariates": covariates,
735
+ "omega_0_mask": omega_0_mask,
736
+ "omega_1_mask": omega_1_mask,
737
+ "cluster_var": cluster_var,
738
+ "unit_fe": unit_fe,
739
+ "time_fe": time_fe,
740
+ "grand_mean": grand_mean,
741
+ "delta_hat": delta_hat,
742
+ "kept_cov_mask": kept_cov_mask,
743
+ "survey_design": survey_design,
744
+ "resolved_survey": resolved_survey,
745
+ "survey_weights": survey_weights,
746
+ }
747
+
748
+ # Pre-compute cluster psi sums for bootstrap
749
+ psi_data = None
750
+ if self.n_bootstrap > 0 and n_valid > 0:
751
+ try:
752
+ # Extract survey weights for untreated obs (same as analytical path)
753
+ _sw_0 = survey_weights[omega_0_mask.values] if survey_weights is not None else None
754
+ # Extract survey weights for treated obs (event-study/group bootstrap paths)
755
+ _sw_1 = survey_weights[omega_1_mask.values] if survey_weights is not None else None
756
+ psi_data = self._precompute_bootstrap_psi(
757
+ df=df,
758
+ outcome=outcome,
759
+ unit=unit,
760
+ time=time,
761
+ first_treat=first_treat,
762
+ covariates=covariates,
763
+ omega_0_mask=omega_0_mask,
764
+ omega_1_mask=omega_1_mask,
765
+ unit_fe=unit_fe,
766
+ time_fe=time_fe,
767
+ grand_mean=grand_mean,
768
+ delta_hat=delta_hat,
769
+ cluster_var=cluster_var,
770
+ kept_cov_mask=kept_cov_mask,
771
+ overall_weights=overall_weights,
772
+ event_study_effects=event_study_effects,
773
+ group_effects=group_effects,
774
+ treatment_groups=treatment_groups,
775
+ tau_hat=tau_hat,
776
+ balance_e=balance_e,
777
+ survey_weights_0=_sw_0,
778
+ survey_weights_1=_sw_1,
779
+ )
780
+ except Exception as e:
781
+ warnings.warn(
782
+ f"Bootstrap pre-computation failed: {e}. " "Skipping bootstrap inference.",
783
+ UserWarning,
784
+ stacklevel=2,
785
+ )
786
+ psi_data = None
787
+
788
+ # Bootstrap
789
+ bootstrap_results = None
790
+ if self.n_bootstrap > 0 and psi_data is not None:
791
+ bootstrap_results = self._run_bootstrap(
792
+ original_att=overall_att,
793
+ original_event_study=event_study_effects,
794
+ original_group=group_effects,
795
+ psi_data=psi_data,
796
+ resolved_survey=resolved_survey,
797
+ )
798
+
799
+ # Update inference with bootstrap results
800
+ overall_se = bootstrap_results.overall_att_se
801
+ overall_t = (
802
+ overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan
803
+ )
804
+ overall_p = bootstrap_results.overall_att_p_value
805
+ overall_ci = bootstrap_results.overall_att_ci
806
+
807
+ # Update event study
808
+ if event_study_effects and bootstrap_results.event_study_ses:
809
+ for h in event_study_effects:
810
+ if (
811
+ h in bootstrap_results.event_study_ses
812
+ and event_study_effects[h].get("n_obs", 1) > 0
813
+ ):
814
+ event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h]
815
+ assert bootstrap_results.event_study_cis is not None
816
+ event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[h]
817
+ assert bootstrap_results.event_study_p_values is not None
818
+ event_study_effects[h]["p_value"] = bootstrap_results.event_study_p_values[
819
+ h
820
+ ]
821
+ eff_val = event_study_effects[h]["effect"]
822
+ se_val = event_study_effects[h]["se"]
823
+ event_study_effects[h]["t_stat"] = safe_inference(
824
+ eff_val, se_val, alpha=self.alpha
825
+ )[0]
826
+
827
+ # Update group effects
828
+ if group_effects and bootstrap_results.group_ses:
829
+ for g in group_effects:
830
+ if g in bootstrap_results.group_ses:
831
+ group_effects[g]["se"] = bootstrap_results.group_ses[g]
832
+ assert bootstrap_results.group_cis is not None
833
+ group_effects[g]["conf_int"] = bootstrap_results.group_cis[g]
834
+ assert bootstrap_results.group_p_values is not None
835
+ group_effects[g]["p_value"] = bootstrap_results.group_p_values[g]
836
+ eff_val = group_effects[g]["effect"]
837
+ se_val = group_effects[g]["se"]
838
+ group_effects[g]["t_stat"] = safe_inference(
839
+ eff_val, se_val, alpha=self.alpha
840
+ )[0]
841
+
842
+ # Construct results
843
+ self.results_ = ImputationDiDResults(
844
+ treatment_effects=treated_df,
845
+ overall_att=overall_att,
846
+ overall_se=overall_se,
847
+ overall_t_stat=overall_t,
848
+ overall_p_value=overall_p,
849
+ overall_conf_int=overall_ci,
850
+ event_study_effects=event_study_effects,
851
+ group_effects=group_effects,
852
+ groups=treatment_groups,
853
+ time_periods=time_periods,
854
+ n_obs=len(df),
855
+ n_treated_obs=n_omega_1,
856
+ n_untreated_obs=n_omega_0,
857
+ n_treated_units=n_treated_units,
858
+ n_control_units=n_control_units,
859
+ alpha=self.alpha,
860
+ bootstrap_results=bootstrap_results,
861
+ _estimator_ref=self,
862
+ survey_metadata=survey_metadata,
863
+ )
864
+
865
+ self.is_fitted_ = True
866
+ return self.results_
867
+
868
+ # =========================================================================
869
+ # Step 1: OLS on untreated observations
870
+ # =========================================================================
871
+
872
+ def _iterative_fe(
873
+ self,
874
+ y: np.ndarray,
875
+ unit_vals: np.ndarray,
876
+ time_vals: np.ndarray,
877
+ idx: pd.Index,
878
+ max_iter: int = 100,
879
+ tol: float = 1e-10,
880
+ weights: Optional[np.ndarray] = None,
881
+ ) -> Tuple[Dict[Any, float], Dict[Any, float]]:
882
+ """
883
+ Estimate unit and time FE via iterative alternating projection (Gauss-Seidel).
884
+
885
+ Converges to the exact OLS solution for both balanced and unbalanced panels.
886
+ For balanced panels, converges in 1-2 iterations (identical to one-pass).
887
+ For unbalanced panels, typically 5-20 iterations.
888
+
889
+ Parameters
890
+ ----------
891
+ weights : np.ndarray, optional
892
+ Survey weights. When provided, uses weighted group means
893
+ (sum(w*x)/sum(w)) instead of unweighted means.
894
+
895
+ Returns
896
+ -------
897
+ unit_fe : dict
898
+ Mapping from unit -> unit fixed effect.
899
+ time_fe : dict
900
+ Mapping from time -> time fixed effect.
901
+ """
902
+ n = len(y)
903
+ alpha = np.zeros(n) # unit FE broadcast to obs level
904
+ beta = np.zeros(n) # time FE broadcast to obs level
905
+
906
+ # Precompute per-group weight sums (invariant across iterations)
907
+ if weights is not None:
908
+ w_series = pd.Series(weights, index=idx)
909
+ wsum_t = w_series.groupby(time_vals).transform("sum").values
910
+ wsum_u = w_series.groupby(unit_vals).transform("sum").values
911
+
912
+ with np.errstate(invalid="ignore", divide="ignore"):
913
+ for iteration in range(max_iter):
914
+ resid_after_alpha = y - alpha
915
+ if weights is not None:
916
+ wr_t = pd.Series(resid_after_alpha * weights, index=idx)
917
+ beta_new = wr_t.groupby(time_vals).transform("sum").values / wsum_t
918
+ else:
919
+ beta_new = (
920
+ pd.Series(resid_after_alpha, index=idx)
921
+ .groupby(time_vals)
922
+ .transform("mean")
923
+ .values
924
+ )
925
+
926
+ resid_after_beta = y - beta_new
927
+ if weights is not None:
928
+ wr_u = pd.Series(resid_after_beta * weights, index=idx)
929
+ alpha_new = wr_u.groupby(unit_vals).transform("sum").values / wsum_u
930
+ else:
931
+ alpha_new = (
932
+ pd.Series(resid_after_beta, index=idx)
933
+ .groupby(unit_vals)
934
+ .transform("mean")
935
+ .values
936
+ )
937
+
938
+ # Check convergence on FE changes
939
+ max_change = max(
940
+ np.max(np.abs(alpha_new - alpha)),
941
+ np.max(np.abs(beta_new - beta)),
942
+ )
943
+ alpha = alpha_new
944
+ beta = beta_new
945
+ if max_change < tol:
946
+ break
947
+
948
+ unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
949
+ time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
950
+ return unit_fe, time_fe
951
+
952
+ @staticmethod
953
+ def _iterative_demean(
954
+ vals: np.ndarray,
955
+ unit_vals: np.ndarray,
956
+ time_vals: np.ndarray,
957
+ idx: pd.Index,
958
+ max_iter: int = 100,
959
+ tol: float = 1e-10,
960
+ weights: Optional[np.ndarray] = None,
961
+ ) -> np.ndarray:
962
+ """Demean a vector by iterative alternating projection (unit + time FE removal).
963
+
964
+ Converges to the exact within-transformation for both balanced and
965
+ unbalanced panels. For balanced panels, converges in 1-2 iterations.
966
+
967
+ Parameters
968
+ ----------
969
+ weights : np.ndarray, optional
970
+ Survey weights. When provided, uses weighted group means
971
+ (sum(w*x)/sum(w)) instead of unweighted means.
972
+ """
973
+ result = vals.copy()
974
+
975
+ # Precompute per-group weight sums (invariant across iterations)
976
+ if weights is not None:
977
+ w_series = pd.Series(weights, index=idx)
978
+ wsum_t = w_series.groupby(time_vals).transform("sum").values
979
+ wsum_u = w_series.groupby(unit_vals).transform("sum").values
980
+
981
+ with np.errstate(invalid="ignore", divide="ignore"):
982
+ for _ in range(max_iter):
983
+ if weights is not None:
984
+ wr_t = pd.Series(result * weights, index=idx)
985
+ time_means = wr_t.groupby(time_vals).transform("sum").values / wsum_t
986
+ else:
987
+ time_means = (
988
+ pd.Series(result, index=idx).groupby(time_vals).transform("mean").values
989
+ )
990
+ result_after_time = result - time_means
991
+ if weights is not None:
992
+ wr_u = pd.Series(result_after_time * weights, index=idx)
993
+ unit_means = wr_u.groupby(unit_vals).transform("sum").values / wsum_u
994
+ else:
995
+ unit_means = (
996
+ pd.Series(result_after_time, index=idx)
997
+ .groupby(unit_vals)
998
+ .transform("mean")
999
+ .values
1000
+ )
1001
+ result_new = result_after_time - unit_means
1002
+ if np.max(np.abs(result_new - result)) < tol:
1003
+ result = result_new
1004
+ break
1005
+ result = result_new
1006
+ return result
1007
+
1008
+ @staticmethod
1009
+ def _compute_balanced_cohort_mask(
1010
+ df_treated: pd.DataFrame,
1011
+ first_treat: str,
1012
+ all_horizons: List[int],
1013
+ balance_e: int,
1014
+ cohort_rel_times: Dict[Any, Set[int]],
1015
+ ) -> np.ndarray:
1016
+ """Compute boolean mask selecting treated obs from balanced cohorts.
1017
+
1018
+ A cohort is 'balanced' if it has observations at every relative time
1019
+ in [-balance_e, max(all_horizons)].
1020
+
1021
+ Parameters
1022
+ ----------
1023
+ df_treated : pd.DataFrame
1024
+ Post-treatment observations (Omega_1).
1025
+ first_treat : str
1026
+ Column name for cohort identifier.
1027
+ all_horizons : list of int
1028
+ Post-treatment horizons in the event study.
1029
+ balance_e : int
1030
+ Number of pre-treatment periods to require.
1031
+ cohort_rel_times : dict
1032
+ Maps each cohort value to the set of all observed relative times
1033
+ (including pre-treatment) from the full panel. Built by
1034
+ _build_cohort_rel_times().
1035
+ """
1036
+ if not all_horizons:
1037
+ return np.ones(len(df_treated), dtype=bool)
1038
+
1039
+ max_h = max(all_horizons)
1040
+ required_range = set(range(-balance_e, max_h + 1))
1041
+
1042
+ balanced_cohorts = set()
1043
+ for g, horizons in cohort_rel_times.items():
1044
+ if required_range.issubset(horizons):
1045
+ balanced_cohorts.add(g)
1046
+
1047
+ return df_treated[first_treat].isin(balanced_cohorts).values
1048
+
1049
+ @staticmethod
1050
+ def _build_cohort_rel_times(
1051
+ df: pd.DataFrame,
1052
+ first_treat: str,
1053
+ ) -> Dict[Any, Set[int]]:
1054
+ """Build mapping of cohort -> set of observed relative times from full panel.
1055
+
1056
+ Precondition: df must have '_never_treated' and '_rel_time' columns
1057
+ (set by fit() before any aggregation calls).
1058
+ """
1059
+ treated_mask = ~df["_never_treated"]
1060
+ treated_df = df.loc[treated_mask]
1061
+ result: Dict[Any, Set[int]] = {}
1062
+ ft_vals = treated_df[first_treat].values
1063
+ rt_vals = treated_df["_rel_time"].values
1064
+ for i in range(len(treated_df)):
1065
+ h = rt_vals[i]
1066
+ if np.isfinite(h):
1067
+ result.setdefault(ft_vals[i], set()).add(int(h))
1068
+ return result
1069
+
1070
+ def _fit_untreated_model(
1071
+ self,
1072
+ df: pd.DataFrame,
1073
+ outcome: str,
1074
+ unit: str,
1075
+ time: str,
1076
+ covariates: Optional[List[str]],
1077
+ omega_0_mask: pd.Series,
1078
+ weights: Optional[np.ndarray] = None,
1079
+ ) -> Tuple[
1080
+ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray]
1081
+ ]:
1082
+ """
1083
+ Step 1: Estimate unit + time FE on untreated observations.
1084
+
1085
+ Uses iterative alternating projection (Gauss-Seidel) to compute exact
1086
+ OLS fixed effects for both balanced and unbalanced panels. For balanced
1087
+ panels, converges in 1-2 iterations (identical to one-pass demeaning).
1088
+
1089
+ Parameters
1090
+ ----------
1091
+ weights : np.ndarray, optional
1092
+ Full-panel survey weights (same length as df). The untreated subset
1093
+ is extracted internally via omega_0_mask. When None, unweighted.
1094
+
1095
+ Returns
1096
+ -------
1097
+ unit_fe : dict
1098
+ Unit fixed effects {unit_id: alpha_i}.
1099
+ time_fe : dict
1100
+ Time fixed effects {time_period: beta_t}.
1101
+ grand_mean : float
1102
+ Grand mean (0.0 — absorbed into iterative FE).
1103
+ delta_hat : np.ndarray or None
1104
+ Covariate coefficients (if covariates provided).
1105
+ kept_cov_mask : np.ndarray or None
1106
+ Boolean mask of shape (n_covariates,) indicating which covariates
1107
+ have finite coefficients. None if no covariates.
1108
+ """
1109
+ df_0 = df.loc[omega_0_mask]
1110
+ w_0 = weights[omega_0_mask.values] if weights is not None else None
1111
+
1112
+ if covariates is None or len(covariates) == 0:
1113
+ # No covariates: estimate FE via iterative alternating projection
1114
+ # (exact OLS for both balanced and unbalanced panels)
1115
+ y = df_0[outcome].values.copy()
1116
+ unit_fe, time_fe = self._iterative_fe(
1117
+ y, df_0[unit].values, df_0[time].values, df_0.index, weights=w_0
1118
+ )
1119
+ # grand_mean = 0: iterative FE absorb the intercept
1120
+ return unit_fe, time_fe, 0.0, None, None
1121
+
1122
+ else:
1123
+ # With covariates: iteratively demean Y and X, OLS for delta,
1124
+ # then recover FE from covariate-adjusted outcome
1125
+ y = df_0[outcome].values.copy()
1126
+ X_raw = df_0[covariates].values.copy()
1127
+ units = df_0[unit].values
1128
+ times = df_0[time].values
1129
+ n_cov = len(covariates)
1130
+
1131
+ # Step A: Iteratively demean Y and all X columns to remove unit+time FE
1132
+ y_dm = self._iterative_demean(y, units, times, df_0.index, weights=w_0)
1133
+ X_dm = np.column_stack(
1134
+ [
1135
+ self._iterative_demean(X_raw[:, j], units, times, df_0.index, weights=w_0)
1136
+ for j in range(n_cov)
1137
+ ]
1138
+ )
1139
+
1140
+ # Step B: OLS for covariate coefficients on demeaned data
1141
+ result = solve_ols(
1142
+ X_dm,
1143
+ y_dm,
1144
+ return_vcov=False,
1145
+ rank_deficient_action=self.rank_deficient_action,
1146
+ column_names=covariates,
1147
+ weights=w_0,
1148
+ )
1149
+ delta_hat = result[0]
1150
+
1151
+ # Mask of covariates with finite coefficients (before cleaning)
1152
+ # Used to exclude rank-deficient covariates from variance design matrices
1153
+ kept_cov_mask = np.isfinite(delta_hat)
1154
+
1155
+ # Replace NaN coefficients with 0 for adjustment
1156
+ # (rank-deficient covariates are dropped)
1157
+ delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0)
1158
+
1159
+ # Step C: Recover FE from covariate-adjusted outcome using iterative FE
1160
+ y_adj = y - np.dot(X_raw, delta_hat_clean)
1161
+ unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index, weights=w_0)
1162
+
1163
+ # grand_mean = 0: iterative FE absorb the intercept
1164
+ return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask
1165
+
1166
+ # =========================================================================
1167
+ # Step 2: Impute counterfactuals
1168
+ # =========================================================================
1169
+
1170
+ def _impute_treatment_effects(
1171
+ self,
1172
+ df: pd.DataFrame,
1173
+ outcome: str,
1174
+ unit: str,
1175
+ time: str,
1176
+ covariates: Optional[List[str]],
1177
+ omega_1_mask: pd.Series,
1178
+ unit_fe: Dict[Any, float],
1179
+ time_fe: Dict[Any, float],
1180
+ grand_mean: float,
1181
+ delta_hat: Optional[np.ndarray],
1182
+ ) -> Tuple[np.ndarray, np.ndarray]:
1183
+ """
1184
+ Step 2: Impute Y(0) for treated observations and compute tau_hat.
1185
+
1186
+ Returns
1187
+ -------
1188
+ tau_hat : np.ndarray
1189
+ Imputed treatment effects for each treated observation.
1190
+ y_hat_0 : np.ndarray
1191
+ Imputed counterfactual Y(0).
1192
+ """
1193
+ df_1 = df.loc[omega_1_mask]
1194
+ n_1 = len(df_1)
1195
+
1196
+ # Look up unit and time FE
1197
+ alpha_i = df_1[unit].map(unit_fe).values
1198
+ beta_t = df_1[time].map(time_fe).values
1199
+
1200
+ # Handle missing FE (set to NaN)
1201
+ alpha_i = np.where(pd.isna(alpha_i), np.nan, alpha_i).astype(float)
1202
+ beta_t = np.where(pd.isna(beta_t), np.nan, beta_t).astype(float)
1203
+
1204
+ y_hat_0 = grand_mean + alpha_i + beta_t
1205
+
1206
+ if delta_hat is not None and covariates:
1207
+ X_1 = df_1[covariates].values
1208
+ y_hat_0 = y_hat_0 + np.dot(X_1, delta_hat)
1209
+
1210
+ tau_hat = df_1[outcome].values - y_hat_0
1211
+
1212
+ return tau_hat, y_hat_0
1213
+
1214
+ # =========================================================================
1215
+ # Conservative Variance (Theorem 3)
1216
+ # =========================================================================
1217
+
1218
+ def _compute_cluster_psi_sums(
1219
+ self,
1220
+ df: pd.DataFrame,
1221
+ outcome: str,
1222
+ unit: str,
1223
+ time: str,
1224
+ first_treat: str,
1225
+ covariates: Optional[List[str]],
1226
+ omega_0_mask: pd.Series,
1227
+ omega_1_mask: pd.Series,
1228
+ unit_fe: Dict[Any, float],
1229
+ time_fe: Dict[Any, float],
1230
+ grand_mean: float,
1231
+ delta_hat: Optional[np.ndarray],
1232
+ weights: np.ndarray,
1233
+ cluster_var: str,
1234
+ kept_cov_mask: Optional[np.ndarray] = None,
1235
+ survey_weights_0: Optional[np.ndarray] = None,
1236
+ ) -> Tuple[np.ndarray, np.ndarray]:
1237
+ """
1238
+ Compute cluster-level influence function sums (Theorem 3).
1239
+
1240
+ psi_i = sum_t v_it * epsilon_tilde_it, summed within each cluster.
1241
+
1242
+ Returns
1243
+ -------
1244
+ cluster_psi_sums : np.ndarray
1245
+ Array of cluster-level psi sums.
1246
+ cluster_ids_unique : np.ndarray
1247
+ Unique cluster identifiers (matching order of psi sums).
1248
+ """
1249
+ df_0 = df.loc[omega_0_mask]
1250
+ df_1 = df.loc[omega_1_mask]
1251
+ n_0 = len(df_0)
1252
+ n_1 = len(df_1)
1253
+
1254
+ # ---- Compute v_it for treated observations ----
1255
+ v_treated = weights.copy()
1256
+
1257
+ # ---- Compute v_it for untreated observations ----
1258
+ if covariates is None or len(covariates) == 0:
1259
+ # FE-only case: closed-form
1260
+ # Build w_by_unit, w_by_time, w_total from the target weights
1261
+ treated_units = df_1[unit].values
1262
+ treated_times = df_1[time].values
1263
+
1264
+ w_by_unit: Dict[Any, float] = {}
1265
+ for i_idx in range(n_1):
1266
+ u = treated_units[i_idx]
1267
+ w_by_unit[u] = w_by_unit.get(u, 0.0) + weights[i_idx]
1268
+
1269
+ w_by_time: Dict[Any, float] = {}
1270
+ for i_idx in range(n_1):
1271
+ t = treated_times[i_idx]
1272
+ w_by_time[t] = w_by_time.get(t, 0.0) + weights[i_idx]
1273
+
1274
+ w_total = float(np.sum(weights))
1275
+
1276
+ untreated_units = df_0[unit].values
1277
+ untreated_times = df_0[time].values
1278
+
1279
+ # Use survey-weighted sums for untreated denominators when present
1280
+ if survey_weights_0 is not None:
1281
+ sw0_series = pd.Series(survey_weights_0, index=df_0.index)
1282
+ n0_by_unit = sw0_series.groupby(df_0[unit]).sum().to_dict()
1283
+ n0_by_time = sw0_series.groupby(df_0[time]).sum().to_dict()
1284
+ n0_denom = float(np.sum(survey_weights_0))
1285
+ else:
1286
+ n0_by_unit = df_0.groupby(unit).size().to_dict()
1287
+ n0_by_time = df_0.groupby(time).size().to_dict()
1288
+ n0_denom = n_0
1289
+
1290
+ v_untreated = np.zeros(n_0)
1291
+
1292
+ for j in range(n_0):
1293
+ u = untreated_units[j]
1294
+ t = untreated_times[j]
1295
+ w_i = w_by_unit.get(u, 0.0)
1296
+ w_t = w_by_time.get(t, 0.0)
1297
+ n0_i = n0_by_unit.get(u, 1)
1298
+ n0_t = n0_by_time.get(t, 1)
1299
+ base_v = -(w_i / n0_i + w_t / n0_t - w_total / n0_denom)
1300
+ # WLS projection requires per-obs survey weight factor
1301
+ if survey_weights_0 is not None:
1302
+ base_v *= survey_weights_0[j]
1303
+ v_untreated[j] = base_v
1304
+ else:
1305
+ v_untreated = self._compute_v_untreated_with_covariates(
1306
+ df_0,
1307
+ df_1,
1308
+ unit,
1309
+ time,
1310
+ covariates,
1311
+ weights,
1312
+ delta_hat,
1313
+ kept_cov_mask=kept_cov_mask,
1314
+ survey_weights_0=survey_weights_0,
1315
+ )
1316
+
1317
+ # ---- Compute auxiliary model residuals (Equation 8) ----
1318
+ epsilon_treated = self._compute_auxiliary_residuals_treated(
1319
+ df_1,
1320
+ outcome,
1321
+ unit,
1322
+ time,
1323
+ first_treat,
1324
+ covariates,
1325
+ unit_fe,
1326
+ time_fe,
1327
+ grand_mean,
1328
+ delta_hat,
1329
+ v_treated,
1330
+ )
1331
+ epsilon_untreated = self._compute_residuals_untreated(
1332
+ df_0, outcome, unit, time, covariates, unit_fe, time_fe, grand_mean, delta_hat
1333
+ )
1334
+
1335
+ # ---- psi_it = v_it * epsilon_tilde_it ----
1336
+ v_all = np.empty(len(df))
1337
+ v_all[omega_1_mask.values] = v_treated
1338
+ v_all[omega_0_mask.values] = v_untreated
1339
+
1340
+ eps_all = np.empty(len(df))
1341
+ eps_all[omega_1_mask.values] = epsilon_treated
1342
+ eps_all[omega_0_mask.values] = epsilon_untreated
1343
+
1344
+ ve_product = v_all * eps_all
1345
+ # NaN eps from missing FE (rank condition violation). Zero their variance
1346
+ # contribution — matches R's did_imputation which drops unimputable obs.
1347
+ np.nan_to_num(ve_product, copy=False, nan=0.0)
1348
+
1349
+ # Sum within clusters
1350
+ cluster_ids = df[cluster_var].values
1351
+ ve_series = pd.Series(ve_product, index=df.index)
1352
+ cluster_sums = ve_series.groupby(cluster_ids).sum()
1353
+
1354
+ return cluster_sums.values, cluster_sums.index.values, ve_product
1355
+
1356
+ def _compute_conservative_variance(
1357
+ self,
1358
+ df: pd.DataFrame,
1359
+ outcome: str,
1360
+ unit: str,
1361
+ time: str,
1362
+ first_treat: str,
1363
+ covariates: Optional[List[str]],
1364
+ omega_0_mask: pd.Series,
1365
+ omega_1_mask: pd.Series,
1366
+ unit_fe: Dict[Any, float],
1367
+ time_fe: Dict[Any, float],
1368
+ grand_mean: float,
1369
+ delta_hat: Optional[np.ndarray],
1370
+ weights: np.ndarray,
1371
+ cluster_var: str,
1372
+ kept_cov_mask: Optional[np.ndarray] = None,
1373
+ survey_weights: Optional[np.ndarray] = None,
1374
+ resolved_survey=None,
1375
+ ) -> float:
1376
+ """
1377
+ Compute conservative clustered variance (Theorem 3, Equation 7).
1378
+
1379
+ Parameters
1380
+ ----------
1381
+ weights : np.ndarray
1382
+ Aggregation weights w_it for treated observations.
1383
+ Shape: (n_treated,), must sum to 1.
1384
+ survey_weights : np.ndarray, optional
1385
+ Full-panel survey weights. When provided, untreated denominators
1386
+ in v_it use survey-weighted sums instead of raw counts.
1387
+ resolved_survey : ResolvedSurveyDesign, optional
1388
+ When provided, uses design-based variance via
1389
+ ``compute_survey_if_variance()`` (supports strata, PSU, FPC).
1390
+
1391
+ Returns
1392
+ -------
1393
+ float
1394
+ Standard error.
1395
+ """
1396
+ sw_0 = survey_weights[omega_0_mask.values] if survey_weights is not None else None
1397
+ cluster_psi_sums, _, ve_product = self._compute_cluster_psi_sums(
1398
+ df=df,
1399
+ outcome=outcome,
1400
+ unit=unit,
1401
+ time=time,
1402
+ first_treat=first_treat,
1403
+ covariates=covariates,
1404
+ omega_0_mask=omega_0_mask,
1405
+ omega_1_mask=omega_1_mask,
1406
+ unit_fe=unit_fe,
1407
+ time_fe=time_fe,
1408
+ grand_mean=grand_mean,
1409
+ delta_hat=delta_hat,
1410
+ weights=weights,
1411
+ cluster_var=cluster_var,
1412
+ kept_cov_mask=kept_cov_mask,
1413
+ survey_weights_0=sw_0,
1414
+ )
1415
+
1416
+ if resolved_survey is not None:
1417
+ # Design-based variance with strata/PSU/FPC support
1418
+ from diff_diff.survey import compute_survey_if_variance
1419
+
1420
+ variance = compute_survey_if_variance(ve_product, resolved_survey)
1421
+ if np.isnan(variance):
1422
+ return np.nan
1423
+ return np.sqrt(max(variance, 0.0))
1424
+
1425
+ sigma_sq = float((cluster_psi_sums**2).sum())
1426
+ return np.sqrt(max(sigma_sq, 0.0))
1427
+
1428
+ def _compute_v_untreated_with_covariates(
1429
+ self,
1430
+ df_0: pd.DataFrame,
1431
+ df_1: pd.DataFrame,
1432
+ unit: str,
1433
+ time: str,
1434
+ covariates: List[str],
1435
+ weights: np.ndarray,
1436
+ delta_hat: Optional[np.ndarray],
1437
+ kept_cov_mask: Optional[np.ndarray] = None,
1438
+ survey_weights_0: Optional[np.ndarray] = None,
1439
+ ) -> np.ndarray:
1440
+ """
1441
+ Compute v_it for untreated observations with covariates.
1442
+
1443
+ Uses the projection: v_untreated = -A_0 (A_0'A_0)^{-1} A_1' w_treated
1444
+ When survey_weights_0 is provided, uses weighted normal equations:
1445
+ v_untreated = -A_0 (A_0' W A_0)^{-1} A_1' w_treated
1446
+
1447
+ Uses scipy.sparse for FE dummy columns to reduce memory from O(N*(U+T))
1448
+ to O(N) for the FE portion.
1449
+ """
1450
+ # Exclude rank-deficient covariates from design matrices
1451
+ if kept_cov_mask is not None and not np.all(kept_cov_mask):
1452
+ covariates = [c for c, k in zip(covariates, kept_cov_mask) if k]
1453
+
1454
+ units_0 = df_0[unit].values
1455
+ times_0 = df_0[time].values
1456
+ units_1 = df_1[unit].values
1457
+ times_1 = df_1[time].values
1458
+
1459
+ all_units = np.unique(np.concatenate([units_0, units_1]))
1460
+ all_times = np.unique(np.concatenate([times_0, times_1]))
1461
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1462
+ time_to_idx = {t: i for i, t in enumerate(all_times)}
1463
+ n_units = len(all_units)
1464
+ n_times = len(all_times)
1465
+ n_cov = len(covariates)
1466
+ n_fe_cols = (n_units - 1) + (n_times - 1)
1467
+
1468
+ def _build_A_sparse(df_sub, unit_vals, time_vals):
1469
+ n = len(df_sub)
1470
+
1471
+ # Unit dummies (drop first) — vectorized
1472
+ u_indices = np.array([unit_to_idx[u] for u in unit_vals])
1473
+ u_mask = u_indices > 0 # skip first unit (dropped)
1474
+ u_rows = np.arange(n)[u_mask]
1475
+ u_cols = u_indices[u_mask] - 1
1476
+
1477
+ # Time dummies (drop first) — vectorized
1478
+ t_indices = np.array([time_to_idx[t] for t in time_vals])
1479
+ t_mask = t_indices > 0
1480
+ t_rows = np.arange(n)[t_mask]
1481
+ t_cols = (n_units - 1) + t_indices[t_mask] - 1
1482
+
1483
+ rows = np.concatenate([u_rows, t_rows])
1484
+ cols = np.concatenate([u_cols, t_cols])
1485
+ data = np.ones(len(rows))
1486
+
1487
+ A_fe = sparse.csr_matrix((data, (rows, cols)), shape=(n, n_fe_cols))
1488
+
1489
+ # Covariates (dense, typically few columns)
1490
+ if n_cov > 0:
1491
+ A_cov = sparse.csr_matrix(df_sub[covariates].values)
1492
+ A = sparse.hstack([A_fe, A_cov], format="csr")
1493
+ else:
1494
+ A = A_fe
1495
+
1496
+ return A
1497
+
1498
+ A_0 = _build_A_sparse(df_0, units_0, times_0)
1499
+ A_1 = _build_A_sparse(df_1, units_1, times_1)
1500
+
1501
+ # Compute A_1' w (sparse.T @ dense -> dense)
1502
+ A1_w = A_1.T @ weights # shape (p,)
1503
+
1504
+ # Solve (A_0' [W] A_0) z = A_1' w using sparse direct solver
1505
+ # When survey weights present, use weighted normal equations A_0' W A_0
1506
+ if survey_weights_0 is not None:
1507
+ A0tA0_sparse = A_0.T @ A_0.multiply(survey_weights_0[:, None])
1508
+ else:
1509
+ A0tA0_sparse = A_0.T @ A_0 # stays sparse
1510
+ try:
1511
+ z = spsolve(A0tA0_sparse.tocsc(), A1_w)
1512
+ except Exception:
1513
+ # Fallback to dense lstsq if sparse solver fails (e.g., singular matrix)
1514
+ A0tA0_dense = A0tA0_sparse.toarray()
1515
+ z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None)
1516
+
1517
+ # v_untreated = -[W_0] A_0 z (WLS projection requires per-obs weight)
1518
+ v_untreated = -(A_0 @ z)
1519
+ if survey_weights_0 is not None:
1520
+ v_untreated = v_untreated * survey_weights_0
1521
+ return v_untreated
1522
+
1523
+ def _compute_auxiliary_residuals_treated(
1524
+ self,
1525
+ df_1: pd.DataFrame,
1526
+ outcome: str,
1527
+ unit: str,
1528
+ time: str,
1529
+ first_treat: str,
1530
+ covariates: Optional[List[str]],
1531
+ unit_fe: Dict[Any, float],
1532
+ time_fe: Dict[Any, float],
1533
+ grand_mean: float,
1534
+ delta_hat: Optional[np.ndarray],
1535
+ v_treated: np.ndarray,
1536
+ ) -> np.ndarray:
1537
+ """
1538
+ Compute v_it-weighted auxiliary residuals for treated obs (Equation 8).
1539
+
1540
+ Computes v_it-weighted tau_tilde_g per Equation 8 of Borusyak et al. (2024):
1541
+ tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g.
1542
+
1543
+ epsilon_tilde_it = Y_it - alpha_i - beta_t [- X'delta] - tau_tilde_g
1544
+ """
1545
+ n_1 = len(df_1)
1546
+
1547
+ # Compute base residuals (Y - Y_hat(0) = tau_hat)
1548
+ # NaN for missing FE (consistent with _impute_treatment_effects)
1549
+ alpha_i = df_1[unit].map(unit_fe).values.astype(float) # NaN for missing
1550
+ beta_t = df_1[time].map(time_fe).values.astype(float) # NaN for missing
1551
+ y_hat_0 = grand_mean + alpha_i + beta_t
1552
+
1553
+ if delta_hat is not None and covariates:
1554
+ y_hat_0 = y_hat_0 + np.dot(df_1[covariates].values, delta_hat)
1555
+
1556
+ tau_hat = df_1[outcome].values - y_hat_0
1557
+
1558
+ # Partition Omega_1 and compute tau_tilde for each group
1559
+ if self.aux_partition == "cohort_horizon":
1560
+ group_keys = list(zip(df_1[first_treat].values, df_1["_rel_time"].values))
1561
+ elif self.aux_partition == "cohort":
1562
+ group_keys = list(df_1[first_treat].values)
1563
+ elif self.aux_partition == "horizon":
1564
+ group_keys = list(df_1["_rel_time"].values)
1565
+ else:
1566
+ group_keys = list(range(n_1)) # each obs is its own group
1567
+
1568
+ # Compute v_it-weighted average tau within each partition group (Equation 8)
1569
+ # tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g
1570
+ group_series = pd.Series(group_keys, index=df_1.index)
1571
+ tau_series = pd.Series(tau_hat, index=df_1.index)
1572
+ v_series = pd.Series(v_treated, index=df_1.index)
1573
+
1574
+ weighted_tau_sum = (v_series * tau_series).groupby(group_series).sum()
1575
+ weight_sum = v_series.groupby(group_series).sum()
1576
+
1577
+ # Guard: zero-weight groups -> their tau_tilde doesn't affect variance
1578
+ # (v_it ~ 0 means these obs contribute nothing to the estimand)
1579
+ # Use simple mean as fallback. This is common for event-study SE computation
1580
+ # where weights target a specific horizon, making other partition groups zero.
1581
+ zero_weight_groups = weight_sum.abs() < 1e-15
1582
+ if zero_weight_groups.any():
1583
+ simple_means = tau_series.groupby(group_series).mean()
1584
+ tau_tilde_map = weighted_tau_sum / weight_sum
1585
+ tau_tilde_map = tau_tilde_map.where(~zero_weight_groups, simple_means)
1586
+ else:
1587
+ tau_tilde_map = weighted_tau_sum / weight_sum
1588
+
1589
+ tau_tilde = group_series.map(tau_tilde_map).values
1590
+
1591
+ # Auxiliary residuals
1592
+ epsilon_treated = tau_hat - tau_tilde
1593
+
1594
+ return epsilon_treated
1595
+
1596
+ def _compute_residuals_untreated(
1597
+ self,
1598
+ df_0: pd.DataFrame,
1599
+ outcome: str,
1600
+ unit: str,
1601
+ time: str,
1602
+ covariates: Optional[List[str]],
1603
+ unit_fe: Dict[Any, float],
1604
+ time_fe: Dict[Any, float],
1605
+ grand_mean: float,
1606
+ delta_hat: Optional[np.ndarray],
1607
+ ) -> np.ndarray:
1608
+ """Compute Step 1 residuals for untreated observations."""
1609
+ alpha_i = df_0[unit].map(unit_fe).fillna(0.0).values
1610
+ beta_t = df_0[time].map(time_fe).fillna(0.0).values
1611
+ y_hat = grand_mean + alpha_i + beta_t
1612
+
1613
+ if delta_hat is not None and covariates:
1614
+ y_hat = y_hat + np.dot(df_0[covariates].values, delta_hat)
1615
+
1616
+ return df_0[outcome].values - y_hat
1617
+
1618
+ # =========================================================================
1619
+ # Aggregation
1620
+ # =========================================================================
1621
+
1622
+ def _aggregate_event_study(
1623
+ self,
1624
+ df: pd.DataFrame,
1625
+ outcome: str,
1626
+ unit: str,
1627
+ time: str,
1628
+ first_treat: str,
1629
+ covariates: Optional[List[str]],
1630
+ omega_0_mask: pd.Series,
1631
+ omega_1_mask: pd.Series,
1632
+ unit_fe: Dict[Any, float],
1633
+ time_fe: Dict[Any, float],
1634
+ grand_mean: float,
1635
+ delta_hat: Optional[np.ndarray],
1636
+ cluster_var: str,
1637
+ treatment_groups: List[Any],
1638
+ balance_e: Optional[int] = None,
1639
+ kept_cov_mask: Optional[np.ndarray] = None,
1640
+ survey_weights: Optional[np.ndarray] = None,
1641
+ survey_df: Optional[int] = None,
1642
+ resolved_survey=None,
1643
+ ) -> Dict[int, Dict[str, Any]]:
1644
+ """Aggregate treatment effects by event-study horizon."""
1645
+ df_1 = df.loc[omega_1_mask]
1646
+ tau_hat = df["_tau_hat"].loc[omega_1_mask].values
1647
+ rel_times = df_1["_rel_time"].values
1648
+
1649
+ # Get all horizons
1650
+ all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h)))
1651
+
1652
+ # Apply horizon_max filter
1653
+ if self.horizon_max is not None:
1654
+ all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
1655
+
1656
+ # Apply balance_e filter
1657
+ if balance_e is not None:
1658
+ cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
1659
+ balanced_mask = pd.Series(
1660
+ self._compute_balanced_cohort_mask(
1661
+ df_1, first_treat, all_horizons, balance_e, cohort_rel_times
1662
+ ),
1663
+ index=df_1.index,
1664
+ )
1665
+ else:
1666
+ balanced_mask = pd.Series(True, index=df_1.index)
1667
+
1668
+ # Check Proposition 5: no never-treated units
1669
+ has_never_treated = df["_never_treated"].any()
1670
+ h_bar = np.inf
1671
+ if not has_never_treated and len(treatment_groups) > 1:
1672
+ h_bar = max(treatment_groups) - min(treatment_groups)
1673
+
1674
+ # Reference period
1675
+ ref_period = -1 - self.anticipation
1676
+
1677
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
1678
+
1679
+ # Add reference period marker
1680
+ event_study_effects[ref_period] = {
1681
+ "effect": 0.0,
1682
+ "se": 0.0,
1683
+ "t_stat": np.nan,
1684
+ "p_value": np.nan,
1685
+ "conf_int": (0.0, 0.0),
1686
+ "n_obs": 0,
1687
+ }
1688
+
1689
+ # Pre-period coefficients via BJS Test 1 lead regression
1690
+ if self.pretrends:
1691
+ df_0 = df.loc[omega_0_mask].copy()
1692
+
1693
+ # Determine which cohorts' lead indicators to include.
1694
+ # balance_e restricts which cohorts contribute lead dummies,
1695
+ # but the full Omega_0 sample (including never-treated controls)
1696
+ # is kept for the within-transformed OLS (BJS Test 1, Equation 9).
1697
+ balanced_cohorts = None
1698
+ skip_preperiods = False
1699
+ if balance_e is not None:
1700
+ cohort_rel_times_0 = self._build_cohort_rel_times(df, first_treat)
1701
+ balanced_cohorts = set()
1702
+ if all_horizons:
1703
+ max_h = max(all_horizons)
1704
+ required_range = set(range(-balance_e, max_h + 1))
1705
+ for g, horizons in cohort_rel_times_0.items():
1706
+ if required_range.issubset(horizons):
1707
+ balanced_cohorts.add(g)
1708
+ if not balanced_cohorts:
1709
+ skip_preperiods = True # No cohorts qualify — skip entirely
1710
+
1711
+ if not skip_preperiods:
1712
+ rel_time_0 = np.where(
1713
+ ~df_0["_never_treated"],
1714
+ df_0[time] - df_0[first_treat],
1715
+ np.nan,
1716
+ )
1717
+
1718
+ # When balance_e is set, only include leads from balanced cohorts
1719
+ if balanced_cohorts is not None:
1720
+ is_balanced = df_0[first_treat].isin(balanced_cohorts).values
1721
+ rel_time_for_leads = np.where(is_balanced, rel_time_0, np.nan)
1722
+ else:
1723
+ rel_time_for_leads = rel_time_0
1724
+
1725
+ pre_rel_times = sorted(
1726
+ set(
1727
+ int(h)
1728
+ for h in rel_time_for_leads
1729
+ if np.isfinite(h) and h < -self.anticipation
1730
+ )
1731
+ )
1732
+ pre_rel_times = [h for h in pre_rel_times if h != ref_period]
1733
+ if self.horizon_max is not None:
1734
+ pre_rel_times = [h for h in pre_rel_times if abs(h) <= self.horizon_max]
1735
+ if pre_rel_times:
1736
+ # Survey pretrends: pass full design (subpopulation approach)
1737
+ _sw_0_pre = None
1738
+ _rs_full_pre = None
1739
+ _n_full_pre = None
1740
+ _o0_idx_pre = None
1741
+ if survey_weights is not None and resolved_survey is not None:
1742
+ _sw_0_pre = survey_weights[omega_0_mask.values]
1743
+ _rs_full_pre = resolved_survey
1744
+ _n_full_pre = len(df)
1745
+ _o0_idx_pre = np.where(omega_0_mask.values)[0]
1746
+ _survey_df_pre = (
1747
+ resolved_survey.df_survey if resolved_survey is not None else None
1748
+ )
1749
+ pre_effects, _, _ = self._compute_lead_coefficients(
1750
+ df_0,
1751
+ outcome,
1752
+ unit,
1753
+ time,
1754
+ first_treat,
1755
+ covariates,
1756
+ cluster_var,
1757
+ pre_rel_times,
1758
+ alpha=self.alpha,
1759
+ balanced_cohorts=balanced_cohorts,
1760
+ survey_weights_0=_sw_0_pre,
1761
+ resolved_survey_full=_rs_full_pre,
1762
+ n_obs_full=_n_full_pre,
1763
+ omega_0_indices=_o0_idx_pre,
1764
+ survey_df=_survey_df_pre,
1765
+ )
1766
+ event_study_effects.update(pre_effects)
1767
+
1768
+ # Collect horizons with Proposition 5 violations
1769
+ prop5_horizons = []
1770
+
1771
+ for h in all_horizons:
1772
+ if h == ref_period:
1773
+ continue
1774
+
1775
+ # Select treated obs at this horizon from balanced cohorts
1776
+ h_mask = (rel_times == h) & balanced_mask.values
1777
+ n_h = int(h_mask.sum())
1778
+
1779
+ if n_h == 0:
1780
+ continue
1781
+
1782
+ # Proposition 5 check
1783
+ if not has_never_treated and h >= h_bar:
1784
+ prop5_horizons.append(h)
1785
+ event_study_effects[h] = {
1786
+ "effect": np.nan,
1787
+ "se": np.nan,
1788
+ "t_stat": np.nan,
1789
+ "p_value": np.nan,
1790
+ "conf_int": (np.nan, np.nan),
1791
+ "n_obs": n_h,
1792
+ }
1793
+ continue
1794
+
1795
+ tau_h = tau_hat[h_mask]
1796
+ finite_h = np.isfinite(tau_h)
1797
+ valid_tau = tau_h[finite_h]
1798
+
1799
+ if len(valid_tau) == 0:
1800
+ event_study_effects[h] = {
1801
+ "effect": np.nan,
1802
+ "se": np.nan,
1803
+ "t_stat": np.nan,
1804
+ "p_value": np.nan,
1805
+ "conf_int": (np.nan, np.nan),
1806
+ "n_obs": n_h,
1807
+ }
1808
+ continue
1809
+
1810
+ # Survey-weighted or simple mean for per-horizon effect
1811
+ if survey_weights is not None:
1812
+ treated_sw = survey_weights[omega_1_mask.values]
1813
+ sw_h = treated_sw[h_mask]
1814
+ sw_valid = sw_h[finite_h]
1815
+ effect = float(np.average(valid_tau, weights=sw_valid))
1816
+ else:
1817
+ effect = float(np.mean(valid_tau))
1818
+
1819
+ # Compute SE via conservative variance with horizon-specific weights
1820
+ # When survey, aggregation weights are proportional to survey weights
1821
+ if survey_weights is not None:
1822
+ treated_sw = survey_weights[omega_1_mask.values]
1823
+ n_1 = len(tau_hat)
1824
+ weights_h = np.zeros(n_1)
1825
+ sw_h = treated_sw[h_mask]
1826
+ finite_in_h = np.isfinite(tau_h)
1827
+ sw_finite = sw_h[finite_in_h]
1828
+ # Set weights proportional to survey weights, summing to 1
1829
+ if sw_finite.sum() > 0:
1830
+ h_indices = np.where(h_mask)[0]
1831
+ finite_indices = h_indices[finite_in_h]
1832
+ weights_h[finite_indices] = sw_finite / sw_finite.sum()
1833
+ n_valid = int(finite_in_h.sum())
1834
+ else:
1835
+ weights_h, n_valid = _compute_target_weights(tau_hat, h_mask)
1836
+
1837
+ se = self._compute_conservative_variance(
1838
+ df=df,
1839
+ outcome=outcome,
1840
+ unit=unit,
1841
+ time=time,
1842
+ first_treat=first_treat,
1843
+ covariates=covariates,
1844
+ omega_0_mask=omega_0_mask,
1845
+ omega_1_mask=omega_1_mask,
1846
+ unit_fe=unit_fe,
1847
+ time_fe=time_fe,
1848
+ grand_mean=grand_mean,
1849
+ delta_hat=delta_hat,
1850
+ weights=weights_h,
1851
+ cluster_var=cluster_var,
1852
+ kept_cov_mask=kept_cov_mask,
1853
+ survey_weights=survey_weights,
1854
+ resolved_survey=resolved_survey,
1855
+ )
1856
+
1857
+ t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df)
1858
+
1859
+ event_study_effects[h] = {
1860
+ "effect": effect,
1861
+ "se": se,
1862
+ "t_stat": t_stat,
1863
+ "p_value": p_value,
1864
+ "conf_int": conf_int,
1865
+ "n_obs": n_h,
1866
+ }
1867
+
1868
+ # Proposition 5 warning
1869
+ if prop5_horizons:
1870
+ warnings.warn(
1871
+ f"Horizons {prop5_horizons} are not identified without "
1872
+ f"never-treated units (Proposition 5). Set to NaN.",
1873
+ UserWarning,
1874
+ stacklevel=3,
1875
+ )
1876
+
1877
+ # Check for empty result set after filtering
1878
+ real_effects = [
1879
+ h for h, v in event_study_effects.items() if h != ref_period and v.get("n_obs", 0) > 0
1880
+ ]
1881
+ if len(real_effects) == 0:
1882
+ filter_info = []
1883
+ if balance_e is not None:
1884
+ filter_info.append(f"balance_e={balance_e}")
1885
+ if self.horizon_max is not None:
1886
+ filter_info.append(f"horizon_max={self.horizon_max}")
1887
+ filter_str = " and ".join(filter_info) if filter_info else "filters"
1888
+ warnings.warn(
1889
+ f"Event study aggregation produced no horizons with observations "
1890
+ f"after applying {filter_str}. The result contains only the "
1891
+ f"reference period marker. Consider relaxing filter parameters.",
1892
+ UserWarning,
1893
+ stacklevel=3,
1894
+ )
1895
+
1896
+ return event_study_effects
1897
+
1898
+ def _aggregate_group(
1899
+ self,
1900
+ df: pd.DataFrame,
1901
+ outcome: str,
1902
+ unit: str,
1903
+ time: str,
1904
+ first_treat: str,
1905
+ covariates: Optional[List[str]],
1906
+ omega_0_mask: pd.Series,
1907
+ omega_1_mask: pd.Series,
1908
+ unit_fe: Dict[Any, float],
1909
+ time_fe: Dict[Any, float],
1910
+ grand_mean: float,
1911
+ delta_hat: Optional[np.ndarray],
1912
+ cluster_var: str,
1913
+ treatment_groups: List[Any],
1914
+ kept_cov_mask: Optional[np.ndarray] = None,
1915
+ survey_weights: Optional[np.ndarray] = None,
1916
+ survey_df: Optional[int] = None,
1917
+ resolved_survey=None,
1918
+ ) -> Dict[Any, Dict[str, Any]]:
1919
+ """Aggregate treatment effects by cohort."""
1920
+ df_1 = df.loc[omega_1_mask]
1921
+ tau_hat = df["_tau_hat"].loc[omega_1_mask].values
1922
+ cohorts = df_1[first_treat].values
1923
+
1924
+ group_effects: Dict[Any, Dict[str, Any]] = {}
1925
+
1926
+ for g in treatment_groups:
1927
+ g_mask = cohorts == g
1928
+ n_g = int(g_mask.sum())
1929
+
1930
+ if n_g == 0:
1931
+ continue
1932
+
1933
+ tau_g = tau_hat[g_mask]
1934
+ finite_g = np.isfinite(tau_g)
1935
+ valid_tau = tau_g[finite_g]
1936
+
1937
+ if len(valid_tau) == 0:
1938
+ group_effects[g] = {
1939
+ "effect": np.nan,
1940
+ "se": np.nan,
1941
+ "t_stat": np.nan,
1942
+ "p_value": np.nan,
1943
+ "conf_int": (np.nan, np.nan),
1944
+ "n_obs": n_g,
1945
+ }
1946
+ continue
1947
+
1948
+ # Survey-weighted or simple mean for per-group effect
1949
+ if survey_weights is not None:
1950
+ treated_sw = survey_weights[omega_1_mask.values]
1951
+ sw_g = treated_sw[g_mask]
1952
+ sw_valid = sw_g[finite_g]
1953
+ effect = float(np.average(valid_tau, weights=sw_valid))
1954
+ else:
1955
+ effect = float(np.mean(valid_tau))
1956
+
1957
+ # Compute SE with group-specific weights
1958
+ # When survey, aggregation weights proportional to survey weights
1959
+ if survey_weights is not None:
1960
+ treated_sw = survey_weights[omega_1_mask.values]
1961
+ n_1 = len(tau_hat)
1962
+ weights_g = np.zeros(n_1)
1963
+ sw_g = treated_sw[g_mask]
1964
+ sw_finite = sw_g[finite_g]
1965
+ if sw_finite.sum() > 0:
1966
+ g_indices = np.where(g_mask)[0]
1967
+ finite_indices = g_indices[finite_g]
1968
+ weights_g[finite_indices] = sw_finite / sw_finite.sum()
1969
+ else:
1970
+ weights_g, _ = _compute_target_weights(tau_hat, g_mask)
1971
+
1972
+ se = self._compute_conservative_variance(
1973
+ df=df,
1974
+ outcome=outcome,
1975
+ unit=unit,
1976
+ time=time,
1977
+ first_treat=first_treat,
1978
+ covariates=covariates,
1979
+ omega_0_mask=omega_0_mask,
1980
+ omega_1_mask=omega_1_mask,
1981
+ unit_fe=unit_fe,
1982
+ time_fe=time_fe,
1983
+ grand_mean=grand_mean,
1984
+ delta_hat=delta_hat,
1985
+ weights=weights_g,
1986
+ cluster_var=cluster_var,
1987
+ kept_cov_mask=kept_cov_mask,
1988
+ survey_weights=survey_weights,
1989
+ resolved_survey=resolved_survey,
1990
+ )
1991
+
1992
+ t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df)
1993
+
1994
+ group_effects[g] = {
1995
+ "effect": effect,
1996
+ "se": se,
1997
+ "t_stat": t_stat,
1998
+ "p_value": p_value,
1999
+ "conf_int": conf_int,
2000
+ "n_obs": n_g,
2001
+ }
2002
+
2003
+ return group_effects
2004
+
2005
+ # =========================================================================
2006
+ # Pre-trend test (Equation 9) & pre-period lead coefficients
2007
+ # =========================================================================
2008
+
2009
+ def _compute_lead_coefficients(
2010
+ self,
2011
+ df_0: pd.DataFrame,
2012
+ outcome: str,
2013
+ unit: str,
2014
+ time: str,
2015
+ first_treat: str,
2016
+ covariates: Optional[List[str]],
2017
+ cluster_var: str,
2018
+ pre_rel_times: List[int],
2019
+ alpha: float = 0.05,
2020
+ balanced_cohorts: Optional[set] = None,
2021
+ survey_weights_0: Optional[np.ndarray] = None,
2022
+ resolved_survey_full=None,
2023
+ n_obs_full: Optional[int] = None,
2024
+ omega_0_indices: Optional[np.ndarray] = None,
2025
+ survey_df: Optional[int] = None,
2026
+ ) -> Tuple[Dict[int, Dict[str, Any]], np.ndarray, np.ndarray]:
2027
+ """
2028
+ Compute pre-period lead coefficients via within-transformed OLS (Test 1).
2029
+
2030
+ Adds lead indicator dummies W_it(h) = 1[K_it = h] to the untreated
2031
+ model and estimates their coefficients. Uses cluster-robust SEs by
2032
+ default, or design-based survey VCV when ``resolved_survey_full``
2033
+ is provided (subpopulation approach: scores zero-padded to full
2034
+ panel length to preserve PSU/strata structure).
2035
+
2036
+ The full Omega_0 sample (including never-treated controls) is always
2037
+ used for within-transformation. When balanced_cohorts is provided,
2038
+ lead indicators are restricted to observations from those cohorts only.
2039
+
2040
+ Returns
2041
+ -------
2042
+ effects : dict
2043
+ Per-horizon event_study_effects entries.
2044
+ gamma : ndarray
2045
+ Lead coefficient vector.
2046
+ V_gamma : ndarray
2047
+ Sub-VCV matrix for lead coefficients.
2048
+ """
2049
+ rel_time_0 = np.where(
2050
+ ~df_0["_never_treated"],
2051
+ df_0[time] - df_0[first_treat],
2052
+ np.nan,
2053
+ )
2054
+
2055
+ # Build lead indicators — restrict to balanced cohorts if specified
2056
+ if balanced_cohorts is not None:
2057
+ is_balanced = df_0[first_treat].isin(balanced_cohorts).values
2058
+ else:
2059
+ is_balanced = None
2060
+
2061
+ lead_cols = []
2062
+ for h in pre_rel_times:
2063
+ col_name = f"_lead_{h}"
2064
+ indicator = (rel_time_0 == h).astype(float)
2065
+ if is_balanced is not None:
2066
+ indicator = indicator * is_balanced # zero out non-balanced cohorts
2067
+ df_0[col_name] = indicator
2068
+ lead_cols.append(col_name)
2069
+
2070
+ # Within-transform via iterative demeaning (survey-weighted when present)
2071
+ y_dm = self._iterative_demean(
2072
+ df_0[outcome].values,
2073
+ df_0[unit].values,
2074
+ df_0[time].values,
2075
+ df_0.index,
2076
+ weights=survey_weights_0,
2077
+ )
2078
+
2079
+ all_x_cols = lead_cols[:]
2080
+ if covariates:
2081
+ all_x_cols.extend(covariates)
2082
+
2083
+ X_dm = np.column_stack(
2084
+ [
2085
+ self._iterative_demean(
2086
+ df_0[col].values,
2087
+ df_0[unit].values,
2088
+ df_0[time].values,
2089
+ df_0.index,
2090
+ weights=survey_weights_0,
2091
+ )
2092
+ for col in all_x_cols
2093
+ ]
2094
+ )
2095
+
2096
+ # OLS for point estimates + VCV. When survey VCV will replace the
2097
+ # cluster-robust VCV, skip cluster_ids to avoid errors on domains
2098
+ # with few PSUs (the cluster-robust VCV is discarded anyway).
2099
+ cluster_ids = df_0[cluster_var].values
2100
+ _ols_weights = survey_weights_0
2101
+ _ols_weight_type = "pweight" if survey_weights_0 is not None else None
2102
+ _use_survey_vcov = resolved_survey_full is not None
2103
+ try:
2104
+ result = solve_ols(
2105
+ X_dm,
2106
+ y_dm,
2107
+ weights=_ols_weights,
2108
+ weight_type=_ols_weight_type,
2109
+ cluster_ids=None if _use_survey_vcov else cluster_ids,
2110
+ return_vcov=True,
2111
+ rank_deficient_action=self.rank_deficient_action,
2112
+ column_names=all_x_cols,
2113
+ )
2114
+ except (IndexError, np.linalg.LinAlgError):
2115
+ # All lead columns dropped (rank deficient after demeaning)
2116
+ effects: Dict[int, Dict[str, Any]] = {}
2117
+ for h in pre_rel_times:
2118
+ n_obs = int(df_0[f"_lead_{h}"].sum())
2119
+ effects[h] = {
2120
+ "effect": np.nan,
2121
+ "se": np.nan,
2122
+ "t_stat": np.nan,
2123
+ "p_value": np.nan,
2124
+ "conf_int": (np.nan, np.nan),
2125
+ "n_obs": n_obs,
2126
+ }
2127
+ for col in lead_cols:
2128
+ df_0.drop(columns=col, inplace=True)
2129
+ return (
2130
+ effects,
2131
+ np.full(len(pre_rel_times), np.nan),
2132
+ np.full((len(pre_rel_times), len(pre_rel_times)), np.nan),
2133
+ )
2134
+
2135
+ coefficients = result[0]
2136
+ vcov = result[2]
2137
+ assert vcov is not None
2138
+
2139
+ # Replace cluster-robust VCV with survey design-based VCV.
2140
+ # Use the FULL survey design (subpopulation approach): zero-pad
2141
+ # the Omega_0 scores back to full-panel length so PSU/strata
2142
+ # structure is preserved for variance estimation.
2143
+ if resolved_survey_full is not None:
2144
+ from diff_diff.survey import compute_survey_vcov
2145
+
2146
+ # Use residuals from solve_ols (safe for rank-deficient fits).
2147
+ residuals_0 = result[1]
2148
+
2149
+ # Reduce to kept (finite-coefficient) columns for VCV
2150
+ kept_mask = np.isfinite(coefficients)
2151
+ if np.all(kept_mask):
2152
+ X_for_vcov = X_dm
2153
+ res_for_vcov = residuals_0
2154
+ else:
2155
+ X_for_vcov = X_dm[:, kept_mask]
2156
+ res_for_vcov = residuals_0
2157
+
2158
+ # Zero-pad to full panel length (subpopulation approach):
2159
+ # observations outside Omega_0 contribute zero to the score,
2160
+ # but preserve PSU/strata structure for design-based variance.
2161
+ n_full_obs = n_obs_full
2162
+ k_vcov = X_for_vcov.shape[1]
2163
+ X_full = np.zeros((n_full_obs, k_vcov), dtype=np.float64)
2164
+ res_full = np.zeros(n_full_obs, dtype=np.float64)
2165
+ X_full[omega_0_indices] = X_for_vcov
2166
+ res_full[omega_0_indices] = res_for_vcov
2167
+
2168
+ vcov_kept = compute_survey_vcov(X_full, res_full, resolved_survey_full)
2169
+
2170
+ if not np.all(kept_mask):
2171
+ # Expand back: NaN rows/cols for dropped columns
2172
+ n_coef = len(coefficients)
2173
+ vcov = np.full((n_coef, n_coef), np.nan)
2174
+ kept_idx = np.where(kept_mask)[0]
2175
+ vcov[np.ix_(kept_idx, kept_idx)] = vcov_kept
2176
+ else:
2177
+ vcov = vcov_kept
2178
+
2179
+ n_leads = len(lead_cols)
2180
+ gamma = coefficients[:n_leads]
2181
+ V_gamma = vcov[:n_leads, :n_leads]
2182
+
2183
+ # Use full-design survey df for t-distribution inference
2184
+ _df = survey_df
2185
+
2186
+ # Build per-horizon effects
2187
+ effects = {}
2188
+ for j, h in enumerate(pre_rel_times):
2189
+ effect = float(gamma[j])
2190
+ se = float(np.sqrt(max(V_gamma[j, j], 0.0)))
2191
+ # n_obs from the lead indicator (respects balanced_cohorts restriction)
2192
+ n_obs = int(df_0[f"_lead_{h}"].sum())
2193
+ t_stat, p_value, conf_int = safe_inference(effect, se, alpha=alpha, df=_df)
2194
+ effects[h] = {
2195
+ "effect": effect,
2196
+ "se": se,
2197
+ "t_stat": t_stat,
2198
+ "p_value": p_value,
2199
+ "conf_int": conf_int,
2200
+ "n_obs": n_obs,
2201
+ }
2202
+
2203
+ # Clean up temporary columns
2204
+ for col in lead_cols:
2205
+ df_0.drop(columns=col, inplace=True)
2206
+
2207
+ return effects, gamma, V_gamma
2208
+
2209
+ def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
2210
+ """
2211
+ Run pre-trend test (Equation 9).
2212
+
2213
+ Adds pre-treatment lead indicators to the Step 1 OLS on Omega_0
2214
+ and tests their joint significance via Wald F-test (cluster-robust
2215
+ or design-based survey VCV when survey_design is present).
2216
+ """
2217
+ if self._fit_data is None:
2218
+ raise RuntimeError("Must call fit() before pretrend_test().")
2219
+
2220
+ fd = self._fit_data
2221
+ resolved_survey = fd.get("resolved_survey")
2222
+ if resolved_survey is not None and resolved_survey.uses_replicate_variance:
2223
+ raise NotImplementedError(
2224
+ "pretrend_test() is not yet supported for replicate-weight "
2225
+ "survey designs. Per-replicate Equation 9 lead regression "
2226
+ "refits are not implemented. Use analytical survey designs "
2227
+ "(strata/PSU/FPC) or call pretrend_test() without survey."
2228
+ )
2229
+
2230
+ df = fd["df"]
2231
+ outcome = fd["outcome"]
2232
+ unit = fd["unit"]
2233
+ time = fd["time"]
2234
+ first_treat = fd["first_treat"]
2235
+ covariates = fd["covariates"]
2236
+ omega_0_mask = fd["omega_0_mask"]
2237
+ cluster_var = fd["cluster_var"]
2238
+ resolved_survey = fd.get("resolved_survey")
2239
+ survey_weights = fd.get("survey_weights")
2240
+
2241
+ df_0 = df.loc[omega_0_mask].copy()
2242
+
2243
+ # Compute relative time for untreated obs
2244
+ rel_time_0 = np.where(
2245
+ ~df_0["_never_treated"],
2246
+ df_0[time] - df_0[first_treat],
2247
+ np.nan,
2248
+ )
2249
+
2250
+ # Get available pre-treatment relative times (negative values)
2251
+ pre_rel_times = sorted(
2252
+ set(int(h) for h in rel_time_0 if np.isfinite(h) and h < -self.anticipation)
2253
+ )
2254
+
2255
+ if len(pre_rel_times) == 0:
2256
+ return {
2257
+ "f_stat": np.nan,
2258
+ "p_value": np.nan,
2259
+ "df": 0,
2260
+ "n_leads": 0,
2261
+ "lead_coefficients": {},
2262
+ }
2263
+
2264
+ # Exclude the reference period (last pre-treatment period)
2265
+ ref = -1 - self.anticipation
2266
+ pre_rel_times = [h for h in pre_rel_times if h != ref]
2267
+
2268
+ if n_leads is not None:
2269
+ pre_rel_times = sorted(pre_rel_times, reverse=True)[:n_leads]
2270
+ pre_rel_times = sorted(pre_rel_times)
2271
+
2272
+ if len(pre_rel_times) == 0:
2273
+ return {
2274
+ "f_stat": np.nan,
2275
+ "p_value": np.nan,
2276
+ "df": 0,
2277
+ "n_leads": 0,
2278
+ "lead_coefficients": {},
2279
+ }
2280
+
2281
+ # Survey pretrends: pass full design (subpopulation approach)
2282
+ _sw_0_pt = None
2283
+ _rs_full_pt = None
2284
+ _n_full_pt = None
2285
+ _o0_idx_pt = None
2286
+ if survey_weights is not None and resolved_survey is not None:
2287
+ _sw_0_pt = survey_weights[omega_0_mask.values]
2288
+ _rs_full_pt = resolved_survey
2289
+ _n_full_pt = len(fd["df"])
2290
+ _o0_idx_pt = np.where(omega_0_mask.values)[0]
2291
+
2292
+ # Use shared lead coefficient computation
2293
+ effects, gamma, V_gamma = self._compute_lead_coefficients(
2294
+ df_0,
2295
+ outcome,
2296
+ unit,
2297
+ time,
2298
+ first_treat,
2299
+ covariates,
2300
+ cluster_var,
2301
+ pre_rel_times,
2302
+ alpha=self.alpha,
2303
+ survey_weights_0=_sw_0_pt,
2304
+ resolved_survey_full=_rs_full_pt,
2305
+ n_obs_full=_n_full_pt,
2306
+ omega_0_indices=_o0_idx_pt,
2307
+ survey_df=(resolved_survey.df_survey if resolved_survey is not None else None),
2308
+ )
2309
+
2310
+ n_leads_actual = len(pre_rel_times)
2311
+
2312
+ # Wald F-test: F = (gamma' V^{-1} gamma) / n_leads
2313
+ try:
2314
+ V_inv_gamma = np.linalg.solve(V_gamma, gamma)
2315
+ wald_stat = float(gamma @ V_inv_gamma)
2316
+ f_stat = wald_stat / n_leads_actual
2317
+ except np.linalg.LinAlgError:
2318
+ f_stat = np.nan
2319
+
2320
+ # P-value from F distribution (survey df when available)
2321
+ if np.isfinite(f_stat) and f_stat >= 0:
2322
+ if resolved_survey is not None and resolved_survey.df_survey is not None:
2323
+ df_denom = resolved_survey.df_survey
2324
+ else:
2325
+ cluster_ids = df_0[cluster_var].values
2326
+ n_clusters = len(np.unique(cluster_ids))
2327
+ df_denom = max(n_clusters - 1, 1)
2328
+ if df_denom <= 0:
2329
+ p_value = np.nan
2330
+ else:
2331
+ p_value = float(stats.f.sf(f_stat, n_leads_actual, df_denom))
2332
+ else:
2333
+ p_value = np.nan
2334
+
2335
+ lead_coefficients = {h: effects[h]["effect"] for h in pre_rel_times}
2336
+
2337
+ return {
2338
+ "f_stat": f_stat,
2339
+ "p_value": p_value,
2340
+ "df": n_leads_actual,
2341
+ "n_leads": n_leads_actual,
2342
+ "lead_coefficients": lead_coefficients,
2343
+ }
2344
+
2345
+ # =========================================================================
2346
+ # sklearn-compatible interface
2347
+ # =========================================================================
2348
+
2349
+ def get_params(self) -> Dict[str, Any]:
2350
+ """Get estimator parameters (sklearn-compatible)."""
2351
+ return {
2352
+ "anticipation": self.anticipation,
2353
+ "alpha": self.alpha,
2354
+ "cluster": self.cluster,
2355
+ "n_bootstrap": self.n_bootstrap,
2356
+ "bootstrap_weights": self.bootstrap_weights,
2357
+ "seed": self.seed,
2358
+ "rank_deficient_action": self.rank_deficient_action,
2359
+ "horizon_max": self.horizon_max,
2360
+ "aux_partition": self.aux_partition,
2361
+ "pretrends": self.pretrends,
2362
+ }
2363
+
2364
+ def set_params(self, **params) -> "ImputationDiD":
2365
+ """Set estimator parameters (sklearn-compatible)."""
2366
+ for key, value in params.items():
2367
+ if hasattr(self, key):
2368
+ setattr(self, key, value)
2369
+ else:
2370
+ raise ValueError(f"Unknown parameter: {key}")
2371
+ return self
2372
+
2373
+ def summary(self) -> str:
2374
+ """Get summary of estimation results."""
2375
+ if not self.is_fitted_:
2376
+ raise RuntimeError("Model must be fitted before calling summary()")
2377
+ assert self.results_ is not None
2378
+ return self.results_.summary()
2379
+
2380
+ def print_summary(self) -> None:
2381
+ """Print summary to stdout."""
2382
+ print(self.summary())
2383
+
2384
+
2385
+ # =============================================================================
2386
+ # Convenience function
2387
+ # =============================================================================
2388
+
2389
+
2390
+ def imputation_did(
2391
+ data: pd.DataFrame,
2392
+ outcome: str,
2393
+ unit: str,
2394
+ time: str,
2395
+ first_treat: str,
2396
+ covariates: Optional[List[str]] = None,
2397
+ aggregate: Optional[str] = None,
2398
+ balance_e: Optional[int] = None,
2399
+ survey_design: object = None,
2400
+ **kwargs,
2401
+ ) -> ImputationDiDResults:
2402
+ """
2403
+ Convenience function for imputation DiD estimation.
2404
+
2405
+ This is a shortcut for creating an ImputationDiD estimator and calling fit().
2406
+
2407
+ Parameters
2408
+ ----------
2409
+ data : pd.DataFrame
2410
+ Panel data.
2411
+ outcome : str
2412
+ Outcome variable column name.
2413
+ unit : str
2414
+ Unit identifier column name.
2415
+ time : str
2416
+ Time period column name.
2417
+ first_treat : str
2418
+ Column indicating first treatment period (0 for never-treated).
2419
+ covariates : list of str, optional
2420
+ Covariate column names.
2421
+ aggregate : str, optional
2422
+ Aggregation mode: None, "simple", "event_study", "group", "all".
2423
+ balance_e : int, optional
2424
+ Balance event study to cohorts observed at all relative times.
2425
+ survey_design : SurveyDesign, optional
2426
+ Survey design specification for design-based inference. Supports
2427
+ pweight only (aweight/fweight raise ValueError). Supports strata,
2428
+ PSU, and FPC for design-based variance. Strata enters survey df
2429
+ for t-distribution inference.
2430
+ Both analytical (n_bootstrap=0) and bootstrap inference are supported.
2431
+ **kwargs
2432
+ Additional keyword arguments passed to ImputationDiD constructor.
2433
+
2434
+ Returns
2435
+ -------
2436
+ ImputationDiDResults
2437
+ Estimation results.
2438
+
2439
+ Examples
2440
+ --------
2441
+ >>> from diff_diff import imputation_did, generate_staggered_data
2442
+ >>> data = generate_staggered_data(seed=42)
2443
+ >>> results = imputation_did(data, 'outcome', 'unit', 'time', 'first_treat',
2444
+ ... aggregate='event_study')
2445
+ >>> results.print_summary()
2446
+ """
2447
+ est = ImputationDiD(**kwargs)
2448
+ return est.fit(
2449
+ data,
2450
+ outcome=outcome,
2451
+ unit=unit,
2452
+ time=time,
2453
+ first_treat=first_treat,
2454
+ covariates=covariates,
2455
+ aggregate=aggregate,
2456
+ balance_e=balance_e,
2457
+ survey_design=survey_design,
2458
+ )