diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1545 @@
1
+ """
2
+ Staggered Triple Difference (DDD) estimator.
3
+
4
+ Implements Ortiz-Villavicencio & Sant'Anna (2025) for staggered adoption
5
+ settings with an eligibility dimension, combining group-time DDD effects
6
+ via GMM-optimal weighting.
7
+
8
+ Core pairwise DiD computation matches R's triplediff::compute_did() exactly
9
+ (Riesz/Hajek normalization, separate M1/M3 OR corrections, hessian = (X'WX)^{-1}*n).
10
+ """
11
+
12
+ import warnings
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from diff_diff.linalg import (
19
+ _check_propensity_diagnostics,
20
+ solve_logit,
21
+ )
22
+ from diff_diff.staggered_aggregation import (
23
+ CallawaySantAnnaAggregationMixin,
24
+ )
25
+ from diff_diff.staggered_bootstrap import (
26
+ CallawaySantAnnaBootstrapMixin,
27
+ )
28
+ from diff_diff.staggered_triple_diff_results import StaggeredTripleDiffResults
29
+ from diff_diff.utils import safe_inference
30
+
31
+ __all__ = [
32
+ "StaggeredTripleDifference",
33
+ "StaggeredTripleDiffResults",
34
+ ]
35
+
36
+ # Type alias for pre-computed structures
37
+ PrecomputedData = Dict[str, Any]
38
+
39
+
40
+ class StaggeredTripleDifference(
41
+ CallawaySantAnnaBootstrapMixin,
42
+ CallawaySantAnnaAggregationMixin,
43
+ ):
44
+ """
45
+ Staggered Triple Difference (DDD) estimator.
46
+
47
+ Computes group-time average treatment effects ATT(g,t) for settings
48
+ with staggered adoption and a binary eligibility dimension, using the
49
+ three-DiD decomposition of Ortiz-Villavicencio & Sant'Anna (2025).
50
+
51
+ Multiple comparison groups are combined via GMM-optimal (inverse-variance)
52
+ weighting. Event study, group, and overall aggregations are supported.
53
+
54
+ Parameters
55
+ ----------
56
+ estimation_method : str, default="dr"
57
+ Estimation method: "dr" (doubly robust), "ipw" (inverse probability
58
+ weighting), or "reg" (regression adjustment).
59
+ alpha : float, default=0.05
60
+ Significance level.
61
+ anticipation : int, default=0
62
+ Number of anticipation periods.
63
+ base_period : str, default="varying"
64
+ Base period selection: "varying" (consecutive comparisons) or
65
+ "universal" (always vs g-1-anticipation).
66
+ n_bootstrap : int, default=0
67
+ Number of multiplier bootstrap repetitions. 0 disables bootstrap.
68
+ bootstrap_weights : str, default="rademacher"
69
+ Bootstrap weight distribution: "rademacher", "mammen", or "webb".
70
+ seed : int or None, default=None
71
+ Random seed for reproducibility.
72
+ cband : bool, default=True
73
+ Whether to compute simultaneous confidence bands.
74
+ pscore_trim : float, default=0.01
75
+ Propensity score trimming bound.
76
+ cluster : str or None, default=None
77
+ Column name for cluster-robust standard errors.
78
+ rank_deficient_action : str, default="warn"
79
+ Action for rank-deficient design matrices: "warn", "error", "silent".
80
+ epv_threshold : float, default=10
81
+ Minimum events per variable for propensity score logistic regression.
82
+ A warning is emitted when EPV falls below this threshold.
83
+ pscore_fallback : str, default="error"
84
+ Action when propensity score estimation fails: "error" (raise) or
85
+ "unconditional" (fall back to unconditional propensity).
86
+
87
+ References
88
+ ----------
89
+ Ortiz-Villavicencio, M. & Sant'Anna, P.H.C. (2025). "Better Understanding
90
+ Triple Differences Estimators." arXiv:2505.09942.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ estimation_method: str = "dr",
96
+ control_group: str = "notyettreated",
97
+ alpha: float = 0.05,
98
+ anticipation: int = 0,
99
+ base_period: str = "varying",
100
+ n_bootstrap: int = 0,
101
+ bootstrap_weights: str = "rademacher",
102
+ seed: Optional[int] = None,
103
+ cband: bool = True,
104
+ pscore_trim: float = 0.01,
105
+ cluster: Optional[str] = None,
106
+ rank_deficient_action: str = "warn",
107
+ epv_threshold: float = 10,
108
+ pscore_fallback: str = "error",
109
+ ):
110
+ if estimation_method not in ["dr", "ipw", "reg"]:
111
+ raise ValueError(
112
+ f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'"
113
+ )
114
+ if control_group not in ["nevertreated", "notyettreated"]:
115
+ raise ValueError(
116
+ f"control_group must be 'nevertreated' or 'notyettreated', "
117
+ f"got '{control_group}'"
118
+ )
119
+ if not (0 < pscore_trim < 0.5):
120
+ raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}")
121
+ if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
122
+ raise ValueError(
123
+ f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
124
+ f"got '{bootstrap_weights}'"
125
+ )
126
+ if rank_deficient_action not in ["warn", "error", "silent"]:
127
+ raise ValueError(
128
+ f"rank_deficient_action must be 'warn', 'error', or 'silent', "
129
+ f"got '{rank_deficient_action}'"
130
+ )
131
+ if base_period not in ["varying", "universal"]:
132
+ raise ValueError(
133
+ f"base_period must be 'varying' or 'universal', " f"got '{base_period}'"
134
+ )
135
+ if epv_threshold <= 0:
136
+ raise ValueError(f"epv_threshold must be > 0, got {epv_threshold}")
137
+ if pscore_fallback not in ["error", "unconditional"]:
138
+ raise ValueError(
139
+ f"pscore_fallback must be 'error' or 'unconditional', "
140
+ f"got '{pscore_fallback}'"
141
+ )
142
+
143
+ self.estimation_method = estimation_method
144
+ self.control_group = control_group
145
+ self.alpha = alpha
146
+ self.anticipation = anticipation
147
+ self.base_period = base_period
148
+ self.n_bootstrap = n_bootstrap
149
+ self.bootstrap_weights = bootstrap_weights
150
+ self.seed = seed
151
+ self.cband = cband
152
+ self.pscore_trim = pscore_trim
153
+ self.cluster = cluster
154
+ self.rank_deficient_action = rank_deficient_action
155
+ self.epv_threshold = epv_threshold
156
+ self.pscore_fallback = pscore_fallback
157
+
158
+ self.is_fitted_ = False
159
+ self.results_: Optional[StaggeredTripleDiffResults] = None
160
+
161
+ def get_params(self) -> Dict[str, Any]:
162
+ """Get estimator parameters (sklearn-compatible)."""
163
+ return {
164
+ "estimation_method": self.estimation_method,
165
+ "control_group": self.control_group,
166
+ "alpha": self.alpha,
167
+ "anticipation": self.anticipation,
168
+ "base_period": self.base_period,
169
+ "n_bootstrap": self.n_bootstrap,
170
+ "bootstrap_weights": self.bootstrap_weights,
171
+ "seed": self.seed,
172
+ "cband": self.cband,
173
+ "pscore_trim": self.pscore_trim,
174
+ "cluster": self.cluster,
175
+ "rank_deficient_action": self.rank_deficient_action,
176
+ "epv_threshold": self.epv_threshold,
177
+ "pscore_fallback": self.pscore_fallback,
178
+ }
179
+
180
+ def set_params(self, **params) -> "StaggeredTripleDifference":
181
+ """Set estimator parameters (sklearn-compatible)."""
182
+ valid_params = self.get_params()
183
+ for key, value in params.items():
184
+ if key not in valid_params:
185
+ raise ValueError(f"Unknown parameter: {key}")
186
+ setattr(self, key, value)
187
+ if "bootstrap_weights" in params:
188
+ self.bootstrap_weights = params["bootstrap_weights"]
189
+ return self
190
+
191
+ # ------------------------------------------------------------------
192
+ # fit()
193
+ # ------------------------------------------------------------------
194
+
195
+ def fit(
196
+ self,
197
+ data: pd.DataFrame,
198
+ outcome: str,
199
+ unit: str,
200
+ time: str,
201
+ first_treat: str,
202
+ eligibility: str,
203
+ covariates: Optional[List[str]] = None,
204
+ aggregate: Optional[str] = None,
205
+ balance_e: Optional[int] = None,
206
+ survey_design: object = None,
207
+ ) -> StaggeredTripleDiffResults:
208
+ """
209
+ Fit the staggered triple difference estimator.
210
+
211
+ Parameters
212
+ ----------
213
+ data : pd.DataFrame
214
+ Panel data.
215
+ outcome : str
216
+ Outcome variable column name.
217
+ unit : str
218
+ Unit identifier column name.
219
+ time : str
220
+ Time period column name.
221
+ first_treat : str
222
+ Column with the enabling period for each unit's group.
223
+ Use 0 or np.inf for never-enabled units.
224
+ eligibility : str
225
+ Binary eligibility indicator column (0/1, time-invariant).
226
+ covariates : list of str, optional
227
+ Covariate column names.
228
+ aggregate : str, optional
229
+ Aggregation method: "event_study", "group", "simple", or "all".
230
+ balance_e : int, optional
231
+ Event time to balance on for event study.
232
+ survey_design : SurveyDesign, optional
233
+ Survey design specification for complex survey data. When
234
+ provided, uses survey weights for estimation (weighted Riesz
235
+ representers, weighted logit, weighted OLS) and design-based
236
+ variance for aggregated SEs (overall, event study, group) via
237
+ Taylor Series Linearization or replicate weights. Requires
238
+ ``weight_type='pweight'``.
239
+
240
+ Returns
241
+ -------
242
+ StaggeredTripleDiffResults
243
+ """
244
+ from diff_diff.survey import (
245
+ _resolve_survey_for_fit,
246
+ _validate_unit_constant_survey,
247
+ compute_survey_metadata,
248
+ )
249
+
250
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
251
+ _resolve_survey_for_fit(survey_design, data, "analytical")
252
+ )
253
+
254
+ if resolved_survey is not None:
255
+ _validate_unit_constant_survey(data, unit, survey_design)
256
+ if resolved_survey.weight_type != "pweight":
257
+ raise ValueError(
258
+ f"StaggeredTripleDifference survey support requires "
259
+ f"weight_type='pweight', got '{resolved_survey.weight_type}'. "
260
+ f"The survey variance math assumes probability weights."
261
+ )
262
+ if aggregate is not None and aggregate not in [
263
+ "event_study",
264
+ "group",
265
+ "simple",
266
+ "all",
267
+ ]:
268
+ raise ValueError(
269
+ f"aggregate must be 'event_study', 'group', 'simple', or 'all', "
270
+ f"got '{aggregate}'"
271
+ )
272
+
273
+ df = data.copy()
274
+ self._validate_inputs(df, outcome, unit, time, first_treat, eligibility, covariates)
275
+
276
+ if self.cluster is not None:
277
+ warnings.warn(
278
+ "cluster parameter is accepted but cluster-robust analytical SEs "
279
+ "are not yet implemented for staggered DDD. Use n_bootstrap > 0 "
280
+ "for unit-level clustered inference via multiplier bootstrap.",
281
+ UserWarning,
282
+ stacklevel=2,
283
+ )
284
+
285
+ if first_treat != "first_treat":
286
+ df["first_treat"] = df[first_treat]
287
+ df["first_treat"] = df["first_treat"].replace([np.inf, float("inf")], 0)
288
+
289
+ precomputed = self._precompute_structures(
290
+ df,
291
+ outcome,
292
+ unit,
293
+ time,
294
+ eligibility,
295
+ covariates,
296
+ resolved_survey=resolved_survey,
297
+ )
298
+
299
+ # Recompute survey metadata from unit-level resolved survey
300
+ if resolved_survey is not None and survey_metadata is not None:
301
+ resolved_survey_unit = precomputed.get("resolved_survey_unit")
302
+ if resolved_survey_unit is not None:
303
+ unit_w = resolved_survey_unit.weights
304
+ survey_metadata = compute_survey_metadata(resolved_survey_unit, unit_w)
305
+
306
+ # Survey df for t-distribution critical values
307
+ df_survey = precomputed.get("df_survey")
308
+ if (
309
+ df_survey is None
310
+ and resolved_survey is not None
311
+ and hasattr(resolved_survey, "uses_replicate_variance")
312
+ and resolved_survey.uses_replicate_variance
313
+ ):
314
+ df_survey = 0 # Forces NaN inference for undefined replicate df
315
+
316
+ has_survey = resolved_survey is not None
317
+
318
+ treatment_groups = precomputed["treatment_groups"]
319
+ time_periods = precomputed["time_periods"]
320
+ all_units = precomputed["all_units"]
321
+ time_to_col = precomputed["time_to_col"]
322
+ unit_cohorts = precomputed["unit_cohorts"]
323
+ eligibility_per_unit = precomputed["eligibility_per_unit"]
324
+ n_units = len(all_units)
325
+
326
+ pscore_cache: Dict = {}
327
+ # Skip Cholesky OR cache when survey weights present (X'WX != X'X)
328
+ cho_cache: Dict = {} if not has_survey else None
329
+
330
+ group_time_effects: Dict[Tuple, Dict[str, Any]] = {}
331
+ influence_func_info: Dict[Tuple, Dict[str, Any]] = {}
332
+ comparison_group_counts: Dict[Tuple, int] = {}
333
+ gmm_weights_store: Dict[Tuple, Dict] = {}
334
+ epv_diagnostics: Optional[Dict[Tuple, Dict[str, Any]]] = (
335
+ {} if (covariates and self.estimation_method in ("ipw", "dr")) else None
336
+ )
337
+
338
+ for g in treatment_groups:
339
+ # In universal mode, skip the reference period (t == g-1-anticipation)
340
+ # so it's omitted from GT estimation. The event-study mixin injects
341
+ # a synthetic reference row with effect=0, matching CS behavior.
342
+ if self.base_period == "universal":
343
+ universal_base = g - 1 - self.anticipation
344
+ valid_periods = [t for t in time_periods if t != universal_base]
345
+ else:
346
+ valid_periods = time_periods
347
+
348
+ for t in valid_periods:
349
+ base_period_val = self._get_base_period(g, t)
350
+ if base_period_val is None:
351
+ continue
352
+ if base_period_val not in time_to_col:
353
+ warnings.warn(
354
+ f"Base period {base_period_val} for (g={g}, t={t}) is "
355
+ "outside the observed panel. Skipping this cell.",
356
+ UserWarning,
357
+ stacklevel=2,
358
+ )
359
+ continue
360
+ if t not in time_to_col:
361
+ continue
362
+
363
+ has_never_enabled = bool(np.any(unit_cohorts == 0))
364
+
365
+ if self.control_group == "nevertreated":
366
+ # Only use never-enabled cohort as comparison
367
+ valid_gc = [0] if has_never_enabled else []
368
+ else:
369
+ # Use all valid comparison cohorts (not-yet-treated + never)
370
+ # Threshold accounts for anticipation: cohorts that start
371
+ # treatment within the anticipation window are contaminated.
372
+ nyt_threshold = max(t, base_period_val) + self.anticipation
373
+ valid_gc = [gc for gc in treatment_groups if gc > nyt_threshold and gc != g]
374
+ if has_never_enabled:
375
+ valid_gc = [0] + valid_gc
376
+
377
+ if not valid_gc:
378
+ warnings.warn(
379
+ f"No valid comparison groups for (g={g}, t={t}), skipping.",
380
+ UserWarning,
381
+ stacklevel=2,
382
+ )
383
+ continue
384
+
385
+ treated_mask = (unit_cohorts == g) & (eligibility_per_unit == 1)
386
+ n_treated = int(np.sum(treated_mask))
387
+ if n_treated == 0:
388
+ continue
389
+
390
+ att_vec = []
391
+ inf_raw = [] # unrescaled IFs
392
+ gc_labels = []
393
+ gc_cell_sizes = [] # size_gt_ctrl per surviving gc
394
+
395
+ for gc in valid_gc:
396
+ result = self._compute_ddd_gt_gc(
397
+ precomputed,
398
+ g,
399
+ gc,
400
+ t,
401
+ base_period_val,
402
+ covariates,
403
+ pscore_cache,
404
+ cho_cache,
405
+ epv_diagnostics=epv_diagnostics,
406
+ )
407
+ if result is None:
408
+ continue
409
+ att_gc, inf_gc, size_gt_ctrl = result
410
+ if not np.isfinite(att_gc):
411
+ continue
412
+
413
+ att_vec.append(att_gc)
414
+ inf_raw.append(inf_gc)
415
+ gc_labels.append(gc)
416
+ gc_cell_sizes.append(size_gt_ctrl)
417
+
418
+ if not att_vec:
419
+ continue
420
+
421
+ # Compute size_gt from SURVIVING comparison cohorts only
422
+ # (not from all initially valid gc's)
423
+ surviving_units = treated_mask.copy()
424
+ for gc in gc_labels:
425
+ surviving_units |= (unit_cohorts == gc) | (unit_cohorts == g)
426
+ survey_w = precomputed.get("survey_weights")
427
+ if survey_w is not None:
428
+ size_gt = float(np.sum(survey_w[surviving_units]))
429
+ else:
430
+ size_gt = float(np.sum(surviving_units))
431
+
432
+ # Apply IF rescaling now that size_gt is known
433
+ inf_matrix = []
434
+ for inf_gc, size_gt_ctrl in zip(inf_raw, gc_cell_sizes):
435
+ if size_gt_ctrl > 0:
436
+ inf_gc = inf_gc * (size_gt / size_gt_ctrl)
437
+ inf_matrix.append(inf_gc)
438
+
439
+ att_gmm, inf_gmm, gmm_w, se_gt = self._combine_gmm(
440
+ np.array(att_vec),
441
+ np.array(inf_matrix),
442
+ n_units,
443
+ )
444
+
445
+ if not np.isfinite(att_gmm):
446
+ continue
447
+
448
+ # R's single-gc SE uses size_gt in denominator, not n_total.
449
+ # For multi-gc (GMM), the size_gt factor is already in Omega
450
+ # via the per-gc rescaling, so n_total is correct.
451
+ if len(gc_labels) == 1:
452
+ se_gt = float(np.sqrt(np.sum(inf_gmm**2) / size_gt**2))
453
+
454
+ if not np.isfinite(se_gt) or se_gt <= 0:
455
+ se_gt = np.nan
456
+
457
+ t_stat, p_value, conf_int = safe_inference(
458
+ att_gmm, se_gt, alpha=self.alpha, df=df_survey
459
+ )
460
+
461
+ # Rescale IF for mixin compatibility.
462
+ # R stores IF * (n/size_gt) in inf_func_mat, then uses
463
+ # SE = sqrt(sum(IF^2)/n^2) = sqrt(sum(psi^2)) with psi = IF/n.
464
+ # We need psi = IF_rescaled / n so mixin's sqrt(sum(psi^2)) works.
465
+ # IF is already at size_gt/size_gt_ctrl scale from above.
466
+ # Apply the final n/size_gt factor, then divide by n for mixin.
467
+ inf_gmm_rescaled = inf_gmm * (n_units / size_gt)
468
+ inf_gmm_scaled = inf_gmm_rescaled / n_units
469
+
470
+ treated_idx = np.where(treated_mask)[0]
471
+ treated_inf = inf_gmm_scaled[treated_idx]
472
+ nonzero_mask = (inf_gmm_scaled != 0) & ~treated_mask
473
+ control_idx = np.where(nonzero_mask)[0]
474
+ control_inf = inf_gmm_scaled[control_idx]
475
+ n_control = int(np.sum(nonzero_mask))
476
+
477
+ group_time_effects[(g, t)] = {
478
+ "effect": att_gmm,
479
+ "se": se_gt,
480
+ "t_stat": t_stat,
481
+ "p_value": p_value,
482
+ "conf_int": conf_int,
483
+ "n_treated": n_treated,
484
+ "n_control": n_control,
485
+ }
486
+ influence_func_info[(g, t)] = {
487
+ "treated_idx": treated_idx,
488
+ "control_idx": control_idx,
489
+ "treated_units": all_units[treated_idx],
490
+ "control_units": all_units[control_idx],
491
+ "treated_inf": treated_inf,
492
+ "control_inf": control_inf,
493
+ }
494
+ comparison_group_counts[(g, t)] = len(gc_labels)
495
+ gmm_weights_store[(g, t)] = dict(zip(gc_labels, gmm_w.tolist()))
496
+
497
+ # Consolidated EPV summary warning
498
+ if epv_diagnostics:
499
+ low_epv = {k: v for k, v in epv_diagnostics.items() if v.get("is_low")}
500
+ if low_epv:
501
+ n_affected = len(low_epv)
502
+ n_total = len(epv_diagnostics)
503
+ min_entry = min(low_epv.values(), key=lambda v: v["epv"])
504
+ min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
505
+ warnings.warn(
506
+ f"Low Events Per Variable (EPV) detected in "
507
+ f"{n_affected} of {n_total} cohort-time cell(s). "
508
+ f"Minimum EPV: {min_entry['epv']:.1f} (cohort g={min_g[0]}). "
509
+ f"Consider estimation_method='reg' or fewer covariates. "
510
+ f"Call results.epv_summary() for per-cohort details.",
511
+ UserWarning,
512
+ stacklevel=2,
513
+ )
514
+
515
+ if not group_time_effects:
516
+ raise ValueError(
517
+ "No valid group-time effects could be computed. "
518
+ "Check that the data has sufficient variation in treatment "
519
+ "timing and eligibility."
520
+ )
521
+
522
+ # For aggregation: use eligible-treated-only cohort assignments so
523
+ # WIF weights match the point estimate weights (n_treated per cohort,
524
+ # i.e. P(S=g, Q=1)). This matches the paper's Eq 4.13 which defines
525
+ # aggregation weights over the treated population (G_i defined only
526
+ # for Q=1 units). Ineligible units get cohort=0 so they don't
527
+ # contribute to pg for any treatment group.
528
+ # Both precomputed["unit_cohorts"] AND df["first_treat"] must be
529
+ # zeroed for ineligible units because the WIF code reads both.
530
+ precomputed_agg = dict(precomputed)
531
+ cohorts_for_agg = precomputed["unit_cohorts"].copy()
532
+ cohorts_for_agg[eligibility_per_unit == 0] = 0
533
+ precomputed_agg["unit_cohorts"] = cohorts_for_agg
534
+
535
+ df_agg = df.copy()
536
+ df_agg.loc[df_agg[eligibility] == 0, "first_treat"] = 0
537
+
538
+ # Overall ATT via aggregation mixin
539
+ overall_att, overall_se, overall_effective_df = self._aggregate_simple(
540
+ group_time_effects, influence_func_info, df_agg, unit, precomputed_agg
541
+ )
542
+ # Use per-statistic effective df from replicate aggregation if available;
543
+ # otherwise fall back to the original df from the survey design.
544
+ if overall_effective_df is not None:
545
+ df_survey = overall_effective_df
546
+ if survey_metadata is not None:
547
+ survey_metadata.df_survey = df_survey
548
+ overall_t_stat, overall_p_value, overall_conf_int = safe_inference(
549
+ overall_att, overall_se, alpha=self.alpha, df=df_survey
550
+ )
551
+
552
+ # Aggregations
553
+ event_study_effects = None
554
+ group_effects = None
555
+ if aggregate in ("event_study", "all"):
556
+ event_study_effects = self._aggregate_event_study(
557
+ group_time_effects,
558
+ influence_func_info,
559
+ treatment_groups,
560
+ time_periods,
561
+ balance_e,
562
+ df_agg,
563
+ unit,
564
+ precomputed_agg,
565
+ )
566
+ if aggregate in ("group", "all"):
567
+ group_effects = self._aggregate_by_group(
568
+ group_time_effects,
569
+ influence_func_info,
570
+ treatment_groups,
571
+ precomputed_agg,
572
+ df_agg,
573
+ unit,
574
+ )
575
+
576
+ # Reject replicate-weight designs for bootstrap — replicate variance
577
+ # is an analytical alternative, not compatible with bootstrap
578
+ if (
579
+ self.n_bootstrap > 0
580
+ and resolved_survey is not None
581
+ and hasattr(resolved_survey, "uses_replicate_variance")
582
+ and resolved_survey.uses_replicate_variance
583
+ ):
584
+ raise NotImplementedError(
585
+ "StaggeredTripleDifference bootstrap (n_bootstrap > 0) is not "
586
+ "supported with replicate-weight survey designs. Replicate "
587
+ "weights provide analytical variance; use n_bootstrap=0 instead."
588
+ )
589
+
590
+ # Bootstrap
591
+ bootstrap_results = None
592
+ cband_crit_value = None
593
+ if self.n_bootstrap > 0:
594
+ bootstrap_results = self._run_multiplier_bootstrap(
595
+ group_time_effects,
596
+ influence_func_info,
597
+ aggregate,
598
+ balance_e,
599
+ treatment_groups,
600
+ time_periods,
601
+ df_agg,
602
+ unit,
603
+ precomputed_agg,
604
+ self.cband,
605
+ )
606
+ if bootstrap_results is not None:
607
+ overall_se = bootstrap_results.overall_att_se
608
+ overall_t_stat, overall_p_value, overall_conf_int = safe_inference(
609
+ overall_att, overall_se, alpha=self.alpha, df=df_survey
610
+ )
611
+ overall_conf_int = bootstrap_results.overall_att_ci
612
+ overall_p_value = bootstrap_results.overall_att_p_value
613
+ if bootstrap_results.cband_crit_value is not None:
614
+ cband_crit_value = bootstrap_results.cband_crit_value
615
+
616
+ # Update group-time effects with bootstrap SEs
617
+ if bootstrap_results.group_time_ses:
618
+ for gt_key in group_time_effects:
619
+ if gt_key in bootstrap_results.group_time_ses:
620
+ group_time_effects[gt_key]["se"] = bootstrap_results.group_time_ses[
621
+ gt_key
622
+ ]
623
+ group_time_effects[gt_key]["conf_int"] = (
624
+ bootstrap_results.group_time_cis[gt_key]
625
+ )
626
+ group_time_effects[gt_key]["p_value"] = (
627
+ bootstrap_results.group_time_p_values[gt_key]
628
+ )
629
+ t_val, _, _ = safe_inference(
630
+ group_time_effects[gt_key]["effect"],
631
+ bootstrap_results.group_time_ses[gt_key],
632
+ alpha=self.alpha,
633
+ df=df_survey,
634
+ )
635
+ group_time_effects[gt_key]["t_stat"] = t_val
636
+
637
+ if event_study_effects and bootstrap_results.event_study_ses:
638
+ for e_key in event_study_effects:
639
+ if e_key in bootstrap_results.event_study_ses:
640
+ event_study_effects[e_key]["se"] = bootstrap_results.event_study_ses[
641
+ e_key
642
+ ]
643
+ event_study_effects[e_key]["conf_int"] = (
644
+ bootstrap_results.event_study_cis[e_key]
645
+ )
646
+ event_study_effects[e_key]["p_value"] = (
647
+ bootstrap_results.event_study_p_values[e_key]
648
+ )
649
+ t_val, _, _ = safe_inference(
650
+ event_study_effects[e_key]["effect"],
651
+ bootstrap_results.event_study_ses[e_key],
652
+ alpha=self.alpha,
653
+ df=df_survey,
654
+ )
655
+ event_study_effects[e_key]["t_stat"] = t_val
656
+ if cband_crit_value is not None:
657
+ bs_se = bootstrap_results.event_study_ses[e_key]
658
+ eff = event_study_effects[e_key]["effect"]
659
+ event_study_effects[e_key]["cband_conf_int"] = (
660
+ eff - cband_crit_value * bs_se,
661
+ eff + cband_crit_value * bs_se,
662
+ )
663
+
664
+ # Update group effects with bootstrap SEs
665
+ if (
666
+ group_effects
667
+ and bootstrap_results.group_effect_ses is not None
668
+ and bootstrap_results.group_effect_cis is not None
669
+ and bootstrap_results.group_effect_p_values is not None
670
+ ):
671
+ grp_keys = [g for g in group_effects if g in bootstrap_results.group_effect_ses]
672
+ for g_key in grp_keys:
673
+ group_effects[g_key]["se"] = bootstrap_results.group_effect_ses[g_key]
674
+ group_effects[g_key]["conf_int"] = bootstrap_results.group_effect_cis[g_key]
675
+ group_effects[g_key]["p_value"] = bootstrap_results.group_effect_p_values[
676
+ g_key
677
+ ]
678
+ t_val, _, _ = safe_inference(
679
+ group_effects[g_key]["effect"],
680
+ bootstrap_results.group_effect_ses[g_key],
681
+ alpha=self.alpha,
682
+ df=df_survey,
683
+ )
684
+ group_effects[g_key]["t_stat"] = t_val
685
+
686
+ n_treated_units = int(np.sum((unit_cohorts > 0) & (eligibility_per_unit == 1)))
687
+ n_control_units = n_units - n_treated_units
688
+ n_never_enabled = int(np.sum(unit_cohorts == 0))
689
+ n_eligible = int(np.sum(eligibility_per_unit == 1))
690
+ n_ineligible = int(np.sum(eligibility_per_unit == 0))
691
+
692
+ self.results_ = StaggeredTripleDiffResults(
693
+ group_time_effects=group_time_effects,
694
+ overall_att=overall_att,
695
+ overall_se=overall_se,
696
+ overall_t_stat=overall_t_stat,
697
+ overall_p_value=overall_p_value,
698
+ overall_conf_int=overall_conf_int,
699
+ groups=treatment_groups,
700
+ time_periods=time_periods,
701
+ n_obs=len(df),
702
+ n_treated_units=n_treated_units,
703
+ n_control_units=n_control_units,
704
+ n_never_enabled=n_never_enabled,
705
+ n_eligible=n_eligible,
706
+ n_ineligible=n_ineligible,
707
+ alpha=self.alpha,
708
+ control_group=self.control_group,
709
+ base_period=self.base_period,
710
+ estimation_method=self.estimation_method,
711
+ event_study_effects=event_study_effects,
712
+ group_effects=group_effects,
713
+ bootstrap_results=bootstrap_results,
714
+ cband_crit_value=cband_crit_value,
715
+ pscore_trim=self.pscore_trim,
716
+ survey_metadata=survey_metadata,
717
+ comparison_group_counts=comparison_group_counts,
718
+ gmm_weights=gmm_weights_store,
719
+ epv_diagnostics=epv_diagnostics if epv_diagnostics else None,
720
+ epv_threshold=self.epv_threshold,
721
+ pscore_fallback=self.pscore_fallback,
722
+ )
723
+ self.is_fitted_ = True
724
+ return self.results_
725
+
726
+ # ------------------------------------------------------------------
727
+ # Validation
728
+ # ------------------------------------------------------------------
729
+
730
+ def _validate_inputs(
731
+ self,
732
+ df: pd.DataFrame,
733
+ outcome: str,
734
+ unit: str,
735
+ time: str,
736
+ first_treat: str,
737
+ eligibility: str,
738
+ covariates: Optional[List[str]],
739
+ ) -> None:
740
+ """Validate input data."""
741
+ required_cols = [outcome, unit, time, first_treat, eligibility]
742
+ if covariates:
743
+ required_cols.extend(covariates)
744
+ missing = [c for c in required_cols if c not in df.columns]
745
+ if missing:
746
+ raise ValueError(f"Missing columns: {missing}")
747
+
748
+ elig_vals = df[eligibility].dropna().unique()
749
+ if not set(elig_vals).issubset({0, 1, 0.0, 1.0}):
750
+ raise ValueError(
751
+ f"Eligibility column '{eligibility}' must be binary (0/1). "
752
+ f"Found values: {sorted(elig_vals)}"
753
+ )
754
+ elig_by_unit = df.groupby(unit)[eligibility].nunique()
755
+ varying = elig_by_unit[elig_by_unit > 1]
756
+ if len(varying) > 0:
757
+ raise ValueError(
758
+ f"Eligibility must be time-invariant within units. "
759
+ f"Found {len(varying)} units with varying eligibility."
760
+ )
761
+ for col in [outcome, first_treat, eligibility]:
762
+ if df[col].isna().any():
763
+ raise ValueError(f"Column '{col}' contains missing values.")
764
+
765
+ # Reject non-finite outcomes (Inf/-Inf)
766
+ if not np.all(np.isfinite(df[outcome])):
767
+ raise ValueError(
768
+ f"Column '{outcome}' contains non-finite values (Inf/-Inf). "
769
+ "All outcome values must be finite."
770
+ )
771
+
772
+ # Reject non-finite covariates
773
+ if covariates:
774
+ for cov in covariates:
775
+ if df[cov].isna().any():
776
+ raise ValueError(f"Covariate '{cov}' contains missing values.")
777
+ if not np.all(np.isfinite(df[cov])):
778
+ raise ValueError(f"Covariate '{cov}' contains non-finite values.")
779
+ if df[eligibility].nunique() < 2:
780
+ raise ValueError(
781
+ "Need both eligible (Q=1) and ineligible (Q=0) units. "
782
+ f"Only found Q={df[eligibility].unique()[0]}."
783
+ )
784
+
785
+ # Check unique (unit, time) pairs — no duplicate rows
786
+ dup = df.duplicated(subset=[unit, time], keep=False)
787
+ if dup.any():
788
+ raise ValueError(
789
+ f"Duplicate (unit, time) rows found. "
790
+ f"{int(dup.sum())} duplicates detected. Panel must have unique rows."
791
+ )
792
+
793
+ # Check balanced panel — every unit observed in exactly the global period set
794
+ global_periods = set(df[time].unique())
795
+ n_global_periods = len(global_periods)
796
+ unit_period_sets = df.groupby(unit)[time].apply(set)
797
+ mismatched = unit_period_sets[unit_period_sets != global_periods]
798
+ if len(mismatched) > 0:
799
+ raise ValueError(
800
+ "Unbalanced panel detected. All units must be observed in "
801
+ f"all {n_global_periods} periods. "
802
+ f"Found {len(mismatched)} units with different period sets."
803
+ )
804
+
805
+ # Check time-invariant first_treat
806
+ ft_by_unit = df.groupby(unit)[first_treat].nunique()
807
+ varying_ft = ft_by_unit[ft_by_unit > 1]
808
+ if len(varying_ft) > 0:
809
+ raise ValueError(
810
+ f"first_treat must be time-invariant within units. "
811
+ f"Found {len(varying_ft)} units with varying first_treat."
812
+ )
813
+
814
+ # Check time-invariant covariates
815
+ if covariates:
816
+ for cov in covariates:
817
+ cov_nunique = df.groupby(unit)[cov].nunique()
818
+ varying_cov = cov_nunique[cov_nunique > 1]
819
+ if len(varying_cov) > 0:
820
+ raise ValueError(
821
+ f"Covariate '{cov}' must be time-invariant within units. "
822
+ f"Found {len(varying_cov)} units with varying values."
823
+ )
824
+
825
+ # ------------------------------------------------------------------
826
+ # Precomputation
827
+ # ------------------------------------------------------------------
828
+
829
+ def _precompute_structures(
830
+ self,
831
+ df: pd.DataFrame,
832
+ outcome: str,
833
+ unit: str,
834
+ time: str,
835
+ eligibility: str,
836
+ covariates: Optional[List[str]],
837
+ resolved_survey=None,
838
+ ) -> PrecomputedData:
839
+ """Build precomputed structures for efficient computation."""
840
+ all_units = np.array(sorted(df[unit].unique()))
841
+ time_periods = sorted(df[time].unique())
842
+ n_units = len(all_units)
843
+ n_periods = len(time_periods)
844
+
845
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
846
+ time_to_col = {t: j for j, t in enumerate(time_periods)}
847
+
848
+ outcome_matrix = np.full((n_units, n_periods), np.nan)
849
+ for _, row in df.iterrows():
850
+ u_idx = unit_to_idx[row[unit]]
851
+ t_idx = time_to_col[row[time]]
852
+ outcome_matrix[u_idx, t_idx] = row[outcome]
853
+
854
+ unit_df = df.groupby(unit).first().reindex(all_units)
855
+ unit_cohorts = unit_df["first_treat"].values.astype(float)
856
+ eligibility_per_unit = unit_df[eligibility].values.astype(int)
857
+
858
+ treatment_groups = sorted([g for g in np.unique(unit_cohorts) if g > 0])
859
+
860
+ covariate_matrix = None
861
+ if covariates:
862
+ cov_wide = {}
863
+ for cov in covariates:
864
+ cov_vals = np.full(n_units, np.nan)
865
+ for u_id, idx in unit_to_idx.items():
866
+ u_data = df.loc[df[unit] == u_id, cov]
867
+ if len(u_data) > 0:
868
+ cov_vals[idx] = u_data.iloc[0]
869
+ cov_wide[cov] = cov_vals
870
+ covariate_matrix = np.column_stack(list(cov_wide.values()))
871
+
872
+ # Extract per-unit survey weights and collapse design to unit level
873
+ survey_weights_arr = None
874
+ resolved_survey_unit = None
875
+ if resolved_survey is not None:
876
+ from diff_diff.survey import collapse_survey_to_unit_level
877
+
878
+ survey_weights_arr = (
879
+ pd.Series(resolved_survey.weights, index=df.index)
880
+ .groupby(df[unit])
881
+ .first()
882
+ .reindex(all_units)
883
+ .values.astype(np.float64)
884
+ )
885
+ # Normalize to sum=n for aggregation/rescaling (matches pweight
886
+ # convention). Raw weights preserved in resolved_survey_unit for
887
+ # replicate w_r/w_full ratios — those are inherently scale-invariant.
888
+ sw_sum = np.sum(survey_weights_arr)
889
+ if sw_sum > 0:
890
+ survey_weights_arr = survey_weights_arr * (len(survey_weights_arr) / sw_sum)
891
+ resolved_survey_unit = collapse_survey_to_unit_level(
892
+ resolved_survey, df, unit, all_units
893
+ )
894
+
895
+ return {
896
+ "all_units": all_units,
897
+ "unit_to_idx": unit_to_idx,
898
+ "time_periods": time_periods,
899
+ "time_to_col": time_to_col,
900
+ "outcome_matrix": outcome_matrix,
901
+ "unit_cohorts": unit_cohorts,
902
+ "eligibility_per_unit": eligibility_per_unit,
903
+ "treatment_groups": treatment_groups,
904
+ "covariate_matrix": covariate_matrix,
905
+ "n_units": n_units,
906
+ "n_periods": n_periods,
907
+ "survey_weights": survey_weights_arr,
908
+ "resolved_survey_unit": resolved_survey_unit,
909
+ "df_survey": (
910
+ resolved_survey_unit.df_survey if resolved_survey_unit is not None else None
911
+ ),
912
+ }
913
+
914
+ # ------------------------------------------------------------------
915
+ # Base period
916
+ # ------------------------------------------------------------------
917
+
918
+ def _get_base_period(self, g: Any, t: Any) -> Optional[Any]:
919
+ """Determine base period for a (g, t) pair."""
920
+ if self.base_period == "universal":
921
+ return g - 1 - self.anticipation
922
+ else:
923
+ if t < g - self.anticipation:
924
+ return t - 1
925
+ else:
926
+ return g - 1 - self.anticipation
927
+
928
+ # ------------------------------------------------------------------
929
+ # Three-DiD DDD for one (g, g_c, t) triple
930
+ # ------------------------------------------------------------------
931
+
932
+ def _compute_ddd_gt_gc(
933
+ self,
934
+ precomputed: PrecomputedData,
935
+ g: Any,
936
+ g_c: Any,
937
+ t: Any,
938
+ base_period_val: Any,
939
+ covariates: Optional[List[str]],
940
+ pscore_cache: Dict,
941
+ cho_cache: Optional[Dict],
942
+ epv_diagnostics: Optional[Dict] = None,
943
+ ) -> Optional[Tuple[float, np.ndarray, float]]:
944
+ """
945
+ Compute DDD ATT for one (g, g_c, t) triple.
946
+
947
+ Returns (att_ddd, inf_full_n_units, size_gt_ctrl) or None.
948
+ """
949
+ outcome_matrix = precomputed["outcome_matrix"]
950
+ time_to_col = precomputed["time_to_col"]
951
+ unit_cohorts = precomputed["unit_cohorts"]
952
+ eligibility_per_unit = precomputed["eligibility_per_unit"]
953
+ covariate_matrix = precomputed["covariate_matrix"]
954
+ n_units = precomputed["n_units"]
955
+ survey_weights = precomputed.get("survey_weights")
956
+
957
+ t_col = time_to_col[t]
958
+ b_col = time_to_col[base_period_val]
959
+
960
+ # Four sub-groups within this (g, g_c) cell
961
+ treated_mask = (unit_cohorts == g) & (eligibility_per_unit == 1) # subgroup 4
962
+ sub_a_mask = (unit_cohorts == g) & (eligibility_per_unit == 0) # subgroup 3
963
+ sub_b_mask = (unit_cohorts == g_c) & (eligibility_per_unit == 1) # subgroup 2
964
+ sub_c_mask = (unit_cohorts == g_c) & (eligibility_per_unit == 0) # subgroup 1
965
+
966
+ n_treated = int(np.sum(treated_mask))
967
+ n_a = int(np.sum(sub_a_mask))
968
+ n_b = int(np.sum(sub_b_mask))
969
+ n_c = int(np.sum(sub_c_mask))
970
+
971
+ # Check for empty subgroups (by count or by survey weight mass)
972
+ empty = []
973
+ if n_treated == 0:
974
+ empty.append(f"(S={g},Q=1)")
975
+ if n_a == 0:
976
+ empty.append(f"(S={g},Q=0)")
977
+ if n_b == 0:
978
+ empty.append(f"(S={g_c},Q=1)")
979
+ if n_c == 0:
980
+ empty.append(f"(S={g_c},Q=0)")
981
+ # Zero survey-weight mass after subpopulation filtering = effectively empty
982
+ if not empty and survey_weights is not None:
983
+ if np.sum(survey_weights[treated_mask]) <= 0:
984
+ empty.append(f"(S={g},Q=1,mass=0)")
985
+ if np.sum(survey_weights[sub_a_mask]) <= 0:
986
+ empty.append(f"(S={g},Q=0,mass=0)")
987
+ if np.sum(survey_weights[sub_b_mask]) <= 0:
988
+ empty.append(f"(S={g_c},Q=1,mass=0)")
989
+ if np.sum(survey_weights[sub_c_mask]) <= 0:
990
+ empty.append(f"(S={g_c},Q=0,mass=0)")
991
+ if empty:
992
+ warnings.warn(
993
+ f"Empty subgroup(s) {', '.join(empty)} for "
994
+ f"(g={g}, g_c={g_c}, t={t}). "
995
+ "Comparison unidentified, skipping.",
996
+ UserWarning,
997
+ stacklevel=3,
998
+ )
999
+ return None
1000
+
1001
+ if min(n_treated, n_a, n_b, n_c) < 5:
1002
+ warnings.warn(
1003
+ f"Small cell size for (g={g}, g_c={g_c}, t={t}). " "Estimates may be unreliable.",
1004
+ UserWarning,
1005
+ stacklevel=3,
1006
+ )
1007
+
1008
+ # Outcome changes
1009
+ delta_y_all = outcome_matrix[:, t_col] - outcome_matrix[:, b_col]
1010
+ valid = np.isfinite(delta_y_all)
1011
+ for m in [treated_mask, sub_a_mask, sub_b_mask, sub_c_mask]:
1012
+ if not np.all(valid[m]):
1013
+ return None
1014
+
1015
+ # Three pairwise DiDs, each on a 2-cell subset
1016
+ # Collect per-DiD EPV diagnostics; merge worst into (g,t) key later
1017
+ epv_diag_a = {} if epv_diagnostics is not None else None
1018
+ epv_diag_b = {} if epv_diagnostics is not None else None
1019
+ epv_diag_c = {} if epv_diagnostics is not None else None
1020
+
1021
+ # DiD_A: subgroup 4 vs 3 (treated-eligible vs treated-ineligible)
1022
+ pair_a_mask = treated_mask | sub_a_mask
1023
+ did_a = self._run_pairwise_did(
1024
+ delta_y_all,
1025
+ pair_a_mask,
1026
+ treated_mask,
1027
+ sub_a_mask,
1028
+ covariate_matrix,
1029
+ pscore_cache,
1030
+ (g, g, 0, base_period_val),
1031
+ cho_cache,
1032
+ ("a", g, g, base_period_val),
1033
+ survey_weights=survey_weights,
1034
+ context_label=f"cohort g={g}, DiD_A (g_c={g_c})",
1035
+ epv_diagnostics_out=epv_diag_a,
1036
+ )
1037
+
1038
+ # DiD_B: subgroup 4 vs 2 (treated-eligible vs control-eligible)
1039
+ pair_b_mask = treated_mask | sub_b_mask
1040
+ did_b = self._run_pairwise_did(
1041
+ delta_y_all,
1042
+ pair_b_mask,
1043
+ treated_mask,
1044
+ sub_b_mask,
1045
+ covariate_matrix,
1046
+ pscore_cache,
1047
+ (g, g_c, 1, base_period_val),
1048
+ cho_cache,
1049
+ ("b", g, g_c, base_period_val),
1050
+ survey_weights=survey_weights,
1051
+ context_label=f"cohort g={g}, DiD_B (g_c={g_c})",
1052
+ epv_diagnostics_out=epv_diag_b,
1053
+ )
1054
+
1055
+ # DiD_C: subgroup 4 vs 1 (treated-eligible vs control-ineligible)
1056
+ pair_c_mask = treated_mask | sub_c_mask
1057
+ did_c = self._run_pairwise_did(
1058
+ delta_y_all,
1059
+ pair_c_mask,
1060
+ treated_mask,
1061
+ sub_c_mask,
1062
+ covariate_matrix,
1063
+ pscore_cache,
1064
+ (g, g_c, 0, base_period_val),
1065
+ cho_cache,
1066
+ ("c", g, g_c, base_period_val),
1067
+ survey_weights=survey_weights,
1068
+ context_label=f"cohort g={g}, DiD_C (g_c={g_c})",
1069
+ epv_diagnostics_out=epv_diag_c,
1070
+ )
1071
+
1072
+ # Merge per-DiD EPV diagnostics: keep the worst (lowest EPV) entry
1073
+ # across all three DiDs for this g_c. If multiple g_c contribute to the
1074
+ # same (g, t) cell, retain the overall minimum EPV across all g_c calls.
1075
+ if epv_diagnostics is not None:
1076
+ candidates = [d for d in [epv_diag_a, epv_diag_b, epv_diag_c] if d]
1077
+ if candidates:
1078
+ worst = min(candidates, key=lambda d: d.get("epv", float("inf")))
1079
+ existing = epv_diagnostics.get((g, t))
1080
+ if existing is None or worst.get("epv", float("inf")) < existing.get(
1081
+ "epv", float("inf")
1082
+ ):
1083
+ epv_diagnostics[(g, t)] = worst
1084
+
1085
+ if did_a is None or did_b is None or did_c is None:
1086
+ return None
1087
+
1088
+ att_a, inf_a = did_a
1089
+ att_b, inf_b = did_b
1090
+ att_c, inf_c = did_c
1091
+
1092
+ att_ddd = att_a + att_b - att_c
1093
+
1094
+ # Three-DiD IF combination: w_j = n_cell / n_pair_j (R's att_dr convention)
1095
+ # With survey weights, use survey-weighted cell sizes
1096
+ if survey_weights is not None:
1097
+ sw_4 = float(np.sum(survey_weights[treated_mask]))
1098
+ sw_3 = float(np.sum(survey_weights[sub_a_mask]))
1099
+ sw_2 = float(np.sum(survey_weights[sub_b_mask]))
1100
+ sw_1 = float(np.sum(survey_weights[sub_c_mask]))
1101
+ n_cell_w = sw_4 + sw_3 + sw_2 + sw_1
1102
+ n_pair_a_w = sw_4 + sw_3
1103
+ n_pair_b_w = sw_4 + sw_2
1104
+ n_pair_c_w = sw_4 + sw_1
1105
+ w_3 = n_cell_w / n_pair_a_w if n_pair_a_w > 0 else 1.0
1106
+ w_2 = n_cell_w / n_pair_b_w if n_pair_b_w > 0 else 1.0
1107
+ w_1 = n_cell_w / n_pair_c_w if n_pair_c_w > 0 else 1.0
1108
+ size_gt_ctrl = n_cell_w
1109
+ else:
1110
+ n_cell = n_treated + n_a + n_b + n_c
1111
+ n_pair_a = n_treated + n_a
1112
+ n_pair_b = n_treated + n_b
1113
+ n_pair_c = n_treated + n_c
1114
+ w_3 = n_cell / n_pair_a if n_pair_a > 0 else 1.0
1115
+ w_2 = n_cell / n_pair_b if n_pair_b > 0 else 1.0
1116
+ w_1 = n_cell / n_pair_c if n_pair_c > 0 else 1.0
1117
+ size_gt_ctrl = float(n_cell)
1118
+
1119
+ # Scatter pair-level IFs into n_units-length vector
1120
+ inf_full = np.zeros(n_units)
1121
+ pair_a_idx = np.where(pair_a_mask)[0]
1122
+ pair_b_idx = np.where(pair_b_mask)[0]
1123
+ pair_c_idx = np.where(pair_c_mask)[0]
1124
+
1125
+ inf_full[pair_a_idx] += w_3 * inf_a
1126
+ inf_full[pair_b_idx] += w_2 * inf_b
1127
+ inf_full[pair_c_idx] -= w_1 * inf_c
1128
+
1129
+ return att_ddd, inf_full, size_gt_ctrl
1130
+
1131
+ # ------------------------------------------------------------------
1132
+ # Pairwise DiD (matches R's compute_did)
1133
+ # ------------------------------------------------------------------
1134
+
1135
+ def _run_pairwise_did(
1136
+ self,
1137
+ delta_y_all: np.ndarray,
1138
+ pair_mask: np.ndarray,
1139
+ treated_mask: np.ndarray,
1140
+ control_mask: np.ndarray,
1141
+ covariate_matrix: Optional[np.ndarray],
1142
+ pscore_cache: Dict,
1143
+ pscore_key: Any,
1144
+ cho_cache: Optional[Dict],
1145
+ cho_key: Any,
1146
+ survey_weights: Optional[np.ndarray] = None,
1147
+ context_label: str = "",
1148
+ epv_diagnostics_out: Optional[dict] = None,
1149
+ ) -> Optional[Tuple[float, np.ndarray]]:
1150
+ """
1151
+ Compute a single pairwise DiD ATT and IF on a 2-cell subset.
1152
+
1153
+ Matches R's triplediff::compute_did() formulation exactly:
1154
+ Riesz/Hajek normalization, PS + OR IF corrections.
1155
+
1156
+ Returns (att, inf_func) where inf_func has length n_pair,
1157
+ ordered by pair_mask indices. Returns None if insufficient data.
1158
+ """
1159
+ pair_idx = np.where(pair_mask)[0]
1160
+ n_pair = len(pair_idx)
1161
+ if n_pair == 0:
1162
+ return None
1163
+
1164
+ delta_y = delta_y_all[pair_idx]
1165
+ PA4 = treated_mask[pair_idx].astype(float)
1166
+ PAa = control_mask[pair_idx].astype(float)
1167
+ sw_pair = survey_weights[pair_idx] if survey_weights is not None else None
1168
+
1169
+ n_t = int(np.sum(PA4))
1170
+ n_c = int(np.sum(PAa))
1171
+ if n_t == 0 or n_c == 0:
1172
+ return None
1173
+
1174
+ has_covariates = covariate_matrix is not None and self.estimation_method != "none"
1175
+
1176
+ # Build covariate matrix with intercept for the pair
1177
+ covX = None
1178
+ if has_covariates:
1179
+ X_pair = covariate_matrix[pair_idx]
1180
+ covX = np.column_stack([np.ones(n_pair), X_pair])
1181
+
1182
+ # Compute nuisance parameters based on estimation method
1183
+ pscore = None
1184
+ hessian = None
1185
+ or_delta = np.zeros(n_pair)
1186
+
1187
+ if self.estimation_method in ("ipw", "dr") and covX is not None:
1188
+ pscore, hessian = self._compute_pscore(
1189
+ PA4,
1190
+ covX,
1191
+ pscore_cache,
1192
+ pscore_key,
1193
+ survey_weights=sw_pair,
1194
+ context_label=context_label,
1195
+ epv_diagnostics_out=epv_diagnostics_out,
1196
+ )
1197
+
1198
+ if self.estimation_method in ("reg", "dr") and covX is not None:
1199
+ # Skip Cholesky cache when survey weights present (cho_cache=None)
1200
+ or_delta = self._compute_or(
1201
+ delta_y,
1202
+ PAa,
1203
+ covX,
1204
+ cho_cache,
1205
+ cho_key,
1206
+ survey_weights=sw_pair,
1207
+ )
1208
+
1209
+ # Compute ATT and IF (R's compute_did formulation)
1210
+ return self._compute_did_panel(
1211
+ delta_y,
1212
+ PA4,
1213
+ PAa,
1214
+ covX,
1215
+ pscore,
1216
+ hessian,
1217
+ or_delta,
1218
+ survey_weights=sw_pair,
1219
+ )
1220
+
1221
+ # ------------------------------------------------------------------
1222
+ # Core DR/IPW/RA computation (matches R's compute_did exactly)
1223
+ # ------------------------------------------------------------------
1224
+
1225
+ def _compute_did_panel(
1226
+ self,
1227
+ delta_y: np.ndarray,
1228
+ PA4: np.ndarray,
1229
+ PAa: np.ndarray,
1230
+ covX: Optional[np.ndarray],
1231
+ pscore: Optional[np.ndarray],
1232
+ hessian: Optional[np.ndarray],
1233
+ or_delta: np.ndarray,
1234
+ survey_weights: Optional[np.ndarray] = None,
1235
+ ) -> Tuple[float, np.ndarray]:
1236
+ """
1237
+ Pairwise DiD ATT and influence function.
1238
+ Matches R's triplediff::compute_did() line-by-line.
1239
+
1240
+ Parameters
1241
+ ----------
1242
+ delta_y : outcome changes for 2-cell subset (n_pair,)
1243
+ PA4 : treated indicator (n_pair,)
1244
+ PAa : control indicator (n_pair,)
1245
+ covX : covariate matrix with intercept (n_pair, p) or None
1246
+ pscore : propensity scores (n_pair,) or None
1247
+ hessian : (X'WX)^{-1} * n_pair or None
1248
+ or_delta : OR predictions (n_pair,), zeros if no covariates
1249
+ survey_weights : per-observation survey weights (n_pair,) or None
1250
+
1251
+ Returns
1252
+ -------
1253
+ (att, inf_func) where inf_func has length n_pair.
1254
+ """
1255
+ n_pair = len(delta_y)
1256
+ est = self.estimation_method
1257
+
1258
+ # Riesz representers (R lines 243-250)
1259
+ if est == "reg" or pscore is None:
1260
+ w_treat = PA4.copy()
1261
+ w_control = PAa.copy()
1262
+ else:
1263
+ w_treat = PA4.copy()
1264
+ pscore_safe = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
1265
+ w_control = pscore_safe * PAa / (1 - pscore_safe)
1266
+
1267
+ # Incorporate survey weights into Riesz representers
1268
+ if survey_weights is not None:
1269
+ w_treat = w_treat * survey_weights
1270
+ w_control = w_control * survey_weights
1271
+
1272
+ # DR ATT via Hajek normalization (R lines 251-256)
1273
+ resid = delta_y - or_delta
1274
+ riesz_treat = w_treat * resid
1275
+ riesz_control = w_control * resid
1276
+
1277
+ mean_w_treat = np.mean(w_treat)
1278
+ mean_w_control = np.mean(w_control)
1279
+
1280
+ if mean_w_treat <= 0 or mean_w_control <= 0:
1281
+ return float("nan"), np.zeros(n_pair)
1282
+
1283
+ att_treat = np.mean(riesz_treat) / mean_w_treat
1284
+ att_control = np.mean(riesz_control) / mean_w_control
1285
+ dr_att = att_treat - att_control
1286
+
1287
+ # Base IF (R lines 302-304)
1288
+ inf_treat_did = riesz_treat - w_treat * att_treat
1289
+ inf_control_did = riesz_control - w_control * att_control
1290
+
1291
+ # PS correction (R lines 262-273) — IPW and DR only
1292
+ inf_control_pscore = 0.0
1293
+ if est != "reg" and hessian is not None and covX is not None:
1294
+ M2 = np.mean((w_control * (resid - att_control))[:, None] * covX, axis=0)
1295
+ if survey_weights is not None:
1296
+ score_ps = survey_weights[:, None] * (PA4 - pscore_safe)[:, None] * covX
1297
+ else:
1298
+ score_ps = (PA4 - pscore_safe)[:, None] * covX
1299
+ asy_lin_rep_ps = score_ps @ hessian
1300
+ inf_control_pscore = asy_lin_rep_ps @ M2
1301
+
1302
+ # OR correction (R lines 278-300) — reg and DR only
1303
+ inf_treat_or = 0.0
1304
+ inf_cont_or = 0.0
1305
+ if est != "ipw" and covX is not None:
1306
+ M1 = np.mean(w_treat[:, None] * covX, axis=0)
1307
+ M3 = np.mean(w_control[:, None] * covX, axis=0)
1308
+
1309
+ if survey_weights is not None:
1310
+ or_x = (PAa * survey_weights)[:, None] * covX
1311
+ or_ex = (PAa * survey_weights * resid)[:, None] * covX
1312
+ else:
1313
+ or_x = PAa[:, None] * covX
1314
+ or_ex = (PAa * resid)[:, None] * covX
1315
+ XpX = or_x.T @ covX / n_pair
1316
+
1317
+ try:
1318
+ asy_linear_or = (np.linalg.solve(XpX, or_ex.T)).T
1319
+ except np.linalg.LinAlgError:
1320
+ asy_linear_or = (np.linalg.lstsq(XpX, or_ex.T, rcond=None)[0]).T
1321
+
1322
+ inf_treat_or = -(asy_linear_or @ M1)
1323
+ inf_cont_or = -(asy_linear_or @ M3)
1324
+
1325
+ # Final IF assembly (R lines 307-310)
1326
+ inf_control = (inf_control_did + inf_control_pscore + inf_cont_or) / mean_w_control
1327
+ inf_treat = (inf_treat_did + inf_treat_or) / mean_w_treat
1328
+ inf_func = inf_treat - inf_control
1329
+
1330
+ return float(dr_att), inf_func
1331
+
1332
+ # ------------------------------------------------------------------
1333
+ # Nuisance parameter computation
1334
+ # ------------------------------------------------------------------
1335
+
1336
+ def _compute_pscore(
1337
+ self,
1338
+ PA4: np.ndarray,
1339
+ covX: np.ndarray,
1340
+ pscore_cache: Dict,
1341
+ pscore_key: Any,
1342
+ survey_weights: Optional[np.ndarray] = None,
1343
+ context_label: str = "",
1344
+ epv_diagnostics_out: Optional[dict] = None,
1345
+ ) -> Tuple[np.ndarray, np.ndarray]:
1346
+ """Fit logistic P(PA4=1|X). Returns (pscore, hessian).
1347
+
1348
+ hessian = (X'WX)^{-1} * n_pair, matching R's convention.
1349
+ When survey_weights is provided, IRLS uses survey-weighted
1350
+ working weights and the hessian accounts for survey weights.
1351
+ """
1352
+ cached = pscore_cache.get(pscore_key)
1353
+ n_pair = len(PA4)
1354
+
1355
+ if cached is not None:
1356
+ beta_logistic, cached_diag = cached
1357
+ z = np.dot(covX, beta_logistic)
1358
+ z = np.clip(z, -500, 500)
1359
+ pscore = 1 / (1 + np.exp(-z))
1360
+ if epv_diagnostics_out is not None and cached_diag:
1361
+ epv_diagnostics_out.update(cached_diag)
1362
+ else:
1363
+ X_no_intercept = covX[:, 1:] # solve_logit adds its own intercept
1364
+ diag = {}
1365
+ try:
1366
+ beta_logistic, pscore = solve_logit(
1367
+ X_no_intercept,
1368
+ PA4,
1369
+ rank_deficient_action=self.rank_deficient_action,
1370
+ weights=survey_weights,
1371
+ epv_threshold=self.epv_threshold,
1372
+ context_label=context_label,
1373
+ diagnostics_out=diag,
1374
+ )
1375
+ _check_propensity_diagnostics(pscore, self.pscore_trim)
1376
+ # Zero-fill NaN coefficients (from rank-deficient columns)
1377
+ # before caching, so cache reuse doesn't propagate NaN.
1378
+ # Cache alongside EPV diagnostics for replay on cache hits.
1379
+ beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0)
1380
+ pscore_cache[pscore_key] = (beta_clean, diag)
1381
+ except (np.linalg.LinAlgError, ValueError):
1382
+ if (
1383
+ self.pscore_fallback == "error"
1384
+ or self.rank_deficient_action == "error"
1385
+ ):
1386
+ raise
1387
+ ctx = f" for {context_label}" if context_label else ""
1388
+ warnings.warn(
1389
+ f"Propensity score estimation failed{ctx}. "
1390
+ f"Falling back to unconditional propensity "
1391
+ f"(propensity model ignores covariates; outcome "
1392
+ f"regression still uses them for DR). "
1393
+ f"Consider estimation_method='reg' to avoid "
1394
+ f"propensity scores entirely.",
1395
+ UserWarning,
1396
+ stacklevel=5,
1397
+ )
1398
+ # Use survey-weighted treated share when weights available
1399
+ if survey_weights is not None:
1400
+ pos = survey_weights > 0
1401
+ if np.any(pos):
1402
+ p_uc = np.average(PA4[pos], weights=survey_weights[pos])
1403
+ else:
1404
+ p_uc = np.mean(PA4)
1405
+ else:
1406
+ p_uc = np.mean(PA4)
1407
+ pscore = np.full(n_pair, p_uc)
1408
+ pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
1409
+ # No hessian for unconditional fallback
1410
+ return pscore, None
1411
+ if epv_diagnostics_out is not None and diag:
1412
+ epv_diagnostics_out.update(diag)
1413
+
1414
+ pscore = np.clip(pscore, 1e-6, 1 - 1e-6)
1415
+
1416
+ # Hessian: (X'WX)^{-1} * n (matching R's compute_pscore)
1417
+ W = pscore * (1 - pscore)
1418
+ if survey_weights is not None:
1419
+ W = W * survey_weights
1420
+ XWX = covX.T @ (W[:, None] * covX)
1421
+ try:
1422
+ hessian = np.linalg.inv(XWX) * n_pair
1423
+ except np.linalg.LinAlgError:
1424
+ hessian = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0] * n_pair
1425
+
1426
+ return pscore, hessian
1427
+
1428
+ def _compute_or(
1429
+ self,
1430
+ delta_y: np.ndarray,
1431
+ PAa: np.ndarray,
1432
+ covX: np.ndarray,
1433
+ cho_cache: Optional[Dict],
1434
+ cho_key: Any,
1435
+ survey_weights: Optional[np.ndarray] = None,
1436
+ ) -> np.ndarray:
1437
+ """Fit OLS on control outcome changes. Returns or_delta for all pair units.
1438
+
1439
+ Honors self.rank_deficient_action for collinear covariates.
1440
+ When survey_weights is provided, uses WLS via solve_ols(weights=...).
1441
+ Cholesky cache is disabled for the survey path (cho_cache=None).
1442
+ """
1443
+ from diff_diff.linalg import solve_ols as _solve_ols
1444
+
1445
+ control_mask = PAa > 0
1446
+ n_c = int(np.sum(control_mask))
1447
+ if n_c == 0:
1448
+ return np.zeros(len(delta_y))
1449
+
1450
+ X_control = covX[control_mask]
1451
+ y_control = delta_y[control_mask]
1452
+ sw_control = survey_weights[control_mask] if survey_weights is not None else None
1453
+
1454
+ # Try Cholesky cache for fast path (full-rank only)
1455
+ # Skipped when cho_cache is None (survey weights present)
1456
+ beta = None
1457
+ if cho_cache is not None:
1458
+ cached_cho = cho_cache.get(cho_key)
1459
+ if cached_cho is False:
1460
+ pass # Previously detected rank-deficient; skip Cholesky
1461
+ elif cached_cho is not None:
1462
+ from scipy import linalg as sp_linalg
1463
+
1464
+ Xty = X_control.T @ y_control
1465
+ beta = sp_linalg.cho_solve(cached_cho, Xty)
1466
+ if np.any(~np.isfinite(beta)):
1467
+ beta = None
1468
+ elif cho_key not in cho_cache:
1469
+ XtX = X_control.T @ X_control
1470
+ try:
1471
+ from scipy import linalg as sp_linalg
1472
+
1473
+ cho_factor = sp_linalg.cho_factor(XtX)
1474
+ cho_cache[cho_key] = cho_factor
1475
+ Xty = X_control.T @ y_control
1476
+ beta = sp_linalg.cho_solve(cho_factor, Xty)
1477
+ if np.any(~np.isfinite(beta)):
1478
+ beta = None
1479
+ except np.linalg.LinAlgError:
1480
+ cho_cache[cho_key] = False
1481
+
1482
+ if beta is None:
1483
+ # Fallback (or survey path): use solve_ols with optional weights
1484
+ beta, _, _ = _solve_ols(
1485
+ X_control,
1486
+ y_control,
1487
+ rank_deficient_action=self.rank_deficient_action,
1488
+ weights=sw_control,
1489
+ )
1490
+ beta = np.where(np.isfinite(beta), beta, 0.0)
1491
+
1492
+ return covX @ beta
1493
+
1494
+ # ------------------------------------------------------------------
1495
+ # GMM-optimal combination (matches R's att_gt GMM procedure)
1496
+ # ------------------------------------------------------------------
1497
+
1498
+ def _combine_gmm(
1499
+ self,
1500
+ att_vec: np.ndarray,
1501
+ inf_func_matrix: np.ndarray,
1502
+ n_units: int,
1503
+ ) -> Tuple[float, np.ndarray, np.ndarray, float]:
1504
+ """
1505
+ Combine comparison-group-specific estimates via GMM-optimal weights.
1506
+
1507
+ Returns (att_gmm, inf_gmm, weights, se_gmm).
1508
+ """
1509
+ k = len(att_vec)
1510
+
1511
+ if k == 1:
1512
+ att_gmm = float(att_vec[0])
1513
+ inf_gmm = inf_func_matrix[0].copy()
1514
+ # R's SE: sqrt(sum(IF^2) / n^2)
1515
+ se_gmm = float(np.sqrt(np.sum(inf_gmm**2) / n_units**2))
1516
+ return att_gmm, inf_gmm, np.array([1.0]), se_gmm
1517
+
1518
+ # R: OMEGA <- cov(inf_mat_local) — sample covariance, ddof=1
1519
+ Omega = np.cov(inf_func_matrix)
1520
+
1521
+ ones = np.ones(k)
1522
+ try:
1523
+ Omega_inv = np.linalg.inv(Omega)
1524
+ except np.linalg.LinAlgError:
1525
+ warnings.warn(
1526
+ "Singular covariance matrix in GMM combination. " "Using pseudoinverse.",
1527
+ UserWarning,
1528
+ stacklevel=3,
1529
+ )
1530
+ Omega_inv = np.linalg.pinv(Omega)
1531
+
1532
+ denom = float(ones @ Omega_inv @ ones)
1533
+ if denom <= 0 or not np.isfinite(denom):
1534
+ weights = np.full(k, 1.0 / k)
1535
+ att_gmm = float(weights @ att_vec)
1536
+ inf_gmm = weights @ inf_func_matrix
1537
+ se_gmm = float(np.sqrt(np.sum(inf_gmm**2) / n_units**2))
1538
+ else:
1539
+ weights = (Omega_inv @ ones) / denom
1540
+ att_gmm = float(weights @ att_vec)
1541
+ inf_gmm = weights @ inf_func_matrix
1542
+ # R: gmm_se <- sqrt(1 / (n * sum(inv_OMEGA)))
1543
+ se_gmm = float(np.sqrt(1.0 / (n_units * denom)))
1544
+
1545
+ return att_gmm, inf_gmm, weights, se_gmm