diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1161 @@
1
+ """
2
+ Difference-in-Differences estimators with sklearn-like API.
3
+
4
+ This module contains the core DiD estimators:
5
+ - DifferenceInDifferences: Basic 2x2 DiD estimator
6
+ - MultiPeriodDiD: Event-study style DiD with period-specific treatment effects
7
+
8
+ Additional estimators are in separate modules:
9
+ - TwoWayFixedEffects: See diff_diff.twfe
10
+ - SyntheticDiD: See diff_diff.synthetic_did
11
+
12
+ For backward compatibility, all estimators are re-exported from this module.
13
+ """
14
+
15
+ import warnings
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ from diff_diff.linalg import (
22
+ LinearRegression,
23
+ compute_r_squared,
24
+ compute_robust_vcov,
25
+ solve_ols,
26
+ )
27
+ from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
28
+ from diff_diff.utils import (
29
+ WildBootstrapResults,
30
+ compute_confidence_interval,
31
+ compute_p_value,
32
+ demean_by_group,
33
+ validate_binary,
34
+ wild_bootstrap_se,
35
+ )
36
+
37
+
38
+ class DifferenceInDifferences:
39
+ """
40
+ Difference-in-Differences estimator with sklearn-like interface.
41
+
42
+ Estimates the Average Treatment effect on the Treated (ATT) using
43
+ the canonical 2x2 DiD design or panel data with two-way fixed effects.
44
+
45
+ Parameters
46
+ ----------
47
+ formula : str, optional
48
+ R-style formula for the model (e.g., "outcome ~ treated * post").
49
+ If provided, overrides column name parameters.
50
+ robust : bool, default=True
51
+ Whether to use heteroskedasticity-robust standard errors (HC1).
52
+ cluster : str, optional
53
+ Column name for cluster-robust standard errors.
54
+ alpha : float, default=0.05
55
+ Significance level for confidence intervals.
56
+ inference : str, default="analytical"
57
+ Inference method: "analytical" for standard asymptotic inference,
58
+ or "wild_bootstrap" for wild cluster bootstrap (recommended when
59
+ number of clusters is small, <50).
60
+ n_bootstrap : int, default=999
61
+ Number of bootstrap replications when inference="wild_bootstrap".
62
+ bootstrap_weights : str, default="rademacher"
63
+ Type of bootstrap weights: "rademacher" (standard), "webb"
64
+ (recommended for <10 clusters), or "mammen" (skewness correction).
65
+ seed : int, optional
66
+ Random seed for reproducibility when using bootstrap inference.
67
+ If None (default), results will vary between runs.
68
+ rank_deficient_action : str, default "warn"
69
+ Action when design matrix is rank-deficient (linearly dependent columns):
70
+ - "warn": Issue warning and drop linearly dependent columns (default)
71
+ - "error": Raise ValueError
72
+ - "silent": Drop columns silently without warning
73
+
74
+ Attributes
75
+ ----------
76
+ results_ : DiDResults
77
+ Estimation results after calling fit().
78
+ is_fitted_ : bool
79
+ Whether the model has been fitted.
80
+
81
+ Examples
82
+ --------
83
+ Basic usage with a DataFrame:
84
+
85
+ >>> import pandas as pd
86
+ >>> from diff_diff import DifferenceInDifferences
87
+ >>>
88
+ >>> # Create sample data
89
+ >>> data = pd.DataFrame({
90
+ ... 'outcome': [10, 11, 15, 18, 9, 10, 12, 13],
91
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
92
+ ... 'post': [0, 0, 1, 1, 0, 0, 1, 1]
93
+ ... })
94
+ >>>
95
+ >>> # Fit the model
96
+ >>> did = DifferenceInDifferences()
97
+ >>> results = did.fit(data, outcome='outcome', treatment='treated', time='post')
98
+ >>>
99
+ >>> # View results
100
+ >>> print(results.att) # ATT estimate
101
+ >>> results.print_summary() # Full summary table
102
+
103
+ Using formula interface:
104
+
105
+ >>> did = DifferenceInDifferences()
106
+ >>> results = did.fit(data, formula='outcome ~ treated * post')
107
+
108
+ Notes
109
+ -----
110
+ The ATT is computed using the standard DiD formula:
111
+
112
+ ATT = (E[Y|D=1,T=1] - E[Y|D=1,T=0]) - (E[Y|D=0,T=1] - E[Y|D=0,T=0])
113
+
114
+ Or equivalently via OLS regression:
115
+
116
+ Y = α + β₁*D + β₂*T + β₃*(D×T) + ε
117
+
118
+ Where β₃ is the ATT.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ robust: bool = True,
124
+ cluster: Optional[str] = None,
125
+ alpha: float = 0.05,
126
+ inference: str = "analytical",
127
+ n_bootstrap: int = 999,
128
+ bootstrap_weights: str = "rademacher",
129
+ seed: Optional[int] = None,
130
+ rank_deficient_action: str = "warn",
131
+ ):
132
+ self.robust = robust
133
+ self.cluster = cluster
134
+ self.alpha = alpha
135
+ self.inference = inference
136
+ self.n_bootstrap = n_bootstrap
137
+ self.bootstrap_weights = bootstrap_weights
138
+ self.seed = seed
139
+ self.rank_deficient_action = rank_deficient_action
140
+
141
+ self.is_fitted_ = False
142
+ self.results_ = None
143
+ self._coefficients = None
144
+ self._vcov = None
145
+ self._bootstrap_results = None # Store WildBootstrapResults if used
146
+
147
+ def fit(
148
+ self,
149
+ data: pd.DataFrame,
150
+ outcome: Optional[str] = None,
151
+ treatment: Optional[str] = None,
152
+ time: Optional[str] = None,
153
+ formula: Optional[str] = None,
154
+ covariates: Optional[List[str]] = None,
155
+ fixed_effects: Optional[List[str]] = None,
156
+ absorb: Optional[List[str]] = None,
157
+ ) -> DiDResults:
158
+ """
159
+ Fit the Difference-in-Differences model.
160
+
161
+ Parameters
162
+ ----------
163
+ data : pd.DataFrame
164
+ DataFrame containing the outcome, treatment, and time variables.
165
+ outcome : str
166
+ Name of the outcome variable column.
167
+ treatment : str
168
+ Name of the treatment group indicator column (0/1).
169
+ time : str
170
+ Name of the post-treatment period indicator column (0/1).
171
+ formula : str, optional
172
+ R-style formula (e.g., "outcome ~ treated * post").
173
+ If provided, overrides outcome, treatment, and time parameters.
174
+ covariates : list, optional
175
+ List of covariate column names to include as linear controls.
176
+ fixed_effects : list, optional
177
+ List of categorical column names to include as fixed effects.
178
+ Creates dummy variables for each category (drops first level).
179
+ Use for low-dimensional fixed effects (e.g., industry, region).
180
+ absorb : list, optional
181
+ List of categorical column names for high-dimensional fixed effects.
182
+ Uses within-transformation (demeaning) instead of dummy variables.
183
+ More efficient for large numbers of categories (e.g., firm, individual).
184
+
185
+ Returns
186
+ -------
187
+ DiDResults
188
+ Object containing estimation results.
189
+
190
+ Raises
191
+ ------
192
+ ValueError
193
+ If required parameters are missing or data validation fails.
194
+
195
+ Examples
196
+ --------
197
+ Using fixed effects (dummy variables):
198
+
199
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
200
+ ... fixed_effects=['state', 'industry'])
201
+
202
+ Using absorbed fixed effects (within-transformation):
203
+
204
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
205
+ ... absorb=['firm_id'])
206
+ """
207
+ # Parse formula if provided
208
+ if formula is not None:
209
+ outcome, treatment, time, covariates = self._parse_formula(formula, data)
210
+ elif outcome is None or treatment is None or time is None:
211
+ raise ValueError(
212
+ "Must provide either 'formula' or all of 'outcome', 'treatment', and 'time'"
213
+ )
214
+
215
+ # Validate inputs
216
+ self._validate_data(data, outcome, treatment, time, covariates)
217
+
218
+ # Validate binary variables BEFORE any transformations
219
+ validate_binary(data[treatment].values, "treatment")
220
+ validate_binary(data[time].values, "time")
221
+
222
+ # Validate fixed effects and absorb columns
223
+ if fixed_effects:
224
+ for fe in fixed_effects:
225
+ if fe not in data.columns:
226
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
227
+ if absorb:
228
+ for ab in absorb:
229
+ if ab not in data.columns:
230
+ raise ValueError(f"Absorb column '{ab}' not found in data")
231
+
232
+ # Handle absorbed fixed effects (within-transformation)
233
+ working_data = data.copy()
234
+ absorbed_vars = []
235
+ n_absorbed_effects = 0
236
+
237
+ if absorb:
238
+ # Apply within-transformation for each absorbed variable
239
+ # Only demean outcome and covariates, NOT treatment/time indicators
240
+ # Treatment is typically time-invariant (within unit), and time is
241
+ # unit-invariant, so demeaning them would create multicollinearity
242
+ vars_to_demean = [outcome] + (covariates or [])
243
+ for ab_var in absorb:
244
+ working_data, n_fe = demean_by_group(
245
+ working_data, vars_to_demean, ab_var, inplace=True
246
+ )
247
+ n_absorbed_effects += n_fe
248
+ absorbed_vars.append(ab_var)
249
+
250
+ # Extract variables (may be demeaned if absorb was used)
251
+ y = working_data[outcome].values.astype(float)
252
+ d = working_data[treatment].values.astype(float)
253
+ t = working_data[time].values.astype(float)
254
+
255
+ # Create interaction term
256
+ dt = d * t
257
+
258
+ # Build design matrix
259
+ X = np.column_stack([np.ones(len(y)), d, t, dt])
260
+ var_names = ["const", treatment, time, f"{treatment}:{time}"]
261
+
262
+ # Add covariates if provided
263
+ if covariates:
264
+ for cov in covariates:
265
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
266
+ var_names.append(cov)
267
+
268
+ # Add fixed effects as dummy variables
269
+ if fixed_effects:
270
+ for fe in fixed_effects:
271
+ # Create dummies, drop first category to avoid multicollinearity
272
+ # Use working_data to be consistent with absorbed FE if both are used
273
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
274
+ for col in dummies.columns:
275
+ X = np.column_stack([X, dummies[col].values.astype(float)])
276
+ var_names.append(col)
277
+
278
+ # Extract ATT index (coefficient on interaction term)
279
+ att_idx = 3 # Index of interaction term
280
+ att_var_name = f"{treatment}:{time}"
281
+ assert var_names[att_idx] == att_var_name, (
282
+ f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, "
283
+ f"but found '{var_names[att_idx]}'"
284
+ )
285
+
286
+ # Always use LinearRegression for initial fit (unified code path)
287
+ # For wild bootstrap, we don't need cluster SEs from the initial fit
288
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
289
+ reg = LinearRegression(
290
+ include_intercept=False, # Intercept already in X
291
+ robust=self.robust,
292
+ cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None,
293
+ alpha=self.alpha,
294
+ rank_deficient_action=self.rank_deficient_action,
295
+ ).fit(X, y, df_adjustment=n_absorbed_effects)
296
+
297
+ coefficients = reg.coefficients_
298
+ residuals = reg.residuals_
299
+ fitted = reg.fitted_values_
300
+ att = coefficients[att_idx]
301
+
302
+ # Get inference - either from bootstrap or analytical
303
+ if self.inference == "wild_bootstrap" and self.cluster is not None:
304
+ # Override with wild cluster bootstrap inference
305
+ se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
306
+ X, y, residuals, cluster_ids, att_idx
307
+ )
308
+ else:
309
+ # Use analytical inference from LinearRegression
310
+ vcov = reg.vcov_
311
+ inference = reg.get_inference(att_idx)
312
+ se = inference.se
313
+ t_stat = inference.t_stat
314
+ p_value = inference.p_value
315
+ conf_int = inference.conf_int
316
+
317
+ r_squared = compute_r_squared(y, residuals)
318
+
319
+ # Count observations
320
+ n_treated = int(np.sum(d))
321
+ n_control = int(np.sum(1 - d))
322
+
323
+ # Create coefficient dictionary
324
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
325
+
326
+ # Determine inference method and bootstrap info
327
+ inference_method = "analytical"
328
+ n_bootstrap_used = None
329
+ n_clusters_used = None
330
+ if self._bootstrap_results is not None:
331
+ inference_method = "wild_bootstrap"
332
+ n_bootstrap_used = self._bootstrap_results.n_bootstrap
333
+ n_clusters_used = self._bootstrap_results.n_clusters
334
+
335
+ # Store results
336
+ self.results_ = DiDResults(
337
+ att=att,
338
+ se=se,
339
+ t_stat=t_stat,
340
+ p_value=p_value,
341
+ conf_int=conf_int,
342
+ n_obs=len(y),
343
+ n_treated=n_treated,
344
+ n_control=n_control,
345
+ alpha=self.alpha,
346
+ coefficients=coef_dict,
347
+ vcov=vcov,
348
+ residuals=residuals,
349
+ fitted_values=fitted,
350
+ r_squared=r_squared,
351
+ inference_method=inference_method,
352
+ n_bootstrap=n_bootstrap_used,
353
+ n_clusters=n_clusters_used,
354
+ )
355
+
356
+ self._coefficients = coefficients
357
+ self._vcov = vcov
358
+ self.is_fitted_ = True
359
+
360
+ return self.results_
361
+
362
+ def _fit_ols(
363
+ self, X: np.ndarray, y: np.ndarray
364
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
365
+ """
366
+ Fit OLS regression.
367
+
368
+ This method is kept for backwards compatibility. Internally uses the
369
+ unified solve_ols from diff_diff.linalg for optimized computation.
370
+
371
+ Parameters
372
+ ----------
373
+ X : np.ndarray
374
+ Design matrix.
375
+ y : np.ndarray
376
+ Outcome vector.
377
+
378
+ Returns
379
+ -------
380
+ tuple
381
+ (coefficients, residuals, fitted_values, r_squared)
382
+ """
383
+ # Use unified OLS backend
384
+ coefficients, residuals, fitted, _ = solve_ols(X, y, return_fitted=True, return_vcov=False)
385
+ r_squared = compute_r_squared(y, residuals)
386
+
387
+ return coefficients, residuals, fitted, r_squared
388
+
389
+ def _run_wild_bootstrap_inference(
390
+ self,
391
+ X: np.ndarray,
392
+ y: np.ndarray,
393
+ residuals: np.ndarray,
394
+ cluster_ids: np.ndarray,
395
+ coefficient_index: int,
396
+ ) -> Tuple[float, float, Tuple[float, float], float, np.ndarray, WildBootstrapResults]:
397
+ """
398
+ Run wild cluster bootstrap inference.
399
+
400
+ Parameters
401
+ ----------
402
+ X : np.ndarray
403
+ Design matrix.
404
+ y : np.ndarray
405
+ Outcome vector.
406
+ residuals : np.ndarray
407
+ OLS residuals.
408
+ cluster_ids : np.ndarray
409
+ Cluster identifiers for each observation.
410
+ coefficient_index : int
411
+ Index of the coefficient to compute inference for.
412
+
413
+ Returns
414
+ -------
415
+ tuple
416
+ (se, p_value, conf_int, t_stat, vcov, bootstrap_results)
417
+ """
418
+ bootstrap_results = wild_bootstrap_se(
419
+ X,
420
+ y,
421
+ residuals,
422
+ cluster_ids,
423
+ coefficient_index=coefficient_index,
424
+ n_bootstrap=self.n_bootstrap,
425
+ weight_type=self.bootstrap_weights,
426
+ alpha=self.alpha,
427
+ seed=self.seed,
428
+ return_distribution=False,
429
+ )
430
+ self._bootstrap_results = bootstrap_results
431
+
432
+ se = bootstrap_results.se
433
+ p_value = bootstrap_results.p_value
434
+ conf_int = (bootstrap_results.ci_lower, bootstrap_results.ci_upper)
435
+ t_stat = bootstrap_results.t_stat_original
436
+
437
+ # Also compute vcov for storage (using cluster-robust for consistency)
438
+ vcov = compute_robust_vcov(X, residuals, cluster_ids)
439
+
440
+ return se, p_value, conf_int, t_stat, vcov, bootstrap_results
441
+
442
+ def _parse_formula(
443
+ self, formula: str, data: pd.DataFrame
444
+ ) -> Tuple[str, str, str, Optional[List[str]]]:
445
+ """
446
+ Parse R-style formula.
447
+
448
+ Supports basic formulas like:
449
+ - "outcome ~ treatment * time"
450
+ - "outcome ~ treatment + time + treatment:time"
451
+ - "outcome ~ treatment * time + covariate1 + covariate2"
452
+
453
+ Parameters
454
+ ----------
455
+ formula : str
456
+ R-style formula string.
457
+ data : pd.DataFrame
458
+ DataFrame to validate column names against.
459
+
460
+ Returns
461
+ -------
462
+ tuple
463
+ (outcome, treatment, time, covariates)
464
+ """
465
+ # Split into LHS and RHS
466
+ if "~" not in formula:
467
+ raise ValueError("Formula must contain '~' to separate outcome from predictors")
468
+
469
+ lhs, rhs = formula.split("~")
470
+ outcome = lhs.strip()
471
+
472
+ # Parse RHS
473
+ rhs = rhs.strip()
474
+
475
+ # Check for interaction term
476
+ if "*" in rhs:
477
+ # Handle "treatment * time" syntax
478
+ parts = rhs.split("*")
479
+ if len(parts) != 2:
480
+ raise ValueError("Currently only supports single interaction (treatment * time)")
481
+
482
+ treatment = parts[0].strip()
483
+ time = parts[1].strip()
484
+
485
+ # Check for additional covariates after interaction
486
+ if "+" in time:
487
+ time_parts = time.split("+")
488
+ time = time_parts[0].strip()
489
+ covariates = [p.strip() for p in time_parts[1:]]
490
+ else:
491
+ covariates = None
492
+
493
+ elif ":" in rhs:
494
+ # Handle explicit interaction syntax
495
+ terms = [t.strip() for t in rhs.split("+")]
496
+ interaction_term = None
497
+ main_effects = []
498
+ covariates = []
499
+
500
+ for term in terms:
501
+ if ":" in term:
502
+ interaction_term = term
503
+ else:
504
+ main_effects.append(term)
505
+
506
+ if interaction_term is None:
507
+ raise ValueError("Formula must contain an interaction term (treatment:time)")
508
+
509
+ treatment, time = [t.strip() for t in interaction_term.split(":")]
510
+
511
+ # Remaining terms after treatment and time are covariates
512
+ for term in main_effects:
513
+ if term != treatment and term != time:
514
+ covariates.append(term)
515
+
516
+ covariates = covariates if covariates else None
517
+ else:
518
+ raise ValueError(
519
+ "Formula must contain interaction term. "
520
+ "Use 'outcome ~ treatment * time' or 'outcome ~ treatment + time + treatment:time'"
521
+ )
522
+
523
+ # Validate columns exist
524
+ for col in [outcome, treatment, time]:
525
+ if col not in data.columns:
526
+ raise ValueError(f"Column '{col}' not found in data")
527
+
528
+ if covariates:
529
+ for cov in covariates:
530
+ if cov not in data.columns:
531
+ raise ValueError(f"Covariate '{cov}' not found in data")
532
+
533
+ return outcome, treatment, time, covariates
534
+
535
+ def _validate_data(
536
+ self,
537
+ data: pd.DataFrame,
538
+ outcome: str,
539
+ treatment: str,
540
+ time: str,
541
+ covariates: Optional[List[str]] = None,
542
+ ) -> None:
543
+ """Validate input data."""
544
+ # Check DataFrame
545
+ if not isinstance(data, pd.DataFrame):
546
+ raise TypeError("data must be a pandas DataFrame")
547
+
548
+ # Check required columns exist
549
+ required_cols = [outcome, treatment, time]
550
+ if covariates:
551
+ required_cols.extend(covariates)
552
+
553
+ missing_cols = [col for col in required_cols if col not in data.columns]
554
+ if missing_cols:
555
+ raise ValueError(f"Missing columns in data: {missing_cols}")
556
+
557
+ # Check for missing values
558
+ for col in required_cols:
559
+ if data[col].isna().any():
560
+ raise ValueError(f"Column '{col}' contains missing values")
561
+
562
+ # Check for sufficient variation
563
+ if data[treatment].nunique() < 2:
564
+ raise ValueError("Treatment variable must have both 0 and 1 values")
565
+ if data[time].nunique() < 2:
566
+ raise ValueError("Time variable must have both 0 and 1 values")
567
+
568
+ def predict(self, data: pd.DataFrame) -> np.ndarray:
569
+ """
570
+ Predict outcomes using fitted model.
571
+
572
+ Parameters
573
+ ----------
574
+ data : pd.DataFrame
575
+ DataFrame with same structure as training data.
576
+
577
+ Returns
578
+ -------
579
+ np.ndarray
580
+ Predicted values.
581
+ """
582
+ if not self.is_fitted_:
583
+ raise RuntimeError("Model must be fitted before calling predict()")
584
+
585
+ # This is a placeholder - would need to store column names
586
+ # for full implementation
587
+ raise NotImplementedError(
588
+ "predict() is not yet implemented. "
589
+ "Use results_.fitted_values for training data predictions."
590
+ )
591
+
592
+ def get_params(self) -> Dict[str, Any]:
593
+ """
594
+ Get estimator parameters (sklearn-compatible).
595
+
596
+ Returns
597
+ -------
598
+ Dict[str, Any]
599
+ Estimator parameters.
600
+ """
601
+ return {
602
+ "robust": self.robust,
603
+ "cluster": self.cluster,
604
+ "alpha": self.alpha,
605
+ "inference": self.inference,
606
+ "n_bootstrap": self.n_bootstrap,
607
+ "bootstrap_weights": self.bootstrap_weights,
608
+ "seed": self.seed,
609
+ "rank_deficient_action": self.rank_deficient_action,
610
+ }
611
+
612
+ def set_params(self, **params) -> "DifferenceInDifferences":
613
+ """
614
+ Set estimator parameters (sklearn-compatible).
615
+
616
+ Parameters
617
+ ----------
618
+ **params
619
+ Estimator parameters.
620
+
621
+ Returns
622
+ -------
623
+ self
624
+ """
625
+ for key, value in params.items():
626
+ if hasattr(self, key):
627
+ setattr(self, key, value)
628
+ else:
629
+ raise ValueError(f"Unknown parameter: {key}")
630
+ return self
631
+
632
+ def summary(self) -> str:
633
+ """
634
+ Get summary of estimation results.
635
+
636
+ Returns
637
+ -------
638
+ str
639
+ Formatted summary.
640
+ """
641
+ if not self.is_fitted_:
642
+ raise RuntimeError("Model must be fitted before calling summary()")
643
+ assert self.results_ is not None
644
+ return self.results_.summary()
645
+
646
+ def print_summary(self) -> None:
647
+ """Print summary to stdout."""
648
+ print(self.summary())
649
+
650
+
651
+ class MultiPeriodDiD(DifferenceInDifferences):
652
+ """
653
+ Multi-Period Difference-in-Differences estimator.
654
+
655
+ Extends the standard DiD to handle multiple pre-treatment and
656
+ post-treatment time periods, providing period-specific treatment
657
+ effects as well as an aggregate average treatment effect.
658
+
659
+ Parameters
660
+ ----------
661
+ robust : bool, default=True
662
+ Whether to use heteroskedasticity-robust standard errors (HC1).
663
+ cluster : str, optional
664
+ Column name for cluster-robust standard errors.
665
+ alpha : float, default=0.05
666
+ Significance level for confidence intervals.
667
+
668
+ Attributes
669
+ ----------
670
+ results_ : MultiPeriodDiDResults
671
+ Estimation results after calling fit().
672
+ is_fitted_ : bool
673
+ Whether the model has been fitted.
674
+
675
+ Examples
676
+ --------
677
+ Basic usage with multiple time periods:
678
+
679
+ >>> import pandas as pd
680
+ >>> from diff_diff import MultiPeriodDiD
681
+ >>>
682
+ >>> # Create sample panel data with 6 time periods
683
+ >>> # Periods 0-2 are pre-treatment, periods 3-5 are post-treatment
684
+ >>> data = create_panel_data() # Your data
685
+ >>>
686
+ >>> # Fit the model
687
+ >>> did = MultiPeriodDiD()
688
+ >>> results = did.fit(
689
+ ... data,
690
+ ... outcome='sales',
691
+ ... treatment='treated',
692
+ ... time='period',
693
+ ... post_periods=[3, 4, 5] # Specify which periods are post-treatment
694
+ ... )
695
+ >>>
696
+ >>> # View period-specific effects
697
+ >>> for period, effect in results.period_effects.items():
698
+ ... print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})")
699
+ >>>
700
+ >>> # View average treatment effect
701
+ >>> print(f"Average ATT: {results.avg_att:.3f}")
702
+
703
+ Notes
704
+ -----
705
+ The model estimates:
706
+
707
+ Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_{t≠ref} δ_t*(D_i × 1{t}) + ε_it
708
+
709
+ Where:
710
+ - D_i is the treatment indicator
711
+ - Period_t are time period dummies (all non-reference periods)
712
+ - D_i × 1{t} are treatment-by-period interactions (all non-reference)
713
+ - δ_t are the period-specific treatment effects
714
+ - The reference period (default: last pre-period) has δ_ref = 0 by construction
715
+
716
+ Pre-treatment δ_t test the parallel trends assumption (should be ≈ 0).
717
+ Post-treatment δ_t estimate dynamic treatment effects.
718
+ The average ATT is computed from post-treatment δ_t only.
719
+ """
720
+
721
+ def fit( # type: ignore[override]
722
+ self,
723
+ data: pd.DataFrame,
724
+ outcome: str,
725
+ treatment: str,
726
+ time: str,
727
+ post_periods: Optional[List[Any]] = None,
728
+ covariates: Optional[List[str]] = None,
729
+ fixed_effects: Optional[List[str]] = None,
730
+ absorb: Optional[List[str]] = None,
731
+ reference_period: Any = None,
732
+ unit: Optional[str] = None,
733
+ ) -> MultiPeriodDiDResults:
734
+ """
735
+ Fit the Multi-Period Difference-in-Differences model.
736
+
737
+ Parameters
738
+ ----------
739
+ data : pd.DataFrame
740
+ DataFrame containing the outcome, treatment, and time variables.
741
+ outcome : str
742
+ Name of the outcome variable column.
743
+ treatment : str
744
+ Name of the treatment group indicator column (0/1). Should be a
745
+ time-invariant ever-treated indicator (D_i = 1 for all periods of
746
+ treated units). If treatment is time-varying (D_it), pre-period
747
+ interaction coefficients will be unidentified.
748
+ time : str
749
+ Name of the time period column (can have multiple values).
750
+ post_periods : list
751
+ List of time period values that are post-treatment.
752
+ All other periods are treated as pre-treatment.
753
+ covariates : list, optional
754
+ List of covariate column names to include as linear controls.
755
+ fixed_effects : list, optional
756
+ List of categorical column names to include as fixed effects.
757
+ absorb : list, optional
758
+ List of categorical column names for high-dimensional fixed effects.
759
+ reference_period : any, optional
760
+ The reference (omitted) time period for the period dummies.
761
+ Defaults to the last pre-treatment period (e=-1 convention).
762
+ unit : str, optional
763
+ Name of the unit identifier column. When provided, checks whether
764
+ treatment timing varies across units and warns if staggered adoption
765
+ is detected (suggests CallawaySantAnna instead). Does NOT affect
766
+ standard error computation -- use the ``cluster`` parameter for
767
+ cluster-robust SEs.
768
+
769
+ Returns
770
+ -------
771
+ MultiPeriodDiDResults
772
+ Object containing period-specific and average treatment effects.
773
+
774
+ Raises
775
+ ------
776
+ ValueError
777
+ If required parameters are missing or data validation fails.
778
+ """
779
+ # Warn if wild bootstrap is requested but not supported
780
+ if self.inference == "wild_bootstrap":
781
+ warnings.warn(
782
+ "Wild bootstrap inference is not yet supported for MultiPeriodDiD. "
783
+ "Using analytical inference instead.",
784
+ UserWarning,
785
+ )
786
+
787
+ # Validate basic inputs
788
+ if outcome is None or treatment is None or time is None:
789
+ raise ValueError("Must provide 'outcome', 'treatment', and 'time'")
790
+
791
+ # Validate columns exist
792
+ self._validate_data(data, outcome, treatment, time, covariates)
793
+
794
+ # Validate treatment is binary
795
+ validate_binary(data[treatment].values, "treatment")
796
+
797
+ # Validate unit column and check for staggered adoption
798
+ if unit is not None:
799
+ if unit not in data.columns:
800
+ raise ValueError(f"Unit column '{unit}' not found in data")
801
+
802
+ # Check for staggered treatment timing and absorbing treatment
803
+ unit_time_sorted = data.sort_values([unit, time])
804
+ adoption_times = {}
805
+ has_reversal = False
806
+ for u, group in unit_time_sorted.groupby(unit):
807
+ d_vals = group[treatment].values
808
+ # Check for treatment reversal (non-absorbing treatment)
809
+ if not has_reversal and len(d_vals) > 1 and np.any(np.diff(d_vals) < 0):
810
+ warnings.warn(
811
+ f"Treatment reversal detected (unit '{u}' transitions from "
812
+ f"treated to untreated). MultiPeriodDiD assumes treatment is "
813
+ f"an absorbing state (once treated, always treated). "
814
+ f"Treatment reversals violate this assumption and may "
815
+ f"produce unreliable estimates.",
816
+ UserWarning,
817
+ stacklevel=2,
818
+ )
819
+ has_reversal = True
820
+ # Only use units with observed 0→1 transition for adoption timing
821
+ # (skip units that are always treated — can't determine adoption time)
822
+ if 0 in d_vals and 1 in d_vals:
823
+ adoption_times[u] = group.loc[group[treatment] == 1, time].iloc[0]
824
+
825
+ if len(adoption_times) > 0:
826
+ unique_adoption = len(set(adoption_times.values()))
827
+ if unique_adoption > 1:
828
+ warnings.warn(
829
+ "Treatment timing varies across units (staggered adoption "
830
+ "detected). MultiPeriodDiD assumes simultaneous adoption "
831
+ "and may produce biased estimates with staggered treatment. "
832
+ "Consider using CallawaySantAnna or SunAbraham instead.",
833
+ UserWarning,
834
+ stacklevel=2,
835
+ )
836
+
837
+ # Check for time-varying treatment (D_it instead of D_i)
838
+ # If any unit has a 0→1 transition, the treatment column is D_it.
839
+ # MultiPeriodDiD expects a time-invariant ever-treated indicator.
840
+ warnings.warn(
841
+ "Treatment indicator varies within units (time-varying "
842
+ "treatment detected). MultiPeriodDiD's event-study "
843
+ "specification expects a time-invariant ever-treated "
844
+ "indicator (D_i = 1 for all periods of eventually-treated "
845
+ "units). With time-varying treatment, pre-period "
846
+ "interaction coefficients will be unidentified. Consider: "
847
+ f"df['ever_treated'] = df.groupby('{unit}')['{treatment}']"
848
+ ".transform('max')",
849
+ UserWarning,
850
+ stacklevel=2,
851
+ )
852
+
853
+ # Get all unique time periods
854
+ all_periods = sorted(data[time].unique())
855
+
856
+ if len(all_periods) < 2:
857
+ raise ValueError("Time variable must have at least 2 unique periods")
858
+
859
+ # Determine pre and post periods
860
+ if post_periods is None:
861
+ # Default: last half of periods are post-treatment
862
+ mid_point = len(all_periods) // 2
863
+ post_periods = all_periods[mid_point:]
864
+ pre_periods = all_periods[:mid_point]
865
+ else:
866
+ post_periods = list(post_periods)
867
+ pre_periods = [p for p in all_periods if p not in post_periods]
868
+
869
+ if len(post_periods) == 0:
870
+ raise ValueError("Must have at least one post-treatment period")
871
+
872
+ if len(pre_periods) == 0:
873
+ raise ValueError("Must have at least one pre-treatment period")
874
+
875
+ if len(pre_periods) < 2:
876
+ warnings.warn(
877
+ "Only one pre-treatment period available. At least 2 pre-periods "
878
+ "are needed to assess parallel trends. The treatment effect estimate "
879
+ "is still valid, but pre-period coefficients for parallel trends "
880
+ "testing are not available.",
881
+ UserWarning,
882
+ stacklevel=2,
883
+ )
884
+
885
+ # Validate post_periods are in the data
886
+ for p in post_periods:
887
+ if p not in all_periods:
888
+ raise ValueError(f"Post-period '{p}' not found in time column")
889
+
890
+ # Determine reference period (omitted dummy)
891
+ if reference_period is None:
892
+ # Default: last pre-period (e=-1 convention, matches fixest)
893
+ if len(pre_periods) > 1:
894
+ warnings.warn(
895
+ f"The default reference_period has changed from the first "
896
+ f"pre-period ({pre_periods[0]}) to the last pre-period "
897
+ f"({pre_periods[-1]}) to match the standard e=-1 convention "
898
+ f"(as used by fixest, did, etc.). "
899
+ f"To silence this warning, pass "
900
+ f"reference_period={pre_periods[-1]} explicitly.",
901
+ FutureWarning,
902
+ stacklevel=2,
903
+ )
904
+ reference_period = pre_periods[-1]
905
+ elif reference_period not in all_periods:
906
+ raise ValueError(f"Reference period '{reference_period}' not found in time column")
907
+
908
+ # Disallow post-period reference (downstream logic assumes reference is pre-period)
909
+ if reference_period in post_periods:
910
+ raise ValueError(
911
+ f"reference_period={reference_period} is a post-treatment period. "
912
+ f"The reference period must be a pre-treatment period "
913
+ f"(e.g., the last pre-period {pre_periods[-1]}). "
914
+ f"Post-period references are not supported because the reference "
915
+ f"period is excluded from estimation, which would bias avg_att "
916
+ f"and break downstream inference."
917
+ )
918
+
919
+ # Validate fixed effects and absorb columns
920
+ if fixed_effects:
921
+ for fe in fixed_effects:
922
+ if fe not in data.columns:
923
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
924
+ if absorb:
925
+ for ab in absorb:
926
+ if ab not in data.columns:
927
+ raise ValueError(f"Absorb column '{ab}' not found in data")
928
+
929
+ # Handle absorbed fixed effects (within-transformation)
930
+ working_data = data.copy()
931
+ n_absorbed_effects = 0
932
+
933
+ if absorb:
934
+ vars_to_demean = [outcome] + (covariates or [])
935
+ for ab_var in absorb:
936
+ working_data, n_fe = demean_by_group(
937
+ working_data, vars_to_demean, ab_var, inplace=True
938
+ )
939
+ n_absorbed_effects += n_fe
940
+
941
+ # Extract outcome and treatment
942
+ y = working_data[outcome].values.astype(float)
943
+ d = working_data[treatment].values.astype(float)
944
+ t = working_data[time].values
945
+
946
+ # Build design matrix
947
+ # Start with intercept and treatment main effect
948
+ X = np.column_stack([np.ones(len(y)), d])
949
+ var_names = ["const", treatment]
950
+
951
+ # Add period dummies (excluding reference period)
952
+ non_ref_periods = [p for p in all_periods if p != reference_period]
953
+ period_dummy_indices = {} # Map period -> column index in X
954
+
955
+ for period in non_ref_periods:
956
+ period_dummy = (t == period).astype(float)
957
+ X = np.column_stack([X, period_dummy])
958
+ var_names.append(f"period_{period}")
959
+ period_dummy_indices[period] = X.shape[1] - 1
960
+
961
+ # Add treatment × period interactions for ALL non-reference periods
962
+ # Pre-period interactions test parallel trends; post-period interactions
963
+ # estimate dynamic treatment effects
964
+ interaction_indices = {} # Map period -> column index in X
965
+
966
+ for period in non_ref_periods:
967
+ interaction = d * (t == period).astype(float)
968
+ X = np.column_stack([X, interaction])
969
+ var_names.append(f"{treatment}:period_{period}")
970
+ interaction_indices[period] = X.shape[1] - 1
971
+
972
+ # Add covariates if provided
973
+ if covariates:
974
+ for cov in covariates:
975
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
976
+ var_names.append(cov)
977
+
978
+ # Add fixed effects as dummy variables
979
+ if fixed_effects:
980
+ for fe in fixed_effects:
981
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
982
+ for col in dummies.columns:
983
+ X = np.column_stack([X, dummies[col].values.astype(float)])
984
+ var_names.append(col)
985
+
986
+ # Fit OLS using unified backend
987
+ # Pass cluster_ids to solve_ols for proper vcov computation
988
+ # This handles rank-deficient matrices by returning NaN for dropped columns
989
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
990
+
991
+ # Note: Wild bootstrap for multi-period effects is complex (multiple coefficients)
992
+ # For now, we use analytical inference even if inference="wild_bootstrap"
993
+ coefficients, residuals, fitted, vcov = solve_ols(
994
+ X,
995
+ y,
996
+ return_fitted=True,
997
+ return_vcov=True,
998
+ cluster_ids=cluster_ids,
999
+ column_names=var_names,
1000
+ rank_deficient_action=self.rank_deficient_action,
1001
+ )
1002
+ r_squared = compute_r_squared(y, residuals)
1003
+
1004
+ # Degrees of freedom using effective rank (non-NaN coefficients)
1005
+ k_effective = int(np.sum(~np.isnan(coefficients)))
1006
+ df = len(y) - k_effective - n_absorbed_effects
1007
+
1008
+ # For non-robust, non-clustered case, we need homoskedastic vcov
1009
+ # solve_ols returns HC1 by default, so compute homoskedastic if needed
1010
+ if not self.robust and self.cluster is None:
1011
+ n = len(y)
1012
+ mse = np.sum(residuals**2) / (n - k_effective)
1013
+ # Use solve() instead of inv() for numerical stability
1014
+ # Only compute for identified columns (non-NaN coefficients)
1015
+ identified_mask = ~np.isnan(coefficients)
1016
+ if np.all(identified_mask):
1017
+ vcov = np.linalg.solve(X.T @ X, mse * np.eye(X.shape[1]))
1018
+ else:
1019
+ # For rank-deficient case, compute vcov on reduced matrix then expand
1020
+ X_reduced = X[:, identified_mask]
1021
+ vcov_reduced = np.linalg.solve(
1022
+ X_reduced.T @ X_reduced, mse * np.eye(X_reduced.shape[1])
1023
+ )
1024
+ # Expand to full size with NaN for dropped columns
1025
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1026
+ vcov[np.ix_(identified_mask, identified_mask)] = vcov_reduced
1027
+
1028
+ # Extract period-specific treatment effects for ALL non-reference periods
1029
+ period_effects = {}
1030
+ post_effect_values = []
1031
+ post_effect_indices = []
1032
+
1033
+ for period in non_ref_periods:
1034
+ idx = interaction_indices[period]
1035
+ effect = coefficients[idx]
1036
+ se = np.sqrt(vcov[idx, idx])
1037
+ if np.isfinite(se) and se > 0:
1038
+ t_stat = effect / se
1039
+ p_value = compute_p_value(t_stat, df=df)
1040
+ conf_int = compute_confidence_interval(effect, se, self.alpha, df=df)
1041
+ else:
1042
+ t_stat = np.nan
1043
+ p_value = np.nan
1044
+ conf_int = (np.nan, np.nan)
1045
+
1046
+ period_effects[period] = PeriodEffect(
1047
+ period=period,
1048
+ effect=effect,
1049
+ se=se,
1050
+ t_stat=t_stat,
1051
+ p_value=p_value,
1052
+ conf_int=conf_int,
1053
+ )
1054
+
1055
+ if period in post_periods:
1056
+ post_effect_values.append(effect)
1057
+ post_effect_indices.append(idx)
1058
+
1059
+ # Compute average treatment effect (post-periods only)
1060
+ # R-style NA propagation: if ANY post-period effect is NaN, average is undefined
1061
+ effect_arr = np.array(post_effect_values)
1062
+
1063
+ if np.any(np.isnan(effect_arr)):
1064
+ # Some period effects are NaN (unidentified) - cannot compute valid average
1065
+ # This follows R's default behavior where mean(c(1, 2, NA)) returns NA
1066
+ avg_att = np.nan
1067
+ avg_se = np.nan
1068
+ avg_t_stat = np.nan
1069
+ avg_p_value = np.nan
1070
+ avg_conf_int = (np.nan, np.nan)
1071
+ else:
1072
+ # All effects identified - compute average normally
1073
+ avg_att = float(np.mean(effect_arr))
1074
+
1075
+ # Standard error of average: need to account for covariance
1076
+ n_post = len(post_periods)
1077
+ sub_vcov = vcov[np.ix_(post_effect_indices, post_effect_indices)]
1078
+ avg_var = np.sum(sub_vcov) / (n_post**2)
1079
+
1080
+ if np.isnan(avg_var) or avg_var < 0:
1081
+ # Vcov has NaN (dropped columns) - propagate NaN
1082
+ avg_se = np.nan
1083
+ avg_t_stat = np.nan
1084
+ avg_p_value = np.nan
1085
+ avg_conf_int = (np.nan, np.nan)
1086
+ else:
1087
+ avg_se = float(np.sqrt(avg_var))
1088
+ if np.isfinite(avg_se) and avg_se > 0:
1089
+ avg_t_stat = avg_att / avg_se
1090
+ avg_p_value = compute_p_value(avg_t_stat, df=df)
1091
+ avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df)
1092
+ else:
1093
+ # Zero SE (degenerate case)
1094
+ avg_t_stat = np.nan
1095
+ avg_p_value = np.nan
1096
+ avg_conf_int = (np.nan, np.nan)
1097
+
1098
+ # Count observations
1099
+ n_treated = int(np.sum(d))
1100
+ n_control = int(np.sum(1 - d))
1101
+
1102
+ # Create coefficient dictionary
1103
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
1104
+
1105
+ # Store results
1106
+ self.results_ = MultiPeriodDiDResults(
1107
+ period_effects=period_effects,
1108
+ avg_att=avg_att,
1109
+ avg_se=avg_se,
1110
+ avg_t_stat=avg_t_stat,
1111
+ avg_p_value=avg_p_value,
1112
+ avg_conf_int=avg_conf_int,
1113
+ n_obs=len(y),
1114
+ n_treated=n_treated,
1115
+ n_control=n_control,
1116
+ pre_periods=pre_periods,
1117
+ post_periods=post_periods,
1118
+ alpha=self.alpha,
1119
+ coefficients=coef_dict,
1120
+ vcov=vcov,
1121
+ residuals=residuals,
1122
+ fitted_values=fitted,
1123
+ r_squared=r_squared,
1124
+ reference_period=reference_period,
1125
+ interaction_indices=interaction_indices,
1126
+ )
1127
+
1128
+ self._coefficients = coefficients
1129
+ self._vcov = vcov
1130
+ self.is_fitted_ = True
1131
+
1132
+ return self.results_
1133
+
1134
+ def summary(self) -> str:
1135
+ """
1136
+ Get summary of estimation results.
1137
+
1138
+ Returns
1139
+ -------
1140
+ str
1141
+ Formatted summary.
1142
+ """
1143
+ if not self.is_fitted_:
1144
+ raise RuntimeError("Model must be fitted before calling summary()")
1145
+ assert self.results_ is not None
1146
+ return self.results_.summary()
1147
+
1148
+
1149
+ # Re-export estimators from submodules for backward compatibility
1150
+ # These can also be imported directly from their respective modules:
1151
+ # - from diff_diff.twfe import TwoWayFixedEffects
1152
+ # - from diff_diff.synthetic_did import SyntheticDiD
1153
+ from diff_diff.synthetic_did import SyntheticDiD # noqa: E402
1154
+ from diff_diff.twfe import TwoWayFixedEffects # noqa: E402
1155
+
1156
+ __all__ = [
1157
+ "DifferenceInDifferences",
1158
+ "MultiPeriodDiD",
1159
+ "TwoWayFixedEffects",
1160
+ "SyntheticDiD",
1161
+ ]