marginaleffects 0.5.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/PKG-INFO +2 -3
  2. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/autodiff/__init__.py +4 -13
  3. marginaleffects-0.6.0/marginaleffects/autodiff/glm/__init__.py +1 -0
  4. marginaleffects-0.6.0/marginaleffects/autodiff/lower.py +219 -0
  5. marginaleffects-0.6.0/marginaleffects/autodiff/ops.py +60 -0
  6. marginaleffects-0.6.0/marginaleffects/autodiff/pipeline.py +237 -0
  7. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/by.py +43 -32
  8. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/classes/model.py +3 -0
  9. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/comparisons.py +210 -117
  10. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/datagrid.py +1 -1
  11. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/estimands.py +4 -4
  12. marginaleffects-0.6.0/marginaleffects/hypothesis_compile.py +213 -0
  13. marginaleffects-0.6.0/marginaleffects/plan.py +170 -0
  14. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/plot/common.py +31 -36
  15. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/predictions.py +94 -86
  16. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/comparison.py +2 -3
  17. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/newdata.py +10 -5
  18. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/variables.py +9 -15
  19. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/vcov.py +2 -1
  20. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/settings.py +7 -0
  21. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/statsmodels/model.py +29 -61
  22. marginaleffects-0.6.0/marginaleffects/test/core.py +17 -0
  23. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/test/formula.py +1 -1
  24. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/uncertainty.py +26 -27
  25. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/utils.py +1 -1
  26. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects.egg-info/PKG-INFO +2 -3
  27. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects.egg-info/SOURCES.txt +11 -8
  28. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects.egg-info/requires.txt +1 -3
  29. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/pyproject.toml +2 -4
  30. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_autodiff.py +35 -1
  31. marginaleffects-0.6.0/tests/test_autodiff_lower.py +180 -0
  32. marginaleffects-0.6.0/tests/test_autodiff_pipeline.py +240 -0
  33. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_bugfix.py +18 -0
  34. marginaleffects-0.6.0/tests/test_comparison_plan.py +162 -0
  35. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_comparisons.py +1 -1
  36. marginaleffects-0.6.0/tests/test_hypothesis_compile.py +51 -0
  37. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_jss.py +2 -5
  38. marginaleffects-0.6.0/tests/test_plan.py +82 -0
  39. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_plot_comparisons.py +2 -5
  40. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_plot_predictions.py +3 -5
  41. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_plot_slopes.py +13 -6
  42. marginaleffects-0.6.0/tests/test_prediction_plan.py +98 -0
  43. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_pyfixest.py +1 -1
  44. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_quantreg.py +0 -1
  45. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/utilities.py +13 -0
  46. marginaleffects-0.5.0/marginaleffects/autodiff/comparisons.py +0 -69
  47. marginaleffects-0.5.0/marginaleffects/autodiff/dispatch.py +0 -310
  48. marginaleffects-0.5.0/marginaleffects/autodiff/glm/__init__.py +0 -5
  49. marginaleffects-0.5.0/marginaleffects/autodiff/glm/comparisons.py +0 -195
  50. marginaleffects-0.5.0/marginaleffects/autodiff/glm/predictions.py +0 -131
  51. marginaleffects-0.5.0/marginaleffects/autodiff/linear/__init__.py +0 -4
  52. marginaleffects-0.5.0/marginaleffects/autodiff/linear/comparisons.py +0 -147
  53. marginaleffects-0.5.0/marginaleffects/autodiff/linear/predictions.py +0 -71
  54. marginaleffects-0.5.0/marginaleffects/autodiff/utils.py +0 -31
  55. marginaleffects-0.5.0/marginaleffects/test/core.py +0 -155
  56. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/README.md +0 -0
  57. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/benchmarks/__init__.py +0 -0
  58. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/benchmarks/benchmark_autodiff.py +0 -0
  59. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/__init__.py +0 -0
  60. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/autodiff/glm/families.py +0 -0
  61. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/classes/__init__.py +0 -0
  62. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/classes/result.py +0 -0
  63. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/datasets.py +0 -0
  64. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/docstrings/__init__.py +0 -0
  65. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/docstrings/params.py +0 -0
  66. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/docstrings/qmd.py +0 -0
  67. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/formula.py +0 -0
  68. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/linearmodels/__init__.py +0 -0
  69. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/linearmodels/model.py +0 -0
  70. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/plot/__init__.py +0 -0
  71. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/plot/comparisons.py +0 -0
  72. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/plot/predictions.py +0 -0
  73. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/plot/slopes.py +0 -0
  74. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/pyfixest/__init__.py +0 -0
  75. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/pyfixest/model.py +0 -0
  76. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/__init__.py +0 -0
  77. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/by.py +0 -0
  78. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/categorical.py +0 -0
  79. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/deprecated.py +0 -0
  80. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/hypothesis_null.py +0 -0
  81. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/sanitize_model.py +0 -0
  82. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/utils.py +0 -0
  83. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sanitize/validation.py +0 -0
  84. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sklearn/__init__.py +0 -0
  85. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/sklearn/model.py +0 -0
  86. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/slopes.py +0 -0
  87. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/statsmodels/__init__.py +0 -0
  88. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/test/__init__.py +0 -0
  89. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/test/equivalence.py +0 -0
  90. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/test/joint.py +0 -0
  91. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/test/main.py +0 -0
  92. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects/transform.py +0 -0
  93. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects.egg-info/dependency_links.txt +0 -0
  94. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/marginaleffects.egg-info/top_level.txt +0 -0
  95. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/setup.cfg +0 -0
  96. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/__init__.py +0 -0
  97. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/helpers.py +0 -0
  98. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_analytic.py +0 -0
  99. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_by.py +0 -0
  100. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_categorical.py +0 -0
  101. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_categorical_validation.py +0 -0
  102. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_comparisons_interaction.py +0 -0
  103. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_datagrid_01.py +0 -0
  104. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_datagrid_02.py +0 -0
  105. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_equivalence.py +0 -0
  106. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_formula.py +0 -0
  107. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_formulaic_utils.py +0 -0
  108. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_hypotheses.py +0 -0
  109. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_hypotheses_joint.py +0 -0
  110. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_hypothesis.py +0 -0
  111. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_linearmodels_panelols.py +0 -0
  112. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_missing.py +0 -0
  113. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_newdata.py +0 -0
  114. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_predictions.py +0 -0
  115. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_sklearn.py +0 -0
  116. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_slopes.py +0 -0
  117. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels.py +0 -0
  118. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_logit.py +0 -0
  119. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_mixedlm.py +0 -0
  120. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_mnlogit.py +0 -0
  121. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_negativebinomial.py +0 -0
  122. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_ols.py +0 -0
  123. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_ordinal.py +0 -0
  124. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_poisson.py +0 -0
  125. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_probit.py +0 -0
  126. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_vcov.py +0 -0
  127. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_statsmodels_wls.py +0 -0
  128. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_typical.py +0 -0
  129. {marginaleffects-0.5.0 → marginaleffects-0.6.0}/tests/test_utils.py +0 -0
@@ -1,11 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: marginaleffects
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
5
5
  License-Expression: GPL-3.0-or-later
6
6
  Requires-Python: >=3.10
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: formulaic>=1.0.2
9
+ Requires-Dist: jax>=0.4.0
9
10
  Requires-Dist: narwhals>=1.34.0
10
11
  Requires-Dist: numpy>=2.0.0
11
12
  Requires-Dist: patsy>=1.0.1
@@ -14,8 +15,6 @@ Requires-Dist: pydantic>=2.10.3
14
15
  Requires-Dist: plotnine>=0.14.5
15
16
  Requires-Dist: scipy>=1.14.1
16
17
  Requires-Dist: pyarrow>=19.0.1
17
- Provides-Extra: autodiff
18
- Requires-Dist: jax>=0.4.0; extra == "autodiff"
19
18
  Provides-Extra: test
20
19
  Requires-Dist: duckdb>=1.1.2; extra == "test"
21
20
  Requires-Dist: matplotlib>=3.7.1; extra == "test"
@@ -29,18 +29,14 @@ if _JAX_AVAILABLE:
29
29
 
30
30
  jax.config.update("jax_enable_x64", True)
31
31
 
32
- # Import submodules to make them accessible
33
- from . import linear as linear
34
32
  from . import glm as glm
33
+ from . import pipeline as pipeline
35
34
 
36
- # Re-export types for convenience
37
- from .comparisons import ComparisonType as ComparisonType
38
35
  from .glm.families import Family as Family, Link as Link
39
36
 
40
37
  __all__ = [
41
- "linear",
42
38
  "glm",
43
- "ComparisonType",
39
+ "pipeline",
44
40
  "Family",
45
41
  "Link",
46
42
  ]
@@ -50,13 +46,8 @@ else:
50
46
  def __getattr__(self, name):
51
47
  _raise_jax_error()
52
48
 
53
- linear = _DummyModule()
54
49
  glm = _DummyModule()
55
-
56
- # Create dummy enums that raise errors
57
- class ComparisonType:
58
- def __getattribute__(self, name):
59
- _raise_jax_error()
50
+ pipeline = _DummyModule()
60
51
 
61
52
  class Family:
62
53
  def __getattribute__(self, name):
@@ -66,4 +57,4 @@ else:
66
57
  def __getattribute__(self, name):
67
58
  _raise_jax_error()
68
59
 
69
- __all__ = []
60
+ __all__ = ["glm", "pipeline", "Family", "Link"]
@@ -0,0 +1 @@
1
+ from . import families as families
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+ from ..settings import is_autodiff_enabled, is_autodiff_forced
9
+ from ..uncertainty import get_se
10
+ from .ops import COMPARISON_OPS
11
+
12
+
13
+ @dataclass
14
+ class Lowered:
15
+ ok: bool
16
+ kwargs: dict | None = None
17
+ reason: str = ""
18
+
19
+
20
+ @dataclass
21
+ class AutodiffResult:
22
+ std_error: np.ndarray
23
+ jacobian: np.ndarray
24
+
25
+
26
+ def _fail(reason: str) -> Lowered:
27
+ return Lowered(False, kwargs=None, reason=reason)
28
+
29
+
30
+ def _model_args(model):
31
+ args = model.get_autodiff_args()
32
+ if args is None:
33
+ return None, _fail("")
34
+ if isinstance(args, str):
35
+ return None, _fail(args)
36
+ return args, None
37
+
38
+
39
+ def _coefs_or_failure(model):
40
+ coefs = np.asarray(model.get_coef(), dtype=float).reshape(-1)
41
+ if np.isnan(coefs).any():
42
+ return None, _fail("models with NA coefficients")
43
+ return coefs, None
44
+
45
+
46
+ def _design_or_failure(X, coefs, n_pred):
47
+ X = np.asarray(X, dtype=float)
48
+ if X.ndim != 2 or X.shape[1] != coefs.size or X.shape[0] != n_pred:
49
+ return None, _fail("this model/data configuration")
50
+ return X, None
51
+
52
+
53
+ def _hypothesis_or_failure(plan):
54
+ if plan.hyp is None:
55
+ return None, None
56
+ if plan.hyp.kind != "matrix":
57
+ return None, _fail("this form of the `hypothesis` argument")
58
+ return np.asarray(plan.hyp.H, dtype=float), None
59
+
60
+
61
+ def _has_nan(x) -> bool:
62
+ return x is not None and np.isnan(np.asarray(x, dtype=float)).any()
63
+
64
+
65
+ def lower_predictions(plan, model) -> Lowered:
66
+ args, failure = _model_args(model)
67
+ if failure is not None:
68
+ return failure
69
+ coefs, failure = _coefs_or_failure(model)
70
+ if failure is not None:
71
+ return failure
72
+ X, failure = _design_or_failure(plan.exog, coefs, plan.n_pred)
73
+ if failure is not None:
74
+ return failure
75
+ if plan.align is not None:
76
+ return _fail("models with grouped/multi-equation outcomes")
77
+ if plan.has_na:
78
+ return _fail("missing values in predictions")
79
+ H, failure = _hypothesis_or_failure(plan)
80
+ if failure is not None:
81
+ return failure
82
+
83
+ agg_segments = None
84
+ agg_num_segments = None
85
+ agg_weights = None
86
+ if plan.agg is not None:
87
+ agg_segments = np.empty(plan.n_pred, dtype=np.int32)
88
+ agg_weights = np.ones(plan.n_pred, dtype=float)
89
+ for i, group in enumerate(plan.agg):
90
+ agg_segments[group.idx] = i
91
+ if group.w is not None:
92
+ if _has_nan(group.w):
93
+ return _fail("missing values in weights")
94
+ agg_weights[group.idx] = np.asarray(group.w, dtype=float)
95
+ agg_num_segments = len(plan.agg)
96
+
97
+ kwargs = {
98
+ **args,
99
+ "X": X,
100
+ "agg_segments": agg_segments,
101
+ "agg_num_segments": agg_num_segments,
102
+ "agg_weights": agg_weights if agg_segments is not None else None,
103
+ "H": H,
104
+ }
105
+ return Lowered(True, kwargs=kwargs)
106
+
107
+
108
+ def lower_comparisons(plan, model) -> Lowered:
109
+ args, failure = _model_args(model)
110
+ if failure is not None:
111
+ return failure
112
+ coefs, failure = _coefs_or_failure(model)
113
+ if failure is not None:
114
+ return failure
115
+ X_hi, failure = _design_or_failure(plan.exog_hi, coefs, plan.n_pred)
116
+ if failure is not None:
117
+ return failure
118
+ X_lo, failure = _design_or_failure(plan.exog_lo, coefs, plan.n_pred)
119
+ if failure is not None:
120
+ return failure
121
+ if plan.align is not None:
122
+ return _fail("models with grouped/multi-equation outcomes")
123
+ if plan.has_na:
124
+ return _fail("missing values in predictions")
125
+
126
+ if any(group.fun_key is None for group in plan.groups):
127
+ return _fail("custom comparison functions")
128
+ if plan.need_y:
129
+ return _fail("elasticities")
130
+ for group in plan.groups:
131
+ if group.fun_key not in COMPARISON_OPS:
132
+ return _fail(f"comparison='{group.fun_key}'")
133
+
134
+ H, failure = _hypothesis_or_failure(plan)
135
+ if failure is not None:
136
+ return failure
137
+
138
+ if not plan.groups:
139
+ order = np.asarray([], dtype=int)
140
+ else:
141
+ order = np.concatenate([group.idx for group in plan.groups]).astype(int)
142
+ if order.size != plan.n_pred:
143
+ return _fail("this model/data configuration")
144
+
145
+ ops = []
146
+ for group in plan.groups:
147
+ spec = COMPARISON_OPS[group.fun_key]
148
+ w = None
149
+ if spec.weighted:
150
+ if _has_nan(group.w):
151
+ return _fail("missing values in weights")
152
+ w = None if group.w is None else np.asarray(group.w, dtype=float)
153
+ ops.append({"op": spec.pipeline_op, "n": len(group.idx), "w": w})
154
+
155
+ kwargs = {
156
+ **args,
157
+ "X_hi": X_hi[order],
158
+ "X_lo": X_lo[order],
159
+ "ops": ops,
160
+ "H": H,
161
+ }
162
+ return Lowered(True, kwargs=kwargs)
163
+
164
+
165
+ def _warn_unsupported(reason):
166
+ if reason:
167
+ warnings.warn(
168
+ "Automatic differentiation does not support "
169
+ f"{reason}. Reverting to finite differences.",
170
+ UserWarning,
171
+ stacklevel=3,
172
+ )
173
+
174
+
175
+ def autodiff_try(plan, model, V, estimate, kind):
176
+ if plan is None or V is None or not is_autodiff_enabled():
177
+ return None
178
+
179
+ warn_on_fallback = is_autodiff_forced()
180
+ lowered = (
181
+ lower_predictions(plan, model)
182
+ if kind == "predictions"
183
+ else lower_comparisons(plan, model)
184
+ )
185
+ if not lowered.ok:
186
+ if warn_on_fallback:
187
+ _warn_unsupported(lowered.reason)
188
+ return None
189
+
190
+ try:
191
+ from . import pipeline
192
+
193
+ result = pipeline.compute(beta=model.get_coef(), **lowered.kwargs)
194
+ except Exception as exc:
195
+ if warn_on_fallback:
196
+ warnings.warn(
197
+ "Automatic differentiation failed "
198
+ f"({exc}). Reverting to finite differences.",
199
+ UserWarning,
200
+ stacklevel=3,
201
+ )
202
+ return None
203
+
204
+ estimate = np.asarray(estimate, dtype=float).reshape(-1)
205
+ if not np.allclose(result["estimate"], estimate, rtol=1e-8, atol=1e-8):
206
+ if warn_on_fallback:
207
+ warnings.warn(
208
+ "Automatic differentiation estimates did not match the standard "
209
+ "pipeline. Reverting to finite differences.",
210
+ UserWarning,
211
+ stacklevel=3,
212
+ )
213
+ return None
214
+
215
+ # Coef/vcov positional alignment is guaranteed by each adapter vault.
216
+ J = np.asarray(result["jacobian"], dtype=float)
217
+ se = get_se(J, V)
218
+ se[se == 0] = np.nan
219
+ return AutodiffResult(std_error=se, jacobian=J)
@@ -0,0 +1,60 @@
1
+ """Shared comparison operation registry for autodiff lowering and execution."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+
10
+ ArrayFn = Callable[[list[Any]], Any]
11
+ WMeanFn = Callable[[Any, Any | None], Any]
12
+ EstimateFn = Callable[[Any, Any, Any | None, ArrayFn, WMeanFn], Any]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class PipelineOp:
17
+ estimate: EstimateFn
18
+ scalar: bool
19
+
20
+ def output_size(self, n: int) -> int:
21
+ return 1 if self.scalar else n
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class ComparisonOp:
26
+ pipeline_op: str
27
+ weighted: bool
28
+
29
+
30
+ def _difference(hi, lo, _w, _array, _wmean):
31
+ return hi - lo
32
+
33
+
34
+ def _ratio(hi, lo, _w, _array, _wmean):
35
+ return hi / lo
36
+
37
+
38
+ def _differenceavg(hi, lo, w, array, wmean):
39
+ return array([wmean(hi, w) - wmean(lo, w)])
40
+
41
+
42
+ def _ratioavg(hi, lo, w, array, wmean):
43
+ return array([wmean(hi, w) / wmean(lo, w)])
44
+
45
+
46
+ PIPELINE_OPS = {
47
+ "difference": PipelineOp(estimate=_difference, scalar=False),
48
+ "ratio": PipelineOp(estimate=_ratio, scalar=False),
49
+ "differenceavg": PipelineOp(estimate=_differenceavg, scalar=True),
50
+ "ratioavg": PipelineOp(estimate=_ratioavg, scalar=True),
51
+ }
52
+
53
+ COMPARISON_OPS = {
54
+ "difference": ComparisonOp(pipeline_op="difference", weighted=False),
55
+ "ratio": ComparisonOp(pipeline_op="ratio", weighted=False),
56
+ "differenceavg": ComparisonOp(pipeline_op="differenceavg", weighted=False),
57
+ "ratioavg": ComparisonOp(pipeline_op="ratioavg", weighted=False),
58
+ "differenceavgwts": ComparisonOp(pipeline_op="differenceavg", weighted=True),
59
+ "ratioavgwts": ComparisonOp(pipeline_op="ratioavg", weighted=True),
60
+ }
@@ -0,0 +1,237 @@
1
+ """Composable JAX pipeline for R-side plan lowering."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import jax
9
+ import jax.numpy as jnp
10
+
11
+ from .glm.families import Family, Link, linkinv, resolve_link
12
+ from .ops import PIPELINE_OPS
13
+
14
+
15
+ _FAMILY = {
16
+ "gaussian": Family.GAUSSIAN,
17
+ "binomial": Family.BINOMIAL,
18
+ "poisson": Family.POISSON,
19
+ "gamma": Family.GAMMA,
20
+ "Gamma": Family.GAMMA,
21
+ }
22
+
23
+ _LINK = {
24
+ "identity": Link.IDENTITY,
25
+ "log": Link.LOG,
26
+ "logit": Link.LOGIT,
27
+ "probit": Link.PROBIT,
28
+ "inverse": Link.INVERSE,
29
+ "sqrt": Link.SQRT,
30
+ "cloglog": Link.CLOGLOG,
31
+ }
32
+
33
+
34
+ def _resolve_family_link(family: str | None, link: str | None) -> int | None:
35
+ if family is None and link is None:
36
+ return None
37
+ try:
38
+ family_type = _FAMILY[family]
39
+ link_type = _LINK[link] if link is not None else None
40
+ except KeyError as exc:
41
+ raise ValueError(f"Unsupported GLM family/link: {family}/{link}") from exc
42
+ return resolve_link(family_type, link_type)
43
+
44
+
45
+ def _wmean(x, w):
46
+ if w is None:
47
+ return jnp.mean(x)
48
+ w = jnp.asarray(w, dtype=jnp.float64)
49
+ return jnp.sum(x * w) / jnp.sum(w)
50
+
51
+
52
+ def _asarray1(x):
53
+ return jnp.asarray(x, dtype=jnp.float64)
54
+
55
+
56
+ def _apply_agg(est, segments, num_segments, weights):
57
+ if segments is None:
58
+ return est
59
+ segments = jnp.asarray(segments, dtype=jnp.int32)
60
+ if weights is None:
61
+ weights = jnp.ones_like(est, dtype=jnp.float64)
62
+ else:
63
+ weights = jnp.asarray(weights, dtype=jnp.float64)
64
+ numer = jax.ops.segment_sum(est * weights, segments, num_segments=num_segments)
65
+ denom = jax.ops.segment_sum(weights, segments, num_segments=num_segments)
66
+ return numer / denom
67
+
68
+
69
+ def _comparison_estimates(mu_hi, mu_lo, ops_meta, ops_weights):
70
+ pieces = []
71
+ start = 0
72
+ w_iter = iter(ops_weights)
73
+ for name, n, has_w in ops_meta:
74
+ stop = start + n
75
+ hi = mu_hi[start:stop]
76
+ lo = mu_lo[start:stop]
77
+ w = next(w_iter) if has_w else None
78
+ spec = PIPELINE_OPS[name]
79
+ pieces.append(spec.estimate(hi, lo, w, _asarray1, _wmean))
80
+ start = stop
81
+ if not pieces:
82
+ return jnp.asarray([], dtype=jnp.float64)
83
+ return jnp.concatenate(pieces)
84
+
85
+
86
+ def _predict(model_type, link_type, x, b):
87
+ eta = x @ b
88
+ if model_type == "glm":
89
+ return linkinv(link_type, eta)
90
+ return eta
91
+
92
+
93
+ # The whole estimate+Jacobian computation is compiled with XLA. Traced arrays
94
+ # only affect the cache key through shape/dtype, so repeated calls at the same
95
+ # problem size recompile nothing; a new shape or plan structure pays one
96
+ # compilation.
97
+ @partial(
98
+ jax.jit,
99
+ static_argnames=(
100
+ "model_type",
101
+ "link_type",
102
+ "ops_meta",
103
+ "agg_num_segments",
104
+ "use_fwd",
105
+ ),
106
+ )
107
+ def _estimate_and_jacobian(
108
+ beta,
109
+ model_type,
110
+ link_type,
111
+ X,
112
+ X_hi,
113
+ X_lo,
114
+ ops_meta,
115
+ ops_weights,
116
+ est_keep,
117
+ agg_segments,
118
+ agg_num_segments,
119
+ agg_weights,
120
+ H,
121
+ use_fwd,
122
+ ):
123
+ def f(b):
124
+ if X is not None:
125
+ est = _predict(model_type, link_type, X, b)
126
+ else:
127
+ mu_hi = _predict(model_type, link_type, X_hi, b)
128
+ mu_lo = _predict(model_type, link_type, X_lo, b)
129
+ est = _comparison_estimates(mu_hi, mu_lo, ops_meta, ops_weights)
130
+
131
+ if est_keep is not None:
132
+ est = est[est_keep]
133
+
134
+ est = _apply_agg(est, agg_segments, agg_num_segments, agg_weights)
135
+
136
+ if H is not None:
137
+ est = est @ H
138
+
139
+ return jnp.atleast_1d(est)
140
+
141
+ estimate = f(beta)
142
+ jac_fun = jax.jacfwd if use_fwd else jax.jacrev
143
+ jacobian = jac_fun(f)(beta)
144
+ jacobian = jnp.reshape(jacobian, (estimate.size, beta.size))
145
+ return estimate, jacobian
146
+
147
+
148
+ def compute(
149
+ beta,
150
+ model_type,
151
+ family=None,
152
+ link=None,
153
+ X=None,
154
+ X_hi=None,
155
+ X_lo=None,
156
+ ops=None,
157
+ est_keep=None,
158
+ agg_segments=None,
159
+ agg_num_segments=None,
160
+ agg_weights=None,
161
+ H=None,
162
+ ):
163
+ """Return estimate and Jacobian for a lowered autodiff plan."""
164
+ beta = jnp.asarray(beta, dtype=jnp.float64)
165
+ if model_type not in ("linear", "glm"):
166
+ raise ValueError(f"Unsupported model_type: {model_type}")
167
+ link_type = _resolve_family_link(family, link) if model_type == "glm" else None
168
+
169
+ if X is not None:
170
+ X = jnp.asarray(X, dtype=jnp.float64)
171
+ if X_hi is not None:
172
+ X_hi = jnp.asarray(X_hi, dtype=jnp.float64)
173
+ if X_lo is not None:
174
+ X_lo = jnp.asarray(X_lo, dtype=jnp.float64)
175
+ if est_keep is not None:
176
+ est_keep = jnp.asarray(est_keep, dtype=jnp.int32)
177
+ if agg_segments is not None:
178
+ agg_segments = jnp.asarray(agg_segments, dtype=jnp.int32)
179
+ agg_num_segments = int(agg_num_segments)
180
+ else:
181
+ agg_num_segments = None
182
+ if agg_weights is not None:
183
+ agg_weights = jnp.asarray(agg_weights, dtype=jnp.float64)
184
+ if H is not None:
185
+ H = jnp.asarray(H, dtype=jnp.float64)
186
+
187
+ # Split ops into a hashable static structure (jit cache key) and traced
188
+ # weight arrays, while tallying the pre-keep estimate length.
189
+ ops_meta = ()
190
+ ops_weights = ()
191
+ if X is not None:
192
+ n_est = X.shape[0]
193
+ else:
194
+ n_est = 0
195
+ for op in ops or []:
196
+ name = op["op"]
197
+ if name not in PIPELINE_OPS:
198
+ raise ValueError(f"Unsupported comparison op: {name}")
199
+ spec = PIPELINE_OPS[name]
200
+ n = int(op["n"])
201
+ w = op.get("w")
202
+ ops_meta = ops_meta + ((name, n, w is not None),)
203
+ if w is not None:
204
+ ops_weights = ops_weights + (jnp.asarray(w, dtype=jnp.float64),)
205
+ n_est += spec.output_size(n)
206
+
207
+ # Output size is known from shapes alone; it picks forward vs reverse mode
208
+ # before tracing.
209
+ n_out = n_est
210
+ if est_keep is not None:
211
+ n_out = est_keep.shape[0]
212
+ if agg_segments is not None:
213
+ n_out = agg_num_segments
214
+ if H is not None:
215
+ n_out = H.shape[1]
216
+ use_fwd = beta.size <= max(n_out, 1)
217
+
218
+ estimate, jacobian = _estimate_and_jacobian(
219
+ beta,
220
+ model_type,
221
+ link_type,
222
+ X,
223
+ X_hi,
224
+ X_lo,
225
+ ops_meta,
226
+ ops_weights,
227
+ est_keep,
228
+ agg_segments,
229
+ agg_num_segments,
230
+ agg_weights,
231
+ H,
232
+ use_fwd,
233
+ )
234
+ return {
235
+ "estimate": np.asarray(estimate, dtype=np.float64),
236
+ "jacobian": np.asarray(jacobian, dtype=np.float64),
237
+ }