diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/results.py ADDED
@@ -0,0 +1,918 @@
1
+ """
2
+ Results classes for difference-in-differences estimation.
3
+
4
+ Provides statsmodels-style output with a more Pythonic interface.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+
14
+ def _format_survey_block(sm, width: int) -> list:
15
+ """Format survey design metadata block for summary() output.
16
+
17
+ Parameters
18
+ ----------
19
+ sm : SurveyMetadata
20
+ Survey metadata from results object.
21
+ width : int
22
+ Total width for separator lines and centering.
23
+ """
24
+ label_width = 30 if width >= 80 else 25
25
+ lines = [
26
+ "",
27
+ "-" * width,
28
+ "Survey Design".center(width),
29
+ "-" * width,
30
+ f"{'Weight type:':<{label_width}} {sm.weight_type:>10}",
31
+ ]
32
+ if getattr(sm, "replicate_method", None) is not None:
33
+ lines.append(f"{'Replicate method:':<{label_width}} {sm.replicate_method:>10}")
34
+ if getattr(sm, "n_replicates", None) is not None:
35
+ lines.append(f"{'Replicates:':<{label_width}} {sm.n_replicates:>10}")
36
+ else:
37
+ if sm.n_strata is not None:
38
+ lines.append(f"{'Strata:':<{label_width}} {sm.n_strata:>10}")
39
+ if sm.n_psu is not None:
40
+ lines.append(f"{'PSU/Cluster:':<{label_width}} {sm.n_psu:>10}")
41
+ lines.append(f"{'Effective sample size:':<{label_width}} {sm.effective_n:>10.1f}")
42
+ lines.append(f"{'Kish DEFF (weights):':<{label_width}} {sm.design_effect:>10.2f}")
43
+ if sm.df_survey is not None:
44
+ lines.append(f"{'Survey d.f.:':<{label_width}} {sm.df_survey:>10}")
45
+ lines.append("-" * width)
46
+ return lines
47
+
48
+
49
+ @dataclass
50
+ class DiDResults:
51
+ """
52
+ Results from a Difference-in-Differences estimation.
53
+
54
+ Provides easy access to coefficients, standard errors, confidence intervals,
55
+ and summary statistics in a Pythonic way.
56
+
57
+ Attributes
58
+ ----------
59
+ att : float
60
+ Average Treatment effect on the Treated (ATT).
61
+ se : float
62
+ Standard error of the ATT estimate.
63
+ t_stat : float
64
+ T-statistic for the ATT estimate.
65
+ p_value : float
66
+ P-value for the null hypothesis that ATT = 0.
67
+ conf_int : tuple[float, float]
68
+ Confidence interval for the ATT.
69
+ n_obs : int
70
+ Number of observations used in estimation.
71
+ n_treated : int
72
+ Number of treated units/observations.
73
+ n_control : int
74
+ Number of control units/observations.
75
+ """
76
+
77
+ att: float
78
+ se: float
79
+ t_stat: float
80
+ p_value: float
81
+ conf_int: Tuple[float, float]
82
+ n_obs: int
83
+ n_treated: int
84
+ n_control: int
85
+ alpha: float = 0.05
86
+ coefficients: Optional[Dict[str, float]] = field(default=None)
87
+ vcov: Optional[np.ndarray] = field(default=None)
88
+ residuals: Optional[np.ndarray] = field(default=None)
89
+ fitted_values: Optional[np.ndarray] = field(default=None)
90
+ r_squared: Optional[float] = field(default=None)
91
+ # Bootstrap inference fields
92
+ inference_method: str = field(default="analytical")
93
+ n_bootstrap: Optional[int] = field(default=None)
94
+ n_clusters: Optional[int] = field(default=None)
95
+ bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
96
+ # Survey design metadata (SurveyMetadata instance from diff_diff.survey)
97
+ survey_metadata: Optional[Any] = field(default=None)
98
+
99
+ def __repr__(self) -> str:
100
+ """Concise string representation."""
101
+ return (
102
+ f"DiDResults(ATT={self.att:.4f}{self.significance_stars}, "
103
+ f"SE={self.se:.4f}, "
104
+ f"p={self.p_value:.4f})"
105
+ )
106
+
107
+ @property
108
+ def coef_var(self) -> float:
109
+ """Coefficient of variation: SE / |ATT|. NaN when ATT is 0 or SE non-finite."""
110
+ if not (np.isfinite(self.se) and self.se >= 0):
111
+ return np.nan
112
+ if not np.isfinite(self.att) or self.att == 0:
113
+ return np.nan
114
+ return self.se / abs(self.att)
115
+
116
+ def summary(self, alpha: Optional[float] = None) -> str:
117
+ """
118
+ Generate a formatted summary of the estimation results.
119
+
120
+ Parameters
121
+ ----------
122
+ alpha : float, optional
123
+ Significance level for confidence intervals. Defaults to the
124
+ alpha used during estimation.
125
+
126
+ Returns
127
+ -------
128
+ str
129
+ Formatted summary table.
130
+ """
131
+ alpha = alpha or self.alpha
132
+ conf_level = int((1 - alpha) * 100)
133
+
134
+ lines = [
135
+ "=" * 70,
136
+ "Difference-in-Differences Estimation Results".center(70),
137
+ "=" * 70,
138
+ "",
139
+ f"{'Observations:':<25} {self.n_obs:>10}",
140
+ f"{'Treated:':<25} {self.n_treated:>10}",
141
+ f"{'Control:':<25} {self.n_control:>10}",
142
+ ]
143
+
144
+ if self.r_squared is not None:
145
+ lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}")
146
+
147
+ # Add survey design info
148
+ if self.survey_metadata is not None:
149
+ sm = self.survey_metadata
150
+ lines.extend(_format_survey_block(sm, 70))
151
+
152
+ # Add inference method info
153
+ if self.inference_method != "analytical":
154
+ lines.append(f"{'Inference method:':<25} {self.inference_method:>10}")
155
+ if self.n_bootstrap is not None:
156
+ lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
157
+ if self.n_clusters is not None:
158
+ lines.append(f"{'Number of clusters:':<25} {self.n_clusters:>10}")
159
+
160
+ lines.extend(
161
+ [
162
+ "",
163
+ "-" * 70,
164
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}",
165
+ "-" * 70,
166
+ f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
167
+ "-" * 70,
168
+ "",
169
+ f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
170
+ ]
171
+ )
172
+
173
+ cv = self.coef_var
174
+ if np.isfinite(cv):
175
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
176
+
177
+ # Add significance codes
178
+ lines.extend(
179
+ [
180
+ "",
181
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
182
+ "=" * 70,
183
+ ]
184
+ )
185
+
186
+ return "\n".join(lines)
187
+
188
+ def print_summary(self, alpha: Optional[float] = None) -> None:
189
+ """Print the summary to stdout."""
190
+ print(self.summary(alpha))
191
+
192
+ def to_dict(self) -> Dict[str, Any]:
193
+ """
194
+ Convert results to a dictionary.
195
+
196
+ Returns
197
+ -------
198
+ Dict[str, Any]
199
+ Dictionary containing all estimation results.
200
+ """
201
+ result = {
202
+ "att": self.att,
203
+ "se": self.se,
204
+ "t_stat": self.t_stat,
205
+ "p_value": self.p_value,
206
+ "conf_int_lower": self.conf_int[0],
207
+ "conf_int_upper": self.conf_int[1],
208
+ "n_obs": self.n_obs,
209
+ "n_treated": self.n_treated,
210
+ "n_control": self.n_control,
211
+ "r_squared": self.r_squared,
212
+ "inference_method": self.inference_method,
213
+ }
214
+ if self.n_bootstrap is not None:
215
+ result["n_bootstrap"] = self.n_bootstrap
216
+ if self.n_clusters is not None:
217
+ result["n_clusters"] = self.n_clusters
218
+ if self.survey_metadata is not None:
219
+ sm = self.survey_metadata
220
+ result["weight_type"] = sm.weight_type
221
+ result["effective_n"] = sm.effective_n
222
+ result["design_effect"] = sm.design_effect
223
+ result["sum_weights"] = sm.sum_weights
224
+ result["n_strata"] = sm.n_strata
225
+ result["n_psu"] = sm.n_psu
226
+ result["df_survey"] = sm.df_survey
227
+ return result
228
+
229
+ def to_dataframe(self) -> pd.DataFrame:
230
+ """
231
+ Convert results to a pandas DataFrame.
232
+
233
+ Returns
234
+ -------
235
+ pd.DataFrame
236
+ DataFrame with estimation results.
237
+ """
238
+ return pd.DataFrame([self.to_dict()])
239
+
240
+ @property
241
+ def is_significant(self) -> bool:
242
+ """Check if the ATT is statistically significant at the alpha level."""
243
+ return bool(self.p_value < self.alpha)
244
+
245
+ @property
246
+ def significance_stars(self) -> str:
247
+ """Return significance stars based on p-value."""
248
+ return _get_significance_stars(self.p_value)
249
+
250
+
251
+ def _get_significance_stars(p_value: float) -> str:
252
+ """Return significance stars based on p-value.
253
+
254
+ Returns empty string for NaN p-values (unidentified coefficients from
255
+ rank-deficient matrices).
256
+ """
257
+ import numpy as np
258
+
259
+ if np.isnan(p_value):
260
+ return ""
261
+ if p_value < 0.001:
262
+ return "***"
263
+ elif p_value < 0.01:
264
+ return "**"
265
+ elif p_value < 0.05:
266
+ return "*"
267
+ elif p_value < 0.1:
268
+ return "."
269
+ return ""
270
+
271
+
272
+ @dataclass
273
+ class PeriodEffect:
274
+ """
275
+ Treatment effect for a single time period.
276
+
277
+ Attributes
278
+ ----------
279
+ period : any
280
+ The time period identifier.
281
+ effect : float
282
+ The treatment effect estimate for this period.
283
+ se : float
284
+ Standard error of the effect estimate.
285
+ t_stat : float
286
+ T-statistic for the effect estimate.
287
+ p_value : float
288
+ P-value for the null hypothesis that effect = 0.
289
+ conf_int : tuple[float, float]
290
+ Confidence interval for the effect.
291
+ """
292
+
293
+ period: Any
294
+ effect: float
295
+ se: float
296
+ t_stat: float
297
+ p_value: float
298
+ conf_int: Tuple[float, float]
299
+
300
+ def __repr__(self) -> str:
301
+ """Concise string representation."""
302
+ sig = _get_significance_stars(self.p_value)
303
+ return (
304
+ f"PeriodEffect(period={self.period}, effect={self.effect:.4f}{sig}, "
305
+ f"SE={self.se:.4f}, p={self.p_value:.4f})"
306
+ )
307
+
308
+ @property
309
+ def is_significant(self) -> bool:
310
+ """Check if the effect is statistically significant at 0.05 level."""
311
+ return bool(self.p_value < 0.05)
312
+
313
+ @property
314
+ def significance_stars(self) -> str:
315
+ """Return significance stars based on p-value."""
316
+ return _get_significance_stars(self.p_value)
317
+
318
+
319
+ @dataclass
320
+ class MultiPeriodDiDResults:
321
+ """
322
+ Results from a Multi-Period Difference-in-Differences estimation.
323
+
324
+ Provides access to period-specific treatment effects as well as
325
+ an aggregate average treatment effect.
326
+
327
+ Attributes
328
+ ----------
329
+ period_effects : dict[any, PeriodEffect]
330
+ Dictionary mapping period identifiers to their PeriodEffect objects.
331
+ Contains all estimated period effects (pre and post, excluding
332
+ the reference period which is normalized to zero).
333
+ avg_att : float
334
+ Average Treatment effect on the Treated across post-periods only.
335
+ avg_se : float
336
+ Standard error of the average ATT.
337
+ avg_t_stat : float
338
+ T-statistic for the average ATT.
339
+ avg_p_value : float
340
+ P-value for the null hypothesis that average ATT = 0.
341
+ avg_conf_int : tuple[float, float]
342
+ Confidence interval for the average ATT.
343
+ n_obs : int
344
+ Number of observations used in estimation.
345
+ n_treated : int
346
+ Number of treated units/observations.
347
+ n_control : int
348
+ Number of control units/observations.
349
+ pre_periods : list
350
+ List of pre-treatment period identifiers.
351
+ post_periods : list
352
+ List of post-treatment period identifiers.
353
+ reference_period : any, optional
354
+ The reference (omitted) period. Its coefficient is zero by
355
+ construction and it is excluded from ``period_effects``.
356
+ interaction_indices : dict, optional
357
+ Mapping from period identifier to column index in the full
358
+ variance-covariance matrix. Used internally for sub-VCV
359
+ extraction (e.g., by HonestDiD and PreTrendsPower).
360
+ """
361
+
362
+ period_effects: Dict[Any, PeriodEffect]
363
+ avg_att: float
364
+ avg_se: float
365
+ avg_t_stat: float
366
+ avg_p_value: float
367
+ avg_conf_int: Tuple[float, float]
368
+ n_obs: int
369
+ n_treated: int
370
+ n_control: int
371
+ pre_periods: List[Any]
372
+ post_periods: List[Any]
373
+ alpha: float = 0.05
374
+ coefficients: Optional[Dict[str, float]] = field(default=None)
375
+ vcov: Optional[np.ndarray] = field(default=None)
376
+ residuals: Optional[np.ndarray] = field(default=None)
377
+ fitted_values: Optional[np.ndarray] = field(default=None)
378
+ r_squared: Optional[float] = field(default=None)
379
+ reference_period: Optional[Any] = field(default=None)
380
+ interaction_indices: Optional[Dict[Any, int]] = field(default=None, repr=False)
381
+ # Survey design metadata (SurveyMetadata instance from diff_diff.survey)
382
+ survey_metadata: Optional[Any] = field(default=None)
383
+
384
+ def __repr__(self) -> str:
385
+ """Concise string representation."""
386
+ sig = _get_significance_stars(self.avg_p_value)
387
+ return (
388
+ f"MultiPeriodDiDResults(avg_ATT={self.avg_att:.4f}{sig}, "
389
+ f"SE={self.avg_se:.4f}, "
390
+ f"n_post_periods={len(self.post_periods)})"
391
+ )
392
+
393
+ @property
394
+ def pre_period_effects(self) -> Dict[Any, PeriodEffect]:
395
+ """Pre-period effects only (for parallel trends assessment)."""
396
+ return {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods}
397
+
398
+ @property
399
+ def post_period_effects(self) -> Dict[Any, PeriodEffect]:
400
+ """Post-period effects only."""
401
+ return {p: pe for p, pe in self.period_effects.items() if p in self.post_periods}
402
+
403
+ @property
404
+ def coef_var(self) -> float:
405
+ """Coefficient of variation: SE / |overall ATT|. NaN when ATT is 0 or SE non-finite."""
406
+ if not (np.isfinite(self.avg_se) and self.avg_se >= 0):
407
+ return np.nan
408
+ if not np.isfinite(self.avg_att) or self.avg_att == 0:
409
+ return np.nan
410
+ return self.avg_se / abs(self.avg_att)
411
+
412
+ def summary(self, alpha: Optional[float] = None) -> str:
413
+ """
414
+ Generate a formatted summary of the estimation results.
415
+
416
+ Parameters
417
+ ----------
418
+ alpha : float, optional
419
+ Significance level for confidence intervals. Defaults to the
420
+ alpha used during estimation.
421
+
422
+ Returns
423
+ -------
424
+ str
425
+ Formatted summary table.
426
+ """
427
+ alpha = alpha or self.alpha
428
+ conf_level = int((1 - alpha) * 100)
429
+
430
+ lines = [
431
+ "=" * 80,
432
+ "Multi-Period Difference-in-Differences Estimation Results".center(80),
433
+ "=" * 80,
434
+ "",
435
+ f"{'Observations:':<25} {self.n_obs:>10}",
436
+ f"{'Treated observations:':<25} {self.n_treated:>10}",
437
+ f"{'Control observations:':<25} {self.n_control:>10}",
438
+ f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
439
+ f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
440
+ ]
441
+
442
+ if self.r_squared is not None:
443
+ lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}")
444
+
445
+ # Add survey design info
446
+ if self.survey_metadata is not None:
447
+ sm = self.survey_metadata
448
+ lines.extend(_format_survey_block(sm, 80))
449
+
450
+ # Pre-period effects (parallel trends test)
451
+ pre_effects = {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods}
452
+ if pre_effects:
453
+ lines.extend(
454
+ [
455
+ "",
456
+ "-" * 80,
457
+ "Pre-Period Effects (Parallel Trends Test)".center(80),
458
+ "-" * 80,
459
+ f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
460
+ "-" * 80,
461
+ ]
462
+ )
463
+
464
+ for period in self.pre_periods:
465
+ if period in self.period_effects:
466
+ pe = self.period_effects[period]
467
+ stars = pe.significance_stars
468
+ lines.append(
469
+ f"{str(period):<15} {pe.effect:>12.4f} {pe.se:>12.4f} "
470
+ f"{pe.t_stat:>10.3f} {pe.p_value:>10.4f} {stars:>6}"
471
+ )
472
+
473
+ # Show reference period
474
+ if self.reference_period is not None:
475
+ lines.append(
476
+ f"[ref: {self.reference_period}]"
477
+ f"{'0.0000':>21} {'---':>12} {'---':>10} {'---':>10} {'':>6}"
478
+ )
479
+
480
+ lines.append("-" * 80)
481
+
482
+ # Post-period treatment effects
483
+ lines.extend(
484
+ [
485
+ "",
486
+ "-" * 80,
487
+ "Post-Period Treatment Effects".center(80),
488
+ "-" * 80,
489
+ f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
490
+ "-" * 80,
491
+ ]
492
+ )
493
+
494
+ for period in self.post_periods:
495
+ pe = self.period_effects[period]
496
+ stars = pe.significance_stars
497
+ lines.append(
498
+ f"{str(period):<15} {pe.effect:>12.4f} {pe.se:>12.4f} "
499
+ f"{pe.t_stat:>10.3f} {pe.p_value:>10.4f} {stars:>6}"
500
+ )
501
+
502
+ # Average effect
503
+ lines.extend(
504
+ [
505
+ "-" * 80,
506
+ "",
507
+ "-" * 80,
508
+ "Average Treatment Effect (across post-periods)".center(80),
509
+ "-" * 80,
510
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
511
+ "-" * 80,
512
+ f"{'Avg ATT':<15} {self.avg_att:>12.4f} {self.avg_se:>12.4f} "
513
+ f"{self.avg_t_stat:>10.3f} {self.avg_p_value:>10.4f} {self.significance_stars:>6}",
514
+ "-" * 80,
515
+ "",
516
+ f"{conf_level}% Confidence Interval: [{self.avg_conf_int[0]:.4f}, {self.avg_conf_int[1]:.4f}]",
517
+ ]
518
+ )
519
+
520
+ cv = self.coef_var
521
+ if np.isfinite(cv):
522
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
523
+
524
+ # Add significance codes
525
+ lines.extend(
526
+ [
527
+ "",
528
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
529
+ "=" * 80,
530
+ ]
531
+ )
532
+
533
+ return "\n".join(lines)
534
+
535
+ def print_summary(self, alpha: Optional[float] = None) -> None:
536
+ """Print the summary to stdout."""
537
+ print(self.summary(alpha))
538
+
539
+ def get_effect(self, period) -> PeriodEffect:
540
+ """
541
+ Get the treatment effect for a specific period.
542
+
543
+ Parameters
544
+ ----------
545
+ period : any
546
+ The period identifier.
547
+
548
+ Returns
549
+ -------
550
+ PeriodEffect
551
+ The treatment effect for the specified period.
552
+
553
+ Raises
554
+ ------
555
+ KeyError
556
+ If the period is not found in post-treatment periods.
557
+ """
558
+ if period not in self.period_effects:
559
+ if hasattr(self, "reference_period") and period == self.reference_period:
560
+ raise KeyError(
561
+ f"Period '{period}' is the reference period (coefficient "
562
+ f"normalized to zero by construction). Its effect is 0.0 with "
563
+ f"no associated uncertainty."
564
+ )
565
+ raise KeyError(
566
+ f"Period '{period}' not found. "
567
+ f"Available periods: {list(self.period_effects.keys())}"
568
+ )
569
+ return self.period_effects[period]
570
+
571
+ def to_dict(self) -> Dict[str, Any]:
572
+ """
573
+ Convert results to a dictionary.
574
+
575
+ Returns
576
+ -------
577
+ Dict[str, Any]
578
+ Dictionary containing all estimation results.
579
+ """
580
+ result: Dict[str, Any] = {
581
+ "avg_att": self.avg_att,
582
+ "avg_se": self.avg_se,
583
+ "avg_t_stat": self.avg_t_stat,
584
+ "avg_p_value": self.avg_p_value,
585
+ "avg_conf_int_lower": self.avg_conf_int[0],
586
+ "avg_conf_int_upper": self.avg_conf_int[1],
587
+ "n_obs": self.n_obs,
588
+ "n_treated": self.n_treated,
589
+ "n_control": self.n_control,
590
+ "n_pre_periods": len(self.pre_periods),
591
+ "n_post_periods": len(self.post_periods),
592
+ "r_squared": self.r_squared,
593
+ "reference_period": self.reference_period,
594
+ }
595
+
596
+ # Add period-specific effects
597
+ for period, pe in self.period_effects.items():
598
+ result[f"effect_period_{period}"] = pe.effect
599
+ result[f"se_period_{period}"] = pe.se
600
+ result[f"pval_period_{period}"] = pe.p_value
601
+
602
+ # Add survey metadata if present
603
+ if self.survey_metadata is not None:
604
+ sm = self.survey_metadata
605
+ result["weight_type"] = sm.weight_type
606
+ result["effective_n"] = sm.effective_n
607
+ result["design_effect"] = sm.design_effect
608
+ result["sum_weights"] = sm.sum_weights
609
+ result["n_strata"] = sm.n_strata
610
+ result["n_psu"] = sm.n_psu
611
+ result["df_survey"] = sm.df_survey
612
+
613
+ return result
614
+
615
+ def to_dataframe(self) -> pd.DataFrame:
616
+ """
617
+ Convert period-specific effects to a pandas DataFrame.
618
+
619
+ Returns
620
+ -------
621
+ pd.DataFrame
622
+ DataFrame with one row per estimated period (pre and post).
623
+ """
624
+ rows = []
625
+ for period, pe in self.period_effects.items():
626
+ rows.append(
627
+ {
628
+ "period": period,
629
+ "effect": pe.effect,
630
+ "se": pe.se,
631
+ "t_stat": pe.t_stat,
632
+ "p_value": pe.p_value,
633
+ "conf_int_lower": pe.conf_int[0],
634
+ "conf_int_upper": pe.conf_int[1],
635
+ "is_significant": pe.is_significant,
636
+ "is_post": period in self.post_periods,
637
+ }
638
+ )
639
+ return pd.DataFrame(rows)
640
+
641
+ @property
642
+ def is_significant(self) -> bool:
643
+ """Check if the average ATT is statistically significant at the alpha level."""
644
+ return bool(self.avg_p_value < self.alpha)
645
+
646
+ @property
647
+ def significance_stars(self) -> str:
648
+ """Return significance stars for the average ATT based on p-value."""
649
+ return _get_significance_stars(self.avg_p_value)
650
+
651
+
652
+ @dataclass
653
+ class SyntheticDiDResults:
654
+ """
655
+ Results from a Synthetic Difference-in-Differences estimation.
656
+
657
+ Combines DiD with synthetic control by re-weighting control units to match
658
+ pre-treatment trends of treated units.
659
+
660
+ Attributes
661
+ ----------
662
+ att : float
663
+ Average Treatment effect on the Treated (ATT).
664
+ se : float
665
+ Standard error of the ATT estimate (bootstrap or placebo-based).
666
+ t_stat : float
667
+ T-statistic for the ATT estimate.
668
+ p_value : float
669
+ P-value for the null hypothesis that ATT = 0.
670
+ conf_int : tuple[float, float]
671
+ Confidence interval for the ATT.
672
+ n_obs : int
673
+ Number of observations used in estimation.
674
+ n_treated : int
675
+ Number of treated units/observations.
676
+ n_control : int
677
+ Number of control units/observations.
678
+ unit_weights : dict
679
+ Dictionary mapping control unit IDs to their synthetic weights.
680
+ time_weights : dict
681
+ Dictionary mapping pre-treatment periods to their time weights.
682
+ pre_periods : list
683
+ List of pre-treatment period identifiers.
684
+ post_periods : list
685
+ List of post-treatment period identifiers.
686
+ variance_method : str
687
+ Method used for variance estimation: "bootstrap" or "placebo".
688
+ """
689
+
690
+ att: float
691
+ se: float
692
+ t_stat: float
693
+ p_value: float
694
+ conf_int: Tuple[float, float]
695
+ n_obs: int
696
+ n_treated: int
697
+ n_control: int
698
+ unit_weights: Dict[Any, float]
699
+ time_weights: Dict[Any, float]
700
+ pre_periods: List[Any]
701
+ post_periods: List[Any]
702
+ alpha: float = 0.05
703
+ variance_method: str = field(default="placebo")
704
+ noise_level: Optional[float] = field(default=None)
705
+ zeta_omega: Optional[float] = field(default=None)
706
+ zeta_lambda: Optional[float] = field(default=None)
707
+ pre_treatment_fit: Optional[float] = field(default=None)
708
+ placebo_effects: Optional[np.ndarray] = field(default=None)
709
+ n_bootstrap: Optional[int] = field(default=None)
710
+ # Survey design metadata (SurveyMetadata instance from diff_diff.survey)
711
+ survey_metadata: Optional[Any] = field(default=None)
712
+
713
+ def __repr__(self) -> str:
714
+ """Concise string representation."""
715
+ sig = _get_significance_stars(self.p_value)
716
+ return (
717
+ f"SyntheticDiDResults(ATT={self.att:.4f}{sig}, "
718
+ f"SE={self.se:.4f}, "
719
+ f"p={self.p_value:.4f})"
720
+ )
721
+
722
+ @property
723
+ def coef_var(self) -> float:
724
+ """Coefficient of variation: SE / |ATT|. NaN when ATT is 0 or SE non-finite."""
725
+ if not (np.isfinite(self.se) and self.se >= 0):
726
+ return np.nan
727
+ if not np.isfinite(self.att) or self.att == 0:
728
+ return np.nan
729
+ return self.se / abs(self.att)
730
+
731
+ def summary(self, alpha: Optional[float] = None) -> str:
732
+ """
733
+ Generate a formatted summary of the estimation results.
734
+
735
+ Parameters
736
+ ----------
737
+ alpha : float, optional
738
+ Significance level for confidence intervals. Defaults to the
739
+ alpha used during estimation.
740
+
741
+ Returns
742
+ -------
743
+ str
744
+ Formatted summary table.
745
+ """
746
+ alpha = alpha or self.alpha
747
+ conf_level = int((1 - alpha) * 100)
748
+
749
+ lines = [
750
+ "=" * 75,
751
+ "Synthetic Difference-in-Differences Estimation Results".center(75),
752
+ "=" * 75,
753
+ "",
754
+ f"{'Observations:':<25} {self.n_obs:>10}",
755
+ f"{'Treated:':<25} {self.n_treated:>10}",
756
+ f"{'Control:':<25} {self.n_control:>10}",
757
+ f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
758
+ f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
759
+ ]
760
+
761
+ if self.zeta_omega is not None:
762
+ lines.append(f"{'Zeta (unit weights):':<25} {self.zeta_omega:>10.4f}")
763
+ if self.zeta_lambda is not None:
764
+ lines.append(f"{'Zeta (time weights):':<25} {self.zeta_lambda:>10.6f}")
765
+ if self.noise_level is not None:
766
+ lines.append(f"{'Noise level:':<25} {self.noise_level:>10.4f}")
767
+
768
+ if self.pre_treatment_fit is not None:
769
+ lines.append(f"{'Pre-treatment fit (RMSE):':<25} {self.pre_treatment_fit:>10.4f}")
770
+
771
+ # Variance method info
772
+ lines.append(f"{'Variance method:':<25} {self.variance_method:>10}")
773
+ if self.variance_method == "bootstrap" and self.n_bootstrap is not None:
774
+ lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}")
775
+
776
+ # Add survey design info
777
+ if self.survey_metadata is not None:
778
+ sm = self.survey_metadata
779
+ lines.extend(_format_survey_block(sm, 75))
780
+
781
+ lines.extend(
782
+ [
783
+ "",
784
+ "-" * 75,
785
+ f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}",
786
+ "-" * 75,
787
+ f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}",
788
+ "-" * 75,
789
+ "",
790
+ f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]",
791
+ ]
792
+ )
793
+
794
+ cv = self.coef_var
795
+ if np.isfinite(cv):
796
+ lines.append(f"{'CV (SE/|ATT|):':<25} {cv:>10.4f}")
797
+
798
+ # Show top unit weights
799
+ if self.unit_weights:
800
+ sorted_weights = sorted(self.unit_weights.items(), key=lambda x: x[1], reverse=True)
801
+ top_n = min(5, len(sorted_weights))
802
+ lines.extend(
803
+ [
804
+ "",
805
+ "-" * 75,
806
+ "Top Unit Weights (Synthetic Control)".center(75),
807
+ "-" * 75,
808
+ ]
809
+ )
810
+ for unit, weight in sorted_weights[:top_n]:
811
+ if weight > 0.001: # Only show meaningful weights
812
+ lines.append(f" Unit {unit}: {weight:.4f}")
813
+
814
+ # Show how many units have non-trivial weight
815
+ n_nonzero = sum(1 for w in self.unit_weights.values() if w > 0.001)
816
+ lines.append(f" ({n_nonzero} units with weight > 0.001)")
817
+
818
+ # Add significance codes
819
+ lines.extend(
820
+ [
821
+ "",
822
+ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
823
+ "=" * 75,
824
+ ]
825
+ )
826
+
827
+ return "\n".join(lines)
828
+
829
+ def print_summary(self, alpha: Optional[float] = None) -> None:
830
+ """Print the summary to stdout."""
831
+ print(self.summary(alpha))
832
+
833
+ def to_dict(self) -> Dict[str, Any]:
834
+ """
835
+ Convert results to a dictionary.
836
+
837
+ Returns
838
+ -------
839
+ Dict[str, Any]
840
+ Dictionary containing all estimation results.
841
+ """
842
+ result = {
843
+ "att": self.att,
844
+ "se": self.se,
845
+ "t_stat": self.t_stat,
846
+ "p_value": self.p_value,
847
+ "conf_int_lower": self.conf_int[0],
848
+ "conf_int_upper": self.conf_int[1],
849
+ "n_obs": self.n_obs,
850
+ "n_treated": self.n_treated,
851
+ "n_control": self.n_control,
852
+ "n_pre_periods": len(self.pre_periods),
853
+ "n_post_periods": len(self.post_periods),
854
+ "variance_method": self.variance_method,
855
+ "noise_level": self.noise_level,
856
+ "zeta_omega": self.zeta_omega,
857
+ "zeta_lambda": self.zeta_lambda,
858
+ "pre_treatment_fit": self.pre_treatment_fit,
859
+ }
860
+ if self.n_bootstrap is not None:
861
+ result["n_bootstrap"] = self.n_bootstrap
862
+ if self.survey_metadata is not None:
863
+ sm = self.survey_metadata
864
+ result["weight_type"] = sm.weight_type
865
+ result["effective_n"] = sm.effective_n
866
+ result["design_effect"] = sm.design_effect
867
+ result["sum_weights"] = sm.sum_weights
868
+ result["n_strata"] = sm.n_strata
869
+ result["n_psu"] = sm.n_psu
870
+ result["df_survey"] = sm.df_survey
871
+ return result
872
+
873
+ def to_dataframe(self) -> pd.DataFrame:
874
+ """
875
+ Convert results to a pandas DataFrame.
876
+
877
+ Returns
878
+ -------
879
+ pd.DataFrame
880
+ DataFrame with estimation results.
881
+ """
882
+ return pd.DataFrame([self.to_dict()])
883
+
884
+ def get_unit_weights_df(self) -> pd.DataFrame:
885
+ """
886
+ Get unit weights as a pandas DataFrame.
887
+
888
+ Returns
889
+ -------
890
+ pd.DataFrame
891
+ DataFrame with unit IDs and their weights.
892
+ """
893
+ return pd.DataFrame(
894
+ [{"unit": unit, "weight": weight} for unit, weight in self.unit_weights.items()]
895
+ ).sort_values("weight", ascending=False)
896
+
897
+ def get_time_weights_df(self) -> pd.DataFrame:
898
+ """
899
+ Get time weights as a pandas DataFrame.
900
+
901
+ Returns
902
+ -------
903
+ pd.DataFrame
904
+ DataFrame with time periods and their weights.
905
+ """
906
+ return pd.DataFrame(
907
+ [{"period": period, "weight": weight} for period, weight in self.time_weights.items()]
908
+ )
909
+
910
+ @property
911
+ def is_significant(self) -> bool:
912
+ """Check if the ATT is statistically significant at the alpha level."""
913
+ return bool(self.p_value < self.alpha)
914
+
915
+ @property
916
+ def significance_stars(self) -> str:
917
+ """Return significance stars based on p-value."""
918
+ return _get_significance_stars(self.p_value)