diff-diff 2.1.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1000 @@
1
+ """
2
+ Difference-in-Differences estimators with sklearn-like API.
3
+
4
+ This module contains the core DiD estimators:
5
+ - DifferenceInDifferences: Basic 2x2 DiD estimator
6
+ - MultiPeriodDiD: Event-study style DiD with period-specific treatment effects
7
+
8
+ Additional estimators are in separate modules:
9
+ - TwoWayFixedEffects: See diff_diff.twfe
10
+ - SyntheticDiD: See diff_diff.synthetic_did
11
+
12
+ For backward compatibility, all estimators are re-exported from this module.
13
+ """
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ from diff_diff.linalg import (
21
+ LinearRegression,
22
+ compute_r_squared,
23
+ compute_robust_vcov,
24
+ solve_ols,
25
+ )
26
+ from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
27
+ from diff_diff.utils import (
28
+ WildBootstrapResults,
29
+ compute_confidence_interval,
30
+ compute_p_value,
31
+ demean_by_group,
32
+ validate_binary,
33
+ wild_bootstrap_se,
34
+ )
35
+
36
+
37
+ class DifferenceInDifferences:
38
+ """
39
+ Difference-in-Differences estimator with sklearn-like interface.
40
+
41
+ Estimates the Average Treatment effect on the Treated (ATT) using
42
+ the canonical 2x2 DiD design or panel data with two-way fixed effects.
43
+
44
+ Parameters
45
+ ----------
46
+ formula : str, optional
47
+ R-style formula for the model (e.g., "outcome ~ treated * post").
48
+ If provided, overrides column name parameters.
49
+ robust : bool, default=True
50
+ Whether to use heteroskedasticity-robust standard errors (HC1).
51
+ cluster : str, optional
52
+ Column name for cluster-robust standard errors.
53
+ alpha : float, default=0.05
54
+ Significance level for confidence intervals.
55
+ inference : str, default="analytical"
56
+ Inference method: "analytical" for standard asymptotic inference,
57
+ or "wild_bootstrap" for wild cluster bootstrap (recommended when
58
+ number of clusters is small, <50).
59
+ n_bootstrap : int, default=999
60
+ Number of bootstrap replications when inference="wild_bootstrap".
61
+ bootstrap_weights : str, default="rademacher"
62
+ Type of bootstrap weights: "rademacher" (standard), "webb"
63
+ (recommended for <10 clusters), or "mammen" (skewness correction).
64
+ seed : int, optional
65
+ Random seed for reproducibility when using bootstrap inference.
66
+ If None (default), results will vary between runs.
67
+
68
+ Attributes
69
+ ----------
70
+ results_ : DiDResults
71
+ Estimation results after calling fit().
72
+ is_fitted_ : bool
73
+ Whether the model has been fitted.
74
+
75
+ Examples
76
+ --------
77
+ Basic usage with a DataFrame:
78
+
79
+ >>> import pandas as pd
80
+ >>> from diff_diff import DifferenceInDifferences
81
+ >>>
82
+ >>> # Create sample data
83
+ >>> data = pd.DataFrame({
84
+ ... 'outcome': [10, 11, 15, 18, 9, 10, 12, 13],
85
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
86
+ ... 'post': [0, 0, 1, 1, 0, 0, 1, 1]
87
+ ... })
88
+ >>>
89
+ >>> # Fit the model
90
+ >>> did = DifferenceInDifferences()
91
+ >>> results = did.fit(data, outcome='outcome', treatment='treated', time='post')
92
+ >>>
93
+ >>> # View results
94
+ >>> print(results.att) # ATT estimate
95
+ >>> results.print_summary() # Full summary table
96
+
97
+ Using formula interface:
98
+
99
+ >>> did = DifferenceInDifferences()
100
+ >>> results = did.fit(data, formula='outcome ~ treated * post')
101
+
102
+ Notes
103
+ -----
104
+ The ATT is computed using the standard DiD formula:
105
+
106
+ ATT = (E[Y|D=1,T=1] - E[Y|D=1,T=0]) - (E[Y|D=0,T=1] - E[Y|D=0,T=0])
107
+
108
+ Or equivalently via OLS regression:
109
+
110
+ Y = α + β₁*D + β₂*T + β₃*(D×T) + ε
111
+
112
+ Where β₃ is the ATT.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ robust: bool = True,
118
+ cluster: Optional[str] = None,
119
+ alpha: float = 0.05,
120
+ inference: str = "analytical",
121
+ n_bootstrap: int = 999,
122
+ bootstrap_weights: str = "rademacher",
123
+ seed: Optional[int] = None
124
+ ):
125
+ self.robust = robust
126
+ self.cluster = cluster
127
+ self.alpha = alpha
128
+ self.inference = inference
129
+ self.n_bootstrap = n_bootstrap
130
+ self.bootstrap_weights = bootstrap_weights
131
+ self.seed = seed
132
+
133
+ self.is_fitted_ = False
134
+ self.results_ = None
135
+ self._coefficients = None
136
+ self._vcov = None
137
+ self._bootstrap_results = None # Store WildBootstrapResults if used
138
+
139
+ def fit(
140
+ self,
141
+ data: pd.DataFrame,
142
+ outcome: Optional[str] = None,
143
+ treatment: Optional[str] = None,
144
+ time: Optional[str] = None,
145
+ formula: Optional[str] = None,
146
+ covariates: Optional[List[str]] = None,
147
+ fixed_effects: Optional[List[str]] = None,
148
+ absorb: Optional[List[str]] = None
149
+ ) -> DiDResults:
150
+ """
151
+ Fit the Difference-in-Differences model.
152
+
153
+ Parameters
154
+ ----------
155
+ data : pd.DataFrame
156
+ DataFrame containing the outcome, treatment, and time variables.
157
+ outcome : str
158
+ Name of the outcome variable column.
159
+ treatment : str
160
+ Name of the treatment group indicator column (0/1).
161
+ time : str
162
+ Name of the post-treatment period indicator column (0/1).
163
+ formula : str, optional
164
+ R-style formula (e.g., "outcome ~ treated * post").
165
+ If provided, overrides outcome, treatment, and time parameters.
166
+ covariates : list, optional
167
+ List of covariate column names to include as linear controls.
168
+ fixed_effects : list, optional
169
+ List of categorical column names to include as fixed effects.
170
+ Creates dummy variables for each category (drops first level).
171
+ Use for low-dimensional fixed effects (e.g., industry, region).
172
+ absorb : list, optional
173
+ List of categorical column names for high-dimensional fixed effects.
174
+ Uses within-transformation (demeaning) instead of dummy variables.
175
+ More efficient for large numbers of categories (e.g., firm, individual).
176
+
177
+ Returns
178
+ -------
179
+ DiDResults
180
+ Object containing estimation results.
181
+
182
+ Raises
183
+ ------
184
+ ValueError
185
+ If required parameters are missing or data validation fails.
186
+
187
+ Examples
188
+ --------
189
+ Using fixed effects (dummy variables):
190
+
191
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
192
+ ... fixed_effects=['state', 'industry'])
193
+
194
+ Using absorbed fixed effects (within-transformation):
195
+
196
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
197
+ ... absorb=['firm_id'])
198
+ """
199
+ # Parse formula if provided
200
+ if formula is not None:
201
+ outcome, treatment, time, covariates = self._parse_formula(formula, data)
202
+ elif outcome is None or treatment is None or time is None:
203
+ raise ValueError(
204
+ "Must provide either 'formula' or all of 'outcome', 'treatment', and 'time'"
205
+ )
206
+
207
+ # Validate inputs
208
+ self._validate_data(data, outcome, treatment, time, covariates)
209
+
210
+ # Validate binary variables BEFORE any transformations
211
+ validate_binary(data[treatment].values, "treatment")
212
+ validate_binary(data[time].values, "time")
213
+
214
+ # Validate fixed effects and absorb columns
215
+ if fixed_effects:
216
+ for fe in fixed_effects:
217
+ if fe not in data.columns:
218
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
219
+ if absorb:
220
+ for ab in absorb:
221
+ if ab not in data.columns:
222
+ raise ValueError(f"Absorb column '{ab}' not found in data")
223
+
224
+ # Handle absorbed fixed effects (within-transformation)
225
+ working_data = data.copy()
226
+ absorbed_vars = []
227
+ n_absorbed_effects = 0
228
+
229
+ if absorb:
230
+ # Apply within-transformation for each absorbed variable
231
+ # Only demean outcome and covariates, NOT treatment/time indicators
232
+ # Treatment is typically time-invariant (within unit), and time is
233
+ # unit-invariant, so demeaning them would create multicollinearity
234
+ vars_to_demean = [outcome] + (covariates or [])
235
+ for ab_var in absorb:
236
+ working_data, n_fe = demean_by_group(
237
+ working_data, vars_to_demean, ab_var, inplace=True
238
+ )
239
+ n_absorbed_effects += n_fe
240
+ absorbed_vars.append(ab_var)
241
+
242
+ # Extract variables (may be demeaned if absorb was used)
243
+ y = working_data[outcome].values.astype(float)
244
+ d = working_data[treatment].values.astype(float)
245
+ t = working_data[time].values.astype(float)
246
+
247
+ # Create interaction term
248
+ dt = d * t
249
+
250
+ # Build design matrix
251
+ X = np.column_stack([np.ones(len(y)), d, t, dt])
252
+ var_names = ["const", treatment, time, f"{treatment}:{time}"]
253
+
254
+ # Add covariates if provided
255
+ if covariates:
256
+ for cov in covariates:
257
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
258
+ var_names.append(cov)
259
+
260
+ # Add fixed effects as dummy variables
261
+ if fixed_effects:
262
+ for fe in fixed_effects:
263
+ # Create dummies, drop first category to avoid multicollinearity
264
+ # Use working_data to be consistent with absorbed FE if both are used
265
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
266
+ for col in dummies.columns:
267
+ X = np.column_stack([X, dummies[col].values.astype(float)])
268
+ var_names.append(col)
269
+
270
+ # Extract ATT index (coefficient on interaction term)
271
+ att_idx = 3 # Index of interaction term
272
+ att_var_name = f"{treatment}:{time}"
273
+ assert var_names[att_idx] == att_var_name, (
274
+ f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, "
275
+ f"but found '{var_names[att_idx]}'"
276
+ )
277
+
278
+ # Always use LinearRegression for initial fit (unified code path)
279
+ # For wild bootstrap, we don't need cluster SEs from the initial fit
280
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
281
+ reg = LinearRegression(
282
+ include_intercept=False, # Intercept already in X
283
+ robust=self.robust,
284
+ cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None,
285
+ alpha=self.alpha,
286
+ ).fit(X, y, df_adjustment=n_absorbed_effects)
287
+
288
+ coefficients = reg.coefficients_
289
+ residuals = reg.residuals_
290
+ fitted = reg.fitted_values_
291
+ att = coefficients[att_idx]
292
+
293
+ # Get inference - either from bootstrap or analytical
294
+ if self.inference == "wild_bootstrap" and self.cluster is not None:
295
+ # Override with wild cluster bootstrap inference
296
+ se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
297
+ X, y, residuals, cluster_ids, att_idx
298
+ )
299
+ else:
300
+ # Use analytical inference from LinearRegression
301
+ vcov = reg.vcov_
302
+ inference = reg.get_inference(att_idx)
303
+ se = inference.se
304
+ t_stat = inference.t_stat
305
+ p_value = inference.p_value
306
+ conf_int = inference.conf_int
307
+
308
+ r_squared = compute_r_squared(y, residuals)
309
+
310
+ # Count observations
311
+ n_treated = int(np.sum(d))
312
+ n_control = int(np.sum(1 - d))
313
+
314
+ # Create coefficient dictionary
315
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
316
+
317
+ # Determine inference method and bootstrap info
318
+ inference_method = "analytical"
319
+ n_bootstrap_used = None
320
+ n_clusters_used = None
321
+ if self._bootstrap_results is not None:
322
+ inference_method = "wild_bootstrap"
323
+ n_bootstrap_used = self._bootstrap_results.n_bootstrap
324
+ n_clusters_used = self._bootstrap_results.n_clusters
325
+
326
+ # Store results
327
+ self.results_ = DiDResults(
328
+ att=att,
329
+ se=se,
330
+ t_stat=t_stat,
331
+ p_value=p_value,
332
+ conf_int=conf_int,
333
+ n_obs=len(y),
334
+ n_treated=n_treated,
335
+ n_control=n_control,
336
+ alpha=self.alpha,
337
+ coefficients=coef_dict,
338
+ vcov=vcov,
339
+ residuals=residuals,
340
+ fitted_values=fitted,
341
+ r_squared=r_squared,
342
+ inference_method=inference_method,
343
+ n_bootstrap=n_bootstrap_used,
344
+ n_clusters=n_clusters_used,
345
+ )
346
+
347
+ self._coefficients = coefficients
348
+ self._vcov = vcov
349
+ self.is_fitted_ = True
350
+
351
+ return self.results_
352
+
353
+ def _fit_ols(
354
+ self, X: np.ndarray, y: np.ndarray
355
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
356
+ """
357
+ Fit OLS regression.
358
+
359
+ This method is kept for backwards compatibility. Internally uses the
360
+ unified solve_ols from diff_diff.linalg for optimized computation.
361
+
362
+ Parameters
363
+ ----------
364
+ X : np.ndarray
365
+ Design matrix.
366
+ y : np.ndarray
367
+ Outcome vector.
368
+
369
+ Returns
370
+ -------
371
+ tuple
372
+ (coefficients, residuals, fitted_values, r_squared)
373
+ """
374
+ # Use unified OLS backend
375
+ coefficients, residuals, fitted, _ = solve_ols(
376
+ X, y, return_fitted=True, return_vcov=False
377
+ )
378
+ r_squared = compute_r_squared(y, residuals)
379
+
380
+ return coefficients, residuals, fitted, r_squared
381
+
382
+ def _run_wild_bootstrap_inference(
383
+ self,
384
+ X: np.ndarray,
385
+ y: np.ndarray,
386
+ residuals: np.ndarray,
387
+ cluster_ids: np.ndarray,
388
+ coefficient_index: int,
389
+ ) -> Tuple[float, float, Tuple[float, float], float, np.ndarray, WildBootstrapResults]:
390
+ """
391
+ Run wild cluster bootstrap inference.
392
+
393
+ Parameters
394
+ ----------
395
+ X : np.ndarray
396
+ Design matrix.
397
+ y : np.ndarray
398
+ Outcome vector.
399
+ residuals : np.ndarray
400
+ OLS residuals.
401
+ cluster_ids : np.ndarray
402
+ Cluster identifiers for each observation.
403
+ coefficient_index : int
404
+ Index of the coefficient to compute inference for.
405
+
406
+ Returns
407
+ -------
408
+ tuple
409
+ (se, p_value, conf_int, t_stat, vcov, bootstrap_results)
410
+ """
411
+ bootstrap_results = wild_bootstrap_se(
412
+ X, y, residuals, cluster_ids,
413
+ coefficient_index=coefficient_index,
414
+ n_bootstrap=self.n_bootstrap,
415
+ weight_type=self.bootstrap_weights,
416
+ alpha=self.alpha,
417
+ seed=self.seed,
418
+ return_distribution=False
419
+ )
420
+ self._bootstrap_results = bootstrap_results
421
+
422
+ se = bootstrap_results.se
423
+ p_value = bootstrap_results.p_value
424
+ conf_int = (bootstrap_results.ci_lower, bootstrap_results.ci_upper)
425
+ t_stat = bootstrap_results.t_stat_original
426
+
427
+ # Also compute vcov for storage (using cluster-robust for consistency)
428
+ vcov = compute_robust_vcov(X, residuals, cluster_ids)
429
+
430
+ return se, p_value, conf_int, t_stat, vcov, bootstrap_results
431
+
432
+ def _parse_formula(
433
+ self, formula: str, data: pd.DataFrame
434
+ ) -> Tuple[str, str, str, Optional[List[str]]]:
435
+ """
436
+ Parse R-style formula.
437
+
438
+ Supports basic formulas like:
439
+ - "outcome ~ treatment * time"
440
+ - "outcome ~ treatment + time + treatment:time"
441
+ - "outcome ~ treatment * time + covariate1 + covariate2"
442
+
443
+ Parameters
444
+ ----------
445
+ formula : str
446
+ R-style formula string.
447
+ data : pd.DataFrame
448
+ DataFrame to validate column names against.
449
+
450
+ Returns
451
+ -------
452
+ tuple
453
+ (outcome, treatment, time, covariates)
454
+ """
455
+ # Split into LHS and RHS
456
+ if "~" not in formula:
457
+ raise ValueError("Formula must contain '~' to separate outcome from predictors")
458
+
459
+ lhs, rhs = formula.split("~")
460
+ outcome = lhs.strip()
461
+
462
+ # Parse RHS
463
+ rhs = rhs.strip()
464
+
465
+ # Check for interaction term
466
+ if "*" in rhs:
467
+ # Handle "treatment * time" syntax
468
+ parts = rhs.split("*")
469
+ if len(parts) != 2:
470
+ raise ValueError("Currently only supports single interaction (treatment * time)")
471
+
472
+ treatment = parts[0].strip()
473
+ time = parts[1].strip()
474
+
475
+ # Check for additional covariates after interaction
476
+ if "+" in time:
477
+ time_parts = time.split("+")
478
+ time = time_parts[0].strip()
479
+ covariates = [p.strip() for p in time_parts[1:]]
480
+ else:
481
+ covariates = None
482
+
483
+ elif ":" in rhs:
484
+ # Handle explicit interaction syntax
485
+ terms = [t.strip() for t in rhs.split("+")]
486
+ interaction_term = None
487
+ main_effects = []
488
+ covariates = []
489
+
490
+ for term in terms:
491
+ if ":" in term:
492
+ interaction_term = term
493
+ else:
494
+ main_effects.append(term)
495
+
496
+ if interaction_term is None:
497
+ raise ValueError("Formula must contain an interaction term (treatment:time)")
498
+
499
+ treatment, time = [t.strip() for t in interaction_term.split(":")]
500
+
501
+ # Remaining terms after treatment and time are covariates
502
+ for term in main_effects:
503
+ if term != treatment and term != time:
504
+ covariates.append(term)
505
+
506
+ covariates = covariates if covariates else None
507
+ else:
508
+ raise ValueError(
509
+ "Formula must contain interaction term. "
510
+ "Use 'outcome ~ treatment * time' or 'outcome ~ treatment + time + treatment:time'"
511
+ )
512
+
513
+ # Validate columns exist
514
+ for col in [outcome, treatment, time]:
515
+ if col not in data.columns:
516
+ raise ValueError(f"Column '{col}' not found in data")
517
+
518
+ if covariates:
519
+ for cov in covariates:
520
+ if cov not in data.columns:
521
+ raise ValueError(f"Covariate '{cov}' not found in data")
522
+
523
+ return outcome, treatment, time, covariates
524
+
525
+ def _validate_data(
526
+ self,
527
+ data: pd.DataFrame,
528
+ outcome: str,
529
+ treatment: str,
530
+ time: str,
531
+ covariates: Optional[List[str]] = None
532
+ ) -> None:
533
+ """Validate input data."""
534
+ # Check DataFrame
535
+ if not isinstance(data, pd.DataFrame):
536
+ raise TypeError("data must be a pandas DataFrame")
537
+
538
+ # Check required columns exist
539
+ required_cols = [outcome, treatment, time]
540
+ if covariates:
541
+ required_cols.extend(covariates)
542
+
543
+ missing_cols = [col for col in required_cols if col not in data.columns]
544
+ if missing_cols:
545
+ raise ValueError(f"Missing columns in data: {missing_cols}")
546
+
547
+ # Check for missing values
548
+ for col in required_cols:
549
+ if data[col].isna().any():
550
+ raise ValueError(f"Column '{col}' contains missing values")
551
+
552
+ # Check for sufficient variation
553
+ if data[treatment].nunique() < 2:
554
+ raise ValueError("Treatment variable must have both 0 and 1 values")
555
+ if data[time].nunique() < 2:
556
+ raise ValueError("Time variable must have both 0 and 1 values")
557
+
558
+ def predict(self, data: pd.DataFrame) -> np.ndarray:
559
+ """
560
+ Predict outcomes using fitted model.
561
+
562
+ Parameters
563
+ ----------
564
+ data : pd.DataFrame
565
+ DataFrame with same structure as training data.
566
+
567
+ Returns
568
+ -------
569
+ np.ndarray
570
+ Predicted values.
571
+ """
572
+ if not self.is_fitted_:
573
+ raise RuntimeError("Model must be fitted before calling predict()")
574
+
575
+ # This is a placeholder - would need to store column names
576
+ # for full implementation
577
+ raise NotImplementedError(
578
+ "predict() is not yet implemented. "
579
+ "Use results_.fitted_values for training data predictions."
580
+ )
581
+
582
+ def get_params(self) -> Dict[str, Any]:
583
+ """
584
+ Get estimator parameters (sklearn-compatible).
585
+
586
+ Returns
587
+ -------
588
+ Dict[str, Any]
589
+ Estimator parameters.
590
+ """
591
+ return {
592
+ "robust": self.robust,
593
+ "cluster": self.cluster,
594
+ "alpha": self.alpha,
595
+ "inference": self.inference,
596
+ "n_bootstrap": self.n_bootstrap,
597
+ "bootstrap_weights": self.bootstrap_weights,
598
+ "seed": self.seed,
599
+ }
600
+
601
+ def set_params(self, **params) -> "DifferenceInDifferences":
602
+ """
603
+ Set estimator parameters (sklearn-compatible).
604
+
605
+ Parameters
606
+ ----------
607
+ **params
608
+ Estimator parameters.
609
+
610
+ Returns
611
+ -------
612
+ self
613
+ """
614
+ for key, value in params.items():
615
+ if hasattr(self, key):
616
+ setattr(self, key, value)
617
+ else:
618
+ raise ValueError(f"Unknown parameter: {key}")
619
+ return self
620
+
621
+ def summary(self) -> str:
622
+ """
623
+ Get summary of estimation results.
624
+
625
+ Returns
626
+ -------
627
+ str
628
+ Formatted summary.
629
+ """
630
+ if not self.is_fitted_:
631
+ raise RuntimeError("Model must be fitted before calling summary()")
632
+ assert self.results_ is not None
633
+ return self.results_.summary()
634
+
635
+ def print_summary(self) -> None:
636
+ """Print summary to stdout."""
637
+ print(self.summary())
638
+
639
+
640
+ class MultiPeriodDiD(DifferenceInDifferences):
641
+ """
642
+ Multi-Period Difference-in-Differences estimator.
643
+
644
+ Extends the standard DiD to handle multiple pre-treatment and
645
+ post-treatment time periods, providing period-specific treatment
646
+ effects as well as an aggregate average treatment effect.
647
+
648
+ Parameters
649
+ ----------
650
+ robust : bool, default=True
651
+ Whether to use heteroskedasticity-robust standard errors (HC1).
652
+ cluster : str, optional
653
+ Column name for cluster-robust standard errors.
654
+ alpha : float, default=0.05
655
+ Significance level for confidence intervals.
656
+
657
+ Attributes
658
+ ----------
659
+ results_ : MultiPeriodDiDResults
660
+ Estimation results after calling fit().
661
+ is_fitted_ : bool
662
+ Whether the model has been fitted.
663
+
664
+ Examples
665
+ --------
666
+ Basic usage with multiple time periods:
667
+
668
+ >>> import pandas as pd
669
+ >>> from diff_diff import MultiPeriodDiD
670
+ >>>
671
+ >>> # Create sample panel data with 6 time periods
672
+ >>> # Periods 0-2 are pre-treatment, periods 3-5 are post-treatment
673
+ >>> data = create_panel_data() # Your data
674
+ >>>
675
+ >>> # Fit the model
676
+ >>> did = MultiPeriodDiD()
677
+ >>> results = did.fit(
678
+ ... data,
679
+ ... outcome='sales',
680
+ ... treatment='treated',
681
+ ... time='period',
682
+ ... post_periods=[3, 4, 5] # Specify which periods are post-treatment
683
+ ... )
684
+ >>>
685
+ >>> # View period-specific effects
686
+ >>> for period, effect in results.period_effects.items():
687
+ ... print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})")
688
+ >>>
689
+ >>> # View average treatment effect
690
+ >>> print(f"Average ATT: {results.avg_att:.3f}")
691
+
692
+ Notes
693
+ -----
694
+ The model estimates:
695
+
696
+ Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_t∈post δ_t*(D_i × Post_t) + ε_it
697
+
698
+ Where:
699
+ - D_i is the treatment indicator
700
+ - Period_t are time period dummies
701
+ - D_i × Post_t are treatment-by-post-period interactions
702
+ - δ_t are the period-specific treatment effects
703
+
704
+ The average ATT is computed as the mean of the δ_t coefficients.
705
+ """
706
+
707
+ def fit( # type: ignore[override]
708
+ self,
709
+ data: pd.DataFrame,
710
+ outcome: str,
711
+ treatment: str,
712
+ time: str,
713
+ post_periods: Optional[List[Any]] = None,
714
+ covariates: Optional[List[str]] = None,
715
+ fixed_effects: Optional[List[str]] = None,
716
+ absorb: Optional[List[str]] = None,
717
+ reference_period: Any = None
718
+ ) -> MultiPeriodDiDResults:
719
+ """
720
+ Fit the Multi-Period Difference-in-Differences model.
721
+
722
+ Parameters
723
+ ----------
724
+ data : pd.DataFrame
725
+ DataFrame containing the outcome, treatment, and time variables.
726
+ outcome : str
727
+ Name of the outcome variable column.
728
+ treatment : str
729
+ Name of the treatment group indicator column (0/1).
730
+ time : str
731
+ Name of the time period column (can have multiple values).
732
+ post_periods : list
733
+ List of time period values that are post-treatment.
734
+ All other periods are treated as pre-treatment.
735
+ covariates : list, optional
736
+ List of covariate column names to include as linear controls.
737
+ fixed_effects : list, optional
738
+ List of categorical column names to include as fixed effects.
739
+ absorb : list, optional
740
+ List of categorical column names for high-dimensional fixed effects.
741
+ reference_period : any, optional
742
+ The reference (omitted) time period for the period dummies.
743
+ Defaults to the first pre-treatment period.
744
+
745
+ Returns
746
+ -------
747
+ MultiPeriodDiDResults
748
+ Object containing period-specific and average treatment effects.
749
+
750
+ Raises
751
+ ------
752
+ ValueError
753
+ If required parameters are missing or data validation fails.
754
+ """
755
+ # Warn if wild bootstrap is requested but not supported
756
+ if self.inference == "wild_bootstrap":
757
+ import warnings
758
+ warnings.warn(
759
+ "Wild bootstrap inference is not yet supported for MultiPeriodDiD. "
760
+ "Using analytical inference instead.",
761
+ UserWarning
762
+ )
763
+
764
+ # Validate basic inputs
765
+ if outcome is None or treatment is None or time is None:
766
+ raise ValueError(
767
+ "Must provide 'outcome', 'treatment', and 'time'"
768
+ )
769
+
770
+ # Validate columns exist
771
+ self._validate_data(data, outcome, treatment, time, covariates)
772
+
773
+ # Validate treatment is binary
774
+ validate_binary(data[treatment].values, "treatment")
775
+
776
+ # Get all unique time periods
777
+ all_periods = sorted(data[time].unique())
778
+
779
+ if len(all_periods) < 2:
780
+ raise ValueError("Time variable must have at least 2 unique periods")
781
+
782
+ # Determine pre and post periods
783
+ if post_periods is None:
784
+ # Default: last half of periods are post-treatment
785
+ mid_point = len(all_periods) // 2
786
+ post_periods = all_periods[mid_point:]
787
+ pre_periods = all_periods[:mid_point]
788
+ else:
789
+ post_periods = list(post_periods)
790
+ pre_periods = [p for p in all_periods if p not in post_periods]
791
+
792
+ if len(post_periods) == 0:
793
+ raise ValueError("Must have at least one post-treatment period")
794
+
795
+ if len(pre_periods) == 0:
796
+ raise ValueError("Must have at least one pre-treatment period")
797
+
798
+ # Validate post_periods are in the data
799
+ for p in post_periods:
800
+ if p not in all_periods:
801
+ raise ValueError(f"Post-period '{p}' not found in time column")
802
+
803
+ # Determine reference period (omitted dummy)
804
+ if reference_period is None:
805
+ reference_period = pre_periods[0]
806
+ elif reference_period not in all_periods:
807
+ raise ValueError(f"Reference period '{reference_period}' not found in time column")
808
+
809
+ # Validate fixed effects and absorb columns
810
+ if fixed_effects:
811
+ for fe in fixed_effects:
812
+ if fe not in data.columns:
813
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
814
+ if absorb:
815
+ for ab in absorb:
816
+ if ab not in data.columns:
817
+ raise ValueError(f"Absorb column '{ab}' not found in data")
818
+
819
+ # Handle absorbed fixed effects (within-transformation)
820
+ working_data = data.copy()
821
+ n_absorbed_effects = 0
822
+
823
+ if absorb:
824
+ vars_to_demean = [outcome] + (covariates or [])
825
+ for ab_var in absorb:
826
+ working_data, n_fe = demean_by_group(
827
+ working_data, vars_to_demean, ab_var, inplace=True
828
+ )
829
+ n_absorbed_effects += n_fe
830
+
831
+ # Extract outcome and treatment
832
+ y = working_data[outcome].values.astype(float)
833
+ d = working_data[treatment].values.astype(float)
834
+ t = working_data[time].values
835
+
836
+ # Build design matrix
837
+ # Start with intercept and treatment main effect
838
+ X = np.column_stack([np.ones(len(y)), d])
839
+ var_names = ["const", treatment]
840
+
841
+ # Add period dummies (excluding reference period)
842
+ non_ref_periods = [p for p in all_periods if p != reference_period]
843
+ period_dummy_indices = {} # Map period -> column index in X
844
+
845
+ for period in non_ref_periods:
846
+ period_dummy = (t == period).astype(float)
847
+ X = np.column_stack([X, period_dummy])
848
+ var_names.append(f"period_{period}")
849
+ period_dummy_indices[period] = X.shape[1] - 1
850
+
851
+ # Add treatment × post-period interactions
852
+ # These are our coefficients of interest
853
+ interaction_indices = {} # Map post-period -> column index in X
854
+
855
+ for period in post_periods:
856
+ interaction = d * (t == period).astype(float)
857
+ X = np.column_stack([X, interaction])
858
+ var_names.append(f"{treatment}:period_{period}")
859
+ interaction_indices[period] = X.shape[1] - 1
860
+
861
+ # Add covariates if provided
862
+ if covariates:
863
+ for cov in covariates:
864
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
865
+ var_names.append(cov)
866
+
867
+ # Add fixed effects as dummy variables
868
+ if fixed_effects:
869
+ for fe in fixed_effects:
870
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
871
+ for col in dummies.columns:
872
+ X = np.column_stack([X, dummies[col].values.astype(float)])
873
+ var_names.append(col)
874
+
875
+ # Fit OLS using unified backend
876
+ coefficients, residuals, fitted, _ = solve_ols(
877
+ X, y, return_fitted=True, return_vcov=False
878
+ )
879
+ r_squared = compute_r_squared(y, residuals)
880
+
881
+ # Degrees of freedom
882
+ df = len(y) - X.shape[1] - n_absorbed_effects
883
+
884
+ # Compute standard errors
885
+ # Note: Wild bootstrap for multi-period effects is complex (multiple coefficients)
886
+ # For now, we use analytical inference even if inference="wild_bootstrap"
887
+ if self.cluster is not None:
888
+ cluster_ids = data[self.cluster].values
889
+ vcov = compute_robust_vcov(X, residuals, cluster_ids)
890
+ elif self.robust:
891
+ vcov = compute_robust_vcov(X, residuals)
892
+ else:
893
+ n = len(y)
894
+ k = X.shape[1]
895
+ mse = np.sum(residuals**2) / (n - k)
896
+ # Use solve() instead of inv() for numerical stability
897
+ # solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse
898
+ vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
899
+
900
+ # Extract period-specific treatment effects
901
+ period_effects = {}
902
+ effect_values = []
903
+ effect_indices = []
904
+
905
+ for period in post_periods:
906
+ idx = interaction_indices[period]
907
+ effect = coefficients[idx]
908
+ se = np.sqrt(vcov[idx, idx])
909
+ t_stat = effect / se
910
+ p_value = compute_p_value(t_stat, df=df)
911
+ conf_int = compute_confidence_interval(effect, se, self.alpha, df=df)
912
+
913
+ period_effects[period] = PeriodEffect(
914
+ period=period,
915
+ effect=effect,
916
+ se=se,
917
+ t_stat=t_stat,
918
+ p_value=p_value,
919
+ conf_int=conf_int
920
+ )
921
+ effect_values.append(effect)
922
+ effect_indices.append(idx)
923
+
924
+ # Compute average treatment effect
925
+ # Average ATT = mean of period-specific effects
926
+ avg_att = np.mean(effect_values)
927
+
928
+ # Standard error of average: need to account for covariance
929
+ # Var(avg) = (1/n^2) * sum of all elements in the sub-covariance matrix
930
+ n_post = len(post_periods)
931
+ sub_vcov = vcov[np.ix_(effect_indices, effect_indices)]
932
+ avg_var = np.sum(sub_vcov) / (n_post ** 2)
933
+ avg_se = np.sqrt(avg_var)
934
+
935
+ avg_t_stat = avg_att / avg_se if avg_se > 0 else 0.0
936
+ avg_p_value = compute_p_value(avg_t_stat, df=df)
937
+ avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df)
938
+
939
+ # Count observations
940
+ n_treated = int(np.sum(d))
941
+ n_control = int(np.sum(1 - d))
942
+
943
+ # Create coefficient dictionary
944
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
945
+
946
+ # Store results
947
+ self.results_ = MultiPeriodDiDResults(
948
+ period_effects=period_effects,
949
+ avg_att=avg_att,
950
+ avg_se=avg_se,
951
+ avg_t_stat=avg_t_stat,
952
+ avg_p_value=avg_p_value,
953
+ avg_conf_int=avg_conf_int,
954
+ n_obs=len(y),
955
+ n_treated=n_treated,
956
+ n_control=n_control,
957
+ pre_periods=pre_periods,
958
+ post_periods=post_periods,
959
+ alpha=self.alpha,
960
+ coefficients=coef_dict,
961
+ vcov=vcov,
962
+ residuals=residuals,
963
+ fitted_values=fitted,
964
+ r_squared=r_squared,
965
+ )
966
+
967
+ self._coefficients = coefficients
968
+ self._vcov = vcov
969
+ self.is_fitted_ = True
970
+
971
+ return self.results_
972
+
973
+ def summary(self) -> str:
974
+ """
975
+ Get summary of estimation results.
976
+
977
+ Returns
978
+ -------
979
+ str
980
+ Formatted summary.
981
+ """
982
+ if not self.is_fitted_:
983
+ raise RuntimeError("Model must be fitted before calling summary()")
984
+ assert self.results_ is not None
985
+ return self.results_.summary()
986
+
987
+
988
+ # Re-export estimators from submodules for backward compatibility
989
+ # These can also be imported directly from their respective modules:
990
+ # - from diff_diff.twfe import TwoWayFixedEffects
991
+ # - from diff_diff.synthetic_did import SyntheticDiD
992
+ from diff_diff.synthetic_did import SyntheticDiD # noqa: E402
993
+ from diff_diff.twfe import TwoWayFixedEffects # noqa: E402
994
+
995
+ __all__ = [
996
+ "DifferenceInDifferences",
997
+ "MultiPeriodDiD",
998
+ "TwoWayFixedEffects",
999
+ "SyntheticDiD",
1000
+ ]