diff-diff 2.0.4__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1493 @@
1
+ """
2
+ Honest DiD sensitivity analysis (Rambachan & Roth 2023).
3
+
4
+ Provides robust inference for difference-in-differences designs when
5
+ parallel trends may be violated. Instead of assuming parallel trends
6
+ holds exactly, this module allows for bounded violations and computes
7
+ partially identified treatment effect bounds.
8
+
9
+ References
10
+ ----------
11
+ Rambachan, A., & Roth, J. (2023). A More Credible Approach to Parallel Trends.
12
+ The Review of Economic Studies, 90(5), 2555-2591.
13
+ https://doi.org/10.1093/restud/rdad018
14
+
15
+ See Also
16
+ --------
17
+ https://github.com/asheshrambachan/HonestDiD - R package implementation
18
+ """
19
+
20
+ from dataclasses import dataclass, field
21
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ from scipy import optimize, stats
26
+
27
+ from diff_diff.results import (
28
+ MultiPeriodDiDResults,
29
+ )
30
+
31
+ # =============================================================================
32
+ # Delta Restriction Classes
33
+ # =============================================================================
34
+
35
+
36
+ @dataclass
37
+ class DeltaSD:
38
+ """
39
+ Smoothness restriction on trend violations (Delta^{SD}).
40
+
41
+ Restricts the second differences of the trend violations:
42
+ |delta_{t+1} - 2*delta_t + delta_{t-1}| <= M
43
+
44
+ When M=0, this enforces that violations follow a linear trend
45
+ (linear extrapolation of pre-trends). Larger M allows more
46
+ curvature in the violation path.
47
+
48
+ Parameters
49
+ ----------
50
+ M : float
51
+ Maximum allowed second difference. M=0 means linear trends only.
52
+
53
+ Examples
54
+ --------
55
+ >>> delta = DeltaSD(M=0.5)
56
+ >>> delta.M
57
+ 0.5
58
+ """
59
+
60
+ M: float = 0.0
61
+
62
+ def __post_init__(self):
63
+ if self.M < 0:
64
+ raise ValueError(f"M must be non-negative, got M={self.M}")
65
+
66
+ def __repr__(self) -> str:
67
+ return f"DeltaSD(M={self.M})"
68
+
69
+
70
+ @dataclass
71
+ class DeltaRM:
72
+ """
73
+ Relative magnitudes restriction on trend violations (Delta^{RM}).
74
+
75
+ Post-treatment violations are bounded by Mbar times the maximum
76
+ absolute pre-treatment violation:
77
+ |delta_post| <= Mbar * max(|delta_pre|)
78
+
79
+ When Mbar=0, this enforces exact parallel trends post-treatment.
80
+ Mbar=1 means post-period violations can be as large as the worst
81
+ observed pre-period violation.
82
+
83
+ Parameters
84
+ ----------
85
+ Mbar : float
86
+ Scaling factor for maximum pre-period violation.
87
+
88
+ Examples
89
+ --------
90
+ >>> delta = DeltaRM(Mbar=1.0)
91
+ >>> delta.Mbar
92
+ 1.0
93
+ """
94
+
95
+ Mbar: float = 1.0
96
+
97
+ def __post_init__(self):
98
+ if self.Mbar < 0:
99
+ raise ValueError(f"Mbar must be non-negative, got Mbar={self.Mbar}")
100
+
101
+ def __repr__(self) -> str:
102
+ return f"DeltaRM(Mbar={self.Mbar})"
103
+
104
+
105
+ @dataclass
106
+ class DeltaSDRM:
107
+ """
108
+ Combined smoothness and relative magnitudes restriction.
109
+
110
+ Imposes both:
111
+ 1. Smoothness: |delta_{t+1} - 2*delta_t + delta_{t-1}| <= M
112
+ 2. Relative magnitudes: |delta_post| <= Mbar * max(|delta_pre|)
113
+
114
+ This is more restrictive than either constraint alone.
115
+
116
+ Parameters
117
+ ----------
118
+ M : float
119
+ Maximum allowed second difference (smoothness).
120
+ Mbar : float
121
+ Scaling factor for maximum pre-period violation (relative magnitudes).
122
+
123
+ Examples
124
+ --------
125
+ >>> delta = DeltaSDRM(M=0.5, Mbar=1.0)
126
+ """
127
+
128
+ M: float = 0.0
129
+ Mbar: float = 1.0
130
+
131
+ def __post_init__(self):
132
+ if self.M < 0:
133
+ raise ValueError(f"M must be non-negative, got M={self.M}")
134
+ if self.Mbar < 0:
135
+ raise ValueError(f"Mbar must be non-negative, got Mbar={self.Mbar}")
136
+
137
+ def __repr__(self) -> str:
138
+ return f"DeltaSDRM(M={self.M}, Mbar={self.Mbar})"
139
+
140
+
141
+ DeltaType = Union[DeltaSD, DeltaRM, DeltaSDRM]
142
+
143
+
144
+ # =============================================================================
145
+ # Results Classes
146
+ # =============================================================================
147
+
148
+
149
+ @dataclass
150
+ class HonestDiDResults:
151
+ """
152
+ Results from Honest DiD sensitivity analysis.
153
+
154
+ Contains bounds on the treatment effect under the specified
155
+ restrictions on violations of parallel trends.
156
+
157
+ Attributes
158
+ ----------
159
+ lb : float
160
+ Lower bound of identified set.
161
+ ub : float
162
+ Upper bound of identified set.
163
+ ci_lb : float
164
+ Lower bound of robust confidence interval.
165
+ ci_ub : float
166
+ Upper bound of robust confidence interval.
167
+ M : float
168
+ The restriction parameter value used.
169
+ method : str
170
+ The type of restriction ("smoothness", "relative_magnitude", or "combined").
171
+ original_estimate : float
172
+ The original point estimate (under parallel trends).
173
+ original_se : float
174
+ The original standard error.
175
+ alpha : float
176
+ Significance level for confidence interval.
177
+ ci_method : str
178
+ Method used for CI construction ("FLCI" or "C-LF").
179
+ original_results : Any
180
+ The original estimation results object.
181
+ """
182
+
183
+ lb: float
184
+ ub: float
185
+ ci_lb: float
186
+ ci_ub: float
187
+ M: float
188
+ method: str
189
+ original_estimate: float
190
+ original_se: float
191
+ alpha: float = 0.05
192
+ ci_method: str = "FLCI"
193
+ original_results: Optional[Any] = field(default=None, repr=False)
194
+ # Event study bounds (optional)
195
+ event_study_bounds: Optional[Dict[Any, Dict[str, float]]] = field(
196
+ default=None, repr=False
197
+ )
198
+
199
+ def __repr__(self) -> str:
200
+ sig = "" if self.ci_lb <= 0 <= self.ci_ub else "*"
201
+ return (
202
+ f"HonestDiDResults(bounds=[{self.lb:.4f}, {self.ub:.4f}], "
203
+ f"CI=[{self.ci_lb:.4f}, {self.ci_ub:.4f}]{sig}, "
204
+ f"M={self.M})"
205
+ )
206
+
207
+ @property
208
+ def is_significant(self) -> bool:
209
+ """Check if CI excludes zero (effect is robust to violations)."""
210
+ return not (self.ci_lb <= 0 <= self.ci_ub)
211
+
212
+ @property
213
+ def significance_stars(self) -> str:
214
+ """
215
+ Return significance indicator if robust CI excludes zero.
216
+
217
+ Note: Unlike point estimation, partial identification does not yield
218
+ a single p-value. This returns "*" if the robust CI excludes zero
219
+ at the specified alpha level, indicating the effect is robust to
220
+ the assumed violations of parallel trends.
221
+ """
222
+ return "*" if self.is_significant else ""
223
+
224
+ @property
225
+ def identified_set_width(self) -> float:
226
+ """Width of the identified set."""
227
+ return self.ub - self.lb
228
+
229
+ @property
230
+ def ci_width(self) -> float:
231
+ """Width of the confidence interval."""
232
+ return self.ci_ub - self.ci_lb
233
+
234
+ def summary(self) -> str:
235
+ """
236
+ Generate formatted summary of sensitivity analysis results.
237
+
238
+ Returns
239
+ -------
240
+ str
241
+ Formatted summary.
242
+ """
243
+ conf_level = int((1 - self.alpha) * 100)
244
+
245
+ method_names = {
246
+ "smoothness": "Smoothness (Delta^SD)",
247
+ "relative_magnitude": "Relative Magnitudes (Delta^RM)",
248
+ "combined": "Combined (Delta^SDRM)",
249
+ }
250
+ method_display = method_names.get(self.method, self.method)
251
+
252
+ lines = [
253
+ "=" * 70,
254
+ "Honest DiD Sensitivity Analysis Results".center(70),
255
+ "(Rambachan & Roth 2023)".center(70),
256
+ "=" * 70,
257
+ "",
258
+ f"{'Method:':<30} {method_display}",
259
+ f"{'Restriction parameter (M):':<30} {self.M:.4f}",
260
+ f"{'CI method:':<30} {self.ci_method}",
261
+ "",
262
+ "-" * 70,
263
+ "Original Estimate (under parallel trends)".center(70),
264
+ "-" * 70,
265
+ f"{'Point estimate:':<30} {self.original_estimate:.4f}",
266
+ f"{'Standard error:':<30} {self.original_se:.4f}",
267
+ "",
268
+ "-" * 70,
269
+ "Robust Results (allowing for violations)".center(70),
270
+ "-" * 70,
271
+ f"{'Identified set:':<30} [{self.lb:.4f}, {self.ub:.4f}]",
272
+ f"{f'{conf_level}% Robust CI:':<30} [{self.ci_lb:.4f}, {self.ci_ub:.4f}]",
273
+ "",
274
+ f"{'Effect robust to violations:':<30} {'Yes' if self.is_significant else 'No'}",
275
+ "",
276
+ ]
277
+
278
+ # Interpretation
279
+ lines.extend([
280
+ "-" * 70,
281
+ "Interpretation".center(70),
282
+ "-" * 70,
283
+ ])
284
+
285
+ if self.method == "relative_magnitude":
286
+ lines.append(
287
+ f"Post-treatment violations bounded at {self.M:.1f}x max pre-period violation."
288
+ )
289
+ elif self.method == "smoothness":
290
+ if self.M == 0:
291
+ lines.append("Violations follow linear extrapolation of pre-trends.")
292
+ else:
293
+ lines.append(
294
+ f"Violation curvature (second diff) bounded by {self.M:.4f} per period."
295
+ )
296
+ else:
297
+ lines.append(
298
+ f"Combined smoothness (M={self.M:.2f}) and relative magnitude bounds."
299
+ )
300
+
301
+ if self.is_significant:
302
+ if self.ci_lb > 0:
303
+ lines.append(f"Effect remains POSITIVE even with violations up to M={self.M}.")
304
+ else:
305
+ lines.append(f"Effect remains NEGATIVE even with violations up to M={self.M}.")
306
+ else:
307
+ lines.append(
308
+ f"Cannot rule out zero effect when allowing violations up to M={self.M}."
309
+ )
310
+
311
+ lines.extend(["", "=" * 70])
312
+
313
+ return "\n".join(lines)
314
+
315
+ def print_summary(self) -> None:
316
+ """Print summary to stdout."""
317
+ print(self.summary())
318
+
319
+ def to_dict(self) -> Dict[str, Any]:
320
+ """Convert results to dictionary."""
321
+ return {
322
+ "lb": self.lb,
323
+ "ub": self.ub,
324
+ "ci_lb": self.ci_lb,
325
+ "ci_ub": self.ci_ub,
326
+ "M": self.M,
327
+ "method": self.method,
328
+ "original_estimate": self.original_estimate,
329
+ "original_se": self.original_se,
330
+ "alpha": self.alpha,
331
+ "ci_method": self.ci_method,
332
+ "is_significant": self.is_significant,
333
+ "identified_set_width": self.identified_set_width,
334
+ "ci_width": self.ci_width,
335
+ }
336
+
337
+ def to_dataframe(self) -> pd.DataFrame:
338
+ """Convert results to DataFrame."""
339
+ return pd.DataFrame([self.to_dict()])
340
+
341
+
342
+ @dataclass
343
+ class SensitivityResults:
344
+ """
345
+ Results from sensitivity analysis over a grid of M values.
346
+
347
+ Contains bounds and confidence intervals for each M value,
348
+ plus the breakdown value.
349
+
350
+ Attributes
351
+ ----------
352
+ M_values : np.ndarray
353
+ Grid of M parameter values.
354
+ bounds : List[Tuple[float, float]]
355
+ List of (lb, ub) identified set bounds for each M.
356
+ robust_cis : List[Tuple[float, float]]
357
+ List of (ci_lb, ci_ub) robust CIs for each M.
358
+ breakdown_M : float
359
+ Smallest M where robust CI includes zero.
360
+ method : str
361
+ Type of restriction used.
362
+ original_estimate : float
363
+ Original point estimate.
364
+ original_se : float
365
+ Original standard error.
366
+ alpha : float
367
+ Significance level.
368
+ """
369
+
370
+ M_values: np.ndarray
371
+ bounds: List[Tuple[float, float]]
372
+ robust_cis: List[Tuple[float, float]]
373
+ breakdown_M: Optional[float]
374
+ method: str
375
+ original_estimate: float
376
+ original_se: float
377
+ alpha: float = 0.05
378
+
379
+ def __repr__(self) -> str:
380
+ breakdown_str = f"{self.breakdown_M:.4f}" if self.breakdown_M else "None"
381
+ return (
382
+ f"SensitivityResults(n_M={len(self.M_values)}, "
383
+ f"breakdown_M={breakdown_str})"
384
+ )
385
+
386
+ @property
387
+ def has_breakdown(self) -> bool:
388
+ """Check if there is a finite breakdown value."""
389
+ return self.breakdown_M is not None
390
+
391
+ def summary(self) -> str:
392
+ """Generate formatted summary."""
393
+ lines = [
394
+ "=" * 70,
395
+ "Honest DiD Sensitivity Analysis".center(70),
396
+ "=" * 70,
397
+ "",
398
+ f"{'Method:':<30} {self.method}",
399
+ f"{'Original estimate:':<30} {self.original_estimate:.4f}",
400
+ f"{'Original SE:':<30} {self.original_se:.4f}",
401
+ f"{'M values tested:':<30} {len(self.M_values)}",
402
+ "",
403
+ ]
404
+
405
+ if self.breakdown_M is not None:
406
+ lines.append(f"{'Breakdown value:':<30} {self.breakdown_M:.4f}")
407
+ lines.append("")
408
+ lines.append(
409
+ f"Result is robust to violations up to M = {self.breakdown_M:.4f}"
410
+ )
411
+ else:
412
+ lines.append(f"{'Breakdown value:':<30} None (always significant)")
413
+
414
+ lines.extend([
415
+ "",
416
+ "-" * 70,
417
+ f"{'M':<10} {'Lower Bound':>12} {'Upper Bound':>12} {'CI Lower':>12} {'CI Upper':>12}",
418
+ "-" * 70,
419
+ ])
420
+
421
+ for i, M in enumerate(self.M_values):
422
+ lb, ub = self.bounds[i]
423
+ ci_lb, ci_ub = self.robust_cis[i]
424
+ lines.append(f"{M:<10.4f} {lb:>12.4f} {ub:>12.4f} {ci_lb:>12.4f} {ci_ub:>12.4f}")
425
+
426
+ lines.extend(["", "=" * 70])
427
+
428
+ return "\n".join(lines)
429
+
430
+ def print_summary(self) -> None:
431
+ """Print summary to stdout."""
432
+ print(self.summary())
433
+
434
+ def to_dataframe(self) -> pd.DataFrame:
435
+ """Convert to DataFrame with one row per M value."""
436
+ rows = []
437
+ for i, M in enumerate(self.M_values):
438
+ lb, ub = self.bounds[i]
439
+ ci_lb, ci_ub = self.robust_cis[i]
440
+ rows.append({
441
+ "M": M,
442
+ "lb": lb,
443
+ "ub": ub,
444
+ "ci_lb": ci_lb,
445
+ "ci_ub": ci_ub,
446
+ "is_significant": not (ci_lb <= 0 <= ci_ub),
447
+ })
448
+ return pd.DataFrame(rows)
449
+
450
+ def plot(self, ax=None, show_bounds: bool = True, show_ci: bool = True,
451
+ breakdown_line: bool = True, **kwargs):
452
+ """
453
+ Plot sensitivity analysis results.
454
+
455
+ Parameters
456
+ ----------
457
+ ax : matplotlib.axes.Axes, optional
458
+ Axes to plot on. If None, creates new figure.
459
+ show_bounds : bool
460
+ Whether to show identified set bounds.
461
+ show_ci : bool
462
+ Whether to show confidence intervals.
463
+ breakdown_line : bool
464
+ Whether to show vertical line at breakdown value.
465
+ **kwargs
466
+ Additional arguments passed to plotting functions.
467
+
468
+ Returns
469
+ -------
470
+ ax : matplotlib.axes.Axes
471
+ The axes with the plot.
472
+ """
473
+ try:
474
+ import matplotlib.pyplot as plt
475
+ except ImportError:
476
+ raise ImportError("matplotlib is required for plotting")
477
+
478
+ if ax is None:
479
+ fig, ax = plt.subplots(figsize=(10, 6))
480
+
481
+ M = self.M_values
482
+ bounds_arr = np.array(self.bounds)
483
+ ci_arr = np.array(self.robust_cis)
484
+
485
+ # Plot original estimate
486
+ ax.axhline(y=self.original_estimate, color='black', linestyle='-',
487
+ linewidth=1.5, label='Original estimate', alpha=0.7)
488
+
489
+ # Plot zero line
490
+ ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5)
491
+
492
+ if show_bounds:
493
+ ax.fill_between(M, bounds_arr[:, 0], bounds_arr[:, 1],
494
+ alpha=0.3, color='blue', label='Identified set')
495
+
496
+ if show_ci:
497
+ ax.plot(M, ci_arr[:, 0], 'b-', linewidth=1.5, label='Robust CI')
498
+ ax.plot(M, ci_arr[:, 1], 'b-', linewidth=1.5)
499
+
500
+ if breakdown_line and self.breakdown_M is not None:
501
+ ax.axvline(x=self.breakdown_M, color='red', linestyle=':',
502
+ linewidth=2, label=f'Breakdown (M={self.breakdown_M:.2f})')
503
+
504
+ ax.set_xlabel('M (restriction parameter)')
505
+ ax.set_ylabel('Treatment Effect')
506
+ ax.set_title('Sensitivity Analysis: Treatment Effect Bounds')
507
+ ax.legend(loc='best')
508
+
509
+ return ax
510
+
511
+
512
+ # =============================================================================
513
+ # Helper Functions
514
+ # =============================================================================
515
+
516
+
517
+ def _extract_event_study_params(
518
+ results: Union[MultiPeriodDiDResults, Any]
519
+ ) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any]]:
520
+ """
521
+ Extract event study parameters from results objects.
522
+
523
+ Parameters
524
+ ----------
525
+ results : MultiPeriodDiDResults or CallawaySantAnnaResults
526
+ Estimation results with event study structure.
527
+
528
+ Returns
529
+ -------
530
+ beta_hat : np.ndarray
531
+ Vector of event study coefficients (pre + post periods).
532
+ sigma : np.ndarray
533
+ Variance-covariance matrix of coefficients.
534
+ num_pre_periods : int
535
+ Number of pre-treatment periods.
536
+ num_post_periods : int
537
+ Number of post-treatment periods.
538
+ pre_periods : list
539
+ Pre-period identifiers.
540
+ post_periods : list
541
+ Post-period identifiers.
542
+ """
543
+ if isinstance(results, MultiPeriodDiDResults):
544
+ # Extract from MultiPeriodDiD
545
+ pre_periods = results.pre_periods
546
+ post_periods = results.post_periods
547
+
548
+ # Get coefficients - need to extract from period_effects
549
+ # Note: MultiPeriodDiD stores effects for post-periods only in period_effects
550
+ # Pre-period effects would be in the coefficients dict if estimated
551
+ effects = []
552
+ ses = []
553
+
554
+ # For now, we'll work with post-period effects
555
+ # In a full event study, we'd also have pre-period coefficients
556
+ for period in post_periods:
557
+ pe = results.period_effects[period]
558
+ effects.append(pe.effect)
559
+ ses.append(pe.se)
560
+
561
+ beta_hat = np.array(effects)
562
+ num_post_periods = len(post_periods)
563
+ num_pre_periods = len(pre_periods) if pre_periods else 0
564
+
565
+ # Get vcov if available
566
+ if results.vcov is not None:
567
+ sigma = results.vcov
568
+ else:
569
+ # Construct diagonal vcov from SEs
570
+ sigma = np.diag(np.array(ses) ** 2)
571
+
572
+ return beta_hat, sigma, num_pre_periods, num_post_periods, pre_periods, post_periods
573
+
574
+ else:
575
+ # Try CallawaySantAnnaResults
576
+ try:
577
+ from diff_diff.staggered import CallawaySantAnnaResults
578
+ if isinstance(results, CallawaySantAnnaResults):
579
+ if results.event_study_effects is None:
580
+ raise ValueError(
581
+ "CallawaySantAnnaResults must have event_study_effects for HonestDiD. "
582
+ "Re-run CallawaySantAnna.fit() with aggregate='event_study' to compute "
583
+ "event study effects."
584
+ )
585
+
586
+ # Extract event study effects by relative time
587
+ event_effects = results.event_study_effects
588
+ rel_times = sorted(event_effects.keys())
589
+
590
+ # Split into pre and post
591
+ pre_times = [t for t in rel_times if t < 0]
592
+ post_times = [t for t in rel_times if t >= 0]
593
+
594
+ effects = []
595
+ ses = []
596
+ for t in rel_times:
597
+ effects.append(event_effects[t]['effect'])
598
+ ses.append(event_effects[t]['se'])
599
+
600
+ beta_hat = np.array(effects)
601
+ sigma = np.diag(np.array(ses) ** 2)
602
+
603
+ return (
604
+ beta_hat, sigma,
605
+ len(pre_times), len(post_times),
606
+ pre_times, post_times
607
+ )
608
+ except ImportError:
609
+ pass
610
+
611
+ raise TypeError(
612
+ f"Unsupported results type: {type(results)}. "
613
+ "Expected MultiPeriodDiDResults or CallawaySantAnnaResults."
614
+ )
615
+
616
+
617
+ def _construct_A_sd(num_periods: int) -> np.ndarray:
618
+ """
619
+ Construct constraint matrix for smoothness (second differences).
620
+
621
+ For T periods, creates matrix A such that:
622
+ A @ delta gives the second differences.
623
+
624
+ Parameters
625
+ ----------
626
+ num_periods : int
627
+ Number of time periods.
628
+
629
+ Returns
630
+ -------
631
+ A : np.ndarray
632
+ Constraint matrix of shape (num_periods - 2, num_periods).
633
+ """
634
+ if num_periods < 3:
635
+ return np.zeros((0, num_periods))
636
+
637
+ n_constraints = num_periods - 2
638
+ A = np.zeros((n_constraints, num_periods))
639
+
640
+ for i in range(n_constraints):
641
+ # Second difference: delta_{t+1} - 2*delta_t + delta_{t-1}
642
+ A[i, i] = 1 # delta_{t-1}
643
+ A[i, i + 1] = -2 # delta_t
644
+ A[i, i + 2] = 1 # delta_{t+1}
645
+
646
+ return A
647
+
648
+
649
+ def _construct_constraints_sd(
650
+ num_pre_periods: int,
651
+ num_post_periods: int,
652
+ M: float
653
+ ) -> Tuple[np.ndarray, np.ndarray]:
654
+ """
655
+ Construct smoothness constraint matrices.
656
+
657
+ Returns A, b such that delta in DeltaSD iff |A @ delta| <= b.
658
+
659
+ Parameters
660
+ ----------
661
+ num_pre_periods : int
662
+ Number of pre-treatment periods.
663
+ num_post_periods : int
664
+ Number of post-treatment periods.
665
+ M : float
666
+ Smoothness parameter.
667
+
668
+ Returns
669
+ -------
670
+ A_ineq : np.ndarray
671
+ Inequality constraint matrix.
672
+ b_ineq : np.ndarray
673
+ Inequality constraint vector.
674
+ """
675
+ total_periods = num_pre_periods + num_post_periods
676
+ A_base = _construct_A_sd(total_periods)
677
+
678
+ if A_base.shape[0] == 0:
679
+ return np.zeros((0, total_periods)), np.zeros(0)
680
+
681
+ # |A @ delta| <= M becomes:
682
+ # A @ delta <= M and -A @ delta <= M
683
+ A_ineq = np.vstack([A_base, -A_base])
684
+ b_ineq = np.full(2 * A_base.shape[0], M)
685
+
686
+ return A_ineq, b_ineq
687
+
688
+
689
+ def _construct_constraints_rm(
690
+ num_pre_periods: int,
691
+ num_post_periods: int,
692
+ Mbar: float,
693
+ max_pre_violation: float
694
+ ) -> Tuple[np.ndarray, np.ndarray]:
695
+ """
696
+ Construct relative magnitudes constraint matrices.
697
+
698
+ Parameters
699
+ ----------
700
+ num_pre_periods : int
701
+ Number of pre-treatment periods.
702
+ num_post_periods : int
703
+ Number of post-treatment periods.
704
+ Mbar : float
705
+ Relative magnitude scaling factor.
706
+ max_pre_violation : float
707
+ Maximum absolute pre-period violation (estimated from data).
708
+
709
+ Returns
710
+ -------
711
+ A_ineq : np.ndarray
712
+ Inequality constraint matrix.
713
+ b_ineq : np.ndarray
714
+ Inequality constraint vector.
715
+ """
716
+ total_periods = num_pre_periods + num_post_periods
717
+
718
+ # Bound post-period violations: |delta_post| <= Mbar * max_pre_violation
719
+ bound = Mbar * max_pre_violation
720
+
721
+ # Create constraints for each post-period
722
+ # delta_post[i] <= bound and -delta_post[i] <= bound
723
+ n_constraints = 2 * num_post_periods
724
+ A_ineq = np.zeros((n_constraints, total_periods))
725
+ b_ineq = np.full(n_constraints, bound)
726
+
727
+ for i in range(num_post_periods):
728
+ post_idx = num_pre_periods + i
729
+ A_ineq[2 * i, post_idx] = 1 # delta <= bound
730
+ A_ineq[2 * i + 1, post_idx] = -1 # -delta <= bound
731
+
732
+ return A_ineq, b_ineq
733
+
734
+
735
+ def _solve_bounds_lp(
736
+ beta_post: np.ndarray,
737
+ l_vec: np.ndarray,
738
+ A_ineq: np.ndarray,
739
+ b_ineq: np.ndarray,
740
+ num_pre_periods: int,
741
+ lp_method: str = 'highs'
742
+ ) -> Tuple[float, float]:
743
+ """
744
+ Solve for identified set bounds using linear programming.
745
+
746
+ The parameter of interest is theta = l' @ (beta_post - delta_post).
747
+ We find min and max over delta in the constraint set.
748
+
749
+ Note: The optimization is over delta for ALL periods (pre + post), but
750
+ only the post-period components contribute to the objective function.
751
+ This correctly handles smoothness constraints that link pre and post periods.
752
+
753
+ Parameters
754
+ ----------
755
+ beta_post : np.ndarray
756
+ Post-period coefficient estimates.
757
+ l_vec : np.ndarray
758
+ Weighting vector for aggregation.
759
+ A_ineq : np.ndarray
760
+ Inequality constraint matrix (for all periods).
761
+ b_ineq : np.ndarray
762
+ Inequality constraint vector.
763
+ num_pre_periods : int
764
+ Number of pre-periods (for indexing).
765
+ lp_method : str
766
+ LP solver method for scipy.optimize.linprog. Default 'highs' requires
767
+ scipy >= 1.6.0. Alternatives: 'interior-point', 'revised simplex'.
768
+
769
+ Returns
770
+ -------
771
+ lb : float
772
+ Lower bound.
773
+ ub : float
774
+ Upper bound.
775
+ """
776
+ num_post = len(beta_post)
777
+ total_periods = A_ineq.shape[1] if A_ineq.shape[0] > 0 else num_pre_periods + num_post
778
+
779
+ # theta = l' @ beta_post - l' @ delta_post
780
+ # We optimize over delta (all periods including pre for smoothness constraints)
781
+
782
+ # Extract post-period part of constraints
783
+ # For delta in R^total_periods, we want min/max of -l' @ delta_post
784
+ # where delta_post = delta[num_pre_periods:]
785
+
786
+ c = np.zeros(total_periods)
787
+ c[num_pre_periods:num_pre_periods + num_post] = -l_vec # min -l'@delta = max l'@delta
788
+
789
+ # For upper bound: max l'@(beta - delta) = l'@beta + max(-l'@delta)
790
+ # For lower bound: min l'@(beta - delta) = l'@beta + min(-l'@delta)
791
+
792
+ if A_ineq.shape[0] == 0:
793
+ # No constraints - unbounded
794
+ return -np.inf, np.inf
795
+
796
+ # Solve for lower bound of -l'@delta (which gives upper bound of theta)
797
+ try:
798
+ result_min = optimize.linprog(
799
+ c, A_ub=A_ineq, b_ub=b_ineq,
800
+ bounds=(None, None),
801
+ method=lp_method
802
+ )
803
+ if result_min.success:
804
+ min_val = result_min.fun
805
+ else:
806
+ min_val = -np.inf
807
+ except (ValueError, TypeError):
808
+ # Optimization failed - return unbounded
809
+ min_val = -np.inf
810
+
811
+ # Solve for upper bound of -l'@delta (which gives lower bound of theta)
812
+ try:
813
+ result_max = optimize.linprog(
814
+ -c, A_ub=A_ineq, b_ub=b_ineq,
815
+ bounds=(None, None),
816
+ method=lp_method
817
+ )
818
+ if result_max.success:
819
+ max_val = -result_max.fun
820
+ else:
821
+ max_val = np.inf
822
+ except (ValueError, TypeError):
823
+ # Optimization failed - return unbounded
824
+ max_val = np.inf
825
+
826
+ theta_base = np.dot(l_vec, beta_post)
827
+ lb = theta_base + min_val # = l'@beta + min(-l'@delta) = min(l'@(beta-delta))
828
+ ub = theta_base + max_val # = l'@beta + max(-l'@delta) = max(l'@(beta-delta))
829
+
830
+ return lb, ub
831
+
832
+
833
+ def _compute_flci(
834
+ lb: float,
835
+ ub: float,
836
+ se: float,
837
+ alpha: float = 0.05
838
+ ) -> Tuple[float, float]:
839
+ """
840
+ Compute Fixed Length Confidence Interval (FLCI).
841
+
842
+ The FLCI extends the identified set by a critical value times
843
+ the standard error on each side.
844
+
845
+ Parameters
846
+ ----------
847
+ lb : float
848
+ Lower bound of identified set.
849
+ ub : float
850
+ Upper bound of identified set.
851
+ se : float
852
+ Standard error of the estimator.
853
+ alpha : float
854
+ Significance level.
855
+
856
+ Returns
857
+ -------
858
+ ci_lb : float
859
+ Lower bound of confidence interval.
860
+ ci_ub : float
861
+ Upper bound of confidence interval.
862
+
863
+ Raises
864
+ ------
865
+ ValueError
866
+ If se <= 0 or alpha is not in (0, 1).
867
+ """
868
+ if se <= 0:
869
+ raise ValueError(f"Standard error must be positive, got se={se}")
870
+ if not (0 < alpha < 1):
871
+ raise ValueError(f"alpha must be between 0 and 1, got alpha={alpha}")
872
+
873
+ z = stats.norm.ppf(1 - alpha / 2)
874
+ ci_lb = lb - z * se
875
+ ci_ub = ub + z * se
876
+ return ci_lb, ci_ub
877
+
878
+
879
+ def _compute_clf_ci(
880
+ beta_post: np.ndarray,
881
+ sigma_post: np.ndarray,
882
+ l_vec: np.ndarray,
883
+ Mbar: float,
884
+ max_pre_violation: float,
885
+ alpha: float = 0.05,
886
+ n_draws: int = 1000
887
+ ) -> Tuple[float, float, float, float]:
888
+ """
889
+ Compute Conditional Least Favorable (C-LF) confidence interval.
890
+
891
+ For relative magnitudes, accounts for estimation of max_pre_violation.
892
+
893
+ Parameters
894
+ ----------
895
+ beta_post : np.ndarray
896
+ Post-period coefficient estimates.
897
+ sigma_post : np.ndarray
898
+ Variance-covariance matrix for post-period coefficients.
899
+ l_vec : np.ndarray
900
+ Weighting vector.
901
+ Mbar : float
902
+ Relative magnitude parameter.
903
+ max_pre_violation : float
904
+ Estimated max pre-period violation.
905
+ alpha : float
906
+ Significance level.
907
+ n_draws : int
908
+ Number of Monte Carlo draws for conditional CI.
909
+
910
+ Returns
911
+ -------
912
+ lb : float
913
+ Lower bound of identified set.
914
+ ub : float
915
+ Upper bound of identified set.
916
+ ci_lb : float
917
+ Lower bound of confidence interval.
918
+ ci_ub : float
919
+ Upper bound of confidence interval.
920
+ """
921
+ # For simplicity, use FLCI approach with adjustment for estimation uncertainty
922
+ # A full implementation would condition on the estimated max_pre_violation
923
+
924
+ theta = np.dot(l_vec, beta_post)
925
+ se = np.sqrt(l_vec @ sigma_post @ l_vec)
926
+
927
+ bound = Mbar * max_pre_violation
928
+
929
+ # Simple bounds: theta +/- bound
930
+ lb = theta - bound
931
+ ub = theta + bound
932
+
933
+ # CI with estimation uncertainty
934
+ z = stats.norm.ppf(1 - alpha / 2)
935
+ ci_lb = lb - z * se
936
+ ci_ub = ub + z * se
937
+
938
+ return lb, ub, ci_lb, ci_ub
939
+
940
+
941
+ # =============================================================================
942
+ # Main Class
943
+ # =============================================================================
944
+
945
+
946
+ class HonestDiD:
947
+ """
948
+ Honest DiD sensitivity analysis (Rambachan & Roth 2023).
949
+
950
+ Computes robust inference for difference-in-differences allowing
951
+ for bounded violations of parallel trends.
952
+
953
+ Parameters
954
+ ----------
955
+ method : {"smoothness", "relative_magnitude", "combined"}
956
+ Type of restriction on trend violations:
957
+ - "smoothness": Bounds on second differences (Delta^SD)
958
+ - "relative_magnitude": Post violations <= M * max pre violation (Delta^RM)
959
+ - "combined": Both restrictions (Delta^SDRM)
960
+ M : float, optional
961
+ Restriction parameter. Interpretation depends on method:
962
+ - smoothness: Max second difference
963
+ - relative_magnitude: Scaling factor for max pre-period violation
964
+ Default is 1.0 for relative_magnitude, 0.0 for smoothness.
965
+ alpha : float
966
+ Significance level for confidence intervals.
967
+ l_vec : array-like or None
968
+ Weighting vector for scalar parameter (length = num_post_periods).
969
+ If None, uses uniform weights (average effect).
970
+
971
+ Examples
972
+ --------
973
+ >>> from diff_diff import MultiPeriodDiD
974
+ >>> from diff_diff.honest_did import HonestDiD
975
+ >>>
976
+ >>> # Fit event study
977
+ >>> mp_did = MultiPeriodDiD()
978
+ >>> results = mp_did.fit(data, outcome='y', treatment='treated',
979
+ ... time='period', post_periods=[4,5,6,7])
980
+ >>>
981
+ >>> # Sensitivity analysis with relative magnitudes
982
+ >>> honest = HonestDiD(method='relative_magnitude', M=1.0)
983
+ >>> bounds = honest.fit(results)
984
+ >>> print(bounds.summary())
985
+ >>>
986
+ >>> # Sensitivity curve over M values
987
+ >>> sensitivity = honest.sensitivity_analysis(results, M_grid=[0, 0.5, 1, 1.5, 2])
988
+ >>> sensitivity.plot()
989
+ """
990
+
991
+ def __init__(
992
+ self,
993
+ method: Literal["smoothness", "relative_magnitude", "combined"] = "relative_magnitude",
994
+ M: Optional[float] = None,
995
+ alpha: float = 0.05,
996
+ l_vec: Optional[np.ndarray] = None,
997
+ ):
998
+ self.method = method
999
+ self.alpha = alpha
1000
+ self.l_vec = l_vec
1001
+
1002
+ # Set default M based on method
1003
+ if M is None:
1004
+ self.M = 1.0 if method == "relative_magnitude" else 0.0
1005
+ else:
1006
+ self.M = M
1007
+
1008
+ self._validate_params()
1009
+
1010
+ def _validate_params(self):
1011
+ """Validate initialization parameters."""
1012
+ if self.method not in ["smoothness", "relative_magnitude", "combined"]:
1013
+ raise ValueError(
1014
+ f"method must be 'smoothness', 'relative_magnitude', or 'combined', "
1015
+ f"got method='{self.method}'"
1016
+ )
1017
+ if self.M < 0:
1018
+ raise ValueError(f"M must be non-negative, got M={self.M}")
1019
+ if not 0 < self.alpha < 1:
1020
+ raise ValueError(f"alpha must be between 0 and 1, got alpha={self.alpha}")
1021
+
1022
+ def get_params(self) -> Dict[str, Any]:
1023
+ """Get parameters for this estimator."""
1024
+ return {
1025
+ "method": self.method,
1026
+ "M": self.M,
1027
+ "alpha": self.alpha,
1028
+ "l_vec": self.l_vec,
1029
+ }
1030
+
1031
+ def set_params(self, **params) -> "HonestDiD":
1032
+ """Set parameters for this estimator."""
1033
+ for key, value in params.items():
1034
+ if hasattr(self, key):
1035
+ setattr(self, key, value)
1036
+ else:
1037
+ raise ValueError(f"Invalid parameter: {key}")
1038
+ self._validate_params()
1039
+ return self
1040
+
1041
+ def fit(
1042
+ self,
1043
+ results: Union[MultiPeriodDiDResults, Any],
1044
+ M: Optional[float] = None,
1045
+ ) -> HonestDiDResults:
1046
+ """
1047
+ Compute bounds and robust confidence intervals.
1048
+
1049
+ Parameters
1050
+ ----------
1051
+ results : MultiPeriodDiDResults or CallawaySantAnnaResults
1052
+ Results from event study estimation.
1053
+ M : float, optional
1054
+ Override the M parameter for this fit.
1055
+
1056
+ Returns
1057
+ -------
1058
+ HonestDiDResults
1059
+ Results containing bounds and robust confidence intervals.
1060
+ """
1061
+ M = M if M is not None else self.M
1062
+
1063
+ # Extract event study parameters
1064
+ (beta_hat, sigma, num_pre, num_post,
1065
+ pre_periods, post_periods) = _extract_event_study_params(results)
1066
+
1067
+ # beta_hat from MultiPeriodDiDResults already contains only post-periods
1068
+ # Check if we have the right number of coefficients
1069
+ if len(beta_hat) == num_post:
1070
+ # Already just post-period effects
1071
+ beta_post = beta_hat
1072
+ elif len(beta_hat) == num_pre + num_post:
1073
+ # Full event study, extract post-periods
1074
+ beta_post = beta_hat[num_pre:]
1075
+ else:
1076
+ # Assume it's post-period effects
1077
+ beta_post = beta_hat
1078
+ num_post = len(beta_hat)
1079
+
1080
+ # Handle sigma extraction for post periods
1081
+ if sigma.shape[0] == num_post and sigma.shape[0] == len(beta_post):
1082
+ sigma_post = sigma
1083
+ elif sigma.shape[0] == num_pre + num_post:
1084
+ sigma_post = sigma[num_pre:, num_pre:]
1085
+ else:
1086
+ # Construct diagonal from available dimensions
1087
+ sigma_post = sigma[:len(beta_post), :len(beta_post)]
1088
+
1089
+ # Update num_post to match actual data
1090
+ num_post = len(beta_post)
1091
+
1092
+ # Set up weighting vector
1093
+ if self.l_vec is None:
1094
+ l_vec = np.ones(num_post) / num_post # Uniform weights
1095
+ else:
1096
+ l_vec = np.asarray(self.l_vec)
1097
+ if len(l_vec) != num_post:
1098
+ raise ValueError(
1099
+ f"l_vec must have length {num_post}, got {len(l_vec)}"
1100
+ )
1101
+
1102
+ # Compute original estimate and SE
1103
+ original_estimate = np.dot(l_vec, beta_post)
1104
+ original_se = np.sqrt(l_vec @ sigma_post @ l_vec)
1105
+
1106
+ # Compute bounds based on method
1107
+ if self.method == "smoothness":
1108
+ lb, ub, ci_lb, ci_ub = self._compute_smoothness_bounds(
1109
+ beta_post, sigma_post, l_vec, num_pre, num_post, M
1110
+ )
1111
+ ci_method = "FLCI"
1112
+
1113
+ elif self.method == "relative_magnitude":
1114
+ lb, ub, ci_lb, ci_ub = self._compute_rm_bounds(
1115
+ beta_post, sigma_post, l_vec, num_pre, num_post, M,
1116
+ pre_periods, results
1117
+ )
1118
+ ci_method = "C-LF"
1119
+
1120
+ else: # combined
1121
+ lb, ub, ci_lb, ci_ub = self._compute_combined_bounds(
1122
+ beta_post, sigma_post, l_vec, num_pre, num_post, M,
1123
+ pre_periods, results
1124
+ )
1125
+ ci_method = "FLCI"
1126
+
1127
+ return HonestDiDResults(
1128
+ lb=lb,
1129
+ ub=ub,
1130
+ ci_lb=ci_lb,
1131
+ ci_ub=ci_ub,
1132
+ M=M,
1133
+ method=self.method,
1134
+ original_estimate=original_estimate,
1135
+ original_se=original_se,
1136
+ alpha=self.alpha,
1137
+ ci_method=ci_method,
1138
+ original_results=results,
1139
+ )
1140
+
1141
+ def _compute_smoothness_bounds(
1142
+ self,
1143
+ beta_post: np.ndarray,
1144
+ sigma_post: np.ndarray,
1145
+ l_vec: np.ndarray,
1146
+ num_pre: int,
1147
+ num_post: int,
1148
+ M: float
1149
+ ) -> Tuple[float, float, float, float]:
1150
+ """Compute bounds under smoothness restriction."""
1151
+ # Construct constraints
1152
+ A_ineq, b_ineq = _construct_constraints_sd(num_pre, num_post, M)
1153
+
1154
+ # Solve for bounds
1155
+ lb, ub = _solve_bounds_lp(beta_post, l_vec, A_ineq, b_ineq, num_pre)
1156
+
1157
+ # Compute FLCI
1158
+ se = np.sqrt(l_vec @ sigma_post @ l_vec)
1159
+ ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha)
1160
+
1161
+ return lb, ub, ci_lb, ci_ub
1162
+
1163
+ def _compute_rm_bounds(
1164
+ self,
1165
+ beta_post: np.ndarray,
1166
+ sigma_post: np.ndarray,
1167
+ l_vec: np.ndarray,
1168
+ num_pre: int,
1169
+ num_post: int,
1170
+ Mbar: float,
1171
+ pre_periods: List,
1172
+ results: Any
1173
+ ) -> Tuple[float, float, float, float]:
1174
+ """Compute bounds under relative magnitudes restriction."""
1175
+ # Estimate max pre-period violation from pre-trends
1176
+ # For relative magnitudes, we use the pre-period coefficients
1177
+ max_pre_violation = self._estimate_max_pre_violation(results, pre_periods)
1178
+
1179
+ if max_pre_violation == 0:
1180
+ # No pre-period violations detected - use point estimate
1181
+ theta = np.dot(l_vec, beta_post)
1182
+ se = np.sqrt(l_vec @ sigma_post @ l_vec)
1183
+ z = stats.norm.ppf(1 - self.alpha / 2)
1184
+ return theta, theta, theta - z * se, theta + z * se
1185
+
1186
+ # Compute bounds
1187
+ lb, ub, ci_lb, ci_ub = _compute_clf_ci(
1188
+ beta_post, sigma_post, l_vec, Mbar, max_pre_violation, self.alpha
1189
+ )
1190
+
1191
+ return lb, ub, ci_lb, ci_ub
1192
+
1193
+ def _compute_combined_bounds(
1194
+ self,
1195
+ beta_post: np.ndarray,
1196
+ sigma_post: np.ndarray,
1197
+ l_vec: np.ndarray,
1198
+ num_pre: int,
1199
+ num_post: int,
1200
+ M: float,
1201
+ pre_periods: List,
1202
+ results: Any
1203
+ ) -> Tuple[float, float, float, float]:
1204
+ """Compute bounds under combined smoothness + RM restriction."""
1205
+ # Get smoothness bounds
1206
+ lb_sd, ub_sd, _, _ = self._compute_smoothness_bounds(
1207
+ beta_post, sigma_post, l_vec, num_pre, num_post, M
1208
+ )
1209
+
1210
+ # Get RM bounds (use M as Mbar for combined)
1211
+ lb_rm, ub_rm, _, _ = self._compute_rm_bounds(
1212
+ beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results
1213
+ )
1214
+
1215
+ # Combined bounds are intersection
1216
+ lb = max(lb_sd, lb_rm)
1217
+ ub = min(ub_sd, ub_rm)
1218
+
1219
+ # If bounds cross, use the original estimate
1220
+ if lb > ub:
1221
+ theta = np.dot(l_vec, beta_post)
1222
+ lb = ub = theta
1223
+
1224
+ # Compute FLCI on combined bounds
1225
+ se = np.sqrt(l_vec @ sigma_post @ l_vec)
1226
+ ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha)
1227
+
1228
+ return lb, ub, ci_lb, ci_ub
1229
+
1230
+ def _estimate_max_pre_violation(
1231
+ self,
1232
+ results: Any,
1233
+ pre_periods: List
1234
+ ) -> float:
1235
+ """
1236
+ Estimate the maximum pre-period violation.
1237
+
1238
+ Uses pre-period coefficients if available, otherwise returns
1239
+ a default based on the overall SE.
1240
+ """
1241
+ if isinstance(results, MultiPeriodDiDResults):
1242
+ # Check if we have pre-period effects
1243
+ # In a standard event study, pre-period coefficients should be ~0
1244
+ # Their magnitude indicates the pre-trend violation
1245
+ if hasattr(results, 'coefficients') and results.coefficients:
1246
+ # Look for pre-period coefficients
1247
+ pre_effects = []
1248
+ for period in pre_periods:
1249
+ key = f"treated:period_{period}"
1250
+ if key in results.coefficients:
1251
+ pre_effects.append(abs(results.coefficients[key]))
1252
+
1253
+ if pre_effects:
1254
+ return max(pre_effects)
1255
+
1256
+ # Fallback: use avg_se as a scale
1257
+ return results.avg_se
1258
+
1259
+ # For CallawaySantAnna, use pre-period event study effects
1260
+ try:
1261
+ from diff_diff.staggered import CallawaySantAnnaResults
1262
+ if isinstance(results, CallawaySantAnnaResults):
1263
+ if results.event_study_effects:
1264
+ pre_effects = [
1265
+ abs(results.event_study_effects[t]['effect'])
1266
+ for t in results.event_study_effects
1267
+ if t < 0
1268
+ ]
1269
+ if pre_effects:
1270
+ return max(pre_effects)
1271
+ return results.overall_se
1272
+ except ImportError:
1273
+ pass
1274
+
1275
+ # Default fallback
1276
+ return 0.1
1277
+
1278
+ def sensitivity_analysis(
1279
+ self,
1280
+ results: Union[MultiPeriodDiDResults, Any],
1281
+ M_grid: Optional[List[float]] = None,
1282
+ ) -> SensitivityResults:
1283
+ """
1284
+ Perform sensitivity analysis over a grid of M values.
1285
+
1286
+ Parameters
1287
+ ----------
1288
+ results : MultiPeriodDiDResults or CallawaySantAnnaResults
1289
+ Results from event study estimation.
1290
+ M_grid : list of float, optional
1291
+ Grid of M values to evaluate. If None, uses default grid
1292
+ based on method.
1293
+
1294
+ Returns
1295
+ -------
1296
+ SensitivityResults
1297
+ Results containing bounds and CIs for each M value.
1298
+ """
1299
+ if M_grid is None:
1300
+ if self.method == "relative_magnitude":
1301
+ M_grid = [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
1302
+ else:
1303
+ M_grid = [0, 0.1, 0.2, 0.3, 0.5, 0.75, 1.0]
1304
+
1305
+ M_values = np.array(M_grid)
1306
+ bounds_list = []
1307
+ ci_list = []
1308
+
1309
+ for M in M_values:
1310
+ result = self.fit(results, M=M)
1311
+ bounds_list.append((result.lb, result.ub))
1312
+ ci_list.append((result.ci_lb, result.ci_ub))
1313
+
1314
+ # Find breakdown value
1315
+ breakdown_M = self._find_breakdown(results, M_values, ci_list)
1316
+
1317
+ # Get original estimate info
1318
+ first_result = self.fit(results, M=0)
1319
+
1320
+ return SensitivityResults(
1321
+ M_values=M_values,
1322
+ bounds=bounds_list,
1323
+ robust_cis=ci_list,
1324
+ breakdown_M=breakdown_M,
1325
+ method=self.method,
1326
+ original_estimate=first_result.original_estimate,
1327
+ original_se=first_result.original_se,
1328
+ alpha=self.alpha,
1329
+ )
1330
+
1331
+ def _find_breakdown(
1332
+ self,
1333
+ results: Any,
1334
+ M_values: np.ndarray,
1335
+ ci_list: List[Tuple[float, float]]
1336
+ ) -> Optional[float]:
1337
+ """
1338
+ Find the breakdown value where CI first includes zero.
1339
+
1340
+ Uses binary search for precision.
1341
+ """
1342
+ # Check if any CI includes zero
1343
+ includes_zero = [ci_lb <= 0 <= ci_ub for ci_lb, ci_ub in ci_list]
1344
+
1345
+ if not any(includes_zero):
1346
+ # Always significant - no breakdown
1347
+ return None
1348
+
1349
+ if all(includes_zero):
1350
+ # Never significant - breakdown at 0
1351
+ return 0.0
1352
+
1353
+ # Find first transition point
1354
+ for i, (inc, M) in enumerate(zip(includes_zero, M_values)):
1355
+ if inc and (i == 0 or not includes_zero[i - 1]):
1356
+ # Binary search between M_values[i-1] and M_values[i]
1357
+ if i == 0:
1358
+ return 0.0
1359
+
1360
+ lo, hi = M_values[i - 1], M_values[i]
1361
+
1362
+ for _ in range(20): # 20 iterations for precision
1363
+ mid = (lo + hi) / 2
1364
+ result = self.fit(results, M=mid)
1365
+ if result.ci_lb <= 0 <= result.ci_ub:
1366
+ hi = mid
1367
+ else:
1368
+ lo = mid
1369
+
1370
+ return (lo + hi) / 2
1371
+
1372
+ return None
1373
+
1374
+ def breakdown_value(
1375
+ self,
1376
+ results: Union[MultiPeriodDiDResults, Any],
1377
+ tol: float = 0.01
1378
+ ) -> Optional[float]:
1379
+ """
1380
+ Find the breakdown value directly using binary search.
1381
+
1382
+ The breakdown value is the smallest M where the robust
1383
+ confidence interval includes zero.
1384
+
1385
+ Parameters
1386
+ ----------
1387
+ results : MultiPeriodDiDResults or CallawaySantAnnaResults
1388
+ Results from event study estimation.
1389
+ tol : float
1390
+ Tolerance for binary search.
1391
+
1392
+ Returns
1393
+ -------
1394
+ float or None
1395
+ Breakdown value, or None if effect is always significant.
1396
+ """
1397
+ # Check at M=0
1398
+ result_0 = self.fit(results, M=0)
1399
+ if result_0.ci_lb <= 0 <= result_0.ci_ub:
1400
+ return 0.0
1401
+
1402
+ # Check if significant even for large M
1403
+ result_large = self.fit(results, M=10)
1404
+ if not (result_large.ci_lb <= 0 <= result_large.ci_ub):
1405
+ return None # Always significant
1406
+
1407
+ # Binary search
1408
+ lo, hi = 0.0, 10.0
1409
+
1410
+ while hi - lo > tol:
1411
+ mid = (lo + hi) / 2
1412
+ result = self.fit(results, M=mid)
1413
+ if result.ci_lb <= 0 <= result.ci_ub:
1414
+ hi = mid
1415
+ else:
1416
+ lo = mid
1417
+
1418
+ return (lo + hi) / 2
1419
+
1420
+
1421
+ # =============================================================================
1422
+ # Convenience Functions
1423
+ # =============================================================================
1424
+
1425
+
1426
+ def compute_honest_did(
1427
+ results: Union[MultiPeriodDiDResults, Any],
1428
+ method: str = "relative_magnitude",
1429
+ M: float = 1.0,
1430
+ alpha: float = 0.05,
1431
+ ) -> HonestDiDResults:
1432
+ """
1433
+ Convenience function for computing Honest DiD bounds.
1434
+
1435
+ Parameters
1436
+ ----------
1437
+ results : MultiPeriodDiDResults or CallawaySantAnnaResults
1438
+ Results from event study estimation.
1439
+ method : str
1440
+ Type of restriction ("smoothness", "relative_magnitude", "combined").
1441
+ M : float
1442
+ Restriction parameter.
1443
+ alpha : float
1444
+ Significance level.
1445
+
1446
+ Returns
1447
+ -------
1448
+ HonestDiDResults
1449
+ Bounds and robust confidence intervals.
1450
+
1451
+ Examples
1452
+ --------
1453
+ >>> bounds = compute_honest_did(event_study_results, method='relative_magnitude', M=1.0)
1454
+ >>> print(f"Robust CI: [{bounds.ci_lb:.3f}, {bounds.ci_ub:.3f}]")
1455
+ """
1456
+ honest = HonestDiD(method=method, M=M, alpha=alpha)
1457
+ return honest.fit(results)
1458
+
1459
+
1460
+ def sensitivity_plot(
1461
+ results: Union[MultiPeriodDiDResults, Any],
1462
+ method: str = "relative_magnitude",
1463
+ M_grid: Optional[List[float]] = None,
1464
+ alpha: float = 0.05,
1465
+ ax=None,
1466
+ **kwargs
1467
+ ):
1468
+ """
1469
+ Create a sensitivity analysis plot.
1470
+
1471
+ Parameters
1472
+ ----------
1473
+ results : MultiPeriodDiDResults or CallawaySantAnnaResults
1474
+ Results from event study estimation.
1475
+ method : str
1476
+ Type of restriction.
1477
+ M_grid : list of float, optional
1478
+ Grid of M values.
1479
+ alpha : float
1480
+ Significance level.
1481
+ ax : matplotlib.axes.Axes, optional
1482
+ Axes to plot on.
1483
+ **kwargs
1484
+ Additional arguments passed to plot method.
1485
+
1486
+ Returns
1487
+ -------
1488
+ ax : matplotlib.axes.Axes
1489
+ The axes with the plot.
1490
+ """
1491
+ honest = HonestDiD(method=method, alpha=alpha)
1492
+ sensitivity = honest.sensitivity_analysis(results, M_grid=M_grid)
1493
+ return sensitivity.plot(ax=ax, **kwargs)