diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1285 @@
1
+ """WooldridgeDiD: Extended Two-Way Fixed Effects (ETWFE) estimator.
2
+
3
+ Implements Wooldridge (2025, 2023) ETWFE, faithful to the Stata jwdid package.
4
+
5
+ References
6
+ ----------
7
+ Wooldridge (2025). Two-Way Fixed Effects, the Two-Way Mundlak Regression,
8
+ and Difference-in-Differences Estimators. Empirical Economics, 69(5), 2545-2587.
9
+ Wooldridge (2023). Simple approaches to nonlinear difference-in-differences
10
+ with panel data. The Econometrics Journal, 26(3), C31-C66.
11
+ Friosavila (2021). jwdid: Stata module. SSC s459114.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ from diff_diff.linalg import compute_robust_vcov, solve_logit, solve_ols, solve_poisson
22
+ from diff_diff.utils import safe_inference, within_transform
23
+ from diff_diff.wooldridge_results import WooldridgeDiDResults
24
+
25
+ _VALID_METHODS = ("ols", "logit", "poisson")
26
+ _VALID_CONTROL_GROUPS = ("never_treated", "not_yet_treated")
27
+ _VALID_BOOTSTRAP_WEIGHTS = ("rademacher", "webb", "mammen")
28
+
29
+
30
+ def _logistic(x: np.ndarray) -> np.ndarray:
31
+ return 1.0 / (1.0 + np.exp(-x))
32
+
33
+
34
+ def _logistic_deriv(x: np.ndarray) -> np.ndarray:
35
+ p = _logistic(x)
36
+ return p * (1.0 - p)
37
+
38
+
39
+ def _compute_weighted_agg(
40
+ gt_effects: Dict,
41
+ gt_weights: Dict,
42
+ gt_keys: List,
43
+ gt_vcov: Optional[np.ndarray],
44
+ alpha: float,
45
+ df: Optional[int] = None,
46
+ ) -> Dict:
47
+ """Compute simple (overall) weighted average ATT and SE via delta method."""
48
+ post_keys = [(g, t) for (g, t) in gt_keys if t >= g]
49
+ w_total = sum(gt_weights.get(k, 0) for k in post_keys)
50
+ if w_total == 0:
51
+ att = float("nan")
52
+ se = float("nan")
53
+ else:
54
+ att = (
55
+ sum(gt_weights.get(k, 0) * gt_effects[k]["att"] for k in post_keys if k in gt_effects)
56
+ / w_total
57
+ )
58
+ if gt_vcov is not None:
59
+ w_vec = np.array(
60
+ [gt_weights.get(k, 0) / w_total if k in post_keys else 0.0 for k in gt_keys]
61
+ )
62
+ var = float(w_vec @ gt_vcov @ w_vec)
63
+ se = float(np.sqrt(max(var, 0.0)))
64
+ else:
65
+ se = float("nan")
66
+
67
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=alpha, df=df)
68
+ return {"att": att, "se": se, "t_stat": t_stat, "p_value": p_value, "conf_int": conf_int}
69
+
70
+
71
+ def _resolve_survey_for_wooldridge(survey_design, sample, cluster_ids, cluster_name):
72
+ """Resolve survey design, inject cluster as PSU, recompute metadata.
73
+
74
+ Shared helper for all three WooldridgeDiD sub-fitters. Matches the
75
+ resolution chain in DifferenceInDifferences.fit() (estimators.py:344-359).
76
+ """
77
+ from diff_diff.survey import (
78
+ _resolve_survey_for_fit,
79
+ _resolve_effective_cluster,
80
+ _inject_cluster_as_psu,
81
+ compute_survey_metadata,
82
+ )
83
+
84
+ resolved, survey_weights, survey_weight_type, survey_metadata = (
85
+ _resolve_survey_for_fit(survey_design, sample)
86
+ )
87
+ if resolved is not None and resolved.uses_replicate_variance:
88
+ raise NotImplementedError(
89
+ "WooldridgeDiD does not yet support replicate-weight variance. "
90
+ "Use TSL (strata/PSU/FPC) instead."
91
+ )
92
+ if resolved is not None and resolved.weight_type != "pweight":
93
+ raise ValueError(
94
+ f"WooldridgeDiD survey support requires weight_type='pweight', "
95
+ f"got '{resolved.weight_type}'. The survey variance math "
96
+ f"assumes probability weights (pweight)."
97
+ )
98
+ if resolved is not None:
99
+ effective_cluster = _resolve_effective_cluster(
100
+ resolved, cluster_ids, cluster_name
101
+ )
102
+ if effective_cluster is not None:
103
+ resolved = _inject_cluster_as_psu(resolved, effective_cluster)
104
+ if resolved.psu is not None and survey_metadata is not None:
105
+ raw_w = (
106
+ sample[survey_design.weights].values.astype(np.float64)
107
+ if survey_design.weights
108
+ else np.ones(len(sample), dtype=np.float64)
109
+ )
110
+ survey_metadata = compute_survey_metadata(resolved, raw_w)
111
+ df_inf = resolved.df_survey if resolved is not None else None
112
+ return resolved, survey_weights, survey_weight_type, survey_metadata, df_inf
113
+
114
+
115
+ def _filter_sample(
116
+ data: pd.DataFrame,
117
+ unit: str,
118
+ time: str,
119
+ cohort: str,
120
+ control_group: str,
121
+ anticipation: int,
122
+ ) -> pd.DataFrame:
123
+ """Return the analysis sample following jwdid selection rules.
124
+
125
+ All treated units keep ALL observations (pre- and post-treatment) for
126
+ proper FE estimation. The control_group setting affects which additional
127
+ control observations are included, AND the interaction matrix structure
128
+ (see _build_interaction_matrix).
129
+ """
130
+ df = data.copy()
131
+ # Normalise never-treated: fill NaN cohort with 0
132
+ df[cohort] = df[cohort].fillna(0)
133
+
134
+ treated_mask = df[cohort] > 0
135
+
136
+ if control_group == "never_treated":
137
+ control_mask = df[cohort] == 0
138
+ else: # not_yet_treated
139
+ # Keep untreated-at-t observations for not-yet-treated units
140
+ control_mask = (df[cohort] == 0) | (df[cohort] > df[time])
141
+
142
+ return df[treated_mask | control_mask].copy()
143
+
144
+
145
+ def _build_interaction_matrix(
146
+ data: pd.DataFrame,
147
+ cohort: str,
148
+ time: str,
149
+ anticipation: int,
150
+ control_group: str = "not_yet_treated",
151
+ method: str = "ols",
152
+ ) -> Tuple[np.ndarray, List[str], List[Tuple[Any, Any]]]:
153
+ """Build the saturated cohort×time interaction design matrix.
154
+
155
+ For ``not_yet_treated``: only post-treatment cells (t >= g - anticipation).
156
+ Pre-treatment obs from treated units sit in the regression baseline alongside
157
+ not-yet-treated controls.
158
+
159
+ For ``never_treated`` + OLS: ALL (g, t) pairs for each treated cohort. This
160
+ "absorbs" pre-treatment obs from treated units into their own indicators so
161
+ they do not serve as implicit controls in the baseline. Only never-treated
162
+ observations remain in the omitted category. Pre-treatment coefficients
163
+ (t < g) serve as placebo/pre-trend tests.
164
+
165
+ For ``never_treated`` + nonlinear (logit/Poisson): post-treatment cells only.
166
+ Nonlinear paths use explicit cohort + time dummies (not within-transformation),
167
+ so including all (g, t) cells would create exact collinearity between each
168
+ cohort dummy and the sum of its cell indicators.
169
+
170
+ Returns
171
+ -------
172
+ X_int : (n, n_cells) binary indicator matrix
173
+ col_names : list of string labels "g{g}_t{t}"
174
+ gt_keys : list of (g, t) tuples in same column order
175
+ """
176
+ groups = sorted(g for g in data[cohort].unique() if g > 0)
177
+ times = sorted(data[time].unique())
178
+ cohort_vals = data[cohort].values
179
+ time_vals = data[time].values
180
+
181
+ # OLS + never_treated: all (g,t) pairs (placebo via within-transform FE)
182
+ # Nonlinear + never_treated: post-treatment only (avoids cohort dummy collinearity)
183
+ # not_yet_treated: post-treatment only (always)
184
+ include_pre = control_group == "never_treated" and method == "ols"
185
+
186
+ cols = []
187
+ col_names = []
188
+ gt_keys = []
189
+
190
+ for g in groups:
191
+ for t in times:
192
+ if include_pre or t >= g - anticipation:
193
+ indicator = ((cohort_vals == g) & (time_vals == t)).astype(float)
194
+ cols.append(indicator)
195
+ col_names.append(f"g{g}_t{t}")
196
+ gt_keys.append((g, t))
197
+
198
+ if not cols:
199
+ return np.empty((len(data), 0)), [], []
200
+ return np.column_stack(cols), col_names, gt_keys
201
+
202
+
203
+ def _prepare_covariates(
204
+ data: pd.DataFrame,
205
+ exovar: Optional[List[str]],
206
+ xtvar: Optional[List[str]],
207
+ xgvar: Optional[List[str]],
208
+ cohort: str,
209
+ time: str,
210
+ demean_covariates: bool,
211
+ groups: List[Any],
212
+ ) -> Optional[np.ndarray]:
213
+ """Build covariate matrix following jwdid covariate type conventions.
214
+
215
+ Returns None if no covariates, else (n, k) array.
216
+ """
217
+ parts = []
218
+
219
+ if exovar:
220
+ parts.append(data[exovar].values.astype(float))
221
+
222
+ if xtvar:
223
+ if demean_covariates:
224
+ # Within-cohort×period demeaning
225
+ grp_key = data[cohort].astype(str) + "_" + data[time].astype(str)
226
+ tmp = data[xtvar].copy()
227
+ for col in xtvar:
228
+ tmp[col] = tmp[col] - tmp.groupby(grp_key)[col].transform("mean")
229
+ parts.append(tmp.values.astype(float))
230
+ else:
231
+ parts.append(data[xtvar].values.astype(float))
232
+
233
+ if xgvar:
234
+ for g in groups:
235
+ g_indicator = (data[cohort] == g).values.astype(float)
236
+ for col in xgvar:
237
+ parts.append((g_indicator * data[col].values).reshape(-1, 1))
238
+
239
+ if not parts:
240
+ return None
241
+ return np.hstack([p if p.ndim == 2 else p.reshape(-1, 1) for p in parts])
242
+
243
+
244
+ class WooldridgeDiD:
245
+ """Extended Two-Way Fixed Effects (ETWFE) DiD estimator.
246
+
247
+ Implements the Wooldridge (2021) saturated cohort×time regression and
248
+ Wooldridge (2023) nonlinear extensions (logit, Poisson). Produces all
249
+ four ``jwdid_estat`` aggregation types: simple, group, calendar, event.
250
+
251
+ Parameters
252
+ ----------
253
+ method : {"ols", "logit", "poisson"}
254
+ Estimation method. "ols" for continuous outcomes; "logit" for binary
255
+ or fractional outcomes; "poisson" for count data.
256
+ control_group : {"not_yet_treated", "never_treated"}
257
+ Which units serve as the comparison group. "not_yet_treated" (jwdid
258
+ default) uses all untreated observations at each time period;
259
+ "never_treated" uses only units never treated throughout the sample.
260
+ anticipation : int
261
+ Number of periods before treatment onset to include as treatment cells
262
+ (anticipation effects). 0 means no anticipation.
263
+ demean_covariates : bool
264
+ If True (jwdid default), ``xtvar`` covariates are demeaned within each
265
+ cohort×period cell before entering the regression. Set to False to
266
+ replicate jwdid's ``xasis`` option.
267
+ alpha : float
268
+ Significance level for confidence intervals.
269
+ cluster : str or None
270
+ Column name to use for cluster-robust SEs. Defaults to the ``unit``
271
+ identifier passed to ``fit()``.
272
+ n_bootstrap : int
273
+ Number of bootstrap replications. 0 disables bootstrap.
274
+ bootstrap_weights : {"rademacher", "webb", "mammen"}
275
+ Bootstrap weight distribution.
276
+ seed : int or None
277
+ Random seed for reproducibility.
278
+ rank_deficient_action : {"warn", "error", "silent"}
279
+ How to handle rank-deficient design matrices.
280
+ """
281
+
282
+ def __init__(
283
+ self,
284
+ method: str = "ols",
285
+ control_group: str = "not_yet_treated",
286
+ anticipation: int = 0,
287
+ demean_covariates: bool = True,
288
+ alpha: float = 0.05,
289
+ cluster: Optional[str] = None,
290
+ n_bootstrap: int = 0,
291
+ bootstrap_weights: str = "rademacher",
292
+ seed: Optional[int] = None,
293
+ rank_deficient_action: str = "warn",
294
+ ) -> None:
295
+ if method not in _VALID_METHODS:
296
+ raise ValueError(f"method must be one of {_VALID_METHODS}, got {method!r}")
297
+ if control_group not in _VALID_CONTROL_GROUPS:
298
+ raise ValueError(
299
+ f"control_group must be one of {_VALID_CONTROL_GROUPS}, got {control_group!r}"
300
+ )
301
+ if anticipation < 0:
302
+ raise ValueError(f"anticipation must be >= 0, got {anticipation}")
303
+ if bootstrap_weights not in _VALID_BOOTSTRAP_WEIGHTS:
304
+ raise ValueError(
305
+ f"bootstrap_weights must be one of {_VALID_BOOTSTRAP_WEIGHTS}, "
306
+ f"got {bootstrap_weights!r}"
307
+ )
308
+
309
+ self.method = method
310
+ self.control_group = control_group
311
+ self.anticipation = anticipation
312
+ self.demean_covariates = demean_covariates
313
+ self.alpha = alpha
314
+ self.cluster = cluster
315
+ self.n_bootstrap = n_bootstrap
316
+ self.bootstrap_weights = bootstrap_weights
317
+ self.seed = seed
318
+ self.rank_deficient_action = rank_deficient_action
319
+
320
+ self.is_fitted_: bool = False
321
+ self._results: Optional[WooldridgeDiDResults] = None
322
+
323
+ @property
324
+ def results_(self) -> WooldridgeDiDResults:
325
+ if not self.is_fitted_:
326
+ raise RuntimeError("Call fit() before accessing results_")
327
+ return self._results # type: ignore[return-value]
328
+
329
+ def get_params(self) -> Dict[str, Any]:
330
+ """Return estimator parameters (sklearn-compatible)."""
331
+ return {
332
+ "method": self.method,
333
+ "control_group": self.control_group,
334
+ "anticipation": self.anticipation,
335
+ "demean_covariates": self.demean_covariates,
336
+ "alpha": self.alpha,
337
+ "cluster": self.cluster,
338
+ "n_bootstrap": self.n_bootstrap,
339
+ "bootstrap_weights": self.bootstrap_weights,
340
+ "seed": self.seed,
341
+ "rank_deficient_action": self.rank_deficient_action,
342
+ }
343
+
344
+ def set_params(self, **params: Any) -> "WooldridgeDiD":
345
+ """Set estimator parameters (sklearn-compatible). Returns self."""
346
+ for key, value in params.items():
347
+ if not hasattr(self, key):
348
+ raise ValueError(f"Unknown parameter: {key!r}")
349
+ setattr(self, key, value)
350
+ # Re-run validation after setting params
351
+ if self.method not in _VALID_METHODS:
352
+ raise ValueError(f"method must be one of {_VALID_METHODS}, got {self.method!r}")
353
+ if self.control_group not in _VALID_CONTROL_GROUPS:
354
+ raise ValueError(
355
+ f"control_group must be one of {_VALID_CONTROL_GROUPS}, "
356
+ f"got {self.control_group!r}"
357
+ )
358
+ if self.anticipation < 0:
359
+ raise ValueError(f"anticipation must be >= 0, got {self.anticipation}")
360
+ if self.bootstrap_weights not in _VALID_BOOTSTRAP_WEIGHTS:
361
+ raise ValueError(
362
+ f"bootstrap_weights must be one of {_VALID_BOOTSTRAP_WEIGHTS}, "
363
+ f"got {self.bootstrap_weights!r}"
364
+ )
365
+ return self
366
+
367
+ def fit(
368
+ self,
369
+ data: pd.DataFrame,
370
+ outcome: str,
371
+ unit: str,
372
+ time: str,
373
+ cohort: str,
374
+ exovar: Optional[List[str]] = None,
375
+ xtvar: Optional[List[str]] = None,
376
+ xgvar: Optional[List[str]] = None,
377
+ survey_design=None,
378
+ ) -> WooldridgeDiDResults:
379
+ """Fit the ETWFE model. See class docstring for parameter details.
380
+
381
+ Parameters
382
+ ----------
383
+ data : DataFrame with panel data (long format)
384
+ outcome : outcome column name
385
+ unit : unit identifier column
386
+ time : time period column
387
+ cohort : first treatment period (0 or NaN = never treated)
388
+ exovar : time-invariant covariates added without interaction/demeaning
389
+ xtvar : time-varying covariates (demeaned within cohort×period cells
390
+ when ``demean_covariates=True``)
391
+ xgvar : covariates interacted with each cohort indicator
392
+ survey_design : SurveyDesign, optional
393
+ Survey design specification for complex survey data. Supports
394
+ stratified, clustered, and weighted designs via Taylor Series
395
+ Linearization (TSL). Replicate-weight designs raise
396
+ ``NotImplementedError``.
397
+ """
398
+ df = data.copy()
399
+ df[cohort] = df[cohort].fillna(0)
400
+
401
+ # 0a. Validate cohort is time-invariant within unit
402
+ cohort_per_unit = df.groupby(unit)[cohort].nunique()
403
+ bad_units = cohort_per_unit[cohort_per_unit > 1]
404
+ if len(bad_units) > 0:
405
+ example = bad_units.index[0]
406
+ raise ValueError(
407
+ f"Cohort column '{cohort}' is not time-invariant within unit. "
408
+ f"Unit {example!r} has {int(bad_units.iloc[0])} distinct cohort "
409
+ f"values. The cohort column must be constant within each unit."
410
+ )
411
+
412
+ # 0b. Reject bootstrap for nonlinear methods (not implemented)
413
+ if self.n_bootstrap > 0 and self.method != "ols":
414
+ raise ValueError(
415
+ f"Bootstrap inference is only supported for method='ols'. "
416
+ f"Got method={self.method!r} with n_bootstrap={self.n_bootstrap}. "
417
+ f"Set n_bootstrap=0 for analytic SEs."
418
+ )
419
+
420
+ # 0c. Reject bootstrap + survey (no survey-aware bootstrap variant)
421
+ if self.n_bootstrap > 0 and survey_design is not None:
422
+ raise ValueError(
423
+ "Bootstrap inference is not supported with survey_design. "
424
+ "Set n_bootstrap=0 for analytic survey SEs."
425
+ )
426
+
427
+ # 1. Filter to analysis sample
428
+ sample = _filter_sample(df, unit, time, cohort, self.control_group, self.anticipation)
429
+
430
+ # 1b. Identification checks
431
+ groups = sorted(g for g in sample[cohort].unique() if g > 0)
432
+ if len(groups) == 0:
433
+ raise ValueError(
434
+ "No treated cohorts found in data. Ensure the cohort column "
435
+ "contains values > 0 for treated units."
436
+ )
437
+ if self.control_group == "never_treated" and not (sample[cohort] == 0).any():
438
+ raise ValueError(
439
+ "control_group='never_treated' but no never-treated units "
440
+ "(cohort == 0) found. Use 'not_yet_treated' or add "
441
+ "never-treated units."
442
+ )
443
+ if self.control_group == "not_yet_treated":
444
+ # Verify at least some untreated comparison observations exist
445
+ has_untreated = (sample[cohort] == 0).any() or (
446
+ (sample[cohort] - self.anticipation) > sample[time]
447
+ ).any()
448
+ if not has_untreated:
449
+ raise ValueError(
450
+ "control_group='not_yet_treated' but no untreated comparison "
451
+ "observations exist. All units are treated at all observed "
452
+ "time periods. Use 'never_treated' with a never-treated group."
453
+ )
454
+
455
+ # 2. Build interaction matrix
456
+ X_int, int_col_names, gt_keys = _build_interaction_matrix(
457
+ sample,
458
+ cohort=cohort,
459
+ time=time,
460
+ anticipation=self.anticipation,
461
+ control_group=self.control_group,
462
+ method=self.method,
463
+ )
464
+ if X_int.shape[1] == 0:
465
+ raise ValueError(
466
+ "No valid treatment cells found. Check that treated units "
467
+ "have post-treatment observations in the data."
468
+ )
469
+
470
+ # 3. Covariates
471
+ X_cov = _prepare_covariates(
472
+ sample,
473
+ exovar=exovar,
474
+ xtvar=xtvar,
475
+ xgvar=xgvar,
476
+ cohort=cohort,
477
+ time=time,
478
+ demean_covariates=self.demean_covariates,
479
+ groups=groups,
480
+ )
481
+
482
+ all_regressors = int_col_names.copy()
483
+ if X_cov is not None:
484
+ # Build treatment × demeaned-covariate interactions (W2025 Eq. 5.3)
485
+ # For each (g,t) cell indicator and each covariate, create the
486
+ # moderating interaction: X_int[:, i] * x_hat[:, j]
487
+ # This allows treatment effects to vary with covariates within cells.
488
+ cov_names_list = list(exovar or []) + list(xtvar or []) + list(xgvar or [])
489
+ # Compute cohort-demeaned covariates for interaction terms
490
+ X_cov_demeaned = X_cov.copy()
491
+ if self.demean_covariates:
492
+ cohort_vals = sample[cohort].values
493
+ for j in range(X_cov.shape[1]):
494
+ for g in groups:
495
+ mask = cohort_vals == g
496
+ if mask.any():
497
+ X_cov_demeaned[mask, j] -= X_cov[mask, j].mean()
498
+
499
+ interact_cols = []
500
+ interact_names = []
501
+ for i, gt_name in enumerate(int_col_names):
502
+ for j in range(X_cov_demeaned.shape[1]):
503
+ interact_cols.append(X_int[:, i] * X_cov_demeaned[:, j])
504
+ cov_label = cov_names_list[j] if j < len(cov_names_list) else f"cov{j}"
505
+ interact_names.append(f"{gt_name}_x_{cov_label}")
506
+
507
+ # Cohort × covariate interactions (W2025 Eq. 5.3: D_g × X)
508
+ # exovar/xtvar get automatic D_g × X; xgvar already has D_g × X
509
+ cov_cols_for_dg = list(exovar or []) + list(xtvar or [])
510
+ cohort_cov_cols = []
511
+ cohort_cov_names = []
512
+ if cov_cols_for_dg:
513
+ cohort_vals_arr = sample[cohort].values
514
+ for g in groups:
515
+ g_ind = (cohort_vals_arr == g).astype(float)
516
+ for col in cov_cols_for_dg:
517
+ cohort_cov_cols.append(g_ind * sample[col].values.astype(float))
518
+ cohort_cov_names.append(f"D{g}_x_{col}")
519
+
520
+ # Time × covariate interactions (W2025 Eq. 5.3: f_t × X)
521
+ # All covariates get f_t × X, drop first time for identification
522
+ all_cov_cols = list(exovar or []) + list(xtvar or []) + list(xgvar or [])
523
+ times_sorted = sorted(sample[time].unique())
524
+ time_cov_cols = []
525
+ time_cov_names = []
526
+ time_vals_arr = sample[time].values
527
+ for t in times_sorted[1:]: # drop first
528
+ t_ind = (time_vals_arr == t).astype(float)
529
+ for col in all_cov_cols:
530
+ time_cov_cols.append(t_ind * sample[col].values.astype(float))
531
+ time_cov_names.append(f"ft{t}_x_{col}")
532
+
533
+ # Assemble: [cell_indicators, cell×cov, D_g×X, f_t×X, raw_cov]
534
+ blocks = [X_int]
535
+ if interact_cols:
536
+ blocks.append(np.column_stack(interact_cols))
537
+ all_regressors.extend(interact_names)
538
+ if cohort_cov_cols:
539
+ blocks.append(np.column_stack(cohort_cov_cols))
540
+ all_regressors.extend(cohort_cov_names)
541
+ if time_cov_cols:
542
+ blocks.append(np.column_stack(time_cov_cols))
543
+ all_regressors.extend(time_cov_names)
544
+ blocks.append(X_cov)
545
+ for i in range(X_cov.shape[1]):
546
+ all_regressors.append(f"_cov_{i}")
547
+ X_design = np.hstack(blocks)
548
+ else:
549
+ X_design = X_int
550
+
551
+ if self.method == "ols":
552
+ results = self._fit_ols(
553
+ sample,
554
+ outcome,
555
+ unit,
556
+ time,
557
+ cohort,
558
+ X_design,
559
+ all_regressors,
560
+ gt_keys,
561
+ int_col_names,
562
+ groups,
563
+ survey_design=survey_design,
564
+ )
565
+ elif self.method == "logit":
566
+ n_cov_interact = X_cov.shape[1] if X_cov is not None else 0
567
+ results = self._fit_logit(
568
+ sample,
569
+ outcome,
570
+ unit,
571
+ time,
572
+ cohort,
573
+ X_design,
574
+ all_regressors,
575
+ gt_keys,
576
+ int_col_names,
577
+ groups,
578
+ n_cov_interact=n_cov_interact,
579
+ survey_design=survey_design,
580
+ )
581
+ else: # poisson
582
+ n_cov_interact = X_cov.shape[1] if X_cov is not None else 0
583
+ results = self._fit_poisson(
584
+ sample,
585
+ outcome,
586
+ unit,
587
+ time,
588
+ cohort,
589
+ X_design,
590
+ all_regressors,
591
+ gt_keys,
592
+ int_col_names,
593
+ groups,
594
+ n_cov_interact=n_cov_interact,
595
+ survey_design=survey_design,
596
+ )
597
+
598
+ self._results = results
599
+ self.is_fitted_ = True
600
+ return results
601
+
602
+ def _count_control_units(self, sample: pd.DataFrame, unit: str, cohort: str, time: str) -> int:
603
+ """Count control units consistent with control_group setting."""
604
+ n_never = int(sample[sample[cohort] == 0][unit].nunique())
605
+ if self.control_group == "not_yet_treated":
606
+ # Also count future-treated units that contribute pre-anticipation obs
607
+ nyt = sample[
608
+ (sample[cohort] > 0) & (sample[time] < sample[cohort] - self.anticipation)
609
+ ][unit].nunique()
610
+ return n_never + int(nyt)
611
+ return n_never
612
+
613
+ def _fit_ols(
614
+ self,
615
+ sample: pd.DataFrame,
616
+ outcome: str,
617
+ unit: str,
618
+ time: str,
619
+ cohort: str,
620
+ X_design: np.ndarray,
621
+ col_names: List[str],
622
+ gt_keys: List[Tuple],
623
+ int_col_names: List[str],
624
+ groups: List[Any],
625
+ survey_design=None,
626
+ ) -> WooldridgeDiDResults:
627
+ """OLS path: within-transform FE, solve_ols, cluster SE."""
628
+ # Reset index so numpy positional indexing matches pandas groupby
629
+ sample = sample.reset_index(drop=True)
630
+ # Cluster IDs (default: unit level) — needed before survey resolution
631
+ cluster_col = self.cluster if self.cluster else unit
632
+ cluster_ids = sample[cluster_col].values
633
+
634
+ # Resolve survey design, inject cluster as PSU only when user explicitly set cluster=
635
+ survey_cluster_ids = cluster_ids if self.cluster else None
636
+ resolved, survey_weights, survey_weight_type, survey_metadata, df_inf = (
637
+ _resolve_survey_for_wooldridge(survey_design, sample, survey_cluster_ids, self.cluster)
638
+ )
639
+
640
+ # 4. Within-transform: absorb unit + time FE
641
+ all_vars = [outcome] + [f"_x{i}" for i in range(X_design.shape[1])]
642
+ tmp = sample[[unit, time]].copy()
643
+ tmp[outcome] = sample[outcome].values
644
+ for i in range(X_design.shape[1]):
645
+ tmp[f"_x{i}"] = X_design[:, i]
646
+
647
+ # Use iterative alternating projections for demeaning (exact for
648
+ # both balanced and unbalanced panels). Survey weights change the
649
+ # weighted FWL projection — all columns (treatment interactions +
650
+ # covariates) are demeaned together.
651
+ wt_weights = survey_weights if survey_weights is not None else np.ones(len(tmp))
652
+
653
+ # Guard: zero-weight unit/time groups cause 0/0 in within_transform
654
+ if survey_weights is not None and np.any(survey_weights == 0):
655
+ sw_series = pd.Series(survey_weights, index=sample.index)
656
+ for grp_col, grp_label in [(unit, "unit"), (time, "time period")]:
657
+ grp_sums = sw_series.groupby(sample[grp_col]).sum()
658
+ zero_grps = grp_sums[grp_sums == 0].index.tolist()
659
+ if zero_grps:
660
+ raise ValueError(
661
+ f"Survey weights sum to zero for {grp_label}(s) "
662
+ f"{zero_grps[:3]}. Cannot compute weighted "
663
+ f"within-transformation. Remove zero-weight "
664
+ f"{grp_label}s or use non-zero weights."
665
+ )
666
+
667
+ transformed = within_transform(
668
+ tmp, all_vars, unit=unit, time=time, suffix="_demeaned",
669
+ weights=wt_weights,
670
+ )
671
+
672
+ y = transformed[f"{outcome}_demeaned"].values
673
+ X_cols = [f"_x{i}_demeaned" for i in range(X_design.shape[1])]
674
+ X = transformed[X_cols].values
675
+
676
+ # 6. Solve OLS (skip cluster-robust vcov when survey will provide TSL vcov)
677
+ coefs, resids, vcov = solve_ols(
678
+ X,
679
+ y,
680
+ cluster_ids=cluster_ids,
681
+ return_vcov=(resolved is None),
682
+ rank_deficient_action=self.rank_deficient_action,
683
+ column_names=col_names,
684
+ weights=survey_weights,
685
+ weight_type=survey_weight_type,
686
+ )
687
+
688
+ # Survey TSL vcov replaces cluster-robust vcov
689
+ if resolved is not None:
690
+ from diff_diff.survey import compute_survey_vcov
691
+ nan_mask_ols = np.isnan(coefs)
692
+ if np.any(nan_mask_ols):
693
+ kept = ~nan_mask_ols
694
+ vcov_kept = compute_survey_vcov(X[:, kept], resids, resolved)
695
+ vcov = np.full((len(coefs), len(coefs)), np.nan)
696
+ kept_idx = np.where(kept)[0]
697
+ vcov[np.ix_(kept_idx, kept_idx)] = vcov_kept
698
+ else:
699
+ vcov = compute_survey_vcov(X, resids, resolved)
700
+
701
+ # 7. Extract β_{g,t} and build gt_effects dict
702
+ gt_effects: Dict[Tuple, Dict] = {}
703
+ gt_weights: Dict[Tuple, int] = {}
704
+ for idx, (g, t) in enumerate(gt_keys):
705
+ if idx >= len(coefs):
706
+ break
707
+ # Skip cells whose coefficient was dropped (rank deficiency)
708
+ if np.isnan(coefs[idx]):
709
+ continue
710
+ att = float(coefs[idx])
711
+ se = float(np.sqrt(max(vcov[idx, idx], 0.0))) if vcov is not None else float("nan")
712
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_inf)
713
+ gt_effects[(g, t)] = {
714
+ "att": att,
715
+ "se": se,
716
+ "t_stat": t_stat,
717
+ "p_value": p_value,
718
+ "conf_int": conf_int,
719
+ }
720
+ gt_weights[(g, t)] = int(((sample[cohort] == g) & (sample[time] == t)).sum())
721
+
722
+ # Extract vcov submatrix for identified β_{g,t} only (skip NaN/dropped)
723
+ gt_keys_ordered = list(gt_effects.keys())
724
+ if vcov is not None and gt_keys_ordered:
725
+ # Map from gt_keys_ordered to original indices in the coef vector
726
+ orig_indices = [i for i, k in enumerate(gt_keys) if k in gt_effects]
727
+ gt_vcov = vcov[np.ix_(orig_indices, orig_indices)]
728
+ else:
729
+ gt_vcov = None
730
+
731
+ # 8. Simple aggregation (always computed)
732
+ overall = _compute_weighted_agg(
733
+ gt_effects, gt_weights, gt_keys_ordered, gt_vcov, self.alpha, df=df_inf
734
+ )
735
+
736
+ # Metadata
737
+ n_treated = int(sample[sample[cohort] > 0][unit].nunique())
738
+ n_control = self._count_control_units(sample, unit, cohort, time)
739
+ all_times = sorted(sample[time].unique().tolist())
740
+
741
+ results = WooldridgeDiDResults(
742
+ group_time_effects=gt_effects,
743
+ overall_att=overall["att"],
744
+ overall_se=overall["se"],
745
+ overall_t_stat=overall["t_stat"],
746
+ overall_p_value=overall["p_value"],
747
+ overall_conf_int=overall["conf_int"],
748
+ method=self.method,
749
+ control_group=self.control_group,
750
+ groups=groups,
751
+ time_periods=all_times,
752
+ n_obs=len(sample),
753
+ n_treated_units=n_treated,
754
+ n_control_units=n_control,
755
+ alpha=self.alpha,
756
+ anticipation=self.anticipation,
757
+ survey_metadata=survey_metadata,
758
+ _gt_weights=gt_weights,
759
+ _gt_vcov=gt_vcov,
760
+ _gt_keys=gt_keys_ordered,
761
+ _df_survey=df_inf,
762
+ )
763
+
764
+ # 9. Optional multiplier bootstrap (overrides analytic SE for overall ATT)
765
+ if self.n_bootstrap > 0:
766
+ rng = np.random.default_rng(self.seed)
767
+ # Draw weights at the analytic cluster level (not always unit)
768
+ unique_boot_clusters = np.unique(cluster_ids)
769
+ n_boot_clusters = len(unique_boot_clusters)
770
+ post_keys = [(g, t) for (g, t) in gt_keys_ordered if t >= g]
771
+ w_total_b = sum(gt_weights.get(k, 0) for k in post_keys)
772
+ boot_atts: List[float] = []
773
+ for _ in range(self.n_bootstrap):
774
+ if self.bootstrap_weights == "rademacher":
775
+ cl_weights = rng.choice([-1.0, 1.0], size=n_boot_clusters)
776
+ elif self.bootstrap_weights == "webb":
777
+ cl_weights = rng.choice(
778
+ [-np.sqrt(1.5), -1.0, -np.sqrt(0.5), np.sqrt(0.5), 1.0, np.sqrt(1.5)],
779
+ size=n_boot_clusters,
780
+ )
781
+ else: # mammen
782
+ phi = (1 + np.sqrt(5)) / 2
783
+ cl_weights = rng.choice(
784
+ [-(phi - 1), phi],
785
+ p=[phi / np.sqrt(5), (phi - 1) / np.sqrt(5)],
786
+ size=n_boot_clusters,
787
+ )
788
+ obs_weights = cl_weights[np.searchsorted(unique_boot_clusters, cluster_ids)]
789
+ y_boot = y + obs_weights * resids
790
+ coefs_b, _, _ = solve_ols(
791
+ X,
792
+ y_boot,
793
+ cluster_ids=cluster_ids,
794
+ return_vcov=False,
795
+ rank_deficient_action="silent",
796
+ )
797
+ if w_total_b > 0:
798
+ att_b = (
799
+ sum(
800
+ gt_weights.get(k, 0) * float(coefs_b[i])
801
+ for i, k in enumerate(gt_keys)
802
+ if k in post_keys and i < len(coefs_b)
803
+ )
804
+ / w_total_b
805
+ )
806
+ boot_atts.append(att_b)
807
+ if boot_atts:
808
+ boot_se = float(np.std(boot_atts, ddof=1))
809
+ t_stat_b, p_b, ci_b = safe_inference(results.overall_att, boot_se, alpha=self.alpha)
810
+ results.overall_se = boot_se
811
+ results.overall_t_stat = t_stat_b
812
+ results.overall_p_value = p_b
813
+ results.overall_conf_int = ci_b
814
+
815
+ return results
816
+
817
+ def _fit_logit(
818
+ self,
819
+ sample: pd.DataFrame,
820
+ outcome: str,
821
+ unit: str,
822
+ time: str,
823
+ cohort: str,
824
+ X_int: np.ndarray,
825
+ col_names: List[str],
826
+ gt_keys: List[Tuple],
827
+ int_col_names: List[str],
828
+ groups: List[Any],
829
+ n_cov_interact: int = 0,
830
+ survey_design=None,
831
+ ) -> WooldridgeDiDResults:
832
+ """Logit path: cohort + time additive FEs + solve_logit + ASF ATT.
833
+
834
+ Matches Stata jwdid method(logit): logit y [treatment_interactions]
835
+ i.gvar i.tvar — cohort main effects + time main effects (additive),
836
+ not cohort×time saturated group FEs.
837
+ """
838
+ n_int = len(int_col_names)
839
+
840
+ # Design matrix: treatment interactions + cohort FEs + time FEs
841
+ # This matches Stata's `i.gvar i.tvar` specification.
842
+ cohort_dummies = pd.get_dummies(sample[cohort], drop_first=True).values.astype(float)
843
+ time_dummies = pd.get_dummies(sample[time], drop_first=True).values.astype(float)
844
+ X_full = np.hstack([X_int, cohort_dummies, time_dummies])
845
+
846
+ y = sample[outcome].values.astype(float)
847
+ if not np.all(np.isfinite(y)):
848
+ raise ValueError("Outcome contains non-finite values (NaN/Inf).")
849
+ if np.any(y < 0) or np.any(y > 1):
850
+ raise ValueError(
851
+ f"method='logit' requires outcomes in [0, 1]. "
852
+ f"Got range [{y.min():.4f}, {y.max():.4f}]."
853
+ )
854
+ cluster_col = self.cluster if self.cluster else unit
855
+ cluster_ids = sample[cluster_col].values
856
+
857
+ # Resolve survey design, inject cluster as PSU only when user explicitly set cluster=
858
+ survey_cluster_ids = cluster_ids if self.cluster else None
859
+ resolved, survey_weights, survey_weight_type, survey_metadata, df_inf = (
860
+ _resolve_survey_for_wooldridge(survey_design, sample, survey_cluster_ids, self.cluster)
861
+ )
862
+ _has_survey = resolved is not None
863
+
864
+ beta, probs = solve_logit(
865
+ X_full,
866
+ y,
867
+ rank_deficient_action=self.rank_deficient_action,
868
+ weights=survey_weights,
869
+ )
870
+ # solve_logit prepends intercept — beta[0] is intercept, beta[1:] are X_full cols
871
+ beta_int_cols = beta[1 : n_int + 1] # treatment interaction coefficients
872
+
873
+ # Handle rank-deficient designs: identify kept columns, compute vcov
874
+ # on reduced design, then expand back
875
+ nan_mask = np.isnan(beta)
876
+ beta_clean = np.where(nan_mask, 0.0, beta)
877
+ kept_beta = ~nan_mask
878
+
879
+ # QMLE sandwich vcov
880
+ resids = y - probs
881
+ X_with_intercept = np.column_stack([np.ones(len(y)), X_full])
882
+
883
+ if _has_survey:
884
+ # X_tilde trick: transform design matrix so compute_survey_vcov
885
+ # produces the correct QMLE sandwich for nonlinear models.
886
+ # Bread: (X_tilde'WX_tilde)^{-1} = (X'diag(w*V)X)^{-1}
887
+ # Scores: w*X_tilde*r_tilde = w*X*(y-mu)
888
+ from diff_diff.survey import compute_survey_vcov
889
+ V = probs * (1 - probs)
890
+ sqrt_V = np.sqrt(np.clip(V, 1e-20, None))
891
+ X_tilde = X_with_intercept * sqrt_V[:, None]
892
+ r_tilde = resids / sqrt_V
893
+ if np.any(nan_mask):
894
+ X_tilde_r = X_tilde[:, kept_beta]
895
+ vcov_reduced = compute_survey_vcov(X_tilde_r, r_tilde, resolved)
896
+ k_full = len(beta)
897
+ vcov_full = np.full((k_full, k_full), np.nan)
898
+ kept_idx = np.where(kept_beta)[0]
899
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
900
+ else:
901
+ vcov_full = compute_survey_vcov(X_tilde, r_tilde, resolved)
902
+ else:
903
+ # Cluster-robust QMLE sandwich (non-survey path)
904
+ if np.any(nan_mask):
905
+ X_reduced = X_with_intercept[:, kept_beta]
906
+ vcov_reduced = compute_robust_vcov(
907
+ X_reduced,
908
+ resids,
909
+ cluster_ids=cluster_ids,
910
+ weights=probs * (1 - probs),
911
+ weight_type="aweight",
912
+ )
913
+ k_full = len(beta)
914
+ vcov_full = np.full((k_full, k_full), np.nan)
915
+ kept_idx = np.where(kept_beta)[0]
916
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
917
+ else:
918
+ vcov_full = compute_robust_vcov(
919
+ X_with_intercept,
920
+ resids,
921
+ cluster_ids=cluster_ids,
922
+ weights=probs * (1 - probs),
923
+ weight_type="aweight",
924
+ )
925
+ beta = beta_clean
926
+
927
+ # Survey-weighted averaging helpers for ASF computation
928
+ def _avg(a, cell_mask):
929
+ if survey_weights is not None:
930
+ return float(np.average(a, weights=survey_weights[cell_mask]))
931
+ return float(np.mean(a))
932
+
933
+ def _avg_ax0(a, cell_mask):
934
+ if survey_weights is not None:
935
+ return np.average(a, weights=survey_weights[cell_mask], axis=0)
936
+ return np.mean(a, axis=0)
937
+
938
+ # ASF ATT(g,t) for treated units in each cell
939
+ gt_effects: Dict[Tuple, Dict] = {}
940
+ gt_weights: Dict[Tuple, int] = {}
941
+ gt_grads: Dict[Tuple, np.ndarray] = {} # store per-cell gradients for aggregate SE
942
+ for idx, (g, t) in enumerate(gt_keys):
943
+ if idx >= n_int:
944
+ break
945
+ cell_mask = (sample[cohort] == g) & (sample[time] == t)
946
+ if cell_mask.sum() == 0:
947
+ continue
948
+ # Skip cells whose interaction coefficient was dropped (rank deficiency)
949
+ # Skip cells where all survey weights are zero (non-estimable)
950
+ if survey_weights is not None and np.sum(survey_weights[cell_mask]) == 0:
951
+ continue
952
+ delta = beta_int_cols[idx]
953
+ if np.isnan(delta):
954
+ continue
955
+ eta_base = X_with_intercept[cell_mask] @ beta
956
+ # Counterfactual: zero the FULL treatment block for cell (g,t).
957
+ # This includes the scalar cell effect δ_{g,t} AND any cell ×
958
+ # covariate interaction effects ξ_{g,t,j} * x_hat_j (W2023 Eq. 3.15).
959
+ delta_total = np.full(cell_mask.sum(), float(delta))
960
+ for j in range(n_cov_interact):
961
+ coef_pos = 1 + n_int + idx * n_cov_interact + j
962
+ if coef_pos < len(beta):
963
+ x_hat_j = X_with_intercept[cell_mask, coef_pos]
964
+ delta_total = delta_total + beta[coef_pos] * x_hat_j
965
+ eta_0 = eta_base - delta_total
966
+ att = _avg(_logistic(eta_base) - _logistic(eta_0), cell_mask)
967
+ # Delta method gradient: d(ATT)/d(β)
968
+ # for nuisance p: mean_i[(Λ'(η_1) - Λ'(η_0)) * X_p]
969
+ # for cell intercept: mean_i[Λ'(η_1)]
970
+ # for cell × cov j: mean_i[Λ'(η_1) * x_hat_j]
971
+ d_diff = _logistic_deriv(eta_base) - _logistic_deriv(eta_0)
972
+ grad = _avg_ax0(X_with_intercept[cell_mask] * d_diff[:, None], cell_mask)
973
+ grad[1 + idx] = _avg(_logistic_deriv(eta_base), cell_mask)
974
+ for j in range(n_cov_interact):
975
+ coef_pos = 1 + n_int + idx * n_cov_interact + j
976
+ if coef_pos < len(beta):
977
+ x_hat_j = X_with_intercept[cell_mask, coef_pos]
978
+ grad[coef_pos] = _avg(_logistic_deriv(eta_base) * x_hat_j, cell_mask)
979
+ # Compute SE in reduced parameter space if rank-deficient
980
+ if np.any(nan_mask):
981
+ grad_r = grad[kept_beta]
982
+ se = float(np.sqrt(max(grad_r @ vcov_reduced @ grad_r, 0.0)))
983
+ else:
984
+ se = float(np.sqrt(max(grad @ vcov_full @ grad, 0.0)))
985
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_inf)
986
+ gt_effects[(g, t)] = {
987
+ "att": att,
988
+ "se": se,
989
+ "t_stat": t_stat,
990
+ "p_value": p_value,
991
+ "conf_int": conf_int,
992
+ }
993
+ gt_weights[(g, t)] = int(cell_mask.sum())
994
+ # Store gradient in reduced space for aggregate SE
995
+ gt_grads[(g, t)] = grad[kept_beta] if np.any(nan_mask) else grad
996
+
997
+ gt_keys_ordered = [k for k in gt_keys if k in gt_effects]
998
+ # Use reduced vcov for all downstream SE computations
999
+ _vcov_se = vcov_reduced if np.any(nan_mask) else vcov_full
1000
+ # ATT-level covariance: J @ vcov @ J' where J rows are per-cell gradients
1001
+ if gt_keys_ordered:
1002
+ J = np.array([gt_grads[k] for k in gt_keys_ordered])
1003
+ gt_vcov = J @ _vcov_se @ J.T
1004
+ else:
1005
+ gt_vcov = None
1006
+
1007
+ # Overall SE via joint delta method: ∇β(overall_att) = Σ w_k/w_total * grad_k
1008
+ post_keys = [(g, t) for (g, t) in gt_keys_ordered if t >= g]
1009
+ w_total = sum(gt_weights.get(k, 0) for k in post_keys)
1010
+ if w_total > 0 and post_keys:
1011
+ overall_att = sum(gt_weights[k] * gt_effects[k]["att"] for k in post_keys) / w_total
1012
+ agg_grad = sum((gt_weights[k] / w_total) * gt_grads[k] for k in post_keys)
1013
+ overall_se = float(np.sqrt(max(agg_grad @ _vcov_se @ agg_grad, 0.0)))
1014
+ t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha, df=df_inf)
1015
+ overall = {
1016
+ "att": overall_att,
1017
+ "se": overall_se,
1018
+ "t_stat": t_stat,
1019
+ "p_value": p_value,
1020
+ "conf_int": conf_int,
1021
+ }
1022
+ else:
1023
+ overall = _compute_weighted_agg(
1024
+ gt_effects, gt_weights, gt_keys_ordered, None, self.alpha, df=df_inf
1025
+ )
1026
+
1027
+ return WooldridgeDiDResults(
1028
+ group_time_effects=gt_effects,
1029
+ overall_att=overall["att"],
1030
+ overall_se=overall["se"],
1031
+ overall_t_stat=overall["t_stat"],
1032
+ overall_p_value=overall["p_value"],
1033
+ overall_conf_int=overall["conf_int"],
1034
+ method=self.method,
1035
+ control_group=self.control_group,
1036
+ groups=groups,
1037
+ time_periods=sorted(sample[time].unique().tolist()),
1038
+ n_obs=len(sample),
1039
+ n_treated_units=int(sample[sample[cohort] > 0][unit].nunique()),
1040
+ n_control_units=self._count_control_units(sample, unit, cohort, time),
1041
+ alpha=self.alpha,
1042
+ anticipation=self.anticipation,
1043
+ survey_metadata=survey_metadata,
1044
+ _gt_weights=gt_weights,
1045
+ _gt_vcov=gt_vcov,
1046
+ _gt_keys=gt_keys_ordered,
1047
+ _df_survey=df_inf,
1048
+ )
1049
+
1050
+ def _fit_poisson(
1051
+ self,
1052
+ sample: pd.DataFrame,
1053
+ outcome: str,
1054
+ unit: str,
1055
+ time: str,
1056
+ cohort: str,
1057
+ X_int: np.ndarray,
1058
+ col_names: List[str],
1059
+ gt_keys: List[Tuple],
1060
+ int_col_names: List[str],
1061
+ groups: List[Any],
1062
+ n_cov_interact: int = 0,
1063
+ survey_design=None,
1064
+ ) -> WooldridgeDiDResults:
1065
+ """Poisson path: cohort + time additive FEs + solve_poisson + ASF ATT.
1066
+
1067
+ Matches Stata jwdid method(poisson): poisson y [treatment_interactions]
1068
+ i.gvar i.tvar — cohort main effects + time main effects (additive),
1069
+ not cohort×time saturated group FEs.
1070
+ """
1071
+ n_int = len(int_col_names)
1072
+
1073
+ # Design matrix: intercept + treatment interactions + cohort FEs + time FEs.
1074
+ # Matches Stata's `i.gvar i.tvar` + treatment interaction specification.
1075
+ # solve_poisson does not prepend an intercept, so we include one explicitly.
1076
+ intercept = np.ones((len(sample), 1))
1077
+ cohort_dummies = pd.get_dummies(sample[cohort], drop_first=True).values.astype(float)
1078
+ time_dummies = pd.get_dummies(sample[time], drop_first=True).values.astype(float)
1079
+ X_full = np.hstack([intercept, X_int, cohort_dummies, time_dummies])
1080
+ # Treatment interaction coefficients start at column index 1.
1081
+
1082
+ y = sample[outcome].values.astype(float)
1083
+ if not np.all(np.isfinite(y)):
1084
+ raise ValueError("Outcome contains non-finite values (NaN/Inf).")
1085
+ if np.any(y < 0):
1086
+ raise ValueError(
1087
+ f"method='poisson' requires non-negative outcomes. "
1088
+ f"Got minimum value {y.min():.4f}."
1089
+ )
1090
+ cluster_col = self.cluster if self.cluster else unit
1091
+ cluster_ids = sample[cluster_col].values
1092
+
1093
+ # Resolve survey design, inject cluster as PSU only when user explicitly set cluster=
1094
+ survey_cluster_ids = cluster_ids if self.cluster else None
1095
+ resolved, survey_weights, survey_weight_type, survey_metadata, df_inf = (
1096
+ _resolve_survey_for_wooldridge(survey_design, sample, survey_cluster_ids, self.cluster)
1097
+ )
1098
+ _has_survey = resolved is not None
1099
+
1100
+ beta, mu_hat = solve_poisson(
1101
+ X_full, y,
1102
+ rank_deficient_action=self.rank_deficient_action,
1103
+ weights=survey_weights,
1104
+ )
1105
+
1106
+ # Handle rank-deficient designs: compute vcov on reduced design.
1107
+ # Preserve raw interaction coefficients BEFORE zeroing NaN so the
1108
+ # NaN check in the ASF loop correctly skips dropped cells.
1109
+ nan_mask = np.isnan(beta)
1110
+ beta_int_raw = beta[1 : 1 + n_int].copy() # before zeroing
1111
+ beta_clean = np.where(nan_mask, 0.0, beta)
1112
+ kept_beta = ~nan_mask
1113
+
1114
+ # QMLE sandwich vcov
1115
+ resids = y - mu_hat
1116
+
1117
+ if _has_survey:
1118
+ # X_tilde trick for nonlinear survey vcov (V = mu for Poisson)
1119
+ from diff_diff.survey import compute_survey_vcov
1120
+ sqrt_V = np.sqrt(np.clip(mu_hat, 1e-20, None))
1121
+ X_tilde = X_full * sqrt_V[:, None]
1122
+ r_tilde = resids / sqrt_V
1123
+ if np.any(nan_mask):
1124
+ X_tilde_r = X_tilde[:, kept_beta]
1125
+ vcov_reduced = compute_survey_vcov(X_tilde_r, r_tilde, resolved)
1126
+ k_full = len(beta)
1127
+ vcov_full = np.full((k_full, k_full), np.nan)
1128
+ kept_idx = np.where(kept_beta)[0]
1129
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
1130
+ else:
1131
+ vcov_full = compute_survey_vcov(X_tilde, r_tilde, resolved)
1132
+ else:
1133
+ # Cluster-robust QMLE sandwich (non-survey path)
1134
+ if np.any(nan_mask):
1135
+ X_reduced = X_full[:, kept_beta]
1136
+ vcov_reduced = compute_robust_vcov(
1137
+ X_reduced,
1138
+ resids,
1139
+ cluster_ids=cluster_ids,
1140
+ weights=mu_hat,
1141
+ weight_type="aweight",
1142
+ )
1143
+ k_full = len(beta)
1144
+ vcov_full = np.full((k_full, k_full), np.nan)
1145
+ kept_idx = np.where(kept_beta)[0]
1146
+ vcov_full[np.ix_(kept_idx, kept_idx)] = vcov_reduced
1147
+ else:
1148
+ vcov_full = compute_robust_vcov(
1149
+ X_full,
1150
+ resids,
1151
+ cluster_ids=cluster_ids,
1152
+ weights=mu_hat,
1153
+ weight_type="aweight",
1154
+ )
1155
+ beta = beta_clean
1156
+
1157
+ # Treatment interaction coefficients (from cleaned beta for computation)
1158
+ beta_int = beta[1 : 1 + n_int]
1159
+
1160
+ # Survey-weighted averaging helpers for ASF computation
1161
+ def _avg(a, cell_mask):
1162
+ if survey_weights is not None:
1163
+ return float(np.average(a, weights=survey_weights[cell_mask]))
1164
+ return float(np.mean(a))
1165
+
1166
+ def _avg_ax0(a, cell_mask):
1167
+ if survey_weights is not None:
1168
+ return np.average(a, weights=survey_weights[cell_mask], axis=0)
1169
+ return np.mean(a, axis=0)
1170
+
1171
+ # ASF ATT(g,t) for treated units in each cell.
1172
+ # eta_base = X_full @ beta already includes the treatment effect (D_{g,t}=1).
1173
+ # Counterfactual: eta_0 = eta_base - delta (treatment switched off).
1174
+ # ATT = E[exp(η_1)] - E[exp(η_0)] = E[exp(η_base)] - E[exp(η_base - δ)]
1175
+ gt_effects: Dict[Tuple, Dict] = {}
1176
+ gt_weights: Dict[Tuple, int] = {}
1177
+ gt_grads: Dict[Tuple, np.ndarray] = {} # per-cell gradients for aggregate SE
1178
+ for idx, (g, t) in enumerate(gt_keys):
1179
+ if idx >= n_int:
1180
+ break
1181
+ cell_mask = (sample[cohort] == g) & (sample[time] == t)
1182
+ if cell_mask.sum() == 0:
1183
+ continue
1184
+ # Skip cells whose interaction coefficient was dropped (rank deficiency).
1185
+ # Use raw coefficients (before NaN->0 zeroing) to detect dropped cells.
1186
+ if np.isnan(beta_int_raw[idx]):
1187
+ continue
1188
+ # Skip cells where all survey weights are zero (non-estimable)
1189
+ if survey_weights is not None and np.sum(survey_weights[cell_mask]) == 0:
1190
+ continue
1191
+ delta = beta_int[idx]
1192
+ if np.isnan(delta):
1193
+ continue
1194
+ eta_base = np.clip(X_full[cell_mask] @ beta, -500, 500)
1195
+ # Counterfactual: zero the FULL treatment block (W2023 Eq. 3.15)
1196
+ delta_total = np.full(cell_mask.sum(), float(delta))
1197
+ for j in range(n_cov_interact):
1198
+ coef_pos = 1 + n_int + idx * n_cov_interact + j
1199
+ if coef_pos < len(beta):
1200
+ x_hat_j = X_full[cell_mask, coef_pos]
1201
+ delta_total = delta_total + beta[coef_pos] * x_hat_j
1202
+ eta_0 = eta_base - delta_total
1203
+ mu_1 = np.exp(eta_base)
1204
+ mu_0 = np.exp(eta_0)
1205
+ att = _avg(mu_1 - mu_0, cell_mask)
1206
+ # Delta method gradient:
1207
+ # for nuisance p: mean_i[(μ_1 - μ_0) * X_p]
1208
+ # for cell intercept: mean_i[μ_1]
1209
+ # for cell × cov j: mean_i[μ_1 * x_hat_j]
1210
+ diff_mu = mu_1 - mu_0
1211
+ grad = _avg_ax0(X_full[cell_mask] * diff_mu[:, None], cell_mask)
1212
+ grad[1 + idx] = _avg(mu_1, cell_mask)
1213
+ for j in range(n_cov_interact):
1214
+ coef_pos = 1 + n_int + idx * n_cov_interact + j
1215
+ if coef_pos < len(beta):
1216
+ x_hat_j = X_full[cell_mask, coef_pos]
1217
+ grad[coef_pos] = _avg(mu_1 * x_hat_j, cell_mask)
1218
+ # Compute SE in reduced parameter space if rank-deficient
1219
+ if np.any(nan_mask):
1220
+ grad_r = grad[kept_beta]
1221
+ se = float(np.sqrt(max(grad_r @ vcov_reduced @ grad_r, 0.0)))
1222
+ else:
1223
+ se = float(np.sqrt(max(grad @ vcov_full @ grad, 0.0)))
1224
+ t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_inf)
1225
+ gt_effects[(g, t)] = {
1226
+ "att": att,
1227
+ "se": se,
1228
+ "t_stat": t_stat,
1229
+ "p_value": p_value,
1230
+ "conf_int": conf_int,
1231
+ }
1232
+ gt_weights[(g, t)] = int(cell_mask.sum())
1233
+ gt_grads[(g, t)] = grad[kept_beta] if np.any(nan_mask) else grad
1234
+
1235
+ gt_keys_ordered = [k for k in gt_keys if k in gt_effects]
1236
+ _vcov_se = vcov_reduced if np.any(nan_mask) else vcov_full
1237
+ # ATT-level covariance: J @ vcov @ J' where J rows are per-cell gradients
1238
+ if gt_keys_ordered:
1239
+ J = np.array([gt_grads[k] for k in gt_keys_ordered])
1240
+ gt_vcov = J @ _vcov_se @ J.T
1241
+ else:
1242
+ gt_vcov = None
1243
+
1244
+ # Overall SE via joint delta method
1245
+ post_keys = [(g, t) for (g, t) in gt_keys_ordered if t >= g]
1246
+ w_total = sum(gt_weights.get(k, 0) for k in post_keys)
1247
+ if w_total > 0 and post_keys:
1248
+ overall_att = sum(gt_weights[k] * gt_effects[k]["att"] for k in post_keys) / w_total
1249
+ agg_grad = sum((gt_weights[k] / w_total) * gt_grads[k] for k in post_keys)
1250
+ overall_se = float(np.sqrt(max(agg_grad @ _vcov_se @ agg_grad, 0.0)))
1251
+ t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha, df=df_inf)
1252
+ overall = {
1253
+ "att": overall_att,
1254
+ "se": overall_se,
1255
+ "t_stat": t_stat,
1256
+ "p_value": p_value,
1257
+ "conf_int": conf_int,
1258
+ }
1259
+ else:
1260
+ overall = _compute_weighted_agg(
1261
+ gt_effects, gt_weights, gt_keys_ordered, None, self.alpha, df=df_inf
1262
+ )
1263
+
1264
+ return WooldridgeDiDResults(
1265
+ group_time_effects=gt_effects,
1266
+ overall_att=overall["att"],
1267
+ overall_se=overall["se"],
1268
+ overall_t_stat=overall["t_stat"],
1269
+ overall_p_value=overall["p_value"],
1270
+ overall_conf_int=overall["conf_int"],
1271
+ method=self.method,
1272
+ control_group=self.control_group,
1273
+ groups=groups,
1274
+ time_periods=sorted(sample[time].unique().tolist()),
1275
+ n_obs=len(sample),
1276
+ n_treated_units=int(sample[sample[cohort] > 0][unit].nunique()),
1277
+ n_control_units=self._count_control_units(sample, unit, cohort, time),
1278
+ alpha=self.alpha,
1279
+ anticipation=self.anticipation,
1280
+ survey_metadata=survey_metadata,
1281
+ _gt_weights=gt_weights,
1282
+ _gt_vcov=gt_vcov,
1283
+ _gt_keys=gt_keys_ordered,
1284
+ _df_survey=df_inf,
1285
+ )