diff-diff 2.1.8__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {diff_diff-2.1.8 → diff_diff-2.2.0}/PKG-INFO +6 -1
  2. {diff_diff-2.1.8 → diff_diff-2.2.0}/README.md +5 -0
  3. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/__init__.py +1 -1
  4. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/_backend.py +16 -4
  5. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/linalg.py +49 -8
  6. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/trop.py +1032 -30
  7. {diff_diff-2.1.8 → diff_diff-2.2.0}/pyproject.toml +1 -1
  8. diff_diff-2.2.0/rust/Cargo.lock +1540 -0
  9. diff_diff-2.2.0/rust/Cargo.toml +34 -0
  10. {diff_diff-2.1.8 → diff_diff-2.2.0}/rust/src/bootstrap.rs +3 -3
  11. {diff_diff-2.1.8 → diff_diff-2.2.0}/rust/src/lib.rs +6 -2
  12. {diff_diff-2.1.8 → diff_diff-2.2.0}/rust/src/linalg.rs +199 -67
  13. {diff_diff-2.1.8 → diff_diff-2.2.0}/rust/src/trop.rs +722 -26
  14. {diff_diff-2.1.8 → diff_diff-2.2.0}/rust/src/weights.rs +5 -5
  15. diff_diff-2.1.8/rust/Cargo.lock +0 -2321
  16. diff_diff-2.1.8/rust/Cargo.toml +0 -43
  17. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/bacon.py +0 -0
  18. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/datasets.py +0 -0
  19. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/diagnostics.py +0 -0
  20. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/estimators.py +0 -0
  21. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/honest_did.py +0 -0
  22. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/power.py +0 -0
  23. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/prep.py +0 -0
  24. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/prep_dgp.py +0 -0
  25. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/pretrends.py +0 -0
  26. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/results.py +0 -0
  27. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/staggered.py +0 -0
  28. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/staggered_aggregation.py +0 -0
  29. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/staggered_bootstrap.py +0 -0
  30. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/staggered_results.py +0 -0
  31. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/sun_abraham.py +0 -0
  32. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/synthetic_did.py +0 -0
  33. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/triple_diff.py +0 -0
  34. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/twfe.py +0 -0
  35. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/utils.py +0 -0
  36. {diff_diff-2.1.8 → diff_diff-2.2.0}/diff_diff/visualization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.1.8
3
+ Version: 2.2.0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -1302,6 +1302,7 @@ trop = TROP(
1302
1302
 
1303
1303
  ```python
1304
1304
  TROP(
1305
+ method='twostep', # Estimation method: 'twostep' (default) or 'joint'
1305
1306
  lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5])
1306
1307
  lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5])
1307
1308
  lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10])
@@ -1314,6 +1315,10 @@ TROP(
1314
1315
  )
1315
1316
  ```
1316
1317
 
1318
+ **Estimation methods:**
1319
+ - `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive.
1320
+ - `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects.
1321
+
1317
1322
  **Convenience function:**
1318
1323
 
1319
1324
  ```python
@@ -1267,6 +1267,7 @@ trop = TROP(
1267
1267
 
1268
1268
  ```python
1269
1269
  TROP(
1270
+ method='twostep', # Estimation method: 'twostep' (default) or 'joint'
1270
1271
  lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5])
1271
1272
  lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5])
1272
1273
  lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10])
@@ -1279,6 +1280,10 @@ TROP(
1279
1280
  )
1280
1281
  ```
1281
1282
 
1283
+ **Estimation methods:**
1284
+ - `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive.
1285
+ - `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects.
1286
+
1282
1287
  **Convenience function:**
1283
1288
 
1284
1289
  ```python
@@ -136,7 +136,7 @@ from diff_diff.datasets import (
136
136
  load_mpdta,
137
137
  )
138
138
 
139
- __version__ = "2.1.8"
139
+ __version__ = "2.2.0"
140
140
  __all__ = [
141
141
  # Estimators
142
142
  "DifferenceInDifferences",
@@ -23,10 +23,13 @@ try:
23
23
  project_simplex as _rust_project_simplex,
24
24
  solve_ols as _rust_solve_ols,
25
25
  compute_robust_vcov as _rust_compute_robust_vcov,
26
- # TROP estimator acceleration
26
+ # TROP estimator acceleration (twostep method)
27
27
  compute_unit_distance_matrix as _rust_unit_distance_matrix,
28
28
  loocv_grid_search as _rust_loocv_grid_search,
29
29
  bootstrap_trop_variance as _rust_bootstrap_trop_variance,
30
+ # TROP estimator acceleration (joint method)
31
+ loocv_grid_search_joint as _rust_loocv_grid_search_joint,
32
+ bootstrap_trop_variance_joint as _rust_bootstrap_trop_variance_joint,
30
33
  )
31
34
  _rust_available = True
32
35
  except ImportError:
@@ -36,10 +39,13 @@ except ImportError:
36
39
  _rust_project_simplex = None
37
40
  _rust_solve_ols = None
38
41
  _rust_compute_robust_vcov = None
39
- # TROP estimator acceleration
42
+ # TROP estimator acceleration (twostep method)
40
43
  _rust_unit_distance_matrix = None
41
44
  _rust_loocv_grid_search = None
42
45
  _rust_bootstrap_trop_variance = None
46
+ # TROP estimator acceleration (joint method)
47
+ _rust_loocv_grid_search_joint = None
48
+ _rust_bootstrap_trop_variance_joint = None
43
49
 
44
50
  # Determine final backend based on environment variable and availability
45
51
  if _backend_env == 'python':
@@ -50,10 +56,13 @@ if _backend_env == 'python':
50
56
  _rust_project_simplex = None
51
57
  _rust_solve_ols = None
52
58
  _rust_compute_robust_vcov = None
53
- # TROP estimator acceleration
59
+ # TROP estimator acceleration (twostep method)
54
60
  _rust_unit_distance_matrix = None
55
61
  _rust_loocv_grid_search = None
56
62
  _rust_bootstrap_trop_variance = None
63
+ # TROP estimator acceleration (joint method)
64
+ _rust_loocv_grid_search_joint = None
65
+ _rust_bootstrap_trop_variance_joint = None
57
66
  elif _backend_env == 'rust':
58
67
  # Force Rust mode - fail if not available
59
68
  if not _rust_available:
@@ -73,8 +82,11 @@ __all__ = [
73
82
  '_rust_project_simplex',
74
83
  '_rust_solve_ols',
75
84
  '_rust_compute_robust_vcov',
76
- # TROP estimator acceleration
85
+ # TROP estimator acceleration (twostep method)
77
86
  '_rust_unit_distance_matrix',
78
87
  '_rust_loocv_grid_search',
79
88
  '_rust_bootstrap_trop_variance',
89
+ # TROP estimator acceleration (joint method)
90
+ '_rust_loocv_grid_search_joint',
91
+ '_rust_bootstrap_trop_variance_joint',
80
92
  ]
@@ -251,10 +251,10 @@ def _solve_ols_rust(
251
251
  cluster_ids: Optional[np.ndarray] = None,
252
252
  return_vcov: bool = True,
253
253
  return_fitted: bool = False,
254
- ) -> Union[
254
+ ) -> Optional[Union[
255
255
  Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
256
256
  Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
257
- ]:
257
+ ]]:
258
258
  """
259
259
  Rust backend implementation of solve_ols for full-rank matrices.
260
260
 
@@ -296,15 +296,30 @@ def _solve_ols_rust(
296
296
  Fitted values if return_fitted=True.
297
297
  vcov : np.ndarray, optional
298
298
  Variance-covariance matrix if return_vcov=True.
299
+ None
300
+ If Rust backend detects numerical instability and caller should
301
+ fall back to Python backend.
299
302
  """
300
303
  # Convert cluster_ids to int64 for Rust (handles string/categorical IDs)
301
304
  if cluster_ids is not None:
302
305
  cluster_ids = _factorize_cluster_ids(cluster_ids)
303
306
 
304
- # Call Rust backend
305
- coefficients, residuals, vcov = _rust_solve_ols(
306
- X, y, cluster_ids=cluster_ids, return_vcov=return_vcov
307
- )
307
+ # Call Rust backend with fallback on numerical instability
308
+ try:
309
+ coefficients, residuals, vcov = _rust_solve_ols(
310
+ X, y, cluster_ids=cluster_ids, return_vcov=return_vcov
311
+ )
312
+ except ValueError as e:
313
+ error_msg = str(e).lower()
314
+ if "numerically unstable" in error_msg or "singular" in error_msg:
315
+ warnings.warn(
316
+ f"Rust backend detected numerical instability: {e}. "
317
+ "Falling back to Python backend.",
318
+ UserWarning,
319
+ stacklevel=3,
320
+ )
321
+ return None # Signal caller to use Python fallback
322
+ raise
308
323
 
309
324
  # Convert to numpy arrays
310
325
  coefficients = np.asarray(coefficients)
@@ -468,12 +483,15 @@ def solve_ols(
468
483
  # This saves O(nk²) QR overhead but won't detect rank-deficient matrices
469
484
  if skip_rank_check:
470
485
  if HAS_RUST_BACKEND and _rust_solve_ols is not None:
471
- return _solve_ols_rust(
486
+ result = _solve_ols_rust(
472
487
  X, y,
473
488
  cluster_ids=cluster_ids,
474
489
  return_vcov=return_vcov,
475
490
  return_fitted=return_fitted,
476
491
  )
492
+ if result is not None:
493
+ return result
494
+ # Fall through to NumPy on numerical instability
477
495
  # Fall through to Python without rank check (user guarantees full rank)
478
496
  return _solve_ols_numpy(
479
497
  X, y,
@@ -499,6 +517,7 @@ def solve_ols(
499
517
  # Routing strategy:
500
518
  # - Full-rank + Rust available → fast Rust backend (SVD-based solve)
501
519
  # - Rank-deficient → Python backend (proper NA handling, valid SEs)
520
+ # - Rust numerical instability → Python fallback (via None return)
502
521
  # - No Rust → Python backend (works for all cases)
503
522
  if HAS_RUST_BACKEND and _rust_solve_ols is not None and not is_rank_deficient:
504
523
  result = _solve_ols_rust(
@@ -508,6 +527,19 @@ def solve_ols(
508
527
  return_fitted=return_fitted,
509
528
  )
510
529
 
530
+ # Check for None: Rust backend detected numerical instability and
531
+ # signaled us to fall back to Python backend
532
+ if result is None:
533
+ return _solve_ols_numpy(
534
+ X, y,
535
+ cluster_ids=cluster_ids,
536
+ return_vcov=return_vcov,
537
+ return_fitted=return_fitted,
538
+ rank_deficient_action=rank_deficient_action,
539
+ column_names=column_names,
540
+ _precomputed_rank_info=None, # Force fresh rank detection
541
+ )
542
+
511
543
  # Check for NaN vcov: Rust SVD may detect rank-deficiency that QR missed
512
544
  # for ill-conditioned matrices (QR and SVD have different numerical properties).
513
545
  # When this happens, fall back to Python's R-style handling.
@@ -732,7 +764,7 @@ def compute_robust_vcov(
732
764
  try:
733
765
  return _rust_compute_robust_vcov(X, residuals, cluster_ids_int)
734
766
  except ValueError as e:
735
- # Translate Rust LAPACK errors to consistent Python error messages
767
+ # Translate Rust errors to consistent Python error messages or fallback
736
768
  error_msg = str(e)
737
769
  if "Matrix inversion failed" in error_msg:
738
770
  raise ValueError(
@@ -740,6 +772,15 @@ def compute_robust_vcov(
740
772
  "This indicates perfect multicollinearity. Check your fixed effects "
741
773
  "and covariates for linear dependencies."
742
774
  ) from e
775
+ if "numerically unstable" in error_msg.lower():
776
+ # Fall back to NumPy on numerical instability (with warning)
777
+ warnings.warn(
778
+ f"Rust backend detected numerical instability: {e}. "
779
+ "Falling back to Python backend for variance computation.",
780
+ UserWarning,
781
+ stacklevel=2,
782
+ )
783
+ return _compute_robust_vcov_numpy(X, residuals, cluster_ids)
743
784
  raise
744
785
 
745
786
  # Fallback to NumPy implementation