diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,1501 @@
1
+ """
2
+ Difference-in-Differences estimators with sklearn-like API.
3
+
4
+ This module contains the core DiD estimators:
5
+ - DifferenceInDifferences: Basic 2x2 DiD estimator
6
+ - MultiPeriodDiD: Event-study style DiD with period-specific treatment effects
7
+
8
+ Additional estimators are in separate modules:
9
+ - TwoWayFixedEffects: See diff_diff.twfe
10
+ - SyntheticDiD: See diff_diff.synthetic_did
11
+
12
+ For backward compatibility, all estimators are re-exported from this module.
13
+ """
14
+
15
+ import warnings
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ from diff_diff.linalg import (
22
+ LinearRegression,
23
+ _expand_vcov_with_nan,
24
+ compute_r_squared,
25
+ compute_robust_vcov,
26
+ solve_ols,
27
+ )
28
+ from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
29
+ from diff_diff.utils import (
30
+ WildBootstrapResults,
31
+ demean_by_group,
32
+ safe_inference,
33
+ validate_binary,
34
+ wild_bootstrap_se,
35
+ )
36
+
37
+
38
+ class DifferenceInDifferences:
39
+ """
40
+ Difference-in-Differences estimator with sklearn-like interface.
41
+
42
+ Estimates the Average Treatment effect on the Treated (ATT) using
43
+ the canonical 2x2 DiD design or panel data with two-way fixed effects.
44
+
45
+ Parameters
46
+ ----------
47
+ formula : str, optional
48
+ R-style formula for the model (e.g., "outcome ~ treated * post").
49
+ If provided, overrides column name parameters.
50
+ robust : bool, default=True
51
+ Whether to use heteroskedasticity-robust standard errors (HC1).
52
+ cluster : str, optional
53
+ Column name for cluster-robust standard errors.
54
+ alpha : float, default=0.05
55
+ Significance level for confidence intervals.
56
+ inference : str, default="analytical"
57
+ Inference method: "analytical" for standard asymptotic inference,
58
+ or "wild_bootstrap" for wild cluster bootstrap (recommended when
59
+ number of clusters is small, <50).
60
+ n_bootstrap : int, default=999
61
+ Number of bootstrap replications when inference="wild_bootstrap".
62
+ bootstrap_weights : str, default="rademacher"
63
+ Type of bootstrap weights: "rademacher" (standard), "webb"
64
+ (recommended for <10 clusters), or "mammen" (skewness correction).
65
+ seed : int, optional
66
+ Random seed for reproducibility when using bootstrap inference.
67
+ If None (default), results will vary between runs.
68
+ rank_deficient_action : str, default "warn"
69
+ Action when design matrix is rank-deficient (linearly dependent columns):
70
+ - "warn": Issue warning and drop linearly dependent columns (default)
71
+ - "error": Raise ValueError
72
+ - "silent": Drop columns silently without warning
73
+
74
+ Attributes
75
+ ----------
76
+ results_ : DiDResults
77
+ Estimation results after calling fit().
78
+ is_fitted_ : bool
79
+ Whether the model has been fitted.
80
+
81
+ Examples
82
+ --------
83
+ Basic usage with a DataFrame:
84
+
85
+ >>> import pandas as pd
86
+ >>> from diff_diff import DifferenceInDifferences
87
+ >>>
88
+ >>> # Create sample data
89
+ >>> data = pd.DataFrame({
90
+ ... 'outcome': [10, 11, 15, 18, 9, 10, 12, 13],
91
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
92
+ ... 'post': [0, 0, 1, 1, 0, 0, 1, 1]
93
+ ... })
94
+ >>>
95
+ >>> # Fit the model
96
+ >>> did = DifferenceInDifferences()
97
+ >>> results = did.fit(data, outcome='outcome', treatment='treated', time='post')
98
+ >>>
99
+ >>> # View results
100
+ >>> print(results.att) # ATT estimate
101
+ >>> results.print_summary() # Full summary table
102
+
103
+ Using formula interface:
104
+
105
+ >>> did = DifferenceInDifferences()
106
+ >>> results = did.fit(data, formula='outcome ~ treated * post')
107
+
108
+ Notes
109
+ -----
110
+ The ATT is computed using the standard DiD formula:
111
+
112
+ ATT = (E[Y|D=1,T=1] - E[Y|D=1,T=0]) - (E[Y|D=0,T=1] - E[Y|D=0,T=0])
113
+
114
+ Or equivalently via OLS regression:
115
+
116
+ Y = α + β₁*D + β₂*T + β₃*(D×T) + ε
117
+
118
+ Where β₃ is the ATT.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ robust: bool = True,
124
+ cluster: Optional[str] = None,
125
+ alpha: float = 0.05,
126
+ inference: str = "analytical",
127
+ n_bootstrap: int = 999,
128
+ bootstrap_weights: str = "rademacher",
129
+ seed: Optional[int] = None,
130
+ rank_deficient_action: str = "warn",
131
+ ):
132
+ self.robust = robust
133
+ self.cluster = cluster
134
+ self.alpha = alpha
135
+ self.inference = inference
136
+ self.n_bootstrap = n_bootstrap
137
+ self.bootstrap_weights = bootstrap_weights
138
+ self.seed = seed
139
+ self.rank_deficient_action = rank_deficient_action
140
+
141
+ self.is_fitted_ = False
142
+ self.results_ = None
143
+ self._coefficients = None
144
+ self._vcov = None
145
+ self._bootstrap_results = None # Store WildBootstrapResults if used
146
+
147
+ def fit(
148
+ self,
149
+ data: pd.DataFrame,
150
+ outcome: Optional[str] = None,
151
+ treatment: Optional[str] = None,
152
+ time: Optional[str] = None,
153
+ formula: Optional[str] = None,
154
+ covariates: Optional[List[str]] = None,
155
+ fixed_effects: Optional[List[str]] = None,
156
+ absorb: Optional[List[str]] = None,
157
+ survey_design=None,
158
+ ) -> DiDResults:
159
+ """
160
+ Fit the Difference-in-Differences model.
161
+
162
+ Parameters
163
+ ----------
164
+ data : pd.DataFrame
165
+ DataFrame containing the outcome, treatment, and time variables.
166
+ outcome : str
167
+ Name of the outcome variable column.
168
+ treatment : str
169
+ Name of the treatment group indicator column (0/1).
170
+ time : str
171
+ Name of the post-treatment period indicator column (0/1).
172
+ formula : str, optional
173
+ R-style formula (e.g., "outcome ~ treated * post").
174
+ If provided, overrides outcome, treatment, and time parameters.
175
+ covariates : list, optional
176
+ List of covariate column names to include as linear controls.
177
+ fixed_effects : list, optional
178
+ List of categorical column names to include as fixed effects.
179
+ Creates dummy variables for each category (drops first level).
180
+ Use for low-dimensional fixed effects (e.g., industry, region).
181
+ absorb : list, optional
182
+ List of categorical column names for high-dimensional fixed effects.
183
+ Uses within-transformation (demeaning) instead of dummy variables.
184
+ More efficient for large numbers of categories (e.g., firm, individual).
185
+ survey_design : SurveyDesign, optional
186
+ Survey design specification for design-based inference. When provided,
187
+ uses Taylor Series Linearization for variance estimation and
188
+ applies sampling weights to the regression.
189
+
190
+ Returns
191
+ -------
192
+ DiDResults
193
+ Object containing estimation results.
194
+
195
+ Raises
196
+ ------
197
+ ValueError
198
+ If required parameters are missing or data validation fails.
199
+
200
+ Examples
201
+ --------
202
+ Using fixed effects (dummy variables):
203
+
204
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
205
+ ... fixed_effects=['state', 'industry'])
206
+
207
+ Using absorbed fixed effects (within-transformation):
208
+
209
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
210
+ ... absorb=['firm_id'])
211
+ """
212
+ # Parse formula if provided
213
+ if formula is not None:
214
+ outcome, treatment, time, covariates = self._parse_formula(formula, data)
215
+ elif outcome is None or treatment is None or time is None:
216
+ raise ValueError(
217
+ "Must provide either 'formula' or all of 'outcome', 'treatment', and 'time'"
218
+ )
219
+
220
+ # Validate inputs
221
+ self._validate_data(data, outcome, treatment, time, covariates)
222
+
223
+ # Validate binary variables BEFORE any transformations
224
+ validate_binary(data[treatment].values, "treatment")
225
+ validate_binary(data[time].values, "time")
226
+
227
+ # Validate fixed effects and absorb columns
228
+ if fixed_effects:
229
+ for fe in fixed_effects:
230
+ if fe not in data.columns:
231
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
232
+ if absorb:
233
+ for ab in absorb:
234
+ if ab not in data.columns:
235
+ raise ValueError(f"Absorb column '{ab}' not found in data")
236
+
237
+ # Resolve survey design if provided
238
+ from diff_diff.survey import _resolve_effective_cluster, _resolve_survey_for_fit
239
+
240
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
241
+ _resolve_survey_for_fit(survey_design, data, self.inference)
242
+ )
243
+ _uses_replicate = (
244
+ resolved_survey is not None and resolved_survey.uses_replicate_variance
245
+ )
246
+ if _uses_replicate and self.inference == "wild_bootstrap":
247
+ raise ValueError(
248
+ "Cannot use inference='wild_bootstrap' with replicate-weight "
249
+ "survey designs. Replicate weights provide their own variance "
250
+ "estimation."
251
+ )
252
+
253
+ # Handle absorbed fixed effects (within-transformation)
254
+ working_data = data.copy()
255
+ absorbed_vars = []
256
+ n_absorbed_effects = 0
257
+
258
+ # Save raw treatment counts before absorb demeaning
259
+ n_treated_raw = int(np.sum(data[treatment].values.astype(float)))
260
+ n_control_raw = len(data) - n_treated_raw
261
+
262
+ # Reject multi-absorb with survey weights (single-pass demeaning is
263
+ # not the correct weighted FWL projection for N > 1 dimensions)
264
+ if absorb and len(absorb) > 1 and survey_weights is not None:
265
+ raise ValueError(
266
+ f"Multiple absorbed fixed effects (absorb={absorb}) with survey "
267
+ "weights is not supported. Single-pass sequential demeaning is not "
268
+ "the correct weighted FWL projection for multiple absorbed dimensions. "
269
+ "Use absorb with a single variable, or use fixed_effects= instead."
270
+ )
271
+
272
+ if absorb and fixed_effects:
273
+ raise ValueError(
274
+ "Cannot use both absorb and fixed_effects. "
275
+ "The absorb within-transformation does not residualize "
276
+ "fixed_effects dummies, violating the FWL theorem. "
277
+ "Use absorb alone (for high-dimensional FE) "
278
+ "or fixed_effects alone (for low-dimensional FE)."
279
+ )
280
+
281
+ if absorb:
282
+ # FWL theorem: demean ALL regressors alongside outcome.
283
+ # Regressors collinear with absorbed FE (e.g., treatment after
284
+ # absorbing unit FE) will zero out and be handled by rank-deficiency.
285
+ working_data["_treat_time"] = working_data[treatment].values.astype(
286
+ float
287
+ ) * working_data[time].values.astype(float)
288
+ vars_to_demean = [outcome, treatment, time, "_treat_time"] + (covariates or [])
289
+ for ab_var in absorb:
290
+ working_data, n_fe = demean_by_group(
291
+ working_data,
292
+ vars_to_demean,
293
+ ab_var,
294
+ inplace=True,
295
+ weights=survey_weights,
296
+ )
297
+ n_absorbed_effects += n_fe
298
+ absorbed_vars.append(ab_var)
299
+
300
+ # Extract variables (may be demeaned if absorb was used)
301
+ y = working_data[outcome].values.astype(float)
302
+ d = working_data[treatment].values.astype(float)
303
+ t = working_data[time].values.astype(float)
304
+
305
+ # Create interaction term
306
+ if absorb:
307
+ dt = working_data["_treat_time"].values.astype(float)
308
+ else:
309
+ dt = d * t
310
+
311
+ # Build design matrix
312
+ X = np.column_stack([np.ones(len(y)), d, t, dt])
313
+ var_names = ["const", treatment, time, f"{treatment}:{time}"]
314
+
315
+ # Add covariates if provided
316
+ if covariates:
317
+ for cov in covariates:
318
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
319
+ var_names.append(cov)
320
+
321
+ # Add fixed effects as dummy variables
322
+ if fixed_effects:
323
+ for fe in fixed_effects:
324
+ # Create dummies, drop first category to avoid multicollinearity
325
+ # Use working_data to be consistent with absorbed FE if both are used
326
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
327
+ for col in dummies.columns:
328
+ X = np.column_stack([X, dummies[col].values.astype(float)])
329
+ var_names.append(col)
330
+
331
+ # Extract ATT index (coefficient on interaction term)
332
+ att_idx = 3 # Index of interaction term
333
+ att_var_name = f"{treatment}:{time}"
334
+ assert var_names[att_idx] == att_var_name, (
335
+ f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, "
336
+ f"but found '{var_names[att_idx]}'"
337
+ )
338
+
339
+ # Always use LinearRegression for initial fit (unified code path)
340
+ # For wild bootstrap, we don't need cluster SEs from the initial fit
341
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
342
+
343
+ # When survey PSU is present, it overrides cluster for variance estimation
344
+ effective_cluster_ids = _resolve_effective_cluster(
345
+ resolved_survey, cluster_ids, self.cluster
346
+ )
347
+
348
+ # Inject cluster as effective PSU for survey variance estimation
349
+ if resolved_survey is not None and effective_cluster_ids is not None:
350
+ from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata
351
+
352
+ resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
353
+ if resolved_survey.psu is not None and survey_metadata is not None:
354
+ raw_w = (
355
+ data[survey_design.weights].values.astype(np.float64)
356
+ if survey_design.weights
357
+ else np.ones(len(data), dtype=np.float64)
358
+ )
359
+ survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
360
+
361
+ # When absorb + replicate: pass survey_design=None to prevent
362
+ # LinearRegression from computing replicate vcov on already-demeaned
363
+ # data (demeaning depends on weights, so replicate refits must re-demean).
364
+ _lr_survey = resolved_survey
365
+ if _uses_replicate and absorbed_vars:
366
+ _lr_survey = None
367
+
368
+ reg = LinearRegression(
369
+ include_intercept=False, # Intercept already in X
370
+ robust=self.robust,
371
+ cluster_ids=effective_cluster_ids if self.inference != "wild_bootstrap" else None,
372
+ alpha=self.alpha,
373
+ rank_deficient_action=self.rank_deficient_action,
374
+ weights=survey_weights,
375
+ weight_type=survey_weight_type,
376
+ survey_design=_lr_survey,
377
+ ).fit(X, y, df_adjustment=n_absorbed_effects)
378
+
379
+ coefficients = reg.coefficients_
380
+ residuals = reg.residuals_
381
+ fitted = reg.fitted_values_
382
+ assert coefficients is not None
383
+ att = coefficients[att_idx]
384
+
385
+ # Get inference - replicate absorb override, bootstrap, or analytical
386
+ if _uses_replicate and absorbed_vars:
387
+ # Estimator-level replicate variance: re-demean + re-solve per replicate
388
+ from diff_diff.survey import compute_replicate_refit_variance
389
+ from diff_diff.utils import safe_inference
390
+
391
+ _absorb_list = list(absorbed_vars) # capture for closure
392
+
393
+ # Handle rank-deficient nuisance: refit only identified columns
394
+ _id_mask = ~np.isnan(coefficients)
395
+ _id_cols = np.where(_id_mask)[0]
396
+ _att_idx_reduced = int(np.searchsorted(_id_cols, att_idx))
397
+
398
+ def _refit_did_absorb(w_r):
399
+ nz = w_r > 0
400
+ wd = data[nz].copy()
401
+ w_nz = w_r[nz]
402
+ wd["_treat_time"] = (
403
+ wd[treatment].values.astype(float) * wd[time].values.astype(float)
404
+ )
405
+ vars_dm = [outcome, treatment, time, "_treat_time"] + (covariates or [])
406
+ for ab_var in _absorb_list:
407
+ wd, _ = demean_by_group(wd, vars_dm, ab_var, inplace=True, weights=w_nz)
408
+ y_r = wd[outcome].values.astype(float)
409
+ d_r = wd[treatment].values.astype(float)
410
+ t_r = wd[time].values.astype(float)
411
+ dt_r = wd["_treat_time"].values.astype(float)
412
+ X_r = np.column_stack([np.ones(len(y_r)), d_r, t_r, dt_r])
413
+ if covariates:
414
+ for cov in covariates:
415
+ X_r = np.column_stack([X_r, wd[cov].values.astype(float)])
416
+ coef_r, _, _ = solve_ols(
417
+ X_r[:, _id_cols], y_r,
418
+ weights=w_nz, weight_type=survey_weight_type,
419
+ rank_deficient_action="silent", return_vcov=False,
420
+ )
421
+ return coef_r
422
+
423
+ vcov_reduced, _n_valid_rep = compute_replicate_refit_variance(
424
+ _refit_did_absorb, coefficients[_id_mask], resolved_survey
425
+ )
426
+ vcov = _expand_vcov_with_nan(vcov_reduced, len(coefficients), _id_cols)
427
+ se = float(np.sqrt(max(vcov[att_idx, att_idx], 0.0)))
428
+ _df_rep = (
429
+ survey_metadata.df_survey
430
+ if survey_metadata and survey_metadata.df_survey
431
+ else 0 # rank-deficient replicate → NaN inference
432
+ )
433
+ if _n_valid_rep < resolved_survey.n_replicates:
434
+ _df_rep = _n_valid_rep - 1 if _n_valid_rep > 1 else 0
435
+ if survey_metadata is not None:
436
+ survey_metadata.df_survey = _df_rep if _df_rep > 0 else None
437
+ t_stat, p_value, conf_int = safe_inference(
438
+ att, se, alpha=self.alpha, df=_df_rep
439
+ )
440
+ elif self.inference == "wild_bootstrap" and self.cluster is not None:
441
+ # Override with wild cluster bootstrap inference
442
+ se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
443
+ X, y, residuals, cluster_ids, att_idx
444
+ )
445
+ else:
446
+ # Use analytical inference from LinearRegression
447
+ # (handles replicate vcov for no-absorb path automatically)
448
+ vcov = reg.vcov_
449
+ inference = reg.get_inference(att_idx)
450
+ se = inference.se
451
+ t_stat = inference.t_stat
452
+ p_value = inference.p_value
453
+ conf_int = inference.conf_int
454
+
455
+ r_squared = compute_r_squared(y, residuals)
456
+
457
+ # Count observations (use raw counts to avoid demeaned values from absorb)
458
+ n_treated = n_treated_raw
459
+ n_control = n_control_raw
460
+
461
+ # Create coefficient dictionary
462
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
463
+
464
+ # Determine inference method and bootstrap info
465
+ inference_method = "analytical"
466
+ n_bootstrap_used = None
467
+ n_clusters_used = None
468
+ if self._bootstrap_results is not None:
469
+ inference_method = "wild_bootstrap"
470
+ n_bootstrap_used = self._bootstrap_results.n_bootstrap
471
+ n_clusters_used = self._bootstrap_results.n_clusters
472
+
473
+ # Store results
474
+ self.results_ = DiDResults(
475
+ att=att,
476
+ se=se,
477
+ t_stat=t_stat,
478
+ p_value=p_value,
479
+ conf_int=conf_int,
480
+ n_obs=len(y),
481
+ n_treated=n_treated,
482
+ n_control=n_control,
483
+ alpha=self.alpha,
484
+ coefficients=coef_dict,
485
+ vcov=vcov,
486
+ residuals=residuals,
487
+ fitted_values=fitted,
488
+ r_squared=r_squared,
489
+ inference_method=inference_method,
490
+ n_bootstrap=n_bootstrap_used,
491
+ n_clusters=n_clusters_used,
492
+ survey_metadata=survey_metadata,
493
+ )
494
+
495
+ self._coefficients = coefficients
496
+ self._vcov = vcov
497
+ self.is_fitted_ = True
498
+
499
+ return self.results_
500
+
501
+ def _fit_ols(
502
+ self, X: np.ndarray, y: np.ndarray
503
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
504
+ """
505
+ Fit OLS regression.
506
+
507
+ This method is kept for backwards compatibility. Internally uses the
508
+ unified solve_ols from diff_diff.linalg for optimized computation.
509
+
510
+ Parameters
511
+ ----------
512
+ X : np.ndarray
513
+ Design matrix.
514
+ y : np.ndarray
515
+ Outcome vector.
516
+
517
+ Returns
518
+ -------
519
+ tuple
520
+ (coefficients, residuals, fitted_values, r_squared)
521
+ """
522
+ # Use unified OLS backend
523
+ coefficients, residuals, fitted, _ = solve_ols(X, y, return_fitted=True, return_vcov=False)
524
+ r_squared = compute_r_squared(y, residuals)
525
+
526
+ return coefficients, residuals, fitted, r_squared
527
+
528
+ def _run_wild_bootstrap_inference(
529
+ self,
530
+ X: np.ndarray,
531
+ y: np.ndarray,
532
+ residuals: np.ndarray,
533
+ cluster_ids: np.ndarray,
534
+ coefficient_index: int,
535
+ ) -> Tuple[float, float, Tuple[float, float], float, np.ndarray, WildBootstrapResults]:
536
+ """
537
+ Run wild cluster bootstrap inference.
538
+
539
+ Parameters
540
+ ----------
541
+ X : np.ndarray
542
+ Design matrix.
543
+ y : np.ndarray
544
+ Outcome vector.
545
+ residuals : np.ndarray
546
+ OLS residuals.
547
+ cluster_ids : np.ndarray
548
+ Cluster identifiers for each observation.
549
+ coefficient_index : int
550
+ Index of the coefficient to compute inference for.
551
+
552
+ Returns
553
+ -------
554
+ tuple
555
+ (se, p_value, conf_int, t_stat, vcov, bootstrap_results)
556
+ """
557
+ bootstrap_results = wild_bootstrap_se(
558
+ X,
559
+ y,
560
+ residuals,
561
+ cluster_ids,
562
+ coefficient_index=coefficient_index,
563
+ n_bootstrap=self.n_bootstrap,
564
+ weight_type=self.bootstrap_weights,
565
+ alpha=self.alpha,
566
+ seed=self.seed,
567
+ return_distribution=False,
568
+ )
569
+ self._bootstrap_results = bootstrap_results
570
+
571
+ se = bootstrap_results.se
572
+ p_value = bootstrap_results.p_value
573
+ conf_int = (bootstrap_results.ci_lower, bootstrap_results.ci_upper)
574
+ t_stat = bootstrap_results.t_stat_original
575
+
576
+ # Also compute vcov for storage (using cluster-robust for consistency)
577
+ vcov = compute_robust_vcov(X, residuals, cluster_ids)
578
+
579
+ return se, p_value, conf_int, t_stat, vcov, bootstrap_results
580
+
581
+ def _parse_formula(
582
+ self, formula: str, data: pd.DataFrame
583
+ ) -> Tuple[str, str, str, Optional[List[str]]]:
584
+ """
585
+ Parse R-style formula.
586
+
587
+ Supports basic formulas like:
588
+ - "outcome ~ treatment * time"
589
+ - "outcome ~ treatment + time + treatment:time"
590
+ - "outcome ~ treatment * time + covariate1 + covariate2"
591
+
592
+ Parameters
593
+ ----------
594
+ formula : str
595
+ R-style formula string.
596
+ data : pd.DataFrame
597
+ DataFrame to validate column names against.
598
+
599
+ Returns
600
+ -------
601
+ tuple
602
+ (outcome, treatment, time, covariates)
603
+ """
604
+ # Split into LHS and RHS
605
+ if "~" not in formula:
606
+ raise ValueError("Formula must contain '~' to separate outcome from predictors")
607
+
608
+ lhs, rhs = formula.split("~")
609
+ outcome = lhs.strip()
610
+
611
+ # Parse RHS
612
+ rhs = rhs.strip()
613
+
614
+ # Check for interaction term
615
+ if "*" in rhs:
616
+ # Handle "treatment * time" syntax
617
+ parts = rhs.split("*")
618
+ if len(parts) != 2:
619
+ raise ValueError("Currently only supports single interaction (treatment * time)")
620
+
621
+ treatment = parts[0].strip()
622
+ time = parts[1].strip()
623
+
624
+ # Check for additional covariates after interaction
625
+ if "+" in time:
626
+ time_parts = time.split("+")
627
+ time = time_parts[0].strip()
628
+ covariates = [p.strip() for p in time_parts[1:]]
629
+ else:
630
+ covariates = None
631
+
632
+ elif ":" in rhs:
633
+ # Handle explicit interaction syntax
634
+ terms = [t.strip() for t in rhs.split("+")]
635
+ interaction_term = None
636
+ main_effects = []
637
+ covariates = []
638
+
639
+ for term in terms:
640
+ if ":" in term:
641
+ interaction_term = term
642
+ else:
643
+ main_effects.append(term)
644
+
645
+ if interaction_term is None:
646
+ raise ValueError("Formula must contain an interaction term (treatment:time)")
647
+
648
+ treatment, time = [t.strip() for t in interaction_term.split(":")]
649
+
650
+ # Remaining terms after treatment and time are covariates
651
+ for term in main_effects:
652
+ if term != treatment and term != time:
653
+ covariates.append(term)
654
+
655
+ covariates = covariates if covariates else None
656
+ else:
657
+ raise ValueError(
658
+ "Formula must contain interaction term. "
659
+ "Use 'outcome ~ treatment * time' or 'outcome ~ treatment + time + treatment:time'"
660
+ )
661
+
662
+ # Validate columns exist
663
+ for col in [outcome, treatment, time]:
664
+ if col not in data.columns:
665
+ raise ValueError(f"Column '{col}' not found in data")
666
+
667
+ if covariates:
668
+ for cov in covariates:
669
+ if cov not in data.columns:
670
+ raise ValueError(f"Covariate '{cov}' not found in data")
671
+
672
+ return outcome, treatment, time, covariates
673
+
674
+ def _validate_data(
675
+ self,
676
+ data: pd.DataFrame,
677
+ outcome: str,
678
+ treatment: str,
679
+ time: str,
680
+ covariates: Optional[List[str]] = None,
681
+ ) -> None:
682
+ """Validate input data."""
683
+ # Check DataFrame
684
+ if not isinstance(data, pd.DataFrame):
685
+ raise TypeError("data must be a pandas DataFrame")
686
+
687
+ # Check required columns exist
688
+ required_cols = [outcome, treatment, time]
689
+ if covariates:
690
+ required_cols.extend(covariates)
691
+
692
+ missing_cols = [col for col in required_cols if col not in data.columns]
693
+ if missing_cols:
694
+ raise ValueError(f"Missing columns in data: {missing_cols}")
695
+
696
+ # Check for missing values
697
+ for col in required_cols:
698
+ if data[col].isna().any():
699
+ raise ValueError(f"Column '{col}' contains missing values")
700
+
701
+ # Check for sufficient variation
702
+ if data[treatment].nunique() < 2:
703
+ raise ValueError("Treatment variable must have both 0 and 1 values")
704
+ if data[time].nunique() < 2:
705
+ raise ValueError("Time variable must have both 0 and 1 values")
706
+
707
+ def predict(self, data: pd.DataFrame) -> np.ndarray:
708
+ """
709
+ Predict outcomes using fitted model.
710
+
711
+ Parameters
712
+ ----------
713
+ data : pd.DataFrame
714
+ DataFrame with same structure as training data.
715
+
716
+ Returns
717
+ -------
718
+ np.ndarray
719
+ Predicted values.
720
+ """
721
+ if not self.is_fitted_:
722
+ raise RuntimeError("Model must be fitted before calling predict()")
723
+
724
+ # This is a placeholder - would need to store column names
725
+ # for full implementation
726
+ raise NotImplementedError(
727
+ "predict() is not yet implemented. "
728
+ "Use results_.fitted_values for training data predictions."
729
+ )
730
+
731
+ def get_params(self) -> Dict[str, Any]:
732
+ """
733
+ Get estimator parameters (sklearn-compatible).
734
+
735
+ Returns
736
+ -------
737
+ Dict[str, Any]
738
+ Estimator parameters.
739
+ """
740
+ return {
741
+ "robust": self.robust,
742
+ "cluster": self.cluster,
743
+ "alpha": self.alpha,
744
+ "inference": self.inference,
745
+ "n_bootstrap": self.n_bootstrap,
746
+ "bootstrap_weights": self.bootstrap_weights,
747
+ "seed": self.seed,
748
+ "rank_deficient_action": self.rank_deficient_action,
749
+ }
750
+
751
+ def set_params(self, **params) -> "DifferenceInDifferences":
752
+ """
753
+ Set estimator parameters (sklearn-compatible).
754
+
755
+ Parameters
756
+ ----------
757
+ **params
758
+ Estimator parameters.
759
+
760
+ Returns
761
+ -------
762
+ self
763
+ """
764
+ for key, value in params.items():
765
+ if hasattr(self, key):
766
+ setattr(self, key, value)
767
+ else:
768
+ raise ValueError(f"Unknown parameter: {key}")
769
+ return self
770
+
771
+ def summary(self) -> str:
772
+ """
773
+ Get summary of estimation results.
774
+
775
+ Returns
776
+ -------
777
+ str
778
+ Formatted summary.
779
+ """
780
+ if not self.is_fitted_:
781
+ raise RuntimeError("Model must be fitted before calling summary()")
782
+ assert self.results_ is not None
783
+ return self.results_.summary()
784
+
785
+ def print_summary(self) -> None:
786
+ """Print summary to stdout."""
787
+ print(self.summary())
788
+
789
+
790
+ class MultiPeriodDiD(DifferenceInDifferences):
791
+ """
792
+ Multi-Period Difference-in-Differences estimator.
793
+
794
+ Extends the standard DiD to handle multiple pre-treatment and
795
+ post-treatment time periods, providing period-specific treatment
796
+ effects as well as an aggregate average treatment effect.
797
+
798
+ Parameters
799
+ ----------
800
+ robust : bool, default=True
801
+ Whether to use heteroskedasticity-robust standard errors (HC1).
802
+ cluster : str, optional
803
+ Column name for cluster-robust standard errors.
804
+ alpha : float, default=0.05
805
+ Significance level for confidence intervals.
806
+
807
+ Attributes
808
+ ----------
809
+ results_ : MultiPeriodDiDResults
810
+ Estimation results after calling fit().
811
+ is_fitted_ : bool
812
+ Whether the model has been fitted.
813
+
814
+ Examples
815
+ --------
816
+ Basic usage with multiple time periods:
817
+
818
+ >>> import pandas as pd
819
+ >>> from diff_diff import MultiPeriodDiD
820
+ >>>
821
+ >>> # Create sample panel data with 6 time periods
822
+ >>> # Periods 0-2 are pre-treatment, periods 3-5 are post-treatment
823
+ >>> data = create_panel_data() # Your data
824
+ >>>
825
+ >>> # Fit the model
826
+ >>> did = MultiPeriodDiD()
827
+ >>> results = did.fit(
828
+ ... data,
829
+ ... outcome='sales',
830
+ ... treatment='treated',
831
+ ... time='period',
832
+ ... post_periods=[3, 4, 5] # Specify which periods are post-treatment
833
+ ... )
834
+ >>>
835
+ >>> # View period-specific effects
836
+ >>> for period, effect in results.period_effects.items():
837
+ ... print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})")
838
+ >>>
839
+ >>> # View average treatment effect
840
+ >>> print(f"Average ATT: {results.avg_att:.3f}")
841
+
842
+ Notes
843
+ -----
844
+ The model estimates:
845
+
846
+ Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_{t≠ref} δ_t*(D_i × 1{t}) + ε_it
847
+
848
+ Where:
849
+ - D_i is the treatment indicator
850
+ - Period_t are time period dummies (all non-reference periods)
851
+ - D_i × 1{t} are treatment-by-period interactions (all non-reference)
852
+ - δ_t are the period-specific treatment effects
853
+ - The reference period (default: last pre-period) has δ_ref = 0 by construction
854
+
855
+ Pre-treatment δ_t test the parallel trends assumption (should be ≈ 0).
856
+ Post-treatment δ_t estimate dynamic treatment effects.
857
+ The average ATT is computed from post-treatment δ_t only.
858
+ """
859
+
860
+ def fit( # type: ignore[override]
861
+ self,
862
+ data: pd.DataFrame,
863
+ outcome: str,
864
+ treatment: str,
865
+ time: str,
866
+ post_periods: Optional[List[Any]] = None,
867
+ covariates: Optional[List[str]] = None,
868
+ fixed_effects: Optional[List[str]] = None,
869
+ absorb: Optional[List[str]] = None,
870
+ reference_period: Any = None,
871
+ unit: Optional[str] = None,
872
+ survey_design=None,
873
+ ) -> MultiPeriodDiDResults:
874
+ """
875
+ Fit the Multi-Period Difference-in-Differences model.
876
+
877
+ Parameters
878
+ ----------
879
+ data : pd.DataFrame
880
+ DataFrame containing the outcome, treatment, and time variables.
881
+ outcome : str
882
+ Name of the outcome variable column.
883
+ treatment : str
884
+ Name of the treatment group indicator column (0/1). Should be a
885
+ time-invariant ever-treated indicator (D_i = 1 for all periods of
886
+ treated units). If treatment is time-varying (D_it), pre-period
887
+ interaction coefficients will be unidentified.
888
+ time : str
889
+ Name of the time period column (can have multiple values).
890
+ post_periods : list
891
+ List of time period values that are post-treatment.
892
+ All other periods are treated as pre-treatment.
893
+ covariates : list, optional
894
+ List of covariate column names to include as linear controls.
895
+ fixed_effects : list, optional
896
+ List of categorical column names to include as fixed effects.
897
+ absorb : list, optional
898
+ List of categorical column names for high-dimensional fixed effects.
899
+ reference_period : any, optional
900
+ The reference (omitted) time period for the period dummies.
901
+ Defaults to the last pre-treatment period (e=-1 convention).
902
+ unit : str, optional
903
+ Name of the unit identifier column. When provided, checks whether
904
+ treatment timing varies across units and warns if staggered adoption
905
+ is detected (suggests CallawaySantAnna instead). Does NOT affect
906
+ standard error computation -- use the ``cluster`` parameter for
907
+ cluster-robust SEs.
908
+ survey_design : SurveyDesign, optional
909
+ Survey design specification for design-based inference. When provided,
910
+ uses Taylor Series Linearization for variance estimation and
911
+ applies sampling weights to the regression.
912
+
913
+ Returns
914
+ -------
915
+ MultiPeriodDiDResults
916
+ Object containing period-specific and average treatment effects.
917
+
918
+ Raises
919
+ ------
920
+ ValueError
921
+ If required parameters are missing or data validation fails.
922
+ """
923
+ # Fall back to analytical inference if wild bootstrap requested
924
+ # (must happen before _resolve_survey_for_fit which rejects bootstrap+survey)
925
+ effective_inference = self.inference
926
+ if self.inference == "wild_bootstrap":
927
+ warnings.warn(
928
+ "Wild bootstrap inference is not yet supported for MultiPeriodDiD. "
929
+ "Using analytical inference instead.",
930
+ UserWarning,
931
+ )
932
+ effective_inference = "analytical"
933
+
934
+ # Validate basic inputs
935
+ if outcome is None or treatment is None or time is None:
936
+ raise ValueError("Must provide 'outcome', 'treatment', and 'time'")
937
+
938
+ # Validate columns exist
939
+ self._validate_data(data, outcome, treatment, time, covariates)
940
+
941
+ # Validate treatment is binary
942
+ validate_binary(data[treatment].values, "treatment")
943
+
944
+ # Validate unit column and check for staggered adoption
945
+ if unit is not None:
946
+ if unit not in data.columns:
947
+ raise ValueError(f"Unit column '{unit}' not found in data")
948
+
949
+ # Check for staggered treatment timing and absorbing treatment
950
+ unit_time_sorted = data.sort_values([unit, time])
951
+ adoption_times = {}
952
+ has_reversal = False
953
+ for u, group in unit_time_sorted.groupby(unit):
954
+ d_vals = group[treatment].values
955
+ # Check for treatment reversal (non-absorbing treatment)
956
+ if not has_reversal and len(d_vals) > 1 and np.any(np.diff(d_vals) < 0):
957
+ warnings.warn(
958
+ f"Treatment reversal detected (unit '{u}' transitions from "
959
+ f"treated to untreated). MultiPeriodDiD assumes treatment is "
960
+ f"an absorbing state (once treated, always treated). "
961
+ f"Treatment reversals violate this assumption and may "
962
+ f"produce unreliable estimates.",
963
+ UserWarning,
964
+ stacklevel=2,
965
+ )
966
+ has_reversal = True
967
+ # Only use units with observed 0→1 transition for adoption timing
968
+ # (skip units that are always treated — can't determine adoption time)
969
+ if 0 in d_vals and 1 in d_vals:
970
+ adoption_times[u] = group.loc[group[treatment] == 1, time].iloc[0]
971
+
972
+ if len(adoption_times) > 0:
973
+ unique_adoption = len(set(adoption_times.values()))
974
+ if unique_adoption > 1:
975
+ warnings.warn(
976
+ "Treatment timing varies across units (staggered adoption "
977
+ "detected). MultiPeriodDiD assumes simultaneous adoption "
978
+ "and may produce biased estimates with staggered treatment. "
979
+ "Consider using CallawaySantAnna or SunAbraham instead.",
980
+ UserWarning,
981
+ stacklevel=2,
982
+ )
983
+
984
+ # Check for time-varying treatment (D_it instead of D_i)
985
+ # If any unit has a 0→1 transition, the treatment column is D_it.
986
+ # MultiPeriodDiD expects a time-invariant ever-treated indicator.
987
+ warnings.warn(
988
+ "Treatment indicator varies within units (time-varying "
989
+ "treatment detected). MultiPeriodDiD's event-study "
990
+ "specification expects a time-invariant ever-treated "
991
+ "indicator (D_i = 1 for all periods of eventually-treated "
992
+ "units). With time-varying treatment, pre-period "
993
+ "interaction coefficients will be unidentified. Consider: "
994
+ f"df['ever_treated'] = df.groupby('{unit}')['{treatment}']"
995
+ ".transform('max')",
996
+ UserWarning,
997
+ stacklevel=2,
998
+ )
999
+
1000
+ # Get all unique time periods
1001
+ all_periods = sorted(data[time].unique())
1002
+
1003
+ if len(all_periods) < 2:
1004
+ raise ValueError("Time variable must have at least 2 unique periods")
1005
+
1006
+ # Determine pre and post periods
1007
+ if post_periods is None:
1008
+ # Default: last half of periods are post-treatment
1009
+ mid_point = len(all_periods) // 2
1010
+ post_periods = all_periods[mid_point:]
1011
+ pre_periods = all_periods[:mid_point]
1012
+ else:
1013
+ post_periods = list(post_periods)
1014
+ pre_periods = [p for p in all_periods if p not in post_periods]
1015
+
1016
+ if len(post_periods) == 0:
1017
+ raise ValueError("Must have at least one post-treatment period")
1018
+
1019
+ if len(pre_periods) == 0:
1020
+ raise ValueError("Must have at least one pre-treatment period")
1021
+
1022
+ if len(pre_periods) < 2:
1023
+ warnings.warn(
1024
+ "Only one pre-treatment period available. At least 2 pre-periods "
1025
+ "are needed to assess parallel trends. The treatment effect estimate "
1026
+ "is still valid, but pre-period coefficients for parallel trends "
1027
+ "testing are not available.",
1028
+ UserWarning,
1029
+ stacklevel=2,
1030
+ )
1031
+
1032
+ # Validate post_periods are in the data
1033
+ for p in post_periods:
1034
+ if p not in all_periods:
1035
+ raise ValueError(f"Post-period '{p}' not found in time column")
1036
+
1037
+ # Determine reference period (omitted dummy)
1038
+ if reference_period is None:
1039
+ # Default: last pre-period (e=-1 convention, matches fixest)
1040
+ if len(pre_periods) > 1:
1041
+ warnings.warn(
1042
+ f"The default reference_period has changed from the first "
1043
+ f"pre-period ({pre_periods[0]}) to the last pre-period "
1044
+ f"({pre_periods[-1]}) to match the standard e=-1 convention "
1045
+ f"(as used by fixest, did, etc.). "
1046
+ f"To silence this warning, pass "
1047
+ f"reference_period={pre_periods[-1]} explicitly.",
1048
+ FutureWarning,
1049
+ stacklevel=2,
1050
+ )
1051
+ reference_period = pre_periods[-1]
1052
+ elif reference_period not in all_periods:
1053
+ raise ValueError(f"Reference period '{reference_period}' not found in time column")
1054
+
1055
+ # Disallow post-period reference (downstream logic assumes reference is pre-period)
1056
+ if reference_period in post_periods:
1057
+ raise ValueError(
1058
+ f"reference_period={reference_period} is a post-treatment period. "
1059
+ f"The reference period must be a pre-treatment period "
1060
+ f"(e.g., the last pre-period {pre_periods[-1]}). "
1061
+ f"Post-period references are not supported because the reference "
1062
+ f"period is excluded from estimation, which would bias avg_att "
1063
+ f"and break downstream inference."
1064
+ )
1065
+
1066
+ # Validate fixed effects and absorb columns
1067
+ if fixed_effects:
1068
+ for fe in fixed_effects:
1069
+ if fe not in data.columns:
1070
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
1071
+ if absorb:
1072
+ for ab in absorb:
1073
+ if ab not in data.columns:
1074
+ raise ValueError(f"Absorb column '{ab}' not found in data")
1075
+
1076
+ # Resolve survey design if provided
1077
+ from diff_diff.survey import _resolve_effective_cluster, _resolve_survey_for_fit
1078
+
1079
+ resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
1080
+ _resolve_survey_for_fit(survey_design, data, effective_inference)
1081
+ )
1082
+ _uses_replicate_mp = (
1083
+ resolved_survey is not None and resolved_survey.uses_replicate_variance
1084
+ )
1085
+ if _uses_replicate_mp and effective_inference == "wild_bootstrap":
1086
+ raise ValueError(
1087
+ "Cannot use inference='wild_bootstrap' with replicate-weight "
1088
+ "survey designs. Replicate weights provide their own variance "
1089
+ "estimation."
1090
+ )
1091
+
1092
+ # Handle absorbed fixed effects (within-transformation)
1093
+ working_data = data.copy()
1094
+ n_absorbed_effects = 0
1095
+
1096
+ # Save raw treatment counts before absorb demeaning
1097
+ n_treated_raw = int(np.sum(data[treatment].values.astype(float)))
1098
+ n_control_raw = len(data) - n_treated_raw
1099
+
1100
+ # Reject multi-absorb with survey weights (single-pass demeaning is
1101
+ # not the correct weighted FWL projection for N > 1 dimensions)
1102
+ if absorb and len(absorb) > 1 and survey_weights is not None:
1103
+ raise ValueError(
1104
+ f"Multiple absorbed fixed effects (absorb={absorb}) with survey "
1105
+ "weights is not supported. Single-pass sequential demeaning is not "
1106
+ "the correct weighted FWL projection for multiple absorbed dimensions. "
1107
+ "Use absorb with a single variable, or use fixed_effects= instead."
1108
+ )
1109
+
1110
+ if absorb and fixed_effects:
1111
+ raise ValueError(
1112
+ "Cannot use both absorb and fixed_effects. "
1113
+ "The absorb within-transformation does not residualize "
1114
+ "fixed_effects dummies, violating the FWL theorem. "
1115
+ "Use absorb alone (for high-dimensional FE) "
1116
+ "or fixed_effects alone (for low-dimensional FE)."
1117
+ )
1118
+
1119
+ # Pre-compute non_ref_periods (needed for absorb demeaning)
1120
+ non_ref_periods = [p for p in all_periods if p != reference_period]
1121
+
1122
+ if absorb:
1123
+ # FWL theorem: demean ALL regressors alongside outcome.
1124
+ # Regressors collinear with absorbed FE (e.g., treatment after
1125
+ # absorbing unit FE) will zero out and be handled by rank-deficiency.
1126
+ d_raw = working_data[treatment].values.astype(float)
1127
+ t_raw = working_data[time].values
1128
+ working_data["_did_treatment"] = d_raw
1129
+ for period in non_ref_periods:
1130
+ working_data[f"_did_period_{period}"] = (t_raw == period).astype(float)
1131
+ working_data[f"_did_interact_{period}"] = d_raw * (t_raw == period).astype(float)
1132
+ vars_to_demean = (
1133
+ [outcome, "_did_treatment"]
1134
+ + [f"_did_period_{p}" for p in non_ref_periods]
1135
+ + [f"_did_interact_{p}" for p in non_ref_periods]
1136
+ + (covariates or [])
1137
+ )
1138
+ for ab_var in absorb:
1139
+ working_data, n_fe = demean_by_group(
1140
+ working_data,
1141
+ vars_to_demean,
1142
+ ab_var,
1143
+ inplace=True,
1144
+ weights=survey_weights,
1145
+ )
1146
+ n_absorbed_effects += n_fe
1147
+
1148
+ # Extract outcome and treatment (may be demeaned if absorb was used)
1149
+ y = working_data[outcome].values.astype(float)
1150
+ if absorb:
1151
+ d = working_data["_did_treatment"].values.astype(float)
1152
+ else:
1153
+ d = working_data[treatment].values.astype(float)
1154
+ t = working_data[time].values
1155
+
1156
+ # Build design matrix
1157
+ # Start with intercept and treatment main effect
1158
+ X = np.column_stack([np.ones(len(y)), d])
1159
+ var_names = ["const", treatment]
1160
+
1161
+ # Add period dummies (excluding reference period)
1162
+ period_dummy_indices = {} # Map period -> column index in X
1163
+
1164
+ for period in non_ref_periods:
1165
+ if absorb:
1166
+ period_dummy = working_data[f"_did_period_{period}"].values.astype(float)
1167
+ else:
1168
+ period_dummy = (t == period).astype(float)
1169
+ X = np.column_stack([X, period_dummy])
1170
+ var_names.append(f"period_{period}")
1171
+ period_dummy_indices[period] = X.shape[1] - 1
1172
+
1173
+ # Add treatment × period interactions for ALL non-reference periods
1174
+ # Pre-period interactions test parallel trends; post-period interactions
1175
+ # estimate dynamic treatment effects
1176
+ interaction_indices = {} # Map period -> column index in X
1177
+
1178
+ for period in non_ref_periods:
1179
+ if absorb:
1180
+ interaction = working_data[f"_did_interact_{period}"].values.astype(float)
1181
+ else:
1182
+ interaction = d * (t == period).astype(float)
1183
+ X = np.column_stack([X, interaction])
1184
+ var_names.append(f"{treatment}:period_{period}")
1185
+ interaction_indices[period] = X.shape[1] - 1
1186
+
1187
+ # Add covariates if provided
1188
+ if covariates:
1189
+ for cov in covariates:
1190
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
1191
+ var_names.append(cov)
1192
+
1193
+ # Add fixed effects as dummy variables
1194
+ if fixed_effects:
1195
+ for fe in fixed_effects:
1196
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
1197
+ for col in dummies.columns:
1198
+ X = np.column_stack([X, dummies[col].values.astype(float)])
1199
+ var_names.append(col)
1200
+
1201
+ # Fit OLS using unified backend
1202
+ # Pass cluster_ids to solve_ols for proper vcov computation
1203
+ # This handles rank-deficient matrices by returning NaN for dropped columns
1204
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
1205
+
1206
+ # When survey PSU is present, it overrides cluster for variance estimation
1207
+ effective_cluster_ids = _resolve_effective_cluster(
1208
+ resolved_survey, cluster_ids, self.cluster
1209
+ )
1210
+
1211
+ # Inject cluster as effective PSU for survey variance estimation
1212
+ if resolved_survey is not None and effective_cluster_ids is not None:
1213
+ from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata
1214
+
1215
+ resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
1216
+ if resolved_survey.psu is not None and survey_metadata is not None:
1217
+ raw_w = (
1218
+ data[survey_design.weights].values.astype(np.float64)
1219
+ if survey_design.weights
1220
+ else np.ones(len(data), dtype=np.float64)
1221
+ )
1222
+ survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
1223
+
1224
+ # Determine if survey vcov should be used
1225
+ _use_survey_vcov = resolved_survey is not None and resolved_survey.needs_survey_vcov
1226
+
1227
+ # Note: Wild bootstrap for multi-period effects is complex (multiple coefficients)
1228
+ # For now, we use analytical inference even if inference="wild_bootstrap"
1229
+ coefficients, residuals, fitted, vcov = solve_ols(
1230
+ X,
1231
+ y,
1232
+ return_fitted=True,
1233
+ return_vcov=not _use_survey_vcov,
1234
+ cluster_ids=effective_cluster_ids,
1235
+ column_names=var_names,
1236
+ rank_deficient_action=self.rank_deficient_action,
1237
+ weights=survey_weights,
1238
+ weight_type=survey_weight_type,
1239
+ )
1240
+
1241
+ # Compute survey vcov if applicable
1242
+ _n_valid_rep_mp = None
1243
+ if _use_survey_vcov and _uses_replicate_mp and absorb:
1244
+ # Absorb + replicate: estimator-level refit (demeaning depends on weights)
1245
+ from diff_diff.survey import compute_replicate_refit_variance
1246
+
1247
+ _absorb_list_mp = list(absorb)
1248
+ # Handle rank-deficient nuisance: refit only identified columns
1249
+ _id_mask_mp = ~np.isnan(coefficients)
1250
+ _id_cols_mp = np.where(_id_mask_mp)[0]
1251
+
1252
+ def _refit_mp_absorb(w_r):
1253
+ nz = w_r > 0
1254
+ wd = data[nz].copy()
1255
+ w_nz = w_r[nz]
1256
+ d_raw_ = wd[treatment].values.astype(float)
1257
+ t_raw_ = wd[time].values
1258
+ wd["_did_treatment"] = d_raw_
1259
+ for period_ in non_ref_periods:
1260
+ wd[f"_did_period_{period_}"] = (t_raw_ == period_).astype(float)
1261
+ wd[f"_did_interact_{period_}"] = d_raw_ * (t_raw_ == period_).astype(float)
1262
+ vars_dm_ = (
1263
+ [outcome, "_did_treatment"]
1264
+ + [f"_did_period_{p}" for p in non_ref_periods]
1265
+ + [f"_did_interact_{p}" for p in non_ref_periods]
1266
+ + (covariates or [])
1267
+ )
1268
+ for ab_var_ in _absorb_list_mp:
1269
+ wd, _ = demean_by_group(wd, vars_dm_, ab_var_, inplace=True, weights=w_nz)
1270
+ y_r = wd[outcome].values.astype(float)
1271
+ d_r = wd["_did_treatment"].values.astype(float)
1272
+ X_r = np.column_stack([np.ones(len(y_r)), d_r])
1273
+ for period_ in non_ref_periods:
1274
+ X_r = np.column_stack(
1275
+ [X_r, wd[f"_did_period_{period_}"].values.astype(float)]
1276
+ )
1277
+ for period_ in non_ref_periods:
1278
+ X_r = np.column_stack(
1279
+ [X_r, wd[f"_did_interact_{period_}"].values.astype(float)]
1280
+ )
1281
+ if covariates:
1282
+ for cov_ in covariates:
1283
+ X_r = np.column_stack([X_r, wd[cov_].values.astype(float)])
1284
+ coef_r, _, _ = solve_ols(
1285
+ X_r[:, _id_cols_mp], y_r,
1286
+ weights=w_nz, weight_type=survey_weight_type,
1287
+ rank_deficient_action="silent", return_vcov=False,
1288
+ )
1289
+ return coef_r
1290
+
1291
+ vcov_reduced_mp, _n_valid_rep_mp = compute_replicate_refit_variance(
1292
+ _refit_mp_absorb, coefficients[_id_mask_mp], resolved_survey
1293
+ )
1294
+ vcov = _expand_vcov_with_nan(vcov_reduced_mp, len(coefficients), _id_cols_mp)
1295
+ elif _use_survey_vcov and _uses_replicate_mp:
1296
+ # No absorb + replicate: X is fixed, use compute_replicate_vcov directly
1297
+ from diff_diff.survey import compute_replicate_vcov
1298
+
1299
+ nan_mask = np.isnan(coefficients)
1300
+ if np.any(nan_mask):
1301
+ kept_cols = np.where(~nan_mask)[0]
1302
+ if len(kept_cols) > 0:
1303
+ vcov_reduced, _n_valid_rep_mp = compute_replicate_vcov(
1304
+ X[:, kept_cols], y, coefficients[kept_cols], resolved_survey,
1305
+ weight_type=survey_weight_type,
1306
+ )
1307
+ vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols)
1308
+ else:
1309
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1310
+ _n_valid_rep_mp = 0
1311
+ else:
1312
+ vcov, _n_valid_rep_mp = compute_replicate_vcov(
1313
+ X, y, coefficients, resolved_survey, weight_type=survey_weight_type,
1314
+ )
1315
+ elif _use_survey_vcov:
1316
+ from diff_diff.survey import compute_survey_vcov
1317
+
1318
+ nan_mask = np.isnan(coefficients)
1319
+ if np.any(nan_mask):
1320
+ kept_cols = np.where(~nan_mask)[0]
1321
+ if len(kept_cols) > 0:
1322
+ vcov_reduced = compute_survey_vcov(X[:, kept_cols], residuals, resolved_survey)
1323
+ vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols)
1324
+ else:
1325
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1326
+ else:
1327
+ vcov = compute_survey_vcov(X, residuals, resolved_survey)
1328
+ r_squared = compute_r_squared(y, residuals)
1329
+
1330
+ # Degrees of freedom: survey df overrides standard df
1331
+ k_effective = int(np.sum(~np.isnan(coefficients)))
1332
+ # For fweights, df uses sum(w) - k (effective sample size)
1333
+ n_eff_df = len(y)
1334
+ if survey_weights is not None and survey_weight_type == "fweight":
1335
+ n_eff_df = int(round(np.sum(survey_weights)))
1336
+ df = n_eff_df - k_effective - n_absorbed_effects
1337
+ if resolved_survey is not None and resolved_survey.df_survey is not None:
1338
+ df = resolved_survey.df_survey
1339
+ # Replicate df: rank-deficient → NaN inference; dropped replicates → n_valid-1
1340
+ if _uses_replicate_mp:
1341
+ if resolved_survey.df_survey is None:
1342
+ df = 0 # rank-deficient replicate → NaN inference
1343
+ if _n_valid_rep_mp is not None and _n_valid_rep_mp < resolved_survey.n_replicates:
1344
+ df = _n_valid_rep_mp - 1 if _n_valid_rep_mp > 1 else 0
1345
+ if survey_metadata is not None:
1346
+ survey_metadata.df_survey = df if df > 0 else None
1347
+
1348
+ # Guard: fall back to normal distribution if df is non-positive
1349
+ # Skip for replicate designs — df=0 is intentional for NaN inference
1350
+ if df is not None and df <= 0 and not _uses_replicate_mp:
1351
+ warnings.warn(
1352
+ f"Degrees of freedom is non-positive (df={df}). "
1353
+ "Using normal distribution instead of t-distribution for inference.",
1354
+ UserWarning,
1355
+ stacklevel=2,
1356
+ )
1357
+ df = None
1358
+
1359
+ # For non-robust, non-clustered case, we need homoskedastic vcov
1360
+ # solve_ols returns HC1 by default, so compute homoskedastic if needed
1361
+ if not self.robust and self.cluster is None and survey_weights is None:
1362
+ n = len(y)
1363
+ mse = np.sum(residuals**2) / (n - k_effective)
1364
+ # Use solve() instead of inv() for numerical stability
1365
+ # Only compute for identified columns (non-NaN coefficients)
1366
+ identified_mask = ~np.isnan(coefficients)
1367
+ if np.all(identified_mask):
1368
+ vcov = np.linalg.solve(X.T @ X, mse * np.eye(X.shape[1]))
1369
+ else:
1370
+ # For rank-deficient case, compute vcov on reduced matrix then expand
1371
+ X_reduced = X[:, identified_mask]
1372
+ vcov_reduced = np.linalg.solve(
1373
+ X_reduced.T @ X_reduced, mse * np.eye(X_reduced.shape[1])
1374
+ )
1375
+ # Expand to full size with NaN for dropped columns
1376
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1377
+ vcov[np.ix_(identified_mask, identified_mask)] = vcov_reduced
1378
+
1379
+ # Extract period-specific treatment effects for ALL non-reference periods
1380
+ period_effects = {}
1381
+ post_effect_values = []
1382
+ post_effect_indices = []
1383
+
1384
+ assert vcov is not None
1385
+ for period in non_ref_periods:
1386
+ idx = interaction_indices[period]
1387
+ effect = coefficients[idx]
1388
+ se = np.sqrt(vcov[idx, idx])
1389
+ t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=df)
1390
+
1391
+ period_effects[period] = PeriodEffect(
1392
+ period=period,
1393
+ effect=effect,
1394
+ se=se,
1395
+ t_stat=t_stat,
1396
+ p_value=p_value,
1397
+ conf_int=conf_int,
1398
+ )
1399
+
1400
+ if period in post_periods:
1401
+ post_effect_values.append(effect)
1402
+ post_effect_indices.append(idx)
1403
+
1404
+ # Compute average treatment effect (post-periods only)
1405
+ # R-style NA propagation: if ANY post-period effect is NaN, average is undefined
1406
+ effect_arr = np.array(post_effect_values)
1407
+
1408
+ if np.any(np.isnan(effect_arr)):
1409
+ # Some period effects are NaN (unidentified) - cannot compute valid average
1410
+ # This follows R's default behavior where mean(c(1, 2, NA)) returns NA
1411
+ avg_att = np.nan
1412
+ avg_se = np.nan
1413
+ avg_t_stat = np.nan
1414
+ avg_p_value = np.nan
1415
+ avg_conf_int = (np.nan, np.nan)
1416
+ else:
1417
+ # All effects identified - compute average normally
1418
+ avg_att = float(np.mean(effect_arr))
1419
+
1420
+ # Standard error of average: need to account for covariance
1421
+ n_post = len(post_periods)
1422
+ sub_vcov = vcov[np.ix_(post_effect_indices, post_effect_indices)]
1423
+ avg_var = np.sum(sub_vcov) / (n_post**2)
1424
+
1425
+ if np.isnan(avg_var) or avg_var < 0:
1426
+ # Vcov has NaN (dropped columns) - propagate NaN
1427
+ avg_se = np.nan
1428
+ avg_t_stat = np.nan
1429
+ avg_p_value = np.nan
1430
+ avg_conf_int = (np.nan, np.nan)
1431
+ else:
1432
+ avg_se = float(np.sqrt(avg_var))
1433
+ avg_t_stat, avg_p_value, avg_conf_int = safe_inference(
1434
+ avg_att, avg_se, alpha=self.alpha, df=df
1435
+ )
1436
+
1437
+ # Count observations (use raw counts to avoid demeaned values from absorb)
1438
+ n_treated = n_treated_raw
1439
+ n_control = n_control_raw
1440
+
1441
+ # Create coefficient dictionary
1442
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
1443
+
1444
+ # Store results
1445
+ self.results_ = MultiPeriodDiDResults(
1446
+ period_effects=period_effects,
1447
+ avg_att=avg_att,
1448
+ avg_se=avg_se,
1449
+ avg_t_stat=avg_t_stat,
1450
+ avg_p_value=avg_p_value,
1451
+ avg_conf_int=avg_conf_int,
1452
+ n_obs=len(y),
1453
+ n_treated=n_treated,
1454
+ n_control=n_control,
1455
+ pre_periods=pre_periods,
1456
+ post_periods=post_periods,
1457
+ alpha=self.alpha,
1458
+ coefficients=coef_dict,
1459
+ vcov=vcov,
1460
+ residuals=residuals,
1461
+ fitted_values=fitted,
1462
+ r_squared=r_squared,
1463
+ reference_period=reference_period,
1464
+ interaction_indices=interaction_indices,
1465
+ survey_metadata=survey_metadata,
1466
+ )
1467
+
1468
+ self._coefficients = coefficients
1469
+ self._vcov = vcov
1470
+ self.is_fitted_ = True
1471
+
1472
+ return self.results_
1473
+
1474
+ def summary(self) -> str:
1475
+ """
1476
+ Get summary of estimation results.
1477
+
1478
+ Returns
1479
+ -------
1480
+ str
1481
+ Formatted summary.
1482
+ """
1483
+ if not self.is_fitted_:
1484
+ raise RuntimeError("Model must be fitted before calling summary()")
1485
+ assert self.results_ is not None
1486
+ return self.results_.summary()
1487
+
1488
+
1489
+ # Re-export estimators from submodules for backward compatibility
1490
+ # These can also be imported directly from their respective modules:
1491
+ # - from diff_diff.twfe import TwoWayFixedEffects
1492
+ # - from diff_diff.synthetic_did import SyntheticDiD
1493
+ from diff_diff.synthetic_did import SyntheticDiD # noqa: E402
1494
+ from diff_diff.twfe import TwoWayFixedEffects # noqa: E402
1495
+
1496
+ __all__ = [
1497
+ "DifferenceInDifferences",
1498
+ "MultiPeriodDiD",
1499
+ "TwoWayFixedEffects",
1500
+ "SyntheticDiD",
1501
+ ]