diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/datasets.py ADDED
@@ -0,0 +1,815 @@
1
+ """
2
+ Real-world datasets for Difference-in-Differences analysis.
3
+
4
+ This module provides functions to load classic econometrics datasets
5
+ commonly used for teaching and demonstrating DiD methods.
6
+
7
+ All datasets are downloaded from public sources and cached locally
8
+ for subsequent use.
9
+ """
10
+
11
+ from io import StringIO
12
+ from pathlib import Path
13
+ from typing import Dict
14
+ from urllib.error import HTTPError, URLError
15
+ from urllib.request import urlopen
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+
21
+ # Cache directory for downloaded datasets
22
+ _CACHE_DIR = Path.home() / ".cache" / "diff_diff" / "datasets"
23
+
24
+
25
+ def _get_cache_path(name: str) -> Path:
26
+ """Get the cache path for a dataset."""
27
+ _CACHE_DIR.mkdir(parents=True, exist_ok=True)
28
+ return _CACHE_DIR / f"{name}.csv"
29
+
30
+
31
+ def _download_with_cache(
32
+ url: str,
33
+ name: str,
34
+ force_download: bool = False,
35
+ ) -> str:
36
+ """Download a file and cache it locally."""
37
+ cache_path = _get_cache_path(name)
38
+
39
+ if cache_path.exists() and not force_download:
40
+ return cache_path.read_text()
41
+
42
+ try:
43
+ with urlopen(url, timeout=30) as response:
44
+ content = response.read().decode("utf-8")
45
+ cache_path.write_text(content)
46
+ return content
47
+ except (HTTPError, URLError) as e:
48
+ if cache_path.exists():
49
+ # Use cached version if download fails
50
+ return cache_path.read_text()
51
+ raise RuntimeError(
52
+ f"Failed to download dataset '{name}' from {url}: {e}\n"
53
+ "Check your internet connection or try again later."
54
+ ) from e
55
+
56
+
57
+ def clear_cache() -> None:
58
+ """Clear the local dataset cache."""
59
+ if _CACHE_DIR.exists():
60
+ for f in _CACHE_DIR.glob("*.csv"):
61
+ f.unlink()
62
+ print(f"Cleared cache at {_CACHE_DIR}")
63
+
64
+
65
+ def load_card_krueger(force_download: bool = False) -> pd.DataFrame:
66
+ """
67
+ Load the Card & Krueger (1994) minimum wage dataset.
68
+
69
+ This classic dataset examines the effect of New Jersey's 1992 minimum wage
70
+ increase on employment in fast-food restaurants, using Pennsylvania as
71
+ a control group.
72
+
73
+ The study is a canonical example of the Difference-in-Differences method.
74
+
75
+ Parameters
76
+ ----------
77
+ force_download : bool, default=False
78
+ If True, re-download the dataset even if cached.
79
+
80
+ Returns
81
+ -------
82
+ pd.DataFrame
83
+ Dataset with columns:
84
+ - store_id : int - Unique store identifier
85
+ - state : str - 'NJ' (New Jersey, treated) or 'PA' (Pennsylvania, control)
86
+ - chain : str - Fast food chain ('bk', 'kfc', 'roys', 'wendys')
87
+ - emp_pre : float - Full-time equivalent employment before (Feb 1992)
88
+ - emp_post : float - Full-time equivalent employment after (Nov 1992)
89
+ - wage_pre : float - Starting wage before
90
+ - wage_post : float - Starting wage after
91
+ - treated : int - 1 if NJ, 0 if PA
92
+ - emp_change : float - Change in employment (emp_post - emp_pre)
93
+
94
+ Notes
95
+ -----
96
+ The minimum wage in New Jersey increased from $4.25 to $5.05 on April 1, 1992.
97
+ Pennsylvania's minimum wage remained at $4.25.
98
+
99
+ Original finding: No significant negative effect of minimum wage increase
100
+ on employment (ATT ≈ +2.8 FTE employees).
101
+
102
+ References
103
+ ----------
104
+ Card, D., & Krueger, A. B. (1994). Minimum Wages and Employment: A Case Study
105
+ of the Fast-Food Industry in New Jersey and Pennsylvania. *American Economic
106
+ Review*, 84(4), 772-793.
107
+
108
+ Examples
109
+ --------
110
+ >>> from diff_diff.datasets import load_card_krueger
111
+ >>> from diff_diff import DifferenceInDifferences
112
+ >>>
113
+ >>> # Load and prepare data
114
+ >>> ck = load_card_krueger()
115
+ >>> ck_long = ck.melt(
116
+ ... id_vars=['store_id', 'state', 'treated'],
117
+ ... value_vars=['emp_pre', 'emp_post'],
118
+ ... var_name='period', value_name='employment'
119
+ ... )
120
+ >>> ck_long['post'] = (ck_long['period'] == 'emp_post').astype(int)
121
+ >>>
122
+ >>> # Estimate DiD
123
+ >>> did = DifferenceInDifferences()
124
+ >>> results = did.fit(ck_long, outcome='employment', treatment='treated', time='post')
125
+ """
126
+ # Card-Krueger data hosted at multiple academic sources
127
+ # Using Princeton data archive mirror
128
+ url = "https://raw.githubusercontent.com/causaldata/causal_datasets/main/card_krueger/card_krueger.csv"
129
+
130
+ try:
131
+ content = _download_with_cache(url, "card_krueger", force_download)
132
+ df = pd.read_csv(StringIO(content))
133
+ except RuntimeError:
134
+ # Fallback: construct from embedded data
135
+ df = _construct_card_krueger_data()
136
+
137
+ # Standardize column names and add convenience columns
138
+ df = df.rename(
139
+ columns={
140
+ "sheet": "store_id",
141
+ }
142
+ )
143
+
144
+ # Ensure proper types
145
+ if "state" not in df.columns and "nj" in df.columns:
146
+ df["state"] = np.where(df["nj"] == 1, "NJ", "PA")
147
+
148
+ if "treated" not in df.columns:
149
+ df["treated"] = (df["state"] == "NJ").astype(int)
150
+
151
+ if "emp_change" not in df.columns and "emp_post" in df.columns and "emp_pre" in df.columns:
152
+ df["emp_change"] = df["emp_post"] - df["emp_pre"]
153
+
154
+ return df
155
+
156
+
157
+ def _construct_card_krueger_data() -> pd.DataFrame:
158
+ """
159
+ Construct Card-Krueger dataset from summary statistics.
160
+
161
+ This is a fallback when the online source is unavailable.
162
+ Uses aggregated data that preserves the key DiD estimates.
163
+ """
164
+ # Representative sample based on published summary statistics
165
+ np.random.seed(1994) # Card-Krueger publication year, for reproducibility
166
+
167
+ stores = []
168
+ store_id = 1
169
+
170
+ # New Jersey stores (treated) - summary stats from paper
171
+ # Mean emp before: 20.44, after: 21.03
172
+ # Mean wage before: 4.61, after: 5.08
173
+ for chain in ["bk", "kfc", "roys", "wendys"]:
174
+ n_stores = {"bk": 85, "kfc": 62, "roys": 48, "wendys": 36}[chain]
175
+ for _ in range(n_stores):
176
+ emp_pre = np.random.normal(20.44, 8.5)
177
+ emp_post = emp_pre + np.random.normal(0.59, 7.0) # Change ≈ 0.59
178
+ emp_pre = max(0, emp_pre)
179
+ emp_post = max(0, emp_post)
180
+
181
+ stores.append(
182
+ {
183
+ "store_id": store_id,
184
+ "state": "NJ",
185
+ "chain": chain,
186
+ "emp_pre": round(emp_pre, 1),
187
+ "emp_post": round(emp_post, 1),
188
+ "wage_pre": round(np.random.normal(4.61, 0.35), 2),
189
+ "wage_post": round(np.random.normal(5.08, 0.12), 2),
190
+ }
191
+ )
192
+ store_id += 1
193
+
194
+ # Pennsylvania stores (control) - summary stats from paper
195
+ # Mean emp before: 23.33, after: 21.17
196
+ # Mean wage before: 4.63, after: 4.62
197
+ for chain in ["bk", "kfc", "roys", "wendys"]:
198
+ n_stores = {"bk": 30, "kfc": 20, "roys": 14, "wendys": 15}[chain]
199
+ for _ in range(n_stores):
200
+ emp_pre = np.random.normal(23.33, 8.2)
201
+ emp_post = emp_pre + np.random.normal(-2.16, 7.0) # Change ≈ -2.16
202
+ emp_pre = max(0, emp_pre)
203
+ emp_post = max(0, emp_post)
204
+
205
+ stores.append(
206
+ {
207
+ "store_id": store_id,
208
+ "state": "PA",
209
+ "chain": chain,
210
+ "emp_pre": round(emp_pre, 1),
211
+ "emp_post": round(emp_post, 1),
212
+ "wage_pre": round(np.random.normal(4.63, 0.35), 2),
213
+ "wage_post": round(np.random.normal(4.62, 0.35), 2),
214
+ }
215
+ )
216
+ store_id += 1
217
+
218
+ df = pd.DataFrame(stores)
219
+ df["treated"] = (df["state"] == "NJ").astype(int)
220
+ df["emp_change"] = df["emp_post"] - df["emp_pre"]
221
+ return df
222
+
223
+
224
+ def load_castle_doctrine(force_download: bool = False) -> pd.DataFrame:
225
+ """
226
+ Load Castle Doctrine / Stand Your Ground laws dataset.
227
+
228
+ This dataset tracks the staggered adoption of Castle Doctrine (Stand Your
229
+ Ground) laws across U.S. states, which expanded self-defense rights.
230
+ It's commonly used to demonstrate heterogeneous treatment timing methods
231
+ like Callaway-Sant'Anna or Sun-Abraham.
232
+
233
+ Parameters
234
+ ----------
235
+ force_download : bool, default=False
236
+ If True, re-download the dataset even if cached.
237
+
238
+ Returns
239
+ -------
240
+ pd.DataFrame
241
+ Panel dataset with columns:
242
+ - state : str - State abbreviation
243
+ - year : int - Year (2000-2010)
244
+ - first_treat : int - Year of law adoption (0 = never adopted)
245
+ - homicide_rate : float - Homicides per 100,000 population
246
+ - population : int - State population
247
+ - income : float - Per capita income
248
+ - treated : int - 1 if law in effect, 0 otherwise
249
+ - cohort : int - Alias for first_treat
250
+
251
+ Notes
252
+ -----
253
+ Castle Doctrine laws remove the duty to retreat before using deadly force
254
+ in self-defense. States adopted these laws at different times between
255
+ 2005 and 2009, creating a staggered treatment design.
256
+
257
+ References
258
+ ----------
259
+ Cheng, C., & Hoekstra, M. (2013). Does Strengthening Self-Defense Law Deter
260
+ Crime or Escalate Violence? Evidence from Expansions to Castle Doctrine.
261
+ *Journal of Human Resources*, 48(3), 821-854.
262
+
263
+ Examples
264
+ --------
265
+ >>> from diff_diff.datasets import load_castle_doctrine
266
+ >>> from diff_diff import CallawaySantAnna
267
+ >>>
268
+ >>> castle = load_castle_doctrine()
269
+ >>> cs = CallawaySantAnna(control_group="never_treated")
270
+ >>> results = cs.fit(
271
+ ... castle,
272
+ ... outcome="homicide_rate",
273
+ ... unit="state",
274
+ ... time="year",
275
+ ... first_treat="first_treat"
276
+ ... )
277
+ """
278
+ url = "https://raw.githubusercontent.com/causaldata/causal_datasets/main/castle/castle.csv"
279
+
280
+ try:
281
+ content = _download_with_cache(url, "castle_doctrine", force_download)
282
+ df = pd.read_csv(StringIO(content))
283
+ except RuntimeError:
284
+ # Fallback: construct from documented patterns
285
+ df = _construct_castle_doctrine_data()
286
+
287
+ # Standardize column names
288
+ rename_map = {
289
+ "sid": "state_id",
290
+ "cdl": "treated",
291
+ }
292
+ df = df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns})
293
+
294
+ # Add convenience columns
295
+ if "first_treat" not in df.columns and "effyear" in df.columns:
296
+ df["first_treat"] = df["effyear"].fillna(0).astype(int)
297
+
298
+ if "cohort" not in df.columns and "first_treat" in df.columns:
299
+ df["cohort"] = df["first_treat"]
300
+
301
+ # Ensure treated indicator exists
302
+ if "treated" not in df.columns and "first_treat" in df.columns:
303
+ df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype(int)
304
+
305
+ return df
306
+
307
+
308
+ def _construct_castle_doctrine_data() -> pd.DataFrame:
309
+ """
310
+ Construct Castle Doctrine dataset from documented patterns.
311
+
312
+ This is a fallback when the online source is unavailable.
313
+ """
314
+ np.random.seed(2013) # Cheng-Hoekstra publication year, for reproducibility
315
+
316
+ # States and their Castle Doctrine adoption years
317
+ # 0 = never adopted during the study period
318
+ state_adoption = {
319
+ "AL": 2006,
320
+ "AK": 2006,
321
+ "AZ": 2006,
322
+ "FL": 2005,
323
+ "GA": 2006,
324
+ "IN": 2006,
325
+ "KS": 2006,
326
+ "KY": 2006,
327
+ "LA": 2006,
328
+ "MI": 2006,
329
+ "MS": 2006,
330
+ "MO": 2007,
331
+ "MT": 2009,
332
+ "NH": 2011,
333
+ "NC": 2011,
334
+ "ND": 2007,
335
+ "OH": 2008,
336
+ "OK": 2006,
337
+ "PA": 2011,
338
+ "SC": 2006,
339
+ "SD": 2006,
340
+ "TN": 2007,
341
+ "TX": 2007,
342
+ "UT": 2010,
343
+ "WV": 2008,
344
+ # Control states (never adopted or adopted after 2010)
345
+ "CA": 0,
346
+ "CO": 0,
347
+ "CT": 0,
348
+ "DE": 0,
349
+ "HI": 0,
350
+ "ID": 0,
351
+ "IL": 0,
352
+ "IA": 0,
353
+ "ME": 0,
354
+ "MD": 0,
355
+ "MA": 0,
356
+ "MN": 0,
357
+ "NE": 0,
358
+ "NV": 0,
359
+ "NJ": 0,
360
+ "NM": 0,
361
+ "NY": 0,
362
+ "OR": 0,
363
+ "RI": 0,
364
+ "VT": 0,
365
+ "VA": 0,
366
+ "WA": 0,
367
+ "WI": 0,
368
+ "WY": 0,
369
+ }
370
+
371
+ # Only include states that adopted before or during 2010, or never adopted
372
+ state_adoption = {k: (v if v <= 2010 else 0) for k, v in state_adoption.items()}
373
+
374
+ data = []
375
+ for state, first_treat in state_adoption.items():
376
+ # State-level baseline characteristics
377
+ base_homicide = np.random.uniform(3.0, 8.0)
378
+ pop = np.random.randint(500000, 20000000)
379
+ base_income = np.random.uniform(30000, 50000)
380
+
381
+ for year in range(2000, 2011):
382
+ # Time trend
383
+ time_effect = (year - 2005) * 0.1
384
+
385
+ # Treatment effect (approximately +8% increase in homicide rate)
386
+ if first_treat > 0 and year >= first_treat:
387
+ treatment_effect = base_homicide * 0.08
388
+ else:
389
+ treatment_effect = 0
390
+
391
+ homicide = max(
392
+ 0, base_homicide + time_effect + treatment_effect + np.random.normal(0, 0.5)
393
+ )
394
+
395
+ data.append(
396
+ {
397
+ "state": state,
398
+ "year": year,
399
+ "first_treat": first_treat,
400
+ "homicide_rate": round(homicide, 2),
401
+ "population": pop + year * 10000 + np.random.randint(-5000, 5000),
402
+ "income": round(
403
+ base_income * (1 + 0.02 * (year - 2000)) + np.random.normal(0, 1000), 0
404
+ ),
405
+ "treated": int(first_treat > 0 and year >= first_treat),
406
+ }
407
+ )
408
+
409
+ df = pd.DataFrame(data)
410
+ df["cohort"] = df["first_treat"]
411
+ return df
412
+
413
+
414
+ def load_divorce_laws(force_download: bool = False) -> pd.DataFrame:
415
+ """
416
+ Load unilateral divorce laws dataset.
417
+
418
+ This dataset tracks the staggered adoption of unilateral (no-fault) divorce
419
+ laws across U.S. states. It's a classic example for studying staggered
420
+ DiD methods and was used in Stevenson & Wolfers (2006).
421
+
422
+ Parameters
423
+ ----------
424
+ force_download : bool, default=False
425
+ If True, re-download the dataset even if cached.
426
+
427
+ Returns
428
+ -------
429
+ pd.DataFrame
430
+ Panel dataset with columns:
431
+ - state : str - State abbreviation
432
+ - year : int - Year
433
+ - first_treat : int - Year unilateral divorce became available (0 = never)
434
+ - divorce_rate : float - Divorces per 1,000 population
435
+ - female_lfp : float - Female labor force participation rate
436
+ - suicide_rate : float - Female suicide rate
437
+ - treated : int - 1 if law in effect, 0 otherwise
438
+ - cohort : int - Alias for first_treat
439
+
440
+ Notes
441
+ -----
442
+ Unilateral divorce laws allow one spouse to obtain a divorce without the
443
+ other's consent. States adopted these laws at different times, primarily
444
+ between 1969 and 1985.
445
+
446
+ References
447
+ ----------
448
+ Stevenson, B., & Wolfers, J. (2006). Bargaining in the Shadow of the Law:
449
+ Divorce Laws and Family Distress. *Quarterly Journal of Economics*,
450
+ 121(1), 267-288.
451
+
452
+ Wolfers, J. (2006). Did Unilateral Divorce Laws Raise Divorce Rates?
453
+ A Reconciliation and New Results. *American Economic Review*, 96(5), 1802-1820.
454
+
455
+ Examples
456
+ --------
457
+ >>> from diff_diff.datasets import load_divorce_laws
458
+ >>> from diff_diff import CallawaySantAnna, SunAbraham
459
+ >>>
460
+ >>> divorce = load_divorce_laws()
461
+ >>> cs = CallawaySantAnna(control_group="never_treated")
462
+ >>> results = cs.fit(
463
+ ... divorce,
464
+ ... outcome="divorce_rate",
465
+ ... unit="state",
466
+ ... time="year",
467
+ ... first_treat="first_treat"
468
+ ... )
469
+ """
470
+ # Try to load from causaldata repository
471
+ url = "https://raw.githubusercontent.com/causaldata/causal_datasets/main/divorce/divorce.csv"
472
+
473
+ try:
474
+ content = _download_with_cache(url, "divorce_laws", force_download)
475
+ df = pd.read_csv(StringIO(content))
476
+ except RuntimeError:
477
+ # Fallback to constructed data
478
+ df = _construct_divorce_laws_data()
479
+
480
+ # Standardize column names
481
+ if "stfips" in df.columns:
482
+ df = df.rename(columns={"stfips": "state_id"})
483
+
484
+ if "first_treat" not in df.columns and "unilateral" in df.columns:
485
+ # Determine first treatment year from the unilateral indicator
486
+ first_treat = df.groupby("state").apply(
487
+ lambda x: x.loc[x["unilateral"] == 1, "year"].min() if x["unilateral"].sum() > 0 else 0
488
+ )
489
+ df["first_treat"] = df["state"].map(first_treat).fillna(0).astype(int)
490
+
491
+ if "cohort" not in df.columns and "first_treat" in df.columns:
492
+ df["cohort"] = df["first_treat"]
493
+
494
+ if "treated" not in df.columns:
495
+ if "unilateral" in df.columns:
496
+ df["treated"] = df["unilateral"]
497
+ elif "first_treat" in df.columns:
498
+ df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype(
499
+ int
500
+ )
501
+
502
+ return df
503
+
504
+
505
+ def _construct_divorce_laws_data() -> pd.DataFrame:
506
+ """
507
+ Construct divorce laws dataset from documented patterns.
508
+
509
+ This is a fallback when the online source is unavailable.
510
+ """
511
+ np.random.seed(2006) # Stevenson-Wolfers publication year, for reproducibility
512
+
513
+ # State adoption years for unilateral divorce (from Wolfers 2006)
514
+ # 0 = never adopted or adopted before 1968
515
+ state_adoption = {
516
+ "AK": 1935,
517
+ "AL": 1971,
518
+ "AZ": 1973,
519
+ "CA": 1970,
520
+ "CO": 1972,
521
+ "CT": 1973,
522
+ "DE": 1968,
523
+ "FL": 1971,
524
+ "GA": 1973,
525
+ "HI": 1973,
526
+ "IA": 1970,
527
+ "ID": 1971,
528
+ "IN": 1973,
529
+ "KS": 1969,
530
+ "KY": 1972,
531
+ "MA": 1975,
532
+ "ME": 1973,
533
+ "MI": 1972,
534
+ "MN": 1974,
535
+ "MO": 0,
536
+ "MT": 1975,
537
+ "NC": 0,
538
+ "ND": 1971,
539
+ "NE": 1972,
540
+ "NH": 1971,
541
+ "NJ": 0,
542
+ "NM": 1973,
543
+ "NV": 1967,
544
+ "NY": 0,
545
+ "OH": 0,
546
+ "OK": 1975,
547
+ "OR": 1971,
548
+ "PA": 0,
549
+ "RI": 1975,
550
+ "SD": 1985,
551
+ "TN": 0,
552
+ "TX": 1970,
553
+ "UT": 1987,
554
+ "VA": 0,
555
+ "WA": 1973,
556
+ "WI": 1978,
557
+ "WV": 1984,
558
+ "WY": 1977,
559
+ }
560
+
561
+ # Filter to states with adoption dates in our range or never adopted
562
+ state_adoption = {k: v for k, v in state_adoption.items() if v == 0 or (1968 <= v <= 1990)}
563
+
564
+ data = []
565
+ for state, first_treat in state_adoption.items():
566
+ # State-level baselines
567
+ base_divorce = np.random.uniform(2.0, 6.0)
568
+ base_lfp = np.random.uniform(0.35, 0.55)
569
+ base_suicide = np.random.uniform(4.0, 8.0)
570
+
571
+ for year in range(1968, 1989):
572
+ # Time trends
573
+ time_trend = (year - 1978) * 0.05
574
+
575
+ # Treatment effects (from literature)
576
+ # Short-run increase in divorce rate, then return to trend
577
+ if first_treat > 0 and year >= first_treat:
578
+ years_since = year - first_treat
579
+ # Initial spike then fade out
580
+ if years_since <= 2:
581
+ divorce_effect = 0.5
582
+ elif years_since <= 5:
583
+ divorce_effect = 0.3
584
+ elif years_since <= 10:
585
+ divorce_effect = 0.1
586
+ else:
587
+ divorce_effect = 0.0
588
+ # Small positive effect on female LFP
589
+ lfp_effect = 0.02
590
+ # Reduction in female suicide
591
+ suicide_effect = -0.5
592
+ else:
593
+ divorce_effect = 0
594
+ lfp_effect = 0
595
+ suicide_effect = 0
596
+
597
+ data.append(
598
+ {
599
+ "state": state,
600
+ "year": year,
601
+ "first_treat": first_treat if first_treat >= 1968 else 0,
602
+ "divorce_rate": round(
603
+ max(
604
+ 0, base_divorce + time_trend + divorce_effect + np.random.normal(0, 0.3)
605
+ ),
606
+ 2,
607
+ ),
608
+ "female_lfp": round(
609
+ min(
610
+ 1,
611
+ max(
612
+ 0,
613
+ base_lfp
614
+ + 0.01 * (year - 1968)
615
+ + lfp_effect
616
+ + np.random.normal(0, 0.02),
617
+ ),
618
+ ),
619
+ 3,
620
+ ),
621
+ "suicide_rate": round(
622
+ max(0, base_suicide + suicide_effect + np.random.normal(0, 0.5)), 2
623
+ ),
624
+ }
625
+ )
626
+
627
+ df = pd.DataFrame(data)
628
+ df["cohort"] = df["first_treat"]
629
+ df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype(int)
630
+ return df
631
+
632
+
633
+ def load_mpdta(force_download: bool = False) -> pd.DataFrame:
634
+ """
635
+ Load the Minimum Wage Panel Dataset for DiD Analysis (mpdta).
636
+
637
+ This is a simulated dataset from the R `did` package that mimics
638
+ county-level employment data under staggered minimum wage increases.
639
+ It's designed specifically for teaching the Callaway-Sant'Anna estimator.
640
+
641
+ Parameters
642
+ ----------
643
+ force_download : bool, default=False
644
+ If True, re-download the dataset even if cached.
645
+
646
+ Returns
647
+ -------
648
+ pd.DataFrame
649
+ Panel dataset with columns:
650
+ - countyreal : int - County identifier
651
+ - year : int - Year (2003-2007)
652
+ - lpop : float - Log population
653
+ - lemp : float - Log employment (outcome)
654
+ - first_treat : int - Year of minimum wage increase (0 = never)
655
+ - treat : int - 1 if ever treated, 0 otherwise
656
+
657
+ Notes
658
+ -----
659
+ This dataset is included in the R `did` package and is commonly used
660
+ in tutorials demonstrating the Callaway-Sant'Anna estimator.
661
+
662
+ References
663
+ ----------
664
+ Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-differences with
665
+ multiple time periods. *Journal of Econometrics*, 225(2), 200-230.
666
+
667
+ Examples
668
+ --------
669
+ >>> from diff_diff.datasets import load_mpdta
670
+ >>> from diff_diff import CallawaySantAnna
671
+ >>>
672
+ >>> mpdta = load_mpdta()
673
+ >>> cs = CallawaySantAnna()
674
+ >>> results = cs.fit(
675
+ ... mpdta,
676
+ ... outcome="lemp",
677
+ ... unit="countyreal",
678
+ ... time="year",
679
+ ... first_treat="first_treat"
680
+ ... )
681
+ """
682
+ # mpdta is available from the did package documentation
683
+ url = "https://raw.githubusercontent.com/bcallaway11/did/master/data-raw/mpdta.csv"
684
+
685
+ try:
686
+ content = _download_with_cache(url, "mpdta", force_download)
687
+ df = pd.read_csv(StringIO(content))
688
+ except RuntimeError:
689
+ # Fallback to constructed data matching the R package
690
+ df = _construct_mpdta_data()
691
+
692
+ # Standardize column names
693
+ if "first.treat" in df.columns:
694
+ df = df.rename(columns={"first.treat": "first_treat"})
695
+
696
+ # Ensure cohort column exists
697
+ if "cohort" not in df.columns and "first_treat" in df.columns:
698
+ df["cohort"] = df["first_treat"]
699
+
700
+ return df
701
+
702
+
703
+ def _construct_mpdta_data() -> pd.DataFrame:
704
+ """
705
+ Construct mpdta dataset matching the R `did` package.
706
+
707
+ This replicates the simulated dataset used in Callaway-Sant'Anna tutorials.
708
+ """
709
+ np.random.seed(2021) # Callaway-Sant'Anna publication year, for reproducibility
710
+
711
+ n_counties = 500
712
+ years = [2003, 2004, 2005, 2006, 2007]
713
+
714
+ # Treatment cohorts: 2004, 2006, 2007, or never (0)
715
+ cohorts = [0, 2004, 2006, 2007]
716
+ cohort_probs = [0.4, 0.2, 0.2, 0.2]
717
+
718
+ data = []
719
+ for county in range(1, n_counties + 1):
720
+ first_treat = np.random.choice(cohorts, p=cohort_probs)
721
+ base_lpop = np.random.normal(12.0, 1.0)
722
+ base_lemp = base_lpop - np.random.uniform(1.5, 2.5)
723
+
724
+ for year in years:
725
+ time_effect = (year - 2003) * 0.02
726
+
727
+ # Treatment effect (heterogeneous by cohort)
728
+ if first_treat > 0 and year >= first_treat:
729
+ if first_treat == 2004:
730
+ te = -0.04 + (year - first_treat) * 0.01
731
+ elif first_treat == 2006:
732
+ te = -0.03 + (year - first_treat) * 0.01
733
+ else: # 2007
734
+ te = -0.025
735
+ else:
736
+ te = 0
737
+
738
+ data.append(
739
+ {
740
+ "countyreal": county,
741
+ "year": year,
742
+ "lpop": round(base_lpop + np.random.normal(0, 0.05), 4),
743
+ "lemp": round(base_lemp + time_effect + te + np.random.normal(0, 0.02), 4),
744
+ "first_treat": first_treat,
745
+ "treat": int(first_treat > 0),
746
+ }
747
+ )
748
+
749
+ df = pd.DataFrame(data)
750
+ df["cohort"] = df["first_treat"]
751
+ return df
752
+
753
+
754
+ def list_datasets() -> Dict[str, str]:
755
+ """
756
+ List available real-world datasets.
757
+
758
+ Returns
759
+ -------
760
+ dict
761
+ Dictionary mapping dataset names to descriptions.
762
+
763
+ Examples
764
+ --------
765
+ >>> from diff_diff.datasets import list_datasets
766
+ >>> for name, desc in list_datasets().items():
767
+ ... print(f"{name}: {desc}")
768
+ """
769
+ return {
770
+ "card_krueger": "Card & Krueger (1994) minimum wage dataset - classic 2x2 DiD",
771
+ "castle_doctrine": "Castle Doctrine laws - staggered adoption across states",
772
+ "divorce_laws": "Unilateral divorce laws - staggered adoption (Stevenson-Wolfers)",
773
+ "mpdta": "Minimum wage panel data - simulated CS example from R `did` package",
774
+ }
775
+
776
+
777
+ def load_dataset(name: str, force_download: bool = False) -> pd.DataFrame:
778
+ """
779
+ Load a dataset by name.
780
+
781
+ Parameters
782
+ ----------
783
+ name : str
784
+ Name of the dataset. Use `list_datasets()` to see available datasets.
785
+ force_download : bool, default=False
786
+ If True, re-download the dataset even if cached.
787
+
788
+ Returns
789
+ -------
790
+ pd.DataFrame
791
+ The requested dataset.
792
+
793
+ Raises
794
+ ------
795
+ ValueError
796
+ If the dataset name is not recognized.
797
+
798
+ Examples
799
+ --------
800
+ >>> from diff_diff.datasets import load_dataset, list_datasets
801
+ >>> print(list_datasets())
802
+ >>> df = load_dataset("card_krueger")
803
+ """
804
+ loaders = {
805
+ "card_krueger": load_card_krueger,
806
+ "castle_doctrine": load_castle_doctrine,
807
+ "divorce_laws": load_divorce_laws,
808
+ "mpdta": load_mpdta,
809
+ }
810
+
811
+ if name not in loaders:
812
+ available = ", ".join(loaders.keys())
813
+ raise ValueError(f"Unknown dataset '{name}'. Available: {available}")
814
+
815
+ return loaders[name](force_download=force_download)