diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/two_stage.py ADDED
@@ -0,0 +1,1952 @@
1
+ """
2
+ Gardner (2022) Two-Stage Difference-in-Differences Estimator.
3
+
4
+ Implements the two-stage DiD estimator from Gardner (2022), "Two-stage
5
+ differences in differences". The method:
6
+ 1. Estimates unit + time fixed effects on untreated observations only
7
+ 2. Residualizes ALL outcomes using estimated FEs
8
+ 3. Regresses residualized outcomes on treatment indicators (Stage 2)
9
+
10
+ Inference uses the GMM sandwich variance estimator from Butts & Gardner
11
+ (2022) that correctly accounts for first-stage estimation uncertainty.
12
+
13
+ Point estimates are identical to ImputationDiD (Borusyak et al. 2024);
14
+ the key difference is the variance estimator (GMM sandwich vs. conservative).
15
+
16
+ References
17
+ ----------
18
+ Gardner, J. (2022). Two-stage differences in differences.
19
+ arXiv:2207.05943.
20
+ Butts, K. & Gardner, J. (2022). did2s: Two-Stage
21
+ Difference-in-Differences. R Journal, 14(1), 162-173.
22
+ """
23
+
24
+ import warnings
25
+ from dataclasses import replace
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import numpy as np
29
+ import pandas as pd
30
+ from scipy import sparse
31
+ from scipy.sparse.linalg import factorized as sparse_factorized
32
+
33
+ # Maximum number of elements before falling back to per-column sparse aggregation.
34
+ # 10M float64 elements ≈ 80 MB peak allocation. Above this, per-column .getcol()
35
+ # trades throughput for bounded memory. Keep in sync with two_stage_bootstrap.py.
36
+ _SPARSE_DENSE_THRESHOLD = 10_000_000
37
+
38
+ from diff_diff.linalg import solve_ols
39
+ from diff_diff.two_stage_bootstrap import TwoStageDiDBootstrapMixin
40
+ from diff_diff.two_stage_results import (
41
+ TwoStageBootstrapResults, # noqa: F401
42
+ TwoStageDiDResults,
43
+ ) # noqa: F401 (re-export)
44
+ from diff_diff.utils import safe_inference
45
+
46
+ # =============================================================================
47
+ # Main Estimator
48
+ # =============================================================================
49
+
50
+
51
+ class TwoStageDiD(TwoStageDiDBootstrapMixin):
52
+ """
53
+ Gardner (2022) two-stage Difference-in-Differences estimator.
54
+
55
+ This estimator addresses TWFE bias under heterogeneous treatment
56
+ effects by:
57
+ 1. Estimating unit + time FEs on untreated observations only
58
+ 2. Residualizing ALL outcomes using estimated FEs
59
+ 3. Regressing residualized outcomes on treatment indicators
60
+
61
+ Point estimates are identical to ImputationDiD (Borusyak et al. 2024).
62
+ The key difference is the variance estimator: TwoStageDiD uses a GMM
63
+ sandwich variance that accounts for first-stage estimation uncertainty,
64
+ while ImputationDiD uses the conservative variance from Theorem 3.
65
+
66
+ Parameters
67
+ ----------
68
+ anticipation : int, default=0
69
+ Number of periods before treatment where effects may occur.
70
+ alpha : float, default=0.05
71
+ Significance level for confidence intervals.
72
+ cluster : str, optional
73
+ Column name for cluster-robust standard errors.
74
+ If None, clusters at the unit level by default.
75
+ n_bootstrap : int, default=0
76
+ Number of bootstrap iterations. If 0, uses analytical GMM
77
+ sandwich inference.
78
+ bootstrap_weights : str, default="rademacher"
79
+ Type of bootstrap weights: "rademacher", "mammen", or "webb".
80
+ seed : int, optional
81
+ Random seed for reproducibility.
82
+ rank_deficient_action : str, default="warn"
83
+ Action when design matrix is rank-deficient:
84
+ - "warn": Issue warning and drop linearly dependent columns
85
+ - "error": Raise ValueError
86
+ - "silent": Drop columns silently
87
+ horizon_max : int, optional
88
+ Maximum event-study horizon. If set, event study effects are only
89
+ computed for |h| <= horizon_max.
90
+ pretrends : bool, default=False
91
+ If True, event study includes pre-treatment horizons for visual
92
+ pre-trends assessment. Pre-period effects should be ~0 under
93
+ parallel trends. Only affects event_study aggregation; overall
94
+ ATT and group aggregation are unchanged.
95
+
96
+ Attributes
97
+ ----------
98
+ results_ : TwoStageDiDResults
99
+ Estimation results after calling fit().
100
+ is_fitted_ : bool
101
+ Whether the model has been fitted.
102
+
103
+ Examples
104
+ --------
105
+ Basic usage:
106
+
107
+ >>> from diff_diff import TwoStageDiD, generate_staggered_data
108
+ >>> data = generate_staggered_data(n_units=200, seed=42)
109
+ >>> est = TwoStageDiD()
110
+ >>> results = est.fit(data, outcome='outcome', unit='unit',
111
+ ... time='period', first_treat='first_treat')
112
+ >>> results.print_summary()
113
+
114
+ With event study:
115
+
116
+ >>> est = TwoStageDiD()
117
+ >>> results = est.fit(data, outcome='outcome', unit='unit',
118
+ ... time='period', first_treat='first_treat',
119
+ ... aggregate='event_study')
120
+ >>> from diff_diff import plot_event_study
121
+ >>> plot_event_study(results)
122
+
123
+ Notes
124
+ -----
125
+ The two-stage estimator uses ALL untreated observations (never-treated +
126
+ not-yet-treated periods of eventually-treated units) to estimate the
127
+ counterfactual model.
128
+
129
+ References
130
+ ----------
131
+ Gardner, J. (2022). Two-stage differences in differences.
132
+ arXiv:2207.05943.
133
+ Butts, K. & Gardner, J. (2022). did2s: Two-Stage
134
+ Difference-in-Differences. R Journal, 14(1), 162-173.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ anticipation: int = 0,
140
+ alpha: float = 0.05,
141
+ cluster: Optional[str] = None,
142
+ n_bootstrap: int = 0,
143
+ bootstrap_weights: str = "rademacher",
144
+ seed: Optional[int] = None,
145
+ rank_deficient_action: str = "warn",
146
+ horizon_max: Optional[int] = None,
147
+ pretrends: bool = False,
148
+ ):
149
+ if rank_deficient_action not in ("warn", "error", "silent"):
150
+ raise ValueError(
151
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
152
+ f"got '{rank_deficient_action}'"
153
+ )
154
+ if bootstrap_weights not in ("rademacher", "mammen", "webb"):
155
+ raise ValueError(
156
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
157
+ f"got '{bootstrap_weights}'"
158
+ )
159
+
160
+ self.anticipation = anticipation
161
+ self.alpha = alpha
162
+ self.cluster = cluster
163
+ self.n_bootstrap = n_bootstrap
164
+ self.bootstrap_weights = bootstrap_weights
165
+ self.seed = seed
166
+ self.rank_deficient_action = rank_deficient_action
167
+ self.horizon_max = horizon_max
168
+ self.pretrends = pretrends
169
+
170
+ self.is_fitted_ = False
171
+ self.results_: Optional[TwoStageDiDResults] = None
172
+
173
+ def fit(
174
+ self,
175
+ data: pd.DataFrame,
176
+ outcome: str,
177
+ unit: str,
178
+ time: str,
179
+ first_treat: str,
180
+ covariates: Optional[List[str]] = None,
181
+ aggregate: Optional[str] = None,
182
+ balance_e: Optional[int] = None,
183
+ survey_design: object = None,
184
+ ) -> TwoStageDiDResults:
185
+ """
186
+ Fit the two-stage DiD estimator.
187
+
188
+ Parameters
189
+ ----------
190
+ data : pd.DataFrame
191
+ Panel data with unit and time identifiers.
192
+ outcome : str
193
+ Name of outcome variable column.
194
+ unit : str
195
+ Name of unit identifier column.
196
+ time : str
197
+ Name of time period column.
198
+ first_treat : str
199
+ Name of column indicating when unit was first treated.
200
+ Use 0 (or np.inf) for never-treated units.
201
+ covariates : list of str, optional
202
+ List of covariate column names.
203
+ aggregate : str, optional
204
+ Aggregation mode: None/"simple" (overall ATT only),
205
+ "event_study", "group", or "all".
206
+ balance_e : int, optional
207
+ When computing event study, restrict to cohorts observed at all
208
+ relative times in [-balance_e, max_h].
209
+ survey_design : SurveyDesign, optional
210
+ Survey design specification for design-based inference. Supports
211
+ pweight only (aweight/fweight raise ValueError). Supports strata,
212
+ PSU, and FPC for design-based GMM sandwich variance. Strata enters
213
+ survey df for t-distribution inference.
214
+ Both analytical (n_bootstrap=0) and bootstrap inference are supported.
215
+
216
+ Returns
217
+ -------
218
+ TwoStageDiDResults
219
+ Object containing all estimation results.
220
+
221
+ Raises
222
+ ------
223
+ ValueError
224
+ If required columns are missing or data validation fails.
225
+ """
226
+ # ---- Data validation ----
227
+ required_cols = [outcome, unit, time, first_treat]
228
+ if covariates:
229
+ required_cols.extend(covariates)
230
+
231
+ missing = [c for c in required_cols if c not in data.columns]
232
+ if missing:
233
+ raise ValueError(f"Missing columns: {missing}")
234
+
235
+ # Create working copy
236
+ df = data.copy()
237
+
238
+ # Resolve survey design if provided
239
+ from diff_diff.survey import (
240
+ _inject_cluster_as_psu,
241
+ _resolve_effective_cluster,
242
+ _resolve_survey_for_fit,
243
+ _validate_unit_constant_survey,
244
+ )
245
+
246
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
247
+ _resolve_survey_for_fit(survey_design, data, "analytical")
248
+ )
249
+
250
+ _uses_replicate_ts = resolved_survey is not None and resolved_survey.uses_replicate_variance
251
+ if _uses_replicate_ts and self.n_bootstrap > 0:
252
+ raise ValueError(
253
+ "Cannot use n_bootstrap > 0 with replicate-weight survey designs. "
254
+ "Replicate weights provide their own variance estimation."
255
+ )
256
+ # Validate within-unit constancy for panel survey designs
257
+ if resolved_survey is not None:
258
+ _validate_unit_constant_survey(data, unit, survey_design)
259
+ if resolved_survey.weight_type != "pweight":
260
+ raise ValueError(
261
+ f"TwoStageDiD survey support requires weight_type='pweight', "
262
+ f"got '{resolved_survey.weight_type}'. The survey variance math "
263
+ f"assumes probability weights (pweight)."
264
+ )
265
+ # FPC is supported — threaded through _compute_stratified_meat_from_psu_scores()
266
+ # in _compute_gmm_variance().
267
+
268
+ # Bootstrap + survey supported via PSU-level multiplier bootstrap.
269
+
270
+ df[time] = pd.to_numeric(df[time])
271
+ df[first_treat] = pd.to_numeric(df[first_treat])
272
+
273
+ # Validate absorbing treatment
274
+ ft_nunique = df.groupby(unit)[first_treat].nunique()
275
+ non_constant = ft_nunique[ft_nunique > 1]
276
+ if len(non_constant) > 0:
277
+ example_unit = non_constant.index[0]
278
+ example_vals = sorted(df.loc[df[unit] == example_unit, first_treat].unique())
279
+ warnings.warn(
280
+ f"{len(non_constant)} unit(s) have non-constant '{first_treat}' "
281
+ f"values (e.g., unit '{example_unit}' has values {example_vals}). "
282
+ f"TwoStageDiD assumes treatment is an absorbing state "
283
+ f"(once treated, always treated) with a single treatment onset "
284
+ f"time per unit. Non-constant first_treat violates this assumption "
285
+ f"and may produce unreliable estimates.",
286
+ UserWarning,
287
+ stacklevel=2,
288
+ )
289
+ df[first_treat] = df.groupby(unit)[first_treat].transform("first")
290
+
291
+ # Identify treatment status
292
+ df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
293
+
294
+ # Check for always-treated units
295
+ min_time = df[time].min()
296
+ always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time)
297
+ always_treated_units = df.loc[always_treated_mask, unit].unique()
298
+ n_always_treated = len(always_treated_units)
299
+ if n_always_treated > 0:
300
+ unit_list = ", ".join(str(u) for u in always_treated_units[:10])
301
+ suffix = f" (and {n_always_treated - 10} more)" if n_always_treated > 10 else ""
302
+ survey_note = ""
303
+ if survey_weights is not None or resolved_survey is not None:
304
+ survey_note = " Associated survey weights and design arrays " "adjusted to match."
305
+ warnings.warn(
306
+ f"{n_always_treated} unit(s) are treated in all observed periods "
307
+ f"(first_treat <= {min_time}): [{unit_list}{suffix}]. "
308
+ "These units have no untreated observations and cannot contribute "
309
+ f"to the counterfactual model. Excluding from estimation.{survey_note}",
310
+ UserWarning,
311
+ stacklevel=2,
312
+ )
313
+ df = df[~df[unit].isin(always_treated_units)].copy()
314
+
315
+ # Subset survey arrays to match filtered df
316
+ if survey_weights is not None:
317
+ keep_mask = ~data[unit].isin(always_treated_units)
318
+ survey_weights = survey_weights[keep_mask.values]
319
+ if resolved_survey is not None:
320
+ keep_mask = ~data[unit].isin(always_treated_units)
321
+ resolved_survey = replace(
322
+ resolved_survey,
323
+ weights=resolved_survey.weights[keep_mask.values],
324
+ strata=(
325
+ resolved_survey.strata[keep_mask.values]
326
+ if resolved_survey.strata is not None
327
+ else None
328
+ ),
329
+ psu=(
330
+ resolved_survey.psu[keep_mask.values]
331
+ if resolved_survey.psu is not None
332
+ else None
333
+ ),
334
+ fpc=(
335
+ resolved_survey.fpc[keep_mask.values]
336
+ if resolved_survey.fpc is not None
337
+ else None
338
+ ),
339
+ replicate_weights=(
340
+ resolved_survey.replicate_weights[keep_mask.values]
341
+ if resolved_survey.replicate_weights is not None
342
+ else None
343
+ ),
344
+ )
345
+ # Recompute n_psu/n_strata after subsetting
346
+ new_n_psu = (
347
+ len(np.unique(resolved_survey.psu)) if resolved_survey.psu is not None else 0
348
+ )
349
+ new_n_strata = (
350
+ len(np.unique(resolved_survey.strata))
351
+ if resolved_survey.strata is not None
352
+ else 0
353
+ )
354
+ resolved_survey = replace(resolved_survey, n_psu=new_n_psu, n_strata=new_n_strata)
355
+ # Recompute survey_metadata since it depends on these counts
356
+ from diff_diff.survey import compute_survey_metadata
357
+
358
+ raw_w = (
359
+ df[survey_design.weights].values.astype(np.float64)
360
+ if survey_design.weights
361
+ else np.ones(len(df), dtype=np.float64)
362
+ )
363
+ survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
364
+
365
+ # Treatment indicator with anticipation
366
+ effective_treat = df[first_treat] - self.anticipation
367
+ df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat)
368
+
369
+ # Partition into Omega_0 (untreated) and Omega_1 (treated)
370
+ omega_0_mask = ~df["_treated"]
371
+ omega_1_mask = df["_treated"]
372
+
373
+ n_omega_0 = int(omega_0_mask.sum())
374
+ n_omega_1 = int(omega_1_mask.sum())
375
+
376
+ if n_omega_0 == 0:
377
+ raise ValueError(
378
+ "No untreated observations found. Cannot estimate counterfactual model."
379
+ )
380
+ if n_omega_1 == 0:
381
+ raise ValueError("No treated observations found. Nothing to estimate.")
382
+
383
+ # Groups and time periods
384
+ time_periods = sorted(df[time].unique())
385
+ treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf])
386
+
387
+ if len(treatment_groups) == 0:
388
+ raise ValueError("No treated units found. Check 'first_treat' column.")
389
+
390
+ # Unit info
391
+ unit_info = (
392
+ df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index()
393
+ )
394
+ n_treated_units = int((~unit_info["_never_treated"]).sum())
395
+ units_in_omega_0 = df.loc[omega_0_mask, unit].unique()
396
+ n_control_units = len(units_in_omega_0)
397
+
398
+ # Cluster variable
399
+ cluster_var = self.cluster if self.cluster is not None else unit
400
+ if self.cluster is not None and self.cluster not in df.columns:
401
+ raise ValueError(
402
+ f"Cluster column '{self.cluster}' not found in data. "
403
+ f"Available columns: {list(df.columns)}"
404
+ )
405
+
406
+ # Resolve effective cluster and inject cluster-as-PSU for survey variance
407
+ if resolved_survey is not None:
408
+ cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None
409
+ effective_cluster_ids = _resolve_effective_cluster(
410
+ resolved_survey,
411
+ cluster_ids_raw,
412
+ cluster_var if self.cluster is not None else None,
413
+ )
414
+ resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
415
+ # When survey PSU is present, use it as the effective cluster for
416
+ # GMM variance (PSU overrides unit-level clustering)
417
+ if resolved_survey.psu is not None:
418
+ df["_survey_cluster"] = resolved_survey.psu
419
+ cluster_var = "_survey_cluster"
420
+ # Recompute metadata after PSU injection
421
+ if resolved_survey.psu is not None and survey_metadata is not None:
422
+ from diff_diff.survey import compute_survey_metadata
423
+
424
+ raw_w = (
425
+ df[survey_design.weights].values.astype(np.float64)
426
+ if survey_design.weights
427
+ else np.ones(len(df), dtype=np.float64)
428
+ )
429
+ survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
430
+
431
+ # Relative time
432
+ df["_rel_time"] = np.where(
433
+ ~df["_never_treated"],
434
+ df[time] - df[first_treat],
435
+ np.nan,
436
+ )
437
+
438
+ # ---- Stage 1: OLS on untreated observations ----
439
+ unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model(
440
+ df, outcome, unit, time, covariates, omega_0_mask, weights=survey_weights
441
+ )
442
+
443
+ # ---- Rank condition checks ----
444
+ treated_unit_ids = df.loc[omega_1_mask, unit].unique()
445
+ units_with_fe = set(unit_fe.keys())
446
+ units_missing_fe = set(treated_unit_ids) - units_with_fe
447
+
448
+ post_period_ids = df.loc[omega_1_mask, time].unique()
449
+ periods_with_fe = set(time_fe.keys())
450
+ periods_missing_fe = set(post_period_ids) - periods_with_fe
451
+
452
+ if units_missing_fe or periods_missing_fe:
453
+ parts = []
454
+ if units_missing_fe:
455
+ sorted_missing = sorted(units_missing_fe)
456
+ parts.append(
457
+ f"{len(units_missing_fe)} treated unit(s) have no untreated "
458
+ f"periods (units: {sorted_missing[:5]}"
459
+ f"{'...' if len(units_missing_fe) > 5 else ''})"
460
+ )
461
+ if periods_missing_fe:
462
+ sorted_missing = sorted(periods_missing_fe)
463
+ parts.append(
464
+ f"{len(periods_missing_fe)} post-treatment period(s) have no "
465
+ f"untreated units (periods: {sorted_missing[:5]}"
466
+ f"{'...' if len(periods_missing_fe) > 5 else ''})"
467
+ )
468
+ msg = (
469
+ "Rank condition violated: "
470
+ + "; ".join(parts)
471
+ + ". Affected treatment effects will be NaN."
472
+ )
473
+ if self.rank_deficient_action == "error":
474
+ raise ValueError(msg)
475
+ elif self.rank_deficient_action == "warn":
476
+ warnings.warn(msg, UserWarning, stacklevel=2)
477
+
478
+ # ---- Residualize ALL observations ----
479
+ y_tilde = self._residualize(
480
+ df, outcome, unit, time, covariates, unit_fe, time_fe, grand_mean, delta_hat
481
+ )
482
+ df["_y_tilde"] = y_tilde
483
+
484
+ # ---- Stage 2: OLS of y_tilde on treatment indicators ----
485
+ # Build design matrices and compute effects + GMM variance
486
+ ref_period = -1 - self.anticipation
487
+
488
+ # Survey degrees of freedom for t-distribution inference
489
+ _survey_df = resolved_survey.df_survey if resolved_survey is not None else None
490
+ # Replicate df: rank-deficient → NaN inference
491
+ if _uses_replicate_ts and _survey_df is None:
492
+ _survey_df = 0
493
+
494
+ # Always compute overall ATT (static specification)
495
+ overall_att, overall_se = self._stage2_static(
496
+ df=df,
497
+ unit=unit,
498
+ time=time,
499
+ first_treat=first_treat,
500
+ covariates=covariates,
501
+ omega_0_mask=omega_0_mask,
502
+ omega_1_mask=omega_1_mask,
503
+ unit_fe=unit_fe,
504
+ time_fe=time_fe,
505
+ grand_mean=grand_mean,
506
+ delta_hat=delta_hat,
507
+ cluster_var=cluster_var,
508
+ kept_cov_mask=kept_cov_mask,
509
+ survey_weights=survey_weights,
510
+ survey_weight_type=survey_weight_type,
511
+ resolved_survey=(resolved_survey if not _uses_replicate_ts else None),
512
+ )
513
+
514
+ # Compute overall ATT inference (may be overridden by replicate below)
515
+ overall_t, overall_p, overall_ci = safe_inference(
516
+ overall_att, overall_se, alpha=self.alpha, df=_survey_df
517
+ )
518
+
519
+ # Event study and group aggregation (full-sample, for point estimates)
520
+ event_study_effects = None
521
+ group_effects = None
522
+
523
+ if aggregate in ("event_study", "all"):
524
+ event_study_effects = self._stage2_event_study(
525
+ df=df,
526
+ unit=unit,
527
+ time=time,
528
+ first_treat=first_treat,
529
+ covariates=covariates,
530
+ omega_0_mask=omega_0_mask,
531
+ omega_1_mask=omega_1_mask,
532
+ unit_fe=unit_fe,
533
+ time_fe=time_fe,
534
+ grand_mean=grand_mean,
535
+ delta_hat=delta_hat,
536
+ cluster_var=cluster_var,
537
+ treatment_groups=treatment_groups,
538
+ ref_period=ref_period,
539
+ balance_e=balance_e,
540
+ kept_cov_mask=kept_cov_mask,
541
+ survey_weights=survey_weights,
542
+ survey_weight_type=survey_weight_type,
543
+ survey_df=_survey_df,
544
+ resolved_survey=(resolved_survey if not _uses_replicate_ts else None),
545
+ )
546
+
547
+ if aggregate in ("group", "all"):
548
+ group_effects = self._stage2_group(
549
+ df=df,
550
+ unit=unit,
551
+ time=time,
552
+ first_treat=first_treat,
553
+ covariates=covariates,
554
+ omega_0_mask=omega_0_mask,
555
+ omega_1_mask=omega_1_mask,
556
+ unit_fe=unit_fe,
557
+ time_fe=time_fe,
558
+ grand_mean=grand_mean,
559
+ delta_hat=delta_hat,
560
+ cluster_var=cluster_var,
561
+ treatment_groups=treatment_groups,
562
+ kept_cov_mask=kept_cov_mask,
563
+ survey_weights=survey_weights,
564
+ survey_weight_type=survey_weight_type,
565
+ survey_df=_survey_df,
566
+ resolved_survey=(resolved_survey if not _uses_replicate_ts else None),
567
+ )
568
+
569
+ # Replicate variance override: derive keys from actual outputs, then refit
570
+ _n_valid_rep_ts = None
571
+ _vcov_rep_ts = None
572
+ if _uses_replicate_ts:
573
+ from diff_diff.survey import compute_replicate_refit_variance
574
+
575
+ # Derive keys from actual outputs (excludes filtered/Prop5 horizons)
576
+ _sorted_es_periods_ts = sorted(
577
+ e
578
+ for e in (event_study_effects or {}).keys()
579
+ if np.isfinite(event_study_effects[e]["effect"])
580
+ )
581
+ _sorted_groups_ts = sorted(
582
+ g for g in (group_effects or {}).keys() if np.isfinite(group_effects[g]["effect"])
583
+ )
584
+ _n_es_ts = len(_sorted_es_periods_ts)
585
+ _n_grp_ts = len(_sorted_groups_ts)
586
+
587
+ # Build full-sample estimate from actual outputs
588
+ _full_est_ts = [overall_att]
589
+ _full_est_ts.extend([event_study_effects[e]["effect"] for e in _sorted_es_periods_ts])
590
+ _full_est_ts.extend([group_effects[g]["effect"] for g in _sorted_groups_ts])
591
+
592
+ def _refit_ts(w_r):
593
+ ufe_r, tfe_r, gm_r, delta_r, kcm_r = self._fit_untreated_model(
594
+ df,
595
+ outcome,
596
+ unit,
597
+ time,
598
+ covariates,
599
+ omega_0_mask,
600
+ weights=w_r,
601
+ )
602
+ y_tilde_r = self._residualize(
603
+ df,
604
+ outcome,
605
+ unit,
606
+ time,
607
+ covariates,
608
+ ufe_r,
609
+ tfe_r,
610
+ gm_r,
611
+ delta_r,
612
+ )
613
+ df_tmp = df.copy()
614
+ df_tmp["_y_tilde"] = y_tilde_r
615
+ results = []
616
+
617
+ att_r, _ = self._stage2_static(
618
+ df=df_tmp,
619
+ unit=unit,
620
+ time=time,
621
+ first_treat=first_treat,
622
+ covariates=covariates,
623
+ omega_0_mask=omega_0_mask,
624
+ omega_1_mask=omega_1_mask,
625
+ unit_fe=ufe_r,
626
+ time_fe=tfe_r,
627
+ grand_mean=gm_r,
628
+ delta_hat=delta_r,
629
+ cluster_var=cluster_var,
630
+ kept_cov_mask=kcm_r,
631
+ survey_weights=w_r,
632
+ survey_weight_type="pweight",
633
+ )
634
+ results.append(att_r)
635
+
636
+ if _sorted_es_periods_ts:
637
+ es_r = self._stage2_event_study(
638
+ df=df_tmp,
639
+ unit=unit,
640
+ time=time,
641
+ first_treat=first_treat,
642
+ covariates=covariates,
643
+ omega_0_mask=omega_0_mask,
644
+ omega_1_mask=omega_1_mask,
645
+ unit_fe=ufe_r,
646
+ time_fe=tfe_r,
647
+ grand_mean=gm_r,
648
+ delta_hat=delta_r,
649
+ cluster_var=cluster_var,
650
+ treatment_groups=treatment_groups,
651
+ ref_period=ref_period,
652
+ balance_e=balance_e,
653
+ kept_cov_mask=kcm_r,
654
+ survey_weights=w_r,
655
+ survey_weight_type="pweight",
656
+ survey_df=None,
657
+ )
658
+ for e in _sorted_es_periods_ts:
659
+ results.append(es_r[e]["effect"] if e in es_r else np.nan)
660
+
661
+ if _sorted_groups_ts:
662
+ grp_r = self._stage2_group(
663
+ df=df_tmp,
664
+ unit=unit,
665
+ time=time,
666
+ first_treat=first_treat,
667
+ covariates=covariates,
668
+ omega_0_mask=omega_0_mask,
669
+ omega_1_mask=omega_1_mask,
670
+ unit_fe=ufe_r,
671
+ time_fe=tfe_r,
672
+ grand_mean=gm_r,
673
+ delta_hat=delta_r,
674
+ cluster_var=cluster_var,
675
+ treatment_groups=treatment_groups,
676
+ kept_cov_mask=kcm_r,
677
+ survey_weights=w_r,
678
+ survey_weight_type="pweight",
679
+ survey_df=None,
680
+ )
681
+ for g in _sorted_groups_ts:
682
+ results.append(grp_r[g]["effect"] if g in grp_r else np.nan)
683
+
684
+ return np.array(results)
685
+
686
+ _vcov_rep_ts, _n_valid_rep_ts = compute_replicate_refit_variance(
687
+ _refit_ts, np.array(_full_est_ts), resolved_survey
688
+ )
689
+ overall_se = float(np.sqrt(max(_vcov_rep_ts[0, 0], 0.0)))
690
+
691
+ # Override df if replicates were dropped
692
+ if _n_valid_rep_ts < resolved_survey.n_replicates:
693
+ _survey_df = _n_valid_rep_ts - 1 if _n_valid_rep_ts > 1 else 0
694
+ if survey_metadata is not None:
695
+ survey_metadata.df_survey = _survey_df if _survey_df and _survey_df > 0 else None
696
+
697
+ # Recompute overall inference with replicate SE/df
698
+ overall_t, overall_p, overall_ci = safe_inference(
699
+ overall_att, overall_se, alpha=self.alpha, df=_survey_df
700
+ )
701
+
702
+ # Override event-study SEs (only for identified effects)
703
+ for i, e in enumerate(_sorted_es_periods_ts):
704
+ if event_study_effects is not None and e in event_study_effects:
705
+ se_e = float(np.sqrt(max(_vcov_rep_ts[1 + i, 1 + i], 0.0)))
706
+ eff_e = event_study_effects[e]["effect"]
707
+ t_e, p_e, ci_e = safe_inference(eff_e, se_e, alpha=self.alpha, df=_survey_df)
708
+ event_study_effects[e]["se"] = se_e
709
+ event_study_effects[e]["t_stat"] = t_e
710
+ event_study_effects[e]["p_value"] = p_e
711
+ event_study_effects[e]["conf_int"] = ci_e
712
+
713
+ # Override group SEs (only for identified effects)
714
+ for j, g in enumerate(_sorted_groups_ts):
715
+ if group_effects is not None and g in group_effects:
716
+ se_g = float(
717
+ np.sqrt(max(_vcov_rep_ts[1 + _n_es_ts + j, 1 + _n_es_ts + j], 0.0))
718
+ )
719
+ eff_g = group_effects[g]["effect"]
720
+ t_g, p_g, ci_g = safe_inference(eff_g, se_g, alpha=self.alpha, df=_survey_df)
721
+ group_effects[g]["se"] = se_g
722
+ group_effects[g]["t_stat"] = t_g
723
+ group_effects[g]["p_value"] = p_g
724
+ group_effects[g]["conf_int"] = ci_g
725
+
726
+ # Build treatment effects DataFrame
727
+ treated_df = df.loc[omega_1_mask, [unit, time, "_y_tilde", "_rel_time"]].copy()
728
+ treated_df = treated_df.rename(columns={"_y_tilde": "tau_hat", "_rel_time": "rel_time"})
729
+ tau_finite = treated_df["tau_hat"].notna() & np.isfinite(treated_df["tau_hat"].values)
730
+ n_valid_te = int(tau_finite.sum())
731
+ if n_valid_te > 0:
732
+ if survey_weights is not None:
733
+ treated_sw = survey_weights[omega_1_mask.values]
734
+ sw_finite = np.where(tau_finite, treated_sw, 0.0)
735
+ sw_sum = sw_finite.sum()
736
+ treated_df["weight"] = sw_finite / sw_sum if sw_sum > 0 else 0.0
737
+ else:
738
+ treated_df["weight"] = np.where(tau_finite, 1.0 / n_valid_te, 0.0)
739
+ else:
740
+ treated_df["weight"] = 0.0
741
+
742
+ # ---- Bootstrap ----
743
+ bootstrap_results = None
744
+ if self.n_bootstrap > 0:
745
+ try:
746
+ bootstrap_results = self._run_bootstrap(
747
+ df=df,
748
+ unit=unit,
749
+ time=time,
750
+ first_treat=first_treat,
751
+ covariates=covariates,
752
+ omega_0_mask=omega_0_mask,
753
+ omega_1_mask=omega_1_mask,
754
+ unit_fe=unit_fe,
755
+ time_fe=time_fe,
756
+ grand_mean=grand_mean,
757
+ delta_hat=delta_hat,
758
+ cluster_var=cluster_var,
759
+ kept_cov_mask=kept_cov_mask,
760
+ treatment_groups=treatment_groups,
761
+ ref_period=ref_period,
762
+ balance_e=balance_e,
763
+ original_att=overall_att,
764
+ original_event_study=event_study_effects,
765
+ original_group=group_effects,
766
+ aggregate=aggregate,
767
+ resolved_survey=resolved_survey,
768
+ )
769
+ except NotImplementedError:
770
+ raise # Don't swallow explicit rejections (e.g. lonely_psu="adjust")
771
+ except Exception as e:
772
+ warnings.warn(
773
+ f"Bootstrap failed: {e}. Skipping bootstrap inference.",
774
+ UserWarning,
775
+ stacklevel=2,
776
+ )
777
+
778
+ if bootstrap_results is not None:
779
+ # Update inference with bootstrap results
780
+ overall_se = bootstrap_results.overall_att_se
781
+ overall_t = (
782
+ overall_att / overall_se
783
+ if np.isfinite(overall_se) and overall_se > 0
784
+ else np.nan
785
+ )
786
+ overall_p = bootstrap_results.overall_att_p_value
787
+ overall_ci = bootstrap_results.overall_att_ci
788
+
789
+ # Update event study
790
+ if event_study_effects and bootstrap_results.event_study_ses:
791
+ for h in event_study_effects:
792
+ if (
793
+ h in bootstrap_results.event_study_ses
794
+ and event_study_effects[h].get("n_obs", 1) > 0
795
+ ):
796
+ event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h]
797
+ assert bootstrap_results.event_study_cis is not None
798
+ event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[
799
+ h
800
+ ]
801
+ assert bootstrap_results.event_study_p_values is not None
802
+ event_study_effects[h]["p_value"] = (
803
+ bootstrap_results.event_study_p_values[h]
804
+ )
805
+ eff_val = event_study_effects[h]["effect"]
806
+ se_val = event_study_effects[h]["se"]
807
+ event_study_effects[h]["t_stat"] = safe_inference(
808
+ eff_val, se_val, alpha=self.alpha
809
+ )[0]
810
+
811
+ # Update group effects
812
+ if group_effects and bootstrap_results.group_ses:
813
+ for g in group_effects:
814
+ if g in bootstrap_results.group_ses:
815
+ group_effects[g]["se"] = bootstrap_results.group_ses[g]
816
+ assert bootstrap_results.group_cis is not None
817
+ group_effects[g]["conf_int"] = bootstrap_results.group_cis[g]
818
+ assert bootstrap_results.group_p_values is not None
819
+ group_effects[g]["p_value"] = bootstrap_results.group_p_values[g]
820
+ eff_val = group_effects[g]["effect"]
821
+ se_val = group_effects[g]["se"]
822
+ group_effects[g]["t_stat"] = safe_inference(
823
+ eff_val, se_val, alpha=self.alpha
824
+ )[0]
825
+
826
+ # Construct results
827
+ self.results_ = TwoStageDiDResults(
828
+ treatment_effects=treated_df,
829
+ overall_att=overall_att,
830
+ overall_se=overall_se,
831
+ overall_t_stat=overall_t,
832
+ overall_p_value=overall_p,
833
+ overall_conf_int=overall_ci,
834
+ event_study_effects=event_study_effects,
835
+ group_effects=group_effects,
836
+ groups=treatment_groups,
837
+ time_periods=time_periods,
838
+ n_obs=len(df),
839
+ n_treated_obs=n_omega_1,
840
+ n_untreated_obs=n_omega_0,
841
+ n_treated_units=n_treated_units,
842
+ n_control_units=n_control_units,
843
+ alpha=self.alpha,
844
+ bootstrap_results=bootstrap_results,
845
+ survey_metadata=survey_metadata,
846
+ )
847
+
848
+ self.is_fitted_ = True
849
+ return self.results_
850
+
851
+ # =========================================================================
852
+ # Stage 1: OLS on untreated observations
853
+ # =========================================================================
854
+
855
+ def _iterative_fe(
856
+ self,
857
+ y: np.ndarray,
858
+ unit_vals: np.ndarray,
859
+ time_vals: np.ndarray,
860
+ idx: pd.Index,
861
+ max_iter: int = 100,
862
+ tol: float = 1e-10,
863
+ weights: Optional[np.ndarray] = None,
864
+ ) -> Tuple[Dict[Any, float], Dict[Any, float]]:
865
+ """
866
+ Estimate unit and time FE via iterative alternating projection.
867
+
868
+ Parameters
869
+ ----------
870
+ weights : np.ndarray, optional
871
+ Survey weights. When provided, uses weighted group means
872
+ (sum(w*x)/sum(w)) instead of unweighted means.
873
+
874
+ Returns
875
+ -------
876
+ unit_fe : dict
877
+ Mapping from unit -> unit fixed effect.
878
+ time_fe : dict
879
+ Mapping from time -> time fixed effect.
880
+ """
881
+ n = len(y)
882
+ alpha = np.zeros(n)
883
+ beta = np.zeros(n)
884
+
885
+ if weights is not None:
886
+ w_series = pd.Series(weights, index=idx)
887
+ wsum_t = w_series.groupby(time_vals).transform("sum").values
888
+ wsum_u = w_series.groupby(unit_vals).transform("sum").values
889
+
890
+ with np.errstate(invalid="ignore", divide="ignore"):
891
+ for iteration in range(max_iter):
892
+ resid_after_alpha = y - alpha
893
+ if weights is not None:
894
+ wr_t = pd.Series(resid_after_alpha * weights, index=idx)
895
+ beta_new = wr_t.groupby(time_vals).transform("sum").values / wsum_t
896
+ else:
897
+ beta_new = (
898
+ pd.Series(resid_after_alpha, index=idx)
899
+ .groupby(time_vals)
900
+ .transform("mean")
901
+ .values
902
+ )
903
+
904
+ resid_after_beta = y - beta_new
905
+ if weights is not None:
906
+ wr_u = pd.Series(resid_after_beta * weights, index=idx)
907
+ alpha_new = wr_u.groupby(unit_vals).transform("sum").values / wsum_u
908
+ else:
909
+ alpha_new = (
910
+ pd.Series(resid_after_beta, index=idx)
911
+ .groupby(unit_vals)
912
+ .transform("mean")
913
+ .values
914
+ )
915
+
916
+ max_change = max(
917
+ np.max(np.abs(alpha_new - alpha)),
918
+ np.max(np.abs(beta_new - beta)),
919
+ )
920
+ alpha = alpha_new
921
+ beta = beta_new
922
+ if max_change < tol:
923
+ break
924
+
925
+ unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
926
+ time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
927
+ return unit_fe, time_fe
928
+
929
+ @staticmethod
930
+ def _iterative_demean(
931
+ vals: np.ndarray,
932
+ unit_vals: np.ndarray,
933
+ time_vals: np.ndarray,
934
+ idx: pd.Index,
935
+ max_iter: int = 100,
936
+ tol: float = 1e-10,
937
+ weights: Optional[np.ndarray] = None,
938
+ ) -> np.ndarray:
939
+ """Demean a vector by iterative alternating projection (unit + time FE removal).
940
+
941
+ Parameters
942
+ ----------
943
+ weights : np.ndarray, optional
944
+ Survey weights. When provided, uses weighted group means
945
+ (sum(w*x)/sum(w)) instead of unweighted means.
946
+ """
947
+ result = vals.copy()
948
+
949
+ if weights is not None:
950
+ w_series = pd.Series(weights, index=idx)
951
+ wsum_t = w_series.groupby(time_vals).transform("sum").values
952
+ wsum_u = w_series.groupby(unit_vals).transform("sum").values
953
+
954
+ with np.errstate(invalid="ignore", divide="ignore"):
955
+ for _ in range(max_iter):
956
+ if weights is not None:
957
+ wr_t = pd.Series(result * weights, index=idx)
958
+ time_means = wr_t.groupby(time_vals).transform("sum").values / wsum_t
959
+ else:
960
+ time_means = (
961
+ pd.Series(result, index=idx).groupby(time_vals).transform("mean").values
962
+ )
963
+ result_after_time = result - time_means
964
+ if weights is not None:
965
+ wr_u = pd.Series(result_after_time * weights, index=idx)
966
+ unit_means = wr_u.groupby(unit_vals).transform("sum").values / wsum_u
967
+ else:
968
+ unit_means = (
969
+ pd.Series(result_after_time, index=idx)
970
+ .groupby(unit_vals)
971
+ .transform("mean")
972
+ .values
973
+ )
974
+ result_new = result_after_time - unit_means
975
+ if np.max(np.abs(result_new - result)) < tol:
976
+ result = result_new
977
+ break
978
+ result = result_new
979
+ return result
980
+
981
+ def _fit_untreated_model(
982
+ self,
983
+ df: pd.DataFrame,
984
+ outcome: str,
985
+ unit: str,
986
+ time: str,
987
+ covariates: Optional[List[str]],
988
+ omega_0_mask: pd.Series,
989
+ weights: Optional[np.ndarray] = None,
990
+ ) -> Tuple[
991
+ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray]
992
+ ]:
993
+ """
994
+ Stage 1: Estimate unit + time FE on untreated observations.
995
+
996
+ Parameters
997
+ ----------
998
+ weights : np.ndarray, optional
999
+ Full-panel survey weights (same length as df). The untreated subset
1000
+ is extracted internally via omega_0_mask. When None, unweighted.
1001
+
1002
+ Returns
1003
+ -------
1004
+ unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask
1005
+ """
1006
+ df_0 = df.loc[omega_0_mask]
1007
+ w_0 = weights[omega_0_mask.values] if weights is not None else None
1008
+
1009
+ if covariates is None or len(covariates) == 0:
1010
+ y = df_0[outcome].values.copy()
1011
+ unit_fe, time_fe = self._iterative_fe(
1012
+ y, df_0[unit].values, df_0[time].values, df_0.index, weights=w_0
1013
+ )
1014
+ return unit_fe, time_fe, 0.0, None, None
1015
+
1016
+ else:
1017
+ y = df_0[outcome].values.copy()
1018
+ X_raw = df_0[covariates].values.copy()
1019
+ units = df_0[unit].values
1020
+ times = df_0[time].values
1021
+ n_cov = len(covariates)
1022
+
1023
+ y_dm = self._iterative_demean(y, units, times, df_0.index, weights=w_0)
1024
+ X_dm = np.column_stack(
1025
+ [
1026
+ self._iterative_demean(X_raw[:, j], units, times, df_0.index, weights=w_0)
1027
+ for j in range(n_cov)
1028
+ ]
1029
+ )
1030
+
1031
+ result = solve_ols(
1032
+ X_dm,
1033
+ y_dm,
1034
+ return_vcov=False,
1035
+ rank_deficient_action=self.rank_deficient_action,
1036
+ column_names=covariates,
1037
+ weights=w_0,
1038
+ )
1039
+ delta_hat = result[0]
1040
+ kept_cov_mask = np.isfinite(delta_hat)
1041
+ delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0)
1042
+
1043
+ y_adj = y - np.dot(X_raw, delta_hat_clean)
1044
+ unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index, weights=w_0)
1045
+
1046
+ return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask
1047
+
1048
+ # =========================================================================
1049
+ # Residualization
1050
+ # =========================================================================
1051
+
1052
+ def _residualize(
1053
+ self,
1054
+ df: pd.DataFrame,
1055
+ outcome: str,
1056
+ unit: str,
1057
+ time: str,
1058
+ covariates: Optional[List[str]],
1059
+ unit_fe: Dict[Any, float],
1060
+ time_fe: Dict[Any, float],
1061
+ grand_mean: float,
1062
+ delta_hat: Optional[np.ndarray],
1063
+ ) -> np.ndarray:
1064
+ """
1065
+ Compute residualized outcome y_tilde for ALL observations.
1066
+
1067
+ y_tilde_i = y_i - mu_hat_i - eta_hat_t [- X_i @ delta_hat]
1068
+ """
1069
+ alpha_i = df[unit].map(unit_fe).values
1070
+ beta_t = df[time].map(time_fe).values
1071
+
1072
+ # Handle missing FE (NaN for units/periods not in untreated sample)
1073
+ alpha_i = np.where(pd.isna(alpha_i), np.nan, alpha_i).astype(float)
1074
+ beta_t = np.where(pd.isna(beta_t), np.nan, beta_t).astype(float)
1075
+
1076
+ y_hat = grand_mean + alpha_i + beta_t
1077
+
1078
+ if delta_hat is not None and covariates:
1079
+ y_hat = y_hat + np.dot(df[covariates].values, delta_hat)
1080
+
1081
+ y_tilde = df[outcome].values - y_hat
1082
+ return y_tilde
1083
+
1084
+ # =========================================================================
1085
+ # Stage 2 specifications
1086
+ # =========================================================================
1087
+
1088
+ @staticmethod
1089
+ def _mask_nan_ytilde(y_tilde):
1090
+ """Mask non-finite y_tilde values and warn if any found.
1091
+
1092
+ Returns the boolean mask of non-finite values. Modifies y_tilde in-place
1093
+ (sets NaN values to 0.0).
1094
+ """
1095
+ nan_mask = ~np.isfinite(y_tilde)
1096
+ if nan_mask.any():
1097
+ n_nan = int(nan_mask.sum())
1098
+ warnings.warn(
1099
+ f"{n_nan} observation(s) have non-finite imputed outcomes "
1100
+ f"(y_tilde) from unidentified fixed effects. These "
1101
+ f"observations are excluded from ATT estimation.",
1102
+ UserWarning,
1103
+ stacklevel=3,
1104
+ )
1105
+ y_tilde[nan_mask] = 0.0
1106
+ return nan_mask
1107
+
1108
+ def _stage2_static(
1109
+ self,
1110
+ df: pd.DataFrame,
1111
+ unit: str,
1112
+ time: str,
1113
+ first_treat: str,
1114
+ covariates: Optional[List[str]],
1115
+ omega_0_mask: pd.Series,
1116
+ omega_1_mask: pd.Series,
1117
+ unit_fe: Dict[Any, float],
1118
+ time_fe: Dict[Any, float],
1119
+ grand_mean: float,
1120
+ delta_hat: Optional[np.ndarray],
1121
+ cluster_var: str,
1122
+ kept_cov_mask: Optional[np.ndarray],
1123
+ survey_weights: Optional[np.ndarray] = None,
1124
+ survey_weight_type: str = "pweight",
1125
+ resolved_survey=None,
1126
+ ) -> Tuple[float, float]:
1127
+ """
1128
+ Static (simple ATT) Stage 2: OLS of y_tilde on D_it.
1129
+
1130
+ Returns (att, se).
1131
+ """
1132
+ y_tilde = df["_y_tilde"].values.copy()
1133
+ nan_mask = self._mask_nan_ytilde(y_tilde)
1134
+
1135
+ D = omega_1_mask.values.astype(float)
1136
+ # Zero out treatment indicator for NaN y_tilde obs (don't count in ATT)
1137
+ D[nan_mask] = 0.0
1138
+
1139
+ # X_2: treatment indicator (no intercept)
1140
+ X_2 = D.reshape(-1, 1)
1141
+
1142
+ # Avoid degenerate case where all treated obs have NaN y_tilde
1143
+ if D.sum() == 0:
1144
+ return np.nan, np.nan
1145
+
1146
+ # Stage 2 OLS for point estimate (discard naive SE)
1147
+ coef, residuals, _ = solve_ols(
1148
+ X_2,
1149
+ y_tilde,
1150
+ return_vcov=False,
1151
+ weights=survey_weights,
1152
+ weight_type=survey_weight_type,
1153
+ )
1154
+ att = float(coef[0])
1155
+
1156
+ # GMM sandwich variance
1157
+ eps_2 = y_tilde - np.dot(X_2, coef) # Stage 2 residuals
1158
+
1159
+ V = self._compute_gmm_variance(
1160
+ df=df,
1161
+ unit=unit,
1162
+ time=time,
1163
+ covariates=covariates,
1164
+ omega_0_mask=omega_0_mask,
1165
+ unit_fe=unit_fe,
1166
+ time_fe=time_fe,
1167
+ delta_hat=delta_hat,
1168
+ kept_cov_mask=kept_cov_mask,
1169
+ X_2=X_2,
1170
+ eps_2=eps_2,
1171
+ cluster_ids=df[cluster_var].values,
1172
+ survey_weights=survey_weights,
1173
+ resolved_survey=resolved_survey,
1174
+ )
1175
+
1176
+ se = float(np.sqrt(max(V[0, 0], 0.0)))
1177
+ return att, se
1178
+
1179
+ def _stage2_event_study(
1180
+ self,
1181
+ df: pd.DataFrame,
1182
+ unit: str,
1183
+ time: str,
1184
+ first_treat: str,
1185
+ covariates: Optional[List[str]],
1186
+ omega_0_mask: pd.Series,
1187
+ omega_1_mask: pd.Series,
1188
+ unit_fe: Dict[Any, float],
1189
+ time_fe: Dict[Any, float],
1190
+ grand_mean: float,
1191
+ delta_hat: Optional[np.ndarray],
1192
+ cluster_var: str,
1193
+ treatment_groups: List[Any],
1194
+ ref_period: int,
1195
+ balance_e: Optional[int],
1196
+ kept_cov_mask: Optional[np.ndarray],
1197
+ survey_weights: Optional[np.ndarray] = None,
1198
+ survey_weight_type: str = "pweight",
1199
+ survey_df: Optional[int] = None,
1200
+ resolved_survey=None,
1201
+ ) -> Dict[int, Dict[str, Any]]:
1202
+ """Event study Stage 2: OLS of y_tilde on relative-time dummies."""
1203
+ y_tilde = df["_y_tilde"].values.copy()
1204
+ nan_mask = self._mask_nan_ytilde(y_tilde)
1205
+ rel_times = df["_rel_time"].values
1206
+ n = len(df)
1207
+
1208
+ # Get all horizons — include pre-periods when pretrends=True
1209
+ if self.pretrends:
1210
+ evt_rel = rel_times[~df["_never_treated"].values]
1211
+ else:
1212
+ evt_rel = rel_times[omega_1_mask.values]
1213
+ all_horizons = sorted(set(int(h) for h in evt_rel if np.isfinite(h)))
1214
+
1215
+ # Apply horizon_max filter
1216
+ if self.horizon_max is not None:
1217
+ all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max]
1218
+
1219
+ # Apply balance_e filter
1220
+ if balance_e is not None:
1221
+ cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
1222
+ balanced_cohorts = set()
1223
+ if all_horizons:
1224
+ max_h = max(all_horizons)
1225
+ required_range = set(range(-balance_e, max_h + 1))
1226
+ for g, horizons in cohort_rel_times.items():
1227
+ if required_range.issubset(horizons):
1228
+ balanced_cohorts.add(g)
1229
+ if not balanced_cohorts:
1230
+ warnings.warn(
1231
+ f"No cohorts satisfy balance_e={balance_e} requirement. "
1232
+ "Event study results will contain only the reference period. "
1233
+ "Consider reducing balance_e.",
1234
+ UserWarning,
1235
+ stacklevel=2,
1236
+ )
1237
+ return {
1238
+ ref_period: {
1239
+ "effect": 0.0,
1240
+ "se": 0.0,
1241
+ "t_stat": np.nan,
1242
+ "p_value": np.nan,
1243
+ "conf_int": (0.0, 0.0),
1244
+ "n_obs": 0,
1245
+ }
1246
+ }
1247
+ balance_mask = df[first_treat].isin(balanced_cohorts).values
1248
+ else:
1249
+ balance_mask = np.ones(n, dtype=bool)
1250
+
1251
+ # Check Proposition 5: no never-treated units
1252
+ has_never_treated = df["_never_treated"].any()
1253
+ h_bar = np.inf
1254
+ if not has_never_treated and len(treatment_groups) > 1:
1255
+ h_bar = max(treatment_groups) - min(treatment_groups)
1256
+
1257
+ # Identify Prop 5 horizons and compute their actual treated obs counts.
1258
+ # Treated obs have NaN y_tilde at these horizons (counterfactual
1259
+ # unidentified), but actual_n counts them to distinguish from truly
1260
+ # empty horizons. rel_times is NaN for untreated/never-treated obs
1261
+ # (line ~653), so (rel_times == h) is False for them.
1262
+ prop5_horizons = []
1263
+ prop5_effects: Dict[int, Dict[str, Any]] = {}
1264
+ if h_bar < np.inf:
1265
+ for h in all_horizons:
1266
+ if h == ref_period:
1267
+ continue
1268
+ if h >= h_bar:
1269
+ actual_n = int(np.sum((rel_times == h) & omega_1_mask.values & balance_mask))
1270
+ if actual_n > 0:
1271
+ prop5_horizons.append(h)
1272
+ prop5_effects[h] = {
1273
+ "effect": np.nan,
1274
+ "se": np.nan,
1275
+ "t_stat": np.nan,
1276
+ "p_value": np.nan,
1277
+ "conf_int": (np.nan, np.nan),
1278
+ "n_obs": actual_n,
1279
+ }
1280
+
1281
+ # Remove reference period AND Prop 5 horizons from estimation
1282
+ prop5_set = set(prop5_horizons)
1283
+ est_horizons = [h for h in all_horizons if h != ref_period and h not in prop5_set]
1284
+
1285
+ if len(est_horizons) == 0:
1286
+ # No horizons to estimate — return just reference period
1287
+ return {
1288
+ ref_period: {
1289
+ "effect": 0.0,
1290
+ "se": 0.0,
1291
+ "t_stat": np.nan,
1292
+ "p_value": np.nan,
1293
+ "conf_int": (0.0, 0.0),
1294
+ "n_obs": 0,
1295
+ }
1296
+ }
1297
+
1298
+ # Build Stage 2 design: one column per horizon (no intercept)
1299
+ # Never-treated obs get all-zero rows (undefined relative time -> NaN)
1300
+ # With no intercept, they contribute zero to X'_2 X_2 and X'_2 y_tilde
1301
+ horizon_to_col = {h: j for j, h in enumerate(est_horizons)}
1302
+ k = len(est_horizons)
1303
+ X_2 = np.zeros((n, k))
1304
+
1305
+ for i in range(n):
1306
+ if not balance_mask[i]:
1307
+ continue
1308
+ if nan_mask[i]:
1309
+ continue # NaN y_tilde -> don't include in event study
1310
+ h = rel_times[i]
1311
+ if np.isfinite(h):
1312
+ h_int = int(h)
1313
+ if h_int in horizon_to_col:
1314
+ X_2[i, horizon_to_col[h_int]] = 1.0
1315
+
1316
+ # Stage 2 OLS
1317
+ coef, residuals, _ = solve_ols(
1318
+ X_2,
1319
+ y_tilde,
1320
+ return_vcov=False,
1321
+ weights=survey_weights,
1322
+ weight_type=survey_weight_type,
1323
+ )
1324
+ eps_2 = y_tilde - np.dot(X_2, coef)
1325
+
1326
+ # GMM variance for full coefficient vector
1327
+ V = self._compute_gmm_variance(
1328
+ df=df,
1329
+ unit=unit,
1330
+ time=time,
1331
+ covariates=covariates,
1332
+ omega_0_mask=omega_0_mask,
1333
+ unit_fe=unit_fe,
1334
+ time_fe=time_fe,
1335
+ delta_hat=delta_hat,
1336
+ kept_cov_mask=kept_cov_mask,
1337
+ X_2=X_2,
1338
+ eps_2=eps_2,
1339
+ cluster_ids=df[cluster_var].values,
1340
+ survey_weights=survey_weights,
1341
+ resolved_survey=resolved_survey,
1342
+ )
1343
+
1344
+ # Build results dict
1345
+ event_study_effects: Dict[int, Dict[str, Any]] = {}
1346
+
1347
+ # Reference period marker
1348
+ event_study_effects[ref_period] = {
1349
+ "effect": 0.0,
1350
+ "se": 0.0,
1351
+ "t_stat": np.nan,
1352
+ "p_value": np.nan,
1353
+ "conf_int": (0.0, 0.0),
1354
+ "n_obs": 0,
1355
+ }
1356
+
1357
+ for h in est_horizons:
1358
+ j = horizon_to_col[h]
1359
+ n_obs = int(np.sum(X_2[:, j]))
1360
+
1361
+ if n_obs == 0:
1362
+ event_study_effects[h] = {
1363
+ "effect": np.nan,
1364
+ "se": np.nan,
1365
+ "t_stat": np.nan,
1366
+ "p_value": np.nan,
1367
+ "conf_int": (np.nan, np.nan),
1368
+ "n_obs": 0,
1369
+ }
1370
+ continue
1371
+
1372
+ effect = float(coef[j])
1373
+ se = float(np.sqrt(max(V[j, j], 0.0)))
1374
+
1375
+ t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha, df=survey_df)
1376
+
1377
+ event_study_effects[h] = {
1378
+ "effect": effect,
1379
+ "se": se,
1380
+ "t_stat": t_stat,
1381
+ "p_value": p_val,
1382
+ "conf_int": ci,
1383
+ "n_obs": n_obs,
1384
+ }
1385
+
1386
+ # Add Proposition 5 entries (unidentified horizons with n_obs > 0)
1387
+ event_study_effects.update(prop5_effects)
1388
+
1389
+ if prop5_horizons:
1390
+ warnings.warn(
1391
+ f"Horizons {prop5_horizons} are not identified without "
1392
+ f"never-treated units (Proposition 5). Set to NaN.",
1393
+ UserWarning,
1394
+ stacklevel=2,
1395
+ )
1396
+
1397
+ return event_study_effects
1398
+
1399
+ def _stage2_group(
1400
+ self,
1401
+ df: pd.DataFrame,
1402
+ unit: str,
1403
+ time: str,
1404
+ first_treat: str,
1405
+ covariates: Optional[List[str]],
1406
+ omega_0_mask: pd.Series,
1407
+ omega_1_mask: pd.Series,
1408
+ unit_fe: Dict[Any, float],
1409
+ time_fe: Dict[Any, float],
1410
+ grand_mean: float,
1411
+ delta_hat: Optional[np.ndarray],
1412
+ cluster_var: str,
1413
+ treatment_groups: List[Any],
1414
+ kept_cov_mask: Optional[np.ndarray],
1415
+ survey_weights: Optional[np.ndarray] = None,
1416
+ survey_weight_type: str = "pweight",
1417
+ survey_df: Optional[int] = None,
1418
+ resolved_survey=None,
1419
+ ) -> Dict[Any, Dict[str, Any]]:
1420
+ """Group (cohort) Stage 2: OLS of y_tilde on cohort dummies."""
1421
+ y_tilde = df["_y_tilde"].values.copy()
1422
+ nan_mask = self._mask_nan_ytilde(y_tilde)
1423
+ n = len(df)
1424
+
1425
+ # Build Stage 2 design: one column per cohort (no intercept)
1426
+ group_to_col = {g: j for j, g in enumerate(treatment_groups)}
1427
+ k = len(treatment_groups)
1428
+ X_2 = np.zeros((n, k))
1429
+
1430
+ ft_vals = df[first_treat].values
1431
+ treated_mask = omega_1_mask.values
1432
+ for i in range(n):
1433
+ if treated_mask[i] and not nan_mask[i]:
1434
+ g = ft_vals[i]
1435
+ if g in group_to_col:
1436
+ X_2[i, group_to_col[g]] = 1.0
1437
+
1438
+ # Stage 2 OLS
1439
+ coef, residuals, _ = solve_ols(
1440
+ X_2,
1441
+ y_tilde,
1442
+ return_vcov=False,
1443
+ weights=survey_weights,
1444
+ weight_type=survey_weight_type,
1445
+ )
1446
+ eps_2 = y_tilde - np.dot(X_2, coef)
1447
+
1448
+ # GMM variance
1449
+ V = self._compute_gmm_variance(
1450
+ df=df,
1451
+ unit=unit,
1452
+ time=time,
1453
+ covariates=covariates,
1454
+ omega_0_mask=omega_0_mask,
1455
+ unit_fe=unit_fe,
1456
+ time_fe=time_fe,
1457
+ delta_hat=delta_hat,
1458
+ kept_cov_mask=kept_cov_mask,
1459
+ X_2=X_2,
1460
+ eps_2=eps_2,
1461
+ cluster_ids=df[cluster_var].values,
1462
+ survey_weights=survey_weights,
1463
+ resolved_survey=resolved_survey,
1464
+ )
1465
+
1466
+ group_effects: Dict[Any, Dict[str, Any]] = {}
1467
+ for g in treatment_groups:
1468
+ j = group_to_col[g]
1469
+ n_obs = int(np.sum(X_2[:, j]))
1470
+
1471
+ if n_obs == 0:
1472
+ group_effects[g] = {
1473
+ "effect": np.nan,
1474
+ "se": np.nan,
1475
+ "t_stat": np.nan,
1476
+ "p_value": np.nan,
1477
+ "conf_int": (np.nan, np.nan),
1478
+ "n_obs": 0,
1479
+ }
1480
+ continue
1481
+
1482
+ effect = float(coef[j])
1483
+ se = float(np.sqrt(max(V[j, j], 0.0)))
1484
+
1485
+ t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha, df=survey_df)
1486
+
1487
+ group_effects[g] = {
1488
+ "effect": effect,
1489
+ "se": se,
1490
+ "t_stat": t_stat,
1491
+ "p_value": p_val,
1492
+ "conf_int": ci,
1493
+ "n_obs": n_obs,
1494
+ }
1495
+
1496
+ return group_effects
1497
+
1498
+ # =========================================================================
1499
+ # GMM score computation
1500
+ # =========================================================================
1501
+
1502
+ @staticmethod
1503
+ def _compute_gmm_scores(
1504
+ c_by_cluster: np.ndarray,
1505
+ gamma_hat: np.ndarray,
1506
+ s2_by_cluster: np.ndarray,
1507
+ ) -> np.ndarray:
1508
+ """
1509
+ Compute per-cluster GMM scores S_g = gamma_hat' c_g - X'_{2g} eps_{2g}.
1510
+
1511
+ Handles NaN/overflow from rank-deficient FE by wrapping in errstate
1512
+ and replacing non-finite values with 0.
1513
+
1514
+ Parameters
1515
+ ----------
1516
+ c_by_cluster : np.ndarray, shape (G, p)
1517
+ Per-cluster Stage 1 scores.
1518
+ gamma_hat : np.ndarray, shape (p, k)
1519
+ Cross-moment correction matrix.
1520
+ s2_by_cluster : np.ndarray, shape (G, k)
1521
+ Per-cluster Stage 2 scores.
1522
+
1523
+ Returns
1524
+ -------
1525
+ np.ndarray, shape (G, k)
1526
+ Per-cluster influence scores.
1527
+ """
1528
+ with np.errstate(invalid="ignore", divide="ignore", over="ignore"):
1529
+ correction = np.dot(c_by_cluster, gamma_hat)
1530
+ np.nan_to_num(correction, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
1531
+ return correction - s2_by_cluster
1532
+
1533
+ # =========================================================================
1534
+ # GMM Sandwich Variance (Butts & Gardner 2022)
1535
+ # =========================================================================
1536
+
1537
+ def _compute_gmm_variance(
1538
+ self,
1539
+ df: pd.DataFrame,
1540
+ unit: str,
1541
+ time: str,
1542
+ covariates: Optional[List[str]],
1543
+ omega_0_mask: pd.Series,
1544
+ unit_fe: Dict[Any, float],
1545
+ time_fe: Dict[Any, float],
1546
+ delta_hat: Optional[np.ndarray],
1547
+ kept_cov_mask: Optional[np.ndarray],
1548
+ X_2: np.ndarray,
1549
+ eps_2: np.ndarray,
1550
+ cluster_ids: np.ndarray,
1551
+ survey_weights: Optional[np.ndarray] = None,
1552
+ resolved_survey=None,
1553
+ ) -> np.ndarray:
1554
+ """
1555
+ Compute GMM sandwich variance (Butts & Gardner 2022).
1556
+
1557
+ Matches the R `did2s` source code implementation: uses the GLOBAL
1558
+ Hessian inverse (not per-cluster) and NO finite-sample adjustments.
1559
+
1560
+ The per-observation influence function is:
1561
+ IF_i = (X'_2 X_2)^{-1} [gamma_hat' x_{10i} eps_{10i} - x_{2i} eps_{2i}]
1562
+
1563
+ where gamma_hat = (X'_{10} X_{10})^{-1} (X'_1 X_2) uses the GLOBAL
1564
+ cross-moment.
1565
+
1566
+ The cluster-robust variance is:
1567
+ V = (X'_2 X_2)^{-1} (sum_g S_g S'_g) (X'_2 X_2)^{-1}
1568
+ S_g = gamma_hat' c_g - X'_{2g} eps_{2g}
1569
+ c_g = X'_{10g} eps_{10g}
1570
+
1571
+ With survey weights W (diagonal):
1572
+ Bread: (X'_2 W X_2)^{-1}
1573
+ gamma_hat: (X'_{10} W X_{10})^{-1} (X'_1 W X_2)
1574
+ c_g = sum_{i in g} w_i * x_{10i} * eps_{10i}
1575
+ s2_g = sum_{i in g} w_i * x_{2i} * eps_{2i}
1576
+
1577
+ Parameters
1578
+ ----------
1579
+ X_2 : np.ndarray, shape (n, k)
1580
+ Stage 2 design matrix (treatment indicators).
1581
+ eps_2 : np.ndarray, shape (n,)
1582
+ Stage 2 residuals.
1583
+ cluster_ids : np.ndarray, shape (n,)
1584
+ Cluster identifiers.
1585
+ survey_weights : np.ndarray, optional
1586
+ Survey weights of shape (n,). When None, unweighted (identical
1587
+ to current code).
1588
+
1589
+ Returns
1590
+ -------
1591
+ np.ndarray, shape (k, k)
1592
+ Variance-covariance matrix.
1593
+ """
1594
+ n = len(df)
1595
+ k = X_2.shape[1]
1596
+
1597
+ # Exclude rank-deficient covariates
1598
+ cov_list = covariates
1599
+ if covariates and kept_cov_mask is not None and not np.all(kept_cov_mask):
1600
+ cov_list = [c for c, k_ in zip(covariates, kept_cov_mask) if k_]
1601
+
1602
+ # Build sparse FE design matrices X_1 (all obs) and X_10 (untreated only)
1603
+ X_1_sparse, X_10_sparse, unit_to_idx, time_to_idx = self._build_fe_design(
1604
+ df, unit, time, cov_list, omega_0_mask
1605
+ )
1606
+
1607
+ p = X_1_sparse.shape[1]
1608
+
1609
+ # eps_10 = Y - X_10 @ gamma_hat
1610
+ # Untreated: stage 1 residual (Y - fitted). Treated: Y (X_10 rows = 0).
1611
+ # Reconstruct Y from y_tilde: Y = y_tilde + fitted_stage1
1612
+ alpha_i = df[unit].map(unit_fe).values
1613
+ beta_t = df[time].map(time_fe).values
1614
+ alpha_i = np.where(pd.isna(alpha_i), 0.0, alpha_i).astype(float)
1615
+ beta_t = np.where(pd.isna(beta_t), 0.0, beta_t).astype(float)
1616
+ fitted_1 = alpha_i + beta_t
1617
+ if delta_hat is not None and cov_list:
1618
+ if kept_cov_mask is not None and not np.all(kept_cov_mask):
1619
+ fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat[kept_cov_mask])
1620
+ else:
1621
+ fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat)
1622
+
1623
+ y_tilde = df["_y_tilde"].values
1624
+ y_vals = y_tilde + fitted_1 # reconstruct Y
1625
+
1626
+ # eps_10: for untreated, stage 1 residual; for treated, Y_i (since X_10 rows = 0)
1627
+ eps_10 = np.empty(n)
1628
+ omega_0 = omega_0_mask.values
1629
+ eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] # Stage 1 residual
1630
+ eps_10[~omega_0] = y_vals[~omega_0] # x_{10i} = 0, so eps_10 = Y
1631
+
1632
+ # 1. gamma_hat = (X'_{10} W X_{10})^{-1} (X'_1 W X_2) [p x k]
1633
+ # With survey weights, both cross-products need W
1634
+ if survey_weights is not None:
1635
+ XtWX_10 = X_10_sparse.T @ X_10_sparse.multiply(survey_weights[:, None])
1636
+ Xt1_WX2 = X_1_sparse.T @ (X_2 * survey_weights[:, None])
1637
+ else:
1638
+ XtWX_10 = X_10_sparse.T @ X_10_sparse # (p x p) sparse
1639
+ Xt1_WX2 = X_1_sparse.T @ X_2 # (p x k) dense
1640
+
1641
+ try:
1642
+ solve_XtX = sparse_factorized(XtWX_10.tocsc())
1643
+ if Xt1_WX2.ndim == 1:
1644
+ gamma_hat = solve_XtX(Xt1_WX2).reshape(-1, 1)
1645
+ else:
1646
+ gamma_hat = np.column_stack(
1647
+ [solve_XtX(Xt1_WX2[:, j]) for j in range(Xt1_WX2.shape[1])]
1648
+ )
1649
+ except RuntimeError:
1650
+ # Singular matrix — fall back to dense least-squares
1651
+ gamma_hat = np.linalg.lstsq(XtWX_10.toarray(), Xt1_WX2, rcond=None)[0]
1652
+ if gamma_hat.ndim == 1:
1653
+ gamma_hat = gamma_hat.reshape(-1, 1)
1654
+
1655
+ # 2. Per-cluster Stage 1 scores: c_g = sum_{i in g} w_i * x_{10i} * eps_{10i}
1656
+ # Only untreated obs have non-zero X_10 rows
1657
+ # With survey weights: multiply eps_10 by survey_weights before sparse multiply
1658
+ if survey_weights is not None:
1659
+ weighted_eps_10 = survey_weights * eps_10
1660
+ else:
1661
+ weighted_eps_10 = eps_10
1662
+ weighted_X10 = X_10_sparse.multiply(weighted_eps_10[:, None]) # sparse element-wise
1663
+
1664
+ unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True)
1665
+ G = len(unique_clusters)
1666
+
1667
+ n_elements = weighted_X10.shape[0] * weighted_X10.shape[1]
1668
+ c_by_cluster = np.zeros((G, p))
1669
+ if n_elements > _SPARSE_DENSE_THRESHOLD:
1670
+ # Per-column path: limits peak memory for large FE matrices
1671
+ weighted_X10_csc = weighted_X10.tocsc()
1672
+ for j_col in range(p):
1673
+ col_data = weighted_X10_csc.getcol(j_col).toarray().ravel()
1674
+ np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data)
1675
+ else:
1676
+ # Dense path: faster for moderate-size matrices
1677
+ weighted_X10_dense = weighted_X10.toarray()
1678
+ for j_col in range(p):
1679
+ np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col])
1680
+
1681
+ # 3. Per-cluster Stage 2 scores: s2_g = sum_{i in g} w_i * x_{2i} * eps_{2i}
1682
+ if survey_weights is not None:
1683
+ weighted_eps_2 = survey_weights * eps_2
1684
+ else:
1685
+ weighted_eps_2 = eps_2
1686
+ weighted_X2 = X_2 * weighted_eps_2[:, None] # (n x k) dense
1687
+ s2_by_cluster = np.zeros((G, k))
1688
+ for j_col in range(k):
1689
+ np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col])
1690
+
1691
+ # 4. S_g = gamma_hat' c_g - X'_{2g} eps_{2g}
1692
+ S = self._compute_gmm_scores(c_by_cluster, gamma_hat, s2_by_cluster)
1693
+
1694
+ # 5. Meat: sum_g S_g S'_g = S' S
1695
+ _use_stratified_meat = resolved_survey is not None and (
1696
+ resolved_survey.strata is not None or resolved_survey.fpc is not None
1697
+ )
1698
+ if _use_stratified_meat:
1699
+ from diff_diff.survey import _compute_stratified_meat_from_psu_scores
1700
+
1701
+ # Build PSU→stratum and PSU→FPC mappings from observation-level arrays.
1702
+ # cluster_ids used here match resolved_survey.psu (via _inject_cluster_as_psu).
1703
+ # unique_clusters is already computed at line above (np.unique(cluster_ids)).
1704
+ G_meat = len(unique_clusters)
1705
+
1706
+ # Strata: synthesize single stratum when strata is None (unstratified FPC)
1707
+ if resolved_survey.strata is not None:
1708
+ psu_strata = np.empty(G_meat, dtype=resolved_survey.strata.dtype)
1709
+ for idx, c in enumerate(unique_clusters):
1710
+ obs_idx = np.where(cluster_ids == c)[0][0]
1711
+ psu_strata[idx] = resolved_survey.strata[obs_idx]
1712
+ else:
1713
+ psu_strata = np.zeros(G_meat, dtype=int)
1714
+
1715
+ # FPC: map observation-level FPC to PSU level
1716
+ psu_fpc = None
1717
+ if resolved_survey.fpc is not None:
1718
+ psu_fpc = np.empty(G_meat, dtype=np.float64)
1719
+ for idx, c in enumerate(unique_clusters):
1720
+ obs_idx = np.where(cluster_ids == c)[0][0]
1721
+ psu_fpc[idx] = resolved_survey.fpc[obs_idx]
1722
+
1723
+ # Unstratified single-PSU: variance is unidentified (matches
1724
+ # _compute_stratified_psu_meat at survey.py:1225 which returns
1725
+ # zero meat with no variance_computed flag for n_psu < 2).
1726
+ if resolved_survey.strata is None and G_meat < 2:
1727
+ return np.full((k, k), np.nan)
1728
+
1729
+ # Reorder S rows to match unique_clusters ordering
1730
+ # S is built using np.add.at with cluster_indices from pd.factorize,
1731
+ # which uses the same order as unique_clusters from the data.
1732
+ meat, _var_computed, _legit_zero = _compute_stratified_meat_from_psu_scores(
1733
+ psu_scores=S,
1734
+ psu_strata=psu_strata,
1735
+ fpc_per_psu=psu_fpc,
1736
+ lonely_psu=resolved_survey.lonely_psu,
1737
+ )
1738
+ # If no variance was computed and no legitimate zeros, variance
1739
+ # is unidentified — return NaN VCV so caller gets NaN SE.
1740
+ if not _var_computed and _legit_zero == 0:
1741
+ return np.full((k, k), np.nan)
1742
+ else:
1743
+ with np.errstate(invalid="ignore", over="ignore"):
1744
+ meat = S.T @ S # (k x k)
1745
+
1746
+ # 6. Bread: (X'_2 W X_2)^{-1}
1747
+ with np.errstate(invalid="ignore", over="ignore", divide="ignore"):
1748
+ if survey_weights is not None:
1749
+ XtWX_2 = X_2.T @ (X_2 * survey_weights[:, None])
1750
+ else:
1751
+ XtWX_2 = X_2.T @ X_2
1752
+ try:
1753
+ bread = np.linalg.solve(XtWX_2, np.eye(k))
1754
+ except np.linalg.LinAlgError:
1755
+ bread = np.linalg.lstsq(XtWX_2, np.eye(k), rcond=None)[0]
1756
+
1757
+ # 7. V = bread @ meat @ bread
1758
+ V = bread @ meat @ bread
1759
+ return V
1760
+
1761
+ def _build_fe_design(
1762
+ self,
1763
+ df: pd.DataFrame,
1764
+ unit: str,
1765
+ time: str,
1766
+ covariates: Optional[List[str]],
1767
+ omega_0_mask: pd.Series,
1768
+ ) -> Tuple[sparse.csr_matrix, sparse.csr_matrix, Dict[Any, int], Dict[Any, int]]:
1769
+ """
1770
+ Build sparse FE design matrices X_1 (all obs) and X_10 (untreated rows only).
1771
+
1772
+ Column layout: [unit_0, ..., unit_{U-2}, time_0, ..., time_{T-2}, cov_1, ..., cov_C]
1773
+ (Drop first unit and first time for identification.)
1774
+
1775
+ X_10 is identical to X_1 except that rows for treated observations are zeroed out.
1776
+
1777
+ Returns
1778
+ -------
1779
+ X_1_sparse : sparse.csr_matrix, shape (n, p)
1780
+ X_10_sparse : sparse.csr_matrix, shape (n, p)
1781
+ unit_to_idx : dict
1782
+ time_to_idx : dict
1783
+ """
1784
+ n = len(df)
1785
+ unit_vals = df[unit].values
1786
+ time_vals = df[time].values
1787
+ omega_0 = omega_0_mask.values
1788
+
1789
+ all_units = np.unique(unit_vals)
1790
+ all_times = np.unique(time_vals)
1791
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
1792
+ time_to_idx = {t: i for i, t in enumerate(all_times)}
1793
+ n_units = len(all_units)
1794
+ n_times = len(all_times)
1795
+ n_cov = len(covariates) if covariates else 0
1796
+ n_fe_cols = (n_units - 1) + (n_times - 1)
1797
+
1798
+ def _build_rows(mask=None):
1799
+ """Build sparse matrix for given observation mask."""
1800
+ # Unit dummies (drop first)
1801
+ u_indices = np.array([unit_to_idx[u] for u in unit_vals])
1802
+ u_mask = u_indices > 0
1803
+ if mask is not None:
1804
+ u_mask = u_mask & mask
1805
+
1806
+ u_rows = np.arange(n)[u_mask]
1807
+ u_cols = u_indices[u_mask] - 1
1808
+
1809
+ # Time dummies (drop first)
1810
+ t_indices = np.array([time_to_idx[t] for t in time_vals])
1811
+ t_mask = t_indices > 0
1812
+ if mask is not None:
1813
+ t_mask = t_mask & mask
1814
+
1815
+ t_rows = np.arange(n)[t_mask]
1816
+ t_cols = (n_units - 1) + t_indices[t_mask] - 1
1817
+
1818
+ rows = np.concatenate([u_rows, t_rows])
1819
+ cols = np.concatenate([u_cols, t_cols])
1820
+ data = np.ones(len(rows))
1821
+
1822
+ A_fe = sparse.csr_matrix((data, (rows, cols)), shape=(n, n_fe_cols))
1823
+
1824
+ if n_cov > 0:
1825
+ cov_data = df[covariates].values.copy()
1826
+ if mask is not None:
1827
+ cov_data[~mask] = 0.0
1828
+ A_cov = sparse.csr_matrix(cov_data)
1829
+ A = sparse.hstack([A_fe, A_cov], format="csr")
1830
+ else:
1831
+ A = A_fe
1832
+
1833
+ return A
1834
+
1835
+ X_1 = _build_rows(mask=None)
1836
+ X_10 = _build_rows(mask=omega_0)
1837
+
1838
+ return X_1, X_10, unit_to_idx, time_to_idx
1839
+
1840
+ # =========================================================================
1841
+ # sklearn-compatible interface
1842
+ # =========================================================================
1843
+
1844
+ def get_params(self) -> Dict[str, Any]:
1845
+ """Get estimator parameters (sklearn-compatible)."""
1846
+ return {
1847
+ "anticipation": self.anticipation,
1848
+ "alpha": self.alpha,
1849
+ "cluster": self.cluster,
1850
+ "n_bootstrap": self.n_bootstrap,
1851
+ "bootstrap_weights": self.bootstrap_weights,
1852
+ "seed": self.seed,
1853
+ "rank_deficient_action": self.rank_deficient_action,
1854
+ "horizon_max": self.horizon_max,
1855
+ "pretrends": self.pretrends,
1856
+ }
1857
+
1858
+ def set_params(self, **params) -> "TwoStageDiD":
1859
+ """Set estimator parameters (sklearn-compatible)."""
1860
+ for key, value in params.items():
1861
+ if hasattr(self, key):
1862
+ setattr(self, key, value)
1863
+ else:
1864
+ raise ValueError(f"Unknown parameter: {key}")
1865
+ return self
1866
+
1867
+ def summary(self) -> str:
1868
+ """Get summary of estimation results."""
1869
+ if not self.is_fitted_:
1870
+ raise RuntimeError("Model must be fitted before calling summary()")
1871
+ assert self.results_ is not None
1872
+ return self.results_.summary()
1873
+
1874
+ def print_summary(self) -> None:
1875
+ """Print summary to stdout."""
1876
+ print(self.summary())
1877
+
1878
+
1879
+ # =============================================================================
1880
+ # Convenience function
1881
+ # =============================================================================
1882
+
1883
+
1884
+ def two_stage_did(
1885
+ data: pd.DataFrame,
1886
+ outcome: str,
1887
+ unit: str,
1888
+ time: str,
1889
+ first_treat: str,
1890
+ covariates: Optional[List[str]] = None,
1891
+ aggregate: Optional[str] = None,
1892
+ balance_e: Optional[int] = None,
1893
+ survey_design: object = None,
1894
+ **kwargs,
1895
+ ) -> TwoStageDiDResults:
1896
+ """
1897
+ Convenience function for two-stage DiD estimation.
1898
+
1899
+ This is a shortcut for creating a TwoStageDiD estimator and calling fit().
1900
+
1901
+ Parameters
1902
+ ----------
1903
+ data : pd.DataFrame
1904
+ Panel data.
1905
+ outcome : str
1906
+ Outcome variable column name.
1907
+ unit : str
1908
+ Unit identifier column name.
1909
+ time : str
1910
+ Time period column name.
1911
+ first_treat : str
1912
+ Column indicating first treatment period (0 for never-treated).
1913
+ covariates : list of str, optional
1914
+ Covariate column names.
1915
+ aggregate : str, optional
1916
+ Aggregation mode: None, "simple", "event_study", "group", "all".
1917
+ balance_e : int, optional
1918
+ Balance event study to cohorts observed at all relative times.
1919
+ survey_design : SurveyDesign, optional
1920
+ Survey design specification for design-based inference. Supports
1921
+ pweight only (aweight/fweight raise ValueError). Supports strata,
1922
+ PSU, and FPC for design-based GMM sandwich variance. Strata enters
1923
+ survey df for t-distribution inference.
1924
+ Both analytical (n_bootstrap=0) and bootstrap inference are supported.
1925
+ **kwargs
1926
+ Additional keyword arguments passed to TwoStageDiD constructor.
1927
+
1928
+ Returns
1929
+ -------
1930
+ TwoStageDiDResults
1931
+ Estimation results.
1932
+
1933
+ Examples
1934
+ --------
1935
+ >>> from diff_diff import two_stage_did, generate_staggered_data
1936
+ >>> data = generate_staggered_data(seed=42)
1937
+ >>> results = two_stage_did(data, 'outcome', 'unit', 'period',
1938
+ ... 'first_treat', aggregate='event_study')
1939
+ >>> results.print_summary()
1940
+ """
1941
+ est = TwoStageDiD(**kwargs)
1942
+ return est.fit(
1943
+ data,
1944
+ outcome=outcome,
1945
+ unit=unit,
1946
+ time=time,
1947
+ first_treat=first_treat,
1948
+ covariates=covariates,
1949
+ aggregate=aggregate,
1950
+ balance_e=balance_e,
1951
+ survey_design=survey_design,
1952
+ )