diff-diff 2.2.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1047 @@
1
+ """
2
+ Difference-in-Differences estimators with sklearn-like API.
3
+
4
+ This module contains the core DiD estimators:
5
+ - DifferenceInDifferences: Basic 2x2 DiD estimator
6
+ - MultiPeriodDiD: Event-study style DiD with period-specific treatment effects
7
+
8
+ Additional estimators are in separate modules:
9
+ - TwoWayFixedEffects: See diff_diff.twfe
10
+ - SyntheticDiD: See diff_diff.synthetic_did
11
+
12
+ For backward compatibility, all estimators are re-exported from this module.
13
+ """
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ from diff_diff.linalg import (
21
+ LinearRegression,
22
+ compute_r_squared,
23
+ compute_robust_vcov,
24
+ solve_ols,
25
+ )
26
+ from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
27
+ from diff_diff.utils import (
28
+ WildBootstrapResults,
29
+ compute_confidence_interval,
30
+ compute_p_value,
31
+ demean_by_group,
32
+ validate_binary,
33
+ wild_bootstrap_se,
34
+ )
35
+
36
+
37
+ class DifferenceInDifferences:
38
+ """
39
+ Difference-in-Differences estimator with sklearn-like interface.
40
+
41
+ Estimates the Average Treatment effect on the Treated (ATT) using
42
+ the canonical 2x2 DiD design or panel data with two-way fixed effects.
43
+
44
+ Parameters
45
+ ----------
46
+ formula : str, optional
47
+ R-style formula for the model (e.g., "outcome ~ treated * post").
48
+ If provided, overrides column name parameters.
49
+ robust : bool, default=True
50
+ Whether to use heteroskedasticity-robust standard errors (HC1).
51
+ cluster : str, optional
52
+ Column name for cluster-robust standard errors.
53
+ alpha : float, default=0.05
54
+ Significance level for confidence intervals.
55
+ inference : str, default="analytical"
56
+ Inference method: "analytical" for standard asymptotic inference,
57
+ or "wild_bootstrap" for wild cluster bootstrap (recommended when
58
+ number of clusters is small, <50).
59
+ n_bootstrap : int, default=999
60
+ Number of bootstrap replications when inference="wild_bootstrap".
61
+ bootstrap_weights : str, default="rademacher"
62
+ Type of bootstrap weights: "rademacher" (standard), "webb"
63
+ (recommended for <10 clusters), or "mammen" (skewness correction).
64
+ seed : int, optional
65
+ Random seed for reproducibility when using bootstrap inference.
66
+ If None (default), results will vary between runs.
67
+ rank_deficient_action : str, default "warn"
68
+ Action when design matrix is rank-deficient (linearly dependent columns):
69
+ - "warn": Issue warning and drop linearly dependent columns (default)
70
+ - "error": Raise ValueError
71
+ - "silent": Drop columns silently without warning
72
+
73
+ Attributes
74
+ ----------
75
+ results_ : DiDResults
76
+ Estimation results after calling fit().
77
+ is_fitted_ : bool
78
+ Whether the model has been fitted.
79
+
80
+ Examples
81
+ --------
82
+ Basic usage with a DataFrame:
83
+
84
+ >>> import pandas as pd
85
+ >>> from diff_diff import DifferenceInDifferences
86
+ >>>
87
+ >>> # Create sample data
88
+ >>> data = pd.DataFrame({
89
+ ... 'outcome': [10, 11, 15, 18, 9, 10, 12, 13],
90
+ ... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
91
+ ... 'post': [0, 0, 1, 1, 0, 0, 1, 1]
92
+ ... })
93
+ >>>
94
+ >>> # Fit the model
95
+ >>> did = DifferenceInDifferences()
96
+ >>> results = did.fit(data, outcome='outcome', treatment='treated', time='post')
97
+ >>>
98
+ >>> # View results
99
+ >>> print(results.att) # ATT estimate
100
+ >>> results.print_summary() # Full summary table
101
+
102
+ Using formula interface:
103
+
104
+ >>> did = DifferenceInDifferences()
105
+ >>> results = did.fit(data, formula='outcome ~ treated * post')
106
+
107
+ Notes
108
+ -----
109
+ The ATT is computed using the standard DiD formula:
110
+
111
+ ATT = (E[Y|D=1,T=1] - E[Y|D=1,T=0]) - (E[Y|D=0,T=1] - E[Y|D=0,T=0])
112
+
113
+ Or equivalently via OLS regression:
114
+
115
+ Y = α + β₁*D + β₂*T + β₃*(D×T) + ε
116
+
117
+ Where β₃ is the ATT.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ robust: bool = True,
123
+ cluster: Optional[str] = None,
124
+ alpha: float = 0.05,
125
+ inference: str = "analytical",
126
+ n_bootstrap: int = 999,
127
+ bootstrap_weights: str = "rademacher",
128
+ seed: Optional[int] = None,
129
+ rank_deficient_action: str = "warn",
130
+ ):
131
+ self.robust = robust
132
+ self.cluster = cluster
133
+ self.alpha = alpha
134
+ self.inference = inference
135
+ self.n_bootstrap = n_bootstrap
136
+ self.bootstrap_weights = bootstrap_weights
137
+ self.seed = seed
138
+ self.rank_deficient_action = rank_deficient_action
139
+
140
+ self.is_fitted_ = False
141
+ self.results_ = None
142
+ self._coefficients = None
143
+ self._vcov = None
144
+ self._bootstrap_results = None # Store WildBootstrapResults if used
145
+
146
+ def fit(
147
+ self,
148
+ data: pd.DataFrame,
149
+ outcome: Optional[str] = None,
150
+ treatment: Optional[str] = None,
151
+ time: Optional[str] = None,
152
+ formula: Optional[str] = None,
153
+ covariates: Optional[List[str]] = None,
154
+ fixed_effects: Optional[List[str]] = None,
155
+ absorb: Optional[List[str]] = None
156
+ ) -> DiDResults:
157
+ """
158
+ Fit the Difference-in-Differences model.
159
+
160
+ Parameters
161
+ ----------
162
+ data : pd.DataFrame
163
+ DataFrame containing the outcome, treatment, and time variables.
164
+ outcome : str
165
+ Name of the outcome variable column.
166
+ treatment : str
167
+ Name of the treatment group indicator column (0/1).
168
+ time : str
169
+ Name of the post-treatment period indicator column (0/1).
170
+ formula : str, optional
171
+ R-style formula (e.g., "outcome ~ treated * post").
172
+ If provided, overrides outcome, treatment, and time parameters.
173
+ covariates : list, optional
174
+ List of covariate column names to include as linear controls.
175
+ fixed_effects : list, optional
176
+ List of categorical column names to include as fixed effects.
177
+ Creates dummy variables for each category (drops first level).
178
+ Use for low-dimensional fixed effects (e.g., industry, region).
179
+ absorb : list, optional
180
+ List of categorical column names for high-dimensional fixed effects.
181
+ Uses within-transformation (demeaning) instead of dummy variables.
182
+ More efficient for large numbers of categories (e.g., firm, individual).
183
+
184
+ Returns
185
+ -------
186
+ DiDResults
187
+ Object containing estimation results.
188
+
189
+ Raises
190
+ ------
191
+ ValueError
192
+ If required parameters are missing or data validation fails.
193
+
194
+ Examples
195
+ --------
196
+ Using fixed effects (dummy variables):
197
+
198
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
199
+ ... fixed_effects=['state', 'industry'])
200
+
201
+ Using absorbed fixed effects (within-transformation):
202
+
203
+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
204
+ ... absorb=['firm_id'])
205
+ """
206
+ # Parse formula if provided
207
+ if formula is not None:
208
+ outcome, treatment, time, covariates = self._parse_formula(formula, data)
209
+ elif outcome is None or treatment is None or time is None:
210
+ raise ValueError(
211
+ "Must provide either 'formula' or all of 'outcome', 'treatment', and 'time'"
212
+ )
213
+
214
+ # Validate inputs
215
+ self._validate_data(data, outcome, treatment, time, covariates)
216
+
217
+ # Validate binary variables BEFORE any transformations
218
+ validate_binary(data[treatment].values, "treatment")
219
+ validate_binary(data[time].values, "time")
220
+
221
+ # Validate fixed effects and absorb columns
222
+ if fixed_effects:
223
+ for fe in fixed_effects:
224
+ if fe not in data.columns:
225
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
226
+ if absorb:
227
+ for ab in absorb:
228
+ if ab not in data.columns:
229
+ raise ValueError(f"Absorb column '{ab}' not found in data")
230
+
231
+ # Handle absorbed fixed effects (within-transformation)
232
+ working_data = data.copy()
233
+ absorbed_vars = []
234
+ n_absorbed_effects = 0
235
+
236
+ if absorb:
237
+ # Apply within-transformation for each absorbed variable
238
+ # Only demean outcome and covariates, NOT treatment/time indicators
239
+ # Treatment is typically time-invariant (within unit), and time is
240
+ # unit-invariant, so demeaning them would create multicollinearity
241
+ vars_to_demean = [outcome] + (covariates or [])
242
+ for ab_var in absorb:
243
+ working_data, n_fe = demean_by_group(
244
+ working_data, vars_to_demean, ab_var, inplace=True
245
+ )
246
+ n_absorbed_effects += n_fe
247
+ absorbed_vars.append(ab_var)
248
+
249
+ # Extract variables (may be demeaned if absorb was used)
250
+ y = working_data[outcome].values.astype(float)
251
+ d = working_data[treatment].values.astype(float)
252
+ t = working_data[time].values.astype(float)
253
+
254
+ # Create interaction term
255
+ dt = d * t
256
+
257
+ # Build design matrix
258
+ X = np.column_stack([np.ones(len(y)), d, t, dt])
259
+ var_names = ["const", treatment, time, f"{treatment}:{time}"]
260
+
261
+ # Add covariates if provided
262
+ if covariates:
263
+ for cov in covariates:
264
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
265
+ var_names.append(cov)
266
+
267
+ # Add fixed effects as dummy variables
268
+ if fixed_effects:
269
+ for fe in fixed_effects:
270
+ # Create dummies, drop first category to avoid multicollinearity
271
+ # Use working_data to be consistent with absorbed FE if both are used
272
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
273
+ for col in dummies.columns:
274
+ X = np.column_stack([X, dummies[col].values.astype(float)])
275
+ var_names.append(col)
276
+
277
+ # Extract ATT index (coefficient on interaction term)
278
+ att_idx = 3 # Index of interaction term
279
+ att_var_name = f"{treatment}:{time}"
280
+ assert var_names[att_idx] == att_var_name, (
281
+ f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, "
282
+ f"but found '{var_names[att_idx]}'"
283
+ )
284
+
285
+ # Always use LinearRegression for initial fit (unified code path)
286
+ # For wild bootstrap, we don't need cluster SEs from the initial fit
287
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
288
+ reg = LinearRegression(
289
+ include_intercept=False, # Intercept already in X
290
+ robust=self.robust,
291
+ cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None,
292
+ alpha=self.alpha,
293
+ rank_deficient_action=self.rank_deficient_action,
294
+ ).fit(X, y, df_adjustment=n_absorbed_effects)
295
+
296
+ coefficients = reg.coefficients_
297
+ residuals = reg.residuals_
298
+ fitted = reg.fitted_values_
299
+ att = coefficients[att_idx]
300
+
301
+ # Get inference - either from bootstrap or analytical
302
+ if self.inference == "wild_bootstrap" and self.cluster is not None:
303
+ # Override with wild cluster bootstrap inference
304
+ se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
305
+ X, y, residuals, cluster_ids, att_idx
306
+ )
307
+ else:
308
+ # Use analytical inference from LinearRegression
309
+ vcov = reg.vcov_
310
+ inference = reg.get_inference(att_idx)
311
+ se = inference.se
312
+ t_stat = inference.t_stat
313
+ p_value = inference.p_value
314
+ conf_int = inference.conf_int
315
+
316
+ r_squared = compute_r_squared(y, residuals)
317
+
318
+ # Count observations
319
+ n_treated = int(np.sum(d))
320
+ n_control = int(np.sum(1 - d))
321
+
322
+ # Create coefficient dictionary
323
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
324
+
325
+ # Determine inference method and bootstrap info
326
+ inference_method = "analytical"
327
+ n_bootstrap_used = None
328
+ n_clusters_used = None
329
+ if self._bootstrap_results is not None:
330
+ inference_method = "wild_bootstrap"
331
+ n_bootstrap_used = self._bootstrap_results.n_bootstrap
332
+ n_clusters_used = self._bootstrap_results.n_clusters
333
+
334
+ # Store results
335
+ self.results_ = DiDResults(
336
+ att=att,
337
+ se=se,
338
+ t_stat=t_stat,
339
+ p_value=p_value,
340
+ conf_int=conf_int,
341
+ n_obs=len(y),
342
+ n_treated=n_treated,
343
+ n_control=n_control,
344
+ alpha=self.alpha,
345
+ coefficients=coef_dict,
346
+ vcov=vcov,
347
+ residuals=residuals,
348
+ fitted_values=fitted,
349
+ r_squared=r_squared,
350
+ inference_method=inference_method,
351
+ n_bootstrap=n_bootstrap_used,
352
+ n_clusters=n_clusters_used,
353
+ )
354
+
355
+ self._coefficients = coefficients
356
+ self._vcov = vcov
357
+ self.is_fitted_ = True
358
+
359
+ return self.results_
360
+
361
+ def _fit_ols(
362
+ self, X: np.ndarray, y: np.ndarray
363
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
364
+ """
365
+ Fit OLS regression.
366
+
367
+ This method is kept for backwards compatibility. Internally uses the
368
+ unified solve_ols from diff_diff.linalg for optimized computation.
369
+
370
+ Parameters
371
+ ----------
372
+ X : np.ndarray
373
+ Design matrix.
374
+ y : np.ndarray
375
+ Outcome vector.
376
+
377
+ Returns
378
+ -------
379
+ tuple
380
+ (coefficients, residuals, fitted_values, r_squared)
381
+ """
382
+ # Use unified OLS backend
383
+ coefficients, residuals, fitted, _ = solve_ols(
384
+ X, y, return_fitted=True, return_vcov=False
385
+ )
386
+ r_squared = compute_r_squared(y, residuals)
387
+
388
+ return coefficients, residuals, fitted, r_squared
389
+
390
+ def _run_wild_bootstrap_inference(
391
+ self,
392
+ X: np.ndarray,
393
+ y: np.ndarray,
394
+ residuals: np.ndarray,
395
+ cluster_ids: np.ndarray,
396
+ coefficient_index: int,
397
+ ) -> Tuple[float, float, Tuple[float, float], float, np.ndarray, WildBootstrapResults]:
398
+ """
399
+ Run wild cluster bootstrap inference.
400
+
401
+ Parameters
402
+ ----------
403
+ X : np.ndarray
404
+ Design matrix.
405
+ y : np.ndarray
406
+ Outcome vector.
407
+ residuals : np.ndarray
408
+ OLS residuals.
409
+ cluster_ids : np.ndarray
410
+ Cluster identifiers for each observation.
411
+ coefficient_index : int
412
+ Index of the coefficient to compute inference for.
413
+
414
+ Returns
415
+ -------
416
+ tuple
417
+ (se, p_value, conf_int, t_stat, vcov, bootstrap_results)
418
+ """
419
+ bootstrap_results = wild_bootstrap_se(
420
+ X, y, residuals, cluster_ids,
421
+ coefficient_index=coefficient_index,
422
+ n_bootstrap=self.n_bootstrap,
423
+ weight_type=self.bootstrap_weights,
424
+ alpha=self.alpha,
425
+ seed=self.seed,
426
+ return_distribution=False
427
+ )
428
+ self._bootstrap_results = bootstrap_results
429
+
430
+ se = bootstrap_results.se
431
+ p_value = bootstrap_results.p_value
432
+ conf_int = (bootstrap_results.ci_lower, bootstrap_results.ci_upper)
433
+ t_stat = bootstrap_results.t_stat_original
434
+
435
+ # Also compute vcov for storage (using cluster-robust for consistency)
436
+ vcov = compute_robust_vcov(X, residuals, cluster_ids)
437
+
438
+ return se, p_value, conf_int, t_stat, vcov, bootstrap_results
439
+
440
+ def _parse_formula(
441
+ self, formula: str, data: pd.DataFrame
442
+ ) -> Tuple[str, str, str, Optional[List[str]]]:
443
+ """
444
+ Parse R-style formula.
445
+
446
+ Supports basic formulas like:
447
+ - "outcome ~ treatment * time"
448
+ - "outcome ~ treatment + time + treatment:time"
449
+ - "outcome ~ treatment * time + covariate1 + covariate2"
450
+
451
+ Parameters
452
+ ----------
453
+ formula : str
454
+ R-style formula string.
455
+ data : pd.DataFrame
456
+ DataFrame to validate column names against.
457
+
458
+ Returns
459
+ -------
460
+ tuple
461
+ (outcome, treatment, time, covariates)
462
+ """
463
+ # Split into LHS and RHS
464
+ if "~" not in formula:
465
+ raise ValueError("Formula must contain '~' to separate outcome from predictors")
466
+
467
+ lhs, rhs = formula.split("~")
468
+ outcome = lhs.strip()
469
+
470
+ # Parse RHS
471
+ rhs = rhs.strip()
472
+
473
+ # Check for interaction term
474
+ if "*" in rhs:
475
+ # Handle "treatment * time" syntax
476
+ parts = rhs.split("*")
477
+ if len(parts) != 2:
478
+ raise ValueError("Currently only supports single interaction (treatment * time)")
479
+
480
+ treatment = parts[0].strip()
481
+ time = parts[1].strip()
482
+
483
+ # Check for additional covariates after interaction
484
+ if "+" in time:
485
+ time_parts = time.split("+")
486
+ time = time_parts[0].strip()
487
+ covariates = [p.strip() for p in time_parts[1:]]
488
+ else:
489
+ covariates = None
490
+
491
+ elif ":" in rhs:
492
+ # Handle explicit interaction syntax
493
+ terms = [t.strip() for t in rhs.split("+")]
494
+ interaction_term = None
495
+ main_effects = []
496
+ covariates = []
497
+
498
+ for term in terms:
499
+ if ":" in term:
500
+ interaction_term = term
501
+ else:
502
+ main_effects.append(term)
503
+
504
+ if interaction_term is None:
505
+ raise ValueError("Formula must contain an interaction term (treatment:time)")
506
+
507
+ treatment, time = [t.strip() for t in interaction_term.split(":")]
508
+
509
+ # Remaining terms after treatment and time are covariates
510
+ for term in main_effects:
511
+ if term != treatment and term != time:
512
+ covariates.append(term)
513
+
514
+ covariates = covariates if covariates else None
515
+ else:
516
+ raise ValueError(
517
+ "Formula must contain interaction term. "
518
+ "Use 'outcome ~ treatment * time' or 'outcome ~ treatment + time + treatment:time'"
519
+ )
520
+
521
+ # Validate columns exist
522
+ for col in [outcome, treatment, time]:
523
+ if col not in data.columns:
524
+ raise ValueError(f"Column '{col}' not found in data")
525
+
526
+ if covariates:
527
+ for cov in covariates:
528
+ if cov not in data.columns:
529
+ raise ValueError(f"Covariate '{cov}' not found in data")
530
+
531
+ return outcome, treatment, time, covariates
532
+
533
+ def _validate_data(
534
+ self,
535
+ data: pd.DataFrame,
536
+ outcome: str,
537
+ treatment: str,
538
+ time: str,
539
+ covariates: Optional[List[str]] = None
540
+ ) -> None:
541
+ """Validate input data."""
542
+ # Check DataFrame
543
+ if not isinstance(data, pd.DataFrame):
544
+ raise TypeError("data must be a pandas DataFrame")
545
+
546
+ # Check required columns exist
547
+ required_cols = [outcome, treatment, time]
548
+ if covariates:
549
+ required_cols.extend(covariates)
550
+
551
+ missing_cols = [col for col in required_cols if col not in data.columns]
552
+ if missing_cols:
553
+ raise ValueError(f"Missing columns in data: {missing_cols}")
554
+
555
+ # Check for missing values
556
+ for col in required_cols:
557
+ if data[col].isna().any():
558
+ raise ValueError(f"Column '{col}' contains missing values")
559
+
560
+ # Check for sufficient variation
561
+ if data[treatment].nunique() < 2:
562
+ raise ValueError("Treatment variable must have both 0 and 1 values")
563
+ if data[time].nunique() < 2:
564
+ raise ValueError("Time variable must have both 0 and 1 values")
565
+
566
+ def predict(self, data: pd.DataFrame) -> np.ndarray:
567
+ """
568
+ Predict outcomes using fitted model.
569
+
570
+ Parameters
571
+ ----------
572
+ data : pd.DataFrame
573
+ DataFrame with same structure as training data.
574
+
575
+ Returns
576
+ -------
577
+ np.ndarray
578
+ Predicted values.
579
+ """
580
+ if not self.is_fitted_:
581
+ raise RuntimeError("Model must be fitted before calling predict()")
582
+
583
+ # This is a placeholder - would need to store column names
584
+ # for full implementation
585
+ raise NotImplementedError(
586
+ "predict() is not yet implemented. "
587
+ "Use results_.fitted_values for training data predictions."
588
+ )
589
+
590
+ def get_params(self) -> Dict[str, Any]:
591
+ """
592
+ Get estimator parameters (sklearn-compatible).
593
+
594
+ Returns
595
+ -------
596
+ Dict[str, Any]
597
+ Estimator parameters.
598
+ """
599
+ return {
600
+ "robust": self.robust,
601
+ "cluster": self.cluster,
602
+ "alpha": self.alpha,
603
+ "inference": self.inference,
604
+ "n_bootstrap": self.n_bootstrap,
605
+ "bootstrap_weights": self.bootstrap_weights,
606
+ "seed": self.seed,
607
+ "rank_deficient_action": self.rank_deficient_action,
608
+ }
609
+
610
+ def set_params(self, **params) -> "DifferenceInDifferences":
611
+ """
612
+ Set estimator parameters (sklearn-compatible).
613
+
614
+ Parameters
615
+ ----------
616
+ **params
617
+ Estimator parameters.
618
+
619
+ Returns
620
+ -------
621
+ self
622
+ """
623
+ for key, value in params.items():
624
+ if hasattr(self, key):
625
+ setattr(self, key, value)
626
+ else:
627
+ raise ValueError(f"Unknown parameter: {key}")
628
+ return self
629
+
630
+ def summary(self) -> str:
631
+ """
632
+ Get summary of estimation results.
633
+
634
+ Returns
635
+ -------
636
+ str
637
+ Formatted summary.
638
+ """
639
+ if not self.is_fitted_:
640
+ raise RuntimeError("Model must be fitted before calling summary()")
641
+ assert self.results_ is not None
642
+ return self.results_.summary()
643
+
644
+ def print_summary(self) -> None:
645
+ """Print summary to stdout."""
646
+ print(self.summary())
647
+
648
+
649
+ class MultiPeriodDiD(DifferenceInDifferences):
650
+ """
651
+ Multi-Period Difference-in-Differences estimator.
652
+
653
+ Extends the standard DiD to handle multiple pre-treatment and
654
+ post-treatment time periods, providing period-specific treatment
655
+ effects as well as an aggregate average treatment effect.
656
+
657
+ Parameters
658
+ ----------
659
+ robust : bool, default=True
660
+ Whether to use heteroskedasticity-robust standard errors (HC1).
661
+ cluster : str, optional
662
+ Column name for cluster-robust standard errors.
663
+ alpha : float, default=0.05
664
+ Significance level for confidence intervals.
665
+
666
+ Attributes
667
+ ----------
668
+ results_ : MultiPeriodDiDResults
669
+ Estimation results after calling fit().
670
+ is_fitted_ : bool
671
+ Whether the model has been fitted.
672
+
673
+ Examples
674
+ --------
675
+ Basic usage with multiple time periods:
676
+
677
+ >>> import pandas as pd
678
+ >>> from diff_diff import MultiPeriodDiD
679
+ >>>
680
+ >>> # Create sample panel data with 6 time periods
681
+ >>> # Periods 0-2 are pre-treatment, periods 3-5 are post-treatment
682
+ >>> data = create_panel_data() # Your data
683
+ >>>
684
+ >>> # Fit the model
685
+ >>> did = MultiPeriodDiD()
686
+ >>> results = did.fit(
687
+ ... data,
688
+ ... outcome='sales',
689
+ ... treatment='treated',
690
+ ... time='period',
691
+ ... post_periods=[3, 4, 5] # Specify which periods are post-treatment
692
+ ... )
693
+ >>>
694
+ >>> # View period-specific effects
695
+ >>> for period, effect in results.period_effects.items():
696
+ ... print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})")
697
+ >>>
698
+ >>> # View average treatment effect
699
+ >>> print(f"Average ATT: {results.avg_att:.3f}")
700
+
701
+ Notes
702
+ -----
703
+ The model estimates:
704
+
705
+ Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_t∈post δ_t*(D_i × Post_t) + ε_it
706
+
707
+ Where:
708
+ - D_i is the treatment indicator
709
+ - Period_t are time period dummies
710
+ - D_i × Post_t are treatment-by-post-period interactions
711
+ - δ_t are the period-specific treatment effects
712
+
713
+ The average ATT is computed as the mean of the δ_t coefficients.
714
+ """
715
+
716
+ def fit( # type: ignore[override]
717
+ self,
718
+ data: pd.DataFrame,
719
+ outcome: str,
720
+ treatment: str,
721
+ time: str,
722
+ post_periods: Optional[List[Any]] = None,
723
+ covariates: Optional[List[str]] = None,
724
+ fixed_effects: Optional[List[str]] = None,
725
+ absorb: Optional[List[str]] = None,
726
+ reference_period: Any = None
727
+ ) -> MultiPeriodDiDResults:
728
+ """
729
+ Fit the Multi-Period Difference-in-Differences model.
730
+
731
+ Parameters
732
+ ----------
733
+ data : pd.DataFrame
734
+ DataFrame containing the outcome, treatment, and time variables.
735
+ outcome : str
736
+ Name of the outcome variable column.
737
+ treatment : str
738
+ Name of the treatment group indicator column (0/1).
739
+ time : str
740
+ Name of the time period column (can have multiple values).
741
+ post_periods : list
742
+ List of time period values that are post-treatment.
743
+ All other periods are treated as pre-treatment.
744
+ covariates : list, optional
745
+ List of covariate column names to include as linear controls.
746
+ fixed_effects : list, optional
747
+ List of categorical column names to include as fixed effects.
748
+ absorb : list, optional
749
+ List of categorical column names for high-dimensional fixed effects.
750
+ reference_period : any, optional
751
+ The reference (omitted) time period for the period dummies.
752
+ Defaults to the first pre-treatment period.
753
+
754
+ Returns
755
+ -------
756
+ MultiPeriodDiDResults
757
+ Object containing period-specific and average treatment effects.
758
+
759
+ Raises
760
+ ------
761
+ ValueError
762
+ If required parameters are missing or data validation fails.
763
+ """
764
+ # Warn if wild bootstrap is requested but not supported
765
+ if self.inference == "wild_bootstrap":
766
+ import warnings
767
+ warnings.warn(
768
+ "Wild bootstrap inference is not yet supported for MultiPeriodDiD. "
769
+ "Using analytical inference instead.",
770
+ UserWarning
771
+ )
772
+
773
+ # Validate basic inputs
774
+ if outcome is None or treatment is None or time is None:
775
+ raise ValueError(
776
+ "Must provide 'outcome', 'treatment', and 'time'"
777
+ )
778
+
779
+ # Validate columns exist
780
+ self._validate_data(data, outcome, treatment, time, covariates)
781
+
782
+ # Validate treatment is binary
783
+ validate_binary(data[treatment].values, "treatment")
784
+
785
+ # Get all unique time periods
786
+ all_periods = sorted(data[time].unique())
787
+
788
+ if len(all_periods) < 2:
789
+ raise ValueError("Time variable must have at least 2 unique periods")
790
+
791
+ # Determine pre and post periods
792
+ if post_periods is None:
793
+ # Default: last half of periods are post-treatment
794
+ mid_point = len(all_periods) // 2
795
+ post_periods = all_periods[mid_point:]
796
+ pre_periods = all_periods[:mid_point]
797
+ else:
798
+ post_periods = list(post_periods)
799
+ pre_periods = [p for p in all_periods if p not in post_periods]
800
+
801
+ if len(post_periods) == 0:
802
+ raise ValueError("Must have at least one post-treatment period")
803
+
804
+ if len(pre_periods) == 0:
805
+ raise ValueError("Must have at least one pre-treatment period")
806
+
807
+ # Validate post_periods are in the data
808
+ for p in post_periods:
809
+ if p not in all_periods:
810
+ raise ValueError(f"Post-period '{p}' not found in time column")
811
+
812
+ # Determine reference period (omitted dummy)
813
+ if reference_period is None:
814
+ reference_period = pre_periods[0]
815
+ elif reference_period not in all_periods:
816
+ raise ValueError(f"Reference period '{reference_period}' not found in time column")
817
+
818
+ # Validate fixed effects and absorb columns
819
+ if fixed_effects:
820
+ for fe in fixed_effects:
821
+ if fe not in data.columns:
822
+ raise ValueError(f"Fixed effect column '{fe}' not found in data")
823
+ if absorb:
824
+ for ab in absorb:
825
+ if ab not in data.columns:
826
+ raise ValueError(f"Absorb column '{ab}' not found in data")
827
+
828
+ # Handle absorbed fixed effects (within-transformation)
829
+ working_data = data.copy()
830
+ n_absorbed_effects = 0
831
+
832
+ if absorb:
833
+ vars_to_demean = [outcome] + (covariates or [])
834
+ for ab_var in absorb:
835
+ working_data, n_fe = demean_by_group(
836
+ working_data, vars_to_demean, ab_var, inplace=True
837
+ )
838
+ n_absorbed_effects += n_fe
839
+
840
+ # Extract outcome and treatment
841
+ y = working_data[outcome].values.astype(float)
842
+ d = working_data[treatment].values.astype(float)
843
+ t = working_data[time].values
844
+
845
+ # Build design matrix
846
+ # Start with intercept and treatment main effect
847
+ X = np.column_stack([np.ones(len(y)), d])
848
+ var_names = ["const", treatment]
849
+
850
+ # Add period dummies (excluding reference period)
851
+ non_ref_periods = [p for p in all_periods if p != reference_period]
852
+ period_dummy_indices = {} # Map period -> column index in X
853
+
854
+ for period in non_ref_periods:
855
+ period_dummy = (t == period).astype(float)
856
+ X = np.column_stack([X, period_dummy])
857
+ var_names.append(f"period_{period}")
858
+ period_dummy_indices[period] = X.shape[1] - 1
859
+
860
+ # Add treatment × post-period interactions
861
+ # These are our coefficients of interest
862
+ interaction_indices = {} # Map post-period -> column index in X
863
+
864
+ for period in post_periods:
865
+ interaction = d * (t == period).astype(float)
866
+ X = np.column_stack([X, interaction])
867
+ var_names.append(f"{treatment}:period_{period}")
868
+ interaction_indices[period] = X.shape[1] - 1
869
+
870
+ # Add covariates if provided
871
+ if covariates:
872
+ for cov in covariates:
873
+ X = np.column_stack([X, working_data[cov].values.astype(float)])
874
+ var_names.append(cov)
875
+
876
+ # Add fixed effects as dummy variables
877
+ if fixed_effects:
878
+ for fe in fixed_effects:
879
+ dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
880
+ for col in dummies.columns:
881
+ X = np.column_stack([X, dummies[col].values.astype(float)])
882
+ var_names.append(col)
883
+
884
+ # Fit OLS using unified backend
885
+ # Pass cluster_ids to solve_ols for proper vcov computation
886
+ # This handles rank-deficient matrices by returning NaN for dropped columns
887
+ cluster_ids = data[self.cluster].values if self.cluster is not None else None
888
+
889
+ # Note: Wild bootstrap for multi-period effects is complex (multiple coefficients)
890
+ # For now, we use analytical inference even if inference="wild_bootstrap"
891
+ coefficients, residuals, fitted, vcov = solve_ols(
892
+ X, y,
893
+ return_fitted=True,
894
+ return_vcov=True,
895
+ cluster_ids=cluster_ids,
896
+ column_names=var_names,
897
+ rank_deficient_action=self.rank_deficient_action,
898
+ )
899
+ r_squared = compute_r_squared(y, residuals)
900
+
901
+ # Degrees of freedom using effective rank (non-NaN coefficients)
902
+ k_effective = int(np.sum(~np.isnan(coefficients)))
903
+ df = len(y) - k_effective - n_absorbed_effects
904
+
905
+ # For non-robust, non-clustered case, we need homoskedastic vcov
906
+ # solve_ols returns HC1 by default, so compute homoskedastic if needed
907
+ if not self.robust and self.cluster is None:
908
+ n = len(y)
909
+ mse = np.sum(residuals**2) / (n - k_effective)
910
+ # Use solve() instead of inv() for numerical stability
911
+ # Only compute for identified columns (non-NaN coefficients)
912
+ identified_mask = ~np.isnan(coefficients)
913
+ if np.all(identified_mask):
914
+ vcov = np.linalg.solve(X.T @ X, mse * np.eye(X.shape[1]))
915
+ else:
916
+ # For rank-deficient case, compute vcov on reduced matrix then expand
917
+ X_reduced = X[:, identified_mask]
918
+ vcov_reduced = np.linalg.solve(X_reduced.T @ X_reduced, mse * np.eye(X_reduced.shape[1]))
919
+ # Expand to full size with NaN for dropped columns
920
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
921
+ vcov[np.ix_(identified_mask, identified_mask)] = vcov_reduced
922
+
923
+ # Extract period-specific treatment effects
924
+ period_effects = {}
925
+ effect_values = []
926
+ effect_indices = []
927
+
928
+ for period in post_periods:
929
+ idx = interaction_indices[period]
930
+ effect = coefficients[idx]
931
+ se = np.sqrt(vcov[idx, idx])
932
+ t_stat = effect / se
933
+ p_value = compute_p_value(t_stat, df=df)
934
+ conf_int = compute_confidence_interval(effect, se, self.alpha, df=df)
935
+
936
+ period_effects[period] = PeriodEffect(
937
+ period=period,
938
+ effect=effect,
939
+ se=se,
940
+ t_stat=t_stat,
941
+ p_value=p_value,
942
+ conf_int=conf_int
943
+ )
944
+ effect_values.append(effect)
945
+ effect_indices.append(idx)
946
+
947
+ # Compute average treatment effect
948
+ # R-style NA propagation: if ANY period effect is NaN, average is undefined
949
+ effect_arr = np.array(effect_values)
950
+
951
+ if np.any(np.isnan(effect_arr)):
952
+ # Some period effects are NaN (unidentified) - cannot compute valid average
953
+ # This follows R's default behavior where mean(c(1, 2, NA)) returns NA
954
+ avg_att = np.nan
955
+ avg_se = np.nan
956
+ avg_t_stat = np.nan
957
+ avg_p_value = np.nan
958
+ avg_conf_int = (np.nan, np.nan)
959
+ else:
960
+ # All effects identified - compute average normally
961
+ avg_att = float(np.mean(effect_arr))
962
+
963
+ # Standard error of average: need to account for covariance
964
+ n_post = len(post_periods)
965
+ sub_vcov = vcov[np.ix_(effect_indices, effect_indices)]
966
+ avg_var = np.sum(sub_vcov) / (n_post ** 2)
967
+
968
+ if np.isnan(avg_var) or avg_var < 0:
969
+ # Vcov has NaN (dropped columns) - propagate NaN
970
+ avg_se = np.nan
971
+ avg_t_stat = np.nan
972
+ avg_p_value = np.nan
973
+ avg_conf_int = (np.nan, np.nan)
974
+ else:
975
+ avg_se = float(np.sqrt(avg_var))
976
+ if avg_se > 0:
977
+ avg_t_stat = avg_att / avg_se
978
+ avg_p_value = compute_p_value(avg_t_stat, df=df)
979
+ avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df)
980
+ else:
981
+ # Zero SE (degenerate case)
982
+ avg_t_stat = np.nan
983
+ avg_p_value = np.nan
984
+ avg_conf_int = (np.nan, np.nan)
985
+
986
+ # Count observations
987
+ n_treated = int(np.sum(d))
988
+ n_control = int(np.sum(1 - d))
989
+
990
+ # Create coefficient dictionary
991
+ coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
992
+
993
+ # Store results
994
+ self.results_ = MultiPeriodDiDResults(
995
+ period_effects=period_effects,
996
+ avg_att=avg_att,
997
+ avg_se=avg_se,
998
+ avg_t_stat=avg_t_stat,
999
+ avg_p_value=avg_p_value,
1000
+ avg_conf_int=avg_conf_int,
1001
+ n_obs=len(y),
1002
+ n_treated=n_treated,
1003
+ n_control=n_control,
1004
+ pre_periods=pre_periods,
1005
+ post_periods=post_periods,
1006
+ alpha=self.alpha,
1007
+ coefficients=coef_dict,
1008
+ vcov=vcov,
1009
+ residuals=residuals,
1010
+ fitted_values=fitted,
1011
+ r_squared=r_squared,
1012
+ )
1013
+
1014
+ self._coefficients = coefficients
1015
+ self._vcov = vcov
1016
+ self.is_fitted_ = True
1017
+
1018
+ return self.results_
1019
+
1020
+ def summary(self) -> str:
1021
+ """
1022
+ Get summary of estimation results.
1023
+
1024
+ Returns
1025
+ -------
1026
+ str
1027
+ Formatted summary.
1028
+ """
1029
+ if not self.is_fitted_:
1030
+ raise RuntimeError("Model must be fitted before calling summary()")
1031
+ assert self.results_ is not None
1032
+ return self.results_.summary()
1033
+
1034
+
1035
+ # Re-export estimators from submodules for backward compatibility
1036
+ # These can also be imported directly from their respective modules:
1037
+ # - from diff_diff.twfe import TwoWayFixedEffects
1038
+ # - from diff_diff.synthetic_did import SyntheticDiD
1039
+ from diff_diff.synthetic_did import SyntheticDiD # noqa: E402
1040
+ from diff_diff.twfe import TwoWayFixedEffects # noqa: E402
1041
+
1042
+ __all__ = [
1043
+ "DifferenceInDifferences",
1044
+ "MultiPeriodDiD",
1045
+ "TwoWayFixedEffects",
1046
+ "SyntheticDiD",
1047
+ ]